From 3fd5e59af80851e0f798c110c709406b684e2721 Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Thu, 27 Jan 2022 14:03:07 +0800 Subject: [PATCH] Add Gelu Operator --- test/test_network_ops/test_gelu.py | 97 +++++++++++++++++++++++ torch_npu/csrc/aten/ops/GeluKernelNpu.cpp | 33 ++++++++ 2 files changed, 130 insertions(+) create mode 100644 test/test_network_ops/test_gelu.py create mode 100644 torch_npu/csrc/aten/ops/GeluKernelNpu.cpp diff --git a/test/test_network_ops/test_gelu.py b/test/test_network_ops/test_gelu.py new file mode 100644 index 0000000000..1a4d3b133b --- /dev/null +++ b/test/test_network_ops/test_gelu.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor +#pylint: disable=unused-argument + +class TestGelu(TestCase): + def generate_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + return npu_input1 + + def cpu_op_exec(self, input1): + output = torch.nn.functional.gelu(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + input1_npu = input1.to('npu') + output = torch.nn.functional.gelu(input1_npu) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_fp16(self, input1): + input1 = input1.to(torch.float32) + output = torch.nn.functional.gelu(input1) + output = output.numpy() + output = output.astype(np.float16) + return output + + def npu_op_exec_fp16(self, input1): + input1 = input1.to(torch.float32).to('npu') + output = torch.nn.functional.gelu(input1) + output = output.to("cpu") + output = output.numpy().astype(np.float16) + return output + + def test_gelu_float32_1(self, device): + input1 = self.generate_data(0, 100, (4,3), np.float32) + cpu_input1 = copy.deepcopy(input1) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_gelu_float32_2(self, device): + input1 = self.generate_data(0, 1000, (4,3), np.float32) + cpu_input1 = copy.deepcopy(input1) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_gelu_float16_1(self, device): + npu_input1 = self.generate_data(0, 100, (5,3), np.float16) + cpu_input1 = copy.deepcopy(npu_input1) + cpu_output = self.cpu_op_exec_fp16(cpu_input1) + npu_output = self.npu_op_exec_fp16(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_gelu_float16_2(self, device): + npu_input1 = self.generate_data(0, 1000, (5,3), np.float16) + cpu_input1 = copy.deepcopy(npu_input1) + cpu_output = self.cpu_op_exec_fp16(cpu_input1) + npu_output = self.npu_op_exec_fp16(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_gelu_float16_3(self, device): + npu_input1 = self.generate_data(0, 1000, (3,3), np.float16) + cpu_input1 = copy.deepcopy(npu_input1) + cpu_output = self.cpu_op_exec_fp16(cpu_input1) + npu_output = self.npu_op_exec_fp16(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestGelu, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GeluKernelNpu.cpp b/torch_npu/csrc/aten/ops/GeluKernelNpu.cpp new file mode 100644 index 0000000000..51180b051a --- /dev/null +++ b/torch_npu/csrc/aten/ops/GeluKernelNpu.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::gelu(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + // calculate the output result of the NPU + OpCommand cmd; + cmd.Name("Gelu") + .Input(self) + .Output(result) + .Run(); + + return result; +} +} // namespace native +} // namespace at_npu -- Gitee