diff --git a/test/_inductor/test_check_accuracy.py b/test/_inductor/test_check_accuracy.py index e30fe2228c008370cc91fb1099f9a95d8513d057..f2c2550fd127d55f57873a6326cd5ea98822abf5 100644 --- a/test/_inductor/test_check_accuracy.py +++ b/test/_inductor/test_check_accuracy.py @@ -7,6 +7,7 @@ from torch.testing._internal.common_utils import run_tests from testutils import TestUtils import torch_npu +torch._inductor.config.fx_graph_cache = False os.environ["INDUCTOR_ASCEND_CHECK_ACCURACY"] = "1" diff --git a/test/_inductor/test_force_fallback.py b/test/_inductor/test_force_fallback.py index 5f3c5f5243b3567f0c681a0b01e1cd664c3e9150..40049e802697bf5d29ef9c48591343b9f75fe413 100644 --- a/test/_inductor/test_force_fallback.py +++ b/test/_inductor/test_force_fallback.py @@ -7,6 +7,7 @@ from torch.testing._internal.common_utils import run_tests from testutils import TestUtils import torch_npu +torch._inductor.config.fx_graph_cache = False os.environ["INDUCTOR_ASCEND_DUMP_FX_GRAPH"] = "1" diff --git a/test/_inductor/test_lowering_fx.py b/test/_inductor/test_lowering_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..db15c1949c5f036eaea2a369ee8a08988360ed02 --- /dev/null +++ b/test/_inductor/test_lowering_fx.py @@ -0,0 +1,44 @@ +import os +import torch +from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests +from torch._inductor.utils import run_and_get_code +from testutils import TestUtils +import torch_npu + +torch._inductor.config.fx_graph_cache = False +os.environ["INDUCTOR_ASCEND_CHECK_ACCURACY"] = "1" + + +class TestLoweringFx(TestUtils): + @parametrize('shape', [(32, 16, 64, 128)]) + @parametrize('dim', [0]) + @parametrize('dtype', ['float32']) + def test_sum_not_fallback(self, shape, dim, dtype): + input_element = self._generate_tensor(shape, dtype, floatPOSIFLAG=1) + golden = torch.sum(input_element, dim) + compiled_op_calc = torch.compile(torch.sum, backend="inductor", dynamic=False) + inductor_result, output_codes = run_and_get_code(compiled_op_calc, input_element, dim) + self.assertTrue(len(output_codes) == 1) + self.assertTrue(output_codes[0].count("async_compile.triton") == 1) + self.assertTrue(output_codes[0].count(".run(") == 1) + torch.testing.assert_close(golden, inductor_result, rtol=1e-4, atol=1e-4) + + @parametrize('shape', [(32, 16, 64, 128)]) + @parametrize('dtype', ['float32']) + def test_div_by_reciprocal_mul(self, shape, dtype): + input_element = self._generate_tensor(shape, dtype) + divisor = 128.0 + golden = torch.div(input_element, divisor) + compiled_op_calc = torch.compile(torch.div, backend="inductor", dynamic=False) + inductor_result, output_codes = run_and_get_code(compiled_op_calc, input_element, divisor) + self.assertTrue(len(output_codes) == 1) + self.assertTrue(output_codes[0].count("async_compile.triton") == 1) + self.assertTrue("torch.ops.aten.div.Tensor(" not in output_codes[0]) + self.assertTrue("torch.ops.aten.mul.Tensor(" in output_codes[0]) + torch.testing.assert_close(golden, inductor_result, rtol=1e-4, atol=1e-4) + + +instantiate_parametrized_tests(TestLoweringFx) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/scheduling.py b/torch_npu/_inductor/codegen/scheduling.py index 5c5f2c9d2526d3bff6c87af3255cc23141c572aa..5c40625e7fc494ca07f698e7f220b071d028b2be 100644 --- a/torch_npu/_inductor/codegen/scheduling.py +++ b/torch_npu/_inductor/codegen/scheduling.py @@ -8,6 +8,7 @@ from typing import Dict, Sequence, List, Iterable from typing import List, Union, Any from typing import Union, Iterable import sympy +import torch from torch._dynamo.utils import counters from torch._dynamo.utils import counters from torch._inductor import scheduler, metrics @@ -136,7 +137,7 @@ class NPUTritonScheduling(TritonScheduling): if traced_graph is None: log.warning(f"For nodes {nodes}, could not gen fx graph while dump-graph.") else: - traced_graph_hash = code_hash(traced_graph.print_readable(print_output=False)) + traced_graph_hash = code_hash(traced_graph.print_readable(print_output=False) + torch.__version__) kernel_name, src_code = self.define_kernel(src_code, node_schedule, kernel, traced_graph_hash) diff --git a/torch_npu/_inductor/lowering_fx.py b/torch_npu/_inductor/lowering_fx.py index 2f0ceea660f7094eaf63efe82a5b33ab9adf20ea..1335374b8215292bd76401b7ae8c4bd40450bf08 100644 --- a/torch_npu/_inductor/lowering_fx.py +++ b/torch_npu/_inductor/lowering_fx.py @@ -1,6 +1,7 @@ import functools import itertools import os +import math import textwrap from typing import ( Any, @@ -14,15 +15,10 @@ from typing import ( ) import sympy import torch._ops -import torch._ops from sympy.core import Expr, Integer, Symbol from torch._inductor import ir -from torch._inductor import ir from torch._inductor import lowering -from torch._inductor import lowering -from torch._inductor import scheduler from torch._inductor import scheduler -from torch._inductor.decomposition import decompositions from torch._inductor.decomposition import decompositions, pw_cast_for_opmath from torch._inductor.fx_passes.post_grad import view_to_reshape from torch._inductor.ir import ( @@ -39,18 +35,12 @@ from torch._inductor.ir import ( validate_ir, View, ) -from torch._inductor.ir import ExpandView, TensorBox -from torch._inductor.ir import ExpandView, TensorBox -from torch._inductor.ir import Reduction -from torch._inductor.ir import Reduction -from torch._inductor.lowering import sum_, _validate_dim, get_promoted_dtype +from torch._inductor.lowering import _validate_dim, get_promoted_dtype from torch._inductor.utils import ModularIndexing, FloorDiv from torch._inductor.utils import ( decode_device, sympy_product, ) -from torch._inductor.utils import sympy_product -from torch._inductor.utils import sympy_product from torch._inductor.virtualized import ops, V from torch._prims_common import ( canonicalize_dims, @@ -63,16 +53,6 @@ from torch._prims_common import ( is_integer_dtype, Number, ) -from torch._prims_common import ( - is_boolean_dtype, - is_integer_dtype, - get_computation_dtype, -) -from torch._prims_common import ( - is_boolean_dtype, - is_integer_dtype, - get_computation_dtype, -) from torch.fx.experimental.proxy_tensor import make_fx from torch.utils._sympy.functions import ( FloorDiv, @@ -1967,6 +1947,14 @@ def _register_npu_inductor_fallbacks(): if is_integral: return truncdiv(a, b) + if (divisor := lowering.get_constant_value(b)) is not None: + # Replace divide by constant with multiply by reciprocal + if divisor.value == 0: + reciprocal = math.copysign(float("inf"), divisor.value) + else: + reciprocal = 1.0 / divisor.value + return mul(a, reciprocal) + def fn(*args): return ops.truediv(*args) @@ -2191,6 +2179,16 @@ def _register_npu_inductor_fallbacks(): register_inplace(aten.__ixor__, aten.__xor__) ########################################################################## + @register_lowering([aten.sum, prims.sum]) + def sum_(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("sum", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + @register_lowering(aten.mean) def mean(x, axis=None, keepdim=False, *, dtype=None):