From 96aea371088a5f0c783e57b8ff94e01f62f1e289 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Tue, 19 Aug 2025 22:10:44 +0800 Subject: [PATCH 1/7] Support stream into Dynamo charts --- test/dynamo/test_stream.py | 29 +++++++++++++++++++++++++++++ torch_npu/__init__.py | 4 ++++ torch_npu/dynamo/__init__.py | 1 + torch_npu/dynamo/trace_rule.py | 15 +++++++++++++++ 4 files changed, 49 insertions(+) create mode 100644 test/dynamo/test_stream.py create mode 100644 torch_npu/dynamo/trace_rule.py diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py new file mode 100644 index 0000000000..0b19cca8d2 --- /dev/null +++ b/test/dynamo/test_stream.py @@ -0,0 +1,29 @@ +# Owner(s): ["module: dynamo"] +import torch +import torch_npu + +import torch._dynamo.test_case + +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") + +class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): + + @requires_npu() + def test_stream(self): + def model_1(x): + a = x * x + s = torch.npu.stream() + s.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(s): + b = x + a + return b + inp = torch.randn(2,8).npu() + m = torch.compile(model_1,backend="aot_eager",fullgraph=True) + output = m(inp) + output1 = model_1(inp) + torch.allclose(output,, output1) + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 10e15a1522..f5154e8397 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -93,6 +93,7 @@ from torch_npu._C._distributed_c10d import ParallelStore from torch_npu.op_plugin.meta import _meta_registrations from torch_npu.version import __version__ as __version__ from torch_npu import _op_plugin_docs +from torch_npu.dynamo import _patch_npu_trace_rules del _op_plugin_docs _cann_package_check() @@ -296,6 +297,9 @@ if 'TORCH_NPU_SANITIZER' in os.environ: # register npu device op overrides for inductor _inductor_register_device_op_overrides() +# Support stream into Dynamo charts +_patch_npu_trace_rules() + if hasattr(sys, 'ps1'): os.environ["TASK_QUEUE_ENABLE"] = '0' warnings.warn("On the interactive interface, the value of TASK_QUEUE_ENABLE is set to 0 by default. \ diff --git a/torch_npu/dynamo/__init__.py b/torch_npu/dynamo/__init__.py index a6c2357087..95be98be63 100644 --- a/torch_npu/dynamo/__init__.py +++ b/torch_npu/dynamo/__init__.py @@ -10,6 +10,7 @@ from torch.library import Library, impl from torch_npu.utils._error_code import ErrCode, pta_error from torch_npu.utils.utils import _should_print_warning +from .trace_rule import _patch_npu_trace_rules _global_npu_backend = None __all__ = [] diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py new file mode 100644 index 0000000000..01e8992fa1 --- /dev/null +++ b/torch_npu/dynamo/trace_rule.py @@ -0,0 +1,15 @@ +import torch +from torch._dynamo.variables import TorchInGraphFunctionVariable + +torch_c_binding_in_graph_functions_npu = dict.fromkeys( + [ + "torch.npu.current_stream", + "torch.npu.default_stream", + "torch.npu.stream", + "torch.npu.set_stream", + ] +) + +def _patch_npu_trace_rules(): + torch._dynamo.trace_rules.clear_lru_cache() + torch._dynamo.trace_rules.torch_name_rule_map.append(torch_c_binding_in_graph_functions_npu) -- Gitee From 89fa1ff97ceeac94f95399f513c520584c97ed61 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 14:58:15 +0800 Subject: [PATCH 2/7] cleancode --- test/dynamo/test_stream.py | 2 +- torch_npu/__init__.py | 2 +- torch_npu/dynamo/trace_rule.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index 0b19cca8d2..4b9e30bfb9 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -21,7 +21,7 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): m = torch.compile(model_1,backend="aot_eager",fullgraph=True) output = m(inp) output1 = model_1(inp) - torch.allclose(output,, output1) + torch.allclose(output, output1) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index f5154e8397..3d6a76432b 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -91,9 +91,9 @@ from torch_npu.asd.asd import _asd_patch from torch_npu.asd.checksum import _matmul_checksum as matmul_checksum from torch_npu._C._distributed_c10d import ParallelStore from torch_npu.op_plugin.meta import _meta_registrations +from torch_npu.dynamo import _patch_npu_trace_rules from torch_npu.version import __version__ as __version__ from torch_npu import _op_plugin_docs -from torch_npu.dynamo import _patch_npu_trace_rules del _op_plugin_docs _cann_package_check() diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py index 01e8992fa1..ca6cd41bfc 100644 --- a/torch_npu/dynamo/trace_rule.py +++ b/torch_npu/dynamo/trace_rule.py @@ -7,7 +7,8 @@ torch_c_binding_in_graph_functions_npu = dict.fromkeys( "torch.npu.default_stream", "torch.npu.stream", "torch.npu.set_stream", - ] + ], + TorchInGraphFunctionVariable, ) def _patch_npu_trace_rules(): -- Gitee From a174757c76dc6ca1b20e5ef6ba0207e20ab9040a Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 15:31:15 +0800 Subject: [PATCH 3/7] cleancode --- test/dynamo/test_stream.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index 4b9e30bfb9..cdf614ed41 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -1,11 +1,13 @@ # Owner(s): ["module: dynamo"] +import functools +import unittest import torch -import torch_npu - import torch._dynamo.test_case +import torch_npu requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") + class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): @requires_npu() @@ -17,12 +19,13 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): with torch.npu.stream(s): b = x + a return b - inp = torch.randn(2,8).npu() - m = torch.compile(model_1,backend="aot_eager",fullgraph=True) + inp = torch.randn(2, 8).npu() + m = torch.compile(model_1, backend="aot_eager", fullgraph=True) output = m(inp) output1 = model_1(inp) torch.allclose(output, output1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests -- Gitee From 4c12e8d60c5c94b3ce3dc694be32b15ac669fcf4 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 16:58:47 +0800 Subject: [PATCH 4/7] cleancode --- torch_npu/dynamo/trace_rule.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py index ca6cd41bfc..3f3593dddc 100644 --- a/torch_npu/dynamo/trace_rule.py +++ b/torch_npu/dynamo/trace_rule.py @@ -11,6 +11,7 @@ torch_c_binding_in_graph_functions_npu = dict.fromkeys( TorchInGraphFunctionVariable, ) + def _patch_npu_trace_rules(): torch._dynamo.trace_rules.clear_lru_cache() torch._dynamo.trace_rules.torch_name_rule_map.append(torch_c_binding_in_graph_functions_npu) -- Gitee From 5b5d65d731a1a73d66462fad0aed3588d9df9e6d Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 18:52:33 +0800 Subject: [PATCH 5/7] fix test --- test/dynamo/test_stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index cdf614ed41..13bd0ec26a 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -14,7 +14,7 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): def test_stream(self): def model_1(x): a = x * x - s = torch.npu.stream() + s = torch.npu.Stream() s.wait_stream(torch.npu.current_stream()) with torch.npu.stream(s): b = x + a -- Gitee From 7dedc6e9cd80b292fda58554cd7c7766e7350829 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 19:08:45 +0800 Subject: [PATCH 6/7] fix test --- test/allowlist_for_publicAPI.json | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index cbce25bfb2..53e224b46f 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2874,5 +2874,8 @@ "torch_npu.utils.profiler": [ "Singleton", "Profile" + ], + "torch_npu.dynamo.trace_rule": [ + "TorchInGraphFunctionVariable" ] } -- Gitee From a5b7c555e243bce312106aba91e121df7d21768c Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Thu, 21 Aug 2025 09:41:05 +0800 Subject: [PATCH 7/7] fix test --- test/allowlist_for_publicAPI.json | 3 --- torch_npu/dynamo/trace_rule.py | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 53e224b46f..cbce25bfb2 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2874,8 +2874,5 @@ "torch_npu.utils.profiler": [ "Singleton", "Profile" - ], - "torch_npu.dynamo.trace_rule": [ - "TorchInGraphFunctionVariable" ] } diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py index 3f3593dddc..856aa214b6 100644 --- a/torch_npu/dynamo/trace_rule.py +++ b/torch_npu/dynamo/trace_rule.py @@ -1,6 +1,8 @@ import torch from torch._dynamo.variables import TorchInGraphFunctionVariable +__all__ = [] + torch_c_binding_in_graph_functions_npu = dict.fromkeys( [ "torch.npu.current_stream", -- Gitee