diff --git a/include/csrc/functions.h b/include/csrc/functions.h index bb725cdc30f76bfa4f7c82dacd564015b2576a60..27f71309a0d298be5f5f8f3ca10669b2fb5c141b 100644 --- a/include/csrc/functions.h +++ b/include/csrc/functions.h @@ -36,8 +36,6 @@ at::Tensor npu_scatter_mean_grad(at::Tensor& grad_out, at::Tensor& index, at::Te std::tuple npu_scatter_mean(at::Tensor& src, at::Tensor& index, c10::optional out, c10::optional dim, c10::optional dim_size); -std::tuple npu_sort_pairs( - const at::Tensor& keys_in, const at::Tensor& values_in, int64_t dim, bool descending); at::Tensor npu_hypot(const at::Tensor& x, const at::Tensor& y); diff --git a/mx_driving/_C/__init__.pyi b/mx_driving/_C/__init__.pyi index 1be1169c83f0fceff0984d1c94afbbbd562ed035..6a8f30d6f61410bcefe9ee69096f4c270a5e7323 100644 --- a/mx_driving/_C/__init__.pyi +++ b/mx_driving/_C/__init__.pyi @@ -29,9 +29,6 @@ def npu_scatter_mean( dim: Optional[int] = None, dim_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... -def npu_sort_pairs( - keys_in: torch.Tensor, values_in: torch.Tensor, dim: int, descending: bool -) -> Tuple[torch.Tensor, torch.Tensor]: ... def npu_hypot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ... def npu_hypot_grad( x: torch.Tensor, y: torch.Tensor, out: torch.Tensor, out_grad: torch.Tensor @@ -461,7 +458,6 @@ __all__ = [ "npu_scatter", "npu_scatter_mean_grad", "npu_scatter_mean", - "npu_sort_pairs", "npu_hypot", "npu_hypot_grad", "assign_score_withk", diff --git a/mx_driving/csrc/SortPairs.cpp b/mx_driving/csrc/SortPairs.cpp deleted file mode 100644 index 9ea09a4dfcd6127fed4890b8c6ad5dee5c857350..0000000000000000000000000000000000000000 --- a/mx_driving/csrc/SortPairs.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) 2024 Huawei Technologies Co., Ltd -// Copyright (c) 2019, Facebook CORPORATION. -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "csrc/OpApiCommon.h" -#include "csrc/functions.h" - -/** - * @brief sort key-value pairs - * @param keys_in: keys to be sorted - * @param values_in: values of corresponding keys - * @param dim: dim to sort - * @param descending: true-descending, false-ascending - * @return (keys_out, values_out): (sorted keys, corresponding values of sorted keys) - */ -std::tuple npu_sort_pairs( - const at::Tensor& keys_in, const at::Tensor& values_in, int64_t dim, bool descending) -{ - TORCH_CHECK_NPU(keys_in); - TORCH_CHECK_NPU(values_in); - bool stable = true; - at::Tensor keys_out = at::zeros_like(keys_in, keys_in.options()); - at::Tensor values_out = at::zeros_like(values_in, values_in.options()); - at::Tensor indices = at::zeros_like(values_in, values_in.options().dtype(at::kLong)); - EXEC_NPU_CMD(aclnnSort, keys_in, stable, dim, descending, keys_out, indices); - EXEC_NPU_CMD(aclnnGather, values_in, dim, indices, values_out); - return std::tie(keys_out, values_out); -} diff --git a/mx_driving/csrc/pybind.cpp b/mx_driving/csrc/pybind.cpp index f044c29f8536691fb7533c201657f36b6a64cfa5..4727c880cc661b95d135c8bfa00240070334a4cd 100644 --- a/mx_driving/csrc/pybind.cpp +++ b/mx_driving/csrc/pybind.cpp @@ -47,9 +47,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("scatter_max_with_argmax_v2", &scatter_max_with_argmax_v2); m.def("npu_scatter_max_backward", &npu_scatter_max_backward); - // npu_sort_pairs - m.def("npu_sort_pairs", &npu_sort_pairs, "sort_pairs NPU version"); - // npu_hypot m.def("npu_hypot", &npu_hypot); m.def("npu_hypot_grad", &npu_hypot_grad); diff --git a/mx_driving/ops/sort_pairs.py b/mx_driving/ops/sort_pairs.py deleted file mode 100644 index e6c6f0b468111d5aca45aad98ffdda9d7241ee46..0000000000000000000000000000000000000000 --- a/mx_driving/ops/sort_pairs.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch - -import mx_driving._C - - -class SortPairs(torch.autograd.Function): - @staticmethod - def forward(ctx, keys_in, values_in, dim, descending=False): - res = mx_driving._C.npu_sort_pairs(keys_in, values_in, dim, descending) - return res - - -sort_pairs = SortPairs.apply