diff --git a/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder.cc b/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder.cc new file mode 100644 index 0000000000000000000000000000000000000000..03863344f84717c31a5fe91828b2c0f30bd7f662 --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flash_attention_encoder.h" +#include "ccsrc/ops/ms_kernels_internal/utils/attention_utils.h" + +namespace ms_custom_ops { + +ShapeArray FlashAttentionEncoderOpFuncImpl::InferShape(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const { + (void)primitive; + auto query_shape = input_infos[kQueryIdx]->GetShape(); + return {query_shape}; +} + +std::vector FlashAttentionEncoderOpFuncImpl::InferType(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const { + (void)primitive; + auto query_type = input_infos[kQueryIdx]->GetType(); + return {query_type}; +} + +internal::InternalOpPtr CustomFlashAttentionEncoder::CreateKernel(const internal::InputsImmutableInfoList &inputs_ii, + const internal::OutputsImmutableInfoList &outputs_ii, + const std::vector &ms_inputs, + const std::vector &ms_outputs) { + param_.headNum = static_cast(ms_inputs[kHeadNumIdx]->GetValueWithCheck()); + param_.qkScale = ms_inputs[kScaleValueIdx]->GetValueWithCheck(); + param_.kvHeadNum = static_cast(ms_inputs[kKvHeadNumIdx]->GetValueWithCheck()); + // 固定为 PA_ENCODER 路径 + param_.calcType = internal::SelfAttentionParam::CalcType::PA_ENCODER; + param_.kernelType = static_cast( + ms_inputs[kKernelTypeIdx]->GetValueWithCheck()); + param_.maskType = static_cast( + ms_inputs[kMaskTypeIdx]->GetValueWithCheck()); + param_.windowSize = static_cast(ms_inputs[kWindowSizeIdx]->GetValueWithCheck()); + param_.cacheType = static_cast( + ms_inputs[kCacheTypeIdx]->GetValueWithCheck()); + // cache_mode: default 0 ND, 1 force NZ + param_.cacheMode = static_cast(ms_inputs[kCacheModeIdx]->GetValueWithCheck()); + + MS_EXCEPTION_IF_CHECK_FAIL(ms_inputs[kQSeqLenIdx]->dtype_id() == TypeId::kNumberTypeInt32, + "q_seq_len must be int32"); + param_.q_seq_len = ms_inputs[kQSeqLenIdx]->GetValueWithCheck>(); + + MS_EXCEPTION_IF_CHECK_FAIL(ms_inputs[kKVSeqLenIdx]->dtype_id() == TypeId::kNumberTypeInt32, + "kv_seq_len must be int32"); + param_.kv_seq_len = ms_inputs[kKVSeqLenIdx]->GetValueWithCheck>(); + + created_flag_ = true; + // 当cache_mode为NZ时,设置Q/K/V和输出为FRACTAL_NZ格式 + if (param_.cacheMode == 1) { + auto inputs_clone = inputs_ii; + auto outputs_clone = outputs_ii; + inputs_clone[static_cast(kQueryIdx)].SetFormat(internal::kFormatFRACTAL_NZ); + inputs_clone[static_cast(kKeyIdx)].SetFormat(internal::kFormatFRACTAL_NZ); + inputs_clone[static_cast(kValueIdx)].SetFormat(internal::kFormatFRACTAL_NZ); + inputs_clone[static_cast(kMaskIdx)].SetFormat(internal::kFormatFRACTAL_NZ); + outputs_clone[static_cast(kAttentionOutIdx)].SetFormat(internal::kFormatFRACTAL_NZ); + return internal::CreateFlashAttentionEncoderOp(inputs_clone, outputs_clone, param_, + internal::kInternalFlashAttentionEncoderOpName); + } + + return internal::CreateFlashAttentionEncoderOp(inputs_ii, outputs_ii, param_, + internal::kInternalFlashAttentionEncoderOpName); +} + +bool CustomFlashAttentionEncoder::UpdateParam(const std::vector &inputs, + const std::vector &outputs) { + if (created_flag_) { + created_flag_ = false; + return true; + } + + auto q_need_recreate = GetSeqLenAndCheckUpdate(inputs[kQSeqLenIdx], ¶m_.q_seq_len); + auto kv_need_recreate = GetSeqLenAndCheckUpdate(inputs[kKVSeqLenIdx], ¶m_.kv_seq_len); + if (q_need_recreate || kv_need_recreate) { + auto ret = internal_op_->UpdateParam(¶m_); + if (ret != internal::kInternalOk) { + MS_LOG(ERROR) << "CustomFlashAttentionEncoder UpdateParam failed, kernel_name: " << kernel_name_; + return false; + } + return true; + } + return true; +} + +void CustomFlashAttentionEncoder::InitKernelInputsOutputsIndex() { + kernel_inputs_index_ = {kQueryIdx, kKeyIdx, kValueIdx, kLayerIdIdx, kMaskIdx, kAlibiCoeffIdx, + kDeqScaleQkIdx, kDeqOffsetQkIdx, kDeqScalePvIdx, kDeqOffsetPvIdx, kQuantPIdx, kLogNIdx}; + kernel_outputs_index_ = {kAttentionOutIdx}; +} + +uint64_t CustomFlashAttentionEncoder::GenerateTilingKey(const std::vector &inputs) { + return InternalTilingCache::GenerateKey(kernel_name_, inputs, param_.q_seq_len, param_.kv_seq_len); +} + +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(flash_attention_encoder, ms_custom_ops::FlashAttentionEncoderOpFuncImpl, ms_custom_ops::CustomFlashAttentionEncoder); + + diff --git a/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder.h b/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder.h new file mode 100644 index 0000000000000000000000000000000000000000..7656f828c72e444185c784b3e6df9824a7907df5 --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder.h @@ -0,0 +1,98 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CCSRC_OPS_MS_KERNELS_INTERNAL_FLASH_ATTENTION_ENCODER_FLASH_ATTENTION_ENCODER_H_ +#define CCSRC_OPS_MS_KERNELS_INTERNAL_FLASH_ATTENTION_ENCODER_FLASH_ATTENTION_ENCODER_H_ + +#include +#include +#include +#include "mindspore/ccsrc/ms_extension/api.h" +#include "mindspore/core/include/ops/base_operator.h" +#include "mindspore/core/include/ops/ops_func_impl/op_func_impl.h" +#include "mindspore/core/include/ops/ops_func_impl/simple_infer.h" +#include "mindspore/core/include/utils/check_convert_utils.h" +#include "internal_kernel_mod.h" +#include "mindspore/core/include/mindapi/ir/tensor.h" +#include "mindspore/ops/kernel/ascend/acl_ir/acl_convert.h" +#include "include/op_creator.h" +#include "include/op_param.h" + +namespace ms_custom_ops { + +enum FlashAttentionEncoderInputIndex : int { + kQueryIdx = 0, + kKeyIdx, + kValueIdx, + kLayerIdIdx, + kMaskIdx, + kAlibiCoeffIdx, + kDeqScaleQkIdx, + kDeqOffsetQkIdx, + kDeqScalePvIdx, + kDeqOffsetPvIdx, + kQuantPIdx, + kLogNIdx, + kQSeqLenIdx, + kKVSeqLenIdx, + kHeadNumIdx, + kScaleValueIdx, + kKvHeadNumIdx, + kMaskTypeIdx, + kKernelTypeIdx, + kWindowSizeIdx, + kCacheTypeIdx, + kCacheModeIdx, + kInputNums +}; + +enum FlashAttentionEncoderOutputIndex : int { kAttentionOutIdx = 0, kOutputNums }; + +class OPS_API FlashAttentionEncoderOpFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override; + std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override; + bool GeneralInferRegistered() const override { return true; } + std::set GetValueDependArgIndices() const override { + return {kQSeqLenIdx, kKVSeqLenIdx, kHeadNumIdx, kScaleValueIdx, kKvHeadNumIdx, kMaskTypeIdx, kKernelTypeIdx, + kWindowSizeIdx, kCacheTypeIdx, kCacheModeIdx}; + }; +}; + +class CustomFlashAttentionEncoder : public InternalKernelMod { + public: + CustomFlashAttentionEncoder() = default; + ~CustomFlashAttentionEncoder() override = default; + + protected: + void InitKernelInputsOutputsIndex() override; + internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, + const internal::OutputsImmutableInfoList &outputs, + const std::vector &ms_inputs, + const std::vector &ms_outputs) override; + bool UpdateParam(const std::vector &inputs, const std::vector &outputs) override; + uint64_t GenerateTilingKey(const std::vector &inputs) override; + + private: + bool created_flag_{false}; + internal::SelfAttentionParam param_{}; +}; + +} // namespace ms_custom_ops + +#endif // CCSRC_OPS_MS_KERNELS_INTERNAL_FLASH_ATTENTION_ENCODER_FLASH_ATTENTION_ENCODER_H_ + + diff --git a/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder_runner.cc b/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder_runner.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1e1f0c6cb88fc41e2020c8ca0c87553952fcd81 --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder_runner.cc @@ -0,0 +1,152 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flash_attention_encoder_runner.h" +#include "ccsrc/utils/utils.h" + +namespace ms_custom_ops { +void FlashAttentionEncoderRunner::SetParam(int64_t head_num, float scale_value, int64_t kv_head_num, + int64_t mask_type, int64_t kernel_type, int64_t window_size, + int64_t cache_type, const std::vector &q_seq_len, + const std::vector &kv_seq_len) { + param_.headNum = static_cast(head_num); + param_.qkScale = static_cast(scale_value); + param_.kvHeadNum = static_cast(kv_head_num); + + param_.calcType = internal::SelfAttentionParam::CalcType::PA_ENCODER; + param_.maskType = static_cast(mask_type); + param_.kernelType = static_cast(kernel_type); + param_.windowSize = static_cast(window_size); + param_.cacheType = static_cast(cache_type); + + param_.q_seq_len = q_seq_len; + param_.kv_seq_len = kv_seq_len; + param_.cacheMode = 0; +} + +bool FlashAttentionEncoderRunner::UpdateParam() { + if (created_flag_) { + created_flag_ = false; + return true; + } + + auto ret = internal_op_->UpdateParam(¶m_); + if (ret != internal::kInternalOk) { + MS_LOG(ERROR) << "Internal FlashAttentionEncoder UpdateParam failed."; + return false; + } + return true; +} + +internal::InternalOpPtr FlashAttentionEncoderRunner::CreateKernel( + const internal::InputsImmutableInfoList &inputs, + const internal::OutputsImmutableInfoList &outputs) { + created_flag_ = true; + // NZ format routing in PyBoost mode when cache_mode == 1 + if (param_.cacheMode == 1) { + auto inputs_clone = inputs; + auto outputs_clone = outputs; + inputs_clone[static_cast(kQueryIdx)].SetFormat(internal::kFormatFRACTAL_NZ); + inputs_clone[static_cast(kKeyIdx)].SetFormat(internal::kFormatFRACTAL_NZ); + inputs_clone[static_cast(kValueIdx)].SetFormat(internal::kFormatFRACTAL_NZ); + inputs_clone[static_cast(kMaskIdx)].SetFormat(internal::kFormatFRACTAL_NZ); + outputs_clone[static_cast(kAttentionOutIdx)].SetFormat(internal::kFormatFRACTAL_NZ); + return internal::CreateFlashAttentionEncoderOp(inputs_clone, outputs_clone, param_, + internal::kInternalFlashAttentionEncoderOpName); + } + return internal::CreateFlashAttentionEncoderOp(inputs, outputs, param_, + internal::kInternalFlashAttentionEncoderOpName); +} + +static std::vector npu_flash_attention_encoder( + const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &value, const std::optional &layer_id, + const std::optional &mask, const std::optional &alibi_coeff, + const std::optional &deq_scale_qk, const std::optional &deq_offset_qk, + const std::optional &deq_scale_pv, const std::optional &deq_offset_pv, + const std::optional &quant_p, const std::optional &logN, + const std::optional &q_seq_len, const std::optional &kv_seq_len, int64_t head_num, + float scale_value, int64_t kv_head_num, int64_t mask_type, int64_t kernel_type, int64_t window_size, + int64_t cache_type, int64_t cache_mode) { + static auto op_name = "FlashAttentionEncoder"; + auto runner = std::make_shared(op_name); + MS_EXCEPTION_IF_NULL(runner); + + // TODO + // TH/TND 必须提供 q_seq_len/kv_seq_len;BSH/BNSD 可为空 + // q_seq_len/kv_seq_len 必须为 int32 + if (!q_seq_len.has_value() || !kv_seq_len.has_value()) { + MS_LOG(EXCEPTION) << "For " << op_name << ", the q_seq_len and kv_seq_len can not be None, but got q_seq_len.has_value(): " + << q_seq_len.has_value() << ", kv_seq_len.has_value(): " << kv_seq_len.has_value(); + } + auto q_seq = GetValueFromTensor>(q_seq_len.value(), op_name, "q_seq_len"); + auto kv_seq = GetValueFromTensor>(kv_seq_len.value(), op_name, "kv_seq_len"); + + runner->SetParam(head_num, scale_value, kv_head_num, mask_type, kernel_type, window_size, cache_type, q_seq, kv_seq); + runner->SetCacheMode(cache_mode); + + // Setup the runner with all parameters to form cache key + runner->Setup(op_name, query, key, value, layer_id, mask, alibi_coeff, deq_scale_qk, deq_offset_qk, deq_scale_pv, deq_offset_pv, + quant_p, logN, q_seq_len, kv_seq_len, head_num, scale_value, kv_head_num, + mask_type, kernel_type, window_size, cache_type, cache_mode); + + // outputs + auto attn_out = ms::Tensor(query.data_type(), query.shape()); + std::vector inputs = {query, + key, + value, + GetTensorOrEmpty(layer_id), + GetTensorOrEmpty(mask), + GetTensorOrEmpty(alibi_coeff), + GetTensorOrEmpty(deq_scale_qk), + GetTensorOrEmpty(deq_offset_qk), + GetTensorOrEmpty(deq_scale_pv), + GetTensorOrEmpty(deq_offset_pv), + GetTensorOrEmpty(quant_p), + GetTensorOrEmpty(logN)}; + std::vector outputs = {attn_out}; + runner->GetOrCreateKernel(inputs, outputs); + runner->Run(inputs, outputs); + return outputs; +} + +static auto pyboost_flash_attention_encoder( + const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &value, const std::optional &layer_id, + const std::optional &mask, const std::optional &alibi_coeff, + const std::optional &deq_scale_qk, const std::optional &deq_offset_qk, + const std::optional &deq_scale_pv, const std::optional &deq_offset_pv, + const std::optional &quant_p, const std::optional &logN, + const std::optional &q_seq_len, const std::optional &kv_seq_len, int64_t head_num, + float scale_value, int64_t kv_head_num, int64_t mask_type, int64_t kernel_type, int64_t window_size, + int64_t cache_type, int64_t cache_mode) { + return ms::pynative::PyboostRunner::Call<1>(npu_flash_attention_encoder, query, key, value, layer_id, mask, alibi_coeff, + deq_scale_qk, deq_offset_qk, deq_scale_pv, deq_offset_pv, quant_p, + logN, q_seq_len, kv_seq_len, head_num, scale_value, kv_head_num, + mask_type, kernel_type, window_size, cache_type, cache_mode); +} +} // namespace ms_custom_ops + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("flash_attention_encoder", &ms_custom_ops::pyboost_flash_attention_encoder, "Flash Attention Encoder", + pybind11::arg("query"), pybind11::arg("key"), pybind11::arg("value"), pybind11::arg("layer_id") = std::nullopt, + pybind11::arg("mask") = std::nullopt, pybind11::arg("alibi_coeff") = std::nullopt, + pybind11::arg("deq_scale_qk") = std::nullopt, pybind11::arg("deq_offset_qk") = std::nullopt, + pybind11::arg("deq_scale_pv") = std::nullopt, pybind11::arg("deq_offset_pv") = std::nullopt, + pybind11::arg("quant_p") = std::nullopt, pybind11::arg("logN") = std::nullopt, + pybind11::arg("q_seq_len") = std::nullopt, pybind11::arg("kv_seq_len") = std::nullopt, + pybind11::arg("head_num") = 32, pybind11::arg("scale_value") = 1.0, pybind11::arg("kv_head_num") = 0, + pybind11::arg("mask_type") = 0, pybind11::arg("kernel_type") = 0, pybind11::arg("window_size") = 0, + pybind11::arg("cache_type") = 0, pybind11::arg("cache_mode") = 0); +} diff --git a/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder_runner.h b/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..acb47607b2ea0b5b3d24e6a768fcf8948fcd703b --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/flash_attention_encoder/flash_attention_encoder_runner.h @@ -0,0 +1,50 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CCSRC_OPS_MS_KERNELS_INTERNAL_FLASH_ATTENTION_ENCODER_FLASH_ATTENTION_ENCODER_RUNNER_H_ +#define CCSRC_OPS_MS_KERNELS_INTERNAL_FLASH_ATTENTION_ENCODER_FLASH_ATTENTION_ENCODER_RUNNER_H_ + +#include +#include +#include + +#include "internal_kernel_mod.h" +#include "internal_pyboost_runner.h" +#include "mindspore/core/include/mindapi/ir/tensor.h" +#include "ccsrc/ops/ms_kernels_internal/utils/attention_utils.h" +#include "flash_attention_encoder.h" + +namespace ms_custom_ops { +class FlashAttentionEncoderRunner : public InternalPyboostRunner { + public: + using InternalPyboostRunner::InternalPyboostRunner; + void SetParam(int64_t head_num, float scale_value, int64_t kv_head_num, int64_t mask_type, int64_t kernel_type, + int64_t window_size, int64_t cache_type, const std::vector &q_seq_len, + const std::vector &kv_seq_len); + void SetCacheMode(int64_t cache_mode) { param_.cacheMode = static_cast(cache_mode); } + + protected: + bool UpdateParam() override; + internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, + const internal::OutputsImmutableInfoList &outputs) override; + + private: + internal::SelfAttentionParam param_{}; + bool created_flag_{false}; +}; +} // namespace ms_custom_ops + +#endif // CCSRC_OPS_MS_KERNELS_INTERNAL_FLASH_ATTENTION_ENCODER_FLASH_ATTENTION_ENCODER_RUNNER_H_ diff --git a/tests/st/test_custom_flash_attention_encoder.py b/tests/st/test_custom_flash_attention_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..02d81af8884ea3c8bb203e70ca046992ad13aed3 --- /dev/null +++ b/tests/st/test_custom_flash_attention_encoder.py @@ -0,0 +1,640 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import math +import numpy as np +import pytest +import mindspore as ms +import ms_custom_ops +from mindspore import Tensor, context, ops, nn +from mindspore.common.np_dtype import bfloat16 as np_bfloat16 + +MASK_NONE = 0 +MASK_NORM = 1 +MASK_ALIBI = 2 +MASK_NORM_COMPRESS = 3 +MASK_SWA_NORM = 7 +MASK_SWA_COMPRESS = 8 + +KERNEL_DEFAULT = 0 +KERNEL_HIGH_PRECISION = 1 + +CACHE_NORM = 0 +CACHE_SWA = 1 + +class FlashAttentionTestSuite: + """Unified test suite for Flash Attention encoder operations. + + Provides comprehensive testing with support for different layouts (TH/TND), + data types (fp16/bf16), mask types, and GQA configurations. + """ + + def __init__(self, rng_seed: int = 2024): + """Initialize test suite with random generator.""" + self.rng = np.random.default_rng(rng_seed) + + # ========== Utility Methods ========== + + @staticmethod + def _ms_tensor(x: np.ndarray) -> Tensor: + """Convert NumPy array to MindSpore Tensor, handling bfloat16.""" + if x is None: + return None + return Tensor(x.astype(np.float32)).astype(ms.bfloat16) if x.dtype == np_bfloat16 else Tensor(x) + + @staticmethod + def _get_alibi_slopes(n_heads: int) -> np.ndarray: + """Generate ALIBI slopes for positional bias""" + n = 2 ** int(np.floor(np.log2(n_heads))) + m0 = 2.0 ** (-8.0 / n) + slopes = np.array([m0 ** i for i in range(1, n + 1)], dtype=np.float32) + + if n < n_heads: + m1 = 2.0 ** (-4.0 / n) + # Generate additional slopes with step size 2 + additional_count = n_heads - n + mm = np.array([m1 ** i for i in range(1, 1 + 2 * additional_count, 2)], dtype=np.float32) + slopes = np.concatenate([slopes, mm], axis=0) + + return slopes + + # ========== Input Generation ========== + + def build_inputs(self, q_seq: np.ndarray, kv_seq: np.ndarray, heads: int, head_dim: int, + np_dtype: np.dtype, kv_heads: int = None, layout: str = 'TH') -> tuple: + """Build test inputs for either TH or TND layout.""" + q_ntok = int(q_seq.sum()) + kv_ntok = int(kv_seq.sum()) + kvh = heads if kv_heads is None else kv_heads + + if layout == 'TH': + # TH layout: [tokens, heads*head_dim] + q_np = self.rng.standard_normal((q_ntok, heads * head_dim)).astype(np_dtype) + k_np = self.rng.standard_normal((kv_ntok, kvh * head_dim)).astype(np_dtype) + v_np = self.rng.standard_normal((kv_ntok, kvh * head_dim)).astype(np_dtype) + else: # TND layout + # TND layout: [tokens, heads, head_dim] + q_np = self.rng.standard_normal((q_ntok, heads, head_dim)).astype(np_dtype) + k_np = self.rng.standard_normal((kv_ntok, kvh, head_dim)).astype(np_dtype) + v_np = self.rng.standard_normal((kv_ntok, kvh, head_dim)).astype(np_dtype) + + return q_np, k_np, v_np + + # ========== Mask Generation ========== + + @staticmethod + def _get_pre_mask_coef(np_dtype: np.dtype, is_alibi: bool = False) -> float: + """Get pre-mask coefficient for operator input (following reference implementation).""" + if np_dtype == np.float16: + return -10000.0 + elif np_dtype == np_bfloat16 and is_alibi: + return -float("inf") + elif np_dtype == np.float32 and is_alibi: + return 1.0 + else: # bf16 non-alibi + return 1.0 + + @staticmethod + def _get_post_mask_coef(np_dtype: np.dtype, is_alibi: bool = False) -> float: + """Get post-mask coefficient for golden computation (following reference implementation).""" + if np_dtype == np.float16: + return 1.0 + elif np_dtype == np_bfloat16 and is_alibi: + return 1.0 + elif np_dtype == np.float32 and is_alibi: + return 1.0 + else: # bf16 non-alibi + return -3e38 + + @classmethod + def _build_swa_mask(cls, max_seq: int, window_size: int, mask_coef: float, mask_shape: tuple, + is_compress: bool, head_dim: int = None) -> np.ndarray: + """Build SWA mask (normal or compressed)""" + if is_compress: + swa_mask = np.ones((1, 512, 512), dtype=np.float32) * mask_coef + + # Calculate true window size + pp_n = 128 if head_dim <= 128 else 64 + if window_size <= pp_n * 3: + true_size = window_size + elif window_size % pp_n == 0: + true_size = pp_n * 3 + else: + true_size = pp_n * 2 + window_size % pp_n + + # Apply SWA compress pattern + triu_mask = np.triu(swa_mask, 1) + tril_mask = np.tril(swa_mask, -true_size) + swa_mask = triu_mask + tril_mask + else: + # Normal SWA: use the provided mask_shape + swa_mask = np.ones(mask_shape, dtype=np.float32) * mask_coef + + # For encoder: apply SWA pattern if window_size < max_seq + if window_size < max_seq: + triu_mask = np.triu(swa_mask, 1) + tril_mask = np.tril(swa_mask, -window_size) + swa_mask = triu_mask + tril_mask + else: + # Window larger than sequence: just upper triangle + swa_mask = np.triu(swa_mask, 1) + return swa_mask + + @classmethod + def _build_alibi_mask(cls, max_seq: int, heads: int, alibi_slopes: np.ndarray, + mask_coef: float, mask_shape: tuple) -> np.ndarray: + """Build ALIBI mask with positional bias.""" + if alibi_slopes is None or heads is None: + # Fallback to triangular mask + mask = np.ones(mask_shape, dtype=np.float32) * mask_coef + return np.triu(mask, 1) + + # Create base triangular mask + ALIBI bias + mask = np.ones(mask_shape, dtype=np.float32) * mask_coef + base_triu = np.triu(mask, 1) + + # Add ALIBI positional bias + q_pos = np.arange(max_seq, dtype=np.float32).reshape(-1, 1) + k_pos = np.arange(max_seq, dtype=np.float32).reshape(1, -1) + alibi_bias = k_pos - q_pos # Relative position matrix + + batch = mask_shape[0] + for b in range(batch): + for h in range(heads): + base_triu[b, h] += alibi_slopes[h] * alibi_bias + + return base_triu + + def build_operator_mask(self, mask_type: int, max_seq: int, np_dtype: np.dtype, batch: int, + window_size: int = 0, heads: int = None, head_dim: int = None, + alibi_slopes: np.ndarray = None) -> np.ndarray: + """Build mask tensor for operator input.""" + # Get pre-mask coefficient for operator input and determine shape + mask_coef = self._get_pre_mask_coef(np_dtype, mask_type == MASK_ALIBI) + if mask_type == MASK_ALIBI: + mask_shape = (batch, heads, max_seq, max_seq) + elif mask_type == MASK_NORM: + mask_shape = (batch, max_seq, max_seq) + else: + mask_shape = (max_seq, max_seq) + + # Build mask based on type + if mask_type == MASK_NONE: + return None + elif mask_type == MASK_NORM: + mask = np.ones(mask_shape, dtype=np.float32) * mask_coef + return np.triu(mask, 1) + elif mask_type == MASK_NORM_COMPRESS: + # NORM_COMPRESS uses fixed (128, 128) shape as in reference implementation + compress_mask = np.ones((128, 128), dtype=np.float32) * mask_coef + result_mask = np.triu(compress_mask, 1) + return result_mask + elif mask_type in (MASK_SWA_NORM, MASK_SWA_COMPRESS): + swa_mask = self._build_swa_mask(max_seq, window_size, mask_coef, mask_shape, + mask_type == MASK_SWA_COMPRESS, head_dim) + # SWA_COMPRESS returns (1, 512, 512), reshape to (512, 512) for operator as in reference + if mask_type == MASK_SWA_COMPRESS and swa_mask.shape == (1, 512, 512): + return swa_mask.reshape(512, 512) + return swa_mask + elif mask_type == MASK_ALIBI: + return self._build_alibi_mask(max_seq, heads, alibi_slopes, mask_coef, mask_shape) + else: + return np.zeros(mask_shape, dtype=np.float32) + + def build_golden_mask(self, mask_type: int, ql: int, kl: int, np_dtype: np.dtype, + batch_idx: int = 0, head_idx: int = 0, op_mask: np.ndarray = None, **kwargs) -> np.ndarray: + """Build mask slice for golden reference computation.""" + if mask_type == MASK_NONE: + return None + + # Special handling for compress masks: golden calculation uses actual sequence length + if mask_type == MASK_NORM_COMPRESS: + # Generate actual-size mask for golden computation (not from op_mask) + pre_mask_coef = self._get_pre_mask_coef(np_dtype, False) + post_mask_coef = self._get_post_mask_coef(np_dtype, False) + golden_mask = np.ones((ql, kl), dtype=np.float32) * pre_mask_coef + result_mask = np.triu(golden_mask, 1) + # Apply post coefficient for golden computation + return result_mask * post_mask_coef + + elif mask_type == MASK_SWA_COMPRESS: + # Generate actual-size SWA mask for golden computation (not from 512x512 op_mask) + pre_mask_coef = self._get_pre_mask_coef(np_dtype, False) + post_mask_coef = self._get_post_mask_coef(np_dtype, False) + golden_mask = np.ones((ql, kl), dtype=np.float32) * pre_mask_coef + + # Apply SWA pattern with actual sequence length + window_size = kwargs.get('window_size', 0) + max_seq = max(ql, kl) + if window_size < max_seq: + triu_mask = np.triu(golden_mask, 1) + tril_mask = np.tril(golden_mask, -window_size) + result_mask = triu_mask + tril_mask + else: + result_mask = np.triu(golden_mask, 1) + + # Apply post coefficient for golden computation + return result_mask * post_mask_coef + + # For other mask types, extract from op_mask + if op_mask is None: + return None + + # Extract appropriate slice based on mask dimensions + if mask_type == MASK_ALIBI and op_mask.ndim == 4: + mask_slice = op_mask[batch_idx, head_idx, :ql, :kl] + elif op_mask.ndim >= 3: + mask_slice = op_mask[batch_idx, :ql, :kl] + else: + mask_slice = op_mask[:ql, :kl] + + # Apply post coefficient for golden computation + post_mask_coef = self._get_post_mask_coef(np_dtype, mask_type == MASK_ALIBI) + return mask_slice.astype(np.float32) * post_mask_coef + + # ========== Golden Reference ========== + + def compute_golden_attention(self, q: np.ndarray, k: np.ndarray, v: np.ndarray, q_seq: np.ndarray, + kv_seq: np.ndarray, heads: int, scale: float, kv_heads: int = None, + mask: np.ndarray = None, mask_type: int = MASK_NONE, **kwargs) -> np.ndarray: + """Compute golden reference attention output.""" + kv_heads = kv_heads or heads + head_dim = q.shape[-1] // heads + q_tokens = q.reshape(-1, heads, head_dim).astype(np.float32) + k_tokens = k.reshape(-1, kv_heads, head_dim).astype(np.float32) + v_tokens = v.reshape(-1, kv_heads, head_dim).astype(np.float32) + + out = np.zeros_like(q_tokens) + q_off = kv_off = 0 + + for b in range(len(q_seq)): + ql, kl = int(q_seq[b]), int(kv_seq[b]) + if ql == 0: + continue + + qs = q_tokens[q_off:q_off + ql] + if kl == 0: + out[q_off:q_off + ql] = 0.0 + else: + ks = k_tokens[kv_off:kv_off + kl] + vs = v_tokens[kv_off:kv_off + kl] + + for h in range(heads): + kh = h // (heads // kv_heads) # GQA support + logits = np.einsum("qd,kd->qk", qs[:, h, :], ks[:, kh, :]) * scale + + # Apply mask if present + mask_slice = self.build_golden_mask( + mask_type, ql, kl, kwargs.get('np_dtype', np.float16), b, h, mask, + window_size=kwargs.get('window_size', 0) + ) + if mask_slice is not None: + logits += mask_slice + + # Softmax and attention + logits_max = np.max(logits, axis=1, keepdims=True) + exp_logits = np.exp(logits - logits_max) + attn_weights = exp_logits / np.sum(exp_logits, axis=1, keepdims=True) + out[q_off:q_off + ql, h, :] = np.einsum("qk,kd->qd", attn_weights, vs[:, kh, :]) + + q_off += ql + kv_off += kl + + return out.reshape(-1, heads * head_dim) + + # ========== Result Validation ========== + + def validate_output(self, out: np.ndarray, golden: np.ndarray, np_dtype: np.dtype, + heads: int, max_seq: int) -> bool: + """Validate operator output against golden reference with adaptive precision.""" + golden_flat = golden.flatten().astype(np.float32) + out_flat = out.flatten().astype(np.float32) + diff = np.abs(golden_flat - out_flat) + out_len = out_flat.shape[0] + max_diff = np.max(diff) + + # Legacy standard with fixed ratios + if np_dtype == np_bfloat16: + ratios = [0.001, 0.001, 0.005, 0.005] # [rel_loose, abs_loose, rel_strict, abs_strict] + else: # fp16 + ratios = [0.001, 0.001, 0.005, 0.005] + + # Calculate accuracy metrics + limit_error = np.maximum(np.abs(golden_flat) * ratios[0], ratios[1]) + strict_limit_error = np.maximum(np.abs(golden_flat) * ratios[2], ratios[3]) + error_count = np.sum(diff > limit_error) + strict_error_count = np.sum(diff > strict_limit_error) + + accuracy_loose = 1.0 - float(error_count) / out_len + accuracy_strict = 1.0 - float(strict_error_count) / out_len + + print(f"Max difference: {max_diff:.6f}") + print(f"Loose accuracy (1/1000): {accuracy_loose:.6f}") + print(f"Strict accuracy (5/1000): {accuracy_strict:.6f}") + + # Legacy pass calculation with data type distinction + if np_dtype == np_bfloat16: + legacy_pass = (float(strict_error_count) / max(1, out_len)) <= ratios[2] + else: + legacy_pass = (float(strict_error_count) / max(1, out_len)) <= ratios[0] + + # Adaptive validation based on computation complexity + calc_times = heads * max_seq + 4 + if np_dtype == np_bfloat16: + error_factor = 2 ** (-7 if calc_times < 2048 else -6) + elif np_dtype == np.float16: + error_factor = 2 ** (-8 if calc_times < 2048 else -7) + else: # float32 + if calc_times < 2048: + error_factor = 2 ** (-11) + elif calc_times < 16384: + error_factor = 2 ** (-10) + else: + error_factor = 2 ** (-9) + + # Adaptive threshold: max(|golden|, 1.0) * error_factor + error_threshold = np.maximum(np.abs(golden_flat), 1.0) * error_factor + adaptive_pass = np.all(diff <= error_threshold) + + print(f"Calculation complexity: {calc_times}") + print(f"Error factor: {error_factor:.6e}") + print(f"Adaptive precision test: {'PASS' if adaptive_pass else 'FAIL'}") + print(f"Legacy precision test: {'PASS' if legacy_pass else 'FAIL'}") + + return adaptive_pass or legacy_pass + + # ========== Dynamic Shape Configuration ========== + + def set_dynamic_shapes(self, net, test_case) -> None: + """Set dynamic input shapes for the network.""" + ms_dtype = ms.float16 if test_case.np_dtype == np.float16 else ms.bfloat16 + + # Input shapes depend on layout + if test_case.layout == 'TH': + q_shape = [None, test_case.hidden] + k_shape = v_shape = [None, test_case.kv_heads * test_case.head_dim] + else: # TND layout + q_shape = [None, test_case.heads, test_case.head_dim] + k_shape = v_shape = [None, test_case.kv_heads, test_case.head_dim] + + # Mask shape depends on mask type + if test_case.mask_type == MASK_ALIBI: + mask_shape = [None, None, None, None] + elif test_case.mask_type == MASK_NORM: + mask_shape = [None, None, None] + else: + mask_shape = [None, None] + mask_tensor = Tensor(shape=mask_shape, dtype=ms_dtype) if test_case.mask_type != MASK_NONE else None + # ALIBI slopes (optional) + alibi_dyn = None + + net.set_inputs( + Tensor(shape=q_shape, dtype=ms_dtype), # q + Tensor(shape=k_shape, dtype=ms_dtype), # k + Tensor(shape=v_shape, dtype=ms_dtype), # v + mask_tensor, # mask + alibi_dyn, # alibi_slopes + Tensor(shape=[None], dtype=ms.int32), # q_seq_lens + Tensor(shape=[None], dtype=ms.int32) # kv_seq_lens + ) + + # ========== Main Test Execution ========== + + def run_test_case(self, test_case, dynamic: bool = False) -> None: + """Execute a complete test case with all necessary components.""" + # 1. Generate ALIBI slopes if needed + alibi_slopes = self._get_alibi_slopes(test_case.heads) if (test_case.mask_type == MASK_ALIBI or test_case.use_alibi) else None + + # 2. Build inputs based on layout + q_data, k_data, v_data = self.build_inputs( + test_case.q_seq, test_case.kv_seq, + test_case.heads, test_case.head_dim, test_case.np_dtype, + test_case.kv_heads, test_case.layout + ) + + # 3. Build operator mask + mask_np = self.build_operator_mask( + test_case.mask_type, test_case.max_seq, test_case.np_dtype, test_case.batch, + window_size=test_case.window_size, heads=test_case.heads, head_dim=test_case.head_dim, + alibi_slopes=alibi_slopes + ) + + # 4. Prepare inputs for golden calculation (flatten TND to TH layout) + if test_case.layout == 'TND': + q_golden = q_data.reshape(q_data.shape[0], -1).astype(np.float32) + k_golden = k_data.reshape(k_data.shape[0], -1).astype(np.float32) + v_golden = v_data.reshape(v_data.shape[0], -1).astype(np.float32) + else: # TH layout + q_golden = q_data.astype(np.float32) + k_golden = k_data.astype(np.float32) + v_golden = v_data.astype(np.float32) + + # 5. Compute golden reference + golden = self.compute_golden_attention( + q_golden, k_golden, v_golden, + test_case.q_seq, test_case.kv_seq, test_case.heads, test_case.scale_value, + kv_heads=test_case.kv_heads, mask=mask_np, mask_type=test_case.mask_type, + alibi_slopes=alibi_slopes, window_size=test_case.window_size, np_dtype=test_case.np_dtype + ).astype(test_case.np_dtype) + + # 6. Create and configure network + net = FlashAttentionEncoderNet( + test_case.heads, test_case.scale_value, test_case.kv_heads, test_case.mask_type, + test_case.kernel_type, test_case.window_size, test_case.cache_type + ) + + if dynamic: + self.set_dynamic_shapes(net, test_case) + + # 7. Execute operator + ms_dtype = ms.float16 if test_case.np_dtype == np.float16 else ms.bfloat16 + out = net( + self._ms_tensor(q_data), self._ms_tensor(k_data), self._ms_tensor(v_data), + self._ms_tensor(mask_np).astype(ms_dtype) if mask_np is not None else None, + None, + self._ms_tensor(test_case.q_seq), self._ms_tensor(test_case.kv_seq) + ) + + # 8. Verify results + out_np = (out.float().asnumpy() if test_case.np_dtype == np_bfloat16 else out.asnumpy()).astype(np.float32) + assert self.validate_output(out_np, golden.astype(np.float32), test_case.np_dtype, test_case.heads, int(test_case.q_seq.max())) + + +class FlashAttentionEncoderNet(nn.Cell): + """MindSpore network wrapper for flash_attention_encoder operator.""" + + def __init__(self, heads: int, scale_value: float, kv_heads: int, mask_type: int, + kernel_type: int, window_size: int, cache_type: int): + super().__init__() + self.heads = heads + self.scale_value = scale_value + self.kv_heads = kv_heads + self.mask_type = mask_type + self.kernel_type = kernel_type + self.window_size = window_size + self.cache_type = cache_type + # determine execution mode once during initialization + self._is_pynative = (context.get_context("mode") == context.PYNATIVE_MODE) + + def construct(self, q, k, v, mask, alibi_slopes, q_seq_lens, kv_seq_lens): + if self._is_pynative: + q_lens_cpu = q_seq_lens.move_to("CPU") + kv_lens_cpu = kv_seq_lens.move_to("CPU") + else: + q_lens_cpu = ops.move_to(q_seq_lens, "CPU") + kv_lens_cpu = ops.move_to(kv_seq_lens, "CPU") + return ms_custom_ops.flash_attention_encoder( + q, k, v, None, mask, alibi_slopes, None, None, None, None, None, None, + q_lens_cpu, kv_lens_cpu, + self.heads, self.scale_value, self.kv_heads, self.mask_type, + self.kernel_type, self.window_size, self.cache_type) + + +class TestCase: + """Configuration for a single test case.""" + def __init__(self, layout: str, heads: int, head_dim: int, q_seq: list, kv_seq: list, + np_dtype: np.dtype, kv_heads: int = None, mask_type: int = MASK_NONE, + kernel_type: int = KERNEL_DEFAULT, window_size: int = 0, + cache_type: int = CACHE_NORM, use_alibi: bool = False): + self.layout = layout + self.heads = heads + self.head_dim = head_dim + self.q_seq = np.array(q_seq, dtype=np.int32) + self.kv_seq = np.array(kv_seq, dtype=np.int32) + self.np_dtype = np_dtype + self.kv_heads = kv_heads or heads + self.mask_type = mask_type + self.kernel_type = kernel_type + self.window_size = window_size + self.cache_type = cache_type + self.use_alibi = use_alibi + self.q_ntokens = int(sum(q_seq)) + self.kv_ntokens = int(sum(kv_seq)) + self.hidden = heads * head_dim + self.max_seq = max(max(q_seq), max(kv_seq)) + self.batch = len(q_seq) + self.scale_value = 1.0 / math.sqrt(float(head_dim)) + + +def _run_test_with_case(test_case: TestCase, run_mode: int, rng_seed: int, dynamic: bool = False): + """Execute a test case with the given run mode using the unified test suite.""" + context.set_context(device_target="Ascend", mode=run_mode) + test_suite = FlashAttentionTestSuite(rng_seed) + test_suite.run_test_case(test_case, dynamic) + + +# Common test decorators +def _flash_attention_test(test_func): + """Apply common decorators for flash attention encoder tests.""" + decorators = [ + pytest.mark.level0, + pytest.mark.platform_ascend910b, + pytest.mark.env_onecard, + pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]), + pytest.mark.parametrize('dynamic', [True, False]), + pytest.mark.parametrize('np_dtype', [np.float16, np_bfloat16]) + ] + + for decorator in reversed(decorators): + test_func = decorator(test_func) + return test_func + + +@_flash_attention_test +def test_flash_attention_encoder_th_layout(run_mode, dynamic, np_dtype): + """Test TH layout with both fp16 and bf16 data types.""" + test_case = TestCase('TH', 4, 32, [48, 16], [48, 16], np_dtype) + _run_test_with_case(test_case, run_mode, 2027, dynamic) + + +@_flash_attention_test +def test_flash_attention_encoder_tnd_layout(run_mode, dynamic, np_dtype): + """Test TND layout with both fp16 and bf16 data types.""" + test_case = TestCase('TND', 2, 64, [20, 12], [20, 12], np_dtype) + _run_test_with_case(test_case, run_mode, 2028, dynamic) + + +@_flash_attention_test +def test_flash_attention_encoder_th_high_precision(run_mode, dynamic, np_dtype): + """Test TH layout with high precision kernel for both fp16 and bf16.""" + test_case = TestCase('TH', 4, 32, [48, 16], [48, 16], np_dtype, kernel_type=KERNEL_HIGH_PRECISION) + _run_test_with_case(test_case, run_mode, 2033, dynamic) + + +@_flash_attention_test +def test_flash_attention_encoder_tnd_high_precision(run_mode, dynamic, np_dtype): + """Test TND layout with high precision kernel for both fp16 and bf16.""" + test_case = TestCase('TND', 2, 64, [20, 12], [20, 12], np_dtype, kernel_type=KERNEL_HIGH_PRECISION) + _run_test_with_case(test_case, run_mode, 2034, dynamic) + + +@_flash_attention_test +def test_flash_attention_encoder_th_norm_mask(run_mode, dynamic, np_dtype): + """Test TH layout with normal mask for both fp16 and bf16.""" + test_case = TestCase('TH', 4, 32, [48, 16], [48, 16], np_dtype, mask_type=MASK_NORM) + _run_test_with_case(test_case, run_mode, 2041, dynamic) + + +@_flash_attention_test +def test_flash_attention_encoder_tnd_norm_mask(run_mode, dynamic, np_dtype): + """Test TND layout with normal mask for both fp16 and bf16.""" + test_case = TestCase('TND', 2, 64, [20, 12], [20, 12], np_dtype, mask_type=MASK_NORM) + _run_test_with_case(test_case, run_mode, 2042, dynamic) + + +@_flash_attention_test +def test_flash_attention_encoder_tnd_gqa(run_mode, dynamic, np_dtype): + """Test TND layout with GQA (Grouped Query Attention) for both fp16 and bf16.""" + test_case = TestCase('TND', 4, 32, [24, 16], [24, 16], np_dtype, kv_heads=2) + _run_test_with_case(test_case, run_mode, 2043, dynamic) + + +@_flash_attention_test +def test_flash_attention_encoder_tnd_swa_full_window(run_mode, dynamic, np_dtype): + """Test TND layout with SWA full window for both fp16 and bf16.""" + test_case = TestCase('TND', 2, 64, [20, 12], [20, 12], np_dtype, + mask_type=MASK_SWA_NORM, window_size=20, cache_type=CACHE_SWA) + _run_test_with_case(test_case, run_mode, 2044, dynamic) + + +@_flash_attention_test +def test_flash_attention_encoder_tnd_swa_compress(run_mode, dynamic, np_dtype): + """Test TND layout with SWA compress mask for both fp16 and bf16.""" + test_case = TestCase('TND', 2, 64, [20, 12], [20, 12], np_dtype, + mask_type=MASK_SWA_COMPRESS, window_size=8, cache_type=CACHE_SWA) + _run_test_with_case(test_case, run_mode, 2053, dynamic) + + +@_flash_attention_test +def test_flash_attention_encoder_th_nomask(run_mode, dynamic, np_dtype): + """Test TH layout with no mask for both fp16 and bf16.""" + test_case = TestCase('TH', 2, 64, [20, 12], [20, 12], np_dtype) + _run_test_with_case(test_case, run_mode, 2050, dynamic) + + +@_flash_attention_test +def test_flash_attention_encoder_tnd_norm_compress(run_mode, dynamic, np_dtype): + """Test TND layout with normal compress mask for both fp16 and bf16.""" + test_case = TestCase('TND', 2, 64, [20, 12], [20, 12], np_dtype, mask_type=MASK_NORM_COMPRESS) + _run_test_with_case(test_case, run_mode, 2051, dynamic) + + +@_flash_attention_test +def test_flash_attention_encoder_tnd_alibi(run_mode, dynamic, np_dtype): + """Test TND layout with ALIBI mask for both fp16 and bf16 - critical for bf16 special handling.""" + test_case = TestCase('TND', 4, 32, [24, 16], [24, 16], np_dtype, + mask_type=MASK_ALIBI, use_alibi=True) + _run_test_with_case(test_case, run_mode, 2052, dynamic) + diff --git a/tests/st/test_custom_flash_attention_encoder_nz.py b/tests/st/test_custom_flash_attention_encoder_nz.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2a1e14aafcc11412e4a61442a871d6acb1c294 --- /dev/null +++ b/tests/st/test_custom_flash_attention_encoder_nz.py @@ -0,0 +1,174 @@ +import math +import numpy as np +import pytest +import mindspore as ms +import ms_custom_ops +from mindspore import Tensor, context, ops, nn +from mindspore.common.np_dtype import bfloat16 as np_bfloat16 + +from test_custom_flash_attention_encoder import ( + FlashAttentionTestSuite, + TestCase, + MASK_NONE, + MASK_NORM, + MASK_ALIBI, + MASK_SWA_NORM, + MASK_SWA_COMPRESS, + CACHE_NORM, + CACHE_SWA, +) + + +class FlashAttentionEncoderNzNet(nn.Cell): + def __init__(self, heads: int, scale_value: float, kv_heads: int, mask_type: int, + kernel_type: int, window_size: int, cache_type: int, cache_mode: int = 1): + super().__init__() + self.heads = heads + self.scale_value = scale_value + self.kv_heads = kv_heads + self.mask_type = mask_type + self.kernel_type = kernel_type + self.window_size = window_size + self.cache_type = cache_type + self.cache_mode = cache_mode + self._is_pynative = (context.get_context("mode") == context.PYNATIVE_MODE) + + def construct(self, q, k, v, mask, alibi_slopes, q_seq_lens, kv_seq_lens): + if self._is_pynative: + q_lens_cpu = q_seq_lens.move_to("CPU") + kv_lens_cpu = kv_seq_lens.move_to("CPU") + else: + q_lens_cpu = ops.move_to(q_seq_lens, "CPU") + kv_lens_cpu = ops.move_to(kv_seq_lens, "CPU") + return ms_custom_ops.flash_attention_encoder( + q, k, v, None, mask, alibi_slopes, None, None, None, None, None, None, + q_lens_cpu, kv_lens_cpu, + self.heads, self.scale_value, self.kv_heads, self.mask_type, + self.kernel_type, self.window_size, self.cache_type, self.cache_mode) + + +def _ms_tensor(x: np.ndarray) -> Tensor: + if x is None: + return None + return Tensor(x.astype(np.float32)).astype(ms.bfloat16) if x.dtype == np_bfloat16 else Tensor(x) + + +def _run_nz_test(test_case: TestCase, run_mode: int, rng_seed: int, dynamic: bool = False): + context.set_context(device_target="Ascend", mode=run_mode) + test_suite = FlashAttentionTestSuite(rng_seed) + + # Build ALIBI slopes if needed + alibi_slopes = ( + FlashAttentionTestSuite._get_alibi_slopes(test_case.heads) + if (test_case.mask_type == MASK_ALIBI or getattr(test_case, 'use_alibi', False)) + else None + ) + q_data, k_data, v_data = test_suite.build_inputs( + test_case.q_seq, test_case.kv_seq, + test_case.heads, test_case.head_dim, test_case.np_dtype, + test_case.kv_heads, test_case.layout + ) + + mask_np = test_suite.build_operator_mask( + test_case.mask_type, test_case.max_seq, test_case.np_dtype, test_case.batch, + window_size=test_case.window_size, heads=test_case.heads, head_dim=test_case.head_dim, + alibi_slopes=alibi_slopes + ) + + if test_case.layout == 'TND': + q_golden = q_data.reshape(q_data.shape[0], -1).astype(np.float32) + k_golden = k_data.reshape(k_data.shape[0], -1).astype(np.float32) + v_golden = v_data.reshape(v_data.shape[0], -1).astype(np.float32) + else: + q_golden = q_data.astype(np.float32) + k_golden = k_data.astype(np.float32) + v_golden = v_data.astype(np.float32) + + golden = test_suite.compute_golden_attention( + q_golden, k_golden, v_golden, + test_case.q_seq, test_case.kv_seq, test_case.heads, test_case.scale_value, + kv_heads=test_case.kv_heads, mask=mask_np, mask_type=test_case.mask_type, + alibi_slopes=alibi_slopes, window_size=test_case.window_size, np_dtype=test_case.np_dtype + ).astype(test_case.np_dtype) + + net = FlashAttentionEncoderNzNet( + test_case.heads, test_case.scale_value, test_case.kv_heads, test_case.mask_type, + test_case.kernel_type, test_case.window_size, test_case.cache_type, cache_mode=1 + ) + + if dynamic: + # Reuse dynamic shape setup from original suite + test_suite.set_dynamic_shapes(net, test_case) + + ms_dtype = ms.float16 if test_case.np_dtype == np.float16 else ms.bfloat16 + + # Convert upstream ND tensors to NZ before calling operator + q_ms = _ms_tensor(q_data).astype(ms_dtype) + k_ms = _ms_tensor(k_data).astype(ms_dtype) + v_ms = _ms_tensor(v_data).astype(ms_dtype) + q_nz = ms_custom_ops.trans_data(q_ms, transdata_type=1) + k_nz = ms_custom_ops.trans_data(k_ms, transdata_type=1) + v_nz = ms_custom_ops.trans_data(v_ms, transdata_type=1) + if mask_np is not None: + mask_ms = _ms_tensor(mask_np).astype(ms_dtype) + mask_nz = ms_custom_ops.trans_data(mask_ms, transdata_type=1) + else: + mask_nz = None + + out_nz = net( + q_nz, k_nz, v_nz, + mask_nz, + _ms_tensor(alibi_slopes.astype(np.float32)) if alibi_slopes is not None else None, + _ms_tensor(test_case.q_seq), _ms_tensor(test_case.kv_seq) + ) + + # Convert NZ output back to ND for validation + out_nd = ms_custom_ops.trans_data(out_nz, transdata_type=0) + out_np = (out_nd.float().asnumpy() if test_case.np_dtype == np_bfloat16 else out_nd.asnumpy()).astype(np.float32) + assert test_suite.validate_output(out_np, golden.astype(np.float32), test_case.np_dtype, test_case.heads, int(test_case.q_seq.max())) + + +def _flash_attention_nz_test(func): + decorators = [ + pytest.mark.level0, + pytest.mark.platform_ascend310p, + pytest.mark.env_onecard, + pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]), + pytest.mark.parametrize('dynamic', [True, False]), + pytest.mark.parametrize('np_dtype', [np.float16]) + ] + for d in reversed(decorators): + func = d(func) + return func + + +@_flash_attention_nz_test +def test_flash_attention_encoder_th_nz_nomask(run_mode, dynamic, np_dtype): + test_case = TestCase('TH', 2, 64, [20, 12], [20, 12], np_dtype) + _run_nz_test(test_case, run_mode, 3061, dynamic) + + +@_flash_attention_nz_test +def test_flash_attention_encoder_th_nz_norm_mask(run_mode, dynamic, np_dtype): + test_case = TestCase('TH', 2, 64, [20, 12], [20, 12], np_dtype, mask_type=MASK_NORM) + _run_nz_test(test_case, run_mode, 3062, dynamic) + + +@_flash_attention_nz_test +def test_flash_attention_encoder_th_nz_alibi_mask(run_mode, dynamic, np_dtype): + test_case = TestCase('TH', 4, 32, [24, 16], [24, 16], np_dtype, mask_type=MASK_ALIBI, use_alibi=True) + _run_nz_test(test_case, run_mode, 3063, dynamic) + + +@_flash_attention_nz_test +def test_flash_attention_encoder_th_nz_swa_norm(run_mode, dynamic, np_dtype): + test_case = TestCase('TH', 2, 64, [20, 12], [20, 12], np_dtype, + mask_type=MASK_SWA_NORM, window_size=16, cache_type=CACHE_SWA) + _run_nz_test(test_case, run_mode, 3064, dynamic) + + +@_flash_attention_nz_test +def test_flash_attention_encoder_th_nz_swa_compress(run_mode, dynamic, np_dtype): + test_case = TestCase('TH', 2, 64, [20, 12], [20, 12], np_dtype, + mask_type=MASK_SWA_COMPRESS, window_size=8, cache_type=CACHE_SWA) + _run_nz_test(test_case, run_mode, 3065, dynamic) diff --git a/yaml/doc/flash_attention_encoder.md b/yaml/doc/flash_attention_encoder.md new file mode 100644 index 0000000000000000000000000000000000000000..598257baab4995eb488e416880e0b04c0ffcbdec --- /dev/null +++ b/yaml/doc/flash_attention_encoder.md @@ -0,0 +1,155 @@ +# FlashAttention Encoder(Self-Attention)算子说明 + +## 背景简介 +FlashAttention 是一种高性能的自注意力实现,主要通过分块/重计算、mask压缩与更优的访存/算子融合来降低显存占用和提升吞吐。本项目中的 `flash_attention_encoder` 为自定义算子接口,统一对接 Ascend 910B(ND 内核)与 310P(NZ 内核)的 Unpad Flash Attention 系列内核: +- 910B: `UnpadFlashAttentionOperation`(ND 路径,支持高精度 FP32 模式) +- 310P: `UnpadFlashAttentionNzOperation`(NZ 路径,支持 SWA/Alibi 压缩等变体) + +参考(ATB 文档,SelfAttentionOperation 参数与特性说明): +- [SelfAttentionOperation 参数列表](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/API/ascendtbapi/ascendtb_01_0278.html) + +## 计算公式 +给定 Q/K/V 与缩放系数 `qkScale` 与可选掩码 M: + +- 令 Q ∈ R^(N_q × H × D), K ∈ R^(N_k × H_kv × D), V ∈ R^(N_k × H_kv × D_v)。H 为注意力头数,D 为 head_dim。 +- 计算 logits: + ``` + S = (Q · K^T) / scale + M # 其中 scale = 1 / qkScale + ``` +- 归一化: + ``` + P = softmax(S) + ``` +- 输出: + ``` + O = P · V + ``` + +说明:本版本未开放 LOGN 变体与量化相关的叠加,仅对接常规scale和常见掩码类型;不同平台(ND/NZ)内部实现差异对用户透明。 + +## 接口与输入输出 +### 名称 +- 算子名:`flash_attention_encoder` + +### 输入参数 + +| Name | DType | Shape | Optional | Format | Description | +|-----------------|---------------------------|---------------------------|----------|--------|-------------| +| query | Tensor[float16/bfloat16] | TH/TND | No | ND/NZ | 查询向量 Q | +| key | Tensor[float16/bfloat16] | 同 query | No | ND/NZ | 键向量 K | +| value | Tensor[float16/bfloat16] | 同 query | No | ND/NZ | 值向量 V | +| layer_id | Tensor[int32] | (1,) 或按实现要求 | Yes | ND | 层索引(保留,可选) | +| mask | Tensor[float16/bfloat16] | 见不同 mask 类型要求 | Yes | ND | 注意力掩码,支持 NORM/ALIBI/SWA 等 | +| alibi_coeff | Tensor[float32] | 见 ALIBI 需求 | Yes | ND | ALIBI 相关(可选) | +| deq_scale_qk | Tensor[float32] | (head_num,) | Yes | ND | 量化相关(保留) | +| deq_offset_qk | Tensor[float32] | (head_num,) | Yes | ND | 量化相关(保留) | +| deq_scale_pv | Tensor[float32] | (head_num,) | Yes | ND | 量化相关(保留) | +| deq_offset_pv | Tensor[float32] | (head_num,) | Yes | ND | 量化相关(保留) | +| quant_p | Tensor[float32] | 按需 | Yes | ND | 量化相关(保留) | +| logN | Tensor[float32] | (1,) 或按实现要求 | Yes | ND | 预留(本版未使用) | +| q_seq_len | Tensor[int32] | (batch,) | Cond | ND | TH/TND 布局必需:每 batch 的 Q 序列长度 | +| kv_seq_len | Tensor[int32] | (batch,) | Cond | ND | TH/TND 布局必需:每 batch 的 KV 序列长度 | +| head_num | int | - | Yes | - | 注意力头数(H) | +| scale_value | float | - | Yes | - | QK 缩放系数 `qkScale`(通常为 1/sqrt(head_dim)) | +| kv_head_num | int | - | Yes | - | KV 头数(GQA 场景);0 表示与 head_num 对齐 | +| mask_type | int | - | Yes | - | 掩码类型:UNDEFINED/NORM/ALIBI/SWA_* | +| kernel_type | int | - | Yes | - | 内核精度:0-默认半精度,1-高精度(FP32 BMM1) | +| window_size | int | - | Yes | - | SWA 窗口大小(SWA 场景) | +| cache_type | int | - | Yes | - | 缓存类型:0-NORM,1-SWA(SWA 优化) | + +注:当前版本未接线量化/online-offline QKV/clamp/ring/prefix 等高级特性,相关张量为占位,可保持为 None。 + +### 输出参数 + +| Name | DType | Shape | Description | +|---------------|--------------------------|-----------------|-------------| +| attention_out | Tensor[float16/bfloat16] | 与 query 对齐 | 注意力输出,与 `value` 的隐藏维对齐 | + +### 形状与布局 +- 支持布局:BSH / TH / TND(以及 BNSD 的等价场景) + - BSH:`query [B, S_q, H*D]`,`key/value [B, S_kv, H_kv*D]` → 输出 `[B, S_q, H*D]` + - TH:展平 token 流;`query [sum(q_seq), H*D]`,`key/value [sum(kv_seq), H_kv*D]` → 输出同 Q + - TND:可按每 batch 的真实序列构造 token 流(建议采用 TH 形式输入以避免额外填充) +- 910B(ND)与 310P(NZ)会在内部完成必要的 ND↔NZ 格式与参数设置,用户接口一致。 + +### Python 使用示例 + +```python +import numpy as np +from mindspore import context, Tensor +import ms_custom_ops + +np.random.seed(0) +context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE) + +head_num = 8 +kv_head_num = 4 +head_dim = 128 +scale_value = float(1.0 / np.sqrt(head_dim)) + +q_seq = np.array([16, 20], dtype=np.int32) +kv_seq = np.array([16, 20], dtype=np.int32) +q_tokens = int(q_seq.sum()) +kv_tokens = int(kv_seq.sum()) + +q = Tensor(np.random.uniform(-1, 1, size=(q_tokens, head_num, head_dim)).astype(np.float16)) +k = Tensor(np.random.uniform(-1, 1, size=(kv_tokens, kv_head_num, head_dim)).astype(np.float16)) +v = Tensor(np.random.uniform(-1, 1, size=(kv_tokens, kv_head_num, head_dim)).astype(np.float16)) + +q_seq_len = Tensor(q_seq).move_to("CPU") +kv_seq_len = Tensor(kv_seq).move_to("CPU") + +out = ms_custom_ops.flash_attention_encoder( + q, k, v, + layer_id=None, mask=None, alibi_coeff=None, + deq_scale_qk=None, deq_offset_qk=None, deq_scale_pv=None, deq_offset_pv=None, + quant_p=None, logN=None, + q_seq_len=q_seq_len, kv_seq_len=kv_seq_len, + head_num=head_num, scale_value=scale_value, kv_head_num=kv_head_num, + kernel_type=0, mask_type=0, window_size=0, cache_type=0 +) + +print(out.shape) # 期望: (q_tokens, head_num, head_dim) +``` + +## 支持的参数与调用顺序 +- 固定设置:`calcType = PA_ENCODER` +- 标量参数(与 ATB 对应关系): + - head_num → atb.headNum(>0) + - scale_value → atb.qkScale(通常 1/sqrt(head_dim)) + - kv_head_num → atb.kvHeadNum(0 表示与 head_num 对齐;GQA 约束:head_num % kv_head_num == 0) + - mask_type → atb.MaskType(支持 NORM、ALIBI、SWA_*,见下) + - kernel_type → atb.KernelType(0: 默认半精度;1: 高精度/910B FP32 BMM1) + - window_size → atb.windowSize(SWA 窗口) + - cache_type → atb.CacheType(0: NORM;1: SWA) +- 动态长度(仅 TH/TND 需要):`q_seq_len`、`kv_seq_len`(int32,CPU Tensor) + +参考:ATB SelfAttention 参数定义(官方文档) +链接: https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/API/ascendtbapi/ascendtb_01_0278.html + +## 特性支持矩阵(本版) +- 平台/内核 + - 910B(ND):`UnpadFlashAttentionOperation` + - 掩码:NORM、ALIBI、UNDEFINED + - 高精度:通过 `kernel_type=1` 启用(FP32 BMM1);`kernel_type=0` 走半精度高吞吐 + - LOGN:未开放 + - 310P(NZ):`UnpadFlashAttentionNzOperation` + - 掩码:NORM、ALIBI、SWA_NORM、SWA_COMPRESS(含 ALIBI 压缩变体) + - LOGN:未开放于本版接口 + - SWA:按 `mask_type` 选择;`window_size`/`cache_type` 推荐联动配置 + +- 动态序列长度 + - TH/TND:必须提供 `q_seq_len/kv_seq_len`(int32/CPU) + - BSH:内部按 batch 维度生成序列长度 + +- 暂未支持(保留/默认关闭) + - 量化(在线/离线 INT8)、clamp、LOGN、prefix/decoder、ring/send-recv 等高级特性 + +## 性能与精度建议 +- 910B:对精度敏感场景设置 `kernel_type=1`(FP32 BMM1),默认 `0` 走半精度高吞吐。 +- 310P:长序列场景开启 SWA,并合理设置 `window_size` 与 `cache_type` 以控内存与吞吐。 + +## 版本与兼容性 +- 当前文档对应的实现为 `flash_attention_encoder` 的首个版本,接口已与 ATB `SelfAttentionParam` 对齐,后续将逐步开放量化/decoder/prefix/ring 等特性,保持 YAML 与参数结构不变,向后兼容。 + + diff --git a/yaml/ms_kernels_internal/flash_attention_encoder_op.yaml b/yaml/ms_kernels_internal/flash_attention_encoder_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a0ffb528d75e502fc86ac88fc240744f18fbd3c --- /dev/null +++ b/yaml/ms_kernels_internal/flash_attention_encoder_op.yaml @@ -0,0 +1,71 @@ +#operator flash_attention_encoder +flash_attention_encoder: + args: + query: + dtype: tensor + key: + dtype: tensor + value: + dtype: tensor + layer_id: + dtype: tensor + default: None + mask: + dtype: tensor + default: None + alibi_coeff: + dtype: tensor + default: None + deq_scale_qk: + dtype: tensor + default: None + deq_offset_qk: + dtype: tensor + default: None + deq_scale_pv: + dtype: tensor + default: None + deq_offset_pv: + dtype: tensor + default: None + quant_p: + dtype: tensor + default: None + logN: + dtype: tensor + default: None + q_seq_len: + dtype: tensor + default: None + kv_seq_len: + dtype: tensor + default: None + head_num: + dtype: int + default: 0 + scale_value: + dtype: float + default: 1.0 + kv_head_num: + dtype: int + default: 0 + mask_type: + dtype: int + default: 0 + kernel_type: + dtype: int + default: 0 + window_size: + dtype: int + default: 0 + cache_type: + dtype: int + default: 0 + cache_mode: + dtype: int + default: 0 + returns: + attention_out: + dtype: tensor + +