diff --git a/omni/models/common/config/model_config.py b/omni/models/common/config/model_config.py index 4bd1378b76701c141dd2e909c0383cbcc21f41ee..7471c55f97d7fe460680ba29b9507418810a5cdd 100644 --- a/omni/models/common/config/model_config.py +++ b/omni/models/common/config/model_config.py @@ -40,6 +40,7 @@ class ModelOperatorOptConfig: opt_w2_scale_cast: bool = False enable_mc2_v2: bool = False decode_gear_list: list[int] = field(default_factory=lambda: [1]) + enable_combine_addrmsnorm_fusion: bool = False control_accept_rate: float = -1 # <0 or >1 不控制, >=0 and <=1 控制MTP开启时接受率为该值,几乎必然导致输出结果异常,仅保证只投机1个token时满足这一数值 use_prefetch: bool = True # 是否开启预取 @@ -130,6 +131,7 @@ def update_model_extra_config(**kwargs): operator_opt_config.use_super_kernel = False operator_opt_config.use_prefetch = False operator_opt_config.use_mlaprolog = False + operator_opt_config.enable_combine_addrmsnorm_fusion = False model_extra_config = get_model_extra_config() diff --git a/omni/models/common/layers/layernorm.py b/omni/models/common/layers/layernorm.py index 27449c29a108eb73256438f7ab323056b16538af..8bd039add506ca548932684d4dabf1f988af5050 100644 --- a/omni/models/common/layers/layernorm.py +++ b/omni/models/common/layers/layernorm.py @@ -18,7 +18,10 @@ class RMSNorm(RMSNormGPU): quant_symbol: bool = False, ) -> Union[tuple[dict[str, Any], Any], Any]: if residual is not None: - x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight, self.variance_epsilon) + if model_extra_config.operator_opt_config.enable_combine_addrmsnorm_fusion: + x, _, residual = torch_npu.npu_add_rms_norm(residual, x, self.weight, self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight, self.variance_epsilon) if model_extra_config.operator_opt_config.use_w8a8_dynamic_quant and quant_symbol: x_int8, pertoken_scale = torch_npu.npu_dynamic_quant(x) x = {"x_int8": x_int8, "pertoken_scale": pertoken_scale} diff --git a/omni/models/deepseek/deepseek_v3.py b/omni/models/deepseek/deepseek_v3.py index 50257140108271e3732ffae4618b1ce6ae7ce4a7..6593b699428ed4759fcb1bc5c54ba44ba4252d49 100644 --- a/omni/models/deepseek/deepseek_v3.py +++ b/omni/models/deepseek/deepseek_v3.py @@ -148,6 +148,7 @@ class DeepseekDecoderLayer(nn.Module): self.layer_name = f"{prefix}.self_attn.attn" self.hidden_size = config.hidden_size self.quant_symbol = quant_config is not None + self.end_layer_idx = config.num_hidden_layers - config.first_k_dense_replace - 1 rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", @@ -216,6 +217,9 @@ class DeepseekDecoderLayer(nn.Module): hidden_states, residual = self.input_layernorm( hidden_states, residual, quant_symbol=(not model_extra_config.operator_opt_config.use_mlaprolog and self.quant_symbol)) # Adapt end. + if isinstance(hidden_states, torch.Tensor) and model_extra_config.operator_opt_config.enable_combine_addrmsnorm_fusion: + hidden_states = hidden_states.view(-1, self.hidden_size) + residual = residual.view(-1, self.hidden_size) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -262,7 +266,12 @@ class DeepseekDecoderLayer(nn.Module): hidden_states, residual = self.mlp(hidden_states, residual, attn_metadata, layer_id, next_attention_weights) if isinstance(hidden_states, (tuple, list)): assert len(hidden_states) == 2 - hidden_states = hidden_states[0] + hidden_states[1] + is_moe_layer_and_not_end_layer = layer_id and layer_id < self.end_layer_idx + if model_extra_config.operator_opt_config.enable_combine_addrmsnorm_fusion and is_moe_layer_and_not_end_layer: + residual = residual.view(-1, 1, self.hidden_size) + hidden_states = hidden_states[0].view(-1, 1, self.hidden_size) + hidden_states[1].view(-1, 1, self.hidden_size) + else: + hidden_states = hidden_states[0] + hidden_states[1] else: hidden_states, residual = self.mlp(hidden_states, residual, attn_metadata) diff --git a/tests/test_config/test_config_decode.json b/tests/test_config/test_config_decode.json index 132ee20f5ba7a0cf2cdba11ccd71eec881b0a3c9..5bdf3ee17d0e2323445c84a040418e45d6ec8335 100644 --- a/tests/test_config/test_config_decode.json +++ b/tests/test_config/test_config_decode.json @@ -28,6 +28,7 @@ "use_prefetch": true, "expert_gate_up_prefetch": 50, "expert_down_prefetch": 28, - "attn_prefetch": 96 + "attn_prefetch": 96, + "enable_combine_addrmsnorm_fusion": false } } \ No newline at end of file diff --git a/tests/test_config/test_config_decode_dp288.json b/tests/test_config/test_config_decode_dp288.json index 5ccc46d815fb3ac687cb76814898cb9f7fb02208..d001ef4a9e141d0e5d1d2629aebbfb1496e810e6 100644 --- a/tests/test_config/test_config_decode_dp288.json +++ b/tests/test_config/test_config_decode_dp288.json @@ -30,6 +30,7 @@ "use_prefetch": true, "expert_gate_up_prefetch": 50, "expert_down_prefetch": 28, - "attn_prefetch": 96 + "attn_prefetch": 96, + "enable_combine_addrmsnorm_fusion": false } } diff --git a/tests/test_config/test_config_prefill.json b/tests/test_config/test_config_prefill.json index 6d5a60ca5cd797b152375ebf997e4fda6e13a95b..71939dcce3b9dd313358362dbc0a410ba85f678f 100644 --- a/tests/test_config/test_config_prefill.json +++ b/tests/test_config/test_config_prefill.json @@ -26,6 +26,7 @@ "control_accept_rate": -1, "enable_mc2_v2": false, "enable_prefill_micro_batch": false, - "use_prefetch": false + "use_prefetch": false, + "enable_combine_addrmsnorm_fusion": false } } \ No newline at end of file