diff --git a/mindformers/core/optim/muon.py b/mindformers/core/optim/muon.py index b6a61f672296b96d1e5981ab4bd4ca04843eb051..0b65b44424f58625ac5db7b7d31ab1a7e2ed624d 100644 --- a/mindformers/core/optim/muon.py +++ b/mindformers/core/optim/muon.py @@ -442,9 +442,7 @@ class Muon(Optimizer): def _initialize_op_groups(self, model): """Initialize optimizer parallel groups for parameters.""" - self.ops, self.op_groups = model.get_op_groups_info( - self._parameters, self.op, self.op_group, self.op_in_tp_group - ) + self.ops, self.op_groups = model.get_op_groups_info(self._parameters, self.op) def _create_communication_group(self, rank_list): """ diff --git a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py index 3ea7260200ea324cb84dd2e46955807ee3424f6c..f81fab230c2acba5c9c4929985fc49f3d71d04d9 100644 --- a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py +++ b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py @@ -15,10 +15,12 @@ """mindformers GPT model""" __all__ = ['GPTModel'] +import hashlib from typing import Literal, Optional, Union import numpy as np import mindspore as ms +from mindspore.communication import create_group, get_group_size, get_rank from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops import auto_generate as aclnn_ops @@ -55,6 +57,60 @@ from mindformers.tools.logger import logger from mindformers.models.utils import get_current_rank_stage, get_model_parameters from mindformers.version_control import get_lazy_inline as lazy_inline from mindformers.core.optim.muon_utils import make_muon_fns +from mindformers.checkpoint.sharded_tensor import ShardedTensor + + +def compute_repeat_num_and_model_parallel_size(sharded_info: ShardedTensor, world_size: int, pp: int, op: int): + """Compute real op size.""" + axis_fragmentations = sharded_info.axis_fragmentations + flag = False + weight_sharded_size = 1 + for axis in axis_fragmentations: + if axis == 1: + continue + if flag: + raise ValueError("Only one axis can be fragmented in Muon optimizer.") + flag = True + weight_sharded_size *= axis + repeat_num = world_size // pp // weight_sharded_size + real_op_size = min(op, repeat_num) + if sharded_info.local_shape[0] % real_op_size != 0: + real_op_size = 1 + return real_op_size, weight_sharded_size + + +def create_communication_group(rank_list): + """ + Create a communication group with a hashed name. + + Args: + rank_list: List of ranks in the communication group + + Returns: + str: The created group name + """ + rank_list_str = "-".join([str(i) for i in rank_list]) + hashed = hashlib.md5(rank_list_str.encode()).hexdigest()[:48] + group_name = str(hashed) + create_group(group_name, rank_list) + return group_name + + +OP_GROUP_NAME = {} + + +def get_op_group_name(rank_id: int, real_op_size: int, model_parallel_size: int): + """Get op group name.""" + if (rank_id, real_op_size, model_parallel_size) in OP_GROUP_NAME: + return OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)] + dp_range = model_parallel_size + op_range = model_parallel_size * real_op_size + rank_start = rank_id % dp_range + rank_id // op_range * op_range + rank_end = rank_start + op_range + rank_list = list(range(rank_start, rank_end, dp_range)) + op_group_name = create_communication_group(rank_list) + OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)] = (op_group_name, rank_list) + return op_group_name, rank_list class PreprocessLabelsAndMasks(nn.Cell): @@ -704,6 +760,14 @@ class GPTModel(nn.Cell): def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Get all sharded state dict.""" + sharded_state_dict = {} + for _, sub_cell in self.cells_and_names(): + if sub_cell != self and hasattr(sub_cell, "sharded_state_dict"): + sharded_state_dict.update(sub_cell.sharded_state_dict()) + return sharded_state_dict + def get_model_parameters(self): """Get current rank trainable parameters in gpt model .""" params = set() @@ -830,7 +894,7 @@ class GPTModel(nn.Cell): tp_dims.append(0) return tuple(tp_dims) - def get_op_groups_info(self, params, op, op_group, op_in_tp_group): + def get_op_groups_info(self, params, op): """Return optimizer parallel group information for each parameter. Args: @@ -849,16 +913,14 @@ class GPTModel(nn.Cell): "self_attention.linear_q_down_proj.weight", "self_attention.linear_kv_up_proj.weight", "self_attention.linear_kv_down_proj.weight", - "eh_proj" - ] - - use_tp_group_list = [ - "mlp.router.weight", - "mlp.shared_experts.linear_fc", - "self_attention.linear_q_down_proj.weight", "eh_proj", + "max_logits_val" ] + sharded_state_dict = self.sharded_state_dict() + world_size = get_group_size() + pp = self.config.pipeline_model_parallel_size + def name_filter(param_name, full_name_list): for full_name in full_name_list: if full_name in param_name: @@ -878,13 +940,22 @@ class GPTModel(nn.Cell): logger.warning( f"Parameter {param.name}: parallel_optimizer was set to False due to the use of Muon optimizer." ) - else: - op_list.append(op) + continue + + # compute real op size + sharded_info = sharded_state_dict.get(param.name) + real_op_size, weight_sharded_size = compute_repeat_num_and_model_parallel_size(sharded_info, world_size, pp, + op) + if real_op_size == 1: + op_list.append(1) + op_groups.append("") + logger.info(f"Parameter {param.name} : No op group.") + continue - if name_filter(param.name, use_tp_group_list): - op_groups.append(op_in_tp_group) - else: - op_groups.append(op_group) + op_list.append(real_op_size) + op_group_name, rank_list = get_op_group_name(get_rank(), real_op_size, weight_sharded_size) + logger.info(f"Parameter {param.name} : Muon op group list is: {rank_list}") + op_groups.append(op_group_name) return tuple(op_list), tuple(op_groups) diff --git a/mindformers/parallel_core/training_graph/tensor_parallel/layers.py b/mindformers/parallel_core/training_graph/tensor_parallel/layers.py index 9979dabc3d40b36546225f6204766d4df2ee8029..5bba189dbb947d5842ef6afc92f161ab19daec18 100644 --- a/mindformers/parallel_core/training_graph/tensor_parallel/layers.py +++ b/mindformers/parallel_core/training_graph/tensor_parallel/layers.py @@ -36,6 +36,7 @@ from mindspore.ops.operations import Morph from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation from mindspore import mint +from mindformers.checkpoint.sharded_tensor import ShardedTensor from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.parallel_core.utils.init_method import init_method_zero from mindformers.parallel_core.inference.utils import divide @@ -207,6 +208,28 @@ class VocabParallelEmbedding(nn.Cell): def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + sharded_state_dict = {} + weight_shape = (self.num_embeddings, self.embedding_dim) + + if self.enable_embedding_tp: + axis_fragmentations = (self.tp, 1) + local_shape = (self.num_embeddings // self.tp, self.embedding_dim) + else: + axis_fragmentations = (1, 1) + local_shape = (self.num_embeddings, self.embedding_dim) + + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=local_shape, + global_shape=weight_shape, + global_offset=(0, 0), + axis_fragmentations=axis_fragmentations) + return sharded_state_dict + class ColumnParallelLinear(nn.Cell): """Linear layer with column parallelism. @@ -427,6 +450,33 @@ class ColumnParallelLinear(nn.Cell): matmul_in_strategy = ((dp * cp, 1), weight_strategy) self.matmul.shard(in_strategy=matmul_in_strategy) + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + tp = self.config.tensor_model_parallel_size + + sharded_state_dict = {} + weight_shape = (self.output_size, self.input_size) + local_shape = (self.output_size // tp, self.input_size) + if not self.skip_weight_param_allocation: + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=local_shape, + global_shape=weight_shape, + global_offset=(0, 0), + axis_fragmentations=(tp, 1)) + if self.has_bias: + sharded_state_dict[self.bias.name] = ShardedTensor( + key=self.bias.name, + org_key=self.bias.name, + dtype=self.bias.dtype, + local_shape=(self.output_size,), + global_shape=(self.output_size,), + global_offset=(0,), + axis_fragmentations=(1,)) + return sharded_state_dict + class RowParallelLinear(nn.Cell): """Linear layer with row parallelism. @@ -663,6 +713,33 @@ class RowParallelLinear(nn.Cell): matmul_in_strategy = ((dp * cp, tp), weight_strategy) self.matmul.shard(in_strategy=matmul_in_strategy) + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + tp = self.config.tensor_model_parallel_size + sharded_state_dict = {} + weight_shape = (self.output_size, self.input_size) + local_shape = (self.output_size, self.input_size // tp) + + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=local_shape, + global_shape=weight_shape, + global_offset=(0, 0), + axis_fragmentations=(1, tp)) + + if self.has_bias: + sharded_state_dict[self.bias.name] = ShardedTensor( + key=self.bias.name, + org_key=self.bias.name, + dtype=self.bias.dtype, + local_shape=(self.output_size,), + global_shape=(self.output_size,), + global_offset=(0,), + axis_fragmentations=(1,)) + return sharded_state_dict + class LinearNoTP(ColumnParallelLinear): """Linear layer without tensor parallelism. @@ -712,6 +789,32 @@ class LinearNoTP(ColumnParallelLinear): ) ) + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + sharded_state_dict = {} + weight_shape = (self.output_size, self.input_size) + local_shape = (self.output_size, self.input_size) + + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=local_shape, + global_shape=weight_shape, + global_offset=(0, 0), + axis_fragmentations=(1, 1)) + + if self.has_bias: + sharded_state_dict[self.bias.name] = ShardedTensor( + key=self.bias.name, + org_key=self.bias.name, + dtype=self.bias.dtype, + local_shape=(self.output_size,), + global_shape=(self.output_size,), + global_offset=(0,), + axis_fragmentations=(1,)) + return sharded_state_dict + class SequenceParallelLinear(ColumnParallelLinear): """Linear layer without tensor parallelism. @@ -761,3 +864,29 @@ class SequenceParallelLinear(ColumnParallelLinear): layout(("cp", "tp"), "dp", "None"), # output [S, B, H] ) ) + + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + sharded_state_dict = {} + weight_shape = (self.output_size, self.input_size) + local_shape = (self.output_size, self.input_size) + + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=local_shape, + global_offset=(0, 0), + global_shape=weight_shape, + axis_fragmentations=(1, 1)) + + if self.has_bias: + sharded_state_dict[self.bias.name] = ShardedTensor( + key=self.bias.name, + org_key=self.bias.name, + dtype=self.bias.dtype, + local_shape=(self.output_size,), + global_offset=(0,), + global_shape=(self.output_size,), + axis_fragmentations=(1,)) + return sharded_state_dict diff --git a/mindformers/parallel_core/training_graph/transformer/moe/ffn.py b/mindformers/parallel_core/training_graph/transformer/moe/ffn.py index 6ff7e9d6a45a1b243a5504597992bdb6f77f82d8..c4830af1d954f5b79e150fb365cb3a2b6ba1153e 100644 --- a/mindformers/parallel_core/training_graph/transformer/moe/ffn.py +++ b/mindformers/parallel_core/training_graph/transformer/moe/ffn.py @@ -23,6 +23,7 @@ from mindspore.ops.auto_generate import Shape, Cast, GroupedMatmul, Reshape, Swi from mindspore.ops.operations import Morph from mindspore.parallel._utils import _get_parallel_mode +from mindformers.checkpoint.sharded_tensor import ShardedTensor from mindformers.parallel_core.training_graph.device_matrix import layout_moe as layout from mindformers.parallel_core.training_graph.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher, MoEAlltoAllDeredundencyTokenDispatcher, MoEAlltoAllZeroRedundancyTokenDispatcher from mindformers.parallel_core.transformer_config import TransformerConfig @@ -215,3 +216,27 @@ class FFNGroupedGEMM(nn.Cell): layout(dp, sp, mp0), # output [B, S, h] ) ) + + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + ep = self.config.expert_model_parallel_size + sharded_state_dict = {} + sharded_state_dict[self.weight1.name] = ShardedTensor( + key=self.weight1.name, + org_key=self.weight1.name, + dtype=self.weight1.dtype, + local_shape=(self.num_local_experts // ep * self.hidden_size, self.moe_ffn_hidden_size * 2), + global_shape=(self.num_local_experts * self.hidden_size, self.moe_ffn_hidden_size * 2), + global_offset=(0, 0), + axis_fragmentations=(ep, 1), + ) + sharded_state_dict[self.weight2.name] = ShardedTensor( + key=self.weight2.name, + org_key=self.weight2.name, + dtype=self.weight2.dtype, + local_shape=(self.num_local_experts // ep * self.moe_ffn_hidden_size, self.hidden_size), + global_shape=(self.num_local_experts * self.moe_ffn_hidden_size, self.hidden_size), + global_offset=(0, 0), + axis_fragmentations=(ep, 1), + ) + return sharded_state_dict diff --git a/mindformers/parallel_core/training_graph/transformer/moe/router.py b/mindformers/parallel_core/training_graph/transformer/moe/router.py index 87f682e16d51515ed43442f9a31f734bf6e9bf73..c9bbf817ad35453a120ee789f4dce1b18c62672f 100644 --- a/mindformers/parallel_core/training_graph/transformer/moe/router.py +++ b/mindformers/parallel_core/training_graph/transformer/moe/router.py @@ -26,6 +26,7 @@ from mindspore.ops.auto_generate import AddExt, AssignAdd, Cast, Div, Mul, Resha from mindspore.ops.operations import Shape, ReduceSum, ReduceMean from mindspore.parallel._utils import _get_parallel_mode +from mindformers.checkpoint.sharded_tensor import ShardedTensor from mindformers.parallel_core.training_graph.device_matrix import layout_moe as layout from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.tools.utils import get_real_group_size, get_real_rank @@ -117,6 +118,20 @@ class Router(ABC, nn.Cell): router_logits = self.linear(inputs.astype(self.moe_router_dtype), weight) return router_logits + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + sharded_state_dict = {} + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=(self.expert_dim, self.hidden_size), + global_shape=(self.expert_dim, self.hidden_size), + global_offset=(0, 0), + axis_fragmentations=(1, 1), + ) + return sharded_state_dict + class TopKRouter(Router): """Route each token to the top-k experts.""" diff --git a/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py b/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py index e00c7f27e09c930341aca554cd3a390d5ed15328..09ff7100a9bea2ae3c0229cb888d962d744faee4 100644 --- a/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py +++ b/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py @@ -23,6 +23,7 @@ from mindspore.ops.auto_generate import Cast, Mul, Sigmoid from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation from mindspore.context import ParallelMode +from mindformers.checkpoint.sharded_tensor import ShardedTensor from mindformers.parallel_core.training_graph.transformer.mlp import MLP, MLPSubmodules, MLPInterleaved from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.parallel_core.training_graph.device_matrix import layout @@ -108,6 +109,20 @@ class SharedExpertMLP(MLP): def expert_sharding_propagation(self, config: TransformerConfig): super().sharding_propagation(config) + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + sharded_state_dict = {} + sharded_state_dict[self.shared_experts_gate.weight.name] = ShardedTensor( + key=self.shared_experts_gate.weight.name, + org_key=self.shared_experts_gate.weight.name, + dtype=self.shared_experts_gate.weight.dtype, + local_shape=(1, self.hidden_size), + global_shape=(1, self.hidden_size), + global_offset=(0, 0), + axis_fragmentations=(1, 1), + ) + return sharded_state_dict + class SharedExpertMLPInterleaved(MLPInterleaved): """ diff --git a/mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py b/mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py index a04f0ff112dcca6da4d36394d1291c16b4bfa3a5..e0dde20f18f66be5026d48ef73900fe78d7ac9f1 100644 --- a/mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py +++ b/mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py @@ -541,6 +541,9 @@ class MtpSharedVocabParallelEmbedding(VocabParallelEmbedding): output = self.embedding_morph(input_ids, weight) return output + def sharded_state_dict(self): + return {} + class MtpSharedLanguageModelEmbedding(LanguageModelEmbedding): """Embedding layer used in Multi-Token Prediction module, same to standard LanguageModelEmbedding.""" diff --git a/mindformers/parallel_core/training_graph/transformer/norm.py b/mindformers/parallel_core/training_graph/transformer/norm.py index 1003a2f760296908ae11adad134f3d3225fc595f..c9a2081f0b331c818cf0b69e9c016f9118125cf1 100644 --- a/mindformers/parallel_core/training_graph/transformer/norm.py +++ b/mindformers/parallel_core/training_graph/transformer/norm.py @@ -23,6 +23,7 @@ from mindspore.ops.auto_generate import MeanExt, Sqrt, Rsqrt, SubExt, AddExt, Mu from mindspore.common.initializer import initializer from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation +from mindformers.checkpoint.sharded_tensor import ShardedTensor from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.parallel_core.training_graph.device_matrix import layout @@ -55,6 +56,7 @@ class LayerNorm(nn.Cell): super().__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype + self.dim = dim self.gamma = Parameter(initializer('ones', dim, self.params_dtype), name="gamma", parallel_optimizer=False) @@ -115,6 +117,29 @@ class LayerNorm(nn.Cell): def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Return sharding metadata for LayerNorm parameters.""" + sharded_state_dict = {} + sharded_state_dict[self.gamma.name] = ShardedTensor( + key=self.gamma.name, + org_key=self.gamma.name, + dtype=self.gamma.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + sharded_state_dict[self.beta.name] = ShardedTensor( + key=self.beta.name, + org_key=self.beta.name, + dtype=self.beta.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + return sharded_state_dict + class FusedLayerNorm(nn.Cell): """ @@ -136,7 +161,7 @@ class FusedLayerNorm(nn.Cell): super().__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype - + self.dim = dim self.layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=eps) @@ -177,6 +202,29 @@ class FusedLayerNorm(nn.Cell): def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Return sharding metadata for FusedLayerNorm parameters.""" + sharded_state_dict = {} + sharded_state_dict[self.gamma.name] = ShardedTensor( + key=self.gamma.name, + org_key=self.gamma.name, + dtype=self.gamma.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + sharded_state_dict[self.beta.name] = ShardedTensor( + key=self.beta.name, + org_key=self.beta.name, + dtype=self.beta.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + return sharded_state_dict + class RMSNorm(nn.Cell): """ @@ -198,7 +246,7 @@ class RMSNorm(nn.Cell): super().__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype - + self.dim = dim self.eps = eps self.weight = Parameter(initializer('ones', (dim), self.params_dtype)) @@ -248,6 +296,20 @@ class RMSNorm(nn.Cell): def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Return sharding metadata for RMSNorm parameters.""" + sharded_state_dict = {} + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + return sharded_state_dict + class FusedRMSNorm(nn.Cell): """ @@ -269,7 +331,7 @@ class FusedRMSNorm(nn.Cell): super().__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype - + self.dim = dim self.eps = eps self.weight = Parameter(initializer('ones', (dim), self.params_dtype)) @@ -293,11 +355,25 @@ class FusedRMSNorm(nn.Cell): if in_strategy: self.norm.shard(in_strategy) else: - self.norm.shard((layout("cp", "dp", "None"), layout("None",))) + self.norm.shard((layout("cp", "dp", "None"), layout("None", ))) def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Return sharding metadata for FusedRMSNorm parameters.""" + sharded_state_dict = {} + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + return sharded_state_dict + class Norm: """ diff --git a/mindformers/parallel_core/utils/model_mixin.py b/mindformers/parallel_core/utils/model_mixin.py index 08861a4224edb4fa7a04943f1c5b2e54d4c3538a..0aed37810fbb577bce3616a4beb8ebb304b3c4a2 100644 --- a/mindformers/parallel_core/utils/model_mixin.py +++ b/mindformers/parallel_core/utils/model_mixin.py @@ -577,10 +577,10 @@ class TrainModelMixin: model = self.check_and_get_model() return model.get_tp_dims(parameters) - def get_op_groups_info(self, parameters, op_size, tp_group, op_group): + def get_op_groups_info(self, parameters, op_size): """Get operation groups information for parameters.""" model = self.check_and_get_model() - return model.get_op_groups_info(parameters, op_size, tp_group, op_group) + return model.get_op_groups_info(parameters, op_size) def get_parallel_config_for_muon(self): """Get parallel configuration for Muon optimizer.""" diff --git a/tests/st/test_ut/test_core/test_optim/test_get_op_group.py b/tests/st/test_ut/test_core/test_optim/test_get_op_group.py new file mode 100644 index 0000000000000000000000000000000000000000..05728ab5128650e4f695680300cd3be40df2cc95 --- /dev/null +++ b/tests/st/test_ut/test_core/test_optim/test_get_op_group.py @@ -0,0 +1,178 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test get op groups info for GPT model.""" + +from unittest.mock import patch + +import mindspore as ms +import pytest + +from mindformers import build_context +from mindformers.checkpoint.sharded_tensor import build_sharded_tensor +from mindformers.parallel_core.training_graph.base_models.gpt import gpt_model +from mindformers.parallel_core.training_graph.base_models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, \ + get_gpt_mtp_block_spec +from mindformers.parallel_core.training_graph.base_models.gpt.gpt_model import GPTModel, \ + compute_repeat_num_and_model_parallel_size, get_op_group_name +from mindformers.parallel_core.transformer_config import TransformerConfig + + +def build_transformer_config() -> TransformerConfig: + """Create a minimal transformer config for tensor-parallel unit tests.""" + return TransformerConfig( + data_parallel_size=1, + pipeline_model_parallel_size=1, + tensor_model_parallel_size=1, + # model architecture + vocab_size=1024, + position_embedding_type="rope", + num_attention_heads=2, + num_layers=2, + hidden_size=128, + ffn_hidden_size=512, + # moe architecture + num_moe_experts=4, + first_k_dense_replace=1, + mtp_num_layers=1, + add_bias_linear=False, + moe_grouped_gemm=True + ) + + +def build_gpt_model(): + """Construct a GPTModel instance with the default test configuration.""" + config = build_transformer_config() + transformer_layer_spec = get_gpt_decoder_block_spec(config) + mtp_block_spec = None + if config.mtp_num_layers is not None: + mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec) + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=config.vocab_size, + max_sequence_length=config.max_position_embeddings, + position_embedding_type=config.position_embedding_type, + rotary_percent=1.0, + rotary_base=config.rotary_base, + rope_scaling=False, + mtp_block_spec=mtp_block_spec + ) + return model + + +def build_sharded_info(local_shape, axis_fragmentations): + """Helper to create a simple ShardedTensor descriptor.""" + return build_sharded_tensor( + param_name="test", + param_dtype=ms.float32, + local_shape=local_shape, + global_shape=local_shape, + axis_fragmentations=axis_fragmentations, + global_offset=(0,) * len(local_shape), + ) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gpt_model_sharded_state_dict(): + """ + Feature: GPTModel + Description: Test the sharded state dict of GPT model. + Expectation: The sharded state dict has all the trainable parameters and the shape is correct. + """ + build_context({"use_legacy": False}) + model = build_gpt_model() + sharded_state_dict = model.sharded_state_dict() + + params = model.trainable_params() + for param in params: + assert param.name in sharded_state_dict + assert param.shape == sharded_state_dict[param.name].global_shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize( + "axis_fragmentations, world_size, pipeline_parallel, opt_group_size, local_shape, expected", + [ + # case 0: real_op_size == opt_group_size + ((1, 1), 12, 2, 4, (12, 4), (4, 1)), + # case 1: real_op_size < opt_group_size + ((2, 1), 16, 2, 8, (12, 4), (4, 2)), + # case 2: real_op_size = 1 due to local shape not divisible by real_op_size + ((4, 1), 32, 2, 4, (10, 4), (1, 4)), + ], +) +def test_compute_repeat_num_and_model_parallel_size(axis_fragmentations, world_size, pipeline_parallel, + opt_group_size, local_shape, expected): + """ + Feature: compute_repeat_num_and_model_parallel_size() + Description: Test the compute repeat num and model parallel size. + Expectation: The compute repeat num and model parallel size should be correct. + """ + sharded_info = build_sharded_info(local_shape, axis_fragmentations) + assert compute_repeat_num_and_model_parallel_size( + sharded_info, + world_size=world_size, + pp=pipeline_parallel, + op=opt_group_size, + ) == expected + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_compute_repeat_num_and_model_parallel_size_multiple_axis_error(): + """ + Feature: compute_repeat_num_and_model_parallel_size() + Description: Test the error of compute repeat num and model parallel size. + Expectation: The ValueError should be raised. + """ + sharded_info = build_sharded_info((8, 8), (2, 2)) + with pytest.raises(ValueError): + compute_repeat_num_and_model_parallel_size(sharded_info, world_size=16, pp=1, op=2) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@patch("mindformers.parallel_core.training_graph.base_models.gpt.gpt_model.create_communication_group") +def test_get_op_group_name_with_mock(mock_create_group): + """ + Feature: get_op_group_name() + Description: Test the get op group name with mock. + Expectation: The get op group name with mock should be correct. + """ + mock_create_group.return_value = "mock_group" + gpt_model.OP_GROUP_NAME.clear() + + # case 0: model_parallel_size > 1 + result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=2) + assert result == ("mock_group", [1, 3]) + mock_create_group.assert_called_once_with([1, 3]) + + second_result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=2) + assert second_result == result + mock_create_group.assert_called_once() + + # case 1: model_parallel_size = 1 + result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=1) + assert result == ("mock_group", [2, 3]) + + # case 2: model_parallel_size = 4 + result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=4) + assert result == ("mock_group", [3, 7]) diff --git a/tests/st/test_ut/test_model_mixin.py b/tests/st/test_ut/test_model_mixin.py index 30bc1338c68baa0e24adef67908c2f2ec19be14d..294a4995ef809aafc9f2025a3cd15a2186c92204 100644 --- a/tests/st/test_ut/test_model_mixin.py +++ b/tests/st/test_ut/test_model_mixin.py @@ -480,7 +480,7 @@ class TestTrainModelMixin: # Create a mock model with get_op_groups_info method class MockModel: # pylint: disable=W0613 - def get_op_groups_info(self, parameters, op_size, tp_group, op_group): + def get_op_groups_info(self, parameters, op_size): return f"info_{op_size}" class TestModel(TrainModelMixin): @@ -489,7 +489,7 @@ class TestTrainModelMixin: self.model = MockModel() mixin = TestModel() - assert mixin.get_op_groups_info(None, 2, None, None) == "info_2" + assert mixin.get_op_groups_info(None, 2) == "info_2" @pytest.mark.level0 @pytest.mark.platform_x86_cpu