From 76f76a48d8982b29ba8f43ccff503aa1290e1364 Mon Sep 17 00:00:00 2001 From: wl1259 Date: Fri, 27 Jun 2025 11:17:52 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E6=8F=90=E4=BA=A4=EF=BC=8Cfix=20hunyuan=20?= =?UTF-8?q?model=20compile=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/_inductor/codegen/triton.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py index e39948eb65..439f63065a 100644 --- a/torch_npu/_inductor/codegen/triton.py +++ b/torch_npu/_inductor/codegen/triton.py @@ -438,6 +438,7 @@ class NPUIndexTritonKernel(TritonKernel): self.golden_var_list = None self.reduce_analysis = None self.load_store_indexing = None + self.store_mask = {} def gen_triton_ext_imports(self): imports = IndentedBuffer() @@ -867,6 +868,18 @@ class NPUIndexTritonKernel(TritonKernel): is_last_axis = index == len(self.sorted_axis) - 1 indexing_code = getattr(range_val, "indexing_code") + if not self.first_node: + for mask in self.store_mask.keys(): + idx = self.store_mask[mask] + if idx == index: + continue + if mask in str(self.body): + continue + # add mask + other_axis_indexing_code = self.sorted_axis[idx].indexing_code + indexing_code.splice(other_axis_indexing_code) + + reduction_1d = is_1d_reduction() do_indent = False # do nothing except for writing porintwise @@ -1002,6 +1015,18 @@ class NPUIndexTritonKernel(TritonKernel): value_str = f"{value}" mask_str = indexing.mask_str + if index_analyze.var_replacements: + for tmp_var in index.free_symbols: + if tmp_var.name in index_analyze.var_replacements.keys(): + continue + + for mask in indexing.mask_vars: + str_var = str(index_analyze.var_replacements[tmp_var]) + if str_var in mask: + axis = self.range_tree_nodes[tmp_var] + idx = self.sorted_axis.index(axis) + self.store_mask[mask] = idx + if index_analyze.need_permute: value_str = value_str.replace(f"{value}", f"{value}{index_analyze.generate_statement()}") -- Gitee From 5d5e6bfc72f2f639b2a3eb432a92b1a2c562fd63 Mon Sep 17 00:00:00 2001 From: wl1259 Date: Sat, 28 Jun 2025 11:10:41 +0800 Subject: [PATCH 2/3] commit for mask bug fix --- torch_npu/_inductor/codegen/triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py index 439f63065a..ecf9bc7dc0 100644 --- a/torch_npu/_inductor/codegen/triton.py +++ b/torch_npu/_inductor/codegen/triton.py @@ -1017,7 +1017,7 @@ class NPUIndexTritonKernel(TritonKernel): if index_analyze.var_replacements: for tmp_var in index.free_symbols: - if tmp_var.name in index_analyze.var_replacements.keys(): + if tmp_var not in index_analyze.var_replacements.keys(): continue for mask in indexing.mask_vars: -- Gitee From 9cf25302951f0a317ded9df5c9994347b9410c1e Mon Sep 17 00:00:00 2001 From: wl1259 Date: Tue, 15 Jul 2025 16:45:46 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=8F=90=E4=BA=A4=EF=BC=8Cfix=20codegen=20?= =?UTF-8?q?bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/_inductor/test_permute_layernorm.py | 47 ++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 test/_inductor/test_permute_layernorm.py diff --git a/test/_inductor/test_permute_layernorm.py b/test/_inductor/test_permute_layernorm.py new file mode 100644 index 0000000000..5b69338fdb --- /dev/null +++ b/test/_inductor/test_permute_layernorm.py @@ -0,0 +1,47 @@ +import torch +from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests +from testutils import OperatorType, TestUtils +from torch._dynamo.testing import rand_strided +import torch_npu + + +class TestSqrt(TestUtils): + def forward(self, arg2_1, arg3_1, arg4_1, arg5_1): + unsqueeze = torch.ops.aten.unsqueeze.default(arg2_1, 1); + npu_dtype_cast_2 = torch.ops.npu.npu_dtype_cast.default(arg3_1, torch.float32) + npu_dtype_cast_3 = torch.ops.npu.npu_dtype_cast.default(arg4_1, torch.float32) + npu_dtype_cast_4 = torch.ops.npu.npu_dtype_cast.default(arg5_1, torch.float32) + clone = torch.ops.aten.clone.default(npu_dtype_cast_2, memory_format=torch.contiguous_format) + var_mean = torch.ops.aten.var_mean.correction(clone, [2], correction=0, keepdim=True) + getitem = var_mean[0] + getitem_1 = var_mean[1] + add = torch.ops.aten.add.Tensor(getitem, 1e-06) + rsqrt = torch.ops.aten.rsqrt.default(add) + sub = torch.ops.aten.sub.Tensor(clone, getitem_1) + mul_1 = torch.ops.aten.mul.Tensor(sub, rsqrt) + mul_2 = torch.ops.aten.mul.Tensor(mul_1, npu_dtype_cast_3) + add_1 = torch.ops.aten.add.Tensor(mul_2, npu_dtype_cast_4) + npu_dtype_cast_5 = torch.ops.npu.npu_dtype_cast.default(add_1, torch.float16) + add_2 = torch.ops.aten.add.Tensor(npu_dtype_cast_5, unsqueeze) + return add_2 + + def test_permute_layernorm_cases(self): + arg2 = rand_strided((2, 1408), (1408, 1), device='npu', dtype=torch.float32) + arg3 = rand_strided((2, 3840, 1408), (5406720, 1, 3840), device='npu', dtype=torch.float16) + arg4 = rand_strided((1408,), (1,), device='npu', dtype=torch.float16) + arg5 = rand_strided((1408,), (1,), device='npu', dtype=torch.float16) + + std_result = self.forward(arg2, arg3, arg4, arg5) + compiled_op_calc = torch.compile(self.forward, backend="inductor") + inductor_result = compiled_op_calc(arg2, arg3, arg4, arg5) + + rtol = 1e-2 + atol = 1e-2 + torch.testing.assert_close(std_result, inductor_result, equal_nan=True, rtol=rtol, atol=atol) + + +instantiate_parametrized_tests(TestSqrt) + +if __name__ == "__main__": + test = TestSqrt() + test.test_permute_layernorm_cases() -- Gitee