diff --git a/test/test_network_ops/test_confusion_transpose.py b/test/test_network_ops/test_confusion_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..578d75df492494dbe2f8ca812e870e9c2dafd274 --- /dev/null +++ b/test/test_network_ops/test_confusion_transpose.py @@ -0,0 +1,57 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestConfusionTransposeD(TestCase): + def npu_op_exec(self, input1, shape, perm, transpose_first): + output = torch_npu.npu_confusion_transpose(input1, perm, shape, transpose_first) + output = output.cpu().numpy() + return output + + def cpu_op_exec(self, input1, shape, perm, transpose_first): + if transpose_first: + output = input1.permute(*perm).contiguous().view(shape) + else: + output = input1.view(shape).permute(*perm) + output = output.numpy() + return output + + def test_confusion_transpose(self, device): + shape_format = [ + [[np.float32, 0, [1, 576, 2560]],[1, 576, 32, 80], (0, 2, 1, 3), False], + [[np.float32, 0, [1, 32, 576, 80]],[1, 576, 2560], (0, 2, 1, 3), True], + [[np.float16, 0, [1, 576, 2560]], [1, 576, 32, 80], (0, 2, 1, 3), False], + [[np.float16, 0, [1, 32, 576, 80]], [1, 576, 2560], (0, 2, 1, 3), True], + [[np.int, 0, [1, 576, 2560]], [1, 576, 32, 80], (0, 2, 1, 3), False], + [[np.int, 0, [1, 32, 576, 80]], [1, 576, 2560], (0, 2, 1, 3), True], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + cpu_output = self.cpu_op_exec(cpu_input, item[1], item[2], item[3]) + npu_output = self.npu_op_exec(npu_input, item[1], item[2], item[3]) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestConfusionTransposeD, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_confusion_transpose_backward.py b/test/test_network_ops/test_confusion_transpose_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..81565ef3db458c53115dc6b17d9a0b81acf2a747 --- /dev/null +++ b/test/test_network_ops/test_confusion_transpose_backward.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestConfusionTransposeDBackward(TestCase): + def npu_op_exec(self, input1, shape, perm, transpose_first): + input1.requires_grad_() + output = torch_npu.npu_confusion_transpose(input1, perm, shape, transpose_first) + output.backward(torch.ones_like(output)) + output1 = output.detach().cpu().numpy() + output2 = input1.grad.cpu().numpy() + return output1, output2 + + def cpu_op_exec(self, input1, shape, perm, transpose_first): + input1.requires_grad_() + if transpose_first: + output = input1.permute(*perm).contiguous().view(shape) + else: + output = input1.view(shape).permute(*perm) + output.backward(torch.ones_like(output)) + output1 = output.detach().numpy() + output2 = input1.grad.numpy() + return output1, output2 + + def test_confusion_transpose_backward(self, device): + shape_format = [ + [[np.float32, 0, [1, 576, 2560]],[1, 576, 32, 80], (0, 2, 1, 3), False], + [[np.float32, 0, [1, 32, 576, 80]],[1, 576, 2560], (0, 2, 1, 3), True], + [[np.float16, 0, [1, 576, 2560]], [1, 576, 32, 80], (0, 2, 1, 3), False], + [[np.float16, 0, [1, 32, 576, 80]], [1, 576, 2560], (0, 2, 1, 3), True], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + cpu_output1, cpu_output2 = self.cpu_op_exec(cpu_input, item[1], item[2], item[3]) + npu_output1, npu_output2 = self.npu_op_exec(npu_input, item[1], item[2], item[3]) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + +instantiate_device_type_tests(TestConfusionTransposeDBackward, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 167c196324270d51a977cc7ee802bda23166e854..df8cbfc9f3f396c3c86a4252ba01a10de998d16b 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1916,7 +1916,9 @@ custom: variants: function, method - func: npu_softmax_cross_entropy_with_logits_backward(Tensor grad, Tensor self, Tensor labels) -> Tensor variants: function, method + - func: npu_confusion_transpose_backward(Tensor grad, int[] perm, int[] shape, bool transpose_first) -> Tensor custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor - - func: fast_gelu(Tensor self) -> Tensor \ No newline at end of file + - func: npu_confusion_transpose(Tensor self, int[] perm, int[] shape, bool transpose_first) -> Tensor + variants: function, method diff --git a/torch_npu/csrc/aten/ops/ConfusionTransposeKernelNpu.cpp b/torch_npu/csrc/aten/ops/ConfusionTransposeKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..82cf9adc1db86cac9b1b1f003f0557104986046f --- /dev/null +++ b/torch_npu/csrc/aten/ops/ConfusionTransposeKernelNpu.cpp @@ -0,0 +1,129 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_list = std::vector; + +at::Tensor confusion_transpose_npu( + const at::Tensor& self, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first) { + c10::SmallVector output_size; + if (transpose_first){ + output_size = array_to_small_vector(shape); + } else { + for (int i = 0; i < perm.size(); i++){ + output_size.emplace_back(shape[perm[i]]); + } + } + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self, output_size); + OpCommand cmd; + cmd.Name("ConfusionTransposeD") + .Input(self) + .Output(result) + .Attr("perm", perm) + .Attr("shape", shape) + .Attr("transpose_first", transpose_first) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::npu_confusion_transpose_backward( + const at::Tensor& grad, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first) { + c10::SmallVector svec_shape; + if (transpose_first){ + svec_shape = array_to_small_vector(shape); + } else { + for (int i = 0; i < perm.size(); i++){ + svec_shape.emplace_back(shape[perm[i]]); + } + } + std::vector vec_perm; + int64_t perm_len = perm.size(); + int64_t temp_perm[perm_len] = {0}; + for (int64_t i = 0; i < perm_len; i++){ + temp_perm[perm[i]] = i; + } + vec_perm = std::vector(temp_perm, temp_perm+perm_len); + perm = at::IntArrayRef(vec_perm); + + at::Tensor result = OpPreparation::ApplyTensor(grad, shape); + + OpCommand cmd; + cmd.Name("ConfusionTransposeD") + .Input(grad) + .Output(result) + .Attr("perm", perm) + .Attr("shape", svec_shape) + .Attr("transpose_first", transpose_first) + .Run(); + + return result; +} + +class NPUConfusionTransposeFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, + const at::Tensor& self, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first) { + ctx->saved_data["perm"] = perm; + ctx->saved_data["shape"] = self.sizes(); + ctx->saved_data["transpose_first"] = !transpose_first; + at::AutoNonVariableTypeMode g; + return confusion_transpose_npu(self, perm, shape, transpose_first); + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto perm = ctx->saved_data["perm"].toIntVector(); + auto shape = ctx->saved_data["shape"].toIntVector(); + auto transpose_first = ctx->saved_data["transpose_first"].toBool(); + at::Tensor result = NPUNativeFunctions::npu_confusion_transpose_backward(grad_outputs[0], perm, shape, transpose_first); + + tensor_list output = {result, + at::Tensor(), + at::Tensor(), + at::Tensor()}; + return output; + } +}; + +at::Tensor NPUNativeFunctions::npu_confusion_transpose(const at::Tensor& self, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first) { + return NPUConfusionTransposeFunction::apply(self, perm, shape, transpose_first); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file