From 9b8f1411f15fdf9caec990b1beb9efc8dbaa041e Mon Sep 17 00:00:00 2001 From: qiucan4 Date: Wed, 10 Sep 2025 11:49:05 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4sort=5Fpairs=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/csrc/functions.h | 2 -- mx_driving/_C/__init__.pyi | 4 ---- mx_driving/csrc/SortPairs.cpp | 40 ----------------------------------- mx_driving/csrc/pybind.cpp | 3 --- mx_driving/ops/sort_pairs.py | 13 ------------ 5 files changed, 62 deletions(-) delete mode 100644 mx_driving/csrc/SortPairs.cpp delete mode 100644 mx_driving/ops/sort_pairs.py diff --git a/include/csrc/functions.h b/include/csrc/functions.h index bb725cdc..27f71309 100644 --- a/include/csrc/functions.h +++ b/include/csrc/functions.h @@ -36,8 +36,6 @@ at::Tensor npu_scatter_mean_grad(at::Tensor& grad_out, at::Tensor& index, at::Te std::tuple npu_scatter_mean(at::Tensor& src, at::Tensor& index, c10::optional out, c10::optional dim, c10::optional dim_size); -std::tuple npu_sort_pairs( - const at::Tensor& keys_in, const at::Tensor& values_in, int64_t dim, bool descending); at::Tensor npu_hypot(const at::Tensor& x, const at::Tensor& y); diff --git a/mx_driving/_C/__init__.pyi b/mx_driving/_C/__init__.pyi index 1be1169c..6a8f30d6 100644 --- a/mx_driving/_C/__init__.pyi +++ b/mx_driving/_C/__init__.pyi @@ -29,9 +29,6 @@ def npu_scatter_mean( dim: Optional[int] = None, dim_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... -def npu_sort_pairs( - keys_in: torch.Tensor, values_in: torch.Tensor, dim: int, descending: bool -) -> Tuple[torch.Tensor, torch.Tensor]: ... def npu_hypot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ... def npu_hypot_grad( x: torch.Tensor, y: torch.Tensor, out: torch.Tensor, out_grad: torch.Tensor @@ -461,7 +458,6 @@ __all__ = [ "npu_scatter", "npu_scatter_mean_grad", "npu_scatter_mean", - "npu_sort_pairs", "npu_hypot", "npu_hypot_grad", "assign_score_withk", diff --git a/mx_driving/csrc/SortPairs.cpp b/mx_driving/csrc/SortPairs.cpp deleted file mode 100644 index 9ea09a4d..00000000 --- a/mx_driving/csrc/SortPairs.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) 2024 Huawei Technologies Co., Ltd -// Copyright (c) 2019, Facebook CORPORATION. -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "csrc/OpApiCommon.h" -#include "csrc/functions.h" - -/** - * @brief sort key-value pairs - * @param keys_in: keys to be sorted - * @param values_in: values of corresponding keys - * @param dim: dim to sort - * @param descending: true-descending, false-ascending - * @return (keys_out, values_out): (sorted keys, corresponding values of sorted keys) - */ -std::tuple npu_sort_pairs( - const at::Tensor& keys_in, const at::Tensor& values_in, int64_t dim, bool descending) -{ - TORCH_CHECK_NPU(keys_in); - TORCH_CHECK_NPU(values_in); - bool stable = true; - at::Tensor keys_out = at::zeros_like(keys_in, keys_in.options()); - at::Tensor values_out = at::zeros_like(values_in, values_in.options()); - at::Tensor indices = at::zeros_like(values_in, values_in.options().dtype(at::kLong)); - EXEC_NPU_CMD(aclnnSort, keys_in, stable, dim, descending, keys_out, indices); - EXEC_NPU_CMD(aclnnGather, values_in, dim, indices, values_out); - return std::tie(keys_out, values_out); -} diff --git a/mx_driving/csrc/pybind.cpp b/mx_driving/csrc/pybind.cpp index f044c29f..4727c880 100644 --- a/mx_driving/csrc/pybind.cpp +++ b/mx_driving/csrc/pybind.cpp @@ -47,9 +47,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("scatter_max_with_argmax_v2", &scatter_max_with_argmax_v2); m.def("npu_scatter_max_backward", &npu_scatter_max_backward); - // npu_sort_pairs - m.def("npu_sort_pairs", &npu_sort_pairs, "sort_pairs NPU version"); - // npu_hypot m.def("npu_hypot", &npu_hypot); m.def("npu_hypot_grad", &npu_hypot_grad); diff --git a/mx_driving/ops/sort_pairs.py b/mx_driving/ops/sort_pairs.py deleted file mode 100644 index e6c6f0b4..00000000 --- a/mx_driving/ops/sort_pairs.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch - -import mx_driving._C - - -class SortPairs(torch.autograd.Function): - @staticmethod - def forward(ctx, keys_in, values_in, dim, descending=False): - res = mx_driving._C.npu_sort_pairs(keys_in, values_in, dim, descending) - return res - - -sort_pairs = SortPairs.apply -- Gitee