From e42d82f7ec36bf5b44b8dbf617f4bc38e0f23c81 Mon Sep 17 00:00:00 2001 From: zhutianyi <“1589841300@qq.com> Date: Thu, 25 Sep 2025 01:56:21 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81TP=E5=88=87=E5=88=86=E5=85=B1?= =?UTF-8?q?=E4=BA=AB=E4=B8=93=E5=AE=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../vllm/distributed/communication_op.py | 12 ++ .../vllm/distributed/parallel_state.py | 32 +++++ omni/layers/moe/deepseek_moe.py | 113 +++++++++++++++--- omni/models/common/config/model_config.py | 1 + .../test_prefill_prefill_bf16.json | 1 + 5 files changed, 145 insertions(+), 14 deletions(-) diff --git a/omni/adaptors/vllm/distributed/communication_op.py b/omni/adaptors/vllm/distributed/communication_op.py index 5ea906cd2..54d86fd6b 100644 --- a/omni/adaptors/vllm/distributed/communication_op.py +++ b/omni/adaptors/vllm/distributed/communication_op.py @@ -14,6 +14,7 @@ from omni.adaptors.vllm.distributed.parallel_state import ( get_round_cross_group_from_list, get_near_cross_group_from_list, get_mlp_tp_group, + get_share_expert_tp_group, GroupCoordinator, ) from omni.models.common.config.model_config import model_extra_config @@ -54,6 +55,17 @@ def all_gather_local(input_: torch.Tensor, idx: int, dim=-1) -> torch.Tensor: def all_gather_cross(input_: torch.Tensor, idx: int, dim=-1) -> torch.Tensor: return get_cross_group_from_list(idx).all_gather(input_, dim) +def share_expert_all_gather(input_: torch.Tensor, dim=-1, comm_group: Optional[GroupCoordinator] = None): + if comm_group is None: + return get_share_expert_tp_group().all_gather(input_, dim) + else: + return comm_group.all_gather(input_, dim) + +def share_expert_reduce_scatter(input_: torch.Tensor, dim=-1, comm_group: Optional[GroupCoordinator] = None): + if comm_group is None: + return get_share_expert_tp_group().all_gather(input_, dim) + else: + return comm_group.reduce_scatter(input_, dim) def mlp_all_gather(input_: torch.Tensor, dim=-1, comm_group: Optional[GroupCoordinator] = None): if comm_group is None: diff --git a/omni/adaptors/vllm/distributed/parallel_state.py b/omni/adaptors/vllm/distributed/parallel_state.py index 103f9cf0d..0954b8061 100644 --- a/omni/adaptors/vllm/distributed/parallel_state.py +++ b/omni/adaptors/vllm/distributed/parallel_state.py @@ -135,6 +135,7 @@ _CROSS_ROUND_COMM_LIST = None # kept for backward compatibility _LOCAL_WORLD: Optional[GroupCoordinator] = None _MLP_TP: Optional[GroupCoordinator] = None +_SHARE_EXPERT_TP: Optional[GroupCoordinator] = None _STREAM1_ATTN_GROUP: Optional[GroupCoordinator] = None _STREAM1_MLP_GROUP: Optional[GroupCoordinator] = None _STREAM1_MOE_GROUP: Optional[GroupCoordinator] = None @@ -159,6 +160,7 @@ def initialize_model_parallel( ) initialize_mlp_tp_group(backend) initialize_local_world_group(backend) + initialize_share_expert_tp_group(backend) if model_extra_config.operator_opt_config.enable_prefill_micro_batch: initialize_stream1_attn_group(backend) @@ -324,6 +326,34 @@ def initialize_mlp_tp_group(backend) -> None: group_name="mlp_tp_group", ) +def initialize_share_expert_tp_group(backend) -> None: + # Get world size and rank. Ensure some consistencies. + if not torch.distributed.is_initialized(): + raise RuntimeError("torch.distributed must be initialized") + world_size: int = torch.distributed.get_world_size() + share_expert_tp_size = model_extra_config.parall_config.share_expert_tp_size + if world_size % share_expert_tp_size != 0: + raise RuntimeError(f"o_proj TP Size ({share_expert_tp_size}) should be divisible by world size ({world_size})") + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + num_local_groups: int = world_size // share_expert_tp_size + global _SHARE_EXPERT_TP + if _SHARE_EXPERT_TP is not None: + raise RuntimeError("_O_PROJ_TP must be None") + group_ranks = [] + for i in range(num_local_groups): + ranks = list(range(i * share_expert_tp_size, (i + 1) * share_expert_tp_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + _SHARE_EXPERT_TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=False, + group_name="share_expert_tp_group", + ) + def initialize_o_proj_tp_group(backend) -> None: # Get world size and rank. Ensure some consistencies. @@ -506,6 +536,8 @@ def initialize_world_comm_group_list(backend) -> None: def get_mlp_tp_group() -> GroupCoordinator: return _MLP_TP +def get_share_expert_tp_group() -> GroupCoordinator: + return _SHARE_EXPERT_TP def get_o_proj_tp_group() -> GroupCoordinator: return _O_PROJ_TP diff --git a/omni/layers/moe/deepseek_moe.py b/omni/layers/moe/deepseek_moe.py index 1879da464..db6f3ed74 100644 --- a/omni/layers/moe/deepseek_moe.py +++ b/omni/layers/moe/deepseek_moe.py @@ -45,6 +45,8 @@ from vllm.model_executor.layers.linear import ( ) from omni.layers.linear import ( MergedReplicatedLinear, + AscendMergedColumnParallelLinear, + AscendRowParallelLinear ) from omni.layers.activation import SiluAndMul from omni.layers.moe.fused_moe.layer import FusedMoE, UNQUANT_MODE, DYNAMIC_QUANT_MODE @@ -62,6 +64,7 @@ from omni.layers.moe.fused_moe.layer import FusedMoE from omni.models.common.config.model_config import model_extra_config from omni.layers.moe.fused_moe.fused_moe import fused_experts_moe_dispatch_combine from omni.adaptors.vllm.utils import get_attr_by_names +from omni.adaptors.vllm.distributed.parallel_state import get_share_expert_tp_group if model_extra_config.operator_opt_config.use_omni_placement: from omni.accelerators.placement.omni_placement.omni_planner import OmniPlanner @@ -143,6 +146,56 @@ class ReplicatedDeepseekMLP(nn.Module): x, _ = self.down_proj.forward(x) return x +class ParallelDeepSeekMLP(nn.Module): + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.prefix = prefix + self.gate_up_proj = AscendMergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + tp_size=get_share_expert_tp_group().world_size, + tp_rank=get_share_expert_tp_group().rank_in_group, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = AscendRowParallelLinear(intermediate_size, + hidden_size, + tp_size=get_share_expert_tp_group().world_size, + tp_rank=get_share_expert_tp_group().rank_in_group, + bias=False, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn_obj = SiluAndMul() + self.quant_symbol = True if quant_config else False + self.tp_size = get_share_expert_tp_group().world_size + + def act_fn(self, x, quant_symbol): + if quant_symbol and isinstance(x, tuple): + x = dict(zip(['x_int8', 'pertoken_scale'], x)) + x['out_scale'] = self.gate_up_proj.weight_scale + return self.act_fn_obj(x, quant_symbol) + + def forward(self,x): + x = get_share_expert_tp_group().all_gather(x, dim=0) + + gate_up, _ = self.gate_up_proj.forward(x) + x = self.act_fn(gate_up, self.quant_symbol) + x, _ = self.down_proj.forward(x) + + x = get_share_expert_tp_group().reduce_scatter(x) + return x + + def DynamicPruningUnsorted(topk_weights: torch.Tensor, topk_ids: torch.Tensor, thresholds: torch.Tensor): @@ -175,6 +228,7 @@ class DeepseekMoE(nn.Module): n_routed_experts_names = ['num_routed_experts', 'n_routed_experts'] self.n_routed_experts = get_attr_by_names(config, n_routed_experts_names, 256) self.redundancy_shared_expert_num = model_extra_config.parall_config.redundancy_shared_expert_num + self.shared_experts_tp_size = model_extra_config.parall_config.share_expert_tp_size self.quant_symbol = quant_config is not None self.is_init_gate = False if os.getenv("ASCEND_PLATFORM", "A3")=="A2": @@ -268,14 +322,24 @@ class DeepseekMoE(nn.Module): self.moe_layer_idx = OmniPlanner.get_deepseek_v3_moe_layer_idx(f"{prefix}.share_experts", first_k_dense_replace=self.first_k_dense_replace) self.expert_mapping = self.planner.expert_mapping_on_current_layer(self.moe_layer_idx, is_prefill=False) - self.shared_experts = ReplicatedDeepseekMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - prefix=f"{prefix}.shared_experts", - ) + if self.shared_experts_tp_size > 1: + self.shared_experts = ParallelDeepSeekMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + else: + self.shared_experts = ReplicatedDeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) if self.experts is not None: self.w13_prefetch_size = model_extra_config.operator_opt_config.expert_gate_up_prefetch * 1024 * 1024 @@ -360,7 +424,7 @@ class DeepseekMoE(nn.Module): attn_metadata=attn_metadata ) - if model_extra_config.operator_opt_config.prefill_moe_all_to_all: + if not self.quant_symbol: if len(final_hidden_states_list) != 4: raise RuntimeError("len(final_hidden_states_list) != 4") final_hidden_states = final_hidden_states_list[0] @@ -369,10 +433,10 @@ class DeepseekMoE(nn.Module): else: final_hidden_states = final_hidden_states_list - if not model_extra_config.operator_opt_config.prefill_moe_all_to_all: - final_hidden_states = reduce_scatter_two_stage(final_hidden_states, idx=0) + if not model_extra_config.operator_opt_config.decode_moe_dispatch_combine: + final_hidden_states = get_world_group().reduce_scatter(final_hidden_states) - if model_extra_config.operator_opt_config.prefill_moe_all_to_all: + if not self.quant_symbol: final_hidden_states = torch_npu.npu_moe_finalize_routing( gathered_tokens, skip1=shared_output, @@ -436,7 +500,7 @@ class DeepseekMoE(nn.Module): topk_ids = torch.round(topk_ids).to(torch.int32) global_pertoken_scale = global_pertoken_scale.squeeze(-1) - final_hidden_states = self.experts( + final_hidden_states_list = self.experts( hidden_states=global_hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, @@ -444,10 +508,31 @@ class DeepseekMoE(nn.Module): attn_metadata=attn_metadata ) + if not self.quant_symbol: + if len(final_hidden_states_list) != 4: + raise RuntimeError("len(final_hidden_states_list) != 4") + final_hidden_states = final_hidden_states_list[0] + gathered_tokens = final_hidden_states_list[1] + expanded_row_idx = final_hidden_states_list[3] + else: + final_hidden_states = final_hidden_states_list + if not model_extra_config.operator_opt_config.decode_moe_dispatch_combine: final_hidden_states = get_world_group().reduce_scatter(final_hidden_states) - final_hidden_states = final_hidden_states + shared_output + if not self.quant_symbol: + final_hidden_states = torch_npu.npu_moe_finalize_routing( + gathered_tokens, + skip1=shared_output, + skip2=None, + bias=None, + scales=topk_weights.to(gathered_tokens.dtype), + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=None, + drop_pad_mode=2 + ) + else: + final_hidden_states = final_hidden_states + shared_output return final_hidden_states, residual diff --git a/omni/models/common/config/model_config.py b/omni/models/common/config/model_config.py index cea8bbccd..ee4eeecf1 100644 --- a/omni/models/common/config/model_config.py +++ b/omni/models/common/config/model_config.py @@ -12,6 +12,7 @@ class ModelParallelConfig: dense_mlp_tp_size: int = 1 dp_size: int = 1 o_proj_tp_size: int = 1 + share_expert_tp_size: int = 1 redundancy_shared_expert_num: int = 0 diff --git a/tests/test_config/test_prefill_prefill_bf16.json b/tests/test_config/test_prefill_prefill_bf16.json index d1698b59f..8103a0585 100644 --- a/tests/test_config/test_prefill_prefill_bf16.json +++ b/tests/test_config/test_prefill_prefill_bf16.json @@ -2,6 +2,7 @@ "model_parallel_config": { "dense_mlp_tp_size": 4, "o_proj_tp_size": 2, + "share_expert_tp_size": 1, "dp_size": 1 }, "operator_optimizition_config": { -- Gitee