From 4a580e13bc815617a71b4c4cd898a7567bf37888 Mon Sep 17 00:00:00 2001 From: zhanghanLeo Date: Wed, 29 Oct 2025 18:03:57 +0800 Subject: [PATCH] [Ops Support]: moe_distribute_dispatch_v3 supportted in Ascend. --- .jenkins/check/config/filter_cpplint.txt | 2 - .../kv_rmsnorm_rope_cache.cc | 34 +- .../moe_distribute_dispatch_op.yaml | 82 +++ .../moe_distribute_dispatch_v3.cc | 674 ++++++++++++++++++ .../moe_distribute_dispatch_v3.md | 139 ++++ ops/c_api/utils/check_utils.h | 10 +- ops/framework/utils.h | 2 + 7 files changed, 918 insertions(+), 25 deletions(-) create mode 100644 ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_op.yaml create mode 100644 ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_v3.cc create mode 100644 ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_v3.md diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index 9888e13..e69de29 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -1,2 +0,0 @@ -"ms_custom_ops/ops/c_api/scatter_nd_update/scatter_nd_update.cc" "build/namespaces" -"ms_custom_ops/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.cc" "build/namespace" diff --git a/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.cc b/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.cc index 0403225..09a11b4 100644 --- a/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.cc +++ b/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.cc @@ -133,26 +133,26 @@ class OPS_API KvRmsNormRopeCacheCustomOpFuncImpl : public OpFuncImpl { // dynamic shape or static shape; MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL( - (kv_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4)), - op_name + " kv input size should be " + + (kv_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4)), op_name, + " kv input size should be " + std::to_string(static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4)) + ", but now got " + std::to_string(kv_shape.size())); MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL( - (gamma_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize1)), - op_name + " gamma input size should be " + + (gamma_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize1)), op_name, + " gamma input size should be " + std::to_string(static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize1)) + ", but now got " + std::to_string(gamma_shape.size())); MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL( - ((cos_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4))), - op_name + std::to_string(static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4)) + - ", but now got cos " + std::to_string(cos_shape.size())); + ((cos_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4))), op_name, + std::to_string(static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4)) + ", but now got cos " + + std::to_string(cos_shape.size())); MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL( - ((sin_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4))), - op_name + std::to_string(static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4)) + - ", but now got sin " + std::to_string(sin_shape.size())); + ((sin_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4))), op_name, + std::to_string(static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4)) + ", but now got sin " + + std::to_string(sin_shape.size())); auto cache_mode = static_cast( input_infos[static_cast(KvRmsNormRopeCacheInputIndex::kKvRmsNormRopeCacheCacheModeIndex)] @@ -161,14 +161,14 @@ class OPS_API KvRmsNormRopeCacheCustomOpFuncImpl : public OpFuncImpl { input_infos[static_cast(KvRmsNormRopeCacheInputIndex::kKvRmsNormRopeCacheIdxIndex)]->GetShape(); if (cache_mode == KvRmsNormRopeCacheMode::kKvRmsNormRopeCacheModeNorm) { MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL( - (index_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize2)), - op_name + " index input size should be " + + (index_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize2)), op_name, + " index input size should be " + std::to_string(static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize2)) + ", but now got" + std::to_string(index_shape.size())); } else { MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL( - (index_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize1)), - op_name + " index input size should be " + + (index_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize1)), op_name, + " index input size should be " + std::to_string(static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize1)) + ", but now got" + std::to_string(index_shape.size())); } @@ -176,7 +176,8 @@ class OPS_API KvRmsNormRopeCacheCustomOpFuncImpl : public OpFuncImpl { MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL( ((k_cache_shape.size() == static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4)) || (c_kv_cache_shape.size() != static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4))), - op_name + " kCache or CKvCache input size should be " + + op_name, + " kCache or CKvCache input size should be " + std::to_string(static_cast(KvRmsNormRopeCacheSize::kKvRmsNormRopeCacheSize4)) + ", but now got k_cache:" + std::to_string(k_cache_shape.size()) + ", c_kv_cache:" + std::to_string(c_kv_cache_shape.size())); @@ -269,9 +270,6 @@ REG_GRAPH_MODE_OP(kv_rmsnorm_rope_cache, ms_custom_ops::KvRmsNormRopeCacheCustom // ============================================================================= namespace ms_custom_ops { -using namespace mindspore; -using namespace mindspore::device::ascend; - std::vector kv_rmsnorm_rope_cache_custom( const ms::Tensor &kv, const ms::Tensor &gamma, const ms::Tensor &cos, const ms::Tensor &sin, const ms::Tensor &index, const ms::Tensor k_cache, const ms::Tensor &c_kv_cache, const std::optional &k_rope_scale, diff --git a/ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_op.yaml b/ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_op.yaml new file mode 100644 index 0000000..a22ba5d --- /dev/null +++ b/ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_op.yaml @@ -0,0 +1,82 @@ +#operator moe_distribute_dispatch_v3 +moe_distribute_dispatch_v3: + args: + x: + dtype: tensor + expert_ids: + dtype: tensor + ep_world_size: + dtype: int + ep_rank_id: + dtype: int + moe_expert_num: + dtype: int + scales: + dtype: tensor + default: None + x_active_mask: + dtype: tensor + default: None + expert_scales: + dtype: tensor + default: None + elastic_info: + dtype: tensor + default: None + group_ep: + dtype: str + default: None + group_tp: + dtype: str + default: None + tp_world_size: + dtype: int + default: 0 + tp_rank_id: + dtype: int + default: 0 + expert_shard_type: + dtype: int + default: 0 + shared_expert_num: + dtype: int + default: 1 + shared_expert_rank_num: + dtype: int + default: 0 + quant_mode: + dtype: int + default: 0 + global_bs: + dtype: int + default: 0 + expert_token_nums_type: + dtype: int + default: 1 + comm_alg: + dtype: str + default: None + zero_expert_num: + dtype: int + default: 0 + copy_expert_num: + dtype: int + default: 0 + const_expert_num: + dtype: int + default: 0 + returns: + expand_x: + dtype: tensor + dynamic_scales: + dtype: tensor + assist_info_for_combine: + dtype: tensor + expert_token_nums: + dtype: tensor + ep_recv_counts: + dtype: tensor + tp_recv_counts: + dtype: tensor + expand_scales: + dtype: tensor \ No newline at end of file diff --git a/ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_v3.cc b/ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_v3.cc new file mode 100644 index 0000000..0cd96cd --- /dev/null +++ b/ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_v3.cc @@ -0,0 +1,674 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include +#include +#include +#include +#include +#include +#include +#include "ops/framework/aclnn/graphmode/aclnn_kernel_mod.h" +#include "ops/framework/utils.h" +#include "ops/c_api/utils/common_utils.h" +#include "ops/c_api/utils/check_utils.h" +#include "mindspore/ops/kernel/ascend/acl_ir/op_api_util.h" + +namespace ms_custom_ops { +enum class MoeDistributeDispatchV3InputIndex : size_t { + kMoeDistributeDispatchV3XIndex = 0, + kMoeDistributeDispatchV3ExpertIdsIndex, + kMoeDistributeDispatchV3EpWorldSizeIndex, + kMoeDistributeDispatchV3EpRankIdIndex, + kMoeDistributeDispatchV3MoeExpertNumIndex, + kMoeDistributeDispatchV3ScalesIndex, + kMoeDistributeDispatchV3XActiveMaskIndex, + kMoeDistributeDispatchV3ExpertScalesIndex, + kMoeDistributeDispatchV3ElasticInfoIndex, + kMoeDistributeDispatchV3GroupEpIndex, + kMoeDistributeDispatchV3GroupTpIndex, + kMoeDistributeDispatchV3TpWorldSizeIndex, + kMoeDistributeDispatchV3TpRankIdIndex, + kMoeDistributeDispatchV3ExpertShardTypeIndex, + kMoeDistributeDispatchV3SharedExpertNumIndex, + kMoeDistributeDispatchV3SharedExpertRankNumIndex, + kMoeDistributeDispatchV3QuantModeIndex, + kMoeDistributeDispatchV3GlobalBsIndex, + kMoeDistributeDispatchV3ExpertTokenNumsTypeIndex, + kMoeDistributeDispatchV3CommAlgIndex, + kMoeDistributeDispatchV3ZeroExpertNumIndex, + kMoeDistributeDispatchV3CopyExpertNumIndex, + kMoeDistributeDispatchV3ConstExpertNumIndex, + kMoeDistributeDispatchV3InputNums, +}; + +enum class MoeDistributeDispatchV3OutputIndex : size_t { + kMoeDistributeDispatchV3ExpandXOutputIndex = 0, + kMoeDistributeDispatchV3DynamicScalesOutputIndex, + kMoeDistributeDispatchV3AssistInfoForCombineOutputIndex, + kMoeDistributeDispatchV3ExpertTokenNumsOutputIndex, + kMoeDistributeDispatchV3EpRecvCountsOutputIndex, + kMoeDistributeDispatchV3TpRecvCountsOutputIndex, + kMoeDistributeDispatchV3ExpandScalesOutputIndex, + kMoeDistributeDispatchV3OutputNums, +}; + +struct InputParam { + ShapeVector x_shape_; + ShapeVector expert_ids_shape_; + int64_t ep_world_size_{0}; + int64_t ep_rank_id_{0}; + int64_t moe_expert_num_{0}; + int64_t tp_world_size_{0}; + int64_t tp_rank_id_{0}; + int64_t expert_shard_type_{0}; + int64_t shared_expert_num_{0}; + int64_t shared_expert_rank_num_{0}; + int64_t quant_mode_{0}; + int64_t global_bs_{0}; + int64_t expert_token_nums_type_{0}; + int64_t bs_{0}; + int64_t h_{0}; + int64_t k_{0}; +}; + +struct OutputShapes { + ShapeVector expand_x; + ShapeVector dynamic_scales; + ShapeVector assist_info; + ShapeVector expert_token_nums; + ShapeVector ep_recv_counts; + ShapeVector tp_recv_counts; + ShapeVector expand_scales; +}; + +void ValidateParamsComm(const int64_t shared_expert_num, const int64_t shared_expert_rank_num, + const int64_t expert_token_nums_type, const std::string &op_name) { + const bool is_shared_default = ((shared_expert_num == 1) && (shared_expert_rank_num == 0)); + const bool is_no_shared = ((shared_expert_num == 0) && (shared_expert_rank_num == 0)); + const bool is_valid_shared = ((shared_expert_num > 0) && ((shared_expert_rank_num / shared_expert_num) > 0) && + ((shared_expert_rank_num % shared_expert_num) == 0)); + MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL( + (is_shared_default || is_no_shared || is_valid_shared), op_name, + std::string("shared_expert_num and shared_expert_rank_num have obvious value situations:") + + std::string("1. shared_expert_num is 1, shared_expert_rank_num is 0; 2. shared_expert_num is " + "0,shared_expert_rank_num is ") + + std::string("0; 3. shared_expert_num is (0, shared_expert_rank_num] and ") + + std::string("shared_expert_rank_num % shared_expert_num = 0, but the current value is shared_expert_num:") + + std::to_string(shared_expert_num) + ", shared_expert_rank_num:" + std::to_string(shared_expert_rank_num)); + + MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL( + ((expert_token_nums_type == 0) || (expert_token_nums_type == 1)), op_name, + "The expect token nums type should be 0 or 1, but got " + std::to_string(expert_token_nums_type)); +} + +class OPS_API MoeDistributeDispatchV3CustomOpFuncImpl : public OpFuncImpl { + public: + void ValidateInputsSize(const InferInfoPtrList &input_infos, const std::string &op_name) const { + constexpr uint32_t kRequiredInputNums = + static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3InputNums); + MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL((input_infos.size() == kRequiredInputNums), op_name, + " the input size should be equal to: " + std::to_string(kRequiredInputNums) + + ", but got " + std::to_string(input_infos.size())); + } + + template + T GetInputValue(const InferInfoPtrList &input_infos, MoeDistributeDispatchV3InputIndex index) const { + return input_infos[static_cast(index)]->GetScalarValueWithCheck(); + } + + InputParam ExtractInputParams(const InferInfoPtrList &input_infos, const std::string &op_name) const { + InputParam params; + params.x_shape_ = + input_infos[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3XIndex)]->GetShape(); + params.expert_ids_shape_ = + input_infos[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ExpertIdsIndex)] + ->GetShape(); + params.ep_world_size_ = + GetInputValue(input_infos, MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3EpWorldSizeIndex); + params.ep_rank_id_ = + GetInputValue(input_infos, MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3EpRankIdIndex); + params.moe_expert_num_ = + GetInputValue(input_infos, MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3MoeExpertNumIndex); + params.tp_world_size_ = + GetInputValue(input_infos, MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3TpWorldSizeIndex); + params.tp_rank_id_ = + GetInputValue(input_infos, MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3TpRankIdIndex); + params.expert_shard_type_ = GetInputValue( + input_infos, MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ExpertShardTypeIndex); + params.shared_expert_num_ = GetInputValue( + input_infos, MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3SharedExpertNumIndex); + params.shared_expert_rank_num_ = GetInputValue( + input_infos, MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3SharedExpertRankNumIndex); + params.quant_mode_ = + GetInputValue(input_infos, MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3QuantModeIndex); + params.global_bs_ = + GetInputValue(input_infos, MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3GlobalBsIndex); + params.expert_token_nums_type_ = GetInputValue( + input_infos, MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ExpertTokenNumsTypeIndex); + + params.bs_ = params.x_shape_[kDim0]; + params.h_ = params.x_shape_[kDim1]; + params.k_ = params.expert_ids_shape_[kDim1]; + return params; + } + + void ValidateParams(const InputParam ¶ms, const std::string &op_name) const { + ValidateParamsComm(params.shared_expert_num_, params.shared_expert_rank_num_, params.expert_token_nums_type_, + op_name); + } + + bool HasDynamicRank(const InferInfoPtrList &input_infos) const { + return input_infos[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3XIndex)] + ->IsDynamicRank() || + input_infos[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ExpertIdsIndex)] + ->IsDynamicRank(); + } + + ShapeArray CreateDynamicRanks() const { + const ShapeVector dynamic_rank_shape{abstract::Shape::kShapeRankAny}; + return {dynamic_rank_shape, dynamic_rank_shape, dynamic_rank_shape, dynamic_rank_shape, + dynamic_rank_shape, dynamic_rank_shape, dynamic_rank_shape}; + } + + size_t CalculateLocalExpertNum(const InferInfoPtrList &input_infos, const InputParam ¶ms) const { + const bool shared_front = (params.expert_shard_type_ == 0); + size_t local_moe_expert_num = 0; + + if (shared_front) { + if (params.ep_rank_id_ < params.shared_expert_rank_num_) { + local_moe_expert_num = 1; + } else { + local_moe_expert_num = params.moe_expert_num_ / (params.ep_world_size_ - params.shared_expert_rank_num_); + } + + if (!input_infos[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ElasticInfoIndex)] + ->IsNone()) { + local_moe_expert_num = std::max( + local_moe_expert_num, + static_cast(params.moe_expert_num_ / (params.ep_world_size_ - params.shared_expert_rank_num_))); + } + } + return local_moe_expert_num; + } + + size_t CalcEpRecvCount(const InputParam ¶ms, size_t local_moe_expert_num) const { + return (params.tp_world_size_ == kDim2) ? params.ep_world_size_ * local_moe_expert_num * params.tp_world_size_ + : params.ep_world_size_ * local_moe_expert_num; + } + + void SetPartiallyDynamicShapes(OutputShapes &outputs, const InputParam ¶ms, size_t local_moe_expert_num, + size_t ep_recv_counts) const { + outputs.expand_x = {abstract::Shape::kShapeDimAny, params.h_}; + outputs.dynamic_scales = {abstract::Shape::kShapeDimAny}; + outputs.assist_info = {abstract::Shape::kShapeDimAny}; + outputs.expert_token_nums = {static_cast(local_moe_expert_num)}; + + if (IsSoc910b()) { + outputs.ep_recv_counts = {abstract::Shape::kShapeDimAny}; + outputs.expand_scales = {abstract::Shape::kShapeDimAny}; + } else if (IsSoc910_93()) { + outputs.ep_recv_counts = {static_cast(ep_recv_counts)}; + outputs.expand_scales = kFakeOutShapes; + } + outputs.tp_recv_counts = {params.tp_world_size_}; + } + + ShapeArray ConvertToShapeArray(const OutputShapes &outputs) const { + return {outputs.expand_x, outputs.dynamic_scales, outputs.assist_info, outputs.expert_token_nums, + outputs.ep_recv_counts, outputs.tp_recv_counts, outputs.expand_scales}; + } + + ShapeArray CreatPartiallyDynamicShapes(const InferInfoPtrList &input_infos, const InputParam ¶ms) const { + OutputShapes outputs; + const size_t local_moe_expert_num = CalculateLocalExpertNum(input_infos, params); + const size_t ep_recv_counts = CalcEpRecvCount(params, local_moe_expert_num); + SetPartiallyDynamicShapes(outputs, params, local_moe_expert_num, ep_recv_counts); + return ConvertToShapeArray(outputs); + } + + std::pair CalculateLocalExpertAndBufferSize(const InferInfoPtrList &input_infos, + const InputParam ¶ms) const { + const bool shared_front = (params.expert_shard_type_ == 0); + const bool is_shared_default = ((params.shared_expert_num_ == 1) && (params.shared_expert_rank_num_ == 0)); + const bool is_no_shared = ((params.shared_expert_num_ == 0) && (params.shared_expert_rank_num_ == 0)); + size_t local_moe_expert_num = 0; + size_t a = 0; + const size_t global_bs_real = (params.global_bs_ == 0) ? (params.bs_ * params.ep_world_size_) : params.global_bs_; + if (shared_front) { + if (params.ep_rank_id_ < params.shared_expert_rank_num_) { + local_moe_expert_num = 1; + const size_t max_bs = global_bs_real / params.ep_world_size_; + const size_t rank_num_per_shared_expert = params.shared_expert_rank_num_ / params.shared_expert_num_; + const size_t max_shared_group_num = + (params.ep_world_size_ + rank_num_per_shared_expert - 1) / rank_num_per_shared_expert; + a = max_bs * max_shared_group_num; + } else { + local_moe_expert_num = params.moe_expert_num_ / (params.ep_world_size_ - params.shared_expert_rank_num_); + a = global_bs_real * std::min(local_moe_expert_num, static_cast(params.k_)); + } + + if (!input_infos[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ElasticInfoIndex)] + ->IsNone()) { + if (is_shared_default || is_no_shared) { + local_moe_expert_num = std::max( + local_moe_expert_num, + static_cast(params.moe_expert_num_ / (params.ep_world_size_ - params.shared_expert_rank_num_))); + a = global_bs_real * std::min(local_moe_expert_num, static_cast(params.k_)); + } else { + const size_t max_bs = global_bs_real / params.ep_world_size_; + const size_t rank_num_per_shared_expert = params.shared_expert_rank_num_ / params.shared_expert_num_; + const size_t max_shared_group_num = + (params.ep_world_size_ + rank_num_per_shared_expert - 1) / rank_num_per_shared_expert; + a = std::max( + max_bs * max_shared_group_num, + static_cast(global_bs_real * + std::min(static_cast(params.moe_expert_num_ / + (params.ep_world_size_ - params.shared_expert_rank_num_)), + static_cast(params.k_)))); + local_moe_expert_num = std::max( + local_moe_expert_num, + static_cast(params.moe_expert_num_ / (params.ep_world_size_ - params.shared_expert_rank_num_))); + } + } + } + return {local_moe_expert_num, a}; + } + + void SetPreciseOutputShapes(OutputShapes &outputs, const InputParam ¶ms, size_t local_moe_expert_num, + size_t ep_recv_cnt_nums, size_t a) const { + if (params.tp_world_size_ == kDim0) { + outputs.expand_x = {static_cast(a), params.h_}; + outputs.dynamic_scales = {static_cast(a)}; + } else { + outputs.expand_x = {static_cast(a * params.tp_world_size_), params.h_}; + outputs.dynamic_scales = {static_cast(a * params.tp_world_size_)}; + } + + outputs.assist_info = { + static_cast(std::max(static_cast(params.bs_ * params.k_), static_cast(a * 128)))}; + outputs.expert_token_nums = {static_cast(local_moe_expert_num)}; + outputs.ep_recv_counts = {static_cast(ep_recv_cnt_nums)}; + outputs.tp_recv_counts = {params.tp_world_size_}; + outputs.expand_scales = {static_cast(a)}; + } + + ShapeArray CalcPreciseShapes(const InferInfoPtrList &input_infos, const InputParam ¶ms) const { + OutputShapes outputs; + const auto [local_moe_expert_num, a] = CalculateLocalExpertAndBufferSize(input_infos, params); + size_t ep_recv_cnt_nums = CalcEpRecvCount(params, local_moe_expert_num); + if (!input_infos[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ExpertScalesIndex)] + ->IsNone()) { + const size_t global_bs_real = (params.global_bs_ == 0) ? (params.bs_ * params.ep_world_size_) : params.global_bs_; + ep_recv_cnt_nums = + params.ep_world_size_ * local_moe_expert_num + 2 * global_bs_real * params.k_ * (params.ep_world_size_ / 8); + } + SetPreciseOutputShapes(outputs, params, local_moe_expert_num, ep_recv_cnt_nums, a); + return ConvertToShapeArray(outputs); + } + + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + + ValidateInputsSize(input_infos, op_name); + InputParam params = ExtractInputParams(input_infos, op_name); + ValidateParams(params, op_name); + if (HasDynamicRank(input_infos)) { + return CreateDynamicRanks(); + } + + if ((params.bs_ == abstract::Shape::kShapeDimAny) || (params.k_ == abstract::Shape::kShapeDimAny)) { + return CreatPartiallyDynamicShapes(input_infos, params); + } + + return CalcPreciseShapes(input_infos, params); + } + + std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + MS_EXCEPTION_IF_NULL(primitive); + std::vector outputs_dtypes; + + const auto quant_mode = + input_infos[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3QuantModeIndex)] + ->GetScalarValueWithCheck(); + + auto x_dtype = + input_infos[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3XIndex)]->GetType(); + if (input_infos[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ScalesIndex)] + ->IsNone() && + quant_mode == 0) { + (void)outputs_dtypes.emplace_back(x_dtype); + } else { + (void)outputs_dtypes.emplace_back(kNumberTypeInt8); + } + (void)outputs_dtypes.emplace_back(kNumberTypeFloat32); // dynamic_scales + (void)outputs_dtypes.emplace_back(kNumberTypeInt32); // assist_info_for_combine + (void)outputs_dtypes.emplace_back(kNumberTypeInt64); // expert_token_nums + (void)outputs_dtypes.emplace_back(kNumberTypeInt32); // ep_recv_counts + (void)outputs_dtypes.emplace_back(kNumberTypeInt32); // tp_recv_counts + (void)outputs_dtypes.emplace_back(kNumberTypeFloat32); // expand_scales + + return outputs_dtypes; + } + + bool GeneralInferRegistered() const override { return true; } +}; + +class MoeDistributeDispatchV3CustomAscend : public AclnnCustomKernelMod { + public: + MoeDistributeDispatchV3CustomAscend() : AclnnCustomKernelMod("aclnnMoeDistributeDispatchV3") {} + ~MoeDistributeDispatchV3CustomAscend() = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + MS_EXCEPTION_IF_NULL(stream_ptr); + RunOp( + stream_ptr, workspace, + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3XIndex)], + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ExpertIdsIndex)], + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ScalesIndex)], + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3XActiveMaskIndex)], + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ExpertScalesIndex)], + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ElasticInfoIndex)], + group_ep_, ep_world_size_, ep_rank_id_, moe_expert_num_, group_tp_, tp_world_size_, tp_rank_id_, + expert_shard_type_, shared_expert_num_, shared_expert_rank_num_, quant_mode_, global_bs_, expert_token_nums_type_, + comm_alg_, zero_expert_num_, copy_expert_num_, const_expert_num_, + outputs[static_cast(MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3ExpandXOutputIndex)], + outputs[static_cast( + MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3DynamicScalesOutputIndex)], + outputs[static_cast( + MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3AssistInfoForCombineOutputIndex)], + outputs[static_cast( + MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3ExpertTokenNumsOutputIndex)], + outputs[static_cast(MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3EpRecvCountsOutputIndex)], + outputs[static_cast(MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3TpRecvCountsOutputIndex)], + outputs[static_cast( + MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3ExpandScalesOutputIndex)]); + return true; + } + + void GetWorkSpaceInfo(const std::vector &inputs, + const std::vector &outputs) override { + auto group_ep = inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3GroupEpIndex)] + ->GetOptionalValueWithCheck(); + group_ep_ = group_ep.has_value() ? mindspore::device::ascend::OpApiUtil::GetCommName(group_ep.value()) + : mindspore::device::ascend::OpApiUtil::GetCommName(kHcclWorldGroup); + ep_world_size_ = + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3EpWorldSizeIndex)] + ->GetValueWithCheck(); + ep_rank_id_ = inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3EpRankIdIndex)] + ->GetValueWithCheck(); + moe_expert_num_ = + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3MoeExpertNumIndex)] + ->GetValueWithCheck(); + auto group_tp = inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3GroupTpIndex)] + ->GetOptionalValueWithCheck(); + group_tp_ = group_tp.has_value() ? mindspore::device::ascend::OpApiUtil::GetCommName(group_tp.value()) : ""; + + tp_world_size_ = + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3TpWorldSizeIndex)] + ->GetValueWithCheck(); + tp_rank_id_ = inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3TpRankIdIndex)] + ->GetValueWithCheck(); + expert_shard_type_ = + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ExpertShardTypeIndex)] + ->GetValueWithCheck(); + shared_expert_num_ = + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3SharedExpertNumIndex)] + ->GetValueWithCheck(); + shared_expert_rank_num_ = + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3SharedExpertRankNumIndex)] + ->GetValueWithCheck(); + + quant_mode_ = inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3QuantModeIndex)] + ->GetValueWithCheck(); + global_bs_ = inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3GlobalBsIndex)] + ->GetValueWithCheck(); + expert_token_nums_type_ = + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ExpertTokenNumsTypeIndex)] + ->GetValueWithCheck(); + auto comm_alg = inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3CommAlgIndex)] + ->GetOptionalValueWithCheck(); + comm_alg_ = (comm_alg.has_value()) ? comm_alg.value() : ""; + zero_expert_num_ = + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ZeroExpertNumIndex)] + ->GetValueWithCheck(); + copy_expert_num_ = + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3CopyExpertNumIndex)] + ->GetValueWithCheck(); + const_expert_num_ = + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ConstExpertNumIndex)] + ->GetValueWithCheck(); + + GetWorkspaceForResize( + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3XIndex)], + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ExpertIdsIndex)], + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ScalesIndex)], + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3XActiveMaskIndex)], + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ExpertScalesIndex)], + inputs[static_cast(MoeDistributeDispatchV3InputIndex::kMoeDistributeDispatchV3ElasticInfoIndex)], + group_ep_, ep_world_size_, ep_rank_id_, moe_expert_num_, group_tp_, tp_world_size_, tp_rank_id_, + expert_shard_type_, shared_expert_num_, shared_expert_rank_num_, quant_mode_, global_bs_, expert_token_nums_type_, + comm_alg_, zero_expert_num_, copy_expert_num_, const_expert_num_, + outputs[static_cast(MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3ExpandXOutputIndex)], + outputs[static_cast( + MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3DynamicScalesOutputIndex)], + outputs[static_cast( + MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3AssistInfoForCombineOutputIndex)], + outputs[static_cast( + MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3ExpertTokenNumsOutputIndex)], + outputs[static_cast(MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3EpRecvCountsOutputIndex)], + outputs[static_cast(MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3TpRecvCountsOutputIndex)], + outputs[static_cast( + MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3ExpandScalesOutputIndex)]); + return; + } + + private: + DEFINE_GET_WORKSPACE_FOR_RESIZE(); + std::string group_ep_; + int64_t ep_world_size_ = 0; + int64_t ep_rank_id_ = 0; + int64_t moe_expert_num_ = 0; + std::string group_tp_; + int64_t tp_world_size_ = 0; + int64_t tp_rank_id_ = 0; + int64_t expert_shard_type_ = 0; + int64_t shared_expert_num_ = 0; + int64_t shared_expert_rank_num_ = 0; + int64_t quant_mode_ = 0; + int64_t global_bs_ = 0; + int64_t expert_token_nums_type_ = 0; + std::string comm_alg_; + int64_t zero_expert_num_ = 0; + int64_t copy_expert_num_ = 0; + int64_t const_expert_num_ = 0; +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(moe_distribute_dispatch_v3, ms_custom_ops::MoeDistributeDispatchV3CustomOpFuncImpl, + ms_custom_ops::MoeDistributeDispatchV3CustomAscend); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +namespace ms_custom_ops { +std::vector get_moe_distribute_v3_output_tensor( + const ms::Tensor &x, const ms::Tensor &expert_ids, const int64_t ep_world_size, const int64_t ep_rank_id, + const int64_t moe_expert_num, const std::optional &scales, const std::optional &x_active_mask, + const std::optional &expert_scales, const std::optional &elastic_info, + const std::string &group_ep, const std::optional &group_tp, const int64_t tp_world_size, + const int64_t tp_rank_id, const int64_t expert_shard_type, const int64_t shared_expert_num, + const int64_t shared_expert_rank_num, const int64_t quant_mode, const int64_t global_bs, + const int64_t expert_token_nums_type, const std::optional &comm_alg, const int64_t zero_expert_num, + const int64_t copy_expert_num, const int64_t const_expert_num) { + ValidateParamsComm(shared_expert_num, shared_expert_rank_num, expert_token_nums_type, "moe_distribute_dispatch_v3"); + + auto x_shape = x.shape(); + auto expert_ids_shape = expert_ids.shape(); + auto bs = x_shape[kDim0]; + auto h = x_shape[kDim1]; + auto k = expert_ids_shape[kDim1]; + + const bool shared_front = (expert_shard_type == 0); + const bool is_shared_default = ((shared_expert_num == 1) && (shared_expert_rank_num == 0)); + const bool is_no_shared = ((shared_expert_num == 0) && (shared_expert_rank_num == 0)); + + int64_t local_moe_expert_num = 1; + int64_t global_bs_real = (global_bs == 0) ? (bs * ep_world_size) : global_bs; + int64_t a = 0; + int64_t ep_recv_cnt_num = 0; + if (shared_front) { + if (ep_rank_id < shared_expert_rank_num) { + local_moe_expert_num = 1; + int64_t max_bs = global_bs_real / ep_world_size; + int64_t rank_num_per_shared_expert = shared_expert_rank_num / shared_expert_num; + int64_t max_shared_group_num = (ep_world_size + rank_num_per_shared_expert - 1) / rank_num_per_shared_expert; + a = max_bs * max_shared_group_num; + } else { + local_moe_expert_num = moe_expert_num / (ep_world_size - shared_expert_rank_num); + a = global_bs_real * std::min(local_moe_expert_num, static_cast(k)); + } + + if (elastic_info.has_value()) { + if ((is_shared_default) || (is_no_shared)) { + local_moe_expert_num = std::max( + local_moe_expert_num, static_cast(moe_expert_num / (ep_world_size - shared_expert_rank_num))); + a = global_bs_real * std::min(local_moe_expert_num, static_cast(k)); + } else { + int64_t max_bs = global_bs_real / ep_world_size; + int64_t rank_num_per_shared_expert = shared_expert_rank_num / shared_expert_num; + int64_t max_shared_group_num = (ep_world_size + rank_num_per_shared_expert - 1) / rank_num_per_shared_expert; + a = std::max( + max_bs * max_shared_group_num, + static_cast(global_bs_real * std::min(moe_expert_num / (ep_world_size - shared_expert_rank_num), + static_cast(k)))); + local_moe_expert_num = std::max( + local_moe_expert_num, static_cast(moe_expert_num / (ep_world_size - shared_expert_rank_num))); + } + } + } + + if (tp_world_size == kDim2) { + ep_recv_cnt_num = ep_world_size * local_moe_expert_num * tp_world_size; + } else { + ep_recv_cnt_num = ep_world_size * local_moe_expert_num; + } + + TypeId expand_x_dtype = (!scales.has_value() && quant_mode == 0) ? x.data_type() : kNumberTypeInt8; + TypeId dynamic_scales_dtype = kNumberTypeFloat32; + TypeId assist_info_for_combine_dtype = kNumberTypeInt32; + TypeId expert_token_nums_dtype = kNumberTypeInt64; + TypeId ep_recv_counts_dtype = kNumberTypeInt32; + TypeId tp_recv_counts_dtype = kNumberTypeInt32; + TypeId expand_scales_dtype = kNumberTypeFloat32; + + ShapeVector expand_x_shape = (tp_world_size == 0) ? ShapeVector{a, h} : ShapeVector{a * tp_world_size, h}; + ShapeVector dynamic_scales_shape = + (tp_world_size == 0) ? ShapeVector{static_cast(1)} : ShapeVector{a * tp_world_size}; + ShapeVector assist_info_for_combine_shape = {std::max(bs * k, a * 128)}; + ShapeVector expert_token_nums_shape = {local_moe_expert_num}; + if (expert_scales.has_value()) { + ep_recv_cnt_num = ep_world_size * local_moe_expert_num + + 2 * global_bs_real * k * (ep_world_size / 8); // 2: 2 buffer, 8 ranknum per server + } + ShapeVector ep_recv_counts_shape = {ep_recv_cnt_num}; + ShapeVector tp_recv_counts_shape = {tp_world_size}; + ShapeVector expand_scales_shape = {a}; + + std::vector outputs = { + ms::Tensor(expand_x_dtype, expand_x_shape), + ms::Tensor(dynamic_scales_dtype, dynamic_scales_shape), + ms::Tensor(assist_info_for_combine_dtype, assist_info_for_combine_shape), + ms::Tensor(expert_token_nums_dtype, expert_token_nums_shape), + ms::Tensor(ep_recv_counts_dtype, ep_recv_counts_shape), + ms::Tensor(tp_recv_counts_dtype, tp_recv_counts_shape), + ms::Tensor(expand_scales_dtype, expand_scales_shape), + }; + + return outputs; +} + +std::vector moe_distribute_dispatch_v3_custom( + const ms::Tensor &x, const ms::Tensor &expert_ids, const int64_t ep_world_size, const int64_t ep_rank_id, + const int64_t moe_expert_num, const std::optional &scales, const std::optional &x_active_mask, + const std::optional &expert_scales, const std::optional &elastic_info, + const std::string &group_ep, const std::optional &group_tp, const int64_t tp_world_size, + const int64_t tp_rank_id, const int64_t expert_shard_type, const int64_t shared_expert_num, + const int64_t shared_expert_rank_num, const int64_t quant_mode, const int64_t global_bs, + const int64_t expert_token_nums_type, const std::optional &comm_alg, const int64_t zero_expert_num, + const int64_t copy_expert_num, const int64_t const_expert_num) { + std::vector outputs = get_moe_distribute_v3_output_tensor( + x, expert_ids, ep_world_size, ep_rank_id, moe_expert_num, scales, x_active_mask, expert_scales, elastic_info, + group_ep, group_tp, tp_world_size, tp_rank_id, expert_shard_type, shared_expert_num, shared_expert_rank_num, + quant_mode, global_bs, expert_token_nums_type, comm_alg, zero_expert_num, copy_expert_num, const_expert_num); + auto runner = std::make_shared("aclnnMoeDistributeDispatchV3"); + + auto new_group_ep = (group_ep == "") ? mindspore::device::ascend::OpApiUtil::GetCommName(kHcclWorldGroup) + : mindspore::device::ascend::OpApiUtil::GetCommName(group_ep); + auto group_tp_str = (group_tp.has_value() ? group_tp.value() : ""); + auto new_group_tp = (group_tp_str.empty()) ? "" : mindspore::device::ascend::OpApiUtil::GetCommName(group_tp_str); + auto new_comm_alg = (comm_alg.has_value() ? comm_alg.value() : ""); + runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC( + aclnnMoeDistributeDispatchV3, x, expert_ids, scales, x_active_mask, expert_scales, elastic_info, new_group_ep, + ep_world_size, ep_rank_id, moe_expert_num, new_group_tp, tp_world_size, tp_rank_id, expert_shard_type, + shared_expert_num, shared_expert_rank_num, quant_mode, global_bs, expert_token_nums_type, new_comm_alg, + zero_expert_num, copy_expert_num, const_expert_num, outputs[kDim0], outputs[kDim1], outputs[kDim2], outputs[kDim3], + outputs[kDim4], outputs[kDim5], outputs[kDim6])); + // only set tensor. + runner->Run({x, expert_ids, GetTensorOrEmpty(scales), GetTensorOrEmpty(x_active_mask), + GetTensorOrEmpty(expert_scales), GetTensorOrEmpty(elastic_info)}, + outputs); + return outputs; +} +} // namespace ms_custom_ops + +auto pyboost_moe_distribute_dispatch_v3( + const ms::Tensor &x, const ms::Tensor &expert_ids, const int64_t ep_world_size, const int64_t ep_rank_id, + const int64_t moe_expert_num, const std::optional &scales, const std::optional &x_active_mask, + const std::optional &expert_scales, const std::optional &elastic_info, + const std::string &group_ep, const std::optional &group_tp, const int64_t tp_world_size, + const int64_t tp_rank_id, const int64_t expert_shard_type, const int64_t shared_expert_num, + const int64_t shared_expert_rank_num, const int64_t quant_mode, const int64_t global_bs, + const int64_t expert_token_nums_type, const std::optional &comm_alg, const int64_t zero_expert_num, + const int64_t copy_expert_num, const int64_t const_expert_num) { + return ms::pynative::PyboostRunner::Call( + ms_custom_ops::MoeDistributeDispatchV3OutputIndex::kMoeDistributeDispatchV3OutputNums)>( + ms_custom_ops::moe_distribute_dispatch_v3_custom, x, expert_ids, ep_world_size, ep_rank_id, moe_expert_num, scales, + x_active_mask, expert_scales, elastic_info, group_ep, group_tp, tp_world_size, tp_rank_id, expert_shard_type, + shared_expert_num, shared_expert_rank_num, quant_mode, global_bs, expert_token_nums_type, comm_alg, zero_expert_num, + copy_expert_num, const_expert_num); +} + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("moe_distribute_dispatch_v3", &pyboost_moe_distribute_dispatch_v3, "MoeDistributeDispatchV3", + pybind11::arg("x"), pybind11::arg("expert_ids"), pybind11::arg("ep_world_size"), pybind11::arg("ep_rank_id"), + pybind11::arg("moe_expert_num"), pybind11::arg("scales") = std::nullopt, + pybind11::arg("x_active_mask") = std::nullopt, pybind11::arg("expert_scales") = std::nullopt, + pybind11::arg("elastic_info") = std::nullopt, pybind11::arg("group_ep") = "", + pybind11::arg("group_tp") = std::nullopt, pybind11::arg("tp_world_size") = 0, pybind11::arg("tp_rank_id") = 0, + pybind11::arg("expert_shard_type") = 0, pybind11::arg("shared_expert_num") = 1, + pybind11::arg("shared_expert_rank_num") = 0, pybind11::arg("quant_mode") = 0, pybind11::arg("global_bs") = 0, + pybind11::arg("expert_token_nums_type") = 1, pybind11::arg("comm_alg") = std::nullopt, + pybind11::arg("zero_expert_num") = 0, pybind11::arg("copy_expert_num") = 0, + pybind11::arg("const_expert_num") = 0); +} diff --git a/ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_v3.md b/ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_v3.md new file mode 100644 index 0000000..adef27d --- /dev/null +++ b/ops/c_api/moe_distribute_dispatch_v3/moe_distribute_dispatch_v3.md @@ -0,0 +1,139 @@ +# moe_distribute_dispatch_v3算子 + +## 描述 + +moe_distribute_dispatch_v3算子对token数据进行量化(可选),当存在TP通信域时,先进行EP(Expert parallelism)域的AllToAll通信,再进行TP(Tensor Parallelism)域的AllGatherV通信;当不存在TP通信域时,进行TP域的AllToAllV通信。该算子底层调用的是aclnnMoeDistributeDispatchV3算子。 + +## 相对于aclnnMoeDistributeDispatchV2算子 + +- 新增支持动态缩容场景,支持在创建通信域后,出现故障卡,将故障卡冲通信域中剔除,算子可正常执行,无需重新编译,通过传入elastic_info参数使能本特性 +- 新增支持特殊专家场景 + - zeroExpertNum非0时使能该特性, $MoE(oriXOptional) = 0$ + - copyExpertNum非0时使能该特性, 同时还需要传入有效的oriXOptional参数,$MoE(oriXOptional) = oriXOptional$ + - constExpertNum非0时使能该特性, 同时还需传入有效的oriXOptional、constExpertAlpha1Optional、constExpertAlpha2Optional、constExpertVOptional参数.$MoE(oriXOptional) = constExpertAlpha1Optional*oriXOptional+constExpertAlpha2Optional*constExpertVOptional$ + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|------|-------|-------|----------|---------|--------|-------------| +| x | Tensor | 2维[Bs, H] | No | No | ND | 表示本卡发送的token数据 | +| expert_ids | Tensor | 2维[Bs, K] | No | No | ND | 每个token的topK个专家索引 | +| ep_world_size | int | No | No | No | - | EP通信域size | +| ep_rank_id | int | No | No | No | - | EP域中本卡ID | +| moe_expert_num | int | No | No | No | - | MoE专家数量 | +| scales | Tensor | 2维[sharedExpertNum + moeExpertNum, H] | Yes | No | ND | 每个专家的量化平滑参数 | +| x_activemask | Tensor | 1维或2维,参见约束说明 | Yes | No | ND | 表示token是否参与通信 | +| expert_scales | Tensor | 2维,参见约束说明 | Yes | No | ND | 每个token的topK个专家权重 | +| elastic_info | Tensor | 1维,参见约束说明 | Yes | No | ND | 表示RP通信域的动态缩容信息,当某些通信卡因异常而从通信域中剔除,实际参与通信的卡数可从本参数中获取 | +| group_ep | str | 字符串长度在[1, 128) | No | No | - | EP通信域名称,专家并行的通信域 | +| group_tp | str | 字符串,参见约束说明 | Yes | No | - | TP通信域名称,数据并行的通信域 | +| tp_world_size | int | No | No | No | - | TP通信域的size | +| tp_rank_id | int | No | No | No | - | TP域本卡Id | +| expert_shard_type | int | No | No | No | - | 表示共享专家卡分布类型 | +| shared_expert_num | int | No | No | No | - | 表示共享专家数量,一个共享专家可以复制部署到多个卡上 | +| shared_expert_rank_num | int | No | No | No | - | 表示共享专家卡数量 | +| quant_mode | int | No | No | No | - | 表示量化模式,0:非量化,2:动态量化 | +| global_bs | int | No | No | No | - | EP域全局的batch_size大小 | +| expert_token_nums_type | int | No | No | No | - | 输出expertTokenNUms中值的语义类型。0:expert_token_nums输出为每个专家处理token数的前缀和;1:expert_token_nums输出为每个专家处理的token数量 | +| comm_alg | str | No | No | No | - | 表示通信亲和内存布局算法,参见约束说明 | +| zero_expert_num | int | No | No | No | - | 表示零专家的数量 | +| copy_expert_num | int | No | No | No | - | 表示拷贝专家的数量 | +| const_expert_num | int | No | No | No | - | 表示常量专家的数量 | + +## 参数说明和约束 + +- **x**: 本卡发送的Token数据,shape=(Bs, H), Bs为batch_size, H为hidden_size,支持dtype为Float16/BFloat16 +- **expert_ids**: 每个Token的topK个专家索引,2D tensor,shape=(Bs, K) +- **ep_world_size**: 在A2机器上,取值支持16/32/64;在A3机器上取值为区间[2, 768] +- **ep_rank_id**: 取值范围为[0, epWorldSize),同一个EP通信域内各卡的ep_rank_id不重复 +- **moe_expert_num**: 需要满足moe_expert_num % (ep_world_size -shared_expert_rank_num) = 0.在A2机器上,取值范围(0, 512],且需要满足moe_expert_num / (ep_world_size - shared_expert_rank_num) <= 24.在A3机器上取值范围为(0, 1024] +- **scales**: 每个专家的量化平滑参数,2维tensor,shape=(shared_expert_num + moe_expert_num, H).非量化场景传None,动态量化场景可选择传入有效数据或者None + - A2机器上,当comm_alg配置为"hierarchy"或者配置HCCL_INTRA_PCIE_ENABLE=1&&HCCL_INTRA_ROCE_ENABLE=0,要求传入None + - A3机器上无特殊要求 +- **x_active_mask**: 在A2机器上当前不支持,传None即可 + - 在A3机器上,可传入1维或者2维tensor。当输入为1维tensor,shape为(Bs);当传入2维tensor,shape为(Bs, K).数据类型为Bool,可选择传入None或有效数据。当输入为1维时,参数为True表示对应Token参与通信,且True必须排到False之前,非法输入示例:[True, False, True].当参数为2维tensor,参数为True表示当前token对应expert_ids参与通信。若当前token对应的K个BOOL值全为False,表示当前token不会参与通信。默认所有token都会参与通信。当每张卡的Bs数量不一致时,所有token必须全部有效 +- **expert_scales**: 在A2机器上是2维tensor,shape=(Bs, K).A3机器上,当前不支持,传默认值None +- **elastic_info**: + - A2机器上当前不支持,传None + - A3上,可传入None或实际有效数据;传入None表示不使能动态缩容功能;当传入有效数据,要求为1维的(4 + 2 x ep_world_size)tensor。Tensor中前四个数字表示(是否缩容,缩容后实际rank数,缩容后共享专家使用的rank数,缩容后moe专家的个数),后续2 x ep_world_size表示2个rank映射表,缩容后本卡中因部分rank异常而从EP通信域中剔除,第一个Table的映射关系为Table1[epRankId]=localEpRankId或-1,localEpRankId表示新EP通信域中的rank Index,-1表示epRankId这张卡从通信域中被剔除,第二个Table映射关系为Table2[localEpRankId] = epRankId +- **group_ep**: 字符串长度范围为[1, 128),不能和groupTp相同 +- **group_tp**: A2上,当前版本不支持,传空字符.A3上,字符串长度范围为[1, 128),不能和groupEp相同 +- **tp_world_size**: A2上,当前版本不支持,传0即可.A3上,取值范围[0, 2],0和1表示无TP域通信,有TP域通信时仅支持2 +- **tp_rank_id**: A2上,当前版本不支持,传0即可.A3上,取值范围[0, 1],同一个TP通信域中各卡的tpRankId不重复.无TP域通信时,传0即可 +- **expert_shard_type**: A2上,当前版本不支持,传0即可.A3上,当前仅支持传0,表示共享专家卡排在MoE专家卡前面 +- **shared_expert_num**: A2上,当前版本不支持,传0即可.A3上,当前取值范围为[0,4] +- **shared_expert_rank_num**: A2上,当前版本不支持,传0即可.A3上,当前取值范围[0, epWorldSize),为0时需满足sharedExpertNum为0或1,不为0时需满足sharedExpertRankNum % sharedExpertNum = 0 +- **quant_mode**: 支持0:非量化,2:动态量化 +- **global_bs**: 当每个rank的Bs数一致时,globalBs = Bs x epWorldSize 或 globalBs = 0;当每个rank的Bs数不一致时,globalBs = maxBs x epWorldSize,其中maxBs表示单卡Bs最大值 +- **expert_token_nums_type**: 支持0:expertTokenNums中的输出为每个专家处理的token数的前缀和,1:expertTokenNums中的输出为每个专家处理的token数量 +- **comm_alg**: A3机器上当前版本不支持;A2机器上,当前版本支持nullptr,"","fullmesh","hierarchy"四种输入方式.推荐配置"hierarchy"并搭配25.0.RC1.1及以上版本驱动使用 + - nullptr和"": 仅在此场景下,HCCL_INTRA_PCIE_ENABLE和HCCL_INTRA_ROCE_ENABLE配置生效.当HCCL_INTRA_PCIE_ENABLE=1&&HCCL_INTRA_ROCE_ENABLE=0时,调用"hierarchy"算法,否则调用"fullmesh"算法.不推荐使用该方式 + - "fullmesh": token数据直接通过RDMA方式发往topk个目标专家所在的卡 + - "hierarchy": token数据经过跨机、机内两次发送,仅不同server同号卡之间使用RDMA通信,server内使用HCCS通信 +- **zero_expert_num**: A2上,当前不支持,传入0即可;A3上,取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1, 合法的零专家的ID的值是[moeExpertNum, moeExpertNum + zeroExpertNum) +- **copy_expert_num**: A2上,当前不支持,传入0即可;A3上,取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1, 合法的拷贝专家的ID的值是[moeExpertNum + zeroExpertNum, moeExpertNum + zeroExpertNum + copyExpertNum) +- **const_expert_num**: A2上,当前不支持,传入0即可;A3上,取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1, 合法的常量专家的ID的值是[moeExpertNum + zeroExpertNum + copyExpertNum, moeExpertNum + zeroExpertNum + copyExpertNum + constExpertNum) + +## 输出参数 + +| Name | DType | Shape | Description | +|------|-------|-------|-------------| +| expand_x | Tensor | [max(tp_world_size, 1) * A, H] | 根据expert_ids进行扩展过的token特征 | +| dynamic_scales | Tensor | [A] | 表示计算得到的动态量化参数 | +| assist_info_for_combine | Tensor | [A*128] | 表示给同一专家发送的token个数,对应Combine算子中的assistInfoForCombine | +| expert_token_nums | Tensor | [local_expert_num] | 表示每个专家收到的token个数 | +| ep_recv_counts | Tensor | 1维,shape见约束说明 | 从EP通信域各卡接收的token数,对应Combine算子中的epSendCounts | +| tp_recv_counts | Tensor | 1维,shape见约束说明 | 从TP通信域各卡接收的token数,对应Combine算子中的tpSendCounts | +| expand_scales | Tensor | 1维,shape见约束说明 | 表示本卡输出token的权重,对应Combine算子中的expertScalesOptional | + +## 输出约束说明 + +- **dynamic_scales**: 当quant_mode为2时,才有该输出 +- **ep_recv_counts**: + - A2机器上,要求shape为 (moeExpertNum + 2 x globalBs x K x serverNum, ),前moeExpertNum个数表示从EP通信域各卡接收的token数,2 x globalBs x K x serverNum存储了机间机内做通信前combine可以提前做reduce的token个数和token在通信区中的偏移,globalBs传入0时在此处应当按照Bs x epWorldSize计算 + - A3机器上,要求shape为 (epWorldSize x max(tpWorldSize, 1) x localExpertNum) +- **tp_recv_counts**: A2机器上当前不支持TP域通信.A3上当有TP域通信时,要求是一个1D的Tensor,shape为 (tpWorldSize, ) +- **expand_scales**: A2上要求是一个1D的Tensor,shape为 (A).A3上,当前版本不支持该输出 + +expand_x中A(表示本卡需要分发的最大token数量)大小计算: + +- 不使能动态缩容场景时: + - 对于共享专家,要满足于:$A = Bs x epWorldSize x \frac{sharedExpertNum}{sharedExpertRankNum}$ + - 对于MoE专家,当global_bs为0时,要满足:$A \gt= Bs x epWorldSize x min(localExpertNum, K)$;当global_bs非0时,要满足$A >= globalBs x min(localExpertNum, K)$ +- 使能动态缩容场景时: + - 当global_bs = 0时, $A >= max(Bs x epWorldSize x \frac{sharedExpertNum}{sharedExpertRankNum}, Bs x ep_world_size x min(localExpertNum, K))$ + - 当global_bs非0时, $A >= max(Bs x epWorldSize x \frac{sharedExpertNum}{sharedExpertRankNum}, globalBs x min(localExpertNum, K))$ + +H:表示是hidden_size隐藏层大小 + +- A2机器上取值是(0, 7168],且需保证是32的整数倍 +- A3机器上取值为[1024, 8192] + +Bs:表示为Batch sequence size,本卡最终输出的token数量 + +- A2机器上取值为(0, 256] +- A3机器上取值为(0, 512] + +K: 表示选取的topK个专家,$K \in (0, 16]$并且满足$K \in (0, moeExpertNum + zeroExpertNum + copyExpertNum + constExpertNum]$ + +localExpertNum:本卡专家数量 + +- 对于共享专家卡, localExpertNum = 1 +- 对于MoE专家卡,$localExpertNum = \frac{moeExpertNum}{(epWorldSize - sharedExpertRankNum)}$, localExpertNum > 1,不支持TP域通信 + +更多详细信息请参考:[aclnnMoeDistributeDispatchV3](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/83RC1alpha003/API/aolapi/context/aclnnMoeDistributeDispatchV3.md) + +## 特殊说明 + +## 使用示例 + +### 基本使用示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops + +ms.set_device("Ascend") +#TODO +``` diff --git a/ops/c_api/utils/check_utils.h b/ops/c_api/utils/check_utils.h index f8b9c24..b31d367 100644 --- a/ops/c_api/utils/check_utils.h +++ b/ops/c_api/utils/check_utils.h @@ -50,11 +50,11 @@ namespace ms_custom_ops { } \ } while (0) -#define MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL(condition, error_info) \ - do { \ - if (!(condition)) { \ - MS_LOG(EXCEPTION) << "Failure info [" << error_info << "]."; \ - } \ +#define MS_CUSTOM_OPS_EXCEPTION_IF_CHECK_FAIL(condition, op_name, error_info) \ + do { \ + if (!(condition)) { \ + MS_LOG(EXCEPTION) << "Failure info [" << op_name << ":" << error_info << "]."; \ + } \ } while (0) #define MS_CUSTOM_OPS_CHECK_VALUE(cond, msg) \ diff --git a/ops/framework/utils.h b/ops/framework/utils.h index ef1ff00..9ecf9ce 100644 --- a/ops/framework/utils.h +++ b/ops/framework/utils.h @@ -54,6 +54,8 @@ constexpr size_t kNumber6{6}; static const mindspore::ShapeArray kFakeOutTensorShapes{mindspore::ShapeVector{1}}; // 用于静态图下没有返回值的算子InferType占位输出 static const std::vector kFakeOutTensorTypes{mindspore::TypeId::kNumberTypeInt8}; +static const mindspore::ShapeVector kFakeOutShapes{1}; + // Helper function to convert optional tensor to tensor or empty tensor inline ms::Tensor GetTensorOrEmpty(const std::optional &opt_tensor) { return opt_tensor.has_value() ? opt_tensor.value() : ms::Tensor(); -- Gitee