From 432a5c56a438654ff0d55457b9335aebc1dc078f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E4=B8=80=E9=A3=9E?= Date: Tue, 18 Nov 2025 16:46:11 +0800 Subject: [PATCH] enable qwen3_moe to turn off FlashComm1 --- omni/models/config_loader/loader.py | 1 + .../configs/qwen3_30b_a3b_bf16_a3_1p1d_d.json | 3 +- omni/models/qwen/fused_moe/layer.py | 154 ++++++++++++++++-- omni/models/qwen/qwen3_moe.py | 41 ++++- 4 files changed, 177 insertions(+), 22 deletions(-) diff --git a/omni/models/config_loader/loader.py b/omni/models/config_loader/loader.py index 3d1f7c0ca0..32cd4c4d0a 100644 --- a/omni/models/config_loader/loader.py +++ b/omni/models/config_loader/loader.py @@ -81,6 +81,7 @@ class ModelOperatorOptConfig: gmm_nz: bool = False unquant_bmm_nz: bool = False decode_moe_dispatch_combine: bool = True + decode_flash_comm_1: bool = True # decode节点开启FlashComm1优化 use_super_kernel: bool = False enable_prefill_micro_batch: bool = False use_mlaprolog: bool = False diff --git a/omni/models/configs/qwen3_30b_a3b_bf16_a3_1p1d_d.json b/omni/models/configs/qwen3_30b_a3b_bf16_a3_1p1d_d.json index 1f15b2a454..caad5bad58 100644 --- a/omni/models/configs/qwen3_30b_a3b_bf16_a3_1p1d_d.json +++ b/omni/models/configs/qwen3_30b_a3b_bf16_a3_1p1d_d.json @@ -9,7 +9,8 @@ "merge_qkv": false, "gmm_nz": true, "unquant_bmm_nz": true, - "decode_moe_dispatch_combine": true, + "decode_moe_dispatch_combine": false, + "decode_flash_comm_1": false, "use_super_kernel": true, "use_mlaprolog": false, "control_accept_rate": -1, diff --git a/omni/models/qwen/fused_moe/layer.py b/omni/models/qwen/fused_moe/layer.py index c25a908a71..a3a363ec64 100644 --- a/omni/models/qwen/fused_moe/layer.py +++ b/omni/models/qwen/fused_moe/layer.py @@ -492,7 +492,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=None + indices_type=None, + finished=~layer.mc2_mask if layer.mc2_mask is not None else None ) topk_ids = topk_ids.int() topk_ids = layer.apply_expert_load_balance( @@ -581,6 +582,104 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase): return output_combine + def apply_allreduce_decode( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_range: List[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = 'softmax', + e_score_correction_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=None, + finished=~layer.mc2_mask if layer.mc2_mask is not None else None + ) + topk_ids = layer.apply_expert_load_balance( + topk_ids=topk_ids, + best_topk_ids=None + ) + sorted_tokens, expanded_x_idx, expert_tokens, _ = torch_npu.npu_moe_init_routing_v2( + x, + topk_ids, + scale=None, + offset=None, + active_num=topk_ids.numel(), + expert_num=global_num_experts, + expert_capacity=-1, + drop_pad_mode=0, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + active_expert_range=expert_range, + quant_mode=-1, + row_idx_type=0 + ) + + gate_up_proj = torch_npu.npu_grouped_matmul( + [sorted_tokens], + [layer.w13_weight], + bias=None, + group_list=expert_tokens, + split_item=3, + output_dtype=sorted_tokens.dtype, + group_type=0, + group_list_type=1 + )[0] + x = torch_npu.npu_swiglu(gate_up_proj) + y = torch_npu.npu_grouped_matmul( + [x], + [layer.w2_weight], + bias=None, + group_list=expert_tokens, + split_item=3, + output_dtype=x.dtype, + group_type=0, + group_list_type=1 + )[0] + + # 将不在本rank的专家的topk_weights置为0 + valid_mask = (topk_ids >= expert_range[0]) & (topk_ids < expert_range[1]) + topk_weights = topk_weights * valid_mask.to(topk_weights.dtype) + + # 旧版cann中expanded_x_idx包含负数会有精度问题,需要消除负数 + expanded_x_idx = (expanded_x_idx + expanded_x_idx.shape[0]) % expanded_x_idx.shape[0] + + y = torch_npu.npu_moe_finalize_routing( + y, None, None, None, + topk_weights, # 数据类型要求与y一致 + expanded_x_idx, + topk_ids, + drop_pad_mode=2 + ) + + y = get_tp_group().all_reduce(y) + + if model_extra_config.task_config.enable_omni_placement: + layer.planner.record_activation( + layer.moe_layer_idx, + expert_tokens, + support_multi_stream=model_extra_config.operator_opt_config.moe_multi_stream_tune + ) + + return y + def apply( self, layer: torch.nn.Module, @@ -645,6 +744,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase): scoring_func, e_score_correction_bias ) + elif not model_extra_config.operator_opt_config.decode_flash_comm_1: + return self.apply_allreduce_decode( + layer, + x, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + global_num_experts, + expert_range, + custom_routing_function, + scoring_func, + e_score_correction_bias + ) else: x = get_ep_group().all_gather(x, dim=0) router_logits = get_ep_group().all_gather(router_logits, dim=0) @@ -819,7 +934,7 @@ class FusedMoE(torch.nn.Module): self.planner = kwargs.get("planner", None) self.moe_layer_idx = kwargs.get("moe_layer_idx", None) self.expert_mapping = kwargs.get("expert_mapping", None) - self.is_prefill_instance = os.environ.get("ROLE", "") == "prefill" + self.is_prefill = os.environ.get("ROLE", "") == "prefill" self.prefix = prefix if params_dtype is None: @@ -853,10 +968,18 @@ class FusedMoE(torch.nn.Module): # Determine expert maps if self.use_ep: - self.local_num_experts, self.expert_range = determine_expert_range( - ep_size=self.ep_size, - ep_rank=self.ep_rank, - global_num_experts=self.global_num_experts) + if self.is_prefill or \ + model_extra_config.operator_opt_config.decode_moe_dispatch_combine or \ + model_extra_config.operator_opt_config.decode_flash_comm_1: + self.local_num_experts, self.expert_range = determine_expert_range( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts) + else: + self.local_num_experts, self.expert_range = determine_expert_range( + ep_size=get_tp_group().world_size, + ep_rank=get_tp_group().rank_in_group, + global_num_experts=self.global_num_experts) else: self.local_num_experts, self.expert_range = (self.global_num_experts, None) @@ -1126,7 +1249,7 @@ class FusedMoE(torch.nn.Module): # Forced load balance if model_extra_config.operator_opt_config.best_ep: - if self.is_prefill_instance: + if self.is_prefill: t = (topk_ids.shape[0] * 8) // 256 topk_ids = torch.arange(256, device=current_platform.device_type, dtype=torch.int32).unsqueeze( 0).repeat(t + 1, 1).view(-1, 8)[:topk_ids.shape[0]] @@ -1150,7 +1273,8 @@ class FusedMoE(torch.nn.Module): scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, routed_scaling_factor: Optional[torch.Tensor] = None, - indices_type: Optional[torch.dtype] = None): + indices_type: Optional[torch.dtype] = None, + finished: Optional[torch.dtype] = None): # DeepSeekV2 uses grouped_top_k if e_score_correction_bias is None: e_score_correction_bias = FusedMoE.ZERO_CORRECTION_BIAS @@ -1181,7 +1305,8 @@ class FusedMoE(torch.nn.Module): topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( router_logits.float(), - k=top_k + k=top_k, + finished=finished ) if renormalize: topk_weights /= topk_weights.sum(dim=-1, keepdim=True) @@ -1210,9 +1335,14 @@ class FusedMoE(torch.nn.Module): if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[next(iter(attn_metadata))] if attn_metadata is not None and attn_metadata.mc2_mask is not None: - mc2_mask_slice_size = hidden_states.shape[0] - mc2_mask_slice_id = get_tp_group().rank_in_group - self.mc2_mask = attn_metadata.mc2_mask[mc2_mask_slice_id * mc2_mask_slice_size : (mc2_mask_slice_id + 1) * mc2_mask_slice_size] + if self.is_prefill or \ + model_extra_config.operator_opt_config.decode_moe_dispatch_combine or \ + model_extra_config.operator_opt_config.decode_flash_comm_1: + mc2_mask_slice_size = hidden_states.shape[0] + mc2_mask_slice_id = get_tp_group().rank_in_group + self.mc2_mask = attn_metadata.mc2_mask[mc2_mask_slice_id * mc2_mask_slice_size : (mc2_mask_slice_id + 1) * mc2_mask_slice_size] + else: + self.mc2_mask = attn_metadata.mc2_mask final_hidden_states = self.quant_method.apply( layer=self, diff --git a/omni/models/qwen/qwen3_moe.py b/omni/models/qwen/qwen3_moe.py index 02fec18352..737492d851 100644 --- a/omni/models/qwen/qwen3_moe.py +++ b/omni/models/qwen/qwen3_moe.py @@ -23,6 +23,7 @@ """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" from collections.abc import Iterable from typing import Any, Optional, Union, List, Tuple +import os import torch import torch_npu @@ -130,7 +131,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module): is_prefill=is_prefill) final_hidden_states = final_hidden_states - return final_hidden_states.view(orig_shape) + if is_prefill or \ + model_extra_config.operator_opt_config.decode_moe_dispatch_combine or \ + model_extra_config.operator_opt_config.decode_flash_comm_1: + return final_hidden_states.view(orig_shape) + return final_hidden_states class Qwen3MoeAttention(nn.Module): @@ -231,8 +236,13 @@ class Qwen3MoeAttention(nn.Module): kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: - is_prefill = attn_metadata is None or not attn_metadata.is_pd_seperate_d - qkv, _ = self.qkv_proj(hidden_states, x_transform='AG', is_prefill = is_prefill) + is_prefill = os.environ.get("ROLE", "") == "prefill" + if is_prefill or \ + model_extra_config.operator_opt_config.decode_moe_dispatch_combine or \ + model_extra_config.operator_opt_config.decode_flash_comm_1: + qkv, _ = self.qkv_proj(hidden_states, x_transform='AG', is_prefill = is_prefill) + else: + qkv, _ = self.qkv_proj(hidden_states, is_prefill = is_prefill) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, @@ -275,7 +285,12 @@ class Qwen3MoeAttention(nn.Module): ).transpose(0, 1).contiguous().view(local_s, -1) output,_ = self.o_proj.forward(attn_output) else: - output, _ = self.o_proj(attn_output, reduce_type="RS") + if is_prefill or \ + model_extra_config.operator_opt_config.decode_moe_dispatch_combine or \ + model_extra_config.operator_opt_config.decode_flash_comm_1: + output, _ = self.o_proj(attn_output, reduce_type="RS") + else: + output, _ = self.o_proj(attn_output, reduce_type="AR") return output @@ -361,7 +376,7 @@ class Qwen3MoeDecoderLayer(nn.Module): attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[next(iter(attn_metadata))] - is_prefill = attn_metadata is None or not attn_metadata.is_pd_seperate_d + is_prefill = os.environ.get("ROLE", "") == "prefill" if is_prefill and model_extra_config.operator_opt_config.enable_mlp_seq_split: local_length = hidden_states.shape[0] reduce_length = torch.tensor(local_length, dtype=torch.int64, device="npu") @@ -454,10 +469,13 @@ class Qwen3MoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - # 采用FlashComm1.0, 通过slice使用hidden_states转为DP aux_hidden_states = [] - hidden_states = self.get_tp_slice(hidden_states) - + is_prefill = os.environ.get("ROLE", "") == "prefill" + if is_prefill or \ + model_extra_config.operator_opt_config.decode_moe_dispatch_combine or \ + model_extra_config.operator_opt_config.decode_flash_comm_1: + # 采用FlashComm1.0, 通过slice使用hidden_states转为DP + hidden_states = self.get_tp_slice(hidden_states) for i in range(self.start_layer, self.end_layer): if i in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) @@ -473,7 +491,12 @@ class Qwen3MoeModel(nn.Module): "hidden_states": hidden_states, "residual": residual }) - hidden_states, _ = self.norm(hidden_states, residual, y_transform='AG') + if is_prefill or \ + model_extra_config.operator_opt_config.decode_moe_dispatch_combine or \ + model_extra_config.operator_opt_config.decode_flash_comm_1: + hidden_states, _ = self.norm(hidden_states, residual, y_transform='AG') + else: + hidden_states, _ = self.norm(hidden_states, residual) if len(aux_hidden_states) > 0: return aux_hidden_states, hidden_states return hidden_states -- Gitee