From 535972084722b1fdf40e45f242c1a6e9fd7db2a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=BA=E9=9B=A8=E6=9D=B0?= Date: Mon, 19 May 2025 10:22:24 +0000 Subject: [PATCH 1/2] update kernels/op_kernel/subm_sparse_conv3d_v2.cpp. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 贺雨杰 --- .../op_host/subm_sparse_conv3d_v2_tiling.cpp | 6 ++- .../op_host/subm_sparse_conv3d_v2_tiling.h | 1 + kernels/op_kernel/subm_sparse_conv3d_v2.cpp | 54 +++++++++++++++---- mx_driving/csrc/SubmSparseCov3d.cpp | 2 +- mx_driving/ops/sparse_functional.py | 13 ++--- 5 files changed, 55 insertions(+), 21 deletions(-) diff --git a/kernels/op_host/subm_sparse_conv3d_v2_tiling.cpp b/kernels/op_host/subm_sparse_conv3d_v2_tiling.cpp index d772e6f7..7c417348 100644 --- a/kernels/op_host/subm_sparse_conv3d_v2_tiling.cpp +++ b/kernels/op_host/subm_sparse_conv3d_v2_tiling.cpp @@ -72,8 +72,11 @@ static ge::graphStatus TilingFunc(gert::TilingContext* context) uint32_t totalTaskCount = featureShapeArr.GetDim(TOTAL_TASK_DIM_IDX); uint32_t coreTaskCount = totalTaskCount / aivNum; uint32_t bigCoreCount = totalTaskCount % aivNum; + int32_t kernelSizeAligned = CeilAlign(static_cast(kernelSizeArr[KERNEL_SIZE_IDX_0] * kernelSizeArr[KERNEL_SIZE_IDX_1] * + kernelSizeArr[KERNEL_SIZE_IDX_2]), BYTE_ALIGN_SIZE / FLOAT_BYTE_SIZE); uint32_t singleLoopTask = ubSize / (SINGLE_LOOP_UB_SIZE + - CeilAlign(*inChannelsPtr, BYTE_ALIGN_SIZE / FLOAT_BYTE_SIZE) * FLOAT_BYTE_SIZE); + CeilAlign(*inChannelsPtr, BYTE_ALIGN_SIZE / FLOAT_BYTE_SIZE) * FLOAT_BYTE_SIZE + + CeilAlign(kernelSizeAligned, BYTE_ALIGN_SIZE / FLOAT_BYTE_SIZE) * FLOAT_BYTE_SIZE); tiling.set_k0(kernelSizeArr[KERNEL_SIZE_IDX_0]); tiling.set_k1(kernelSizeArr[KERNEL_SIZE_IDX_1]); @@ -87,6 +90,7 @@ static ge::graphStatus TilingFunc(gert::TilingContext* context) tiling.set_coreTaskCount(coreTaskCount); tiling.set_bigCoreCount(bigCoreCount); tiling.set_singleLoopTask(singleLoopTask); + tiling.set_totalTaskCount(totalTaskCount); if (context->GetRawTilingData() == nullptr) { return ge::GRAPH_FAILED; diff --git a/kernels/op_host/subm_sparse_conv3d_v2_tiling.h b/kernels/op_host/subm_sparse_conv3d_v2_tiling.h index e7f354e4..3cfb9f3e 100644 --- a/kernels/op_host/subm_sparse_conv3d_v2_tiling.h +++ b/kernels/op_host/subm_sparse_conv3d_v2_tiling.h @@ -19,6 +19,7 @@ BEGIN_TILING_DATA_DEF(SubmSparseConv3dV2TilingData) TILING_DATA_FIELD_DEF(uint32_t, coreTaskCount); TILING_DATA_FIELD_DEF(uint32_t, bigCoreCount); TILING_DATA_FIELD_DEF(uint32_t, singleLoopTask); + TILING_DATA_FIELD_DEF(uint32_t, totalTaskCount); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(SubmSparseConv3dV2, SubmSparseConv3dV2TilingData) diff --git a/kernels/op_kernel/subm_sparse_conv3d_v2.cpp b/kernels/op_kernel/subm_sparse_conv3d_v2.cpp index 9b01eb11..8e50a076 100644 --- a/kernels/op_kernel/subm_sparse_conv3d_v2.cpp +++ b/kernels/op_kernel/subm_sparse_conv3d_v2.cpp @@ -31,6 +31,7 @@ constexpr int32_t MAP2_OFFSET_1 = 1; constexpr int32_t MAP2_OFFSET_2 = 2; constexpr int32_t MAP2_OFFSET_3 = 3; constexpr int32_t MAP2_OFFSET_4 = 4; +constexpr float SPARSE_THRESHOLD = 0.5; }; class KernelSubmSparseConv3dV2 { @@ -63,6 +64,7 @@ public: spatialShape12_ = spatialShape1_ * spatialShape2_; totalSpatialShape_ = (int64_t)spatialShape01_ * spatialShape2_; useTwolevelMap_ = totalSpatialShape_ * (int64_t)batchSize_ >= SPATIAL_SHAPE_THRESHOLD; + sparseRate = (tilingData->totalTaskCount + 0.0f) / (float)(totalSpatialShape_ * batchSize_); copyByteSize_ = inChannels_ * FLOAT_BYTE_SIZE; if (blkIdx_ < tilingData->bigCoreCount) { @@ -99,6 +101,7 @@ public: pipe_->InitBuffer(tmpFeatureBuf_, singleLoopTaskAligned_ * inChannelsAligned_ * FLOAT_BYTE_SIZE); pipe_->InitBuffer(mapValBuf_, k0_ * k1_ * k2Aligned_ * INT32_BYTE_SIZE); pipe_->InitBuffer(mapValFloatBuf_, MAP_VAL_FLOAT_BUF_LENGTH * k0_ * k1_ * k2Aligned_ * FLOAT_BYTE_SIZE); + pipe_->InitBuffer(indicesOffsetBuf_, singleLoopTaskAligned_ * kernelSizeAligned_ * INT32_BYTE_SIZE); inputIndicesLocal_ = inputIndicesBuf_.Get(); tmpFeatureLocal_ = tmpFeatureBuf_.Get(); @@ -111,6 +114,7 @@ public: mapValFloatLocal_ = mapValFloatBuf_.Get(); mapValFloatLocalBak_ = mapValFloatLocal_[k0_ * k1_ * k2Aligned_]; workLocal_ = mapValFloatLocal_[WORK_LOCAL_IDX * k0_ * k1_ * k2Aligned_]; + indicesOffsetLocal_ = indicesOffsetBuf_.Get(); } __aicore__ inline void Init(TPipe *pipe, GM_ADDR feature, GM_ADDR indices, GM_ADDR map1, GM_ADDR map2, @@ -129,7 +133,6 @@ public: for (int32_t taskOffset = 0; taskOffset < coreTaskCount_; taskOffset += singleLoopTask_, globalTaskOffset_ += singleLoopTask_) { uint32_t taskCount = min(singleLoopTask_, coreTaskCount_ - taskOffset); - uint32_t copyInDataElemCount = AlignUp(taskCount * INDICES_TASK_SIZE, ALIGNED_BYTE_SIZE / FLOAT_BYTE_SIZE); PipeBarrier(); DataCopyPad(tmpFeatureLocal_, inputFeatureGM_[globalTaskOffset_ * inChannels_], @@ -139,9 +142,10 @@ public: {static_cast(taskCount), static_cast(inChannels_ * FLOAT_BYTE_SIZE), 0, static_cast(kernelSize_ - 1) * inChannels_ * FLOAT_BYTE_SIZE, 0}); // CopyIn - DataCopy(inputIndicesLocal_, indicesGM_[INDICES_TASK_SIZE * globalTaskOffset_], copyInDataElemCount); + DataCopyPad(inputIndicesLocal_, indicesGM_[globalTaskOffset_ * INDICES_TASK_SIZE], + {1, static_cast(INDICES_TASK_SIZE * taskCount * INT32_BYTE_SIZE), 0, 0, 0}, {false, 0, 0, 0}); PipeBarrier(); - + uint32_t mask = 0; uint64_t rsvdCnt = 0; uint16_t repeatTimes = Ceil(taskCount * 4, REPEAT_BYTE_SIZE / INT32_BYTE_SIZE); @@ -149,6 +153,7 @@ public: GatherMask(spatial0Local_, inputIndicesLocal_, SRC_PARTTEN_1, false, mask, { 1, repeatTimes, 8, 0 }, rsvdCnt); GatherMask(spatial1Local_, inputIndicesLocal_, SRC_PARTTEN_2, false, mask, { 1, repeatTimes, 8, 0 }, rsvdCnt); GatherMask(spatial2Local_, inputIndicesLocal_, SRC_PARTTEN_3, false, mask, { 1, repeatTimes, 8, 0 }, rsvdCnt); + Duplicate(indicesOffsetLocal_, static_cast(-1), singleLoopTaskAligned_ * kernelSizeAligned_); Adds(spatial0Local_, spatial0Local_, - halfk0_, taskCount); Adds(spatial1Local_, spatial1Local_, - halfk0_, taskCount); @@ -157,8 +162,15 @@ public: if (useTwolevelMap_) { ProcessOneLoopForTwoLevelMap(taskOffset, taskCount); } else { - ProcessOneLoopForOneLevelMap(taskOffset, taskCount); + if (sparseRate < SPARSE_THRESHOLD) { + ProcessOneLoopForOneLevelMap(taskOffset, taskCount); + } else { + ProcessOneLoopForOneLevelMapDense(taskOffset, taskCount); + } } + + DataCopyPad(indicesOffsetGM_[globalTaskOffset_ * kernelSize_], indicesOffsetLocal_, + {static_cast(taskCount), static_cast(kernelSize_ * INT32_BYTE_SIZE), 0, 0, 0}); } } @@ -168,20 +180,22 @@ public: return; } - int32_t outputIdx = (globalTaskOffset_ + i) * kernelSize_ + k0Idx * k12_ + k1Idx * k2_ + k2Idx; - indicesOffsetGM_.SetValue(outputIdx, map1Val); + int32_t innerKernelOffset = k0Idx * k12_ + k1Idx * k2_ + k2Idx; + int32_t indicesOffsetOutputIdx = i * kernelSizeAligned_ + innerKernelOffset; + indicesOffsetLocal_.SetValue(indicesOffsetOutputIdx, map1Val); + if (k0Idx == halfk0_ && k1Idx == halfk0_ && k2Idx == halfk0_) { return; } DataCopyPad(tmpFeatureLocal_[copyInOffset_ * inChannelsAligned_], inputFeatureGM_[map1Val * inChannels_], {1, copyByteSize_, 0, 0, 0}, {false, 0, 0, 0}); - + SetFlag(eventMTE2ToMTE3_); WaitFlag(eventMTE2ToMTE3_); - DataCopyPad(outputFeatureGM_[outputIdx * inChannels_], tmpFeatureLocal_[copyInOffset_ * inChannelsAligned_], + DataCopyPad(outputFeatureGM_[((globalTaskOffset_ + i) * kernelSize_ + innerKernelOffset) * inChannels_], tmpFeatureLocal_[copyInOffset_ * inChannelsAligned_], {1, copyByteSize_, 0, 0, 0}); copyInOffset_ = (copyInOffset_ + 1) % tmpBufLength_; @@ -223,6 +237,25 @@ public: } } + __aicore__ inline void ProcessOneLoopForOneLevelMapDense(int32_t taskOffset, uint32_t taskCount) + { + for (int16_t i = 0; i < taskCount; i++) { + int16_t batchIdx = batchIdxLocal_.GetValue(i); + int16_t spatial0BaseIdx = spatial0Local_.GetValue(i); + int16_t spatial1BaseIdx = spatial1Local_.GetValue(i); + int16_t spatial2BaseIdx = spatial2Local_.GetValue(i); + + for (int16_t k = 0; k < kernelSize_; k++) { + int8_t k2Idx = k % k2_; + int8_t k1Idx = (k % k12_) / k2_; + int8_t k0Idx = k / k12_; + + ProcessOnePoint(i, k0Idx, k1Idx, k2Idx, map1GM_.GetValue(batchIdx * totalSpatialShape_ + (spatial0BaseIdx + k0Idx) * spatialShape12_ + + (spatial1BaseIdx + k1Idx) * spatialShape2_ + (spatial2BaseIdx + k2Idx))); + } + } + } + __aicore__ inline void ProcessOneLoopForTwoLevelMap(int32_t taskOffset, uint32_t taskCount) { for (int16_t i = 0; i < taskCount; i++) { @@ -276,12 +309,13 @@ private: spatialShape2_, spatialShape01_, spatialShape12_, coreTaskCount_, singleLoopTask_, singleLoopTaskAligned_, globalTaskOffset_, inChannelsAligned_, copyInOffset_, kernelSizeAligned_, outputOneLineElementCount_, outputHalfLineElementCount_; int32_t eventMTE2ToMTE3_; + float sparseRate; int64_t totalSpatialShape_; GlobalTensor inputFeatureGM_, outputFeatureGM_; GlobalTensor indicesGM_, map1GM_, map2GM_, indicesOffsetGM_; LocalTensor tmpFeatureLocal_, mapValFloatLocal_, mapValFloatLocalBak_, workLocal_; - LocalTensor inputIndicesLocal_, batchIdxLocal_, spatial0Local_, spatial1Local_, spatial2Local_, mapValLocal_; - TBuf inputIndicesBuf_, totalIndicesBuf_, tmpFeatureBuf_, mapValBuf_, mapValFloatBuf_; + LocalTensor inputIndicesLocal_, batchIdxLocal_, spatial0Local_, spatial1Local_, spatial2Local_, mapValLocal_, indicesOffsetLocal_; + TBuf inputIndicesBuf_, totalIndicesBuf_, tmpFeatureBuf_, mapValBuf_, mapValFloatBuf_, indicesOffsetBuf_; TPipe* pipe_; }; diff --git a/mx_driving/csrc/SubmSparseCov3d.cpp b/mx_driving/csrc/SubmSparseCov3d.cpp index b142753c..d16b5700 100644 --- a/mx_driving/csrc/SubmSparseCov3d.cpp +++ b/mx_driving/csrc/SubmSparseCov3d.cpp @@ -78,7 +78,7 @@ std::tuple npu_subm_sparse_conv3d_v2(const at::Tensor& f c10::SmallVector indices_out_size = {outputsum}; at::Tensor feature_out = at::zeros(output_size, feature.options()); - at::Tensor indices_offset = at::empty(indices_out_size, feature.options().dtype(at::kInt)).fill_(-1); + at::Tensor indices_offset = at::empty(indices_out_size, feature.options().dtype(at::kInt)); EXEC_NPU_CMD(aclnnSubmSparseConv3dV2, feature, indices, map1, map2, kernel_size, in_channels, out_spatial_shape, batch_size, feature_out, indices_offset); diff --git a/mx_driving/ops/sparse_functional.py b/mx_driving/ops/sparse_functional.py index 5b9def3a..dd219aac 100644 --- a/mx_driving/ops/sparse_functional.py +++ b/mx_driving/ops/sparse_functional.py @@ -87,15 +87,10 @@ def generate_map(coors, origin_spatial_shape, bs, kernel_size): new_coors1 = spatial_shape1 * coors[:, 0] + spatial_shape[1] * coors[:, 1] + coors[:, 2] map1 = torch.full((spatial_shape1 * bs, ), -1, dtype=torch.int32, device=coors.device) - if bs * spatial_shape1 > 2**24: - map1[new_coors1] = 1 - mask = (map1 != -1) - map1_length = map1[mask].shape[0] - map1[mask] = torch.arange(map1_length, dtype=torch.int32, device=coors.device) - else: - coors_unique = mx_driving.npu_unique(new_coors1) - map1_length = coors_unique.shape[0] - map1[coors_unique] = torch.arange(map1_length, dtype=torch.int32, device=coors.device) + map1[new_coors1] = 1 + mask = (map1 != -1) + map1_length = map1[mask].shape[0] + map1[mask] = torch.arange(map1_length, dtype=torch.int32, device=coors.device) map2 = torch.full((map1_length, spatial_shape[2]), -1, dtype=torch.int32, device=coors.device) map2[map1[new_coors1], coors[:, 3]] = torch.arange(new_coors1.numel(), dtype=torch.int32, device=coors.device) -- Gitee From 5529a4c9e8697d917f83fd5c40d0441932205aeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=BA=E9=9B=A8=E6=9D=B0?= Date: Wed, 21 May 2025 09:34:54 +0800 Subject: [PATCH 2/2] update DeformableDETR requirements --- model_examples/Deformable-DETR/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_examples/Deformable-DETR/requirements.txt b/model_examples/Deformable-DETR/requirements.txt index b0f2650c..0b6909ef 100644 --- a/model_examples/Deformable-DETR/requirements.txt +++ b/model_examples/Deformable-DETR/requirements.txt @@ -2,7 +2,7 @@ pyyaml torchvision==0.16.0 tqdm pycocotools -numpy==1.22 +numpy==1.23 scipy decorator attrs -- Gitee