diff --git a/test/_inductor/test_expanded_reduction.py b/test/_inductor/test_expanded_reduction.py new file mode 100644 index 0000000000000000000000000000000000000000..40e0363916c2d5920e7c57b75343421eea46683a --- /dev/null +++ b/test/_inductor/test_expanded_reduction.py @@ -0,0 +1,27 @@ +import torch +from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests +from testutils import TestUtils +import torch_npu + + +class TestExpandedReduction(TestUtils): + def op_calc(self, first_element, second_element): + x = first_element * second_element + return x.sum((0, 1)) + + @parametrize('shape_x', [(2, 197, 256)]) + @parametrize('shape_y', [(2, 1, 256)]) + @parametrize('dtype', ['float16']) + def test_expanded_reduction_cases(self, shape_x, shape_y, dtype): + first_element = self._generate_tensor(shape_x, dtype) + second_element = self._generate_tensor(shape_y, dtype) + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + self.assertEqual(std_result, inductor_result, atol=1e-5, rtol=1e-5) + +instantiate_parametrized_tests(TestExpandedReduction) + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/_inductor/codegen/split_tiling.py b/torch_npu/_inductor/codegen/split_tiling.py index 782cc9f7455cd6d9f4eea63075768fdeb1af0690..9b94ce3c5cce240c3234e1088849ccbcd27b9fb2 100644 --- a/torch_npu/_inductor/codegen/split_tiling.py +++ b/torch_npu/_inductor/codegen/split_tiling.py @@ -141,6 +141,8 @@ class SplitTiling: def select_tiling(low_dim=True, reduction=True): for axis in reversed(self.kernel.sorted_axis): + if meet_stop_condition(): + break if low_dim and axis.sorted_order in self.kernel.low_dims and axis not in self.kernel.tiling_axis: axis.is_tiling_axis = True self.kernel.tiling_axis.append(axis)