diff --git a/torch_npu/csrc/aten/ops/op_api/FlashAttentionKernelNpuOpApi.cpp b/torch_npu/csrc/aten/ops/op_api/FlashAttentionKernelNpuOpApi.cpp index 8a16104f8923955179e7fb83f178370efe3670de..d82bda61349248c019cf67082d5ca7c03a004652 100644 --- a/torch_npu/csrc/aten/ops/op_api/FlashAttentionKernelNpuOpApi.cpp +++ b/torch_npu/csrc/aten/ops/op_api/FlashAttentionKernelNpuOpApi.cpp @@ -22,6 +22,8 @@ #include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "torch_npu/csrc/aten/ops/op_api/op_api_common.h" #include "torch_npu/csrc/aten/NPUGeneratorImpl.h" +#include "torch_npu/csrc/aten/NPUNativeOpApiFunctions.h" + namespace at_npu { namespace native { @@ -177,14 +179,26 @@ std::vector npu_flash_attention_backward( dpse = OpPreparation::ApplyTensorWithoutFormat(format_pse); } + // aply fp32 dq kv + + at::Tensor dq_32 = OpPreparation::ApplyTensorWithoutFormat(query.sizes(), query.options().dtype(at::kFloat)); + at::Tensor dk_32 = OpPreparation::ApplyTensorWithoutFormat(query.sizes(), query.options().dtype(at::kFloat)); + at::Tensor dv_32 = OpPreparation::ApplyTensorWithoutFormat(query.sizes(), query.options().dtype(at::kFloat)); + EXEC_NPU_NO_FORMAT_CHECK_CMD( aclnnFlashAttentionScoreGrad, format_query, format_key, format_value, format_dy, format_pse, format_drop_mask, format_padding_mask, dtype_atten_mask, format_softmax_max, format_softmax_sum, format_softmax, format_attention, scale_value, keep_prob, - pre_tockens, next_tockens, is_flash, head_num, input_layout_ptr, dq, dk, dv, dpse); + pre_tockens, next_tockens, is_flash, head_num, input_layout_ptr, dq_32, dk_32, dv_32, dpse); + + at::Tensor dq_scalared = at::mul(dq_32, at::Scalar(scale)); + + //cast + dq = NPUNativeOpApiFunctions::npu_dtype_cast(dq_scalared, query.scalar_type()); + dk = NPUNativeOpApiFunctions::npu_dtype_cast(dk_32, query.scalar_type()); + dv = NPUNativeOpApiFunctions::npu_dtype_cast(dv_32, query.scalar_type()); - at::Tensor dq_scalared = at::mul(dq, at::Scalar(scale)); - return {dq_scalared, dk, dv, + return {dq, dk, dv, at::Tensor(), at::Tensor(), dpse, at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor()}; }