From 340c6d794fd7b4df232a514c7f1fc1256b72aac0 Mon Sep 17 00:00:00 2001 From: zhangshucheng Date: Mon, 20 Oct 2025 10:22:29 +0800 Subject: [PATCH 1/3] rope_v3 tiling check Signed-off-by: zhangshucheng --- .../op_host/apply_rotary_pos_emb_v3.cpp | 71 +++++++++++++------ 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp b/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp index 7b82b44..7ebb683 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp +++ b/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp @@ -19,27 +19,43 @@ #include "tiling/platform/platform_ascendc.h" namespace optiling { +constexpr uint32_t ROPE_V3_USE_TBUF_COUNT = 6; +constexpr uint32_t ROPE_V3_USE_TBUF_COSSIN_COUNT = 2; +constexpr uint32_t ROPE_V3_TILINGKEY_FP16 = 1; +constexpr uint32_t ROPE_V3_TILINGKEY_FP32 = 2; +constexpr uint32_t ROPE_V3_TILINGKEY_FACTOR = 10; +constexpr uint32_t ROPE_V3_ROTARY_DIM_FACTOR = 2; +constexpr uint32_t kIndex0 = 0; +constexpr uint32_t kIndex1 = 1; +constexpr uint32_t kIndex2 = 2; +constexpr uint32_t kDim0 = 0; +constexpr uint32_t kDim1 = 1; +constexpr uint32_t kDim2 = 2; +constexpr char NODE_NAME[] = "apply_rotary_pos_emb_v3"; + static ge::graphStatus ApplyRotaryPosEmbV3Tiling(gert::TilingContext *context) { ApplyRotaryPosEmbV3TilingData tiling; uint32_t tiling_key{0}; - uint64_t ub_size; + uint64_t ub_total_size; auto ascendc_platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - ascendc_platform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_size); + ascendc_platform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_total_size); auto coreNum = ascendc_platform.GetCoreNum(); - auto query_shape = context->GetInputShape(0)->GetOriginShape(); - auto key_shape = context->GetInputShape(1)->GetOriginShape(); - auto cos_shape = context->GetInputShape(2)->GetOriginShape(); + auto query_shape = context->GetInputShape(kIndex0)->GetOriginShape(); + auto key_shape = context->GetInputShape(kIndex1)->GetOriginShape(); + auto cos_shape = context->GetInputShape(kIndex2)->GetOriginShape(); + ge::DataType query_type = context->GetInputDesc(kIndex0)->GetDataType(); + ge::DataType cos_type = context->GetInputDesc(kIndex2)->GetDataType(); - uint32_t tokens = query_shape.GetDim(0); - uint32_t query_head_num = query_shape.GetDim(1); - uint32_t key_head_num = key_shape.GetDim(1); + uint32_t tokens = query_shape.GetDim(kDim0); + uint32_t query_head_num = query_shape.GetDim(kDim1); + uint32_t key_head_num = key_shape.GetDim(kDim1); - uint32_t query_head_dim = query_shape.GetDim(2); - uint32_t cos_head_dim = cos_shape.GetDim(1); - uint32_t rotary_dim = cos_head_dim *2; + uint32_t query_head_dim = query_shape.GetDim(kDim2); + uint32_t cos_head_dim = cos_shape.GetDim(kDim1); + uint32_t rotary_dim = cos_head_dim * ROPE_V3_ROTARY_DIM_FACTOR; - uint32_t is_split = (rotary_dim == query_head_dim ? 0: 1); + uint32_t is_split = (rotary_dim == query_head_dim ? 0 : 1); tiling.set_queryHeadDim(query_head_dim); tiling.set_qHeadNum(query_head_num); tiling.set_kHeadNum(key_head_num); @@ -53,18 +69,34 @@ static ge::graphStatus ApplyRotaryPosEmbV3Tiling(gert::TilingContext *context) { } tiling.set_tokensPerCore(static_cast(tokens / coreNum)); tiling.set_tokensTail(tokens % coreNum); - const uint32_t *layout = context->GetAttrs()->GetAttrPointer(0); - const uint32_t *rotaryMode = context->GetAttrs()->GetAttrPointer(1); + const uint32_t *layout = context->GetAttrs()->GetAttrPointer(kIndex0); + const uint32_t *rotaryMode = context->GetAttrs()->GetAttrPointer(kIndex1); tiling.set_layout(*layout); tiling.set_rotaryMode(*rotaryMode); + uint32_t ub_use = + query_head_dim * (key_head_num + query_head_num) * ge::GetSizeByDataType(query_type) + + ROPE_V3_USE_TBUF_COSSIN_COUNT * cos_head_dim * ge::GetSizeByDataType(cos_type) + + (tiling.get_qHiddenSize() + tiling.get_kHiddenSize()) * ROPE_V3_USE_TBUF_COUNT * ge::GetSizeByDataType(query_type); - ge::DataType query_type = context->GetInputDesc(0)->GetDataType(); + if (ub_use > ub_total_size) { + std::cerr << "ERROR: " + << "(cos.head_dim * 2 + query.hidden_size + key.hidden_size + (query.head_num + key.head_num) * 6) " + << "* sizeof(type) should be less than or equal to UB size, but got " << std::to_string(ub_use) << " > " + << std::to_string(ub_total_size); + return ge::GRAPH_FAILED; + } + // OP_CHECK(ub_use > ub_total_size, + // OP_LOGE(context->GetNodeName(), + // "(cos.head_dim * 2 + query.hidden_size + key.hidden_size + (query.head_num + key.head_num) * 6) " + // "* sizeof(type) should be less than or equal to UB size, but got " + + // std::to_string(ub_use) + " > " + std::to_string(ub_total_size)), + // return ge::GRAPH_FAILED); if (query_type == ge::DataType::DT_FLOAT16) { - tiling_key = 1; - }else if(query_type == ge::DataType::DT_FLOAT) { - tiling_key = 2; + tiling_key = ROPE_V3_TILINGKEY_FP16; + } else if (query_type == ge::DataType::DT_FLOAT) { + tiling_key = ROPE_V3_TILINGKEY_FP32; } - tiling_key = tiling_key * 10 + is_split; + tiling_key = tiling_key * ROPE_V3_TILINGKEY_FACTOR + is_split; context->SetBlockDim(coreNum); context->SetTilingKey(tiling_key); @@ -75,7 +107,6 @@ static ge::graphStatus ApplyRotaryPosEmbV3Tiling(gert::TilingContext *context) { return ge::GRAPH_SUCCESS; } } // namespace optiling - namespace ge { static ge::graphStatus ApplyRotaryPosEmbV3InferShape(gert::InferShapeContext *context) { const gert::Shape *query_shape = context->GetInputShape(0); -- Gitee From da9de1876bbb4f3e342987fb0f69c88d9cfbc443 Mon Sep 17 00:00:00 2001 From: zhangshucheng Date: Wed, 22 Oct 2025 10:14:49 +0800 Subject: [PATCH 2/3] doc for apply_rotary_pos_emb Signed-off-by: zhangshucheng --- .../apply_rotary_pos_emb_v3.cc | 50 ++++--- .../apply_rotary_pos_emb_v3.md | 26 ++-- .../apply_rotary_pos_emb_v3_op.yaml | 8 +- .../op_host/apply_rotary_pos_emb_v3.cpp | 11 +- .../apply_rotary_pos_emb.cc | 4 +- .../apply_rotary_pos_emb.md | 62 +++++++++ .../apply_rotary_pos_emb_doc.yaml | 47 ------- .../moe_gating_group_topk.md | 66 +++++++++ .../moe_gating_group_topk_doc.yaml | 47 ------- tests/st/test_custom_rope_v3.py | 129 ++++++++++-------- 10 files changed, 250 insertions(+), 200 deletions(-) create mode 100644 ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.md delete mode 100644 ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb_doc.yaml create mode 100644 ops/c_api/moe_gating_group_topk/moe_gating_group_topk.md delete mode 100644 ops/c_api/moe_gating_group_topk/moe_gating_group_topk_doc.yaml diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc index 63caab0..670c2b7 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc +++ b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc @@ -21,6 +21,9 @@ #include #include #include +#include +#include +#include #include "ops/framework/aclnn/graphmode/aclnn_kernel_mod.h" #include "ops/framework/utils.h" #include "mindspore/include/custom_op_api.h" @@ -28,6 +31,9 @@ namespace ms_custom_ops { // 约束条件: rotary_dim = 2 * cos_head_dim, query_head_dim >= rotary_dim constexpr uint32_t ROTARY_DIM_FACTOR = 2; +constexpr int32_t LAYOUT_BSH = 1; +constexpr const char *LAYOUT_BSH_STR = "BSH"; +constexpr const char *ROTARY_INTERLEAVE_STR = "interleave"; enum class ApplyRotaryPosEmbV3InputIndex : size_t { kApplyRotaryPosEmbV3QueryIndex = 0, kApplyRotaryPosEmbV3KeyIndex, @@ -114,10 +120,10 @@ class OPS_API ApplyRotaryPosEmbV3OpFuncImpl : public OpFuncImpl { ->GetScalarValueWithCheck(); auto layout = input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3LayoutIndex)] ->GetScalarValueWithCheck(); - MS_CHECK_VALUE(layout == "BSH", + MS_CHECK_VALUE(layout == LAYOUT_BSH_STR, CheckAndConvertUtils::FormatCommMsg(op_name, " layout should be 'BSH', but got ", layout)); MS_CHECK_VALUE( - rotary_mode == "interleave", + rotary_mode == ROTARY_INTERLEAVE_STR, CheckAndConvertUtils::FormatCommMsg(op_name, " rotary_mode should be 'interleave', but got ", rotary_mode)); ApplyRotaryPosEmbV3CheckInputsShape(op_name, query_shape, key_shape, cos_shape, sin_shape); return {query_shape, key_shape}; @@ -157,8 +163,13 @@ class ApplyRotaryPosEmbV3Ascend : public AclnnCustomKernelMod { } void GetWorkSpaceInfo(const std::vector &inputs, const std::vector &outputs) override { - layout_ = inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3LayoutIndex)] - ->GetValueWithCheck(); + auto layout_str = inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3LayoutIndex)] + ->GetValueWithCheck(); + if (layout_str == LAYOUT_BSH_STR) { + layout_ = LAYOUT_BSH; + } else { + MS_LOG(EXCEPTION) << "layout should be 'BSH', but got " << layout_str; + } rotary_mode_ = inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3RotaryModeIndex)] ->GetValueWithCheck(); GetWorkspaceForResize(inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)], @@ -171,8 +182,8 @@ class ApplyRotaryPosEmbV3Ascend : public AclnnCustomKernelMod { private: DEFINE_GET_WORKSPACE_FOR_RESIZE(); - std::string layout_ = "BSH"; - std::string rotary_mode_ = "interleave"; + int32_t layout_ = LAYOUT_BSH; + std::string rotary_mode_ = ROTARY_INTERLEAVE_STR; }; } // namespace ms_custom_ops @@ -184,13 +195,11 @@ REG_GRAPH_MODE_OP(apply_rotary_pos_emb_v3, ms_custom_ops::ApplyRotaryPosEmbV3OpF // ============================================================================= namespace ms_custom_ops { -using namespace mindspore; -using namespace mindspore::device::ascend; -constexpr size_t kApplyRotaryPosEmbV3OutputNum = 2; +constexpr size_t kApplyRotaryPosEmbV3OutputNum = 0; -std::vector apply_rotary_pos_emb_v3_custom(const ms::Tensor &query, const ms::Tensor &key, - const ms::Tensor &cos, const ms::Tensor &sin, - const std::string layout_str, const std::string rotary_mode) { +std::vector npu_apply_rotary_pos_emb_v3(const ms::Tensor &query, const ms::Tensor &key, + const ms::Tensor &cos, const ms::Tensor &sin, + const std::string &layout_str, const std::string &rotary_mode) { std::string op_name = "ApplyRotaryPosEmbV3"; // 此处op_name是给人看的, 跟算子命名没有直接关联 auto runner = std::make_shared(op_name); @@ -204,12 +213,19 @@ std::vector apply_rotary_pos_emb_v3_custom(const ms::Tensor &query, runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnApplyRotaryPosEmbV3, query, key, cos, sin, layout_str, rotary_mode)); // 如果是复写算子(inplace), 输出参数为空 runner->Run({query, key, cos, sin}, {}); - // 如果是复写算子(inplace), 将复写的input参数作为output返回 - return {query, key}; + return {}; } } // namespace ms_custom_ops +auto pyboost_apply_rotary_pos_emb_v3(const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &cos, + const ms::Tensor &sin, const std::string &layout_str, + const std::string &rotary_mode) { + return ms::pynative::PyboostRunner::Call( + ms_custom_ops::npu_apply_rotary_pos_emb_v3, query, key, cos, sin, layout_str, rotary_mode); +} + MS_CUSTOM_OPS_EXTENSION_MODULE(m) { - m.def("apply_rotary_pos_emb_v3", - PYBOOST_CALLER(ms_custom_ops::kApplyRotaryPosEmbV3OutputNum, ms_custom_ops::apply_rotary_pos_emb_v3_custom)); -} \ No newline at end of file + m.def("apply_rotary_pos_emb_v3", &pyboost_apply_rotary_pos_emb_v3, "ApplyRotaryPosEmbV3", pybind11::arg("query"), + pybind11::arg("key"), pybind11::arg("cos"), pybind11::arg("sin"), pybind11::arg("layout"), + pybind11::arg("rotary_mode")); +} diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md index 4688814..6ae245f 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md +++ b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md @@ -12,24 +12,19 @@ apply_rotary_pos_emb_v3算子用于计算旋转编码操作。且支持部分数 | key | Tensor(dtype=FP16/FP32) | 3维[tokens, k_head_num, qk_head_dim] | No | Yes | ND | 执行旋转位置编码的第二个变量 | | cos | Tensor(dtype=FP16/FP32) | 2维[tokens, cos_sin_head_dim] | No | No | ND | 表示参与计算的位置编码张量 | | sin | Tensor(dtype=FP16/FP32) | 2维[tokens, cos_sin_head_dim] | No | No | ND | 表示参与计算的位置编码张量 | -| layout | string | No | Yes | No | string | 表示输入Tensor的布局格式 | -| rotary_mode | string | No | Yes | No | string | 表示支持计算公式中的旋转模式 | +| layout | string | No | No | No | string | 表示输入Tensor的布局格式 | +| rotary_mode | string | No | No | No | string | 表示支持计算公式中的旋转模式 | Note: -- 产品支持: Atlas推理系列产品AI Core -- rotary_mode: 当前仅支持'interleave'模式 -- layout: 当前仅支持'BSH' -- dtype: query/key/cos/sin数据类型支持FP16/FP32,且四个输入参数类型一致。 -- head_dim: 令`rotary_head_dim = 2 * cos_sin_head_dim`。 - - 要求`qk_head_dim >= rotary_head_dim`, qk_head_dim 不能小于rotary_head_dim。 - - 当`qk_head_dim > rotary_head_dim`时,只对`query/key[...:rotary_head_dim]` 做旋转位置编码。且(qk_head_dim - rotary_head_dim)* size(dtype)必须能被32整除 - - cos_sin_head_dim * sizeof(dtype) 必须能被32整除 +- 产品支持: Atlas推理系列产品AI Core。 +- rotary_mode: 当前仅支持'interleave'模式。 +- layout: 当前仅支持'BSH'。 +- dtype: query/key/cos/sin数据类型支持FP16/FP32,且四个输入参数类型一致。 +- head_dim: 令`rotary_head_dim = 2 * cos_sin_head_dim`。要求`qk_head_dim >= rotary_head_dim`, qk_head_dim 不能小于rotary_head_dim。当`qk_head_dim > rotary_head_dim`时,只对`query/key[...:rotary_head_dim]` 做旋转位置编码。且`(qk_head_dim - rotary_head_dim)* size(dtype)`必须能被32整除。`cos_sin_head_dim * sizeof(dtype)`必须能被32整除。 ## 输出参数 - - ## 特殊说明 ## 使用示例 @@ -43,9 +38,6 @@ import numpy as np import ms_custom_ops from mindspore import context, Tensor -ms.set_context(device_target="Ascend", mode=context.GRAPH_MODE) -ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - tokens = 4096 head_num_q = 32 head_num_k =2 @@ -63,5 +55,7 @@ query = Tensor(np_query, dtype=query_dtype) key = Tensor(np_key , dtype=query_dtype) cos = Tensor(np_cos, dtype=query_dtype) sin = Tensor(np_sin, dtype=query_dtype) -out_query, out_key = ms_custom_ops.apply_rotary_pos_emb_v3(query, key, cos, sin, layout, rotary_mode) +ms_custom_ops.apply_rotary_pos_emb_v3(query, key, cos, sin, layout, rotary_mode) +print("result query:", query) +print("result key:", key) ``` diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3_op.yaml b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3_op.yaml index d628ce9..8f21785 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3_op.yaml +++ b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3_op.yaml @@ -18,9 +18,5 @@ apply_rotary_pos_emb_v3: labels: side_effect_mem: True returns: - query_embed: - dtype: tensor - inplace: query - key_embed: - dtype: tensor - inplace: key \ No newline at end of file + out: + dtype: tensor \ No newline at end of file diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp b/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp index 7ebb683..7b4b04b 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp +++ b/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp @@ -13,7 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "apply_rotary_pos_emb_v3_tiling.h" +#include +#include "apply_rotary_pos_emb_v3_tiling.h" // NOLINT(build/include) #include "register/op_def_registry.h" #include "graph/utils/type_utils.h" #include "tiling/platform/platform_ascendc.h" @@ -31,7 +32,6 @@ constexpr uint32_t kIndex2 = 2; constexpr uint32_t kDim0 = 0; constexpr uint32_t kDim1 = 1; constexpr uint32_t kDim2 = 2; -constexpr char NODE_NAME[] = "apply_rotary_pos_emb_v3"; static ge::graphStatus ApplyRotaryPosEmbV3Tiling(gert::TilingContext *context) { ApplyRotaryPosEmbV3TilingData tiling; @@ -79,18 +79,13 @@ static ge::graphStatus ApplyRotaryPosEmbV3Tiling(gert::TilingContext *context) { (tiling.get_qHiddenSize() + tiling.get_kHiddenSize()) * ROPE_V3_USE_TBUF_COUNT * ge::GetSizeByDataType(query_type); if (ub_use > ub_total_size) { + // 这里临时用std::cerr输出错误日志,等框架支持log功能后,再改用log std::cerr << "ERROR: " << "(cos.head_dim * 2 + query.hidden_size + key.hidden_size + (query.head_num + key.head_num) * 6) " << "* sizeof(type) should be less than or equal to UB size, but got " << std::to_string(ub_use) << " > " << std::to_string(ub_total_size); return ge::GRAPH_FAILED; } - // OP_CHECK(ub_use > ub_total_size, - // OP_LOGE(context->GetNodeName(), - // "(cos.head_dim * 2 + query.hidden_size + key.hidden_size + (query.head_num + key.head_num) * 6) " - // "* sizeof(type) should be less than or equal to UB size, but got " + - // std::to_string(ub_use) + " > " + std::to_string(ub_total_size)), - // return ge::GRAPH_FAILED); if (query_type == ge::DataType::DT_FLOAT16) { tiling_key = ROPE_V3_TILINGKEY_FP16; } else if (query_type == ge::DataType::DT_FLOAT) { diff --git a/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.cc b/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.cc index bcc96e2..c6fea98 100644 --- a/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.cc +++ b/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.cc @@ -118,7 +118,7 @@ class ApplyRotaryPosEmbRunner : public InternalPyboostRunner { } private: - int32_t cos_format_{0}; + int32_t cos_format_{2}; }; std::vector npu_apply_rotary_pos_emb(const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &cos, @@ -156,5 +156,5 @@ auto pyboost_apply_rotary_pos_emb(const ms::Tensor &query, const ms::Tensor &key MS_CUSTOM_OPS_EXTENSION_MODULE(m) { m.def("apply_rotary_pos_emb", &pyboost_apply_rotary_pos_emb, "ApplyRotaryPosEmb", pybind11::arg("query"), pybind11::arg("key"), pybind11::arg("cos"), pybind11::arg("sin"), pybind11::arg("position_ids"), - pybind11::arg("cos_format") = std::nullopt); + pybind11::arg("cos_format") = 2); } diff --git a/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.md b/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.md new file mode 100644 index 0000000..05fbd5b --- /dev/null +++ b/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.md @@ -0,0 +1,62 @@ + +# apply_rotary_pos_emb 算子 + +## 描述 + +旋转位置编码(Rotary Position Embedding,RoPE),以旋转矩阵的方式在q、k中注入位置信息,使得attention计算时能感受到token的位置关系,在各大模型中,RoPE被广泛应用。RoPE以绝对位置编码的方式实现了相对位置编码,能有效保持位置信息相对关系,并且可以通过编码外推的方式支持超过训练长度的位置编码。支持query/key输入为2/3/4维,其中2维仅支持TH unpad方案。 + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|---------------------|-----------------|----------------------------------------|----------|---------|--------|--------------------------------------------------------| +| query | Tensor(float16/bf16) |[batch, seqlen, hidden_size_q]/ [ntokens, hidden_size_q]/ [batch, seqlen, head_num_q, head_size]| No | No | ND |当前step多个token的query。支持2维TH, 3维BSH, 4维BSND 。 | +| key | Tensor(float16/bf16) |[batch, seqlen, hidden_size_k]/ [ntokens, hidden_size_k]/ [batch, seqlen, head_num_K, head_size]| No | No | ND | 当前step多个token的key。。支持2维TH, 3维BSH, 4维BSND 。 | +| cos | Tensor(float16/float/bf16) | [ntokens, head_size]/ [max_seqlen, head_size] | No | No | ND | ROPE高精度模式,需要输入cos的数据类型为float时生效。 | +| sin | Tensor(float16/float/bf16) | [ntokens, head_size]/ [max_seqlen, head_size] | No | No | ND | ROPE高精度模式,需要输入sin的数据类型为float时生效。 | +| position_ids | Tensor(uint32) | [batch] | No | No | Nd | 在推理prefill阶段表示每个batch的sequence length,在推理decode阶段表示每个batch递推的index。 | +| cos_format | int | | Yes | No | |默认值为2,可取值0,1,2,3。推荐使用2或3,当取值为0或1时sin/cos的shape为[max_seqlen, head_size],当取值为2或3时sin/cos的shape为[ntokens, head_size],当取值为0或2时表示half模式,当取值为1或3时为interleave模式 | + +## 输出参数 + +| Name | DType | Shape | Description | +|--------|------------|------------|-----------------------| +| out_query| Tensor(float16/bf16) | [batch, seqlen, hidden_size_q]/ [ntokens, hidden_size_q]/ [batch, seqlen, head_num_q, head_size]| query旋转位置编码后的结果 | +| out_key| Tensor(float16/bf16) | [batch, seqlen, hidden_size_k]/ [ntokens, hidden_size_k]/ [batch, seqlen, head_num_K, head_size]| key旋转位置编码后的结果 | + +## 约束说明 + +- 输入tensor数据类型需保持一致,高精度模式例外。 +- cos、sin传入数据类型为float时,中间计算结果以float保存。 +- hidden_size_q和hidden_size_k必须是head_size的整数倍,满足`hidden_size_q = head_size * head_num_q、 hidden_size_k = head_size * head_num_k`,其中head_num_q可以大于head_num_k,hidden_size_q和hidden_size_k需要32bytes对齐。 +- ntokens = sum(seqlen[i]),i=0...batch-1。 +- query和key支持2/3/4维,Format均为ND,[ntokens, hidden_size]/[batch, seqlen, hidden_size]/[batch, seqlen, head_num, head_size];当query和key为2维时,仅支持TH unpad方案 +- Decoder阶段要取cos和sin表中seqlen对应的cos/sin值输入。 +- 多batch场景需要组合使用gather算子。 + +## 使用示例 + +```python +import numpy as np +import mindspore as ms +import ms_custom_ops + +inv_freq = 1.0 / (10000 ** (np.arange(0, 128, 2).astype(np.float32) * (1 / 128))) +t = np.arange(2048, dtype=inv_freq.dtype) +freqs = np.outer(t, inv_freq) +emb = np.concatenate((freqs, freqs), axis=-1) +cos = np.cos(emb).astype(np.float16) +sin = np.sin(emb).astype(np.float16) +query = np.random.rand(2, 1, 128).astype(np.float16) +key = np.random.rand(2, 1, 128).astype(np.float16) +position_ids = np.random.randint(0, 2048, [2], dtype=np.int32) +cos = cos[position_ids] +sin = sin[position_ids] +query_tensor = ms.Tensor(query, dtype=ms.float16) +key_tensor = ms.Tensor(key, dtype=ms.float16) +cos_tensor = ms.Tensor(cos, dtype=ms.float16) +sin_tensor = ms.Tensor(sin, dtype=ms.float16) +pos_tensor = ms.Tensor(position_ids, dtype=ms.float16) +out_query, out_key = ms_custom_ops.apply_rotary_pos_emb(query_tensor, key_tensor, cos_tensor, sin_tensor, pos_tensor, 2) +print("query out: ", out_query) +print("key out: ", out_key) +``` diff --git a/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb_doc.yaml b/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb_doc.yaml deleted file mode 100644 index 1232b75..0000000 --- a/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb_doc.yaml +++ /dev/null @@ -1,47 +0,0 @@ -apply_rotary_pos_emb: - description: | - 推理网络为了提升性能,将query和key两路算子融合成一路。执行旋转位置编码计算。 - - Args: - query (Tensor): 表示要执行旋转位置编码的第一个张量,公式中的query,仅支持连续Tensor,数据格式支持BSH和TH,数据类型支持float16、float32、bfloat16。 - key (Tensor): 表示要执行旋转位置编码的第一个张量,公式中的key,仅支持连续Tensor,数据格式支持BSH和TH,数据类型支持float16、float32、bfloat16。 - cos (Tensor): 表示参与计算的位置编码张量,公式中的cos。 - sin (Tensor): 表示参与计算的位置编码张量,公式中的sin。 - position_ids (Tensor): 这是个一维Tensor,在推理网络prefill阶段表示每个batch的sequence length,在推理网络的decode阶段表示每个batch递推的下标。 - cos_format (Tensor): 此参数是cos/sin形态的配置,默认值为2,取值范围(0,1,2,3)。目前推理网络中,基本采用cos_format=2. - cos_format等于0或1时,表示cos/sin采用max sequence length构造Tensor的shape是(max_seqlen, head_dim),0表示cos/sin的值不交替,1则表示交替。 - cos_format等于2或3时,表示cos/sin采用tokens length构造Tensor的shape是(tokens_len, head_dim),2表示cos/sin的值不交替,3则表示交替。 - - Returns: - - Tensor, query经过旋转位置编码后的结果,数据类型和大小于输入相同。 - - Tensor, query经过旋转位置编码后的结果,数据类型和大小于输入相同。 - - Supported Platforms: - ``Atlas 800I A2 推理产品/Atlas 800I A3 推理产品`` - - Examples: - >>> import numpy as np - >>> import mindspore as ms - >>> import ms_custom_ops - >>> ms.set_device("Ascend") - >>> ms.set_context(mode=ms.context.PYNATIVE_MODE) - >>> ms.set_context(jit_config={"jit_level": "O0"}) - >>> inv_freq = 1.0 / (10000 ** (np.arange(0, 128, 2).astype(np.float32) * (1 / 128))) - >>> t = np.arange(2048, dtype=inv_freq.dtype) - >>> freqs = np.outer(t, inv_freq) - >>> emb = np.concatenate((freqs, freqs), axis=-1) - >>> cos = np.cos(emb).astype(np.float16) - >>> sin = np.sin(emb).astype(np.float16) - >>> query = np.random.rand(2, 1, 128).astype(np.float16) - >>> key = np.random.rand(2, 1, 128).astype(np.float16) - >>> position_ids = np.random.randint(0, 2048, [2], dtype=np.int32) - >>> cos = cos[position_ids] - >>> sin = sin[position_ids] - >>> query_tensor = ms.Tensor(query, dtype=ms.float16) - >>> key_tensor = ms.Tensor(key, dtype=ms.float16) - >>> cos_tensor = ms.Tensor(cos, dtype=ms.float16) - >>> sin_tensor = ms.Tensor(sin, dtype=ms.float16) - >>> pos_tensor = ms.Tensor(position_ids, dtype=ms.float16) - >>> out_query, out_key = ms_custom_ops.apply_rotary_pos_emb(query_tensor, key_tensor, cos_tensor, sin_tensor, pos_tensor, 2) - >>> print("query out: ", out_query) - >>> print("key out: ", out_key) \ No newline at end of file diff --git a/ops/c_api/moe_gating_group_topk/moe_gating_group_topk.md b/ops/c_api/moe_gating_group_topk/moe_gating_group_topk.md new file mode 100644 index 0000000..3e8f2aa --- /dev/null +++ b/ops/c_api/moe_gating_group_topk/moe_gating_group_topk.md @@ -0,0 +1,66 @@ +# moe_gating_group_topk算子 + +## 描述 + +moe_gating_group_topk算子实现了MoE专家分组计算,对输入x做Sigmoid计算,对计算结果分组进行排序,最后根据分组排序的结果选取前k个专家。主要应用于Pangu模型。 + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|------|-------|-------|----------|---------|--------|-------------| +| x | Tensor(float16/float32/bf16) | [n,expert] | No | No | ND | 专家组分数 | +| bias | Tensor(float16/float32/bf16) | [expert] | No | No | ND | 用于与x相加,数据类型和格式与x一致,bias当前仅支持None | +| k | int | | No | No | | topk的k值。要求1 <= k <= x_shape[-1] / groupCount * kGroup。 | +| k_group| int | | No | No | | 分组排序后取的group个数。要求1 <= kGroup <= groupCount,并且kGroup * x_shape[-1] / groupCount的值要大于等于k。 | +| group_count | int | | No | No | | 分组的总个数。要求groupCount > 0,x_shape[-1]能够被groupCount整除且整除后的结果大于2,并且整除的结果按照32个数对齐后乘groupCount的结果不大于2048。 | +| group_select_mode | int | | No | No | | 0表示使用最大值对group进行排序, 1表示使用topk2的sum值对group进行排序。仅支持0| +| renorm | int | | No | No | | renorm标记,当前仅支持0,表示先进行norm操作,再计算topk。 | +| norm_type | int | | No | No | | 0表示使用Softmax函数,1表示使用Sigmoid函数。仅支持0。 | +| out_flag | bool | | No | No | | true表示输出,false表示不输出。仅支持false。| +| routed_scaling_factor | double | | No| No | | 用于计算yOut使用的routedScalingFactor系数 | +| eps | double | | No| No | | 用于计算yOut使用的eps系数 | + +注意: + +- enableExpertMapping参数控制是否启用逻辑专家模式。当enableExpertMapping为false时,输入只有x和add_num;当为true时,输入包括x、add_num、mapping_num和mapping_table。 + +- a表示batch大小,b表示专家数量,c表示最大冗余专家数(最多128)。 + +## 输出参数 + +| Name | DType | Shape | Description | +|------|-------|-------|-------------| +| y_out | Tensor(float16/float32/bf16) | [n, topk] | 分组排序topk后计算的结果 | +| expert_idx_out | Tensor(int32) | [n, k] | 专家的序号 | +| norm_out | Tensor(float32) | [n,expert] | norm计算的输出结果, 当前无输出 | + +## 特殊说明 + +- 当前仅支持:`group_select_mode = 0, renorm = 0, norm_type = 0, out_flag = false`。 +- expert能够被group_count整除,expert不超过2048,当前仅支持`1 < group_count < 32,1 < expert/group_count < 64,expert/group_count能被8整除,group_count = k_group = k`。 + +## 使用示例 + +### 基本使用示例(常规模式) + +```python +import numpy as np +import mindspore as ms +import ms_custom_ops + +x = np.random.uniform(-2, 2, (8, 64)).astype(np.float16) +x_tensor = ms.Tensor(x, dtype=ms.float16) +bias = None +k = 4 +k_group = 4 +group_count = 4 +group_select_mode = 0 +renorm = 0 +norm_type = 0 +out_flag = False +routed_scaling_factor = 1.0 +eps = 1e-20 +y_out, expert_idx_out, _ = ms_custom_ops.moe_gating_group_topk(x_tensor, bias, k, k_group, group_count, group_select_mode, renorm, norm_type, out_flag, routed_scaling_factor, eps) +print("y_out:", y_out) +print("expert_idx_out:", expert_idx_out) +``` diff --git a/ops/c_api/moe_gating_group_topk/moe_gating_group_topk_doc.yaml b/ops/c_api/moe_gating_group_topk/moe_gating_group_topk_doc.yaml deleted file mode 100644 index de8723c..0000000 --- a/ops/c_api/moe_gating_group_topk/moe_gating_group_topk_doc.yaml +++ /dev/null @@ -1,47 +0,0 @@ -moe_gating_group_topk: - description: | - MoE计算中,对输入x做Sigmoid计算,对计算结果分组进行排序,最后根据分组排序的结果选取前k个专家. - - Args: - x (Tensor): 两维专家分数Tensor,数据类型支持float16、bfloat16、float32,仅支持连续Tensor - bias (Tensor, optional): 要求是1D的Tensor,要求shape值与x的最后一维相等。数据类型支持float16、bfloat16、float32,数据类型需要与x保持一致。 - k (整型): 每个token最终筛选得到的专家个数,数据类型为int64。要求1≤k≤x.shape[-1]/group_count*k_group。k取值范围为[1, 32]。 - k_groupk (整型): 个token组筛选过程中,选出的专家组个数,数据类型为int64。 - group_count (整型): 表示将全部专家划分的组数,数据类型为int64,当前仅支持group_count = k_groupk = k。 - group_select_mode (整型): 表示一个专家组的总得分计算方式。默认值为0,表示取组内Top2的专家进行得分累加,作为专家组得分。当前仅支持默认值0。 - renorm (整型): renorm标记,当前仅只支持0,表示先进行norm再进行topk计算 - norm_type (整型): 表示norm函数类型,当前仅支持0,表示使用Softmax函数。 - out_flag (布尔类型): 是否输出norm函数中间结果。当前仅支持False,表示不输出。 - routed_scaling_factor (float类型): routed_scaling_factor系数,默认值1.0 - eps (float类型): eps系数,默认值1e-20 - - Returns: - - y_out:Tensor类型,表示对x做norm操作和分组排序topk后计算的结果。要求是一个2D的Tensor,数据类型支持float16、bfloat16、float32, - 数据类型与x需要保持一致,数据格式要求为ND,第一维的大小要求与x的第一维相同,最后一维的大小与k相同。不支持非连续Tensor。 - - expert_idx_out:Tensor类型,表示对x做norm操作和分组排序topk后的索引,即专家的序号。shape要求与yOut一致,数据类型支持int32,数据格式要求为ND。不支持非连续Tensor。 - - norm_out:Tensor类型,norm计算的输出结果。shape要求与x保持一致,数据类型为float32,数据格式要求为ND。不支持非连续Tensor。 - - Supported Platforms: - ``Atlas 800I A2 推理产品/Atlas A3 推理系列产品/Atlas 推理系列产品AI Core`` - - Examples: - >>> import numpy as np - >>> import mindspore as ms - >>> import ms_custom_ops - >>> ms.set_device("Ascend") - >>> ms.set_context(mode=ms.context.PYNATIVE_MODE) - >>> x = np.random.uniform(-2, 2, (8, 64)).astype(np.float16) - >>> x_tensor = ms.Tensor(x, dtype=ms.float16) - >>> bias = None - >>> k = 4 - >>> k_group = 4 - >>> group_count = 4 - >>> group_select_mode = 0 - >>> renorm = 0 - >>> norm_type = 0 - >>> out_flag = False - >>> routed_scaling_factor = 1.0 - >>> eps = 1e-20 - >>> y_out, expert_idx_out, _ = ms_custom_ops.moe_gating_group_topk(x_tensor, bias, k, k_group, group_count, group_select_mode, renorm, norm_type, out_flag, routed_scaling_factor, eps) - >>> print("y_out:", y_out) - >>> print("expert_idx_out:", expert_idx_out) diff --git a/tests/st/test_custom_rope_v3.py b/tests/st/test_custom_rope_v3.py index fb46b33..c490c83 100644 --- a/tests/st/test_custom_rope_v3.py +++ b/tests/st/test_custom_rope_v3.py @@ -12,27 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import os +"""test rove_v3 cases""" import time +from functools import wraps import numpy as np import pytest -from functools import wraps - import ms_custom_ops -import mindspore.ops as ops -import mindspore.nn as nn import mindspore as ms from mindspore.common.api import jit -from mindspore import Tensor, mint, nn, ops, context, Profiler +from mindspore import Tensor, mint, nn, context, Profiler from mindspore.profiler import ProfilerLevel, ProfilerActivity, AicoreMetrics -# from mindspore.common.np_dtype import bfloat16 -from mindspore._c_expression import MSContext + def jit_for_graph_mode(fn): """ A decorator that conditionally applies jit to a function at runtime based on the context mode. """ jitted_fn = jit(fn) + @wraps(fn) def wrapper(*args, **kwargs): if context.get_context("mode") == context.GRAPH_MODE: @@ -40,12 +37,8 @@ def jit_for_graph_mode(fn): return fn(*args, **kwargs) return wrapper -def golden_apply_rotary_emb( - x: Tensor, - cos: Tensor, - sin: Tensor, - is_neox_style: bool, -) -> Tensor: + +def golden_apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, is_neox_style: bool) -> Tensor: """ Args: x: [num_tokens, num_heads, head_size] @@ -65,76 +58,90 @@ def golden_apply_rotary_emb( o2 = x2 * cos + x1 * sin if is_neox_style: return mint.cat((o1, o2), dim=-1) - else: - return mint.stack((o1, o2), dim=-1).flatten(-2) + return mint.stack((o1, o2), dim=-1).flatten(-2) + -def golden_apply_rotary_emb_split(net, exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, rotary_dim, is_profiler=False): +def golden_apply_rotary_emb_split(net, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, + rotary_dim, is_profiler=False): + """golden_apply_rotary_emb_split""" cos_head_dim = rotary_dim // 2 np_query = np.random.random((tokens, head_num_q, head_dim)) np_key = np.random.random((tokens, head_num_k, head_dim)) np_cos = np.random.random((tokens, cos_head_dim)) np_sin = np.random.random((tokens, cos_head_dim)) query = Tensor(np_query, dtype=query_dtype) - key = Tensor(np_key , dtype=query_dtype) + key = Tensor(np_key, dtype=query_dtype) cos = Tensor(np_cos, dtype=query_dtype) sin = Tensor(np_sin, dtype=query_dtype) golden_q = golden_apply_rotary_emb(query, cos, sin, False) golden_k = golden_apply_rotary_emb(key, cos, sin, False) - if is_profiler == False: + if not is_profiler: out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) - np.testing.assert_allclose(golden_q.asnumpy(), out_query.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" query ") - np.testing.assert_allclose(golden_k.asnumpy(), out_key.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" key ") + np.testing.assert_allclose(golden_q.asnumpy( + ), out_query.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" query ") + np.testing.assert_allclose( + golden_k.asnumpy(), out_key.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" key ") else: profiler = Profiler(profiler_level=ProfilerLevel.Level2, - activities=[ProfilerActivity.CPU, ProfilerActivity.NPU], - aic_metrics=AicoreMetrics.AiCoreNone) - for i in range(50): + activities=[ProfilerActivity.CPU, + ProfilerActivity.NPU], + aic_metrics=AicoreMetrics.AiCoreNone) + for _ in range(50): out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) profiler.analyse() + class ApplyRotaryEmbV3Net(nn.Cell): """Reshape and cache operation for NZ/ND format with all parameters""" - + @jit_for_graph_mode def construct(self, query, key, cos, sin, layout, rotary_mode): return ms_custom_ops.apply_rotary_pos_emb_v3(query, key, cos, sin, layout, rotary_mode) -def run_rope_interleave(net, exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, is_profiler=False): + +def run_rope_interleave(net, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, + is_profiler=False): + """rotary_dim same as head_size""" cos_head_dim = head_dim // 2 np_query = np.random.random((tokens, head_num_q, head_dim)) np_key = np.random.random((tokens, head_num_k, head_dim)) np_cos = np.random.random((tokens, cos_head_dim)) np_sin = np.random.random((tokens, cos_head_dim)) query = Tensor(np_query, dtype=query_dtype) - key = Tensor(np_key , dtype=query_dtype) + key = Tensor(np_key, dtype=query_dtype) cos = Tensor(np_cos, dtype=query_dtype) sin = Tensor(np_sin, dtype=query_dtype) golden_q = golden_apply_rotary_emb(query, cos, sin, False) golden_k = golden_apply_rotary_emb(key, cos, sin, False) - if is_profiler == False: - out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) - np.testing.assert_allclose(golden_q.asnumpy(), out_query.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" query ") - np.testing.assert_allclose(golden_k.asnumpy(), out_key.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" key ") + if not is_profiler: + net(query, key, cos, sin, layout, rotary_mode) + np.testing.assert_allclose( + golden_q.asnumpy(), query.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" query ") + np.testing.assert_allclose( + golden_k.asnumpy(), key.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" key ") else: profiler = Profiler(profiler_level=ProfilerLevel.Level2, - activities=[ProfilerActivity.CPU, ProfilerActivity.NPU], - aic_metrics=AicoreMetrics.AiCoreNone) - for i in range(50): - out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) + activities=[ProfilerActivity.CPU, + ProfilerActivity.NPU], + aic_metrics=AicoreMetrics.AiCoreNone) + for _ in range(50): + net(query, key, cos, sin, layout, rotary_mode) profiler.analyse() -def run_rope_interleave_split(net, exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, rotary_dim, is_profiler=False): +def run_rope_interleave_split(net, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, + rotary_dim, is_profiler=False): + """head_size greater rotary_dim""" cos_head_dim = rotary_dim // 2 np_query = np.random.random((tokens, head_num_q, head_dim)) np_key = np.random.random((tokens, head_num_k, head_dim)) np_cos = np.random.random((tokens, cos_head_dim)) np_sin = np.random.random((tokens, cos_head_dim)) query = Tensor(np_query, dtype=query_dtype) - key = Tensor(np_key , dtype=query_dtype) + key = Tensor(np_key, dtype=query_dtype) cos = Tensor(np_cos, dtype=query_dtype) sin = Tensor(np_sin, dtype=query_dtype) - if is_profiler == False: + if not is_profiler: query_rot = query[..., :rotary_dim] query_pass = query[..., rotary_dim:] @@ -143,12 +150,12 @@ def run_rope_interleave_split(net, exec_mode, query_dtype, layout, rotary_mode, query_rot = golden_apply_rotary_emb(query_rot, cos, sin, False) key_rot = golden_apply_rotary_emb(key_rot, cos, sin, False) - + golden_q = mint.cat((query_rot, query_pass), dim=-1) golden_k = mint.cat((key_rot, key_pass), dim=-1) else: start_time = time.perf_counter() - for i in range(50): + for _ in range(50): query_rot = query[..., :rotary_dim] query_pass = query[..., rotary_dim:] @@ -157,7 +164,7 @@ def run_rope_interleave_split(net, exec_mode, query_dtype, layout, rotary_mode, query_rot = golden_apply_rotary_emb(query_rot, cos, sin, False) key_rot = golden_apply_rotary_emb(key_rot, cos, sin, False) - + query = mint.cat((query_rot, query_pass), dim=-1) key = mint.cat((key_rot, key_pass), dim=-1) end_time = time.perf_counter() @@ -165,27 +172,31 @@ def run_rope_interleave_split(net, exec_mode, query_dtype, layout, rotary_mode, print("tokens:", tokens) print("小算子50次平均耗时(毫秒):", total_time * 1000/50) - if is_profiler == False: - out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) - np.testing.assert_allclose(golden_q.asnumpy(), out_query.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" query ") - np.testing.assert_allclose(golden_k.asnumpy(), out_key.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" key ") + if not is_profiler: + net(query, key, cos, sin, layout, rotary_mode) + np.testing.assert_allclose( + golden_q.asnumpy(), query.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" query ") + np.testing.assert_allclose( + golden_k.asnumpy(), key.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" key ") else: profiler = Profiler(profiler_level=ProfilerLevel.Level2, - activities=[ProfilerActivity.CPU, ProfilerActivity.NPU], - aic_metrics=AicoreMetrics.AiCoreNone) - for i in range(50): - out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) + activities=[ProfilerActivity.CPU, + ProfilerActivity.NPU], + aic_metrics=AicoreMetrics.AiCoreNone) + for _ in range(50): + net(query, key, cos, sin, layout, rotary_mode) profiler.analyse() start_time = time.perf_counter() - for i in range(50): - out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) + for _ in range(50): + net(query, key, cos, sin, layout, rotary_mode) end_time = time.perf_counter() total_time = end_time - start_time print("tokens:", tokens) print("大算子50次平均耗时(毫秒):", total_time * 1000/50) -@pytest.mark.level0 + +@pytest.mark.level0 @pytest.mark.env_onecard @pytest.mark.platform_ascend310p @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) @@ -204,8 +215,10 @@ def test_rope_v3_interleave(exec_mode, query_dtype, layout, rotary_mode, tokens, """ ms.set_context(device_target="Ascend", mode=exec_mode) ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - net = ApplyRotaryEmbV3Net() - run_rope_interleave(net, exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim) + net = ApplyRotaryEmbV3Net() + run_rope_interleave(net, query_dtype, layout, rotary_mode, + tokens, head_num_q, head_num_k, head_dim) + @pytest.mark.level0 @pytest.mark.env_onecard @@ -219,7 +232,8 @@ def test_rope_v3_interleave(exec_mode, query_dtype, layout, rotary_mode, tokens, @pytest.mark.parametrize("head_num_k", [2]) @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("rotary_dim", [64]) -def test_rope_v3_interleave_split(exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, rotary_dim): +def test_rope_v3_interleave_split(exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, + head_dim, rotary_dim): """ Feature:aclnnApplyRotaryPosEmb kernel. Description: test for ApplyRotaryPosEmbExt ops. @@ -227,5 +241,6 @@ def test_rope_v3_interleave_split(exec_mode, query_dtype, layout, rotary_mode, t """ ms.set_context(device_target="Ascend", mode=exec_mode) ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - net = ApplyRotaryEmbV3Net() - run_rope_interleave_split(net, exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, rotary_dim) + net = ApplyRotaryEmbV3Net() + run_rope_interleave_split(net, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, + rotary_dim) -- Gitee From 2f9dea648831b110db54130d95e975b6090f430e Mon Sep 17 00:00:00 2001 From: zhangshucheng Date: Fri, 24 Oct 2025 17:13:39 +0800 Subject: [PATCH 3/3] Unify the dynamic and static shape interfaces of all operators Signed-off-by: zhangshucheng --- .../apply_rotary_pos_emb_v3.cc | 9 +- .../op_host/apply_rotary_pos_emb_v3.cpp | 2 +- .../apply_rotary_pos_emb.cc | 4 +- .../apply_rotary_pos_emb.md | 3 +- .../apply_rotary_pos_emb_ext.cc | 34 ++++--- .../apply_rotary_pos_emb_ext.md | 19 +--- .../apply_rotary_pos_emb_ext_op.yaml | 8 +- ops/c_api/fa_update/fa_update_doc.md | 4 +- ops/c_api/fa_update/fa_update_pynative.cc | 31 ++++--- .../mla_preprocess/mla_preprocess_graph.cc | 3 +- .../mla_preprocess/mla_preprocess_pynative.cc | 4 +- .../moe_gating_group_topk.cc | 9 +- .../paged_cache_load_pynative.cc | 21 ++--- .../reshape_and_cache/reshape_and_cache.cc | 18 ++-- ops/c_api/ring_mla/ring_mla_runner.cc | 10 +- ops/c_api/rope/rope.md | 4 +- ops/c_api/trans_data/trans_data.cc | 18 ++-- tests/st/test_apply_rotary_pos_emb_ext.py | 91 ++++++------------- tests/st/test_custom_rope_v3.py | 2 +- 19 files changed, 131 insertions(+), 163 deletions(-) diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc index 670c2b7..83a5404 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc +++ b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc @@ -126,7 +126,8 @@ class OPS_API ApplyRotaryPosEmbV3OpFuncImpl : public OpFuncImpl { rotary_mode == ROTARY_INTERLEAVE_STR, CheckAndConvertUtils::FormatCommMsg(op_name, " rotary_mode should be 'interleave', but got ", rotary_mode)); ApplyRotaryPosEmbV3CheckInputsShape(op_name, query_shape, key_shape, cos_shape, sin_shape); - return {query_shape, key_shape}; + // 复写算子, 无输出. 此处的输出与yaml中算子定义的return相对应, 无实际意义. + return {ShapeVector{1}}; } std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { @@ -140,7 +141,8 @@ class OPS_API ApplyRotaryPosEmbV3OpFuncImpl : public OpFuncImpl { auto sin_dtype = input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)]->GetType(); ApplyRotaryPosEmbV3CheckInputsType(op_name, query_dtype, key_dtype, cos_dtype, sin_dtype); - return {query_dtype, key_dtype}; + // 复写算子, 无输出. 此处的输出与yaml中算子定义的return相对应, 无实际意义. + return {query_dtype}; } bool GeneralInferRegistered() const override { return true; } @@ -213,7 +215,8 @@ std::vector npu_apply_rotary_pos_emb_v3(const ms::Tensor &query, con runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnApplyRotaryPosEmbV3, query, key, cos, sin, layout_str, rotary_mode)); // 如果是复写算子(inplace), 输出参数为空 runner->Run({query, key, cos, sin}, {}); - return {}; + // 复写算子, 无输出. 此处的输出与yaml中算子定义的return相对应, 无实际意义. + return std::vector{ms::Tensor()}; } } // namespace ms_custom_ops diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp b/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp index 7b4b04b..463ccb9 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp +++ b/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ #include -#include "apply_rotary_pos_emb_v3_tiling.h" // NOLINT(build/include) +#include "apply_rotary_pos_emb_v3_tiling.h" // NOLINT(build/include_subdir) #include "register/op_def_registry.h" #include "graph/utils/type_utils.h" #include "tiling/platform/platform_ascendc.h" diff --git a/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.cc b/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.cc index c6fea98..c564304 100644 --- a/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.cc +++ b/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.cc @@ -22,11 +22,13 @@ #include #include #include +#include #include "ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.h" #include "ops/framework/utils.h" namespace ms_custom_ops { +constexpr uint32_t COS_FORMAT_HALF_MODE = 2; enum class ApplyRotaryPosEmbQueryInputIndex : size_t { kApplyRotaryPosEmbQueryIndex = 0, kApplyRotaryPosEmbKeyIndex, @@ -118,7 +120,7 @@ class ApplyRotaryPosEmbRunner : public InternalPyboostRunner { } private: - int32_t cos_format_{2}; + int32_t cos_format_{COS_FORMAT_HALF_MODE}; }; std::vector npu_apply_rotary_pos_emb(const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &cos, diff --git a/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.md b/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.md index 05fbd5b..8d2793c 100644 --- a/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.md +++ b/ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.md @@ -14,7 +14,7 @@ | cos | Tensor(float16/float/bf16) | [ntokens, head_size]/ [max_seqlen, head_size] | No | No | ND | ROPE高精度模式,需要输入cos的数据类型为float时生效。 | | sin | Tensor(float16/float/bf16) | [ntokens, head_size]/ [max_seqlen, head_size] | No | No | ND | ROPE高精度模式,需要输入sin的数据类型为float时生效。 | | position_ids | Tensor(uint32) | [batch] | No | No | Nd | 在推理prefill阶段表示每个batch的sequence length,在推理decode阶段表示每个batch递推的index。 | -| cos_format | int | | Yes | No | |默认值为2,可取值0,1,2,3。推荐使用2或3,当取值为0或1时sin/cos的shape为[max_seqlen, head_size],当取值为2或3时sin/cos的shape为[ntokens, head_size],当取值为0或2时表示half模式,当取值为1或3时为interleave模式 | +| cos_format | int | | No | No | |可取值0,1,2,3。推荐使用2或3,当取值为0或1时sin/cos的shape为[max_seqlen, head_size],当取值为2或3时sin/cos的shape为[ntokens, head_size],当取值为0或2时表示half模式,当取值为1或3时为interleave模式 | ## 输出参数 @@ -25,6 +25,7 @@ ## 约束说明 +- 支持产品:Atlas 800I A2 推理产品 - 输入tensor数据类型需保持一致,高精度模式例外。 - cos、sin传入数据类型为float时,中间计算结果以float保存。 - hidden_size_q和hidden_size_k必须是head_size的整数倍,满足`hidden_size_q = head_size * head_num_q、 hidden_size_k = head_size * head_num_k`,其中head_num_q可以大于head_num_k,hidden_size_q和hidden_size_k需要32bytes对齐。 diff --git a/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.cc b/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.cc index f639d89..fef6fc9 100644 --- a/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.cc +++ b/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.cc @@ -21,6 +21,9 @@ #include #include #include +#include +#include + #include "ops/framework/aclnn/graphmode/aclnn_kernel_mod.h" #include "ops/framework/utils.h" @@ -59,7 +62,7 @@ static std::set apply_rotary_pos_emb_layout_mode_set = { "SBND", }; -static size_t GetRopeLayout(std::string layout_str) { +static size_t GetRopeLayout(const std::string &layout_str) { if (layout_str == "BSH" || layout_str == "BSND") { return static_cast(ApplyRotaryPosEmbExtLayoutMode::LAYOUT_BSND_BSH); } else if (layout_str == "BNSD") { @@ -80,7 +83,7 @@ ShapeArray ApplyRotaryPosEmbExtMakeShape(const ShapeVector query_shape, const Sh "For ApplyRotaryPosEmbExt, cos must be a 4D tensor, but got shape " + ShapeVectorToStr(cos_shape)); MS_CHECK_VALUE(sin_shape.size() == kApplyRotaryPosEmbExtShapeSize, "For ApplyRotaryPosEmbExt, sin must be a 4D tensor, but got shape " + ShapeVectorToStr(sin_shape)); - return {query_shape, key_shape}; + return {ShapeVector{1}}; } class OPS_API ApplyRotaryPosEmbExtCustomOpFuncImpl : public OpFuncImpl { @@ -127,7 +130,7 @@ class OPS_API ApplyRotaryPosEmbExtCustomOpFuncImpl : public OpFuncImpl { MS_LOG(EXCEPTION) << "'ApplyRotaryPosEmbExt' only support [" << kAscendVersion910b << ", " << kAscendVersion910_93 << ", " << kAscendVersion310p << "], but got " << soc_version; } - return {query_dtype, key_dtype}; + return {query_dtype}; } bool GeneralInferRegistered() const override { return true; } @@ -173,25 +176,30 @@ REG_GRAPH_MODE_OP(apply_rotary_pos_emb_ext, ms_custom_ops::ApplyRotaryPosEmbExtC // ============================================================================= namespace ms_custom_ops { -using namespace mindspore; -using namespace mindspore::device::ascend; -constexpr size_t kApplyRotaryPosEmbExtOutputNum = 2; +constexpr size_t kApplyRotaryPosEmbExtOutputNum = 0; std::vector apply_rotary_pos_emb_ext_custom(const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &cos, const ms::Tensor &sin, - const std::string layout_str, const std::string rotary_mode) { + const std::string &layout_str, const std::string &rotary_mode) { (void)ApplyRotaryPosEmbExtMakeShape(query.shape(), key.shape(), cos.shape(), sin.shape()); auto layout_mode = GetRopeLayout(layout_str); - auto outputs = {ms::Tensor(query.data_type(), query.shape()), ms::Tensor(key.data_type(), key.shape())}; auto runner = std::make_shared("aclnnApplyRotaryPosEmbV2"); runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnApplyRotaryPosEmbV2, query, key, cos, sin, layout_mode, rotary_mode)); // only set tensor. - runner->Run({query, key, cos, sin}, outputs); - return outputs; + runner->Run({query, key, cos, sin}, {}); + return std::vector{ms::Tensor()}; } } // namespace ms_custom_ops +auto pyboost_apply_rotary_pos_emb_ext(const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &cos, + const ms::Tensor &sin, const std::string &layout_str, + const std::string &rotary_mode) { + return ms::pynative::PyboostRunner::Call( + ms_custom_ops::apply_rotary_pos_emb_ext_custom, query, key, cos, sin, layout_str, rotary_mode); +} + MS_CUSTOM_OPS_EXTENSION_MODULE(m) { - m.def("apply_rotary_pos_emb_ext", - PYBOOST_CALLER(ms_custom_ops::kApplyRotaryPosEmbExtOutputNum, ms_custom_ops::apply_rotary_pos_emb_ext_custom)); -} \ No newline at end of file + m.def("apply_rotary_pos_emb_ext", &pyboost_apply_rotary_pos_emb_ext, "ApplyRotaryPosEmbExt", pybind11::arg("query"), + pybind11::arg("key"), pybind11::arg("cos"), pybind11::arg("sin"), pybind11::arg("layout"), + pybind11::arg("rotary_mode")); +} diff --git a/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.md b/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.md index 3e1e766..57115c8 100644 --- a/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.md +++ b/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.md @@ -12,12 +12,12 @@ apply_rotary_pos_emb_ext算子用于计算旋转编码操作。该算子底层 | key | Tensor | 4维[batch_size, seq_len, k_head_num, head_dim] | No | No | ND | 执行旋转位置编码的第二个变量 | | cos | Tensor | 4维[batch_size, seq_len, 1, head_dim] | No | No | ND | 表示参与计算的位置编码张量 | | sin | Tensor | 4维[batch_size, seq_len, 1, head_dim] | No | No | ND | 表示参与计算的位置编码张量 | -| layout | string | No | Yes | No | string | 表示输入Tensor的布局格式 | -| rotary_mode | string | No | Yes | No | string | 表示支持计算公式中的旋转模式 | +| layout | string | No | No | No | string | 表示输入Tensor的布局格式 | +| rotary_mode | string | No | No | No | string | 表示支持计算公式中的旋转模式 | Note: head_dim当前只支持128. -910B/910C机器上: +Atlas推理系列产品A2, Atlas推理系列产品A3: rotary_mode只支持"half". layout只支持"BSND". query shape为[batch_size, seq_len, q_head_num, head_dim]. 支持类型为:BF16/FP16/FP32. @@ -31,24 +31,15 @@ query shape为[batch_size, seq_len, q_head_num, head_dim]. 支持类型为:FP16/ key shape大小为[batch_size, seq_len, k_head_num, head_dim].支持类型为:FP16/FP32. cos/sin shape大小为[batch_size, seq_len, 1, head_dim].支持类型为:FP16/FP32. -此外注意,ub_required = (q_n + k_n) * 128 * castSize * 2 + 128 * DtypeSize * 4 + (q_n + k_n) * 128 * castSize + (q_n + k_n) * 128 * castSize * 2 + cast * (128 * 4 * 2), 当计算出ub_required的大小超过当前AI处理器的UB空间总大小时,不支持使用该融合算子. +此外注意,`ub_required = (q_n + k_n) * 128 * castSize * 2 + 128 * DtypeSize * 4 + (q_n + k_n) * 128 * castSize + (q_n + k_n) * 128 * castSize * 2 + cast * (128 * 4 * 2)`,当计算出ub_required的大小超过当前AI处理器的UB空间总大小时,不支持使用该融合算子. 不支持空tensor场景. ## 输出参数 -| Name | DType | Shape | Description | -|--------|------------|------------|-----------------------| -| query_emb| Tensor | [batch_size, seq_len, q_head_num, head_dim] | query旋转位置编码后的结果 | -| key_emb | Tensor | [batch_size, seq_len, k_head_num, head_dim] | key旋转位置编码后的结果 | - -query_emb数据类型和query相同,shape大小一样。 -key_emb数据类型和key相同,shape大小一样。 +## 特殊说明 更多详细信息请参考:[aclnnApplyRotaryPosEmbV2](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/API/aolapi/context/aclnnApplyRotaryPosEmbV2.md) - -## 特殊说明 - ## 使用示例 ### 基本使用示例 diff --git a/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext_op.yaml b/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext_op.yaml index ebff55c..da37cc1 100644 --- a/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext_op.yaml +++ b/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext_op.yaml @@ -18,9 +18,5 @@ apply_rotary_pos_emb_ext: labels: side_effect_mem: True returns: - query_embed: - dtype: tensor - inplace: query - key_embed: - dtype: tensor - inplace: key \ No newline at end of file + out: + dtype: tensor \ No newline at end of file diff --git a/ops/c_api/fa_update/fa_update_doc.md b/ops/c_api/fa_update/fa_update_doc.md index 84c30e0..cd33668 100644 --- a/ops/c_api/fa_update/fa_update_doc.md +++ b/ops/c_api/fa_update/fa_update_doc.md @@ -10,8 +10,8 @@ fa_update算子实现了attention部分中间结果卡间同步。 |------|-------|-------|----------|---------|--------|-------------| | lse | Tensor(fp32) | [sp, batch * seqLen * headNum] | No | No | ND | 输入tensor,数据类型为float32, 各SP域计算的lse | | local_out | Tensor(fp32) | [sp, batch * seqLen * headNum, head_size] | No | No | ND | 输入tensor,数据类型为float32, 各SP域计算的output | -| fa_update_type | int | | No | No | | 指定下标需要执行的操作类型,目前只支持0:DECODE_UPDATE | -| sp | int | | No | No | | 序列并行的并行度SP,取值范围[1, 8] | +| fa_update_type | int | | No | No | |指定下标需要执行的操作类型,目前只支持0:DECODE_UPDATE | +| sp | int | | No | No | |序列并行的并行度SP,取值范围[1, 8] | 注意: diff --git a/ops/c_api/fa_update/fa_update_pynative.cc b/ops/c_api/fa_update/fa_update_pynative.cc index b1a7bbe..f8ed745 100644 --- a/ops/c_api/fa_update/fa_update_pynative.cc +++ b/ops/c_api/fa_update/fa_update_pynative.cc @@ -17,7 +17,8 @@ // ============================================================================= // PYBOOST MODE IMPLEMENTATION // ============================================================================= - +#include +#include #include "ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h" #include "ops/c_api/fa_update/fa_update_common.h" #include "ops/framework/utils.h" @@ -41,8 +42,8 @@ class FaUpdateRunner : public InternalPyboostRunner { bool sp_{1}; }; -std::vector npu_fa_update(const ms::Tensor &lse, const ms::Tensor &local_out, - const uint64_t fa_update_type, const uint64_t sp) { +std::vector npu_fa_update(const ms::Tensor &lse, const ms::Tensor &local_out, const uint64_t fa_update_type, + const uint64_t sp) { auto op_name = "FaUpdate"; auto runner = std::make_shared(op_name); MS_EXCEPTION_IF_NULL(runner); @@ -56,13 +57,14 @@ std::vector npu_fa_update(const ms::Tensor &lse, const ms::Tensor &l // Setup the runner with all parameters (including hash calculation) runner->Setup(op_name, lse, local_out, fa_update_type, sp); std::vector inputs = {lse, local_out}; - MS_CHECK_VALUE(local_out.shape().size() == kFaUpdateLocalOutShapeRank, - CheckAndConvertUtils::FormatCommMsg( - "For FaUpdate, local_out dim must be 3, but got : ", local_out.shape().size())); + MS_CHECK_VALUE( + local_out.shape().size() == kFaUpdateLocalOutShapeRank, + CheckAndConvertUtils::FormatCommMsg("For FaUpdate, local_out dim must be 3, but got : ", local_out.shape().size())); auto head_size = local_out.shape()[kIndex2]; - MS_CHECK_VALUE(head_size >= 8 && head_size <= 512 && ALIGN_8(head_size), - CheckAndConvertUtils::FormatCommMsg( - "For FaUpdate, head_size must be in range [8, 512] and be the multiple of 8, but got : ", head_size)); + MS_CHECK_VALUE( + head_size >= 8 && head_size <= 512 && ALIGN_8(head_size), + CheckAndConvertUtils::FormatCommMsg( + "For FaUpdate, head_size must be in range [8, 512] and be the multiple of 8, but got : ", head_size)); ShapeVector output_shape{local_out.shape()[kIndex1], local_out.shape()[kIndex2]}; auto output_tensor = ms::Tensor(local_out.data_type(), output_shape); std::vector outputs = {output_tensor}; @@ -71,14 +73,13 @@ std::vector npu_fa_update(const ms::Tensor &lse, const ms::Tensor &l return outputs; } -auto pyboost_fa_update(const ms::Tensor &lse, const ms::Tensor &local_out, - const uint64_t fa_update_type, const uint64_t sp) { - return ms::pynative::PyboostRunner::Call<1>( - npu_fa_update, lse, local_out, fa_update_type, sp); +auto pyboost_fa_update(const ms::Tensor &lse, const ms::Tensor &local_out, const uint64_t fa_update_type, + const uint64_t sp) { + return ms::pynative::PyboostRunner::Call<1>(npu_fa_update, lse, local_out, fa_update_type, sp); } } // namespace ms_custom_ops MS_CUSTOM_OPS_EXTENSION_MODULE(m) { - m.def("fa_update", &ms_custom_ops::pyboost_fa_update, "FaUpdate", - pybind11::arg("lse"), pybind11::arg("local_out"), pybind11::arg("fa_update_type"), pybind11::arg("sp")); + m.def("fa_update", &ms_custom_ops::pyboost_fa_update, "FaUpdate", pybind11::arg("lse"), pybind11::arg("local_out"), + pybind11::arg("fa_update_type") = 0, pybind11::arg("sp") = 1); } diff --git a/ops/c_api/mla_preprocess/mla_preprocess_graph.cc b/ops/c_api/mla_preprocess/mla_preprocess_graph.cc index bdaa263..c874f98 100644 --- a/ops/c_api/mla_preprocess/mla_preprocess_graph.cc +++ b/ops/c_api/mla_preprocess/mla_preprocess_graph.cc @@ -13,7 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#include +#include #include "ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.h" #include "ops/c_api/mla_preprocess/mla_preprocess_common.h" diff --git a/ops/c_api/mla_preprocess/mla_preprocess_pynative.cc b/ops/c_api/mla_preprocess/mla_preprocess_pynative.cc index 04155fd..43e2ec0 100644 --- a/ops/c_api/mla_preprocess/mla_preprocess_pynative.cc +++ b/ops/c_api/mla_preprocess/mla_preprocess_pynative.cc @@ -17,6 +17,8 @@ // ============================================================================= // PYBOOST MODE IMPLEMENTATION // ============================================================================= +#include +#include #include "ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h" #include "ops/c_api/mla_preprocess/mla_preprocess_common.h" @@ -121,5 +123,5 @@ MS_CUSTOM_OPS_EXTENSION_MODULE(m) { pybind11::arg("cos1"), pybind11::arg("sin2"), pybind11::arg("cos2"), pybind11::arg("key_cache"), pybind11::arg("slot_mapping"), pybind11::arg("wuq"), pybind11::arg("bias2"), pybind11::arg("wuk"), pybind11::arg("de_scale1"), pybind11::arg("de_scale2"), pybind11::arg("ctkv_scale"), - pybind11::arg("qnope_scale"), pybind11::arg("krope_cache"), pybind11::arg("param_cache_mode")); + pybind11::arg("qnope_scale"), pybind11::arg("krope_cache"), pybind11::arg("param_cache_mode") = 0); } diff --git a/ops/c_api/moe_gating_group_topk/moe_gating_group_topk.cc b/ops/c_api/moe_gating_group_topk/moe_gating_group_topk.cc index 196ed9a..255a81d 100644 --- a/ops/c_api/moe_gating_group_topk/moe_gating_group_topk.cc +++ b/ops/c_api/moe_gating_group_topk/moe_gating_group_topk.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.h" @@ -224,8 +225,8 @@ auto pyboost_moe_gating_group_topk(const ms::Tensor &x, const std::optional +#include +#include "ops/c_api/paged_cache_load/paged_cache_load_common.h" namespace ms_custom_ops { class PagedCacheLoadRunner : public InternalPyboostRunner { @@ -70,21 +70,19 @@ std::vector npu_paged_cache_load(const ms::Tensor &key_cache, const int64_t sum_context_lens = abstract::Shape::kShapeDimAny; if (seq_lens.data_type() != mindspore::TypeId::kNumberTypeInt32) { - MS_LOG(EXCEPTION) << "For " << op_name - << ", the seq_lens dtype must be int32, but got : " - << seq_lens.data_type(); + MS_LOG(EXCEPTION) << "For " << op_name << ", the seq_lens dtype must be int32, but got : " << seq_lens.data_type(); } if (seq_lens.GetDataPtr() != nullptr) { if (is_seq_lens_cumsum_type.value()) { - int32_t * seq_lens_ptr = static_cast(seq_lens.GetDataPtr()); - for (size_t i = 0; i < seq_lens.numel(); i ++) { + int32_t *seq_lens_ptr = static_cast(seq_lens.GetDataPtr()); + for (size_t i = 0; i < seq_lens.numel(); i++) { sum_context_lens = seq_lens_ptr[seq_lens.numel() - 1]; } } else { sum_context_lens = 0; - int32_t * seq_lens_ptr = static_cast(seq_lens.GetDataPtr()); - for (size_t i = 0; i < seq_lens.numel(); i ++) { + int32_t *seq_lens_ptr = static_cast(seq_lens.GetDataPtr()); + for (size_t i = 0; i < seq_lens.numel(); i++) { sum_context_lens += seq_lens_ptr[i]; } } @@ -129,7 +127,6 @@ auto pyboost_paged_cache_load(const ms::Tensor &key_cache, const ms::Tensor &val MS_CUSTOM_OPS_EXTENSION_MODULE(m) { m.def("paged_cache_load", &ms_custom_ops::pyboost_paged_cache_load, "Paged Cache Load", pybind11::arg("key_cache"), pybind11::arg("value_cache"), pybind11::arg("block_table"), pybind11::arg("seq_lens"), - pybind11::arg("seq_starts") = std::nullopt, pybind11::arg("kv_cache_cfg") = std::nullopt, - pybind11::arg("is_seq_lens_cumsum_type") = std::nullopt, - pybind11::arg("has_seq_starts") = std::nullopt); + pybind11::arg("seq_starts") = std::nullopt, pybind11::arg("kv_cache_cfg") = 0, + pybind11::arg("is_seq_lens_cumsum_type") = false, pybind11::arg("has_seq_starts") = false); } diff --git a/ops/c_api/reshape_and_cache/reshape_and_cache.cc b/ops/c_api/reshape_and_cache/reshape_and_cache.cc index 1879bca..af9a7b6 100644 --- a/ops/c_api/reshape_and_cache/reshape_and_cache.cc +++ b/ops/c_api/reshape_and_cache/reshape_and_cache.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include "ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.h" @@ -91,14 +92,13 @@ class CustomReshapeAndCache : public InternalKernelMod { for (const auto &input : inputs) { if (input == nullptr) continue; auto shape = input->GetShapeVector(); - for (const auto &dim : shape) { - if (dim == 0) { - MS_LOG(INFO) << "ReshapeAndCache: Skipping execution due to zero " - "dimension in input shape: " - << shape; - skip_execution_ = true; - return KernelMod::Resize(inputs, outputs); // Skip execution - } + bool has_zero = std::any_of(shape.begin(), shape.end(), [](const auto &dim) { return dim == 0; }); + if (has_zero) { + MS_LOG(INFO) << "ReshapeAndCache: Skipping execution due to zero " + "dimension in input shape: " + << shape; + skip_execution_ = true; + return KernelMod::Resize(inputs, outputs); // Skip execution } } @@ -213,5 +213,5 @@ MS_CUSTOM_OPS_EXTENSION_MODULE(m) { m.def("reshape_and_cache", &pyboost_reshape_and_cache, "Reshape And Cache", pybind11::arg("key"), pybind11::arg("value") = std::nullopt, pybind11::arg("key_cache") = std::nullopt, pybind11::arg("value_cache") = std::nullopt, pybind11::arg("slot_mapping") = std::nullopt, - pybind11::arg("cache_mode"), pybind11::arg("head_num")); + pybind11::arg("cache_mode") = 0, pybind11::arg("head_num") = 0); } diff --git a/ops/c_api/ring_mla/ring_mla_runner.cc b/ops/c_api/ring_mla/ring_mla_runner.cc index 8e587c5..87c4aad 100644 --- a/ops/c_api/ring_mla/ring_mla_runner.cc +++ b/ops/c_api/ring_mla/ring_mla_runner.cc @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include +#include +#include #include "ops/c_api/ring_mla/ring_mla_runner.h" #include "ops/framework/utils.h" -using namespace ms_custom_ops; namespace ms_custom_ops { namespace { @@ -179,7 +181,7 @@ MS_CUSTOM_OPS_EXTENSION_MODULE(m) { pybind11::arg("deq_offset_qk") = std::nullopt, pybind11::arg("deq_scale_pv") = std::nullopt, pybind11::arg("deq_offset_pv") = std::nullopt, pybind11::arg("quant_p") = std::nullopt, pybind11::arg("log_n") = std::nullopt, pybind11::arg("o_prev") = std::nullopt, - pybind11::arg("lse_prev") = std::nullopt, pybind11::arg("q_seq_lens"), pybind11::arg("context_lens"), - pybind11::arg("head_num"), pybind11::arg("scale_value"), pybind11::arg("kv_head_num"), - pybind11::arg("mask_type"), pybind11::arg("calc_type")); + pybind11::arg("lse_prev") = std::nullopt, pybind11::arg("q_seq_lens") = std::nullopt, + pybind11::arg("context_lens") = std::nullopt, pybind11::arg("head_num") = 0, pybind11::arg("scale_value") = 1.0, + pybind11::arg("kv_head_num") = 0, pybind11::arg("mask_type") = 0, pybind11::arg("calc_type") = 0); } diff --git a/ops/c_api/rope/rope.md b/ops/c_api/rope/rope.md index 1ff37dd..55641c3 100644 --- a/ops/c_api/rope/rope.md +++ b/ops/c_api/rope/rope.md @@ -18,8 +18,8 @@ | cos | Tensor(float16/float/bf16) | [ntokens, head_size] / [ntokens, head_size / 2] | No | No | ND | 当cos的第二个维度与参数rotaryCoeff不相等时,其值为head_size。ROPE高精度模式,需要输入cos的数据类型为float时生效。cos的第二个维度需要是参数rotaryCoeff的整数倍。 | | sin | Tensor(float16/float/bf16) | [ntokens, head_size] / [ntokens, head_size / 2] | No | No | ND | 当sin的第二个维度与参数rotaryCoeff不相等时,其值为head_size。ROPE高精度模式,需要输入sin的数据类型为float时生效。sin的第二个维度需要是参数rotaryCoeff的整数倍。 | | seq_len | Tensor(uint32/int32) | [batch] | No | No | Nd | | -| rotary_coeff | int | | No | No | | rope,旋转系数,对半旋转是2,支持配置2、4、head_size / 2、head_size。 | -| cos_format | int | | Yes | No | | 训练用参数,支持配置0或1,推理采用默认值0 | +| rotary_coeff | int | | No | No | | 旋转系数,对半旋转是2,支持配置2、4、head_size/2 | +| cos_format | int | | Yes | No | | 训练用参数,支持配置0或1,推理采用默认值0 | ## 输出参数 diff --git a/ops/c_api/trans_data/trans_data.cc b/ops/c_api/trans_data/trans_data.cc index 3347c40..e8fe62f 100644 --- a/ops/c_api/trans_data/trans_data.cc +++ b/ops/c_api/trans_data/trans_data.cc @@ -18,6 +18,8 @@ #include #include #include +#include + #include "ops/framework/utils.h" #include "ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.h" @@ -101,12 +103,11 @@ class CustomTransData : public InternalKernelMod { for (const auto &input : inputs) { if (input == nullptr) continue; auto shape = input->GetShapeVector(); - for (const auto &dim : shape) { - if (dim == 0) { - MS_LOG(INFO) << "TransData: Skipping execution due to zero dimension in input shape: " << shape; - skip_execution_ = true; - return KernelMod::Resize(inputs, outputs); // Skip execution - } + bool has_zero = std::any_of(shape.begin(), shape.end(), [](const auto &dim) { return dim == 0; }); + if (has_zero) { + MS_LOG(INFO) << "TransData: Skipping execution due to zero dimension in input shape: " << shape; + skip_execution_ = true; + return KernelMod::Resize(inputs, outputs); // Skip execution } } @@ -204,6 +205,5 @@ auto pyboost_trans_data(const ms::Tensor &input, std::optional transdat } MS_CUSTOM_OPS_EXTENSION_MODULE(m) { - m.def("trans_data", &pyboost_trans_data, "Trans Data", pybind11::arg("input"), - pybind11::arg("transdata_type") = std::nullopt); -} \ No newline at end of file + m.def("trans_data", &pyboost_trans_data, "Trans Data", pybind11::arg("input"), pybind11::arg("transdata_type") = 0); +} diff --git a/tests/st/test_apply_rotary_pos_emb_ext.py b/tests/st/test_apply_rotary_pos_emb_ext.py index 8fb8d97..55c9732 100644 --- a/tests/st/test_apply_rotary_pos_emb_ext.py +++ b/tests/st/test_apply_rotary_pos_emb_ext.py @@ -12,21 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import os +"""apply_rotary_pos_emb_ext test case""" +from functools import wraps import numpy as np import pytest -from functools import wraps -import mindspore.ops as ops -import mindspore.nn as nn import mindspore as ms from mindspore import context, Tensor from mindspore.common.np_dtype import bfloat16 -from mindspore._c_expression import MSContext import ms_custom_ops def get_ms_dtype(query_dtype): + ms_dtype = ms.float32 if query_dtype == np.float32: ms_dtype = ms.float32 elif query_dtype == np.float16: @@ -50,27 +48,27 @@ def apply_rotary_pos_emb_ext(query, key, cos, sin, layout, rotary_mode="half"): q_embed: 旋转位置编码后的query k_embed: 旋转位置编码后的key """ + print("layout:", layout) if rotary_mode == "half": return apply_rotary_pos_emb_half(query, key, cos, sin) - elif rotary_mode == "quarter": + if rotary_mode == "quarter": return apply_rotary_pos_emb_quarter(query, key, cos, sin) - elif rotary_mode == "interleave": + if rotary_mode == "interleave": return apply_rotary_pos_emb_interleave(query, key, cos, sin) - else: - raise ValueError(f"Unsupported rotary mode: {rotary_mode}") + raise ValueError(f"Unsupported rotary mode: {rotary_mode}") def apply_rotary_pos_emb_half(query, key, cos, sin): """Half模式旋转位置编码的numpy实现(Golden函数)""" # 处理query query_q1 = query[..., : query.shape[-1] // 2] - query_q2 = query[..., query.shape[-1] // 2 :] + query_q2 = query[..., query.shape[-1] // 2:] query_rotate = np.concatenate((-query_q2, query_q1), axis=-1) q_embed = query * cos + query_rotate * sin # 处理key key_k1 = key[..., : key.shape[-1] // 2] - key_k2 = key[..., key.shape[-1] // 2 :] + key_k2 = key[..., key.shape[-1] // 2:] key_rotate = np.concatenate((-key_k2, key_k1), axis=-1) k_embed = key * cos + key_rotate * sin @@ -89,7 +87,8 @@ def apply_rotary_pos_emb_quarter(query, key, cos, sin): query_q3 = query[..., half_idx:three_quarter_idx] query_q4 = query[..., three_quarter_idx:] - query_rotate = np.concatenate((-query_q2, query_q1, -query_q4, query_q3), axis=-1) + query_rotate = np.concatenate( + (-query_q2, query_q1, -query_q4, query_q3), axis=-1) q_embed = query * cos + query_rotate * sin # 处理key @@ -115,7 +114,8 @@ def apply_rotary_pos_emb_interleave(query, key, cos, sin): query_q1_flat = query_q1.reshape(-1, 1) query_q2_flat = query_q2.reshape(-1, 1) - query_rotate_flat = np.concatenate((-query_q2_flat, query_q1_flat), axis=-1) + query_rotate_flat = np.concatenate( + (-query_q2_flat, query_q1_flat), axis=-1) query_rotate = query_rotate_flat.reshape(orig_shape) q_embed = query * cos + query_rotate * sin @@ -136,6 +136,9 @@ def apply_rotary_pos_emb_interleave(query, key, cos, sin): def jit(func): + """ + A decorator that conditionally applies jit to a function at runtime based on the context mode. + """ @wraps(func) def decorator(*args, **kwargs): if ms.get_context("mode") == "PYNATIVE_MODE": @@ -146,6 +149,7 @@ def jit(func): class ApplyRotaryPosEmbNet(ms.nn.Cell): + """ApplyRotaryPosEmbNet""" def _init__(self): super().__init__() @@ -157,21 +161,8 @@ class ApplyRotaryPosEmbNet(ms.nn.Cell): return query_embed, key_embed -def run( - net, - base, - cos_dtype, - seq_len, - batch_size, - num_head, - hidden_dim, - max_seq_len, - query_dtype, - pos_dtype, - ndim, - cos_format, - rotary_mode="half", -): +def run(net, cos_dtype, seq_len, batch_size, num_head, hidden_dim, query_dtype, rotary_mode="half"): + """run case""" query_data = np.random.uniform( 0, 1, [batch_size, seq_len, num_head, hidden_dim] ).astype(query_dtype) @@ -202,14 +193,12 @@ def run( query2 = Tensor(query2, dtype=get_ms_dtype(query_dtype)) key2 = Tensor(key2, dtype=get_ms_dtype(query_dtype)) - cos2 = Tensor(cos2, dtype=get_ms_dtype(query_dtype)) - sin2 = Tensor(sin2, dtype=get_ms_dtype(query_dtype)) + cos2 = Tensor(cos2, dtype=get_ms_dtype(cos_dtype)) + sin2 = Tensor(sin2, dtype=get_ms_dtype(cos_dtype)) - custom_query_emb, custom_key_emb = net( - query2, key2, cos2, sin2, "BSND", rotary_mode - ) - np.testing.assert_allclose(golden_query_emb, custom_query_emb, rtol=1e-2, atol=1e-2) - np.testing.assert_allclose(golden_key_emb, custom_key_emb, rtol=1e-2, atol=1e-2) + net(query2, key2, cos2, sin2, "BSND", rotary_mode) + np.testing.assert_allclose(golden_query_emb, query2, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(golden_key_emb, key2, rtol=1e-2, atol=1e-2) @pytest.mark.level0 @@ -219,45 +208,19 @@ def run( @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) @pytest.mark.parametrize("query_dtype", [np.float16]) @pytest.mark.parametrize("cos_dtype", [np.float16]) -@pytest.mark.parametrize("cos_format", [2]) @pytest.mark.parametrize("rotary_mode", ["half"]) @pytest.mark.parametrize("batch_size", [1, 16]) @pytest.mark.parametrize("seq_len", [1, 256, 512, 1024]) @pytest.mark.parametrize("num_head", [16, 32]) -def test_rope_float16( - exec_mode, - query_dtype, - cos_dtype, - cos_format, - rotary_mode, - batch_size, - seq_len, - num_head, -): +def test_rope_float16(exec_mode, query_dtype, cos_dtype, rotary_mode, batch_size, seq_len, num_head): """ Feature:aclnnApplyRotaryPosEmb kernel. Description: test for ApplyRotaryPosEmbExt ops. Expectation:should pass for all testcases. """ - ndim = 4 hidden_dim = 128 - base = 10000 - max_seq_len = seq_len ms.set_context(device_target="Ascend", mode=exec_mode) ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) net = ApplyRotaryPosEmbNet() - run( - net, - base, - cos_dtype, - seq_len, - batch_size, - num_head, - hidden_dim, - max_seq_len, - query_dtype, - np.int32, - ndim, - cos_format, - rotary_mode, - ) + run(net, cos_dtype, seq_len, batch_size, num_head, + hidden_dim, query_dtype, rotary_mode) diff --git a/tests/st/test_custom_rope_v3.py b/tests/st/test_custom_rope_v3.py index c490c83..37db31f 100644 --- a/tests/st/test_custom_rope_v3.py +++ b/tests/st/test_custom_rope_v3.py @@ -17,11 +17,11 @@ import time from functools import wraps import numpy as np import pytest -import ms_custom_ops import mindspore as ms from mindspore.common.api import jit from mindspore import Tensor, mint, nn, context, Profiler from mindspore.profiler import ProfilerLevel, ProfilerActivity, AicoreMetrics +import ms_custom_ops def jit_for_graph_mode(fn): -- Gitee