diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index 695c481f38bdd68c1e9feade3e40cacfa35ed37e..764de873882cb1eee8890a945c62d018e7c89cf0 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -24,6 +24,7 @@ from typing import Optional import mindspore as ms import numpy as np from mindspore import Tensor, mint, nn, ops +import ms_custom_ops from mindspore.ops.auto_generate import (GroupedMatmulV4, MoeDistributeCombine, MoeDistributeDispatch, MoeGatingTopKSoftmax, @@ -92,8 +93,8 @@ class FusedExperts(nn.Cell): if moe_config.moe_parallel_config.ep_size > 1 and \ moe_config.moe_parallel_config.tp_size == 1: self.moe_mode = MoeMode.PURE_EP - self.dispatch = MoeDistributeDispatch() - self.combine = MoeDistributeCombine() + self.dispatch = ms_custom_ops.moe_distribute_dispatch_v3 + self.combine = ms_custom_ops.moe_distribute_combine_v3 self.dispatch_tp_world_size = 0 if is_910b() else 1 self.dispatch_shared_expert_num = 0 if is_910b() else 1 self.max_bs = 256 if is_910b() else 512 @@ -245,4 +246,36 @@ class FusedExperts(nn.Cell): def run_ep(self, hidden_states, w1, w2, topk_ids, topk_weights, activation, global_num_experts): - raise NotImplementedError("ep mode not implemented yet.") + topk_ids = topk_ids.astype(ms.int32) + topk_weights = topk_weights.astype(ms.float32) + # Dispatch + sorted_input_tensor, _, assist_info_for_combine, group_list, ep_recv_counts, tp_recv_counts, _ = self.dispatch( + x=hidden_states, + expert_ids=topk_ids, + ep_world_size=self.ep_size, + ep_rank_id=self.ep_rank, + moe_expert_num=global_num_experts, + group_ep=self.ep_group, + tp_world_size=self.dispatch_tp_world_size, + global_bs=self.max_bs) + + # GroupMatmul + expert_output = self._ffn(sorted_input_tensor, w1, w2, group_list, + activation) + + # Combine + moe_output = self.combine( + expand_x=expert_output, + expert_ids=topk_ids, + assist_info_for_combine=assist_info_for_combine, + ep_send_counts=ep_recv_counts, + expert_scales=topk_weights, + ep_world_size=self.ep_size, + ep_rank_id=self.ep_rank, + moe_expert_num=global_num_experts, + tp_send_counts=tp_recv_counts, + group_ep=self.ep_group, + tp_world_size=self.dispatch_tp_world_size, + shared_expert_num=self.dispatch_shared_expert_num, + global_bs=self.max_bs) + return moe_output