From f857818c3153feac4a43ac716debb01416492a30 Mon Sep 17 00:00:00 2001 From: huangyuan Date: Fri, 11 Jul 2025 09:50:49 +0800 Subject: [PATCH] add inductor graph mode adapt --- include/csrc/functions.h | 2 +- mx_driving/csrc/MaxPool2d.cpp | 6 ++-- mx_driving/csrc/pybind.cpp | 34 +++++++++++++++++++ mx_driving/ops/deform_conv2d.py | 18 ++++++++-- mx_driving/ops/multi_scale_deformable_attn.py | 16 +++++++-- mx_driving/ops/npu_add_relu.py | 15 ++++++-- mx_driving/ops/npu_batch_matmul.py | 13 +++++-- mx_driving/ops/npu_deformable_aggregation.py | 15 ++++++-- mx_driving/ops/npu_max_pool2d.py | 12 ++++++- 9 files changed, 115 insertions(+), 16 deletions(-) diff --git a/include/csrc/functions.h b/include/csrc/functions.h index 8fb5e700..8cfdb6ad 100644 --- a/include/csrc/functions.h +++ b/include/csrc/functions.h @@ -52,7 +52,7 @@ void assign_score_withk_grad(const at::Tensor& grad_out, const at::Tensor& point const at::Tensor& knn_idx, at::Tensor& grad_points, at::Tensor& grad_centers, at::Tensor& grad_scores, int32_t B, int32_t N, int32_t npoint, int32_t M, int32_t K, int32_t out_dim, int32_t aggregate); -at::Tensor npu_max_pool2d(const at::Tensor& x, int kernel_size, int stride, int padding); +at::Tensor npu_max_pool2d(const at::Tensor& x, int64_t kernel_size, int64_t stride, int64_t padding); at::Tensor multi_scale_deformable_attn(const at::Tensor& value, const at::Tensor& value_spatial_shapes, const at::Tensor& value_level_start_index, const at::Tensor& sampling_locations, diff --git a/mx_driving/csrc/MaxPool2d.cpp b/mx_driving/csrc/MaxPool2d.cpp index 56fc798b..147f6486 100644 --- a/mx_driving/csrc/MaxPool2d.cpp +++ b/mx_driving/csrc/MaxPool2d.cpp @@ -20,7 +20,7 @@ constexpr size_t C_LIMIT = 64; constexpr size_t X_NUM_LIMIT = 1000000000; -at::Tensor npu_max_pool2d(const at::Tensor& x, int kernel_size, int stride, int padding) +at::Tensor npu_max_pool2d(const at::Tensor& x, int64_t kernel_size, int64_t stride, int64_t padding) { TORCH_CHECK_NPU(x); TORCH_CHECK(x.scalar_type() == at::kFloat || x.scalar_type() == at::kHalf, @@ -68,13 +68,13 @@ at::Tensor npu_max_pool2d(const at::Tensor& x, int kernel_size, int stride, int TORCH_CHECK(channel % 16 == 0, "channel: expected 16X when dtype is fp16 but got: ", channel); } - at::Tensor x_trans = x.permute({0, 2, 3, 1}); + at::Tensor x_trans = x.permute({0, 2, 3, 1}).contiguous(); auto output_size = {batch, output_height, output_width, channel}; at::Tensor y_trans = at::empty(output_size, x.options()); EXEC_NPU_CMD(aclnnMaxPool2d, x_trans, y_trans); - at::Tensor y = y_trans.permute({0, 3, 1, 2}); + at::Tensor y = y_trans.permute({0, 3, 1, 2}).contiguous(); return y; } } diff --git a/mx_driving/csrc/pybind.cpp b/mx_driving/csrc/pybind.cpp index 84f7b35a..336c5ce4 100644 --- a/mx_driving/csrc/pybind.cpp +++ b/mx_driving/csrc/pybind.cpp @@ -255,3 +255,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) // grid_sampler3d_grad_v1 m.def("grid_sampler3d_grad_v1", &grid_sampler3d_grad_v1); } + +namespace { +TORCH_LIBRARY_FRAGMENT(mx_driving, m) +{ + m.def("multi_scale_deformable_attn(Tensor value, Tensor value_spatial_shapes, Tensor value_level_start_index, Tensor sampling_locations, Tensor attention_weights) -> (Tensor y)"); + m.def("multi_scale_deformable_attn_backward(Tensor value, Tensor value_spatial_shapes, Tensor value_level_start_index, Tensor sampling_locations, Tensor attention_weights, Tensor grad_output) -> (Tensor grad_value, Tensor grad_sampling_loc, Tensor grad_attn_weight)"); + m.def("deformable_aggregation(Tensor mc_ms_feat, Tensor spatial_shape, Tensor scale_start_index, Tensor sampling_location, Tensor weights) -> (Tensor out)"); + m.def("deformable_aggregation_backward(Tensor mc_ms_feat, Tensor spatial_shape, Tensor scale_start_index, Tensor sampling_location, Tensor weights, Tensor grad_output, Tensor grad_mc_ms_feat, Tensor grad_sampling_location, Tensor grad_weights) -> (Tensor grad_mc_ms_feat, Tensor grad_sampling_location, Tensor grad_weights)"); + m.def("deformable_conv2d(Tensor input, Tensor offset, Tensor weight, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups, int deformable_groups) -> (Tensor result1, Tensor result2)"); + m.def("deformable_conv2d_backward(Tensor input, Tensor weight, Tensor offset, Tensor offset_output, Tensor grad_y, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups, int deformable_groups) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_offset)"); + m.def("modulated_deformable_conv2d(Tensor input, Tensor offset, Tensor weight, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups, int deformable_groups) -> (Tensor result1, Tensor result2)"); + m.def("modulated_deformable_conv2d_backward(Tensor input, Tensor weight, Tensor offset, Tensor offset_output, Tensor grad_y, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups, int deformable_groups) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_offset)"); + m.def("npu_batch_matmul(Tensor projection_mat, Tensor pts_extend) -> (Tensor result)"); + m.def("npu_max_pool2d(Tensor x, int kernel_size, int stride, int padding) -> (Tensor y)"); + m.def("npu_add_relu(Tensor x, Tensor y) -> (Tensor result)"); + m.def("npu_add_relu_grad(Tensor self, Tensor grad_output) -> (Tensor result)"); +} +} + +namespace { +TORCH_LIBRARY_IMPL(mx_driving, PrivateUse1, m) +{ + m.impl("multi_scale_deformable_attn", TORCH_FN(multi_scale_deformable_attn)); + m.impl("multi_scale_deformable_attn_backward", TORCH_FN(multi_scale_deformable_attn_backward)); + m.impl("deformable_aggregation", TORCH_FN(deformable_aggregation)); + m.impl("deformable_aggregation_backward", TORCH_FN(deformable_aggregation_backward)); + m.impl("deformable_conv2d", TORCH_FN(deformable_conv2d)); + m.impl("deformable_conv2d_backward", TORCH_FN(deformable_conv2d_backward)); + m.impl("npu_batch_matmul", TORCH_FN(npu_batch_matmul)); + m.impl("npu_max_pool2d", TORCH_FN(npu_max_pool2d)); + m.impl("npu_add_relu", TORCH_FN(npu_add_relu)); + m.impl("npu_add_relu_grad", TORCH_FN(npu_add_relu_grad)); +} +} \ No newline at end of file diff --git a/mx_driving/ops/deform_conv2d.py b/mx_driving/ops/deform_conv2d.py index 3e0c6330..1c61621b 100644 --- a/mx_driving/ops/deform_conv2d.py +++ b/mx_driving/ops/deform_conv2d.py @@ -9,6 +9,7 @@ Modification 1. Add support for Ascend NPU from typing import Tuple, Union import torch +from torch.library import Library, impl from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair @@ -43,7 +44,7 @@ class DeformConv2dFunction(Function): nhwc_offset = offset.permute(0, 2, 3, 1).contiguous() nhwc_weight = weight.permute(0, 2, 3, 1).contiguous() - out, offset_output = mx_driving._C.deformable_conv2d( + out, offset_output = torch.ops.mx_driving.deformable_conv2d( nhwc_x, nhwc_offset, nhwc_weight, @@ -63,7 +64,7 @@ class DeformConv2dFunction(Function): def backward(ctx, grad_out): nhwc_x, nhwc_offset, nhwc_weight, offset_output = ctx.saved_tensors nhwc_grad_out = grad_out.permute(0, 2, 1, 3).contiguous() - grad_x, grad_weight, grad_offset = mx_driving._C.deformable_conv2d_backward( + grad_x, grad_weight, grad_offset = torch.ops.mx_driving.deformable_conv2d_backward( nhwc_x, nhwc_weight, nhwc_offset, @@ -89,3 +90,16 @@ class DeformConv2dFunction(Function): deform_conv2d = DeformConv2dFunction.apply + +m = Library("mx_driving", "IMPL", "Meta") + +@impl(m, "deformable_conv2d", "Meta") +def custom_op_meta(nhwc_x, nhwc_offset, nhwc_weight, kernel_size, stride, padding, dilation, groups, deformable_groups): + n, h_in, w_in, c_in = nhwc_x.shape[0], nhwc_x.shape[1], nhwc_x.shape[2] + c_out, h_out, w_out = nhwc_offset.shape[0], nhwc_offset.shape[1], nhwc_offset.shape[2] + kh, kw = nhwc_weight.shape[1], nhwc_weight.shape[2] + return torch.empty(n, h_out, c_out, w_out, device=nhwc_x.device, dtype=nhwc_x.dtype), torch.empty(n, h_out * w_out, groups, kh * kw, c_in / groups, device=nhwc_x.device, dtype=nhwc_x.dtype) + +@impl(m, "deformable_conv2d_backward", "Meta") +def custom_op_meta(nhwc_x, nhwc_weight, nhwc_offset, offset_output, nhwc_grad_out, kernel_size, stride, padding, dilation, groups, deformable_groups): + return torch.empty_like(nhwc_x), torch.empty_like(nhwc_weight), torch.empty_like(nhwc_offset) \ No newline at end of file diff --git a/mx_driving/ops/multi_scale_deformable_attn.py b/mx_driving/ops/multi_scale_deformable_attn.py index 0e213489..c2494f9c 100644 --- a/mx_driving/ops/multi_scale_deformable_attn.py +++ b/mx_driving/ops/multi_scale_deformable_attn.py @@ -10,6 +10,7 @@ Modification 1. Add support for Ascend NPU import warnings import torch +from torch.library import Library, impl from torch.autograd.function import Function, once_differentiable from torch.npu.amp import custom_bwd, custom_fwd import mx_driving._C @@ -33,7 +34,7 @@ class MultiScaleDeformableAttnFunction(Function): sampling_locations = sampling_locations.type_as(value) attention_weights = attention_weights.type_as(value) - output = mx_driving._C.multi_scale_deformable_attn( + output = torch.ops.mx_driving.multi_scale_deformable_attn( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights ) ctx.save_for_backward( @@ -47,7 +48,7 @@ class MultiScaleDeformableAttnFunction(Function): # pylint: disable=too-many-return-values def backward(ctx, grad_output: torch.Tensor) -> tuple: value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors - grad_value, grad_sampling_loc, grad_attn_weight = mx_driving._C.multi_scale_deformable_attn_backward( + grad_value, grad_sampling_loc, grad_attn_weight = torch.ops.mx_driving.multi_scale_deformable_attn_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output ) return grad_value, None, None, grad_sampling_loc, grad_attn_weight @@ -81,3 +82,14 @@ def npu_multi_scale_deformable_attn_function(value, shape, offset, locations, we DeprecationWarning, ) return MultiScaleDeformableAttnFunction.apply(value, shape, offset, locations, weight) + + +m = Library("mx_driving", "IMPL", "Meta") + +@impl(m, "multi_scale_deformable_attn", "Meta") +def custom_op_meta(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights): + return torch.empty(value.shape[0], sampling_locations.shape[1], value.shape[2] * value.shape[3], device=value.device, dtype=value.dtype) + +@impl(m, "multi_scale_deformable_attn_backward", "Meta") +def custom_op_meta(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights): + return torch.empty_like(value), torch.empty_like(sampling_locations), torch.empty_like(attention_weights) diff --git a/mx_driving/ops/npu_add_relu.py b/mx_driving/ops/npu_add_relu.py index dbb66145..77d9ca75 100644 --- a/mx_driving/ops/npu_add_relu.py +++ b/mx_driving/ops/npu_add_relu.py @@ -8,6 +8,7 @@ Modification 1. Add support for Ascend NPU """ import torch +from torch.library import Library, impl import torch.nn.functional as F import torch_npu from torch.autograd import Function @@ -19,7 +20,7 @@ class AddReluFunction(Function): @staticmethod def forward(ctx, x, y): if x.numel() >= 2000000: - x = mx_driving._C.npu_add_relu(x, y) + x = torch.ops.mx_driving.npu_add_relu(x, y) else: x = F.relu(x + y) ctx.save_for_backward(x) @@ -28,8 +29,18 @@ class AddReluFunction(Function): @staticmethod def backward(ctx, grad_output): (x,) = ctx.saved_tensors - result = mx_driving._C.npu_add_relu_grad(x, grad_output) + result = torch.ops.mx_driving.npu_add_relu_grad(x, grad_output) return result, result npu_add_relu = AddReluFunction.apply + +m = Library("mx_driving", "IMPL", "Meta") + +@impl(m, "npu_add_relu", "Meta") +def custom_op_meta(x, y): + return torch.empty_like(x) + +@impl(m, "npu_add_relu_grad", "Meta") +def custom_op_meta(x, grad_output): + return torch.empty_like(x) diff --git a/mx_driving/ops/npu_batch_matmul.py b/mx_driving/ops/npu_batch_matmul.py index d8ed8b96..00bcfae3 100644 --- a/mx_driving/ops/npu_batch_matmul.py +++ b/mx_driving/ops/npu_batch_matmul.py @@ -3,6 +3,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. """ import torch +from torch.library import Library, impl import torch.nn.functional as F import torch_npu from torch.autograd import Function @@ -17,7 +18,7 @@ class BacthMatmulFunction(Function): broadcast_shape = [max(a, b) for a, b in zip(projection_mat.shape, pts_extend.shape)] projection_mat = projection_mat.expand(broadcast_shape).contiguous() pts_extend = pts_extend.expand(broadcast_shape).contiguous() - result = mx_driving._C.npu_batch_matmul(projection_mat, pts_extend) + result = torch.ops.mx_driving.npu_batch_matmul(projection_mat, pts_extend) result = result.sum(dim=-1, keepdim=True) ctx.save_for_backward(projection_mat, pts_extend) return result @@ -27,10 +28,16 @@ class BacthMatmulFunction(Function): (projection_mat, pts_extend) = ctx.saved_tensors broadcast_shape = projection_mat.shape grad = grad.expand(broadcast_shape).contiguous() - dx = mx_driving._C.npu_batch_matmul(grad, pts_extend) - dw = mx_driving._C.npu_batch_matmul(projection_mat, grad) + dx = torch.ops.mx_driving.npu_batch_matmul(grad, pts_extend) + dw = torch.ops.mx_driving.npu_batch_matmul(projection_mat, grad) dw = dw.sum(dim=-2, keepdim=True).transpose(-1, -2).contiguous() return dx, dw npu_batch_matmul = BacthMatmulFunction.apply + +m = Library("mx_driving", "IMPL", "Meta") + +@impl(m, "npu_batch_matmul", "Meta") +def custom_op_meta(projection_mat, pts_extend): + return torch.empty_like(pts_extend) diff --git a/mx_driving/ops/npu_deformable_aggregation.py b/mx_driving/ops/npu_deformable_aggregation.py index 1940802d..cf934363 100644 --- a/mx_driving/ops/npu_deformable_aggregation.py +++ b/mx_driving/ops/npu_deformable_aggregation.py @@ -1,5 +1,6 @@ import numpy as np import torch +from torch.library import Library, impl import torch_npu from torch.autograd import Function @@ -27,7 +28,7 @@ class AdsDeformableAggregation(Function): sampling_location = sampling_location.contiguous() weights = weights.contiguous() - output = mx_driving._C.npu_deformable_aggregation( + output = torch.ops.mx_driving.deformable_aggregation( mc_ms_feat, spatial_shape, scale_start_index, @@ -64,7 +65,7 @@ class AdsDeformableAggregation(Function): grad_mc_ms_feat = torch.zeros_like(mc_ms_feat) grad_sampling_location = torch.zeros_like(sampling_location) grad_weights = torch.zeros_like(weights) - grad_mc_ms_feat, grad_sampling_location, grad_weights = mx_driving._C.npu_deformable_aggregation_backward( + grad_mc_ms_feat, grad_sampling_location, grad_weights = torch.ops.mx_driving.deformable_aggregation_backward( mc_ms_feat, spatial_shape, scale_start_index, @@ -86,3 +87,13 @@ class AdsDeformableAggregation(Function): npu_deformable_aggregation = AdsDeformableAggregation.apply deformable_aggregation = AdsDeformableAggregation.apply + +m = Library("mx_driving", "IMPL", "Meta") + +@impl(m, "deformable_aggregation", "Meta") +def custom_op_meta(mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights): + return torch.empty(mc_ms_feat.shape[0], sampling_location.shape[1], mc_ms_feat.shape[2], device=mc_ms_feat.device, dtype=mc_ms_feat.dtype) + +@impl(m, "deformable_aggregation_backward", "Meta") +def custom_op_meta(mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, grad_output, grad_mc_ms_feat, grad_sampling_location_padding, grad_weights): + return torch.empty_like(mc_ms_feat), torch.empty_like(sampling_location), torch.empty_like(weights) \ No newline at end of file diff --git a/mx_driving/ops/npu_max_pool2d.py b/mx_driving/ops/npu_max_pool2d.py index fa4e72f0..50dace78 100644 --- a/mx_driving/ops/npu_max_pool2d.py +++ b/mx_driving/ops/npu_max_pool2d.py @@ -6,6 +6,8 @@ Modification date: 2024-06-04 Modification Description: Modification 1. Add support for Ascend NPU """ +import torch +from torch.library import Library, impl from torch.autograd import Function import mx_driving._C @@ -14,7 +16,15 @@ class MaxPool2d(Function): @staticmethod # 'pylint: disable=too-many-arguments,huawei-too-many-arguments def forward(ctx, x, kernel_size, stride, padding): - y = mx_driving._C.npu_max_pool2d(x, kernel_size, stride, padding) + y = torch.ops.mx_driving.npu_max_pool2d(x, kernel_size, stride, padding) return y npu_max_pool2d = MaxPool2d.apply + +m = Library("mx_driving", "IMPL", "Meta") + +@impl(m, "npu_max_pool2d", "Meta") +def custom_op_meta(x, kernel_size, stride, padding): + height = int((x.shape[2] + 1) / 2) + width = int((x.shape[3] + 1) / 2) + return torch.empty(x.shape[0], x.shape[1], height, width, device=x.device, dtype=x.dtype) -- Gitee