From 6c5df3dfed228851ce4e93731424f981751364ec Mon Sep 17 00:00:00 2001 From: kongziyi Date: Sat, 8 Nov 2025 11:44:10 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90master=E3=80=91=E3=80=90bugfix?= =?UTF-8?q?=E3=80=91=E4=BF=AE=E5=A4=8Dmcore=E8=AE=AD=E7=BB=83=E5=BC=80swap?= =?UTF-8?q?=E6=8A=A5=E9=94=99=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../training_graph/transformer/utils.py | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/mindformers/parallel_core/training_graph/transformer/utils.py b/mindformers/parallel_core/training_graph/transformer/utils.py index c44c100f7..46243be00 100644 --- a/mindformers/parallel_core/training_graph/transformer/utils.py +++ b/mindformers/parallel_core/training_graph/transformer/utils.py @@ -32,7 +32,7 @@ class AttnMaskFill(nn.Cell): """Applies a mask to attention scores by filling masked positions with a specified value.""" def __init__(self, config): - super(AttnMaskFill, self).__init__() + super().__init__() self.shape = P.Shape() self.reshape = P.Reshape() self.tile = P.Tile() @@ -87,9 +87,8 @@ def get_attn_mask_func(mask_func_type): Function, the attention mask function. """ if mask_func_type not in ATTNMASK_FUNC_MAP: - raise KeyError("Invalid attention mask function. Supported attention " - "mask function are ['attn_mask_fill'] " - ", but got {}.".format(mask_func_type)) + raise KeyError(f"Invalid attention mask function. Supported attention " + f"mask function are ['attn_mask_fill'], but got {mask_func_type}.") return ATTNMASK_FUNC_MAP[mask_func_type] @@ -142,7 +141,7 @@ class LayerSetting: try: use_pp_interleave = ms.get_auto_parallel_context("pipeline_interleave") except ValueError: - logger.warning(f"Current MindSpore version do not pipeline interleave. `pp_interleave_num` is set to 1.") + logger.warning("Current MindSpore version do not pipeline interleave. `pp_interleave_num` is set to 1.") use_pp_interleave = False self.num_layers = num_layers if stage_num != 0: @@ -208,7 +207,7 @@ class LayerSetting: if self.swap.swap: self.layer_swap = [] - self.op_swap = dict() + self.op_swap = {} self.layer_swap = self._initialize_swap_list(self.swap.layer_swap) for key in self.swap.op_swap: self.op_swap[key] = self._initialize_swap_list(self.swap.op_swap[key]) @@ -262,7 +261,7 @@ class LayerSetting: continue for layer_idx in recompute_layers: self._check_layer_swap_recompute_conflict(layer_idx, self.layer_swap, op_name) - if op_name in self.op_swap.keys(): + if op_name in self.op_swap: self._check_op_swap_recompute_conflict(op_name, recompute_layers) def _check_layer_swap_recompute_conflict(self, layer_idx, layer_list, op_name="All"): @@ -328,8 +327,8 @@ class LayerSetting: def _set_op_swap(self, layer, layer_id): """Set swap for operations in the layer based on patterns and layer_id.""" log_ops = [] - for pattern in self.op_swap: - for layer_swap in self.op_swap[pattern]: + for pattern, layer_swaps in self.op_swap.items(): + for layer_swap in layer_swaps: layers_id = layer_swap.get(self.layers) is_valid_bool = isinstance(layers_id, bool) and layers_id is_valid_list = isinstance(layers_id, list) and layer_id in layers_id @@ -384,22 +383,8 @@ class LayerSetting: return " " + ", ".join(log_list) return log - def set(self, layer, layer_id): - """Set pipeline stage and recompute for each layer with a layer_id.""" - pp_id = int(self.pp_ids[layer_id]) + self.start_stage - if self.is_zbv: - if self.interleave_ids[layer_id] == 1: - pp_id = self.pp - 1 - pp_id - layer.pipeline_segment = 1 - else: - layer.pipeline_segment = 0 - layer.pipeline_stage = pp_id - dis = max(int((self.num_layers + 1) / self.gradient_aggregation_group), 1) - if self.pp > 1: - layer.set_comm_fusion(2) - else: - layer.set_comm_fusion(int((layer_id + self.offset[0, 0]) / dis) + 1) - + def set_recompute(self, layer, layer_id): + """Set recompute for each layer with a layer_id.""" if isinstance(self.recompute, bool): if self.recompute: layer.recompute() @@ -423,6 +408,24 @@ class LayerSetting: self._set_select_recompute(layer, layer_id, False, set_on=False) self._set_select_recompute(layer, layer_id, True, set_on=False) + def set_pipeline_stage(self, layer, layer_id): + """Set pipeline stage for each layer with a layer_id.""" + pp_id = int(self.pp_ids[layer_id]) + self.start_stage + if self.is_zbv: + if self.interleave_ids[layer_id] == 1: + pp_id = self.pp - 1 - pp_id + layer.pipeline_segment = 1 + else: + layer.pipeline_segment = 0 + layer.pipeline_stage = pp_id + + def set_communication_fusion(self, layer, layer_id): + """Set communication fusion for each layer with a layer_id.""" + dis = max(int((self.num_layers + 1) / self.gradient_aggregation_group), 1) + if self.pp > 1: + layer.set_comm_fusion(2) + else: + layer.set_comm_fusion(int((layer_id + self.offset[0, 0]) / dis) + 1) def set_recompute_select_layer_index(self, layer, layer_id): """Set swap for specific layer based on its layer_id when.""" @@ -627,8 +630,10 @@ class LayerSetting: return log def __call__(self, layer, layer_id): + self.set_pipeline_stage(layer, layer_id) + self.set_communication_fusion(layer, layer_id) if self.swap.swap: self.set_recompute_select_layer_index(layer, layer_id) self.set_swap(layer, layer_id) else: - self.set(layer, layer_id) + self.set_recompute(layer, layer_id) -- Gitee