diff --git a/test/test_network_ops/test_sort_without_indices.py b/test/test_network_ops/test_sort_without_indices.py new file mode 100644 index 0000000000000000000000000000000000000000..12f552940c798d3d1e1160e0d0cf303e5f5bb4de --- /dev/null +++ b/test/test_network_ops/test_sort_without_indices.py @@ -0,0 +1,88 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np +from torch.nn import functional as F + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, test_2args_broadcast, create_dtype_tensor, UT_FAST_MODE + +class TestSortWithoutIndices(TestCase): + def cpu_default_op_exec(self, input1): + output, _ = torch.sort(input1) + output = output.to(torch.float16).numpy() + return output + + def npu_default_op_exec(self, input1): + output = torch_npu.npu_sort_v2(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec(self, input1, descending): + output, _ = torch.sort(input1, descending=descending) + output = output.to(torch.float16).numpy() + return output + + def npu_op_exec(self, input1, descending): + output = torch_npu.npu_sort_v2(input1, descending=descending) + output = output.to("cpu") + output = output.numpy() + return output + + def test_sort_v2_shape_format(self, device): + shape_format = [ + [[np.float16, 0, (1, 5000)]], + [[np.float16, 0, (1, 50000)]], + [[np.float16, 0, (1, 289600)], False], + [[np.float16, 0, (1, 409600)], True] + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + if len(item) == 1: + cpu_output = self.cpu_default_op_exec(cpu_input1.to(torch.float)) + npu_output = self.npu_default_op_exec(npu_input1) + else: + cpu_output = self.cpu_op_exec(cpu_input1.to(torch.float), item[1]) + npu_output = self.npu_op_exec(npu_input1, item[1]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_sort_v2_shape_format_big_range(self, device): + shape_format = [ + [[np.float16, 0, (1, 5000)]], + [[np.float16, 0, (1, 50000)]], + [[np.float16, 0, (1, 289600)], False], + [[np.float16, 0, (1, 409600)], True] + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -60000, 60000) + if len(item) == 1: + cpu_output = self.cpu_default_op_exec(cpu_input1.to(torch.float)) + npu_output = self.npu_default_op_exec(npu_input1) + else: + cpu_output = self.cpu_op_exec(cpu_input1.to(torch.float), item[1]) + npu_output = self.npu_op_exec(npu_input1, item[1]) + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestSortWithoutIndices, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/SortWithoutIndicesKernelNpu.cpp b/torch_npu/SortWithoutIndicesKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ee7d10cb6873128441946b38de2ff382fe20c58 --- /dev/null +++ b/torch_npu/SortWithoutIndicesKernelNpu.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +at::Tensor& NPUNativeFunctions::npu_sort_v2_out( + const at::Tensor& self, + int64_t dim, + bool descending, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("SortV2") + .Input(self) + .Output(result) + .Attr("axis", dim) + .Attr("descending", descending) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::npu_sort_v2( + const at::Tensor& self, + int64_t dim, + bool descending) { + auto outputSize = input_same_output_size(self); + + at::Tensor result = OpPreparation::ApplyTensor(self); + + npu_sort_v2_out(self, dim, descending, result); + + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index c2624df5c34a5feb2e077fe8689c5c5959bcb26e..c95089e0be685fff0e17894c7202e052cec7fb54 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1914,6 +1914,9 @@ custom: variants: function, method - func: npu_softmax_cross_entropy_with_logits_backward(Tensor grad, Tensor self, Tensor labels) -> Tensor variants: function, method + - func: npu_sort_v2.out(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) out) -> Tensor(a!) + - func: npu_sort_v2(Tensor self, int dim=-1, bool descending=False) -> Tensor + custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor + - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor \ No newline at end of file