diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc new file mode 100644 index 0000000000000000000000000000000000000000..63caab0c50b2b7c3c4de2ba6ff05f5d5c86c789f --- /dev/null +++ b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.cc @@ -0,0 +1,215 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include +#include +#include +#include "ops/framework/aclnn/graphmode/aclnn_kernel_mod.h" +#include "ops/framework/utils.h" +#include "mindspore/include/custom_op_api.h" + +namespace ms_custom_ops { +// 约束条件: rotary_dim = 2 * cos_head_dim, query_head_dim >= rotary_dim +constexpr uint32_t ROTARY_DIM_FACTOR = 2; +enum class ApplyRotaryPosEmbV3InputIndex : size_t { + kApplyRotaryPosEmbV3QueryIndex = 0, + kApplyRotaryPosEmbV3KeyIndex, + kApplyRotaryPosEmbV3CosIndex, + kApplyRotaryPosEmbV3SinIndex, + kApplyRotaryPosEmbV3LayoutIndex, + kApplyRotaryPosEmbV3RotaryModeIndex, + kApplyRotaryPosEmbV3InputsNum, +}; + +static void ApplyRotaryPosEmbV3CheckInputsShape(const std::string &op_name, const std::vector &query_shape, + const std::vector &key_shape, + const std::vector &cos_shape, + const std::vector &sin_shape) { + if (query_shape.size() != kDim3 || key_shape.size() != kDim3 || cos_shape.size() != kDim2 || + sin_shape.size() != kDim2) { + MS_LOG(EXCEPTION) << op_name << ", the dim of inputs should be query.dim=key.dim=3, " + << "cos.dim=sin.dim=2, but got query.dim=" << query_shape.size() + << ", key.dim=" << key_shape.size() << ", cos.dim=" << cos_shape.size() + << ", sin.dim=" << sin_shape.size(); + } + MS_CHECK_VALUE(query_shape[kIndex2] == key_shape[kIndex2] && query_shape[kIndex0] == key_shape[kIndex0], + CheckAndConvertUtils::FormatCommMsg( + op_name, ", query.dim0 should be equal key.dim0, query.dim2 should be equal key.dim2,", + " but got query.shape=", query_shape, ", key.shape=", key_shape)); + MS_CHECK_VALUE( + cos_shape == sin_shape, + CheckAndConvertUtils::FormatCommMsg( + op_name, ", cos.shape should be equals sin.shape, but got cos.shape=", cos_shape, ", sin.shape=", sin_shape)); + MS_CHECK_VALUE( + query_shape[kIndex2] >= ROTARY_DIM_FACTOR * cos_shape[kIndex1], + CheckAndConvertUtils::FormatCommMsg( + op_name, ", the head_dim of query and key should be greater than or equal to twice head_dim of cos or sin,", + " but got query.shape=", query_shape, ", cos.shape=", cos_shape)); + MS_CHECK_VALUE(query_shape[kIndex0] == cos_shape[kIndex0], + CheckAndConvertUtils::FormatCommMsg( + op_name, ", query/key's dim0 should be equal cos/sin's dim0, but got query's shape is ", query_shape, + ", cos's shape is ", cos_shape)); +} +static void ApplyRotaryPosEmbV3CheckInputsType(const std::string &op_name, const TypeId &query_dtype, + const TypeId &key_dtype, const TypeId &cos_dtype, + const TypeId &sin_dtype) { + const std::unordered_set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; + std::unordered_set input_types = {query_dtype, key_dtype, cos_dtype, sin_dtype}; + if (input_types.size() > 1) { + MS_LOG(EXCEPTION) << op_name << ", the dtype of 'query, key, cos, sin' should be same, but got '" + << TypeIdToString(query_dtype) << ", " << TypeIdToString(key_dtype) << ", " + << TypeIdToString(cos_dtype) << ", " << TypeIdToString(sin_dtype) << "'"; + } + if (valid_types.find(query_dtype) == valid_types.end()) { + MS_LOG(EXCEPTION) << op_name << ", the dtype of 'query, key, cos, sin' should be " + << TypeIdToString(kNumberTypeFloat16) << " or " << TypeIdToString(kNumberTypeFloat32) + << ", but got '" << TypeIdToString(query_dtype) << ", " << TypeIdToString(key_dtype) << ", " + << TypeIdToString(cos_dtype) << ", " << TypeIdToString(sin_dtype) << "'"; + } +} +class OPS_API ApplyRotaryPosEmbV3OpFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + MS_EXCEPTION_IF_NULL(primitive); + if (input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)] + ->IsDynamicRank() || + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)] + ->IsDynamicRank() || + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)] + ->IsDynamicRank() || + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)] + ->IsDynamicRank()) { + return { + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)]->GetShape(), + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)]->GetShape()}; + } + auto op_name = primitive->name(); + auto query_shape = + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)]->GetShape(); + auto key_shape = + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)]->GetShape(); + auto cos_shape = + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)]->GetShape(); + auto sin_shape = + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)]->GetShape(); + auto rotary_mode = + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3RotaryModeIndex)] + ->GetScalarValueWithCheck(); + auto layout = input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3LayoutIndex)] + ->GetScalarValueWithCheck(); + MS_CHECK_VALUE(layout == "BSH", + CheckAndConvertUtils::FormatCommMsg(op_name, " layout should be 'BSH', but got ", layout)); + MS_CHECK_VALUE( + rotary_mode == "interleave", + CheckAndConvertUtils::FormatCommMsg(op_name, " rotary_mode should be 'interleave', but got ", rotary_mode)); + ApplyRotaryPosEmbV3CheckInputsShape(op_name, query_shape, key_shape, cos_shape, sin_shape); + return {query_shape, key_shape}; + } + + std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + auto op_name = primitive->name(); + auto query_dtype = + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)]->GetType(); + auto key_dtype = + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)]->GetType(); + auto cos_dtype = + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)]->GetType(); + auto sin_dtype = + input_infos[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)]->GetType(); + ApplyRotaryPosEmbV3CheckInputsType(op_name, query_dtype, key_dtype, cos_dtype, sin_dtype); + return {query_dtype, key_dtype}; + } + + bool GeneralInferRegistered() const override { return true; } +}; + +class ApplyRotaryPosEmbV3Ascend : public AclnnCustomKernelMod { + public: + ApplyRotaryPosEmbV3Ascend() : AclnnCustomKernelMod(std::move("aclnnApplyRotaryPosEmbV3")) {} + ~ApplyRotaryPosEmbV3Ascend() = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + MS_EXCEPTION_IF_NULL(stream_ptr); + RunOp( + stream_ptr, workspace, inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)], + inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)], + inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)], + inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)], layout_, rotary_mode_); + return true; + } + void GetWorkSpaceInfo(const std::vector &inputs, + const std::vector &outputs) override { + layout_ = inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3LayoutIndex)] + ->GetValueWithCheck(); + rotary_mode_ = inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3RotaryModeIndex)] + ->GetValueWithCheck(); + GetWorkspaceForResize(inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3QueryIndex)], + inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3KeyIndex)], + inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3CosIndex)], + inputs[static_cast(ApplyRotaryPosEmbV3InputIndex::kApplyRotaryPosEmbV3SinIndex)], + layout_, rotary_mode_); + return; + } + + private: + DEFINE_GET_WORKSPACE_FOR_RESIZE(); + std::string layout_ = "BSH"; + std::string rotary_mode_ = "interleave"; +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(apply_rotary_pos_emb_v3, ms_custom_ops::ApplyRotaryPosEmbV3OpFuncImpl, + ms_custom_ops::ApplyRotaryPosEmbV3Ascend); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +namespace ms_custom_ops { +using namespace mindspore; +using namespace mindspore::device::ascend; +constexpr size_t kApplyRotaryPosEmbV3OutputNum = 2; + +std::vector apply_rotary_pos_emb_v3_custom(const ms::Tensor &query, const ms::Tensor &key, + const ms::Tensor &cos, const ms::Tensor &sin, + const std::string layout_str, const std::string rotary_mode) { + std::string op_name = "ApplyRotaryPosEmbV3"; + // 此处op_name是给人看的, 跟算子命名没有直接关联 + auto runner = std::make_shared(op_name); + // 输入shape检查 + ApplyRotaryPosEmbV3CheckInputsShape(op_name, query.shape(), key.shape(), cos.shape(), sin.shape()); + // 输入dtype检查 + ApplyRotaryPosEmbV3CheckInputsType(op_name, query.data_type(), key.data_type(), cos.data_type(), sin.data_type()); + // 此处"aclnnApplyRotaryPosEmbV3", 是算字库函数表中名字前面加上aclnn + // 可通过 nm -D ./build/xxx/xxx/ms_custom_ops.xxx.so | grep "ApplyRotaryPosEmbV3"来确认 + // 如果是复写算子(inplace), 不必添加输出参数 + runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnApplyRotaryPosEmbV3, query, key, cos, sin, layout_str, rotary_mode)); + // 如果是复写算子(inplace), 输出参数为空 + runner->Run({query, key, cos, sin}, {}); + // 如果是复写算子(inplace), 将复写的input参数作为output返回 + return {query, key}; +} +} // namespace ms_custom_ops + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("apply_rotary_pos_emb_v3", + PYBOOST_CALLER(ms_custom_ops::kApplyRotaryPosEmbV3OutputNum, ms_custom_ops::apply_rotary_pos_emb_v3_custom)); +} \ No newline at end of file diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md new file mode 100644 index 0000000000000000000000000000000000000000..4688814a90d6c31e30bc04710e2d6d693b6645d9 --- /dev/null +++ b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3.md @@ -0,0 +1,67 @@ +# apply_rotary_pos_emb_v3算子 + +## 描述 + +apply_rotary_pos_emb_v3算子用于计算旋转编码操作。且支持部分数据参与选择位置编码计算。 + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|---------------------|-----------------|----------------------------------------|----------|---------|--------|--------------------------------------------------------| +| query | Tensor(dtype=FP16/FP32) | 3维[tokens, q_head_num, qk_head_dim] | No | Yes | ND | 执行旋转位置编码的第一个变量 | +| key | Tensor(dtype=FP16/FP32) | 3维[tokens, k_head_num, qk_head_dim] | No | Yes | ND | 执行旋转位置编码的第二个变量 | +| cos | Tensor(dtype=FP16/FP32) | 2维[tokens, cos_sin_head_dim] | No | No | ND | 表示参与计算的位置编码张量 | +| sin | Tensor(dtype=FP16/FP32) | 2维[tokens, cos_sin_head_dim] | No | No | ND | 表示参与计算的位置编码张量 | +| layout | string | No | Yes | No | string | 表示输入Tensor的布局格式 | +| rotary_mode | string | No | Yes | No | string | 表示支持计算公式中的旋转模式 | + +Note: +- 产品支持: Atlas推理系列产品AI Core +- rotary_mode: 当前仅支持'interleave'模式 +- layout: 当前仅支持'BSH' +- dtype: query/key/cos/sin数据类型支持FP16/FP32,且四个输入参数类型一致。 +- head_dim: 令`rotary_head_dim = 2 * cos_sin_head_dim`。 + - 要求`qk_head_dim >= rotary_head_dim`, qk_head_dim 不能小于rotary_head_dim。 + - 当`qk_head_dim > rotary_head_dim`时,只对`query/key[...:rotary_head_dim]` 做旋转位置编码。且(qk_head_dim - rotary_head_dim)* size(dtype)必须能被32整除 + - cos_sin_head_dim * sizeof(dtype) 必须能被32整除 + + +## 输出参数 + + + +## 特殊说明 + +## 使用示例 + +### 基本使用示例 + +```python + +import mindspore as ms +import numpy as np +import ms_custom_ops +from mindspore import context, Tensor + +ms.set_context(device_target="Ascend", mode=context.GRAPH_MODE) +ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + +tokens = 4096 +head_num_q = 32 +head_num_k =2 +qk_head_dim = 128 +rotary_head_dim = 64 +layout='BSH' +rotary_mode='interleave' +cos_head_dim = rotary_head_dim // 2 +query_dtype=ms.float16 +np_query = np.random.random((tokens, head_num_q, qk_head_dim)) +np_key = np.random.random((tokens, head_num_k, qk_head_dim)) +np_cos = np.random.random((tokens, cos_head_dim)) +np_sin = np.random.random((tokens, cos_head_dim)) +query = Tensor(np_query, dtype=query_dtype) +key = Tensor(np_key , dtype=query_dtype) +cos = Tensor(np_cos, dtype=query_dtype) +sin = Tensor(np_sin, dtype=query_dtype) +out_query, out_key = ms_custom_ops.apply_rotary_pos_emb_v3(query, key, cos, sin, layout, rotary_mode) +``` diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3_op.yaml b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d628ce9aec9f6321ac2eebfb298af70b21b42c62 --- /dev/null +++ b/ops/ascendc/apply_rotary_pos_emb_v3/apply_rotary_pos_emb_v3_op.yaml @@ -0,0 +1,26 @@ +#operator apply_rotary_pos_emb_v3 +apply_rotary_pos_emb_v3: + args: + query: + dtype: tensor + key: + dtype: tensor + cos: + dtype: tensor + sin: + dtype: tensor + layout: + dtype: str + rotary_mode: + dtype: str + args_signature: + rw_write: query, key + labels: + side_effect_mem: True + returns: + query_embed: + dtype: tensor + inplace: query + key_embed: + dtype: tensor + inplace: key \ No newline at end of file diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp b/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7b82b4490ece81c9d22bc9b8a6115cafc92e6ca2 --- /dev/null +++ b/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3.cpp @@ -0,0 +1,143 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "apply_rotary_pos_emb_v3_tiling.h" +#include "register/op_def_registry.h" +#include "graph/utils/type_utils.h" +#include "tiling/platform/platform_ascendc.h" + +namespace optiling { +static ge::graphStatus ApplyRotaryPosEmbV3Tiling(gert::TilingContext *context) { + ApplyRotaryPosEmbV3TilingData tiling; + uint32_t tiling_key{0}; + uint64_t ub_size; + auto ascendc_platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + ascendc_platform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_size); + auto coreNum = ascendc_platform.GetCoreNum(); + + auto query_shape = context->GetInputShape(0)->GetOriginShape(); + auto key_shape = context->GetInputShape(1)->GetOriginShape(); + auto cos_shape = context->GetInputShape(2)->GetOriginShape(); + + uint32_t tokens = query_shape.GetDim(0); + uint32_t query_head_num = query_shape.GetDim(1); + uint32_t key_head_num = key_shape.GetDim(1); + + uint32_t query_head_dim = query_shape.GetDim(2); + uint32_t cos_head_dim = cos_shape.GetDim(1); + uint32_t rotary_dim = cos_head_dim *2; + + uint32_t is_split = (rotary_dim == query_head_dim ? 0: 1); + tiling.set_queryHeadDim(query_head_dim); + tiling.set_qHeadNum(query_head_num); + tiling.set_kHeadNum(key_head_num); + tiling.set_rotaryDim(rotary_dim); + tiling.set_qHiddenSize(query_head_num * rotary_dim); + tiling.set_kHiddenSize(key_head_num * rotary_dim); + tiling.set_cosHeadDim(cos_head_dim); + + if (tokens < coreNum) { + coreNum = tokens; + } + tiling.set_tokensPerCore(static_cast(tokens / coreNum)); + tiling.set_tokensTail(tokens % coreNum); + const uint32_t *layout = context->GetAttrs()->GetAttrPointer(0); + const uint32_t *rotaryMode = context->GetAttrs()->GetAttrPointer(1); + tiling.set_layout(*layout); + tiling.set_rotaryMode(*rotaryMode); + + ge::DataType query_type = context->GetInputDesc(0)->GetDataType(); + if (query_type == ge::DataType::DT_FLOAT16) { + tiling_key = 1; + }else if(query_type == ge::DataType::DT_FLOAT) { + tiling_key = 2; + } + tiling_key = tiling_key * 10 + is_split; + + context->SetBlockDim(coreNum); + context->SetTilingKey(tiling_key); + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + currentWorkspace[0] = 0; + return ge::GRAPH_SUCCESS; +} +} // namespace optiling + +namespace ge { +static ge::graphStatus ApplyRotaryPosEmbV3InferShape(gert::InferShapeContext *context) { + const gert::Shape *query_shape = context->GetInputShape(0); + const gert::Shape *key_shape = context->GetInputShape(1); + gert::Shape *out_query_shape = context->GetOutputShape(0); + gert::Shape *out_key_shape = context->GetOutputShape(1); + *out_query_shape = *query_shape; + *out_key_shape = *key_shape; + return GRAPH_SUCCESS; +} +static graphStatus ApplyRotaryPosEmbV3InferDataType(gert::InferDataTypeContext *context) { + const auto inputDataType = context->GetInputDataType(0); + context->SetOutputDataType(0, inputDataType); + context->SetOutputDataType(1, inputDataType); + return ge::GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class ApplyRotaryPosEmbV3 : public OpDef { + public: + explicit ApplyRotaryPosEmbV3(const char *name) : OpDef(name) { + this->Input("query") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("key") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("cos") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("sin") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Output("query") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("key") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("layout").AttrType(OPTIONAL).Int(1); + this->Attr("rotary_mode").AttrType(OPTIONAL).String("interleave"); + + this->SetInferShape(ge::ApplyRotaryPosEmbV3InferShape).SetInferDataType(ge::ApplyRotaryPosEmbV3InferDataType); + this->AICore().SetTiling(optiling::ApplyRotaryPosEmbV3Tiling).AddConfig("ascend310p"); + } +}; +OP_ADD(ApplyRotaryPosEmbV3); +} // namespace ops diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3_tiling.h b/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..bf90e3c9abd78fc63656bedf6e912296e4a9b678 --- /dev/null +++ b/ops/ascendc/apply_rotary_pos_emb_v3/op_host/apply_rotary_pos_emb_v3_tiling.h @@ -0,0 +1,39 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ADD_CUSTOM_TILING_H +#define ADD_CUSTOM_TILING_H +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(ApplyRotaryPosEmbV3TilingData) + TILING_DATA_FIELD_DEF(uint32_t, tilingId); + TILING_DATA_FIELD_DEF(uint32_t, useCoreNum); + TILING_DATA_FIELD_DEF(uint32_t, tokensPerCore); + TILING_DATA_FIELD_DEF(uint32_t, tokensTail); + TILING_DATA_FIELD_DEF(uint32_t, qHeadNum); + TILING_DATA_FIELD_DEF(uint32_t, kHeadNum); + TILING_DATA_FIELD_DEF(uint32_t, qHiddenSize); + TILING_DATA_FIELD_DEF(uint32_t, kHiddenSize); + TILING_DATA_FIELD_DEF(uint32_t, queryHeadDim); + TILING_DATA_FIELD_DEF(uint32_t, cosHeadDim); + TILING_DATA_FIELD_DEF(uint32_t, rotaryDim); + TILING_DATA_FIELD_DEF(uint32_t, layout); + TILING_DATA_FIELD_DEF(uint32_t, rotaryMode); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(ApplyRotaryPosEmbV3, ApplyRotaryPosEmbV3TilingData) +} +#endif // ADD_CUSTOM_TILING_H diff --git a/ops/ascendc/apply_rotary_pos_emb_v3/op_kernel/apply_rotary_pos_emb_v3.cpp b/ops/ascendc/apply_rotary_pos_emb_v3/op_kernel/apply_rotary_pos_emb_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b1e2927256f64418218923d3d8dbfe1b477670d6 --- /dev/null +++ b/ops/ascendc/apply_rotary_pos_emb_v3/op_kernel/apply_rotary_pos_emb_v3.cpp @@ -0,0 +1,254 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel_operator.h" +constexpr int32_t BUFFER_NUM = 1; +template +class KernelApplyRotaryPosEmbV3 { + public: + __aicore__ inline KernelApplyRotaryPosEmbV3() {} + __aicore__ inline void Init(GM_ADDR query, GM_ADDR key, GM_ADDR cos, GM_ADDR sin, GM_ADDR oquery, GM_ADDR okey, + GM_ADDR workspace, ApplyRotaryPosEmbV3TilingData *tiling, AscendC::TPipe *tPipe) { + pipe = tPipe; + tilingData = tiling; + if constexpr (IS_SPLIT) { + this->originQueryHiddenSize = tilingData->queryHeadDim * tilingData->qHeadNum; + this->originKeyHiddenSize = tilingData->queryHeadDim * tilingData->kHeadNum; + } else { + this->originQueryHiddenSize = tilingData->qHiddenSize; + this->originKeyHiddenSize = tilingData->kHiddenSize; + } + this->blockId = AscendC::GetBlockIdx(); + this->loop = this->blockId < tilingData->tokensTail ? tilingData->tokensPerCore + 1 : tilingData->tokensPerCore; + this->startOffset = this->blockId < tilingData->tokensTail + ? (this->blockId * tilingData->tokensPerCore + this->blockId) + : (this->blockId * tilingData->tokensPerCore + tilingData->tokensTail); + queryKeyCalHiddenSize = tilingData->qHiddenSize + tilingData->kHiddenSize; + queryKeyInHiddenSize = this->originQueryHiddenSize + this->originKeyHiddenSize; + + queryInParams = {static_cast(tilingData->qHeadNum), + static_cast((tilingData->rotaryDim * sizeof(T)) / 32), + static_cast(((tilingData->queryHeadDim - tilingData->rotaryDim) * sizeof(T)) / 32), + static_cast(0)}; + keyInParams = {static_cast(tilingData->kHeadNum), + static_cast((tilingData->rotaryDim * sizeof(T)) / 32), + static_cast(((tilingData->queryHeadDim - tilingData->rotaryDim) * sizeof(T)) / 32), + static_cast(0)}; + queryOutParams = {static_cast(tilingData->qHeadNum), + static_cast((tilingData->rotaryDim * sizeof(T)) / 32), static_cast(0), + static_cast(((tilingData->queryHeadDim - tilingData->rotaryDim) * sizeof(T)) / 32)}; + keyOutParams = {static_cast(tilingData->kHeadNum), + static_cast((tilingData->rotaryDim * sizeof(T)) / 32), static_cast(0), + static_cast(((tilingData->queryHeadDim - tilingData->rotaryDim) * sizeof(T)) / 32)}; + uint16_t blockLen = tilingData->cosHeadDim * sizeof(T) / 32; + uint16_t blockCount = tilingData->qHeadNum + tilingData->kHeadNum; + qkPrepareDataCopyParams = {blockCount, blockLen, blockLen, blockLen}; + qGm.SetGlobalBuffer((__gm__ T *)query + startOffset * this->originQueryHiddenSize, this->originQueryHiddenSize); + kGm.SetGlobalBuffer((__gm__ T *)key + startOffset * this->originKeyHiddenSize, this->originKeyHiddenSize); + qOutGm.SetGlobalBuffer((__gm__ T *)oquery + startOffset * this->originQueryHiddenSize, this->originQueryHiddenSize); + kOutGm.SetGlobalBuffer((__gm__ T *)okey + startOffset * this->originKeyHiddenSize, this->originKeyHiddenSize); + cosGm.SetGlobalBuffer((__gm__ T *)cos + startOffset * tilingData->cosHeadDim, tilingData->cosHeadDim); + sinGm.SetGlobalBuffer((__gm__ T *)sin + startOffset * tilingData->cosHeadDim, tilingData->cosHeadDim); + + pipe->InitBuffer(qInQueue, BUFFER_NUM, (queryKeyInHiddenSize) * sizeof(T)); + pipe->InitBuffer(cosInQueue, BUFFER_NUM, tilingData->cosHeadDim * sizeof(T)); + pipe->InitBuffer(sinInQueue, BUFFER_NUM, tilingData->cosHeadDim * sizeof(T)); + pipe->InitBuffer(qOutQueue, BUFFER_NUM, (queryKeyCalHiddenSize) * sizeof(T)); + pipe->InitBuffer(originBuf, queryKeyCalHiddenSize * sizeof(T)); + pipe->InitBuffer(rotaryBuf, queryKeyCalHiddenSize * sizeof(T)); + pipe->InitBuffer(cosBuf, queryKeyCalHiddenSize * sizeof(T)); + pipe->InitBuffer(sinBuf, queryKeyCalHiddenSize * sizeof(T)); + pipe->InitBuffer(scatterBuf, queryKeyCalHiddenSize * sizeof(T)); + } + __aicore__ inline void Process() { + PrePareScatterOffset(); + for (uint32_t i = 0; i < loop; i++) { + CopyIn(i); + PrepareCosSin(); + PrepareQK(); + Compute(); + CopyOut(i); + } + } + + private: + __aicore__ inline void CopyIn(int32_t index) { + AscendC::LocalTensor qLocal = qInQueue.AllocTensor(); + AscendC::LocalTensor cosLocal = cosInQueue.AllocTensor(); + AscendC::LocalTensor sinLocal = sinInQueue.AllocTensor(); + AscendC::DataCopy(qLocal, qGm[index * this->originQueryHiddenSize], this->originQueryHiddenSize); + AscendC::DataCopy(qLocal[this->originQueryHiddenSize], kGm[index * this->originKeyHiddenSize], + this->originKeyHiddenSize); + AscendC::DataCopy(cosLocal, cosGm[index * tilingData->cosHeadDim], tilingData->cosHeadDim); + AscendC::DataCopy(sinLocal, sinGm[index * tilingData->cosHeadDim], tilingData->cosHeadDim); + qInQueue.EnQue(qLocal); + cosInQueue.EnQue(cosLocal); + sinInQueue.EnQue(sinLocal); + } + __aicore__ inline void PrepareCosSin() { + AscendC::LocalTensor cos = cosInQueue.DeQue(); + AscendC::LocalTensor cosTmp = cosBuf.Get(); + for (uint32_t i = 0; i < 2 * (tilingData->qHeadNum + tilingData->kHeadNum); ++i) { + AscendC::DataCopy(cosTmp[i * tilingData->cosHeadDim], cos, tilingData->cosHeadDim); + } + AscendC::LocalTensor sin = sinInQueue.DeQue(); + AscendC::LocalTensor sinTmp = sinBuf.Get(); + AscendC::DataCopy(sinTmp[tilingData->cosHeadDim], sin, tilingData->cosHeadDim); + AscendC::Muls(sinTmp[0], sin, static_cast(-1), static_cast(tilingData->cosHeadDim)); + AscendC::PipeBarrier(); + for (uint32_t i = 1; i < tilingData->qHeadNum + tilingData->kHeadNum; ++i) { + AscendC::DataCopy(sinTmp[i * tilingData->rotaryDim], sinTmp, tilingData->rotaryDim); + } + AscendC::PipeBarrier(); + + cosInQueue.FreeTensor(cos); + sinInQueue.FreeTensor(sin); + } + + __aicore__ inline void PrepareQK() { + AscendC::LocalTensor queryKey = qInQueue.DeQue(); + AscendC::LocalTensor origin = originBuf.Get(); + AscendC::LocalTensor rotary = rotaryBuf.Get(); + + uint64_t rsvdCnt = 0; + uint32_t maskIndex{0}; + if constexpr (IS_SPLIT) { + for (uint32_t i = 0; i < tilingData->qHeadNum + tilingData->kHeadNum; ++i) { + maskIndex = i * tilingData->rotaryDim; + AscendC::GatherMask(origin[maskIndex], queryKey[i * tilingData->queryHeadDim], 1, true, + tilingData->rotaryDim, gatherMask, rsvdCnt); + AscendC::GatherMask(origin[maskIndex + tilingData->cosHeadDim], queryKey[i * tilingData->queryHeadDim], 2, + true, tilingData->rotaryDim, gatherMask, rsvdCnt); + } + } else { + for (uint32_t i = 0; i < tilingData->qHeadNum + tilingData->kHeadNum; ++i) { + maskIndex = i * tilingData->rotaryDim; + AscendC::GatherMask(origin[maskIndex], queryKey[maskIndex], 1, true, tilingData->rotaryDim, gatherMask, + rsvdCnt); + AscendC::GatherMask(origin[maskIndex + tilingData->cosHeadDim], queryKey[maskIndex], 2, true, + tilingData->rotaryDim, gatherMask, rsvdCnt); + } + } + AscendC::PipeBarrier(); + + AscendC::DataCopy(rotary, origin[tilingData->cosHeadDim], qkPrepareDataCopyParams); + AscendC::DataCopy(rotary[tilingData->cosHeadDim], origin, qkPrepareDataCopyParams); + AscendC::PipeBarrier(); + + qInQueue.FreeTensor(queryKey); + } + __aicore__ inline void PrePareScatterOffset() { + AscendC::LocalTensor dstOffset = scatterBuf.Get(); + for (uint32_t i = 0; i < (tilingData->qHeadNum + tilingData->kHeadNum); ++i) { + AscendC::ArithProgression(dstOffset[i * tilingData->rotaryDim], + static_cast(i * tilingData->rotaryDim * sizeof(T)), 2 * sizeof(T), + static_cast(tilingData->cosHeadDim)); + AscendC::ArithProgression(dstOffset[i * tilingData->rotaryDim + tilingData->cosHeadDim], + static_cast((i * tilingData->rotaryDim + 1) * sizeof(T)), + 2 * sizeof(T), static_cast(tilingData->cosHeadDim)); + } + AscendC::PipeBarrier(); + } + __aicore__ inline void Compute() { + AscendC::LocalTensor qOutLocal = qOutQueue.AllocTensor(); + AscendC::LocalTensor origin = originBuf.Get(); + AscendC::LocalTensor rotary = rotaryBuf.Get(); + AscendC::LocalTensor cos = cosBuf.Get(); + AscendC::LocalTensor sin = sinBuf.Get(); + AscendC::Mul(origin, origin, cos, queryKeyCalHiddenSize); + AscendC::Mul(rotary, rotary, sin, queryKeyCalHiddenSize); + AscendC::PipeBarrier(); + AscendC::Add(origin, origin, rotary, queryKeyCalHiddenSize); + AscendC::PipeBarrier(); + + AscendC::LocalTensor dstOffset = scatterBuf.Get(); + AscendC::Scatter(qOutLocal, origin, dstOffset.ReinterpretCast(), static_cast(0), + (queryKeyCalHiddenSize)); + qOutQueue.EnQue(qOutLocal); + } + __aicore__ inline void CopyOut(int32_t index) { + AscendC::LocalTensor out = qOutQueue.DeQue(); + if constexpr (IS_SPLIT) { + AscendC::DataCopy(qOutGm[index * this->originQueryHiddenSize], out, queryOutParams); + AscendC::DataCopy(kOutGm[index * this->originKeyHiddenSize], out[tilingData->qHiddenSize], keyOutParams); + } else { + AscendC::DataCopy(qOutGm[index * tilingData->qHiddenSize], out, tilingData->qHiddenSize); + AscendC::DataCopy(kOutGm[index * tilingData->kHiddenSize], out[tilingData->qHiddenSize], tilingData->kHiddenSize); + } + qOutQueue.FreeTensor(out); + } + + private: + AscendC::TPipe *pipe; + AscendC::TQue qInQueue; + AscendC::TQue cosInQueue; + AscendC::TQue sinInQueue; + AscendC::TQue qOutQueue; + AscendC::TBuf originBuf; + AscendC::TBuf rotaryBuf; + AscendC::TBuf cosBuf; + AscendC::TBuf sinBuf; + AscendC::TBuf scatterBuf; + AscendC::GlobalTensor qGm; + AscendC::GlobalTensor kGm; + AscendC::GlobalTensor cosGm; + AscendC::GlobalTensor sinGm; + AscendC::GlobalTensor qOutGm; + AscendC::GlobalTensor kOutGm; + AscendC::GatherMaskParams gatherMask{1, 1, 8, 0}; + AscendC::DataCopyParams qkPrepareDataCopyParams; + AscendC::DataCopyParams queryOutParams, keyOutParams, queryInParams, keyInParams; + uint32_t blockId{0}; + uint32_t startOffset{0}; + uint32_t loop{0}; + uint32_t originQueryHiddenSize{0}; + uint32_t originKeyHiddenSize{0}; + uint32_t queryKeyCalHiddenSize{0}; + uint32_t queryKeyInHiddenSize{0}; + ApplyRotaryPosEmbV3TilingData *tilingData = nullptr; +}; + +extern "C" __global__ __aicore__ void apply_rotary_pos_emb_v3(GM_ADDR query, GM_ADDR key, GM_ADDR cos, GM_ADDR sin, + GM_ADDR outq, GM_ADDR outk, GM_ADDR workspace, + GM_ADDR tiling) { + GET_TILING_DATA(tilingData, tiling); + GM_ADDR usrWorkspace = AscendC::GetUserWorkspace(workspace); + AscendC::TPipe pipe; + if (TILING_KEY_IS(20)) { + KernelApplyRotaryPosEmbV3 op; + op.Init(query, key, cos, sin, outq, outk, workspace, &tilingData, &pipe); + op.Process(); + } else if (TILING_KEY_IS(10)) { + KernelApplyRotaryPosEmbV3 op; + op.Init(query, key, cos, sin, outq, outk, workspace, &tilingData, &pipe); + op.Process(); + } else if (TILING_KEY_IS(21)) { + KernelApplyRotaryPosEmbV3 op; + op.Init(query, key, cos, sin, outq, outk, workspace, &tilingData, &pipe); + op.Process(); + } else if (TILING_KEY_IS(11)) { + KernelApplyRotaryPosEmbV3 op; + op.Init(query, key, cos, sin, outq, outk, workspace, &tilingData, &pipe); + op.Process(); + } +} + +#ifndef ASCENDC_CPU_DEBUG +void apply_rotary_pos_emb_v3_do(uint32_t blockDim, void *l2ctrl, void *stream, uint8_t *query, uint8_t *key, + uint8_t *cos, uint8_t *sin, uint8_t *outq, uint8_t *outk, uint8_t *workspace, + uint8_t *tiling) { + apply_rotary_pos_emb_v3<<>>(query, key, cos, sin, outq, outk, workspace, tiling); +} +#endif diff --git a/ops/framework/utils.h b/ops/framework/utils.h index 9e897a6e023f735b2c367717812f2db9098ef2e8..53e088a74d6bbd5af92a9678aadda80b74a3ddca 100644 --- a/ops/framework/utils.h +++ b/ops/framework/utils.h @@ -23,6 +23,22 @@ #include "mindspore/include/custom_op_api.h" namespace ms_custom_ops { +constexpr size_t kIndex0{0}; +constexpr size_t kIndex1{1}; +constexpr size_t kIndex2{2}; +constexpr size_t kIndex3{3}; +constexpr size_t kIndex4{4}; +constexpr size_t kIndex5{5}; +constexpr size_t kIndex6{6}; +constexpr size_t kIndex7{7}; +constexpr size_t kIndex8{8}; +constexpr size_t kIndex9{9}; +constexpr size_t kDim0{0}; +constexpr size_t kDim1{1}; +constexpr size_t kDim2{2}; +constexpr size_t kDim3{3}; +constexpr size_t kDim4{4}; +constexpr size_t kDim5{5}; // Helper function to convert optional tensor to tensor or empty tensor inline ms::Tensor GetTensorOrEmpty(const std::optional &opt_tensor) { return opt_tensor.has_value() ? opt_tensor.value() : ms::Tensor(); diff --git a/tests/st/test_custom_rope_v3.py b/tests/st/test_custom_rope_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..fb46b3355be87f76d7a124c77ae9e3c97acb1a71 --- /dev/null +++ b/tests/st/test_custom_rope_v3.py @@ -0,0 +1,231 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import time +import numpy as np +import pytest +from functools import wraps + +import ms_custom_ops +import mindspore.ops as ops +import mindspore.nn as nn +import mindspore as ms +from mindspore.common.api import jit +from mindspore import Tensor, mint, nn, ops, context, Profiler +from mindspore.profiler import ProfilerLevel, ProfilerActivity, AicoreMetrics +# from mindspore.common.np_dtype import bfloat16 +from mindspore._c_expression import MSContext + +def jit_for_graph_mode(fn): + """ + A decorator that conditionally applies jit to a function at runtime based on the context mode. + """ + jitted_fn = jit(fn) + @wraps(fn) + def wrapper(*args, **kwargs): + if context.get_context("mode") == context.GRAPH_MODE: + return jitted_fn(*args, **kwargs) + return fn(*args, **kwargs) + return wrapper + +def golden_apply_rotary_emb( + x: Tensor, + cos: Tensor, + sin: Tensor, + is_neox_style: bool, +) -> Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = mint.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return mint.cat((o1, o2), dim=-1) + else: + return mint.stack((o1, o2), dim=-1).flatten(-2) + +def golden_apply_rotary_emb_split(net, exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, rotary_dim, is_profiler=False): + cos_head_dim = rotary_dim // 2 + np_query = np.random.random((tokens, head_num_q, head_dim)) + np_key = np.random.random((tokens, head_num_k, head_dim)) + np_cos = np.random.random((tokens, cos_head_dim)) + np_sin = np.random.random((tokens, cos_head_dim)) + query = Tensor(np_query, dtype=query_dtype) + key = Tensor(np_key , dtype=query_dtype) + cos = Tensor(np_cos, dtype=query_dtype) + sin = Tensor(np_sin, dtype=query_dtype) + golden_q = golden_apply_rotary_emb(query, cos, sin, False) + golden_k = golden_apply_rotary_emb(key, cos, sin, False) + if is_profiler == False: + out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) + np.testing.assert_allclose(golden_q.asnumpy(), out_query.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" query ") + np.testing.assert_allclose(golden_k.asnumpy(), out_key.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" key ") + else: + profiler = Profiler(profiler_level=ProfilerLevel.Level2, + activities=[ProfilerActivity.CPU, ProfilerActivity.NPU], + aic_metrics=AicoreMetrics.AiCoreNone) + for i in range(50): + out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) + profiler.analyse() + +class ApplyRotaryEmbV3Net(nn.Cell): + """Reshape and cache operation for NZ/ND format with all parameters""" + + @jit_for_graph_mode + def construct(self, query, key, cos, sin, layout, rotary_mode): + return ms_custom_ops.apply_rotary_pos_emb_v3(query, key, cos, sin, layout, rotary_mode) + +def run_rope_interleave(net, exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, is_profiler=False): + cos_head_dim = head_dim // 2 + np_query = np.random.random((tokens, head_num_q, head_dim)) + np_key = np.random.random((tokens, head_num_k, head_dim)) + np_cos = np.random.random((tokens, cos_head_dim)) + np_sin = np.random.random((tokens, cos_head_dim)) + query = Tensor(np_query, dtype=query_dtype) + key = Tensor(np_key , dtype=query_dtype) + cos = Tensor(np_cos, dtype=query_dtype) + sin = Tensor(np_sin, dtype=query_dtype) + golden_q = golden_apply_rotary_emb(query, cos, sin, False) + golden_k = golden_apply_rotary_emb(key, cos, sin, False) + if is_profiler == False: + out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) + np.testing.assert_allclose(golden_q.asnumpy(), out_query.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" query ") + np.testing.assert_allclose(golden_k.asnumpy(), out_key.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" key ") + else: + profiler = Profiler(profiler_level=ProfilerLevel.Level2, + activities=[ProfilerActivity.CPU, ProfilerActivity.NPU], + aic_metrics=AicoreMetrics.AiCoreNone) + for i in range(50): + out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) + profiler.analyse() + + +def run_rope_interleave_split(net, exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, rotary_dim, is_profiler=False): + cos_head_dim = rotary_dim // 2 + np_query = np.random.random((tokens, head_num_q, head_dim)) + np_key = np.random.random((tokens, head_num_k, head_dim)) + np_cos = np.random.random((tokens, cos_head_dim)) + np_sin = np.random.random((tokens, cos_head_dim)) + query = Tensor(np_query, dtype=query_dtype) + key = Tensor(np_key , dtype=query_dtype) + cos = Tensor(np_cos, dtype=query_dtype) + sin = Tensor(np_sin, dtype=query_dtype) + if is_profiler == False: + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + + query_rot = golden_apply_rotary_emb(query_rot, cos, sin, False) + key_rot = golden_apply_rotary_emb(key_rot, cos, sin, False) + + golden_q = mint.cat((query_rot, query_pass), dim=-1) + golden_k = mint.cat((key_rot, key_pass), dim=-1) + else: + start_time = time.perf_counter() + for i in range(50): + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + + query_rot = golden_apply_rotary_emb(query_rot, cos, sin, False) + key_rot = golden_apply_rotary_emb(key_rot, cos, sin, False) + + query = mint.cat((query_rot, query_pass), dim=-1) + key = mint.cat((key_rot, key_pass), dim=-1) + end_time = time.perf_counter() + total_time = end_time - start_time + print("tokens:", tokens) + print("小算子50次平均耗时(毫秒):", total_time * 1000/50) + + if is_profiler == False: + out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) + np.testing.assert_allclose(golden_q.asnumpy(), out_query.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" query ") + np.testing.assert_allclose(golden_k.asnumpy(), out_key.asnumpy(), rtol=1e-4, atol=1e-4, err_msg=" key ") + else: + profiler = Profiler(profiler_level=ProfilerLevel.Level2, + activities=[ProfilerActivity.CPU, ProfilerActivity.NPU], + aic_metrics=AicoreMetrics.AiCoreNone) + for i in range(50): + out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) + profiler.analyse() + + start_time = time.perf_counter() + for i in range(50): + out_query, out_key = net(query, key, cos, sin, layout, rotary_mode) + end_time = time.perf_counter() + total_time = end_time - start_time + print("tokens:", tokens) + print("大算子50次平均耗时(毫秒):", total_time * 1000/50) + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("query_dtype", [ms.float32, ms.float16]) +@pytest.mark.parametrize("layout", ["BSH"]) +@pytest.mark.parametrize("rotary_mode", ["interleave"]) +@pytest.mark.parametrize("tokens", [10, 40960]) +@pytest.mark.parametrize("head_num_q", [32]) +@pytest.mark.parametrize("head_num_k", [2]) +@pytest.mark.parametrize("head_dim", [64]) +def test_rope_v3_interleave(exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim): + """ + Feature:aclnnApplyRotaryPosEmb kernel. + Description: test for ApplyRotaryPosEmbExt ops. + Expectation:should pass for all testcases. + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + net = ApplyRotaryEmbV3Net() + run_rope_interleave(net, exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim) + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("query_dtype", [ms.float32, ms.float16]) +@pytest.mark.parametrize("layout", ["BSH"]) +@pytest.mark.parametrize("rotary_mode", ["interleave"]) +@pytest.mark.parametrize("tokens", [10, 40960]) +@pytest.mark.parametrize("head_num_q", [32]) +@pytest.mark.parametrize("head_num_k", [2]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("rotary_dim", [64]) +def test_rope_v3_interleave_split(exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, rotary_dim): + """ + Feature:aclnnApplyRotaryPosEmb kernel. + Description: test for ApplyRotaryPosEmbExt ops. + Expectation:should pass for all testcases. + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + net = ApplyRotaryEmbV3Net() + run_rope_interleave_split(net, exec_mode, query_dtype, layout, rotary_mode, tokens, head_num_q, head_num_k, head_dim, rotary_dim)