From 08fb9ec692be98b5cc76511ce463fd749c629ebb Mon Sep 17 00:00:00 2001 From: zhanghanLeo Date: Thu, 20 Nov 2025 15:45:15 +0800 Subject: [PATCH] [OpsSupport]:moe_gating_topk_softmax --- inferrt/python/mrt/torch/fx_backend.py | 1 + .../aclnn/aclnn_moe_gating_topk_softmax.cc | 62 ++++++++++++++ .../aclnn/aclnn_moe_gating_topk_softmax.h | 42 ++++++++++ inferrt/src/ops/op_def/ops.list | 1 + mopt/include/mopt/Dialect/Mrt/Mrt.td | 2 +- mopt/include/mopt/Dialect/Mrt/MrtOps.td | 25 ++++++ .../ops/test_aclnn_moe_gating_topk_softmax.py | 84 +++++++++++++++++++ 7 files changed, 216 insertions(+), 1 deletion(-) create mode 100644 inferrt/src/ops/ascend/aclnn/aclnn_moe_gating_topk_softmax.cc create mode 100644 inferrt/src/ops/ascend/aclnn/aclnn_moe_gating_topk_softmax.h create mode 100644 tests/st/inferrt/ops/test_aclnn_moe_gating_topk_softmax.py diff --git a/inferrt/python/mrt/torch/fx_backend.py b/inferrt/python/mrt/torch/fx_backend.py index 1bee3bc5..fbfcb994 100644 --- a/inferrt/python/mrt/torch/fx_backend.py +++ b/inferrt/python/mrt/torch/fx_backend.py @@ -140,6 +140,7 @@ _OP_MAP = { torch.ops.npu.npu_moe_init_routing_v2: Op.moe_init_routing_v3, torch.ops.npu.npu_add_rms_norm: Op.add_rms_norm, torch.ops.npu.npu_rms_norm: Op.rms_norm, + torch.ops.npu.npu_moe_gating_top_k_softmax: Op.moe_gating_top_k_softmax, # operator functions operator.getitem: Op.tuple_getitem, operator.add: Op.add, diff --git a/inferrt/src/ops/ascend/aclnn/aclnn_moe_gating_topk_softmax.cc b/inferrt/src/ops/ascend/aclnn/aclnn_moe_gating_topk_softmax.cc new file mode 100644 index 00000000..86195ac9 --- /dev/null +++ b/inferrt/src/ops/ascend/aclnn/aclnn_moe_gating_topk_softmax.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "ops/ascend/aclnn/aclnn_moe_gating_topk_softmax.h" +#include "ops/ascend/aclnn/utils/opapi_utils.h" +#include "ops/op_register.h" + +namespace mrt { +namespace ops { + +OpsErrorCode AclnnMoeGatingTopKSoftmax::InferShape(const std::vector &input, ir::Value *output) { + const auto x_shape = input[kIndex0]->ToTensor()->Shape(); + auto k = input[kIndex2]->ToInt(); + std::vector output_shape = x_shape; + output_shape[output_shape.size() - 1] = k; + ir::VisitAllTensors(output, [&](const ir::TensorPtr &tensor) { + tensor->Shape() = output_shape; + tensor->Resize(); + }); + return SUCCESS; +} + +OpsErrorCode AclnnMoeGatingTopKSoftmax::CalcWorkspace(const std::vector &input, + const ir::Value *output, size_t *workspaceSize) { + LOG_OUT << "Begin CalcWorkspace for op [moe_gating_top_k_softmax]"; + auto &output_tuple = output->ToTuple(); + executor_->GetWorkspaceSize(static_cast(workspaceSize), input[kIndex0]->ToTensor(), + input[kIndex1]->IsTensor() ? input[kIndex1]->ToTensor() : nullptr, + input[kIndex2]->ToInt(), (*output_tuple)[kIndex0]->ToTensor(), + (*output_tuple)[kIndex1]->ToTensor(), (*output_tuple)[kIndex2]->ToTensor()); + return SUCCESS; +} + +OpsErrorCode AclnnMoeGatingTopKSoftmax::Launch(const std::vector &input, void *workspace, + size_t workspaceSize, ir::Value *output, void *stream) { + LOG_OUT << "Begin Launch op [moe_gating_top_k_softmax]"; + auto &output_tuple = output->ToTuple(); + executor_->Launch(workspace, workspaceSize, stream, input[kIndex0]->ToTensor(), + input[kIndex1]->IsTensor() ? input[kIndex1]->ToTensor() : nullptr, input[kIndex2]->ToInt(), + (*output_tuple)[kIndex0]->ToTensor(), (*output_tuple)[kIndex1]->ToTensor(), + (*output_tuple)[kIndex2]->ToTensor()); + return SUCCESS; +} + +MRT_REG_OP(moe_gating_top_k_softmax, AclnnMoeGatingTopKSoftmax, Ascend); +} // namespace ops +} // namespace mrt diff --git a/inferrt/src/ops/ascend/aclnn/aclnn_moe_gating_topk_softmax.h b/inferrt/src/ops/ascend/aclnn/aclnn_moe_gating_topk_softmax.h new file mode 100644 index 00000000..c0b19572 --- /dev/null +++ b/inferrt/src/ops/ascend/aclnn/aclnn_moe_gating_topk_softmax.h @@ -0,0 +1,42 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPS_ASCEND_ACLNN_ACLNN_MOE_GATING_TOPK_SOFTMAX_H__ +#define __OPS_ASCEND_ACLNN_ACLNN_MOE_GATING_TOPK_SOFTMAX_H__ + +#include "ops/operator.h" +#include "ops/ascend/aclnn/utils/aclnn_executor.h" + +namespace mrt { +namespace ops { +class AclnnMoeGatingTopKSoftmax : public Operator { + public: + AclnnMoeGatingTopKSoftmax() { executor_ = std::make_unique("aclnnMoeGatingTopKSoftmax"); } + ~AclnnMoeGatingTopKSoftmax() override = default; + + OpsErrorCode InferShape(const std::vector &input, ir::Value *output) override; + OpsErrorCode CalcWorkspace(const std::vector &input, const ir::Value *output, + size_t *workspaceSize) override; + OpsErrorCode Launch(const std::vector &input, void *workspace, size_t workspaceSize, + ir::Value *output, void *stream) override; + + private: + std::unique_ptr executor_{nullptr}; +}; + +} // namespace ops +} // namespace mrt +#endif // __OPS_ASCEND_ACLNN_ACLNN_MOE_GATING_TOPK_SOFTMAX_H__ diff --git a/inferrt/src/ops/op_def/ops.list b/inferrt/src/ops/op_def/ops.list index 0fb3706e..6cebe2c3 100644 --- a/inferrt/src/ops/op_def/ops.list +++ b/inferrt/src/ops/op_def/ops.list @@ -55,3 +55,4 @@ OP(reduce_scatter) OP(all_to_all) OP(moe_init_routing_v3) OP(custom_call) +OP(moe_gating_top_k_softmax) diff --git a/mopt/include/mopt/Dialect/Mrt/Mrt.td b/mopt/include/mopt/Dialect/Mrt/Mrt.td index dcae9f4a..836a6d2d 100644 --- a/mopt/include/mopt/Dialect/Mrt/Mrt.td +++ b/mopt/include/mopt/Dialect/Mrt/Mrt.td @@ -24,4 +24,4 @@ include "mopt/Dialect/Mrt/MrtTypes.td" include "mopt/Dialect/Mrt/MrtConstantOps.td" include "mopt/Dialect/Mrt/MrtOps.td" -#endif // MRT_DIALECT_MRT_TD +#endif // MRT_DIALECT_MRT_TD \ No newline at end of file diff --git a/mopt/include/mopt/Dialect/Mrt/MrtOps.td b/mopt/include/mopt/Dialect/Mrt/MrtOps.td index 100d12e8..23266bb4 100644 --- a/mopt/include/mopt/Dialect/Mrt/MrtOps.td +++ b/mopt/include/mopt/Dialect/Mrt/MrtOps.td @@ -357,4 +357,29 @@ def Mrt_SplitOp : Mrt_Op<"split", [Pure]> { }]; } + +def Mrt_MoeGatingTopKSoftmaxOp : Mrt_Op<"moe_gating_top_k_softmax", [Pure]> { + let summary = "moe calc for input do Softmax, and do topK selection"; + let description = [{ + Performs a Softmax and TopK on input tensors. + }]; + + let arguments = (ins + AnyRankedTensor:$x, + MrtOptTensor:$finished, + Mrt_I64Type:$k + ); + + let results = (outs + AnyRankedTensor:$y, + AnyRankedTensor:$expert_idx, + AnyRankedTensor:$row_idx + ); + + let assemblyFormat = [{ + $x (`,` $finished^)? `,` $k + attr-dict `:` functional-type(operands, results) + }]; +} + #endif // MRT_DIALECT_MRT_OPS_TD diff --git a/tests/st/inferrt/ops/test_aclnn_moe_gating_topk_softmax.py b/tests/st/inferrt/ops/test_aclnn_moe_gating_topk_softmax.py new file mode 100644 index 00000000..2fecb5ca --- /dev/null +++ b/tests/st/inferrt/ops/test_aclnn_moe_gating_topk_softmax.py @@ -0,0 +1,84 @@ +import pytest +import numpy as np +import torch + +from tests.mark_utils import arg_mark +from tests.ops_utils import AssertRtolEqual +from mrt.torch import backend + + + +def softmax_func(x): + is_fp16 = x.dtype == torch.float16 + x = x.to(torch.float32) + x_max = x.max(axis=-1, keepdim=True).values + x_sub = x - x_max + y = torch.exp(x_sub) + x_sum = y.sum(dim=-1, keepdim=True) + + #处理除零情况 + zero_mask = (x_sum == 0) + ans = torch.where(zero_mask, torch.tensor(0.0, device=x.device), y / x_sum) + if is_fp16: + ans = ans.to(torch.float16) + x_max = x_max.to(torch.float16) + x_sum = x_sum.to(torch.float16) + return ans, x_max, x_sum + +def op_func(x, finished_optional, k): + num_expert = x.shape[-1] + softmax, _, _ = softmax_func(x) + + #使用稳定排序的argsort + _, expert_idx = torch.sort(softmax, dim=-1, descending=True, stable=True) + expert_idx = expert_idx[:,:k] + y = torch.gather(softmax, -1, expert_idx) + + if finished_optional is not None: + finished_optional = finished_optional.reshape(finished_optional.shape[0], 1) + finished_optional = finished_optional.repeat(1, k) + expert_idx = torch.where(finished_optional, num_expert, expert_idx) + + batch_size, k_size = y.shape[0], y.shape[1] + row_idx = torch.arange(batch_size * k_size, device=x.device).reshape(k_size, batch_size).t() + + if x.dtype == torch.float16: + y = y.to(torch.float16) + + return y, expert_idx.to(torch.int32), row_idx.to(torch.int32) + + + +def get_op_func_compiled(): + def custom_op_func(x, finished, k): + return torch.ops.npu.npu_moe_gating_top_k_softmax(x, finished, k) + return torch.compile(custom_op_func, backend=backend) + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") +@pytest.mark.parametrize("pipeline", (True, False)) +@pytest.mark.parametrize("dtype", (torch.float16, torch.float32)) +@pytest.mark.parametrize("n", [10, 420, 520]) +@pytest.mark.parametrize("k", [2, 4, 5, 9]) +@pytest.mark.parametrize("col", [200, 1256, 5120]) +def test_moe_gating_topk_softmax(pipeline, monkeypatch, dtype, n, k, col): + """ + Feature: Test aclnn moe_gating_topk_softmax + Description: Test aclnn moe_gating_topk_softmax with fp32/fp16 inputs + Expectation: The result is correct + """ + if pipeline: + monkeypatch.setenv("MRT_ENABLE_PIPELINE", "on") + + x = torch.rand(n, col, dtype=dtype) + finished = torch.rand(n).to(torch.bool) + + x_npu = x.npu() + finished_npu = finished.npu() + + y_golden, expert_idx_golden, row_idx_golden = op_func(x, finished, k) + op_func_compiled = get_op_func_compiled() + y, expert_idx, row_idx = [npu_output.detach().cpu() for npu_output in op_func_compiled(x_npu, finished_npu, k)] + AssertRtolEqual(y, y_golden) + AssertRtolEqual(expert_idx, expert_idx_golden) + AssertRtolEqual(row_idx, row_idx_golden) -- Gitee