diff --git a/configs/glm4/predict_glm4_9b_chat_800I_A2.yaml b/configs/glm4/predict_glm4_9b_chat_800I_A2.yaml index 9c493e5f69b7bb8ff0daaff4a041bbef1af3bc0e..ec3ff637d3af166c1ab7c5ecd78577ddfa6186a9 100644 --- a/configs/glm4/predict_glm4_9b_chat_800I_A2.yaml +++ b/configs/glm4/predict_glm4_9b_chat_800I_A2.yaml @@ -47,7 +47,6 @@ model: param_init_type: "bfloat16" compute_dtype: "bfloat16" layernorm_compute_type: "float32" - residual_dtype: "bfloat16" use_past: True is_dynamic: True qkv_concat: True diff --git a/configs/qwen3/finetune_qwen3.yaml b/configs/qwen3/finetune_qwen3.yaml index d1044dd67c29fcff9663b84f06644605c7b7aaa0..3486a1b3a0bf99641696609415e0a9de82a1d705 100644 --- a/configs/qwen3/finetune_qwen3.yaml +++ b/configs/qwen3/finetune_qwen3.yaml @@ -144,7 +144,7 @@ model: layernorm_compute_dtype: "float32" softmax_compute_dtype: "float32" rotary_dtype: "float32" - residual_dtype: "float32" + fp32_residual_connection: True # Callbacks configuration, reference: https://www.mindspore.cn/mindformers/docs/en/r1.5.0/appendix/conf_files.html?highlight=enable_alltoall#callbacks-configuration callbacks: diff --git a/configs/qwen3/pretrain_qwen3_32b_4k.yaml b/configs/qwen3/pretrain_qwen3_32b_4k.yaml index 5d955afae017353f384a92a3e18dfa14df9284ee..f55c686f8b4fbbe329ddbafa391fee91507abf3e 100644 --- a/configs/qwen3/pretrain_qwen3_32b_4k.yaml +++ b/configs/qwen3/pretrain_qwen3_32b_4k.yaml @@ -159,7 +159,7 @@ model: layernorm_compute_dtype: "float32" softmax_compute_dtype: "float32" rotary_dtype: "float32" - residual_dtype: "float32" + fp32_residual_connection: True model_type: "qwen3" architectures: ["Qwen3ForCausalLM"] diff --git a/mindformers/parallel_core/mf_model_config.py b/mindformers/parallel_core/mf_model_config.py index dd7ba2557fdb62e6e3d8625ee98303b543209fa4..069b177d67dd74e3cc38b431475b1d0372e31bb1 100644 --- a/mindformers/parallel_core/mf_model_config.py +++ b/mindformers/parallel_core/mf_model_config.py @@ -345,12 +345,6 @@ class MFModelConfig: hidden_dropout: float = 0.0 """Dropout probability for transformer hidden state.""" - residual_dtype: str = None - """ - Data type computed in residual connections. - It will be converted to `fp32_residual_connection` in `TransformerConfig`. - """ - print_separate_loss: bool = True """Print lm_loss, extra_loss and mtp_loss separately.""" @@ -428,6 +422,3 @@ class MFModelConfig: def __post_init__(self): self.parallel_config = default_transformer_config - - if self.residual_dtype is None: - self.residual_dtype = self.compute_dtype diff --git a/mindformers/parallel_core/training_graph/base_models/common/embeddings/rope_utils.py b/mindformers/parallel_core/training_graph/base_models/common/embeddings/rope_utils.py index 90c380846747786725cf8e0955f422628ba5b733..409d9f6c7f16f5c38fd1a464d9c21897b2caf3e7 100644 --- a/mindformers/parallel_core/training_graph/base_models/common/embeddings/rope_utils.py +++ b/mindformers/parallel_core/training_graph/base_models/common/embeddings/rope_utils.py @@ -22,7 +22,7 @@ __all__ = [ from mindspore import nn, Tensor, ops from mindspore.context import ParallelMode from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation -from mindspore.ops.auto_generate import AddExt, Reshape, Mul, Cos, Sin, Split, Neg, Concat, StackExt, StridedSlice +from mindspore.ops.auto_generate import AddExt, Reshape, Mul, Cos, Sin, Split, Neg, Concat, StackExt, StridedSlice, Cast from mindformers.parallel_core.training_graph.device_matrix import layout from mindformers.parallel_core.transformer_config import TransformerConfig, MLATransformerConfig from mindformers.parallel_core.training_graph.base_models.common.embeddings.rotary_pos_embedding import ( @@ -75,7 +75,7 @@ class ApplyRotaryPosEmb(nn.Cell): config: TransformerConfig, for_k_pos_emb=False ): - super(ApplyRotaryPosEmb, self).__init__() + super().__init__() self.append_eod = config.use_eod_reset self.add = AddExt() self.mul = Mul() @@ -89,8 +89,10 @@ class ApplyRotaryPosEmb(nn.Cell): self.slice = StridedSlice() self.strideslice = StridedSlice() self.reshape = Reshape() + self.cast = Cast() self.apply_rope_fusion = config.apply_rope_fusion self.for_k_pos_emb = for_k_pos_emb + self.rotary_dtype = config.rotary_dtype if self.apply_rope_fusion: self.rope = ops.auto_generate.gen_ops_prim.RotaryPositionEmbedding() @@ -118,6 +120,7 @@ class ApplyRotaryPosEmb(nn.Cell): Returns: Tensor: Output tensor after applying rotary position embedding """ + origin_dtype = t.dtype seq_len, bs, n_heads, head_dim = t.shape freqs, m_scale = freqs rot_dim = freqs.shape[-1] @@ -132,8 +135,9 @@ class ApplyRotaryPosEmb(nn.Cell): x2 = self.strideslice(t, (0, 0, 0, 1), (seq_len, bs, n_heads, head_dim), (1, 1, 1, 2)) t = self.cat((x1, x2)) - cos_ = self.cos(self.mul_mscale(freqs, m_scale)).astype(t.dtype) - sin_ = self.sin(self.mul_mscale(freqs, m_scale)).astype(t.dtype) + t = self.cast(t, self.rotary_dtype) + cos_ = self.cast(self.cos(self.mul_mscale(freqs, m_scale)), self.rotary_dtype) + sin_ = self.cast(self.sin(self.mul_mscale(freqs, m_scale)), self.rotary_dtype) if self.apply_rope_fusion: output = self.rope(t, cos_, sin_, 0) @@ -143,7 +147,7 @@ class ApplyRotaryPosEmb(nn.Cell): if t_not_rotary is not None: output = self.cat((output, t_not_rotary)) - return output + return self.cast(output, origin_dtype) def _rotate_half(self, t: Tensor, rotary_interleaved: bool = False) -> Tensor: """Rotates half of the input tensor for rotary position embeddings. diff --git a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py index 342d901d62b296fcf112974939e387e61e49d73e..5fcb7be3fe626d6aee10610e434719622696a75e 100644 --- a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py +++ b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py @@ -324,6 +324,10 @@ class GPTModel(nn.Cell): raise NotImplementedError("position_embedding_type = mrope is not supported now.") elif self.position_embedding_type == 'none': self.rotary_pos_emb = None + if config.rotary_dtype == dtype.float16: + raise ValueError("rotary_dtype `float16` is not supported now.") + if config.rotary_dtype != dtype.float32: + logger.warning("For training stability, rotary_dtype is recommended to `float32`.") if self.use_rotary_position_embeddings: self.rotary_pos_emb.shard(config) diff --git a/mindformers/parallel_core/training_graph/transformer/moe/ffn.py b/mindformers/parallel_core/training_graph/transformer/moe/ffn.py index a9672a9df733cc03a2dff0e4cfab2344a35ef920..6ff7e9d6a45a1b243a5504597992bdb6f77f82d8 100644 --- a/mindformers/parallel_core/training_graph/transformer/moe/ffn.py +++ b/mindformers/parallel_core/training_graph/transformer/moe/ffn.py @@ -140,12 +140,14 @@ class FFNGroupedGEMM(nn.Cell): if self.moe_token_dispatcher_type == "alltoall_deredundency": tokens = self.stride_slice(tokens, (0, 0, 0), new_input_tensor_shape, (1, 1, 1)) dtype = tokens.dtype - w1 = self.cast_op(self.weight1, dtype) - w2 = self.cast_op(self.weight2, dtype) + tokens = self.cast_op(tokens, self.compute_dtype) + w1 = self.cast_op(self.weight1, self.compute_dtype) + w2 = self.cast_op(self.weight2, self.compute_dtype) # reshape w1 and w2 to [E, h, H] and [E, H, h] output = self.morphed_forward(tokens, probs, routing_map, w1, w2, self.num_tokens_per_expert) if self.moe_token_dispatcher_type == "alltoall_deredundency": output = self.stride_slice(output, (0, 0, 0), new_input_tensor_shape, (1, 1, 1)) + output = self.cast_op(output, dtype) return output def forward_func(self, tokens, probs, routing_map, w1, w2, num_tokens_per_expert=None): diff --git a/mindformers/parallel_core/training_graph/transformer/transformer_layer.py b/mindformers/parallel_core/training_graph/transformer/transformer_layer.py index 86983b7f4358bedf3162c9808757d70fb4430ee3..2070dd2ffbbded0ac8f7e8061d3de507ba4592f7 100644 --- a/mindformers/parallel_core/training_graph/transformer/transformer_layer.py +++ b/mindformers/parallel_core/training_graph/transformer/transformer_layer.py @@ -15,6 +15,7 @@ """Transformer Layer""" from dataclasses import dataclass from typing import Union +import mindspore as ms from mindspore.ops import auto_generate as aclnn_ops from mindspore import nn from mindspore.context import ParallelMode @@ -92,6 +93,8 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout self.use_eod_attn_mask_compression = config.use_eod_attn_mask_compression self.sequence_parallel = config.sequence_parallel + self.fp32_residual_connection = config.fp32_residual_connection + self.compute_dtype = config.compute_dtype self.input_layernorm = build_module( submodules.input_layernorm, @@ -149,6 +152,7 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): self.hidden_states_dropout = Dropout(drop_prob=self.hidden_dropout) self.add = aclnn_ops.AddExt() self.add_bias = aclnn_ops.AddExt() + self.cast = aclnn_ops.Cast() # NOTE: Recompute configuration is managed by # training_graph.transformer.utils.LayerSetting @@ -224,6 +228,9 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): # Dropout dropout_output = self.hidden_states_dropout(attention_output) + if self.fp32_residual_connection: + residual = self.cast(residual, ms.float32) + dropout_output = self.cast(dropout_output, ms.float32) norm_input = self.add(residual, dropout_output) # Layer norm post the self attention @@ -246,7 +253,12 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): # Dropout dropout_output = self.hidden_states_dropout(mlp_output) + if self.fp32_residual_connection: + residual = self.cast(residual, ms.float32) + dropout_output = self.cast(dropout_output, ms.float32) output = self.add(residual, dropout_output) + output = self.cast(output, self.compute_dtype) + # 'return context' is useless, this param may be deprecated later. return output, context, extra_loss diff --git a/mindformers/parallel_core/transformer_config_utils.py b/mindformers/parallel_core/transformer_config_utils.py index 7747b3a3064c8826e0eb8bae401fb85899183c6b..bce3b218cddcb21ec52002defef12c201e786890 100644 --- a/mindformers/parallel_core/transformer_config_utils.py +++ b/mindformers/parallel_core/transformer_config_utils.py @@ -244,9 +244,6 @@ COMMON_CONFIG_MAPPING = { ("n_kv_heads", "num_key_value_heads", "num_query_groups"): "num_query_groups", ("intermediate_size", "ffn_hidden_size"): "ffn_hidden_size", ("head_dim", "kv_channels"): "kv_channels", - ("residual_dtype", "fp32_residual_connection"): ( - "fp32_residual_connection", is_float_32 - ), ("rms_norm_eps", "layernorm_epsilon", "layer_norm_epsilon"): "layernorm_epsilon", ("qkv_has_bias", "attention_bias", "add_qkv_bias"): "add_qkv_bias", ("expert_num", "n_routed_experts", "num_experts", "num_moe_experts"): "num_moe_experts", @@ -283,6 +280,7 @@ COMMON_CONFIG_MAPPING = { "post_process": "post_process", "add_mlp_fc1_bias_linear": "add_mlp_fc1_bias_linear", "add_mlp_fc2_bias_linear": "add_mlp_fc2_bias_linear", + "fp32_residual_connection": "fp32_residual_connection", # Flash Attention # not changes diff --git a/tests/st/test_ut/test_parallel_core/test_training_graph/test_base_models/test_embeddings/test_apply_rotary_pos_emb/run_apply_rotary_pos_emb.py b/tests/st/test_ut/test_parallel_core/test_training_graph/test_base_models/test_embeddings/test_apply_rotary_pos_emb/run_apply_rotary_pos_emb.py index 4eeedc49d7955406bb9fcc38bfaf6d507aea2b9a..c07d3a6015c0c53221fdbdca83868c9d7fb8d1ca 100644 --- a/tests/st/test_ut/test_parallel_core/test_training_graph/test_base_models/test_embeddings/test_apply_rotary_pos_emb/run_apply_rotary_pos_emb.py +++ b/tests/st/test_ut/test_parallel_core/test_training_graph/test_base_models/test_embeddings/test_apply_rotary_pos_emb/run_apply_rotary_pos_emb.py @@ -19,9 +19,10 @@ from pathlib import Path import numpy as np import mindspore as ms from mindspore.communication import init +from data_gen_utils import get_init_params from mindformers.parallel_core.training_graph.base_models.common.embeddings.rope_utils import ApplyRotaryPosEmb from mindformers.parallel_core.transformer_config import TransformerConfig -from data_gen_utils import get_init_params + SCRIPT_DIR = Path(__file__).parent.resolve() @@ -63,6 +64,7 @@ class ApplyRotaryPosEmbRunner: tensor_model_parallel_size=self.args.tensor_parallel, compute_dtype='bf16', params_dtype='fp32', + rotary_dtype='bf16', num_attention_heads=self.args.tensor_parallel, num_layers=1 ) diff --git a/tests/st/test_ut/test_parallel_core/test_training_graph/test_base_models/test_embeddings/test_apply_rotary_pos_emb/test_fused_rope.py b/tests/st/test_ut/test_parallel_core/test_training_graph/test_base_models/test_embeddings/test_apply_rotary_pos_emb/test_fused_rope.py index eb645cc8b33968e7433a27fa1ce2f9205cdf41be..cf755dd1668887dea44c0602609820637767af29 100644 --- a/tests/st/test_ut/test_parallel_core/test_training_graph/test_base_models/test_embeddings/test_apply_rotary_pos_emb/test_fused_rope.py +++ b/tests/st/test_ut/test_parallel_core/test_training_graph/test_base_models/test_embeddings/test_apply_rotary_pos_emb/test_fused_rope.py @@ -15,9 +15,9 @@ """Test module for testing fused RoPE in ApplyRotaryPosEmb used for mindformers.""" import pytest import mindspore as ms +from tests.utils.double_benchmark import DoubleBenchmarkComparator, DoubleBenchmarkStandard from mindformers.parallel_core.training_graph.base_models.common.embeddings.rope_utils import ApplyRotaryPosEmb from mindformers.parallel_core.transformer_config import TransformerConfig -from tests.utils.double_benchmark import DoubleBenchmarkComparator, DoubleBenchmarkStandard from .data_gen_utils import get_init_params @@ -34,10 +34,12 @@ class TestFusedRoPE: self.freqs = (self.input_freqs, self.mscale) self.no_fused_rope_config = TransformerConfig(num_attention_heads=1, num_layers=1, - apply_rope_fusion=False) + apply_rope_fusion=False, + rotary_dtype='bf16') self.fused_rope_config = TransformerConfig(num_attention_heads=1, num_layers=1, - apply_rope_fusion=True) + apply_rope_fusion=True, + rotary_dtype='bf16') def run_test(self): """Helper function to run test""" diff --git a/tests/st/test_ut/test_parallel_core/test_transformer_config/test_convert_to_transformer_config.py b/tests/st/test_ut/test_parallel_core/test_transformer_config/test_convert_to_transformer_config.py index 0d144ee09e08809e6b6c53730d1a16afd99d377a..477652a4d73254bcfc79f96172415426aef0ba11 100644 --- a/tests/st/test_ut/test_parallel_core/test_transformer_config/test_convert_to_transformer_config.py +++ b/tests/st/test_ut/test_parallel_core/test_transformer_config/test_convert_to_transformer_config.py @@ -16,6 +16,7 @@ import pytest +# pylint: disable=import-outside-toplevel from mindformers.parallel_core.transformer_config_utils import convert_to_transformer_config from mindformers.parallel_core.transformer_config import TransformerConfig, MLATransformerConfig @@ -123,11 +124,10 @@ def test_trans_func_case(): """ Feature: Test the function `trans_func` can run normally. Description: Input config contains the key that will trigger `trans_func`, - such as 'residual_dtype', 'softmax_compute_dtype', 'first_k_dense_replace', 'use_gating_sigmoid'. + such as 'softmax_compute_dtype', 'first_k_dense_replace', 'use_gating_sigmoid'. Expectation: `trans_func` can convert the mapping of special keys, and instantiate TransformerConfig successfully. """ config = DummyConfig({ - 'residual_dtype': 'fp32', 'softmax_compute_dtype': 'float16', 'first_k_dense_replace': 2, 'use_gating_sigmoid': True, @@ -138,14 +138,12 @@ def test_trans_func_case(): 'context_parallel': 1, }) result = convert_to_transformer_config(config) - used_parameter = True unused_parameter = False assert result.num_layers == 2 assert result.hidden_size == 128 assert result.num_attention_heads == 8 assert result.pipeline_model_parallel_size == 1 assert result.context_parallel_size == 1 - assert result.fp32_residual_connection == used_parameter assert result.attention_softmax_in_fp32 == unused_parameter assert result.first_k_dense_replace == 2 assert result.moe_router_score_function == "sigmoid"