diff --git a/kernels/op_host/draw_gaussian_to_heatmap.cpp b/kernels/op_host/draw_gaussian_to_heatmap.cpp index 1734615d3db205197727072886424347dfe83491..b608edf93610f19fcff30eaf7b4f8042f568d63a 100644 --- a/kernels/op_host/draw_gaussian_to_heatmap.cpp +++ b/kernels/op_host/draw_gaussian_to_heatmap.cpp @@ -30,12 +30,12 @@ static ge::graphStatus TilingFuncForDrawGaussianToHeatmap(gert::TilingContext* c CHECK_NULLPTR(centerIntPtr); auto centerIntShape = centerIntPtr->GetStorageShape(); uint32_t taskObj = centerIntShape.GetDim(1); - uint32_t coreTaskLen = Ceil(taskObj, coreNum); - uint32_t usedCoreNum = Ceil(taskObj, coreTaskLen); + uint32_t coreTaskLen = Ceil(numClasses, coreNum); + uint32_t usedCoreNum = Ceil(numClasses, coreTaskLen); uint32_t radiusLen = 1024 * 4 * sizeof(float); uint32_t singlePorcessCopyLen = (ubSize - RESERVED_UB_SIZE - radiusLen) / 4 / 5; singlePorcessCopyLen = AlignUp(singlePorcessCopyLen, 32); - uint32_t taskRepeatTimes = Ceil(coreTaskLen, singlePorcessCopyLen); + uint32_t taskRepeatTimes = Ceil(taskObj, singlePorcessCopyLen); context->SetBlockDim(usedCoreNum); tiling.set_coreTaskLen(coreTaskLen); tiling.set_numClasses(numClasses);