diff --git a/test/test_onnx/test_wrapper_onnx_ops.py b/test/test_onnx/test_wrapper_onnx_ops.py index 99d7c950fa19f9b3505da72014d0575d2ed28c1f..9a617f7a6d39853e2e91e2032a4921104336bce0 100644 --- a/test/test_onnx/test_wrapper_onnx_ops.py +++ b/test/test_onnx/test_wrapper_onnx_ops.py @@ -213,6 +213,29 @@ class TestOnnxOps(TestCase): assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + + def test_wrapper_npu_group_norm_silu(self): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.weight = torch.rand((4,)).npu().half() + self.bias = torch.rand((4,)).npu().half() + + def forward(self, input_): + return torch_npu.npu_group_norm_silu(input_, self.weight, self.bias, + group=2, eps=0.00001) + + def export_onnx(onnx_model_name): + input_ = torch.rand(5, 4, 3, 3).npu().half() + model = Model().to("npu") + model(input_) + self.onnx_export(model, input_, onnx_model_name, ["input_"], ["out1", "out2", "out3"]) + + onnx_model_name = "model_npu_group_norm_silu.onnx" + export_onnx(onnx_model_name) + assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, + onnx_model_name))) + def test_wrapper_npu_fused_attention_score(self): class Model(torch.nn.Module): def __init__(self): diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 854b37ff1489d2a20d17c6f58ff08003e6f23862..f1e4247efc85356637b9e5edfb74ac3ba6465f18 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -109,6 +109,23 @@ class NPUGeGluOP(torch.autograd.Function): return g.op("npu::NPUGeGlu", self, dim_i=dim, approximate_i=approximate, outputs=2) +class NPUGroupNormSiluOP(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch_npu._C._VariableFunctionsClass.npu_group_norm_silu(*args, **kwargs) + + @staticmethod + def symbolic(g, self: Tensor, gamma: Optional[Tensor], beta: Optional[Tensor], + group: int, eps: float = 0.00001): + if gamma is None: + gamma = g.op("Constant", value_t=torch.tensor([]).to(torch.float)) + if beta is None: + beta = g.op("Constant", value_t=torch.tensor([]).to(torch.float)) + return g.op("npu::NPUGroupNormSilu", self, gamma, beta, group_i=group, eps_f=eps, + outputs=3) + + class NPUFusedAttentionScoreOP(torch.autograd.Function): @staticmethod @@ -681,6 +698,10 @@ def wrapper_npu_geglu(self, dim=-1, approximate=1): return NPUGeGluOP.apply(self, dim, approximate) +def wrapper_npu_group_norm_silu(x, gamma, beta, group, eps=0.00001): + return NPUGroupNormSiluOP.apply(x, gamma, beta, group, eps) + + def wrapper_npu_fused_attention_score(query_layer, key_layer, value_layer, attention_mask, scale, keep_prob, query_transpose=False, key_transpose=False, bmm_score_transpose_a=False, bmm_score_transpose_b=False, @@ -878,6 +899,7 @@ def add_onnx_ops(): torch_npu.fast_gelu = wrapper_npu_fast_gelu torch_npu.npu_fast_gelu = wrapper_npu_fast_gelu torch_npu.npu_geglu = wrapper_npu_geglu + torch_npu.npu_group_norm_silu = wrapper_npu_group_norm_silu torch_npu.npu_fused_attention_score = wrapper_npu_fused_attention_score torch_npu.npu_ciou = wrapper_npu_ciou torch_npu.npu_multi_head_attention = wrapper_npu_multi_head_attention