diff --git a/src/akg_reduce/README.md b/src/akg_reduce/README.md index 081e4a58aa9e39b575abc7335395bdb3cd5a64d7..2995dd417eccc0302731810d67acb5c12f0c009b 100644 --- a/src/akg_reduce/README.md +++ b/src/akg_reduce/README.md @@ -27,7 +27,7 @@ __global__ void Reduce1DMultiBlock(int x_len, T *arr, T *output, int item_per_th } } __syncthreads(); - AkgReduce(op, &temp_output[0], red_buf, acc); + AkgReduce(op, &temp_output[0], red_buf, acc, 32); __syncthreads(); if (threadIdx.x == 0) { AkgAtomicReturn(temp_output[0], &output[0], op); @@ -37,6 +37,9 @@ __global__ void Reduce1DMultiBlock(int x_len, T *arr, T *output, int item_per_th ## 4. Updates +### 2021.8.20 +- Fix bugs when using "shfl.down" function in irregular cases. Since "shfl.down" is a wrap-based functions but irregular reduction algorithms broke this rule. To safe reduction, we use volatile shared algorithm instead. + ### 2021.8.16 - Support ProdOp, AtomicProd. Now you can use akg-reduce-lib to implement a ReduceProd kernel. diff --git a/src/akg_reduce/algorithm/shared_reduce.cuh b/src/akg_reduce/algorithm/shared_reduce.cuh index ebf96a49fb08bf14f19adc308665aa8e882b8610..29518e213ec2b1b4b628060200ed2355334b967f 100644 --- a/src/akg_reduce/algorithm/shared_reduce.cuh +++ b/src/akg_reduce/algorithm/shared_reduce.cuh @@ -25,7 +25,8 @@ namespace akg_reduce { /** - * \brief Reduction in a warp using shfl functions. + * \brief Reduction in a warp using shfl functions. This func doesn't save when BlockDimX + * isn't 2^X (X >= 0). * * \par * - Supports 1D or 2D reduction computation. The reduction direction is along x-axis. @@ -37,7 +38,7 @@ namespace akg_reduce { * \tparam T Dtype of reduction **/ template -__device__ __forceinline__ void WarpReduce(T *shared_buf, // Shared memory buffer +__device__ __forceinline__ void WarpReduceShfl(T *shared_buf, // Shared memory buffer const ReduceOp op, // Reduce operator const int tx = 0, // Real threadIdx.x const int ty = 0 // Real threadIdx.y @@ -65,7 +66,8 @@ __device__ __forceinline__ void WarpReduce(T *shared_buf, // Shared memory b } /** - * \brief Reduction in a warp for one btye dtype. + * \brief Reduction in a warp for all dtype using volatile shared algorithm. + * This func is safe for all cases but little bit slower than WarpReduceShfl. * * \par * - Supports 1D or 2D reduction computation. The reduction direction is along x-axis. @@ -74,34 +76,38 @@ __device__ __forceinline__ void WarpReduce(T *shared_buf, // Shared memory b * \tparam BlockDimX Real blockDim.x * \tparam T Dtype of reduction **/ -template -__device__ __forceinline__ void WarpReduceOneByte(T *shared_buf, // Shared memory buffer +template +__device__ __forceinline__ void WarpReduceSafe(T *shared_buf, // Shared memory buffer const ReduceOp op, // Reduce operator const int tx = 0, // Real threadIdx.x const int ty = 0 // Real threadIdx.y ) { const int tid = ty * BlockDimX + tx; - if (BlockDimX >= 32) { + if (UpperBound >= 32) { if (tx < 16) ((volatile T *)shared_buf)[tid] = op(((volatile T *)shared_buf)[tid], ((volatile T *)shared_buf)[tid + 16]); } - if (BlockDimX >= 16) { + __syncthreads(); + if (UpperBound >= 16) { if (tx < 8) ((volatile T *)shared_buf)[tid] = op(((volatile T *)shared_buf)[tid], ((volatile T *)shared_buf)[tid + 8]); } - if (BlockDimX >= 8) { + __syncthreads(); + if (UpperBound >= 8) { if (tx < 4) ((volatile T *)shared_buf)[tid] = op(((volatile T *)shared_buf)[tid], ((volatile T *)shared_buf)[tid + 4]); } - if (BlockDimX >= 4) { + __syncthreads(); + if (UpperBound >= 4) { if (tx < 2) ((volatile T *)shared_buf)[tid] = op(((volatile T *)shared_buf)[tid], ((volatile T *)shared_buf)[tid + 2]); } - if (BlockDimX >= 2) { + __syncthreads(); + if (UpperBound >= 2) { if (tx < 1) ((volatile T *)shared_buf)[tid] = op(((volatile T *)shared_buf)[tid], ((volatile T *)shared_buf)[tid + 1]); @@ -116,52 +122,53 @@ __device__ __forceinline__ void WarpReduceOneByte(T *shared_buf, // Shared me * * \tparam ReduceOp Reduce operator type * \tparam BlockDimX Real blockDim.x + * \tparam UpperBound Lenght of x-axis after cur irregular shape * \tparam T Dtype of reduction * **/ -template +template __device__ __forceinline__ void ReduceXInBlock(T *shared_buf, // Shared memory buffer. const ReduceOp op, // Reduce operator. const int tx = 0, // Real threadIdx.x const int ty = 0 // Real threadIdx.y ) { const int tid = ty * BlockDimX + tx; - if (BlockDimX >= 1024) { + if (UpperBound >= 1024) { if (tx < 512) { shared_buf[tid] = op(shared_buf[tid], shared_buf[tid + 512]); } __syncthreads(); } - if (BlockDimX >= 512) { + if (UpperBound >= 512) { if (tx < 256) { shared_buf[tid] = op(shared_buf[tid], shared_buf[tid + 256]); } __syncthreads(); } - if (BlockDimX >= 256) { + if (UpperBound >= 256) { if (tx < 128) { shared_buf[tid] = op(shared_buf[tid], shared_buf[tid + 128]); } __syncthreads(); } - if (BlockDimX >= 128) { + if (UpperBound >= 128) { if (tx < 64) { shared_buf[tid] = op(shared_buf[tid], shared_buf[tid + 64]); } __syncthreads(); } - if (BlockDimX >= 64) { + if (UpperBound >= 64) { if (tx < 32) { shared_buf[tid] = op(shared_buf[tid], shared_buf[tid + 32]); } } - if (tx < 32) { - // choose proper algorithm for different dtype. - if (sizeof(T) == 1) { - WarpReduceOneByte(shared_buf, op, tx, ty); - } else { - WarpReduce(shared_buf, op, tx, ty); - } + __syncthreads(); + // choose proper algorithm for different scenarios. + if (BlockDimX == UpperBound) { + if (tx < 32) + WarpReduceShfl(shared_buf, op, tx, ty); + } else { + WarpReduceSafe(shared_buf, op, tx, ty); } __syncthreads(); } @@ -276,7 +283,7 @@ __device__ __forceinline__ void HalvedReduce1D(T *shared_buf, // Shared memo const int tid = tx; if (IsPowOfTwo(BlockDimX)) { // Using unroll strategy. - ReduceXInBlock(shared_buf, op, tx, ty); + ReduceXInBlock(shared_buf, op, tx, ty); } else { constexpr int UpperBound = GetUpperBound(BlockDimX); @@ -286,7 +293,7 @@ __device__ __forceinline__ void HalvedReduce1D(T *shared_buf, // Shared memo } __syncthreads(); - ReduceXInBlock(shared_buf, op, tx, ty); + ReduceXInBlock(shared_buf, op, tx, ty); } } @@ -318,7 +325,7 @@ __device__ __forceinline__ void HalvedReduce2DX(T *shared_buf, // Shared mem __syncthreads(); if (IsPowOfTwo(BlockDimX)) { - ReduceXInBlock(shared_buf, op, tx, ty); + ReduceXInBlock(shared_buf, op, tx, ty); } else { constexpr int UpperBound = GetUpperBound(BlockDimX); @@ -327,7 +334,7 @@ __device__ __forceinline__ void HalvedReduce2DX(T *shared_buf, // Shared mem } __syncthreads(); - ReduceXInBlock(shared_buf, op, tx, ty); + ReduceXInBlock(shared_buf, op, tx, ty); } }