diff --git a/docs/map_from_buildin_to_custom.md b/docs/map_from_buildin_to_custom.md index 6cf29fd4023046ab670641c8ecbada930bc2ed05..54a7daa00d4ab6ab8d073b52f73275e42f171720 100644 --- a/docs/map_from_buildin_to_custom.md +++ b/docs/map_from_buildin_to_custom.md @@ -2,3 +2,4 @@ |------------------------------------|-------------------------------------|-----------------------------------| | ops.auto_generate.mla | [ms_custom_ops.mla](../ops/c_api/mla/mla_doc.md) | 新增了input_format参数,用于指定输入参数的format | | ops.moe_init_routing_v2 | [ms_custom_ops.moe_init_routing_v2](../ops/c_api/moe_init_routing_v2/moe_init_routing_v2.md) | 接口一致,仅支持 Atlas 推理系列 | +|ops.auto_generate.moe_gating_group_topk|[ms_custom_ops.moe_gating_group_topk](../ops/c_api/moe_gating_group_topk/moe_gating_group_topk.md) |接口一致| diff --git a/docs/op_list.md b/docs/op_list.md index 4026f18fa13ec8256ac0517e27228c7a52f81ea9..6eda4efefe18af5e0cd0be7b1a3c0af818f4bce4 100644 --- a/docs/op_list.md +++ b/docs/op_list.md @@ -2,7 +2,7 @@ 1. [apply_rotary_pos_emb](../ops/c_api/apply_rotary_pos_emb/apply_rotary_pos_emb.md) 1. [apply_rotary_pos_emb_ext](../ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.md) -1. [apply_rotary_pos_emb_v3](../ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md) +1. [apply_rotary_pos_emb_ms](../ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms.md) 1. [fa_update](../ops/c_api/fa_update/fa_update_doc.md) 1. [flash_attention_encoder](../ops/c_api/flash_attention_encoder/flash_attention_encoder.md) 1. [fused_add_topk_div](../ops/c_api/fused_add_topk_div/fused_add_topk_div.md) diff --git a/docs/ops_develop_guide.md b/docs/ops_develop_guide.md index bfebe1fdaccf56d6c19f911f265607cf131a4c31..4b49c5e4d60eea486fac2fbc42ceac64c0b44776 100644 --- a/docs/ops_develop_guide.md +++ b/docs/ops_develop_guide.md @@ -22,17 +22,17 @@ ## 3 ascendc类型算子实现 -以[`ops/ascendc/apply_rotary_pos_emb_v3`](../ops/ascendc/apply_rotary_pos_emb_v3)为例。 +以[`ops/ascendc/apply_rotary_pos_emb_ms`](../ops/ascendc/apply_rotary_pos_emb_ms)为例。 ### 3.1 目录结构 ```text -apply_rotary_pos_emb_v3/ +apply_rotary_pos_emb_ms/ ├── op_host/ # 算子逻辑host侧实现,包括类型注册、InferShape、Tiling等实现代码 ├── op_kernel/ # 算子逻辑kernel侧实现 -├── apply_rotary_pos_emb_v3_op.yaml # 算子在MindSpore侧原型定义 -├── apply_rotary_pos_emb_v3.md # 算子通过MindSpore对外提供的接口说明文档 -└── apply_rotary_pos_emb_v3.cc # 算子在MindSpore中的接入,包括InferShape(与op_host中的InferShape逻辑相同,但是实现接口不一致)、静态图KernelMod接入、动态图等代码 +├── apply_rotary_pos_emb_ms_op.yaml # 算子在MindSpore侧原型定义 +├── apply_rotary_pos_emb_ms.md # 算子通过MindSpore对外提供的接口说明文档 +└── apply_rotary_pos_emb_ms.cc # 算子在MindSpore中的接入,包括InferShape(与op_host中的InferShape逻辑相同,但是实现接口不一致)、静态图KernelMod接入、动态图等代码 ``` ### 3.2 kernel开发 @@ -49,10 +49,10 @@ apply_rotary_pos_emb_v3/ #### 3.3.2 Infer实现 继承类`OpFuncImpl`,一般只要重写`InferShape`、`InferType`和`GeneralInferRegistered`方法。其中`GeneralInferRegistered`方法固定返回`true`。 -`ApplyRotaryPosEmbV3OpFuncImpl`类名需要满足规则:`op_name` + `OpFuncImpl`,`op_name`与yaml中原型定义指定的算子名称保持一致。 +`ApplyRotaryPosEmbMSOpFuncImpl`类名需要满足规则:`op_name` + `OpFuncImpl`,`op_name`与yaml中原型定义指定的算子名称保持一致。 ```c++ -class OPS_API ApplyRotaryPosEmbV3OpFuncImpl : public OpFuncImpl { +class OPS_API ApplyRotaryPosEmbMSOpFuncImpl : public OpFuncImpl { public: ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { ... @@ -77,29 +77,29 @@ class OPS_API ApplyRotaryPosEmbV3OpFuncImpl : public OpFuncImpl { 继承类`AclnnCustomKernelMod`。需要实现构造函数,并重写`Launch`和`GetWorkSpaceInfo`函数。 ```c++ -class ApplyRotaryPosEmbV3Ascend : public AclnnCustomKernelMod { +class ApplyRotaryPosEmbMSAscend : public AclnnCustomKernelMod { public: - ApplyRotaryPosEmbV3Ascend() : AclnnCustomKernelMod(std::move("aclnnApplyRotaryPosEmbV3")) {} - ~ApplyRotaryPosEmbV3Ascend() = default; + ApplyRotaryPosEmbMSAscend() : AclnnCustomKernelMod(std::move("aclnnApplyRotaryPosEmbMS")) {} + ~ApplyRotaryPosEmbMSAscend() = default; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { ... RunOp( - stream_ptr, workspace, inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)], layout_, rotary_mode_); + stream_ptr, workspace, inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSQueryIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSKeyIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSCosIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSSinIndex)], layout_, rotary_mode_); return true; } void GetWorkSpaceInfo(const std::vector &inputs, const std::vector &outputs) override { ... - GetWorkspaceForResize(inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)], + GetWorkspaceForResize(inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSQueryIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSKeyIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSCosIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSSinIndex)], layout_, rotary_mode_); } @@ -114,8 +114,8 @@ class ApplyRotaryPosEmbV3Ascend : public AclnnCustomKernelMod { **REG_GRAPH_MODE_OP(***python接口名称***, ***Infer实现类***, ***KernelMod类***)** ```c++ -REG_GRAPH_MODE_OP(apply_rotary_pos_emb_v3, ms_custom_ops::ApplyRotaryPosEmbV3OpFuncImpl, - ms_custom_ops::ApplyRotaryPosEmbV3Ascend); +REG_GRAPH_MODE_OP(apply_rotary_pos_emb_ms, ms_custom_ops::ApplyRotaryPosEmbMSOpFuncImpl, + ms_custom_ops::ApplyRotaryPosEmbMSAscend); ``` ### 3.4 MindSpore动态图接入 @@ -123,24 +123,24 @@ REG_GRAPH_MODE_OP(apply_rotary_pos_emb_v3, ms_custom_ops::ApplyRotaryPosEmbV3OpF #### 3.4.1 实现C++侧调用函数 ```c++ -std::vector apply_rotary_pos_emb_v3_custom(const ms::Tensor &query, const ms::Tensor &key, +std::vector apply_rotary_pos_emb_ms_custom(const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &cos, const ms::Tensor &sin, const std::string layout_str, const std::string rotary_mode) { - std::string op_name = "ApplyRotaryPosEmbV3"; + std::string op_name = "ApplyRotaryPosEmbMS"; // 1). 创建runner,AscendC算子采用aclnn两段式接口,所以只需要继承预定义的`ms::pynative::AclnnOpRunner`类即可 auto runner = std::make_shared(op_name); // 输入shape检查 - ApplyRotaryPosEmbV3CheckInputsShape(op_name, query.shape(), key.shape(), cos.shape(), sin.shape()); + ApplyRotaryPosEmbMSCheckInputsShape(op_name, query.shape(), key.shape(), cos.shape(), sin.shape()); // 输入dtype检查 - ApplyRotaryPosEmbV3CheckInputsType(op_name, query.data_type(), key.data_type(), cos.data_type(), sin.data_type()); + ApplyRotaryPosEmbMSCheckInputsType(op_name, query.data_type(), key.data_type(), cos.data_type(), sin.data_type()); - // 2). 推导输出Tensor,包括shape和dtype信息。`apply_rotary_pos_emb_v3`属于原地更新算子,不需要推导输出 + // 2). 推导输出Tensor,包括shape和dtype信息。`apply_rotary_pos_emb_ms`属于原地更新算子,不需要推导输出 // 3). 设置launch Function - // 此处"aclnnApplyRotaryPosEmbV3", 是算字库函数表中名字前面加上aclnn - // 可通过 nm -D ./build/xxx/xxx/ms_custom_ops.xxx.so | grep "ApplyRotaryPosEmbV3"来确认 + // 此处"aclnnApplyRotaryPosEmbMS", 是算字库函数表中名字前面加上aclnn + // 可通过 nm -D ./build/xxx/xxx/ms_custom_ops.xxx.so | grep "ApplyRotaryPosEmbMS"来确认 // 如果是复写算子(inplace), 不必添加输出参数 - runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnApplyRotaryPosEmbV3, query, key, cos, sin, layout_str, rotary_mode)); + runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnApplyRotaryPosEmbMS, query, key, cos, sin, layout_str, rotary_mode)); // 4). 执行算子。如果是复写算子(inplace), 输出参数为空 runner->Run({query, key, cos, sin}, {}); // 5). 返回输出 @@ -152,9 +152,9 @@ std::vector apply_rotary_pos_emb_v3_custom(const ms::Tensor &query, ```c++ MS_CUSTOM_OPS_EXTENSION_MODULE(m) { - m.def("apply_rotary_pos_emb_v3", // python接口名称 - &pyboost_apply_rotary_pos_emb_v3, // 绑定到python的c++接口 - "ApplyRotaryPosEmbV3", // 算子描述 + m.def("apply_rotary_pos_emb_ms", // python接口名称 + &pyboost_apply_rotary_pos_emb_ms, // 绑定到python的c++接口 + "ApplyRotaryPosEmbMS", // 算子描述 pybind11::arg("query"), // 以下为参数 pybind11::arg("key"), pybind11::arg("cos"), pybind11::arg("sin"), pybind11::arg("layout"), pybind11::arg("rotary_mode")); @@ -163,7 +163,7 @@ MS_CUSTOM_OPS_EXTENSION_MODULE(m) { ### 3.5 编写算子文档 -参考[`apply_rotary_pos_emb_v3.md`](../ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md)。新增算子后需要同步更新[`op_list.md`](op_list.md),调用脚本可以自动生成: +参考[`apply_rotary_pos_emb_ms.md`](../ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms.md)。新增算子后需要同步更新[`op_list.md`](op_list.md),调用脚本可以自动生成: ``` python python scripts/generate_op_list.py diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc b/ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms.cc similarity index 70% rename from ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc rename to ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms.cc index 746b3dbeda230dc0795b4dc3723246b7a78d5707..0d5ad31e2ebb52f5972fe0d747f942dae31b5119 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc +++ b/ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms.cc @@ -35,17 +35,17 @@ constexpr uint32_t ROTARY_DIM_FACTOR = 2; constexpr int32_t LAYOUT_BSH = 1; constexpr const char *LAYOUT_BSH_STR = "BSH"; constexpr const char *ROTARY_INTERLEAVE_STR = "interleave"; -enum class ApplyRotaryPosEmbV3InputIndex : size_t { - kApplyRotaryPosEmbV3QueryIndex = 0, - kApplyRotaryPosEmbV3KeyIndex, - kApplyRotaryPosEmbV3CosIndex, - kApplyRotaryPosEmbV3SinIndex, - kApplyRotaryPosEmbV3LayoutIndex, - kApplyRotaryPosEmbV3RotaryModeIndex, - kApplyRotaryPosEmbV3InputsNum, +enum class ApplyRotaryPosEmbMSInputIndex : size_t { + kApplyRotaryPosEmbMSQueryIndex = 0, + kApplyRotaryPosEmbMSKeyIndex, + kApplyRotaryPosEmbMSCosIndex, + kApplyRotaryPosEmbMSSinIndex, + kApplyRotaryPosEmbMSLayoutIndex, + kApplyRotaryPosEmbMSRotaryModeIndex, + kApplyRotaryPosEmbMSInputsNum, }; -static void ApplyRotaryPosEmbV3CheckInputsShape(const std::string &op_name, const std::vector &query_shape, +static void ApplyRotaryPosEmbMSCheckInputsShape(const std::string &op_name, const std::vector &query_shape, const std::vector &key_shape, const std::vector &cos_shape, const std::vector &sin_shape) { @@ -74,7 +74,7 @@ static void ApplyRotaryPosEmbV3CheckInputsShape(const std::string &op_name, cons op_name, ", query/key's dim0 should be equal cos/sin's dim0, but got query's shape is ", query_shape, ", cos's shape is ", cos_shape)); } -static void ApplyRotaryPosEmbV3CheckInputsType(const std::string &op_name, const TypeId &query_dtype, +static void ApplyRotaryPosEmbMSCheckInputsType(const std::string &op_name, const TypeId &query_dtype, const TypeId &key_dtype, const TypeId &cos_dtype, const TypeId &sin_dtype) { const std::unordered_set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; @@ -91,40 +91,40 @@ static void ApplyRotaryPosEmbV3CheckInputsType(const std::string &op_name, const << TypeIdToString(cos_dtype) << ", " << TypeIdToString(sin_dtype) << "'"; } } -class OPS_API ApplyRotaryPosEmbV3OpFuncImpl : public OpFuncImpl { +class OPS_API ApplyRotaryPosEmbMSOpFuncImpl : public OpFuncImpl { public: ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { MS_EXCEPTION_IF_NULL(primitive); - if (input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)] + if (input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSQueryIndex)] ->IsDynamicRank() || - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)] + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSKeyIndex)] ->IsDynamicRank() || - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)] + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSCosIndex)] ->IsDynamicRank() || - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)] + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSSinIndex)] ->IsDynamicRank()) { return kFakeOutTensorShapes; } auto op_name = primitive->name(); auto query_shape = - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)]->GetShape(); + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSQueryIndex)]->GetShape(); auto key_shape = - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)]->GetShape(); + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSKeyIndex)]->GetShape(); auto cos_shape = - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)]->GetShape(); + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSCosIndex)]->GetShape(); auto sin_shape = - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)]->GetShape(); + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSSinIndex)]->GetShape(); auto rotary_mode = - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3RotaryModeIndex)] + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSRotaryModeIndex)] ->GetScalarValueWithCheck(); - auto layout = input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3LayoutIndex)] + auto layout = input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSLayoutIndex)] ->GetScalarValueWithCheck(); MS_CHECK_VALUE(layout == LAYOUT_BSH_STR, CheckAndConvertUtils::FormatCommMsg(op_name, " layout should be 'BSH', but got ", layout)); MS_CHECK_VALUE( rotary_mode == ROTARY_INTERLEAVE_STR, CheckAndConvertUtils::FormatCommMsg(op_name, " rotary_mode should be 'interleave', but got ", rotary_mode)); - ApplyRotaryPosEmbV3CheckInputsShape(op_name, query_shape, key_shape, cos_shape, sin_shape); + ApplyRotaryPosEmbMSCheckInputsShape(op_name, query_shape, key_shape, cos_shape, sin_shape); // 复写算子, 无输出. 此处的输出与yaml中算子定义的return相对应, 无实际意义. return kFakeOutTensorShapes; } @@ -132,14 +132,14 @@ class OPS_API ApplyRotaryPosEmbV3OpFuncImpl : public OpFuncImpl { std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { auto op_name = primitive->name(); auto query_dtype = - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)]->GetType(); + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSQueryIndex)]->GetType(); auto key_dtype = - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)]->GetType(); + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSKeyIndex)]->GetType(); auto cos_dtype = - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)]->GetType(); + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSCosIndex)]->GetType(); auto sin_dtype = - input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)]->GetType(); - ApplyRotaryPosEmbV3CheckInputsType(op_name, query_dtype, key_dtype, cos_dtype, sin_dtype); + input_infos[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSSinIndex)]->GetType(); + ApplyRotaryPosEmbMSCheckInputsType(op_name, query_dtype, key_dtype, cos_dtype, sin_dtype); // 复写算子, 无输出. 此处的输出与yaml中算子定义的return相对应, 无实际意义. return kFakeOutTensorTypes; } @@ -147,36 +147,36 @@ class OPS_API ApplyRotaryPosEmbV3OpFuncImpl : public OpFuncImpl { bool GeneralInferRegistered() const override { return true; } }; -class ApplyRotaryPosEmbV3Ascend : public AclnnCustomKernelMod { +class ApplyRotaryPosEmbMSAscend : public AclnnCustomKernelMod { public: - ApplyRotaryPosEmbV3Ascend() : AclnnCustomKernelMod(std::move("aclnnApplyRotaryPosEmbV3")) {} - ~ApplyRotaryPosEmbV3Ascend() = default; + ApplyRotaryPosEmbMSAscend() : AclnnCustomKernelMod(std::move("aclnnApplyRotaryPosEmbMS")) {} + ~ApplyRotaryPosEmbMSAscend() = default; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { MS_EXCEPTION_IF_NULL(stream_ptr); RunOp( - stream_ptr, workspace, inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)], layout_, rotary_mode_); + stream_ptr, workspace, inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSQueryIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSKeyIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSCosIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSSinIndex)], layout_, rotary_mode_); return true; } void GetWorkSpaceInfo(const std::vector &inputs, const std::vector &outputs) override { - auto layout_str = inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3LayoutIndex)] + auto layout_str = inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSLayoutIndex)] ->GetValueWithCheck(); if (layout_str == LAYOUT_BSH_STR) { layout_ = LAYOUT_BSH; } else { MS_LOG(EXCEPTION) << "layout should be 'BSH', but got " << layout_str; } - rotary_mode_ = inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3RotaryModeIndex)] + rotary_mode_ = inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSRotaryModeIndex)] ->GetValueWithCheck(); - GetWorkspaceForResize(inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)], - inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)], + GetWorkspaceForResize(inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSQueryIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSKeyIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSCosIndex)], + inputs[static_cast(ApplyRotaryPosEmbMSInputIndex::kApplyRotaryPosEmbMSSinIndex)], layout_, rotary_mode_); return; } @@ -188,27 +188,27 @@ class ApplyRotaryPosEmbV3Ascend : public AclnnCustomKernelMod { }; } // namespace ms_custom_ops -REG_GRAPH_MODE_OP(apply_rotary_pos_emb_v3, ms_custom_ops::ApplyRotaryPosEmbV3OpFuncImpl, - ms_custom_ops::ApplyRotaryPosEmbV3Ascend); +REG_GRAPH_MODE_OP(apply_rotary_pos_emb_ms, ms_custom_ops::ApplyRotaryPosEmbMSOpFuncImpl, + ms_custom_ops::ApplyRotaryPosEmbMSAscend); // ============================================================================= // PYBOOST MODE IMPLEMENTATION // ============================================================================= namespace ms_custom_ops { -void npu_apply_rotary_pos_emb_v3(const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &cos, +void npu_apply_rotary_pos_emb_ms(const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &cos, const ms::Tensor &sin, const std::string &layout_str, const std::string &rotary_mode) { - std::string op_name = "ApplyRotaryPosEmbV3"; + std::string op_name = "ApplyRotaryPosEmbMS"; // 此处op_name是给人看的, 跟算子命名没有直接关联 auto runner = std::make_shared(op_name); // 输入shape检查 - ApplyRotaryPosEmbV3CheckInputsShape(op_name, query.shape(), key.shape(), cos.shape(), sin.shape()); + ApplyRotaryPosEmbMSCheckInputsShape(op_name, query.shape(), key.shape(), cos.shape(), sin.shape()); // 输入dtype检查 - ApplyRotaryPosEmbV3CheckInputsType(op_name, query.data_type(), key.data_type(), cos.data_type(), sin.data_type()); - // 此处"aclnnApplyRotaryPosEmbV3", 是算字库函数表中名字前面加上aclnn - // 可通过 nm -D ./build/xxx/xxx/ms_custom_ops.xxx.so | grep "ApplyRotaryPosEmbV3"来确认 + ApplyRotaryPosEmbMSCheckInputsType(op_name, query.data_type(), key.data_type(), cos.data_type(), sin.data_type()); + // 此处"aclnnApplyRotaryPosEmbMS", 是算字库函数表中名字前面加上aclnn + // 可通过 nm -D ./build/xxx/xxx/ms_custom_ops.xxx.so | grep "ApplyRotaryPosEmbMS"来确认 // 如果是复写算子(inplace), 不必添加输出参数 - runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnApplyRotaryPosEmbV3, query, key, cos, sin, layout_str, rotary_mode)); + runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnApplyRotaryPosEmbMS, query, key, cos, sin, layout_str, rotary_mode)); // 如果是复写算子(inplace), 输出参数为空 runner->Run({query, key, cos, sin}, {}); // 无输出的算子返回值用void(不同于静态图) @@ -216,15 +216,15 @@ void npu_apply_rotary_pos_emb_v3(const ms::Tensor &query, const ms::Tensor &key, } } // namespace ms_custom_ops -auto pyboost_apply_rotary_pos_emb_v3(const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &cos, +auto pyboost_apply_rotary_pos_emb_ms(const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &cos, const ms::Tensor &sin, const std::string &layout_str, const std::string &rotary_mode) { - return ms::pynative::PyboostRunner::Call(ms_custom_ops::npu_apply_rotary_pos_emb_v3, query, + return ms::pynative::PyboostRunner::Call(ms_custom_ops::npu_apply_rotary_pos_emb_ms, query, key, cos, sin, layout_str, rotary_mode); } MS_CUSTOM_OPS_EXTENSION_MODULE(m) { - m.def("apply_rotary_pos_emb_v3", &pyboost_apply_rotary_pos_emb_v3, "ApplyRotaryPosEmbV3", pybind11::arg("query"), + m.def("apply_rotary_pos_emb_ms", &pyboost_apply_rotary_pos_emb_ms, "ApplyRotaryPosEmbMS", pybind11::arg("query"), pybind11::arg("key"), pybind11::arg("cos"), pybind11::arg("sin"), pybind11::arg("layout"), pybind11::arg("rotary_mode")); } diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md b/ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms.md similarity index 81% rename from ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md rename to ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms.md index d800dbcb482a2ccb770a6f07eb2070e69d3c22d9..bb5fa917576eca56f483c09414d8cc100256d8f4 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md +++ b/ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms.md @@ -1,8 +1,8 @@ -# apply_rotary_pos_emb_v3算子 +# apply_rotary_pos_emb_ms算子 ## 描述 -apply_rotary_pos_emb_v3算子用于计算旋转编码操作。且支持部分数据参与选择位置编码计算。 +apply_rotary_pos_emb_ms算子用于计算旋转编码操作。且支持部分数据参与选择位置编码计算。 ## 输入参数 @@ -15,7 +15,7 @@ apply_rotary_pos_emb_v3算子用于计算旋转编码操作。且支持部分数 | layout | string | No | No | No | string | 表示输入Tensor的布局格式 | | rotary_mode | string | No | No | No | string | 表示支持计算公式中的旋转模式 | -Note: +### 约束说明 - 产品支持: Atlas推理系列产品AI Core。 - rotary_mode: 当前仅支持'interleave'模式。 @@ -23,6 +23,13 @@ Note: - dtype: query/key/cos/sin数据类型支持FP16/FP32,且四个输入参数类型一致。 - head_dim: 令`rotary_head_dim = 2 * cos_sin_head_dim`。要求`qk_head_dim >= rotary_head_dim`, qk_head_dim 不能小于rotary_head_dim。当`qk_head_dim > rotary_head_dim`时,只对`query/key[...:rotary_head_dim]` 做旋转位置编码。且`(qk_head_dim - rotary_head_dim)* size(dtype)`必须能被32整除。`cos_sin_head_dim * sizeof(dtype)`必须能被32整除。 +### 另见 + +|算子|特点| +|----|----| +|[rope](https://gitee.com/mindspore/ms_custom_ops/blob/master/ops/c_api/rope/rope.md)【推荐】|输入参数query/key/cos/sin要求两维, 支持half、interleave、quarter模式| +|[apply_rotary_pos_emb_ext](https://gitee.com/mindspore/ms_custom_ops/blob/master/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.md)|输入参数query/key/cos/sin要求四维, 当前仅支持head_size=128和half模式| + ## 输出参数 ## 特殊说明 @@ -55,7 +62,7 @@ query = Tensor(np_query, dtype=query_dtype) key = Tensor(np_key , dtype=query_dtype) cos = Tensor(np_cos, dtype=query_dtype) sin = Tensor(np_sin, dtype=query_dtype) -ms_custom_ops.apply_rotary_pos_emb_v3(query, key, cos, sin, layout, rotary_mode) +ms_custom_ops.apply_rotary_pos_emb_ms(query, key, cos, sin, layout, rotary_mode) print("result query:", query) print("result key:", key) ``` diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3_op.yaml b/ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms_op.yaml similarity index 87% rename from ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3_op.yaml rename to ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms_op.yaml index 02119d4ca4ede83b70e7298173127291d58e0376..0648385931587810e6a2021c56da8a4ea5874759 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3_op.yaml +++ b/ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms_op.yaml @@ -1,5 +1,5 @@ -#operator apply_rotary_pos_emb_v3 -apply_rotary_pos_emb_v3: +#operator apply_rotary_pos_emb_ms +apply_rotary_pos_emb_ms: args: query: dtype: tensor diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp b/ops/ascendc/apply_rotary_pos_emb_ms/op_host/apply_rotary_pos_emb_ms.cpp similarity index 73% rename from ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp rename to ops/ascendc/apply_rotary_pos_emb_ms/op_host/apply_rotary_pos_emb_ms.cpp index 463ccb90afe1d4e53b4fdfc33aa6a512dbb76900..33f807f04881e960bc0b7a977dee09cea40fabf8 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp +++ b/ops/ascendc/apply_rotary_pos_emb_ms/op_host/apply_rotary_pos_emb_ms.cpp @@ -14,18 +14,19 @@ * limitations under the License. */ #include -#include "apply_rotary_pos_emb_v3_tiling.h" // NOLINT(build/include_subdir) +#include "apply_rotary_pos_emb_ms_tiling.h" // NOLINT(build/include_subdir) +#include "utils/log/asc_cpu_log.h" #include "register/op_def_registry.h" #include "graph/utils/type_utils.h" #include "tiling/platform/platform_ascendc.h" namespace optiling { -constexpr uint32_t ROPE_V3_USE_TBUF_COUNT = 6; -constexpr uint32_t ROPE_V3_USE_TBUF_COSSIN_COUNT = 2; -constexpr uint32_t ROPE_V3_TILINGKEY_FP16 = 1; -constexpr uint32_t ROPE_V3_TILINGKEY_FP32 = 2; -constexpr uint32_t ROPE_V3_TILINGKEY_FACTOR = 10; -constexpr uint32_t ROPE_V3_ROTARY_DIM_FACTOR = 2; +constexpr uint32_t ApplyRotaryPosEmbMS_USE_TBUF_COUNT = 6; +constexpr uint32_t ApplyRotaryPosEmbMS_USE_TBUF_COSSIN_COUNT = 2; +constexpr uint32_t ApplyRotaryPosEmbMS_TILINGKEY_FP16 = 1; +constexpr uint32_t ApplyRotaryPosEmbMS_TILINGKEY_FP32 = 2; +constexpr uint32_t ApplyRotaryPosEmbMS_TILINGKEY_FACTOR = 10; +constexpr uint32_t ApplyRotaryPosEmbMS_ROTARY_DIM_FACTOR = 2; constexpr uint32_t kIndex0 = 0; constexpr uint32_t kIndex1 = 1; constexpr uint32_t kIndex2 = 2; @@ -33,8 +34,8 @@ constexpr uint32_t kDim0 = 0; constexpr uint32_t kDim1 = 1; constexpr uint32_t kDim2 = 2; -static ge::graphStatus ApplyRotaryPosEmbV3Tiling(gert::TilingContext *context) { - ApplyRotaryPosEmbV3TilingData tiling; +static ge::graphStatus ApplyRotaryPosEmbMSTiling(gert::TilingContext *context) { + ApplyRotaryPosEmbMSTilingData tiling; uint32_t tiling_key{0}; uint64_t ub_total_size; auto ascendc_platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); @@ -53,7 +54,7 @@ static ge::graphStatus ApplyRotaryPosEmbV3Tiling(gert::TilingContext *context) { uint32_t query_head_dim = query_shape.GetDim(kDim2); uint32_t cos_head_dim = cos_shape.GetDim(kDim1); - uint32_t rotary_dim = cos_head_dim * ROPE_V3_ROTARY_DIM_FACTOR; + uint32_t rotary_dim = cos_head_dim * ApplyRotaryPosEmbMS_ROTARY_DIM_FACTOR; uint32_t is_split = (rotary_dim == query_head_dim ? 0 : 1); tiling.set_queryHeadDim(query_head_dim); @@ -73,25 +74,24 @@ static ge::graphStatus ApplyRotaryPosEmbV3Tiling(gert::TilingContext *context) { const uint32_t *rotaryMode = context->GetAttrs()->GetAttrPointer(kIndex1); tiling.set_layout(*layout); tiling.set_rotaryMode(*rotaryMode); - uint32_t ub_use = - query_head_dim * (key_head_num + query_head_num) * ge::GetSizeByDataType(query_type) + - ROPE_V3_USE_TBUF_COSSIN_COUNT * cos_head_dim * ge::GetSizeByDataType(cos_type) + - (tiling.get_qHiddenSize() + tiling.get_kHiddenSize()) * ROPE_V3_USE_TBUF_COUNT * ge::GetSizeByDataType(query_type); + uint32_t ub_use = query_head_dim * (key_head_num + query_head_num) * ge::GetSizeByDataType(query_type) + + ApplyRotaryPosEmbMS_USE_TBUF_COSSIN_COUNT * cos_head_dim * ge::GetSizeByDataType(cos_type) + + (tiling.get_qHiddenSize() + tiling.get_kHiddenSize()) * ApplyRotaryPosEmbMS_USE_TBUF_COUNT * + ge::GetSizeByDataType(query_type); if (ub_use > ub_total_size) { - // 这里临时用std::cerr输出错误日志,等框架支持log功能后,再改用log - std::cerr << "ERROR: " - << "(cos.head_dim * 2 + query.hidden_size + key.hidden_size + (query.head_num + key.head_num) * 6) " - << "* sizeof(type) should be less than or equal to UB size, but got " << std::to_string(ub_use) << " > " - << std::to_string(ub_total_size); + ASC_CPU_LOG_ERROR( + "Not support, (cos.head_dim * 2 + query.hidden_size + key.hidden_size + (query.head_num + key.head_num) * 6) * " + "sizeof(type) should be less than or equal to UB size, but got %lld > %lld", + std::to_string(ub_use), std::to_string(ub_total_size)); return ge::GRAPH_FAILED; } if (query_type == ge::DataType::DT_FLOAT16) { - tiling_key = ROPE_V3_TILINGKEY_FP16; + tiling_key = ApplyRotaryPosEmbMS_TILINGKEY_FP16; } else if (query_type == ge::DataType::DT_FLOAT) { - tiling_key = ROPE_V3_TILINGKEY_FP32; + tiling_key = ApplyRotaryPosEmbMS_TILINGKEY_FP32; } - tiling_key = tiling_key * ROPE_V3_TILINGKEY_FACTOR + is_split; + tiling_key = tiling_key * ApplyRotaryPosEmbMS_TILINGKEY_FACTOR + is_split; context->SetBlockDim(coreNum); context->SetTilingKey(tiling_key); @@ -103,7 +103,7 @@ static ge::graphStatus ApplyRotaryPosEmbV3Tiling(gert::TilingContext *context) { } } // namespace optiling namespace ge { -static ge::graphStatus ApplyRotaryPosEmbV3InferShape(gert::InferShapeContext *context) { +static ge::graphStatus ApplyRotaryPosEmbMSInferShape(gert::InferShapeContext *context) { const gert::Shape *query_shape = context->GetInputShape(0); const gert::Shape *key_shape = context->GetInputShape(1); gert::Shape *out_query_shape = context->GetOutputShape(0); @@ -112,7 +112,7 @@ static ge::graphStatus ApplyRotaryPosEmbV3InferShape(gert::InferShapeContext *co *out_key_shape = *key_shape; return GRAPH_SUCCESS; } -static graphStatus ApplyRotaryPosEmbV3InferDataType(gert::InferDataTypeContext *context) { +static graphStatus ApplyRotaryPosEmbMSInferDataType(gert::InferDataTypeContext *context) { const auto inputDataType = context->GetInputDataType(0); context->SetOutputDataType(0, inputDataType); context->SetOutputDataType(1, inputDataType); @@ -121,9 +121,9 @@ static graphStatus ApplyRotaryPosEmbV3InferDataType(gert::InferDataTypeContext * } // namespace ge namespace ops { -class ApplyRotaryPosEmbV3 : public OpDef { +class ApplyRotaryPosEmbMS : public OpDef { public: - explicit ApplyRotaryPosEmbV3(const char *name) : OpDef(name) { + explicit ApplyRotaryPosEmbMS(const char *name) : OpDef(name) { this->Input("query") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) @@ -161,9 +161,9 @@ class ApplyRotaryPosEmbV3 : public OpDef { this->Attr("layout").AttrType(OPTIONAL).Int(1); this->Attr("rotary_mode").AttrType(OPTIONAL).String("interleave"); - this->SetInferShape(ge::ApplyRotaryPosEmbV3InferShape).SetInferDataType(ge::ApplyRotaryPosEmbV3InferDataType); - this->AICore().SetTiling(optiling::ApplyRotaryPosEmbV3Tiling).AddConfig("ascend310p"); + this->SetInferShape(ge::ApplyRotaryPosEmbMSInferShape).SetInferDataType(ge::ApplyRotaryPosEmbMSInferDataType); + this->AICore().SetTiling(optiling::ApplyRotaryPosEmbMSTiling).AddConfig("ascend310p"); } }; -OP_ADD(ApplyRotaryPosEmbV3); +OP_ADD(ApplyRotaryPosEmbMS); } // namespace ops diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3_tiling.h b/ops/ascendc/apply_rotary_pos_emb_ms/op_host/apply_rotary_pos_emb_ms_tiling.h similarity index 49% rename from ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3_tiling.h rename to ops/ascendc/apply_rotary_pos_emb_ms/op_host/apply_rotary_pos_emb_ms_tiling.h index bf90e3c9abd78fc63656bedf6e912296e4a9b678..f5c6f018969b67e058ce2ff6626482616ec19ace 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3_tiling.h +++ b/ops/ascendc/apply_rotary_pos_emb_ms/op_host/apply_rotary_pos_emb_ms_tiling.h @@ -18,22 +18,22 @@ #include "register/tilingdata_base.h" namespace optiling { -BEGIN_TILING_DATA_DEF(ApplyRotaryPosEmbV3TilingData) - TILING_DATA_FIELD_DEF(uint32_t, tilingId); - TILING_DATA_FIELD_DEF(uint32_t, useCoreNum); - TILING_DATA_FIELD_DEF(uint32_t, tokensPerCore); - TILING_DATA_FIELD_DEF(uint32_t, tokensTail); - TILING_DATA_FIELD_DEF(uint32_t, qHeadNum); - TILING_DATA_FIELD_DEF(uint32_t, kHeadNum); - TILING_DATA_FIELD_DEF(uint32_t, qHiddenSize); - TILING_DATA_FIELD_DEF(uint32_t, kHiddenSize); - TILING_DATA_FIELD_DEF(uint32_t, queryHeadDim); - TILING_DATA_FIELD_DEF(uint32_t, cosHeadDim); - TILING_DATA_FIELD_DEF(uint32_t, rotaryDim); - TILING_DATA_FIELD_DEF(uint32_t, layout); - TILING_DATA_FIELD_DEF(uint32_t, rotaryMode); +BEGIN_TILING_DATA_DEF(ApplyRotaryPosEmbMSTilingData) +TILING_DATA_FIELD_DEF(uint32_t, tilingId); +TILING_DATA_FIELD_DEF(uint32_t, useCoreNum); +TILING_DATA_FIELD_DEF(uint32_t, tokensPerCore); +TILING_DATA_FIELD_DEF(uint32_t, tokensTail); +TILING_DATA_FIELD_DEF(uint32_t, qHeadNum); +TILING_DATA_FIELD_DEF(uint32_t, kHeadNum); +TILING_DATA_FIELD_DEF(uint32_t, qHiddenSize); +TILING_DATA_FIELD_DEF(uint32_t, kHiddenSize); +TILING_DATA_FIELD_DEF(uint32_t, queryHeadDim); +TILING_DATA_FIELD_DEF(uint32_t, cosHeadDim); +TILING_DATA_FIELD_DEF(uint32_t, rotaryDim); +TILING_DATA_FIELD_DEF(uint32_t, layout); +TILING_DATA_FIELD_DEF(uint32_t, rotaryMode); END_TILING_DATA_DEF; -REGISTER_TILING_DATA_CLASS(ApplyRotaryPosEmbV3, ApplyRotaryPosEmbV3TilingData) -} -#endif // ADD_CUSTOM_TILING_H +REGISTER_TILING_DATA_CLASS(ApplyRotaryPosEmbMS, ApplyRotaryPosEmbMSTilingData) +} // namespace optiling +#endif // ADD_CUSTOM_TILING_H diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/op_kernel/apply_rotary_pos_emb_v3.cpp b/ops/ascendc/apply_rotary_pos_emb_ms/op_kernel/apply_rotary_pos_emb_ms.cpp similarity index 95% rename from ops/ascendc/apply_rotary_pos_emb_v3/op_kernel/apply_rotary_pos_emb_v3.cpp rename to ops/ascendc/apply_rotary_pos_emb_ms/op_kernel/apply_rotary_pos_emb_ms.cpp index b1e2927256f64418218923d3d8dbfe1b477670d6..9f94e9f3adfc70affe8cea758941222fd1a7e7f8 100644 --- a/ops/ascendc/apply_rotary_pos_emb_v3/op_kernel/apply_rotary_pos_emb_v3.cpp +++ b/ops/ascendc/apply_rotary_pos_emb_ms/op_kernel/apply_rotary_pos_emb_ms.cpp @@ -16,11 +16,11 @@ #include "kernel_operator.h" constexpr int32_t BUFFER_NUM = 1; template -class KernelApplyRotaryPosEmbV3 { +class KernelApplyRotaryPosEmbMS { public: - __aicore__ inline KernelApplyRotaryPosEmbV3() {} + __aicore__ inline KernelApplyRotaryPosEmbMS() {} __aicore__ inline void Init(GM_ADDR query, GM_ADDR key, GM_ADDR cos, GM_ADDR sin, GM_ADDR oquery, GM_ADDR okey, - GM_ADDR workspace, ApplyRotaryPosEmbV3TilingData *tiling, AscendC::TPipe *tPipe) { + GM_ADDR workspace, ApplyRotaryPosEmbMSTilingData *tiling, AscendC::TPipe *tPipe) { pipe = tPipe; tilingData = tiling; if constexpr (IS_SPLIT) { @@ -217,38 +217,38 @@ class KernelApplyRotaryPosEmbV3 { uint32_t originKeyHiddenSize{0}; uint32_t queryKeyCalHiddenSize{0}; uint32_t queryKeyInHiddenSize{0}; - ApplyRotaryPosEmbV3TilingData *tilingData = nullptr; + ApplyRotaryPosEmbMSTilingData *tilingData = nullptr; }; -extern "C" __global__ __aicore__ void apply_rotary_pos_emb_v3(GM_ADDR query, GM_ADDR key, GM_ADDR cos, GM_ADDR sin, +extern "C" __global__ __aicore__ void apply_rotary_pos_emb_ms(GM_ADDR query, GM_ADDR key, GM_ADDR cos, GM_ADDR sin, GM_ADDR outq, GM_ADDR outk, GM_ADDR workspace, GM_ADDR tiling) { GET_TILING_DATA(tilingData, tiling); GM_ADDR usrWorkspace = AscendC::GetUserWorkspace(workspace); AscendC::TPipe pipe; if (TILING_KEY_IS(20)) { - KernelApplyRotaryPosEmbV3 op; + KernelApplyRotaryPosEmbMS op; op.Init(query, key, cos, sin, outq, outk, workspace, &tilingData, &pipe); op.Process(); } else if (TILING_KEY_IS(10)) { - KernelApplyRotaryPosEmbV3 op; + KernelApplyRotaryPosEmbMS op; op.Init(query, key, cos, sin, outq, outk, workspace, &tilingData, &pipe); op.Process(); } else if (TILING_KEY_IS(21)) { - KernelApplyRotaryPosEmbV3 op; + KernelApplyRotaryPosEmbMS op; op.Init(query, key, cos, sin, outq, outk, workspace, &tilingData, &pipe); op.Process(); } else if (TILING_KEY_IS(11)) { - KernelApplyRotaryPosEmbV3 op; + KernelApplyRotaryPosEmbMS op; op.Init(query, key, cos, sin, outq, outk, workspace, &tilingData, &pipe); op.Process(); } } #ifndef ASCENDC_CPU_DEBUG -void apply_rotary_pos_emb_v3_do(uint32_t blockDim, void *l2ctrl, void *stream, uint8_t *query, uint8_t *key, +void apply_rotary_pos_emb_ms_do(uint32_t blockDim, void *l2ctrl, void *stream, uint8_t *query, uint8_t *key, uint8_t *cos, uint8_t *sin, uint8_t *outq, uint8_t *outk, uint8_t *workspace, uint8_t *tiling) { - apply_rotary_pos_emb_v3<<>>(query, key, cos, sin, outq, outk, workspace, tiling); + apply_rotary_pos_emb_ms<<>>(query, key, cos, sin, outq, outk, workspace, tiling); } #endif diff --git a/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.md b/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.md index 20c5bc120f3d105e2071ab3ad3777df971118c8f..beea8b2d25f639e884c143558d5d7be61f0fb6bf 100644 --- a/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.md +++ b/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.md @@ -15,7 +15,8 @@ apply_rotary_pos_emb_ext算子用于计算旋转编码操作。该算子底层 | layout | string | No | No | No | string | 表示输入Tensor的布局格式 | | rotary_mode | string | No | No | No | string | 表示支持计算公式中的旋转模式 | -Note: +## 约束说明 + head_dim当前只支持128. Atlas推理系列产品A2, Atlas推理系列产品A3: rotary_mode只支持"half". @@ -34,6 +35,13 @@ cos/sin shape大小为[batch_size, seq_len, 1, head_dim].支持类型为:FP16/FP 此外注意,`ub_required = (q_n + k_n) * 128 * castSize * 2 + 128 * DtypeSize * 4 + (q_n + k_n) * 128 * castSize + (q_n + k_n) * 128 * castSize * 2 + cast * (128 * 4 * 2)`,当计算出ub_required的大小超过当前AI处理器的UB空间总大小时,不支持使用该融合算子. 不支持空tensor场景. +### 另见 + +|算子|特点| +|----|----| +|[rope](https://gitee.com/mindspore/ms_custom_ops/blob/master/ops/c_api/rope/rope.md)【推荐】|输入参数query/key/cos/sin要求两维, 支持half、interleave、quarter模式| +|[apply_rotary_pos_emb_ms](https://gitee.com/mindspore/ms_custom_ops/blob/master/ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms.md) |输入参数query/key要求三维, cos/sin要求两维, 支持query_head_size >= cos_head_size,当前仅支持Atlas推理系列产品下的interleave模式| + ## 输出参数 ## 特殊说明 diff --git a/ops/c_api/rope/rope.md b/ops/c_api/rope/rope.md index 55641c3809c8749bc028d46997eac75b76a95aec..30bc3003d2da075f01e2cc1a9f9730dfb4c1bcab 100644 --- a/ops/c_api/rope/rope.md +++ b/ops/c_api/rope/rope.md @@ -39,6 +39,13 @@ - Decoder阶段要取cos和sin表中seqlen对应的cos/sin值输入。 - 多batch场景需要组合使用gather算子。 +### 另见 + +|算子|特点| +|----|----| +|[apply_rotary_pos_emb_ext](https://gitee.com/mindspore/ms_custom_ops/blob/master/ops/c_api/apply_rotary_pos_emb_ext/apply_rotary_pos_emb_ext.md)|输入参数query/key/cos/sin要求四维, 当前仅支持head_size=128和half模式| +|[apply_rotary_pos_emb_ms](https://gitee.com/mindspore/ms_custom_ops/blob/master/ops/ascendc/apply_rotary_pos_emb_ms/apply_rotary_pos_emb_ms.md) |输入参数query/key要求三维, cos/sin要求两维, 支持query_head_size >= cos_head_size,当前仅支持Atlas推理系列产品下的interleave模式| + ## 使用示例 ```python diff --git a/tests/st/test_custom_rope_v3.py b/tests/st/test_custom_apply_rotary_pos_emb_ms.py similarity index 95% rename from tests/st/test_custom_rope_v3.py rename to tests/st/test_custom_apply_rotary_pos_emb_ms.py index 07132807776c56ae703048cd8158d432c088a1df..170d08223c0909b5d9356bcf4c415cf90c06200b 100644 --- a/tests/st/test_custom_rope_v3.py +++ b/tests/st/test_custom_apply_rotary_pos_emb_ms.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""test rove_v3 cases""" +"""test apply_rotary_pos_emb_ms cases""" import time from functools import wraps @@ -93,12 +93,12 @@ def golden_apply_rotary_emb_split(net, query_dtype, layout, rotary_mode, tokens, profiler.analyse() -class ApplyRotaryEmbV3Net(nn.Cell): +class ApplyRotaryPosEmbMSNet(nn.Cell): """Reshape and cache operation for NZ/ND format with all parameters""" @jit_for_graph_mode def construct(self, query, key, cos, sin, layout, rotary_mode): - return ms_custom_ops.apply_rotary_pos_emb_v3(query, key, cos, sin, layout, rotary_mode) + return ms_custom_ops.apply_rotary_pos_emb_ms(query, key, cos, sin, layout, rotary_mode) def run_rope_interleave(net, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, @@ -209,7 +209,7 @@ def run_rope_interleave_split(net, query_dtype, layout, rotary_mode, tokens, hea @pytest.mark.parametrize("head_num_q", [32]) @pytest.mark.parametrize("head_num_k", [2]) @pytest.mark.parametrize("head_dim", [64]) -def test_rope_v3_interleave(exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim): +def test_apply_rotary_pos_emb_ms_interleave(exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim): """ Feature:aclnnApplyRotaryPosEmb kernel. Description: test for ApplyRotaryPosEmbExt ops. @@ -217,7 +217,7 @@ def test_rope_v3_interleave(exec_mode, query_dtype, layout, rotary_mode, tokens, """ ms.set_context(device_target="Ascend", mode=exec_mode) ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - net = ApplyRotaryEmbV3Net() + net = ApplyRotaryPosEmbMSNet() run_rope_interleave(net, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim) @@ -234,7 +234,7 @@ def test_rope_v3_interleave(exec_mode, query_dtype, layout, rotary_mode, tokens, @pytest.mark.parametrize("head_num_k", [2]) @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("rotary_dim", [64]) -def test_rope_v3_interleave_split(exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, +def test_apply_rotary_pos_emb_ms_interleave_split(exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, rotary_dim): """ Feature:aclnnApplyRotaryPosEmb kernel. @@ -243,6 +243,6 @@ def test_rope_v3_interleave_split(exec_mode, query_dtype, layout, rotary_mode, t """ ms.set_context(device_target="Ascend", mode=exec_mode) ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - net = ApplyRotaryEmbV3Net() + net = ApplyRotaryPosEmbMSNet() run_rope_interleave_split(net, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, rotary_dim)