diff --git a/docs/op_list.md b/docs/op_list.md index bf5223a71c657f17c74591b9be9237a8bfb142f2..da7e33f2a726e979ba7801ab73ffdd7cbcd943f4 100644 --- a/docs/op_list.md +++ b/docs/op_list.md @@ -12,12 +12,13 @@ 10. [mla](../ops/c_api/mla/mla_doc.md) 11. [mla_preprocess](../ops/c_api/mla_preprocess/mla_preprocess_doc.md) 12. [moe_gating_group_topk](../ops/c_api/moe_gating_group_topk/moe_gating_group_topk.md) -13. [paged_cache_load](../ops/c_api/paged_cache_load/paged_cache_load_doc.md) -14. [quant_batch_matmul](../ops/c_api/quant_batch_matmul/quant_batch_matmul.md) -15. [reshape_and_cache](../ops/c_api/reshape_and_cache/reshape_and_cache.md) -16. [reshape_and_cache_npd](../ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd.md) -17. [rope](../ops/c_api/rope/rope.md) -18. [scatter_nd_update](../ops/c_api/scatter_nd_update/scatter_nd_update.md) -19. [sparse_flash_attention](../ops/c_api/sparse_flash_attention/sparse_flash_attention_doc.md) -20. [trans_data](../ops/c_api/trans_data/trans_data.md) -21. [type_cast](../ops/c_api/type_cast/type_cast.md) +13. [moe_token_unpermute](../ops/c_api/moe_token_unpermute/moe_token_unpermute.md) +14. [paged_cache_load](../ops/c_api/paged_cache_load/paged_cache_load_doc.md) +15. [quant_batch_matmul](../ops/c_api/quant_batch_matmul/quant_batch_matmul.md) +16. [reshape_and_cache](../ops/c_api/reshape_and_cache/reshape_and_cache.md) +17. [reshape_and_cache_npd](../ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd.md) +18. [rope](../ops/c_api/rope/rope.md) +19. [scatter_nd_update](../ops/c_api/scatter_nd_update/scatter_nd_update.md) +20. [sparse_flash_attention](../ops/c_api/sparse_flash_attention/sparse_flash_attention_doc.md) +21. [trans_data](../ops/c_api/trans_data/trans_data.md) +22. [type_cast](../ops/c_api/type_cast/type_cast.md) diff --git a/ops/c_api/moe_token_unpermute/moe_token_unpermute.cc b/ops/c_api/moe_token_unpermute/moe_token_unpermute.cc new file mode 100644 index 0000000000000000000000000000000000000000..77f7b656216b5b820fd6b8a659f0bc03eae70a80 --- /dev/null +++ b/ops/c_api/moe_token_unpermute/moe_token_unpermute.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.h" +#include "ops/framework/utils.h" + +// ============================================================================= +// COMMON FUNCTION +// ============================================================================= + +namespace ms_custom_ops { + +enum class InputIndex : size_t { + kInputPermutedTokensIndex = 0, + kInputSortedIndicesIndex = 1, + kInputProbsIndex = 2, + kInputPaddedModeIndex = 3, + kInputRestoreShapeIndex = 4, +}; + +enum class OutputIndex : size_t { kOutputIndex = 0 }; + +ShapeVector MoeTokenUnpermuteMakeShape(const ShapeVector &permuted_tokens_shape, const ShapeVector &probs_shape) { + if (permuted_tokens_shape.size() != kDim2) { + MS_LOG(EXCEPTION) << "For MoeTokenUnpermute, permuted_tokens must be a 2D tensor, but got dimension: " + << permuted_tokens_shape.size(); + } + if (probs_shape.size() != kDim2) { + MS_LOG(EXCEPTION) << "For MoeTokenUnpermute, probs must be a 2D tensor, but got dimension: " << probs_shape.size(); + } + ShapeVector out_shape{probs_shape[0], permuted_tokens_shape[1]}; + return out_shape; +} +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +class MoeTokenUnpermuteFuncImpl : public OpFuncImpl { + public: + MoeTokenUnpermuteFuncImpl() : OpFuncImpl() {} + ~MoeTokenUnpermuteFuncImpl() = default; + + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + auto &permuted_tokens = input_infos[static_cast(InputIndex::kInputPermutedTokensIndex)]; + auto &probs = input_infos[static_cast(InputIndex::kInputProbsIndex)]; + ShapeVector out_shape = {abstract::Shape::kShapeRankAny}; + if (permuted_tokens->IsDynamicRank() || probs->IsDynamicRank()) { + return {out_shape}; + } + out_shape = MoeTokenUnpermuteMakeShape(permuted_tokens->GetShape(), probs->GetShape()); + return {out_shape}; + } + + std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + auto permuted_tokens_type = input_infos[static_cast(InputIndex::kInputPermutedTokensIndex)]->GetType(); + if (permuted_tokens_type != TypeId::kNumberTypeFloat16) { + MS_LOG(EXCEPTION) << "For MoeTokenUnpermute, permuted_tokens must be a float16 tensor, but got: " + << permuted_tokens_type; + } + return {permuted_tokens_type}; + } + + bool GeneralInferRegistered() const override { return true; } +}; + +class MoeTokenUnpermute : public InternalKernelMod { + public: + MoeTokenUnpermute() : InternalKernelMod() {} + ~MoeTokenUnpermute() = default; + + void InitKernelInputsOutputsIndex() override { + kernel_inputs_index_ = { + static_cast(InputIndex::kInputPermutedTokensIndex), + static_cast(InputIndex::kInputSortedIndicesIndex), + static_cast(InputIndex::kInputProbsIndex), static_cast(InputIndex::kInputPaddedModeIndex), + static_cast(InputIndex::kInputRestoreShapeIndex)}; + kernel_outputs_index_ = {static_cast(OutputIndex::kOutputIndex)}; + } + + protected: + internal_v2::InternalOpPtr CreateKernel(const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs, + const std::vector &ms_inputs, + const std::vector &ms_outputs) override { + return internal_v2::CreateMoeTokenUnpermuteOp(inputs, outputs, internal_v2::kInternalMoeTokenUnpermuteOpName); + } +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(moe_token_unpermute, ms_custom_ops::MoeTokenUnpermuteFuncImpl, ms_custom_ops::MoeTokenUnpermute); +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +#include "ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h" + +namespace ms_custom_ops { +class MoeTokenUnpermuteRunner : public InternalPyboostRunner { + public: + explicit MoeTokenUnpermuteRunner(const std::string &op_name) : InternalPyboostRunner(op_name) {} + ~MoeTokenUnpermuteRunner() = default; + + protected: + internal_v2::InternalOpPtr CreateKernel(const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs) override { + return internal_v2::CreateMoeTokenUnpermuteOp(inputs, outputs, internal_v2::kInternalMoeTokenUnpermuteOpName); + } +}; + +std::vector npu_moe_token_unpermute(const ms::Tensor &permuted_tokens, const ms::Tensor &sorted_indices, + const ms::Tensor &probs, const std::optional &padded_mode, + const std::optional> &restore_shape) { + auto op_name = "MoeTokenUnpermute"; + auto runner = std::make_shared(op_name); + MS_EXCEPTION_IF_NULL(runner); + + // Setup the runner with all parameters (including hash calculation) + runner->Setup(op_name, permuted_tokens, sorted_indices, probs, padded_mode, restore_shape); + + // if you need infer shape and type, you need create output tensors. + ShapeVector output_shape = MoeTokenUnpermuteMakeShape(permuted_tokens.shape(), probs.shape()); + auto output = ms::Tensor(permuted_tokens.data_type(), output_shape); + std::vector inputs = {permuted_tokens, sorted_indices, probs}; + std::vector outputs = {output}; + runner->GetOrCreateKernel(inputs, outputs); + runner->Run(inputs, outputs); + return outputs; +} +} // namespace ms_custom_ops + +auto pyboost_moe_token_unpermute(const ms::Tensor &permuted_tokens, const ms::Tensor &sorted_indices, + const ms::Tensor &probs, const std::optional &padded_mode, + const std::optional> &restore_shape) { + return ms::pynative::PyboostRunner::Call<1>(ms_custom_ops::npu_moe_token_unpermute, permuted_tokens, sorted_indices, + probs, padded_mode, restore_shape); +} + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("moe_token_unpermute", &pyboost_moe_token_unpermute, "MoeTokenUnpermute", pybind11::arg("permuted_tokens"), + pybind11::arg("sorted_indices"), pybind11::arg("probs"), pybind11::arg("padded_mode") = false, + pybind11::arg("restore_shape") = std::nullopt); +} diff --git a/ops/c_api/moe_token_unpermute/moe_token_unpermute.md b/ops/c_api/moe_token_unpermute/moe_token_unpermute.md new file mode 100644 index 0000000000000000000000000000000000000000..4f5ea1736242588fbb0fbfd88c607ea59613b379 --- /dev/null +++ b/ops/c_api/moe_token_unpermute/moe_token_unpermute.md @@ -0,0 +1,50 @@ +# moe_token_unpermute 算子 + +## 描述 + +根据sorted_indices存储的下标,获取permuted_tokens中存储的输入数据,permuted_tokens会与probs相乘;最后进行累加求和,并输出计算结果。 + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|------|-------|-------|----------|---------|--------|-------------| +| permuted_tokens | Tensor(float16) | [tokens_num * topK_num, hidden_size] | No | No | ND | 要进行反排列的已排列标记的Tensor | +| sorted_indices | Tensor(int32) | [tokens_num * topK_num] | No | No | ND | 表示需要计算的数据在permuted_tokens中的位置 | +| probs | Tensor(float16) | [tokens_num, topK_num] | No | No | ND | 与已排列标记对应的概率Tensor | +| padded_mode | bool | | Yes | No | | true表示开启paddedMode,false表示关闭paddedMode,paddedMode解释见restoreShape参数。
目前仅支持 False, 不对输出的shape进行变换。默认False | +| restore_shape | Tuple(int) | | Yes | No | | padded_mode=true时生效,否则不会对其进行操作。paddedMode=true时,out的shape将表征为restoreShape。目前仅支持None。默认None | + +## 输出参数 + +| Name | DType | Shape | Description | +|------|-------|-------|-------------| +| out | Tensor(float16) | [tokens_num, hidden_size] | 加权反排列后的Tensor | + +## 支持产品 + +- Atlas 推理系列产品 + +## 特殊说明 + +- 当前仅支持:`padded_mode = false, restore_shape = None`。 +- topK 支持 1,2,4,8 +- hidden_size 支持 2048,5120,7168 + +## 使用示例 + +### 基本使用示例(常规模式) + +```python +import numpy as np +from mindspore import Tensor +import ms_custom_ops + +token_num = 128 +hidden_size = 7168 +top_k = 8 +permuted_token = np.random.randn(token_num * top_k, hidden_size).astype(np.float16) +sorted_idx = np.arange(token_num * top_k, dtype=np.int32) +np.random.shuffle(sorted_idx) +probs = np.random.randn(token_num, top_k).astype(np.float16) +out = ms_custom_ops.moe_token_unpermute(Tensor(permuted_token), Tensor(sorted_idx), Tensor(probs)) +``` diff --git a/ops/c_api/moe_token_unpermute/moe_token_unpermute_op.yaml b/ops/c_api/moe_token_unpermute/moe_token_unpermute_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..87cc336a1a3cc1d239694fcad3251a8d91da0a88 --- /dev/null +++ b/ops/c_api/moe_token_unpermute/moe_token_unpermute_op.yaml @@ -0,0 +1,18 @@ +# operator moe_token_unpermute +moe_token_unpermute: + args: + permuted_tokens: + dtype: tensor + sorted_indices: + dtype: tensor + probs: + dtype: tensor + padded_mode: + dtype: Bool + default: False + restore_shape: + dtype: tuple[int] + default: None + returns: + unpermuted_tokens: + dtype: tensor diff --git a/tests/st/test_custom_moe_token_unpermute.py b/tests/st/test_custom_moe_token_unpermute.py new file mode 100644 index 0000000000000000000000000000000000000000..06113a17daa117d99e329c1dadc199d71023076f --- /dev/null +++ b/tests/st/test_custom_moe_token_unpermute.py @@ -0,0 +1,176 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" tests_custom_moe_token_unpermute_pyboost_ascend """ + +import numpy as np +import pytest +import logging +from mindspore import nn +from mindspore import Tensor, context + +# Local imports +import ms_custom_ops + +np.set_printoptions(precision=2, suppress=True, linewidth=200) + + +class MoeTokenUnpermuteNet(nn.Cell): + """MoeTokenUnpermuteNet""" + def construct(self, permuted_tokens, sorted_indices, probs, padded_mode=False, restore_shape=None): + return ms_custom_ops.moe_token_unpermute( + permuted_tokens, sorted_indices, probs, padded_mode, restore_shape + ) + + +def moe_token_unpermute_op_impl(permute_token, sorted_idx, probs): + token_num = probs.shape[0] + top_k = probs.shape[1] + hidden = permute_token.shape[1] + out = np.zeros((token_num, hidden), dtype=np.float16) + + for i in range(token_num): + for k in range(top_k): + dst_row = permute_token[sorted_idx[i * top_k + k], :] + out[i, :] += probs[i, k] * dst_row + return out + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("token_num", [1, 16, 1001]) +@pytest.mark.parametrize("hidden", [7168, 2048]) +@pytest.mark.parametrize("top_k", [8]) +def test_moe_token_unpermute(exec_mode, token_num, hidden, top_k): + """ + Feature: test moe_token_unpermute operator + Description: test different mode and input dimension of the operator correctness + Expectation: the result is correct + """ + context.set_context(mode=exec_mode, device_target="Ascend") + context.set_context(jit_config={"jit_level": "O0"}) + net = MoeTokenUnpermuteNet() + + # 生成输入数据 + permute_token = np.random.randn(token_num * top_k, hidden).astype(np.float16) + sorted_idx = np.arange(token_num * top_k, dtype=np.int32) + np.random.shuffle(sorted_idx) + probs = np.random.randn(token_num, top_k).astype(np.float16) + + # 计算期望输出 + expected = moe_token_unpermute_op_impl(permute_token, sorted_idx, probs) + + # 运行算子 + output = net(Tensor(permute_token), Tensor(sorted_idx), Tensor(probs)) + + # 验证结果 + assert np.allclose(output.asnumpy(), expected, rtol=1e-3, atol=1e-3) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("token_num", [16]) +@pytest.mark.parametrize("hidden", [1024]) +@pytest.mark.parametrize("top_k", [8]) +def test_moe_token_unpermute_unsupported_hidden_size(exec_mode, token_num, hidden, top_k): + """ + Feature: test moe_token_unpermute operator + Description: test unsupported hidden_size + Expectation: Unsupported hidden size correctly rejected + """ + context.set_context(mode=exec_mode, device_target="Ascend") + context.set_context(jit_config={"jit_level": "O0"}) + net = MoeTokenUnpermuteNet() + + # 生成输入数据 + permute_token = np.random.randn(token_num * top_k, hidden).astype(np.float16) + sorted_idx = np.arange(token_num * top_k, dtype=np.int32) + np.random.shuffle(sorted_idx) + probs = np.random.randn(token_num, top_k).astype(np.float16) + + # 运行算子 + with pytest.raises(RuntimeError, match="Tiling error"): + output = net(Tensor(permute_token), Tensor(sorted_idx), Tensor(probs)) + out = output.asnumpy() + logging.info( + f"Unsupported hidden_size correctly rejected: hidden_size={hidden}", + ) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("token_num", [16]) +@pytest.mark.parametrize("hidden", [7168]) +@pytest.mark.parametrize("top_k", [7]) +def test_moe_token_unpermute_unsupported_top_k(exec_mode, token_num, hidden, top_k): + """ + Feature: test moe_token_unpermute operator + Description: test unsupported top_k + Expectation: Unsupported top_k correctly rejected + """ + context.set_context(mode=exec_mode, device_target="Ascend") + context.set_context(jit_config={"jit_level": "O0"}) + net = MoeTokenUnpermuteNet() + + # 生成输入数据 + permute_token = np.random.randn(token_num * top_k, hidden).astype(np.float16) + sorted_idx = np.arange(token_num * top_k, dtype=np.int32) + np.random.shuffle(sorted_idx) + probs = np.random.randn(token_num, top_k).astype(np.float16) + + # 运行算子 + with pytest.raises(RuntimeError, match="Tiling error"): + output = net(Tensor(permute_token), Tensor(sorted_idx), Tensor(probs)) + out = output.asnumpy() + logging.info( + f"Unsupported top_k correctly rejected: top_k={top_k}", + ) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("token_num", [16]) +@pytest.mark.parametrize("hidden", [7168]) +@pytest.mark.parametrize("top_k", [8]) +def test_moe_token_unpermute_1d_input(exec_mode, token_num, hidden, top_k): + """ + Feature: test moe_token_unpermute operator + Description: test 1d input + Expectation: 1d input correctly rejected + """ + context.set_context(mode=exec_mode, device_target="Ascend") + context.set_context(jit_config={"jit_level": "O0"}) + net = MoeTokenUnpermuteNet() + + # 生成输入数据 + permute_token = np.random.randn(hidden).astype(np.float16) + sorted_idx = np.arange(token_num * top_k, dtype=np.int32) + np.random.shuffle(sorted_idx) + probs = np.random.randn(token_num, top_k).astype(np.float16) + + # 运行算子 + with pytest.raises(RuntimeError, match="must be a 2D tensor"): + output = net(Tensor(permute_token), Tensor(sorted_idx), Tensor(probs)) + out = output.asnumpy() + logging.info( + f"1d input correctly rejected", + )