diff --git a/kernels/op_kernel/subm_sparse_conv3d_v2.cpp b/kernels/op_kernel/subm_sparse_conv3d_v2.cpp index 1b1ddc998b24ca1dba197dd2e3190ffb4f72109c..337c70538bdcfad13041dbefffd7c04bba025afc 100644 --- a/kernels/op_kernel/subm_sparse_conv3d_v2.cpp +++ b/kernels/op_kernel/subm_sparse_conv3d_v2.cpp @@ -19,6 +19,7 @@ constexpr uint8_t SRC_PARTTEN_1 = 4; constexpr uint8_t SRC_PARTTEN_2 = 5; constexpr uint8_t SRC_PARTTEN_3 = 6; constexpr uint8_t MAP_VAL_FLOAT_BUF_LENGTH = 3; +constexpr uint8_t K2_SIZE_1 = 1; constexpr uint8_t K2_SIZE_3 = 3; constexpr uint8_t K2_SIZE_5 = 5; constexpr int8_t K2_IDX_0 = 0; @@ -254,7 +255,9 @@ public: for (int8_t k0Idx = 0; k0Idx < k0_; k0Idx++) { innerKernelOffset = k0Idx * k1_ * k2Aligned_; for (int8_t k1Idx = 0; k1Idx < k1_; k1Idx++) { - if (k2_ == K2_SIZE_3) { + if (k2 == K2_SIZE_1) { + ProcessOnePoint(i, k0Idx, k1Idx, K2_IDX_0, mapValLocal_.GetValue(innerKernelOffset)); + } else if (k2_ == K2_SIZE_3) { ProcessOnePoint(i, k0Idx, k1Idx, K2_IDX_0, mapValLocal_.GetValue(innerKernelOffset)); ProcessOnePoint(i, k0Idx, k1Idx, K2_IDX_1, mapValLocal_.GetValue(innerKernelOffset + MAP2_OFFSET_1)); ProcessOnePoint(i, k0Idx, k1Idx, K2_IDX_2, mapValLocal_.GetValue(innerKernelOffset + MAP2_OFFSET_2)); @@ -302,7 +305,9 @@ public: int8_t k1Idx = mapIdx % k1Aligned_; int32_t map2Offset = map1Val * spatialShape2_ + spatial2BaseIdx; - if (k2_ == K2_SIZE_3) { + if (k2 == K2_SIZE_1) { + ProcessOnePoint(i, k0Idx, k1Idx, K2_IDX_0, map2GM_.GetValue(map2Offset)); + } else if (k2_ == K2_SIZE_3) { ProcessOnePoint(i, k0Idx, k1Idx, K2_IDX_0, map2GM_.GetValue(map2Offset)); ProcessOnePoint(i, k0Idx, k1Idx, K2_IDX_1, map2GM_.GetValue(map2Offset + MAP2_OFFSET_1)); ProcessOnePoint(i, k0Idx, k1Idx, K2_IDX_2, map2GM_.GetValue(map2Offset + MAP2_OFFSET_2)); diff --git a/mx_driving/csrc/MultiScaleDeformableAttn.cpp b/mx_driving/csrc/MultiScaleDeformableAttn.cpp index 8d54973a1aaa3a789888433c54cbe3d1ccba670a..3d65199eaf2446f6fd7fee387a2b58d130703756 100644 --- a/mx_driving/csrc/MultiScaleDeformableAttn.cpp +++ b/mx_driving/csrc/MultiScaleDeformableAttn.cpp @@ -101,12 +101,12 @@ std::tuple multi_scale_deformable_attn_backw } if (ASCEND_UNLIKELY(value.scalar_type() == at::kHalf)) { - at::Tensor grad_value_fp32 = grad_value.to(at::kFloat); + at::Tensor grad_output_fp32 = grad_output.to(at::kFloat); at::Tensor value_fp32 = value.to(at::kFloat); at::Tensor sampling_locations_fp32 = sampling_locations.to(at::kFloat); at::Tensor attention_weights_fp32 = attention_weights.to(at::kFloat); EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnGrad, value_fp32, value_spatial_shapes, value_level_start_index, - sampling_locations_fp32, attention_weights_fp32, grad_value_fp32, grad_value, grad_sampling_loc, + sampling_locations_fp32, attention_weights_fp32, grad_output_fp32, grad_value, grad_sampling_loc, grad_attn_weight); return std::make_tuple( grad_value.to(at::kHalf), grad_sampling_loc.to(at::kHalf), grad_attn_weight.to(at::kHalf)); diff --git a/mx_driving/csrc/SubmSparseCov3d.cpp b/mx_driving/csrc/SubmSparseCov3d.cpp index 2940075be7e212627d29ff1fef51b11c6741e55f..779058ad41dc7b6ce9e62b6f2963ccd1cd77d3ba 100644 --- a/mx_driving/csrc/SubmSparseCov3d.cpp +++ b/mx_driving/csrc/SubmSparseCov3d.cpp @@ -19,6 +19,7 @@ namespace { constexpr size_t TOTAL_CAPACITY = 8; + constexpr uint8_t KERNEL_SIZE_1 = 1; constexpr uint8_t KERNEL_SIZE_3 = 3; constexpr uint8_t KERNEL_SIZE_5 = 5; constexpr uint32_t KERNEL_SIZE_IDX_0 = 0; @@ -62,9 +63,10 @@ std::tuple npu_subm_sparse_conv3d_v2(const at::Tensor& f { TORCH_CHECK_NPU(feature); TORCH_CHECK_NPU(indices); - TORCH_CHECK((kernel_size[KERNEL_SIZE_IDX_0] == KERNEL_SIZE_3 && kernel_size[KERNEL_SIZE_IDX_1] == KERNEL_SIZE_3 && kernel_size[KERNEL_SIZE_IDX_2] == KERNEL_SIZE_3) || + TORCH_CHECK((kernel_size[KERNEL_SIZE_IDX_0] == KERNEL_SIZE_1 && kernel_size[KERNEL_SIZE_IDX_1] == KERNEL_SIZE_1 && kernel_size[KERNEL_SIZE_IDX_2] == KERNEL_SIZE_1) || + (kernel_size[KERNEL_SIZE_IDX_0] == KERNEL_SIZE_3 && kernel_size[KERNEL_SIZE_IDX_1] == KERNEL_SIZE_3 && kernel_size[KERNEL_SIZE_IDX_2] == KERNEL_SIZE_3) || (kernel_size[KERNEL_SIZE_IDX_0] == KERNEL_SIZE_5 && kernel_size[KERNEL_SIZE_IDX_1] == KERNEL_SIZE_5 && kernel_size[KERNEL_SIZE_IDX_2] == KERNEL_SIZE_5), - "kernel size current only support (3, 3, 3) and (5, 5, 5) but got: (", + "kernel size current only support (1, 1, 1), (3, 3, 3) and (5, 5, 5) but got: (", kernel_size[KERNEL_SIZE_IDX_0], ", ", kernel_size[KERNEL_SIZE_IDX_1], ", ", kernel_size[KERNEL_SIZE_IDX_2], ")"); auto indices_size = indices.sizes(); diff --git a/tests/torch/test_multi_scale_deformable_attn.py b/tests/torch/test_multi_scale_deformable_attn.py index 1265bac41ac8f0cf23ab690042866712c663b61b..9d7491a8dc7f74b3b6f8a3bafb17a9e2a2cc7b45 100644 --- a/tests/torch/test_multi_scale_deformable_attn.py +++ b/tests/torch/test_multi_scale_deformable_attn.py @@ -202,6 +202,16 @@ class TestMultiScaleDeformableAttnFunction(TestCase): self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights) self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations) + def test_fp16(self): + shape = [6, 9680, 32, 8, 4, 4] + cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float16) + cpu_results = self.cpu_to_exec(cpu_inputs) + npu_results = self.npu_to_exec(npu_inputs) + self.assertRtolEqual(cpu_results.output, npu_results.output) + self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value) + self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights) + self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations) + if __name__ == "__main__": run_tests() diff --git a/tests/torch/test_subm_sparse_conv3d.py b/tests/torch/test_subm_sparse_conv3d.py index 741f7ecde737ddb7de34a8de9233699f2c1443a2..4cf778c12eac6ac6a51b24e7966369a9594b98c9 100644 --- a/tests/torch/test_subm_sparse_conv3d.py +++ b/tests/torch/test_subm_sparse_conv3d.py @@ -254,6 +254,41 @@ class TestSubmSparseConv3d(TestCase): res, golden = get_output(num_points, batch_size, in_channels, out_channels, kernel_size, out_spatial_shape, dtype) self.assertRtolEqual(golden, res, 1e-3, 1e-3) + def test_1x1_small_spatial_shape(self): + num_points = [20000] + out_spatial_shape = [1180, 180, 5] + in_channels = 32 + out_channels = 64 + kernel_size = 1 + batch_size = len(num_points) + + res, golden = get_output(num_points, batch_size, in_channels, out_channels, kernel_size, out_spatial_shape) + self.assertRtolEqual(golden, res, 1e-3, 1e-3) + + def test_1x1_small_spatial_shape_fp16(self): + num_points = [20000] + out_spatial_shape = [1180, 180, 5] + in_channels = 32 + out_channels = 64 + kernel_size = 1 + batch_size = len(num_points) + dtype = torch.float16 + + res, golden = get_output(num_points, batch_size, in_channels, out_channels, kernel_size, out_spatial_shape, dtype) + self.assertRtolEqual(golden, res, 1e-3, 1e-3) + + def test_1x1_large_spatial_shape_fp16(self): + num_points = [10000] + out_spatial_shape = [3280, 2480, 500] + in_channels = 128 + out_channels = 256 + kernel_size = 1 + batch_size = len(num_points) + dtype = torch.float16 + + res, golden = get_output(num_points, batch_size, in_channels, out_channels, kernel_size, out_spatial_shape, dtype) + self.assertRtolEqual(golden, res, 1e-3, 1e-3) + if __name__ == "__main__": np.random.seed(100) run_tests() \ No newline at end of file