diff --git a/ops/c_api/flash_attention_encoder/flash_attention_encoder.cc b/ops/c_api/flash_attention_encoder/flash_attention_encoder.cc index 680c34ac8b8cf30ac33133dc60609d39bc93f191..475d977e7c3f2efaf808208e35e15eaa674c47a9 100644 --- a/ops/c_api/flash_attention_encoder/flash_attention_encoder.cc +++ b/ops/c_api/flash_attention_encoder/flash_attention_encoder.cc @@ -49,6 +49,7 @@ internal_v2::InternalOpPtr CustomFlashAttentionEncoder::CreateKernel( param_.window_size = static_cast(ms_inputs[kWindowSizeIdx]->GetValueWithCheck()); param_.cache_type = static_cast(ms_inputs[kCacheTypeIdx]->GetValueWithCheck()); + param_.inner_precise = static_cast(ms_inputs[kInnerPreciseIdx]->GetValueWithCheck()); // input_format: default 0 ND, 1 force NZ param_.input_format = static_cast(ms_inputs[kInputFormatIdx]->GetValueWithCheck()); diff --git a/ops/c_api/flash_attention_encoder/flash_attention_encoder.h b/ops/c_api/flash_attention_encoder/flash_attention_encoder.h index 3fbbf787f95d017860a7aab5f3417d47a727f587..335a189335993e02b1bc952b09fb8f3810163b5f 100644 --- a/ops/c_api/flash_attention_encoder/flash_attention_encoder.h +++ b/ops/c_api/flash_attention_encoder/flash_attention_encoder.h @@ -48,6 +48,7 @@ enum FlashAttentionEncoderInputIndex : int { kKernelTypeIdx, kWindowSizeIdx, kCacheTypeIdx, + kInnerPreciseIdx, kInputFormatIdx, kInputNums }; @@ -60,8 +61,9 @@ class OPS_API FlashAttentionEncoderOpFuncImpl : public OpFuncImpl { std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override; bool GeneralInferRegistered() const override { return true; } std::set GetValueDependArgIndices() const override { - return {kQSeqLenIdx, kKVSeqLenIdx, kHeadNumIdx, kScaleValueIdx, kKvHeadNumIdx, - kMaskTypeIdx, kKernelTypeIdx, kWindowSizeIdx, kCacheTypeIdx, kInputFormatIdx}; + return {kQSeqLenIdx, kKVSeqLenIdx, kHeadNumIdx, kScaleValueIdx, kKvHeadNumIdx, + kMaskTypeIdx, kKernelTypeIdx, kWindowSizeIdx, kCacheTypeIdx, kInnerPreciseIdx, + kInputFormatIdx}; }; }; diff --git a/ops/c_api/flash_attention_encoder/flash_attention_encoder.md b/ops/c_api/flash_attention_encoder/flash_attention_encoder.md index d33975aad20cd2b22780926a7db7d4b9b2a5b0a0..92683183a35939de942aa6ba29c4d2d5be6b262d 100644 --- a/ops/c_api/flash_attention_encoder/flash_attention_encoder.md +++ b/ops/c_api/flash_attention_encoder/flash_attention_encoder.md @@ -39,6 +39,7 @@ FlashAttention 是一种高性能的自注意力实现,主要通过分块/重 | kernel_type | int | - | Yes | - | 内核精度,默认 0:半精度;1:高精度(FP32 BMM1) | | window_size | int | - | Yes | - | SWA 窗口大小,默认 0(关闭 SWA) | | cache_type | int | - | Yes | - | 缓存类型,默认 0:NORM;1:SWA(SWA 优化) | +| inner_precise | int | - | Yes | - | 内部计算精度控制参数,默认 0 | | input_format | int | - | Yes | - | 输入格式选择,默认 0:ND;1:NZ | 注:当前版本未接线量化/online-offline QKV/clamp/ring/prefix 等高级特性,相关张量为占位,可保持为 None。 diff --git a/ops/c_api/flash_attention_encoder/flash_attention_encoder_op.yaml b/ops/c_api/flash_attention_encoder/flash_attention_encoder_op.yaml index ce21d026103ed53dac8a653438fe088ebe937858..9b7df010ea2da449570ef1261199147a15d2718d 100644 --- a/ops/c_api/flash_attention_encoder/flash_attention_encoder_op.yaml +++ b/ops/c_api/flash_attention_encoder/flash_attention_encoder_op.yaml @@ -61,6 +61,9 @@ flash_attention_encoder: cache_type: dtype: int default: 0 + inner_precise: + dtype: int + default: 0 input_format: dtype: int default: 0 diff --git a/ops/c_api/flash_attention_encoder/flash_attention_encoder_runner.cc b/ops/c_api/flash_attention_encoder/flash_attention_encoder_runner.cc index 0b8776d0f66c8a19f37f40748114ea599d0e3764..4eeb6a7deb365454546cfb2e647b506974168ec1 100644 --- a/ops/c_api/flash_attention_encoder/flash_attention_encoder_runner.cc +++ b/ops/c_api/flash_attention_encoder/flash_attention_encoder_runner.cc @@ -23,7 +23,7 @@ namespace ms_custom_ops { void FlashAttentionEncoderRunner::SetParam(int64_t head_num, float scale_value, int64_t kv_head_num, int64_t mask_type, int64_t kernel_type, int64_t window_size, int64_t cache_type, - const std::vector &q_seq_len, + int64_t inner_precise, const std::vector &q_seq_len, const std::vector &kv_seq_len) { param_.head_num = static_cast(head_num); param_.qk_scale = static_cast(scale_value); @@ -34,6 +34,7 @@ void FlashAttentionEncoderRunner::SetParam(int64_t head_num, float scale_value, param_.kernel_type = static_cast(kernel_type); param_.window_size = static_cast(window_size); param_.cache_type = static_cast(cache_type); + param_.inner_precise = static_cast(inner_precise); param_.q_seq_len = q_seq_len; param_.kv_seq_len = kv_seq_len; @@ -81,7 +82,7 @@ static std::vector npu_flash_attention_encoder( const std::optional &quant_p, const std::optional &logN, const std::optional &q_seq_len, const std::optional &kv_seq_len, int64_t head_num, float scale_value, int64_t kv_head_num, int64_t mask_type, int64_t kernel_type, int64_t window_size, - int64_t cache_type, int64_t input_format) { + int64_t cache_type, int64_t inner_precise, int64_t input_format) { static auto op_name = "FlashAttentionEncoder"; auto runner = std::make_shared(op_name); MS_EXCEPTION_IF_NULL(runner); @@ -96,13 +97,14 @@ static std::vector npu_flash_attention_encoder( auto q_seq = GetValueFromTensor>(q_seq_len.value(), op_name, "q_seq_len"); auto kv_seq = GetValueFromTensor>(kv_seq_len.value(), op_name, "kv_seq_len"); - runner->SetParam(head_num, scale_value, kv_head_num, mask_type, kernel_type, window_size, cache_type, q_seq, kv_seq); + runner->SetParam(head_num, scale_value, kv_head_num, mask_type, kernel_type, window_size, cache_type, inner_precise, + q_seq, kv_seq); runner->SetInputFormat(input_format); // Setup the runner with all parameters to form cache key runner->Setup(op_name, query, key, value, layer_id, mask, alibi_coeff, deq_scale_qk, deq_offset_qk, deq_scale_pv, deq_offset_pv, quant_p, logN, q_seq_len, kv_seq_len, head_num, scale_value, kv_head_num, mask_type, - kernel_type, window_size, cache_type, input_format); + kernel_type, window_size, cache_type, inner_precise, input_format); // outputs auto attn_out = ms::Tensor(query.data_type(), query.shape()); @@ -132,11 +134,12 @@ static auto pyboost_flash_attention_encoder( const std::optional &quant_p, const std::optional &logN, const std::optional &q_seq_len, const std::optional &kv_seq_len, int64_t head_num, float scale_value, int64_t kv_head_num, int64_t mask_type, int64_t kernel_type, int64_t window_size, - int64_t cache_type, int64_t input_format) { + int64_t cache_type, int64_t inner_precise, int64_t input_format) { return ms::pynative::PyboostRunner::Call<1>(npu_flash_attention_encoder, query, key, value, layer_id, mask, alibi_coeff, deq_scale_qk, deq_offset_qk, deq_scale_pv, deq_offset_pv, quant_p, logN, q_seq_len, kv_seq_len, head_num, scale_value, kv_head_num, - mask_type, kernel_type, window_size, cache_type, input_format); + mask_type, kernel_type, window_size, cache_type, inner_precise, + input_format); } } // namespace ms_custom_ops @@ -150,5 +153,5 @@ MS_CUSTOM_OPS_EXTENSION_MODULE(m) { pybind11::arg("q_seq_len") = std::nullopt, pybind11::arg("kv_seq_len") = std::nullopt, pybind11::arg("head_num") = 0, pybind11::arg("scale_value") = 1.0, pybind11::arg("kv_head_num") = 0, pybind11::arg("mask_type") = 0, pybind11::arg("kernel_type") = 0, pybind11::arg("window_size") = 0, - pybind11::arg("cache_type") = 0, pybind11::arg("input_format") = 0); + pybind11::arg("cache_type") = 0, pybind11::arg("inner_precise") = 0, pybind11::arg("input_format") = 0); } diff --git a/ops/c_api/flash_attention_encoder/flash_attention_encoder_runner.h b/ops/c_api/flash_attention_encoder/flash_attention_encoder_runner.h index dfdd7391a965322c4ae8405f8576258c4551ffa1..463398bc64722488a487b340647f3eb6321ff443 100644 --- a/ops/c_api/flash_attention_encoder/flash_attention_encoder_runner.h +++ b/ops/c_api/flash_attention_encoder/flash_attention_encoder_runner.h @@ -33,7 +33,8 @@ class FlashAttentionEncoderRunner : public InternalPyboostRunner { public: using InternalPyboostRunner::InternalPyboostRunner; void SetParam(int64_t head_num, float scale_value, int64_t kv_head_num, int64_t mask_type, int64_t kernel_type, - int64_t window_size, int64_t cache_type, const std::vector &q_seq_len, + int64_t window_size, int64_t cache_type, int64_t inner_precise, + const std::vector &q_seq_len, const std::vector &kv_seq_len); void SetInputFormat(int64_t input_format) { param_.input_format = static_cast(input_format); } diff --git a/tests/st/test_custom_flash_attention_encoder_nz.py b/tests/st/test_custom_flash_attention_encoder_nz.py index b937c3aaa2b4dafdafa5247e72ba4f09350858ee..b995220a02aa65f2e9a4df7a89b3216847133c30 100644 --- a/tests/st/test_custom_flash_attention_encoder_nz.py +++ b/tests/st/test_custom_flash_attention_encoder_nz.py @@ -63,7 +63,7 @@ class FlashAttentionEncoderNzNet(nn.Cell): q_nz, k_nz, v_nz, None, mask_nz, alibi_slopes, None, None, None, None, None, None, q_lens_cpu, kv_lens_cpu, self.heads, self.scale_value, self.kv_heads, self.mask_type, - self.kernel_type, self.window_size, self.cache_type, self.input_format) + self.kernel_type, self.window_size, self.cache_type, input_format=self.input_format) # Convert output back to ND format within graph out_nd = ms_custom_ops.trans_data(out_nz, transdata_type=0)