diff --git a/test/_inductor/__init__.py b/test/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2497b84aeda72a81c72604bbd678e7cd0494594 --- /dev/null +++ b/test/_inductor/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. \ No newline at end of file diff --git a/test/_inductor/commonutils.py b/test/_inductor/commonutils.py new file mode 100644 index 0000000000000000000000000000000000000000..00527a41d87baaf9de5f31495acf6fe9f852bc27 --- /dev/null +++ b/test/_inductor/commonutils.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + +import subprocess +from io import StringIO +from typing import List + + +""" +对外统一调用接口 +""" +def get_available_npu_device_ids(): + npu_ids = _get_all_npu_device_ids() + sorted_npu_dict = _sort_npu_by_usage_cap(npu_ids) + return list(sorted_npu_dict.keys()) + + +""" +通过ls /dev/davinci*获取所有的npu的id +""" +def _get_all_npu_device_ids(): + ## ls /dev/davinci* + buffer = StringIO() + try: + result = subprocess.run( + ["ls /dev/davinci*"], + capture_output=True, + shell=True, + text=True, + check=True + ) + output = result.stdout + buffer.write(output) + except subprocess.CalledProcessError as e: + print(f"Error running command: {e}") + finally: + content = buffer.getvalue() + buffer.close() + + npu_ids = [] + if content is None: + return npu_ids + for line in content.splitlines(): + if not line[-1].isdigit(): + continue + idx = -1 + while line[idx].isdigit(): + idx -= 1 + id = line[idx + 1:] + npu_ids.append(id) + return npu_ids + + +""" +通过npu-smi info -t usages -i %id 获取每个卡的使用率并升序排序 +返回字典{id:[HBM Capacity(MB), HBM Usage Rate(%)]},按使用率升序,使用率相同按容量降序 +""" +def _sort_npu_by_usage_cap(npu_ids: List[str]) -> List[int]: + npu_dict = dict() + try: + for id in npu_ids: + result = subprocess.run(["npu-smi info -t usages -i " + id], + capture_output=True, + text=True, + shell=True, + check=True) + ss = result.stdout + ## [HBM Capacity(MB), HBM Usage Rate(%)] + tmp = [] + for line in ss.splitlines(): + if ":" not in line: + continue + key, val = line.split(":") + key, val = key.strip(), val.strip() + if key == "HBM Usage Rate(%)": + tmp.append(val) + if key == "HBM Capacity(MB)": + tmp.append(val) + if tmp is not None: + npu_dict[int(id)] = tmp + sorted_npu_dict = dict(sorted(npu_dict.items(), key=lambda x: (int(x[1][1]), -int(x[1][0])))) + return sorted_npu_dict + except subprocess.CalledProcessError as e: + print(f"Error running command: {e}") + + +if __name__ == '__main__': + res = get_available_npu_device_ids() + print(res) \ No newline at end of file diff --git a/test/_inductor/conftest.py b/test/_inductor/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..0b907d9a290b85769e9e746082fdda13c46791d6 --- /dev/null +++ b/test/_inductor/conftest.py @@ -0,0 +1,22 @@ +import pytest +import os +import torch_npu._inductor +import getpass + + +def pytest_addoption(parser): + parser.addoption("--npu_indexing", action='store', default='False', + help='whether enable npu indexing or not,default is True', choices=['True', 'False']) + + +@pytest.fixture(scope="session") +def clear_cache(): + # os.system('rm -rf /tmp/torchinductor_' + getpass.getuser() + '/*') + # os.system('rm -rf ~/.triton/dump') + # os.system('rm -rf ~/.triton/cache') + return + + +@pytest.fixture(scope="session", autouse=True) +def set_npu_indexing(pytestconfig): + torch_npu._inductor.config.enable_npu_indexing = eval(pytestconfig.getoption("--npu_indexing")) diff --git a/test/_inductor/run_ut.sh b/test/_inductor/run_ut.sh new file mode 100644 index 0000000000000000000000000000000000000000..bdbe08e8dfecf6ccaabba9c550241c4b30a45ffd --- /dev/null +++ b/test/_inductor/run_ut.sh @@ -0,0 +1,69 @@ +#!/bin/bash +set -ex + +source /root/anaconda3/bin/activate inductor260 +pip list + +# 先编译tritonNpu +pip uninstall triton + +mkdir -p ${WORKSPACE}TritonNpu +cd ${WORKSPACE}TritonNpu +git clone https://gitee.com/ascend/triton-ascend.git -b master + +if [ -d ${WORKSPACE}TritonNpu/triton-ascend/triton ];then + rm -rf ${WORKSPACE}TritonNpu/triton-ascend/triton +fi + +if [ -d ~/.triton/dump ];then + rm -rf ~/.triton/dump +fi + +if [ -d ~/.triton/cache ];then + rm -rf ~/.triton/cache +fi + +cd ${WORKSPACE}TritonNpu/triton-ascend +git clone --depth 1 https://gitee.com/shijingchang/triton.git +#cp -r /triton_depends/triton ${WORKSPACE}TritonNpu/triton-ascend/triton +#cd ${WORKSPACE}TritonNpu/triton-ascend/triton +#git apply ${WORKSPACE}TritonNpu/triton-ascend/build/patch/triton_ebce7f.patch +#git apply ${WORKSPACE}TritonNpu/triton-ascend/build/patch/0001-AttrDescriptor-fix-and-delete-power-of-two.patch +#cd ${WORKSPACE}TritonNpu/triton-ascend +echo ${pwd} + +TRITON_PLUGIN_DIRS=${WORKSPACE}TritonNpu/triton-ascend/ascend \ +LLVM_INCLUDE_DIRS=$LLVM_SYSPATH/include \ +LLVM_LIBRARY_DIR=$LLVM_SYSPATH/lib \ +LLVM_SYSPATH=$LLVM_SYSPATH \ +TRITON_BUILD_WITH_CLANG_LLD=true \ +pip install -e ${WORKSPACE}TritonNpu/triton-ascend/triton/python --no-build-isolation -vvv + +pip list + +cd ${WORKSPACE} +echo ${PWD} +ls -al + +# run inductor ut +export PYTHONPATH=${WORKSPACE}:$PYTHONPATH +export TORCHINDUCTOR_COMPILE_THREADS=1 +export ASCEND_LAUNCH_BLOCKING=1 +export CI="" +env + +if [ -d ~/.triton/dump ];then + rm -rf ~/.triton/dump +fi + +if [ -d ~/.triton/cache ];then + rm -rf ~/.triton/cache +fi + +tree + +cd test + +pytest -svvv . --npu_indexing=True || { exit 1; } + + diff --git a/test/_inductor/test_abs.py b/test/_inductor/test_abs.py new file mode 100644 index 0000000000000000000000000000000000000000..8440aab1fcadce97e53de9bf5d28d25ae4335d6a --- /dev/null +++ b/test/_inductor/test_abs.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestAbs(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element): + result = torch.abs(first_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(1024, 32), (256, 8)]) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + # print(std_result[0:8]) + # print(inductor_result[0:8]) + torch.testing.assert_close(std_result, inductor_result, atol=1e-3, rtol=1e-3) diff --git a/test/_inductor/test_add.py b/test/_inductor/test_add.py new file mode 100644 index 0000000000000000000000000000000000000000..8da8dff4f5949946980e23ac4a8a54e0f333d789 --- /dev/null +++ b/test/_inductor/test_add.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestAdd(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element, second_element): + result = first_element + second_element + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float32', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_sum = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_sum = compiled_op_calc(first_element, second_element) + + torch.testing.assert_close(std_sum, inductor_sum) + + # should be implemented when __OPTYPE is OperatorType.REDUCTION + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape,dim', TestUtils._reduction_extest_SDbinding) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + @pytest.mark.skipif(__OPTYPE != OperatorType.REDUCTION, reason='not reduction operator') + def test_reduction_cases(self, shape, dim, dtype, clear_cache): + pass + +if __name__ == "__main__": + size = (1024, 1024) + test = TestAdd() + test.test_pointwise_cases(size, 'float32', None) \ No newline at end of file diff --git a/test/_inductor/test_add_sum.py b/test/_inductor/test_add_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..3e54152be9d8ca082797a19c06c884ab5cabb194 --- /dev/null +++ b/test/_inductor/test_add_sum.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestSumAdd(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.REDUCTION + + def foo(self,a, b, dim): + y = a + b + y = y.sum(dim) + return y + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(9, 9, 31, 64)]) + @pytest.mark.parametrize('dim', [3]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + r1 = self.foo(a, b, dim) + func = torch.compile(self.foo, backend="inductor", dynamic=False) + r = func(a, b, dim) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(9, 10, 31, 63)]) + @pytest.mark.parametrize('dim', [0, 1]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes1(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + r1 = self.foo(a, b, dim) + func = torch.compile(self.foo, backend="inductor", dynamic=False) + r = func(a, b, dim) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) diff --git a/test/_inductor/test_alias.py b/test/_inductor/test_alias.py new file mode 100644 index 0000000000000000000000000000000000000000..7f93f091ce4b6e5ecb4d60b34fa14f5b624db2d1 --- /dev/null +++ b/test/_inductor/test_alias.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestAlias(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + x = torch.ops.aten.alias(input_element) + y = x + 1.0 + return y + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(32, 64)]) + @pytest.mark.parametrize('dim', [0]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + print(f"input_element= {input_element}") + std_ret = self.op_calc(input_element, dim) + print(f"std_ret= {std_ret}") + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element, dim) + print(f"inductor_ret= {inductor_ret}") + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_ret, inductor_ret, equal_nan=True, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + size = (32, 64) + test = TestAlias() + test.test_reduction_cases_shapes(size, -1, 'float32', None) diff --git a/test/_inductor/test_argmax.py b/test/_inductor/test_argmax.py new file mode 100644 index 0000000000000000000000000000000000000000..d8ddb3771f6df92cb6da150c4cc63f631f68c723 --- /dev/null +++ b/test/_inductor/test_argmax.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestArgmax(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + def argmax(self, a, dim): + return torch.argmax(a, dim) + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.skip(reason='not support yet') + def test_argmax(self): + shape=(512, 64) + dim = -1 + print(f"start to test argmax on shape:{shape} dim:{dim} ") + a = torch.randn(shape, requires_grad=False, dtype=torch.float32, device='npu') + + argmax_triton = torch.compile(self.argmax, backend="inductor", dynamic=False) + r = self.argmax(a, dim) + r1 = argmax_triton(a, dim) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) + + diff --git a/test/_inductor/test_argmax_unalign.py b/test/_inductor/test_argmax_unalign.py new file mode 100644 index 0000000000000000000000000000000000000000..66beb403054f04cd733f93d25be168bbf7c981d6 --- /dev/null +++ b/test/_inductor/test_argmax_unalign.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import sys +sys.path.append("../..") +import torch_npu._inductor + +import pytest +# from .testutils import OperatorType, TestUtils +torch_npu._inductor.config.enable_npu_indexing = True +class TestMaxWithIndex(): + __TIME_LIMIT = 100 + def op_calc(self, input_element, dim): + return torch.argmax(input_element, dim) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(512, 64)]) # (513, 64), (514,33) + @pytest.mark.parametrize('dim', [-1 ]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases(self, shape, dim, dtype): + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + input_element = torch.randn(size=shape, dtype=eval('torch.' + dtype), device=torch.device("npu")) * 2000 + std_argmax = self.op_calc(input_element, dim) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_argmax = compiled_op_calc(input_element, dim) + torch.testing.assert_close(std_argmax, inductor_argmax, rtol=1e-2, atol=1e-2) +if __name__ == '__main__': + self = TestMaxWithIndex() + self.test_reduction_cases((513, 64), -1, 'float32') \ No newline at end of file diff --git a/test/_inductor/test_arrange.py b/test/_inductor/test_arrange.py new file mode 100644 index 0000000000000000000000000000000000000000..3fe320fdb47d56e701d8d4029979ab2454a2721c --- /dev/null +++ b/test/_inductor/test_arrange.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestArrange(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, start, end, step): + a = torch.arange(start, end, step, device=torch.device('npu')) + y = a + a + return y + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(2, )]) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + s = self._generate_tensor(shape, dtype) + start = min(s) + end = max(s) + step = (end - start) / 32 + + std_arrange = self.op_calc(start, end, step) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_arrange = compiled_op_calc(start, end, step) + + torch.testing.assert_close(std_arrange, inductor_arrange) + + # should be implemented when __OPTYPE is OperatorType.REDUCTION + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape,dim', TestUtils._reduction_extest_SDbinding) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + @pytest.mark.skipif(__OPTYPE != OperatorType.REDUCTION, reason='not reduction operator') + def test_reduction_cases(self, shape, dim, dtype, clear_cache): + pass diff --git a/test/_inductor/test_attncp.py b/test/_inductor/test_attncp.py new file mode 100644 index 0000000000000000000000000000000000000000..c2e40d3469915334040f103f1938bee1064849c5 --- /dev/null +++ b/test/_inductor/test_attncp.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestAttnCp(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + # @torch.compile(options={"aggressive_fusion": False}) + shape = (8, 8, 256, 128) + dim = -1 + def foo(self, a, b, c): + y = a + b + y = y.sum(self.dim) + y = y.unsqueeze(self.dim) + y = y.broadcast_to(self.shape) + b + y = c + y.permute(0, 1, 3, 2) + return y + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + def test_pointwise_cases(self): + a, b = [torch.randn(self.shape, dtype=torch.float32, device="npu") for _ in range(2)] + d = torch.randn(self.shape, dtype=torch.float32, device="npu") + c = d.permute(0, 1, 3, 2).contiguous() + func = torch.compile(self.foo, backend="inductor") + r = func(a, b, c) + r1 = self.foo(a, b, c) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) diff --git a/test/_inductor/test_batch_norm.py b/test/_inductor/test_batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..d330800f8f75ca78cf33b2974d85e549569386f4 --- /dev/null +++ b/test/_inductor/test_batch_norm.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestNativeBatchNorm(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element): + # 创建权重和偏置张量 + weight = torch.ones(32).npu() + bias = torch.zeros(32).npu() + + # 创建运行均值和方差张量 + running_mean = torch.zeros(32).npu() + running_var = torch.ones(32).npu() + + + # 执行批量归一化 + output, running_mean_out, running_var_out = torch.native_batch_norm( + input=input_element, + weight=weight, + bias=bias, + running_mean=running_mean, + running_var=running_var, + training=True, + momentum=0.1, + eps=1e-05 + ) + return output, running_mean_out, running_var_out + + @pytest.mark.skip(reason="npu compiler bug") + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(16, 32, 64)]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + + print(f"input_element= {input_element}") + std_ret, std_ret2, std_ret3 = self.op_calc(input_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret, inductor_ret2, inductor_ret3 = compiled_op_calc(input_element) + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_ret, inductor_ret, equal_nan=True, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + size = (16, 32, 64) + test = TestNativeBatchNorm() + test.test_reduction_cases_shapes(size, 'float32', None) + diff --git a/test/_inductor/test_broadcast.py b/test/_inductor/test_broadcast.py new file mode 100644 index 0000000000000000000000000000000000000000..85ec062ff53acc408a149ca12620b09a36ff9aca --- /dev/null +++ b/test/_inductor/test_broadcast.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu + +import torch_npu._inductor + +import copy +import pytest +from testutils import OperatorType, TestUtils + +class TestBroadcast(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + broadcast_size = 128 + + def op_calc(self, a, b, dim, new_shape): + a = a.unsqueeze(dim) + a = a.broadcast_to(new_shape) + b = b.unsqueeze(dim) + b = b.broadcast_to(new_shape) + y = a + b + return y + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 8, 256)]) + @pytest.mark.parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16']) + def test_view_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + a = self._generate_tensor(shape, dtype) + b = self._generate_tensor(shape, dtype) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + for dim in [3, 2, 1, 0]: + new_shape = list(copy.deepcopy(shape)) + new_shape.insert(dim, self.broadcast_size) + std_broadcast = self.op_calc(a, b, dim, new_shape) + inductor_broadcast = compiled_op_calc(a, b, dim, new_shape) + + torch.testing.assert_close(std_broadcast.float(), inductor_broadcast.float(), rtol=1e-3, atol=1e-3) + print(f"data validation passed") + + # should be implemented when __OPTYPE is OperatorType.REDUCTION + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape,dim', TestUtils._reduction_extest_SDbinding) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + @pytest.mark.skipif(__OPTYPE != OperatorType.REDUCTION, reason='not reduction operator') + def test_reduction_cases(self, shape, dim, dtype, clear_cache): + pass diff --git a/test/_inductor/test_cat.py b/test/_inductor/test_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..715c44bf561b14b8636b92c9f8360a8a1513e76a --- /dev/null +++ b/test/_inductor/test_cat.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestCat(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return torch.cat([input_element, input_element], dim) + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 16, 32, 64)]) + @pytest.mark.parametrize('dim', [-1]) + @pytest.mark.parametrize('dtype', ['bfloat16']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + std_cat = self.op_calc(input_element, dim) + # print(f"std_cat.shape= {std_cat.shape}") + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_cat = compiled_op_calc(input_element, dim) + # print(f"inductor_cat.shape= {inductor_cat.shape}") + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_cat, inductor_cat, equal_nan=True, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + size = (8, 8, 8, 2048) + test = TestCat() + test.test_reduction_cases_shapes(size, 2, 'float32', None) diff --git a/test/_inductor/test_ceil.py b/test/_inductor/test_ceil.py new file mode 100644 index 0000000000000000000000000000000000000000..da2d7cc73be89b4f4f676c1ce5fdb94db7462e10 --- /dev/null +++ b/test/_inductor/test_ceil.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRelu(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element): + result = torch.ceil(first_element) + return result + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.testing.assert_close(std_result, inductor_result) + + +if __name__ == '__main__': + TestRelu() diff --git a/test/_inductor/test_clamp.py b/test/_inductor/test_clamp.py new file mode 100644 index 0000000000000000000000000000000000000000..adc3bcaf12286201a4df5db299ff341deb1d2da5 --- /dev/null +++ b/test/_inductor/test_clamp.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestClamp(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, input, min=None, max=None): + return input.clamp(min, max) + + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + @pytest.mark.skip(reason='not support yet') + def test_pointwise_cases_minmax_is_tensor(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + min = self._generate_tensor(shape, dtype) + max = self._generate_tensor(shape, dtype) + + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, min=min, max=max) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, min=min, max=max) + + torch.testing.assert_close(std_result, inductor_result) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(1,)]) + @pytest.mark.parametrize('dtype', ['float32']) + @pytest.mark.skip(reason='not support yet') + def test_pointwise_cases_single_scalar(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + min = 0 + max = 100 + + first_element = 200 * torch.rand(size=shape, dtype=eval('torch.' + dtype), device=torch.device("npu")) + + std_result = self.op_calc(first_element, min=min, max=max) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, min=min, max=max) + torch.testing.assert_close(std_result, inductor_result) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(1024, 32)]) + @pytest.mark.parametrize('dtype', ['int32']) + @pytest.mark.skip(reason='not support yet') + def test_pointwise_cases_minmax_is_number(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + min = 0 + max = 100 + + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, min=min, max=max) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, min=min, max=max) + + torch.testing.assert_close(std_result, inductor_result) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases_max_only(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + max = 100 + + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, min=None, max=max) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, min=None, max=max) + + torch.testing.assert_close(std_result, inductor_result) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases_min_only(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + min = 0 + + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, min=min, max=None) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, min=min, max=None) + + torch.testing.assert_close(std_result, inductor_result) +if __name__ == '__main__': + obj = TestClamp() + obj.test_pointwise_cases_single_scalar((1,), 'float32', None) \ No newline at end of file diff --git a/test/_inductor/test_clone.py b/test/_inductor/test_clone.py new file mode 100644 index 0000000000000000000000000000000000000000..374317523b7a9959015fdc6301937e1e0c1bf6fc --- /dev/null +++ b/test/_inductor/test_clone.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestClone(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return torch.clone(input_element) + + # case: change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 64, 128)]) + @pytest.mark.parametrize('dim', [0]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + std_ret = self.op_calc(input_element, dim) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element, dim) + + assert torch.allclose(std_ret, inductor_ret, equal_nan=True) + + +if __name__ == "__main__": + size = (8, 64, 128) + test = TestClone() + test.test_reduction_cases_shapes(size, 2, 'float32', None) + + + diff --git a/test/_inductor/test_cos.py b/test/_inductor/test_cos.py new file mode 100644 index 0000000000000000000000000000000000000000..b963eb8ce822c6f4be91057137aa82a56c487089 --- /dev/null +++ b/test/_inductor/test_cos.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestLog(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element): + result = torch.cos(first_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型, 将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float32', 'int64']) + @pytest.mark.skip(reason='not support yet') + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.testing.assert_close(std_result, inductor_result) + + # should be implemented when __OPTYPE is OperatorType.REDUCTION + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape,dim', TestUtils._reduction_extest_SDbinding) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + @pytest.mark.skipif(__OPTYPE != OperatorType.REDUCTION, reason='not reduction operator') + def test_reduction_cases(self, shape, dim, dtype, clear_cache): + pass + +if __name__ == '__main__': + TestLog() + + diff --git a/test/_inductor/test_device_put.py b/test/_inductor/test_device_put.py new file mode 100644 index 0000000000000000000000000000000000000000..39b17ea27db16d315dc28c8c2f876c1fc7bf35f9 --- /dev/null +++ b/test/_inductor/test_device_put.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestDevicePut(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element1, input_element2): + return torch.add(input_element1, input_element2) + + # case: change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 16, 8)]) + @pytest.mark.parametrize('dtype', ['int32']) + def test_cases_shapes(self, shape, dtype, clear_cache): + low = 0 + high = 2 + dtype = eval('torch.' + dtype) + print(f"shape= {shape}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + # 指定目标设备为 NPU + npu_device = torch.device('npu:0') + input_element1_tmp = torch.randint(low, high, shape, dtype=dtype).cpu() + input_element2_tmp = torch.randint(low, high, shape, dtype=dtype).cpu() + input_element1 = torch.ops.prims.device_put(input_element1_tmp, npu_device) + input_element2 = torch.ops.prims.device_put(input_element2_tmp, npu_device) + + std_ret = self.op_calc(input_element1, input_element2) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element1, input_element2) + + assert torch.allclose(std_ret, inductor_ret, equal_nan=True) + + +if __name__ == "__main__": + size = (8, 16, 8) + test = TestDevicePut() + test.test_cases_shapes(size, 2, 'int32', None) + + + diff --git a/test/_inductor/test_div.py b/test/_inductor/test_div.py new file mode 100644 index 0000000000000000000000000000000000000000..318b521fe43b9e8ee72d2b58b0f75b5e1d28a5a9 --- /dev/null +++ b/test/_inductor/test_div.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestMul(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element, second_element): + result = torch.div(first_element, second_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型, 将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + torch.testing.assert_close(std_result, inductor_result, equal_nan=True) + + + diff --git a/test/_inductor/test_embedding.py b/test/_inductor/test_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..c7732ec5c65a050f11deda648c9dc1c9bdf0c38c --- /dev/null +++ b/test/_inductor/test_embedding.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +#from testutils import OperatorType, TestUtils +import torch.nn as nn + +class TestSub(): + + def op_calc(self): + embedding = nn.Embedding(16, 128).npu() + + input = torch.tensor([[14, 1, 2, 10, 0, 10, 0], + [ 9, 13, 13, 4, 7, 15, 14], + [ 8, 0, 3, 15, 4, 2, 6], + [15, 12, 13, 9, 0, 8, 1], + [ 8, 15, 4, 15, 12, 9, 3], + [ 6, 11, 12, 8, 0, 13, 8], + [ 4, 10, 1, 12, 0, 0, 4], + [ 6, 6, 15, 6, 0, 10, 15], + [ 2, 5, 14, 0, 5, 7, 9], + [13, 4, 14, 11, 11, 9, 2], + [ 1, 1, 5, 1, 1, 6, 14], + [ 3, 9, 8, 4, 13, 8, 3], + [ 4, 10, 8, 13, 6, 8, 3]], device='npu:0') + + output = embedding(input.npu()) + return output + + def test_pointwise_cases(self): + torch_npu._inductor.config.enable_npu_indexing = True + + std_sub = self.op_calc() + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_sum = compiled_op_calc() + #torch.testing.assert_close(std_sub, inductor_sum) + + +if __name__ == "__main__": + test = TestSub() + test.test_pointwise_cases() + + + + diff --git a/test/_inductor/test_embedding_fallback.py b/test/_inductor/test_embedding_fallback.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5b492b9de662d8c5fe16a594faa73a68aa16ad --- /dev/null +++ b/test/_inductor/test_embedding_fallback.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRsqrt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, slice_4, sum_23): + result = torch.ops.aten.embedding_dense_backward.default(sum_23, slice_4, 512, -1, False) + return result + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(1, 512, 128)]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_pointwise_cases(self, shape, dtype): + torch_npu._inductor.config.enable_npu_indexing = True + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = torch.randint(low=0, high=128, size=(1, 512), dtype=torch.int64).npu() + second_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + + print(std_result) + print(inductor_result) + + torch.testing.assert_close(std_result, inductor_result, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + size = (1, 512, 128) + test = TestRsqrt() + test.test_pointwise_cases(size, 'float32') + + + + diff --git a/test/_inductor/test_empty.py b/test/_inductor/test_empty.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc8fe36f5d0379cc61ac2ec48dffbbb8a890765 --- /dev/null +++ b/test/_inductor/test_empty.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestEmpty(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self): + x = torch.empty(8, 64, 128, dtype=torch.float32).npu() + x.uniform_(-100, 100) + return x + def op_calc_empty_permuted(self): + input_shape = (8, 64, 128) + physical_layout =(0, 1, 2) #物理布局与输入形状相同 + x = torch.empty_permuted(input_shape, physical_layout).npu() + x.uniform_(-100, 100) + return x + + # case: change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 64, 128)]) + @pytest.mark.parametrize('dim', [0]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_cases_empty(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + std_ret = self.op_calc() + # print(f"std_ret= {std_ret}") + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc() + # print(f"inductor_ret= {inductor_ret}") + + assert inductor_ret.numel() > 0 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 64, 128)]) + @pytest.mark.parametrize('dim', [0]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_cases_empty_permuted(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + std_ret = self.op_calc_empty_permuted() + # print(f"std_ret= {std_ret}") + compiled_op_calc = torch.compile(self.op_calc_empty_permuted, backend="inductor") + inductor_ret = compiled_op_calc() + # print(f"inductor_ret= {inductor_ret}") + + assert inductor_ret.numel() > 0 + + +if __name__ == "__main__": + size = (8, 64, 128) + test = TestEmpty() + test.test_reduction_cases_shapes(size, 2, 'float32', None) + + + diff --git a/test/_inductor/test_eq.py b/test/_inductor/test_eq.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4c9d103bc96d260f2337b14a694136c4c2b9da --- /dev/null +++ b/test/_inductor/test_eq.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestEq(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element, second_element): + return torch.eq(first_element, second_element) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + first_element = self._generate_tensor(shape, dtype) + second_element = first_element.clone() + + # randomly change some elements in second tensor + flat_second_view = second_element.flatten() + num_elements_to_change = first_element.numel() //3 + random_indices = torch.randint(0, first_element.numel(), (num_elements_to_change,)) + flat_second_view[random_indices] = 1- flat_second_view[random_indices] + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + + torch.testing.assert_close(std_result, inductor_result) + + + diff --git a/test/_inductor/test_exp.py b/test/_inductor/test_exp.py new file mode 100644 index 0000000000000000000000000000000000000000..078f9e653d9cee2b9d292cd0717d3801e462bf2a --- /dev/null +++ b/test/_inductor/test_exp.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestExp(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element): + result = torch.exp(first_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型, 将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + # print(std_result[0:8]) + # print(inductor_result[0:8]) + # torch.testing.assert_close(std_result, inductor_result) + # 需要比较包含 NaN 值的张量, 并且希望认为两个 NaN值是相等的, 您可以使用 torch.allclose 函数, 并设置 equal_nan=True 参数 + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_result, inductor_result, equal_nan=True, rtol=rtol, atol=atol) + + + diff --git a/test/_inductor/test_expm1.py b/test/_inductor/test_expm1.py new file mode 100644 index 0000000000000000000000000000000000000000..27d8e053466b9eb50c6909efa92e3d88573ee7f5 --- /dev/null +++ b/test/_inductor/test_expm1.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestSqrt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element): + result = torch.expm1(first_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型, 将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.allclose(std_result, inductor_result, equal_nan=True) + + diff --git a/test/_inductor/test_floor.py b/test/_inductor/test_floor.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7d144feed47d8fddaf4119a1af812219424562 --- /dev/null +++ b/test/_inductor/test_floor.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRelu(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element): + result = torch.floor(first_element) + return result + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.testing.assert_close(std_result, inductor_result) + + +if __name__ == '__main__': + TestRelu() + + + + diff --git a/test/_inductor/test_foreach_add.py b/test/_inductor/test_foreach_add.py new file mode 100644 index 0000000000000000000000000000000000000000..66111096f641f0b3af5d31d90afe1c63b79c51eb --- /dev/null +++ b/test/_inductor/test_foreach_add.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRsqrt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element, second_element): + tensor_list = [first_element, second_element] + + add_list =[first_element, second_element] + result = torch._foreach_add_(tensor_list, add_list) + return result + + @pytest.mark.skip(reason='compile error, torch npu segmet fault') + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['int32']) + def test_pointwise_cases(self, shape, dtype): + torch_npu._inductor.config.enable_npu_indexing = True + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + + torch.testing.assert_close(std_result, inductor_result, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + size = (1024, 32) + test = TestRsqrt() + test.test_pointwise_cases(size, 'float32') + + + diff --git a/test/_inductor/test_ge.py b/test/_inductor/test_ge.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe7a95e50f2d15c2b7500485926efde143c5f6e --- /dev/null +++ b/test/_inductor/test_ge.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestGe(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element, second_element): + return torch.ge(first_element, second_element) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + + torch.testing.assert_close(std_result, inductor_result) + + + diff --git a/test/_inductor/test_geometric.py b/test/_inductor/test_geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..827146f2e5a43b3190867d53b33e22005649bcb3 --- /dev/null +++ b/test/_inductor/test_geometric.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestGeometric(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self): + # 创建一个形状为 (3, 3)的张量, 每个位置的概率为 0.5 + prob =torch.full((16, 16), 0.5).npu() + + #使用 aten.geometric生成几何分布的随机数 + geometric_tensor =torch.ops.aten.geometric(prob, p=0.5) + + return geometric_tensor + + # case: change shapes + @pytest.mark.skip(reason="this has problem in torch 260") + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(16, 16, 16)]) + @pytest.mark.parametrize('dim', [0]) + @pytest.mark.parametrize('dtype', ['int32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + std_ret = self.op_calc() + std_ret_mean =torch.mean(std_ret) + print(f"std_ret_mean= {std_ret_mean}") + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc() + + inductor_ret_mean = torch.mean(inductor_ret) + print(f"inductor_ret_mean= {inductor_ret_mean}") + assert inductor_ret_mean is not None + + +if __name__ == "__main__": + size = (16, 16, 16) + test = TestGeometric() + test.test_reduction_cases_shapes(size, -1, 'float32', None) + + + diff --git a/test/_inductor/test_gt.py b/test/_inductor/test_gt.py new file mode 100644 index 0000000000000000000000000000000000000000..4b29d7eef7ed75e4561b23a62db8c422e66745db --- /dev/null +++ b/test/_inductor/test_gt.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestGt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element, second_element): + result = torch.gt(first_element, second_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型, 将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32']) + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + + print("start test!") + torch.testing.assert_close(std_result, inductor_result) + + # should be implemented when __OPTYPE is OperatorType.REDUCTION + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape,dim', TestUtils._reduction_extest_SDbinding) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + @pytest.mark.skipif(__OPTYPE != OperatorType.REDUCTION, reason='not reduction operator') + def test_reduction_cases(self, shape, dim, dtype, clear_cache): + pass + +if __name__ == '__main__': + TestGt() + + + + + diff --git a/test/_inductor/test_high_order_sum.py b/test/_inductor/test_high_order_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..8b48e963e3bd205e279c24b17fa68642220c299d --- /dev/null +++ b/test/_inductor/test_high_order_sum.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch.nn.functional as F +import torch +import torch_npu +import torch_npu._inductor + +def op_sum(npu_dropout_backward_9): + view_337: "f32[32768, 256]" = torch.ops.aten.view.default(npu_dropout_backward_9, [32768, 256]); + sum_63: "f32[1, 256]" = torch.ops.aten.sum.dim_IntList(view_337, [0], True); + view_338: "f32[256]" = torch.ops.aten.view.default(sum_63, [256]); + return view_338 + +device='npu' + +def test_high_order_sum(): + npu_dropout_backward_9 = torch.randn((32768, 256), device=device, dtype=torch.float32) + ref = op_sum(npu_dropout_backward_9) + func = torch.compile(op_sum, backend="inductor", dynamic=False) + calc = func(npu_dropout_backward_9) + + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + +if __name__ == "__main__": + npu_dropout_backward_9 = torch.randn((32768, 256), device=device, dtype=torch.float32) + ref = op_sum(npu_dropout_backward_9) + func = torch.compile(op_sum, backend="inductor", dynamic=False) + calc = func(npu_dropout_backward_9) + + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False + ) + with torch_npu.profiler.profile( + activities=[ # torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU], + with_stack=False, #采集torch 算子的函数调用栈的开关,该参数选填,默认关闭 + record_shapes=False, # 采集torch 算子的input shape和input type的开关,该参数选填,默认关闭 + profile_memory=False, # 采集memory相关数据的开关,该参数选填,默认关闭 + schedule=torch_npu.profiler.schedule(wait=1, + warmup=1, + active=10, + repeat=1, + skip_first=1), + # schedule=torch_npu.profiler.schedule(wait=1, warmup=1, active=1, skip_first=6), + # warmup默认为0,老版本torch_npu包该参数为必填项 + experimental_config=experimental_config, # 该参数选填,默认为Level0 + # 产生的profling文件的位置 + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./result_dir") + # 导出tensorboard可呈现的数据形式,可指定worker_name, 默认为:{host名称}_{进程id} + ) as prof: + for i in range(20): + # ref1 = call(args) + op_sum(npu_dropout_backward_9) + func(npu_dropout_backward_9) + prof.step() + + + + + + diff --git a/test/_inductor/test_issue54.py b/test/_inductor/test_issue54.py new file mode 100644 index 0000000000000000000000000000000000000000..2f532c059bcffd2c16363a41c1814eb9a97fc834 --- /dev/null +++ b/test/_inductor/test_issue54.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import torch.nn.functional as F +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from torch.nn import CrossEntropyLoss +from torch import nn +from test2.npu_indexing.utils import benchmark_test + + +class Test_issue54(): + def func_layernorm(self, add_3, primals_6, primals_7, view, primals_9, permute_1, primals_10, primals_11): + # 原网络 + permute: "f32[256, 256]" = torch.ops.aten.permute.default(primals_6, [1, 0]); + addmm: "f32[32768, 256]" = torch.ops.aten.addmm.default(primals_7, view, permute); + view_1: "f32[64, 512, 256]" = torch.ops.aten.view.default(addmm, [64, 512, 256]); + addmm_1: "f32[32768, 256]" = torch.ops.aten.addmm.default(primals_9, view, permute_1); + view_3: "f32[64, 512, 256]" = torch.ops.aten.view.default(addmm_1, [64, 512, 256]); + view_4: "f32[64, 512, 4, 64]" = torch.ops.aten.view.default(view_3, [64, 512, 4, 64]); + permute_2: "f32[64, 4, 512, 64]" = torch.ops.aten.permute.default(view_4, [0, 2, 1, 3]); + permute_3: "f32[256, 256]" = torch.ops.aten.permute.default(primals_10, [1, 0]); + addmm_2: "f32[32768, 256]" = torch.ops.aten.addmm.default(primals_11, view, permute_3); + view_6: "f32[64, 512, 256]" = torch.ops.aten.view.default(addmm_2, [64, 512, 256]); + + view_8: "f32[64, 512, 4, 64]" = torch.ops.aten.view.default(view_1, [64, 512, 4, 64]); + permute_5: "f32[64, 4, 512, 64]" = torch.ops.aten.permute.default(view_8, [0, 2, 1, 3]); + + permute_6: "f32[64, 4, 64, 512]" = torch.ops.aten.permute.default(permute_2, [0, 1, 3, 2]); + expand_1: "f32[64, 4, 512, 64]" = torch.ops.aten.expand.default(permute_5, [64, 4, 512, 64]) + clone: "f32[64, 4, 512, 64]" = torch.ops.aten.clone.default(expand_1, memory_format=torch.contiguous_format); + view_9: "f32[256, 512, 64]" = torch.ops.aten.view.default(clone, [256, 512, 64]); + expand_2: "f32[64, 4, 64, 512]" = torch.ops.aten.expand.default(permute_6, [64, 4, 64, 512]) + clone_1: "f32[64, 4, 64, 512]" = torch.ops.aten.clone.default(expand_2, memory_format=torch.contiguous_format); + view_10: "f32[256, 64, 512]" = torch.ops.aten.view.default(clone_1, [256, 64, 512]); + bmm: "f32[256, 512, 512]" = torch.ops.aten.bmm.default(view_9, view_10); + view_7: "f32[64, 512, 4, 64]" = torch.ops.aten.view.default(view_6, [64, 512, 4, 64]); + permute_4: "f32[64, 4, 512, 64]" = torch.ops.aten.permute.default(view_7, [0, 2, 1, 3]); + expand_4: "f32[64, 4, 512, 64]" = torch.ops.aten.expand.default(permute_4, [64, 4, 512, 64]) + clone_2: "f32[64, 4, 512, 64]" = torch.ops.aten.clone.default(expand_4, memory_format=torch.contiguous_format); + view_13: "f32[256, 512, 64]" = torch.ops.aten.view.default(clone_2, [256, 512, 64]); + + return bmm, view_13 + + def test_issue54(self): + device = 'npu' + test = Test_issue54() + # add_3, primals_6, primals_7, view, primals_9, permute_1, primals_10, primals_11 + + add_3 = torch.randn((64, 512, 256), device=device, dtype=torch.float32) + primals_6 = torch.randn((256, 256), device=device, dtype=torch.float32) + primals_7 = torch.randn((256), device=device, dtype=torch.float32) + view = torch.randn((32768, 256), device=device, dtype=torch.float32) + primals_9 = torch.randn((256), device=device, dtype=torch.float32) + permute_1 = torch.randn((256, 256), device=device, dtype=torch.float32) + primals_10 = torch.randn((256, 256), device=device, dtype=torch.float32) + primals_11 = torch.randn((256), device=device, dtype=torch.float32) + + ref = test.func_layernorm(add_3, primals_6, primals_7, view, primals_9, permute_1, primals_10, primals_11) + func = torch.compile(test.func_layernorm, backend="inductor", dynamic=False, + options={"unroll_reductions_threshold": 1, "aggressive_fusion": True}) + calc = func(add_3, primals_6, primals_7, view, primals_9, permute_1, primals_10, primals_11) + torch.testing.assert_close(ref[0], calc[0], rtol=1e-3, atol=1e-3) + torch.testing.assert_close(ref[1], calc[1], rtol=1e-3, atol=1e-3) + print("valid ok") + + benchmark_test(test.func_layernorm, func, + args=(add_3, primals_6, primals_7, view, primals_9, permute_1, primals_10, primals_11,), + name="test_layernorm", times=10, repeat=10, profile=False) + + +if __name__ == "__main__": + test = Test_issue54() + test.test_issue54() \ No newline at end of file diff --git a/test/_inductor/test_issue57.py b/test/_inductor/test_issue57.py new file mode 100644 index 0000000000000000000000000000000000000000..5ad6be8e2d0de3aa4f5fea42afb0caa858997f61 --- /dev/null +++ b/test/_inductor/test_issue57.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import torch.nn.functional as F +import torch +import torch_npu +import torch_npu._inductor +import pytest +from test2.npu_indexing.utils import benchmark_test + + +class Test_issue57(): + def op_sum(self, view_12, embedding_1, slice_11): + # 原网络 + + permute_7 = torch.ops.aten.permute.default(embedding_1, [2, 0, 1]); + embedding_1 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(permute_7, 0); + permute_7 = None + + add_5 = torch.ops.aten.add.Tensor(unsqueeze_4, slice_11); + slice_8 = slice_11 = None + add_6 = torch.ops.aten.add.Tensor(view_12, add_5); + view_12 = None + return add_6 + + def test_issue57(self): + device = 'npu' + test = Test_issue57() + embedding_1 = torch.randn((512, 512, 64), device=device, dtype=torch.float32) + primals_221 = torch.randn((1, 1, 1, 512), device=device, dtype=torch.float32) + view_12 = torch.randn((1, 64, 512, 512), device=device, dtype=torch.float32) + slice_11 = torch.randn((1, 1, 1, 512), device=device, dtype=torch.float32) + + ref = test.op_sum(view_12, embedding_1, primals_221) + func = torch.compile(test.op_sum, backend="inductor", dynamic=False) + calc = func(view_12, embedding_1, primals_221) + + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + + print("valid ok") + benchmark_test(test.op_sum, func, args=(view_12, embedding_1, primals_221), + name="issue57", times=10, repeat=10, profile=False) + + +if __name__ == "__main__": + test = Test_issue57() + test.test_issue57() \ No newline at end of file diff --git a/test/_inductor/test_issue59.py b/test/_inductor/test_issue59.py new file mode 100644 index 0000000000000000000000000000000000000000..a1644749e4ff0cf1739e1877d9af6078e0d317de --- /dev/null +++ b/test/_inductor/test_issue59.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from test2.npu_indexing.utils import benchmark_test + + +class Test_issue59(): + def layernorm_backward(self, x, y, z): + sum = torch.sum(x) + mean = sum / torch.numel(sum) + sub = x - mean + sqr = sub * sub + sum_1 = torch.sum(sqr) + mean_1 = sum_1 / torch.numel(sum_1) + 1e-05 + rsqrt = torch.rsqrt(mean_1) + mul = sub * rsqrt + mul_1 = mul * y + add = mul_1 + z + mean_2 = rsqrt / torch.numel(rsqrt) + return mul, add, mean_2 + + def test_issue59(self): + device = 'npu' + test = Test_issue59() + x = torch.randn((1, 1024), device=device, dtype=torch.float32) + y = torch.randn((1, 1024), device=device, dtype=torch.float32) + z = torch.randn((1, 1024), device=device, dtype=torch.float32) + + mul, add, mean_2 = test.layernorm_backward(x, y, z) + func = torch.compile(test.layernorm_backward, backend="inductor", dynamic=False) + mul_t, add_t, mean_2_t = func(x, y, z) + + torch.testing.assert_close(mul, mul_t, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(add, add_t, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(mean_2, mean_2_t, rtol=1e-3, atol=1e-3) + + print("valid ok") + benchmark_test(test.layernorm_backward, func, args=(x, y, z), + name="issue59", times=10, repeat=10, profile=False) + + +if __name__ == "__main__": + test = Test_issue59() + test.test_issue59() diff --git a/test/_inductor/test_issue62.py b/test/_inductor/test_issue62.py new file mode 100644 index 0000000000000000000000000000000000000000..075b45a7b04ef9c6f37a340deb47c1d0ff157235 --- /dev/null +++ b/test/_inductor/test_issue62.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import torch +import torch_npu +import triton +import triton.language as tl +import torch_npu._inductor +import pytest + + +# 实际就是 layernorm的计算过程 : torch.nn.LayerNorm(convert_element_type_25, elementwise_affine=False, eps=1e-6) +class Test_issue62(): + def op_func(self, addmm_5, add): + split = torch.ops.aten.split.Tensor(addmm_5, 1536, 1) + getitem = split[0] + getitem_1 = split[1] + getitem_2 = split[2] + getitem_3 = split[3] + getitem_4 = split[4] + getitem_5 = split[5] + + clone_1 = torch.ops.aten.clone.default(add, memory_format=torch.contiguous_format) + convert_element_type_25 = torch.ops.prims.convert_element_type.default(clone_1, torch.float32) + var_mean = torch.ops.aten.var_mean.correction(convert_element_type_25, [2], correction=0, keepdim=True) + getitem_6 = var_mean[0] + getitem_7 = var_mean[1] + add_3 = torch.ops.aten.add.Tensor(getitem_6, 1e-06) + rsqrt = torch.ops.aten.rsqrt.default(add_3) + sub = torch.ops.aten.sub.Tensor(clone_1, getitem_7) + mul_7 = torch.ops.aten.mul.Tensor(sub, rsqrt) + convert_element_type_26 = torch.ops.prims.convert_element_type.default(mul_7, torch.float16) + slice_11 = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 9223372036854775807) + unsqueeze_2 = torch.ops.aten.unsqueeze.default(slice_11, 1) + add_4 = torch.ops.aten.add.Tensor(unsqueeze_2, 1) + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_26, add_4) + slice_12 = torch.ops.aten.slice.Tensor(getitem, 0, 0, 9223372036854775807) + unsqueeze_3 = torch.ops.aten.unsqueeze.default(slice_12, 1) + add_5 = torch.ops.aten.add.Tensor(mul_8, unsqueeze_3) + return add_5 + + def test_issue62(self): + test = Test_issue62() + addmm_5 = torch.randn((2, 9216), device='npu:0', dtype=torch.float16) + add = torch.randn((2, 4096, 1536), device='npu:0', dtype=torch.float16) + + std_ret = test.op_func(addmm_5, add) + compiled_func = torch.compile(test.op_func, backend="inductor") + inductor_ret = compiled_func(addmm_5, add) + assert torch.allclose(std_ret, inductor_ret, atol=1e-2, rtol=1e-2), "Tensors are not close enough!" + print("valid ok") + + +if __name__ == "__main__": + test = Test_issue62() + test.test_issue62() \ No newline at end of file diff --git a/test/_inductor/test_issue70.py b/test/_inductor/test_issue70.py new file mode 100644 index 0000000000000000000000000000000000000000..6b8410bb1d100cdba70eb8dbf4eb187778fb78d1 --- /dev/null +++ b/test/_inductor/test_issue70.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import torch.nn as nn +import pytest + + +class Test_issue70(): + + def op_forward(self, x): + return x.mean(-1) + + + def test_issue70(self): + test = Test_issue70() + compiled_net = torch.compile(test.op_forward, backend="inductor") + + input = torch.randn((1, 1, 7168)).npu() + + output = test.op_forward(input) + output1 = compiled_net(input) + torch.testing.assert_allclose(output, output1, rtol=1e-03, atol=1e-03) + print("valid ok") + + +if __name__ == "__main__": + test = Test_issue70() + test.test_issue70() diff --git a/test/_inductor/test_opensora_graph1.py b/test/_inductor/test_opensora_graph1.py new file mode 100644 index 0000000000000000000000000000000000000000..5b7f35992f986abac4c90f52a65f00379ba9974b --- /dev/null +++ b/test/_inductor/test_opensora_graph1.py @@ -0,0 +1,343 @@ +import torch +import torch_npu +import torch_npu._inductor +import pytest +__TIME_LIMIT = 100 +from torch import device +device_npu = 'npu' + +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_9_inference(): + def forward(primals_1: "f32[1, 9600, 2304]"): + permute: "f32[9600, 1, 2304]" = torch.ops.aten.permute.default(primals_1, [1, 0, 2]); + return permute + primals_2 = torch.randn((1, 9600, 2304), device = device_npu, dtype=torch.float32) + ref = forward(primals_2) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_2) + assert torch.allclose(ref, calc, equal_nan=True, rtol=1e-4, atol=1e-4) + primals_3 = torch.randn((1, 512, 2304), device=device_npu, dtype=torch.float32) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_3) + ref = forward(primals_3) + assert torch.allclose(ref, calc, equal_nan=True, rtol=1e-4, atol=1e-4) + primals_4 = torch.randn((9600, 1, 2304), device=device_npu, dtype=torch.float32) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_4) + ref = forward(primals_4) + assert torch.allclose(ref, calc, equal_nan=True, rtol=1e-4, atol=1e-4) + +@pytest.mark.skip +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_11_inference(): + def forward(arg0_1: "f32[1, 1, 9600]", arg1_1: "f32[1, 1, 512]"): + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:119 in prepare_sparse_mask, code: video_mask = video_mask.unsqueeze(1) + unsqueeze: "f32[1, 1, 1, 9600]" = torch.ops.aten.unsqueeze.default(arg0_1, 1); + arg0_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:120 in prepare_sparse_mask, code: prompt_mask = prompt_mask.unsqueeze(1) + unsqueeze_1: "f32[1, 1, 1, 512]" = torch.ops.aten.unsqueeze.default(arg1_1, 1); + arg1_1 = None + # File: /root/anaconda3/envs/inductor2.3_sora/lib/python3.9/site-packages/torch/nn/functional.py:4522 in pad, code: return torch._C._nn.pad(input, pad, mode, value) + constant_pad_nd: "f32[1, 1, 1, 9600]" = torch.ops.aten.constant_pad_nd.default(unsqueeze, [0, 0, 0, 0], + -9980.0); + unsqueeze = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:128 in prepare_sparse_mask, code: video_mask_sparse_1d = rearrange( + view: "f32[1, 9600, 1]" = torch.ops.aten.view.default(constant_pad_nd, [1, 9600, 1]) + permute: "f32[1, 1, 9600]" = torch.ops.aten.permute.default(view, [2, 0, 1]); + view = None + view_1: "f32[1, 1, 1, 9600]" = torch.ops.aten.view.default(permute, [1, 1, 1, 9600]); + permute = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:133 in prepare_sparse_mask, code: video_mask_sparse_1d_group = rearrange( + view_2: "f32[1, 9600, 1, 1]" = torch.ops.aten.view.default(constant_pad_nd, [1, 9600, 1, 1]); + constant_pad_nd = None + permute_1: "f32[1, 1, 9600, 1]" = torch.ops.aten.permute.default(view_2, [2, 0, 1, 3]); + view_2 = None + view_3: "f32[1, 1, 1, 9600]" = torch.ops.aten.view.default(permute_1, [1, 1, 1, 9600]); + permute_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:139 in prepare_sparse_mask, code: prompt_mask_sparse = prompt_mask.repeat(sparse_n, 1, 1, 1) + repeat: "f32[1, 1, 1, 512]" = torch.ops.aten.repeat.default(unsqueeze_1, [1, 1, 1, 1]); + unsqueeze_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:142 in get_attention_mask, code: mask = mask.to(torch.bool) + npu_dtype_cast: "b8[1, 1, 1, 9600]" = torch.ops.npu.npu_dtype_cast.default(view_1, torch.bool); + view_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:143 in get_attention_mask, code: mask = mask.repeat(1, 1, repeat_num, 1) + repeat_1: "b8[1, 1, 9600, 9600]" = torch.ops.aten.repeat.default(npu_dtype_cast, [1, 1, 9600, 1]); + npu_dtype_cast = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:142 in get_attention_mask, code: mask = mask.to(torch.bool) + npu_dtype_cast_1: "b8[1, 1, 1, 9600]" = torch.ops.npu.npu_dtype_cast.default(view_3, torch.bool); + view_3 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:143 in get_attention_mask, code: mask = mask.repeat(1, 1, repeat_num, 1) + repeat_2: "b8[1, 1, 9600, 9600]" = torch.ops.aten.repeat.default(npu_dtype_cast_1, [1, 1, 9600, 1]); + npu_dtype_cast_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:142 in get_attention_mask, code: mask = mask.to(torch.bool) + npu_dtype_cast_2: "b8[1, 1, 1, 512]" = torch.ops.npu.npu_dtype_cast.default(repeat, torch.bool); + repeat = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:143 in get_attention_mask, code: mask = mask.repeat(1, 1, repeat_num, 1) + repeat_3: "b8[1, 1, 9600, 512]" = torch.ops.aten.repeat.default(npu_dtype_cast_2, [1, 1, 9600, 1]); + npu_dtype_cast_2 = None + return (repeat_1, repeat_3, repeat_2) + arg0_1 = torch.rand((1, 1, 9600), device=device_npu, dtype=torch.float32) + arg1_1 = torch.rand((1, 1, 512), device=device_npu, dtype=torch.float32) + ref = forward(arg0_1, arg1_1) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(arg0_1, arg1_1) + for i in range(len(ref)): + print(ref[i]) + assert torch.allclose(ref[i], calc[i], equal_nan=True, rtol=1e-4, atol=1e-4) + +@pytest.mark.skip +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_14_backward(): + def forward(primals_5: "f32[1, 9600, 2304]", getitem_3: "f32[1, 9600, 1]", rsqrt: "f32[1, 9600, 1]", + add_2: "f32[1, 1, 2304]", view: "f32[9600, 2304]", permute_1: "f32[32, 2304]", + tangents_1: "f32[1, 9600, 32]"): + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:384 in _get_output_for_patched_inputs, code: latents = self.norm_out(latents) + sub: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(primals_5, getitem_3); + primals_5 = getitem_3 = None + mul: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(sub, rsqrt); + sub = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:387 in _get_output_for_patched_inputs, code: latents = self.proj_out(latents) + view_2: "f32[9600, 32]" = torch.ops.aten.view.default(tangents_1, [9600, 32]); + tangents_1 = None + mm: "f32[9600, 2304]" = torch.ops.aten.mm.default(view_2, permute_1); + permute_1 = None + permute_2: "f32[32, 9600]" = torch.ops.aten.permute.default(view_2, [1, 0]) + mm_1: "f32[32, 2304]" = torch.ops.aten.mm.default(permute_2, view); + permute_2 = view = None + permute_3: "f32[2304, 32]" = torch.ops.aten.permute.default(mm_1, [1, 0]); + mm_1 = None + sum_1: "f32[1, 32]" = torch.ops.aten.sum.dim_IntList(view_2, [0], True); + view_2 = None + view_3: "f32[32]" = torch.ops.aten.view.default(sum_1, [32]); + sum_1 = None + permute_4: "f32[32, 2304]" = torch.ops.aten.permute.default(permute_3, [1, 0]); + permute_3 = None + view_4: "f32[1, 9600, 2304]" = torch.ops.aten.view.default(mm, [1, 9600, 2304]); + mm = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:386 in _get_output_for_patched_inputs, code: latents = latents * (1 + scale) + shift + sum_2: "f32[1, 1, 2304]" = torch.ops.aten.sum.dim_IntList(view_4, [1], True) + mul_2: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(view_4, mul) + mul_3: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(view_4, add_2); + view_4 = add_2 = None + sum_3: "f32[1, 1, 2304]" = torch.ops.aten.sum.dim_IntList(mul_2, [1], True); + mul_2 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:384 in _get_output_for_patched_inputs, code: latents = self.norm_out(latents) + mul_5: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul_3, 2304) + sum_4: "f32[1, 9600, 1]" = torch.ops.aten.sum.dim_IntList(mul_3, [2], True) + mul_6: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul_3, mul); + mul_3 = None + sum_5: "f32[1, 9600, 1]" = torch.ops.aten.sum.dim_IntList(mul_6, [2], True); + mul_6 = None + mul_7: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul, sum_5); + mul = sum_5 = None + sub_2: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(mul_5, sum_4); + mul_5 = sum_4 = None + sub_3: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(sub_2, mul_7); + sub_2 = mul_7 = None + div: "f32[1, 9600, 1]" = torch.ops.aten.div.Tensor(rsqrt, 2304); + rsqrt = None + mul_8: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(div, sub_3); + div = sub_3 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:383 in _get_output_for_patched_inputs, code: shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + cat: "f32[1, 2, 2304]" = torch.ops.aten.cat.default([sum_2, sum_3], 1); + sum_2 = sum_3 = None + sum_6: "f32[1, 1, 2304]" = torch.ops.aten.sum.dim_IntList(cat, [1], True) + squeeze_1: "f32[1, 2304]" = torch.ops.aten.squeeze.dim(sum_6, 1); + sum_6 = None + full_default: "f32[1, 2304]" = torch.ops.aten.full.default([1, 2304], 0, dtype=torch.float32, + layout=torch.strided, + device=device(type='npu', index=0), pin_memory=False) + slice_scatter: "f32[1, 2304]" = torch.ops.aten.slice_scatter.default(full_default, squeeze_1, 0, 0, + 9223372036854775807); + full_default = squeeze_1 = None + squeeze_2: "f32[2, 2304]" = torch.ops.aten.squeeze.dim(cat, 0); + cat = None + return [squeeze_2, permute_4, view_3, slice_scatter, mul_8] + primals_5 = torch.randn((1, 9600, 2304), device=device_npu, dtype=torch.float32) + getitem_3 = torch.randn((1, 9600, 1), device=device_npu, dtype=torch.float32) + rsqrt = torch.randn((1, 9600, 1), device=device_npu, dtype=torch.float32) + add_2 = torch.randn((1, 1, 2304), device=device_npu, dtype=torch.float32) + view = torch.randn((9600, 2304), device=device_npu, dtype=torch.float32) + permute_1 = torch.randn((32, 2304), device=device_npu, dtype=torch.float32) + tangents_1 = torch.randn((1, 9600, 32), device=device_npu, dtype=torch.float32) + ref = forward(primals_5, getitem_3, rsqrt, + add_2, view, permute_1,tangents_1) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_5, getitem_3, rsqrt, + add_2, view, permute_1,tangents_1) + for i in range(len(ref)): + # 1e-3 can not pass, should check reduction accuracy + assert torch.allclose(ref[i], calc[i], equal_nan=True, rtol=1e-4, atol=1e-4) + +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_14_forward(): + def forward(primals_1: "f32[2, 2304]", primals_2: "f32[32, 2304]", primals_3: "f32[32]", + primals_4: "f32[1, 2304]", primals_5: "f32[1, 9600, 2304]"): + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:383 in _get_output_for_patched_inputs, code: shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + unsqueeze: "f32[1, 2, 2304]" = torch.ops.aten.unsqueeze.default(primals_1, 0); + primals_1 = None + slice_1: "f32[1, 2304]" = torch.ops.aten.slice.Tensor(primals_4, 0, 0, 9223372036854775807); + primals_4 = None + unsqueeze_1: "f32[1, 1, 2304]" = torch.ops.aten.unsqueeze.default(slice_1, 1); + slice_1 = None + add: "f32[1, 2, 2304]" = torch.ops.aten.add.Tensor(unsqueeze, unsqueeze_1); + unsqueeze = unsqueeze_1 = None + split = torch.ops.aten.split.Tensor(add, 1, 1); + add = None + getitem: "f32[1, 1, 2304]" = split[0] + getitem_1: "f32[1, 1, 2304]" = split[1]; + split = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:384 in _get_output_for_patched_inputs, code: latents = self.norm_out(latents) + var_mean = torch.ops.aten.var_mean.correction(primals_5, [2], correction=0, keepdim=True) + getitem_2: "f32[1, 9600, 1]" = var_mean[0] + getitem_3: "f32[1, 9600, 1]" = var_mean[1]; + var_mean = None + add_1: "f32[1, 9600, 1]" = torch.ops.aten.add.Tensor(getitem_2, 1e-06); + getitem_2 = None + rsqrt: "f32[1, 9600, 1]" = torch.ops.aten.rsqrt.default(add_1); + add_1 = None + sub: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(primals_5, getitem_3) + mul: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(sub, rsqrt); + sub = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:386 in _get_output_for_patched_inputs, code: latents = latents * (1 + scale) + shift + add_2: "f32[1, 1, 2304]" = torch.ops.aten.add.Tensor(getitem_1, 1); + getitem_1 = None + mul_1: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul, add_2); + mul = None + add_3: "f32[1, 9600, 2304]" = torch.ops.aten.add.Tensor(mul_1, getitem); + mul_1 = getitem = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:387 in _get_output_for_patched_inputs, code: latents = self.proj_out(latents) + view: "f32[9600, 2304]" = torch.ops.aten.view.default(add_3, [9600, 2304]); + add_3 = None + permute: "f32[2304, 32]" = torch.ops.aten.permute.default(primals_2, [1, 0]); + primals_2 = None + addmm: "f32[9600, 32]" = torch.ops.aten.addmm.default(primals_3, view, permute); + primals_3 = None + view_1: "f32[1, 9600, 32]" = torch.ops.aten.view.default(addmm, [1, 9600, 32]); + addmm = None + # No stacktrace found for following nodes + squeeze: "f32[1, 9600, 32]" = torch.ops.aten.squeeze.dim(view_1, 1); + view_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:387 in _get_output_for_patched_inputs, code: latents = self.proj_out(latents) + permute_1: "f32[32, 2304]" = torch.ops.aten.permute.default(permute, [1, 0]); + permute = None + return [squeeze, primals_5, getitem_3, rsqrt, add_2, view, permute_1] + primals_1 = torch.ones((2, 2304), device=device_npu, dtype=torch.float32) + primals_2 = torch.ones((32, 2304), device=device_npu, dtype=torch.float32) + primals_3 = torch.ones((32,), device=device_npu, dtype=torch.float32) + primals_4 = torch.ones((1, 2304), device=device_npu, dtype=torch.float32) + primals_5 = torch.ones((1, 9600, 2304), device=device_npu, dtype=torch.float32) + ref = forward(primals_1, primals_2, primals_3,primals_4, primals_5) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_1, primals_2, primals_3,primals_4, primals_5) + for i in range(len(ref)): + assert torch.allclose(ref[i], calc[i], equal_nan=True, rtol=1e-4, atol=1e-4) + +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_15_forward(): + def forward(primals_1: "f32[1, 8, 30, 40, 1, 2, 2, 8]", primals_2: "i64[]", primals_3: "i64[]", + primals_4: "i64[]"): + permute: "f32[1, 8, 8, 1, 30, 2, 40, 2]" = torch.ops.aten.permute.default(primals_1, [0, 7, 1, 4, 2, 5, 3, 6]); + mul: "i64[]" = torch.ops.aten.mul.Tensor(primals_2, 1); + mul_1: "i64[]" = torch.ops.aten.mul.Tensor(primals_3, 2); + mul_2: "i64[]" = torch.ops.aten.mul.Tensor(primals_4, 2); + return [permute, mul, mul_1, mul_2] + + primals_1 = torch.randn((1, 8, 30, 40, 1, 2, 2, 8), device=device_npu, dtype=torch.float32) + primals_2 = torch.tensor((1), device=device_npu, dtype=torch.int64) + primals_3 = torch.tensor((1), device=device_npu, dtype=torch.int64) + primals_4 = torch.tensor((1), device=device_npu, dtype=torch.int64) + ref = forward(primals_1, primals_2, primals_3, + primals_4) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_1, primals_2, primals_3, + primals_4) + for i in range(len(ref)): + assert torch.allclose(ref[i], calc[i], equal_nan=True, rtol=1e-4, atol=1e-4) + +def find_first_mismatch(output_calc, out, rtol=1e-2, atol=1e-2): + for index in torch.cartesian_prod(*[torch.arange(s) for s in output_calc.shape]): + index = tuple(index.tolist()) + diff = torch.abs(output_calc[index] - out[index]) + rel_diff = diff / torch.abs(out[index]) if torch.abs(out[index]) > 0 else 0 + if diff > atol or rel_diff > rtol: + return index + return None + +@pytest.mark.skip +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_16_forward(): + def forward(primals_1: "f32[2, 2304]", primals_2: "f32[32, 2304]", primals_3: "f32[32]", primals_4: "f32[1, 2304]", primals_5: "f32[1, 9600, 2304]"): + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:407 in _get_output_for_patched_inputs, code: shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + unsqueeze: "f32[1, 2, 2304]" = torch.ops.aten.unsqueeze.default(primals_1, 0); primals_1 = None + slice_1: "f32[1, 2304]" = torch.ops.aten.slice.Tensor(primals_4, 0, 0, 9223372036854775807); primals_4 = None + unsqueeze_1: "f32[1, 1, 2304]" = torch.ops.aten.unsqueeze.default(slice_1, 1); slice_1 = None + add: "f32[1, 2, 2304]" = torch.ops.aten.add.Tensor(unsqueeze, unsqueeze_1); unsqueeze = unsqueeze_1 = None + split = torch.ops.aten.split.Tensor(add, 1, 1); add = None + getitem: "f32[1, 1, 2304]" = split[0] + getitem_1: "f32[1, 1, 2304]" = split[1]; split = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:408 in _get_output_for_patched_inputs, code: latents = self.norm_out(latents) + var_mean = torch.ops.aten.var_mean.correction(primals_5, [2], correction = 0, keepdim = True) + getitem_2: "f32[1, 9600, 1]" = var_mean[0] + getitem_3: "f32[1, 9600, 1]" = var_mean[1]; var_mean = None + add_1: "f32[1, 9600, 1]" = torch.ops.aten.add.Tensor(getitem_2, 1e-06); getitem_2 = None + rsqrt: "f32[1, 9600, 1]" = torch.ops.aten.rsqrt.default(add_1); add_1 = None + sub: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(primals_5, getitem_3) + mul: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(sub, rsqrt); sub = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:410 in _get_output_for_patched_inputs, code: latents = latents * (1 + scale) + shift + add_2: "f32[1, 1, 2304]" = torch.ops.aten.add.Tensor(getitem_1, 1); getitem_1 = None + mul_1: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul, add_2); mul = None + add_3: "f32[1, 9600, 2304]" = torch.ops.aten.add.Tensor(mul_1, getitem); mul_1 = getitem = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:411 in _get_output_for_patched_inputs, code: latents = self.proj_out(latents) + view: "f32[9600, 2304]" = torch.ops.aten.view.default(add_3, [9600, 2304]); add_3 = None + permute: "f32[2304, 32]" = torch.ops.aten.permute.default(primals_2, [1, 0]); primals_2 = None + addmm: "f32[9600, 32]" = torch.ops.aten.addmm.default(primals_3, view, permute); primals_3 = None + #import pdb;pdb.set_trace() + view_1: "f32[1, 9600, 32]" = torch.ops.aten.view.default(addmm, [1, 9600, 32]); + # No stacktrace found for following nodes + squeeze: "f32[1, 9600, 32]" = torch.ops.aten.squeeze.dim(view_1, 1); + # import pdb; + # pdb.set_trace() + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:418 in _get_output_for_patched_inputs, code: latents = latents.reshape( + view_2: "f32[1, 8, 30, 40, 1, 2, 2, 8]" = torch.ops.aten.view.default(squeeze, [1, 8, 30, 40, 1, 2, 2, 8]); squeeze = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:428 in _get_output_for_patched_inputs, code: latents = latents.permute(0, 7, 1, 4, 2, 5, 3, 6).contiguous() + permute_1: "f32[1, 8, 8, 1, 30, 2, 40, 2]" = torch.ops.aten.permute.default(view_2, [0, 7, 1, 4, 2, 5, 3, 6]); view_2 = None + clone: "f32[1, 8, 8, 1, 30, 2, 40, 2]" = torch.ops.aten.clone.default(permute_1); permute_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:429 in _get_output_for_patched_inputs, code: output = latents.reshape( + clone_1: "f32[1, 8, 8, 1, 30, 2, 40, 2]" = torch.ops.aten.clone.default(clone, memory_format = torch.contiguous_format); clone = None + view_3: "f32[1, 8, 8, 60, 80]" = torch.ops.aten.view.default(clone_1, [1, 8, 8, 60, 80]); clone_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:411 in _get_output_for_patched_inputs, code: latents = self.proj_out(latents) + permute_3: "f32[32, 2304]" = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + return [view_3, primals_5, getitem_3, rsqrt, add_2, view, permute_3] + + import random + import numpy as np + import os + def seed_all(seed=1234, mode=False): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(mode) + torch_npu.npu.manual_seed_all(seed) + torch_npu.npu.manual_seed(seed) + + seed_all(True) + primals_1 = torch.randn((2, 2304), device=device_npu,dtype=torch.float32) + print(primals_1) + primals_2 = torch.randn((32, 2304), device=device_npu,dtype=torch.float32) + primals_3 = torch.randn((32,), device=device_npu,dtype=torch.float32) + primals_4 = torch.randn((1, 2304), device=device_npu,dtype=torch.float32) + primals_5 = torch.randn((1, 9600, 2304), device=device_npu,dtype=torch.float32) + + ref = forward(primals_1, primals_2, primals_3, primals_4, primals_5) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_1, primals_2, primals_3, primals_4, primals_5) + for i in range(len(ref)): + print("i=", i) + assert torch.allclose(ref[i], calc[i], equal_nan=True, rtol=1e-3, atol=1e-3) + +if __name__ == '__main__': + test_opensora_cases_model_15_forward() + #test_opensora_cases_model_15_forward() + #test_opensora_cases_model_16_forward() diff --git a/test/_inductor/test_permute.py b/test/_inductor/test_permute.py new file mode 100644 index 0000000000000000000000000000000000000000..fee281959207f1e1dfb11be26ad38c1016cf45d9 --- /dev/null +++ b/test/_inductor/test_permute.py @@ -0,0 +1,47 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +torch_npu._inductor.config.enable_npu_indexing = True + + +class TestPermute(TestUtils): + __TIME_LIMIT = 100 + + _permute_dims = [ + (0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 1, 3), (0, 2, 3, 1), + (0, 3, 1, 2), (0, 3, 2, 1), (1, 0, 2, 3), (1, 0, 3, 2), + (1, 2, 0, 3), (1, 2, 3, 0), (1, 3, 0, 2), (1, 3, 2, 0), + (2, 0, 1, 3), (2, 0, 3, 1), (2, 1, 0, 3), (2, 1, 3, 0), + (2, 3, 0, 1), (2, 3, 1, 0), (3, 0, 1, 2), (3, 0, 2, 1), + (3, 1, 0, 2), (3, 1, 2, 0), (3, 2, 0, 1), (3, 2, 1, 0), + ] + + def op_calc(self, a, b, dim): + a = a.permute(dim) + b = b.permute(dim) + y = a + b + return y + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 8, 512, 128)]) + @pytest.mark.parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16', 'int64']) + def test_view_cases(self, shape, dtype, clear_cache): + print(f"shape={shape}") + print(f"dtype={dtype}") + print("npu_indexing={}".format(torch_npu._inductor.config.enable_npu_indexing)) + + a = self._generate_tensor(shape, dtype) + b = self._generate_tensor(shape, dtype) + + for dim in self._permute_dims: + print(f"start to test permute on dim :{dim}") + std_permute = self.op_calc(a, b, dim) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_permute = compiled_op_calc(a, b, dim) + + torch.testing.assert_close(std_permute, inductor_permute, rtol=1e-3, atol=1e-3) + print("data validation passed.") diff --git a/test/_inductor/test_reduction_brocast_add.py b/test/_inductor/test_reduction_brocast_add.py new file mode 100644 index 0000000000000000000000000000000000000000..29e86fdae90af454a819ee7ddc624e8d3ab2ecd5 --- /dev/null +++ b/test/_inductor/test_reduction_brocast_add.py @@ -0,0 +1,34 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestSumAdd(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.REDUCTION + + def foo(self,a, b, dim, shape): + y = a + b + y = y.sum(dim) + y = y.unsqueeze(dim) + y = y.broadcast_to(shape) + b + return y + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(9, 9, 31, 63)]) + @pytest.mark.parametrize('dim', [0, 1, 2]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes1(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + r1 = self.foo(a, b, dim, shape) + func = torch.compile(self.foo, backend="inductor", dynamic=False) + r = func(a, b, dim, shape) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) diff --git a/test/_inductor/test_relu.py b/test/_inductor/test_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..3107d4d5cb73d3b525238687597ab557ff4e612e --- /dev/null +++ b/test/_inductor/test_relu.py @@ -0,0 +1,34 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRelu(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element): + result = torch.relu(first_element) + return result + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.testing.assert_close(std_result, inductor_result) + + +if __name__ == '__main__': + TestRelu() diff --git a/test/_inductor/test_renorm.py b/test/_inductor/test_renorm.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e55a833d9849cd37a903f3d96f52ecad90e0b6 --- /dev/null +++ b/test/_inductor/test_renorm.py @@ -0,0 +1,40 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestRenorm(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return torch.renorm(input_element, p=2, dim=dim, maxnorm=5) + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(32, 64)]) + @pytest.mark.parametrize('dim', [-1]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + print(f"input_element= {input_element}") + std_ret = self.op_calc(input_element, dim) + print(f"std_ret= {std_ret}") + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element, dim) + print(f"inductor_ret= {inductor_ret}") + + assert torch.allclose(std_ret, inductor_ret, equal_nan=True) + + +if __name__ == "__main__": + size = (32, 64) + test = TestRenorm() + test.test_reduction_cases_shapes(size, -1, 'float32', None) + diff --git a/test/_inductor/test_repeat.py b/test/_inductor/test_repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..d3df6a138d27d23f63501175402a17cc4496f0bb --- /dev/null +++ b/test/_inductor/test_repeat.py @@ -0,0 +1,40 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestRepeat(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return input_element.repeat(dim) + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(16, 128, 64)]) + @pytest.mark.parametrize('dim', [(1, 1, 2), (1, 2, 1), (2, 1, 1)]) #(2, 3, 4), (1, 2, 3) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + + std_ret = self.op_calc(input_element, dim) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element, dim) + + torch.testing.assert_close(std_ret, inductor_ret, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + size = (16, 512, 64) + dim = (2, 3, 4) + test = TestRepeat() + test.test_reduction_cases_shapes(size, dim, 'float32', None) + diff --git a/test/_inductor/test_reshape.py b/test/_inductor/test_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..910e0c83b402776be45963e4101fba839d0f941e --- /dev/null +++ b/test/_inductor/test_reshape.py @@ -0,0 +1,39 @@ +import torch +import torch_npu +import pytest +import torch_npu._inductor +from testutils import OperatorType, TestUtils + +torch_npu._inductor.config.enable_npu_indexing = True + +class TestReshape(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + B, N, S, D = (1, 12, 256, 8) + + def op_calc(self, a, b): + a = a.reshape(self.S, self.B, self.N * self.D) + b = b.reshape(self.S, self.B, self.N * self.D) + y = a + b + return y + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(1, 12, 256, 8)]) + @pytest.mark.parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16', 'int64']) + def test_view_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + a = self._generate_tensor(shape, dtype) + b = self._generate_tensor(shape, dtype) + + print(f"start to test reshape on shape :{shape} ") + std_reshape = self.op_calc(a, b) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_reshape = compiled_op_calc(a, b) + + torch.testing.assert_close(std_reshape, inductor_reshape, rtol=1e-3, atol=1e-3) + + print("data validation passed") + diff --git a/test/_inductor/test_rsqrt.py b/test/_inductor/test_rsqrt.py new file mode 100644 index 0000000000000000000000000000000000000000..b76e1779f48f43db46f2a61594d3967188fe3857 --- /dev/null +++ b/test/_inductor/test_rsqrt.py @@ -0,0 +1,35 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRsqrt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element): + result = torch.rsqrt(first_element) + return result + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype, 1) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.testing.assert_close(std_result, inductor_result, rtol=1e-1, atol=1e-1) + + +if __name__ == '__main__': + TestRsqrt() + diff --git a/test/_inductor/test_slice.py b/test/_inductor/test_slice.py new file mode 100644 index 0000000000000000000000000000000000000000..2b8e75a91ba170a172a4ef27b3730360aff6e262 --- /dev/null +++ b/test/_inductor/test_slice.py @@ -0,0 +1,55 @@ +import torch +import torch_npu +import pytest +import torch_npu._inductor +from testutils import OperatorType, TestUtils + + +class TestSlice(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, a, b, dim, step): + if dim == 0: + target = a.shape[0] + end = target // step + a = a[:end:, ::, ::, ::] + b = b[:end:, ::, ::, ::] + elif dim == 1: + target = a.shape[1] + end = target // step + a = a[::, :end:, ::, ::] + b = b[::, :end:, ::, ::] + elif dim == 2: + target = a.shape[2] + end = target // step + a = a[::, ::, :end:, ::] + b = b[::, ::, :end:, ::] + elif dim == 3: + target = a.shape[3] + end = target // step + a = a[::, ::, ::, :end:] + b = b[::, ::, ::, :end:] + y = a + b + return y + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 8, 256, 128)]) + @pytest.mark.parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16', 'int64']) + def test_view_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + a = self._generate_tensor(shape, dtype) + b = self._generate_tensor(shape, dtype) + + for dim in [3, 2, 1, 0]: + print(f"start to test slice on dim :{dim} ") + std_slice = self.op_calc(a, b, dim, min(shape)//2) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_slice = compiled_op_calc(a, b, dim, min(shape)//2) + + torch.testing.assert_close(std_slice, inductor_slice, rtol=1e-3, atol=1e-3) + + print("data validation passed") + diff --git a/test/_inductor/test_split_loop.py b/test/_inductor/test_split_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..5682a22836ec958114bc752c30b3646d180e7a55 --- /dev/null +++ b/test/_inductor/test_split_loop.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestSplitLoop(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, a, b): + return torch.nn.functional.gelu(a + b) + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8,86,1152),(61,89,157),(7,89,971)]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_split_loop(self, shape, dtype): + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + a = self._generate_tensor(shape, dtype) + b = self._generate_tensor((shape[0],1,shape[2]), dtype) + + std_ = self.op_calc(a, b) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_ = compiled_op_calc(a, b) + # print(f"inductor_cat.shape= {inductor_cat.shape}") + torch.testing.assert_close(std_,inductor_,atol=1e-3,rtol=1e-3) + + +if __name__ == "__main__": + size = (8,86,1152) + test = TestSplitLoop() + test.test_split_loop(size, 'float32') diff --git a/test/_inductor/test_sqrt.py b/test/_inductor/test_sqrt.py new file mode 100644 index 0000000000000000000000000000000000000000..201b646f9c2cb369ce5adac6f57a4edf629a6935 --- /dev/null +++ b/test/_inductor/test_sqrt.py @@ -0,0 +1,44 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestSqrt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element): + result = torch.sqrt(first_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype, 1) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + # print(std_result[0:8]) + # print(inductor_result[0:8]) + # torch.testing.assert_close(std_result, inductor_result) + # 需要比较包含 NaN 值的张量,并且希望认为两个 NaN 值是相等的,您可以使用 torch.allclose 函数,并设置 equal_nan=True 参数 + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_result, inductor_result, equal_nan=True, rtol=rtol, atol=atol) + diff --git a/test/_inductor/test_sub.py b/test/_inductor/test_sub.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc44938c3554b79353b574e498a6490dd6378a6 --- /dev/null +++ b/test/_inductor/test_sub.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestSub(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element, second_element): + result = first_element - second_element + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_sub = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_sum = compiled_op_calc(first_element, second_element) + # print(std_sub[0:8]) + # print(inductor_sum[0:8]) + torch.testing.assert_close(std_sub, inductor_sum) diff --git a/test/_inductor/test_sum.py b/test/_inductor/test_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..6b13e88c48b854c01b2e63fe5201907fd233d598 --- /dev/null +++ b/test/_inductor/test_sum.py @@ -0,0 +1,75 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestSum(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.REDUCTION + + def op_calc(self, input_element, dim): + return torch.sum(input_element, dim) + # 规约轴和非规约轴对齐用例 float32 XBLOCK_SUB>=8:shape=(8,32) + # non-persistent reduction 用例 规约轴>1024:shape=(8,8,8,2048) dim=-1 + _reduction_extest_shape4d_all = [(8, 32), (8, 8, 8, 2048)] + _reduction_extest_dim4d_low = [-1] + _reduction_extest_dim4d_all = [0, 1, 2] + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的case + # 若需测试更多数据类型,将dtype手动修改,若在一个ut中涉及多个dtype的更改,可能因为tiling固化导致失败 + # 对indexing开关情况的测试需要用外部参数--npu-indexing=True/False完成 + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', _reduction_extest_shape4d_all) + @pytest.mark.parametrize('dim', _reduction_extest_dim4d_low) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape={shape}") + print(f"dim={dim}") + print(f"dtype={dtype}") + print('npu_indexing={}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + std_sum = self.op_calc(input_element, dim) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_sum_tmp = compiled_op_calc(input_element, dim) + if dtype == 'int32' or dtype == 'int64': + # inductor return float32,need to change int64 for assert + inductor_sum = inductor_sum_tmp.long() + elif dtype == 'float16': + # inductor return float32,need to change float16 for assert + inductor_sum = inductor_sum_tmp.half() + elif dtype == 'bfloat16': + # inductor return float32,need to change float32 for assert + std_sum = std_sum.float() + inductor_sum = inductor_sum_tmp + else: + inductor_sum = inductor_sum_tmp + + # print(f"std_sum={std_sum[0:8]}") + # print(f"inductor_sum={inductor_sum[0:8]}") + torch.testing.assert_close(std_sum, inductor_sum, rtol=1e-1, atol=1e-1) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(32, 16, 64, 128)]) + @pytest.mark.parametrize('dim', _reduction_extest_dim4d_all) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_dims(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + std_sum = self.op_calc(input_element, dim) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_sum = compiled_op_calc(input_element, dim) + + torch.testing.assert_close(std_sum, inductor_sum, rtol=1e-1, atol=1e-1) + +if __name__ == "__main__": + size = (32, 16, 64, 128) + test = TestSum() + test.test_reduction_cases_shapes(size, 2, 'float32', None) diff --git a/test/_inductor/test_sum_add.py b/test/_inductor/test_sum_add.py new file mode 100644 index 0000000000000000000000000000000000000000..670623d722fd07904631ed6c3028a3474eb40729 --- /dev/null +++ b/test/_inductor/test_sum_add.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestSumAdd(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.REDUCTION + def op_calc(self, input_element, dim, input_element2): + tmp = torch.sum(input_element, dim) + return tmp + input_element2 + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(32, 64, 128, 2048)]) + @pytest.mark.parametrize('dim', [0, 1, 2, 3]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + if dim == -1 or dim == 3: + input_element2 = torch.full(size=(32, 64, 128), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu")) + elif dim == 2: + input_element2 = torch.full(size=(32, 64, 2048), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu")) + elif dim == 1: + input_element2 = torch.full(size=(32, 128, 2048), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu")) + else: + input_element2 = torch.full(size=(64, 128, 2048), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu")) + + std_sum = self.op_calc(input_element, dim, input_element2) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_sum = compiled_op_calc(input_element, dim, input_element2) + + torch.testing.assert_close(std_sum, inductor_sum, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + size = (32, 64, 128, 2048) + test = TestSumAdd() + test.test_reduction_cases_shapes(size, -1, 'float32', None) \ No newline at end of file diff --git a/test/_inductor/test_var.py b/test/_inductor/test_var.py new file mode 100644 index 0000000000000000000000000000000000000000..5c583452c8d54e5ab178b51271504fb6081b7fb6 --- /dev/null +++ b/test/_inductor/test_var.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestVar(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return torch.var(input_element, dim) + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 64, 128)]) + @pytest.mark.parametrize('dim', [0, 1, 2]) + @pytest.mark.parametrize('dtype', ['float16']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + std_ret = self.op_calc(input_element, dim) + # print(f"std_ret= {std_ret}") + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element, dim) + # print(f"inductor_ret= {inductor_ret}") + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_ret, inductor_ret, equal_nan=True, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + size = (8, 64, 128) + test = TestVar() + test.test_reduction_cases_shapes(size, 2, 'float32', None) \ No newline at end of file diff --git a/test/_inductor/test_var_mean.py b/test/_inductor/test_var_mean.py new file mode 100644 index 0000000000000000000000000000000000000000..a36403daabde8f546fb1af2c86c6fd03d6e143fa --- /dev/null +++ b/test/_inductor/test_var_mean.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestVarMean(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return torch.var_mean(input_element, dim) + + # case:The shape must not be too large + #@pytest.mark.skip(reason="npu compiler bug") + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 64, 128)]) + @pytest.mark.parametrize('dim', [0, 1, 2, (0, 2), (0, 1)]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + + std_var, std_mean = self.op_calc(input_element, dim) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_var, inductor_mean = compiled_op_calc(input_element, dim) + + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_var, inductor_var, equal_nan=True, rtol=rtol, atol=atol) + assert torch.allclose(std_mean, inductor_mean, equal_nan=True, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + size = (8, 64, 1024) + test = TestVarMean() + test.test_reduction_cases_shapes(size, 2, 'float32', None) \ No newline at end of file diff --git a/test/_inductor/test_var_mean_add_mul.py b/test/_inductor/test_var_mean_add_mul.py new file mode 100644 index 0000000000000000000000000000000000000000..a20cdfe54749b0ea17d2b73667c26770be988d8d --- /dev/null +++ b/test/_inductor/test_var_mean_add_mul.py @@ -0,0 +1,45 @@ +import torch +import torch_npu +import torch_npu._inductor +import pytest + +__TIME_LIMIT = 100 +@pytest.mark.timeout(__TIME_LIMIT) +def test_reduction_cases_shapes(): + device = 'npu' + + def forward(add: "f32[1, 2, 2304]", primals_2: "f32[32, 2304]", primals_5: "f32[1, 9600, 2304]"): + split = torch.ops.aten.split.Tensor(add, 1, 1); + getitem: "f32[1, 1, 2304]" = split[0] + getitem_1: "f32[1, 1, 2304]" = split[1]; + + var_mean = torch.ops.aten.var_mean.correction(primals_5, [2], correction=0, keepdim=True) + getitem_2: "f32[1, 9600, 1]" = var_mean[0] + getitem_3: "f32[1, 9600, 1]" = var_mean[1]; + add_1: "f32[1, 9600, 1]" = torch.ops.aten.add.Tensor(getitem_2, 1e-06); + rsqrt: "f32[1, 9600, 1]" = torch.ops.aten.rsqrt.default(add_1); + sub: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(primals_5, getitem_3) + mul: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(sub, rsqrt); + + add_2: "f32[1, 1, 2304]" = torch.ops.aten.add.Tensor(getitem_1, 1); + mul_1: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul, add_2); + add_3: "f32[1, 9600, 2304]" = torch.ops.aten.add.Tensor(mul_1, getitem); + + view: "f32[9600, 2304]" = torch.ops.aten.view.default(add_3, [9600, 2304]); + return [None, primals_5, getitem_3, rsqrt, add_2, view, primals_2] + + torch_npu._inductor.config.enable_npu_indexing = True + primals_2: "f32[32, 2304]" = torch.randn((32, 2304), device = device, dtype=torch.float32) + primals_5: "f32[1, 9600, 2304]" = torch.randn((1, 9600, 2304), device = device, dtype=torch.float32) + add: "f32[1, 2, 2304]" = torch.randn((1, 2, 2304), device =device, dtype=torch.float32) + + _, primals_5_ref, getitem_3_ref, rsqrt_ref, add_2_ref, view_ref, primals_2_ref = forward(add, primals_2, primals_5) + + forward = torch.compile(forward, backend="inductor", dynamic=False) + _, primals_5, getitem_3, rsqrt, add_2, view, primals_2 = forward(add, primals_2, primals_5) + + assert torch.allclose(primals_5_ref, primals_5, equal_nan=True, rtol=1e-3, atol=1e-3) + assert torch.allclose(getitem_3_ref, getitem_3, equal_nan=True, rtol=1e-3, atol=1e-3) + assert torch.allclose(rsqrt_ref, rsqrt, equal_nan=True, rtol=1e-3, atol=1e-3) + assert torch.allclose(add_2_ref, add_2, equal_nan=True, rtol=1e-3, atol=1e-3) + assert torch.allclose(primals_2_ref, primals_2, equal_nan=True, rtol=1e-3, atol=1e-3) \ No newline at end of file diff --git a/test/_inductor/test_where.py b/test/_inductor/test_where.py new file mode 100644 index 0000000000000000000000000000000000000000..b10b0aa3d98fe5fe5cb743f1277368a16c036ec4 --- /dev/null +++ b/test/_inductor/test_where.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestWhere(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, condition, first_element, second_element): + return torch.where(condition, first_element, second_element) + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + condition = self._generate_tensor(shape, 'bool') + + std_result = self.op_calc(condition, first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(condition, first_element, second_element) + + torch.testing.assert_close(std_result, inductor_result) \ No newline at end of file diff --git a/test/_inductor/testutils.py b/test/_inductor/testutils.py new file mode 100644 index 0000000000000000000000000000000000000000..3559820fc21e9a32a4559dbe122b96b7b3806e7d --- /dev/null +++ b/test/_inductor/testutils.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +from enum import Enum, unique +import os + + +@unique +class OperatorType(Enum): + POINTWISE = 'POINTWISE' + REDUCTION = 'REDUCTION' + + +class TestUtils: + _pointwise_test_shape2d = [(4096, 256), (1024, 32), (8, 2048), (8, 4096)] # (8, 4), (8, 8), not supported + _pointwise_test_shape3d = [(8, 8, 4), (8, 8, 8), (8, 8, 2048), (8, 8, 4096)] + _pointwise_test_shape4d = [(128, 128, 4096, 4), (128, 128, 4096, 8), + (32, 32, 1024, 1024)] # 128*128*4096*2048 is too big(512G) + _pointwise_test_shapes = _pointwise_test_shape2d + _pointwise_test_shape3d + _pointwise_test_shape4d + + _pointwise_demo_shapes = [(1024, 32), (8, 16, 256, 32)] + _reduction_extest_shape4d = [(8, 8, 8, 16384), (8, 8, 16384, 8), (8, 16384, 8, 8), (16384, 8, 8, 8)] + _reduction_extest_dim4d = [-1, -2, 1, 0] + _reduction_extest_SDbinding = list(zip(_reduction_extest_shape4d, _reduction_extest_dim4d)) + + _test_dtypes = ['float32', 'int32', 'float16', 'bfloat16', 'int64'] + + @staticmethod + def _generate_tensor(shape, dtype, floatPOSIFLAG=0): + if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': + if floatPOSIFLAG: + return 1000 * torch.rand(size=shape, dtype=eval('torch.' + dtype), device=torch.device("npu")) + else: + return torch.randn(size=shape, dtype=eval('torch.' + dtype), device=torch.device("npu")) * 2000 + elif dtype == 'int32' or dtype == 'int64': + return torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype), device=torch.device("npu")) + elif dtype == 'bool': + return torch.randint(low=0, high=2, size=shape, device=torch.device("npu")).bool() + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) \ No newline at end of file diff --git a/torch_npu/_inductor/__init__.py b/torch_npu/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd98467d4a9f2c46275dbcd9f84390e91dce9c2f --- /dev/null +++ b/torch_npu/_inductor/__init__.py @@ -0,0 +1,93 @@ + +import torch +from torch._inductor.codegen.common import register_backend_for_device, register_device_op_overrides +from torch._dynamo.device_interface import register_interface_for_device, get_interface_for_device +from torch._inductor import lowering as inductor_lowering +from torch._inductor.choices import InductorChoices + +from torch_npu.utils._inductor import NPUDeviceOpOverrides +from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device +from torch_npu.npu.utils import device_count + +from .lowering import _register_npu_inductor_fallbacks, make_reduction +from .decomposition import _register_npu_inductor_decompositons +from .utils import get_current_raw_stream +from .config import log as npulog +from .config import aggresive_autotune, num_vector_core +from .npu_choices import should_use_persistent_reduction +from . import config as npu_config +#register fx_pass should be put behind of _register_npu_inductor_decompositons +from . import codegen +from . import npu_fusion_attention_graph +from . import dynamo_patch3 + +npulog.info("perform torch_npu._inductor patch") + + +def _inductor_register_backend_for_device(): + from .codegen.schduling import NPUTritonScheduling + from .codegen.wrapper import NPUWrapperCodeGen + register_backend_for_device('npu', NPUTritonScheduling, NPUWrapperCodeGen) + +_inductor_register_backend_for_device() + + +## Override original inductor device overrides in torch_npu +class NewNPUDeviceOpOverrides(NPUDeviceOpOverrides): + def import_get_raw_stream_as(self, name): + return f"from torch_npu._inductor import get_current_raw_stream as {name}" + + + +def _inductor_register_device_op_overrides(): + register_device_op_overrides('npu', NewNPUDeviceOpOverrides()) + +_inductor_register_device_op_overrides() + + +## Override original dynamo device interface in torch_npu +class NewNpuInterface(NpuInterface): + + @staticmethod + def is_available() -> bool: + return device_count() > 0 + + @staticmethod + def get_compute_capability(mydevice=None): + # npu has no concept of cc. triton-npu compiler depends on subarch instead + return torch.npu.get_device_name(mydevice) + + @staticmethod + def exchange_device(device_id: int) -> int: + curr_device = current_device() + set_device(device_id) + return curr_device + + @staticmethod + def maybe_exchange_device(device_id: int) -> int: + return device_id + + +register_interface_for_device("npu", NewNpuInterface) +register_interface_for_device("npu:0", NewNpuInterface) +device = get_interface_for_device("npu") + + + +inductor_lowering.make_reduction = make_reduction +_register_npu_inductor_fallbacks() +_register_npu_inductor_decompositons() + + +def _replace_benchmark_all_configs(): + from torch._inductor.triton_heuristics import CachingAutotuner + from .npu_triton_heuristics import benchmark_all_configs + CachingAutotuner.benchmark_all_configs = benchmark_all_configs + + +if (aggresive_autotune): + _replace_benchmark_all_configs() + import os + os.environ["TRITON_BENCH_METHOD"] = "npu" + +InductorChoices.should_use_persistent_reduction = should_use_persistent_reduction \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/__init__.py b/torch_npu/_inductor/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a6832665e9c1afe472d174a93d26dce1337e61 --- /dev/null +++ b/torch_npu/_inductor/codegen/__init__.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + + + +from torch._inductor.ir import Reduction, LoopBody +from torch._inductor.codegen.triton import TritonScheduling +from torch._inductor import sizevars +from torch._inductor.codegen.triton import TritonKernel +from torch._inductor.codegen.simd import SIMDKernel + +from torch_npu._inductor.codegen._sizevars import simplify +from torch_npu._inductor.codegen.ir import (num_splits, loopbody__call__, transform_dims_in_indexing, substituted_dims_in_indexing) +from torch_npu._inductor.codegen.triton import is_compatible +from torch_npu._inductor.codegen.triton import group_fn, select_index_dtype +from torch_npu._inductor.codegen.schduling import create_tiling + +from ..config import log as npulog +npulog.info("perform npu_indexing patch") +#graph +#common +#ir + + +Reduction.num_splits = num_splits +setattr(LoopBody, 'transform_dims_in_indexing', transform_dims_in_indexing) +setattr(LoopBody, 'substituted_dims_in_indexing', substituted_dims_in_indexing) + +LoopBody.__call__ = loopbody__call__ +#need to enable this to speedup attn_cp_test +#ComputedBuffer.simplify_and_reorder = simplify_and_reorder +#triton scheduling +TritonScheduling.group_fn = group_fn +TritonScheduling.select_index_dtype = select_index_dtype +TritonScheduling.create_tiling = create_tiling +#triton kernel +setattr(SIMDKernel, 'is_compatible', is_compatible) + +#util +sizevars.SizeVarAllocator.simplify = simplify \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/_sizevars.py b/torch_npu/_inductor/codegen/_sizevars.py new file mode 100644 index 0000000000000000000000000000000000000000..84206554041b15e3930fead7d0759bb3b9c8ab8e --- /dev/null +++ b/torch_npu/_inductor/codegen/_sizevars.py @@ -0,0 +1,10 @@ +import sympy +from sympy import Expr +from torch._inductor.utils import sympy_subs + + +def simplify(self, expr: Expr): + if isinstance(expr, (tuple, list)): + return [sympy.expand(s).xreplace(self.replacements) for s in expr] + return sympy.expand(expr).xreplace(self.replacements) + diff --git a/torch_npu/_inductor/codegen/ir.py b/torch_npu/_inductor/codegen/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb83e5c1fa4d13f9e54c38a48a12a0379fce318 --- /dev/null +++ b/torch_npu/_inductor/codegen/ir.py @@ -0,0 +1,203 @@ + +from typing import List, Tuple, Dict, Any, Optional +import itertools +import sympy + + +from torch._inductor.virtualized import V +from torch._inductor.ir import (ReductionHint, IRNode, ModularIndexing, FloorDiv) +from torch._inductor.utils import sympy_subs, sympy_index_symbol +from torch_npu._inductor.codegen.triton import NPUIndexTritonKernel + +from ..config import log + + +# NPU doesn't need to support ReductionHint.OUTER, and persistent reduction +def num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node: Optional[IRNode] = None, + ): + return ReductionHint.DEFAULT, 1 + + +def detect_flattened_dims(kernel, index): + new_vars = {} + if not isinstance(index, (sympy.core.add.Add, ModularIndexing, FloorDiv)): + return new_vars + + def detect_flattened_axis(expr): + def init_new_vars(var, length): + if var not in new_vars: + new_vars[var] = {length: [None, None]} + if length not in new_vars[var]: + new_vars[var][length] = [None, None] + if isinstance(expr, ModularIndexing): + var, divisor, length = expr.args + init_new_vars(var, length) + new_vars[var][length][1] = (expr, divisor, length) + elif isinstance(expr, FloorDiv): + var, divisor = expr.args + init_new_vars(var, divisor) + # over than 1 node_schedule, var may be deleted in kernel.range_tree_nodes + # it shoule be find in range_tree_nodes_removed dict + if (var in kernel.range_tree_nodes): + numel = kernel.range_tree_nodes[var].length + else: + numel = kernel.range_tree_nodes_removed[var].length + + length = expr.eval(numel, divisor) + new_vars[var][divisor][0] = (expr, divisor, length) + + else: + for x in expr.args: + detect_flattened_axis(x) + + # add + if isinstance(index, sympy.core.add.Add): + for x in index.args: + detect_flattened_axis(x) + elif isinstance(index, (ModularIndexing, FloorDiv)): + detect_flattened_axis(index) + else: + pass + + # make sure FloorDiv, MouldarIndexing must be in-pair + for var, divisors in new_vars.items(): + if var in kernel.range_tree_nodes: + parent_axis = kernel.range_tree_nodes[var] + else: + parent_axis = kernel.range_tree_nodes_removed[var] + for divisor, pair in divisors.items(): + if not pair[0] and not pair[1]: + pass + #FloorDiv not inplace + elif not pair[0]: + _, _, length = pair[1] + expr = FloorDiv(var, length) + new_vars[var][divisor][0] = (expr, length, parent_axis.length // length) + #ModularIndexing not inplace + elif not pair[1]: + expr = ModularIndexing(var, 1, divisor) + new_vars[var][divisor][1] = (expr, 1, divisor) + else: + pass + + return new_vars + + +def rebuild_flattened_dims(indexing): + def rebuild_flattened_dim(key, index, old_node, flatten_dim): + for _, pair in flatten_dim.items(): + new_var_expr = sympy.Integer(0) + origin_axis_length = 0 + pair_is_valid = True + # don't create duplicated axis, e.g. y1:1024, y1 % 1024 is duplicated + expr, divisor, length = pair[1] + if not old_node.parent.duplicated_check(divisor, length): + V.kernel.expr_substituted[expr] = old_node.symbol() + break + + for axis in pair: + expr, divisor, length = axis + # 3. try to rebuild the axis in kernel + new_node = old_node.parent.lookup(divisor, length) + + # 4. substitute div/mod expression in indexing + index = index.subs(expr, new_node.symbol()) + indexing[key] = index + if isinstance(expr, FloorDiv): + new_var_expr = new_var_expr + new_node.symbol() * divisor + origin_axis_length = divisor * length + elif isinstance(expr, ModularIndexing): + new_var_expr = new_var_expr + new_node.symbol() + V.kernel.expr_substituted[expr] = new_node.symbol() + + if var not in V.kernel.range_tree_nodes_substituted: + V.kernel.range_tree_nodes_substituted[var] = [] + V.kernel.range_tree_nodes_substituted[var].append((origin_axis_length, new_var_expr)) + + def find_index_in_substitute(index, kernel): + return any([index.find(key) for key in kernel.expr_substituted.keys()]) + + kernel = V.kernel + for key, index in indexing.items(): + # 1. try to find out flattened axis from indexing + flatten_dims = detect_flattened_dims(kernel, index) + #2. try to rebuild these flattened dims + for var, flatten_dim in flatten_dims.items(): + if (var in kernel.range_tree_nodes): + old_node = kernel.range_tree_nodes[var] + else: + old_node = kernel.range_tree_nodes_removed[var] + + rebuild_flattened_dim(key, index, old_node, flatten_dim) + + if find_index_in_substitute(index, kernel): + new_index = sympy_subs(index, kernel.expr_substituted) + indexing[key] = new_index + + +def substituted_dims_in_indexing(self, indexing, kernel, range_tree_nodes_substituted): + substituted = False + for var, candidates in range_tree_nodes_substituted.items(): + if not (len(candidates) > 0): + raise RuntimeError("assert len(candidates) > 0, candidates") + exprs = sorted(candidates, reverse=True, key=lambda x: x[0]) + # the best candidate is with the longest numel + numel = exprs[0][0] + expr = exprs[0][1] + node = kernel.range_tree_nodes[var] + if node.length != numel: + log.debug("sub nodes (expr%s, numel:%d) can not substitute parent node(%s:%d)", + expr, numel, node.symbol(), node.length) + continue + for key, index in indexing.items(): + if var in index.free_symbols: + index = index.subs(var, expr) + indexing[key] = index + substituted = True + + return substituted + + +def generate_body_indexing(body, indices): + index = list(itertools.chain.from_iterable(indices)) + if not (len(index) == len(body.var_ranges)): + raise RuntimeError("assert len(index) == len(body.var_ranges), (index, body.var_ranges)") + if not (all(v not in body.var_ranges for v in index)): + raise RuntimeError("assert all(v not in body.var_ranges for v in index)") + + replacements = dict(zip(body.var_ranges.keys(), index)) + indexing_map = dict(zip(index, body.var_ranges.keys())) + setattr(body, 'indexing_map', indexing_map) + body.indexing = { + name: sympy_subs(expr, replacements) + for name, expr in body.indexing_exprs.items() + } + + +def transform_dims_in_indexing(self, indices): + if self.indexing is None: + generate_body_indexing(self, indices) + + if V.kernel is not None and isinstance(V.kernel, NPUIndexTritonKernel): + rebuild_flattened_dims(self.indexing) + + +# select tiling axis, recover missing dimensions, +def loopbody__call__(self, *indices): + if self.indexing is None: + generate_body_indexing(self, indices) + result = self.root_block() + self.indexing = None + return result + + + diff --git a/torch_npu/_inductor/codegen/npu_kernel_features.py b/torch_npu/_inductor/codegen/npu_kernel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..57c1211e3531963f952b6adea8b0fb82ffa74d6d --- /dev/null +++ b/torch_npu/_inductor/codegen/npu_kernel_features.py @@ -0,0 +1,94 @@ +import functools +from typing import Tuple, List +from typing import Iterable +import sympy + +import torch +from torch._inductor.codegen.simd_kernel_features import SIMDKernelFeatures, NodeScheduleEntry +from torch._inductor.utils import cache_on_self +from torch.utils._ordered_set import OrderedSet +from torch._inductor.virtualized import V +from torch._inductor.codegen.simd import SIMDScheduling + + +class NumelList(Tuple): + + def numels(self): + numel = functools.reduce(lambda a, b: a * b, self) + return numel + + def __eq__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel == numel2 + + def __le__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel <= numel2 + + def __lt__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel < numel2 + + def __ge__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel >= numel2 + + def __gt__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel > numel2 + + + def __mod__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel % numel2 + + def __truediv__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel / numel2 + + def __floordiv__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel // numel2 + + def __mul__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel * numel2 + + def __rmul__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel * numel2 + + def __add__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel + numel2 + + def __radd__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel + numel2 + + def __hash__(self): + return super(NumelList, self).__hash__() + + +class NPUKernelFeatures(SIMDKernelFeatures): + def __init__( + self, + node_schedule: List[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr = sympy.S.One, + ): + super().__init__(node_schedule, numel, reduction_numel) + self.numel = NumelList(self.numel) if isinstance(self.numel, Iterable) else self.numel + self.reduction_numel = NumelList(self.reduction_numel) if isinstance(self.reduction_numel, Iterable) else self.reduction_numel diff --git a/torch_npu/_inductor/codegen/schduling.py b/torch_npu/_inductor/codegen/schduling.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4b3061441048784bfcf72be5ad5eb7a5f9fed5 --- /dev/null +++ b/torch_npu/_inductor/codegen/schduling.py @@ -0,0 +1,221 @@ +import itertools +import contextlib +from typing import Union, Iterable +from typing import Dict, Sequence, List, Iterable +import sympy + + +from torch.fx.immutable_collections import immutable_dict +from torch._inductor.codegen.triton import (TritonScheduling, log, config) +from torch._inductor.codegen.simd import DisableReduction, EnableReduction, SIMDKernelFeatures, SIMDKernel +from torch._inductor.codegen.simd import schedule_log, scheduler +from torch._inductor.codegen.multi_kernel import MultiKernel +from torch._inductor.virtualized import (V,) +from torch._inductor.codecache import code_hash +from torch._dynamo.utils import counters +from torch._inductor.utils import sympy_index_symbol, ModularIndexing, FloorDiv + +from torch_npu._inductor.codegen.triton import NPUIndexTritonKernel, flatten +from .split_tiling import SplitTiling +from .npu_kernel_features import NumelList, NPUKernelFeatures + + +def flatten_groups(nums): + res = [] + for i in nums: + if isinstance(i, Iterable): + for x in i: + res.append(x) + else: + res.append(i) + return res + + +@classmethod +def create_tiling( + cls, pw_tiling: Sequence[sympy.Expr], reduction_tiling: Sequence[sympy.Expr] + ) -> Dict[str, sympy.Expr]: + """ + Create a tiling dict from pointwise and reduction splits. + """ + + pw_tiling = flatten_groups(pw_tiling) + pw_prefixes = ["w", "v", "t", "z", "y", "x"][-len(pw_tiling):] + reduction_tiling = flatten_groups(reduction_tiling) + reduction_tiling = [NumelList(reduction_tiling).numels()] + reduction_prefixes = ["r"][: len(reduction_tiling)] + tiling = immutable_dict( + list(zip(pw_prefixes, pw_tiling)) + + list(zip(reduction_prefixes, reduction_tiling))) + return tiling + + + +class NPUTritonScheduling(TritonScheduling): + def __init__(self, input_scheduler): + super().__init__(input_scheduler) + self.kernel_type = NPUIndexTritonKernel + + def create_kernel_choices( + self, kernel_features: SIMDKernelFeatures, kernel_args, kernel_kwargs + ) -> List[SIMDKernel]: + + return [ + self.kernel_type( + *kernel_args, + **kernel_kwargs, + ) + ] + + # transform indexing before call codegen_node_schedule_with_kernel + def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): + node_schedule = kernel_features.node_schedule + tiling = self.select_tiling( + node_schedule, kernel_features.numel, kernel_features.reduction_numel + ) + + kernels = self.create_kernel_choices( + kernel_features, [tiling], {"features": kernel_features} + ) + kernel = kernels[0] + setattr(kernel, "node_schedule", node_schedule) + self.decide_codegen_dims_in_kernel(node_schedule, kernel) + + for kernel in kernels: + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + + MultiKernel.merge_workspaces_inplace(kernels) + for kernel in kernels: + with V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + log.debug("Generating kernel code with kernel_name: %s", kernel_name) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + del kernel + + final_kernel: Union[SIMDKernel, MultiKernel] + if len(kernels) > 1: + final_kernel = MultiKernel(kernels) + else: + (final_kernel,) = kernels + + with V.set_kernel_handler(final_kernel): + for node in kernel_features.scheduler_nodes(): + node.mark_run() + + self.codegen_comment(node_schedule) + final_kernel.call_kernel(final_kernel.kernel_name) + + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernels[0].kernel_name) + + V.graph.removed_buffers |= final_kernel.removed_buffers + V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove + + if ( + V.graph.wrapper_code.supports_intermediate_hooks + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernels[0].args.live_output_buffers() + for node in kernel_features.scheduler_nodes(): + name = node.get_name() + if name not in live_outs: + continue + if node.node is None: + raise RuntimeError("assert node.node is not None") + + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.scheduler.free_buffers() + + def codegen_node( + self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] + ): + """ + Given a set of pre-fused nodes, generate a Triton kernel. + """ + + nodes: List[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + schedule_log.debug("Schedule:\n %s", node_schedule) + + return self.codegen_node_schedule( + NPUKernelFeatures(node_schedule, numel, rnumel) + ) + + def decide_codegen_dims_in_kernel(self, node_schedule, kernel): + def current_reduction_nodes(nodes): + return itertools.takewhile(lambda n: n is not DisableReduction, nodes) + + with kernel: + # 1. transform dims: create new dims to substitute floor_divide and modular expression + stack = contextlib.ExitStack() + for _, node in enumerate(node_schedule): + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + else: + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + node._body.transform_dims_in_indexing(index_vars) + # 2. go through range_tree_nodes to findout, to find one axis could be substituted by others + self.additional_nodes_to_be_subs(kernel, kernel.range_tree_nodes_substituted) + # 3.do the substitution on all indexing + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + continue + indexing = node._body.indexing + node._body.substituted_dims_in_indexing(indexing, kernel, kernel.range_tree_nodes_substituted) + + # 4.remove the substituted dims from kernel + for var, _ in kernel.range_tree_nodes_substituted.items(): + if (var in kernel.range_tree_nodes): + root = kernel.range_tree_nodes[var].parent + root.remove_entry(var) + # select split and tiling axis + split_tiling = SplitTiling(kernel) + split_tiling.select_tiling_axis() + # debug print index transforms + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + continue + for x, y in zip(node._body.indexing_exprs.values(), node._body.indexing.values()): + print(f"index transform:{x}->{y}") + + def additional_nodes_to_be_subs(self, kernel, node_to_be_substituted): + for node in kernel.range_tree_nodes.values(): + if node.expr != sympy_index_symbol(f"{node.parent.prefix}index") \ + or len(node.parent.var_ranges) == 1 \ + or node.symbol() in node_to_be_substituted: + continue + numel = sympy.Integer(1) + new_var_expr = sympy.Integer(0) + for k, s in node.parent.var_ranges.items(): + if k == node.symbol(): + continue + numel = numel * s + sub_node = kernel.range_tree_nodes[k] + new_var_expr = new_var_expr + sub_node.symbol() * sub_node.divisor + + if numel == node.length: + node_to_be_substituted[node.symbol()] = [(node.length, new_var_expr)] + else: + log.warning("sub nodes (expr%s, numel:%d) can not make up parent node(%s:%d)", + new_var_expr, numel, node.symbol(), node.length) + + + + + diff --git a/torch_npu/_inductor/codegen/split_tiling.py b/torch_npu/_inductor/codegen/split_tiling.py new file mode 100644 index 0000000000000000000000000000000000000000..7be80830d94eca0b3ddea998a28f3fa37818ba4f --- /dev/null +++ b/torch_npu/_inductor/codegen/split_tiling.py @@ -0,0 +1,297 @@ +import sympy as sympy + +from torch._inductor.codegen.triton import TritonKernel +from torch._inductor.utils import ModularIndexing, sympy_subs +from torch._inductor.virtualized import V +from torch._inductor.codegen.simd import (EnableReduction, DisableReduction) +from torch._inductor.runtime.runtime_utils import next_power_of_2 +from torch._inductor.loop_body import MemoryUsageType + +from .triton_utils import get_aligned_numel +from ..config import num_vector_core, log + + +# split and tiling axis selector +class SplitTiling: + + def __init__(self, kernel: TritonKernel): + self.kernel = kernel + self.indexing = [] + + def key(x): + # to be higher than x and y + if x.name[0] == 'w' or x.name[0] == 'v' or x.name[0] == 'p' or x.name[0] == 't': + return "z" + x.name + # to be lower than floor_dir + elif isinstance(x.expr, ModularIndexing): + return x.name[0] + "0" + x.name[1:] + else: + return x.name + + kernel.sorted_axis = [x for x in kernel.range_tree_nodes.values()] + kernel.sorted_axis.sort(reverse=True, key=key) + for i, dim in enumerate(kernel.sorted_axis): + dim.sorted_order = i + + self.find_lowest_dimension() + self.should_outer_reduce = False + + # Split 原则1 :先做维度合并,再切分 。通过维度合并降维降低,split和tiling轴选择策略的复杂性 。 + # Split 原则2: 切分的数量要和AIcore的数量对齐(相同或是倍数)。每个核要分配的split的量一致。每个split形状要一致(包括维度和尺寸)。 + # Split 原则3: 对于规约类融合算子, 从非规约选择切分轴。对于非规约类融合算子, 从所有轴中选切分轴。 + # 为了tiling时刻的低维tilesize最大化,切分轴最好不是低维轴且长度大于aicore的数量 。 + # Split 原则4: 如果高维规约类融合算子,而且高维尺寸非常大( >= 64KB),低维度尺寸比较小( <= 32B), 可以选择对规约轴切分,然后在核间用atomic + # 原语做规约。 + # Split 原则5 :根据算子逻辑,优先选择一维发射。 + def select_split_axis(self): + + def select_longest_dim(can_be_low_dim=True): + longest = -1 + longest_dim = None + for x in candidates: + if SplitTiling.great_than(x.length, longest) and (can_be_low_dim or not self.is_lowest_dimension(x)): + longest_dim = x + longest = x.length + return longest_dim + # point-wise : all dims , reduction: outer_reduction dim or non-reduction dims + is_reduction = lambda x: x.prefix == 'r' + candidates = [x for x in self.kernel.sorted_axis if not is_reduction(x) or self.should_outer_reduce_me(x)] + if self.should_outer_reduce: + return self.kernel.split_axis + + # 0307 patch 5lines + if len(candidates) > 0: + longest_dim = candidates[0] + self.kernel.split_axis = longest_dim + self.kernel.split_axis.is_split_axis = True + return longest_dim + + #longest and not low dims + longest_dim = select_longest_dim(can_be_low_dim=False) + + # longest and can be low dims + if longest_dim is None or SplitTiling.less_than(longest_dim.length, int(num_vector_core * 0.8)): + longest_dim = select_longest_dim(can_be_low_dim=True) + if longest_dim is not None: + self.kernel.split_axis = longest_dim + self.kernel.split_axis.is_split_axis = True + elif len(self.kernel.sorted_axis) > 0: + longest_dim = self.kernel.sorted_axis[0] + self.kernel.split_axis = longest_dim + self.kernel.split_axis.is_split_axis = True + + return longest_dim + + # Tiling 原则1:切分要照顾所有load / store 中索引表达式的中的低维轴 :所有的低维轴都被切分 从而成为tiling 轴。写代码的时候对所有的tiling + # 轴通过make_range产生连续索引,从而保证load / store的连续性。 + # Tiling 原则2 :规约的tile必须要二维。 对于低维规约算子,规约轴和至少一个非规约轴要选择为tiling轴。对于高维规约,规约轴和低维轴要选择为tiling轴 + # 对于是多维规约, 所有的规约轴都要选择为tiling 轴 。 + # Tiling 原则3: 如果tiling轴是低维,在该轴上的切分的尺寸要与SIMD的BlockSize 对齐(32bytes) + # Tiling 原则4: 低维轴的tile size 越大,性能越好。这个其实autotune 的原则,放在这里只是为了更好解释用例中使用的数值 。 + + def select_tiling_axis(self): + + # True :self.kernel.axis2 is Not None and all reduction axis selected, False : other cases + def axis2_selection_done(axis): + if self.kernel.total_numels <= 1: + return True + elif self.kernel.axis2 is not None: + is_reduction = axis.prefix == "r" + if not is_reduction: + return True + reduction_axis = self.kernel.numof_reduction_axis() + return True if reduction_axis <= 1 else len(self.kernel.axis2_list) == reduction_axis + else: + return False + + if self.kernel.axis2 is not None or self.kernel.axis1 is not None: + return + # two or more reduction axises, need to flatten reduction dims to one to do 1 dim reduction . + if self.kernel.numof_reduction_axis() > 1: + self.kernel.persistent_reduction = True + biggest = -1 + dims = self.kernel.sorted_axis + if self.kernel.split_axis is None: + self.select_split_axis() + + if self.kernel.split_axis is None: + return + # select tiling_axis2 then tiling_axis1, for reduction, all reduction axis will be selected as tiling_axis2 + for i in range(len(dims) - 1, -1, -1): + axis = dims[i] + numel = axis.length + if isinstance(numel, (sympy.Symbol, sympy.Expr)) and not isinstance(numel, sympy.Integer): + numel = numel.subs(V.graph.sizevars.var_to_val) + if axis.is_split_axis: + dtype = self.kernel.get_axis_dtype(axis) + + min_aligned_numel = get_aligned_numel(dtype) + _, numel = SplitTiling.decide_nblocks_xblock(numel, len(self.kernel.sorted_axis) <= 1, min_aligned_numel) + + # choose reduction axis or low-dim as axis2 + if not axis2_selection_done(axis): + axis.is_tiling_axis2 = True if SplitTiling.great_than(numel, 1) else False + # axis2 must be the reduction axis in case inside_reduction + if axis.prefix == "r": + axis.is_tiling_axis2 = True + if axis.is_tiling_axis2 and self.kernel.axis2 is None: + self.kernel.axis2 = axis.symbol() + if self.kernel.numof_reduction_axis() > 1: + self.kernel.axis2_list.append(axis.symbol()) + self.kernel.axis2 = axis.symbol() if isinstance(axis.expr, ModularIndexing) else self.kernel.axis2 + else: + # for _higher_order_reduction, axis1 must be the lowest dimension + if self.kernel.inside_reduction and self.kernel.is_higher_order_reduction(): + self.kernel.axis1 = axis.symbol() + break + + # low-dim should be selected as another tiling axis + if self.is_lowest_dimension(axis): + self.kernel.axis1 = axis.symbol() + break + # select the longest in other cases + if numel > biggest: + self.kernel.axis1 = axis.symbol() + biggest = numel + if self.kernel.axis1 is not None: + axis = self.kernel.range_tree_nodes[self.kernel.axis1] + axis.is_tiling_axis1 = True + + + log.debug(f"split_tiling numels:{self.kernel.numels} split_axis: {self.kernel.split_axis.symbol()} " + f"axis1:{self.kernel.axis1} axis2:{self.kernel.axis2} low_dims:{self.kernel.low_dims}, " + f"indexing: {self.indexing}") + + + + + def should_outer_reduce_me(self, x): + should_outer = self.kernel.is_higher_order_reduction(True) and SplitTiling.great_than(x.length, 32768) and x.is_loop + if should_outer: + self.should_outer_reduce = True + self.kernel.split_axis = x + self.kernel.split_axis.is_split_axis = True + return should_outer + + @staticmethod + def decide_nblocks_xblock(numel, no_axis2, min_aligned_numel, xblock=None): + #no_axis2 mean there's only on dims + min_xblock = min_aligned_numel if no_axis2 else 1 + + # need to keep linearity for low_dims + if xblock is None: + xblock = (numel + num_vector_core - 1) // num_vector_core if numel > num_vector_core else min_xblock + + xblock = next_power_of_2(xblock) + + nblocks = (numel + xblock - 1) // xblock + return nblocks, xblock + + @staticmethod + def get_nblocks_before_launch(numel, xblock): + nblocks = (numel + xblock - 1) // xblock + return nblocks, xblock + + @staticmethod + def get_nblocks_xblock_list(numel): + ret = [] + XBLOCK = numel + NBLOCKS = 1 + ret.append((NBLOCKS, XBLOCK)) + while NBLOCKS <= num_vector_core and XBLOCK > 1: + XBLOCK -= 1 + NBLOCKS = (numel + XBLOCK - 1) // XBLOCK + XBLOCK = (numel + NBLOCKS - 1) // NBLOCKS + ret.append((NBLOCKS, XBLOCK)) + + return ret + + # return True when x is the low-dim in indexing + def is_lowest_dimension(self, x): + return x.sorted_order in self.kernel.low_dims + + def find_lowest_dimension(self): + def construct_low_dim(): + for index in self.indexing: + coefficients_dict = index.as_coefficients_dict() + for key, value in coefficients_dict.items(): + if not key.free_symbols: + continue + key = list(key.free_symbols)[0] + if key not in self.kernel.range_tree_nodes: + continue + + if value == sympy.Integer(1): + axis = self.kernel.range_tree_nodes[key] + self.kernel.low_dims.add(axis.sorted_order) + + # all read index should be considered + buf_names = [node.node.name for node in self.kernel.node_schedule if node not in (EnableReduction, DisableReduction)] + for node in self.kernel.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + names = [] + + for read in node._body.memory_usage[MemoryUsageType.LOAD]: + name = read.index_name + arg = read.buffer_name + read_is_inptr = False if arg[:3] != 'arg' and arg in buf_names else True + if read_is_inptr: + names.append(name) + + for key, index in node._body.indexing.items(): + if key in names and index not in self.indexing: + self.indexing.append(index) + + if self.kernel.inside_reduction: + construct_low_dim() + return + + # for non-reduction, write index should be considered + for node in self.kernel.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + names = [] + for write in node._body.memory_usage[MemoryUsageType.STORE]: + names.append(write.index_name) + for write in node._body.memory_usage[MemoryUsageType.STORE_REDUCTION]: + names.append(write.index_name) + for key, index in node._body.indexing.items(): + if key in names and index not in self.indexing: + self.indexing.append(index) + + construct_low_dim() + + @staticmethod + def convert(x, y): + xnumel = x + ynumel = y + if isinstance(xnumel, (sympy.Symbol, sympy.Expr)) and not isinstance(xnumel, sympy.Integer): + xnumel = xnumel.subs(V.graph.sizevars.var_to_val) + + if isinstance(ynumel, (sympy.Symbol, sympy.Expr)) and not isinstance(ynumel, sympy.Integer): + ynumel = ynumel.subs(V.graph.sizevars.var_to_val) + + if isinstance(xnumel, sympy.Integer) and isinstance(ynumel, int): + ynumel = sympy.Integer(ynumel) + + if isinstance(ynumel, sympy.Integer) and isinstance(xnumel, int): + xnumel = sympy.Integer(xnumel) + + return (xnumel, ynumel) + + + @staticmethod + def less_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel < ynumel + + @staticmethod + def great_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel > ynumel + + @staticmethod + def ge_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel >= ynumel diff --git a/torch_npu/_inductor/codegen/tile_generator.py b/torch_npu/_inductor/codegen/tile_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6cca5e4d76804a8b7d9b7657eeb82549cbf3823d --- /dev/null +++ b/torch_npu/_inductor/codegen/tile_generator.py @@ -0,0 +1,135 @@ +import copy +import math + +from torch._inductor.runtime.triton_heuristics import Config +from torch._inductor.runtime.runtime_utils import next_power_of_2 +from .triton_utils import get_aligned_numel, byte_per_numel + + +# generate tiling configs +class TileGenerator: + + @staticmethod + def aligned_numel(numel): + aligned = next_power_of_2(numel) + return aligned + + @staticmethod + def get_byte_per_numel(dtype): + if dtype is None: + return 1 + return byte_per_numel[dtype] + + @staticmethod + def valid_config(config, align_numel, rnumel=1): + + count_bytes = align_numel + max_numel = 16384 * 4 // count_bytes + + rblock = config["RBLOCK"] if "RBLOCK" in config else rnumel + xblock_sub = config["XBLOCK_SUB"] + if rblock * xblock_sub <= max_numel: + return True + + return False + + # when rblock is low dim, need to maximize rblock + @staticmethod + def descend_xblock(rnumel, xblock, configs, cfg, align_numel, aggresive=True): + + count_bytes = align_numel + start_numel = 2048 // count_bytes if aggresive else 1024 // count_bytes + # include rblock is too big, need to decend rblock first + rblock = rnumel if rnumel > 0 else 1 + while (rblock > start_numel): + newcfg = copy.deepcopy(cfg) + newcfg["RBLOCK"] = rblock + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + rblock = rblock // 2 + cfg["RBLOCK"] = rblock + xblock_sub = TileGenerator.aligned_numel(xblock) + + while True: + newcfg = copy.deepcopy(cfg) + newcfg["XBLOCK_SUB"] = xblock_sub + if TileGenerator.valid_config(newcfg, align_numel, rnumel=rblock): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + xblock_sub = xblock_sub // 2 + if xblock_sub * rblock <= start_numel: + break + + @staticmethod + def descend_rblock(rnumel, xblock, configs, cfg, align_numel, aggresive=True): + count_bytes = align_numel + start_numel = 4096 // count_bytes if aggresive else 1024 // count_bytes + + xblock_sub = start_numel if xblock > start_numel else xblock + cfg["XBLOCK_SUB"] = xblock_sub + rblock = rnumel + while True: + newcfg = copy.deepcopy(cfg) + newcfg["RBLOCK"] = rblock + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + rblock = rblock // 2 + if xblock_sub * rblock <= start_numel: + break + + @staticmethod + def descend_xblock_rblock(rnumel, xblock, configs, cfg, align_numel, aggresive=True): + count_bytes = align_numel + start_numel = 4096 // count_bytes if aggresive else 1024 // count_bytes + + # Depending on the number of bytes available to the hardware UB, + # 4096 bytes is an appropriate empirical value for an intra-core split. + # Rule: xblock_sub * rblock <= start_numel + end_numel = math.floor(math.sqrt(start_numel)) + + xblock = next_power_of_2(xblock) + rnumel = next_power_of_2(rnumel) + + xblock_sub = xblock if xblock > start_numel else xblock + rblock = start_numel if rnumel > start_numel else rnumel + + rblock_is_biggerr = rblock > xblock_sub + + if xblock_sub * rblock <= start_numel: + newcfg = copy.deepcopy(cfg) + newcfg["XBLOCK_SUB"] = xblock_sub + newcfg["RBLOCK"] = rblock + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + + if rblock_is_biggerr: + while rblock > xblock_sub and xblock_sub * rblock > start_numel: + newcfg = copy.deepcopy(cfg) + newcfg["RBLOCK"] = rblock + xblock_sub = xblock + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + rblock = rblock // 2 + else: + while rblock < xblock_sub and xblock_sub * rblock > start_numel: + newcfg = copy.deepcopy(cfg) + newcfg["XBLOCK_SUB"] = xblock_sub + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + xblock_sub = xblock_sub // 2 + + while xblock_sub * rblock > start_numel: + newcfg = copy.deepcopy(cfg) + newcfg["XBLOCK_SUB"] = xblock_sub + newcfg["RBLOCK"] = rblock + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + if xblock_sub >= end_numel: + xblock_sub = xblock_sub // 2 + if rblock >= end_numel: + rblock = rblock // 2 + + @staticmethod + def nearest_power_of_2(n): + big = next_power_of_2(n) + small = big // 2 + return big if (big - n) < (n - small) else small diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..68ffb16a4ed19d5bb21ba801a4c98502e5b0af68 --- /dev/null +++ b/torch_npu/_inductor/codegen/triton.py @@ -0,0 +1,2107 @@ +import os +from typing import List, Set, Iterable, Callable, Sequence +from typing import Dict +import operator +import itertools +from enum import Enum +import functools + +from typing import ( + Optional, + Union, + Tuple, + Any, + cast +) + +import re +import textwrap +import sympy + +import torch +from torch._inductor.utils import sympy_subs +from torch._inductor.scheduler import SchedulerNode + +from torch._inductor.codegen.simd import CantSplit, DisableReduction, EnableReduction +from torch._inductor.codegen.common import free_symbol_is_type +from torch._inductor.codegen.triton import ( + IndexingOptions, + triton_reshape, + TritonCSEVariable, + OpsHandler, +) +from torch._inductor.runtime.hints import ReductionHint +from torch._inductor.codegen.triton import ( + TritonKernel, + TritonKernelOverrides, + IterationRangesRoot, + IterationRangesEntry, + CSEVariable, + gen_common_triton_imports, + BlockPtrOptions, + triton_acc_type, + constant_repr, + is_welford_reduction, FixedTritonConfig, + prefix_is_reduction, upcast_acc_dtype +) + +from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing +from torch._inductor.utils import sympy_index_symbol, generate_assert +from torch.utils import _pytree as pytree +from torch.utils._sympy.value_ranges import ValueRanges +from torch._inductor import config, ir +from torch._inductor.virtualized import ( + V, + StoreMode, + ReductionType, + _ops as ops, +) + +from torch._inductor.utils import ( + Placeholder, +) +from torch._inductor.runtime.runtime_utils import next_power_of_2 +from torch._inductor.codegen.common import ( + IndentedBuffer, + SizeArg, + DeferredLine, +) +from torch._inductor.codegen.triton_utils import config_of, signature_of, signature_to_meta +from torch.utils._sympy.symbol import SymT, symbol_is_type +from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.numbers import int_oo +from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + +from ..runtime import NPUDeviceProperties +from .npu_kernel_features import NumelList + + +def flatten(nums): + res = [] + for i in nums: + if isinstance(i, list): + res.extend(flatten(i)) + else: + res.append(i) + return res + + +class AxisDirection(Enum): + Flat = 0, + Vertical = 1, + Horizontal = 2 + + +def reverse_direction(direction): + if direction == AxisDirection.Vertical: + return AxisDirection.Horizontal + elif direction == AxisDirection.Horizontal: + return AxisDirection.Vertical + else: + return AxisDirection.Flat + + +class NPUTritonKernelOverrides(TritonKernelOverrides): + + @staticmethod + def exp(x): + return f"tl_math.exp({x})" + + @staticmethod + def sqrt(x): + return f"tl_math.sqrt({x})" + + @staticmethod + def tanh(x): + return f"tl_math.tanh({x})" + + @staticmethod + def rsqrt(x): + return f"tl.rsqrt({x})" + + @staticmethod + def floor(x): + return f"tl_math.floor({x})" + + @staticmethod + def erf(x): + return f"tl_math.erf({x})" + + @staticmethod + def ceil(x): + return f"tl_math.ceil({x})" + + +def group_fn(self, sizes): + groups = list() + for s in sizes: + if not s: + groups.append(1) + elif isinstance(s, list): + group = flatten(s) + groups.append(NumelList(tuple(group)) if isinstance(group, list) else group) + else: + groups.append(s) + return tuple(groups) + + +@staticmethod +def select_index_dtype(node_schedule, numel, reduction_numel): + return "tl.int32" + + + +class IterationRangesEntryNPUIndex(IterationRangesEntry): + def __init__( + self, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_tiling_axis1 = False + self.is_tiling_axis2 = False + self.is_split_axis = False + self.indexing_code = IndentedBuffer() + self.sorted_order = None + self.low_dims = set() + + + def _codegen_mask(self): + if self.is_tiling_axis1 or self.is_tiling_axis2: + upper = f"{self.name}_numel" + line = f"{self.name}_mask = {self.name} < {upper}" + self.writeline(line) + line = f"{self.name}_prime_mask = {self.name}_prime < {upper}" + self.writeline(line) + else: + pass + + def _codegen(self): + index = None + vertical = self.is_tiling_axis1 if V.kernel.numof_reduction_axis() <= 1 else not isinstance(self.expr, ModularIndexing) + direction = V.kernel.get_axis_direction(vertical) + # for multiple reduce dims, don't need this + if self.is_tiling_axis1 and V.kernel.numof_reduction_axis() <= 1: + index = f"{self.name} = {self.codegen_index(direction)}" + #to be fixed, only permute need to this . + self.writeline(f"{self.name}_prime = {self.codegen_index(reverse_direction(direction))}") + + elif self.is_tiling_axis2: + index = f"{self.name} = {self.codegen_index(direction)}" + #to be fixed, only permute need to this . + self.writeline(f"{self.name}_prime = {self.codegen_index(reverse_direction(direction))}") + if V.kernel.inside_reduction and V.kernel.current_node \ + and isinstance(V.kernel.current_node, SchedulerNode) \ + and V.kernel.current_node.node \ + and V.kernel.current_node.node.data \ + and isinstance(V.kernel.current_node.node.data, ir.Reduction): + reduction_type = V.kernel.current_node.node.data.reduction_type + if reduction_type in {"argmax", "argmin"}: + self.writeline(f"{self.parent.prefix}index = " + f"{self.codegen_index(reverse_direction(AxisDirection.Flat))}") + if index: + self.writeline(index) + self._codegen_mask() + return self.name + + def writeline(self, line): + self.indexing_code.writeline(line) + + def codegen_index(self, direction): + if self.is_tiling_axis1 and V.kernel.axis2 is None and V.kernel.persistent_reduction: + index = f"tl.arange(0, RBLOCK)" + return index + elif self.is_tiling_axis1: + if self.is_split_axis: + offset = f"{self.symbol()}_offset" + index = f"{offset} + (loop1 * XBLOCK_SUB) + base1" + else: + index = f"(loop1 * XBLOCK_SUB) + base1" + + if V.kernel.axis2 is not None and direction != AxisDirection.Flat: + index += ("[None, :]" if direction == AxisDirection.Horizontal else "[:, None]") + return index + elif self.is_tiling_axis2: + if V.kernel.persistent_reduction: + index = f"tl.arange(0, RBLOCK_{self.symbol()})" if V.kernel.numof_reduction_axis() > 1 else "base2" + elif self.is_split_axis: + offset = f"{self.symbol()}_offset" + index = f"{offset} + (loop2 * RBLOCK) + base2" + else: + index = "loop2 * RBLOCK + base2" + + if direction != AxisDirection.Flat: + index += ("[:, None]" if direction == AxisDirection.Vertical else "[None, :]") + return index + else: + raise RuntimeError("codegen_index") + + def codegen_header(self, code): + # generate offset index loop + lines = [] + + if self.is_split_axis and not (V.kernel.axis2 is None and V.kernel.persistent_reduction): + lines.append(f"{self.symbol()}_offset = tl.program_id(0) * XBLOCK") + + if self.is_tiling_axis1 and not (V.kernel.axis2 is None and V.kernel.persistent_reduction): + # don't create loops for multi-reductions + if V.kernel.numof_reduction_axis() <= 1: + lines.append("base1 = tl.arange(0, XBLOCK_SUB)") + xblock = f"XBLOCK" if self.is_split_axis else f"{self.symbol()}_numel" + lines.append(f"loops1 = ({xblock} + XBLOCK_SUB - 1) // XBLOCK_SUB") + + elif self.is_tiling_axis2 and len(V.kernel.axis2_list) <= 1: + lines.append("base2 = tl.arange(0, RBLOCK)") + if self.is_split_axis: + lines.append(f"loops2 = (XBLOCK + RBLOCK - 1) // RBLOCK") + else: + lines.append(f"loops2 = ({self.name}_numel + RBLOCK - 1) // RBLOCK") + else: + pass + + code.writelines(lines) + + def precomputed_args(self): + # for dynamic shapes, find parts of indexing expressions that have to be precomputed + precomputed_args: List[sympy.Expr] = [] + if isinstance(self.expr, (sympy.Symbol, sympy.Integer)): + return precomputed_args + + if not isinstance(self.expr, (FloorDiv, ModularIndexing)): + raise RuntimeError("assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr)") + for arg in self.expr.args[1:]: + if not isinstance(arg, (sympy.Integer, sympy.Symbol)): + symbols = arg.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, SymT.SIZE) for s in symbols + ): + precomputed_args.append(arg) + return precomputed_args + + +class IterationRangesRootNPUIndex(IterationRangesRoot): + def __init__( + self, + name: str, + numel: sympy.Expr, + prefix: str, + index: int, + kernel: TritonKernel, + pid_cache=None, + *, + is_loop: bool, + tensor_dim: Optional[int], + grid_dim: Optional[int], + ): + super().__init__(name, numel, prefix, index, kernel, pid_cache, is_loop=is_loop, tensor_dim=tensor_dim, + grid_dim=grid_dim, has_zdim=False) + + def __repr__(self): + return f"IterationRangesRootNPUIndex({self.name!r}, {self.numel}, ...)" + + def remove_entry(self, name): + if name in self.var_ranges: + del self.var_ranges[name] + if name in self.var_list: + del self.var_list[self.var_list.index(name)] + if name in V.kernel.range_tree_nodes: + V.kernel.range_tree_nodes_removed[name] = V.kernel.range_tree_nodes[name] + del V.kernel.range_tree_nodes[name] + if name in self.nodes: + del self.nodes[name] + + def duplicated_check(self, divisor, length): + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) + else: + expr = ModularIndexing( + sympy_index_symbol(f"{self.prefix}index"), divisor, length + ) + + return expr not in self.nodes + + + def lookup(self, divisor, length): + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) + else: + expr = ModularIndexing( + sympy_index_symbol(f"{self.prefix}index"), divisor, length + ) + + if expr not in self.nodes: + node = IterationRangesEntryNPUIndex( + f"{self.prefix}{next(V.kernel.iter_vars_count)}", + divisor, + length, + expr, + self, + ) + V.kernel.range_tree_nodes[node.symbol()] = node + self.var_list.append(node.symbol()) + self.var_ranges[node.symbol()] = length + self.nodes[expr] = node + + + return self.nodes[expr] + + +def is_compatible(groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]]): + try: + groups = flatten(groups) + NPUIndexTritonKernel._split_iteration_ranges(groups, lengths) + return True + except CantSplit: + return False + + +class NPUIndexTritonKernel(TritonKernel): + overrides = NPUTritonKernelOverrides + + def __init__( + self, + tiling: Dict[str, sympy.Expr], + min_elem_per_thread=0, + optimize_mask=True, + fixed_config: Optional[FixedTritonConfig] = None, + **kwargs,): + + super().__init__(tiling=tiling, + min_elem_per_thread=min_elem_per_thread, + optimize_mask=optimize_mask, + fixed_config=fixed_config, + **kwargs) + self.first_node = True + self.inside_high_order_reduction = False + # split axis + self.split_axis = None + # tiling axis + self.axis1 = None + self.axis2 = None + # incase two reduction axis + self.axis2_list = [] + self.low_dims = set() + + self.range_tree_nodes_removed: Dict[sympy.Symbol, IterationRangesEntry] = {} + self.range_tree_nodes_substituted = {} + self.expr_substituted = {} + self.sorted_axis = [] + self.prefix: IndentedBuffer = IndentedBuffer() + + def gen_triton_ext_imports(self): + imports = IndentedBuffer() + imports.splice( + """ + from torch._inductor.runtime import triton_helpers + from torch_npu._inductor import npu_triton_heuristics + from torch_npu._inductor import npu_triton_helpers + from torch_npu._inductor.runtime import NPUDeviceProperties + from torch_npu._inductor.npu_triton_helpers import libdevice, math as tl_math + import torch + """ + ) + return imports.getvalue() + + + def patch_triton_hash(self): + # remove this method once the original invocation is fixed + import hashlib + from triton.compiler.compiler import triton_key, make_backend + from triton.runtime.driver import driver + backend = make_backend(driver.active.get_current_target()) + key = f"{triton_key()}-{backend.hash()}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def numof_reduction_axis(self): + root = self.range_trees[-1] + if root is None: + return 0 + + return len(root.var_list) + + def numof_tiling_axis(self): + return (1 if self.axis1 is not None else 0) + (1 if self.axis2 is not None else 0) + + #do nothing in NpuTritonKernel + def codegen_range_tree(self): + pass + + + def initialize_range_tree(self, pid_cache): + self.total_numels = 0 + for k, x in self.numels.items(): + if not isinstance(x, sympy.Integer): + x = x.subs(V.graph.sizevars.var_to_val) + self.numels[k] = x + if x > 1: + self.total_numels += 1 + + no_r_dim = not self.inside_reduction or self.numels["r"] == 1 + prefixes = "wvtzyxr" + active_prefixes = prefixes[-len(self.numels):] + #prefix can not be 's', 'u', 'ps' , 'i', 'z', 'q' + #prefix can not be 'p' from torch 2.6.0 + grid_dims = "xyztvw" + if self.no_x_dim: + tensor_dims = "r" + elif no_r_dim: + tensor_dims = "xyztvw" + else: + tensor_dims = "xyztvwr" + tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes) + for i, prefix in enumerate(active_prefixes): + is_reduction = prefix_is_reduction(prefix) + tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None + grid_dim = None if is_reduction else grid_dims.find(prefix) + index = i if grid_dim is None else grid_dim + self.range_trees.append( + IterationRangesRootNPUIndex( + f"{prefix}index", + self.numels[prefix], + prefix, + index, + self, + pid_cache=pid_cache, + is_loop=is_reduction and not self.persistent_reduction, + tensor_dim=tensor_dim, + grid_dim=grid_dim + ) + ) + + # numels sent to autotune configs + def get_size_hints(self): + size_hints = [] + + if (len(self.range_tree_nodes.values()) == 0): + return size_hints + + for _, node in enumerate(self.sorted_axis): + if isinstance(node.expr, ModularIndexing): + numel_expr = node.length + else: + numel_expr = node.expr.subs({sympy_index_symbol(r.name): r.numel for r in self.range_trees}) + + numel_expr = V.graph.sizevars.symbolic_hint(numel_expr) + + size_hints.append(numel_expr) + return size_hints + + # torch251 done + def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid): + for node in self.sorted_axis: + if isinstance(node.expr, ModularIndexing): + numel_expr = node.length + else: + numel_expr = node.expr.subs({sympy_index_symbol(r.name): r.numel for r in self.range_trees}) + + if isinstance(numel_expr, (sympy.Integer, sympy.Symbol)): + expr = numel_expr + else: + expr = V.graph.wrapper_code.generate_node_numel_expr(name, node, numel_expr) + call_args.append(expr) + arg_types.append(type(expr)) + if node.parent.grid_dim is not None: + grid.append(expr) + + def gen_numel_args(self, signature, triton_meta_signature, argdefs): + for node in self.sorted_axis: + arg_name = f"{node.name}_numel" + if not os.environ.get('INDUCTOR_STATIC_MODE'): + sizearg = SizeArg(arg_name, node.length) + signature.append(sizearg) + triton_meta_signature[arg_name] = signature_of( + sizearg, size_dtype=self.index_dtype + ) + argdefs.append(arg_name) + else: + argdefs.append(f"{arg_name}: tl.constexpr") + self.triton_meta["constants"][arg_name] = node.length + + + def codegen_kernel(self, name=None): + code = IndentedBuffer() + size_hints = self.get_size_hints() + heuristics = self._get_heuristic() + if name is None: + code.splice(gen_common_triton_imports()) + # Note: add extra imports for extensions + code.splice(self.gen_triton_ext_imports()) + + if config.benchmark_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature, _ = self.args.python_argdefs() + + for i, arg in enumerate(signature): + if isinstance(arg, SizeArg): + symbol = cast(sympy.Symbol, arg.expr) + if symbol in V.graph.sizevars.inv_precomputed_replacements: + signature[i] = SizeArg( + arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] + ) + + triton_meta_signature = signature_to_meta(signature, size_dtype=self.index_dtype, argdefs=argdefs) + + triton_meta = { + "signature": triton_meta_signature, + "device": + NPUDeviceProperties.create( + V.graph.get_current_device_or_throw() + ), + "constants": {}, + # special config for NPU, specify compile target + "mix_mode": "aiv", + } + + inductor_meta = self.create_inductor_meta() + num_gb = None + if config.benchmark_kernel or config.profile_bandwidth: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + + self.triton_meta = triton_meta + self.gen_numel_args(signature, triton_meta_signature, argdefs) + + #add in tiling args + self.add_autotune_args(argdefs) + #for scalar codegen + if len(self.range_tree_nodes) == 0: + self.write_scalar() + else: + self.codegen_body() + + for helper in self.helper_functions: + code.writeline("") + code.splice(helper) + + + # Note: override original triton_heuristics + if self.inside_reduction: + reduction_hint = self.features.get_reduction_hint() + heuristics_line = f""" + @npu_triton_heuristics.{heuristics}( + size_hints={size_hints}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + if len(signature) == 4: # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @npu_triton_heuristics.{heuristics}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + min_elem_per_thread={self.min_elem_per_thread} + ) + @triton.jit + """ + code.splice(heuristics_line) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" + ) + with code.indent(): + self.codegen_static_numels(code) + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.body) + + if config.benchmark_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb)) + + return code.getvalue() + + + + def codegen_static_numels(self, code): + no_x_axis = self.numof_reduction_axis() > 1 + symbols = [] + if self.axis2 is not None: + symbols = list(self.axis2_list) if no_x_axis else list([self.axis2]) + elif self.persistent_reduction and self.axis1 is not None: + symbols = list([self.axis1]) + + nodes = [self.range_tree_nodes[symbol] for symbol in symbols if symbol is not None] + for node in nodes: + if node.prefix == "r" and self.persistent_reduction: + simplified_tree_numel = V.graph.sizevars.simplify(node.length) + if isinstance(simplified_tree_numel, (sympy.Integer, int)): + val = int(simplified_tree_numel) + else: + continue + val = next_power_of_2(val) + if no_x_axis: + code.writeline(f"RBLOCK_{node.symbol()}: tl.constexpr = {val}") + else: + code.writeline(f"RBLOCK: tl.constexpr = {val}") + + def axis2_variable(self): + if self.axis2 is not None: + return self.range_tree_nodes[self.axis2] + return None + + def is_isolated_symbol(self, input_str, symbol): + # 使用正则表达式查找独立的符号, 防止out_ptr0 匹配上r0 r0_prime + pattern1 = r'\b' + re.escape(symbol) + r'\b' + pattern2 = r'\b' + re.escape(symbol + '_prime') + r'\b' + + return bool(re.search(pattern1, input_str)) or bool(re.search(pattern2, input_str)) + + def find_axis2_in_load_store(self): + var = self.axis2_variable() + if not var: + return False + for line in self.loads._lines: + if line.find('tl.load') >= 0 and self.is_isolated_symbol(line, var.name): + return True + for line in self.compute._lines: + if line.find('tl.load') >= 0 and self.is_isolated_symbol(line, var.name): + return True + for line in self.post_loop_store._lines: + if line.find('tl.store') >= 0 and self.is_isolated_symbol(line, var.name): + return True + for line in self.stores._lines: + if isinstance(line, DeferredLine): + line = line.line + if line.find('tl.store') >= 0 and self.is_isolated_symbol(line, var.name): + return True + return False + + def find_axis2_in_indexing(self): + var = self.axis2_variable() + if not var: + return False + if self.current_node is None: + return False + for index in self.current_node._body.indexing.values(): + if var.symbol() in index.free_symbols: + return True + return False + + def write_scalar(self): + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.post_loop_store.clear() + self.prefix.clear() + + def is_1d_reduction(self): + return self.numels["r"] > 1 and self.axis2 is None + + def codegen_body(self): + if not ( + self.loads + or self.stores + or self.compute + or self.post_loop_store + ): + return + + def write_pointwise(): + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + + def codegen_range(index): + + def loop_body(index, indexing_code, is_last_axis, do_indent=True): + if do_indent: + self.body.do_indent() + if indexing_code: + self.body.splice(indexing_code) + + if is_last_axis: + write_pointwise() + else: + codegen_range(index + 1) + + if do_indent: + self.body.do_unindent() + + if index < 0 or index >= len(self.range_tree_nodes): + return + nodes = self.sorted_axis + range_node = nodes[index] + is_tilling_asix1 = getattr(range_node, "is_tiling_axis1") + is_tilling_asix2 = getattr(range_node, "is_tiling_axis2") + is_last_axis = index == len(nodes) - 1 + indexing_code = getattr(range_node, "indexing_code") + numof_axis2 = self.numof_reduction_axis() + if is_tilling_asix1: + do_indent = True + reduction_1d = self.is_1d_reduction() + if reduction_1d: + self.body.splice(self.prefix) + self.prefix.clear() + + # multi-dim reduction, i.e. var_mean[1,2] + if numof_axis2 > 1: + if range_node.is_split_axis: + offset = f"{range_node.name}_offset" + self.body.writeline(f"for {range_node.name} in range({offset}, " + f"min({offset} + XBLOCK, {range_node.name}_numel)):") + else: + self.body.writeline(f"for {range_node.name} in range({range_node.name}_numel):") + # 1D persistent_reduction or 1d reduction non-first-node + elif self.axis2 is None and (self.persistent_reduction or len(self.loads._lines) == 0): + do_indent = False + if len(self.loads._lines) == 0: + indexing_code = None + else: + self.body.writeline(f"for loop1 in range(loops1):") + + + if not reduction_1d and self.persistent_reduction: + self.body.do_indent() + self.body.splice(self.prefix) + self.prefix.clear() + self.body.do_unindent() + + loop_body(index, indexing_code, is_last_axis, do_indent=do_indent) + + # for 1D reduction, need to add in suffix for persist_reduction or second node of 1d reduction + if self.is_1d_reduction() or self.persistent_reduction: + self.body.splice(self.post_loop_store) + self.post_loop_store.clear() + + + elif is_tilling_asix2: + do_indent = False + need_axis2_loop = self.find_axis2_in_load_store() + if not need_axis2_loop: + indexing_code = None + if (not self.inside_reduction or not self.persistent_reduction) \ + and need_axis2_loop: + self.body.splice(self.prefix) + self.body.writeline(f"for loop2 in range(loops2):") + do_indent = True + loop_body(index, indexing_code, is_last_axis, do_indent) + self.body.splice(self.post_loop_store) + self.post_loop_store.clear() + + elif is_last_axis and range_node.numel == 1: + #pointwise , last axis =1 + write_pointwise() + else: + if range_node.is_split_axis: + offset = f"{range_node.symbol()}_offset" + self.body.writeline(f"for {range_node.symbol()} in range({offset}, min({offset} + XBLOCK, {range_node.name}_numel)):") + else: + self.body.writeline(f"for {range_node.symbol()} in range({range_node.name}_numel):") + loop_body(index, indexing_code, is_last_axis) + + if self.first_node: + for node in self.sorted_axis: + node.codegen_header(self.body) + + + if self.first_node: + codegen_range(0) + else: + if self.axis2 is None: + codegen_range(0) + else: + axis2_order = self.range_tree_nodes[self.axis2].sorted_order + if self.persistent_reduction and self.numof_reduction_axis() > 1: + axis2_order = axis2_order - self.numof_reduction_axis() + 1 + for _ in range(axis2_order): + self.body.do_indent() + codegen_range(axis2_order) + for _ in range(axis2_order): + self.body.do_unindent() + + self.cse.invalidate(self.outside_loop_vars) + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.post_loop_store.clear() + self.prefix.clear() + self.first_node = False + + # for creat constant tensor, if have two axis, constant=tl.full([1,1]) else tl.full([1]) + def triton_tensor_ndim(self): + if self.numof_reduction_axis() > 1: + return 1 + if self.axis1 is not None and self.axis2 is not None: + ndim = 2 + else: + ndim = 1 + return ndim + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + if not self.inside_reduction: + raise RuntimeError("assert self.inside_reduction") + + self.inside_reduction = False + indexing = self.indexing(index, block_ptr=True) + self.inside_reduction = True + var = self.args.output(name) + if isinstance(indexing, BlockPtrOptions): + self.post_loop_store.writeline( + DeferredLine( + name, + self.codegen_block_ptr_store_line( + name, + indexing, + indexing.format(var), + value, + f", boundary_check={indexing.boundary_check()!r}", + ), + ) + ) + else: + if not isinstance(indexing, IndexingOptions): + raise RuntimeError("assert isinstance(indexing, IndexingOptions)") + line = f"tl.store({var} + ({indexing.index_str} ), {value}, {indexing.mask_str})" + if self.numof_reduction_axis() > 1: + line = f"tl.store({var} + ({indexing.index_str} + tl.arange(0,1) ), {value}, {indexing.mask_str})" + self.post_loop_store.writeline( + DeferredLine(name, line) + ) + + def apply_var_prime(self, index, line, mask): + # axis should only be replaced once + axis_list = [] + for key in index.as_coefficients_dict().keys(): + if not key.free_symbols: + continue + symbol = list(key.free_symbols)[0] + if symbol not in self.range_tree_nodes: + continue + range_node = self.range_tree_nodes[symbol] + if (range_node.is_tiling_axis1 or range_node.is_tiling_axis2) and (symbol not in axis_list): + line = line.replace(f"{range_node.name}", f"{range_node.name}_prime") + mask = mask.replace(f"{range_node.name}", f"{range_node.name}_prime") + axis_list.append(symbol) + return line, mask + + # apply xxx_prime var in case dim are permuted + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + + var = self.args.output(name) + original_index = index + indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None) + index_str = indexing.index_str + value_str = f"{value}" + + # need to reshape when value's dimensions > 2, e.g. (XBLOCK,1,RBLOCK) + is_permuted = self.need_permuted(index) + + mask_str = indexing.mask_str + if is_permuted: + index_str, mask_str = self.apply_var_prime(index, index_str, indexing.mask_str) + value_str = value_str.replace(f"{value}", f"{value}.permute(1,0)") + + advance_block_ptr = None + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing + ) + # block_ptr stores don't do implicit casting + line = self.codegen_block_ptr_store_line( + name, indexing, block_ptr, value, other + ) + elif mode is None: + line = f"tl.store({var} + ({index_str}), {value_str}, {mask_str})" + if len(self.axis2_list) > 1: + line = f"tl.store({var} + ({index_str} + tl.arange(0,1) ), {value_str}, {indexing.mask_str})" + + elif mode == "atomic_add": + line = f"tl.atomic_add({var} + ({index_str}), {value_str}, {indexing.mask_str})" + else: + raise NotImplementedError(f"store mode={mode}") + + self.stores.writeline(DeferredLine(name, line)) + if advance_block_ptr: + self.stores.writeline(advance_block_ptr) + + if not self.inside_reduction: + self.outside_loop_vars.add(value) + + + @staticmethod + def _get_next_scheduler_node(node_schedule, current_node): + found_current = False if current_node else True + for node in node_schedule: + if isinstance(node, SchedulerNode): + if not found_current and node.get_name() == current_node.get_name(): + found_current = True + continue + if found_current: + return node + return None + + def get_next_scheduler_node(self, node): + return self._get_next_scheduler_node(self.node_schedule, node) + + def get_prev_scheduler_node(self, node): + return self._get_next_scheduler_node(reversed(self.node_schedule), node) + + def check_all_index_is_1d_for_dual_reduction(self): + if self.numof_reduction_axis() <= 1: + return False + + all_index_is_1d = True + for _, index in self.current_node._body.indexing.items(): + count = 0 + for symbol in index.free_symbols: + if symbol in self.axis2_list: + count = count + 1 + if count > 1: + all_index_is_1d = False + + if not all_index_is_1d: + break + return all_index_is_1d + + # to generate the shape of the accumulator of RBLOCK loop + def dense_size_list(self, is_permute) -> List[str]: + + sizes = [] + if self.numof_reduction_axis() > 1: + sizes = [] if self.check_all_index_is_1d_for_dual_reduction() else [f"RBLOCK_{axis}" for axis in self.axis2_list] + return sizes + if self.persistent_reduction and self.axis2 is None: + sizes = ["RBLOCK"] + return sizes + # current computedbuffer is reduction + cb_is_reduction = self.inside_reduction if not self.current_node else isinstance(self.current_node.node.data, ir.Reduction) + + for tree in self.sorted_axis: + if tree.is_tiling_axis1: + sizes.append("XBLOCK_SUB") + elif tree.is_tiling_axis2: + sizes.append("RBLOCK") + + if cb_is_reduction and self.inside_reduction and self.is_higher_order_reduction() or is_permute: + sizes = sizes[::-1] + + return sizes + + def dense_size_str(self, is_permute=False): + sizes = self.dense_size_list(is_permute) + if self.numof_reduction_axis() > 1: + return f"[{'* '.join(sizes)}]" + return f"[{', '.join(sizes)}]" + + def filter_masks(self, mask_vars): + for node in self.sorted_axis: + if not(node.is_tiling_axis1 or node.is_tiling_axis2): + mask_vars.discard(f"{node.name}_mask") + if len(self.axis2_list) > 1 and not node.is_tiling_axis2: + mask_vars.discard(f"{node.name}_mask") + + # and add to shape to value + def reduction_resize(self, value): + ndims = self.triton_tensor_ndim() + if ndims == 1: + return f"triton_helpers.promote_to_tensor({value})" + is_higher_order_reduction = self.is_higher_order_reduction() + + expand_str = "1," if is_higher_order_reduction else ",1" + if is_higher_order_reduction: + return f"{value}.reshape({expand_str}XBLOCK_SUB)" + else: + return f"{value}.reshape(XBLOCK_SUB{expand_str})" + + def get_axis_direction(self, is_axis1, is_reversed=False): + + if self.check_all_index_is_1d_for_dual_reduction(): + result = AxisDirection.Flat + elif not self.inside_reduction: + if self.numof_tiling_axis() > 1: + result = AxisDirection.Vertical if is_axis1 else AxisDirection.Horizontal + else: + result = AxisDirection.Flat + else: + if is_axis1: + result = AxisDirection.Horizontal if V.kernel.is_higher_order_reduction() else AxisDirection.Vertical + else: + result = AxisDirection.Vertical if V.kernel.is_higher_order_reduction() else AxisDirection.Horizontal + + result = reverse_direction(result) if is_reversed else result + return result + + def is_higher_order_reduction(self, check_prev_node=False): + if self.numof_reduction_axis() > 1: + return False + if not (self.inside_reduction): + raise RuntimeError("assert self.inside_reduction") + + if self.inside_high_order_reduction: + return self.inside_high_order_reduction + + node = self.current_node if self.current_node is not None else self.get_prev_scheduler_node(None) + if node is None or not isinstance(node, SchedulerNode): + return False + + reduction = node.node.data + while check_prev_node and reduction is not None and not isinstance(reduction, ir.Reduction): + node = self.get_prev_scheduler_node(node) + if node is None: + reduction = None + else: + reduction = node.node.data + + + if reduction is None or not isinstance(reduction, ir.Reduction): + return False + if not hasattr(reduction, "reduced_idx"): + return False + + reduced_order = reduction.reduced_idx[0] + is_last_axis = all(_ < reduced_order for _ in reduction.kept_idx) + self.inside_high_order_reduction = not is_last_axis + return self.inside_high_order_reduction + + def get_axis_dtype(self, axis): + dtype = None + if axis is None: + return None + for node in self.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + if axis.symbol() in node._body.indexing_map: + dtype = V.graph.get_dtype(node.node.name) + break + if dtype is None: + should_break_all = False + for node in self.node_schedule: + if should_break_all: + break + if node in (EnableReduction, DisableReduction): + continue + for key, _ in node._body.indexing_map.items(): + if key in self.range_tree_nodes: + dim = self.range_tree_nodes[key] + else: + dim = self.range_tree_nodes_removed[key] + + if dim.parent == axis.parent: + dtype = V.graph.get_dtype(node.node.name) + should_break_all = True + break + return dtype + + def create_inductor_meta(self): + mutated_args = set() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if ( + mutation in self.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in self.removed_buffers + ): + mutated_args.add(self.args.inplace_buffers[mutation].inner_name) + if mutation in self.args.output_buffers: + mutated_args.add(self.args.output_buffers[mutation]) + mutated_args = sorted(mutated_args) + axis1_order = self.range_tree_nodes[self.axis1].sorted_order if self.axis1 is not None else None + axis2_order = self.range_tree_nodes[self.axis2].sorted_order if self.axis2 is not None else None + split_axis_dtype = self.get_axis_dtype(self.split_axis) + inductor_meta = { + "autotune_hints": set(self.autotune_hints), + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + "no_x_dim": self.no_x_dim, + # Due to breaking change of triton 3.0, the original invocation is broken + "backend_hash": self.patch_triton_hash(), # torch.utils._triton.triton_hash_with_backend(), + "split_axis_order": self.split_axis.sorted_order if self.split_axis is not None else None, + "axis1_order": axis1_order, + "axis2_order": axis2_order, + "low_dims": self.low_dims, + "numof_reduction_axis": self.numof_reduction_axis(), + "split_axis_dtype": split_axis_dtype + } + return inductor_meta + + def reduction_dim(self): + if not self.inside_reduction: + raise RuntimeError("assert self.inside_reduction") + if self.numof_reduction_axis() > 1: + return 0 + return 0 if self.is_higher_order_reduction() or len(self.sorted_axis) == 1 else 1 + + def reduction_var(self): + var = self.axis2 + return var + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + if not self.inside_reduction: + raise RuntimeError("assert self.inside_reduction") + masks = {f"{node.symbol()}_mask" for node in self.sorted_axis} + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + reduction_range_prefix = self.range_trees[-1].prefix + + dense_size_str = self.dense_size_str(False) + + if len(dense_size_str) > 2: + value = self._map_tuple_or_scalar( + lambda v: self.cse.generate( + self.compute, f"tl.reshape({v}, {dense_size_str})", dtype=v.dtype, + ), + value, + + ) + + dim: int + root_op: str + + def final_reduction(value): + module = "tl" + # use tl + # use tl.max + if reduction_type in {"max", "min"}: + return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})") + return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})") + + def final_argreduce(buffer, result_var, value, index): + buffer.splice( + f"""\ + _, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) + {result_var} = {self.reduction_resize(f'{result_var}_tmp')} + """ + ) + + def get_reduction_axis(): + return list(self.range_tree_nodes.values())[-1] + + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + dim = self.reduction_dim() + acc_type = triton_acc_type(src_dtype) + torch_acc_type = upcast_acc_dtype(src_dtype) + result_var: Any = self.cse.newvar(dtype=torch_acc_type) + result_var.mask_vars = {var for var in masks if var[0] != "r"} + cond = " & ".join(masks) + + def where_cond(tval, fval): + if not cond: + return tval + return TritonKernelOverrides.where(cond, tval, fval) + + if self.persistent_reduction: + default = ir.Reduction.default_value(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + + def _mask_value(value, default): + return self.cse.generate(self.compute, where_cond(value, default), dtype=value.dtype) + + if self.numof_reduction_axis() == 1: + if isinstance(value, tuple): + masked_value = [_mask_value(v, d) for v, d in zip(value, default)] + else: + masked_value = _mask_value(value, default) + else: + masked_value = value + + if reduction_type in {"argmax", "argmin", "max", "min"}: + reduce_axis = get_reduction_axis() + broadcast_string: str + if self.is_1d_reduction(): + broadcast_string = f"tl.broadcast_to({reduce_axis.symbol()}.reshape({reduction_range_prefix.upper()}BLOCK), {masked_value}.shape)" + elif self.is_higher_order_reduction(): + broadcast_string = f"tl.broadcast_to({reduce_axis.symbol()}.reshape({reduction_range_prefix.upper()}BLOCK,1), {masked_value}.shape)" + else: + broadcast_string = f"tl.broadcast_to({reduce_axis.symbol()}.reshape(1,{reduction_range_prefix.upper()}BLOCK), {masked_value}.shape)" + accumulator_index = str( + self.cse.generate( + self.compute, + broadcast_string, + dtype=torch.int64 + ) + ) + if reduction_type == "argmax" or reduction_type == "argmin": + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + final_argreduce( + self.compute, result_var, masked_value, accumulator_index + ) + elif reduction_type == "max" or reduction_type == "min": + result_var = self.cse.generate( + self.compute, final_reduction(masked_value), dtype=masked_value.dtype, + ) + elif reduction_type == "welford_reduce": + raise RuntimeError("assert False, welford_reduction and is not supported now..") + elif reduction_type == "welford_combine": + raise RuntimeError("assert False, welford_combine and is not supported now..") + else: + result_var = self.cse.generate( + self.compute, final_reduction(masked_value), dtype=masked_value.dtype, + ) + else: + accumulator = self.cse.namedvar(f"_{result_var}", dtype=torch_acc_type) + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + if not isinstance(default, tuple): + self.prefix.writeline( + f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})" + ) + + if reduction_type in {"argmax", "argmin"}: + accumulator_index = f"_{result_var}_index" + long_max = torch.iinfo(torch.int64).max + self.prefix.writeline( + f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)" + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( + {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index + ) + {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} + {accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)} + """ + ) + final_argreduce(self.post_loop_store, result_var, accumulator, accumulator_index) + elif is_welford_reduction(reduction_type): + raise RuntimeError("assert False, welford_reduction and is not supported now..") + else: + combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype) + updated = combine_fn(accumulator, value) + self.compute.writeline( + f"{accumulator} = {where_cond(updated, accumulator)}" + ) + + if src_dtype == torch.bool: + accumulator = f"{accumulator}.to(tl.int8)" + result_type = triton_compute_type(dtype) + self.post_loop_store.writeline( + f"{result_var} = {final_reduction(accumulator)}.to({result_type})" + ) + else: + self.post_loop_store.writeline( + f"{result_var} = {final_reduction(accumulator)}" + ) + + self.cse.reduction_cache[cache_key] = result_var + + if isinstance(result_var, tuple): + self.outside_loop_vars |= set(result_var) + else: + self.outside_loop_vars.add(result_var) + + return result_var + + #XBLICK:split size, XBLOCK_SUB : tile1 size, RBLOCK:tile2 size + def add_autotune_args(self, argdefs): + # no tiling in this case + if self.persistent_reduction and self.axis2 is None: + return + argdefs.append(f"XBLOCK: tl.constexpr") + if self.numof_reduction_axis() <= 1: + argdefs.append(f"XBLOCK_SUB: tl.constexpr") + if self.axis2 is not None and not self.persistent_reduction: + argdefs.append(f"RBLOCK: tl.constexpr") + + def _get_heuristic(self): + if self.persistent_reduction: + if not (self.inside_reduction): + raise RuntimeError(" assert self.inside_reduction") + + return "persistent_reduction_npu_index" + elif self.inside_reduction: + return "reduction_npu_index" + return "pointwise_npu_index" + + def need_broadcast(self, index: sympy.Expr): + tiling_axis = [False, False] + for axis in index.free_symbols: + if axis not in self.range_tree_nodes: + continue + if self.range_tree_nodes[axis].is_tiling_axis1: + tiling_axis[0] = True + elif self.range_tree_nodes[axis].is_tiling_axis2: + tiling_axis[1] = True + #implict broadcast + result = (self.numof_tiling_axis() > 1 and not self.persistent_reduction) and (tiling_axis[1] ^ tiling_axis[0]) + result = result and self.find_axis2_in_indexing() + return result, tiling_axis + + def current_node_has_permute(self): + if not self.current_node: + return False + for index in self.current_node._body.indexing.values(): + if self.need_permuted(index): + return True + return False + + def need_permuted(self, index: sympy.Expr): + if self.numof_tiling_axis() <= 1: + return False + + need_permute = False + tmp_list = [] + coefficients_dict = index.as_coefficients_dict() + need_permute_axis1 = False + need_permute_axis2 = False + for key, value in coefficients_dict.items(): + if not key.free_symbols: + continue + key = list(key.free_symbols)[0] + if key not in self.range_tree_nodes: + continue + axis = self.range_tree_nodes[key] + # normally, axis2 is lowest dimension, except for higher_order_reduction + if (self.inside_reduction and self.is_higher_order_reduction(True)): + if axis.is_tiling_axis1 and value > sympy.Integer(1): + need_permute_axis1 = True + elif axis.is_tiling_axis2 and value > sympy.Integer(1): + need_permute_axis2 = True if self.numof_reduction_axis() <= 1 else isinstance(axis.expr, ModularIndexing) + tmp_list.append(True if value > sympy.Integer(1) else False) + + # If all axes have coefficients greater than 1, + # then the stride is not 1, and in this case, return false, + # indicating that the transpose is not required. + if all(tmp_list): + return False + return need_permute_axis1 or need_permute_axis2 + + def get_reshape_dense_str(self, tiling_axis): + # there must be one tiling asis missing + if not (tiling_axis[1] or tiling_axis[0]): + raise RuntimeError("assert tiling_axis[1] or tiling_axis[0]") + + sizes = ["XBLOCK_SUB", "1"] + if not tiling_axis[0]: + sizes = ["1", "RBLOCK"] + + if self.inside_reduction and self.is_higher_order_reduction(): + sizes = reversed(sizes) + return f"[{', '.join(sizes)}]" + + def get_reshape_str(self, tiling_axis, check_prev_node=True): + # there must be one tiling asis missing + if not (tiling_axis[1] or tiling_axis[0]): + raise RuntimeError("assert tiling_axis[1] or tiling_axis[0]") + + sizes = ["XBLOCK_SUB", "RBLOCK"] + if not tiling_axis[0]: + sizes[0] = "1" + elif not tiling_axis[1]: + sizes[1] = "1" + if self.inside_reduction and self.is_higher_order_reduction(check_prev_node): + sizes = reversed(sizes) + + return f"[{', '.join(sizes)}]" + + def get_broadcast_dense_str(self, tiling_axis, check_prev_node=True): + # there must be one tiling asis missing + if not (tiling_axis[1] or tiling_axis[0]): + raise RuntimeError("assert tiling_axis[1] or tiling_axis[0]") + + sizes = ["XBLOCK_SUB", "RBLOCK"] + if self.inside_reduction and self.is_higher_order_reduction(check_prev_node): + sizes = reversed(sizes) + return f"[{', '.join(sizes)}]" + + #broadcast, permute handling + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + original_index = index + is_permuted = self.need_permuted(index) + store_cache = self.cse.store_cache + if name in store_cache: + broadcasted, tiling_axis = self.need_broadcast(original_index) + result_var = store_cache[name] + if broadcasted: + line = f"{result_var}.broadcast_to({self.get_broadcast_dense_str(tiling_axis, True)})" + buffer = self.compute if self.persistent_reduction else self.loads + result_var = self.cse.generate(buffer, line, dtype=result_var.dtype) + elif is_permuted: + line = f"{result_var}.permute(1,0)" + buffer = self.compute if self.persistent_reduction else self.loads + result_var = self.cse.generate(self.loads, line, dtype=result_var.dtype) + return result_var + + need_broadcast, tiling_axis = self.need_broadcast(index) + indirect_indexing = self.is_indirect_indexing(index) + indexing = self.indexing(index, block_ptr=True) + has_rindex = indexing.has_rindex() + has_tmpmask = indexing.has_tmpmask() + is_coalesced = any( + i == 1 for i in self.get_strides_of_load(original_index).values() + ) + ep = "" + if ( + (has_tmpmask or has_rindex) + and V.graph.get_dtype(name) != torch.bool + and indexing.has_mask() + ): + other = ", other=0.0" + else: + other = "" + + advance_block_ptr = None + append_broadcast = None + dtype = V.graph.get_dtype(name) + + if V.graph.is_unspec_arg(name): + line = var + else: + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing, other + ) + line = f"tl.load({block_ptr}{other}{ep})" + # add needed size=1 dimensions + line = triton_reshape( + line, indexing.block_shape, indexing.reshape_suffix + ) + elif isinstance(original_index, sympy.Integer): + line = f"tl.load({var} + ({original_index}))" + num_size = len(self.dense_size_list(is_permuted)) + append_broadcast = "[1, 1]" if (num_size > 1) else "[1]" + else: + index_str = indexing.index_str + mask_str = indexing.mask_str + if is_permuted: + index_str, mask_str = self.apply_var_prime(index, index_str, mask_str) + line = f"tl.load({var} + ({index_str}), {mask_str}{ep}{other})" + + dtype = V.graph.get_dtype(name) + if dtype in (torch.bfloat16, ): + line += ".to(tl.float32)" + if dtype == torch.bool and torch.version.hip is None: + line += ".to(tl.int1)" + if has_tmpmask: + # Masked loads must come after the mask is computed + load_buffer = self.compute + elif ( + self.inside_reduction + and self.range_trees[-1].is_loop + and not indirect_indexing + and not has_rindex + ): + # can lift a common load outside of reduction loop + # One exception is when this is an indirect_load. + load_buffer = self.prefix + + else: + load_buffer = self.loads + + result_var = self.cse.generate(load_buffer, line, dtype=dtype) + if not (isinstance(result_var, TritonCSEVariable)): + raise RuntimeError("assert isinstance(result_var, TritonCSEVariable)") + result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] + + if append_broadcast and append_broadcast != '[]': + line = f"tl.broadcast_to({result_var}, {append_broadcast})" + result_var = self.cse.generate(load_buffer, line, dtype=dtype) + elif need_broadcast and not indirect_indexing: + line = f"{result_var}.broadcast_to({self.get_broadcast_dense_str(tiling_axis)})" + result_var = self.cse.generate(load_buffer, line, dtype=dtype) + elif is_permuted: + line = f"{result_var}.permute(1,0)" + result_var = self.cse.generate(self.loads, line, dtype=dtype) + + if advance_block_ptr: + load_buffer.writeline(advance_block_ptr) + + if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): + self.outside_loop_vars.add(result_var) + + return result_var + + # don't call symlify_indexing + def prepare_indexing( + self, + index: sympy.Expr, + ): + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) + for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + simp_index = index + + # Now that we are done simplifying we can unwrap Identity so that downstream handling + # for its contained expression will work. previously, tl.full wrapping of sympy.Integer + # would not occur + simp_index = ( + simp_index if not isinstance(simp_index, Identity) else simp_index.args[0] + ) + + return self.codegen_indexing(simp_index) + + #1. only remove the line which asserts index var should be in "xyr" + #2. don't do simplify_indexing, which combine continuous dims + #3. removed block_ptr, removed dense mask/broadcast support + # fixme, dense_mask_vars should be generated from sorted_axis + # upgraded to torch251 + def indexing( + self, + index: sympy.Expr, + *, + copy_shape=None, + dense_indexing=False, + override_mask=None, + block_ptr=False, + ) -> Union[IndexingOptions, BlockPtrOptions]: + """ + Compute the index and mask to pass to tl.load() or tl.store() + """ + index = self.prepare_indexing(index) + index_vars = index.free_symbols + has_rindex = False + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + s.name.startswith("s") or s.name.startswith("ps") for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + index_vars = index.free_symbols + has_rindex = False + + mask_vars: Set[str] = set() + for var in index_vars: + if not (isinstance(var, sympy.Symbol)): + raise RuntimeError("assert isinstance(var, sympy.Symbol)") + + has_rindex = has_rindex or var.name.startswith("r") + if override_mask: + pass + elif var.name.startswith("tmp"): + # indirect indexing + cse_var = self.cse.varname_map[var.name] + mask_vars.update(cse_var.mask_vars) + elif var.name.startswith(("s", "ps", "i")): + pass + else: + # var is one of xN, yN or rN + mask_vars.add(f"{var.name}_mask") + + expand_str = None + index_str = self.index_to_str(index) + is_permute = self.need_permuted(index) + if isinstance(index, sympy.Integer): + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str(is_permute) + if (index != 0): + index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" + else: + index_str = f"tl.arange(0,1)" + return IndexingOptions(index_str, set(), "None", expand_str, has_rindex, index) + + if override_mask: + mask_vars = {override_mask} + if self._load_mask: + mask_vars.add(self._load_mask) + self.filter_masks(mask_vars) + mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" + return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex, index) # type: ignore[arg-type] + + + + def codegen_indexing(self, expr: sympy.Expr): + expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) + for sym in sorted(expr.free_symbols, key=str): + if sym in self.range_tree_nodes: + # if indexing expression is complicated, we precompute it on the host side + # and send the result as a kernel argument + replacements = {} + for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] + replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) + if len(replacements) > 0: + self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] + self.range_tree_nodes[sym].expr, replacements # type: ignore[index] + ) + self.range_tree_nodes[sym].codegen() # type: ignore[index] + return expr + + def split_and_set_ranges(self, lengths: Sequence[Sequence[sympy.Expr]]): + groups = [rt.numel for rt in self.range_trees] + if not self.inside_reduction: + groups[-1] = sympy.S.One + + return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges) + + #support split multiple ranges (instead of double) from one flatten range, triple-ranges are needed in mamba model + @staticmethod + def _split_iteration_ranges( + groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] + ): + sv = V.graph.sizevars + new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] + remaining = [sv.simplify(g) for g in groups] + for i, group in enumerate(remaining): + if isinstance(group, (list, tuple)): + remaining[i] = NumelList(group).numels() + + var_count = itertools.count() + + def add_range(i, expr): + expr = sv.simplify(expr) + if not sv.statically_known_multiple_of(remaining[i], expr): + raise CantSplit() + # guard on the last item out + remaining[i] = FloorDiv(remaining[i], expr) + new_ranges[i].append(expr) + return next(var_count) + + def make_combined(strides, index_list): + def getter(flat_vars): + expr = sympy.Integer(0) + for stride, index in zip(strides, index_list): + expr = stride * flat_vars[index] + expr + return expr + + return getter + + def size_hints(group): + if isinstance(group, (list, tuple)): + return sv.size_hint(NumelList(group).numels()) + return sv.size_hint(group) + + def add_multiple_range(size, return_getters): + # need to break size in multiple + index_list = [] + stride_list = [] + group = current_group + remained_size = size + # Two checks: + # 1. remaining sizes to be merged + # 2. remained_size is already divided to 1 + while (group < len(remaining) and remaining[group] > 1) and (remained_size > 1): + group_size = remaining[group] + # size should be divisible by group_size + if not sv.statically_known_multiple_of(remained_size, group_size): + raise CantSplit() + index_list.append(add_range(group, group_size)) + remained_size = FloorDiv(remained_size, group_size) + stride_list.append(remained_size) + group = group + 1 + if remained_size != 1: + raise CantSplit() + return_getters.append(make_combined(stride_list, index_list)) + + return_getters_groups = [] + current_group = 0 + + for length_group in lengths: + return_getters = [] + for size in length_group: + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] + return_getters.append(lambda _: sympy.Integer(0)) + continue + + while ( + current_group < len(remaining) + and size_hints(remaining[current_group]) == 1 + ): + # scroll to next group with remaining elements + current_group += 1 + size_hint = sv.size_hint(size) + if current_group >= len(remaining): + pdb.set_trace() + if size_hint > size_hints(remaining[current_group]): + #add multiple ranges (two or more) to the list, as well as the getter funcs + add_multiple_range(size_hint, return_getters) + else: + return_getters.append( + operator.itemgetter(add_range(current_group, size_hint)) + ) + return_getters_groups.append(return_getters) + + if not (all(V.graph.sizevars.size_hint(s) == 1 for s in remaining)): + raise RuntimeError("assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining)") + + return new_ranges, return_getters_groups + + # torch260 done + # just to override load method of CSEProxy, however, CSEProxy is an inner which can not be monkey patched, + # we need to override the whole inner class + def __enter__(self): + class CSEProxy: + self.name = "CSEProxy" + vr_analysis = ValueRangeAnalysis() + + @staticmethod + def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] + def inner(*args, **kwargs): + bounds = CSEProxy._bound_variable(name, *args, **kwargs) + + value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] + dtype_handler = DtypePropagationOpsHandler() + + output_idx = 0 + + def do_cse(v): + # cpp backend doesnt set current device - TODO: fix + if V.graph.current_device is not None: + device_str = V.graph.get_current_device_or_throw().type + triton_backend = ( + config.cpu_backend == "triton" + if device_str == "cpu" + else config.cuda_backend == "triton" + ) + else: + triton_backend = False + + # only triton backend tracks dtype currently + if triton_backend: + if name == "masked": + output_dtype = value.dtype + else: + output_dtype = getattr( + dtype_handler, + name, + )(*args, **kwargs) + else: + # cpp backend doesnt track dtype yet + output_dtype = None + + csevar = V.kernel.cse.generate( + V.kernel.compute, + v, + bounds=bounds, + dtype=output_dtype, + ) + + nonlocal output_idx + if ( + config.test_configs.runtime_triton_dtype_assert + and triton_backend + ): + from torch._inductor.codegen.triton import triton_type + + # we tree_map over the output, so we need to fetch corresponding dtype + if isinstance(output_dtype, (list, tuple)): + output_dtype = output_dtype[output_idx] + + V.kernel.compute.writeline( + f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})" + ) + output_idx += 1 + + csevar.update_on_args(name, args, kwargs) + + return csevar + + return pytree.tree_map(do_cse, value) + + return inner + + @staticmethod + def _bound_variable(name, *args, **kwargs): + """ + If the variable comes from an FX node, we forward the bound we have already computed + Else, if the variable when codegen'ing another op, we try to compute its bounds + """ + from torch._inductor.select_algorithm import TritonTemplateKernel + + if isinstance(V.kernel, TritonTemplateKernel): + return ValueRanges.unknown() + + fx_node = V.interpreter.current_node + if fx_node.target == name and self.node_to_bounds is not None: + if not (isinstance(self.node_to_bounds, dict)): + raise RuntimeError("assert isinstance(self.node_to_bounds, dict)") + + return self.node_to_bounds.get(fx_node, ValueRanges.unknown()) + elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): + # These create lots of inner strings. We would need to compute the bounds at the ops + # We will also likely not get much from computing VRs on these nodes + if any( + s in fx_node.target + for s in ("set_indirect", "reduction", "scan") + ): + return ValueRanges.unknown() + + # We assume that the inputs come from `ops.` and are not strings. If you want to generate + # intermediary strings, wrap them in CSE variables with properly initialised bounds. + + # If there is no FX bound but we know how to compute one we do so + if (kwargs): + raise RuntimeError("assert not kwargs") + + def arg_to_bound(x): + if isinstance(x, CSEVariable): + return x.bounds + elif isinstance(x, sympy.Expr): + return bound_sympy(x) + else: + return x + + arg_bounds = list(map(arg_to_bound, args)) + return getattr(CSEProxy.vr_analysis, name)(*arg_bounds) + return ValueRanges.unknown() + + @staticmethod + def indirect_indexing( + var: CSEVariable, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg=True, + ): + if isinstance(size, int): + size = sympy.Integer(size) + if not (isinstance(size, sympy.Expr)): + raise RuntimeError("assert isinstance(size, sympy.Expr), size") + # Skip CSE since this doesn't return an expression + + if var.bounds.lower < 0: # type: ignore[operator] + if wrap_neg: + stm = ops.add(var, ops.index_expr(size, torch.long)) + # Mixed negative and non-negative + if var.bounds.upper >= 0: # type: ignore[operator] + lt = ops.lt(var, 0) + stm = ops.where(lt, stm, var) + else: + stm = var + + # Propagate bounds as we know how to compute them properly + new_bounds = ValueRanges.unknown() + if var.bounds != ValueRanges.unknown() and isinstance( + size, sympy.Number + ): + # Take the negative part of the bound and add size to it + # Then take union of that and the positive part + # This is a tighter bound than that of a generic ops.where, as we have info on the cond + neg_bounds = var.bounds & ValueRanges(-int_oo, -1) + new_bounds = ValueRanges( + neg_bounds.lower + size, neg_bounds.upper + size + ) + # We don't have a good way of representing the empty range + if var.bounds.upper >= 0: # type: ignore[operator] + pos = var.bounds & ValueRanges(0, int_oo) + new_bounds = new_bounds | pos + + var = self.cse.generate(self.compute, stm, bounds=new_bounds) + + sympy_var = parent_handler.indirect_indexing(var, size, check) + if generate_assert(check): + assert_lower = not (var.bounds.lower >= 0) + # value ranges cannot x < s when x and s are symbols + assert_upper = not isinstance(size, sympy.Number) or not ( + var.bounds.upper < size + ) + self.check_bounds(sympy_var, size, assert_lower, assert_upper) + return sympy_var + + @staticmethod + def check_bounds( + expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ): + return self.check_bounds(expr, size, lower, upper) + + @staticmethod + def load(name: str, index: sympy.Expr) -> CSEVariable: + if name in self.cse.invalidated_stores: + # A load from an invalidated store requires us to + # keep the actual buffer around + V.kernel.must_keep_buffers.add(name) + if free_symbol_is_type(index, SymT.TMP): + return self.indirect_load(name, index) + store_cache = self.cse.store_cache + if name in store_cache: + return self.load(name, index) + out = self.load(name, index) + # count load that is not in the store_cache, and also not in the + # cse cache. + if out.use_count == 1: + self.num_load += 1 + return out + + @staticmethod + def _update_store_cache(name: str, value: CSEVariable): + self.cse.store_cache[name] = value + if self.current_node and name in V.graph.name_to_buffer: + buf = self.current_node.get_output(name) + for other_name in buf.get_mutations(): + self.cse.store_cache[other_name] = value + + @staticmethod + def store( + name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + self.store_buffer_names.add(name) + if mode is None: + CSEProxy._update_store_cache(name, value) + if name not in V.graph.removed_buffers: + return self.store(name, index, value, mode=mode) + return None # type: ignore[return-value] + + @staticmethod + def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): + self.store_buffer_names.add(name) + CSEProxy._update_store_cache(name, value) + + if name not in V.graph.removed_buffers: + return self.store_reduction(name, index, value) + raise RuntimeError("store_reduction") + + @staticmethod + def reduction( + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + self.num_reduction += 1 + return self.reduction(dtype, src_dtype, reduction_type, value) + + @staticmethod + def scan( + dtypes: Tuple[torch.dtype, ...], + combine_fn: Callable[ + [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], + Tuple[CSEVariable, ...], + ], + values: Tuple[CSEVariable, ...], + ) -> Tuple[CSEVariable, ...]: + return self.scan(dtypes, combine_fn, values) + + @staticmethod + def sort( + dtypes: Tuple[torch.dtype, ...], + values: Tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> Tuple[CSEVariable, ...]: + return self.sort(dtypes, values, stable, descending) + + @staticmethod + def bucketize( + values: CSEVariable, + boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[Tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + """ + [Note: Inductor bucketize op] + + Inputs: + ------- + values: the values to be bucketized. + boundaries: a tuple containing + (a) the name of the boundaries tensor (which must be sorted, unless + the sorting tensor is present), + (b) the length of the tensor in the last dimension (i.e. the length of + one set of boundaries), + (c) the number of elements in the underlying storage (i.e. the length + of the flattened tensor, ignoring striding), and + (d) the stride of the tensor in the last dimension. + boundary_indices: indices into a flattened version of the boundaries + tensor, of the same size and shape as "values". Each index points to + the first element in the set of boundaries to be used for the + corresponding value. + indexing_dtype: the dtype to use when indexing into the boundaries + tensor. This must be int64 or int32. This additionally specifies the + dtype of the return value. + right: see "Details" below. + sorter: an optional tuple containing + (a) the name of an optional sorting tensor, used to access unsorted + boundaries without reordering the boundaries tensor, and + (b) the stride of the tensor in the last dimension. + The values in the sorting tensor are used as indices into the *last* + dimension of the boundaries tensor, with all other indices matching. + The size of the sorting and boundaries tensors must be equivalent. + sorter_indices: must be present if the sorting array is present; see + "boundary_indices" for the equivalent definition for the boundaries + tensor. + + Output: + ------- + The buckets each value belongs in, within a given set of boundaries. 0 + indicates a position before the first boundary, and len(boundaries_set) + represents a position after the last boundary. + + Details: + -------- + Given a value and a set of boundaries, calculate the bucket that each + value belongs to. This works differently in 1-D and N-D cases. + + for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [0, 4, 4, 8], right=True + return = [[ 0, 1, 1, 1], [1, 3, 3, 4]]. + + for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [[0, 4], [4, 8]], right=True + return = [[ 0, 1, 1, 1], [0, 1, 1, 2]] + + Note that in the N-D boundaries case, the shape of "values" and + "boundaries" must match in every dimension _except_ the last. + + When right == False, bucket i refers to range (boundaries[i], boundaries[i+1]]. + When right == True, bucket i refers to range [boundaries[i], boundaries[i+1]). + + Boundaries must be non-decreasing, or a sorter must be provided which + would re-index offsets in a non-decreasing order (e.g. the second output + of torch.sort(offsets)). Otherwise, the result is undefined. + """ + return self.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) + + # Use mypy to check protocol implemented correctly + def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: + return h + + super().__enter__() + if not (self.overrides): + raise RuntimeError("assert self.overrides") + parent_handler = self.overrides(V.get_ops_handler()) + self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self diff --git a/torch_npu/_inductor/codegen/triton_utils.py b/torch_npu/_inductor/codegen/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5acd971ba027980ccede319c0a8870b1419ae37a --- /dev/null +++ b/torch_npu/_inductor/codegen/triton_utils.py @@ -0,0 +1,29 @@ + +import torch + +# wrapper npu 32 bytes align, get and pass unalign info to triton meta +# then autotune choose tiling param and send them to bishengIR +byte_per_numel = { + torch.float32: 4, # torch.float32 or torch.float + torch.float64: 8, # torch.float64 or torch.double + torch.float16: 2, # torch.float16 or torch.half + torch.bfloat16: 2, # torch.bfloat16 + torch.int32: 4, # torch.int32 or torch.int + torch.int64: 8, # torch.int64 or torch.long + torch.int16: 2, # torch.int16 or torch.short + torch.int8: 1, # torch.int8 + torch.uint8: 1, # torch.uint8 + torch.bool: 1, # torch.bool + torch.complex32: 4, # torch.complex32 (not yet available in PyTorch as of the latest stable release) + torch.complex64: 8, # torch.complex64 + torch.complex128: 16 # torch.complex128 +} + + +def get_aligned_numel(dtype): + if dtype in byte_per_numel: + return 32 // byte_per_numel[dtype] + else: + return 1 + + diff --git a/torch_npu/_inductor/codegen/wrapper.py b/torch_npu/_inductor/codegen/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8daeafbf6349f90907a62b7fdf3feb006087bd24 --- /dev/null +++ b/torch_npu/_inductor/codegen/wrapper.py @@ -0,0 +1,73 @@ +from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg, SubgraphPythonWrapperCodegen +from torch._inductor.virtualized import V +from torch._inductor.utils import ( + cache_on_self, +) +from torch._inductor.runtime import triton_heuristics +from torch._inductor import config + + +class NPUWrapperCodeGen(PythonWrapperCodegen): + def __init__(self): + super().__init__() + + @staticmethod + def create( + is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + ): + if is_subgraph: + return SubgraphPythonWrapperCodegen(subgraph_name, parent_wrapper) + return NPUWrapperCodeGen() + + @cache_on_self + def write_triton_header_once(self) -> None: + import_str = f""" + import triton + import triton.language as tl + from {triton_heuristics.__name__} import ( + split_scan_grid, + grid_combo_kernels, + start_graph, + end_graph, + cooperative_reduction_grid, + ) + from torch_npu._inductor.npu_triton_heuristics import grid + """ + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.splice(import_str) + self.kernel_autotune_calls.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + if not V.graph.cpp_wrapper: + self.imports.splice(import_str, strip=True) + self.imports.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + + #generate numel expr for range_tree_node + def generate_node_numel_expr(self, kernel_name: str, node, numel_expr): + expr = f"{kernel_name}_{node.name}_numel" + if (expr, V.graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((expr, V.graph)) + self.writeline( + f"{self.declare}{expr} = {self.expr_printer(numel_expr)}{self.ending}" + ) + else: + self.writeline(f"{expr} = {self.expr_printer(numel_expr)}{self.ending}") + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + return SymbolicCallArg(expr, numel_expr) + + # don't free anything + def make_buffer_free(self, buffer): + return "" + + # don't assert + def codegen_input_size_asserts(self) -> None: + pass diff --git a/torch_npu/_inductor/config.py b/torch_npu/_inductor/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9c0e168c78ee3f317af8db820e7361eb8af655c8 --- /dev/null +++ b/torch_npu/_inductor/config.py @@ -0,0 +1,46 @@ +import os # noqa: C101 +import sys +import logging +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from triton.runtime.driver import driver +from torch._inductor import config +enable_npu_indexing = True + +config.triton.unique_kernel_names = True +# avoid test_opensora_cases_model_16_forward reinterpre_tensor issue +config.allow_buffer_reuse = False +#inductor debug switch +config.trace.enabled = True + +# npu hardware params from trion +target = driver.active.get_current_target() +device = driver.active.get_current_device() +prop = driver.active.utils.get_device_properties(device) + +num_cube_core = prop["num_aicore"] +num_vector_core = prop["num_aicore"] + +# unit byte +npu_block = 32 + +if ("Ascend910B" in target.arch): + num_vector_core = num_cube_core * 2 + +log_level_env = os.getenv('INDUCTOR_ASCEND_LOG_LEVEL', 'INFO').upper() +log_level_mapping = { + 'DEBUG': logging.DEBUG, + 'INFO': logging.INFO, + 'WARNING': logging.WARNING, + 'ERROR': logging.ERROR, + 'CRITICAL': logging.CRITICAL +} +log_level = log_level_mapping.get(log_level_env.upper(), logging.INFO) +logging.basicConfig( + level=log_level, + format='%(asctime)s - %(levelname)s - %(message)s' +) +log = logging.getLogger(__name__) + +aggresive_autotune = os.getenv("INDUCTOR_ASCEND_AGGRESSIVE_AUTOTUNE", '0').lower() in ('1', 'true') + +profile_path = "./profile_result/" \ No newline at end of file diff --git a/torch_npu/_inductor/decomposition.py b/torch_npu/_inductor/decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..af5ecbf311ecf2471bdc1a89bb8f991f6f9f4c34 --- /dev/null +++ b/torch_npu/_inductor/decomposition.py @@ -0,0 +1,48 @@ +from torch._inductor.decomposition import decompositions, pw_cast_for_opmath +from torch._inductor.decomposition import register_decomposition +import torch._ops +from .lowering import _init_set + + +aten = torch.ops.aten + +DECOMPOSITION_OVERLOAD_OP = [ + aten._log_softmax, + aten.nll_loss_forward, + # aten.gelu_backward, + # aten.gelu, + aten.nll_loss_backward, + aten._log_softmax_backward_data, + aten.embedding_dense_backward +] + + +def _register_npu_inductor_decompositons(): + + overload_op_set = set() + _init_set(DECOMPOSITION_OVERLOAD_OP, overload_op_set) + + for op in overload_op_set: + if (op in decompositions): + del decompositions[op] + + @register_decomposition([aten.scatter.src]) + @pw_cast_for_opmath + def scatter_src(self, input_tensor, dim, index_tensor, source_tensor): + (XNUMEL, YS) = input_tensor.shape + index_rblock = torch.arange(YS).npu().reshape((1, YS)).repeat((XNUMEL, 1)) + + index_tensor_brd = index_tensor.to(torch.int32).broadcast_to(XNUMEL, YS) + source_tensor_brd = source_tensor.broadcast_to(XNUMEL, YS).to(torch.float32) + scatter1 = torch.where(index_rblock == index_tensor_brd, 1.0, 0.0) * source_tensor_brd + return scatter1 + + @register_decomposition([aten.expm1]) + def expm1(x): + tensor = torch.exp(x) - torch.ones_like(x) + return tensor + + @register_decomposition([aten.erfc]) + def erfc(x): + tensor = torch.ones_like(x) - torch.exp(x) + return tensor \ No newline at end of file diff --git a/torch_npu/_inductor/dynamo_embedding_backward_dispatch.py b/torch_npu/_inductor/dynamo_embedding_backward_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..6584c99e7a864619f73b564a060cafc52f3654fa --- /dev/null +++ b/torch_npu/_inductor/dynamo_embedding_backward_dispatch.py @@ -0,0 +1,10 @@ +import torch +from torch.library import Library, impl +python_dispatcher_lib = Library("aten", "IMPL", "PythonDispatcher") + + +@impl(python_dispatcher_lib, "embedding_backward") +def embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse): + if sparse: + raise RuntimeError("the current NPU does not yet support sparse tensor, when sparse is set to True") + return torch.ops.aten.embedding_dense_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq) \ No newline at end of file diff --git a/torch_npu/_inductor/lowering.py b/torch_npu/_inductor/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..a7580f8cb492f8b4955423fbd24468f04ccc8003 --- /dev/null +++ b/torch_npu/_inductor/lowering.py @@ -0,0 +1,313 @@ +import sympy +from torch._inductor.ir import Reduction +from torch._inductor.utils import sympy_product +from torch._inductor import ir +from torch._inductor.ir import ExpandView, TensorBox, ops_wrapper +from torch._inductor.lowering import sum_ +from torch._inductor import lowering +from torch._prims_common import ( + is_boolean_dtype, + is_integer_dtype, + get_computation_dtype, +) +from torch._inductor.decomposition import decompositions, pw_cast_for_opmath +import torch._ops + +from torch._inductor.lowering import ( + lowerings, + make_fallback, + register_lowering, + to_dtype, + # make_reduction, + # reduce_amax, + # reduce_amin, + fallback_cumsum, + _validate_reduction_axis, + div, + squeeze, + square, + sub, + fallback_handler, + is_boolean_type, + logical_and, + make_pointwise, + _make_reduction_inner, + _validate_reduction_axis, +) +import torch_npu +from torch_npu import npu_dtype_cast + + +def make_reduction(reduction_type: str, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) + if isinstance( + result.data.data, Reduction + ): + #Only realize if reduction isn't unrolled + size = x.get_size() + axis = set(_validate_reduction_axis(x, axis)) + kept_idx = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + else: + kept_idx.append(i) + + object.__setattr__(result.data.data, "kept_idx", kept_idx) + object.__setattr__(result.data.data, "reduced_idx", reduced_idx) + + result.realize() + return result + + return inner + + +lowering.make_reduction = make_reduction + + +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims + + +def _init_set(input_list, output_set): + for fn in input_list: + output_set.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + output_set.add(other_fn) + + +GENERATE_LIST = [ + aten.mul, + aten.add, + aten.sub, + aten.div, + aten.exp, + aten.maximum, + aten.sum, + aten.select, + aten.unsqueeze, + aten.repeat, + #aten.clone, + aten.reshape, + aten.where, + aten.lt, + aten.minimum, + aten.gt, + aten.le, + aten.ceil, + aten.floor, + aten.rsqrt, + aten.abs, + aten.log, + aten.bitwise_xor, + aten.amax, + # backward + prims.convert_element_type, + aten.min, + aten.max, + aten.erf, + aten.argmax, + aten.argmin, + aten.clamp_min, + aten.slice, + aten.neg, + aten.cat, + aten.arange, + aten.expand, + aten.eq, + aten.where, + aten.scalar_tensor, + aten.ge, + aten.permute, + aten.sqrt, + aten.relu, + aten.clamp, + aten.clamp_max, + aten.mean, + # npu.npu_dtype_cast + npu_dtype_cast, + aten.select_scatter, + aten.slice_scatter, + prims.broadcast_in_dim, + prims.maximum, + aten.ne, + aten.sigmoid, + aten.sign, + aten.logical_and, + aten.logical_or, + aten.logical_not, + aten.pow, + aten.gelu, + aten.tanh, + aten.isnan, + aten.bitwise_and, + aten.squeeze, + aten.copy, + aten.reciprocal +] + +GENERATE_LIST2 = [ + "foreach" +] + +FALLBACK_LIST = [] + +# 先删除从lowering已经注册的op,再更新,不然会lowering的时候找到在torch注册的op +LOWERING_OVERLOAD_OP = [ + aten.cumsum, + aten.mean, + # aten.max, + # aten.min, + # aten.mul, + aten.var_mean, + aten.var, + + aten.embedding, + aten.split, + aten.split_with_sizes, + aten.nll_loss_forward, + aten.gather, + aten.cat, + aten.clone +] + + +def _register_npu_inductor_fallbacks(): + gen_set = set() + _init_set(GENERATE_LIST, gen_set) + overload_op_set = set() + _init_set(LOWERING_OVERLOAD_OP, overload_op_set) + + # 把不在白名单的op fallback + for op in lowerings: + if op not in decompositions and op not in gen_set: + if isinstance(op, torch._ops.OpOverloadPacket) or \ + isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + flag = False + for gens in GENERATE_LIST2: + if str(op).find(gens) != -1: + flag = True + if flag: + continue + else: + make_fallback(op) + FALLBACK_LIST.append(op) + # 把需要overload的op在lowering里删除 + for op in overload_op_set: + if op in lowerings: + del lowerings[op] + + @register_lowering(aten.mean) + def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + + @register_lowering(aten.cumsum) + def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + # torch.int64->torch.int32 + dtype = torch.int32 + if len(x.get_size()) == 0: + if axis not in [0, -1]: + raise ValueError("axis must be 0 or -1") + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + return fallback_cumsum(x, dim=axis, dtype=dtype) + + @register_lowering(npu_dtype_cast, type_promotion_kind=None) + def _convert_npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + + + def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + + def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + @register_lowering(aten.var_mean) + def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + @register_lowering([aten.var, prims.var]) + def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + @register_lowering(aten.embedding, type_promotion_kind=None) + def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + return fallback_handler(aten.embedding.default)(weight, indices, padding_idx=-1, scale_grad_by_freq=False, + sparse=False) + + @register_lowering(aten.cat) + def cat(inputs, dim=0): + return fallback_handler(aten.cat.default)(inputs, dim) + + make_fallback(aten._log_softmax) + make_fallback(aten.gather) + make_fallback(aten.nll_loss_forward) diff --git a/torch_npu/_inductor/npu_choices.py b/torch_npu/_inductor/npu_choices.py new file mode 100644 index 0000000000000000000000000000000000000000..ff0c11b4ff0e8bf20d91f480e5db0da327096751 --- /dev/null +++ b/torch_npu/_inductor/npu_choices.py @@ -0,0 +1,38 @@ +import typing +from typing import Any, Dict, List, Type, TYPE_CHECKING + +import sympy + +from torch._inductor import config +from torch._inductor.runtime.hints import ReductionHint +from torch._inductor.virtualized import V +from torch._inductor.codegen.simd_kernel_features import SIMDKernelFeatures +from torch._inductor.codegen.triton import TritonKernel + + + + + +@staticmethod +def should_use_persistent_reduction( + features: SIMDKernelFeatures, cooperative_reduction: bool +) -> bool: + """ + Heuristic to decide if a persistent reduction should be used. + """ + if not config.triton.persistent_reductions: + return False + threshold = { + ReductionHint.INNER: 1024, + ReductionHint.DEFAULT: 1024 + }.get(features.get_reduction_hint(), 64) + if cooperative_reduction: + # The RSPLIT of cooperative reductions means each thread block is operating on fewer elements + try: + threshold *= 32 // min(V.graph.sizevars.size_hint(features.numel), 32) + except ValueError: + pass # unbacked symint + + if config.triton.multi_kernel: + threshold *= 16 + return V.graph.sizevars.statically_known_leq(features.reduction_numel, threshold) # type: ignore[arg-types] \ No newline at end of file diff --git a/torch_npu/_inductor/npu_fusion_attention_graph.py b/torch_npu/_inductor/npu_fusion_attention_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..2e421b6a1c2201a7a678d91eacf06ade5da9f105 --- /dev/null +++ b/torch_npu/_inductor/npu_fusion_attention_graph.py @@ -0,0 +1,238 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import functools +import sympy +import torch +from torch.autograd import Function +from torch.library import Library, impl +import torch.nn.functional as F +import torch_npu + + + +npu_def = Library("npu_graph", "DEF") +npu_lib = Library("npu_graph", "IMPL", "PrivateUse1") +meta_lib = Library("npu_graph", "IMPL", "Meta") + +npu_def.define("npu_fa(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)") +npu_def.define("npu_fa_backward(Tensor query, Tensor key, Tensor value, Tensor dy, int head_num, str input_layout, *, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, Tensor? softmax_max=None, Tensor? softmax_sum=None, Tensor? softmax_in=None, Tensor? attention_in=None, float scale_value=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, Tensor? seed=None, Tensor? offset=None, Tensor? numels=None, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor)") + + +@impl(npu_lib, "npu_fa") +def npu_fa(*args, **kwargs): + if len(args) > 8: + args = list(args) + # for scale + try: + args[8] = 1.0 / args[8] + except IndexError: + args[8] = 1.0 / (args[8] + 1e-6) + print("args[8]: zero can not be divided") + r1, r2, r3, r4, seed, offset, numel = torch_npu.npu_fusion_attention(*args, **kwargs) + r2.requires_grad = False + r3.requires_grad = False + r4.requires_grad = False + return r1, r2, r3, r4, torch.tensor([seed], requires_grad=False), torch.tensor([offset], requires_grad=False), torch.tensor([numel], requires_grad=False) + + +@impl(npu_lib, "npu_fa_backward") +def npu_fa_backward(*args, **kwargs): + if 'scale_value' in kwargs: + kwargs['scale_value'] = 1.0 / kwargs['scale_value'] + return torch_npu.npu_fusion_attention_grad(*args, **kwargs) + + +@impl(meta_lib, "npu_fa") +def npu_fa(query, key, value, head_num, input_layout, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, + inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + B = query.size(0) + N = head_num + S1 = query.size(2) + S2 = key.size(2) + + if input_layout == "BSH": + B = query.size(0) + S1 = query.size(1) + S2 = key.size(1) + + if input_layout == "SBH": + B = query.size(1) + S1 = query.size(0) + S2 = key.size(0) + + attention_score = torch.empty_like(query, dtype=query.dtype, device='meta').contiguous() + softmax_max = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta') + softmax_sum = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta') + softmax_out = torch.empty([0], dtype=query.dtype, device='meta') + return (torch.empty_like(attention_score), + torch.empty_like(softmax_max), + torch.empty_like(softmax_sum), + torch.empty_like(softmax_out), + torch.tensor([0], device='meta', requires_grad=False), + torch.tensor([0], device='meta', requires_grad=False), + torch.tensor([0], device='meta', requires_grad=False)) + + +@impl(meta_lib, "npu_fa_backward") +def npu_fa_backward(query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None, atten_mask=None, + softmax_max=None, softmax_sum=None, softmax_in=None, attention_in=None, scale_value=1.0, + keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, seed=0, offset=0, + numels=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + + dq = torch.empty_like(query, dtype=query.dtype, device='meta').contiguous() + dk = torch.empty_like(key, dtype=query.dtype, device='meta').contiguous() + dv = torch.empty_like(value, dtype=query.dtype, device='meta').contiguous() + dpse = torch.empty([0], dtype=query.dtype, device='meta').contiguous() + return (torch.empty_like(dq), torch.empty_like(dk), torch.empty_like(dv), torch.empty_like(dpse) if pse else None) + + +class NpuGraphAttentionFunction(Function): + @staticmethod + def forward(ctx, query, key, value, head_num, input_layout, pse=None, padding_mask=None, atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + # 前向传播逻辑 + # 这里假设有一个实现前向传播的函数 `npu_fusion_attention_forward` + result0, result1, result2, result3, result4, result5, result6 = torch.ops.npu_graph.npu_fa( + query, key, value, head_num, input_layout, pse=pse, padding_mask=padding_mask, atten_mask=atten_mask, scale=scale, keep_prob=keep_prob, pre_tockens=pre_tockens, next_tockens=next_tockens, inner_precise=inner_precise, prefix=prefix, actual_seq_qlen=actual_seq_qlen, actual_seq_kvlen=actual_seq_kvlen, sparse_mode=sparse_mode, gen_mask_parallel=gen_mask_parallel, sync=sync + ) + # 保存中间结果,以便在反向传播中使用 + ctx.save_for_backward(query, key, value, pse, padding_mask, atten_mask, result1, result2, result3, result0, result4, result5, result6) + ctx.head_num = head_num + ctx.input_layout = input_layout + ctx.scale = scale + ctx.keep_prob = keep_prob + ctx.pre_tockens = pre_tockens + ctx.next_tockens = next_tockens + ctx.inner_precise = inner_precise + ctx.prefix = prefix + ctx.actual_seq_qlen = actual_seq_qlen + ctx.actual_seq_kvlen = actual_seq_kvlen + ctx.sparse_mode = sparse_mode + ctx.gen_mask_parallel = gen_mask_parallel + ctx.sync = sync + + return result0, result1, result2, result3, result4, result5, result6 + + @staticmethod + def backward(ctx, grad_result0, grad_result1, grad_result2, grad_result3, grad_result4, grad_result5, grad_result6): + # 获取保存的中间结果 + query, key, value, pse, padding_mask, atten_mask, result1, result2, result3, result0, result4, result5, result6 = ctx.saved_tensors + # 反向传播逻辑 + # 这里假设有一个实现反向传播的函数 `npu_fusion_attention_backward` + grad_query, grad_key, grad_value, grad_pse = torch.ops.npu_graph.npu_fa_backward( + query, key, value, grad_result0, ctx.head_num, ctx.input_layout, pse=pse, padding_mask=padding_mask, atten_mask=atten_mask, softmax_max=result1, softmax_sum=result2, softmax_in=result3, attention_in=result0, scale_value=ctx.scale, keep_prob=ctx.keep_prob, pre_tockens=ctx.pre_tockens, next_tockens=ctx.next_tockens, inner_precise=ctx.inner_precise, seed=result4, offset=result5, numels=result6, prefix=ctx.prefix, actual_seq_qlen=ctx.actual_seq_qlen, actual_seq_kvlen=ctx.actual_seq_kvlen, sparse_mode=ctx.sparse_mode, gen_mask_parallel=ctx.gen_mask_parallel, sync=ctx.sync + ) + return (grad_query, grad_key, grad_value, None, None, grad_pse, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) + + +def npu_fusion_attention_graph(query, key, value, head_num, input_layout, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, + inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + return NpuGraphAttentionFunction.apply(query, key, value, head_num, input_layout, pse, padding_mask, + atten_mask, scale, keep_prob, pre_tockens, next_tockens, + inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, sparse_mode, gen_mask_parallel, sync) +torch_npu.npu_fusion_attention_graph = npu_fusion_attention_graph + + +def register_fx_pass(): + TOKEN_MAX = 2147483647 + from torch._inductor.pattern_matcher import register_replacement, fwd_only, joint_fwd_bwd + from torch._inductor.fx_passes.joint_graph import patterns + from torch._dynamo.utils import counters + from torch._inductor.fx_passes.fuse_attention import partialize_and_update_signature + + def _npu_fusion_attention_graph_pattern_1(query, key, value, inv_scale_factor, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.nn.functional.dropout( + torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(v) + + + def _npu_fusion_attention_graph_replacement_1(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + head_num = query.size(2) + input_layout = "BNSD" + return torch_npu.npu_fusion_attention_graph( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + head_num, + input_layout, + None, + atten_mask=None, + scale=inv_scale_factor, + keep_prob=1.0 - dropout_p, + )[0] + + def _get_sfdp_patterns(): + device = 'npu' + g_inp = functools.partial( + torch.empty, (2, 4, 8, 16), device=device, requires_grad=True + ) + c_inp = functools.partial(torch.tensor, 2.0, device=device) + d = {"dropout_p": 0.113377} + candidates = [] + for dtype in [torch.float]: + g = functools.partial(g_inp, dtype=dtype) + c = functools.partial(c_inp, dtype=dtype) + candidates.append(( + _npu_fusion_attention_graph_pattern_1, + _npu_fusion_attention_graph_replacement_1, + [g(), g(), g(), c()], + d, + )) + + for pattern, replacement, args, workaround in candidates: + # gets serialized to a python file and does not require tracing at runtime. + if not isinstance(workaround, dict): + raise ValueError("workaround not dict") + name = pattern.__name__ + + if dtype != torch.float: + name += "_half" + + if args[0].size(0) == 1: + name += "_bs1" + + training_name = name + "_training" + yield training_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": joint_fwd_bwd, + "pass_dicts": patterns, + "scalar_workaround": workaround, + } + + if workaround: + if not (len(workaround) == 1 and "dropout_p" in workaround): + raise ValueError("not (len(workaround) == 1 and dropout_p in workaround)") + # functools.partial insufficient because we look at signature downstream + pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + replacement = partialize_and_update_signature( + replacement, dropout_p=0.0 + ) + workaround = {} + + inference_name = name + "_inference" + yield inference_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "scalar_workaround": workaround, + } + + for _, register_replacement_kwargs in _get_sfdp_patterns(): + register_replacement( + **register_replacement_kwargs, + ) + +register_fx_pass() + + + diff --git a/torch_npu/_inductor/npu_triton_helpers.py b/torch_npu/_inductor/npu_triton_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..5f60f5dd499eec9257d4e37bb6cf9094059b506f --- /dev/null +++ b/torch_npu/_inductor/npu_triton_helpers.py @@ -0,0 +1,20 @@ +import triton +import triton.language as tl + +import triton.language.extra.ascend.libdevice as libdevice +from torch._inductor.runtime import triton_helpers +libdevice = tl.extra.ascend.libdevice +math = tl.math + + +@triton.jit +def maximum(a, b): + return tl.maximum(a, b) + + +@triton.jit +def minimum(a, b): + return tl.minimum(a, b) + +triton_helpers.maximum = maximum +triton_helpers.minimum = minimum diff --git a/torch_npu/_inductor/npu_triton_heuristics.py b/torch_npu/_inductor/npu_triton_heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..42575d31712a7133982767b8b11b2e1ca7f1cb23 --- /dev/null +++ b/torch_npu/_inductor/npu_triton_heuristics.py @@ -0,0 +1,961 @@ +# This file is based on triton_heuristics with heuristics designed for NPU +import os +import functools +import copy +from typing import Any, Callable, List, Optional +import logging +import re +import hashlib +import json + +import torch +from torch._inductor import config +from torch._dynamo.utils import dynamo_timed +from torch._inductor.runtime.triton_heuristics import ( + CachingAutotuner, + HeuristicType, + unique_configs, + hash_configs, + Config, + ASTSource, + _find_names, + get_first_attr, + collected_calls, +) +from torch._inductor.runtime.benchmarking import benchmarker +from torch._inductor.runtime.autotune_cache import AutotuneCache + + +from torch._inductor.runtime.runtime_utils import ( + create_bandwidth_info_str, + get_num_bytes, + +) + +import triton +from triton.compiler import CompiledKernel + +try: + from triton.backends.compiler import GPUTarget + from triton.runtime.autotuner import OutOfResources + import torch.autograd.profiler as autograd_profiler +except ImportError: + GPUTarget = None + OutOfResources = None + autograd_profiler = None + +from .codegen.split_tiling import SplitTiling +from .utils import get_current_raw_stream +from .codegen.tile_generator import TileGenerator +from .codegen.triton_utils import get_aligned_numel +from .config import aggresive_autotune +from .config import log + + +# torch-261 +class NPUCachingAutotuner(CachingAutotuner): + def __init__( + self, + fn, + triton_meta, # passed directly to triton + configs, + save_cache_hook, + mutated_arg_names: List[str], # see [Note: clone mutated buffers] + optimize_mem, + heuristic_type, + size_hints=None, + inductor_meta=None, # metadata not relevant to triton + custom_kernel=False, # whether the kernel is inductor-generated or custom + filename: Optional[str] = None, + reset_to_zero_arg_names: Optional[List[str]] = None, + ): + super().__init__(fn, triton_meta, configs, save_cache_hook, mutated_arg_names, optimize_mem, heuristic_type, + size_hints, inductor_meta, custom_kernel, filename, reset_to_zero_arg_names) + + self.exceptions = [] + + def precompile(self, warm_cache_only=False): + # xpu_graph changed TORCHINDUCTOR_CACHE_DIR. + # When TORCHINDUCTOR_COMPILE_THREADS > 1, multiprocessing's fork method + # does not propagate TORCHINDUCTOR_CACHE_DIR into the child threads. + # However, after all the child threads finished, the main thread reaches + # here and inherits xpu_graph's TORCHINDUCTOR_CACHE_DIR. Then the main + # thread finds the cache dir does not have any compiled kernel. It will + # compile all kernels one by one. + # So we directly replace TORCHINDUCTOR_CACHE_DIR with the standard cache dir. + if ("xpu_graph" in os.getenv("TORCHINDUCTOR_CACHE_DIR", "")): + import getpass + import tempfile + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + cache_dir = os.path.join( + tempfile.gettempdir(), + "torchinductor_" + sanitized_username, + ) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir + os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_dir, "triton", "0") + with self.lock: + if self.launchers: + return + self.launchers = [] + compiled_binaries = [] + if not self.configs: + raise RuntimeError("No triton configs are available") + for c in self.configs: + try: + compiled_binary, launcher = self._precompile_config( + c, warm_cache_only + ) + except Exception as e: + log.debug(f"[thread {os.getpid()}][InductorNPU.precompile] Exception = {e}, kernel = {self.fn.__name__} config = {c}") + # Skip the config if the compilation fails + continue + if launcher is not None: + self.launchers.append(launcher) + compiled_binaries.append(compiled_binary) + + if len(self.launchers) == 0: + raise RuntimeError( + "No valid triton configs. Report a fatal compilation error" + ) + + self.configs = None + + + + def _precompile_config(self, cfg: Config, warm_cache_only: bool): + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.triton_meta) + + for k, v in cfg.kwargs.items(): + if k not in self.fn.arg_names: + continue + compile_meta["constants"][k] = v + + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + + compile_meta["debug"] = ( + os.getenv("INDUCTOR_ASCEND_DEBUG", 'false').lower() in ('true', '1') and + config.assert_indirect_indexing and torch.version.hip is None + ) + + # device type will be "hip" rather than "cuda" here + compile_meta["device_type"] = self.device_props.type + compile_meta["cc"] = self.device_props.cc + + if ASTSource: + compile_args = ( + ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + ), + ) + + cc_str = str(compile_meta["cc"]) + if "gfx10" in cc_str or "gfx11" in cc_str: + rocm_warp_size = 32 + else: + rocm_warp_size = 64 + + if GPUTarget: + target = GPUTarget( + compile_meta["device_type"], + compile_meta["cc"], + rocm_warp_size if torch.version.hip else 32, + ) + else: + target = ( + (compile_meta["device_type"], compile_meta["cc"]) + if not torch.version.hip + else [ + compile_meta["device_type"], + compile_meta["cc"], + rocm_warp_size, + ] + ) + + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"], + } + if self.device_props.type == "hip": + if "waves_per_eu" in compile_meta: + options["waves_per_eu"] = compile_meta["waves_per_eu"] + if "matrix_instr_nonkdim" in compile_meta: + options["matrix_instr_nonkdim"] = compile_meta[ + "matrix_instr_nonkdim" + ] + compile_kwargs = { + "target": target, + "options": options, + } + else: + compile_args = (self.fn,) + compile_kwargs = compile_meta + if warm_cache_only: + return ( + triton.compile(*compile_args, **compile_kwargs), + None, + ) + + # importing from torch is safe now that precompile has returned + from torch._dynamo.device_interface import DeviceGuard + + device_interface = self.get_device_interface() + + # load binary to the correct device + with DeviceGuard(device_interface, compile_meta["device"]): # type: ignore[attr-defined] + # need to initialize context + device_interface.synchronize(device_interface.current_device()) + + try: + + binary = triton.compile(*compile_args, **compile_kwargs) + binary._init_handles() + + except Exception: + log.exception( + "Triton compilation failed: %s\n%s\nmetadata: %s", + self.inductor_meta.get("kernel_name", "triton_"), + self.fn.src, + compile_meta, + ) + return None, None + + call_args = [ + arg + for i, arg in enumerate(self.fn.arg_names) + if i not in self.fn.constexprs + ] + def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs] + + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "launch_enter_hook": CompiledKernel.launch_enter_hook, + "launch_exit_hook": CompiledKernel.launch_exit_hook, + "metadata": binary.packed_metadata + if hasattr(binary, "packed_metadata") + else binary.metadata, + "shared": binary_shared, + } + + scope["num_warps"] = ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ) + + scope["cta_args"] = ( + (binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims")) + if hasattr(binary, "num_ctas") + else ( + (binary.metadata.num_ctas, *binary.metadata.cluster_dims) + if hasattr(binary, "metadata") + else () + ) + ) + + scope["function"] = get_first_attr(binary, "function", "cu_function") + + def get_launch_args_without_kernel_launch_metadata( + input_grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + input_bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args before CompiledKernel.launch_metadata is added. + """ + return ( + grid_0, + grid_1, + grid_2, + num_warps, + *cta_args, + shared, + stream, + function, + launch_enter_hook, + launch_exit_hook, + metadata, + ) + + # Getting the kernel launch args is extremely perf-sensitive. Evaluating + # `bin.launch_metadata` is relatively expensive, and returns None unless a + # `launch_enter_hook` is installed. So if we don't have that hook installed, + # we want to burn None in to the launch args with zero overhead. + if binary.launch_enter_hook: + + def get_launch_args_with_kernel_launch_metadata( + input_grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + input_bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args after CompiledKernel.launch_metadata is added + """ + return ( + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + input_bin.launch_metadata(input_grid, stream, *args), + launch_enter_hook, + launch_exit_hook, + ) + + else: + + def get_launch_args_with_kernel_launch_metadata( + input_grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + input_bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args after CompiledKernel.launch_metadata is added + """ + return ( + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + None, + launch_enter_hook, + launch_exit_hook, + ) + + scope["get_launch_args"] = ( + get_launch_args_with_kernel_launch_metadata + if hasattr(binary, "launch_metadata") + else get_launch_args_without_kernel_launch_metadata + ) + + scope["runner"] = get_first_attr(binary, "run", "c_wrapper") + + exec( + f""" + def launcher({', '.join(def_args)}, grid, stream): + if callable(grid): + grid_0, grid_1, grid_2 = grid(grid_meta) + else: + grid_0, grid_1, grid_2 = grid + + args = {', '.join(call_args)}, + launch_args = get_launch_args( + grid, grid_0, grid_1, grid_2, stream, function, + metadata, bin, launch_enter_hook, launch_exit_hook, + num_warps, shared, cta_args, args + ) + runner(*launch_args, *args) + return bin + """.lstrip(), + scope, + ) + + launcher = scope["launcher"] + launcher.config = cfg + launcher.n_regs = getattr(binary, "n_regs", None) + launcher.n_spills = getattr(binary, "n_spills", None) + launcher.shared = binary_shared + launcher.store_cubin = True + # store this global variable to avoid the high overhead of reading it when calling run + if launcher.store_cubin: + launcher.fn = self.fn + launcher.bin = binary + + return binary, launcher + + def save_gpu_kernel(self, input_grid, input_stream, input_launcher): + self.save_npu_kernel(input_grid, input_stream, input_launcher) + + def save_npu_kernel(self, input_grid, input_stream, input_launcher): + if callable(input_grid): + grid_x, grid_y, grid_z = input_grid(input_launcher.config.kwargs) + else: + grid_x, grid_y, grid_z = input_grid + + key = self.inductor_meta.get("kernel_name", None) # unique kernel name + + if key is None: + raise RuntimeError("assert key is not None, kernel_name can not be None") + params = { + "mangled_name": ( + input_launcher.bin.metadata.name + if hasattr(input_launcher.bin.metadata, "name") + else input_launcher.bin.metadata["name"] + ), + "grid_x": grid_x, + "grid_y": grid_y, + "grid_z": grid_z, + "num_warps": ( + input_launcher.bin.num_warps + if hasattr(input_launcher.bin, "num_warps") + else input_launcher.bin.metadata.num_warps + ), + "shared_mem": ( + input_launcher.bin.shared + if hasattr(input_launcher.bin, "shared") + else input_launcher.bin.metadata.shared + ), + "stream": input_stream, + # User defined triton kernels will have arbitrary kwarg names + "meta": input_launcher.config.kwargs, + } + from torch._inductor.codecache import CudaKernelParamCache + + bin_type = "npubin" + binary = input_launcher.bin.asm[bin_type] # npubin type = npubin + CudaKernelParamCache.set(key, params, binary, bin_type='cubin') # CudaKernelParam + + self.cuda_kernel_saved = True + + # bench method is called by torch, grid can not be modified + def bench(self, launcher, *args, grid, with_profiler=False, **kwargs): + """Measure the performance of a given launcher""" + + if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( + "spill_threshold", 16 + ): + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + device_interface = self.get_device_interface() + stream = device_interface.get_raw_stream(device_interface.current_device()) + + def kernel_call(): + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + grid=grid, + stream=stream, + ) + + if with_profiler: + from torch._inductor.utils import do_bench_using_profiling + + return do_bench_using_profiling(kernel_call, warmup=10, rep=40) + # remove fast_flush=True for high version triton + return benchmarker.benchmark_gpu(kernel_call, rep=40) + + + +class NPUDebugAutotuner(NPUCachingAutotuner): + def __init__(self, *args, regex_filter="", **kwargs): + self.regex_filter = regex_filter + super().__init__(*args, **kwargs) + self.cached = None + + def run(self, *args, input_grid, stream): + possible_names = _find_names(self) + kernel_name = f"{max(possible_names, key=len)}" + if not re.match(self.regex_filter, kernel_name): + return + super().run(*args, grid=input_grid, stream=stream) + (launcher,) = self.launchers + + if self.cached is None: + ms = self.bench(launcher, *args, input_grid=input_grid) + num_in_out_ptrs = len( + [ + arg_name + for arg_name in self.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + gb_per_s = num_gb / (ms / 1e3) + self.cached = (ms, num_gb, gb_per_s, kernel_name) + else: + ms, num_gb, gb_per_s, kernel_name = self.cached + collected_calls.append((ms, num_gb, gb_per_s, kernel_name)) + print( + create_bandwidth_info_str(ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}") + ) + + +def cached_autotune( + size_hints: Optional[List[int]], + configs: List[Config], + triton_meta, + heuristic_type, + filename=None, + inductor_meta=None, + custom_kernel=False, +): + """ + A copy of triton.autotune that calls our subclass. Our subclass + has additional debugging, error handling, and on-disk caching. + """ + configs = unique_configs(configs) + if not (len(configs) == 1 or filename): + raise RuntimeError("assert len(configs) == 1 or filename") + + inductor_meta = {} if inductor_meta is None else inductor_meta + + disabled = inductor_meta.get("force_disable_caches", False) + + # on disk caching logic and/or remote caching + autotune_cache = None + if ( + not disabled + and filename is not None + and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) + and not os.environ.get("TRITON_INTERPRET", "0") == "1" + ): + configs_hash = hash_configs(configs) + + autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) + if autotune_cache: + if best_config := autotune_cache.read_best(inductor_meta, configs): + configs = [best_config] + + else: + if disabled: + log.debug("autotune caching is disabled by config.force_disable_caches") + + mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) + optimize_mem = inductor_meta.pop("optimize_mem", True) + + if "restore_value" in triton_meta: + mutated_arg_names += triton_meta.pop("restore_value") + + reset_to_zero_arg_names: List[str] = [] + if "reset_to_zero" in triton_meta: + reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero")) + + def decorator(fn): + # Remove XBLOCK from config if it's not a function argument. + # This way, coordinate descent tuning will not try to tune it. + # + # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1. + import inspect + + if "XBLOCK" not in inspect.signature(fn.fn).parameters: + for tconfig in configs: + if "XBLOCK" in tconfig.kwargs: + if tconfig.kwargs["XBLOCK"] != 1: + raise ValueError("tconfig.kwargs[XBLOCK] != 1") + tconfig.kwargs.pop("XBLOCK") + + if inductor_meta.get("profile_bandwidth"): + return NPUDebugAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + regex_filter=inductor_meta["profile_bandwidth_regex"], + with_profiler=inductor_meta[ + "profile_bandwidth_with_do_bench_using_profiling" + ], + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + with_bandwidth_info=True, + ) + return NPUCachingAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + ) + + return decorator + + +###################################################### +## Main entry points for triton kernel invocation ## +## adapts original heuristics for NPU arch, and ## +## redirect to NPUCaching autotuner ## +###################################################### + +def grid(*numels): + def grid_fn(meta): + split_axis_order = meta["split_axis_order"] + + if split_axis_order is not None and split_axis_order < len(numels): + numel = numels[split_axis_order] if split_axis_order is not None else 1 + xblock = meta["XBLOCK"] + NBLOCKS, _ = SplitTiling.get_nblocks_before_launch(numel, xblock) + else: + NBLOCKS = 1 + + log.debug("launch grid(%s), NBLOCKS:%d, meta:%s", numels, NBLOCKS, meta) + return ( + NBLOCKS, + 1, + 1, + ) + + return grid_fn + + +# split:sizeof split, xblock:axis1 length, rblock:axis2 length +def triton_config_npu_index( + size_hints, + inductor_meta, + triton_meta=None, + reduction=False, + persistent_reduction=False, + +) -> List[Config]: + num_warps = 1 + num_stages = 1 + configs = [] + log.info("[InductorNPU] processing kernel %s", inductor_meta['kernel_name']) + split_axis_order = inductor_meta["split_axis_order"] + axis1_order = inductor_meta["axis1_order"] + axis2_order = inductor_meta["axis2_order"] + low_dims = inductor_meta["low_dims"] + split_axis_dtype = inductor_meta["split_axis_dtype"] + split_numel = size_hints[split_axis_order] if split_axis_order is not None else 1 + is_low_dim = True if split_axis_order is not None and split_axis_order in low_dims else False + + min_aligned_numel = get_aligned_numel(split_axis_dtype) + + grid_list = [] + if (aggresive_autotune): + grid_list = SplitTiling.get_nblocks_xblock_list(split_numel) + else: + nblocks, split = SplitTiling.decide_nblocks_xblock(split_numel, axis2_order is None, min_aligned_numel) + grid_list.append((nblocks, split)) + + for nblocks, split in grid_list: + log.debug("generating tiling : size_hints:%s split_axis_order:%s, axis1_order:%s, axis2_order:%s, " + "low_dims:%s nblocks %s, split:%s persistent_reduction:%s split_axis_dtype:%s", size_hints, + split_axis_order, axis1_order, axis2_order, low_dims, nblocks, split, + persistent_reduction, split_axis_dtype) + # xblock is a range, don't auto_tune + xnumel = split if split_axis_order == axis1_order else size_hints[axis1_order] + rblock = 1 + if axis2_order is not None: + rblock = split if split_axis_order == axis2_order else size_hints[axis2_order] + + xblock_sub = xnumel + cfg = {"NBLOCKS": nblocks, "XBLOCK": split, "XBLOCK_SUB": xblock_sub} + # forward to grid() + cfg["split_axis_order"] = split_axis_order + cfg["axis2_order"] = axis2_order if not(axis2_order is None) else -1 + cfg["is_low_dim"] = is_low_dim + cfg["min_aligned_numel"] = min_aligned_numel + is_1d_reduction = reduction and axis2_order is None + if persistent_reduction: + numof_reduction_axis = inductor_meta["numof_reduction_axis"] + if numof_reduction_axis > 1: + del cfg["XBLOCK_SUB"] + configs.append(Config(cfg, num_warps=1, num_stages=1)) + elif axis2_order is None: + del cfg["XBLOCK"] + del cfg["XBLOCK_SUB"] + cfg["NBLOCKS"] = 1 + configs.append(Config(cfg, num_warps=1, num_stages=1)) + else: + TileGenerator.descend_xblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + elif is_1d_reduction: + cfg["NBLOCKS"] = 1 + cfg["XBLOCK"] = split_numel + cfg["XBLOCK_SUB"] = split_numel + TileGenerator.descend_xblock(rnumel=rblock, xblock=split_numel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + # both of the two axis are low dims + elif axis1_order in low_dims and axis2_order in low_dims: + cfg["RBLOCK"] = rblock + TileGenerator.descend_xblock_rblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + elif axis2_order is None and axis1_order is not None: + TileGenerator.descend_xblock(rnumel=0, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + # need to maximize xblock_sub + elif axis1_order in low_dims: + cfg["RBLOCK"] = rblock + TileGenerator.descend_rblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + elif axis2_order in low_dims: + cfg["RBLOCK"] = rblock + TileGenerator.descend_xblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + elif len(low_dims) == 0: + cfg["RBLOCK"] = rblock + if (axis1_order is not None) and (axis2_order is not None): + TileGenerator.descend_xblock_rblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel, aggresive=False) + elif axis1_order is not None: + TileGenerator.descend_xblock(rnumel=0, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel, aggresive=False) + else: + TileGenerator.descend_rblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel, aggresive=False) + else: + cfg["RBLOCK"] = rblock + tmp = Config(cfg, num_warps=num_warps, num_stages=num_stages) + configs.append(tmp) + + for cfg in configs: + log.debug("generated tiling configs %s", cfg.kwargs) + + return configs + + +def pointwise_npu_index( + size_hints, + triton_meta, + tile_hint=None, + filename=None, + min_elem_per_thread=0, + inductor_meta=None, +): + + inductor_meta = {} if inductor_meta is None else inductor_meta + triton_config_with_settings = functools.partial( + triton_config_npu_index + ) + return cached_autotune( + size_hints, + triton_config_with_settings(size_hints, inductor_meta=inductor_meta), + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + + +def reduction_npu_index( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + + """args to @triton.heuristics()""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if triton_meta is None: + raise RuntimeError("assert triton_meta is not None") + + contiguous_config = triton_config_npu_index(size_hints, inductor_meta=inductor_meta, reduction=True) + return cached_autotune( + size_hints, + [ + *contiguous_config, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.REDUCTION, + ) + + +def persistent_reduction_npu_index( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + configs = triton_config_npu_index(size_hints, inductor_meta=inductor_meta, reduction=True, + persistent_reduction=True) + + + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.PERSISTENT_REDUCTION, + ) + + +def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): + """ + Compile a triton foreach kernel + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=1, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +@dynamo_timed +def benchmark_all_configs(self, *args, input_grid, **kwargs): + print(f"candidate launcher count = {len(self.launchers)}") + + tilling_kernel_list = [] + + def kernel_call(launcher): + def call_kernel(): + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} + ) + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + grid=input_grid, + stream=stream, + ) + return call_kernel + + for launcher in self.launchers: + if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold: + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + stream = self.gpu_device.get_raw_stream( # type: ignore[call-arg] + self.gpu_device.current_device() + ) + tilling_kernel_list.append(kernel_call(launcher)) + + def do_batch_benchmark(tilling_kernel_list): + + def delete_file(base_path): + import shutil + if os.path.exists(base_path): + shutil.rmtree(base_path) + + import torch_npu + + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + + import uuid + random_uuid = uuid.uuid4().hex + md5_hash = hashlib.md5(random_uuid.encode()).hexdigest() + + from torch_npu._inductor.config import profile_path + + torch_path = profile_path + md5_hash + rep = 1 + with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule(wait=0, warmup=1, active=rep, repeat=1, skip_first=1), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), + record_shapes=False, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for _ in range(rep + 3): + for fn in tilling_kernel_list: + fn() + prof.step() + stream.synchronize() + + import pandas as pd + for root, _, files in os.walk(torch_path): + for file in files: + if file != 'kernel_details.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + triton_rows = df[df['Name'].str.startswith('triton', na=False)] + ret = triton_rows['Duration(us)'].astype(float).tolist() + delete_file(torch_path) + return ret + + delete_file(torch_path) + return [] + + try: + timinglist = do_batch_benchmark(tilling_kernel_list) + if not len(timinglist) == len(self.launchers): + raise RuntimeError("not len(timinglist) == len(self.launchers)") + timings = {launcher: timing for launcher, timing in zip(self.launchers, timinglist)} + except Exception as e: + print("some cases in batch benchmark has error! Logging Exception as:") + print(e) + print("switched to single bench...") + timings = { + launcher: self.bench(launcher, *args, **kwargs) + for launcher in self.launchers + } + + for k, v in timings.items(): + self.coordesc_tuner.cache_benchmark_result(k.config, v) + + if log.isEnabledFor(logging.DEBUG): + log.debug("Benchmark all input configs for %s, get:", self.fn.__name__) + for k, v in timings.items(): + log.debug( + "%s: %f, nreg %d, nspill %d, #shared-mem %s", + k.config, + v, + k.n_regs, + k.n_spills, + k.shared, + ) + print(f"final valid tillings count = {len(timings)}") + return timings \ No newline at end of file diff --git a/torch_npu/_inductor/runtime.py b/torch_npu/_inductor/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..ae00e9904361326ee61d36742a0b542017e0e5d2 --- /dev/null +++ b/torch_npu/_inductor/runtime.py @@ -0,0 +1,43 @@ +from typing import Optional +import functools + +from torch._inductor.runtime.hints import DeviceProperties +from .config import num_vector_core + + +class NPUDeviceProperties(DeviceProperties): + + + @classmethod + @functools.lru_cache(None) + def create(cls, device) -> DeviceProperties: + import torch + from torch._dynamo.device_interface import get_interface_for_device + + device_type = device.type + + if torch.version.hip and device_type == "cuda": + device_type = "hip" + + device_interface = get_interface_for_device(device) + props = device_interface.get_device_properties(device) + + try: + multi_processor_count = num_vector_core + except AttributeError: + if device_type == "xpu": + multi_processor_count = props.gpu_subslice_count + else: + raise + return cls( + type=device_type, + index=device.index, + multi_processor_count=multi_processor_count, + cc=device_interface.get_compute_capability(device), + major=getattr(props, "major", None), + regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None), + max_threads_per_multi_processor=getattr( + props, "max_threads_per_multi_processor", None + ), + warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None), + ) diff --git a/torch_npu/_inductor/utils.py b/torch_npu/_inductor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..697059f7a885b8b2ec3912954728b6a3a84a234c --- /dev/null +++ b/torch_npu/_inductor/utils.py @@ -0,0 +1,7 @@ +import torch +import torch_npu + + +# Not good implementation, but no other way +def get_current_raw_stream(device): + return torch.npu.current_stream(device).npu_stream \ No newline at end of file