diff --git a/test/test_network_ops/test_fast_gelu.py b/test/test_network_ops/test_fast_gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..3cbed325e24a62b21b116f5bb477dea65a741c3f --- /dev/null +++ b/test/test_network_ops/test_fast_gelu.py @@ -0,0 +1,37 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestFastGelu(TestCase): + def npu_op_exec(self, input1): + output = torch_npu.fast_gelu(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def test_fastgelu(self, device): + input1 = torch.tensor([1.,2.,3.,4.]).npu() + exoutput = torch.tensor([0.8458, 1.9357, 2.9819, 3.9956]) + output = self.npu_op_exec(input1) + self.assertRtolEqual(exoutput.numpy(), output) + +instantiate_device_type_tests(TestFastGelu, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_fast_gelu_backward.py b/test/test_network_ops/test_fast_gelu_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..6fe30f1d21ae34d758dc5abd4a49258c05764c9a --- /dev/null +++ b/test/test_network_ops/test_fast_gelu_backward.py @@ -0,0 +1,43 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestFastGelu(TestCase): + def npu_op_exec(self, input1): + input1.requires_grad = True + output = torch_npu.fast_gelu(input1) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + output_grad = output_grad.to("cpu") + output_grad = output_grad.detach().numpy() + output = output.cpu().detach().numpy() + return output_grad, output + + def test_fastgelu(self, device): + input1 = torch.tensor([1.,2.,3.,4.]).npu() + exoutputgrad = torch.tensor([1.0677795, 1.0738151, 1.0245483, 1.0064018]) + exoutput = torch.tensor([0.8458, 1.9357, 2.9819, 3.9956]) + outputgrad, output = self.npu_op_exec(input1) + self.assertRtolEqual(exoutputgrad.numpy(), outputgrad) + self.assertRtolEqual(exoutput.numpy(), output) + +instantiate_device_type_tests(TestFastGelu, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index a7a709fa46c0070cc261ca1d6c7d88886164b333..0e6ab1ad91639253d5b03045d5f9cae854f00cd6 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1890,6 +1890,8 @@ custom: variants: function, method - func: one_(Tensor(a!) self) -> Tensor(a!) variants: method, function + - func: fast_gelu_backward(Tensor grad, Tensor self) -> Tensor + variants: function, method - func: npu_bert_apply_adam(Scalar lr, Scalar beta1, Scalar beta2, Scalar epsilon, Tensor grad, Scalar max_grad_norm, Scalar global_grad_norm, Scalar weight_decay) -> (Tensor var, Tensor m, Tensor v) - func: npu_bert_apply_adam.out(Scalar lr, Scalar beta1, Scalar beta2, Scalar epsilon, Tensor grad, Scalar max_grad_norm, Scalar global_grad_norm, Scalar weight_decay, *, Tensor(a!) var, Tensor(b!) m, Tensor(c!) v) -> (Tensor(a!), Tensor(b!), Tensor(c!)) - func: npu_conv_transpose2d_backward(Tensor input, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) @@ -1918,3 +1920,4 @@ custom: custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor + - func: fast_gelu(Tensor self) -> Tensor \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/FastGeluKernelNpu.cpp b/torch_npu/csrc/aten/ops/FastGeluKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b5ebb1b52b85e72784d5fb619dc1b5c7d47622a --- /dev/null +++ b/torch_npu/csrc/aten/ops/FastGeluKernelNpu.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/KernelNpuOutputSize.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using torch::autograd::AutogradContext; +using tensor_list = std::vector; +namespace { +at::Tensor fast_gelu_npu_nocheck(at::Tensor& result, const at::Tensor& self) { + + OpCommand cmd; + cmd.Name("FastGelu") + .Input(self) + .Output(result) + .Run(); + + return result; +} + +} // namespace + +namespace { +at::Tensor& fast_gelu_backward_npu_nocheck( + at::Tensor& grad_input, + const at::Tensor& grad, + const at::Tensor& self) { + // constructs the input and output NPUTensorDesc + OpCommand cmd; + cmd.Name("FastGeluGrad") + .Input(grad) + .Input(self) + .Output(grad_input) + .Run(); + + return grad_input; +} +} + +class NPUFastGeluFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, + const at::Tensor& self) { + at::AutoNonVariableTypeMode g; + ctx->save_for_backward({self}); + auto outputSize = input_same_output_size(self); + at::Tensor result = OpPreparation::ApplyTensor(self); + + return fast_gelu_npu_nocheck(result, self); + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + + at::Tensor result = NPUNativeFunctions::fast_gelu_backward(grad_outputs[0], input); + tensor_list output = {result}; + return output; + } +}; + +at::Tensor NPUNativeFunctions::fast_gelu_backward( + const at::Tensor& grad, + const at::Tensor& self) { + // calculate the output size + auto outputSize = input_same_output_size(self); + + // construct the output tensor of the NPU + at::Tensor grad_input = OpPreparation::ApplyTensor(self); + + // calculate the output result of the NPU + fast_gelu_backward_npu_nocheck(grad_input, grad, self); + + return grad_input; +} + +at::Tensor NPUNativeFunctions::fast_gelu(const at::Tensor& self) { + return NPUFastGeluFunction::apply(self); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file