From 67315253c463c799d7bc38b7f69760a425c22b82 Mon Sep 17 00:00:00 2001 From: tianxiaodong3 Date: Thu, 16 Oct 2025 16:54:58 +0800 Subject: [PATCH 1/2] support custom paged_attention op within ms_custom_ops --- .jenkins/check/config/filter_cpplint.txt | 2 + docs/map_from_buildin_to_custom.md | 1 + docs/op_list.md | 1 + ops/c_api/paged_attention/paged_attention.md | 139 + .../paged_attention/paged_attention_common.h | 120 + .../paged_attention/paged_attention_graph.cc | 885 ++++++ .../paged_attention/paged_attention_op.yaml | 88 + .../paged_attention_pynative.cc | 229 ++ .../pyboost/internal_pyboost_runner.h | 3 +- tests/st/test_custom_paged_attention.py | 2793 +++++++++++++++++ tests/st/test_custom_paged_attention_nz.py | 1500 +++++++++ 11 files changed, 5759 insertions(+), 2 deletions(-) create mode 100644 ops/c_api/paged_attention/paged_attention.md create mode 100644 ops/c_api/paged_attention/paged_attention_common.h create mode 100644 ops/c_api/paged_attention/paged_attention_graph.cc create mode 100644 ops/c_api/paged_attention/paged_attention_op.yaml create mode 100644 ops/c_api/paged_attention/paged_attention_pynative.cc create mode 100644 tests/st/test_custom_paged_attention.py create mode 100644 tests/st/test_custom_paged_attention_nz.py diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index e69de29..c80b002 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -0,0 +1,2 @@ +# ms_custom_ops +"ms_custom_ops/ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h "build/include_subdir" \ No newline at end of file diff --git a/docs/map_from_buildin_to_custom.md b/docs/map_from_buildin_to_custom.md index 40cbe07..7c381d1 100644 --- a/docs/map_from_buildin_to_custom.md +++ b/docs/map_from_buildin_to_custom.md @@ -11,6 +11,7 @@ | ops.auto_generate.paged_cache_load | [ms_custom_ops.paged_cache_load](../ops/c_api/paged_cache_load/paged_cache_load_doc.md) | 新增支持key、value支持不同dtype;取消inplace更新的输出key、value,直接改为输出 | | ops.auto_generate.quant_batch_matmul | [ms_custom_ops.quant_batch_matmul](../ops/c_api/quant_batch_matmul/quant_batch_matmul.md) | 新增了x2_format参数,用于指定x2的format; 入参名称`pertokenScaleOptional`修改为`pertoken_scale`; 入参名称`dtype`修改为`output_dtype` | | ops.auto_generate.apply_rotary_pos_emb | [ms_custom_ops.apply_rotary_pos_emb_atb](../ops/c_api/apply_rotary_pos_emb_atb/apply_rotary_pos_emb_atb.md) | 新增atb的apply_rotary_pos_emb_atb算子,代替ops.auto_generate.apply_rotary_pos_emb,注意rotary_coeff和cos_format有变化,详见[API](https://gitee.com/mindspore/ms_custom_ops/blob/master/ops/c_api/apply_rotary_pos_emb_atb/apply_rotary_pos_emb_atb.md) | +| ops.auto_generate.PagedAttention | [ms_custom_ops.paged_attention](../ops/c_api/paged_attention/paged_attention.md) | 接口统一封装 ATB PagedAttentionOperation,支持 ND/NZ、量化、MTP、MLA等特性,具体约束见文档 | | ops.moe_token_unpermute | [ms_custom_ops.moe_token_unpermute](../ops/c_api/moe_token_unpermute/moe_token_unpermute.md) | 接口参数一致, 但需注意:ops接口只支持A2训练芯片,ms_custom_ops场景下只支持Atlas推理系列产品, 并且ms_custom_ops场景下当前仅支持:`padded_mode = false, restore_shape = None`, topK 支持 1、2,、4、8, hidden_size 支持 2048、5120、7168。 | | ops.auto_generate.GroupedMatmul | [ms_custom_ops.grouped_matmul](../ops/c_api/grouped_matmul/grouped_matmul.md) | 接口变更:输入从tuple[tensor]改为tensor,group_list前移到weight之后并改为必传参数,移除了offset、antiquant_offset、split_item、group_type等参数,返回从tuple[tensor]改为tensor,基于Internal框架实现,仅支持 Atlas 推理系列 | | ops.auto_generate.GroupedMatmulV4 | [ms_custom_ops.grouped_matmul](../ops/c_api/grouped_matmul/grouped_matmul.md) | 接口变更:输入从tuple[tensor]改为tensor,group_list前移到weight之后并改为必传参数,per_token_scale改为per_token_scale,移除了offset、antiquant_offset、activation相关参数、split_item、group_type、group_list_type、act_type、output_dtype等参数,返回从tuple[tensor]改为tensor,增加transpose_a、transpose_b参数,基于Internal框架实现,仅支持 Atlas 推理系列 | \ No newline at end of file diff --git a/docs/op_list.md b/docs/op_list.md index a7a8ed8..39abb44 100644 --- a/docs/op_list.md +++ b/docs/op_list.md @@ -19,6 +19,7 @@ 1. [moe_init_routing_v2](../ops/c_api/moe_init_routing_v2/moe_init_routing_v2.md) 1. [moe_token_unpermute](../ops/c_api/moe_token_unpermute/moe_token_unpermute.md) 1. [paged_cache_load](../ops/c_api/paged_cache_load/paged_cache_load_doc.md) +1. [paged_attention](../ops/c_api/paged_attention/paged_attention.md) 1. [quant_batch_matmul](../ops/c_api/quant_batch_matmul/quant_batch_matmul.md) 1. [reshape_and_cache](../ops/c_api/reshape_and_cache/reshape_and_cache.md) 1. [reshape_and_cache_npd](../ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd.md) diff --git a/ops/c_api/paged_attention/paged_attention.md b/ops/c_api/paged_attention/paged_attention.md new file mode 100644 index 0000000..b51903d --- /dev/null +++ b/ops/c_api/paged_attention/paged_attention.md @@ -0,0 +1,139 @@ +# PagedAttention + +## 描述 + +PagedAttention 是针对增量推理/分页 KV-Cache 的注意力算子,支持未量化、反量化融合、QKV 全量化(Offline/Online),多 Token 推理(MTP)、多种掩码、不同输入布局与 ND/NZ 输入格式。 + +## 接口与输入输出 + +### 名称 + +- 算子名:`paged_attention` + +### 输入参数 + +| Name | DType | Shape | Optional (Default) | Format | Description | +|--------------------|---------------------------------------|-----------------------------------------|--------------------|------------|-------------| +| query | float16/bfloat16 或 int8 | ND: [num_tokens, q_head_num, head_size];Atlas 推理系列产品/NZ: [num_tokens, q_head_num*head_size] | No | ND/NZ | 查询张量;全量化时为 int8。| +| key_cache | float16/bfloat16 或 int8 | ND: [num_blocks, block_size, kv_head_num, head_size_k];NZ: [num_blocks, block_size, kv_head_num*head_size_k]| No | ND/NZ | 分页 KV 的 K 缓存;量化场景为 int8。 | +| value_cache | float16/bfloat16 或 int8 | ND: [num_blocks, block_size, kv_head_num, head_size_v];NZ: [num_blocks, block_size, kv_head_num*head_size_v] | No | ND/NZ | 分页 KV 的 V 缓存;量化场景为 int8。 | +| block_tables | int32 | [num_tokens, max_num_blocks_per_query] | No | ND | 每个 token 对应的 block 索引表。 | +| context_lens | int32 | [batch] | No | ND | 每 batch 的 KV token 数。Atlas A2 训练系列产品/Atlas A2 推理系列产品和Atlas A3 推理系列产品/Atlas A3 训练系列产品:作为 param 从 CPU 读取;Atlas 推理系列产品:保持为 NPU 张量输入。 | +| attn_mask | float16/bfloat16 | 较复杂,参考用例 | Yes (None) | ND/NZ | 注意力掩码;不同 `mask_type` 对应不同构造。 | +| batch_run_status | int32 | [batch] | Yes (None) | ND | 控制可计算 batch 的标志;需与 `batch_run_status_enable` 配合。 | +| k_descale | float32 或 int64 | [kv_head_num*head_size] | Yes (None) | ND | 反量化融合/全量化。 | +| k_offset | int32 | [kv_head_num*head_size] | Yes (None) | ND | 非对称反量化偏移,`has_quant_offset=True` 时使用。 | +| v_descale | float32 或 int64 | [kv_head_num*head_size] | Yes (None) | ND | 反量化融合/全量化步长。 | +| v_offset | int32 | [kv_head_num*head_size] | Yes (None) | ND | 非对称反量化偏移,`has_quant_offset=True` 时使用。 | +| razor_offset | float32 | [num_blocks, block_size] | Yes (None) | ND | Razor Rope 场景偏移。 | +| p_scale | float32 | [q_head_num] | Yes (None) | ND | 离线全量化时传入 P 矩阵量化 scale。 | +| log_n | float32 | [batch] | Yes (None) | ND | `scale_type=LOGN` 时为各 batch 缩放系数。 | +| q_seq_lens | int32 | [batch] | Cond (None) | ND(CPU) | 并行解码/MTP(`calc_type=1`)时必需,始终按 param 绑定(CPU)。 | +| q_head_num | int | - | Yes (0) | - | Q 头数;不允许为 0。 | +| qk_scale | float | - | Yes (1.0) | - | QK 缩放(TOR)。 | +| kv_head_num | int | - | Yes (0) | - | KV 头数。 | +| mask_type | int | - | Yes (0) | - | 掩码类型,见下文枚举。 | +| batch_run_status_enable | bool | - | Yes (False) | - | 是否启用 `batch_run_status`。 | +| quant_type | int | - | Yes (0) | - | 量化类型,见下文枚举。 | +| out_data_type | int | - | Yes (-1) | - | 全量化输出类型:1=fp16,27=bf16。 | +| has_quant_offset | bool | - | Yes (False) | - | 是否使用非对称反量化偏置。 | +| compress_type | int | - | Yes (0) | - | 压缩类型,见下文枚举。 | +| calc_type | int | - | Yes (0) | - | 计算模式(MTP),见下文枚举。 | +| scale_type | int | - | Yes (0) | - | 缩放类型(TOR/LOGN)。 | +| input_layout | int | - | Yes (0) | - | 输入布局:0=BSND,1=BNSD。 | +| mla_v_dim | int | - | Yes (0) | - | MLA 输出 head 维度(0 关闭)。 | +| input_format | int | - | Yes (0) | - | 0=ND(Atlas A2 训练系列产品/Atlas A2 推理系列产品和Atlas A3 推理系列产品/Atlas A3 训练系列产品),1=NZ(Atlas 推理系列产品)。 | + +说明:`context_lens` 与 `q_seq_lens` 虽为输入张量,但以 param 方式绑定到 host;在 Atlas 推理系列产品/NZ 路径下 `context_lens` 保持为 NPU 输入张量,而 `q_seq_lens` 始终通过 CPU 侧 param 传入。 + +### 输出参数 + +| Name | DType | Shape | Description | +|----------------|-----------------------------------------|------------------------------------|-------------| +| attention_out | float16/bfloat16(或由 `out_data_type` 指定) | [num_tokens, q_head_num, head_size_out] | 注意力输出;当 `mla_v_dim>0` 时 `head_size_out=mla_v_dim`,否则与输入 `head_size` 一致。 | + +## 参数含义与枚举值 + +- mask_type(掩码类型) + - 0:PA_MASK_UNDEFINED + - 1:PA_MASK_TYPE_NORM(倒三角mask) + - 2:PA_MASK_TYPE_ALIBI + - 3:PA_MASK_TYPE_SPEC(并行解码mask) + - 4:PA_MASK_TYPE_MASK_FREE(仅 fp16 支持) +- quant_type(量化类型) + - 0:UNQUANT/UNDEFINED(未量化) + - 1:DEQUANT_FUSION(KV int8 反量化融合) + - 2:QUANT_QKV_OFFLINE(全量化离线) + - 3:QUANT_QKV_ONLINE(全量化在线) +- out_data_type(全量化输出类型) + - -1:未指定(默认) + - 1:float16 + - 27:bfloat16 +- compress_type(压缩类型) + - 0:UNDEFINED + - 1:KVHEAD + - 2:KVHEAD_ROPE + - 3:MAX(非法) +- calc_type(计算模式) + - 0:UNDEFINED(常规单 token 解码) + - 1:SPEC(MTP,多 Token 推理;需提供 `q_seq_lens`) +- scale_type(缩放类型) + - 0:TOR + - 1:LOGN + - 2:MAX +- input_layout(输入布局) + - 0:BSND + - 1:BNSD +- input_format(输入格式) + - 0:ND(Atlas A2 训练系列产品/Atlas A2 推理系列产品和Atlas A3 推理系列产品/Atlas A3 训练系列产品 默认) + - 1:NZ(Atlas 推理系列产品,需对 Q/K/V/Mask 使用 `trans_data(..., 1)` 转换) + +更多关于输入/输出约束说明,请参考 ATB 文档:[PagedAttentionOperation 输入输出列表](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/API/ascendtbapi/ascendtb_01_0197.html)。 + +## Python 使用示例 + +```python +import numpy as np +from mindspore import Tensor, context, ops +import ms_custom_ops + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +batch_size = 4 +head_num = 32 +kv_head_num = 32 +head_dim = 128 +block_size = 128 + +kv_seq_lens = [192, 193, 194, 195] +q_seq_lens = [1] * batch_size + +num_tokens = sum(q_seq_lens) +max_kv_len = max(kv_seq_lens) +max_blocks_per_query = (max_kv_len + block_size - 1) // block_size +num_blocks = batch_size * max_blocks_per_query + +query = Tensor(np.random.randn(num_tokens, head_num, head_dim).astype(np.float16)) +key_cache = Tensor(np.random.randn(num_blocks, block_size, kv_head_num, head_dim).astype(np.float16)) +value_cache = Tensor(np.random.randn(num_blocks, block_size, kv_head_num, head_dim).astype(np.float16)) +block_tables_np = np.stack([ + np.arange(i * max_blocks_per_query, (i + 1) * max_blocks_per_query, dtype=np.int32) + for i in range(batch_size) +]) +block_tables = Tensor(block_tables_np) +context_lens = Tensor(np.array(kv_seq_lens, dtype=np.int32)) + +context_lens_cpu = ops.move_to(context_lens, "CPU") + +qk_scale = float(1.0 / np.sqrt(head_dim)) +out = ms_custom_ops.paged_attention( + query, key_cache, value_cache, block_tables, context_lens_cpu, + attn_mask=None, q_seq_lens=None, + q_head_num=head_num, qk_scale=qk_scale, kv_head_num=kv_head_num, + mask_type=0, batch_run_status_enable=False, + quant_type=0, out_data_type=-1, has_quant_offset=False, + compress_type=0, calc_type=0, scale_type=0, + input_layout=0, mla_v_dim=0, input_format=0 +) +print(out.shape) # (num_tokens, head_num, head_dim) +``` diff --git a/ops/c_api/paged_attention/paged_attention_common.h b/ops/c_api/paged_attention/paged_attention_common.h new file mode 100644 index 0000000..d0a6be1 --- /dev/null +++ b/ops/c_api/paged_attention/paged_attention_common.h @@ -0,0 +1,120 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_CUSTOM_OPS_OPS_C_API_PAGED_ATTENTION_PAGED_ATTENTION_COMMON_H_ +#define MS_CUSTOM_OPS_OPS_C_API_PAGED_ATTENTION_PAGED_ATTENTION_COMMON_H_ + +#include + +namespace ms_custom_ops { + +enum PagedAttentionInputIndex : int32_t { + kPagedAttentionInputQueryIndex = 0, // 0 + kPagedAttentionInputKeyCacheIndex, // 1 + kPagedAttentionInputValueCacheIndex, // 2 + kPagedAttentionInputBlockTablesIndex, // 3 + kPagedAttentionInputContextLensIndex, // 4 + kPagedAttentionInputAttnMaskIndex, // 5 + kPagedAttentionInputBatchRunStatusIndex, // 6 + kPagedAttentionInputKDescalekIndex, // 7 + kPagedAttentionInputKOffsetIndex, // 8 + kPagedAttentionInputVDescaleIndex, // 9 + kPagedAttentionInputVOffsetIndex, // 10 + kPagedAttentionInputRazorOffsetIndex, // 11 + kPagedAttentionInputPScaleIndex, // 12 + kPagedAttentionInputLogNIndex, // 13 + kPagedAttentionInputQSeqLenIndex, // 14 + kPagedAttentionInputQHeadNumIndex, // 15 + kPagedAttentionInputQKScaleIndex, // 16 + kPagedAttentionInputKVHeadNumIndex, // 17 + kPagedAttentionInputMaskTypeIndex, // 18 + kPagedAttentionInputBatchRunStatusEnableIndex, // 19 + kPagedAttentionInputQuantTypeIndex, // 20 + kPagedAttentionInputOutDataTypeIndex, // 21 + kPagedAttentionInputHasQuantOffsetIndex, // 22 + kPagedAttentionInputCompressTypeIndex, // 23 + kPagedAttentionInputCalcTypeIndex, // 24 + kPagedAttentionInputScaleTypeIndex, // 25 + kPagedAttentionInputInputLayoutIndex, // 26 + kPagedAttentionInputMlaVDimHeadSizeIndex, // 27 + kPagedAttentionInputInputFormatIndex, // 28 + kPagedAttentionInputsNum // 29 +}; + +enum PAOutputIndex : int32_t { + kPagedAttentionOutputIndex = 0, + kPagedAttentionOutputNum, +}; + +enum PAMaskType : int32_t { + kPA_MASK_UNDEFINED = 0, + kPA_MASK_TYPE_NORM, + kPA_MASK_TYPE_ALIBI, + kPA_MASK_TYPE_SPEC, + kPA_MASK_TYPE_MASK_FREE, +}; + +enum PAQuantType : int32_t { + kPA_TYPE_QUANT_UNDEFINED = 0, + kPA_TYPE_QUANT_UNQUANT = 0, + kPA_TYPE_DEQUANT_FUSION, + kPA_TYPE_QUANT_QKV_OFFLINE, + kPA_TYPE_QUANT_QKV_ONLINE, +}; + +enum PACompressType : int32_t { + kPA_COMPRESS_TYPE_UNDEFINED = 0, + kPA_COMPRESS_TYPE_KVHEAD, + kPA_COMPRESS_TYPE_KVHEAD_ROPE, + kPA_COMPRESS_TYPE_MAX, +}; + +enum PACalcType : int32_t { + kPA_CALC_TYPE_UNDEFINED = 0, + kPA_CALC_TYPE_SPEC, +}; + +enum PAOutDataType : int32_t { + kPA_ACL_DT_UNDEFINED = -1, + kPA_ACL_FLOAT16 = 1, + kPA_ACL_BF16 = 27, +}; + +enum PAInputLayout : int32_t { + kPA_INPUT_LAYOUT_BSND = 0, + kPA_INPUT_LAYOUT_BNSD = 1, +}; + +enum PAScaleType : int32_t { + kPA_SCALE_TYPE_TOR = 0, + kPA_SCALE_TYPE_LOGN, + kPA_SCALE_TYPE_MAX +}; + +enum PAInputFormat : int8_t { + kKVFormatND = 0, + kKVFormatNZ +}; + +static constexpr auto kPAQShapeRank = 3; +static constexpr auto kPAKVCacheRank = 4; +static constexpr auto kPAKVCacheRankAltas = 3; +static constexpr auto kPABlockTableRank = 2; +static constexpr auto kPAContextLenRank = 1; + +} // namespace ms_custom_ops + +#endif // MS_CUSTOM_OPS_OPS_C_API_PAGED_ATTENTION_PAGED_ATTENTION_COMMON_H_ diff --git a/ops/c_api/paged_attention/paged_attention_graph.cc b/ops/c_api/paged_attention/paged_attention_graph.cc new file mode 100644 index 0000000..838badb --- /dev/null +++ b/ops/c_api/paged_attention/paged_attention_graph.cc @@ -0,0 +1,885 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/c_api/paged_attention/paged_attention_common.h" + +#include +#include +#include +#include + +#include "ops/c_api/utils/attention_utils.h" +#include "ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.h" + +namespace ms_custom_ops { +namespace { + +inline bool IsKnownDim(int64_t dim) { + return dim != abstract::Shape::kShapeDimAny; +} + +} // namespace + +static void CheckHeadNumbers(int64_t q_head_num, int64_t kv_head_num) { + MS_CHECK_VALUE(q_head_num != 0, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the q_head_num should not be 0, but got 0.")); + MS_CHECK_VALUE(kv_head_num >= 0, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the kv_head_num should be greater than or equal to 0, but got ", + kv_head_num)); + if (kv_head_num != 0) { + MS_CHECK_VALUE( + q_head_num % kv_head_num == 0, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the q_head_num must be divisible by kv_head_num, but got q_head_num=", + q_head_num, ", kv_head_num=", kv_head_num)); + } +} + +static void CheckScaleTypeLogN(int64_t scale_type, int64_t quant_type, + int64_t calc_type, int64_t compress_type) { + MS_CHECK_VALUE( + (scale_type >= PAScaleType::kPA_SCALE_TYPE_TOR && + scale_type < PAScaleType::kPA_SCALE_TYPE_MAX), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the scale_type is invalid, got ", scale_type)); + + if (scale_type == PAScaleType::kPA_SCALE_TYPE_LOGN) { + MS_CHECK_VALUE((quant_type == PAQuantType::kPA_TYPE_QUANT_UNQUANT), + CheckAndConvertUtils::FormatCommMsg( + "In PA scale type logn mode, quant_type must be 0(TYPE_QUANT_UNQUANT/" + "TYPE_QUANT_UNDEFINED), but got ", quant_type)); + MS_CHECK_VALUE( + (calc_type == PACalcType::kPA_CALC_TYPE_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg( + "In PA scale type logn mode, calc_type feature is not supported, but got calc_type=", + calc_type)); + MS_CHECK_VALUE( + (compress_type == PACompressType::kPA_COMPRESS_TYPE_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg( + "In PA scale type logn mode, compress_type feature is not supported, but got compress_type=", + compress_type)); + } +} + +static void CheckMLAVHeadSize(int64_t mla_v_head_size) { + constexpr int64_t kMaxMLAVHeadSize = 576; + MS_CHECK_VALUE(((mla_v_head_size >= 0 && mla_v_head_size <= kMaxMLAVHeadSize)), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention(MLA mode) the value head size should be [0, 576], but got ", + mla_v_head_size)); +} + +static void CheckInputLayoutAndCalcType(int64_t input_layout, int64_t calc_type, + int64_t quant_type, int64_t compress_type, + bool batch_run_status_enable) { + MS_CHECK_VALUE( + ((input_layout == PAInputLayout::kPA_INPUT_LAYOUT_BNSD || + input_layout == PAInputLayout::kPA_INPUT_LAYOUT_BSND)), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the input layout should be 0(BSND)/1(BNSD), but got ", + input_layout)); + + MS_CHECK_VALUE( + ((calc_type == PACalcType::kPA_CALC_TYPE_SPEC || + calc_type == PACalcType::kPA_CALC_TYPE_UNDEFINED)), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the calc_type should be 0(disable MTP)/1(enable MTP), but got ", + calc_type)); + + if (calc_type == PACalcType::kPA_CALC_TYPE_SPEC) { + MS_CHECK_VALUE( + (quant_type == PAQuantType::kPA_TYPE_QUANT_UNQUANT), + CheckAndConvertUtils::FormatCommMsg( + "In PA MTP scene, quant mode should be " + "0(TYPE_QUANT_UNQUANT/TYPE_QUANT_UNDEFINED), but now got ", + quant_type)); + MS_CHECK_VALUE( + (!batch_run_status_enable), + CheckAndConvertUtils::FormatCommMsg( + "In PA MTP scene, batch_run_status_enable should be false, but now got true.")); + MS_CHECK_VALUE( + (compress_type == PACompressType::kPA_COMPRESS_TYPE_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg( + "In PA MTP scene, compress_type should be 0(kPA_COMPRESS_TYPE_UNDEFINED), but now got ", + compress_type)); + } +} + +static void CheckBNSDLayout(int64_t input_layout, int64_t calc_type, + int64_t compress_type, int64_t quant_type, + int64_t scale_type) { + if (input_layout == PAInputLayout::kPA_INPUT_LAYOUT_BNSD) { + MS_CHECK_VALUE( + (calc_type == PACalcType::kPA_CALC_TYPE_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention when input layout is BNSD, calc_type feature is not supported," + " but got ", calc_type)); + MS_CHECK_VALUE( + (compress_type == PACompressType::kPA_COMPRESS_TYPE_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention when input layout is BNSD, compress_type feature is not supported," + " but got ", compress_type)); + MS_CHECK_VALUE( + (quant_type == PAQuantType::kPA_TYPE_QUANT_UNQUANT), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention when input layout is BNSD, quant_type must be 0(TYPE_QUANT_UNQUANT)," + " but got ", quant_type)); + MS_CHECK_VALUE( + (scale_type == PAScaleType::kPA_SCALE_TYPE_TOR), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention when input layout is BNSD, scale_type must be 0(kPA_SCALE_TYPE_TOR)," + " but got ", scale_type)); + } +} + +static void CheckCompressType(int64_t compress_type, int64_t quant_type, + bool batch_run_status_enable, int64_t mask_type) { + MS_CHECK_VALUE( + (compress_type != PACompressType::kPA_COMPRESS_TYPE_MAX), + CheckAndConvertUtils::FormatCommMsg( + "In PA compress scene, compress type should not be 3(kPA_COMPRESS_TYPE_MAX).")); + + if (compress_type == PACompressType::kPA_COMPRESS_TYPE_KVHEAD || + compress_type == PACompressType::kPA_COMPRESS_TYPE_KVHEAD_ROPE) { + MS_CHECK_VALUE( + (quant_type == PAQuantType::kPA_TYPE_QUANT_UNQUANT), + CheckAndConvertUtils::FormatCommMsg( + "In PA compress scene, quant_type must be 0(TYPE_QUANT_UNQUANT), but now got ", + quant_type)); + } + + if (compress_type == PACompressType::kPA_COMPRESS_TYPE_KVHEAD_ROPE) { + MS_CHECK_VALUE( + (!batch_run_status_enable), + CheckAndConvertUtils::FormatCommMsg( + "In PA COMPRESS_TYPE_KVHEAD_ROPE scene, batch_run_status_enable must be false, but now got true.")); + MS_CHECK_VALUE( + (mask_type == PAMaskType::kPA_MASK_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg( + "In PA COMPRESS_TYPE_KVHEAD_ROPE scene, mask type should not be " + "0(PA_MASK_UNDEFINED), but now got ", mask_type)); + } +} + +static void CheckQuantType(int64_t quant_type, int64_t calc_type, + int64_t out_data_type, bool has_quant_offset, + int64_t compress_type) { + if (quant_type == PAQuantType::kPA_TYPE_DEQUANT_FUSION) { + MS_CHECK_VALUE( + (calc_type == PACalcType::kPA_CALC_TYPE_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention when quant_type is DEQUANT_FUSION, calc_type feature is not supported," + " but got ", calc_type)); + } + + if (quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_OFFLINE || + quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_ONLINE) { + MS_CHECK_VALUE( + (out_data_type == PAOutDataType::kPA_ACL_FLOAT16 || + out_data_type == PAOutDataType::kPA_ACL_BF16), + CheckAndConvertUtils::FormatCommMsg( + "In PA full quant scene, out_data_type must be 1(Float16) or 27(BFloat16), but got ", + out_data_type)); + MS_CHECK_VALUE( + (!has_quant_offset), + CheckAndConvertUtils::FormatCommMsg( + "In PA full quant scene, has_quant_offset must be false, but now got true.")); + MS_CHECK_VALUE( + (compress_type == PACompressType::kPA_COMPRESS_TYPE_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg( + "In PA full quant scene, compress_type should be 0(kPA_COMPRESS_TYPE_UNDEFINED)," + " but now got ", compress_type)); + MS_CHECK_VALUE( + (calc_type == PACalcType::kPA_CALC_TYPE_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg( + "In PA full quant scene, calc_type feature is not supported, but got ", + calc_type)); + } +} + +static void CheckMLAMode(int64_t mla_v_head_size, int64_t mask_type, + int64_t compress_type, int64_t quant_type, + int64_t scale_type, int64_t input_layout, + int64_t kv_head_num, int64_t calc_type, + bool batch_run_status_enable) { + if (mla_v_head_size > 0) { + MS_CHECK_VALUE( + (mask_type != PAMaskType::kPA_MASK_TYPE_ALIBI), + CheckAndConvertUtils::FormatCommMsg( + "In PA MLA mode, mask type kPA_MASK_TYPE_ALIBI is not supported, but now got ", + mask_type)); + MS_CHECK_VALUE( + (compress_type == PACompressType::kPA_COMPRESS_TYPE_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg( + "In PA MLA mode, compress_type should be 0(kPA_COMPRESS_TYPE_UNDEFINED), but now got ", + compress_type)); + MS_CHECK_VALUE( + (quant_type != PAQuantType::kPA_TYPE_DEQUANT_FUSION), + CheckAndConvertUtils::FormatCommMsg( + "In PA MLA mode, quant_type kPA_TYPE_DEQUANT_FUSION is not supported, but now got ", + quant_type)); + MS_CHECK_VALUE( + (scale_type == PAScaleType::kPA_SCALE_TYPE_TOR), + CheckAndConvertUtils::FormatCommMsg( + "In PA MLA mode, scale_type must be 0(kPA_SCALE_TYPE_TOR), but now got ", + scale_type)); + MS_CHECK_VALUE( + (input_layout == PAInputLayout::kPA_INPUT_LAYOUT_BSND), + CheckAndConvertUtils::FormatCommMsg( + "In PA MLA mode, input layout must be 0(BSND), but now got ", + input_layout)); + MS_CHECK_VALUE( + (kv_head_num == 1), + CheckAndConvertUtils::FormatCommMsg( + "In PA MLA mode, kv_head_num should be 1 (MQA), but now got ", + kv_head_num)); + + if (calc_type == PACalcType::kPA_CALC_TYPE_SPEC) { + MS_CHECK_VALUE( + (mask_type != PAMaskType::kPA_MASK_TYPE_NORM), + CheckAndConvertUtils::FormatCommMsg( + "In PA MLA MTP scene, mask type kPA_MASK_TYPE_NORM is not supported, but now got ", + mask_type)); + MS_CHECK_VALUE( + (quant_type == PAQuantType::kPA_TYPE_QUANT_UNQUANT), + CheckAndConvertUtils::FormatCommMsg( + "In PA MLA MTP scene, quant_type must be 0(TYPE_QUANT_UNQUANT), but now got ", + quant_type)); + MS_CHECK_VALUE( + (!batch_run_status_enable), + CheckAndConvertUtils::FormatCommMsg( + "In PA MLA MTP scene, batch_run_status_enable should be false, but now got true.")); + } + } +} + +static void CheckParams(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) { + (void)primitive; + + // Extract all parameters + auto q_head_num = + input_infos[kPagedAttentionInputQHeadNumIndex]->GetScalarValueWithCheck(); + auto kv_head_num = + input_infos[kPagedAttentionInputKVHeadNumIndex]->GetScalarValueWithCheck(); + auto scale_type = + input_infos[kPagedAttentionInputScaleTypeIndex]->GetScalarValueWithCheck(); + auto quant_type = + input_infos[kPagedAttentionInputQuantTypeIndex]->GetScalarValueWithCheck(); + auto mla_v_head_size = + input_infos[kPagedAttentionInputMlaVDimHeadSizeIndex]->GetScalarValueWithCheck(); + auto input_layout = + input_infos[kPagedAttentionInputInputLayoutIndex]->GetScalarValueWithCheck(); + auto calc_type = + input_infos[kPagedAttentionInputCalcTypeIndex]->GetScalarValueWithCheck(); + auto compress_type = + input_infos[kPagedAttentionInputCompressTypeIndex]->GetScalarValueWithCheck(); + auto mask_type = + input_infos[kPagedAttentionInputMaskTypeIndex]->GetScalarValueWithCheck(); + auto batch_run_status_enable = + input_infos[kPagedAttentionInputBatchRunStatusEnableIndex]->GetScalarValueWithCheck(); + auto has_quant_offset = + input_infos[kPagedAttentionInputHasQuantOffsetIndex]->GetScalarValueWithCheck(); + auto out_data_type = + input_infos[kPagedAttentionInputOutDataTypeIndex]->GetScalarValueWithCheck(); + + // Perform validation checks by calling specialized sub-functions + CheckHeadNumbers(q_head_num, kv_head_num); + CheckScaleTypeLogN(scale_type, quant_type, calc_type, compress_type); + CheckMLAVHeadSize(mla_v_head_size); + CheckInputLayoutAndCalcType(input_layout, calc_type, quant_type, compress_type, + batch_run_status_enable); + CheckBNSDLayout(input_layout, calc_type, compress_type, quant_type, scale_type); + CheckCompressType(compress_type, quant_type, batch_run_status_enable, mask_type); + CheckQuantType(quant_type, calc_type, out_data_type, has_quant_offset, compress_type); + CheckMLAMode(mla_v_head_size, mask_type, compress_type, quant_type, scale_type, + input_layout, kv_head_num, calc_type, batch_run_status_enable); +} + +static void CheckKVCacheConsistency(int64_t mla_v_dim, + int64_t num_blocks_k, int64_t num_blocks_v, + int64_t block_size_k, int64_t block_size_v, + int64_t head_num_k, int64_t head_num_v) { + if (mla_v_dim == 0) { + if (IsKnownDim(num_blocks_k) && IsKnownDim(num_blocks_v)) { + MS_CHECK_VALUE( + num_blocks_k == num_blocks_v, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the num_blocks of key_cache and value_cache must be same, but got ", + num_blocks_k, " and ", num_blocks_v)); + } + if (IsKnownDim(block_size_k) && IsKnownDim(block_size_v)) { + MS_CHECK_VALUE( + block_size_k == block_size_v, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the block_size of key_cache and value_cache must be same, but got ", + block_size_k, " and ", block_size_v)); + } + if (IsKnownDim(head_num_k) && IsKnownDim(head_num_v)) { + MS_CHECK_VALUE( + head_num_k == head_num_v, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the head_num of key_cache and value_cache must be same, but got ", + head_num_k, " and ", head_num_v)); + } + } +} + +static void CheckQueryKVHeadSizeConsistency(int64_t head_size_k, int64_t head_size_q) { + if (IsKnownDim(head_size_k) && IsKnownDim(head_size_q)) { + MS_CHECK_VALUE( + head_size_k == head_size_q, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the head_size of key_cache and query must be same, but got ", + head_size_k, " and ", head_size_q)); + } +} + +static void CheckNonMLAHeadSizeLimits(int64_t mla_v_dim, int64_t head_size_k, int64_t block_size_k) { + constexpr int64_t kMaxHeadSize910B = 256; + constexpr int64_t kMaxHeadSizeProd = 128 * 128; + + if (mla_v_dim == 0 && IsKnownDim(head_size_k) && IsKnownDim(block_size_k)) { + MS_CHECK_VALUE( + head_size_k > 0 && head_size_k <= kMaxHeadSize910B, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention on ND path, head_size of key_cache must be in (0, 256], but got ", + head_size_k)); + MS_CHECK_VALUE( + block_size_k * head_size_k <= kMaxHeadSizeProd, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention on ND path, block_size * head_size must be <= 128 * 128, but got ", + block_size_k * head_size_k)); + } +} + +static void CheckMLAHeadSizeConstraints(int64_t mla_v_dim, int64_t head_size_k, + int64_t head_size_v, int64_t block_size_k) { + if (mla_v_dim > 0) { + if (IsKnownDim(head_size_k) && IsKnownDim(head_size_v)) { + constexpr int64_t kMaxMLAHeadSize = 576; + MS_CHECK_VALUE( + head_size_k <= kMaxMLAHeadSize && head_size_v <= kMaxMLAHeadSize, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention MLA mode, head_size of key_cache and value_cache must be <= 576, but got ", + head_size_k, " and ", head_size_v)); + } + if (IsKnownDim(head_size_k) && IsKnownDim(head_size_v) && IsKnownDim(block_size_k)) { + constexpr int64_t kHeadSizeThreshold = 256; + constexpr int64_t kBlockSizeLimit = 128; + if ((head_size_k > kHeadSizeThreshold || head_size_v > kHeadSizeThreshold) && + block_size_k > kBlockSizeLimit) { + MS_CHECK_VALUE( + false, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention MLA mode, when head_size > 256, block_size must be <= 128, but got " + "head_size_k=", head_size_k, ", head_size_v=", head_size_v, ", block_size=", block_size_k)); + } + } + } +} + +static void CheckKVCacheShapeND(const InferInfoPtrList &input_infos) { + auto query_shape = input_infos[kPagedAttentionInputQueryIndex]->GetShape(); + auto key_cache_shape = input_infos[kPagedAttentionInputKeyCacheIndex]->GetShape(); + auto value_cache_shape = input_infos[kPagedAttentionInputValueCacheIndex]->GetShape(); + auto mla_v_dim = + input_infos[kPagedAttentionInputMlaVDimHeadSizeIndex]->GetScalarValueWithCheck(); + + if (input_infos[kPagedAttentionInputQueryIndex]->IsDynamic() || + input_infos[kPagedAttentionInputKeyCacheIndex]->IsDynamic() || + input_infos[kPagedAttentionInputValueCacheIndex]->IsDynamic()) { + return; + } + if (query_shape.size() != kPAQShapeRank || key_cache_shape.size() != kPAKVCacheRank || + value_cache_shape.size() != kPAKVCacheRank) { + // Rank mismatch is handled in basic rank checks. + return; + } + + // Extract shape dimensions + auto num_blocks_k = key_cache_shape[0]; + auto block_size_k = key_cache_shape[1]; + auto head_num_k = key_cache_shape[2]; + auto head_size_k = key_cache_shape[3]; + + auto num_blocks_v = value_cache_shape[0]; + auto block_size_v = value_cache_shape[1]; + auto head_num_v = value_cache_shape[2]; + auto head_size_v = value_cache_shape[3]; + + auto head_size_q = query_shape[2]; + + // Perform validation checks by calling specialized sub-functions + CheckKVCacheConsistency(mla_v_dim, num_blocks_k, num_blocks_v, + block_size_k, block_size_v, head_num_k, head_num_v); + CheckQueryKVHeadSizeConsistency(head_size_k, head_size_q); + CheckNonMLAHeadSizeLimits(mla_v_dim, head_size_k, block_size_k); + CheckMLAHeadSizeConstraints(mla_v_dim, head_size_k, head_size_v, block_size_k); +} + +static void CheckMaskFreeShape310P(const InferInfoPtrList &input_infos, int64_t input_format) { + auto mask_type = + input_infos[kPagedAttentionInputMaskTypeIndex]->GetScalarValueWithCheck(); + if (mask_type != PAMaskType::kPA_MASK_TYPE_MASK_FREE) { + return; + } + + // Mask-free only supports 310P (NZ format). + if (input_format != PAInputFormat::kKVFormatNZ) { + MS_CHECK_VALUE( + false, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention, MASK_FREE is only supported on 310P (NZ format).")); + } + + const auto &mask_info = input_infos[kPagedAttentionInputAttnMaskIndex]; + if (!mask_info->IsDynamic()) { + auto mask_shape = mask_info->GetShape(); + constexpr size_t kMaskRank = 3; + MS_CHECK_VALUE( + mask_shape.size() == kMaskRank, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention MASK_FREE on 310P, the rank of mask must be 3, but got shape: ", + mask_shape)); + + constexpr int64_t kExpectedBatch = 1; + constexpr int64_t kExpectedBlockSize = 128; + + if (IsKnownDim(mask_shape[0])) { + MS_CHECK_VALUE( + mask_shape[0] == kExpectedBatch, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention MASK_FREE on 310P, mask dim[0] must be 1, but got ", mask_shape[0])); + } + if (IsKnownDim(mask_shape[1])) { + MS_CHECK_VALUE( + mask_shape[1] == kExpectedBlockSize, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention MASK_FREE on 310P, mask dim[1] must be 128, but got ", mask_shape[1])); + } + if (IsKnownDim(mask_shape[2])) { + MS_CHECK_VALUE( + mask_shape[2] == kExpectedBlockSize, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention MASK_FREE on 310P, mask dim[2] must be 128, but got ", mask_shape[2])); + } + } +} + +static void CheckShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) { + (void)primitive; + auto query_shape = input_infos[kPagedAttentionInputQueryIndex]->GetShape(); + auto key_cache_shape = input_infos[kPagedAttentionInputKeyCacheIndex]->GetShape(); + auto value_cache_shape = input_infos[kPagedAttentionInputValueCacheIndex]->GetShape(); + auto block_tables_shape = input_infos[kPagedAttentionInputBlockTablesIndex]->GetShape(); + auto context_len_shape = input_infos[kPagedAttentionInputContextLensIndex]->GetShape(); + // Input format: ND (910B) vs NZ (310P) + auto input_format = + input_infos[kPagedAttentionInputInputFormatIndex]->GetScalarValueWithCheck(); + + if (!input_infos[kPagedAttentionInputQueryIndex]->IsDynamic()) { + // 910B/ND: rank must be 3 (T, H, Dh) + // 310P/NZ: rank must be 2 (T, H*Dh) + auto q_rank = static_cast(query_shape.size()); + constexpr int64_t kPA310PQueryRank = 2; + if (input_format == PAInputFormat::kKVFormatNZ) { + MS_CHECK_VALUE( + q_rank == kPA310PQueryRank, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention (310P/NZ) the rank of query must be 2 (T, H*Dh), but got shape: ", + query_shape)); + } else { + MS_CHECK_VALUE( + q_rank == kPAQShapeRank, + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the rank of query must be ", kPAQShapeRank, + ", but got shape: ", query_shape)); + } + } + + if (!input_infos[kPagedAttentionInputKeyCacheIndex]->IsDynamic()) { + // 910B/ND: rank must be 4 (num_blocks, block_size, num_heads, head_size) + // 310P/NZ: rank must be 3 (num_blocks, block_size, num_heads*head_size) + auto expected_rank = (input_format == PAInputFormat::kKVFormatNZ + ? kPAKVCacheRankAltas + : kPAKVCacheRank); + MS_CHECK_VALUE( + key_cache_shape.size() == static_cast(expected_rank), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the rank of key_cache must be ", expected_rank, + " (3 for 310P/NZ, 4 for 910B/ND), but got shape: ", key_cache_shape)); + } + + if (!input_infos[kPagedAttentionInputValueCacheIndex]->IsDynamic()) { + auto expected_rank = (input_format == PAInputFormat::kKVFormatNZ + ? kPAKVCacheRankAltas + : kPAKVCacheRank); + MS_CHECK_VALUE( + value_cache_shape.size() == static_cast(expected_rank), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the rank of value_cache must be ", expected_rank, + " (3 for 310P/NZ, 4 for 910B/ND), but got shape: ", value_cache_shape)); + } + + if (!input_infos[kPagedAttentionInputBlockTablesIndex]->IsDynamic()) { + MS_CHECK_VALUE( + block_tables_shape.size() == kPABlockTableRank, + CheckAndConvertUtils::FormatCommMsg( + "For PA The rank of block table must be ", kPABlockTableRank, + ", but got shape: ", block_tables_shape)); + } + + if (!input_infos[kPagedAttentionInputContextLensIndex]->IsDynamic()) { + MS_CHECK_VALUE( + context_len_shape.size() == kPAContextLenRank, + CheckAndConvertUtils::FormatCommMsg( + "For PA The rank of context len must be ", kPAContextLenRank, + ", but got shape: ", context_len_shape)); + } + + // Head_size / block_size / MLA-specific shape checks for ND path (910B). + if (input_format == PAInputFormat::kKVFormatND) { + CheckKVCacheShapeND(input_infos); + } + + CheckMaskFreeShape310P(input_infos, input_format); +} + +static void CheckType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) { + auto quant_type = + input_infos[kPagedAttentionInputQuantTypeIndex]->GetScalarValueWithCheck(); + auto query_dtype = input_infos[kPagedAttentionInputQueryIndex]->GetType(); + auto key_cache_dtype = input_infos[kPagedAttentionInputKeyCacheIndex]->GetType(); + auto value_cache_dtype = input_infos[kPagedAttentionInputValueCacheIndex]->GetType(); + if (quant_type == PAQuantType::kPA_TYPE_QUANT_UNQUANT) { + MS_CHECK_VALUE( + ((query_dtype == kNumberTypeFloat16) || (query_dtype == kNumberTypeBFloat16)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in unquant mode, query dtype must be float16/bfloat16, but got type: ", + query_dtype)); + MS_CHECK_VALUE( + ((key_cache_dtype == kNumberTypeFloat16) || (key_cache_dtype == kNumberTypeBFloat16)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in unquant mode, key cache dtype must be float16/bfloat16, but got type: ", + key_cache_dtype)); + MS_CHECK_VALUE( + ((value_cache_dtype == kNumberTypeFloat16) || + (value_cache_dtype == kNumberTypeBFloat16)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in unquant mode,value cache dtype must be float16/bfloat16, but got type: ", + value_cache_dtype)); + } else if (quant_type == PAQuantType::kPA_TYPE_DEQUANT_FUSION) { + MS_CHECK_VALUE( + ((query_dtype == kNumberTypeFloat16) || (query_dtype == kNumberTypeBFloat16)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in dequant mode, query dtype must be float16/bfloat16, but got type: ", + query_dtype)); + MS_CHECK_VALUE( + ((key_cache_dtype == kNumberTypeInt8)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in dequant mode, key cache dtype must be int8, but got type: ", + key_cache_dtype)); + MS_CHECK_VALUE( + ((value_cache_dtype == kNumberTypeInt8)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in dequant mode,value cache dtype must be int8, but got type: ", + value_cache_dtype)); + } else if (quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_OFFLINE || + quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_ONLINE) { + MS_CHECK_VALUE( + ((query_dtype == kNumberTypeInt8)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in full quant mode, query dtype must be int8, but got type: ", + query_dtype)); + MS_CHECK_VALUE( + ((key_cache_dtype == kNumberTypeInt8)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in full quant mode, key cache dtype must be int8, but got type: ", + key_cache_dtype)); + MS_CHECK_VALUE( + ((value_cache_dtype == kNumberTypeInt8)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in full quant mode, value cache dtype must be int8, but got type: ", + value_cache_dtype)); + } +} + +class OPS_API PagedAttentionFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + if (input_infos.size() != kPagedAttentionInputsNum) { + MS_LOG(EXCEPTION) << "Paged Attention input args should be equal to " + << kPagedAttentionInputsNum + << ",but now get " << input_infos.size(); + } + + auto &query_info = input_infos[kPagedAttentionInputQueryIndex]; + auto query_shape = query_info->GetShape(); + + CheckParams(primitive, input_infos); + CheckShape(primitive, input_infos); + + if (query_info->IsDynamic() || + input_infos[kPagedAttentionInputKeyCacheIndex]->IsDynamic() || + input_infos[kPagedAttentionInputValueCacheIndex]->IsDynamic()) { + return {query_shape}; + } + + auto mla_v_dim = + input_infos[kPagedAttentionInputMlaVDimHeadSizeIndex]->GetScalarValueWithCheck(); + if (mla_v_dim == 0) { + return {query_shape}; + } + + auto key_cache_shape = input_infos[kPagedAttentionInputKeyCacheIndex]->GetShape(); + auto key_head_dim = key_cache_shape[key_cache_shape.size() - 1]; + if ((key_head_dim == abstract::Shape::kShapeDimAny) || + (query_shape[query_shape.size() - 1] == abstract::Shape::kShapeDimAny)) { + query_shape[query_shape.size() - 1] = abstract::Shape::kShapeDimAny; + } else { + query_shape[query_shape.size() - 1] = + query_shape[query_shape.size() - 1] / key_head_dim * mla_v_dim; + } + + return {query_shape}; + } + + std::vector InferType(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + CheckType(primitive, input_infos); + auto quant_type = + input_infos[kPagedAttentionInputQuantTypeIndex]->GetScalarValueWithCheck(); + auto query_type = input_infos[kPagedAttentionInputQueryIndex]->GetType(); + if ((quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_ONLINE) || + (quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_OFFLINE)) { + auto out_data_type = + input_infos[kPagedAttentionInputOutDataTypeIndex]->GetScalarValueWithCheck(); + + switch (out_data_type) { + case PAOutDataType::kPA_ACL_FLOAT16: + return {kNumberTypeFloat16}; + case PAOutDataType::kPA_ACL_BF16: + return {kNumberTypeBFloat16}; + default: + MS_LOG(EXCEPTION) + << "In PA full quant scene, we should set the output data type:1(Float16) or 27(BFloat16)"; + } + } + + return {query_type}; + } + + bool GeneralInferRegistered() const override { return true; } + + std::set GetValueDependArgIndices() const override { + return { + kPagedAttentionInputContextLensIndex, + kPagedAttentionInputQSeqLenIndex, + kPagedAttentionInputQHeadNumIndex, + kPagedAttentionInputQKScaleIndex, + kPagedAttentionInputKVHeadNumIndex, + kPagedAttentionInputMaskTypeIndex, + kPagedAttentionInputBatchRunStatusEnableIndex, + kPagedAttentionInputQuantTypeIndex, + kPagedAttentionInputOutDataTypeIndex, + kPagedAttentionInputHasQuantOffsetIndex, + kPagedAttentionInputCompressTypeIndex, + kPagedAttentionInputCalcTypeIndex, + kPagedAttentionInputScaleTypeIndex, + kPagedAttentionInputInputLayoutIndex, + kPagedAttentionInputMlaVDimHeadSizeIndex, + kPagedAttentionInputInputFormatIndex + }; + } +}; + +class PagedAttention : public InternalKernelMod { + public: + PagedAttention() : InternalKernelMod() {} + ~PagedAttention() override = default; + + protected: + internal_v2::InternalOpPtr CreateKernel( + const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs, + const std::vector &ms_inputs, + const std::vector &ms_outputs) override { + param_.q_head_num = + static_cast( + ms_inputs[kPagedAttentionInputQHeadNumIndex]->GetValueWithCheck()); + param_.tor = ms_inputs[kPagedAttentionInputQKScaleIndex]->GetValueWithCheck(); + param_.kv_head_num = + static_cast( + ms_inputs[kPagedAttentionInputKVHeadNumIndex]->GetValueWithCheck()); + param_.mask_type = + static_cast( + ms_inputs[kPagedAttentionInputMaskTypeIndex]->GetValueWithCheck()); + param_.batch_run_status_enable = + ms_inputs[kPagedAttentionInputBatchRunStatusEnableIndex]->GetValueWithCheck(); + param_.quant_type = + static_cast( + ms_inputs[kPagedAttentionInputQuantTypeIndex]->GetValueWithCheck()); + param_.out_data_type = + static_cast( + ms_inputs[kPagedAttentionInputOutDataTypeIndex]->GetValueWithCheck()); + param_.has_quant_offset = + ms_inputs[kPagedAttentionInputHasQuantOffsetIndex]->GetValueWithCheck(); + param_.compress_type = + static_cast( + ms_inputs[kPagedAttentionInputCompressTypeIndex]->GetValueWithCheck()); + param_.calc_type = + static_cast( + ms_inputs[kPagedAttentionInputCalcTypeIndex]->GetValueWithCheck()); + param_.scale_type = + static_cast( + ms_inputs[kPagedAttentionInputScaleTypeIndex]->GetValueWithCheck()); + param_.input_layout = + static_cast( + ms_inputs[kPagedAttentionInputInputLayoutIndex]->GetValueWithCheck()); + param_.mla_v_dim = + static_cast( + ms_inputs[kPagedAttentionInputMlaVDimHeadSizeIndex]->GetValueWithCheck()); + if (param_.calc_type == + internal_v2::CustomASDPagedAttentionParam::CaclType::kPACalcTypeSpec) { + param_.q_seq_len = + ms_inputs[kPagedAttentionInputQSeqLenIndex]->GetValueWithCheck>(); + } + + // Get input_format parameter: 0 for ND (910B), 1 for NZ (310P) + auto input_format = + static_cast( + ms_inputs[kPagedAttentionInputInputFormatIndex]->GetValueWithCheck()); + + // On 310P (input_format==1), context_lens is passed as NPU tensor input, not from param + // On 910B (input_format==0), context_lens is passed through param + constexpr int32_t kInputFormatND = 0; + constexpr int32_t kInputFormatNZ = 1; + if (input_format == kInputFormatND) { + param_.kv_seq_len = + ms_inputs[kPagedAttentionInputContextLensIndex]->GetValueWithCheck>(); + } + + if (param_.batch_run_status_enable) { + param_.batch_run_status = + ms_inputs[kPagedAttentionInputBatchRunStatusIndex]->GetValueWithCheck>(); + } + created_flag_ = true; + + // NZ format routing for 310P platform when input_format == 1 + if (input_format == kInputFormatNZ) { + auto inputs_clone = inputs; + auto outputs_clone = outputs; + // Set Query, Key Cache, Value Cache and Mask to FRACTAL_NZ format + inputs_clone[static_cast(kPagedAttentionInputQueryIndex)] + .SetFormat(internal_v2::kFormatFRACTAL_NZ); + inputs_clone[static_cast(kPagedAttentionInputKeyCacheIndex)] + .SetFormat(internal_v2::kFormatFRACTAL_NZ); + inputs_clone[static_cast(kPagedAttentionInputValueCacheIndex)] + .SetFormat(internal_v2::kFormatFRACTAL_NZ); + inputs_clone[static_cast(kPagedAttentionInputAttnMaskIndex)] + .SetFormat(internal_v2::kFormatFRACTAL_NZ); + outputs_clone[static_cast(kPagedAttentionOutputIndex)] + .SetFormat(internal_v2::kFormatFRACTAL_NZ); + return internal_v2::CreateCustomPagedAttentionOp( + inputs_clone, outputs_clone, param_, internal_v2::kInternalCustomPagedAttention); + } + + // Default ND format for 910B + return internal_v2::CreateCustomPagedAttentionOp( + inputs, outputs, param_, internal_v2::kInternalCustomPagedAttention); + } + + bool UpdateParam( + const std::vector &inputs, + const std::vector &outputs) override { + if (created_flag_) { + // the q_seq_len and batch_valid_length are inited in CreateKernel, + // so there is no need to load them again + created_flag_ = false; + return true; + } + + // Get input_format to determine platform (0=910B, 1=310P) + auto input_format = static_cast( + inputs[kPagedAttentionInputInputFormatIndex]->GetValueWithCheck()); + + bool need_recreate = false; + if (param_.calc_type == + internal_v2::CustomASDPagedAttentionParam::CaclType::kPACalcTypeSpec) { + auto q_need_recreate = + GetSeqLenAndCheckUpdate(inputs[kPagedAttentionInputQSeqLenIndex], ¶m_.q_seq_len); + need_recreate |= q_need_recreate; + } + + // On 310P (input_format==1), context_lens is NPU tensor input, not in param + // On 910B (input_format==0), context_lens is in param and needs update check + constexpr int32_t kInputFormatND = 0; + if (input_format == kInputFormatND) { + auto kv_need_recreate = GetSeqLenAndCheckUpdate( + inputs[kPagedAttentionInputContextLensIndex], ¶m_.kv_seq_len); + need_recreate |= kv_need_recreate; + } + + if (need_recreate) { + auto ret = internal_op_->UpdateParam(¶m_); + if (ret != internal_v2::kInternalOk) { + MS_LOG(ERROR) + << "ASD PagedAttention UpdateParam failed, kernel_name: " << kernel_name_; + return false; + } + return true; + } + return true; + } + + uint64_t GenerateTilingKey(const std::vector &inputs) override { + // User defined CacheKey, the inputs should include all the factors which + // will affect tiling result. + return InternalTilingCache::GenerateKey( + kernel_name_, inputs, param_.q_seq_len, param_.kv_seq_len, param_.mla_v_dim); + } + + void InitKernelInputsOutputsIndex() override { + kernel_inputs_index_ = { + kPagedAttentionInputQueryIndex, + kPagedAttentionInputKeyCacheIndex, + kPagedAttentionInputValueCacheIndex, + kPagedAttentionInputBlockTablesIndex, + kPagedAttentionInputContextLensIndex, + kPagedAttentionInputAttnMaskIndex, + kPagedAttentionInputKDescalekIndex, + kPagedAttentionInputKOffsetIndex, + kPagedAttentionInputVDescaleIndex, + kPagedAttentionInputVOffsetIndex, + kPagedAttentionInputRazorOffsetIndex, + kPagedAttentionInputPScaleIndex, + kPagedAttentionInputLogNIndex, + }; + kernel_outputs_index_ = {kPagedAttentionOutputIndex}; + } + + private: + bool created_flag_{false}; + internal_v2::CustomASDPagedAttentionParam param_; +}; + +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP( + paged_attention, ms_custom_ops::PagedAttentionFuncImpl, ms_custom_ops::PagedAttention); diff --git a/ops/c_api/paged_attention/paged_attention_op.yaml b/ops/c_api/paged_attention/paged_attention_op.yaml new file mode 100644 index 0000000..0b97df4 --- /dev/null +++ b/ops/c_api/paged_attention/paged_attention_op.yaml @@ -0,0 +1,88 @@ +#operator paged_attention +paged_attention: + args: + query: + dtype: tensor + key_cache: + dtype: tensor + value_cache: + dtype: tensor + block_tables: + dtype: tensor + context_lens: + dtype: tensor + attn_mask: + dtype: tensor + default: None + batch_run_status: + dtype: tensor + default: None + k_descale: + dtype: tensor + default: None + k_offset: + dtype: tensor + default: None + v_descale: + dtype: tensor + default: None + v_offset: + dtype: tensor + default: None + razor_offset: + dtype: tensor + default: None + p_scale: + dtype: tensor + default: None + log_n: + dtype: tensor + default: None + q_seq_lens: + dtype: tensor + default: None + q_head_num: + dtype: int + default: 0 + qk_scale: + dtype: float + default: 1.0 + kv_head_num: + dtype: int + default: 0 + mask_type: + dtype: int + default: 0 + batch_run_status_enable: + dtype: bool + default: False + quant_type: + dtype: int + default: 0 + out_data_type: + dtype: int + default: -1 + has_quant_offset: + dtype: bool + default: False + compress_type: + dtype: int + default: 0 + calc_type: + dtype: int + default: 0 + scale_type: + dtype: int + default: 0 + input_layout: + dtype: int + default: 0 + mla_v_dim: + dtype: int + default: 0 + input_format: + dtype: int + default: 0 + returns: + attention_out: + dtype: tensor diff --git a/ops/c_api/paged_attention/paged_attention_pynative.cc b/ops/c_api/paged_attention/paged_attention_pynative.cc new file mode 100644 index 0000000..4377f89 --- /dev/null +++ b/ops/c_api/paged_attention/paged_attention_pynative.cc @@ -0,0 +1,229 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/c_api/paged_attention/paged_attention_common.h" + +#include +#include +#include +#include +#include + +#include "mindspore/include/custom_op_api.h" +#include "ops/c_api/utils/attention_utils.h" +#include "ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h" +#include "ops/framework/utils.h" + +namespace ms_custom_ops { +class PagedAttentionRunner : public InternalPyboostRunner { + public: + explicit PagedAttentionRunner(const std::string &op_name) : InternalPyboostRunner(op_name) {} + ~PagedAttentionRunner() = default; + + void SetParam(int32_t q_head_num, float qk_scale, int32_t kv_head_num, int32_t mask_type, + bool batch_run_status_enable, int32_t quant_type, int32_t out_data_type, bool has_quant_offset, + int32_t compress_type, int32_t calc_type, int32_t scale_type, int32_t input_layout, uint32_t mla_v_dim, + const std::vector &q_seq_len, const std::vector &kv_seq_len, + std::vector &batch_run_status) { + param_.q_head_num = q_head_num; + param_.tor = qk_scale; + param_.kv_head_num = kv_head_num; + param_.batch_run_status_enable = batch_run_status_enable; + param_.has_quant_offset = has_quant_offset; + param_.mla_v_dim = mla_v_dim; + param_.mask_type = static_cast(mask_type); + param_.quant_type = static_cast(quant_type); + param_.out_data_type = static_cast(out_data_type); + param_.compress_type = static_cast(compress_type); + param_.calc_type = static_cast(calc_type); + param_.scale_type = static_cast(scale_type); + param_.input_layout = static_cast(input_layout); + auto is_q_changed = CheckAndUpdate(q_seq_len, &(param_.q_seq_len)); + auto is_kv_changed = CheckAndUpdate(kv_seq_len, &(param_.kv_seq_len)); + (void)CheckAndUpdate(batch_run_status, &(param_.batch_run_status)); + + need_update_param_ = is_q_changed | is_kv_changed; + } + + void SetInputFormat(PAInputFormat input_format) { input_format_ = input_format; } + + protected: + bool UpdateParam() override { + if (created_flag_) { + // the q_seq_len and kv_seq_len are inited in CreatedKernel, so there is no need to load them again + created_flag_ = false; + return true; + } + + if (need_update_param_) { + auto ret = internal_op_->UpdateParam(¶m_); + if (ret != internal_v2::kInternalOk) { + MS_LOG(ERROR) << "ASD PagedAttention UpdateParam failed in MlaRunner."; + return false; + } + return true; + } + return true; + } + + internal_v2::InternalOpPtr CreateKernel(const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs) override { + created_flag_ = true; + // NZ format routing for 310P platform when input_format == 1 + if (input_format_ == PAInputFormat::kKVFormatNZ) { + auto inputs_new = inputs; + auto outputs_new = outputs; + // Set Query, Key Cache, Value Cache and Mask to FRACTAL_NZ format + inputs_new[kPagedAttentionInputQueryIndex].SetFormat(internal_v2::kFormatFRACTAL_NZ); + inputs_new[kPagedAttentionInputKeyCacheIndex].SetFormat(internal_v2::kFormatFRACTAL_NZ); + inputs_new[kPagedAttentionInputValueCacheIndex].SetFormat(internal_v2::kFormatFRACTAL_NZ); + inputs_new[kPagedAttentionInputAttnMaskIndex].SetFormat(internal_v2::kFormatFRACTAL_NZ); + // Set output to FRACTAL_NZ format + outputs_new[kPagedAttentionOutputIndex].SetFormat(internal_v2::kFormatFRACTAL_NZ); + return internal_v2::CreateCustomPagedAttentionOp(inputs_new, outputs_new, param_, + internal_v2::kInternalCustomPagedAttention); + } + return internal_v2::CreateCustomPagedAttentionOp(inputs, outputs, param_, + internal_v2::kInternalCustomPagedAttention); + } + + private: + internal_v2::CustomASDPagedAttentionParam param_; + bool created_flag_{true}; + bool need_update_param_{false}; + PAInputFormat input_format_{kKVFormatND}; +}; + +std::vector paged_attention_atb( + const ms::Tensor &query, const ms::Tensor &key_cache, const ms::Tensor &value_cache, + const ms::Tensor &block_tables, const ms::Tensor &context_lens, const std::optional &attn_mask, + const std::optional &batch_run_status, const std::optional &k_descale, + const std::optional &k_offset, const std::optional &v_descale, + const std::optional &v_offset, const std::optional &razor_offset, + const std::optional &p_scale, const std::optional &log_n, + const std::optional &q_seq_lens, int64_t q_head_num, double qk_scale, int64_t kv_head_num, + int64_t mask_type, bool batch_run_status_enable, int64_t quant_type, int64_t out_data_type, + bool has_quant_offset, int64_t compress_type, int64_t calc_type, int64_t scale_type, int64_t input_layout, + int64_t mla_v_dim, int64_t input_format) { + static auto op_name = "PagedAttention"; + auto runner = std::make_shared(op_name); + MS_EXCEPTION_IF_NULL(runner); + + if (input_format != PAInputFormat::kKVFormatND && input_format != PAInputFormat::kKVFormatNZ) { + MS_LOG(EXCEPTION) << "For " << op_name << ", the input_format is invalid: " << input_format; + } + + // q_seq_lens: always passed through param (move to CPU) for both 910B and 310P + std::vector q_seq_lens_value; + constexpr int64_t kPACalcTypeSpec = 1; + if (calc_type == kPACalcTypeSpec) { + if (!q_seq_lens.has_value()) { + MS_LOG(EXCEPTION) << "For " << op_name << ", calc_type is SPEC(MTP), q_seq_lens must be provided."; + } + q_seq_lens_value = GetValueFromTensor>(q_seq_lens.value(), op_name, "q_seq_lens"); + } + + // - 910B (ND format): GetValueFromTensor -> passed through param + // - 310P (NZ format): keep as NPU tensor, DON'T read value + std::vector context_lens_value; + if (input_format == PAInputFormat::kKVFormatND) { + // 910B: safe to read context_lens from CPU tensor + context_lens_value = GetValueFromTensor>(context_lens, op_name, "context_lens"); + } + + std::vector batch_run_status_value; + if (batch_run_status_enable && batch_run_status.has_value()) { + batch_run_status_value = + GetValueFromTensor>(batch_run_status.value(), op_name, "batch_run_status"); + } + runner->SetInputFormat(static_cast(input_format)); + runner->SetParam(static_cast(q_head_num), static_cast(qk_scale), static_cast(kv_head_num), + static_cast(mask_type), batch_run_status_enable, static_cast(quant_type), + static_cast(out_data_type), has_quant_offset, static_cast(compress_type), + static_cast(calc_type), static_cast(scale_type), + static_cast(input_layout), static_cast(mla_v_dim), q_seq_lens_value, + context_lens_value, batch_run_status_value); + + runner->Setup(op_name, query, key_cache, value_cache, block_tables, context_lens, attn_mask, batch_run_status, + k_descale, k_offset, v_descale, v_offset, razor_offset, p_scale, log_n, q_seq_lens, q_head_num, + qk_scale, kv_head_num, mask_type, batch_run_status_enable, quant_type, out_data_type, + has_quant_offset, compress_type, calc_type, scale_type, input_layout, mla_v_dim, input_format); + + auto output_data_type = query.data_type(); + if (query.data_type() == kNumberTypeInt8 && out_data_type != PAOutDataType::kPA_ACL_DT_UNDEFINED) { + if (out_data_type == PAOutDataType::kPA_ACL_FLOAT16) { + output_data_type = kNumberTypeFloat16; + } else if (out_data_type == PAOutDataType::kPA_ACL_BF16) { + output_data_type = kNumberTypeBFloat16; + } + } + auto attn_out = ms::Tensor(output_data_type, query.shape()); + + // Construct inputs: include context_lens as tensor input for 310P + // On 910B, it will be mapped to empty in ms_kernels_internal layer + std::vector inputs = {query, + key_cache, + value_cache, + block_tables, + context_lens, // Add context_lens as NPU tensor input + GetTensorOrEmpty(attn_mask), + GetTensorOrEmpty(k_descale), + GetTensorOrEmpty(k_offset), + GetTensorOrEmpty(v_descale), + GetTensorOrEmpty(v_offset), + GetTensorOrEmpty(razor_offset), + GetTensorOrEmpty(p_scale), + GetTensorOrEmpty(log_n)}; + std::vector outputs = {attn_out}; + runner->GetOrCreateKernel(inputs, outputs); + runner->Run(inputs, outputs); + return outputs; +} + +auto pyboost_paged_attention(const ms::Tensor &query, const ms::Tensor &key_cache, const ms::Tensor &value_cache, + const ms::Tensor &block_tables, const ms::Tensor &context_lens, + const std::optional &attn_mask, + const std::optional &batch_run_status, + const std::optional &k_descale, const std::optional &k_offset, + const std::optional &v_descale, const std::optional &v_offset, + const std::optional &razor_offset, const std::optional &p_scale, + const std::optional &log_n, const std::optional &q_seq_lens, + int64_t q_head_num, double qk_scale, int64_t kv_head_num, int64_t mask_type, + bool batch_run_status_enable, int64_t quant_type, int64_t out_data_type, + bool has_quant_offset, int64_t compress_type, int64_t calc_type, int64_t scale_type, + int64_t input_layout, int64_t mla_v_dim, int64_t input_format) { + return ms::pynative::PyboostRunner::Call( + paged_attention_atb, query, key_cache, value_cache, block_tables, context_lens, attn_mask, batch_run_status, + k_descale, k_offset, v_descale, v_offset, razor_offset, p_scale, log_n, q_seq_lens, q_head_num, qk_scale, + kv_head_num, mask_type, batch_run_status_enable, quant_type, out_data_type, has_quant_offset, compress_type, + calc_type, scale_type, input_layout, mla_v_dim, input_format); +} +} // namespace ms_custom_ops + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("paged_attention", &ms_custom_ops::pyboost_paged_attention, "PagedAttention", pybind11::arg("query"), + pybind11::arg("key_cache"), pybind11::arg("value_cache"), pybind11::arg("block_tables"), + pybind11::arg("context_lens"), pybind11::arg("attn_mask") = std::nullopt, + pybind11::arg("batch_run_status") = std::nullopt, pybind11::arg("k_descale") = std::nullopt, + pybind11::arg("k_offset") = std::nullopt, pybind11::arg("v_descale") = std::nullopt, + pybind11::arg("v_offset") = std::nullopt, pybind11::arg("razor_offset") = std::nullopt, + pybind11::arg("p_scale") = std::nullopt, pybind11::arg("log_n") = std::nullopt, + pybind11::arg("q_seq_lens") = std::nullopt, pybind11::arg("q_head_num") = 0, pybind11::arg("qk_scale") = 1.0, + pybind11::arg("kv_head_num") = 0, pybind11::arg("mask_type") = 0, + pybind11::arg("batch_run_status_enable") = false, pybind11::arg("quant_type") = 0, + pybind11::arg("out_data_type") = -1, pybind11::arg("has_quant_offset") = false, + pybind11::arg("compress_type") = 0, pybind11::arg("calc_type") = 0, pybind11::arg("scale_type") = 0, + pybind11::arg("input_layout") = 0, pybind11::arg("mla_v_dim") = 0, pybind11::arg("input_format") = 0); +} diff --git a/ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h b/ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h index 1420776..d00fa29 100644 --- a/ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h +++ b/ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h @@ -31,7 +31,6 @@ #include "internal.h" namespace ms_custom_ops { -using namespace mindspore; using TensorList = std::vector; class InternalPyboostRunner : public ms::pynative::PyboostRunner { @@ -85,7 +84,7 @@ class InternalPyboostRunner : public ms::pynative::PyboostRunner { uint64_t op_key_{0}; uint64_t tiling_key_{0}; internal_v2::InternalOpPtr internal_op_{nullptr}; - inline static std::unordered_map hash_map_; + std::unordered_map hash_map_; internal_v2::DtypeInfoList internal_inputs_dtype_; internal_v2::DtypeInfoList internal_outputs_dtype_; internal_v2::ShapeInfoList internal_inputs_shape_; diff --git a/tests/st/test_custom_paged_attention.py b/tests/st/test_custom_paged_attention.py new file mode 100644 index 0000000..c91627f --- /dev/null +++ b/tests/st/test_custom_paged_attention.py @@ -0,0 +1,2793 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +System tests for the ms_custom_ops.paged_attention operator. + +This module validates functional correctness, masking modes, GQA/MLA variants, +quantization paths, layouts, and MTP/lookahead decoding behaviour. +""" + +import math +import random +import numpy as np +import pytest +import mindspore as ms +from mindspore import Tensor, context, ops, nn +import ms_custom_ops + + +# Mask type enumerations +MASK_UNDEFINED = 0 +MASK_NORM = 1 +MASK_ALIBI = 2 +MASK_SPEC = 3 +MASK_FREE = 4 + +# Quantization type enumerations +QUANT_UNQUANT = 0 +DEQUANT_FUSION = 1 +QUANT_QKV_OFFLINE = 2 +QUANT_QKV_ONLINE = 3 + +# Input layout enumerations +INPUT_LAYOUT_BSND = 0 +INPUT_LAYOUT_BNSD = 1 + +# Input format enumerations +INPUT_FORMAT_ND = 0 +INPUT_FORMAT_NZ = 1 + + +class PagedAttentionDataGenerator: + """Data generator and golden reference calculator for paged attention tests. + + Handles mask construction, golden computation, and accuracy validation. + """ + + def __init__(self, rng_seed: int = 2025): + """Initialize with random seed.""" + self.rng = np.random.default_rng(rng_seed) + random.seed(rng_seed) + np.random.seed(rng_seed) + + def _np_dtype_for(self, ms_dtype: ms.dtype) -> np.dtype: + """Convert MindSpore dtype to NumPy dtype.""" + if ms_dtype == ms.float16: + return np.float16 + if ms_dtype == ms.bfloat16: + return np.float32 # bfloat16 not native in numpy + if ms_dtype == ms.int8: + return np.int8 + return np.float32 + + def generate_inputs(self, num_heads: int, kv_heads: int, head_size: int, + block_size: int, num_blocks: int, q_seq_lens: list, context_lens: list, + q_dtype: ms.dtype, kv_dtype: ms.dtype, mask_type: int, + quant_type: int = QUANT_UNQUANT, has_quant_offset: bool = False, + mla_v_dim: int = 0, mask_out_dtype: ms.dtype = None): + """Generate query, key_cache, value_cache, block_tables, masks, and quantization parameters. + + Directly accepts q_seq_lens and context_lens (kv_seq_lens) for flexible configuration, + especially important for MTP/lookahead scenarios. + + **Returns numpy arrays** for golden calculation; convert to Tensors before network input. + **Masks are generated in their natural shapes (without NZ padding)** for golden calculation. + NZ padding (if needed) should be applied later in _prepare_mask_for_network. + + Args: + num_heads: Number of query heads + kv_heads: Number of KV heads + head_size: Dimension per head (for Q/K in MLA mode, or all in + non-MLA mode) + block_size: Block size for KV cache + num_blocks: Total number of blocks + q_seq_lens: List of query sequence lengths per batch + (e.g., [1, 15, 30, 6] for MTP) + context_lens: List of context (KV) lengths per batch + (e.g., [10, 64, 64, 64]) + q_dtype: Query tensor dtype (for reference, actual data is numpy) + kv_dtype: Key/Value cache dtype (for reference, actual data is numpy) + mask_type: Type of attention mask (MASK_UNDEFINED, MASK_NORM, + MASK_SPEC, etc.) + quant_type: Quantization type. + Supported values: QUANT_UNQUANT, DEQUANT_FUSION, + QUANT_QKV_OFFLINE, QUANT_QKV_ONLINE + has_quant_offset: Whether to generate quantization offsets + mla_v_dim: MLA V/O head dimension (0 for non-MLA, >0 for MLA mode) + mask_out_dtype: Mask output dtype (for reference) + + Returns: + Dictionary containing all generated numpy arrays and metadata: + - query, key_cache, value_cache, block_tables, q_seq_lens, + kv_seq_lens, mask (all numpy arrays or None) + - k_descale, v_descale (for DEQUANT_FUSION, per-element, numpy + arrays) + - k_offset, v_offset (ONLY for DEQUANT_FUSION if has_quant_offset, + numpy arrays) + - k_scale_per_head, v_scale_per_head (for full QKV quant, + per-head, numpy arrays) + - p_scale (ONLY for QUANT_QKV_OFFLINE, per-head, numpy array) + - q_dtype, kv_dtype, mask_dtype (MindSpore dtypes for later Tensor + conversion) + Note: Q scale is NOT used (only K and V scales are used) + """ + # Determine if this is full quantization + is_full_quant = quant_type in (QUANT_QKV_OFFLINE, QUANT_QKV_ONLINE) + is_dequant_fusion = quant_type == DEQUANT_FUSION + + # Calculate total number of query tokens from q_seq_lens + num_tokens = sum(q_seq_lens) + batch_size = len(q_seq_lens) + + # Generate query array + q_np_dtype = self._np_dtype_for(q_dtype) + if q_dtype == ms.int8: + # For quantized int8: use range -2.0~2.0 before quantization + q_range = 2.0 if is_full_quant else 1.0 + query_np = self.rng.uniform( + -q_range, q_range, (num_tokens, num_heads, head_size) + ).astype(np.float32) + query_np = np.clip( + np.rint(query_np * 127.0 / q_range), -127, 127 + ).astype(np.int8) + else: + # Generate query in the target numpy dtype grid. + # For fp16: use np.float16 directly. + # For bf16: generate float32 and then quantize through MindSpore + # Tensor(bfloat16) + query_np = self.rng.uniform( + -1.0, 1.0, (num_tokens, num_heads, head_size) + ).astype(q_np_dtype) + if q_dtype == ms.bfloat16: + # Simulate bf16 quantization so that golden and operator see the + # same value grid. + # Tensor(..., ms.bfloat16).asnumpy() returns float32 values that + # lie on the bf16 grid. + query_np = Tensor(query_np, dtype=ms.bfloat16).asnumpy().astype( + np.float32 + ) + + # Determine head dimensions for MLA + # MLA: Q/K use head_size (head_size_qk), V/O use mla_v_dim (head_size_vo) + # Non-MLA: all use head_size + head_size_qk = head_size + head_size_vo = mla_v_dim if mla_v_dim > 0 else head_size + is_mla = mla_v_dim > 0 + + # Generate key/value cache arrays + # IMPORTANT: For MLA with KV combined mode, V cache views the first + # head_size_vo dimensions of K cache + kv_np_dtype = self._np_dtype_for(kv_dtype) + if kv_dtype == ms.int8: + # Full quant uses 2.0, dequant_fusion uses 4.0 + kv_range = 2.0 if is_full_quant else 4.0 + # For MLA combined mode, K cache must have at least head_size_qk dimensions + # (which includes head_size_vo) + cache_shape = (num_blocks, block_size, kv_heads, head_size_qk) + key_cache_np = self.rng.uniform( + -kv_range, kv_range, cache_shape + ).astype(np.float32) + key_cache_np = np.clip( + np.rint(key_cache_np * 127.0 / kv_range), -127, 127 + ).astype(np.int8) + # For MLA: V cache is the first head_size_vo dimensions of K cache (KV combined) + if is_mla: + value_cache_np = key_cache_np[:, :, :, :head_size_vo] + else: + v_cache_shape = (num_blocks, block_size, kv_heads, head_size_vo) + value_cache_np = self.rng.uniform( + -kv_range, kv_range, v_cache_shape + ).astype(np.float32) + value_cache_np = np.clip( + np.rint(value_cache_np * 127.0 / kv_range), -127, 127 + ).astype(np.int8) + else: + cache_shape = (num_blocks, block_size, kv_heads, head_size_qk) + key_cache_np = self.rng.uniform( + -1.0, 1.0, cache_shape + ).astype(kv_np_dtype) + # For MLA: V cache is the first head_size_vo dimensions of K cache (KV combined) + if is_mla: + value_cache_np = key_cache_np[:, :, :, :head_size_vo] + else: + v_cache_shape = (num_blocks, block_size, kv_heads, head_size_vo) + value_cache_np = self.rng.uniform( + -1.0, 1.0, v_cache_shape + ).astype(kv_np_dtype) + + # Generate block tables (numpy array) + max_context_len = max(context_lens) + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables_list = [] + for _ in range(batch_size): + block_table = [random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq)] + block_tables_list.append(block_table) + block_tables_np = np.array(block_tables_list, dtype=np.int32) + + # Use provided q_seq_lens and context_lens (kv_seq_lens) directly + # Keep as numpy arrays + kv_seq_lens = context_lens + q_seq_lens_np = np.array(q_seq_lens, dtype=np.int32) + kv_seq_lens_np = np.array(kv_seq_lens, dtype=np.int32) + + # Generate mask based on mask_type (returns numpy array without NZ padding) + # Mask dtype must match operator output dtype in quant scenes + mask_dtype = mask_out_dtype if mask_out_dtype is not None else q_dtype + is_mla = mla_v_dim > 0 + mask_np = self._generate_mask( + mask_type, + num_tokens, + max_context_len, + q_seq_lens, + kv_seq_lens, + mask_dtype, + num_heads, + is_mla, + ) + + # Generate quantization parameters based on quant_type + # All as numpy arrays + k_descale_np = None + v_descale_np = None + k_offset_np = None + v_offset_np = None + # Note: Q scale is NOT used in paged attention (only K/V scales are used) + k_scale_per_head_np = None + v_scale_per_head_np = None + p_scale_np = None # For full quant: softmax P matrix quantization scale + + if is_dequant_fusion: + # Dequant Fusion: generate per-element descale for KV + # Shape: (kv_heads * head_size_qk/vo) - per-element quantization + k_descale_np = self.rng.integers( + -1, 2, size=(kv_heads * head_size_qk,) + ).astype(np.float32) + v_descale_np = self.rng.integers( + -1, 2, size=(kv_heads * head_size_vo,) + ).astype(np.float32) + + # Generate offsets if requested + # Offsets are ONLY for Dequant Fusion, NOT for full quantization + # Shape: (kv_heads * head_size_qk/vo) - per-element offset + if has_quant_offset: + k_offset_np = self.rng.integers( + -20, 20, size=(kv_heads * head_size_qk,) + ).astype(np.int32) + v_offset_np = self.rng.integers( + -20, 20, size=(kv_heads * head_size_vo,) + ).astype(np.int32) + + elif is_full_quant: + # Full QKV Quant: generate per-head scales + # Range [-1, 2] for K/V scales + # NOTE: Q scale is NOT used (only K and V scales) + # NOTE: Full quantization does NOT use offsets + k_scale_per_head_np = self.rng.uniform( + -1.0, 2.0, size=(num_heads,) + ).astype(np.float32) + v_scale_per_head_np = self.rng.uniform( + -1.0, 2.0, size=(num_heads,) + ).astype(np.float32) + + # Generate P matrix quantization scale ONLY for offline quantization + # Online quantization (quantType=3) doesn't need p_scale + if quant_type == QUANT_QKV_OFFLINE: + # Shape: (num_heads,), range [-1, 2] + p_scale_np = self.rng.uniform( + -1.0, 2.0, size=(num_heads,) + ).astype(np.float32) + + # Return all parameters as a dictionary of numpy arrays and metadata + return { + 'query': query_np, + 'key_cache': key_cache_np, + 'value_cache': value_cache_np, + 'block_tables': block_tables_np, + 'q_seq_lens': q_seq_lens_np, + 'kv_seq_lens': kv_seq_lens_np, + 'mask': mask_np, + 'k_descale': k_descale_np, + 'v_descale': v_descale_np, + 'k_offset': k_offset_np, + 'v_offset': v_offset_np, + 'k_scale_per_head': k_scale_per_head_np, + 'v_scale_per_head': v_scale_per_head_np, + 'p_scale': p_scale_np, + 'q_dtype': q_dtype, + 'kv_dtype': kv_dtype, + 'mask_dtype': mask_dtype, + } + + def _get_alibi_slopes(self, num_heads: int) -> np.ndarray: + """Generate ALIBI slopes for positional bias. + + Args: + num_heads: Number of attention heads + + Returns: + ALIBI slopes array of shape (num_heads,) + """ + nearest_power_of_two = 2 ** int(np.floor(np.log2(num_heads))) + m0 = 2.0 ** (-8.0 / nearest_power_of_two) + slopes = np.array( + [m0 ** i for i in range(1, nearest_power_of_two + 1)], + dtype=np.float32, + ) + + if nearest_power_of_two < num_heads: + m1 = 2.0 ** (-4.0 / nearest_power_of_two) + # Generate additional slopes with step size 2 + additional_count = num_heads - nearest_power_of_two + mm = np.array( + [m1 ** i for i in range(1, 1 + 2 * additional_count, 2)], + dtype=np.float32, + ) + slopes = np.concatenate([slopes, mm], axis=0) + + return slopes + + def _generate_mask(self, mask_type: int, num_tokens: int, max_context_len: int, + q_seq_lens: list, kv_seq_lens: list, dtype: ms.dtype, num_heads: int = 0, + is_mla: bool = False) -> np.ndarray: + """Generate attention mask (without NZ padding). + + This method generates masks in their natural shapes for golden calculation. + NZ padding (if needed) should be applied later in _prepare_mask_for_network. + + Args: + mask_type: MASK_UNDEFINED, MASK_NORM, MASK_ALIBI, MASK_SPEC, etc. + num_tokens: Total query tokens + max_context_len: Maximum context length + q_seq_lens: List of query sequence lengths + kv_seq_lens: List of KV sequence lengths + dtype: Data type for mask (MindSpore dtype, used for determining numpy dtype) + num_heads: Number of attention heads (required for ALIBI) + is_mla: Whether this is MLA (Multi-Head Latent Attention) mode + + Returns: + Mask numpy array (without NZ padding) or None + """ + if mask_type == MASK_UNDEFINED: + return None + + if mask_type == MASK_FREE: + # MASK_FREE: Generate global causal mask for golden calculation + # Shape: (num_tokens, max_context_len) + # Only mask the last q_seqlen window within k_seqlen context + batch_size = len(q_seq_lens) + np_dtype = self._np_dtype_for(dtype) + mask_np = np.zeros((num_tokens, max_context_len), dtype=np_dtype) + prev_qseq = 0 + for i in range(batch_size): + qseq = q_seq_lens[i] + kseq = kv_seq_lens[i] + start = kseq - qseq + tri = np.ones((qseq, qseq), dtype=np_dtype) + tri = np.triu(tri, 1) # Upper triangular (exclude diagonal) + tri *= -60000.0 + mask_np[prev_qseq:(prev_qseq + qseq), start:kseq] = tri + prev_qseq += qseq + return mask_np + + # Determine pre_mask_factor for SPEC mask construction + # Note: + # - ALIBI mask doesn't use pre_mask_factor (position bias, not boolean) + # - NORM mask always uses -10000.0 regardless of dtype + # - SPEC mask: pre_mask_factor=1.0 ONLY for MLA + bf16, otherwise -10000.0 + # - post_mask_factor only applies to MLA + SPEC + bf16 paths + if mask_type == MASK_SPEC: + if is_mla and dtype == ms.bfloat16: + # For MLA + bf16 SPEC mask: use 1.0; golden multiplies by -10000.0 + pre_mask_factor = 1.0 + else: + # For non-MLA or fp16 SPEC mask: use -10000.0 directly + pre_mask_factor = -10000.0 + else: + # For ALIBI/NORM, pre_mask_factor is not used + pre_mask_factor = None + + if mask_type == MASK_NORM: + # Normal causal mask: upper triangular + # NORM mask ALWAYS uses -10000.0, regardless of dtype! + batch_size = len(q_seq_lens) + max_q_len = max(q_seq_lens) + + if max_q_len == 1: + # Decode mode: shape (num_tokens, 1, max_context_len) + mask_np = np.zeros((num_tokens, 1, max_context_len), dtype=np.float32) + for i in range(num_tokens): + # Mask out positions before current token index + if i > 0: + mask_np[i, :, :i] = -10000.0 + + return mask_np + + # Prefill mode: shape (batch_size, max_q_len, max_context_len) + np_dtype = self._np_dtype_for(dtype) + mask_np = np.zeros((batch_size, max_q_len, max_context_len), dtype=np_dtype) + for i in range(batch_size): + qseq = q_seq_lens[i] + # Create upper triangular mask (qseq x qseq) for causal attention + tri = np.ones((qseq, qseq), dtype=np_dtype) + tri = np.triu(tri, 1) # Upper triangular (exclude diagonal) + tri *= -10000.0 + # Place mask at the end: [-qseq:, -qseq:] + mask_np[i, -qseq:, -qseq:] = tri + return mask_np + + if mask_type == MASK_ALIBI: + # ALIBI mask: positional bias + # ALIBI is NOT a boolean mask but a bias added to attention scores + # Decode shape: (batch, num_heads, 1, max_context_len) + # Prefill shape: (batch, num_heads, max_q_len, max_context_len) + # ALIBI values are NOT multiplied by a mask factor + batch_size = len(q_seq_lens) + max_q_len = max(q_seq_lens) + + # For decode (all q_len=1), use shape (batch, num_heads, 1, max_context_len) + # For prefill/MTP, use shape (batch, num_heads, max_q_len, max_context_len) + if max_q_len == 1: + mask_np = np.zeros( + (batch_size, num_heads, 1, max_context_len), + dtype=np.float32, + ) + else: + mask_np = np.zeros( + (batch_size, num_heads, max_q_len, max_context_len), + dtype=np.float32, + ) + + alibi_slopes = self._get_alibi_slopes(num_heads) + + for i, (ql, kl) in enumerate(zip(q_seq_lens, kv_seq_lens)): + if kl == 0: + continue + # position_ids - context_len + 1 + # Generates negative bias values: [-kl+1, ..., -1, 0] + position_ids = np.arange(kl, dtype=np.int32) + alibi_bias = (position_ids - kl + 1).astype( + np.float32 + ) # [-kl+1, ..., -1, 0] + # Shape: (num_heads, 1, kl) + alibi_bias = ( + alibi_slopes.reshape((-1, 1, 1)) + * alibi_bias.reshape((1, 1, -1)) + ) + + if max_q_len == 1: + mask_np[i, :, :, :kl] = alibi_bias + # Direct assignment, no mask factor needed + else: + # For prefill/MTP: repeat for all query positions + # Each query position gets the same ALIBI bias per key position + mask_np[i, :, :ql, :kl] = alibi_bias + # Direct assignment, no mask factor needed + + return mask_np + + if mask_type == MASK_SPEC: + # SPEC mask for parallel decoding (MTP) + # Only mask within the last q_len window + np_dtype = self._np_dtype_for(dtype) + mask_np = np.zeros((num_tokens, max_context_len), dtype=np_dtype) + pre_q = 0 + for ql, kl in zip(q_seq_lens, kv_seq_lens): + if ql == 0 or kl == 0: + pre_q += ql + continue + # Create upper triangular mask for the last q_len tokens in context + start = max(0, kl - ql) + tri = np.triu(np.ones((ql, ql), dtype=np_dtype), 1) * pre_mask_factor + mask_np[pre_q: pre_q + ql, start: kl] = tri + pre_q += ql + return mask_np + + return None + + def compute_golden_reference( + self, + query: np.ndarray, + key_cache: np.ndarray, + value_cache: np.ndarray, + block_tables: np.ndarray, + q_seq_lens: np.ndarray, + kv_seq_lens: np.ndarray, + scale: float, + mask: np.ndarray, + mask_type: int = MASK_UNDEFINED, + k_descale: np.ndarray = None, + v_descale: np.ndarray = None, + k_offset: np.ndarray = None, + v_offset: np.ndarray = None, + k_scale_per_head: np.ndarray = None, + v_scale_per_head: np.ndarray = None, + p_scale: np.ndarray = None, + output_dtype: ms.dtype = None, + mla_v_dim: int = 0, + mask_dtype: ms.dtype = None, + ) -> np.ndarray: + """Compute golden reference output using NumPy. + + Supports group matmul with GQA support and quantization. + Supports MLA (Multi-Head Latent Attention) with different Q/K and V/O dimensions. + + Args: + query: Query numpy array + key_cache: Key cache numpy array + value_cache: Value cache numpy array + block_tables: Block table mapping numpy array + q_seq_lens: Query sequence lengths numpy array + kv_seq_lens: KV sequence lengths numpy array + scale: QK scaling factor + mask: Attention mask numpy array (optional) + mask_type: Mask type constant + (MASK_UNDEFINED, MASK_NORM, MASK_ALIBI, MASK_SPEC, MASK_FREE) + k_descale: Key dequantization scales (for Dequant Fusion int8 KV) + v_descale: Value dequantization scales (for Dequant Fusion int8 KV) + k_offset: Key dequantization offsets + (ONLY for Dequant Fusion with offset) + v_offset: Value dequantization offsets + (ONLY for Dequant Fusion with offset) + k_scale_per_head: Key per-head scales (for full QKV quant) + v_scale_per_head: Value per-head scales (for full QKV quant) + p_scale: P matrix quantization scale (for offline full QKV quant) + Note: Q scale is NOT used + output_dtype: Output data type (for quantization scenarios) + mla_v_dim: MLA V/O head dimension (0 for non-MLA, >0 for MLA mode) + mask_dtype: Original mask dtype (for determining post_mask_factor + in MLA + SPEC + bf16) + + Returns: + Golden reference output numpy array + """ + # Convert to float32 for computation (inputs are already numpy arrays) + q_np = query.astype(np.float32) + kc_np = key_cache.astype(np.float32) + vc_np = value_cache.astype(np.float32) + + # Determine head dimensions (MLA support) + num_tokens, num_heads, head_size_qk = q_np.shape + kv_heads = kc_np.shape[2] + head_size_k = kc_np.shape[3] # Key head dimension + head_size_v = vc_np.shape[3] # Value head dimension (may differ in MLA) + output_head_dim = head_size_v if mla_v_dim > 0 else head_size_qk + + # Apply dequantization for Dequant Fusion (int8 KV, fp16 Q) + # Add offset (if provided) then multiply by per-element descale + # Shapes for per-element params: (kv_heads * head_size_qk/vo) + if k_descale is not None and v_descale is not None: + k_scale_np = k_descale.astype(np.float32) + v_scale_np = v_descale.astype(np.float32) + # Reshape scale from (kv_heads * head_size) to (kv_heads, head_size) + k_scale_reshaped = k_scale_np.reshape(kv_heads, head_size_k) + v_scale_reshaped = v_scale_np.reshape(kv_heads, head_size_v) + # Offsets are optional (only for Dequant Fusion with offset) + if k_offset is not None: + k_off_np = k_offset.astype(np.float32).reshape(kv_heads, head_size_k) + kc_np += k_off_np[np.newaxis, np.newaxis, :, :] + if v_offset is not None: + v_off_np = v_offset.astype(np.float32).reshape(kv_heads, head_size_v) + vc_np += v_off_np[np.newaxis, np.newaxis, :, :] + # Apply per-element scale: broadcast to (num_blocks, block_size, kv_heads, head_size) + kc_np *= k_scale_reshaped[np.newaxis, np.newaxis, :, :] + vc_np *= v_scale_reshaped[np.newaxis, np.newaxis, :, :] + + # For full QKV quantization: keep int8 data, scales will be applied after matmul + # We'll apply dequantization AFTER int32 matmul results + is_full_quant = k_scale_per_head is not None and v_scale_per_head is not None + if is_full_quant: + k_scale_np = k_scale_per_head.astype(np.float32) + v_scale_np = v_scale_per_head.astype(np.float32) + # Keep data as int8 for now (stored in float32 array but with int8 values) + # Convert to int32 for matmul computation + q_np = q_np.astype(np.int32) + kc_np = kc_np.astype(np.int32) + vc_np = vc_np.astype(np.int32) + + tables_np = block_tables.astype(np.int32) + q_lens = q_seq_lens.tolist() + k_lens = kv_seq_lens.tolist() + + block_size = kc_np.shape[1] + + # Initialize output with correct dimensions (MLA: output uses head_size_v) + # Non-MLA: [num_tokens, num_heads, head_size_qk] + # MLA: [num_tokens, num_heads, head_size_v] + out_np = np.zeros( + (num_tokens, num_heads, output_head_dim), + dtype=np.float32, + ) + + q_idx = 0 + for batch_idx, (ql, kl) in enumerate(zip(q_lens, k_lens)): + if ql == 0 or kl == 0: + q_idx += ql + continue + + # Extract KV from cache and prepare tensors + key_seq, value_seq = self._extract_kv_from_cache( + kc_np, vc_np, tables_np[batch_idx], kl, block_size + ) + + q_slice = q_np[q_idx: q_idx + ql] + query_t = np.transpose(q_slice, (1, 0, 2)) + key_t = np.transpose(key_seq, (1, 2, 0)) + value_t = np.transpose(value_seq, (1, 0, 2)) + + # Compute attention output based on quantization type + if is_full_quant: + out_slice = self._compute_full_quant_attention( + query_t, key_t, value_t, num_heads, kv_heads, scale, + k_scale_np, v_scale_np, p_scale, mask, mask_type, + batch_idx, q_idx, ql, kl, mla_v_dim, mask_dtype + ) + else: + out_slice = self._compute_float_attention( + query_t, key_t, value_t, num_heads, kv_heads, scale, + mask, mask_type, batch_idx, q_idx, ql, kl, + mla_v_dim, mask_dtype + ) + + # Transpose output from [num_heads, ql, head_size] to [ql, num_heads, head_size] + out_np[q_idx: q_idx + ql] = np.transpose(out_slice, (1, 0, 2)) + + q_idx += ql + + # Convert to specified output dtype (for quantization) or keep as float32 + if output_dtype is not None: + out_dtype_np = self._np_dtype_for(output_dtype) + return out_np.astype(out_dtype_np) + + # Return as float32 or convert to original query dtype + return out_np.astype(query.dtype) + + def _extract_kv_from_cache( + self, + key_cache: np.ndarray, + value_cache: np.ndarray, + block_table: np.ndarray, + kv_len: int, + block_size: int, + ) -> tuple: + """Extract keys and values from paged cache using block table. + + Args: + key_cache: Key cache array [num_blocks, block_size, kv_heads, head_size] + value_cache: Value cache array [num_blocks, block_size, kv_heads, head_size] + block_table: Block table for current sequence + kv_len: KV sequence length + block_size: Block size + + Returns: + Tuple of (key_seq, value_seq) arrays + """ + keys_list = [] + values_list = [] + for j in range(kv_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + keys_list.append(key_cache[block_number, block_offset]) + values_list.append(value_cache[block_number, block_offset]) + + key_seq = np.stack(keys_list, axis=0) + value_seq = np.stack(values_list, axis=0) + return key_seq, value_seq + + def _compute_full_quant_attention( + self, + query_t: np.ndarray, + key_t: np.ndarray, + value_t: np.ndarray, + num_heads: int, + kv_heads: int, + scale: float, + k_scale: np.ndarray, + v_scale: np.ndarray, + p_scale: np.ndarray, + mask: np.ndarray, + mask_type: int, + batch_idx: int, + q_idx: int, + ql: int, + kl: int, + mla_v_dim: int, + mask_dtype: ms.dtype, + ) -> np.ndarray: + """Compute attention with full QKV quantization (int8 Q/K/V). + + Returns: + Output slice [num_heads, ql, head_size] + """ + # Q@K^T with int32, then apply K scale + scores_int32 = self._group_matmul(query_t, key_t, num_heads, kv_heads) + + scores = np.zeros_like(scores_int32, dtype=np.float32) + for h in range(num_heads): + scores[h] = scores_int32[h].astype(np.float32) * k_scale[h] + + scores = scores * scale + + # Apply mask + if mask is not None: + scores = self._apply_mask_simple( + scores, mask, mask_type, batch_idx, q_idx, ql, kl, + mla_v_dim, mask_dtype + ) + + # Softmax + scores_max = np.max(scores, axis=-1, keepdims=True) + exp_scores = np.exp(scores - scores_max) + row_sum = np.sum(exp_scores, axis=-1, keepdims=True) + + # Quantize P and apply V + has_p_scale = p_scale is not None + if has_p_scale: + out_slice = self._apply_offline_quant_pv( + exp_scores, value_t, num_heads, kv_heads, v_scale, p_scale + ) + else: + out_slice = self._apply_online_quant_pv( + exp_scores, value_t, num_heads, kv_heads, v_scale + ) + + return out_slice / row_sum + + def _compute_float_attention( + self, + query_t: np.ndarray, + key_t: np.ndarray, + value_t: np.ndarray, + num_heads: int, + kv_heads: int, + scale: float, + mask: np.ndarray, + mask_type: int, + batch_idx: int, + q_idx: int, + ql: int, + kl: int, + mla_v_dim: int, + mask_dtype: ms.dtype, + ) -> np.ndarray: + """Compute standard float attention (unquant or dequant fusion). + + Returns: + Output slice [num_heads, ql, head_size] + """ + scores = self._group_matmul(query_t, key_t, num_heads, kv_heads) * scale + + if mask is not None: + scores = self._apply_mask_full( + scores, mask, mask_type, batch_idx, q_idx, ql, kl, + mla_v_dim, mask_dtype + ) + + scores_max = np.max(scores, axis=-1, keepdims=True) + exp_scores = np.exp(scores - scores_max) + probs = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True) + + return self._group_matmul_pv(probs, value_t, num_heads, kv_heads) + + def _apply_mask_simple( + self, + scores: np.ndarray, + mask: np.ndarray, + mask_type: int, + q_idx: int, + ql: int, + kl: int, + mla_v_dim: int, + mask_dtype: ms.dtype, + ) -> np.ndarray: + """Apply mask for full quant path (MASK_FREE and SPEC only).""" + mask_np = mask.astype(np.float32) + post_factor = 1.0 + if mla_v_dim > 0 and mask_type == MASK_SPEC and mask_dtype == ms.bfloat16: + post_factor = -10000.0 + + if mask_type == MASK_FREE: + mask_slice = mask_np[q_idx: q_idx + ql, :kl] + scores += mask_slice[np.newaxis, :, :] + else: + mask_slice = mask_np[q_idx: q_idx + ql, :kl] + scores += mask_slice[np.newaxis, :, :] * post_factor + + return scores + + def _apply_mask_full( + self, + scores: np.ndarray, + mask: np.ndarray, + mask_type: int, + batch_idx: int, + q_idx: int, + ql: int, + kl: int, + mla_v_dim: int, + mask_dtype: ms.dtype, + ) -> np.ndarray: + """Apply mask for float attention (all mask types).""" + mask_np = mask.astype(np.float32) + post_factor = 1.0 + if mla_v_dim > 0 and mask_type == MASK_SPEC and mask_dtype == ms.bfloat16: + post_factor = -10000.0 + + if mask_type == MASK_FREE: + mask_slice = mask_np[q_idx: q_idx + ql, :kl] + scores += mask_slice[np.newaxis, :, :] + elif mask_type == MASK_ALIBI: + if mask_np.ndim == 4: + mask_slice = mask_np[batch_idx, :, :ql, :kl] + else: + mask_slice = mask_np[batch_idx, :, :, :kl][:, :ql, :] + scores += mask_slice + elif mask_type == MASK_SPEC: + mask_slice = mask_np[q_idx: q_idx + ql, :kl] + scores += mask_slice[np.newaxis, :, :] * post_factor + else: + # MASK_NORM + if mask_np.ndim == 3: + if mask_np.shape[1] == 1: + mask_slice = mask_np[batch_idx, 0, :kl] + scores += mask_slice[np.newaxis, np.newaxis, :] + else: + mask_slice = mask_np[batch_idx, -ql:, :kl] + scores += mask_slice[np.newaxis, :, :] + + return scores + + def _apply_offline_quant_pv( + self, + exp_scores: np.ndarray, + value_t: np.ndarray, + num_heads: int, + kv_heads: int, + v_scale: np.ndarray, + p_scale: np.ndarray, + ) -> np.ndarray: + """Apply offline P quantization and V matmul.""" + p_scale_np = p_scale.astype(np.float32) + probs_fp = exp_scores * p_scale_np.reshape(num_heads, 1, 1) + probs_int8 = np.rint(probs_fp.astype(np.float16)).astype(np.int8) + probs_int8 = probs_int8.astype(np.int32) + + out_int32 = self._group_matmul_pv(probs_int8, value_t, num_heads, kv_heads) + out_slice = np.zeros_like(out_int32, dtype=np.float32) + for h in range(num_heads): + out_slice[h, :, :] = out_int32[h, :, :].astype(np.float32) * v_scale[h] + + return out_slice + + def _apply_online_quant_pv( + self, + exp_scores: np.ndarray, + value_t: np.ndarray, + num_heads: int, + kv_heads: int, + v_scale: np.ndarray, + ) -> np.ndarray: + """Apply online P quantization and V matmul.""" + row_maxp = np.max(exp_scores, axis=-1, keepdims=True) + p_scale_dynamic = row_maxp / 127.0 + probs_fp = exp_scores / p_scale_dynamic + probs_int8 = np.rint(probs_fp.astype(np.float16)).astype(np.int8) + probs_int8 = probs_int8.astype(np.int32) + + out_int32 = self._group_matmul_pv(probs_int8, value_t, num_heads, kv_heads) + out_slice = np.zeros_like(out_int32, dtype=np.float32) + for h in range(num_heads): + de_scalev = v_scale[h] * row_maxp[h, 0, 0] / 127.0 + out_slice[h, :, :] = out_int32[h, :, :].astype(np.float32) * de_scalev + + return out_slice + + def _group_matmul(self, query_block: np.ndarray, key_block: np.ndarray, + num_heads: int, kv_heads: int) -> np.ndarray: + """Group matmul for Q @ K^T with GQA support. + + Query and key blocks should be pre-transposed before calling this method. + - query_block: [num_heads, ql, head_size] + - key_block: [kv_heads, head_size, kl] + + Args: + query_block: Query [num_heads, ql, head_size] (pre-transposed) + key_block: Key [kv_heads, head_size, kl] (pre-transposed) + num_heads: Number of query heads + kv_heads: Number of KV heads + + Returns: + Scores [num_heads, ql, kl] + """ + # Always use group loop + # When kv_heads == num_heads, group_size = 1, loop runs num_heads times + group_size = num_heads // kv_heads + scores_list = [] + for kv_h in range(kv_heads): + query_group = query_block[ + kv_h * group_size: (kv_h + 1) * group_size, :, : + ] # [group_size, ql, head_size] + key_head_block = key_block[kv_h: kv_h + 1, :, :] # [1, head_size, kl] + # query_group @ key_head_block -> + # [group_size, ql, head_size] @ [1, head_size, kl] = [group_size, ql, kl] + scores_group = np.matmul( + query_group.astype(np.float32), key_head_block.astype(np.float32) + ) + scores_list.append(scores_group) + return np.concatenate(scores_list, axis=0) # [num_heads, ql, kl] + + def _group_matmul_pv(self, prob_block: np.ndarray, value_block: np.ndarray, + num_heads: int, kv_heads: int) -> np.ndarray: + """Group matmul for P @ V with GQA support. + + Value block should be pre-transposed before calling this method. + - prob_block: [num_heads, ql, kl] + - value_block: [kv_heads, kl, head_size] + + Args: + prob_block: Attention probabilities [num_heads, ql, kl] + value_block: Value [kv_heads, kl, head_size] (pre-transposed) + num_heads: Number of query heads + kv_heads: Number of KV heads + + Returns: + Output [num_heads, ql, head_size] + """ + # Always use group loop + # When kv_heads == num_heads, group_size = 1, loop runs num_heads times + group_size = num_heads // kv_heads + out_list = [] + for kv_h in range(kv_heads): + prob_group = prob_block[ + kv_h * group_size: (kv_h + 1) * group_size, :, : + ] # [group_size, ql, kl] + value_head_block = value_block[kv_h: kv_h + 1, :, :] # [1, kl, head_size] + # prob_group @ value_head_block -> + # [group_size, ql, kl] @ [1, kl, head_size] = [group_size, ql, head_size] + out_group = np.matmul( + prob_group.astype(np.float32), value_head_block.astype(np.float32) + ) + out_list.append(out_group) + return np.concatenate(out_list, axis=0) # [num_heads, ql, head_size] + + def validate_accuracy( + self, + output: np.ndarray, + golden: np.ndarray, + dtype: ms.dtype, + num_heads: int, + max_context_len: int, + is_quant: bool = False, + ) -> bool: + """Validate output accuracy against golden reference. + + Uses both legacy (ratio-based) and adaptive (complexity-based) thresholds. + Quantization paths use stricter thresholds. + + Args: + output: Operator output numpy array + golden: Golden reference numpy array + dtype: Data type (output dtype for quantization) + num_heads: Number of heads + max_context_len: Maximum context length + is_quant: Whether this is a quantization test (affects threshold selection) + + Returns: + True if accuracy check passes + """ + out_np = output.astype(np.float32) + golden_np = golden.astype(np.float32) + + out_flat = out_np.flatten() + golden_flat = golden_np.flatten() + diff = np.abs(out_flat - golden_flat) + max_diff = np.max(diff) + + # Legacy ratio-based validation (slightly relaxed for bf16+quant) + ratios = [0.001, 0.001, 0.005, 0.005] # [rel_loose, abs_loose, rel_strict, abs_strict] + rel_loose, abs_loose, rel_strict, abs_strict = ratios + + # 只在 "bf16 且为量化" 的场景下放宽严格阈值; + # - 纯 bf16(非量化)仍使用原始严格阈值 + # - 其它量化精度(如 fp16+int8)也使用原始严格阈值 + strict_scale = 2.0 if (dtype == ms.bfloat16 and is_quant) else 1.0 + rel_strict_eff = rel_strict * strict_scale + abs_strict_eff = abs_strict * strict_scale + + limit_error = np.maximum(np.abs(golden_flat) * rel_loose, abs_loose) + strict_limit_error = np.maximum(np.abs(golden_flat) * rel_strict_eff, abs_strict_eff) + error_count = np.sum(diff > limit_error) + strict_error_count = np.sum(diff > strict_limit_error) + + out_len = max(1, out_flat.shape[0]) + accuracy_loose = 1.0 - float(error_count) / out_len + accuracy_strict = 1.0 - float(strict_error_count) / out_len + + print(f"Max difference: {max_diff:.6e}") + print(f"Loose accuracy (1/1000): {accuracy_loose:.6f}") + print(f"Strict accuracy (5/1000): {accuracy_strict:.6f}") + + # Quantization uses stricter threshold + # For "bf16 with quantization" scenario, relax strict error ratio; + # Other scenarios: + # - Quantization (non bf16) still uses stricter ratio + # - Non-quantization uses looser ratio + error_ratio = float(strict_error_count) / out_len + if dtype == ms.bfloat16 or is_quant: + # Original strict ratio is 0.005; relax to 0.02 for these harder cases. + legacy_pass = error_ratio <= rel_strict_eff + else: + legacy_pass = error_ratio <= ratios[0] + + # Adaptive validation based on computation complexity + calc_times = num_heads * max_context_len + 4 + if dtype == ms.bfloat16: + # For bf16, especially with quantization (int8 KV / full QKV quant), + # relax adaptive threshold by one bit compared to the original setting. + base = 2 ** (-7) if calc_times < 2048 else 2 ** (-6) + error_factor = base * (2.0 if is_quant else 1.0) + elif dtype == ms.float16: + error_factor = 2 ** (-8) if calc_times < 2048 else 2 ** (-7) + else: # float32 + if calc_times < 2048: + error_factor = 2 ** (-11) + elif calc_times < 16384: + error_factor = 2 ** (-10) + else: + error_factor = 2 ** (-9) + + error_threshold = np.maximum(np.abs(golden_flat), 1.0) * error_factor + adaptive_pass = np.all(diff <= error_threshold) + + print(f"Calculation complexity: {calc_times}") + print(f"Error factor: {error_factor:.6e}") + print(f"Adaptive test: {'PASS' if adaptive_pass else 'FAIL'}") + print(f"Legacy test: {'PASS' if legacy_pass else 'FAIL'}") + + return bool(adaptive_pass or legacy_pass) + + +class PagedAttentionNet(nn.Cell): + """MindSpore network wrapper for paged_attention operator. + + Handles CPU transfer of sequence length tensors internally. + """ + + def __init__(self, q_head_num: int, qk_scale: float, kv_head_num: int, mask_type: int, + batch_run_status_enable: bool = False, quant_type: int = QUANT_UNQUANT, + out_data_type: int = -1, has_quant_offset: bool = False, compress_type: int = 0, + calc_type: int = 0, scale_type: int = 0, input_layout: int = INPUT_LAYOUT_BSND, + mla_v_dim: int = 0, input_format: int = INPUT_FORMAT_ND): + super().__init__() + self.q_head_num = int(q_head_num) + self.qk_scale = float(qk_scale) + self.kv_head_num = int(kv_head_num) + self.mask_type = int(mask_type) + self.batch_run_status_enable = bool(batch_run_status_enable) + self.quant_type = int(quant_type) + self.out_data_type = int(out_data_type) + self.has_quant_offset = bool(has_quant_offset) + self.compress_type = int(compress_type) + self.calc_type = int(calc_type) + self.scale_type = int(scale_type) + self.input_layout = int(input_layout) + self.mla_v_dim = int(mla_v_dim) + self.input_format = int(input_format) + self._is_pynative = context.get_context("mode") == context.PYNATIVE_MODE + + def construct(self, query, key_cache, value_cache, block_tables, + attn_mask, batch_run_status, + k_descale, k_offset, v_descale, v_offset, + razor_offset, p_scale, log_n, + q_seq_lens, kv_seq_lens): + """Forward pass with CPU tensor transfer for sequence lengths.""" + # Transfer sequence length tensors to CPU (required by internal runner) + needs_q_seq = self.calc_type == 1 and q_seq_lens is not None + if self._is_pynative: + kv_seq_cpu = kv_seq_lens.move_to("CPU") + q_seq_cpu = q_seq_lens.move_to("CPU") if needs_q_seq else None + else: + kv_seq_cpu = ops.move_to(kv_seq_lens, "CPU") + q_seq_cpu = ops.move_to(q_seq_lens, "CPU") if needs_q_seq else None + + return ms_custom_ops.paged_attention( + query, + key_cache, + value_cache, + block_tables, + kv_seq_cpu, + attn_mask, + batch_run_status, + k_descale, + k_offset, + v_descale, + v_offset, + razor_offset, + p_scale, + log_n, + q_seq_cpu if self.calc_type == 1 else None, + self.q_head_num, + self.qk_scale, + self.kv_head_num, + self.mask_type, + self.batch_run_status_enable, + self.quant_type, + self.out_data_type, + self.has_quant_offset, + self.compress_type, + self.calc_type, + self.scale_type, + self.input_layout, + self.mla_v_dim, + self.input_format, + ) + + +def _set_dynamic_shapes_for_pa(net: PagedAttentionNet, test_config: dict) -> None: + """Configure dynamic input shapes for PagedAttentionNet.""" + # Extract key params + num_heads = int(test_config['num_heads']) + kv_heads = int(test_config['kv_heads']) + head_size_qk = int(test_config['head_size']) + head_size_vo = int(test_config.get('mla_v_dim', 0) or head_size_qk) + q_dtype = test_config['q_dtype'] + kv_dtype = test_config['kv_dtype'] + mask_type = int(test_config['mask_type']) + quant_type = int(test_config.get('quant_type', QUANT_UNQUANT)) + has_quant_offset = bool(test_config.get('has_quant_offset', False)) + # Dynamic shapes (None for variable dims) + query_dyn = Tensor(shape=[None, num_heads, head_size_qk], dtype=q_dtype) + key_dyn = Tensor(shape=[None, None, kv_heads, head_size_qk], dtype=kv_dtype) + value_dyn = Tensor(shape=[None, None, kv_heads, head_size_vo], dtype=kv_dtype) + block_tables_dyn = Tensor(shape=[None, None], dtype=ms.int32) + + # Mask shapes by type + mask_dyn = None + mask_dtype = test_config.get('expected_dtype', q_dtype) + if mask_type == MASK_ALIBI: + mask_dyn = Tensor(shape=[None, num_heads, None, None], dtype=mask_dtype) + elif mask_type == MASK_NORM: + # NORM mask: shape varies by mode + # Decode: (num_tokens, 1, max_context_len) + # Prefill: (batch_size, max_q_len, max_context_len) + mask_dyn = Tensor(shape=[None, None, None], dtype=mask_dtype) + elif mask_type == MASK_SPEC: + mask_dyn = Tensor(shape=[None, None], dtype=mask_dtype) + elif mask_type in (MASK_FREE, MASK_UNDEFINED): + mask_dyn = None + + # Optional tensors + batch_run_status_enable = test_config.get('batch_run_status_enable', False) + if batch_run_status_enable: + batch_run_status_dyn = Tensor(shape=[None], dtype=ms.int32) + else: + batch_run_status_dyn = None + + # Descale/offset shapes depend on quantization type + # We only need correct rank/dtype here; lengths are dynamic + if quant_type in [QUANT_QKV_OFFLINE, QUANT_QKV_ONLINE]: + # per-head scales in descale slots, and optional p_scale (offline) + k_descale_dyn = Tensor(shape=[None], dtype=ms.float32) + v_descale_dyn = Tensor(shape=[None], dtype=ms.float32) + if quant_type == QUANT_QKV_OFFLINE: + p_scale_dyn = Tensor(shape=[None], dtype=ms.float32) + else: + p_scale_dyn = None + k_offset_dyn = None + v_offset_dyn = None + elif quant_type == DEQUANT_FUSION: + # per-element descale; use dynamic length 1-D + k_descale_dyn = Tensor(shape=[None], dtype=ms.float32) + v_descale_dyn = Tensor(shape=[None], dtype=ms.float32) + p_scale_dyn = None + k_offset_dyn = Tensor(shape=[None], dtype=ms.int32) if has_quant_offset else None + v_offset_dyn = Tensor(shape=[None], dtype=ms.int32) if has_quant_offset else None + else: + k_descale_dyn = None + v_descale_dyn = None + k_offset_dyn = None + v_offset_dyn = None + p_scale_dyn = None + + # q_seq_lens / kv_seq_lens must be static shapes in set_inputs + batch_size = len(test_config.get('context_lens', [])) + q_seq_dyn = Tensor(shape=[batch_size], dtype=ms.int32) + kv_seq_dyn = Tensor(shape=[batch_size], dtype=ms.int32) + + # razor_offset/log_n are unused in our tests; keep None + net.set_inputs( + query_dyn, key_dyn, value_dyn, block_tables_dyn, + mask_dyn, batch_run_status_dyn, + k_descale_dyn, k_offset_dyn, v_descale_dyn, v_offset_dyn, + None, p_scale_dyn, None, + q_seq_dyn, kv_seq_dyn, + ) + + +def _run_paged_attention_test(generator: PagedAttentionDataGenerator, test_config: dict, + run_mode: int, validate_accuracy: bool = True, dynamic: bool = False): + """Execute paged attention test with given configuration. + + Following the refactored workflow: + 1. Generate numpy arrays for golden calculation + 2. Compute golden reference using numpy arrays + 3. Convert numpy arrays to Tensors for network input + 4. Execute network + 5. Validate output against golden + + Args: + generator: Data generator instance + test_config: Test configuration dictionary + run_mode: GRAPH_MODE or PYNATIVE_MODE + validate_accuracy: Whether to validate accuracy against golden reference + dynamic: Whether to use dynamic shapes + """ + context.set_context(device_target="Ascend", mode=run_mode) + + # Extract configuration + num_heads = test_config['num_heads'] + kv_heads = test_config['kv_heads'] + head_size = test_config['head_size'] + block_size = test_config['block_size'] + num_blocks = test_config['num_blocks'] + context_lens = test_config['context_lens'] + q_dtype = test_config['q_dtype'] + kv_dtype = test_config['kv_dtype'] + mask_type = test_config['mask_type'] + qk_scale = test_config.get('qk_scale', 1.0 / math.sqrt(head_size)) + + # Get q_seq_lens from test_config (directly configure q_seq_lens) + # If not provided, default to [1] * batch_size for basic decode mode + q_seq_lens_config = test_config.get('q_seq_lens', [1] * len(context_lens)) + + # Determine quantization type + quant_type = test_config.get('quant_type', QUANT_UNQUANT) + has_quant_offset = test_config.get('has_quant_offset', False) + is_full_quant = quant_type in (QUANT_QKV_OFFLINE, QUANT_QKV_ONLINE) + + # MLA configuration + mla_v_dim = test_config.get('mla_v_dim', 0) + + # Step 1: Generate numpy arrays for golden calculation + # Determine mask dtype to match expected output dtype when provided + mask_out_dtype = test_config.get('expected_dtype') + inputs_dict = generator.generate_inputs( + num_heads, + kv_heads, + head_size, + block_size, + num_blocks, + q_seq_lens_config, + context_lens, + q_dtype, + kv_dtype, + mask_type, + quant_type, + has_quant_offset, + mla_v_dim, + mask_out_dtype, + ) + + # Step 2: Compute golden reference using numpy arrays + golden_output_dtype = test_config.get('expected_dtype') if is_full_quant else None + golden = generator.compute_golden_reference( + inputs_dict['query'], + inputs_dict['key_cache'], + inputs_dict['value_cache'], + inputs_dict['block_tables'], + inputs_dict['q_seq_lens'], + inputs_dict['kv_seq_lens'], + qk_scale, + inputs_dict['mask'], + mask_type, + inputs_dict['k_descale'], + inputs_dict['v_descale'], + inputs_dict['k_offset'], + inputs_dict['v_offset'], + inputs_dict['k_scale_per_head'], + inputs_dict['v_scale_per_head'], + inputs_dict['p_scale'], + output_dtype=golden_output_dtype, + mla_v_dim=mla_v_dim, + mask_dtype=inputs_dict['mask_dtype'] + ) + + # Step 3: Convert numpy arrays to Tensors for network input + query = Tensor(inputs_dict['query'], dtype=q_dtype) + key_cache = Tensor(inputs_dict['key_cache'], dtype=kv_dtype) + value_cache = Tensor(inputs_dict['value_cache'], dtype=kv_dtype) + block_tables = Tensor(inputs_dict['block_tables'], dtype=ms.int32) + q_seq_lens = Tensor(inputs_dict['q_seq_lens'], dtype=ms.int32) + kv_seq_lens = Tensor(inputs_dict['kv_seq_lens'], dtype=ms.int32) + mask = ( + Tensor(inputs_dict['mask'], dtype=inputs_dict['mask_dtype']) + if inputs_dict['mask'] is not None + else None + ) + + # Convert quantization parameters to Tensors + k_descale = (Tensor(inputs_dict['k_descale'], dtype=ms.float32) + if inputs_dict['k_descale'] is not None else None) + v_descale = (Tensor(inputs_dict['v_descale'], dtype=ms.float32) + if inputs_dict['v_descale'] is not None else None) + k_offset = (Tensor(inputs_dict['k_offset'], dtype=ms.int32) + if inputs_dict['k_offset'] is not None else None) + v_offset = (Tensor(inputs_dict['v_offset'], dtype=ms.int32) + if inputs_dict['v_offset'] is not None else None) + k_scale_per_head = (Tensor(inputs_dict['k_scale_per_head'], dtype=ms.float32) + if inputs_dict['k_scale_per_head'] is not None else None) + v_scale_per_head = (Tensor(inputs_dict['v_scale_per_head'], dtype=ms.float32) + if inputs_dict['v_scale_per_head'] is not None else None) + p_scale = (Tensor(inputs_dict['p_scale'], dtype=ms.float32) + if inputs_dict['p_scale'] is not None else None) + + # Optional batch_run_status (not generated by generate_inputs) + batch_run_status = test_config.get('batch_run_status') + + # Optional format conversion (NZ format in PyNative) + input_format = test_config.get('input_format', INPUT_FORMAT_ND) + # Create network + net = PagedAttentionNet( + q_head_num=num_heads, + qk_scale=qk_scale, + kv_head_num=kv_heads, + mask_type=mask_type, + batch_run_status_enable=test_config.get('batch_run_status_enable', False), + quant_type=test_config.get('quant_type', QUANT_UNQUANT), + out_data_type=test_config.get('out_data_type', -1), + has_quant_offset=test_config.get('has_quant_offset', False), + compress_type=test_config.get('compress_type', 0), + calc_type=test_config.get('calc_type', 0), + scale_type=test_config.get('scale_type', 0), + input_layout=test_config.get('input_layout', INPUT_LAYOUT_BSND), + mla_v_dim=test_config.get('mla_v_dim', 0), + input_format=input_format, + ) + + # Select correct descale parameters based on quantization type + # For Dequant Fusion: use k_descale, v_descale (per-element) + # For Full QKV Quant: use k_scale_per_head, v_scale_per_head (per-head) in descale position + quant_type = test_config.get('quant_type', QUANT_UNQUANT) + if quant_type in [QUANT_QKV_OFFLINE, QUANT_QKV_ONLINE]: + # Full quantization: pass per-head scales in descale position + final_k_descale = k_scale_per_head + final_v_descale = v_scale_per_head + else: + # Dequant Fusion or Unquant: use original descale + final_k_descale = k_descale + final_v_descale = v_descale + + # Configure dynamic shapes if requested + if dynamic: + _set_dynamic_shapes_for_pa(net, test_config) + + # Execute operator + output = net( + query, key_cache, value_cache, block_tables, + mask, batch_run_status, + final_k_descale, k_offset, final_v_descale, v_offset, + None, p_scale, None, # razor_offset, p_scale, log_n + q_seq_lens, kv_seq_lens + ) + + # Validate output shape and dtype + num_tokens = sum(q_seq_lens_config) # Calculate from q_seq_lens + # MLA: output dimension is mla_v_dim (head_size_vo), not head_size (head_size_qk) + output_head_dim = mla_v_dim if mla_v_dim > 0 else head_size + expected_shape = (num_tokens, num_heads, output_head_dim) + assert tuple(output.shape) == expected_shape, \ + f"Shape mismatch: {output.shape} vs {expected_shape}" + + if test_config.get('expected_dtype'): + expected_dtype = test_config['expected_dtype'] + assert output.dtype == expected_dtype, \ + f"Dtype mismatch: {output.dtype} vs {expected_dtype}" + + # Step 5: Validate accuracy if requested + if validate_accuracy: + # Convert output to numpy for validation + output_np = output.asnumpy() + + # For quantization, use output dtype for validation and set is_quant flag + validate_dtype = test_config.get('expected_dtype', q_dtype) + is_quant_test = quant_type != QUANT_UNQUANT + max_ctx_len = max(context_lens) + assert generator.validate_accuracy(output_np, golden, validate_dtype, + num_heads, max_ctx_len, is_quant_test) + + +# ======================================== +# Test Cases - Organized by Feature Category +# ======================================== +# +# Test organization: +# 1. Basic Functionality Tests +# - Decode path across GRAPH/PYNATIVE, dynamic/static, fp16/bf16; mask-free baseline +# 2. Mask Type Tests +# - NORM (triangular), ALIBI positional bias, SPEC for MTP/parallel decoding +# 3. GQA Tests +# - Head ratios such as 8:1, 4:2, 32:8 +# 4. Configuration Variation Tests +# - Odd/non-16-aligned head counts, larger head size, small/large seq lens +# varied batch seq lens, small block size +# 5. Quantization Tests +# - Dequant Fusion (int8 KV) with/without offsets; Full QKV quant (offline/online) +# output dtype selection +# 6. Combined Feature Tests +# - GQA + int8 KV (with/without offsets), ALIBI + GQA +# 7. BF16 + int8 KV (Dequant Fusion) +# - Varied kv seq lens, head sizes (incl. non-standard), GQA ratios, block sizes, edge cases +# 8. MLA (Multi-Head Latent Attention) Tests +# - Split/combined cache, varied Q/K vs V/O dims, NORM mask, varied block sizes +# - Full QKV quant (online/offline), unaligned dims +# 9. MLA + MTP Combined Tests +# - Prefill (large/small embeds, fp16/bf16, with/without SPEC) +# - Multi-token outputs under SPEC, scalability with many heads +# 10. Lookahead Decoding Tests +# - Mixed q/k lengths, single-head, varied block sizes, very long contexts, with GQA +# + + +# Common test decorators +def _paged_attention_test(test_func): + """Apply common decorators for PagedAttention tests. + + Adds shared marks used by most tests in this module to reduce repetition. + """ + decorators = [ + pytest.mark.level0, + pytest.mark.platform_arm_ascend910b_training, + pytest.mark.env_onecard, + ] + for decorator in reversed(decorators): + test_func = decorator(test_func) + return test_func + + +# ==================== 1. Basic Functionality Tests ==================== +# Test basic features: data types, execution modes, simple configurations + +@_paged_attention_test +@pytest.mark.parametrize('ms_dtype', [ms.float16, ms.bfloat16]) +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('dynamic', [False, True]) +def test_pa_basic_dtype_and_mode(ms_dtype, run_mode, dynamic): + """ + Feature: PagedAttention - basic decode path across modes and dtypes + Description: Unquantized decode with random inputs over GRAPH/PYNATIVE, + static/dynamic, fp16/bf16 + Expectation: Runs successfully; output shape matches and matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(3001) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode: 1 token per sequence + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms_dtype, + 'kv_dtype': ms_dtype, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0, + } + _run_paged_attention_test( + generator, + test_config, + run_mode, + validate_accuracy=True, + dynamic=dynamic, + ) + + +@_paged_attention_test +def test_pa_no_mask(): + """ + Feature: PagedAttention - mask-free decode + Description: Simplest decode with no attention mask, fp16 inputs + Expectation: Operator executes; output shape/dtype correct and accuracy within tolerance + """ + generator = PagedAttentionDataGenerator(4115) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode: 1 token per sequence + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +# ==================== 2. Mask Type Tests ==================== +# Test different attention mask types + +@_paged_attention_test +@pytest.mark.parametrize('ms_dtype', [ms.float16, ms.bfloat16]) +def test_pa_norm_mask(ms_dtype): + """ + Feature: PagedAttention - normal causal mask + Description: Decode with triangular causal mask under fp16/bf16 + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4116) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode: 1 token per sequence + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms_dtype, + 'kv_dtype': ms_dtype, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +@pytest.mark.parametrize('ms_dtype', [ms.float16, ms.bfloat16]) +def test_pa_alibi_mask(ms_dtype): + """ + Feature: PagedAttention - ALIBI positional bias + Description: Decode with ALIBI bias per head using fp16/bf16 + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4100) + test_config = { + 'q_seq_lens': [1, 1], # Basic decode + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [500, 500], + 'q_dtype': ms_dtype, + 'kv_dtype': ms_dtype, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +@pytest.mark.parametrize('ms_dtype', [ms.float16, ms.bfloat16]) +def test_pa_spec_mask_mtp(ms_dtype): + """ + Feature: PagedAttention - SPEC mask for MTP + Description: Multi-token prediction (q_len=2) with SPEC mask in fp16/bf16 + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(3002) + test_config = { + 'q_seq_lens': [2, 2, 2, 2], # MTP: 2 tokens per sequence + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms_dtype, + 'kv_dtype': ms_dtype, + 'mask_type': MASK_SPEC, + 'qk_scale': 0.01, + 'calc_type': 1, # Enable MTP mode + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +# ==================== 3. GQA (Grouped Query Attention) Tests ==================== +# Test various query-to-kv head ratios + +@_paged_attention_test +@pytest.mark.parametrize('ms_dtype', [ms.float16, ms.bfloat16]) +def test_pa_gqa_8to1(ms_dtype): + """ + Feature: PagedAttention - GQA 8:1 ratio + Description: Decode with 8 query heads and 1 KV head (ALIBI mask) + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4101) + test_config = { + 'q_seq_lens': [1] * 13, # Basic decode + 'num_heads': 8, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 13, + 'q_dtype': ms_dtype, + 'kv_dtype': ms_dtype, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_gqa_4to2(): + """ + Feature: PagedAttention - GQA 4:2 ratio + Description: Decode with 4 query heads and 2 KV heads (ALIBI mask) + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4102) + test_config = { + 'q_seq_lens': [1] * 13, # Basic decode + 'num_heads': 4, + 'kv_heads': 2, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 13, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_gqa_32to8(): + """ + Feature: PagedAttention - GQA 32:8 ratio + Description: Decode with 32 query heads and 8 KV heads (ALIBI mask) + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4103) + test_config = { + 'q_seq_lens': [1] * 13, # Basic decode + 'num_heads': 32, + 'kv_heads': 8, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 13, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +# ==================== 4. Configuration Variation Tests ==================== +# Test different head counts, dimensions, block sizes, sequence lengths + +@_paged_attention_test +def test_pa_odd_heads(): + """ + Feature: PagedAttention - odd head count + Description: Decode with 7 heads to verify non-power-of-2 handling + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4104) + test_config = { + 'q_seq_lens': [1] * 13, # Basic decode + 'num_heads': 7, + 'kv_heads': 7, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 13, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_head_not_aligned(): + """ + Feature: PagedAttention - non-16-aligned head count + Description: Decode with 20 heads (bf16) to verify edge-case handling + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4110) + test_config = { + 'q_seq_lens': [1, 1], # Basic decode + 'num_heads': 20, + 'kv_heads': 20, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [500, 500], + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.bfloat16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_headsize_256(): + """ + Feature: PagedAttention - larger head dimension + Description: Decode with head_size=256 and block_size=16, no mask + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4107) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode + 'num_heads': 16, + 'kv_heads': 16, + 'head_size': 256, + 'block_size': 16, + 'num_blocks': 512, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(256), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_blocksize_16(): + """ + Feature: PagedAttention - small block size + Description: Decode with block_size=16 and fp16, no mask + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4108) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 16, + 'num_blocks': 512, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_small_seqlen(): + """ + Feature: PagedAttention - small sequence lengths + Description: Decode with kv_seq_len=33 and NORM mask + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4105) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [33, 33, 33, 33], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_large_seqlen(): + """ + Feature: PagedAttention - large sequence lengths + Description: Decode with kv_seq_len=3000 (bf16) and NORM mask + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4106) + test_config = { + 'q_seq_lens': [1, 1], # Basic decode + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 128, + 'context_lens': [3000, 3000], + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.bfloat16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_multi_batch_varied_seqlens(): + """ + Feature: PagedAttention - varied sequence lengths across batch + Description: Decode with kv_seq_lens [100, 500, 1000, 2000] and NORM mask + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4111) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 256, + 'context_lens': [100, 500, 1000, 2000], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +# ==================== 5. Quantization Tests ==================== +# Test KV cache quantization and full QKV quantization + +@_paged_attention_test +def test_pa_dequant_fusion_kv_int8(): + """ + Feature: PagedAttention - Dequant Fusion (int8 KV) + Description: Decode with int8 KV cache and per-element descales (no offsets) + Expectation: Operator executes; output dtype/shape correct and accuracy within tolerance + """ + generator = PagedAttentionDataGenerator(3003) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.float16, + 'kv_dtype': ms.int8, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0, + 'quant_type': DEQUANT_FUSION, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('dynamic', [False, True]) +def test_pa_dequant_fusion_with_offset(run_mode, dynamic): + """ + Feature: PagedAttention - Dequant Fusion with offsets (int8 KV) + Description: Decode testing k_offset/v_offset handling across modes and dynamic shapes + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(3004) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.float16, + 'kv_dtype': ms.int8, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0, + 'quant_type': DEQUANT_FUSION, + 'has_quant_offset': True, + } + _run_paged_attention_test( + generator, + test_config, + run_mode, + validate_accuracy=True, + dynamic=dynamic, + ) + + +@_paged_attention_test +@pytest.mark.parametrize('out_dtype_sel, expect_ms_dtype', [(1, ms.float16), (27, ms.bfloat16)]) +def test_pa_full_quant_qkv_offline(out_dtype_sel, expect_ms_dtype): + """ + Feature: PagedAttention - Full QKV quant (offline) + Description: int8 Q/K/V with per-head K/V scales and P scale; fp16/bf16 outputs + Expectation: Operator executes; output dtype matches selection and accuracy within tolerance + """ + generator = PagedAttentionDataGenerator(3004) + test_config = { + 'q_seq_lens': [1, 1], # Basic decode + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [64, 64], + 'q_dtype': ms.int8, + 'kv_dtype': ms.int8, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(128), + 'quant_type': QUANT_QKV_OFFLINE, + 'out_data_type': out_dtype_sel, + 'expected_dtype': expect_ms_dtype, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +@pytest.mark.parametrize('quant_type', [QUANT_QKV_OFFLINE, QUANT_QKV_ONLINE]) +@pytest.mark.parametrize('out_dtype_sel, expect_ms_dtype', [(1, ms.float16), (27, ms.bfloat16)]) +@pytest.mark.parametrize('dynamic', [False, True]) +def test_pa_full_quant_qkv_configs(quant_type, out_dtype_sel, expect_ms_dtype, dynamic): + """ + Feature: PagedAttention - Full QKV quant (offline/online) configs + Description: Validate offline vs online quant, layouts, and output dtypes + with dynamic/static shapes + Expectation: Operator executes; output dtype/accuracy meet expectations + """ + generator = PagedAttentionDataGenerator(4001 + quant_type) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [96, 64, 64, 64], + 'q_dtype': ms.int8, + 'kv_dtype': ms.int8, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(128), + 'quant_type': quant_type, + 'out_data_type': out_dtype_sel, + 'expected_dtype': expect_ms_dtype, + } + _run_paged_attention_test( + generator, + test_config, + context.GRAPH_MODE, + validate_accuracy=True, + dynamic=dynamic, + ) + + +# ==================== 6. Combined Feature Tests ==================== +# Test combinations of multiple features + +@_paged_attention_test +def test_pa_gqa_with_int8_kv(): + """ + Feature: PagedAttention - GQA with int8 KV (Dequant Fusion) + Description: bf16 query, 32:8 GQA, int8 KV, NORM mask + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4112) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode + 'num_heads': 32, + 'kv_heads': 8, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.int8, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0, + 'quant_type': DEQUANT_FUSION, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_gqa_int8_with_offset(): + """ + Feature: PagedAttention - GQA with int8 KV and offsets + Description: fp16 query, 32:8 GQA, Dequant Fusion with k/v offsets + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4113) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode + 'num_heads': 32, + 'kv_heads': 8, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.float16, + 'kv_dtype': ms.int8, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0, + 'quant_type': DEQUANT_FUSION, + 'has_quant_offset': True, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_alibi_with_gqa(): + """ + Feature: PagedAttention - ALIBI with GQA + Description: ALIBI bias combined with 8:1 GQA in fp16 + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(4114) + test_config = { + 'q_seq_lens': [1] * 13, # Basic decode + 'num_heads': 8, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 13, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +# ==================== 7. BFloat16 Quantization Tests ==================== +# Test bfloat16 with quantization (int8 KV cache) + +@_paged_attention_test +def test_pa_bf16_int8_kv_basic(): + """ + Feature: PagedAttention - bf16 with int8 KV (Dequant Fusion) + Description: Decode with bf16 Q and int8 KV over moderate sequence length + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(6001) + test_config = { + 'q_seq_lens': [1], # Basic decode + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 256, + 'context_lens': [768], + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.int8, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + 'quant_type': DEQUANT_FUSION, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +@pytest.mark.parametrize('k_seqlen', [128, 513, 768, 1024, 1025]) +def test_pa_bf16_int8_kv_varied_seqlen(k_seqlen): + """ + Feature: PagedAttention - bf16 with int8 KV across sequence lengths + Description: Decode across varied kv_seq_len including non-aligned sizes + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(6002 + k_seqlen) + test_config = { + 'q_seq_lens': [1], + 'num_heads': 24, + 'kv_heads': 24, + 'head_size': 64, + 'block_size': 128, + 'num_blocks': 256, + 'context_lens': [k_seqlen], + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.int8, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(64), + 'quant_type': DEQUANT_FUSION, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +@pytest.mark.parametrize('head_size', [32, 33, 64, 128]) +def test_pa_bf16_int8_kv_varied_headsize(head_size): + """ + Feature: PagedAttention - bf16 with int8 KV across head sizes + Description: Decode with varied head_size including non-standard 33 + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(6100 + head_size) + test_config = { + 'q_seq_lens': [1], + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': head_size, + 'block_size': 16 if head_size * 16 <= 128 * 128 else 128, + 'num_blocks': 512, + 'context_lens': [512], + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.int8, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(head_size), + 'quant_type': DEQUANT_FUSION, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +@pytest.mark.parametrize('num_heads,kv_heads', [(4, 1), (24, 2), (64, 32)]) +def test_pa_bf16_int8_kv_gqa_combinations(num_heads, kv_heads): + """ + Feature: PagedAttention - bf16 with int8 KV and GQA variants + Description: Decode with varied GQA ratios under bf16+int8 + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(6200 + num_heads * 100 + kv_heads) + test_config = { + 'q_seq_lens': [1], + 'num_heads': num_heads, + 'kv_heads': kv_heads, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 512, + 'context_lens': [766], + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.int8, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + 'quant_type': DEQUANT_FUSION, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +@pytest.mark.parametrize('block_size', [16, 128]) +def test_pa_bf16_int8_kv_varied_blocksize(block_size): + """ + Feature: PagedAttention - bf16 with int8 KV across block sizes + Description: Decode with block_size in {16, 128} under bf16+int8 + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(6300 + block_size) + test_config = { + 'q_seq_lens': [1], + 'num_heads': 32, + 'kv_heads': 8, + 'head_size': 64, + 'block_size': block_size, + 'num_blocks': 1024, + 'context_lens': [1024], + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.int8, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(64), + 'quant_type': DEQUANT_FUSION, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_bf16_int8_kv_edge_case_309(): + """ + Feature: PagedAttention - bf16 with int8 KV edge case (k_len=309) + Description: Decode with k_len=309, head_size=32, block_size=16 + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(6400) + test_config = { + 'q_seq_lens': [1], + 'num_heads': 64, + 'kv_heads': 64, + 'head_size': 32, + 'block_size': 16, + 'num_blocks': 256, + 'context_lens': [309], + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.int8, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(32), + 'quant_type': DEQUANT_FUSION, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +# ==================== 8. MLA (Multi-Head Latent Attention) Tests ==================== +# Test MLA scenarios with different Q/K and V/O head dimensions +# MLA特性: head_size_qk != head_size_vo, 需要设置mla_v_head_size参数 + +@_paged_attention_test +def test_pa_mla_split_cache_basic(): + """ + Feature: PagedAttention + MLA - split KV cache basic + Description: MLA with head_size_qk=576, head_size_vo=512, no mask + Expectation: Graph mode executes and matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(7001) + test_config = { + 'q_seq_lens': [1] * 20, # 20 sequences, 1 token each + 'num_heads': 4, + 'kv_heads': 1, + 'head_size': 576, # Q/K head size + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [128] * 20, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(576), + 'mla_v_dim': 512, # V/O head size (different from Q/K) + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_split_cache_large_groupnum(): + """ + Feature: PagedAttention + MLA - split cache with large head count + Description: MLA with 128 Q heads and 1 KV head (128:1), no mask + Expectation: Graph mode executes and matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(7002) + test_config = { + 'q_seq_lens': [1] * 20, + 'num_heads': 128, + 'kv_heads': 1, + 'head_size': 576, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [256] * 20, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(576), + 'mla_v_dim': 512, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +@pytest.mark.parametrize('head_size_qk,head_size_vo', [(576, 512), (192, 128)]) +def test_pa_mla_varied_head_dimensions(head_size_qk, head_size_vo): + """ + Feature: PagedAttention + MLA - varied head dimensions + Description: Validate MLA with different Q/K and V/O sizes + Expectation: Graph mode executes and matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(7100 + head_size_qk) + test_config = { + 'q_seq_lens': [1] * 32, + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': head_size_qk, + 'block_size': 128, + 'num_blocks': 256, + 'context_lens': [256] * 32, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(head_size_qk), + 'mla_v_dim': head_size_vo, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_with_norm_mask(): + """ + Feature: PagedAttention + MLA - NORM mask + Description: MLA with varied kv_seq_lens and causal mask + Expectation: Graph mode executes and matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(7200) + test_config = { + 'q_seq_lens': [1] * 9, + 'num_heads': 128, + 'kv_heads': 1, + 'head_size': 576, + 'block_size': 128, + 'num_blocks': 512, + 'context_lens': [3000, 300, 1400, 33, 65, 1, 16, 1400, 300], # Varied lengths + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(576), + 'mla_v_dim': 512, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +@pytest.mark.parametrize('block_size', [128, 64]) +def test_pa_mla_varied_blocksize(block_size): + """ + Feature: PagedAttention + MLA - varied block sizes + Description: MLA with block_size in {128, 256}, no mask + Expectation: Graph mode executes and matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(7300 + block_size) + test_config = { + 'q_seq_lens': [1] * 16, + 'num_heads': 16, + 'kv_heads': 1, + 'head_size': 576, + 'block_size': block_size, + 'num_blocks': 256, + 'context_lens': [256] * 16, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(576), + 'mla_v_dim': 512, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_int8_kv_quant(): + """ + Feature: PagedAttention + MLA - full QKV quant (online) + Description: MLA with int8 Q/K/V, online quant, bf16 output + Expectation: Graph mode executes; output dtype/accuracy are correct + """ + generator = PagedAttentionDataGenerator(7400) + test_config = { + 'q_seq_lens': [1] * 20, + 'num_heads': 40, + 'kv_heads': 1, + 'head_size': 576, + 'block_size': 128, + 'num_blocks': 256, + 'context_lens': [1024] * 20, + 'q_dtype': ms.int8, # Full quantization: Q is int8 + 'kv_dtype': ms.int8, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(576), + 'quant_type': QUANT_QKV_ONLINE, # Online quantization + 'out_data_type': 27, # bfloat16 output + 'expected_dtype': ms.bfloat16, + 'mla_v_dim': 512, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_int8_kv_quant_offline(): + """ + Feature: PagedAttention + MLA - full QKV quant (offline) + Description: MLA with int8 Q/K/V, offline quant, varied kv_seq_lens, bf16 output + Expectation: Graph mode executes; output dtype/accuracy are correct + """ + generator = PagedAttentionDataGenerator(7500) + test_config = { + 'q_seq_lens': [1] * 9, + 'num_heads': 128, + 'kv_heads': 1, + 'head_size': 576, + 'block_size': 128, + 'num_blocks': 512, + 'context_lens': [3000, 300, 1400, 33, 65, 1, 16, 1400, 300], + 'q_dtype': ms.int8, # Full quantization: Q is int8 + 'kv_dtype': ms.int8, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(576), + 'quant_type': QUANT_QKV_OFFLINE, # Offline quantization + 'out_data_type': 27, # bfloat16 output + 'expected_dtype': ms.bfloat16, + 'mla_v_dim': 512, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_unaligned_embed(): + """ + Feature: PagedAttention + MLA - unaligned head dimensions + Description: MLA with Q/K=290 and V/O=130, full QKV quant + Expectation: Graph mode executes; output dtype/accuracy are correct + """ + generator = PagedAttentionDataGenerator(7600) + test_config = { + 'q_seq_lens': [1] * 25, + 'num_heads': 25, + 'kv_heads': 1, + 'head_size': 290, # Non-standard Q/K dimension + 'block_size': 128, + 'num_blocks': 256, + 'context_lens': [128] * 25, + 'q_dtype': ms.int8, + 'kv_dtype': ms.int8, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(290), + 'quant_type': QUANT_QKV_ONLINE, + 'out_data_type': 1, # float16 output + 'expected_dtype': ms.float16, + 'mla_v_dim': 130, # Non-standard V/O dimension + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +# ==================== 9. MLA + MTP Combined Tests ==================== +# Test MLA (Multi-Head Latent Attention) combined with MTP (Multi-Token Prediction) +# These tests cover scenarios where Q/K and V/O have different dimensions AND +# multiple tokens are predicted per sequence (q_seqlen > 1) + +@_paged_attention_test +def test_pa_mla_mtp_prefill_large_embed_no_mask_fp16(): + """ + Feature: PagedAttention + MLA + MTP - prefill large embed (fp16) + Description: Prefill q_len=128, Q/K=576, V/O=512, no mask + Expectation: Graph mode executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8001) + batch = 27 + test_config = { + 'q_seq_lens': [128] * batch, # Prefill: 128 tokens per sequence + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 576, # Q/K head size + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [512] * batch, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, # No mask + 'qk_scale': 1.0 / math.sqrt(576), + 'calc_type': 1, # MTP mode (required to pass q_seq_lens into op) + 'mla_v_dim': 512, # V/O head size (different from Q/K) + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_mtp_prefill_large_embed_no_mask_bf16(): + """ + Feature: PagedAttention + MLA + MTP - prefill large embed (bf16) + Description: Prefill q_len=64, Q/K=576, V/O=512, no mask, bf16 dtype + Expectation: Graph mode executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8002) + batch = 32 + test_config = { + 'q_seq_lens': [64] * batch, # Prefill: 64 tokens per sequence + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 576, # Q/K head size + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [273] * batch, + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.bfloat16, + 'mask_type': MASK_UNDEFINED, # No mask + 'qk_scale': 1.0 / math.sqrt(576), + 'calc_type': 1, # MTP mode + 'mla_v_dim': 512, # V/O head size + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_mtp_prefill_small_embed_no_mask_fp16(): + """ + Feature: PagedAttention + MLA + MTP - prefill small embed (fp16) + Description: Prefill q_len=23, Q/K=192, V/O=128, no mask + Expectation: Graph mode executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8003) + batch = 25 + test_config = { + 'q_seq_lens': [23] * batch, # Prefill: 23 tokens per sequence + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 192, # Q/K head size (smaller) + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [156] * batch, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, # No mask + 'qk_scale': 1.0 / math.sqrt(192), + 'calc_type': 1, # MTP mode + 'mla_v_dim': 128, # V/O head size (smaller) + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_mtp_prefill_small_embed_spec_mask_fp16(): + """ + Feature: PagedAttention + MLA + MTP - prefill small embed with SPEC (fp16) + Description: Prefill q_len=256, Q/K=192, V/O=128 using SPEC mask + Expectation: Graph mode executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8004) + batch = 27 + test_config = { + 'q_seq_lens': [256] * batch, # Prefill: 256 tokens per sequence + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 192, # Q/K head size + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [766] * batch, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, # SPEC mask for parallel decoding + 'qk_scale': 1.0 / math.sqrt(192), + 'calc_type': 1, # MTP mode + 'mla_v_dim': 128, # V/O head size + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_mtp_prefill_small_embed_spec_mask_bf16(): + """ + Feature: PagedAttention + MLA + MTP - prefill small embed with SPEC (bf16) + Description: Prefill q_len=1056, Q/K=192, V/O=128 using SPEC mask and bf16 dtype + Expectation: Graph mode executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8005) + batch = 11 + test_config = { + 'q_seq_lens': [1056] * batch, # Large prefill: 1056 tokens per sequence + 'num_heads': 16, + 'kv_heads': 1, + 'head_size': 192, # Q/K head size + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [1963] * batch, + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.bfloat16, + 'mask_type': MASK_SPEC, # SPEC mask + 'qk_scale': 1.0 / math.sqrt(192), + 'calc_type': 1, # MTP mode + 'mla_v_dim': 128, # V/O head size + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_mtp_multi_token_large_embed_spec_mask_fp16(): + """ + Feature: PagedAttention + MLA - MTP with SPEC mask (fp16) + Description: MTP (q_len=4) with large Q/K (576) and V/O (512) head dims using SPEC mask + Expectation: Graph mode executes and matches golden within tolerance; shapes/dtypes correct + """ + generator = PagedAttentionDataGenerator(8006) + batch = 27 + test_config = { + 'q_seq_lens': [4] * batch, # MTP: 4 tokens per sequence + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 576, # Q/K head size (large) + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [512] * batch, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, # SPEC mask for MTP + 'qk_scale': 1.0 / math.sqrt(576), + 'calc_type': 1, # MTP mode + 'mla_v_dim': 512, # V/O head size (large) + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_mtp_multi_token_large_embed_spec_mask_bf16(): + """ + Feature: PagedAttention + MLA - MTP with SPEC mask (bf16) + Description: MTP (q_len=4) in bf16 with large Q/K (576) and V/O (512) using SPEC mask + Expectation: Graph mode executes and matches golden within tolerance; shapes/dtypes correct + """ + generator = PagedAttentionDataGenerator(8007) + batch = 32 + test_config = { + 'q_seq_lens': [4] * batch, # MTP: 4 tokens per sequence + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 576, # Q/K head size + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [273] * batch, + 'q_dtype': ms.bfloat16, + 'kv_dtype': ms.bfloat16, + 'mask_type': MASK_SPEC, # SPEC mask for MTP + 'qk_scale': 1.0 / math.sqrt(576), + 'calc_type': 1, # MTP mode + 'mla_v_dim': 512, # V/O head size + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +@_paged_attention_test +def test_pa_mla_mtp_multi_token_large_heads_long_context_fp16(): + """ + Feature: PagedAttention + MLA - MTP scalability (fp16) + Description: Stress with 128 heads and 4096 context under SPEC mask in MTP + Expectation: Graph mode executes; results align with golden within tolerance + """ + generator = PagedAttentionDataGenerator(8008) + batch = 16 + test_config = { + 'q_seq_lens': [4] * batch, # MTP: 4 tokens per sequence + 'num_heads': 128, # Very large head count + 'kv_heads': 1, + 'head_size': 576, # Q/K head size + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [4096] * batch, # Very long context + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, # SPEC mask for MTP + 'qk_scale': 1.0 / math.sqrt(576), + 'calc_type': 1, # MTP mode + 'mla_v_dim': 512, # V/O head size + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True) + + +# ==================== 10. Lookahead Decoding Tests ==================== +# Test lookahead/speculative decoding scenarios (MASK_SPEC with varied q/k lengths) + +@_paged_attention_test +def test_pa_lookahead_mixed_lengths(): + """ + Feature: PagedAttention - lookahead/speculative decoding with SPEC mask + Description: Mixed q_seq_lens [1,15,30,6] and kv_seq_lens [10,64,64,64] in MTP mode + Expectation: Graph mode executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(5001) + test_config = { + 'q_seq_lens': [1, 15, 30, 6], # MTP: varied tokens per sequence + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 512, + 'context_lens': [10, 64, 64, 64], # Different context lengths per sequence + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, # MTP mode for lookahead + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_lookahead_single_head(): + """ + Feature: PagedAttention - lookahead with single head + Description: MTP with single head and varied sequences, SPEC mask + Expectation: Graph mode executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(5002) + test_config = { + 'q_seq_lens': [256, 256, 15], # MTP: varied tokens + 'num_heads': 1, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 256, + 'context_lens': [512, 512, 2048], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_lookahead_blocksize_16(): + """ + Feature: PagedAttention - lookahead with small block size + Description: SPEC mask MTP using block_size=16 to verify granularity handling + Expectation: Graph mode executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(5003) + test_config = { + 'q_seq_lens': [1, 15, 30, 6], # MTP: varied tokens per sequence + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 16, + 'num_blocks': 512, + 'context_lens': [10, 64, 64, 64], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_lookahead_blocksize_32(): + """ + Feature: PagedAttention - lookahead with medium block size + Description: SPEC mask MTP using block_size=32 with varied q/k lengths + Expectation: Graph mode executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(5004) + test_config = { + 'q_seq_lens': [15, 103, 1024], # MTP: varied tokens + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 32, + 'num_blocks': 256, + 'context_lens': [64, 103, 1025], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_lookahead_very_long_context(): + """ + Feature: PagedAttention - lookahead with very long contexts + Description: Decode (q_len=1) with kv_seq_lens up to 33k to test long-context handling + Expectation: Graph mode executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(5005) + test_config = { + 'q_seq_lens': [1, 1, 1], # Decode mode with very long context + 'num_heads': 40, + 'kv_heads': 40, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1234, + 'context_lens': [13333, 23333, 33331], # Very long contexts + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_test +def test_pa_lookahead_with_gqa(): + """ + Feature: PagedAttention - lookahead with GQA + Description: SPEC mask MTP using kv_heads < num_heads (GQA) + Expectation: Graph mode executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(5006) + test_config = { + 'q_seq_lens': [1, 15, 30, 6], # MTP: varied tokens per sequence + 'num_heads': 32, + 'kv_heads': 8, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 512, + 'context_lens': [10, 64, 64, 64], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, + } + _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) diff --git a/tests/st/test_custom_paged_attention_nz.py b/tests/st/test_custom_paged_attention_nz.py new file mode 100644 index 0000000..31c82fe --- /dev/null +++ b/tests/st/test_custom_paged_attention_nz.py @@ -0,0 +1,1500 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""System tests for paged_attention operator on Ascend 310P with NZ format. + +This module tests the paged_attention operator on 310P hardware with NZ data format. +NZ format conversion (ND->NZ) is performed within the graph using trans_data operator. + +Supported Features (310P NZ path): +- Basic decode with fp16 +- Mask types: UNDEFINED, NORM, ALIBI, SPEC, FREE +- Grouped Query Attention (GQA) +- Multi-Token Prediction (MTP) / Lookahead decoding +- Prefill scenarios + +Unsupported Features (310P NZ path): +- Dequant Fusion / Full QKV quantization (INT8 KV cache) +- MLA (Multi-Head Latent Attention) - use dedicated MLA operator instead +""" + +import math +from typing import Dict, Optional + +import numpy as np +import pytest +from test_custom_paged_attention import ( + PagedAttentionDataGenerator, + MASK_UNDEFINED, + MASK_NORM, + MASK_ALIBI, + MASK_SPEC, + MASK_FREE, + QUANT_UNQUANT, + QUANT_QKV_OFFLINE, + QUANT_QKV_ONLINE, + INPUT_LAYOUT_BSND, + INPUT_FORMAT_NZ, +) +import mindspore as ms +from mindspore import Tensor, context, nn, ops +import ms_custom_ops + +# ========== Constants ========== + +# Trans_data operation types +_TRANSDATA_ND_TO_NZ = 1 # Convert ND format to NZ format +_TRANSDATA_NZ_TO_ND = 0 # Convert NZ format to ND format + +# ========== Helper Functions ========== + + +def _select_descale_tensors( + quant_type: int, + k_descale: Optional[Tensor], + v_descale: Optional[Tensor], + k_scale_per_head: Optional[Tensor], + v_scale_per_head: Optional[Tensor] +) -> tuple: + """Select appropriate descale tensors based on quantization type. + + Args: + quant_type: Quantization type + k_descale: Per-element K descale (for DEQUANT_FUSION) + v_descale: Per-element V descale (for DEQUANT_FUSION) + k_scale_per_head: Per-head K scale (for QUANT_QKV_*) + v_scale_per_head: Per-head V scale (for QUANT_QKV_*) + + Returns: + Tuple of (selected_k_descale, selected_v_descale) + """ + if quant_type in (QUANT_QKV_OFFLINE, QUANT_QKV_ONLINE): + return k_scale_per_head, v_scale_per_head + return k_descale, v_descale + + +def _check_310p_capability(test_config: Dict) -> None: + """Check if test configuration is supported on 310P NZ path. + + Raises: + pytest.skip: If configuration uses unsupported features + """ + quant_type = test_config.get('quant_type', QUANT_UNQUANT) + has_quant_offset = test_config.get('has_quant_offset', False) + mla_v_dim = test_config.get('mla_v_dim', 0) + + # Check quantization support + is_full_quant = quant_type in (QUANT_QKV_OFFLINE, QUANT_QKV_ONLINE) + if quant_type != QUANT_UNQUANT or is_full_quant or has_quant_offset: + pytest.skip( + "310P paged_attention does not support quantization on NZ path " + "(Dequant Fusion/Full Quant/Offsets)." + ) + + # Check MLA support + if int(mla_v_dim or 0) > 0: + pytest.skip( + "310P paged_attention does not support MLA on NZ path. " + "Use dedicated MLA operator for MLA functionality." + ) + + +# ========== Network Definition ========== + +class PagedAttentionNzNet(nn.Cell): + """MindSpore network wrapper for paged_attention with NZ format conversion. + + This network: + 1) Converts ND format inputs to NZ format via trans_data + 2) Executes paged_attention + 3) Converts NZ output back to ND format + + Platform-specific handling: + - 310P (NZ): context_lens stays on NPU as tensor input + - 910B (ND): context_lens moved to CPU and passed via param + + Note: No manual padding/reshape is required; trans_data and kernel handle layout/alignment internally. + """ + + def __init__( + self, + q_head_num: int, + qk_scale: float, + kv_head_num: int, + mask_type: int, + batch_run_status_enable: bool = False, + quant_type: int = QUANT_UNQUANT, + out_data_type: int = -1, + has_quant_offset: bool = False, + compress_type: int = 0, + calc_type: int = 0, + scale_type: int = 0, + input_layout: int = INPUT_LAYOUT_BSND, + mla_v_dim: int = 0, + input_format: int = INPUT_FORMAT_NZ, + ): + """Initialize PagedAttentionNzNet. + + Args: + q_head_num: Number of query heads + qk_scale: QK scale factor + kv_head_num: Number of KV heads (for GQA) + mask_type: Attention mask type + batch_run_status_enable: Enable batch run status + quant_type: Quantization type + out_data_type: Output data type override + has_quant_offset: Whether quantization uses offsets + compress_type: Compression type + calc_type: Calculation type (0=decode, 1=MTP) + scale_type: Scale type + input_layout: Input tensor layout + mla_v_dim: MLA V dimension (0 means no MLA) + input_format: Input format (0=ND, 1=NZ) + """ + super().__init__() + self.q_head_num = int(q_head_num) + self.qk_scale = float(qk_scale) + self.kv_head_num = int(kv_head_num) + self.mask_type = int(mask_type) + self.batch_run_status_enable = bool(batch_run_status_enable) + self.quant_type = int(quant_type) + self.out_data_type = int(out_data_type) + self.has_quant_offset = bool(has_quant_offset) + self.compress_type = int(compress_type) + self.calc_type = int(calc_type) + self.scale_type = int(scale_type) + self.input_layout = int(input_layout) + self.mla_v_dim = int(mla_v_dim) + self.input_format = int(input_format) + self._is_pynative = context.get_context("mode") == context.PYNATIVE_MODE + + def construct( + self, + query_2d: Tensor, + key_cache_3d: Tensor, + value_cache_3d: Tensor, + block_tables: Tensor, + attn_mask_nd: Optional[Tensor], + batch_run_status: Optional[Tensor], + k_descale: Optional[Tensor], + k_offset: Optional[Tensor], + v_descale: Optional[Tensor], + v_offset: Optional[Tensor], + razor_offset: Optional[Tensor], + p_scale: Optional[Tensor], + log_n: Optional[Tensor], + q_seq_lens: Optional[Tensor], + kv_seq_lens: Tensor, + ) -> Tensor: + """Forward pass with ND->NZ conversion and paged attention execution. + + This method performs the following transformations: + 1. Query: (T, H*D) -> trans_data -> NZ format + 2. KV Cache: (B, S, KH*D) -> trans_data -> NZ format + 3. Mask: ND format -> trans_data -> NZ format (if present) + 4. Execute paged_attention with NZ inputs + 5. Output: NZ format -> trans_data -> (T, H*D) -> (T, H, D) + + Note: Input tensors are expected to be already reshaped to 2D/3D format. + Use _prepare_inputs_for_nz_network() to prepare inputs before calling this network. + + Args: + query_2d: Query tensor in 2D format (tokens, num_heads * head_size) + key_cache_3d: Key cache tensor in 3D format (num_blocks, block_size, kv_heads * head_size_qk) + value_cache_3d: Value cache tensor in 3D format (num_blocks, block_size, kv_heads * head_size_vo) + block_tables: Block table mapping (batch_size, max_blocks_per_seq) + attn_mask_nd: Attention mask in ND format (optional) + batch_run_status: Batch run status (optional) + k_descale: Key descale factors (optional) + k_offset: Key quantization offsets (optional) + v_descale: Value descale factors (optional) + v_offset: Value quantization offsets (optional) + razor_offset: Razor offset (unused) + p_scale: P scale factor (optional) + log_n: Log N tensor (unused) + q_seq_lens: Query sequence lengths (optional, for MTP) + kv_seq_lens: KV sequence lengths (context lengths) + + Returns: + Output tensor in ND format (tokens, num_heads, head_size) + """ + # Step 1: Ensure query is contiguous for trans_data + # query_2d = query_2d.contiguous() + + # Step 2: Convert ND to NZ format (2D/3D ND -> 4D NZ) + query_nz = ms_custom_ops.trans_data(query_2d, transdata_type=_TRANSDATA_ND_TO_NZ) + key_cache_nz = ms_custom_ops.trans_data(key_cache_3d, transdata_type=_TRANSDATA_ND_TO_NZ) + value_cache_nz = ms_custom_ops.trans_data(value_cache_3d, transdata_type=_TRANSDATA_ND_TO_NZ) + + # Step 3: Convert mask if present (ND -> NZ via trans_data) + # Masks are in ND format with proper shapes: + # - ALIBI: (batch*num_heads, 16, max_context_len_pad) + # - NORM: (num_tokens, 16, max_context_len_pad) for decode or (batch, max_q_len, max_context_len) for prefill + # - SPEC: (num_tokens, max_context_len) + # - MASK_FREE: (1, 128, 128) + attn_mask_nz = None + if attn_mask_nd is not None: + attn_mask_nz = ms_custom_ops.trans_data(attn_mask_nd, transdata_type=_TRANSDATA_ND_TO_NZ) + + # Step 4: Handle seq_lens based on platform + # 310P NZ: kv_seq_lens stays on NPU, q_seq_lens moves to CPU + kv_seq_input = kv_seq_lens # Keep on NPU for 310P + + q_seq_input = None + if self.calc_type == 1 and q_seq_lens is not None: + # MTP mode: move q_seq_lens to CPU + if self._is_pynative: + q_seq_input = q_seq_lens.move_to("CPU") + else: + q_seq_input = ops.move_to(q_seq_lens, "CPU") + + # Step 5: Execute paged attention with NZ format + out_nz = ms_custom_ops.paged_attention( + query_nz, key_cache_nz, value_cache_nz, block_tables, kv_seq_input, + attn_mask_nz, batch_run_status, + k_descale, k_offset, v_descale, v_offset, + razor_offset, p_scale, log_n, + q_seq_input, + self.q_head_num, self.qk_scale, self.kv_head_num, self.mask_type, + self.batch_run_status_enable, self.quant_type, self.out_data_type, + self.has_quant_offset, self.compress_type, self.calc_type, + self.scale_type, self.input_layout, self.mla_v_dim, self.input_format + ) + + # Step 6: Convert output from NZ back to ND format + out_nd_2d = ms_custom_ops.trans_data(out_nz, transdata_type=_TRANSDATA_NZ_TO_ND) + + # Step 7: Reshape output from 2D to 3D TND format + # Calculate token count and output head size from query_2d shape + token_count = query_2d.shape[0] + output_head_size = self.mla_v_dim if self.mla_v_dim > 0 else (query_2d.shape[1] // self.q_head_num) + # (tokens, H*D_out) -> (tokens, H_out, D_out) + out_nd = ops.reshape(out_nd_2d, (token_count, self.q_head_num, output_head_size)) + + return out_nd + + +# ========== Test Execution ========== + +def _prepare_inputs_for_nz_network( + query: np.ndarray, + key_cache: np.ndarray, + value_cache: np.ndarray, + num_heads: int, + kv_heads: int, + q_dtype: ms.dtype, + kv_dtype: ms.dtype +) -> tuple: + """Prepare query and KV cache tensors for NZ network input. + + This function reshapes numpy arrays to the format expected by PagedAttentionNzNet: + - Query: (tokens, num_heads, head_size) -> (tokens, num_heads * head_size) + - Key cache: (num_blocks, block_size, kv_heads, head_size_qk) -> (num_blocks, block_size, kv_heads * head_size_qk) + - Value cache: (num_blocks, block_size, kv_heads, head_size_vo) -> (num_blocks, block_size, kv_heads * head_size_vo) + + Args: + query: Query numpy array (tokens, num_heads, head_size) + key_cache: Key cache numpy array (num_blocks, block_size, kv_heads, head_size_qk) + value_cache: Value cache numpy array (num_blocks, block_size, kv_heads, head_size_vo) + num_heads: Number of query heads + kv_heads: Number of KV heads + q_dtype: Query dtype + kv_dtype: KV cache dtype + + Returns: + Tuple of (query_2d_tensor, key_cache_3d_tensor, value_cache_3d_tensor) + """ + # Reshape query: (tokens, num_heads, head_size) -> (tokens, num_heads * head_size) + tokens = query.shape[0] + head_size = query.shape[2] + query_2d = query.reshape(tokens, num_heads * head_size) + + # Reshape key cache: (num_blocks, block_size, kv_heads, head_size_qk) -> + # (num_blocks, block_size, kv_heads * head_size_qk) + num_blocks = key_cache.shape[0] + block_size = key_cache.shape[1] + head_size_qk = key_cache.shape[3] + key_cache_3d = key_cache.reshape(num_blocks, block_size, kv_heads * head_size_qk) + + # Reshape value cache: (num_blocks, block_size, kv_heads, head_size_vo) -> + # (num_blocks, block_size, kv_heads * head_size_vo) + head_size_vo = value_cache.shape[3] + value_cache_3d = value_cache.reshape(num_blocks, block_size, kv_heads * head_size_vo) + + # Convert to Tensors + query_tensor = Tensor(query_2d, dtype=q_dtype) + key_cache_tensor = Tensor(key_cache_3d, dtype=kv_dtype) + value_cache_tensor = Tensor(value_cache_3d, dtype=kv_dtype) + + return query_tensor, key_cache_tensor, value_cache_tensor + + +def _prepare_mask_for_network( + mask: Optional[np.ndarray], + mask_type: int, + num_tokens: int, + batch_size: int, + num_heads: int, + context_lens: list, + q_seq_lens: list, + dtype: ms.dtype +) -> tuple: + """Prepare mask tensor for network input (ND format with NZ padding, ready for trans_data). + + This function handles: + 1. NZ padding for MASK_NORM and MASK_ALIBI (16-alignment) + 2. Reshape for ALIBI: (batch, num_heads, q_len, kv_len) -> (batch*num_heads, q_len, kv_len) + 3. Special handling for MASK_FREE and MASK_UNDEFINED + + Args: + mask: Original mask numpy array (without NZ padding, or None) + mask_type: Mask type constant + num_tokens: Total number of query tokens + batch_size: Batch size + num_heads: Number of query heads + context_lens: Context lengths for each sequence + q_seq_lens: Query sequence lengths for each sequence + dtype: Target dtype + + Returns: + Tuple of (prepared_mask, effective_mask_type) + - prepared_mask: Mask tensor in ND format with NZ padding (or None) + - effective_mask_type: Actual mask type to use (may differ from input for MASK_UNDEFINED) + """ + effective_mask_type = mask_type + np_dtype = ms.dtype_to_nptype(dtype) + max_context_len = max(context_lens) + max_q_len = max(q_seq_lens) + + # Handle different mask types + if mask_type == MASK_ALIBI: + # ALIBI: Apply NZ padding then reshape + if mask is None: + return None, effective_mask_type + + # Step 1: Apply NZ padding (16-alignment) + # Original shape: (batch, num_heads, q_len, max_context_len) + # Padded shape: (batch, num_heads, 16, max_context_len_pad) + max_context_len_pad = ((max_context_len + 15) // 16) * 16 + q_len_for_pad = 1 if max_q_len == 1 else max_q_len + mask_padded = np.zeros((batch_size, num_heads, 16, max_context_len_pad), dtype=np_dtype) + mask_padded[:, :, :q_len_for_pad, :max_context_len] = mask + + # Step 2: Reshape from (batch, num_heads, 16, max_context_len_pad) to + # (batch*num_heads, 16, max_context_len_pad) + mask_np = mask_padded.reshape((batch_size * num_heads, 16, max_context_len_pad)) + return Tensor(mask_np, dtype=dtype), effective_mask_type + + if mask_type == MASK_FREE: + # MASK_FREE: create fixed 128x128 upper triangular mask + mask_size = 128 + mask_np = np.zeros((mask_size, mask_size), dtype=np.float16) + mask_np[np.triu_indices(mask_size, k=1)] = 1.0 + mask_np = mask_np * -60000.0 + # Reshape to 3D (1, 128, 128) before trans_data + mask_nd = mask_np.reshape((1, mask_size, mask_size)) + return Tensor(mask_nd, dtype=dtype), effective_mask_type + + if mask_type == MASK_NORM: + # NORM: Apply NZ padding + if mask is None: + return None, effective_mask_type + + if max_q_len == 1: + # Decode mode: Apply NZ padding + # Original shape: (num_tokens, 1, max_context_len) + # Padded shape: (num_tokens, 16, max_context_len_pad) + max_context_len_pad = ((max_context_len + 15) // 16) * 16 + mask_padded = np.zeros((num_tokens, 16, max_context_len_pad), dtype=np_dtype) + mask_padded[:, :1, :max_context_len] = mask + return Tensor(mask_padded, dtype=dtype), effective_mask_type + # Prefill mode: no padding needed (already in correct shape) + return Tensor(mask, dtype=dtype), effective_mask_type + + if mask_type == MASK_SPEC: + # SPEC: keep original shape (no padding needed) + if mask is None: + return None, effective_mask_type + return Tensor(mask, dtype=dtype), effective_mask_type + + # MASK_UNDEFINED + # For decode mode (q_len=1), create zero-filled NORM mask to force masked kernel path + # This ensures NZ padding columns don't contribute to softmax denominator + return None, effective_mask_type + + +def _run_paged_attention_nz_test( + generator: PagedAttentionDataGenerator, + test_config: Dict, + run_mode: int, + validate_accuracy: bool = True, + dynamic: bool = False, +) -> None: + """Execute paged attention test with NZ format conversion. + + Following the pattern of test_custom_flash_attention_encoder_nz.py: + 1. Generate numpy arrays for golden calculation (no padding/reshaping) + 2. Compute golden reference using numpy arrays + 3. Reshape/prepare tensors for network input (ND format) + 4. Network internally converts ND->NZ via trans_data + 5. Validate output against golden + + Args: + generator: Data generator for inputs and golden reference + test_config: Test configuration dictionary + run_mode: Execution mode (GRAPH_MODE or PYNATIVE_MODE) + validate_accuracy: Whether to validate against golden reference + dynamic: Whether to use dynamic shapes + + Raises: + AssertionError: If output shape/dtype/accuracy validation fails + pytest.skip: If configuration uses unsupported 310P features + """ + context.set_context(device_target="Ascend", mode=run_mode) + + # Check 310P capability constraints + _check_310p_capability(test_config) + + # Extract test parameters + num_heads = test_config['num_heads'] + kv_heads = test_config['kv_heads'] + head_size = test_config['head_size'] + q_seq_lens_config = test_config.get('q_seq_lens', [1] * len(test_config['context_lens'])) + context_lens = test_config['context_lens'] + qk_scale = test_config.get('qk_scale', 1.0 / math.sqrt(head_size)) + quant_type = test_config.get('quant_type', QUANT_UNQUANT) + mla_v_dim = test_config.get('mla_v_dim', 0) + num_tokens = sum(q_seq_lens_config) + batch_size = len(q_seq_lens_config) + + # Step 1: Generate numpy arrays for golden calculation (without NZ padding) + # Masks are in their natural shapes for golden calculation + # NZ padding will be applied later in _prepare_mask_for_network + inputs_dict = generator.generate_inputs( + num_heads, kv_heads, head_size, + test_config['block_size'], test_config['num_blocks'], + q_seq_lens_config, context_lens, + test_config['q_dtype'], test_config['kv_dtype'], + test_config['mask_type'], quant_type, + test_config.get('has_quant_offset', False), mla_v_dim, + test_config.get('expected_dtype') + ) + + # Step 2: Compute golden reference using numpy arrays (float32 for precision) + # This is done before any tensor conversion or padding + is_full_quant = quant_type in (QUANT_QKV_OFFLINE, QUANT_QKV_ONLINE) + golden_output_dtype = test_config.get('expected_dtype') if is_full_quant else None + + golden = generator.compute_golden_reference( + inputs_dict['query'], + inputs_dict['key_cache'], + inputs_dict['value_cache'], + inputs_dict['block_tables'], + inputs_dict['q_seq_lens'], + inputs_dict['kv_seq_lens'], + qk_scale, + inputs_dict['mask'], + test_config['mask_type'], + inputs_dict['k_descale'], + inputs_dict['v_descale'], + inputs_dict['k_offset'], + inputs_dict['v_offset'], + inputs_dict['k_scale_per_head'], + inputs_dict['v_scale_per_head'], + inputs_dict['p_scale'], + output_dtype=golden_output_dtype, + mla_v_dim=mla_v_dim, + mask_dtype=inputs_dict['mask_dtype'] + ) + + # Step 3: Prepare ND format tensors for network input + # Apply NZ padding and reshape masks in _prepare_mask_for_network + mask_nd, effective_mask_type = _prepare_mask_for_network( + inputs_dict['mask'], + test_config['mask_type'], + num_tokens, + batch_size, + num_heads, + context_lens, + q_seq_lens_config, + test_config['q_dtype'] + ) + + # Step 4: Create network + net = PagedAttentionNzNet( + q_head_num=num_heads, + qk_scale=qk_scale, + kv_head_num=kv_heads, + mask_type=effective_mask_type, # Use effective mask type (may be switched from UNDEFINED to NORM) + batch_run_status_enable=test_config.get('batch_run_status_enable', False), + quant_type=quant_type, + out_data_type=test_config.get('out_data_type', -1), + has_quant_offset=test_config.get('has_quant_offset', False), + compress_type=test_config.get('compress_type', 0), + calc_type=test_config.get('calc_type', 0), + scale_type=test_config.get('scale_type', 0), + input_layout=test_config.get('input_layout', INPUT_LAYOUT_BSND), + mla_v_dim=mla_v_dim, + input_format=INPUT_FORMAT_NZ, + ) + + # Configure dynamic shapes if requested + if dynamic: + _configure_dynamic_shapes(net, test_config) + + # Select descale tensors based on quantization type + k_descale, v_descale = _select_descale_tensors( + quant_type, + inputs_dict['k_descale'], inputs_dict['v_descale'], + inputs_dict['k_scale_per_head'], inputs_dict['v_scale_per_head'] + ) + + # Step 5: Prepare query and KV cache tensors (reshape to 2D/3D format) + query_2d_tensor, key_cache_3d_tensor, value_cache_3d_tensor = _prepare_inputs_for_nz_network( + inputs_dict['query'], + inputs_dict['key_cache'], + inputs_dict['value_cache'], + num_heads, + kv_heads, + test_config['q_dtype'], + test_config['kv_dtype'] + ) + + # Convert other numpy arrays to Tensors + block_tables_tensor = Tensor(inputs_dict['block_tables'], dtype=ms.int32) + q_seq_lens_tensor = Tensor(inputs_dict['q_seq_lens'], dtype=ms.int32) + kv_seq_lens_tensor = Tensor(inputs_dict['kv_seq_lens'], dtype=ms.int32) + + # Convert quantization parameters to Tensors + k_descale_tensor = Tensor(k_descale, dtype=ms.float32) if k_descale is not None else None + v_descale_tensor = Tensor(v_descale, dtype=ms.float32) if v_descale is not None else None + k_offset_tensor = Tensor(inputs_dict['k_offset'], dtype=ms.int32) if inputs_dict['k_offset'] is not None else None + v_offset_tensor = Tensor(inputs_dict['v_offset'], dtype=ms.int32) if inputs_dict['v_offset'] is not None else None + p_scale_tensor = Tensor(inputs_dict['p_scale'], dtype=ms.float32) if inputs_dict['p_scale'] is not None else None + + # Step 6: Execute network (trans_data happens inside) + output = net( + query_2d_tensor, + key_cache_3d_tensor, + value_cache_3d_tensor, + block_tables_tensor, + mask_nd, # Use prepared ND mask (already a Tensor from _prepare_mask_for_network) + test_config.get('batch_run_status'), + k_descale_tensor, k_offset_tensor, + v_descale_tensor, v_offset_tensor, + None, p_scale_tensor, None, # razor_offset, p_scale, log_n + q_seq_lens_tensor, + kv_seq_lens_tensor + ) + + # Step 7: Validate output + _validate_output(output, test_config, num_heads, head_size, mla_v_dim, num_tokens) + + # Step 8: Validate accuracy against golden reference + if validate_accuracy: + # Convert output to numpy for validation + output_np = output.asnumpy() + + validate_dtype = test_config.get('expected_dtype', test_config['q_dtype']) + is_quant_test = quant_type != QUANT_UNQUANT + + assert generator.validate_accuracy( + output_np, golden, validate_dtype, num_heads, max(context_lens), is_quant_test + ), "Accuracy validation failed: output does not match golden reference within tolerance" + + +def _configure_dynamic_shapes(net: PagedAttentionNzNet, test_config: Dict) -> None: + """Configure dynamic input shapes for the network. + + Note: Input shapes reflect the reshaped format (2D for query, 3D for KV cache). + + Args: + net: Network instance to configure + test_config: Test configuration dictionary + """ + num_heads = test_config['num_heads'] + kv_heads = test_config['kv_heads'] + head_size = test_config['head_size'] + q_dtype = test_config['q_dtype'] + kv_dtype = test_config['kv_dtype'] + mask_type = test_config['mask_type'] + batch_size = len(test_config['context_lens']) + mla_v_dim = test_config.get('mla_v_dim', 0) + + # Dynamic tensor shapes (already reshaped to 2D/3D) + # Query: (tokens, num_heads * head_size) + query_dyn = Tensor(shape=[None, num_heads * head_size], dtype=q_dtype) + # Key cache: (num_blocks, block_size, kv_heads * head_size_qk) + key_dyn = Tensor(shape=[None, None, kv_heads * head_size], dtype=kv_dtype) + # Value cache: (num_blocks, block_size, kv_heads * head_size_vo) + head_size_vo = mla_v_dim if mla_v_dim > 0 else head_size + value_dyn = Tensor(shape=[None, None, kv_heads * head_size_vo], dtype=kv_dtype) + block_tables_dyn = Tensor(shape=[None, None], dtype=ms.int32) + + # Mask shape depends on mask type + mask_dyn = None + if mask_type == MASK_ALIBI: + mask_dyn = Tensor(shape=[None, num_heads, None, None], dtype=q_dtype) + elif mask_type == MASK_NORM: + # NORM mask: shape varies by mode + # Decode: (num_tokens, 1, max_context_len) + # Prefill: (batch_size, max_q_len, max_context_len) + mask_dyn = Tensor(shape=[None, None, None], dtype=q_dtype) + elif mask_type == MASK_SPEC: + # SPEC mask: (num_tokens, max_context_len) + mask_dyn = Tensor(shape=[None, None], dtype=q_dtype) + elif mask_type == MASK_FREE: + # For NZ tests we provide a 3D (1, 128, 128) mask tensor before trans_data + mask_dyn = Tensor(shape=[None, None, None], dtype=q_dtype) + + # Sequence lengths (static batch size) + q_seq_dyn = Tensor(shape=[batch_size], dtype=ms.int32) + kv_seq_dyn = Tensor(shape=[batch_size], dtype=ms.int32) + + net.set_inputs( + query_dyn, key_dyn, value_dyn, block_tables_dyn, + mask_dyn, None, # mask, batch_run_status + None, None, None, None, # k_descale, k_offset, v_descale, v_offset + None, None, None, # razor_offset, p_scale, log_n + q_seq_dyn, kv_seq_dyn + ) + + +def _validate_output( + output: Tensor, + test_config: Dict, + num_heads: int, + head_size: int, + mla_v_dim: int, + num_tokens: int, +) -> None: + """Validate output tensor shape and dtype. + + Args: + output: Output tensor to validate + test_config: Test configuration + num_heads: Number of query heads + head_size: Head dimension size + mla_v_dim: MLA V dimension (0 if not MLA) + num_tokens: Expected number of tokens + + Raises: + AssertionError: If shape or dtype validation fails + """ + output_head_dim = mla_v_dim if mla_v_dim > 0 else head_size + expected_shape = (num_tokens, num_heads, output_head_dim) + + assert tuple(output.shape) == expected_shape, \ + f"Output shape mismatch: got {output.shape}, expected {expected_shape}" + + if 'expected_dtype' in test_config: + assert output.dtype == test_config['expected_dtype'], \ + f"Output dtype mismatch: got {output.dtype}, expected {test_config['expected_dtype']}" + + +# ========== Test Decorator ========== + +def _paged_attention_nz_test(test_func): + """Apply common decorators for 310P NZ tests.""" + decorators = [ + pytest.mark.level0, + pytest.mark.platform_ascend310p, + pytest.mark.env_onecard, + ] + for decorator in reversed(decorators): + test_func = decorator(test_func) + return test_func + + +# ==================== 1. Basic Functionality Tests (310P NZ) ==================== + +@_paged_attention_nz_test +@pytest.mark.parametrize('dynamic', [False, True]) +def test_pa_nz_basic_decode_graph(dynamic): + """ + Feature: PagedAttention (NZ 310P) - basic decode path (Graph mode) + Description: Basic decode with fp16 inputs, ND->NZ conversion within graph, dynamic shapes + Expectation: Operator executes successfully; output shape matches and matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8001) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # Basic decode: 1 token per sequence (total=4, not 16-aligned) + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0, + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE, validate_accuracy=True, dynamic=dynamic) + + +@_paged_attention_nz_test +@pytest.mark.parametrize('dynamic', [False, True]) +def test_pa_nz_basic_decode_pynative(dynamic): + """ + Feature: PagedAttention (NZ 310P) - basic decode path (PyNative mode) + Description: Basic decode with fp16 inputs, ND->NZ conversion, PyNative mode with 16-aligned num_tokens + Expectation: Operator executes successfully; output shape matches and matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8001) + test_config = { + 'q_seq_lens': [1] * 16, # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192] * 16, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0, + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE, validate_accuracy=True, dynamic=dynamic) + + +@_paged_attention_nz_test +def test_pa_nz_no_mask_graph(): + """ + Feature: PagedAttention (NZ 310P) - mask-free decode (Graph mode) + Description: Simplest decode with no attention mask, fp16 inputs, NZ format + Expectation: Operator executes; output shape/dtype correct and accuracy within tolerance + """ + generator = PagedAttentionDataGenerator(8002) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0, + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_no_mask_pynative(): + """ + Feature: PagedAttention (NZ 310P) - mask-free decode (PyNative mode) + Description: Simplest decode with no attention mask, fp16 inputs, NZ format, 16-aligned num_tokens + Expectation: Operator executes; output shape/dtype correct and accuracy within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8002) + test_config = { + 'q_seq_lens': [1] * 16, # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192] * 16, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0, + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +# ==================== 2. Mask Type Tests (310P NZ) ==================== + +@_paged_attention_nz_test +def test_pa_nz_norm_mask_graph(): + """ + Feature: PagedAttention (NZ 310P) - normal causal mask (Graph mode) + Description: Decode with triangular causal mask, fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8003) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_norm_mask_pynative(): + """ + Feature: PagedAttention (NZ 310P) - normal causal mask (PyNative mode) + Description: Decode with triangular causal mask, fp16, NZ format, 16-aligned num_tokens + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8003) + test_config = { + 'q_seq_lens': [1] * 16, # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192] * 16, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_alibi_mask_graph(): + """ + Feature: PagedAttention (NZ 310P) - ALIBI positional bias (Graph mode) + Description: Decode with ALIBI bias per head using fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8004) + test_config = { + 'q_seq_lens': [1, 1], + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [500, 500], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_alibi_mask_pynative(): + """ + Feature: PagedAttention (NZ 310P) - ALIBI positional bias (PyNative mode) + Description: Decode with ALIBI bias per head using fp16, NZ format, 16-aligned num_tokens + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8004) + test_config = { + 'q_seq_lens': [1] * 16, # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [500] * 16, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_spec_mask_mtp_graph(): + """ + Feature: PagedAttention (NZ 310P) - SPEC mask for MTP (Graph mode) + Description: Multi-token prediction (q_len=2) with SPEC mask in fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8005) + test_config = { + 'q_seq_lens': [2, 2, 2, 2], # MTP: 2 tokens per sequence (total=8, not 16-aligned) + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 0.01, + 'calc_type': 1, # Enable MTP mode + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_spec_mask_mtp_pynative(): + """ + Feature: PagedAttention (NZ 310P) - SPEC mask for MTP (PyNative mode) + Description: Multi-token prediction with SPEC mask in fp16, NZ format, 16-aligned num_tokens + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8005) + test_config = { + 'q_seq_lens': [2] * 8, # PyNative: num_tokens=16 (16-aligned), 2 tokens per sequence + 'num_heads': 32, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 64, + 'context_lens': [192] * 8, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 0.01, + 'calc_type': 1, # Enable MTP mode + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +# ==================== 3. GQA Tests (310P NZ) ==================== + +@_paged_attention_nz_test +def test_pa_nz_gqa_8to1_graph(): + """ + Feature: PagedAttention (NZ 310P) - GQA 8:1 (Graph mode) + Description: Grouped Query Attention with 8 query heads per 1 KV head, ALIBI mask, fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8006) + test_config = { + 'q_seq_lens': [1] * 13, # total=13, not 16-aligned + 'num_heads': 8, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 13, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_gqa_8to1_pynative(): + """ + Feature: PagedAttention (NZ 310P) - GQA 8:1 (PyNative mode) + Description: Grouped Query Attention with 8 query heads per 1 KV head, ALIBI mask, fp16, NZ format, 16-aligned + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8006) + test_config = { + 'q_seq_lens': [1] * 16, # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 8, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 16, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_gqa_32to8_graph(): + """ + Feature: PagedAttention (NZ 310P) - GQA 32:8 (Graph mode) + Description: Grouped Query Attention with 32 query heads per 8 KV heads, ALIBI mask, fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8007) + test_config = { + 'q_seq_lens': [1] * 13, # total=13, not 16-aligned + 'num_heads': 32, + 'kv_heads': 8, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 13, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_gqa_32to8_pynative(): + """ + Feature: PagedAttention (NZ 310P) - GQA 32:8 (PyNative mode) + Description: Grouped Query Attention with 32 query heads per 8 KV heads, ALIBI mask, fp16, NZ format, 16-aligned + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8007) + test_config = { + 'q_seq_lens': [1] * 16, # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 32, + 'kv_heads': 8, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 16, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +# ==================== 4. Configuration Variation Tests (310P NZ) ==================== + +@_paged_attention_nz_test +def test_pa_nz_odd_heads_graph(): + """ + Feature: PagedAttention (NZ 310P) - Non-power-of-2 heads (Graph mode) + Description: MHA with 7 query/KV heads (non-power-of-2), NORM mask, fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8008) + test_config = { + 'q_seq_lens': [1] * 13, # total=13, not 16-aligned + 'num_heads': 7, + 'kv_heads': 7, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 13, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_odd_heads_pynative(): + """ + Feature: PagedAttention (NZ 310P) - Non-power-of-2 heads (PyNative mode) + Description: MHA with 7 query/KV heads (non-power-of-2), NORM mask, fp16, NZ format, 16-aligned + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8008) + test_config = { + 'q_seq_lens': [1] * 16, # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 7, + 'kv_heads': 7, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 16, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_small_blocksize_graph(): + """ + Feature: PagedAttention (NZ 310P) - Small block size (Graph mode) + Description: Decode with small block size (16), fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8009) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # total=4, not 16-aligned + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 16, + 'num_blocks': 512, + 'context_lens': [192, 193, 194, 195], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_small_blocksize_pynative(): + """ + Feature: PagedAttention (NZ 310P) - Small block size (PyNative mode) + Description: Decode with small block size (16), fp16, NZ format, 16-aligned num_tokens + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8009) + test_config = { + 'q_seq_lens': [1] * 16, # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 16, + 'num_blocks': 512, + 'context_lens': [192] * 16, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_UNDEFINED, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_varied_seqlens_graph(): + """ + Feature: PagedAttention (NZ 310P) - Varied context lengths (Graph mode) + Description: Decode with diverse context lengths [100, 500, 1000, 2000], NORM mask, fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8010) + test_config = { + 'q_seq_lens': [1, 1, 1, 1], # total=4, not 16-aligned + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 256, + 'context_lens': [100, 500, 1000, 2000], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_varied_seqlens_pynative(): + """ + Feature: PagedAttention (NZ 310P) - Varied context lengths (PyNative mode) + Description: Decode with diverse context lengths, NORM mask, fp16, NZ format, 16-aligned num_tokens + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8010) + test_config = { + 'q_seq_lens': [1] * 16, # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 256, + # PyNative: all 16-aligned + 'context_lens': [96, 496, 1008, 2000, 96, 496, 1008, 2000, + 96, 496, 1008, 2000, 96, 496, 1008, 2000], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_NORM, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +# ==================== 5. Combined Feature Tests (310P NZ) ==================== + +@_paged_attention_nz_test +def test_pa_nz_alibi_with_gqa_graph(): + """ + Feature: PagedAttention (NZ 310P) - ALIBI + GQA (Graph mode) + Description: ALIBI mask combined with GQA (8:1), fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8014) + test_config = { + 'q_seq_lens': [1] * 13, # total=13, not 16-aligned + 'num_heads': 8, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 13, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_alibi_with_gqa_pynative(): + """ + Feature: PagedAttention (NZ 310P) - ALIBI + GQA (PyNative mode) + Description: ALIBI mask combined with GQA (8:1), fp16, NZ format, 16-aligned num_tokens + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8014) + test_config = { + 'q_seq_lens': [1] * 16, # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 8, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [88] * 16, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_ALIBI, + 'qk_scale': 1.0 / math.sqrt(128), + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +# ==================== 6. Lookahead/MTP Tests (310P NZ) ==================== + +@_paged_attention_nz_test +def test_pa_nz_lookahead_mixed_lengths_graph(): + """ + Feature: PagedAttention (NZ 310P) - Lookahead with mixed seq lengths (Graph mode) + Description: Speculative decoding/lookahead with mixed query lengths [1, 15, 30, 6], SPEC mask, fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8017) + test_config = { + 'q_seq_lens': [1, 15, 30, 6], # total=52, not 16-aligned + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 512, + 'context_lens': [10, 64, 64, 64], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, # Enable MTP mode + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_lookahead_mixed_lengths_pynative(): + """ + Feature: PagedAttention (NZ 310P) - Lookahead with mixed seq lengths (PyNative mode) + Description: Speculative decoding/lookahead with mixed query lengths, SPEC mask, fp16, NZ format, 16-aligned + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8017) + test_config = { + 'q_seq_lens': [4, 4, 4, 4], # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 32, + 'kv_heads': 32, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 512, + 'context_lens': [10, 64, 64, 64], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, # Enable MTP mode + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_lookahead_with_gqa_graph(): + """ + Feature: PagedAttention (NZ 310P) - Lookahead + GQA (Graph mode) + Description: Speculative decoding/lookahead with GQA (32:8), SPEC mask, fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + generator = PagedAttentionDataGenerator(8018) + test_config = { + 'q_seq_lens': [1, 15, 30, 6], # total=52, not 16-aligned + 'num_heads': 32, + 'kv_heads': 8, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 512, + 'context_lens': [10, 64, 64, 64], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, # Enable MTP mode + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_lookahead_with_gqa_pynative(): + """ + Feature: PagedAttention (NZ 310P) - Lookahead + GQA (PyNative mode) + Description: Speculative decoding/lookahead with GQA (32:8), SPEC mask, fp16, NZ format, 16-aligned + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + generator = PagedAttentionDataGenerator(8018) + test_config = { + 'q_seq_lens': [4, 4, 4, 4], # PyNative: num_tokens=16 (16-aligned) + 'num_heads': 32, + 'kv_heads': 8, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 512, + 'context_lens': [10, 64, 64, 64], + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, # Enable MTP mode + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +# ==================== 7. Prefill Tests (310P NZ) ==================== + +@_paged_attention_nz_test +def test_pa_nz_prefill_mask_free_graph(): + """ + Feature: PagedAttention (NZ 310P) - Prefill without mask (Graph mode) + Description: Prefill scenario with long query lengths (139), mask-free, fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + batch = 2 + generator = PagedAttentionDataGenerator(8019) + test_config = { + 'q_seq_lens': [128 + 11] * batch, # total=278, not 16-aligned + 'num_heads': 8, + 'kv_heads': 8, + 'head_size': 64, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [128 * 2 + 11] * batch, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_FREE, + 'qk_scale': 1.0 / math.sqrt(64), + 'calc_type': 1, # Enable MTP/prefill mode + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_prefill_mask_free_pynative(): + """ + Feature: PagedAttention (NZ 310P) - Prefill without mask (PyNative mode) + Description: Prefill scenario with long query lengths, mask-free, fp16, NZ format, 16-aligned + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + batch = 2 + generator = PagedAttentionDataGenerator(8019) + test_config = { + 'q_seq_lens': [128] * batch, # PyNative: num_tokens=256 (16-aligned) + 'num_heads': 8, + 'kv_heads': 8, + 'head_size': 64, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [128 * 2] * batch, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_FREE, + 'qk_scale': 1.0 / math.sqrt(64), + 'calc_type': 1, # Enable MTP/prefill mode + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_prefill_spec_mask_graph(): + """ + Feature: PagedAttention (NZ 310P) - Prefill with SPEC mask (Graph mode) + Description: Prefill scenario with long query lengths (139), SPEC mask, GQA (8:1), fp16, NZ format + Expectation: Operator executes; output matches golden within tolerance + """ + batch = 1 + generator = PagedAttentionDataGenerator(8020) + test_config = { + 'q_seq_lens': [128 * 1 + 11] * batch, # total=139, not 16-aligned + 'num_heads': 8, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [128 * 2 + 11] * batch, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, # Enable MTP/prefill mode + } + _run_paged_attention_nz_test(generator, test_config, context.GRAPH_MODE) + + +@_paged_attention_nz_test +def test_pa_nz_prefill_spec_mask_pynative(): + """ + Feature: PagedAttention (NZ 310P) - Prefill with SPEC mask (PyNative mode) + Description: Prefill scenario with long query lengths, SPEC mask, GQA (8:1), fp16, NZ format, 16-aligned + Expectation: Operator executes; output matches golden within tolerance + + Note: PyNative mode requires num_tokens (H dimension) to be 16-aligned for trans_data. + """ + batch = 1 + generator = PagedAttentionDataGenerator(8020) + test_config = { + 'q_seq_lens': [128] * batch, # PyNative: num_tokens=128 (16-aligned) + 'num_heads': 8, + 'kv_heads': 1, + 'head_size': 128, + 'block_size': 128, + 'num_blocks': 1024, + 'context_lens': [128 * 2] * batch, + 'q_dtype': ms.float16, + 'kv_dtype': ms.float16, + 'mask_type': MASK_SPEC, + 'qk_scale': 1.0 / math.sqrt(128), + 'calc_type': 1, # Enable MTP/prefill mode + } + _run_paged_attention_nz_test(generator, test_config, context.PYNATIVE_MODE) -- Gitee From 36d1ca386e089a4674a9ddc39ace4b4629c2fdfd Mon Sep 17 00:00:00 2001 From: tianxiaodong3 Date: Mon, 24 Nov 2025 14:34:09 +0800 Subject: [PATCH 2/2] fixbug --- tests/st/paged_attention_ms_reference.py | 699 +++++++++++++++++++++++ tests/st/test_custom_paged_attention.py | 171 ++++-- 2 files changed, 820 insertions(+), 50 deletions(-) create mode 100644 tests/st/paged_attention_ms_reference.py diff --git a/tests/st/paged_attention_ms_reference.py b/tests/st/paged_attention_ms_reference.py new file mode 100644 index 0000000..0b1acca --- /dev/null +++ b/tests/st/paged_attention_ms_reference.py @@ -0,0 +1,699 @@ +"""MindSpore reference implementation for paged attention tests. + +This module mirrors the numpy-based logic in ``test_custom_paged_attention.py`` +using MindSpore Tensor operators so that test data generation, mask construction, +and golden reference computation can be performed fully inside MindSpore. +""" + +from __future__ import annotations + +import math +from typing import Dict, List, Optional + +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops + +MASK_UNDEFINED = 0 +MASK_NORM = 1 +MASK_ALIBI = 2 +MASK_SPEC = 3 +MASK_FREE = 4 + +QUANT_UNQUANT = 0 +DEQUANT_FUSION = 1 +QUANT_QKV_OFFLINE = 2 +QUANT_QKV_ONLINE = 3 + + +class PagedAttentionMsReference: + """MindSpore version of paged attention data generator and golden calc.""" + + def __init__(self, seed: int = 2025): + self.seed = int(seed) + ms.set_seed(self.seed) + self.np_rng = np.random.default_rng(self.seed) + self.round = ops.round + self.clip = ops.clip_by_value + self.concat = ops.concat + self.stack = ops.stack + # Use explicit max / exp / sum for softmax to mirror numpy reference. + self.exp = ops.exp + self.transpose = ops.transpose + self.matmul = ops.matmul + self.maximum = ops.maximum + self.abs = ops.abs + + # ---------------------------------------------------------------------- + # Random helpers + + def _uniform_np(self, shape, low, high): + """Generate numpy array from uniform distribution.""" + return self.np_rng.uniform(low, high, size=shape) + + def _randint_np(self, low, high, shape): + return self.np_rng.integers(low, high, size=shape, endpoint=False, dtype=np.int32) + + # ---------------------------------------------------------------------- + # Public APIs + + def generate_inputs_ms( + self, + num_heads: int, + kv_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + q_seq_lens: List[int], + context_lens: List[int], + q_dtype: ms.dtype, + kv_dtype: ms.dtype, + mask_type: int, + quant_type: int = QUANT_UNQUANT, + has_quant_offset: bool = False, + mla_v_dim: int = 0, + mask_out_dtype: Optional[ms.dtype] = None, + ) -> Dict[str, Tensor]: + """MindSpore variant of generate_inputs (returns Tensors).""" + is_full_quant = quant_type in (QUANT_QKV_OFFLINE, QUANT_QKV_ONLINE) + is_dequant_fusion = quant_type == DEQUANT_FUSION + is_mla = mla_v_dim > 0 + + batch_size = len(q_seq_lens) + num_tokens = sum(q_seq_lens) + head_size_qk = head_size + head_size_vo = mla_v_dim if mla_v_dim > 0 else head_size + + # Query + if q_dtype == ms.int8: + q_range = 2.0 if is_full_quant else 1.0 + query_np = self._uniform_np((num_tokens, num_heads, head_size_qk), -q_range, q_range).astype(np.float32) + query_np = np.clip(np.rint(query_np * (127.0 / q_range)), -127, 127).astype(np.int8) + query = Tensor(query_np, dtype=ms.int8) + else: + query_np = self._uniform_np((num_tokens, num_heads, head_size_qk), -1.0, 1.0).astype(np.float32) + query = Tensor(query_np, dtype=q_dtype) + + # Key/Value cache + if kv_dtype == ms.int8: + kv_range = 2.0 if is_full_quant else 4.0 + key_np = self._uniform_np((num_blocks, block_size, kv_heads, head_size_qk), -kv_range, kv_range).astype(np.float32) + key_np = np.clip(np.rint(key_np * (127.0 / kv_range)), -127, 127).astype(np.int8) + key_cache = Tensor(key_np, dtype=ms.int8) + if is_mla: + value_cache = Tensor(key_np[:, :, :, :head_size_vo], dtype=ms.int8) + else: + value_np = self._uniform_np((num_blocks, block_size, kv_heads, head_size_vo), -kv_range, kv_range).astype(np.float32) + value_np = np.clip(np.rint(value_np * (127.0 / kv_range)), -127, 127).astype(np.int8) + value_cache = Tensor(value_np, dtype=ms.int8) + else: + key_np = self._uniform_np((num_blocks, block_size, kv_heads, head_size_qk), -1.0, 1.0).astype(np.float32) + key_cache = Tensor(key_np, dtype=kv_dtype) + if is_mla: + value_cache = Tensor(key_np[:, :, :, :head_size_vo], dtype=kv_dtype) + else: + value_np = self._uniform_np((num_blocks, block_size, kv_heads, head_size_vo), -1.0, 1.0).astype(np.float32) + value_cache = Tensor(value_np, dtype=kv_dtype) + + # Block tables + max_context_len = max(context_lens) + max_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables_np = self._randint_np(0, num_blocks, (batch_size, max_blocks_per_seq)) + block_tables = Tensor(block_tables_np, dtype=ms.int32) + + q_seq_tensor = Tensor(q_seq_lens, ms.int32) + kv_seq_tensor = Tensor(context_lens, ms.int32) + + mask_dtype = mask_out_dtype if mask_out_dtype is not None else q_dtype + mask = self._generate_mask_tensor( + mask_type, + num_tokens, + max_context_len, + q_seq_lens, + context_lens, + mask_dtype, + num_heads, + is_mla, + ) + + k_descale = v_descale = k_offset = v_offset = None + k_scale_per_head = v_scale_per_head = p_scale = None + + if is_dequant_fusion: + k_descale_np = self._uniform_np((kv_heads * head_size_qk,), -1.0, 1.0).astype(np.float32) + v_descale_np = self._uniform_np((kv_heads * head_size_vo,), -1.0, 1.0).astype(np.float32) + k_descale = Tensor(k_descale_np, ms.float32) + v_descale = Tensor(v_descale_np, ms.float32) + if has_quant_offset: + k_offset_np = self._randint_np(-20, 20, (kv_heads * head_size_qk,)) + v_offset_np = self._randint_np(-20, 20, (kv_heads * head_size_vo,)) + k_offset = Tensor(k_offset_np, ms.int32) + v_offset = Tensor(v_offset_np, ms.int32) + elif is_full_quant: + k_scale_np = self._uniform_np((num_heads,), -1.0, 2.0).astype(np.float32) + v_scale_np = self._uniform_np((num_heads,), -1.0, 2.0).astype(np.float32) + k_scale_per_head = Tensor(k_scale_np, ms.float32) + v_scale_per_head = Tensor(v_scale_np, ms.float32) + if quant_type == QUANT_QKV_OFFLINE: + p_scale_np = self._uniform_np((num_heads,), -1.0, 2.0).astype(np.float32) + p_scale = Tensor(p_scale_np, ms.float32) + + return { + "query": query, + "key_cache": key_cache, + "value_cache": value_cache, + "block_tables": block_tables, + "q_seq_lens": q_seq_tensor, + "kv_seq_lens": kv_seq_tensor, + "mask": mask, + "k_descale": k_descale, + "v_descale": v_descale, + "k_offset": k_offset, + "v_offset": v_offset, + "k_scale_per_head": k_scale_per_head, + "v_scale_per_head": v_scale_per_head, + "p_scale": p_scale, + "q_dtype": q_dtype, + "kv_dtype": kv_dtype, + "mask_dtype": mask_dtype, + } + + # ------------------------------------------------------------------ + # Mask construction + + def _generate_mask_tensor( + self, + mask_type: int, + num_tokens: int, + max_context_len: int, + q_seq_lens: List[int], + kv_seq_lens: List[int], + dtype: ms.dtype, + num_heads: int, + is_mla: bool, + ) -> Optional[Tensor]: + mask_np = self._generate_mask_np( + mask_type, + num_tokens, + max_context_len, + q_seq_lens, + kv_seq_lens, + dtype, + num_heads, + is_mla, + ) + if mask_np is None: + return None + return Tensor(mask_np, dtype=dtype) + + def _generate_mask_np( + self, + mask_type: int, + num_tokens: int, + max_context_len: int, + q_seq_lens: List[int], + kv_seq_lens: List[int], + dtype: ms.dtype, + num_heads: int, + is_mla: bool, + ) -> Optional[np.ndarray]: + if mask_type == MASK_UNDEFINED: + return None + + if mask_type == MASK_FREE: + mask = np.zeros((num_tokens, max_context_len), dtype=np.float32) + row = 0 + for q_len, k_len in zip(q_seq_lens, kv_seq_lens): + tri = np.triu(np.ones((q_len, q_len), dtype=np.float32), k=1) * -60000.0 + start = k_len - q_len + mask[row : row + q_len, start:k_len] = tri + row += q_len + return mask + + if mask_type == MASK_NORM: + max_q_len = max(q_seq_lens) + if max_q_len == 1: + mask = np.zeros((num_tokens, 1, max_context_len), dtype=np.float32) + for token_idx in range(num_tokens): + if token_idx > 0: + mask[token_idx, :, :token_idx] = -10000.0 + return mask + + mask = np.zeros((len(q_seq_lens), max_q_len, max_context_len), dtype=np.float32) + for batch_idx, q_len in enumerate(q_seq_lens): + tri = np.triu(np.ones((q_len, q_len), dtype=np.float32), k=1) * -10000.0 + mask[batch_idx, -q_len:, -q_len:] = tri + return mask + + if mask_type == MASK_ALIBI: + max_q_len = max(q_seq_lens) + base_shape = ( + len(q_seq_lens), + num_heads, + 1 if max_q_len == 1 else max_q_len, + max_context_len, + ) + mask = np.zeros(base_shape, dtype=np.float32) + slopes = self._get_alibi_slopes_numpy(num_heads) + for i, (q_len, k_len) in enumerate(zip(q_seq_lens, kv_seq_lens)): + if k_len == 0: + continue + positions = np.arange(k_len, dtype=np.float32) + alibi_bias = (positions - (k_len - 1)).astype(np.float32) + alibi_bias = alibi_bias.reshape(1, 1, k_len) + bias = slopes[:, :, None] * alibi_bias + if max_q_len == 1: + mask[i, :, :, :k_len] = bias + else: + mask[i, :, :q_len, :k_len] = bias + return mask + + if mask_type == MASK_SPEC: + mask = np.zeros((num_tokens, max_context_len), dtype=np.float32) + row = 0 + pre_mask_factor = 1.0 if (is_mla and dtype == ms.bfloat16) else -10000.0 + for q_len, k_len in zip(q_seq_lens, kv_seq_lens): + # Keep row (global token index) semantics aligned with numpy reference: + # - Always advance row by q_len whenever there are query tokens, + # even if k_len == 0 (no valid context); this matches the + # behaviour of pre_q += qseq in the numpy implementation. + if q_len == 0: + continue + if k_len == 0: + row += q_len + continue + start = max(0, k_len - q_len) + tri = np.triu(np.ones((q_len, q_len), dtype=np.float32), k=1) * pre_mask_factor + mask[row : row + q_len, start:k_len] = tri + row += q_len + return mask + + return None + + def _get_alibi_slopes(self, num_heads: int) -> Tensor: + nearest_pow = 2 ** int(math.floor(math.log2(num_heads))) + base = 2.0 ** (-8.0 / nearest_pow) + slopes = [base ** i for i in range(1, nearest_pow + 1)] + if nearest_pow < num_heads: + extra_base = 2.0 ** (-4.0 / nearest_pow) + extra = [ + extra_base ** i for i in range(1, 1 + 2 * (num_heads - nearest_pow), 2) + ] + slopes.extend(extra) + return Tensor(slopes[:num_heads], ms.float32).reshape((num_heads, 1, 1)) + + def _get_alibi_slopes_numpy(self, num_heads: int) -> np.ndarray: + nearest_pow = 2 ** int(math.floor(math.log2(num_heads))) + base = 2.0 ** (-8.0 / nearest_pow) + slopes = [base ** i for i in range(1, nearest_pow + 1)] + if nearest_pow < num_heads: + extra_base = 2.0 ** (-4.0 / nearest_pow) + extra = [ + extra_base ** i for i in range(1, 1 + 2 * (num_heads - nearest_pow), 2) + ] + slopes.extend(extra) + slopes = np.array(slopes[:num_heads], dtype=np.float32) + return slopes.reshape(num_heads, 1, 1) + + # ------------------------------------------------------------------ + # Golden reference computation + + def compute_golden_reference_ms( + self, + query: Tensor, + key_cache: Tensor, + value_cache: Tensor, + block_tables: Tensor, + q_seq_lens: Tensor, + kv_seq_lens: Tensor, + scale: float, + mask: Optional[Tensor], + mask_type: int = MASK_UNDEFINED, + k_descale: Optional[Tensor] = None, + v_descale: Optional[Tensor] = None, + k_offset: Optional[Tensor] = None, + v_offset: Optional[Tensor] = None, + k_scale_per_head: Optional[Tensor] = None, + v_scale_per_head: Optional[Tensor] = None, + p_scale: Optional[Tensor] = None, + output_dtype: Optional[ms.dtype] = None, + mla_v_dim: int = 0, + mask_dtype: Optional[ms.dtype] = None, + ) -> Tensor: + """Compute golden reference using MindSpore operators, mirroring numpy reference. + + To align numerically with the numpy implementation used in + ``test_custom_paged_attention.py``, all heavy computations are forced + onto CPU in float32/int32, and softmax is implemented via explicit + max/exp/sum rather than a fused op. + """ + original_q_dtype = query.dtype + q = ops.cast(query, ms.float32) + kc = ops.cast(key_cache, ms.float32) + vc = ops.cast(value_cache, ms.float32) + num_tokens, num_heads, head_size_qk = q.shape + kv_heads = kc.shape[2] + head_size_v = vc.shape[3] + output_head_dim = head_size_v if mla_v_dim > 0 else head_size_qk + + # Dequant fusion (int8 KV → float domain) + if k_descale is not None and v_descale is not None: + k_scale = ops.reshape(ops.cast(k_descale, ms.float32), (kv_heads, -1)) + v_scale = ops.reshape(ops.cast(v_descale, ms.float32), (kv_heads, -1)) + if k_offset is not None: + k_off = ops.reshape(ops.cast(k_offset, ms.float32), (kv_heads, -1)) + kc = kc + k_off[None, None, :, :] + if v_offset is not None: + v_off = ops.reshape(ops.cast(v_offset, ms.float32), (kv_heads, -1)) + vc = vc + v_off[None, None, :, :] + kc = kc * k_scale[None, None, :, :] + vc = vc * v_scale[None, None, :, :] + + # -------------------------------------------------------------- + # Simulate low-precision matmul behaviour for quantized KV: + # - For bf16/fp16 Q inputs, many hardware kernels effectively + # run Q/K/V matmuls in that low precision and accumulate to + # fp32. To make golden closer to the real kernel (and more + # consistent with the numpy reference thresholds), we round + # Q/K/V onto the corresponding grid and then lift back to + # float32 for the rest of the computation. + # -------------------------------------------------------------- + if original_q_dtype in (ms.bfloat16, ms.float16): + lowp_dtype = ms.bfloat16 if original_q_dtype == ms.bfloat16 else ms.float16 + q = ops.cast(ops.cast(q, lowp_dtype), ms.float32) + kc = ops.cast(ops.cast(kc, lowp_dtype), ms.float32) + vc = ops.cast(ops.cast(vc, lowp_dtype), ms.float32) + + is_full_quant = k_scale_per_head is not None and v_scale_per_head is not None + if is_full_quant: + q = ops.cast(q, ms.int32) + kc = ops.cast(kc, ms.int32) + vc = ops.cast(vc, ms.int32) + + out = ops.zeros( + (num_tokens, num_heads, output_head_dim), ms.float32 + ) + q_seq_list = q_seq_lens.asnumpy().tolist() + kv_seq_list = kv_seq_lens.asnumpy().tolist() + block_tables_np = block_tables.asnumpy().tolist() + block_size = kc.shape[1] + q_ptr = 0 + + for batch_idx, (ql, kl) in enumerate(zip(q_seq_list, kv_seq_list)): + if ql == 0 or kl == 0: + q_ptr += ql + continue + + keys = [] + values = [] + block_table = block_tables_np[batch_idx] + for j in range(kl): + block_id = int(block_table[j // block_size]) + offset = j % block_size + keys.append(kc[block_id, offset]) + values.append(vc[block_id, offset]) + key_seq = self.stack(keys, axis=0) + value_seq = self.stack(values, axis=0) + + q_slice = q[q_ptr : q_ptr + ql] + q_t = self.transpose(q_slice, (1, 0, 2)) + k_t = self.transpose(key_seq, (1, 2, 0)) + v_t = self.transpose(value_seq, (1, 0, 2)) + + if is_full_quant: + out_slice = self._compute_full_quant_attention_ms( + q_t, + k_t, + v_t, + num_heads, + kv_heads, + scale, + k_scale_per_head, + v_scale_per_head, + p_scale, + mask, + mask_type, + batch_idx, + q_ptr, + ql, + kl, + mla_v_dim, + mask_dtype, + ) + else: + out_slice = self._compute_float_attention_ms( + q_t, + k_t, + v_t, + num_heads, + kv_heads, + scale, + mask, + mask_type, + batch_idx, + q_ptr, + ql, + kl, + mla_v_dim, + mask_dtype, + ) + + out_slice = ops.reshape(out_slice, (num_heads, ql, output_head_dim)) + out[q_ptr : q_ptr + ql] = self.transpose(out_slice, (1, 0, 2)) + q_ptr += ql + + if output_dtype is not None: + return ops.cast(out, output_dtype) + return ops.cast(out, query.dtype) + + def _compute_full_quant_attention_ms( + self, + query_t: Tensor, + key_t: Tensor, + value_t: Tensor, + num_heads: int, + kv_heads: int, + scale: float, + k_scale: Tensor, + v_scale: Tensor, + p_scale: Optional[Tensor], + mask: Optional[Tensor], + mask_type: int, + batch_idx: int, + q_ptr: int, + ql: int, + kl: int, + mla_v_dim: int, + mask_dtype: Optional[ms.dtype], + ) -> Tensor: + scores_int32 = self._group_matmul_ms(query_t, key_t, num_heads, kv_heads) + scores = ops.zeros_like(scores_int32, ms.float32) + for h in range(num_heads): + scores[h] = ops.cast(scores_int32[h], ms.float32) * k_scale[h] + scores = scores * Tensor(scale, ms.float32) + + if mask is not None: + scores = self._apply_mask_ms( + scores, mask, mask_type, batch_idx, q_ptr, ql, kl, mla_v_dim, mask_dtype + ) + + scores_max = ops.amax(scores, axis=-1, keepdims=True) + exp_scores = self.exp(scores - scores_max) + row_sum = ops.sum(exp_scores, dim=-1, keepdim=True) + + if p_scale is not None: + probs = exp_scores * ops.reshape(p_scale, (num_heads, 1, 1)) + probs = self.round(probs.astype(ms.float16)).astype(ms.int32) + out_int32 = self._group_matmul_pv_ms(probs, value_t, num_heads, kv_heads) + out = ops.zeros_like(out_int32, ms.float32) + for h in range(num_heads): + out[h] = ops.cast(out_int32[h], ms.float32) * v_scale[h] + else: + row_max = ops.amax(exp_scores, axis=-1, keepdims=True) + p_scale_dyn = row_max / 127.0 + probs = exp_scores / p_scale_dyn + probs = self.round(probs.astype(ms.float16)).astype(ms.int32) + out_int32 = self._group_matmul_pv_ms(probs, value_t, num_heads, kv_heads) + out = ops.zeros_like(out_int32, ms.float32) + for h in range(num_heads): + de_scale = v_scale[h] * row_max[h, 0, 0] / 127.0 + out[h] = ops.cast(out_int32[h], ms.float32) * de_scale + + return out / row_sum + + def _compute_float_attention_ms( + self, + query_t: Tensor, + key_t: Tensor, + value_t: Tensor, + num_heads: int, + kv_heads: int, + scale: float, + mask: Optional[Tensor], + mask_type: int, + batch_idx: int, + q_ptr: int, + ql: int, + kl: int, + mla_v_dim: int, + mask_dtype: Optional[ms.dtype], + ) -> Tensor: + # Mirror numpy reference: + # scores = group_matmul(q, k) * scale + # scores = apply_mask(scores) + # scores_max = max(scores, axis=-1, keepdims=True) + # exp_scores = exp(scores - scores_max) + # probs = exp_scores / sum(exp_scores, axis=-1, keepdims=True) + # out = group_matmul_pv(probs, v) + scores = self._group_matmul_ms(query_t, key_t, num_heads, kv_heads) + scores = scores * Tensor(scale, ms.float32) + if mask is not None: + scores = self._apply_mask_ms( + scores, mask, mask_type, batch_idx, q_ptr, ql, kl, mla_v_dim, mask_dtype + ) + + scores_max = ops.amax(scores, axis=-1, keepdims=True) + exp_scores = self.exp(scores - scores_max) + denom = ops.sum(exp_scores, dim=-1, keepdim=True) + probs = exp_scores / denom + + return self._group_matmul_pv_ms(probs, value_t, num_heads, kv_heads) + + def _apply_mask_ms( + self, + scores: Tensor, + mask: Tensor, + mask_type: int, + batch_idx: int, + q_ptr: int, + ql: int, + kl: int, + mla_v_dim: int, + mask_dtype: Optional[ms.dtype], + ) -> Tensor: + post_factor = 1.0 + if mla_v_dim > 0 and mask_type == MASK_SPEC and mask_dtype == ms.bfloat16: + post_factor = -10000.0 + mask_float = ops.cast(mask, ms.float32) + if mask_type == MASK_FREE: + mask_slice = mask_float[q_ptr : q_ptr + ql, :kl] + scores = scores + mask_slice[None, :, :] + elif mask_type == MASK_ALIBI: + mask_slice = mask_float[batch_idx, :, :ql, :kl] + scores = scores + mask_slice + elif mask_type == MASK_SPEC: + mask_slice = mask_float[q_ptr : q_ptr + ql, :kl] + scores = scores + mask_slice[None, :, :] * post_factor + else: # MASK_NORM + # Align MASK_NORM behaviour with numpy reference implementation: + # - Decode mode (max_q_len == 1): mask shape is (num_tokens, 1, max_context_len) + # Use batch_idx to index the current sequence and add a 1D mask + # broadcasted over (num_heads, ql, kl). + # - Prefill mode: mask shape is (batch_size, max_q_len, max_context_len) + # Use last ql rows for current batch, same as numpy path. + if mask_float.ndim == 3 and mask_float.shape[1] == 1: + # Decode mode: take 1D mask [kl] for current sequence. + mask_slice = mask_float[batch_idx, 0, :kl] # [kl] + scores = scores + mask_slice[None, None, :] # -> broadcast to [num_heads, ql, kl] + else: + # Prefill mode: causal upper triangular on the last ql positions. + mask_slice = mask_float[batch_idx, -ql:, :kl] # [ql, kl] + scores = scores + mask_slice[None, :, :] # -> [1, ql, kl] broadcast over heads + return scores + + def _group_matmul_ms( + self, query_block: Tensor, key_block: Tensor, num_heads: int, kv_heads: int + ) -> Tensor: + group_size = num_heads // kv_heads + outputs = [] + for kv_h in range(kv_heads): + q_group = query_block[kv_h * group_size : (kv_h + 1) * group_size] + k_head = key_block[kv_h : kv_h + 1] + outputs.append(self.matmul(q_group, k_head)) + return self.concat(outputs, axis=0) + + def _group_matmul_pv_ms( + self, prob_block: Tensor, value_block: Tensor, num_heads: int, kv_heads: int + ) -> Tensor: + group_size = num_heads // kv_heads + outputs = [] + for kv_h in range(kv_heads): + p_group = prob_block[kv_h * group_size : (kv_h + 1) * group_size] + v_head = value_block[kv_h : kv_h + 1] + outputs.append(self.matmul(p_group, v_head)) + return self.concat(outputs, axis=0) + + # ------------------------------------------------------------------ + # Accuracy check + + def validate_accuracy_ms( + self, + output, + golden, + dtype: ms.dtype, + num_heads: int, + max_context_len: int, + is_quant: bool = False, + ) -> bool: + # Convert to numpy float32 for comparison + out_np = np.array(output, dtype=np.float32) + golden_np = np.array(golden, dtype=np.float32) + + out_flat = out_np.reshape(-1) + golden_flat = golden_np.reshape(-1) + diff = np.abs(out_flat - golden_flat) + max_diff = np.max(diff) + + ratios = [0.001, 0.001, 0.005, 0.005] + rel_loose, abs_loose, rel_strict, abs_strict = ratios + strict_scale = 2.0 if (dtype == ms.bfloat16 and is_quant) else 1.0 + rel_strict_eff = rel_strict * strict_scale + abs_strict_eff = abs_strict * strict_scale + + limit_error = np.maximum(np.abs(golden_flat) * rel_loose, abs_loose) + strict_limit_error = np.maximum(np.abs(golden_flat) * rel_strict_eff, abs_strict_eff) + error_count = np.sum(diff > limit_error) + strict_error_count = np.sum(diff > strict_limit_error) + + out_len = max(1, out_flat.shape[0]) + accuracy_loose = 1.0 - float(error_count) / out_len + accuracy_strict = 1.0 - float(strict_error_count) / out_len + + print(f"[MS-REF] Max difference: {max_diff:.6f}") + print(f"[MS-REF] Loose accuracy (1/1000): {accuracy_loose:.6f}") + print(f"[MS-REF] Strict accuracy (5/1000): {accuracy_strict:.6f}") + + error_ratio = float(strict_error_count) / out_len + if dtype == ms.bfloat16 or is_quant: + legacy_pass = error_ratio <= rel_strict_eff + else: + legacy_pass = error_ratio <= ratios[0] + + calc_times = num_heads * max_context_len + 4 + if dtype == ms.bfloat16: + base = 2 ** (-7) if calc_times < 2048 else 2 ** (-6) + error_factor = base * (2.0 if is_quant else 1.0) + elif dtype == ms.float16: + error_factor = 2 ** (-8) if calc_times < 2048 else 2 ** (-7) + else: + if calc_times < 2048: + error_factor = 2 ** (-11) + elif calc_times < 16384: + error_factor = 2 ** (-10) + else: + error_factor = 2 ** (-9) + + error_threshold = np.maximum(np.abs(golden_flat), 1.0) * error_factor + adaptive_pass = np.all(diff <= error_threshold) + + print(f"[MS-REF] Calculation complexity: {calc_times}") + print(f"[MS-REF] Error factor: {error_factor:.6e}") + print(f"[MS-REF] Adaptive test: {'PASS' if adaptive_pass else 'FAIL'}") + print(f"[MS-REF] Legacy test: {'PASS' if legacy_pass else 'FAIL'}") + + return bool(adaptive_pass or legacy_pass) + + +__all__ = ["PagedAttentionMsReference"] + + diff --git a/tests/st/test_custom_paged_attention.py b/tests/st/test_custom_paged_attention.py index c91627f..971a6ad 100644 --- a/tests/st/test_custom_paged_attention.py +++ b/tests/st/test_custom_paged_attention.py @@ -27,6 +27,8 @@ import mindspore as ms from mindspore import Tensor, context, ops, nn import ms_custom_ops +from paged_attention_ms_reference import PagedAttentionMsReference + # Mask type enumerations MASK_UNDEFINED = 0 @@ -49,6 +51,8 @@ INPUT_LAYOUT_BNSD = 1 INPUT_FORMAT_ND = 0 INPUT_FORMAT_NZ = 1 +REFERENCE_NUMPY = "numpy" +REFERENCE_MINDSPORE = "mindspore" class PagedAttentionDataGenerator: """Data generator and golden reference calculator for paged attention tests. @@ -58,6 +62,7 @@ class PagedAttentionDataGenerator: def __init__(self, rng_seed: int = 2025): """Initialize with random seed.""" + self.seed = rng_seed self.rng = np.random.default_rng(rng_seed) random.seed(rng_seed) np.random.seed(rng_seed) @@ -1216,7 +1221,8 @@ def _set_dynamic_shapes_for_pa(net: PagedAttentionNet, test_config: dict) -> Non def _run_paged_attention_test(generator: PagedAttentionDataGenerator, test_config: dict, - run_mode: int, validate_accuracy: bool = True, dynamic: bool = False): + run_mode: int, validate_accuracy: bool = True, dynamic: bool = False, + reference_impl: str = REFERENCE_NUMPY): """Execute paged attention test with given configuration. Following the refactored workflow: @@ -1232,6 +1238,7 @@ def _run_paged_attention_test(generator: PagedAttentionDataGenerator, test_confi run_mode: GRAPH_MODE or PYNATIVE_MODE validate_accuracy: Whether to validate accuracy against golden reference dynamic: Whether to use dynamic shapes + reference_impl: "numpy" (default) or "mindspore" for MS reference """ context.set_context(device_target="Ascend", mode=run_mode) @@ -1259,49 +1266,91 @@ def _run_paged_attention_test(generator: PagedAttentionDataGenerator, test_confi # MLA configuration mla_v_dim = test_config.get('mla_v_dim', 0) - # Step 1: Generate numpy arrays for golden calculation - # Determine mask dtype to match expected output dtype when provided - mask_out_dtype = test_config.get('expected_dtype') - inputs_dict = generator.generate_inputs( - num_heads, - kv_heads, - head_size, - block_size, - num_blocks, - q_seq_lens_config, - context_lens, - q_dtype, - kv_dtype, - mask_type, - quant_type, - has_quant_offset, - mla_v_dim, - mask_out_dtype, - ) + ms_ref = None + if reference_impl == REFERENCE_MINDSPORE: + ms_ref = PagedAttentionMsReference(getattr(generator, "seed", 2025)) + ms_inputs = ms_ref.generate_inputs_ms( + num_heads, + kv_heads, + head_size, + block_size, + num_blocks, + q_seq_lens_config, + context_lens, + q_dtype, + kv_dtype, + mask_type, + quant_type, + has_quant_offset, + mla_v_dim, + test_config.get('expected_dtype'), + ) + inputs_dict = {key: _tensor_to_numpy(val) for key, val in ms_inputs.items()} + golden_tensor = ms_ref.compute_golden_reference_ms( + ms_inputs['query'], + ms_inputs['key_cache'], + ms_inputs['value_cache'], + ms_inputs['block_tables'], + ms_inputs['q_seq_lens'], + ms_inputs['kv_seq_lens'], + qk_scale, + ms_inputs['mask'], + mask_type, + ms_inputs['k_descale'], + ms_inputs['v_descale'], + ms_inputs['k_offset'], + ms_inputs['v_offset'], + ms_inputs['k_scale_per_head'], + ms_inputs['v_scale_per_head'], + ms_inputs['p_scale'], + output_dtype=test_config.get('expected_dtype') if is_full_quant else None, + mla_v_dim=mla_v_dim, + mask_dtype=ms_inputs['mask_dtype'], + ) + golden = golden_tensor.asnumpy() + else: + # Step 1: Generate numpy arrays for golden calculation + mask_out_dtype = test_config.get('expected_dtype') + inputs_dict = generator.generate_inputs( + num_heads, + kv_heads, + head_size, + block_size, + num_blocks, + q_seq_lens_config, + context_lens, + q_dtype, + kv_dtype, + mask_type, + quant_type, + has_quant_offset, + mla_v_dim, + mask_out_dtype, + ) - # Step 2: Compute golden reference using numpy arrays - golden_output_dtype = test_config.get('expected_dtype') if is_full_quant else None - golden = generator.compute_golden_reference( - inputs_dict['query'], - inputs_dict['key_cache'], - inputs_dict['value_cache'], - inputs_dict['block_tables'], - inputs_dict['q_seq_lens'], - inputs_dict['kv_seq_lens'], - qk_scale, - inputs_dict['mask'], - mask_type, - inputs_dict['k_descale'], - inputs_dict['v_descale'], - inputs_dict['k_offset'], - inputs_dict['v_offset'], - inputs_dict['k_scale_per_head'], - inputs_dict['v_scale_per_head'], - inputs_dict['p_scale'], - output_dtype=golden_output_dtype, - mla_v_dim=mla_v_dim, - mask_dtype=inputs_dict['mask_dtype'] - ) + # Step 2: Compute golden reference using numpy arrays + golden_output_dtype = test_config.get('expected_dtype') if is_full_quant else None + golden = generator.compute_golden_reference( + inputs_dict['query'], + inputs_dict['key_cache'], + inputs_dict['value_cache'], + inputs_dict['block_tables'], + inputs_dict['q_seq_lens'], + inputs_dict['kv_seq_lens'], + qk_scale, + inputs_dict['mask'], + mask_type, + inputs_dict['k_descale'], + inputs_dict['v_descale'], + inputs_dict['k_offset'], + inputs_dict['v_offset'], + inputs_dict['k_scale_per_head'], + inputs_dict['v_scale_per_head'], + inputs_dict['p_scale'], + output_dtype=golden_output_dtype, + mla_v_dim=mla_v_dim, + mask_dtype=inputs_dict['mask_dtype'] + ) # Step 3: Convert numpy arrays to Tensors for network input query = Tensor(inputs_dict['query'], dtype=q_dtype) @@ -1396,15 +1445,23 @@ def _run_paged_attention_test(generator: PagedAttentionDataGenerator, test_confi # Step 5: Validate accuracy if requested if validate_accuracy: - # Convert output to numpy for validation output_np = output.asnumpy() - - # For quantization, use output dtype for validation and set is_quant flag validate_dtype = test_config.get('expected_dtype', q_dtype) is_quant_test = quant_type != QUANT_UNQUANT max_ctx_len = max(context_lens) - assert generator.validate_accuracy(output_np, golden, validate_dtype, - num_heads, max_ctx_len, is_quant_test) + if reference_impl == REFERENCE_MINDSPORE and ms_ref is not None: + assert ms_ref.validate_accuracy_ms( + output_np, + golden, + validate_dtype, + num_heads, + max_ctx_len, + is_quant_test, + ) + else: + assert generator.validate_accuracy( + output_np, golden, validate_dtype, num_heads, max_ctx_len, is_quant_test + ) # ======================================== @@ -1455,6 +1512,14 @@ def _paged_attention_test(test_func): return test_func +def _tensor_to_numpy(value): + if value is None: + return None + if isinstance(value, Tensor): + return value.asnumpy() + return value + + # ==================== 1. Basic Functionality Tests ==================== # Test basic features: data types, execution modes, simple configurations @@ -2051,7 +2116,8 @@ def test_pa_alibi_with_gqa(): # Test bfloat16 with quantization (int8 KV cache) @_paged_attention_test -def test_pa_bf16_int8_kv_basic(): +@pytest.mark.parametrize('reference_impl', [REFERENCE_NUMPY, REFERENCE_MINDSPORE]) +def test_pa_bf16_int8_kv_basic(reference_impl): """ Feature: PagedAttention - bf16 with int8 KV (Dequant Fusion) Description: Decode with bf16 Q and int8 KV over moderate sequence length @@ -2072,7 +2138,12 @@ def test_pa_bf16_int8_kv_basic(): 'qk_scale': 1.0 / math.sqrt(128), 'quant_type': DEQUANT_FUSION, } - _run_paged_attention_test(generator, test_config, context.GRAPH_MODE) + _run_paged_attention_test( + generator, + test_config, + context.GRAPH_MODE, + reference_impl=reference_impl, + ) @_paged_attention_test -- Gitee