diff --git a/test/custom_ops/test_prompt_flash_attention.py b/test/custom_ops/test_prompt_flash_attention.py index c870476c117fb6881c4ac9f6ffbc3f7ae092a18c..50f3cd20be482a511ee709e0c7dbee4a2bf57ba0 100644 --- a/test/custom_ops/test_prompt_flash_attention.py +++ b/test/custom_ops/test_prompt_flash_attention.py @@ -23,7 +23,7 @@ class TestPromptFlashAttention(TestCase): def custom_op_exec(self, query, key, value, head_dim): scale = 1 / 0.0078125 return torch_npu.npu_prompt_flash_attention( - query, key, value, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) + query, key, value, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535, sparse_mode=0) @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PromptFlashAttention` is only supported on 910B, skip this ut for this device type!") diff --git a/test/network_ops/test_prompt_flash_attention.py b/test/network_ops/test_prompt_flash_attention.py index 631588b2f6341ea708de33d59089ff070d920972..8b3bf6b809d6cc52d1dec17904dca93eaf4353e4 100644 --- a/test/network_ops/test_prompt_flash_attention.py +++ b/test/network_ops/test_prompt_flash_attention.py @@ -22,7 +22,7 @@ class TestPromptFlashAttetion(TestCase): def prompt_flash_attention_npu(self, q, k, v, head_dim): scale = 1 / 0.0078125 return torch_npu.npu_prompt_flash_attention( - q, k, v, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) + q, k, v, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535, sparse_mode=0) @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PromptFlashAttention` is only supported on 910B, skip this ut for this device type!") diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 5e6fd04e426e10813655f116ca6b1caa3f469e13..ea86353bba47e348a39701a6b3f69fee7824a76b 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -18,10 +18,8 @@ def npu_incre_flash_attention_forward(query, key, value, *, padding_mask=None, a @impl(m, "npu_prompt_flash_attention") -def npu_prompt_flash_attention_forward(query, key, value, *, padding_mask=None, atten_mask=None, - actual_seq_lengths=None, num_heads=1, scale_value=1.0, pre_tokens=2147473647, - next_tokens=0, input_layout="BSH", num_key_value_heads=0): - return torch.empty_like(query) +def npu_prompt_flash_attention_forward(query, key, value, *, padding_mask=None, atten_mask=None, actual_seq_lengths=None, num_heads=1, scale_value=1.0, pre_tokens=2147473647, next_tokens=0, input_layout="BSH", num_key_value_heads=0, actual_seq_lengths_kv=None, sparse_mode=0): + return torch.empty_like(query, dtype=query.dtype) @impl(m, "npu_fusion_attention") diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 660f38934aafc66963e3cb26bec03c3be9f0208e..6efbc6126590ba7ac547c1a709fd22a15864ef04 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -598,13 +598,16 @@ class NPUPromptFlashAttentionOP(torch.autograd.Function): @staticmethod def symbolic(g, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[Tensor], atten_mask: Optional[Tensor], - actual_seq_lengths: Optional[Tensor], num_heads: int = 1, + actual_seq_lengths: Optional[Tensor], + num_heads: int = 1, scale_value: float = 1.0, pre_tokens: int = 2147473647, next_tokens: int = 0, - input_layout: str = "BSH", num_key_value_heads: int = 0): - return g.op("npu::NPUPromptFlashAttention", self, query, key, value, - padding_mask, atten_mask, actual_seq_lengths, - num_heads, scale_value, pre_tokens, next_tokens, - input_layout, num_key_value_heads) + input_layout: str = "BSH", num_key_value_heads: int = 0, + actual_seq_lengths_kv: Optional[Tensor], sparse_mode: int = 0): + return g.op("npu::NPUPromptFlashAttention", self, query, key, value, + padding_mask, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, + None, None, None, None, None, + num_heads, scale_value, pre_tokens, next_tokens, + input_layout, num_key_value_heads, sparse_mode) class NPUIncreFlashAttentionOP(torch.autograd.Function): @@ -843,9 +846,12 @@ def wrapper_npu_rotary_mul(x, r1, r2): def wrapper_npu_prompt_flash_attention(self, query, key, value, padding_mask, atten_mask, actual_seq_lengths, - num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads): - return NPUPromptFlashAttentionOP.apply(self, query, key, value, padding_mask, atten_mask, actual_seq_lengths, - num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads) + num_heads, scale_value, pre_tokens, next_tokens, + input_layout, num_key_value_heads, actual_seq_lengths_kv, sparse_mode): + return NPUPromptFlashAttentionOP.apply(self, query, key, value, padding_mask, atten_mask, + actual_seq_lengths, actual_seq_lengths_kv, + num_heads, scale_value, pre_tokens, next_tokens, + input_layout, num_key_value_heads, actual_seq_lengths_kv, sparse_mode) def wrapper_npu_incre_flash_attention(self, query, key, value, padding_mask, atten_mask, actual_seq_lengths,