From 95aa4276f4ae7e6fbb0ee1255e83b3de1ddee469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=AD=A6=E6=96=8C?= Date: Thu, 18 Jul 2024 02:34:19 +0000 Subject: [PATCH] layernorm optimize MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 杨学斌 --- .../layernorm/layernorm_common_impl.h | 89 ++++++++++++++----- 1 file changed, 66 insertions(+), 23 deletions(-) diff --git a/impl/normalization/layernorm/layernorm_common_impl.h b/impl/normalization/layernorm/layernorm_common_impl.h index 5ab03082..b5a191cb 100644 --- a/impl/normalization/layernorm/layernorm_common_impl.h +++ b/impl/normalization/layernorm/layernorm_common_impl.h @@ -24,34 +24,77 @@ template __aicore__ inline void LayerNormReduceSumImpl(const LocalTensor& dstMVTmp, const LocalTensor& dst, const LocalTensor& src, const uint32_t bsLength, const uint32_t hLength) { - for (uint32_t i = 0; i < bsLength; i++) { - uint32_t totalNum = hLength; - LocalTensor srcTmp = src[i * hLength]; - LocalTensor dstTmp = dst[i * hLength]; - - while (totalNum > 1) { - SetVectorMask(0, totalNum); - - if (totalNum <= ONE_REPEAT_FLOAT_SIZE) { - if constexpr (isRelocate) { - WholeReduceSum(dstMVTmp[i], srcTmp, MASK_PLACEHOLDER, 1, DEFAULT_BLK_STRIDE, - DEFAULT_BLK_STRIDE, DEFAULT_REPEAT_STRIDE); - PipeBarrier(); - } - - if constexpr (isTransposeDst) { - dstTmp = dst[i]; - } + ResetMask(); + SetMaskNorm(); + constexpr uint32_t rightShiftSix = 6; + if (hLength > ONE_REPEAT_FLOAT_SIZE) { + uint32_t addRepeatTime = (hLength >> rightShiftSix) - 1; + uint32_t addTailNumber = (hLength & 0x3f); + if ((hLength & 0x3F) == 0) { + for (uint32_t i = 0; i < bsLength * hLength; i += hLength) { + LocalTensor dstTmp = src[i]; + LocalTensor srcTmp = src[i + ONE_REPEAT_FLOAT_SIZE]; + Add(dstTmp, srcTmp, dstTmp, ONE_REPEAT_FLOAT_SIZE, addRepeatTime, + { DEFAULT_BLK_STRIDE, DEFAULT_BLK_STRIDE, DEFAULT_BLK_STRIDE, 0, DEFAULT_REPEAT_STRIDE, 0 }); + PipeBarrier(); + } + } else if (addRepeatTime > 0) { + for (uint32_t i = 0; i < bsLength * hLength; i += hLength) { + LocalTensor dstTmp = src[i]; + LocalTensor srcTmp = src[i + ONE_REPEAT_FLOAT_SIZE]; + LocalTensor srcTailTmp = src[i + hLength & 0xFFFFFFC0]; + Add(dstTmp, srcTmp, dstTmp, ONE_REPEAT_FLOAT_SIZE, addRepeatTime, + { DEFAULT_BLK_STRIDE, DEFAULT_BLK_STRIDE, DEFAULT_BLK_STRIDE, 0, DEFAULT_REPEAT_STRIDE, 0 }); + PipeBarrier(); + Add(dstTmp, srcTailTmp, dstTmp, addTailNumber, 1, + { DEFAULT_BLK_STRIDE, DEFAULT_BLK_STRIDE, DEFAULT_BLK_STRIDE, 0, DEFAULT_REPEAT_STRIDE, 0 }); + PipeBarrier(); } + } else { + for (uint32_t i = 0; i < bsLength * hLength; i += hLength) { + LocalTensor dstTmp = src[i]; + LocalTensor srcTailTmp = src[i + hLength & 0xFFFFFFC0]; + Add(dstTmp, srcTailTmp, dstTmp, addTailNumber, 1, + { DEFAULT_BLK_STRIDE, DEFAULT_BLK_STRIDE, DEFAULT_BLK_STRIDE, 0, DEFAULT_REPEAT_STRIDE, 0 }); + PipeBarrier(); + } + } + } - WholeReduceSum(dstTmp, srcTmp, MASK_PLACEHOLDER, 1, DEFAULT_BLK_STRIDE, DEFAULT_BLK_STRIDE, - DEFAULT_REPEAT_STRIDE); - PipeBarrier(); + uint32_t repeatTime = bsLength; + uint32_t cursorSrc = 0; + uint32_t wholeReduceSumHLength = (hLength > ONE_REPEAT_FLOAT_SIZE) ? ONE_REPEAT_FLOAT_SIZE : hLength; + constexpr uint32_t rightShiftThree = 3; + const uint32_t reduceSumSrcRepeatStride = hLength >> rightShiftThree; + + while (repeatTime >= MAX_REPEAT_TIMES) { + LocalTensor srcTmp = src[cursorSrc * MAX_REPEAT_TIMES * hLength]; + LocalTensor dstTmp = dst[cursorSrc * MAX_REPEAT_TIMES * hLength]; + if constexpr (isRelocate) { + WholeReduceSum(dstMVTmp[cursorSrc * MAX_REPEAT_TIMES], srcTmp, wholeReduceSumHLength, + MAX_REPEAT_TIMES, 1, DEFAULT_BLK_STRIDE, reduceSumSrcRepeatStride); + } + WholeReduceSum(dstTmp, srcTmp, wholeReduceSumHLength, MAX_REPEAT_TIMES, hLength, DEFAULT_BLK_STRIDE, + reduceSumSrcRepeatStride); + PipeBarrier(); + repeatTime -= MAX_REPEAT_TIMES; + ++cursorSrc; + } - totalNum = DivCeil(totalNum, ONE_REPEAT_FLOAT_SIZE); - srcTmp = dstTmp; + uint32_t reduceSumSrcRepeatTimeTail = bsLength - cursorSrc * MAX_REPEAT_TIMES; + if (reduceSumSrcRepeatTimeTail > 0) { + LocalTensor srcTmp = src[cursorSrc * MAX_REPEAT_TIMES * hLength]; + LocalTensor dstTmp = dst[cursorSrc * MAX_REPEAT_TIMES * hLength]; + if constexpr (isRelocate) { + WholeReduceSum(dstMVTmp[cursorSrc * MAX_REPEAT_TIMES], srcTmp, wholeReduceSumHLength, + reduceSumSrcRepeatTimeTail, 1, DEFAULT_BLK_STRIDE, reduceSumSrcRepeatStride); } + WholeReduceSum(dstTmp, srcTmp, wholeReduceSumHLength, reduceSumSrcRepeatTimeTail, hLength, + DEFAULT_BLK_STRIDE, reduceSumSrcRepeatStride); + PipeBarrier(); } + + SetMaskCount(); } __aicore__ inline void GetLayerNormOutputMean(const LocalTensor& outputMean, const LocalTensor& inputX, -- Gitee