From 786621cb19ade07aad2d60acf7092a6a7b5a802f Mon Sep 17 00:00:00 2001 From: JavaZero <2487163254@qq.com> Date: Sat, 13 Dec 2025 17:24:34 +0800 Subject: [PATCH] Refactor Muon optimizer to remove micro_batch_num parameter and adjust logit threshold calculation. Update FlashAttention operations to replace AssignAdd with Assign and Maximum for improved functionality. Clean up BaseTrainer to eliminate micro_batch_num references in optimizer arguments. --- mindformers/core/callback/callback.py | 3 +- mindformers/core/optim/muon.py | 6 +- mindformers/parallel_core/mf_model_config.py | 2 +- .../base_models/gpt/gpt_model.py | 132 ++++++++-------- .../training_graph/communication.py | 147 ++++++++++++++++++ .../transformer/flash_attention.py | 40 +++-- .../parallel_core/transformer_config_utils.py | 2 +- mindformers/trainer/base_trainer.py | 2 - mindformers/trainer/trainer.py | 4 +- tests/st/test_ut/base_schema.json | 2 +- .../test_communication/test_communication.py | 143 +++++++++++++++++ .../test_core/test_optim/test_get_op_group.py | 38 +---- 12 files changed, 390 insertions(+), 131 deletions(-) create mode 100644 mindformers/parallel_core/training_graph/communication.py create mode 100644 tests/st/test_ut/test_core/test_communication/test_communication.py diff --git a/mindformers/core/callback/callback.py b/mindformers/core/callback/callback.py index feb94e28a..9c0c1f7c4 100644 --- a/mindformers/core/callback/callback.py +++ b/mindformers/core/callback/callback.py @@ -1280,8 +1280,7 @@ class TrainingStateMonitor(Callback): step = cb_params.cur_step_num vals = [] for param_name, param in params.items(): - v = param.asnumpy().squeeze() - v = v / max(1, self.micro_batch_num) + v = param.asnumpy() tag = f"max_attention_logit/{param_name}" if 'log' in self.max_attention_logit_format: diff --git a/mindformers/core/optim/muon.py b/mindformers/core/optim/muon.py index 4ebe5924a..af0239364 100644 --- a/mindformers/core/optim/muon.py +++ b/mindformers/core/optim/muon.py @@ -256,8 +256,7 @@ class Muon(Optimizer): ns_steps (int): Number of Newton-Schulz steps. Default: ``5``. adamw_betas (tuple): Beta parameters for AdamW. Default: ``(0.95, 0.95)``. adamw_eps (float): Epsilon for AdamW. Default: ``1e-8``. - micro_batch_num (int): Number of micro batches. Default: ``1``. - qk_clip_threshold (float): QK clip threshold. Default: ``4``. + qk_clip_threshold (float): QK clip threshold. Default: ``100``. model: The model model. Default: ``None``. """ @@ -272,7 +271,6 @@ class Muon(Optimizer): ns_steps=5, adamw_betas=(0.95, 0.95), adamw_eps=1e-8, - micro_batch_num=1, qk_clip_threshold=100, model=None, **kwargs, @@ -296,7 +294,7 @@ class Muon(Optimizer): self.ones = Tensor([1.0], mstype.float32) self.rank_id = get_rank() self.rank_ids = tuple(self.rank_id for _ in self._parameters) - self.logit_threshold = Tensor([qk_clip_threshold * micro_batch_num], dtype=mstype.float32) + self.logit_threshold = Tensor([qk_clip_threshold], dtype=mstype.float32) # Initialize Muon momentum self._initialize_muon_moments(model) diff --git a/mindformers/parallel_core/mf_model_config.py b/mindformers/parallel_core/mf_model_config.py index ab57f4561..1616d2fd2 100644 --- a/mindformers/parallel_core/mf_model_config.py +++ b/mindformers/parallel_core/mf_model_config.py @@ -262,7 +262,7 @@ class MFModelConfig: mask_func_type: str = "attn_mask_fill" """Mask function type to use for the attention layer.""" - monitor_max_attention_logit: bool = False + track_max_attention_logit: bool = False """Whether to monitor the maximum attention logit value during training.""" #################################################### diff --git a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py index 09a09f723..ecfae9b73 100644 --- a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py +++ b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py @@ -15,12 +15,11 @@ """mindformers GPT model""" __all__ = ['GPTModel'] -import hashlib from typing import Literal, Optional, Union import numpy as np import mindspore as ms -from mindspore.communication import create_group, get_group_size, get_rank +from mindspore.communication import get_group_size, get_rank from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops import auto_generate as aclnn_ops @@ -34,6 +33,12 @@ from mindformers.parallel_core.training_graph.loss_func import CrossEntropyLoss from mindformers.parallel_core.training_graph.transformer.multi_token_prediction import MultiTokenPredictionBlock, \ func_infer_dtype, func_infer_shape, func_infer_shape_labels_and_masks from mindformers.parallel_core.training_graph.device_matrix import layout +from mindformers.parallel_core.training_graph.communication import ( + compute_repeat_num_and_model_parallel_size, + get_cp_group_name, + get_dp_group_name, + get_op_group_name +) from mindformers.parallel_core.utils.spec_utils import ModuleSpec from mindformers.parallel_core.training_graph.transformer.mask_generate import CausalMaskGenerate from mindformers.parallel_core.transformer_config import TransformerConfig @@ -57,61 +62,6 @@ from mindformers.tools.logger import logger from mindformers.models.utils import get_current_rank_stage, get_model_parameters from mindformers.version_control import get_lazy_inline as lazy_inline from mindformers.core.optim.muon_utils import make_muon_fns -from mindformers.checkpoint.sharded_tensor import ShardedTensor - - -def compute_repeat_num_and_model_parallel_size(sharded_info: ShardedTensor, world_size: int, pp: int, op: int): - """Compute real op size.""" - axis_fragmentations = sharded_info.axis_fragmentations - flag = False - weight_sharded_size = 1 - for axis in axis_fragmentations: - if axis == 1: - continue - if flag: - raise ValueError("Only one axis can be fragmented in Muon optimizer.") - flag = True - weight_sharded_size *= axis - repeat_num = world_size // pp // weight_sharded_size - real_op_size = min(op, repeat_num) - if sharded_info.local_shape[0] % real_op_size != 0: - real_op_size = 1 - return real_op_size, weight_sharded_size - - -def create_communication_group(rank_list): - """ - Create a communication group with a hashed name. - - Args: - rank_list: List of ranks in the communication group - - Returns: - str: The created group name - """ - rank_list_str = "-".join([str(i) for i in rank_list]) - hashed = hashlib.md5(rank_list_str.encode()).hexdigest()[:48] - group_name = str(hashed) - create_group(group_name, rank_list) - return group_name - - -OP_GROUP_NAME = {} - - -def get_op_group_name(rank_id: int, real_op_size: int, model_parallel_size: int): - """Get op group name.""" - if (rank_id, real_op_size, model_parallel_size) in OP_GROUP_NAME: - return OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)] - dp_range = model_parallel_size - op_range = model_parallel_size * real_op_size - rank_start = rank_id % dp_range + rank_id // op_range * op_range - rank_end = rank_start + op_range - rank_list = list(range(rank_start, rank_end, dp_range)) - op_group_name = create_communication_group(rank_list) - OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)] = (op_group_name, rank_list) - return op_group_name, rank_list - class PreprocessLabelsAndMasks(nn.Cell): """Preprocess input_ids and generate labels and masks. @@ -316,6 +266,17 @@ class GPTModel(nn.Cell): initialize_model_parallel(tensor_model_parallel_size=self.tp, data_parallel_size=self.dp, pipeline_model_parallel_size=self.pp, context_parallel_size=self.cp) + if self.config.track_max_attention_logit: + self.rank_id = get_rank() + self.allreduce_max_in_dp = ( + None if self.dp == 1 + else P.AllReduce(op=P.ReduceOp.MAX, group=get_dp_group_name(self.rank_id, self.dp, self.tp, self.cp)[0]) + ) + self.allreduce_max_in_cp = ( + None if self.cp == 1 + else P.AllReduce(op=P.ReduceOp.MAX, group=get_cp_group_name(self.rank_id, self.dp, self.tp, self.cp)[0]) + ) + self.preprocess_labels_and_masks = PreprocessLabelsAndMasks(config) # Embeddings @@ -731,6 +692,32 @@ class GPTModel(nn.Cell): max_logits[f"{param_name}.max_logits_val"] = param return max_logits + def allreduce_max_attention_logit(self): + """ + Perform AllReduce-Max operation across DP and CP dimensions for max attention logits. + + This method aggregates the maximum attention logit values from all data parallel + and context parallel ranks to ensure consistent max logit values across the model. + """ + num_layers = self.config.num_layers + mtp_num_layers = 0 if self.config.mtp_num_layers is None else self.config.mtp_num_layers + + def _allreduce_max_param(max_logits): + param = max_logits.value() + if self.allreduce_max_in_dp is not None: + param = self.allreduce_max_in_dp(param) + if self.allreduce_max_in_cp is not None: + param = self.allreduce_max_in_cp(param) + self.assign(max_logits, param) + + for i in range(num_layers): + max_logits = self.decoder.layers[i].self_attention.core_attention.max_logits_val + _allreduce_max_param(max_logits) + + for i in range(mtp_num_layers): + max_logits = self.mtp.layers[i].transformer_layer.self_attention.core_attention.max_logits_val + _allreduce_max_param(max_logits) + def reset_max_attention_logit(self): """Reset max attention logit to zeros for all layers.""" for _, core_attn in self._iter_core_attentions(): @@ -783,6 +770,9 @@ class GPTModel(nn.Cell): layout("dp_cp", "tp"), ) ) + if self.config.track_max_attention_logit: + self.allreduce_max_in_dp.shard((layout("tp"),)) + self.allreduce_max_in_cp.shard((layout("tp"),)) def sharding_propagation(self, config: TransformerConfig): pass @@ -1047,6 +1037,7 @@ class GPTModel(nn.Cell): Returns: List of (param_idx, scaled_weights) tuples to be updated. """ + self.allreduce_max_attention_logit() if not self.config.multi_latent_attention: return [] ones = ms.Tensor([1.0], dtype.float32) @@ -1058,9 +1049,6 @@ class GPTModel(nn.Cell): scale_broadcast = ops.expand_dims(scale_broadcast, 1) return scale_broadcast - # Build param name to index mapping - param_idx_in_opt = {name: idx for idx, name in enumerate(param_names)} - updates = [] for idx, param_name in enumerate(param_names): if ( @@ -1076,16 +1064,24 @@ class GPTModel(nn.Cell): # Compute per-head scale factor logit_threshold_f32 = ops.cast(logit_threshold, dtype=dtype.float32) if layer_idx >= 0: - max_logits_name = (f"decoder.layers.{layer_idx}.self_attention." - "core_attention.max_logits_val") + logits_row = ( + self.decoder.layers[layer_idx] + .self_attention + .core_attention + .max_logits_val + .value() + ) else: - max_logits_name = (f"mtp.layers.{-(layer_idx+1)}.transformer_layer." - "self_attention.core_attention.max_logits_val") - - if max_logits_name not in param_idx_in_opt: - continue + logits_row = ( + self.mtp.layers[-(layer_idx + 1)] + .transformer_layer + .self_attention + .core_attention + .max_logits_val + .value() + ) - logits_row = params[param_idx_in_opt[max_logits_name]].reshape(-1) + logits_row = logits_row.reshape(-1) mask = ops.greater_equal(logits_row, logit_threshold_f32) safe_den = ops.where(mask, logits_row, ones) scales = ops.where(mask, logit_threshold_f32 / safe_den, ones) diff --git a/mindformers/parallel_core/training_graph/communication.py b/mindformers/parallel_core/training_graph/communication.py new file mode 100644 index 000000000..d3f0910be --- /dev/null +++ b/mindformers/parallel_core/training_graph/communication.py @@ -0,0 +1,147 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Communication utilities for parallel training.""" +import hashlib +from typing import Tuple, List + +from mindspore.communication import create_group + +from mindformers.checkpoint.sharded_tensor import ShardedTensor +from mindformers.tools.logger import logger + + +def compute_repeat_num_and_model_parallel_size(sharded_info: ShardedTensor, world_size: int, pp: int, op: int): + """Compute real op size.""" + axis_fragmentations = sharded_info.axis_fragmentations + flag = False + weight_sharded_size = 1 + for axis in axis_fragmentations: + if axis == 1: + continue + if flag: + raise ValueError("Only one axis can be fragmented in Muon optimizer.") + flag = True + weight_sharded_size *= axis + repeat_num = world_size // pp // weight_sharded_size + real_op_size = min(op, repeat_num) + if sharded_info.local_shape[0] % real_op_size != 0: + real_op_size = 1 + return real_op_size, weight_sharded_size + + +def create_communication_group(rank_list): + """ + Create a communication group with a hashed name. + + Args: + rank_list: List of ranks in the communication group + + Returns: + str: The created group name + """ + rank_list_str = "-".join([str(i) for i in rank_list]) + hashed = hashlib.md5(rank_list_str.encode()).hexdigest()[:48] + group_name = str(hashed) + create_group(group_name, rank_list) + return group_name + + +OP_GROUP_NAME = {} +CP_GROUP_NAME = {} +DP_GROUP_NAME = {} + + +def get_cp_group_name(rank_id: int, dp: int, tp: int, cp: int) -> Tuple[str, List[int]]: + """ + Get the CP (Context Parallel) communication group name and rank list. + + Under the rank encoding where DP is the highest bit, CP is the second bit, + and TP is the lowest bit, return the CP communication domain rank_list for the current rank. + + Args: + rank_id (int): Current rank ID. + dp (int): Data parallel size. + tp (int): Tensor parallel size. + cp (int): Context parallel size. + + Returns: + Tuple[str, List[int]]: Communication group name and rank list. + """ + cache_key = (rank_id, dp, tp, cp) + if cache_key in CP_GROUP_NAME: + return CP_GROUP_NAME[cache_key] + + pp_block = dp * cp * tp + inner = cp * tp + + pp_base = (rank_id // pp_block) * pp_block + local = rank_id % pp_block + + dp_id = local // inner + tp_id = local % tp + + base = pp_base + dp_id * inner + tp_id + rank_list = [base + c * tp for c in range(cp)] + logger.info(f"Get cp rank list: {rank_list}") + result = (create_communication_group(rank_list), rank_list) + CP_GROUP_NAME[cache_key] = result + return result + + +def get_dp_group_name(rank_id: int, dp: int, tp: int, cp: int) -> Tuple[str, List[int]]: + """ + Get the DP (Data Parallel) communication group name and rank list. + + Under the rank encoding where DP is the highest bit, CP is the second bit, + and TP is the lowest bit, return the DP communication domain rank_list for the current rank. + + Args: + rank_id (int): Current rank ID. + dp (int): Data parallel size. + tp (int): Tensor parallel size. + cp (int): Context parallel size. + + Returns: + Tuple[str, List[int]]: Communication group name and rank list. + """ + cache_key = (rank_id, dp, tp, cp) + if cache_key in DP_GROUP_NAME: + return DP_GROUP_NAME[cache_key] + + pp_block = dp * cp * tp + inner = cp * tp + + pp_base = (rank_id // pp_block) * pp_block + local = rank_id % pp_block + base_low = local % inner + rank_list = [pp_base + base_low + d * inner for d in range(dp)] + logger.info(f"Get dp rank list: {rank_list}") + result = (create_communication_group(rank_list), rank_list) + DP_GROUP_NAME[cache_key] = result + return result + + +def get_op_group_name(rank_id: int, real_op_size: int, model_parallel_size: int): + """Get op group name.""" + if (rank_id, real_op_size, model_parallel_size) in OP_GROUP_NAME: + return OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)] + dp_range = model_parallel_size + op_range = model_parallel_size * real_op_size + rank_start = rank_id % dp_range + rank_id // op_range * op_range + rank_end = rank_start + op_range + rank_list = list(range(rank_start, rank_end, dp_range)) + op_group_name = create_communication_group(rank_list) + OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)] = (op_group_name, rank_list) + return op_group_name, rank_list diff --git a/mindformers/parallel_core/training_graph/transformer/flash_attention.py b/mindformers/parallel_core/training_graph/transformer/flash_attention.py index d308b134a..785038ff1 100644 --- a/mindformers/parallel_core/training_graph/transformer/flash_attention.py +++ b/mindformers/parallel_core/training_graph/transformer/flash_attention.py @@ -173,17 +173,23 @@ class FlashAttention(Cell): self.reshape = aclnn_ops.Reshape() self.fa_out_transpose = aclnn_ops.Transpose() - self.monitor_max_attention_logit = self.config.monitor_max_attention_logit + self.track_max_attention_logit = self.config.track_max_attention_logit - if self.monitor_max_attention_logit: + if self.track_max_attention_logit: + # Parameter to store the maximum attention logit value per head. + # Note: This is a local max within each device's partition. Cross-device + # synchronization (AllReduce-Max across DP/CP dimensions) is performed + # later in GPTModel.allreduce_max_attention_logit() to obtain the global max. self.max_logits_val = Parameter( Tensor(np.zeros((self.head_num)), dtype=mstype.float32), parallel_optimizer=False, requires_grad=False ) self.reduce_max = aclnn_ops.ReduceMax() self.reduce_max.add_prim_attr("self_define_shard", True) - self.assign_add = ops.AssignAdd() - self.assign_add.add_prim_attr("self_define_shard", True) + self.assign = ops.Assign() + self.assign.add_prim_attr("self_define_shard", True) + self.maximum = ops.Maximum() + self.maximum.add_prim_attr("self_define_shard", True) if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): self.sharding_propagation(config) @@ -300,9 +306,10 @@ class FlashAttention(Cell): prefix, actual_seq_qlen, actual_seq_kvlen) - if self.monitor_max_attention_logit: + if self.track_max_attention_logit: max_logits = self.reduce_max(softmax_val, (0, 2)) - output = F.depend(output, self.assign_add(self.max_logits_val, max_logits)) + # Update local maximum; global sync happens in GPTModel.allreduce_max_attention_logit() + self.assign(self.max_logits_val, self.maximum(self.max_logits_val, max_logits)) return output q_seq_len, bsz = query.shape[:2] @@ -335,9 +342,10 @@ class FlashAttention(Cell): padding_mask, attention_mask, prefix) - if self.monitor_max_attention_logit: + if self.track_max_attention_logit: max_logits = self.reduce_max(softmax_val, (0, 2, 3)) - output = F.depend(output, self.assign_add(self.max_logits_val, max_logits)) + # Update local maximum; global sync happens in GPTModel.allreduce_max_attention_logit() + self.assign(self.max_logits_val, self.maximum(self.max_logits_val, max_logits)) if self.input_layout == "BNSD": output = self._merge_heads(output) @@ -379,19 +387,23 @@ class FlashAttention(Cell): if self.use_alibi_mask: self.alibi_rescale_mul.shard(((dp, tp, cp, 1), (1,))) - if self.monitor_max_attention_logit: - self.assign_add.shard( + if self.track_max_attention_logit: + self.assign.shard( in_strategy=(layout("tp"), layout("tp")), out_strategy=(layout("tp"),) ) - if self.input_layout == "BNSD": + self.maximum.shard( + in_strategy=(layout("tp"), layout("tp")), + out_strategy=(layout("tp"),) + ) + if self.input_layout == "TND": self.reduce_max.shard( - in_strategy=(layout("None", "tp", "None", "None"),), + in_strategy=(layout("dp_cp", "tp", "None"),), out_strategy=(layout("tp"),) ) - elif self.input_layout == "TND": + else: self.reduce_max.shard( - in_strategy=(layout("None", "tp", "None"),), + in_strategy=(layout("dp", "tp", "cp", "None"),), out_strategy=(layout("tp"),) ) return self diff --git a/mindformers/parallel_core/transformer_config_utils.py b/mindformers/parallel_core/transformer_config_utils.py index 91bb45956..a7bfb7d48 100644 --- a/mindformers/parallel_core/transformer_config_utils.py +++ b/mindformers/parallel_core/transformer_config_utils.py @@ -303,7 +303,7 @@ COMMON_CONFIG_MAPPING = { "hidden_act": "hidden_act", "mask_func_type": "mask_func_type", "param_init_std_rules": "param_init_std_rules", - "monitor_max_attention_logit": "monitor_max_attention_logit", + "track_max_attention_logit": "track_max_attention_logit", ("extend_method", "position_embedding_type"): "position_embedding_type", ("init_method_std", "initializer_range"): "init_method_std", diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 0672c679b..ce98b2fdb 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -645,7 +645,6 @@ class BaseTrainer: if self.lr_scheduler is not None: default_args = {"params": group_params, "learning_rate": self.lr_scheduler} if self.config.optimizer.type == "Muon": - default_args["micro_batch_num"] = self.config.parallel_config.micro_batch_num default_args["model"] = None if not hasattr(self, 'real_model') else self.real_model self.optimizer = build_optim( self.config.optimizer, @@ -662,7 +661,6 @@ class BaseTrainer: default_args = {"params": group_params} if self.config.optimizer.type == "Muon": - default_args["micro_batch_num"] = self.config.parallel_config.micro_batch_num default_args["model"] = None if not hasattr(self, 'real_model') else self.real_model # Build optimizer with fixed learning rate self.optimizer = build_optim( diff --git a/mindformers/trainer/trainer.py b/mindformers/trainer/trainer.py index 9042fe166..45fbcec3e 100644 --- a/mindformers/trainer/trainer.py +++ b/mindformers/trainer/trainer.py @@ -272,7 +272,7 @@ class Trainer: self.config = self._config_init(args, task_config) if self.config.get('optimizer') and self.config.optimizer.type == "Muon": - self.config.model.model_config.monitor_max_attention_logit = True + self.config.model.model_config.track_max_attention_logit = True self._reassign_monitor_config() # build parallel config build_parallel_config(self.config) @@ -347,7 +347,7 @@ class Trainer: dump_device_local_norm=bool(monitor_config.get('device_local_norm_format')) ) if monitor_config.max_attention_logit_format: - self.config.model.model_config.monitor_max_attention_logit = True + self.config.model.model_config.track_max_attention_logit = True if monitor_config.local_loss_format: set_context(monitor_local_loss=True) if monitor_config.device_local_loss_format: diff --git a/tests/st/test_ut/base_schema.json b/tests/st/test_ut/base_schema.json index deb984083..38ec577b3 100644 --- a/tests/st/test_ut/base_schema.json +++ b/tests/st/test_ut/base_schema.json @@ -1065,7 +1065,7 @@ "signature": "(use_fused)" }, "mindformers.core.optim.Muon": { - "signature": "(params, learning_rate=0.02, weight_decay=0.1, matched_adamw_rms=0.2, momentum=0.95, nesterov=True, ns_steps=5, adamw_betas=(0.95, 0.95), adamw_eps=1e-08, micro_batch_num=1, qk_clip_threshold=100, model=None, **kwargs)" + "signature": "(params, learning_rate=0.02, weight_decay=0.1, matched_adamw_rms=0.2, momentum=0.95, nesterov=True, ns_steps=5, adamw_betas=(0.95, 0.95), adamw_eps=1e-08, qk_clip_threshold=100, model=None, **kwargs)" }, "mindformers.core.optim.Muon._verify_model": { "signature": "(self, model)" diff --git a/tests/st/test_ut/test_core/test_communication/test_communication.py b/tests/st/test_ut/test_core/test_communication/test_communication.py new file mode 100644 index 000000000..b1c159721 --- /dev/null +++ b/tests/st/test_ut/test_core/test_communication/test_communication.py @@ -0,0 +1,143 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test communication utilities for parallel training.""" + +from unittest.mock import patch + +import pytest + +from mindformers.parallel_core.training_graph import communication +from mindformers.parallel_core.training_graph.communication import ( + get_op_group_name, + get_cp_group_name, + get_dp_group_name, +) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@patch("mindformers.parallel_core.training_graph.communication.create_communication_group") +def test_get_op_group_name_with_mock(mock_create_group): + """ + Feature: get_op_group_name() + Description: Test the get op group name with mock. + Expectation: The get op group name with mock should be correct. + """ + mock_create_group.return_value = "mock_group" + communication.OP_GROUP_NAME.clear() + + # case 0: model_parallel_size > 1 + result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=2) + assert result == ("mock_group", [1, 3]) + mock_create_group.assert_called_once_with([1, 3]) + + second_result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=2) + assert second_result == result + mock_create_group.assert_called_once() + + # case 1: model_parallel_size = 1 + result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=1) + assert result == ("mock_group", [2, 3]) + + # case 2: model_parallel_size = 4 + result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=4) + assert result == ("mock_group", [3, 7]) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@patch("mindformers.parallel_core.training_graph.communication.create_communication_group") +def test_get_cp_group_name_with_mock(mock_create_group): + """ + Feature: get_cp_group_name() + Description: Test the get cp group name with mock. + Expectation: The get cp group name with mock should be correct. + + For pp=2, dp=2, cp=2, tp=2 (16 cards): + PP stage 0: ranks 0-7 + PP stage 1: ranks 8-15 (mirrors stage 0) + + Within each PP stage, layout is same as 8-card case: + rank | dp | cp | tp + 0/8 | 0 | 0 | 0 + 1/9 | 0 | 0 | 1 + 2/10 | 0 | 1 | 0 + 3/11 | 0 | 1 | 1 + 4/12 | 1 | 0 | 0 + 5/13 | 1 | 0 | 1 + 6/14 | 1 | 1 | 0 + 7/15 | 1 | 1 | 1 + + CP group: ranks with same (dp, tp), different cp + """ + mock_create_group.return_value = "mock_group" + communication.CP_GROUP_NAME.clear() + + # PP stage 0 (8 cards): dp=2, tp=2, cp=2 + assert get_cp_group_name(rank_id=0, dp=2, tp=2, cp=2) == ("mock_group", [0, 2]) + assert get_cp_group_name(rank_id=1, dp=2, tp=2, cp=2) == ("mock_group", [1, 3]) + assert get_cp_group_name(rank_id=4, dp=2, tp=2, cp=2) == ("mock_group", [4, 6]) + assert get_cp_group_name(rank_id=5, dp=2, tp=2, cp=2) == ("mock_group", [5, 7]) + + # PP stage 1 (16 cards): ranks 8-15 mirror stage 0 + assert get_cp_group_name(rank_id=8, dp=2, tp=2, cp=2) == ("mock_group", [8, 10]) + assert get_cp_group_name(rank_id=9, dp=2, tp=2, cp=2) == ("mock_group", [9, 11]) + assert get_cp_group_name(rank_id=12, dp=2, tp=2, cp=2) == ("mock_group", [12, 14]) + assert get_cp_group_name(rank_id=13, dp=2, tp=2, cp=2) == ("mock_group", [13, 15]) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@patch("mindformers.parallel_core.training_graph.communication.create_communication_group") +def test_get_dp_group_name_with_mock(mock_create_group): + """ + Feature: get_dp_group_name() + Description: Test the get dp group name with mock. + Expectation: The get dp group name with mock should be correct. + + For pp=2, dp=2, cp=2, tp=2 (16 cards): + PP stage 0: ranks 0-7 + PP stage 1: ranks 8-15 (mirrors stage 0) + + Within each PP stage, layout is same as 8-card case: + rank | dp | cp | tp + 0/8 | 0 | 0 | 0 + 1/9 | 0 | 0 | 1 + 2/10 | 0 | 1 | 0 + 3/11 | 0 | 1 | 1 + 4/12 | 1 | 0 | 0 + 5/13 | 1 | 0 | 1 + 6/14 | 1 | 1 | 0 + 7/15 | 1 | 1 | 1 + + DP group: ranks with same (cp, tp), different dp + """ + mock_create_group.return_value = "mock_group" + communication.DP_GROUP_NAME.clear() + + # PP stage 0 (8 cards): dp=2, tp=2, cp=2 + assert get_dp_group_name(rank_id=0, dp=2, tp=2, cp=2) == ("mock_group", [0, 4]) + assert get_dp_group_name(rank_id=1, dp=2, tp=2, cp=2) == ("mock_group", [1, 5]) + assert get_dp_group_name(rank_id=2, dp=2, tp=2, cp=2) == ("mock_group", [2, 6]) + assert get_dp_group_name(rank_id=3, dp=2, tp=2, cp=2) == ("mock_group", [3, 7]) + + # PP stage 1 (16 cards): ranks 8-15 mirror stage 0 + assert get_dp_group_name(rank_id=8, dp=2, tp=2, cp=2) == ("mock_group", [8, 12]) + assert get_dp_group_name(rank_id=9, dp=2, tp=2, cp=2) == ("mock_group", [9, 13]) + assert get_dp_group_name(rank_id=10, dp=2, tp=2, cp=2) == ("mock_group", [10, 14]) + assert get_dp_group_name(rank_id=11, dp=2, tp=2, cp=2) == ("mock_group", [11, 15]) diff --git a/tests/st/test_ut/test_core/test_optim/test_get_op_group.py b/tests/st/test_ut/test_core/test_optim/test_get_op_group.py index 05728ab51..6fe55f4f2 100644 --- a/tests/st/test_ut/test_core/test_optim/test_get_op_group.py +++ b/tests/st/test_ut/test_core/test_optim/test_get_op_group.py @@ -14,18 +14,15 @@ # ============================================================================ """Test get op groups info for GPT model.""" -from unittest.mock import patch - import mindspore as ms import pytest from mindformers import build_context from mindformers.checkpoint.sharded_tensor import build_sharded_tensor -from mindformers.parallel_core.training_graph.base_models.gpt import gpt_model from mindformers.parallel_core.training_graph.base_models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, \ get_gpt_mtp_block_spec -from mindformers.parallel_core.training_graph.base_models.gpt.gpt_model import GPTModel, \ - compute_repeat_num_and_model_parallel_size, get_op_group_name +from mindformers.parallel_core.training_graph.base_models.gpt.gpt_model import GPTModel +from mindformers.parallel_core.training_graph.communication import compute_repeat_num_and_model_parallel_size from mindformers.parallel_core.transformer_config import TransformerConfig @@ -145,34 +142,3 @@ def test_compute_repeat_num_and_model_parallel_size_multiple_axis_error(): sharded_info = build_sharded_info((8, 8), (2, 2)) with pytest.raises(ValueError): compute_repeat_num_and_model_parallel_size(sharded_info, world_size=16, pp=1, op=2) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -@patch("mindformers.parallel_core.training_graph.base_models.gpt.gpt_model.create_communication_group") -def test_get_op_group_name_with_mock(mock_create_group): - """ - Feature: get_op_group_name() - Description: Test the get op group name with mock. - Expectation: The get op group name with mock should be correct. - """ - mock_create_group.return_value = "mock_group" - gpt_model.OP_GROUP_NAME.clear() - - # case 0: model_parallel_size > 1 - result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=2) - assert result == ("mock_group", [1, 3]) - mock_create_group.assert_called_once_with([1, 3]) - - second_result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=2) - assert second_result == result - mock_create_group.assert_called_once() - - # case 1: model_parallel_size = 1 - result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=1) - assert result == ("mock_group", [2, 3]) - - # case 2: model_parallel_size = 4 - result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=4) - assert result == ("mock_group", [3, 7]) -- Gitee