diff --git a/mindformers/parallel_core/mf_model_config.py b/mindformers/parallel_core/mf_model_config.py index be1a97b497004b7874a2a5be7c49cc2db48604b5..ab57f4561d2e25a548889573247fef35c929093c 100644 --- a/mindformers/parallel_core/mf_model_config.py +++ b/mindformers/parallel_core/mf_model_config.py @@ -193,12 +193,6 @@ class MFModelConfig: use_flash_attention: bool = True """If true, use flash attention for the attention layer.""" - attention_pre_tokens: int = None - """Pre-tokens for flash attention.""" - - attention_next_tokens: int = None - """Next tokens for flash attention.""" - rotary_seq_len_interpolation_factor: float = None """ RoPE scaling used for linear interpolation of longer sequences. @@ -308,9 +302,6 @@ class MFModelConfig: callback_moe_droprate: bool = False """Whether to print each expert's load information through callback.""" - moe_init_method_std: float = 0.01 - """Standard deviation of the zero mean normal for the MoE initialization method.""" - first_k_dense_replace: int = None r""" Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). diff --git a/mindformers/parallel_core/model_parallel_config.py b/mindformers/parallel_core/model_parallel_config.py index 2b31f3090b2a946b0d6ac7b83cec905239a5ed95..141728b69f250eb4ff105451b095076ca8fd74c9 100644 --- a/mindformers/parallel_core/model_parallel_config.py +++ b/mindformers/parallel_core/model_parallel_config.py @@ -4,7 +4,7 @@ """Model Parallel Config""" from dataclasses import dataclass -from typing import Callable, Optional, Union +from typing import Optional, Union from mindformers.parallel_core.mf_model_config import convert_str_to_mstype @@ -103,48 +103,9 @@ class ModelParallelConfig: # Training ################### - fp16: bool = False - """If true, train with fp16 mixed precision training.""" - - bf16: bool = False - """If true, train with bf16 mixed precision training.""" - params_dtype: str = "float32" """dtype used when initializing the weights.""" - finalize_model_grads_func: Optional[Callable] = None - """ - Function that finalizes gradients on all workers. - Could include ensuring that grads are all-reduced across data parallelism, pipeline parallelism, - and sequence parallelism dimensions. - """ - - grad_scale_func: Optional[Callable] = None - """ - If using loss scaling, this function should take the loss and return the scaled loss. - If None, no function is called on the loss. - """ - - grad_sync_func: Optional[Callable] = None - """ - Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient reduce-scatters). - The function should take one argument: an iterable of parameters whose gradients are to be synchronized. - """ - - param_sync_func: Optional[Callable] = None - """ - Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer parameter all-gathers). - The function should take one argument: an iterable of parameters to be synchronized. - """ - - num_microbatches_with_partial_activation_checkpoints: Optional[int] = None - """ - If int, set the number of microbatches where not all of the layers will be checkpointed and recomputed. - The rest of the microbatches within the window of maximum outstanding - microbatches will recompute all layers (either full recompute or selective recompute). - If None, the checkpoint and recompute will be left up to the forward_step function. - """ - ################### # CPU Offloading ################### diff --git a/mindformers/parallel_core/training_graph/transformer/attention.py b/mindformers/parallel_core/training_graph/transformer/attention.py index bf50cfa94e5e32c45619ee64553379c93772df31..ac6098698edab8b0a06e917a47c96fa5c47efbef 100644 --- a/mindformers/parallel_core/training_graph/transformer/attention.py +++ b/mindformers/parallel_core/training_graph/transformer/attention.py @@ -149,8 +149,8 @@ class Attention(nn.Cell): self.layer_number = max(1, layer_number) self.norm_factor = math.sqrt(self.head_dim) self.seq_length = config.seq_length - self.pre_tokens = 2147483647 if self.config.attention_pre_tokens is None else self.config.attention_pre_tokens - self.next_tokens = 0 if self.config.attention_next_tokens is None else self.config.attention_next_tokens + self.pre_tokens = 2147483647 + self.next_tokens = 0 self.keep_prob = 1.0 if self.config.attention_dropout is None else 1 - self.config.attention_dropout self.use_attention_mask = True if self.config.use_attention_mask is None else self.config.use_attention_mask self.is_nope_layer = ( diff --git a/mindformers/parallel_core/transformer_config.py b/mindformers/parallel_core/transformer_config.py index 74f84c368818d6361f9c2509e9e7b4cef8916c21..297cdc5da60794b8938848b266fc86bf79ff2bb1 100644 --- a/mindformers/parallel_core/transformer_config.py +++ b/mindformers/parallel_core/transformer_config.py @@ -76,12 +76,6 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig): layernorm_epsilon: float = 1e-5 """Epsilon value for any LayerNorm operations.""" - layernorm_zero_centered_gamma: bool = False - """ - If set to True, the LayerNorm is adjusted to center the gamma values around 0. - This improves numerical stability. - """ - add_qkv_bias: bool = False """Add a bias term only for QKV projections.""" @@ -152,12 +146,6 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig): param_init_std_rules: List[dict[str, Union[str, float]]] = None """Configuration for decoupled weight initialization.""" - init_model_with_meta_device: bool = False - """ - If True, initializes the model with the meta device. This is helpful for - training of very large models. This feature is only works when custom fsdp is turned on. - """ - #################### # Mixed-Precision #################### @@ -177,32 +165,10 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig): softmax_compute_dtype: str = 'float32' """Data type for computing softmax during attention computation.""" - disable_bf16_reduced_precision_matmul: bool = False - """If True, prevent matmul from using reduced precision accumulation when using BF16.""" - #################### # Fusion #################### - bias_activation_fusion: bool = False - """If True, fuses bias addition and the activation function when possible.""" - - masked_softmax_fusion: bool = False - """If True, uses softmax fusion.""" - - persist_layer_norm: bool = False - """ - If True, uses the persistent fused layer norm kernel. - This kernel only supports a fixed set of hidden sizes. - """ - - memory_efficient_layer_norm: bool = False - """ - If True, and using local layers (not from TransformerEngine), - tells Apex to use the memory efficient fused LayerNorm kernel. - Ignored if not using LayerNorm. - """ - bias_dropout_fusion: bool = False """If True, uses bias dropout fusion.""" @@ -342,9 +308,6 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig): moe_z_loss_coeff: Optional[float] = None # 1e-3 would be a good start value for z-loss """Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended.""" - moe_input_jitter_eps: Optional[float] = None - """Add noise to the input tensor by applying jitter with a specified epsilon value.""" - group_wise_a2a: bool = False """ Whether to enable group-wise alltoall communication, @@ -357,25 +320,12 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig): """The type of token dispatcher to use. The default is 'alltoall'. Options are 'alltoall', 'alltoall_deredundency' and 'alltoall_zero_redundancy'.""" - moe_enable_deepep: bool = False - """[Experimental] Enable DeepEP for efficient token dispatching and combine in MoE models.""" - - moe_per_layer_logging: bool = False - """Enable per-layer logging for MoE, currently supports auxiliary loss and z loss.""" - moe_expert_capacity_factor: Optional[float] = None """ The capacity factor for each expert, None means no token will be dropped. The default is None. """ - moe_pad_expert_input_to_capacity: bool = False - """ - If True, pads the input for each expert to match the expert capacity length, - effective only after the moe_expert_capacity_factor is set. - The default setting is False. - """ - moe_token_drop_policy: str = 'probs' """ The policy to drop tokens. Can be either "probs" or "position". @@ -386,9 +336,6 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig): moe_permute_fusion: bool = False """Fuse token rearrangement ops during token dispatching.""" - moe_apply_probs_on_input: bool = False - """Apply probs on input of experts instead of applying after activation and glu.""" - # MindFormers New shared_expert_num: int = 0 """Number of shared experts.""" @@ -576,12 +523,6 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig): 'sub_seq_aux_loss, seq_aux_loss, gbs_aux_loss' ) - if self.moe_pad_expert_input_to_capacity: - if self.moe_expert_capacity_factor is None: - raise ValueError( - 'moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity' - ) - if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True diff --git a/mindformers/parallel_core/transformer_config_utils.py b/mindformers/parallel_core/transformer_config_utils.py index 46d30170415b98a9eaeafe7334ceefda5c8b6193..21b8407fc5aad4df89fa0573545c0edf06f0b268 100644 --- a/mindformers/parallel_core/transformer_config_utils.py +++ b/mindformers/parallel_core/transformer_config_utils.py @@ -192,6 +192,8 @@ DEFAULT_WHITE_KEY.update({ 'scaling_factor', 'input_sliced_sig', 'return_extra_loss', 'moe_config' }) +ERROR_LOG = {} + COMMON_CONFIG_MAPPING = { ##################################################################################### # Maps the configuration keys on the left to the TransformerConfig keys on the right. @@ -220,11 +222,6 @@ COMMON_CONFIG_MAPPING = { # Training ("param_init_type", "params_dtype"): "params_dtype", # not changes - "finalize_model_grads_func": "finalize_model_grads_func", - "grad_scale_func": "grad_scale_func", - "grad_sync_func": "grad_sync_func", - "param_sync_func": "param_sync_func", - "num_microbatches_with_partial_activation_checkpoints": "num_microbatches_with_partial_activation_checkpoints", "print_separate_loss": "print_separate_loss", "disable_lazy_inline": "disable_lazy_inline", "batch_size": "batch_size", @@ -248,7 +245,6 @@ COMMON_CONFIG_MAPPING = { ("qkv_has_bias", "attention_bias", "add_qkv_bias"): "add_qkv_bias", ("expert_num", "n_routed_experts", "num_experts", "num_moe_experts"): "num_moe_experts", ("num_layers", "num_hidden_layers", "n_layer"): "num_layers", - ("rope_interleave", "rotary_interleaved"): "rotary_interleaved", ("use_qk_norm", "qk_layernorm"): "qk_layernorm", # not changes "hidden_size": "hidden_size", @@ -256,7 +252,6 @@ COMMON_CONFIG_MAPPING = { "hidden_dropout": "hidden_dropout", "attention_dropout": "attention_dropout", "apply_residual_connection_post_layernorm": "apply_residual_connection_post_layernorm", - "layernorm_zero_centered_gamma": "layernorm_zero_centered_gamma", "add_bias_linear": "add_bias_linear", "gated_linear_unit": "gated_linear_unit", "activation_func": "activation_func", @@ -280,6 +275,7 @@ COMMON_CONFIG_MAPPING = { "post_process": "post_process", "add_mlp_fc1_bias_linear": "add_mlp_fc1_bias_linear", "add_mlp_fc2_bias_linear": "add_mlp_fc2_bias_linear", + "rotary_interleaved": "rotary_interleaved", "fp32_residual_connection": "fp32_residual_connection", "window_size": "window_size", "window_attn_skip_freq": "window_attn_skip_freq", @@ -290,8 +286,6 @@ COMMON_CONFIG_MAPPING = { # Flash Attention # not changes "use_flash_attention": "use_flash_attention", - "attention_pre_tokens": "attention_pre_tokens", - "attention_next_tokens": "attention_next_tokens", "rotary_seq_len_interpolation_factor": "rotary_seq_len_interpolation_factor", "use_rope_scaling": "use_rope_scaling", "rope_scaling": "rope_scaling", @@ -316,7 +310,6 @@ COMMON_CONFIG_MAPPING = { # not changes "init_method": "init_method", "output_layer_init_method": "output_layer_init_method", - "init_model_with_meta_device": "init_model_with_meta_device", # Mixed-Precision ("softmax_compute_dtype", "attention_softmax_in_fp32"): ( @@ -324,17 +317,12 @@ COMMON_CONFIG_MAPPING = { ), # not changes "apply_query_key_layer_scaling": "apply_query_key_layer_scaling", - "disable_bf16_reduced_precision_matmul": "disable_bf16_reduced_precision_matmul", # Fusion ("use_fused_rope", "apply_rope_fusion"): "apply_rope_fusion", ("use_fused_swiglu", "bias_swiglu_fusion"): "bias_swiglu_fusion", ("use_fused_ops_permute", "moe_permute_fusion"): "moe_permute_fusion", # not changes - "bias_activation_fusion": "bias_activation_fusion", - "masked_softmax_fusion": "masked_softmax_fusion", - "persist_layer_norm": "persist_layer_norm", - "memory_efficient_layer_norm": "memory_efficient_layer_norm", "bias_dropout_fusion": "bias_dropout_fusion", # not changes @@ -373,13 +361,8 @@ COMMON_CONFIG_MAPPING = { "first_k_dense_replace": "first_k_dense_replace", "moe_shared_expert_overlap": "moe_shared_expert_overlap", "moe_router_pre_softmax": "moe_router_pre_softmax", - "moe_input_jitter_eps": "moe_input_jitter_eps", "moe_token_dispatcher_type": "moe_token_dispatcher_type", "group_wise_a2a": "group_wise_a2a", - "moe_enable_deepep": "moe_enable_deepep", - "moe_per_layer_logging": "moe_per_layer_logging", - "moe_pad_expert_input_to_capacity": "moe_pad_expert_input_to_capacity", - "moe_apply_probs_on_input": "moe_apply_probs_on_input", "comp_comm_parallel": "comp_comm_parallel", "comp_comm_parallel_degree": "comp_comm_parallel_degree", "norm_topk_prob": "norm_topk_prob", @@ -388,7 +371,6 @@ COMMON_CONFIG_MAPPING = { "topk_method": "topk_method", "npu_nums_per_device": "npu_nums_per_device", "callback_moe_droprate": "callback_moe_droprate", - "moe_init_method_std": "moe_init_method_std", "moe_router_force_expert_balance": "moe_router_force_expert_balance", "moe_router_fusion": "moe_router_fusion", "print_expert_load": "print_expert_load", @@ -458,8 +440,8 @@ def convert_to_transformer_config( """ # Check whether the type of `model_config` is legal if model_config is None or not isinstance(model_config, (dict, PretrainedConfig)): - raise ValueError(f"The ModelConfig should be an instance of 'PretrainedConfig' or 'dict', " - f"but got '{type(model_config)}'.") + raise TypeError(f"The ModelConfig should be an instance of 'PretrainedConfig' or 'dict', " + f"but got '{type(model_config)}'.") if isinstance(model_config, PretrainedConfig): model_config = model_config.to_dict() @@ -474,23 +456,31 @@ def convert_to_transformer_config( if not_convert_whitelist is None: not_convert_whitelist = set() not_convert_whitelist.update(DEFAULT_WHITE_KEY) - logger.info(f"These Keys of this model will do not need to be mapped: {not_convert_whitelist}") # Record the new Config after conversion update_dict = {} - # Record the keys of `model_config` outside the mapping rules in conversion - not_convert_keys_list = [] def mapping_config(key, value): """Map the model_config's key and add it to 'update_dict'.""" mapping_key = convert_map[key] if not isinstance(mapping_key, str): (mapping_key, trans_func) = mapping_key - value = trans_func(value) + try: + value = trans_func(value) + except Exception as e: + ERROR_LOG.setdefault('Function Convert Error', []).append( + f"The key `{key}` converted to `{mapping_key}` failed. " + f"Please check the function `{trans_func.__name__}`. " + f"And the Error info is: {e}" + ) if mapping_key in update_dict: - raise KeyError(f"Multiple configurations provided for the same setting. " - f"Please check these conflicting configs: {list(reversed_mapping[mapping_key])}") - update_dict[mapping_key] = value + ERROR_LOG.setdefault('Multiple Config', []).append( + f"The key `{mapping_key}`({value}) is already converted into TransformerConfig. " + f"And the value is: {update_dict[mapping_key]}. " + f"Please check these conflicting configs: `{list(reversed_mapping[mapping_key])}`." + ) + else: + update_dict[mapping_key] = value # Start converting parameters if 'parallel_config' in model_config: @@ -509,12 +499,19 @@ def convert_to_transformer_config( if model_config_key in convert_map: mapping_config(model_config_key, model_config_value) else: - not_convert_keys_list.append(model_config_key) - - # If there are any unconverted key values, print them out to inform the user to check the configuration - if not_convert_keys_list: - raise ValueError(f"Keys: {not_convert_keys_list} dose not be converted! " - f"Please check your config parameters.") + ERROR_LOG.setdefault('Unexpected Keys', []).append( + f"`{model_config_key}` is unexpected in TransformerConfig, therefore they will not be converted. " + f"Please check your config parameters." + ) + + # If there are any failed log, print them formatted. + if ERROR_LOG: + error_str = "There are errors occurred during the mapping process: \n" + for title_idx, (key, value) in enumerate(ERROR_LOG.items()): + error_str = error_str + f" ({title_idx + 1}) {key}:\n" + for idx, err in enumerate(value): + error_str = error_str + f" {idx + 1}. " + err + "\n" + raise RuntimeError(f"Configuration conversion failed. Please check the error log:\n{error_str}") # If it is an MLA model, use MLATransformerConfig for initialization if is_mla_model: diff --git a/tests/st/test_ut/test_models/test_model_config_utils.py b/tests/st/test_ut/test_models/test_model_config_utils.py index 07024ed18505ab798e485d2075a43c75fb216d54..79f841981f6fff0373255199aea563cba2ac3611 100644 --- a/tests/st/test_ut/test_models/test_model_config_utils.py +++ b/tests/st/test_ut/test_models/test_model_config_utils.py @@ -98,7 +98,7 @@ class HuggingFaceModelConfigExample3(PretrainedConfig): @register_mf_model_parameter( mf_model_kwargs=MFModelConfig( compute_dtype="bfloat16", - attention_pre_tokens=2025 + batch_size=32 ) ) def __init__( @@ -182,4 +182,4 @@ class TestModelConfigUtils: assert ignore_n_shared_experts == "useless" assert config.compute_dtype == "float32" assert config.layernorm_compute_dtype == "bf16" - assert config.attention_pre_tokens == 2025 + assert config.batch_size == 32 diff --git a/tests/st/test_ut/test_parallel_core/test_transformer_config/test_convert_to_transformer_config.py b/tests/st/test_ut/test_parallel_core/test_transformer_config/test_convert_to_transformer_config.py index 477652a4d73254bcfc79f96172415426aef0ba11..28aa1a949a995634d094b45f04ea2d886c498b58 100644 --- a/tests/st/test_ut/test_parallel_core/test_transformer_config/test_convert_to_transformer_config.py +++ b/tests/st/test_ut/test_parallel_core/test_transformer_config/test_convert_to_transformer_config.py @@ -16,10 +16,14 @@ import pytest -# pylint: disable=import-outside-toplevel +import numpy as np +from mindspore import dtype as mstype + +from mindformers.parallel_core.mf_model_config import convert_str_to_mstype from mindformers.parallel_core.transformer_config_utils import convert_to_transformer_config from mindformers.parallel_core.transformer_config import TransformerConfig, MLATransformerConfig + class DummyConfig(dict): """A dummy config that behaves like a dict and supports attribute access.""" @@ -110,10 +114,10 @@ def test_is_mla_model_case(): def test_passed_in_none_model_config_case(): """ Feature: Test none model config case. - Description: Input None config to raise ValueError. - Expectation: Capture the raised ValueError. + Description: Input None config to raise TypeError. + Expectation: Capture the raised TypeError. """ - with pytest.raises(ValueError): + with pytest.raises(TypeError): convert_to_transformer_config(None) @@ -156,7 +160,7 @@ def test_not_exist_key_in_mapping_case(): """ Feature: Test not exist keys in mapping case. Description: Input config has key that do not need to be converted. - Expectation: Capture the raised ValueError. + Expectation: Capture the raised RuntimeError. """ config = DummyConfig({ 'num_layers': 2, @@ -167,7 +171,7 @@ def test_not_exist_key_in_mapping_case(): 'context_parallel': 1, 'not_exist_key': 999 }) - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): convert_to_transformer_config(config) @@ -180,7 +184,6 @@ def test_empty_str_to_convert_str_to_mstype_case(): Description: Input str is empty. Expectation: Capture the raised ValueError. """ - from mindformers.parallel_core.mf_model_config import convert_str_to_mstype with pytest.raises(ValueError): convert_str_to_mstype('') @@ -192,16 +195,20 @@ def test_passed_in_dtype_case(): """ Feature: Test dtype passed in convert_str_to_mstype. Description: Input a string of dtype, mindspore dtype, and a numpy dtype. - Expectation: No interception for the string of dtype, but for mindspore dtype and numpy dtype. + Expectation: No interception for the expected string of dtype, + but for pass in an unexpected string, mindspore dtype, numpy dtype, or an integer. """ - from mindformers.parallel_core.mf_model_config import convert_str_to_mstype - from mindspore import dtype as mstype result = convert_str_to_mstype('bf16') assert result == mstype.bfloat16 with pytest.raises(TypeError): convert_str_to_mstype(mstype.float16) - import numpy as np with pytest.raises(TypeError): convert_str_to_mstype(np.float32) + + with pytest.raises(TypeError): + convert_str_to_mstype(32) + + with pytest.raises(ValueError): + convert_str_to_mstype('fp8')