diff --git a/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.cc b/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.cc new file mode 100644 index 0000000000000000000000000000000000000000..65aa36b42a268b13bcb114d736681f11fdf3e2d8 --- /dev/null +++ b/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.cc @@ -0,0 +1,281 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include +#include +#include +#include "ops/framework/aclnn/graphmode/aclnn_kernel_mod.h" +#include "ops/framework/utils.h" + +namespace ms_custom_ops { +enum KvRmsNormRopeCacheInputIndex : size_t { + kKvRmsNormRopeCacheKvIndex = 0, + kKvRmsNormRopeCacheGammaIndex, + kKvRmsNormRopeCacheCosIndex, + kKvRmsNormRopeCacheSinIndex, + kKvRmsNormRopeCacheIdxIndex, + kKvRmsNormRopeCacheKCacheIndex, + kKvRmsNormRopeCacheCKvCacheIndex, + kKvRmsNormRopeCacheKRopeScaleIndex, + kKvRmsNormRopeCacheCKvScaleIndex, + kKvRmsNormRopeCacheKRopeOffsetIndex, + kKvRmsNormRopeCacheCKvOffsetIndex, + kKvRmsNormRopeCacheEpsilonIndex, + kKvRmsNormRopeCacheCacheModeIndex, + kKvRmsNormRopeCacheIsOutputKvIndex, + kKvRmsNormRopeCacheInputNums, +}; + +enum KvRmsNormRopeCacheSize : size_t { + kKvRmsNormRopeCacheSize0 = 0, + kKvRmsNormRopeCacheSize1, + kKvRmsNormRopeCacheSize2, + kKvRmsNormRopeCacheSize3, + kKvRmsNormRopeCacheSize4, + kKvRmsNormRopeCacheSize5, +}; + +enum KvRmsNormRopeCacheModeOutput : size_t { + kKvRmsNormRopeCacheKCacheOutIndex = 0, + kKvRmsNormRopeCacheCKvCacheOutIndex, + kKvRmsNormRopeCacheKRopeOutIndex, + kKvRmsNormRopeCacheCKvOutIndex, + kKvRmsNormRopeCacheOutNums, +}; + +enum KvRmsNormRopeCacheMode : size_t { + kKvRmsNormRopeCacheModeNorm = 0, + kKvRmsNormRopeCacheModePA, + kKvRmsNormRopeCacheModePA_BNSD, + kKvRmsNormRopeCacheModePA_NZ, + kKvRmsNormRopeCacheModePA_BLK_BNSD, + kKvRmsNormRopeCacheModePA_BLK_NZ, +}; + +static std::map KvRmsNormRopeCacheModeMap = { + {kKvRmsNormRopeCacheModeNorm, "Norm"}, + {kKvRmsNormRopeCacheModePA, "PA"}, + {kKvRmsNormRopeCacheModePA_BNSD, "PA_BNSD"}, + {kKvRmsNormRopeCacheModePA_NZ, "PA_NZ"}, + {kKvRmsNormRopeCacheModePA_BLK_BNSD, "PA_BLK_BNSD"}, + {kKvRmsNormRopeCacheModePA_BLK_NZ, "PA_BLK_NZ"}, +}; + +static std::string get_kv_rmsnorm_rope_cache_mode(KvRmsNormRopeCacheMode cache_mode) { + auto it = KvRmsNormRopeCacheModeMap.find(cache_mode); + if (it == KvRmsNormRopeCacheModeMap.end()) { + MS_EXCEPTION(ValueError) + << "For kv_rmsnorm_rope_cache, the cache mode should be Norm/PA/PA_BNSD/PA_NZ/PA_BLK_BNSD/PA_BLK_NZ, but got:" + << cache_mode; + } + return it->second; +} + +static constexpr size_t kKvRmsNormRopeCacheNormOutNums = 2; + +class OPS_API KvRmsNormRopeCacheCustomOpFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + MS_EXCEPTION_IF_NULL(primitive); + + auto kv_shape = input_infos[kKvRmsNormRopeCacheKvIndex]->GetShape(); + auto k_cache_shape = input_infos[kKvRmsNormRopeCacheKCacheIndex]->GetShape(); + auto c_kv_cache_shape = input_infos[kKvRmsNormRopeCacheCKvCacheIndex]->GetShape(); + auto gamma_shape = input_infos[kKvRmsNormRopeCacheGammaIndex]->GetShape(); + auto cos_shape = input_infos[kKvRmsNormRopeCacheCosIndex]->GetShape(); + auto dv = gamma_shape.back(); + auto dk = cos_shape.back(); + auto is_output_kv = input_infos[kKvRmsNormRopeCacheIsOutputKvIndex]->GetScalarValueWithCheck(); + auto k_rope_shape = kv_shape; + auto c_kv_out_shape = kv_shape; + k_rope_shape[k_rope_shape.size() - 1] = dk; + c_kv_out_shape[c_kv_out_shape.size() - 1] = dv; + ShapeVector empty_shape; + if (input_infos[kKvRmsNormRopeCacheKvIndex]->IsDynamicRank()) { + if (is_output_kv) { + return {k_cache_shape, c_kv_cache_shape, k_rope_shape, c_kv_out_shape}; + } else { + return {k_cache_shape, c_kv_cache_shape, empty_shape, empty_shape}; + } + } + + // dynamic shape or static shape; + if (kv_shape.size() != kKvRmsNormRopeCacheSize4) { + MS_LOG(EXCEPTION) << "kv_rmsnorm_rope_cache kv input size should be " << kKvRmsNormRopeCacheSize4 + << ", but now got " << kv_shape.size(); + } + + if (gamma_shape.size() != kKvRmsNormRopeCacheSize1) { + MS_LOG(EXCEPTION) << "kv_rmsnorm_rope_cache gamma input size should be " << kKvRmsNormRopeCacheSize1 + << ", but now got " << gamma_shape.size(); + } + auto sin_shape = input_infos[kKvRmsNormRopeCacheSinIndex]->GetShape(); + if ((sin_shape.size() != kKvRmsNormRopeCacheSize4) || (cos_shape.size() != kKvRmsNormRopeCacheSize4)) { + MS_LOG(EXCEPTION) << "kv_rmsnorm_rope_cache cos or sin input size should be " << kKvRmsNormRopeCacheSize4 + << ", but now got cos " << cos_shape.size() << ", sin :" << sin_shape.size(); + } + auto cache_mode = static_cast( + input_infos[kKvRmsNormRopeCacheCacheModeIndex]->GetScalarValueWithCheck()); + auto index_shape = input_infos[kKvRmsNormRopeCacheIdxIndex]->GetShape(); + if (cache_mode == kKvRmsNormRopeCacheModeNorm) { + if (index_shape.size() != kKvRmsNormRopeCacheSize2) { + MS_LOG(EXCEPTION) << "kv_rmsnorm_rope_cache index input size should be " << kKvRmsNormRopeCacheSize2 + << ", but now got " << index_shape.size(); + } + } else { + if (index_shape.size() != kKvRmsNormRopeCacheSize1) { + MS_LOG(EXCEPTION) << "kv_rmsnorm_rope_cache index input size should be " << kKvRmsNormRopeCacheSize1 + << ", but now got " << index_shape.size(); + } + } + + if ((k_cache_shape.size() != kKvRmsNormRopeCacheSize4) || (c_kv_cache_shape.size() != kKvRmsNormRopeCacheSize4)) { + MS_LOG(EXCEPTION) << "kv_rmsnorm_rope_cache kCache or CKvCache input size should be " << kKvRmsNormRopeCacheSize4 + << ", but now got k_cache:" << k_cache_shape.size() + << ", c_kv_cache:" << c_kv_cache_shape.size(); + } + if (is_output_kv) { + return {k_cache_shape, c_kv_cache_shape, k_rope_shape, c_kv_out_shape}; + } else { + return {k_cache_shape, c_kv_cache_shape, empty_shape, empty_shape}; + } + } + + std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + auto kv_dtype = input_infos[kKvRmsNormRopeCacheKvIndex]->GetType(); + const std::set valid_types = {kNumberTypeFloat16, kNumberTypeBFloat16}; + CheckAndConvertUtils::CheckTypeIdValid("kv", kv_dtype, valid_types, op_name); + auto k_cache_dtype = input_infos[kKvRmsNormRopeCacheKCacheIndex]->GetType(); + auto c_Kv_cache_dtype = input_infos[kKvRmsNormRopeCacheCKvCacheIndex]->GetType(); + auto is_output_kv = input_infos[kKvRmsNormRopeCacheIsOutputKvIndex]->GetScalarValueWithCheck(); + return {k_cache_dtype, c_Kv_cache_dtype, kv_dtype, kv_dtype}; + } + + bool GeneralInferRegistered() const override { return true; } +}; + +class KvRmsNormRopeCacheCustomAscend : public AclnnCustomKernelMod { + public: + KvRmsNormRopeCacheCustomAscend() : AclnnCustomKernelMod("aclnnKvRmsNormRopeCache") {} + ~KvRmsNormRopeCacheCustomAscend() = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + MS_EXCEPTION_IF_NULL(stream_ptr); + RunOp(stream_ptr, workspace, inputs[kKvRmsNormRopeCacheKvIndex], inputs[kKvRmsNormRopeCacheGammaIndex], + inputs[kKvRmsNormRopeCacheCosIndex], inputs[kKvRmsNormRopeCacheSinIndex], inputs[kKvRmsNormRopeCacheIdxIndex], + inputs[kKvRmsNormRopeCacheKCacheIndex], inputs[kKvRmsNormRopeCacheCKvCacheIndex], + inputs[kKvRmsNormRopeCacheKRopeScaleIndex], inputs[kKvRmsNormRopeCacheCKvScaleIndex], + inputs[kKvRmsNormRopeCacheKRopeOffsetIndex], inputs[kKvRmsNormRopeCacheKRopeOffsetIndex], epsilon, + cache_mode_str, is_output_kv, outputs[kKvRmsNormRopeCacheKRopeOutIndex], + outputs[kKvRmsNormRopeCacheCKvOutIndex]); + return true; + } + + void GetWorkSpaceInfo(const std::vector &inputs, + const std::vector &outputs) override { + cache_mode_ = device::ascend::ConvertKernelTensor(inputs[kKvRmsNormRopeCacheCacheModeIndex]); + cache_mode_str = get_kv_rmsnorm_rope_cache_mode(static_cast(cache_mode_)); + epsilon = static_cast(device::ascend::ConvertKernelTensor(inputs[kKvRmsNormRopeCacheEpsilonIndex])); + is_output_kv = device::ascend::ConvertKernelTensor(inputs[kKvRmsNormRopeCacheIsOutputKvIndex]); + GetWorkspaceForResize(inputs[kKvRmsNormRopeCacheKvIndex], inputs[kKvRmsNormRopeCacheGammaIndex], + inputs[kKvRmsNormRopeCacheCosIndex], inputs[kKvRmsNormRopeCacheSinIndex], + inputs[kKvRmsNormRopeCacheIdxIndex], inputs[kKvRmsNormRopeCacheKCacheIndex], + inputs[kKvRmsNormRopeCacheCKvCacheIndex], inputs[kKvRmsNormRopeCacheKRopeScaleIndex], + inputs[kKvRmsNormRopeCacheCKvScaleIndex], inputs[kKvRmsNormRopeCacheKRopeOffsetIndex], + inputs[kKvRmsNormRopeCacheKRopeOffsetIndex], epsilon, cache_mode_str, is_output_kv, + outputs[kKvRmsNormRopeCacheKRopeOutIndex], outputs[kKvRmsNormRopeCacheCKvOutIndex]); + return; + } + + private: + DEFINE_GET_WORKSPACE_FOR_RESIZE(); + double epsilon; + int64_t cache_mode_; + bool is_output_kv; + std::string cache_mode_str; +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(kv_rmsnorm_rope_cache, ms_custom_ops::KvRmsNormRopeCacheCustomOpFuncImpl, + ms_custom_ops::KvRmsNormRopeCacheCustomAscend); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +namespace ms_custom_ops { +using namespace mindspore; +using namespace mindspore::device::ascend; + +std::vector kv_rmsnorm_rope_cache_custom( + const ms::Tensor &kv, const ms::Tensor &gamma, const ms::Tensor &cos, const ms::Tensor &sin, const ms::Tensor &index, + const ms::Tensor k_cache, const ms::Tensor &c_kv_cache, const std::optional &k_rope_scale, + const std::optional &c_kv_scale, const std::optional &k_rope_offset, + const std::optional &c_kv_offset, const float epsilon, const int64_t cache_mode, + const bool is_output_kv) { + auto kv_shape = kv.shape(); + auto dv = gamma.shape().back(); + auto dk = cos.shape().back(); + ShapeVector k_rope_shape; + ShapeVector c_kv_out_shape; + + if (is_output_kv) { + k_rope_shape = kv_shape; + c_kv_out_shape = kv_shape; + k_rope_shape[k_rope_shape.size() - 1] = dk; + c_kv_out_shape[c_kv_out_shape.size() - 1] = dv; + } + + std::vector outputs = { + ms::Tensor(kv.data_type(), k_rope_shape), + ms::Tensor(kv.data_type(), c_kv_out_shape), + }; + auto runner = std::make_shared("aclnnKvRmsNormRopeCache"); + auto cache_mode_str = get_kv_rmsnorm_rope_cache_mode(static_cast(cache_mode)); + runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC( + aclnnKvRmsNormRopeCache, kv, gamma, cos, sin, index, k_cache, c_kv_cache, k_rope_scale, c_kv_scale, k_rope_offset, + c_kv_offset, static_cast(epsilon), cache_mode_str, is_output_kv, outputs[0], outputs[1])); + // only set tensor. + runner->Run( + { + kv, + gamma, + cos, + sin, + index, + k_cache, + c_kv_cache, + GetTensorOrEmpty(k_rope_scale), + GetTensorOrEmpty(c_kv_scale), + GetTensorOrEmpty(k_rope_offset), + GetTensorOrEmpty(c_kv_offset), + }, + outputs); + return {k_cache, c_kv_cache, outputs[0], outputs[1]}; +} // namespace ms_custom_ops +} // namespace ms_custom_ops + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("kv_rmsnorm_rope_cache", + PYBOOST_CALLER(ms_custom_ops::kKvRmsNormRopeCacheOutNums, ms_custom_ops::kv_rmsnorm_rope_cache_custom)); +} \ No newline at end of file diff --git a/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.md b/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.md new file mode 100644 index 0000000000000000000000000000000000000000..06161c10b15a6b44fa1ed3b384b6f0ca86c041c7 --- /dev/null +++ b/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.md @@ -0,0 +1,270 @@ +# kv_rmsnorm_rope_cache算子 + +## 描述 + +kv_rmsnorm_rope_cache算子用于对输入张量(kv)的尾轴,拆分出左半边用于rms_norm计算,右半边用于rope计算,再将计算结果分别scatter到两块cache中。该算子底层调用的是aclnnKvRmsNormRopeCache算子。 + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|---------------------|-----------------|----------------------------------------|----------|---------|--------|--------------------------------------------------------| +| kv | Tensor | 4维[B_kv, N, S_kv, D] | No | No | ND | 用于切分出RMSNorm计算所需数据Dv和RoPE计算所需数据Dk的输入数据 | +| gamma | Tensor | 1维[D_v] | No | No | ND | 用于RMSNorm计算的输入数据 | +| cos | Tensor | 4维[Bkv, 1, Skv, Dk] 或 [Bkv, 1, 1, Dk] | No | No | ND | 用于计算Rope的余弦变换数据 | +| sin | Tensor | 4维[Bkv, 1, Skv, Dk] 或 [Bkv, 1, 1, Dk] | No | No | ND | 表用于计算Rope的正弦变换数据 | +| index | Tensor | 多维,参见约束说明 | No | No | ND | 用于指定写入cache的具体索引位置,value为-1时跳过更新 | +| k_cache | Tensor | 4维,参见约束说明 | No | No | ND | 提前申请的Cache | +| c_kv_cache | Tensor | 4维,参见约束说明 | No | No | ND | 提前申请的Cache | +| k_rope_scale | Tensor | 多维,参见约束说明 | Yes | No | ND | k_cache为int8时量化参数 | +| c_kv_scale | Tensor | 多维,参见约束说明 | Yes | No | ND | 提c_kv_cache int8时量化参数 | +| k_rope_offet | Tensor | 多维,参见约束说明 | Yes | No | ND | k_cache为int8时量化参数且非对称量化时的偏移参数 | +| c_kv_offset_ | Tensor | 多维,参见约束说明 | Yes | No | ND | c_kv_cache为int8时量化参数且非对称量化时的偏移参数 | +| epsilon | float | No | No | No | - | Rmsnorm进行计算时防除零的epsilon值 | +| cache_mode | int | No | No | No | - | 提cache格式选择枚举,枚举含义详见约束说明 | +| is_output_kv | bool | No | No | No | - | 控制是否输出k_rope_out和c_kv_out的标志位 | + +Note: +形状约束 ++ ​kv输入: shape为[Bkv, N, Skv, D],其中D = Dk + Dv ++ ​gamma输入: shape为[Dv,] ++ ​cos/sin输入: shape为[Bkv, 1, Skv, Dk] 或 [Bkv, 1, 1, Dk],必须与cos保持一致 +​index输入: ++ cacheMode为Norm时: shape为2维[Bkv, Skv] ++ cacheMode为PA_BNSD/PA_NZ时: shape为1维[Bkv * Skv] ++ cacheMode为PA_BLK_BSND/PA_BLK_NZ时: shape为1维[Bkv * ceil_div(Skv, BlockSize)] + +数据类型约束 ++ ​kv/gamma/cos/sin: 支持FLOAT16、BFLOAT16 ++ ​index: 支持INT64 ++ ​cache: 支持与输入kv相同的数据类型或INT8 + +数值约束 ++ ​N值: 仅支持N=1(与DeepSeekV3网络结构强相关) ++ ​Dk值: 必须为偶数(满足RoPE规则) +​对齐要求: ++ NZ场景下Dk、Dv需32B对齐 ++ PA场景下BlockSize需32B对齐 ++ 不同数据类型对齐要求不同(INT8需32B对齐,FLOAT16需16B对齐) + +index约束 ++ ​Norm模式: value范围[-1, Scache),不同Bkv下value可重复 ++ ​PA_BNSD/PA_NZ模式: value范围[-1, BlockNum * BlockSize),value不能重复 ++ ​PA_BLK_BSND/PA_BLK_NZ模式: value范围[-1, BlockNum * BlockSize),value/BlockSize的值不能重复 + +cache_mode枚举值约束: ++ "Norm" => 0 ++ "PA" => 1 ++ "PA_BNSD" => 2 ++ "PA_NZ" => 3 ++ "PA_BLK_BNSD" => 4 ++ "PA_BLK_NZ" => 5 ++ 其余枚举值输入无效,会被拦截; + +## 输出参数 + +| Name | DType | Shape | Description | +|--------|------------|------------|-----------------------| +| k_cache_out| Tensor | 与k_cache相同 | rope计算后的cache | +| c_kv_cache_out | Tensor | 与c_kv_cache相同 | Rmsnorm计算后的cache | +| k_rope_out| Tensor | [B_kv, N, S_kv, D_k] | Rope计算的结果(is_output_kv为True时输出) | +| c_kv_out | Tensor | [B_kv, N, S_kv, D_v] | Rmsnorm计算的结果(is_output_kv为True时输出) | + +k_cache_out数据类型和k_cache相同,shape大小一样。 +c_kv_cache_out数据类型和c_kv_cache相同,shape大小一样。 +k_rope_out数据类型和kv相同。 +c_kv_out数据类型和kv相同。 + +更多详细信息请参考:[aclnnKvRmsNormRopeCache](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/83RC1alpha002/API/aolapi/context/aclnnKvRmsNormRopeCache.md) + + +## 特殊说明 + +## 使用示例 + +### 基本使用示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops + +ms.set_device("Ascend") + +@ms.jit +def kv_rmsnorm_rope_cache_func(kv, + gamma, + cos, + sin, + index, + kCacheRef, + ckvCacheRef, + kRopeScale=None, + ckvScale=None, + kRopeOffset=None, + cKvOffset=None, + epsilon=1e-5, + cache_mode=0, + is_output_kv=False, + ): + return ms_custom_ops.kv_rmsnorm_rope_cache( + kv, + gamma, + cos, + sin, + index, + kCacheRef, + ckvCacheRef, + kRopeScale, + ckvScale, + kRopeOffset, + cKvOffset, + epsilon, + cache_mode, + is_output_kv, + ) + +def to_tensor(arr, dtype=None): + if arr is None: + return None + if dtype == np.int8: + return Tensor(arr, dtype=ms.int8) + return Tensor(arr, dtype=get_ms_dtype(dtype)) + +def generate_inputs( + batch_size, + seq_len, + page_num, + page_size, + quant_mode, + cache_mode, + output_mode, + input_dtype, + epsilon, +): + # 生成基础输入张量 + kv = np.random.randn(batch_size, 1, seq_len, 576).astype(input_dtype) + gamma = np.random.randn(512).astype(input_dtype) + cos = np.random.randn(batch_size, 1, seq_len, 64).astype(input_dtype) + sin = np.random.randn(batch_size, 1, seq_len, 64).astype(input_dtype) + + # 初始化缓存相关变量 + k_cache = None + ckv_cache = None + index = None + k_rope_scale = None + c_kv_scale = None + + # 处理缓存模式 + if cache_mode != "Norm": + # 创建初始缓存(全9张量) + k_cache = np.ones((page_num, page_size, 1, 64), dtype=input_dtype) * 9 + ckv_cache = np.ones((page_num, page_size, 1, 512), dtype=input_dtype) * 9 + + # 创建索引数组 + if "BLK" in cache_mode: + total_blocks = batch_size * ((seq_len + page_size - 1) // page_size) + index = np.arange(0, total_blocks * page_size, page_size, dtype=np.int64) + else: + index = np.arange(0, batch_size * seq_len, 1, dtype=np.int64) + + # 处理量化模式 + if quant_mode == 1: + if k_cache is not None: + k_cache = k_cache.astype(np.int8) + if ckv_cache is not None: + ckv_cache = ckv_cache.astype(np.int8) + k_rope_scale = np.random.randn(64).astype(np.float32) + c_kv_scale = np.random.randn(512).astype(np.float32) + + # 应用与原始代码相同的变换 + kv = 8 * kv - 10 # (-2 + 10) = 8 + gamma = 990 * gamma - 1000 # (-10 + 1000) = 990 + sin = 0.02 * sin - 0.01 # (0.01 + 0.01) = 0.02 + return ( + kv, + gamma, + cos, + sin, + index, + k_cache, + ckv_cache, + k_rope_scale, + c_kv_scale, + cache_mode, + output_mode, + input_dtype, + epsilon, + ) + + +def get_kv_rmsnorm_rope_cache_mode_enum(cache_mode_str): + if cache_mode_str == "PA": + return 1 + elif cache_mode_str == "PA_BNSD": + return 2 + elif cache_mode_str == "PA_NZ": + return 3 + elif cache_mode_str == "PA_BLK_BNSD": + return 4 + elif cache_mode_str == "PA_BLK_NZ": + return 5 + return 0 # "Norm" + +batch_size = 64 +seq_len = 1 +page_num = 576 +page_size = 128 +quant_mode = 0 +cache_mode = "PA_BNSD" +output_mode = False +input_dtype = np.float16 +epsilon = 1e-5 +(kv, gamma, cos, sin, index, k_cache, ckv_cache, k_rope_scale, c_kv_scale, cache_mode, output_mode, _, _,) = generate_inputs( + batch_size, + seq_len, + page_num, + page_size, + quant_mode, + cache_mode, + output_mode, + input_dtype, + epsilon) + +kv_tensor = to_tensor(kv, input_dtype) +gamma_tensor = to_tensor(gamma, input_dtype) +cos_tensor = to_tensor(cos, input_dtype) +sin_tensor = to_tensor(sin, input_dtype) +index_tensor = to_tensor(index, np.int64) +k_rope_scale_tensor = None +c_kv_scale_tensor = None +if quant_mode == 1: + k_cahce_tensor = to_tensor(k_cache, np.int8) + ckv_cache_tensor = to_tensor(ckv_cache, np.int8) + k_rope_scale_tensor = to_tensor(k_rope_scale, np.float32) + c_kv_scale_tensor = to_tensor(c_kv_scale, np.float32) +else: + k_cache_tensor = to_tensor(k_cache, input_dtype) + ckv_cache_tensor = to_tensor(ckv_cache, input_dtype) + +cache_mode_enum = get_kv_rmsnorm_rope_cache_mode_enum(cache_mode) +k_rope_offset = None +c_kv_offset = None +k_cache, ckv_cache, k_rope, c_kv = kv_rmsnorm_rope_cache_func( + kv_tensor, + gamma_tensor, + cos_tensor, + sin_tensor, + index_tensor, + k_cache_tensor, + ckv_cache_tensor, + k_rope_scale_tensor, + c_kv_scale_tensor, + k_rope_offset, + c_kv_offset, + epsilon, + cache_mode_enum, + output_mode, + ) +print(k_cache) +print(ckv_cache) +``` diff --git a/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache_op.yaml b/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..98fd958aa873383b62261c5570ff966dd194d7ca --- /dev/null +++ b/ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache_op.yaml @@ -0,0 +1,53 @@ +#operator kv_rmsnorm_rope_cache +kv_rmsnorm_rope_cache: + args: + kv: + dtype: tensor + gamma: + dtype: tensor + cos: + dtype: tensor + sin: + dtype: tensor + index: + dtype: tensor + k_cache: + dtype: tensor + ckv_cache: + dtype: tensor + k_rope_scale: + dtype: tensor + default: None + c_kv_scale: + dtype: tensor + default: None + k_rope_offset: + dtype: tensor + default: None + c_kv_offset: + dtype: tensor + default: None + epsilon: + dtype: float + default: 1e-5 + cache_mode: + dtype: int + default: 0 + is_output_kv: + dtype: Bool + default: False + args_signature: + rw_write: k_cache, ckv_cache + labels: + side_effect_mem: True + returns: + k_cache_out: + dtype: tensor + inplace: k_cache + ckv_cache_out: + dtype: tensor + inplace: ckv_cache + k_rope_out: + dtype: tensor + c_kv_out: + dtype: tensor \ No newline at end of file diff --git a/tests/st/test_custom_kv_rmsnorm_rope_cache.py b/tests/st/test_custom_kv_rmsnorm_rope_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..2c682d57fedd9ae3566068a4b504ebd465b09ff1 --- /dev/null +++ b/tests/st/test_custom_kv_rmsnorm_rope_cache.py @@ -0,0 +1,657 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import numpy as np +import pytest +import copy +from functools import wraps + +import mindspore.ops as ops +import mindspore.nn as nn +import mindspore as ms +from mindspore import context, Tensor +from mindspore.common.np_dtype import bfloat16 +from mindspore._c_expression import MSContext +import ms_custom_ops + + +def get_ms_dtype(query_dtype): + if query_dtype == np.float32: + ms_dtype = ms.float32 + elif query_dtype == np.float16: + ms_dtype = ms.float16 + elif query_dtype == bfloat16: + ms_dtype = ms.bfloat16 + elif query_dtype == np.int64: + ms_dtype = ms.int64 + return ms_dtype + + +def get_kv_rmsnorm_rope_cache_mode_enum(cache_mode_str): + if cache_mode_str == "PA": + return 1 + elif cache_mode_str == "PA_BNSD": + return 2 + elif cache_mode_str == "PA_NZ": + return 3 + elif cache_mode_str == "PA_BLK_BNSD": + return 4 + elif cache_mode_str == "PA_BLK_NZ": + return 5 + return 0 # "Norm" + + +def generate_inputs( + batch_size, + seq_len, + page_num, + page_size, + quant_mode, + cache_mode, + output_mode, + input_dtype, + epsilon, +): + # 生成基础输入张量 + kv = np.random.randn(batch_size, 1, seq_len, 576).astype(input_dtype) + gamma = np.random.randn(512).astype(input_dtype) + cos = np.random.randn(batch_size, 1, seq_len, 64).astype(input_dtype) + sin = np.random.randn(batch_size, 1, seq_len, 64).astype(input_dtype) + + # 初始化缓存相关变量 + k_cache = None + ckv_cache = None + index = None + k_rope_scale = None + c_kv_scale = None + + # 处理缓存模式 + if cache_mode != "Norm": + # 创建初始缓存(全9张量) + k_cache = np.ones((page_num, page_size, 1, 64), dtype=input_dtype) * 9 + ckv_cache = np.ones((page_num, page_size, 1, 512), dtype=input_dtype) * 9 + + # 创建索引数组 + if "BLK" in cache_mode: + total_blocks = batch_size * ((seq_len + page_size - 1) // page_size) + index = np.arange(0, total_blocks * page_size, page_size, dtype=np.int64) + else: + index = np.arange(0, batch_size * seq_len, 1, dtype=np.int64) + + # 处理量化模式 + if quant_mode == 1: + if k_cache is not None: + k_cache = k_cache.astype(np.int8) + if ckv_cache is not None: + ckv_cache = ckv_cache.astype(np.int8) + k_rope_scale = np.random.randn(64).astype(np.float32) + c_kv_scale = np.random.randn(512).astype(np.float32) + + # 应用与原始代码相同的变换 + kv = 8 * kv - 10 # (-2 + 10) = 8 + gamma = 990 * gamma - 1000 # (-10 + 1000) = 990 + sin = 0.02 * sin - 0.01 # (0.01 + 0.01) = 0.02 + + return ( + kv, + gamma, + cos, + sin, + index, + k_cache, + ckv_cache, + k_rope_scale, + c_kv_scale, + cache_mode, + output_mode, + input_dtype, + epsilon, + ) + + +def generate_inputs_mindspore( + batch_size, + seq_len, + page_num, + page_size, + quant_mode, + cache_mode, + output_mode, + input_dtype, + epsilon, +): + # 使用 NumPy 函数生成相同的数据 + results = generate_inputs( + batch_size, + seq_len, + page_num, + page_size, + quant_mode, + cache_mode, + output_mode, + input_dtype, + epsilon, + ) + + # 解包结果 + ( + kv, + gamma, + cos, + sin, + index, + k_cache, + ckv_cache, + k_rope_scale, + c_kv_scale, + cache_mode, + output_mode, + dtype, + _, + ) = results + + def to_tensor(arr, dtype=None): + if arr is None: + return None + if dtype == np.int8: + return Tensor(arr, dtype=ms.int8) + return Tensor(arr, dtype=get_ms_dtype(dtype)) + + kv_tensor = to_tensor(kv, dtype) + gamma_tensor = to_tensor(gamma, dtype) + cos_tensor = to_tensor(cos, dtype) + sin_tensor = to_tensor(sin, dtype) + index_tensor = to_tensor(index, np.int64) + k_rope_scale_tensor = None + c_kv_scale_tensor = None + if quant_mode == 1: + k_cahce_tensor = to_tensor(k_cache, np.int8) + ckv_cache_tensor = to_tensor(ckv_cache, np.int8) + k_rope_scale_tensor = to_tensor(k_rope_scale, np.float32) + c_kv_scale_tensor = to_tensor(c_kv_scale, np.float32) + else: + k_cache_tensor = to_tensor(k_cache, dtype) + ckv_cache_tensor = to_tensor(ckv_cache, dtype) + + return ( + kv, + gamma, + cos, + sin, + index, + k_cache, + ckv_cache, + k_rope_scale, + c_kv_scale, + kv_tensor, + gamma_tensor, + cos_tensor, + sin_tensor, + index_tensor, + k_cache_tensor, + ckv_cache_tensor, + k_rope_scale_tensor, + c_kv_scale_tensor, + cache_mode, + output_mode, + input_dtype, + epsilon, + ) + + +def supported_op_exec_numpy( + kv, + gamma, + cos, + sin, + index, + k_cache, + ckv_cache, + k_rope_scale=None, + c_kv_scale=None, + k_rope_offset=None, + c_kv_offset=None, + epsilon=1e-05, + cache_mode="Norm", + is_output_kv=False, +): + """ + NumPy实现aclnnKvRmsNormRopeCache功能 + + 参数: + kv: 输入张量 [batch_size, N, seq_len, D] + gamma: RMSNorm缩放因子 [Dv] + cos: 旋转位置编码余弦部分 [batch_size, 1, seq_len, Dk] 或 [batch_size, 1, 1, Dk] + sin: 旋转位置编码正弦部分 [batch_size, 1, seq_len, Dk] 或 [batch_size, 1, 1, Dk] + index: 索引张量,形状取决于cache_mode + k_cache: k缓存 (输入/输出) + ckv_cache: ckv缓存 (输入/输出) + k_rope_scale: k的量化缩放因子 (可选) + c_kv_scale: ckv的量化缩放因子 (可选) + k_rope_offset: k的量化偏移 (可选) + c_kv_offset: ckv的量化偏移 (可选) + epsilon: RMSNorm的小常数 + cache_mode: 缓存模式 ("Norm", "PA", "PA_BNSD", "PA_NZ", "PA_BLK_NZ", "PA_BLK_BNSD") + is_output_kv: 是否输出中间结果 + """ + + # 辅助函数 + def round_float_to_int8(src_array): + """将浮点数组四舍五入并转换为int8""" + rounded_array = np.round(src_array) + clipped_array = np.clip(rounded_array, -128, 127) + return clipped_array.astype(np.int8) + + def rotate_half(x): + """旋转输入的一半隐藏维度""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return np.concatenate((-x2, x1), axis=-1) + + # 获取输入形状 + batch_size, _, seq_len, D = kv.shape + Dv = gamma.shape[0] # RMSNorm部分的维度 + Dk = D - Dv # ROPE部分的维度 + + # 确定量化模式 + quantMode = 0 if k_rope_scale is None else 1 + d0 = 32 if quantMode == 1 else 16 + + # 对于PA模式,获取缓存形状 + if "PA" in cache_mode: + block_num, block_size, _, _ = k_cache.shape + + # 对于BLK模式,计算索引页长度 + if "BLK" in cache_mode: + index_page_id_length = index.shape[0] // batch_size + + # 拆分输入 + rms_in = kv[..., :Dv].astype(np.float32) # RMSNorm部分 + rope_in = kv[..., Dv:].astype(np.float32) # ROPE部分 + + # 计算RMSNorm + # 计算平方和平均值 + square_x = np.square(rms_in) + mean_square_x = np.mean(square_x, axis=-1, keepdims=True) + rms = np.sqrt(mean_square_x + epsilon) + + # 应用RMSNorm + gamma_f32 = gamma.astype(np.float32) + y = rms_in / rms * gamma_f32 + + # 计算ROPE + # 重塑输入以匹配旋转要求 + k = rope_in.reshape(batch_size, 1, seq_len, 32, 2) + k = np.transpose(k, (0, 1, 2, 4, 3)) # 交换最后两个维度 + k = k.reshape(batch_size, 1, seq_len, 64) + + # 广播cos和sin到正确的形状 + if cos.shape[2] == 1: # [batch_size, 1, 1, Dk] + cos_broadcast = np.tile(cos, (1, 1, seq_len, 1)) + else: # [batch_size, 1, seq_len, Dk] + cos_broadcast = cos + + if sin.shape[2] == 1: # [batch_size, 1, 1, Dk] + sin_broadcast = np.tile(sin, (1, 1, seq_len, 1)) + else: # [batch_size, 1, seq_len, Dk] + sin_broadcast = sin + + # 应用旋转位置编码 + k_embed = k * cos_broadcast.astype(np.float32) + rotate_half( + k + ) * sin_broadcast.astype(np.float32) + + # 复制输出结果 + k_embed_out = copy.deepcopy(k_embed).astype(kv.dtype) + y_out = copy.deepcopy(y).astype(kv.dtype) + + # 准备缓存数据 + if quantMode == 1: + # 应用量化 + if k_rope_offset is not None: + k_embed = (k_embed - k_rope_offset) * k_rope_scale + else: + k_embed = k_embed * k_rope_scale + k_embed = round_float_to_int8(k_embed) + + if c_kv_offset is not None: + y = (y - c_kv_offset) * c_kv_scale + else: + y = y * c_kv_scale + y = round_float_to_int8(y) + else: + # 转换为缓存数据类型 + k_embed = k_embed.astype(k_cache.dtype) + y = y.astype(ckv_cache.dtype) + + # 根据缓存模式更新缓存 + if cache_mode == "Norm": + # Norm模式 - 直接更新缓存 + for b in range(batch_size): + for s in range(seq_len): + pos = index[b, s] + if pos != -1: # -1表示跳过更新 + k_cache[b, 0, pos] = k_embed[b, 0, s] + ckv_cache[b, 0, pos] = y[b, 0, s] + + elif cache_mode in ("PA", "PA_BNSD"): + # PA和PA_BNSD模式 + k_cache_flat = k_cache.reshape(-1, Dk) + ckv_cache_flat = ckv_cache.reshape(-1, Dv) + + for b in range(batch_size): + for s in range(seq_len): + offset = index[b * seq_len + s] + if offset >= 0: + k_cache_flat[offset] = k_embed[b, 0, s] + ckv_cache_flat[offset] = y[b, 0, s] + + # 恢复缓存形状 + k_cache[:] = k_cache_flat.reshape(block_num, block_size, 1, Dk) + ckv_cache[:] = ckv_cache_flat.reshape(block_num, block_size, 1, Dv) + + elif cache_mode == "PA_BLK_NZ": + # PA_BLK_NZ模式 + k_cache_reshaped = k_cache.reshape(block_num, 1, -1, block_size, d0) + ckv_cache_reshaped = ckv_cache.reshape(block_num, 1, -1, block_size, d0) + + for b in range(batch_size): + for s in range(seq_len): + index_page_id = s // block_size + page_offset = index[b * index_page_id_length + index_page_id] + + if page_offset >= 0: + page_id = page_offset // block_size + token_offset = s % block_size + + # 更新缓存 + k_embed_reshaped = k_embed[b, 0, s].reshape(-1, d0) + k_cache_reshaped[page_id, 0, :, token_offset] = k_embed_reshaped + + y_reshaped = y[b, 0, s].reshape(-1, d0) + ckv_cache_reshaped[page_id, 0, :, token_offset] = y_reshaped + + # 恢复缓存形状 + k_cache[:] = k_cache_reshaped.reshape(block_num, block_size, 1, Dk) + ckv_cache[:] = ckv_cache_reshaped.reshape(block_num, block_size, 1, Dv) + + elif cache_mode == "PA_BLK_BNSD": + # PA_BLK_BNSD模式 + k_cache_reshaped = k_cache.reshape(block_num, block_size, 1, -1) + ckv_cache_reshaped = ckv_cache.reshape(block_num, block_size, 1, -1) + + for b in range(batch_size): + for s in range(seq_len): + index_page_id = s // block_size + page_offset = index[b * index_page_id_length + index_page_id] + + if page_offset >= 0: + page_id = page_offset // block_size + token_offset = s % block_size + + # 更新缓存 + k_cache_reshaped[page_id, token_offset, 0] = k_embed[b, 0, s] + ckv_cache_reshaped[page_id, token_offset, 0] = y[b, 0, s] + + # 恢复缓存形状 + k_cache[:] = k_cache_reshaped.reshape(block_num, block_size, 1, Dk) + ckv_cache[:] = ckv_cache_reshaped.reshape(block_num, block_size, 1, Dv) + + elif cache_mode == "PA_NZ": + # PA_NZ模式 + if quantMode == 1: + d0 = 32 + else: + d0 = 16 + + k_cache_reshaped = k_cache.reshape(block_num, 1, -1, block_size, d0) + ckv_cache_reshaped = ckv_cache.reshape(block_num, 1, -1, block_size, d0) + + for b in range(batch_size): + for s in range(seq_len): + page_offset = index[b * seq_len + s] + if page_offset >= 0: + page_id = page_offset // block_size + token_offset = page_offset % block_size + + # 更新缓存 + k_embed_reshaped = k_embed[b, 0, s].reshape(-1, d0) + k_cache_reshaped[page_id, 0, :, token_offset] = k_embed_reshaped + + y_reshaped = y[b, 0, s].reshape(-1, d0) + ckv_cache_reshaped[page_id, 0, :, token_offset] = y_reshaped + + # 恢复缓存形状 + k_cache[:] = k_cache_reshaped.reshape(block_num, block_size, 1, Dk) + ckv_cache[:] = ckv_cache_reshaped.reshape(block_num, block_size, 1, Dv) + + # 返回结果 + if is_output_kv: + return k_cache, ckv_cache, k_embed_out, y_out + else: + return k_cache, ckv_cache + + +class KvRmsNormRopeCacheNet(ms.nn.Cell): + def _init__(self): + super().__init__() + + def construct( + self, + kv, + gamma, + cos, + sin, + index, + kCacheRef, + ckvCacheRef, + kRopeScale=None, + ckvScale=None, + kRopeOffset=None, + cKvOffset=None, + epsilon=1e-5, + cache_mode=0, + is_output_kv=False, + ): + return ms_custom_ops.kv_rmsnorm_rope_cache( + kv, + gamma, + cos, + sin, + index, + kCacheRef, + ckvCacheRef, + kRopeScale, + ckvScale, + kRopeOffset, + cKvOffset, + epsilon, + cache_mode, + is_output_kv, + ) + + +def run( + net, + batch_size, + seq_len, + page_num, + page_size, + quant_mode, + cache_mode, + output_mode, + input_dtype, + epsilon, +): + ( + kv, + gamma, + cos, + sin, + index, + k_cache, + ckv_cache, + k_rope_scale, + c_kv_scale, + kv_tensor, + gamma_tensor, + cos_tensor, + sin_tensor, + index_tensor, + k_cache_tensor, + ckv_cache_tensor, + k_rope_scale_tensor, + c_kv_scale_tensor, + cache_mode, + output_mode, + input_dtype, + epsilon, + ) = generate_inputs_mindspore( + batch_size, + seq_len, + page_num, + page_size, + quant_mode, + cache_mode, + output_mode, + input_dtype, + epsilon, + ) + k_rope_offset = None + c_kv_offset = None + k_cache_golden = None + ckv_cache_golden = None + k_rope_golden = None + ckv_out_golden = None + if output_mode == False: + k_cache_golden, ckv_cache_golden = supported_op_exec_numpy( + kv, + gamma, + cos, + sin, + index, + k_cache, + ckv_cache, + k_rope_scale, + c_kv_scale, + k_rope_offset, + c_kv_offset, + epsilon, + cache_mode, + is_output_kv=output_mode, + ) + else: + k_cache_golden, ckv_cache_golden, k_rope_golden, ckv_out_golden = ( + supported_op_exec_numpy( + kv, + gamma, + cos, + sin, + index, + k_cache, + ckv_cache, + k_rope_scale, + c_kv_scale, + k_rope_offset, + c_kv_offset, + epsilon, + cache_mode, + is_output_kv=output_mode, + ) + ) + + cache_mode_enum = get_kv_rmsnorm_rope_cache_mode_enum(cache_mode) + k_cache, ckv_cache, k_rope, c_kv = net( + kv_tensor, + gamma_tensor, + cos_tensor, + sin_tensor, + index_tensor, + k_cache_tensor, + ckv_cache_tensor, + k_rope_scale_tensor, + c_kv_scale_tensor, + k_rope_offset, + c_kv_offset, + epsilon, + cache_mode_enum, + output_mode, + ) + k_cache_np = k_cache.asnumpy() + ckv_cache_np = ckv_cache.asnumpy() + if output_mode == True: + k_rope_np = k_rope.asnumpy() + c_kv_np = c_kv.asnumpy() + else: + pass + + if input_dtype == np.float16: + atol = 0.01 + rtol = 0.01 + else: + atol = 1e-5 + rtol = 1e-5 + + np.testing.assert_allclose(k_cache_golden, k_cache_np, rtol=rtol, atol=atol) + np.testing.assert_allclose(ckv_cache_golden, ckv_cache_np, rtol=rtol, atol=atol) + if output_mode == True: + np.testing.assert_allclose(k_rope_golden, k_rope_np, rtol=rtol, atol=atol) + np.testing.assert_allclose(ckv_out_golden, c_kv_np, rtol=rtol, atol=atol) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("batch_size", [64]) +@pytest.mark.parametrize("seq_len", [1]) +@pytest.mark.parametrize("page_num", [576]) +@pytest.mark.parametrize("page_size", [128]) +@pytest.mark.parametrize("quant_mode", [0]) +@pytest.mark.parametrize( + "cache_mode", + ["PA_BLK_NZ", "PA_BLK_BNSD", "PA_NZ", "PA_BNSD"], +) +@pytest.mark.parametrize("output_mode", [False]) +@pytest.mark.parametrize( + "input_dtype", + [np.float16], +) +def test_kv_rmsnorm_rope_cache( + exec_mode, + batch_size, + seq_len, + page_num, + page_size, + quant_mode, + cache_mode, + output_mode, + input_dtype, +): + epsilon = 1e-5 + ms.set_context(device_target="Ascend", mode=exec_mode) + net = KvRmsNormRopeCacheNet() + run( + net, + batch_size, + seq_len, + page_num, + page_size, + quant_mode, + cache_mode, + output_mode, + input_dtype, + epsilon, + )