From 0b3a3b5576361b510cdc5a5b8f870f71e2548995 Mon Sep 17 00:00:00 2001 From: zyw_hw Date: Mon, 1 Dec 2025 18:01:46 +0800 Subject: [PATCH] fix huge cc --- mindformers/__init__.py | 1 - mindformers/core/optim/__init__.py | 77 ++--- mindformers/core/optim/fused_pma_adamw.py | 4 +- mindformers/modules/__init__.py | 1 - mindformers/modules/layers.py | 58 ++-- mindformers/modules/quantizers/base.py | 6 +- .../modules/quantizers/ptq_quantizer.py | 12 +- .../modules/quantizers/rtn_quantizer.py | 11 - mindformers/modules/transformer/__init__.py | 1 - .../modules/transformer/transformer.py | 299 ++---------------- .../tensor_parallel/grouped_layers.py | 53 ++-- mindformers/parallel_core/inference/utils.py | 20 +- .../base_models/gpt/gpt_model.py | 23 +- .../parallel_core/training_graph/loss_func.py | 49 ++- .../training_graph/transformer/norm.py | 19 +- .../training_graph/transformer/utils.py | 2 +- mindformers/wrapper/wrapper.py | 2 +- tests/st/test_ut/test_api_compatibility.py | 17 +- tests/st/test_ut/test_transformer_apis.py | 20 +- 19 files changed, 173 insertions(+), 502 deletions(-) diff --git a/mindformers/__init__.py b/mindformers/__init__.py index 021c16e5b..3951c5fc6 100644 --- a/mindformers/__init__.py +++ b/mindformers/__init__.py @@ -157,7 +157,6 @@ from mindformers.modules import ( AlibiTensorV2, Dropout, EmbeddingOpParallelConfig, - FeedForward, FixedSparseAttention, LayerNorm, Linear, diff --git a/mindformers/core/optim/__init__.py b/mindformers/core/optim/__init__.py index 74b46a495..7f8741aa6 100644 --- a/mindformers/core/optim/__init__.py +++ b/mindformers/core/optim/__init__.py @@ -28,47 +28,50 @@ __all__ = ['AdamW', 'PmaAdamW', 'Muon'] @MindFormerRegister.register(MindFormerModuleType.OPTIMIZER) class AdamW: - r""" + """ This is the implementation of AdamW. .. math:: - \begin{array}{l} - &\newline - &\hline \\ - &\textbf{Parameters}: \: 1^{\text {st }}\text {moment vector} \: m , \: 2^{\text {nd}} \: - \text{moment vector} \: v , \\ - &\: gradients \: g, \: \text{learning rate} \: \gamma, - \text {exponential decay rates for the moment estimates} \: \beta_{1} \: \beta_{2} , \\ - &\:\text {parameter vector} \: w_{0}, \:\text{timestep} \: t, \: \text{weight decay} \: \lambda \\ - &\textbf{Init}: m_{0} \leftarrow 0, \: v_{0} \leftarrow 0, \: t \leftarrow 0, \: - \text{init parameter vector} \: w_{0} \\[-1.ex] - &\newline - &\hline \\ - &\textbf{repeat} \\ - &\hspace{5mm} t \leftarrow t+1 \\ - &\hspace{5mm}\boldsymbol{g}_{t} \leftarrow \nabla f_{t}\left(\boldsymbol{w}_{t-1}\right) \\ - &\hspace{5mm}\boldsymbol{w}_{t} \leftarrow \boldsymbol{w}_{t-1}-\gamma\lambda\boldsymbol{w}_{t-1} \\ - &\hspace{5mm}\boldsymbol{m}_{t} \leftarrow \beta_{1} \boldsymbol{m}_{t-1}+\left(1-\beta_{1}\right) - \boldsymbol{g}_{t} \\ - &\hspace{5mm}\boldsymbol{v}_{t} \leftarrow \beta_{2} \boldsymbol{v}_{t-1}+\left(1-\beta_{2}\right) - \boldsymbol{g}_{t}^{2} \\ - &\hspace{5mm}\widehat{\boldsymbol{m}_{t}} \leftarrow \boldsymbol{m}_{t}/\big(1-\beta_{1}^{t} \big) \\ - &\hspace{5mm}\widehat{\boldsymbol{v}_{t}} \leftarrow \boldsymbol{v}_{t}/\big(1-\beta_{2}^{t} \big) \\ - &\hspace{5mm}\boldsymbol{w}_{t} \leftarrow \boldsymbol{w}_{t-1}-\gamma\widehat{\boldsymbol{m}_{t}} - /\left(\sqrt{\widehat{\boldsymbol{v}_{t}}}+\epsilon\right) \\ - &\textbf{until}\text { stopping criterion is met } \\[-1.ex] - &\newline - &\hline \\[-1.ex] - &\textbf{return} \: \boldsymbol{w}_{t} \\[-1.ex] - &\newline - &\hline \\[-1.ex] - \end{array} + \\begin{array}{l} + &\\newline + &\\hline \\\\ + &\\textbf{Parameters}: \\: 1^{\\text {st }}\\text {moment vector} \\: m , \\: 2^{\\text {nd}} \\: + \\text{moment vector} \\: v , \\\\ + &\\: gradients \\: g, \\: \\text{learning rate} \\: \\gamma, + \\text {exponential decay rates for the moment estimates} \\: \\beta_{1} \\: \\beta_{2} , \\\\ + &\\:\\text {parameter vector} \\: w_{0}, \\:\\text{timestep} \\: t, \\: \\text{weight decay} \\: \\lambda \\\\ + &\\textbf{Init}: m_{0} \\leftarrow 0, \\: v_{0} \\leftarrow 0, \\: t \\leftarrow 0, \\: + \\text{init parameter vector} \\: w_{0} \\\\[-1.ex] + &\\newline + &\\hline \\\\ + &\\textbf{repeat} \\\\ + &\\hspace{5mm} t \\leftarrow t+1 \\\\ + &\\hspace{5mm}\\boldsymbol{g}_{t} \\leftarrow \\nabla f_{t}\\left(\\boldsymbol{w}_{t-1}\\right) \\\\ + &\\hspace{5mm}\\boldsymbol{w}_{t} \\leftarrow \\boldsymbol{w}_{t-1}-\\gamma\\lambda + \\boldsymbol{w}_{t-1} \\\\ + &\\hspace{5mm}\\boldsymbol{m}_{t} \\leftarrow \\beta_{1} \\boldsymbol{m}_{t-1}+\\left(1-\\beta_{1}\\right) + \\boldsymbol{g}_{t} \\\\ + &\\hspace{5mm}\\boldsymbol{v}_{t} \\leftarrow \\beta_{2} \\boldsymbol{v}_{t-1}+\\left(1-\\beta_{2}\\right) + \\boldsymbol{g}_{t}^{2} \\\\ + &\\hspace{5mm}\\widehat{\\boldsymbol{m}_{t}} \\leftarrow \\boldsymbol{m}_{t}/ + \\big(1-\\beta_{1}^{t} \\big) \\\\ + &\\hspace{5mm}\\widehat{\\boldsymbol{v}_{t}} \\leftarrow \\boldsymbol{v}_{t}/ + \\big(1-\\beta_{2}^{t} \\big) \\\\ + &\\hspace{5mm}\\boldsymbol{w}_{t} \\leftarrow \\boldsymbol{w}_{t-1}-\\gamma\\widehat{\\boldsymbol{m}_{t}} + /\\left(\\sqrt{\\widehat{\\boldsymbol{v}_{t}}}+\\epsilon\\right) \\\\ + &\\textbf{until}\\text { stopping criterion is met } \\\\[-1.ex] + &\\newline + &\\hline \\\\[-1.ex] + &\\textbf{return} \\: \\boldsymbol{w}_{t} \\\\[-1.ex] + &\\newline + &\\hline \\\\[-1.ex] + \\end{array} :math:`m` represents the first moment vector moment1, :math:`v` represents the second moment vector moment2, - :math:`\widehat{m}` represents the bias-corrected first moment vector, :math:`\widehat{v}` represents - the bias-corrected second moment vector, :math:`g` represents gradients, :math:`\gamma` represents - learning_rate, :math:`\beta_1`, `\beta_2` represent beta1 and beta2, :math:`t` represents the current step, - :math:`w` represents params, and :math:`\lambda` represents weight_decay. + :math:`\\widehat{m}` represents the bias-corrected first moment vector, :math:`\\widehat{v}` represents + the bias-corrected second moment vector, :math:`g` represents gradients, :math:`\\gamma` represents + learning_rate, :math:`\\beta_1`, `\\beta_2` represent beta1 and beta2, :math:`t` represents the current step, + :math:`w` represents params, and :math:`\\lambda` represents weight_decay. Args: params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the @@ -218,7 +221,7 @@ class AdamW: @MindFormerRegister.register(MindFormerModuleType.OPTIMIZER) class PmaAdamW: - r""" + """ This is the implementation of PmAdamW. Args: diff --git a/mindformers/core/optim/fused_pma_adamw.py b/mindformers/core/optim/fused_pma_adamw.py index a64e5eff4..0ed11c373 100644 --- a/mindformers/core/optim/fused_pma_adamw.py +++ b/mindformers/core/optim/fused_pma_adamw.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """FusedPmaAdamW implementation""" -import mindspore.ops as ops +from mindspore import ops from mindspore._checkparam import GT, INC_NEITHER from mindspore import _checkparam as validator @@ -76,7 +76,7 @@ def _check_param_value(fused_num, interleave_step, fused_algo, ema_alpha, prim_n class FusedPmaAdamW(FusedAdamW): - r""" + """ This is the implementation of PmaAdamW that uses fused operators. Args: diff --git a/mindformers/modules/__init__.py b/mindformers/modules/__init__.py index a7f18cf6d..1312007e5 100644 --- a/mindformers/modules/__init__.py +++ b/mindformers/modules/__init__.py @@ -15,7 +15,6 @@ """MindFormers Transformers API.""" from .transformer import ( EmbeddingOpParallelConfig, - FeedForward, LowerTriangularMaskWithDynamic, MoEConfig, OpParallelConfig, diff --git a/mindformers/modules/layers.py b/mindformers/modules/layers.py index fc0ef5667..a695d6039 100644 --- a/mindformers/modules/layers.py +++ b/mindformers/modules/layers.py @@ -49,6 +49,8 @@ from mindformers.tools.logger import logger from mindformers.tools.utils import is_pynative from mindformers.modules.activation import get_activation from mindformers.modules.transformer.op_parallel_config import default_dpmp_config, OpParallelConfig, MoEParallelConfig +from mindformers.parallel_core.training_graph.base_models.common.embeddings.yarn_rotary_pos_embedding import \ + _yarn_find_correction_range __all__ = [ "FixedSparseAttention", @@ -177,7 +179,6 @@ class _LayerInputCheck: Check the input shape's is equal to the expected shape, the value on 0-th is viewed as batch, and the batch size will not be checked. """ - target_shape = target_shape length, hidden = target_shape if isinstance(input_shape, tuple): input_shape = list(input_shape) @@ -244,11 +245,9 @@ class Dropout(nn.Cell): """ def __init__(self, keep_prob=0.5, dtype=mstype.float32): - super(Dropout, self).__init__() + super().__init__() if keep_prob <= 0 or keep_prob > 1: - raise ValueError( - "dropout probability should be a number in range (0, 1], but got {}".format( - keep_prob)) + raise ValueError(f"dropout probability should be a number in range (0, 1], but got {keep_prob}") Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) self.keep_prob = keep_prob @@ -269,7 +268,7 @@ class Dropout(nn.Cell): return out def extend_repr(self): - return 'keep_prob={}'.format(self.keep_prob) + return f'keep_prob={self.keep_prob}' def shard(self, strategy): self.dropout.shard(strategy) @@ -291,10 +290,10 @@ class LayerNorm(Cell): """ def __init__(self, normalized_shape, eps=1e-5, param_init_type=mstype.float32, is_self_defined=False): - super(LayerNorm, self).__init__() + super().__init__() if param_init_type not in [mstype.float32, mstype.float16, mstype.bfloat16]: - raise TypeError("The type of parameter 'param_init_type' should in [float32, float16], " - "but got the type : {}.".format(type(param_init_type))) + raise TypeError(f"The type of parameter 'param_init_type' should in [float32, float16], " + f"but got the type : {type(param_init_type)}.") # Since the mindspore 1.10 version, the layernorm has been changed to P.LayerNorm self.is_self_defined = is_self_defined if not self.is_self_defined: @@ -441,7 +440,7 @@ class Linear(Cell): use_gmm=False, param_init_type=mstype.float32, compute_dtype=mstype.float16): - super(Linear, self).__init__() + super().__init__() self.in_channels = in_channels self.out_channels = out_channels if not (isinstance(activation, str) or activation is None or issubclass(activation, nn.Cell)): @@ -465,6 +464,7 @@ class Linear(Cell): self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type), name="weight") if self.use_gmm: + # pylint: disable=import-outside-toplevel from mindspore.ops.auto_generate import GroupedMatmul # split_item only supports 0 and 3 now, 0 means the size of tensorlist not equal to 1, # 3 means the size of tensorlist is 1. @@ -676,7 +676,7 @@ class FixedSparseAttention(nn.Cell): seq_length=1024, num_different_global_patterns=4, parallel_config=default_dpmp_config): - super(FixedSparseAttention, self).__init__() + super().__init__() dp, mp = parallel_config.data_parallel, parallel_config.model_parallel if num_heads % mp != 0: raise ValueError(f"The number of heads {num_heads} must be a " @@ -700,17 +700,17 @@ class FixedSparseAttention(nn.Cell): self.parallel_config = parallel_config size_per_head_list = [64, 128] if self.seq_length != 1024: - raise ValueError("For 'FixedSparseAttention', the class variable 'seq_length' must be 1024, " - "but got the value : {}.".format(seq_length)) + raise ValueError(f"For 'FixedSparseAttention', the class variable 'seq_length' must be 1024, " + f"but got the value : {seq_length}.") if self.block_size != 64: - raise ValueError("For 'FixedSparseAttention', the class variable 'block_size' must be 64, " - "but got the value : {}.".format(block_size)) + raise ValueError(f"For 'FixedSparseAttention', the class variable 'block_size' must be 64, " + f"but got the value : {block_size}.") if num_different_global_patterns != 4: - raise ValueError("For 'FixedSparseAttention', the class variable 'num_different_global_patterns' " - "must be 4, but got the value : {}".format(num_different_global_patterns)) + raise ValueError(f"For 'FixedSparseAttention', the class variable 'num_different_global_patterns' " + f"must be 4, but got the value : {num_different_global_patterns}") if self.size_per_head not in size_per_head_list: - raise ValueError("For 'FixedSparseAttention', the class variable 'size_per_head' only supports {}, " - "but got the value : {}.".format(size_per_head_list, self.size_per_head)) + raise ValueError(f"For 'FixedSparseAttention', the class variable 'size_per_head' " + f"only supports {size_per_head_list}, but got the value : {self.size_per_head}.") local_ones = np.ones((self.block_size, self.block_size), dtype=np.float16) global_mask_original = np.ones((self.seq_length, self.global_size), dtype=np.float16) @@ -851,7 +851,7 @@ class AlibiTensor(nn.Cell): """ def __init__(self, seq_length, num_heads, parallel_config=default_dpmp_config): - super(AlibiTensor, self).__init__() + super().__init__() dp = parallel_config.data_parallel self.seq_length = seq_length @@ -915,7 +915,7 @@ class AlibiTensorV2(nn.Cell): """ def __init__(self, num_heads): - super(AlibiTensorV2, self).__init__() + super().__init__() self.num_heads = num_heads self.expand_2d = P.ExpandDims() @@ -1124,22 +1124,6 @@ def _check_linear_scaling_factor(scaling_factor): raise ValueError(f"`scaling_factor`'s factor field must be a float >= 1, got {factor}") -def _yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): - """Inverse dim formula to find dim based on number of rotations""" - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - -def _yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - """Find dim range bounds based on rotations""" - low = math.floor( - _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - def _yarn_get_mscale(scale=1, mscale=1): if scale <= 1: return 1.0 diff --git a/mindformers/modules/quantizers/base.py b/mindformers/modules/quantizers/base.py index 6d17e6639..2efb9f5ce 100644 --- a/mindformers/modules/quantizers/base.py +++ b/mindformers/modules/quantizers/base.py @@ -221,14 +221,14 @@ class Quantizer(ABC): @abstractmethod def _process_model_after_weight_loading(self, model, **kwargs): - pass + return model @property @abstractmethod def is_serializable(self): - pass + return False @property @abstractmethod def is_trainable(self): - pass + return False diff --git a/mindformers/modules/quantizers/ptq_quantizer.py b/mindformers/modules/quantizers/ptq_quantizer.py index 832cf6aba..2b441c1e5 100644 --- a/mindformers/modules/quantizers/ptq_quantizer.py +++ b/mindformers/modules/quantizers/ptq_quantizer.py @@ -52,19 +52,9 @@ class PtqQuantizer(Quantizer): def _process_model_before_weight_loading( self, model: "PreTrainedModel", **kwargs ): + # pylint: disable=import-outside-toplevel from mindspore_gs.ptq import PTQ ptq = PTQ(config=self.quant_config, layer_policies=self.layer_policies) model = ptq.apply(model) model = ptq.convert(model) return model - - def _process_model_after_weight_loading(self, model, **kwargs): - return model - - @property - def is_serializable(self): - return False - - @property - def is_trainable(self): - return False diff --git a/mindformers/modules/quantizers/rtn_quantizer.py b/mindformers/modules/quantizers/rtn_quantizer.py index 36486ae8d..f0d063050 100644 --- a/mindformers/modules/quantizers/rtn_quantizer.py +++ b/mindformers/modules/quantizers/rtn_quantizer.py @@ -65,14 +65,3 @@ class RtnQuantizer(Quantizer): model = ptq.apply(model) model = ptq.convert(model) return model - - def _process_model_after_weight_loading(self, model, **kwargs): - return model - - @property - def is_serializable(self): - return False - - @property - def is_trainable(self): - return False diff --git a/mindformers/modules/transformer/__init__.py b/mindformers/modules/transformer/__init__.py index 7872e84d5..68a4c3ba4 100644 --- a/mindformers/modules/transformer/__init__.py +++ b/mindformers/modules/transformer/__init__.py @@ -21,7 +21,6 @@ This is an experimental interface that is subject to change or deletion. from .transformer import ( EmbeddingOpParallelConfig, - FeedForward, LowerTriangularMaskWithDynamic, TransformerOpParallelConfig, TransformerRecomputeConfig, diff --git a/mindformers/modules/transformer/transformer.py b/mindformers/modules/transformer/transformer.py index aa04d8682..02e076aed 100644 --- a/mindformers/modules/transformer/transformer.py +++ b/mindformers/modules/transformer/transformer.py @@ -26,7 +26,6 @@ import mindspore as ms from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.common.initializer import Zero -from mindspore import nn import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore.ops import functional as F @@ -40,18 +39,14 @@ except ImportError: from mindspore import log as logger from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation from mindspore.context import ParallelMode -from mindformers.modules.layers import Linear, _args_type_validator_check, _valid_type_checks, _valid_value_checks, \ - _check_input_dtype -from mindformers.modules.transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, \ - _Config, _check_config, MoEParallelConfig -from mindformers.version_control import get_dropout +from mindformers.modules.transformer.op_parallel_config import _PipeLineConfig, OpParallelConfig, \ + _Config, MoEParallelConfig from mindformers.tools.logger import _LogActionOnce from mindformers.tools.utils import is_pynative __all__ = [ "LowerTriangularMaskWithDynamic", - "FeedForward", "TransformerOpParallelConfig", "EmbeddingOpParallelConfig", "TransformerRecomputeConfig", @@ -211,7 +206,7 @@ class TransformerSwapConfig(_Config): if isinstance(layer_swap, dict): layer_swap = [layer_swap] if self._validate_layers_consistency(layer_swap): - return [dict(backward_prefetch=layer_swap[0][self.backward_prefetch], layers=True)] + return [{"backward_prefetch": layer_swap[0][self.backward_prefetch], "layers": True}] return layer_swap def _initialize_op_swap(self, op_swap): @@ -225,7 +220,7 @@ class TransformerSwapConfig(_Config): op_swap_dict = self.op_swap_to_dict(op_swap) for k, v in op_swap_dict.items(): if self._validate_layers_consistency(v, mode=f'op_swap: {k}'): - op_swap_dict[k] = [dict(backward_prefetch=v[0][self.backward_prefetch], layers=True)] + op_swap_dict[k] = [{"backward_prefetch": v[0][self.backward_prefetch], "layers": True}] return op_swap_dict def _validate_layers_consistency(self, layer_swap, mode='layer_swap'): @@ -283,17 +278,17 @@ class TransformerSwapConfig(_Config): """Adds an operation swap configuration to the dictionary.""" if key in dic: dic[key].append( - dict( - layers=item.get(self.layers), - backward_prefetch=item.get(self.backward_prefetch) - ) + { + 'layers': item.get(self.layers), + 'backward_prefetch': item.get(self.backward_prefetch) + } ) else: dic[key] = [ - dict( - layers=item.get(self.layers), - backward_prefetch=item.get(self.backward_prefetch) - ) + { + 'layers': item.get(self.layers), + 'backward_prefetch': item.get(self.backward_prefetch) + } ] return dic @@ -507,9 +502,9 @@ class ContextParallelAlgo(Enum): Args: Enum (str): chosses context parallel type """ - colossalai_cp = "colossalai_cp" - ulysses_cp = "ulysses_cp" - hybrid_cp = "hybrid_cp" + COLOSSALAI_CP = "colossalai_cp" + ULYSSES_CP = "ulysses_cp" + HYBRID_CP = "hybrid_cp" default_transformer_swap_config = TransformerSwapConfig() @@ -601,7 +596,7 @@ class TransformerOpParallelConfig(_Config): ValueError: in hybrid_cp algorithm, context_parallel should be divisible by ulysses_degree_in_cp """ if self.context_parallel == 1: - if self.context_parallel_algo != ContextParallelAlgo.colossalai_cp: + if self.context_parallel_algo != ContextParallelAlgo.COLOSSALAI_CP: logger.warning(f"context_parallel_algo {self.context_parallel_algo.value} will not take effect " "when context_parallel == 1.") if self.ulysses_degree_in_cp > 1: @@ -610,10 +605,10 @@ class TransformerOpParallelConfig(_Config): return # here context parallel > 1 - if self.context_parallel_algo != ContextParallelAlgo.hybrid_cp and self.ulysses_degree_in_cp > 1: + if self.context_parallel_algo != ContextParallelAlgo.HYBRID_CP and self.ulysses_degree_in_cp > 1: logger.warning(f"ulysses_degree_in_cp {self.ulysses_degree_in_cp} will not take effect when " f"context_parallel_algo {self.context_parallel_algo.value} is not `hybrid_cp`.") - if (self.context_parallel_algo == ContextParallelAlgo.hybrid_cp and + if (self.context_parallel_algo == ContextParallelAlgo.HYBRID_CP and self.context_parallel % self.ulysses_degree_in_cp != 0): raise ValueError(f"When using hybrid_cp algorithm, context_parallel {self.context_parallel} " f"should be divisible by ulysses_degree_in_cp {self.ulysses_degree_in_cp}. " @@ -627,9 +622,9 @@ class TransformerOpParallelConfig(_Config): """ if self.context_parallel == 1: return 1 - if self.context_parallel_algo == ContextParallelAlgo.colossalai_cp: + if self.context_parallel_algo == ContextParallelAlgo.COLOSSALAI_CP: return 1 - if self.context_parallel_algo == ContextParallelAlgo.ulysses_cp: + if self.context_parallel_algo == ContextParallelAlgo.ULYSSES_CP: return self.context_parallel # hybird return self.ulysses_degree_in_cp @@ -786,260 +781,6 @@ class TransformerOpParallelConfig(_Config): default_transformer_config = TransformerOpParallelConfig() -class FeedForward(Cell): - r""" - The multilayer perceptron with two linear layers with dropout applied at final output. The first linear - will project the input dimension from hidden_size to ffn_hidden_size. The second linear will project the - dimension from ffn_hidden_size to hidden_size. The first linear is sharded on the relative dimension, - and the second linear is sharded on the output dimension. The overview process can be: - - .. math:: - Dropout((xW_1+b_1)W_2 + b_2) - - where the :math:`W_1, W_2, b_1` and :math:`b_2` are trainable parameters. - - Args: - hidden_size (int): The dimension of the inputs. - ffn_hidden_size (int): The intermediate hidden size. - dropout_rate (float): The dropout rate for the second linear's output. - hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu', - 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', - 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument. - If user wants to run the net in the parallel mode, the custom activation must also provide - the `activation_shard` function. Please see examples. Default: gelu. - expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used - and the first dimension in BatchMatMul indicate expert_num. Default: 1. - expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is - effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION. - param_init_type (dtype.Number): The parameter initialization type. Should be mstype.float32 or - mstype.float16. Default: mstype.float32. - parallel_config (OpParallelConfig, MoEParallelConfig): The config of parallel setting, see - `OpParallelConfig` or `MoEParallelConfig`. When MoE is applied, MoEParallelConfig is effective, - otherwise OpParallelConfig is effective. Default `default_dpmp_config`, - an instance of `OpParallelConfig` with default args. - - Inputs: - - **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`. - Float tensor. - - Outputs: - Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or - [batch * seq_length, hidden_size]`. - - Raises: - TypeError: `hidden_act` is not a string or nn.Cell. - TypeError: `parallel_config` is not a subclass of OpParallelConfig. - ValueError: `ffn_hidden_size` is not a multiple of the model parallel way. - ValueError: `hidden_size` is not a multiple of the model parallel way. - - Supported Platforms: - ``Ascend`` ``GPU`` - - Examples: - >>> import numpy as np - >>> from mindformers.modules.transformer import FeedForward - >>> from mindspore import dtype as mstype - >>> from mindspore import Tensor, nn - >>> import mindspore.ops as ops - >>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1) - >>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32) - >>> output = model(tensor) - >>> print(output.shape) - (2, 20, 15) - >>> # Example 2 using custom hidden activation - >>> class MyActivationNoShard(nn.Cell): - ... def __init__(self): - ... super(MyActivationNoShard, self).__init__() - ... self.add = ops.Add() - ... def construct(self, x): - ... return self.add(x, 0.1) - >>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1, - ... hidden_act=MyActivationNoShard) - >>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32) - >>> output = model(tensor) - >>> print(output.shape) - (2, 20, 15) - >>> # Example 3 using custom hidden activation with activation_shard - >>> # If user wantss to run on the SEMI/AUTO parallel mode, the custom activation must provide - >>> # a class function named activation_shard. It accepts the argument parallel_config (OpParallelConfig, - >>> # MoEParallelConfig) and set the shard for the primitives used in the construct. - >>> class MyActivationWithShard(nn.Cell): - ... def __init__(self): - ... super(MyActivationWithShard, self).__init__() - ... self.add = ops.Add() - ... def construct(self, x): - ... return self.add(x, 0.1) - ... def activation_shard(self, parallel_config): - ... self.add.shard(((parallel_config.data_parallel, parallel_config.model_parallel), ())) - >>> - >>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1, - ... hidden_act=MyActivationWithShard) - >>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32) - >>> output = model(tensor) - >>> print(output.shape) - (2, 20, 15) - """ - - @_LogActionOnce(m_logger=logger, key='FeedForward', - no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,)) - @_args_type_validator_check(hidden_size=Validator.check_positive_int, - ffn_hidden_size=Validator.check_positive_int, - dropout_rate=Validator.check_non_negative_float, - param_init_type=_valid_value_checks([mstype.float32, mstype.bfloat16, mstype.float16], - "FeedForward"), - compute_dtype=_valid_value_checks([mstype.float32, mstype.bfloat16, mstype.float16], - "FeedForward"), - parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig], - "FeedForward")) - def __init__(self, hidden_size, - ffn_hidden_size, - dropout_rate, - hidden_act='gelu', - expert_num=1, - expert_group_size=None, - param_init_type=mstype.float32, - parallel_config=default_dpmp_config, - compute_dtype=mstype.float16): - super(FeedForward, self).__init__() - self.dtype = compute_dtype - if hidden_act is None or not (isinstance(hidden_act, str) or issubclass(hidden_act, nn.Cell)): - raise TypeError(f"For FeedForward cell, the hidden_act should str type or nn.Cell type, " - f"but got {hidden_act}.") - if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,): - _check_config(parallel_config) - mp = parallel_config.model_parallel - if expert_num > 1: - ep = parallel_config.expert_parallel - else: - ep = 1 - # ffn use less dp than other ops when use_moe, due to there are ops use dp and ep. - dp = parallel_config.data_parallel // ep - if ffn_hidden_size % mp != 0: - raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the" - "num of model parallel, but got the ffn_hidden_size is {} and the num of model " - "parallel is {}.".format(ffn_hidden_size, mp)) - if hidden_size % mp != 0: - raise ValueError("For 'FeedForward', the class variable 'hidden_size' must be a multiple of the num of " - "model parallel, but got the hidden_size is {} and the num of model parallel is {}." - .format(hidden_size, mp)) - if dropout_rate < 0 or dropout_rate >= 1: - raise ValueError("For 'FeedForward', the class variable 'dropout_rate' must be in the range [0, 1.0), " - "but got the value : {}.".format(dropout_rate)) - input_size = hidden_size - output_size = ffn_hidden_size - - # Project to ffn_hidden_size - self.mapping = Linear(in_channels=input_size, - out_channels=output_size, - activation=hidden_act, - transpose_b=False, - expert_num=expert_num, - expert_group_size=expert_group_size, - outer_batch=dp, - param_init_type=param_init_type, - compute_dtype=compute_dtype) - - # Project back to hidden_size - self.projection = Linear(in_channels=output_size, - out_channels=input_size, - transpose_b=False, - expert_num=expert_num, - expert_group_size=expert_group_size, - outer_batch=dp, - param_init_type=param_init_type, - compute_dtype=compute_dtype) - if expert_num > 1: - self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1))) - else: - self.projection.shard(strategy_matmul=((dp, mp), (mp, 1))) - self.projection.bias.parallel_optimizer = False - self.dropout = get_dropout(dropout_rate) - self.dropout_3d = get_dropout(dropout_rate) - self.dropout_4d = get_dropout(dropout_rate) - self.cast = P.Cast() - else: - _check_config(parallel_config) - mp = parallel_config.model_parallel - if expert_num > 1: - ep = parallel_config.expert_parallel - else: - ep = 1 - # ffn use less dp than other ops when use_moe, due to there are ops use dp and ep. - dp = parallel_config.data_parallel // ep - if ffn_hidden_size % mp != 0: - raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the" - "num of model parallel, but got the ffn_hidden_size is {} and the num of model " - "parallel is {}.".format(ffn_hidden_size, mp)) - if hidden_size % mp != 0: - raise ValueError("For 'FeedForward', the class variable 'hidden_size' must be a multiple of the num of " - "model parallel, but got the hidden_size is {} and the num of model parallel is {}." - .format(hidden_size, mp)) - if dropout_rate < 0 or dropout_rate >= 1: - raise ValueError("For 'FeedForward', the class variable 'dropout_rate' must be in the range [0, 1.0), " - "but got the value : {}.".format(dropout_rate)) - input_size = hidden_size - output_size = ffn_hidden_size - - # Project to ffn_hidden_size - self.mapping = Linear(in_channels=input_size, - out_channels=output_size, - activation=hidden_act, - transpose_b=False, - expert_num=expert_num, - expert_group_size=expert_group_size, - outer_batch=dp, - param_init_type=param_init_type, - compute_dtype=compute_dtype) - - if expert_num > 1: - self.mapping.shard(strategy_matmul=((dp, ep, 1, 1), (ep, 1, mp)), - strategy_bias=((dp, ep, 1, mp), (1, ep, 1, mp)), - strategy_activation=((dp, ep, 1, mp),)) - else: - self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)), - strategy_bias=((dp, mp), (mp,)), - strategy_activation=((dp, mp),)) - # Project back to hidden_size - self.projection = Linear(in_channels=output_size, - out_channels=input_size, - transpose_b=False, - expert_num=expert_num, - expert_group_size=expert_group_size, - outer_batch=dp, - param_init_type=param_init_type, - compute_dtype=compute_dtype) - if expert_num > 1: - self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)), - strategy_bias=((dp, ep, 1, 1), (1, ep, 1, 1))) - else: - self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)), - strategy_bias=((dp, 1), (1,))) - self.projection.bias.parallel_optimizer = False - self.dropout = get_dropout(dropout_rate) - self.dropout_3d = get_dropout(dropout_rate) - self.dropout_4d = get_dropout(dropout_rate) - self.dropout.dropout.shard(((dp, 1),)) - self.dropout_3d.dropout.shard(((dp, 1, 1),)) - self.dropout_4d.dropout.shard(((dp, ep, 1, 1),)) - self.cast = P.Cast() - - def construct(self, x): - """Forward process of the FeedForward""" - _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) - x = self.cast(x, self.dtype) - # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size] - hidden = self.mapping(x) - output = self.projection(hidden) - # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size] - if len(F.shape(output)) == 3: - output = self.dropout_3d(output) - elif len(F.shape(output)) == 2: - output = self.dropout(output) - else: - output = self.dropout_4d(output) - return output - - class LowerTriangularMaskWithDynamic(Cell): r""" Get the Strictly Lower triangular matrix from the input_ids. diff --git a/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py b/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py index 39e51ebd0..fab73dd74 100644 --- a/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py +++ b/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py @@ -84,6 +84,7 @@ class UnquantizedGroupedLinearMethod(GroupedLinearMethodBase): self.cast = P.Cast() self.matmul = ops.auto_generate.GroupedMatmulV4() + # pylint: disable=W0237 def create_weights(self, layer: nn.Cell, num_local_experts: int, input_size_per_partition: int, output_partition_sizes: list[int], params_dtype, **extra_weight_attrs): @@ -216,17 +217,17 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): quant_config: Optional[QuantizationConfig] = None, prefix: str = "" ): - super(ColumnParallelGroupedLinear, self).__init__(num_local_experts, - input_size, - output_size, - skip_bias_add, - config.params_dtype, - quant_config=quant_config, - prefix=prefix) + super().__init__(num_local_experts, + input_size, + output_size, + skip_bias_add, + config.params_dtype, + quant_config=quant_config, + prefix=prefix) if stride > 1: raise NotImplementedError( - "For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, " - "but got `stride={}`".format(stride)) + f"For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, " + f"but got `stride={stride}`") if skip_bias_add: raise NotImplementedError( "For ColumnParallelGroupedLinear, `skip_bias_add=True` is not supported for now." @@ -275,6 +276,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): else: self.bias = None + # pylint: disable=W0237 def construct(self, input_parallel, weight=None, group_list=None): """Forward of ColumnParallelGroupedLinear.""" if weight is None: @@ -386,15 +388,15 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): class RowParallelGroupedLinear(GroupedLinearBase): - r""" + """ The group linear layer with weight sliced on first dimension by tensor parallel size. This layer implements the operation as: .. math:: - \text{outputs} = \text{inputs} * \text{weight} + \text{bias}, + \\text{outputs} = \\text{inputs} * \\text{weight} + \\text{bias}, - where :math:`inputs` is the input tensors, :math:`\text{weight}` is a weight matrix created by the layer, - and :math:`\text{bias}` is a bias vector created by the layer (only if has_bias is True). + where :math:`inputs` is the input tensors, :math:`\\text{weight}` is a weight matrix created by the layer, + and :math:`\\text{bias}` is a bias vector created by the layer (only if has_bias is True). Args: num_local_experts (int): The number of local expert. @@ -416,11 +418,11 @@ class RowParallelGroupedLinear(GroupedLinearBase): prefix (str): The prefix string for this linear layer. Default: empty string(""). Inputs: - - **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `input_size` in `Args` should be equal - to :math:`in\_channels` in `Inputs`. + - **x** (Tensor) - Tensor of shape :math:`(*, in\\_channels)`. The `input_size` in `Args` should be equal + to :math:`in\\_channels` in `Inputs`. Outputs: - Tensor of shape :math:`(*, out\_channels)`. + Tensor of shape :math:`(*, out\\_channels)`. Supported Platforms: ``Ascend`` @@ -445,17 +447,17 @@ class RowParallelGroupedLinear(GroupedLinearBase): quant_config: Optional[QuantizationConfig] = None, prefix: str = "" ): - super(RowParallelGroupedLinear, self).__init__(num_local_experts, - input_size, - output_size, - skip_bias_add, - config.params_dtype, - quant_config=quant_config, - prefix=prefix) + super().__init__(num_local_experts, + input_size, + output_size, + skip_bias_add, + config.params_dtype, + quant_config=quant_config, + prefix=prefix) if stride > 1: raise NotImplementedError( - "For RowParallelGroupedLinear, `stride > 1` is not supported for now, " - "but got `stride={}`".format(stride)) + f"For RowParallelGroupedLinear, `stride > 1` is not supported for now, " + f"but got `stride={stride}`") if not is_expert: raise NotImplementedError( "For RowParallelGroupedLinear, `is_expert=False` is not supported for now.") @@ -502,6 +504,7 @@ class RowParallelGroupedLinear(GroupedLinearBase): else: self.bias = None + # pylint: disable=W0237 def construct(self, input_, weight=None, group_list=None): """Forward of RowParallelGroupedLinear.""" if weight is None: diff --git a/mindformers/parallel_core/inference/utils.py b/mindformers/parallel_core/inference/utils.py index 852b26e12..f012d2eba 100644 --- a/mindformers/parallel_core/inference/utils.py +++ b/mindformers/parallel_core/inference/utils.py @@ -20,12 +20,15 @@ __all__ = [ "update_comm_config", ] +import os +import stat from contextlib import contextmanager import numpy as np import mindspore as ms from mindspore import Tensor, ops, Parameter, mint from mindspore.communication import get_group_size +from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy from mindformers.version_control import is_310p from mindformers.parallel_core.transformer_config import TransformerConfig @@ -65,7 +68,7 @@ ATTNMASK_FUNC_MAP = { def get_attn_mask_func(mask_func_type): - r""" + """ Get attention mask function. Args: @@ -75,9 +78,9 @@ def get_attn_mask_func(mask_func_type): Function, the attention mask function. """ if mask_func_type not in ATTNMASK_FUNC_MAP: - raise KeyError("Invalid attention mask function. Supported attention " - "mask function are ['attn_mask_fill', 'attn_mask_add'] " - ", but got {}.".format(mask_func_type)) + raise KeyError(f"Invalid attention mask function. Supported attention " + f"mask function are ['attn_mask_fill', 'attn_mask_add'] " + f", but got {mask_func_type}.") return ATTNMASK_FUNC_MAP[mask_func_type] @@ -158,7 +161,7 @@ def create_empty_parameter(shape, *, dtype=None, device=None, **kwargs): def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" if numerator % denominator != 0: - raise ValueError("{} is not divisible by {}".format(numerator, denominator)) + raise ValueError(f"{numerator} is not divisible by {denominator}") def divide(numerator, denominator): @@ -178,10 +181,6 @@ def save_strategy_file(state_dict, strategy_file_name): Supported Platforms: ``Ascend`` """ - import os - import stat - from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy - stra = ckpt_strategy() stage_rank_size = state_dict["stage_rank_size"] @@ -361,12 +360,13 @@ def get_num_layers_and_offset(config): return int(layer_list[pp_rank]), int(sum(layer_list[:pp_rank])) return num_layers, 0 + def use_ms_custom_ops(): """ Determine whether has custom ops """ try: - # pylint: disable=W0611 + # pylint: disable=W0611, C0415 import ms_custom_ops except ModuleNotFoundError: # environment need install ms_custom_ops package diff --git a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py index e4e89a69b..3ea726020 100644 --- a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py +++ b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py @@ -29,7 +29,8 @@ from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagati from mindspore import ops from mindformers.parallel_core.training_graph.loss_func import CrossEntropyLoss -from mindformers.parallel_core.training_graph.transformer.multi_token_prediction import MultiTokenPredictionBlock +from mindformers.parallel_core.training_graph.transformer.multi_token_prediction import MultiTokenPredictionBlock, \ + func_infer_dtype, func_infer_shape, func_infer_shape_labels_and_masks from mindformers.parallel_core.training_graph.device_matrix import layout from mindformers.parallel_core.utils.spec_utils import ModuleSpec from mindformers.parallel_core.training_graph.transformer.mask_generate import CausalMaskGenerate @@ -56,26 +57,6 @@ from mindformers.version_control import get_lazy_inline as lazy_inline from mindformers.core.optim.muon_utils import make_muon_fns -def func_infer_dtype(*args): - """infer_dtype for Morph.""" - return args[0] - - -def func_infer_shape(*args): - """infer_shape for Morph.""" - input_shape = args[0] - shape_value = np.prod(input_shape[:-1]) - output_shape = [int(shape_value), args[0][-1]] - return output_shape - -def func_infer_shape_labels_and_masks(*args): - """infer_shape for Morph.""" - input_shape = args[0] - shape_value = np.prod(input_shape) - output_shape = [int(shape_value)] - return output_shape - - class PreprocessLabelsAndMasks(nn.Cell): """Preprocess input_ids and generate labels and masks. """ diff --git a/mindformers/parallel_core/training_graph/loss_func.py b/mindformers/parallel_core/training_graph/loss_func.py index c2dc69e9b..d03b1677b 100644 --- a/mindformers/parallel_core/training_graph/loss_func.py +++ b/mindformers/parallel_core/training_graph/loss_func.py @@ -40,26 +40,23 @@ _device_local_loss = {} def get_device_local_loss(tag="lm"): """Get `_device_local_loss` Parameter after init""" - global _device_local_loss if tag is None: return _device_local_loss if _device_local_loss.get(tag, None) is None: _device_local_loss[tag] = Parameter( - Tensor([0.0], mstype.float32), name=f"_device_local_loss", requires_grad=False + Tensor([0.0], mstype.float32), name="_device_local_loss", requires_grad=False ) return _device_local_loss[tag] def reset_device_local_loss(): """Reset `_device_local_loss` parameter to zero""" - global _device_local_loss for _, loss in _device_local_loss.items(): F.assign(loss, Tensor([0.0], mstype.float32)) def check_device_local_loss(): """check if Nan or Inf in `_device_local_loss` parameter then terminate training""" - global _device_local_loss if not _device_local_loss: return for tag, device_local_loss in _device_local_loss.items(): @@ -88,7 +85,7 @@ class _LogSoftmax(nn.Cell): The corresponding log softmax results. """ def __init__(self, config: TransformerConfig = default_transformer_config): - super(_LogSoftmax, self).__init__() + super().__init__() dp = config.data_parallel_size mp = config.tensor_model_parallel_size cp = config.context_parallel_size @@ -143,7 +140,7 @@ class _NLLLoss(nn.Cell): The corresponding loss results. """ def __init__(self, config: TransformerConfig = default_transformer_config): - super(_NLLLoss, self).__init__() + super().__init__() dp = config.data_parallel_size mp = config.tensor_model_parallel_size cp = config.context_parallel_size @@ -176,7 +173,7 @@ class _NLLLoss(nn.Cell): class CrossEntropyLoss(nn.Cell): - r""" + """ Calculate the cross entropy loss. CrossEntropyLoss supports two different types of targets: @@ -185,9 +182,9 @@ class CrossEntropyLoss(nn.Cell): When reduction is set to 'none', the cross-entropy loss is computed as follows: .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} - \cdot \mathbb{1}\{y_n \not= \text{ignore_index}\} + \\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\top, \\quad + l_n = - w_{y_n} \\log \\frac{\\exp(x_{n,y_n})}{\\sum_{c=1}^C \\exp(x_{n,c})} + \\cdot \\mathbb{1}\\{y_n \\not= \\text{ignore_index}\\} where :math:`x` denotes the predicted values, :math:`t` denotes the target values, :math:`w` denotes the weights, and :math:`N` is the batch size. The index :math:`c` ranges from [0, C-1], representing the class indices, @@ -196,19 +193,19 @@ class CrossEntropyLoss(nn.Cell): If reduction is not set to 'none' (the default is 'mean'), the loss is computed as: .. math:: - \ell(x, y) = \begin{cases} - \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore_index}\}} l_n, & - \text{if reduction} = \text{'mean',}\\ - \sum_{n=1}^N l_n, & - \text{if reduction} = \text{'sum'.} - \end{cases} + \\ell(x, y) = \\begin{cases} + \\sum_{n=1}^N \\frac{1}{\\sum_{n=1}^N w_{y_n} \\cdot \\mathbb{1}\\{y_n \\not= + \\text{ignore_index}\\}} l_n, &\\text{if reduction} = \\text{'mean',}\\\\ + \\sum_{n=1}^N l_n, & + \\text{if reduction} = \\text{'sum'.} + \\end{cases} - Class probabilities (float), used when the target is a probability distribution over multiple class labels. When reduction is set to 'none', the cross-entropy loss is computed as follows: .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c} + \\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\\top, \\quad + l_n = - \\sum_{c=1}^C w_c \\log \\frac{\\exp(x_{n,c})}{\\sum_{i=1}^C \\exp(x_{n,i})} y_{n,c} where :math:`x` denotes the predicted values, :math:`t` denotes the target values, :math:`w` denotes the weights, and :math:`N` is the batch size. The index :math:`c` ranges from [0, C-1], representing the class indices, @@ -217,12 +214,12 @@ class CrossEntropyLoss(nn.Cell): If reduction is not set to 'none' (the default is 'mean'), the loss is computed as: .. math:: - \ell(x, y) = \begin{cases} - \frac{\sum_{n=1}^N l_n}{N}, & - \text{if reduction} = \text{'mean',}\\ - \sum_{n=1}^N l_n, & - \text{if reduction} = \text{'sum'.} - \end{cases} + \\ell(x, y) = \\begin{cases} + \\frac{\\sum_{n=1}^N l_n}{N}, & + \\text{if reduction} = \\text{'mean',}\\\\ + \\sum_{n=1}^N l_n, & + \\text{if reduction} = \\text{'sum'.} + \\end{cases} Args: config (TransformerConfig): The parallel configuration. Default: default_transformer_config, @@ -258,7 +255,7 @@ class CrossEntropyLoss(nn.Cell): @_LogActionOnce(m_logger=logger, key='CrossEntropyLoss', no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,)) def __init__(self, config: TransformerConfig = default_transformer_config, loss_tag='lm', **kwargs): - super(CrossEntropyLoss, self).__init__() + super().__init__() dp = config.data_parallel_size mp = config.tensor_model_parallel_size cp = config.context_parallel_size @@ -347,7 +344,7 @@ class VocabParallelCrossEntropy(nn.Cell): """calculate cross entropy loss""" def __init__(self, config: TransformerConfig = default_transformer_config, **kwargs): - super(VocabParallelCrossEntropy, self).__init__() + super().__init__() self.cross_entropy = CrossEntropyLoss(config, **kwargs) def construct(self, vocab_parallel_logits, target, input_mask=None, label_smoothing=None): diff --git a/mindformers/parallel_core/training_graph/transformer/norm.py b/mindformers/parallel_core/training_graph/transformer/norm.py index 01adc1d01..1003a2f76 100644 --- a/mindformers/parallel_core/training_graph/transformer/norm.py +++ b/mindformers/parallel_core/training_graph/transformer/norm.py @@ -36,7 +36,7 @@ def get_strategy(config: TransformerConfig): class LayerNorm(nn.Cell): - r""" + """ Layer norm operation. Args: @@ -52,7 +52,7 @@ class LayerNorm(nn.Cell): """ def __init__(self, config, dim, eps=1e-5): - super(LayerNorm, self).__init__() + super().__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype @@ -117,7 +117,7 @@ class LayerNorm(nn.Cell): class FusedLayerNorm(nn.Cell): - r""" + """ Layer norm operation. Args: @@ -133,7 +133,7 @@ class FusedLayerNorm(nn.Cell): """ def __init__(self, config, dim, eps=1e-5): - super(FusedLayerNorm, self).__init__() + super().__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype @@ -170,8 +170,7 @@ class FusedLayerNorm(nn.Cell): strategy = (cp, dp, 1) if strategy[-1] != 1: - raise TypeError( - 'The last dim in FusedLayerNorm can not equal to 1! Strategy {} not supported!'.format(strategy)) + raise TypeError(f'The last dim in FusedLayerNorm can not equal to 1! Strategy {strategy} not supported!') self.layer_norm.shard((strategy, (strategy[-1],), (strategy[-1],))) @@ -180,7 +179,7 @@ class FusedLayerNorm(nn.Cell): class RMSNorm(nn.Cell): - r""" + """ A self-defined RMSNorm operation using reduce mean. Args: @@ -196,7 +195,7 @@ class RMSNorm(nn.Cell): """ def __init__(self, config, dim, eps=1e-6): - super(RMSNorm, self).__init__() + super().__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype @@ -251,7 +250,7 @@ class RMSNorm(nn.Cell): class FusedRMSNorm(nn.Cell): - r""" + """ FusedRMSNorm operation Args: @@ -267,7 +266,7 @@ class FusedRMSNorm(nn.Cell): """ def __init__(self, config, dim, eps=1e-6): - super(FusedRMSNorm, self).__init__() + super().__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype diff --git a/mindformers/parallel_core/training_graph/transformer/utils.py b/mindformers/parallel_core/training_graph/transformer/utils.py index 61bc60b41..72e20ee8f 100644 --- a/mindformers/parallel_core/training_graph/transformer/utils.py +++ b/mindformers/parallel_core/training_graph/transformer/utils.py @@ -77,7 +77,7 @@ ATTNMASK_FUNC_MAP = { def get_attn_mask_func(mask_func_type): - r""" + """ Get attention mask function. Args: diff --git a/mindformers/wrapper/wrapper.py b/mindformers/wrapper/wrapper.py index 6f009d021..9447616a5 100644 --- a/mindformers/wrapper/wrapper.py +++ b/mindformers/wrapper/wrapper.py @@ -207,7 +207,7 @@ class MFTrainOneStepCell(nn.TrainOneStepWithLossScaleCell): **kwargs (Any): Additional parameters. Inputs: - - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. + - **\\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \\ldots)`. Outputs: Tuple of 5 or 7 Tensor, the loss, overflow flag, current loss scale value, learning rate, diff --git a/tests/st/test_ut/test_api_compatibility.py b/tests/st/test_ut/test_api_compatibility.py index 492bb23bb..2f2514cf4 100644 --- a/tests/st/test_ut/test_api_compatibility.py +++ b/tests/st/test_ut/test_api_compatibility.py @@ -101,7 +101,7 @@ def is_not_compatibility(base_str, new_str): def set_failure_list(api_str, value, signature, failure_list): """set failure info list""" failure_list.append(f"# {api_str}:") - failure_list.append(f" - function signature is different: ") + failure_list.append(" - function signature is different: ") failure_list.append(f" - the base signature is {value}.") failure_list.append(f" - now it is {signature}.") @@ -170,12 +170,12 @@ def api_signature(obj, api_str, content, base_schema, failure_list, is_update=Fa else: tmp_len = -1 signature = None - for i in range(len(signature_list)): - if signature_list[i] == "(*args, **kwargs)": + for _, sig in enumerate(signature_list): + if sig == "(*args, **kwargs)": continue - if len(signature_list[i]) > tmp_len: - tmp_len = len(signature_list[i]) - signature = signature_list[i] + if len(sig) > tmp_len: + tmp_len = len(sig) + signature = sig else: signature = str(inspect.signature(obj)) @@ -293,7 +293,8 @@ class TestApiStability: def check_one_element(elem, mod_name, mod, is_public): obj = getattr(mod, elem) if hasattr(obj, "__module__"): - if obj.__module__ not in ['sentencepiece_model_pb2']: # cannot use __import__ module list + # cannot use __import__ module list + if obj.__module__ not in ['sentencepiece_model_pb2', 'node_strategy_pb2']: mod_source = str(__import__(obj.__module__)) if "mindformers" not in mod_source: return @@ -337,4 +338,4 @@ class TestApiStability: with open(self.api_json_path, "w", encoding="utf-8") as w: w.write(json.dumps(self.content, ensure_ascii=False, indent=4)) - assert not self.is_update, f"self.is_update should be set to False" + assert not self.is_update, "self.is_update should be set to False" diff --git a/tests/st/test_ut/test_transformer_apis.py b/tests/st/test_ut/test_transformer_apis.py index 062673a73..8ead9d11b 100644 --- a/tests/st/test_ut/test_transformer_apis.py +++ b/tests/st/test_ut/test_transformer_apis.py @@ -22,14 +22,14 @@ from mindspore.ops import operations as ops from mindspore.common.api import _cell_graph_executor from mindformers.core import CrossEntropyLoss -from mindformers.modules import FeedForward, FixedSparseAttention, LowerTriangularMaskWithDynamic +from mindformers.modules import FixedSparseAttention, LowerTriangularMaskWithDynamic class MyActivation(mindspore.nn.Cell): """An example of custom activation""" def __init__(self): - super(MyActivation, self).__init__() + super().__init__() self.add = ops.Add() def construct(self, x): @@ -43,27 +43,13 @@ class MyActivationNoShard(mindspore.nn.Cell): """An example of custom activation without shard""" def __init__(self): - super(MyActivationNoShard, self).__init__() + super().__init__() self.add = ops.Add() def construct(self, x): return self.add(x, 0.1) -def test_feedforward(): - """ - Feature: Feedforward - Description: Test Feedforward module - Expectation: No exception - """ - model = FeedForward(hidden_size=15, - ffn_hidden_size=30, - dropout_rate=0.1, - hidden_act='relu') - tensor = Tensor(np.ones((2, 20, 15)), dtype.float32) - _cell_graph_executor.compile(model, tensor) - - def test_cross_entropy_loss(): """ Feature: CrossEntropyLoss -- Gitee