diff --git a/test/npu/test_has_triton_package.py b/test/npu/test_has_triton_package.py new file mode 100644 index 0000000000000000000000000000000000000000..02a2389922b058c6fda16489d04374dc70f75a25 --- /dev/null +++ b/test/npu/test_has_triton_package.py @@ -0,0 +1,42 @@ +import os +import shutil +import importlib.util +import unittest + +# has_triton_package will be called during torch.testing init, remove existed triton due to init issue in torch 2.4 +if os.path.isdir("triton"): + shutil.rmtree("triton") +if os.path.isfile("triton"): + os.remove("triton") + +from torch.utils import _triton +from torch.testing._internal.common_utils import TestCase, run_tests + +TRITON_IS_INSTALLED = importlib.util.find_spec("triton") is not None +# clear lru cache +_triton.has_triton_package.cache_clear() + + +class TestHasTritonPackage(TestCase): + def setUp(self): + super().setUp() + if not os.path.exists("triton"): + os.mkdir("triton") + + def tearDown(self): + super().tearDown() + if os.path.isdir("triton"): + shutil.rmtree("triton") + + @unittest.skipIf(TRITON_IS_INSTALLED, "Skip this case due to triton is installed.") + def test_has_triton_package(self): + self.assertTrue(_triton.has_triton_package()) + + @unittest.skipIf(TRITON_IS_INSTALLED, "Skip this case due to triton is installed.") + def test_has_triton_package_with_patch(self): + import torch_npu + self.assertFalse(_triton.has_triton_package()) + + +if __name__ == '__main__': + run_tests() diff --git a/torch_npu/utils/_dynamo.py b/torch_npu/utils/_dynamo.py index 5915b8ed9c96a8b2d46a1cc94a31de505dcc24a4..6742a51c2e517189c0f066add50ac8afcf419e86 100644 --- a/torch_npu/utils/_dynamo.py +++ b/torch_npu/utils/_dynamo.py @@ -1,7 +1,21 @@ import inspect -from typing import Dict, List - +import functools import torch + + +# Folder with the same name as pkg `triton` in workdir could cause init problem in torch 2.4 +@functools.lru_cache(None) +def _has_triton_package() -> bool: + try: + from triton.compiler.compiler import triton_key + return triton_key is not None + except ImportError: + return False + + +from torch.utils import _triton +_triton.has_triton_package = _has_triton_package + from torch._dynamo.utils import tensortype_to_dtype from torch._dynamo.variables.torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable from torch._dynamo.variables.base import VariableTracker