diff --git a/torch_npu/npu/graphs.py b/torch_npu/npu/graphs.py index 7e21ce5ed9a78512d66e1bd915eb834732aa00fc..ef4d0d498bb1b06afcd7b7b264f33960dccd84cf 100644 --- a/torch_npu/npu/graphs.py +++ b/torch_npu/npu/graphs.py @@ -135,7 +135,7 @@ class _GraphDispatchMode(torch.utils._python_dispatch.TorchDispatchMode): return _GraphDispatchRecord(event=event, handle=handle, kwargs=kwargs_ref, args=tuple(args_ref), op_cache_entry=func) def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if func.__name__ == "npu_fused_infer_attention_score": + if func.__name__ in ["npu_fused_infer_attention_score", "npu_fused_infer_attention_score.default"]: func_out = torch_npu.npu_fused_infer_attention_score.out self.update_schema(str(func_out.__name__), str(func_out._schema)) stream = torch_npu.npu.current_stream()