diff --git a/impl/matmul/kfc/matmul_server.h b/impl/matmul/kfc/matmul_server.h index 7a73fd1a64986f885a547d12519e7799b5d1c01b..f022f96c093c2e8196e506a623c2ebe23e0e9a54 100644 --- a/impl/matmul/kfc/matmul_server.h +++ b/impl/matmul/kfc/matmul_server.h @@ -311,6 +311,7 @@ public: } } } + __aicore__ inline bool IterateBatch(MSG_POS KfcMsg* msg); __aicore__ inline void StartIterateNBatch(MsgTmpPos MatmulConfigParams* body, uint32_t &cntIterator); __aicore__ inline bool IterateNBatch(MSG_POS KfcMsg* msg); diff --git a/impl/matmul/kfc/matmul_server_impl.h b/impl/matmul/kfc/matmul_server_impl.h index a9e2f011fc9ebdbbe3994e33a4cd4e402f415f74..e456a5272b69cb5ee4ce09dc6eee7fc627593fa6 100644 --- a/impl/matmul/kfc/matmul_server_impl.h +++ b/impl/matmul/kfc/matmul_server_impl.h @@ -344,13 +344,13 @@ __aicore__ inline void MatmulServicebatchA, body->batchB, batchC, batchOffset); - BmmOffset batchLoopOffset; for (uint32_t loopIdx = 0U; loopIdx < body->batchLoop; loopIdx++) { const uint64_t biasOffset = batchOffsetBias * loopIdx; - CalcNBatchoffset(body->batchA, body->batchB, batchC, loopIdx, batchOffset, batchLoopOffset); + batchLoopOffset.offA = CalcNBatchoffset(body->batchA, loopIdx, tiling_.GetALayoutInfoN(), tiling_.GetALayoutInfoG(), tiling_.GetALayoutInfoD(), tiling_.GetALayoutInfoS()); + batchLoopOffset.offB = CalcNBatchoffset(body->batchB, loopIdx, tiling_.GetBLayoutInfoN(), tiling_.GetBLayoutInfoG(), tiling_.GetBLayoutInfoD(), tiling_.GetBLayoutInfoS()); + batchLoopOffset.offC = CalcNBatchoffset(batchC, loopIdx, tiling_.GetCLayoutInfoN(), tiling_.GetCLayoutInfoG(), tiling_.GetCLayoutInfoS2(), tiling_.GetCLayoutInfoS1()); + IterateSetMessage(body, singleBatchASize, singleBatchBSize, batchLoopOffset.offA, batchLoopOffset.offB, biasOffset); GlobalTensor cGlobal; cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ DstT*>(body->cAddr + batchLoopOffset.offC), size); diff --git a/impl/matmul/scheduler/batch/batch_scheduler.h b/impl/matmul/scheduler/batch/batch_scheduler.h index bdc990d71b90a13b3adb551cb36eaa744dfa7d47..b8666a2ba8c83e2f2641488e543ebeb796711e79 100644 --- a/impl/matmul/scheduler/batch/batch_scheduler.h +++ b/impl/matmul/scheduler/batch/batch_scheduler.h @@ -94,9 +94,12 @@ public: event_t eventIDMToMte1 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::M_MTE1)); auto batchLoop = MATMUL_MODULE(BatchLoop); for (batchLoop->SplitStart(); !batchLoop->SplitEnd(); batchLoop->SplitNext()) { - MATMUL_MODULE(BatchCopyCubeInA)->BatchLoad(a1, matrixStrideA, batchLoop->GetOuterIndex(), + uint32_t outerIdxA; + uint32_t outerIdxB; + batchLoop->CalcBatchOuterIdx(outerIdxA, outerIdxB); + MATMUL_MODULE(BatchCopyCubeInA)->BatchLoad(a1, matrixStrideA, outerIdxA, batchLoop->GetSplitIndex(), batchLoop->GetSplitSize()); - MATMUL_MODULE(BatchCopyCubeInB)->BatchLoad(b1, matrixStrideB, batchLoop->GetOuterIndex(), + MATMUL_MODULE(BatchCopyCubeInB)->BatchLoad(b1, matrixStrideB, outerIdxB, batchLoop->GetSplitIndex(), batchLoop->GetSplitSize()); SetFlag(eventIDMte2ToMte1); WaitFlag(eventIDMte2ToMte1); @@ -294,6 +297,7 @@ private: ComputeMDb(a1, b1, bias, ctx, sL0CInit, sL0CLast, enPartialSum); } else { ComputeNDb(a1, b1, bias, ctx, sL0CInit, sL0CLast, enPartialSum); + MATMUL_MODULE(BiasScheduler)->Free(); } } while(MATMUL_MODULE(KLoop)->OuterNext()); } @@ -337,7 +341,6 @@ private: cmatrixInitVal, false); bufferPool.Free(); - MATMUL_MODULE(BiasScheduler)->Free(); axisL1DbOffset += MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseM(); } } diff --git a/impl/matmul/scheduler/iterator/batch_loop/batch_loop.h b/impl/matmul/scheduler/iterator/batch_loop/batch_loop.h index d149be8051b7deb4800956d6740e42358cd2eeab..fece97902b2c43b7f612b57f8dc3b1e4803d1fa1 100644 --- a/impl/matmul/scheduler/iterator/batch_loop/batch_loop.h +++ b/impl/matmul/scheduler/iterator/batch_loop/batch_loop.h @@ -15,7 +15,8 @@ #ifndef IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LOOP_H #define IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LOOP_H -#include "batch_loop_multi.h" #include "batch_loop_single.h" +#include "batch_loop_batch_less.h" +#include "batch_loop_batch_large.h" #endif // _BATCH_LOOP_H_ \ No newline at end of file diff --git a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_batch_base.h similarity index 58% rename from impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h rename to impl/matmul/scheduler/iterator/batch_loop/batch_loop_batch_base.h index fb98aec92043604285a8fab59dc2a72cf610d81b..ca00011443382e8a89be5037af1a46191a5fe3c7 100644 --- a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h +++ b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_batch_base.h @@ -10,28 +10,26 @@ */ /*! - * \file batch_loop_multi.h + * \file file batch_loop_batch_base.h * \brief */ -#ifndef IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LOOP_MULTI_H -#define IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LOOP_MULTI_H + #ifndef IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LOOP_BASE_H + #define IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LOOP_BASE_H -#include "batch_loop_intf.h" +#include "../../../utils/matmul_module.h" namespace AscendC { namespace Impl { namespace Detail { /* - BatchLoop is considered entirely experimental. + BatchLoopBase is considered entirely experimental. We retain the freedom to make incompatible changes, but do not guarantee the stability. - BatchLoop is only for internal usage, does not support extension or customized specialization! + BatchLoopBase is only for internal usage, does not support extension or customized specialization! */ template -class BatchLoop() == Impl::Detail::CopyCubeInType::BMM) || - (Impl::Detail::IsBMMFromL1())>> +class BatchLoopBase { MATMUL_USE_MODULE(MatmulShapeTiling); MATMUL_USE_MODULE(MatmulShapeInfo); @@ -39,27 +37,26 @@ class BatchLoopGetTiling(); - CalcBatchNum(tiling.GetALayoutInfoB(), tiling.GetBLayoutInfoB(), tiling.GetBatchNum(), tiling.GetBatchNum()); - if constexpr (IsBmmDoubleBuffer()) { auto batchNum = tiling.GetBatchNum(); splitSize_ = (batchNum % DB_FACTOR == 0) ? DB_FACTOR : 1; splitBatchNum_ = batchNum / splitSize_; } - UpdateBatchNumParams(); + CalcBatchNum(tiling.GetBatchNum(), tiling.GetBatchNum()); + batchNum_ = batchA_ > batchB_ ? batchA_ : batchB_; } __aicore__ inline void SetBatchNum(int32_t batchNumA, int32_t batchNumB) { - CalcBatchNum(batchNumA, batchNumB, batchNumA, batchNumB); - UpdateBatchNumParams(); + CalcBatchNum(batchNumA, batchNumB); + batchNum_ = batchA_ > batchB_ ? batchA_ : batchB_; } __aicore__ inline void SetNBatchOutNum(int32_t nBatchOutNum) @@ -80,16 +77,6 @@ public: { outerIdx_++; dstOffset_ += batchCalcSize_; - if (oddAndLargeThanL1_ && outerIdx_ == batchOuter_ - 1) { - const int32_t tail = inputBatchNum_ % batchA_; - batchA_ = tail == 0 ? mainBatchInner_ : tail; - batchB_ = batchA_; - batchNum_ = batchA_; - batchCalcSize_ = batchNum_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreM() * - MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); - splitSize_ = (batchA_ >= DB_FACTOR) ? DB_FACTOR : 1; - splitBatchNum_ = batchNum_ / splitSize_; - } } __aicore__ inline bool OuterEnd() @@ -97,16 +84,6 @@ public: return outerIdx_ >= batchOuter_; } - __aicore__ inline int32_t GetMainBatchBlockA() const - { - return oddAndLargeThanL1_ ? mainBatchInner_ : batchA_; // batchNum main block in outLoop - } - - __aicore__ inline int32_t GetMainBatchBlockB() const - { - return oddAndLargeThanL1_ ? mainBatchInner_ : batchB_; // batchNum main block in outLoop - } - __aicore__ inline uint32_t GetOuterIndex() const { return outerIdx_; @@ -148,9 +125,19 @@ public: return batchB_; } + __aicore__ inline int32_t GetMainBatchBlockA() const + { + return batchA_; // batchNum main block in outLoop + } + + __aicore__ inline int32_t GetMainBatchBlockB() const + { + return batchB_; // batchNum main block in outLoop + } + __aicore__ inline int32_t GetBiasBatchSrcOffset() const { - return outerIdx_ * (oddAndLargeThanL1_ ? mainBatchInner_ : batchNum_) * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); + return outerIdx_ * batchNum_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); } // Double Buffer Loop @@ -222,18 +209,10 @@ public: __aicore__ inline bool InnerEnd() { - if ((!oddAndLargeThanL1_) || (batchNum_ % DB_FACTOR == 0) || (splitSize_ < DB_FACTOR)) { - if constexpr (IsBmmDoubleBuffer()) { - return (innerIdx_ >= splitBatchNum_) || (splitOuterIdx_ * splitBatchNum_ >= batchNum_) || (innerBatchIdx_ >= batchNum_); - } else { - return (innerIdx_ >= splitBatchNum_) || (splitOuterIdx_ * splitBatchNum_ >= batchNum_); - } - } - const auto firstBatchNum = batchNum_ / splitSize_; - if (splitOuterIdx_ < 1) { - return innerIdx_ >= firstBatchNum; + if constexpr (IsBmmDoubleBuffer()) { + return (innerIdx_ >= splitBatchNum_) || (splitOuterIdx_ * splitBatchNum_ >= batchNum_) || (innerBatchIdx_ >= batchNum_); } else { - return innerIdx_ >= batchNum_ - firstBatchNum; + return (innerIdx_ >= splitBatchNum_) || (splitOuterIdx_ * splitBatchNum_ >= batchNum_); } } @@ -289,49 +268,27 @@ public: } private: - __aicore__ inline void CalcBatchNum(int32_t layoutBatchNumA, int32_t layoutBatchNumB, - int32_t batchNumA, int32_t batchNumB) + __aicore__ inline void CalcBatchNum(int32_t batchNumA, int32_t batchNumB) { totalBatchNum_ = batchNumA > batchNumB ? batchNumA : batchNumB; - if constexpr (ToMatmulConfig(MM_CFG).batchMode != BatchMode::BATCH_LARGE_THAN_L1) { - ASSERT(batchNumA > 0 && batchNumB > 0 && - (batchNumA % batchNumB == 0 || batchNumB % batchNumA == 0)); - batchA_ = batchNumA; - batchB_ = batchNumB; - mainBatchInner_ = 0; - return; - } + ASSERT(batchNumA > 0 && batchNumB > 0 && + (batchNumA % batchNumB == 0 || batchNumB % batchNumA == 0)); + batchA_ = batchNumA; + batchB_ = batchNumB; + } - ASSERT(layoutBatchNumA > 0 && layoutBatchNumB > 0 && - (layoutBatchNumA % layoutBatchNumB == 0 || layoutBatchNumB % layoutBatchNumA == 0)); - int32_t aMatrixSingleBatchSize = GetSingleSizeAlignA(); - int32_t bMatrixSingleBatchSize = GetSingleSizeAlignB(); - if ((layoutBatchNumA * aMatrixSingleBatchSize + layoutBatchNumB * bMatrixSingleBatchSize + - MATMUL_MODULE(MatmulShapeTiling)->GetTiling().IsBias() * - MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN() * sizeof(BiasT)) <= TOTAL_L1_SIZE) { - batchOuter_ = 1; - batchA_ = layoutBatchNumA; - batchB_ = layoutBatchNumB; - return; - } - int32_t batchNumLarge; - int32_t batchNumLess; - int32_t largeMatrixSingleBatchSize; - int32_t lessMatrixSingleBatchSize; - if (layoutBatchNumA >= layoutBatchNumB) { - batchNumLarge = layoutBatchNumA; - batchNumLess = layoutBatchNumB; - largeMatrixSingleBatchSize = aMatrixSingleBatchSize; - lessMatrixSingleBatchSize = bMatrixSingleBatchSize; - } else { - batchNumLarge = layoutBatchNumB; - batchNumLess = layoutBatchNumA; - largeMatrixSingleBatchSize = bMatrixSingleBatchSize; - lessMatrixSingleBatchSize = aMatrixSingleBatchSize; - } - CalcBatchAB(batchNumLarge, batchNumLess, largeMatrixSingleBatchSize, lessMatrixSingleBatchSize, layoutBatchNumA >= layoutBatchNumB); + __aicore__ inline void UpdateSplitParams() + { + splitBatchIdx_ += splitBatchNum_; } + __aicore__ inline void UpdateInnerParams() + { + innerBatchIdx_ = innerIdx_ + splitBatchIdx_; + } + +protected: + __aicore__ inline int32_t GetSingleSizeAlignA() { const auto matmulShapeInfo = MATMUL_MODULE(MatmulShapeInfo); @@ -366,65 +323,6 @@ private: } } - __aicore__ inline void CalcBatchAB(int32_t batchNumLarge, int32_t batchNumLess, - int32_t largeMatrixSingleBatchSize, int32_t lessMatrixSingleBatchSize, bool isBatchALarger) - { - int32_t multiples = batchNumLarge / batchNumLess; - int32_t singleBatchSize = multiples * largeMatrixSingleBatchSize + lessMatrixSingleBatchSize + - MATMUL_MODULE(MatmulShapeTiling)->GetTiling().IsBias() * - MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN() * sizeof(BiasT); - int32_t batchInner = TOTAL_L1_SIZE / singleBatchSize; - inputBatchNum_ = batchNumLarge; - - ASSERT(batchInner > 0); - oddAndLargeThanL1_ = (multiples == 1) && (inputBatchNum_ % DB_FACTOR != 0); - if (oddAndLargeThanL1_) { - mainBatchInner_ = batchInner; - batchOuter_ = CeilT(batchNumLess, batchInner); - batchA_ = batchInner; - batchB_ = batchInner; - } else { - while (batchNumLess % batchInner != 0 && batchInner > 0) { - --batchInner; - } - mainBatchInner_ = batchInner; - batchOuter_ = batchNumLess / batchInner; - if (isBatchALarger) { - batchA_ = multiples * batchInner; - batchB_ = batchInner; - } else { - batchA_ = batchInner; - batchB_ = multiples * batchInner; - } - } - } - - __aicore__ inline void UpdateBatchNumParams() - { - batchNum_ = batchA_ > batchB_ ? batchA_ : batchB_; - if constexpr (!IsBmmDoubleBuffer()) { - if (batchOuter_ > 1 && batchA_ == batchB_) { - splitSize_ = (batchA_ >= DB_FACTOR) ? DB_FACTOR : 1; - splitBatchNum_ = batchNum_ / splitSize_; - } else { - splitSize_ = (batchNum_ >= DB_FACTOR) && (batchA_ % DB_FACTOR == 0) && (batchB_ % DB_FACTOR == 0) - ? DB_FACTOR - : 1; - splitBatchNum_ = batchNum_ / splitSize_; - } - } - } - - __aicore__ inline void UpdateSplitParams() - { - splitBatchIdx_ += splitBatchNum_; - } - - __aicore__ inline void UpdateInnerParams() - { - innerBatchIdx_ = innerIdx_ + splitBatchIdx_; - } - int32_t batchA_; // outerLoop main/tail block int32_t batchB_; // outerLoop main/tail block int32_t batchNum_; // outerLoop main/tail block @@ -453,10 +351,8 @@ private: int32_t batchOutOffsetNum_ = 0; int32_t inputBatchNum_ = 0; - bool oddAndLargeThanL1_ = false; // new logical judgment condition for handling odd batchNum && large than L1 - int32_t mainBatchInner_ = 0; // outerLoop main block }; } // namespace Detail } // namespace Impl } // namespace AscendC -#endif // IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LOOP_MULTI_H +#endif // IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LOOP_BASE_H diff --git a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_batch_large.h b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_batch_large.h new file mode 100644 index 0000000000000000000000000000000000000000..a3752a78331ea81d693a8680feeb7d9f5d5d9f5e --- /dev/null +++ b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_batch_large.h @@ -0,0 +1,220 @@ +/* + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file batch_loop_batch_large.h + * \brief + */ +#ifndef IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LARGE_THAN_L1_H +#define IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LARGE_THAN_L1_H + +#include "batch_loop_intf.h" +#include "batch_loop_batch_base.h" + +namespace AscendC { +namespace Impl { +namespace Detail { +/* + BatchLoop is considered entirely experimental. + We retain the freedom to make incompatible changes, but do not guarantee the stability. + BatchLoop is only for internal usage, does not support extension or customized specialization! +*/ +template +class BatchLoop> + : public BatchLoopBase +{ + MATMUL_USE_MODULE(MatmulShapeTiling); + MATMUL_USE_MODULE(MatmulShapeInfo); + using SrcT = typename INPUT_TYPE::T; + using BiasT = typename BIAS_TYPE::T; + +public: + using BASE_MODULE = AscendC::Impl::Detail::BatchLoopBase; + __aicore__ inline BatchLoop() = default; + __aicore__ inline ~BatchLoop() = default; + + __aicore__ inline void Init() + { + const auto tiling = MATMUL_MODULE(MatmulShapeTiling)->GetTiling(); + CalcBatchNum(tiling.GetALayoutInfoB(), tiling.GetBLayoutInfoB(), tiling.GetBatchNum(), tiling.GetBatchNum()); + UpdateBatchNumParams(); + } + + __aicore__ inline void SetBatchNum(int32_t batchNumA, int32_t batchNumB) + { + CalcBatchNum(batchNumA, batchNumB, batchNumA, batchNumB); + UpdateBatchNumParams(); + } + + __aicore__ inline void OuterNext() + { + BASE_MODULE::outerIdx_++; + BASE_MODULE::dstOffset_ += BASE_MODULE::batchCalcSize_; + if (oddAndLargeThanL1_ && BASE_MODULE::outerIdx_ == BASE_MODULE::batchOuter_ - 1) { + const int32_t tail = BASE_MODULE::inputBatchNum_ % BASE_MODULE::batchA_; + BASE_MODULE::batchA_ = tail == 0 ? mainBatchInner_ : tail; + BASE_MODULE::batchB_ = BASE_MODULE::batchA_; + BASE_MODULE::batchNum_ = BASE_MODULE::batchA_; + BASE_MODULE::batchCalcSize_ = BASE_MODULE::batchNum_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreM() * + MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); + BASE_MODULE::splitSize_ = (BASE_MODULE::batchA_ >= DB_FACTOR) ? DB_FACTOR : 1; + BASE_MODULE::splitBatchNum_ = BASE_MODULE::batchNum_ / BASE_MODULE::splitSize_; + } + } + + __aicore__ inline bool InnerEnd() + { + if ((!oddAndLargeThanL1_) || (BASE_MODULE::batchNum_ % DB_FACTOR == 0) || (BASE_MODULE::splitSize_ < DB_FACTOR)) { + return (BASE_MODULE::innerIdx_ >= BASE_MODULE::splitBatchNum_) || (BASE_MODULE::splitOuterIdx_ * BASE_MODULE::splitBatchNum_ >= BASE_MODULE::batchNum_); + } + const auto firstBatchNum = BASE_MODULE::batchNum_ / BASE_MODULE::splitSize_; + if (BASE_MODULE::splitOuterIdx_ < 1) { + return BASE_MODULE::innerIdx_ >= firstBatchNum; + } else { + return BASE_MODULE::innerIdx_ >= BASE_MODULE::batchNum_ - firstBatchNum; + } + } + + __aicore__ inline void CalcBatchOuterIdx(uint32_t& outerIdxA, uint32_t& outerIdxB) + { + if (outerLoop_ == 1 || BASE_MODULE::batchA_ == BASE_MODULE::batchB_) { + outerIdxA = BASE_MODULE::outerIdx_; + outerIdxB = BASE_MODULE::outerIdx_; + } else if (BASE_MODULE::batchA_ > BASE_MODULE::batchB_) { + outerIdxA = BASE_MODULE::outerIdx_; + outerIdxB = BASE_MODULE::outerIdx_ / outerLoop_; + } else { + outerIdxA = BASE_MODULE::outerIdx_ / outerLoop_; + outerIdxB = BASE_MODULE::outerIdx_; + } + } + + __aicore__ inline int32_t GetMainBatchBlockA() const + { + return oddAndLargeThanL1_ ? mainBatchInner_ : BASE_MODULE::batchA_; // batchNum main block in outLoop + } + + __aicore__ inline int32_t GetMainBatchBlockB() const + { + return oddAndLargeThanL1_ ? mainBatchInner_ : BASE_MODULE::batchB_; // batchNum main block in outLoop + } + + __aicore__ inline int32_t GetBiasBatchSrcOffset() const + { + return BASE_MODULE::outerIdx_ * (oddAndLargeThanL1_ ? mainBatchInner_ : BASE_MODULE::batchNum_) * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); + } + +private: + __aicore__ inline void CalcBatchNum(int32_t layoutBatchNumA, int32_t layoutBatchNumB, + int32_t batchNumA, int32_t batchNumB) + { + BASE_MODULE::totalBatchNum_ = batchNumA > batchNumB ? batchNumA : batchNumB; + + ASSERT(layoutBatchNumA > 0 && layoutBatchNumB > 0 && + (layoutBatchNumA % layoutBatchNumB == 0 || layoutBatchNumB % layoutBatchNumA == 0)); + int32_t aMatrixSingleBatchSize = BASE_MODULE::GetSingleSizeAlignA(); + int32_t bMatrixSingleBatchSize = BASE_MODULE::GetSingleSizeAlignB(); + if ((layoutBatchNumA * aMatrixSingleBatchSize + layoutBatchNumB * bMatrixSingleBatchSize + + MATMUL_MODULE(MatmulShapeTiling)->GetTiling().IsBias() * + MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN() * sizeof(BiasT)) <= TOTAL_L1_SIZE) { + BASE_MODULE::batchA_ = layoutBatchNumA; + BASE_MODULE::batchB_ = layoutBatchNumB; + return; + } + int32_t batchNumLarge; + int32_t batchNumLess; + int32_t largeMatrixSingleBatchSize; + int32_t lessMatrixSingleBatchSize; + if (layoutBatchNumA >= layoutBatchNumB) { + batchNumLarge = layoutBatchNumA; + batchNumLess = layoutBatchNumB; + largeMatrixSingleBatchSize = aMatrixSingleBatchSize; + lessMatrixSingleBatchSize = bMatrixSingleBatchSize; + } else { + batchNumLarge = layoutBatchNumB; + batchNumLess = layoutBatchNumA; + largeMatrixSingleBatchSize = bMatrixSingleBatchSize; + lessMatrixSingleBatchSize = aMatrixSingleBatchSize; + } + CalcBatchAB(batchNumLarge, batchNumLess, largeMatrixSingleBatchSize, lessMatrixSingleBatchSize, layoutBatchNumA >= layoutBatchNumB); + } + + __aicore__ inline void CalcBatchAB(int32_t batchNumLarge, int32_t batchNumLess, int32_t largeMatrixSingleBatchSize, int32_t lessMatrixSingleBatchSize, bool isBatchALarger) + { + int32_t multiples = batchNumLarge / batchNumLess; + int32_t singleBatchSize = multiples * largeMatrixSingleBatchSize + lessMatrixSingleBatchSize + + MATMUL_MODULE(MatmulShapeTiling)->GetTiling().IsBias() * + MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN() * sizeof(BiasT); + + int32_t batchInner = TOTAL_L1_SIZE / singleBatchSize; + BASE_MODULE::inputBatchNum_ = batchNumLarge; + oddAndLargeThanL1_ = (multiples == 1) && (BASE_MODULE::inputBatchNum_ % DB_FACTOR != 0); + if (batchInner <= 0) { + outerLoop_ = 1; + while (batchInner <= 0) { + outerLoop_ += 1; + while (multiples % outerLoop_ != 0 && outerLoop_ < multiples) { + outerLoop_ += 1; + } + singleBatchSize = multiples / outerLoop_ * largeMatrixSingleBatchSize + lessMatrixSingleBatchSize + + MATMUL_MODULE(MatmulShapeTiling)->GetTiling().IsBias() * + MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN() * sizeof(BiasT); + batchInner = TOTAL_L1_SIZE / singleBatchSize; + } + multiples /= outerLoop_; + } + ASSERT(batchInner > 0); + if (oddAndLargeThanL1_) { + mainBatchInner_ = batchInner; + BASE_MODULE::batchOuter_ = CeilT(batchNumLess, batchInner); + BASE_MODULE::batchA_ = batchInner; + BASE_MODULE::batchB_ = batchInner; + } else { + while (batchNumLess % batchInner != 0 && batchInner > 0) { + --batchInner; + } + mainBatchInner_ = batchInner; + BASE_MODULE::batchOuter_ = batchNumLess / batchInner * outerLoop_; + if (isBatchALarger) { + BASE_MODULE::batchA_ = multiples * batchInner; + BASE_MODULE::batchB_ = batchInner; + } else { + BASE_MODULE::batchA_ = batchInner; + BASE_MODULE::batchB_ = multiples * batchInner; + } + } + } + + __aicore__ inline void UpdateBatchNumParams() + { + BASE_MODULE::batchNum_ = BASE_MODULE::batchA_ > BASE_MODULE::batchB_ ? BASE_MODULE::batchA_ : BASE_MODULE::batchB_; + if constexpr (!IsBmmDoubleBuffer()) { + if (BASE_MODULE::batchOuter_ > 1 && BASE_MODULE::batchA_ == BASE_MODULE::batchB_) { + BASE_MODULE::splitSize_ = (BASE_MODULE::batchA_ >= DB_FACTOR) ? DB_FACTOR : 1; + BASE_MODULE::splitBatchNum_ = BASE_MODULE::batchNum_ / BASE_MODULE::splitSize_; + } else { + BASE_MODULE::splitSize_ = (BASE_MODULE::batchNum_ >= DB_FACTOR) && (BASE_MODULE::batchA_ % DB_FACTOR == 0) && + (BASE_MODULE::batchB_ % DB_FACTOR == 0) ? DB_FACTOR : 1; + BASE_MODULE::splitBatchNum_ = BASE_MODULE::batchNum_ / BASE_MODULE::splitSize_; + } + } + } + + int32_t outerLoop_ = 1; + bool oddAndLargeThanL1_ = false; // new logical judgment condition for handling odd batchNum && large than L1 + int32_t mainBatchInner_ = 0; // outerLoop main block +}; +} // namespace Detail +} // namespace Impl +} // namespace AscendC +#endif // IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LARGE_THAN_L1_H diff --git a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_batch_less.h b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_batch_less.h new file mode 100644 index 0000000000000000000000000000000000000000..c30ed8be0e7939597a5fff0bbe90c5a36b38fed2 --- /dev/null +++ b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_batch_less.h @@ -0,0 +1,44 @@ +/* + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file batch_loop_batch_less.h + * \brief + */ +#ifndef IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LOOP_BATCH_LESS_THAN_L1_H +#define IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LOOP_BATCH_LESS_THAN_L1_H + +#include "batch_loop_intf.h" +#include "batch_loop_batch_base.h" + +namespace AscendC { +namespace Impl { +namespace Detail { +/* + BatchLoop is considered entirely experimental. + We retain the freedom to make incompatible changes, but do not guarantee the stability. + BatchLoop is only for internal usage, does not support extension or customized specialization! +*/ +template +class BatchLoop> + : public BatchLoopBase +{ +public: + using BASE_MODULE = AscendC::Impl::Detail::BatchLoopBase; + __aicore__ inline BatchLoop() = default; + __aicore__ inline ~BatchLoop() = default; +}; +} // namespace Detail +} // namespace Impl +} // namespace AscendC +#endif // IMPL_MATMUL_SCHEDULER_ITERATOR_BATCH_LOOP_BATCH_LOOP_BATCH_LESS_THAN_L1_H diff --git a/impl/matmul/utils/batch_matmul_utils.h b/impl/matmul/utils/batch_matmul_utils.h index 9c70c87611fe9c5eee225abea9ba536492b88e82..718c72b699b7e07b48462b5198a851e164b8b0f9 100644 --- a/impl/matmul/utils/batch_matmul_utils.h +++ b/impl/matmul/utils/batch_matmul_utils.h @@ -35,6 +35,91 @@ constexpr bool IsBmmBatchScheduler = DoMatmulNorm(MM_CFG) && template constexpr bool IsBmmSingleScheduler = DoMatmulNorm(MM_CFG) && (A_TYPE::layout == LayoutMode::NORMAL && ToMatmulConfig(MM_CFG).batchMode == BatchMode::SINGLE_LARGE_THAN_L1); + +struct BatchOffsetInfo +{ + int32_t modA; + int32_t divisorA; + int32_t alignA; + int32_t modB; + int32_t divisorB; + int32_t alignB; + int32_t modBias; + int32_t divisorBias; + int32_t alignBias; + bool setBiasFlag {false}; +}; +struct SplitParams +{ + int16_t axisL1Len; + int16_t kAxisL1Len; + int16_t axisL1Offset; + int16_t kAxisL1Offset; + int16_t axisL0Len; +}; + +struct BatchSchedulerContext +{ + int32_t offsetA; + int32_t offsetB; + int32_t offsetBias; + uint32_t reduceGNum; + bool isReduceG; + SplitParams aL0Params; + SplitParams bL0Params; +}; + +struct BmmOffset { + uint64_t offA = 0; + uint64_t offB = 0; + uint64_t offC = 0; +}; + +// It is invoked by the matmulV3 operator and cannot be removed at present +__aicore__ inline uint16_t CeilDiv(uint16_t num1, uint16_t num2) +{ + ASSERT(num2 > 0); + return (num1 + num2 - 1) / num2; +} + +// It is invoked by the matmulV3 operator and cannot be removed at present +__aicore__ inline uint16_t CeilAlign(uint16_t num1, uint16_t num2) +{ + ASSERT(num2 > 0); + return CeilDiv(num1, num2) * num2; +} + +template +__aicore__ inline uint64_t CalcNBatchoffset(uint32_t batchValue, uint32_t loopIdx, uint32_t layoutInfoN, uint32_t layoutInfoG, uint32_t layoutInfoD, uint32_t layoutInfoS) +{ + uint32_t alignedSingleCoreN = layoutInfoD; + if constexpr (INPUT_TYPE::format == CubeFormat::ND_ALIGN) { + alignedSingleCoreN = CeilAlign(layoutInfoD, AscendCUtils::GetC0Count(sizeof(typename INPUT_TYPE::T))); + } + uint64_t offset; + if constexpr (INPUT_TYPE::layout == LayoutMode::BNGS1S2 || INPUT_TYPE::layout == LayoutMode::NORMAL) { + offset = alignedSingleCoreN * layoutInfoS * batchValue * loopIdx * sizeof(typename INPUT_TYPE::T); + } else if constexpr (INPUT_TYPE::layout == LayoutMode::SBNGD) { + offset = alignedSingleCoreN * batchValue * loopIdx * sizeof(typename INPUT_TYPE::T); + } else if constexpr (INPUT_TYPE::layout == LayoutMode::BSNGD) { + uint64_t layoutBIdx = loopIdx * batchValue / (layoutInfoN * layoutInfoG); + uint64_t layoutNGOff = loopIdx * batchValue % (layoutInfoN * layoutInfoG); + offset = (layoutBIdx * alignedSingleCoreN * layoutInfoS * layoutInfoN * layoutInfoG + layoutNGOff * alignedSingleCoreN) * sizeof(typename INPUT_TYPE::T); + } + return offset; +} + +__aicore__ inline uint64_t GetBatchCNum(uint32_t batchA, uint32_t batchB, uint32_t aLayoutInfoG, uint32_t bLayoutInfoG, uint32_t cLayoutInfoG) +{ + uint32_t batchC = batchA > batchB ? batchA : batchB; + bool layoutGCondition = cLayoutInfoG == 1 && + (aLayoutInfoG != 1 || bLayoutInfoG != 1); + if (layoutGCondition) { + int32_t layoutG = bLayoutInfoG > aLayoutInfoG ? bLayoutInfoG : aLayoutInfoG; + batchC = batchC / layoutG; + } + return batchC; +} } // namespace AscendC #endif // IMPL_MATMUL_UTILS_BATCH_MATMUL_UTILS_H \ No newline at end of file diff --git a/impl/matmul/utils/matmul_module.h b/impl/matmul/utils/matmul_module.h index e46e186b6d3a38f60d7378cd4b09d06e3c023961..6e5d89e368c01aa1283a8f54d58cba1e483eb4c8 100644 --- a/impl/matmul/utils/matmul_module.h +++ b/impl/matmul/utils/matmul_module.h @@ -164,6 +164,9 @@ MATMUL_PRIVATE_TEMPLATE::type; \ +friend typename AscendC::Impl::Detail::MatmulModuleBaseBase::type; \ +friend typename AscendC::Impl::Detail::MatmulModuleBaseBaseBase::type; \ +friend typename AscendC::Impl::Detail::MatmulModuleRoot::type; \ friend NAME #define MATMUL_ALLOW_USING_TEMPLATE_PRIVATE(NAME, ...) \ diff --git a/impl/matmul/utils/matmul_utils.h b/impl/matmul/utils/matmul_utils.h index adeacb481c41146f791d20351a92894e58f06e92..0c55960b637e2da83973bc7e7c3476832649a1ff 100644 --- a/impl/matmul/utils/matmul_utils.h +++ b/impl/matmul/utils/matmul_utils.h @@ -89,45 +89,12 @@ struct DataCopyOutParams { uint64_t cbufWorkspaceAddr = 0; }; -struct SplitParams -{ - int16_t axisL1Len; - int16_t kAxisL1Len; - int16_t axisL1Offset; - int16_t kAxisL1Offset; - int16_t axisL0Len; -}; - struct MxSplitParams : public SplitParams { int16_t kAuxMatrixL1Len; int16_t kAuxMatrixL1Offset; }; -struct BatchOffsetInfo -{ - int32_t modA; - int32_t divisorA; - int32_t alignA; - int32_t modB; - int32_t divisorB; - int32_t alignB; - int32_t modBias; - int32_t divisorBias; - int32_t alignBias; - bool setBiasFlag {false}; -}; - -struct BatchSchedulerContext -{ - int32_t offsetA; - int32_t offsetB; - int32_t offsetBias; - uint32_t reduceGNum; - bool isReduceG; - SplitParams aL0Params; - SplitParams bL0Params; -}; template __aicore__ inline constexpr int32_t GetC0Size() { if (sizeof(SrcT) == sizeof(float)) { @@ -575,20 +542,6 @@ __aicore__ inline T CeilAlign(T num1, T num2) return Ceil(num1, num2) * num2; } -// It is invoked by the matmulV3 operator and cannot be removed at present -__aicore__ inline uint16_t CeilDiv(uint16_t num1, uint16_t num2) -{ - ASSERT(num2 > 0); - return (num1 + num2 - 1) / num2; -} - -// It is invoked by the matmulV3 operator and cannot be removed at present -__aicore__ inline uint16_t CeilAlign(uint16_t num1, uint16_t num2) -{ - ASSERT(num2 > 0); - return CeilDiv(num1, num2) * num2; -} - template __aicore__ inline constexpr bool IsL0ACache() { diff --git a/lib/matmul/matmul_client.h b/lib/matmul/matmul_client.h index cab2905c5ee88e5a5c0aea00df7066c1270a88b7..0ef04bc78d3df3bc99e68247514c7c7942276934 100644 --- a/lib/matmul/matmul_client.h +++ b/lib/matmul/matmul_client.h @@ -1081,8 +1081,8 @@ public: { static_assert(!ToMatmulConfig(MM_CFG).enableMixDualMaster, "IterateNBatch not support when enableMixDualMaster is enabled."); - static_assert(A_TYPE::layout != LayoutMode::NONE && B_TYPE::layout != LayoutMode::NONE && - A_TYPE::layout != LayoutMode::NORMAL && B_TYPE::layout != LayoutMode::NORMAL && C_TYPE::layout != LayoutMode::NORMAL, + static_assert(A_TYPE::layout != LayoutMode::NONE && B_TYPE::layout != LayoutMode::NONE && + A_TYPE::layout != LayoutMode::NORMAL && B_TYPE::layout != LayoutMode::NORMAL && C_TYPE::layout != LayoutMode::NORMAL, "BMM does not support the layout being NONE or NORMAL"); if constexpr (!ToMatmulConfig(MM_CFG).isNBatch) { return; diff --git a/tests/matmul/iterator/test_batch_loop.cpp b/tests/matmul/iterator/test_batch_loop.cpp index e74d5f0871c45b0ced7b2fc2d458326636d323bb..f15462c9046a951c83edb28d416d6de679933a9a 100644 --- a/tests/matmul/iterator/test_batch_loop.cpp +++ b/tests/matmul/iterator/test_batch_loop.cpp @@ -23,7 +23,7 @@ #include "impl/matmul/policy/matmul_private_modules.h" #include "impl/matmul/param/matmul_tensor_info.h" #include "impl/matmul/param/matmul_shape_tiling.h" -#include "impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h" +#include "impl/matmul/scheduler/iterator/batch_loop/batch_loop.h" using namespace std; using namespace AscendC; diff --git a/tests/matmul/scheduler/batch_scheduler/test_batch_scheduler.cpp b/tests/matmul/scheduler/batch_scheduler/test_batch_scheduler.cpp index 435ceaa1a10afef839a7f264c37acfe6d6641ffb..eabdf635e10554972e030f1f2da53257e9cc331c 100644 --- a/tests/matmul/scheduler/batch_scheduler/test_batch_scheduler.cpp +++ b/tests/matmul/scheduler/batch_scheduler/test_batch_scheduler.cpp @@ -140,6 +140,10 @@ public: return 0; } + __aicore__ inline void CalcBatchOuterIdx(uint32_t& outerIdxA, uint32_t& outerIdxB) + { + } + private: uint32_t outerIdx_; uint32_t splitIdx_;