From 3182f22aa7114dc45f2621d45cffa3eb6100de16 Mon Sep 17 00:00:00 2001 From: YeZZ Date: Thu, 15 May 2025 09:46:44 +0800 Subject: [PATCH] A fullload constant tiling --- .../tiling/matmul_constant_tiling_impl.h | 40 +++++++++++++++++++ .../tiling/matmul_constant_tiling_utils.h | 3 ++ lib/matmul/constant_tiling.h | 14 ++++--- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/impl/matmul/tiling/matmul_constant_tiling_impl.h b/impl/matmul/tiling/matmul_constant_tiling_impl.h index b1691c9d..1d90462a 100644 --- a/impl/matmul/tiling/matmul_constant_tiling_impl.h +++ b/impl/matmul/tiling/matmul_constant_tiling_impl.h @@ -429,5 +429,45 @@ __aicore__ constexpr L1Status GetL1Factor(const MatmulConfig &mmCFG, int32_t l1S } return neitherFullLoad; } + +template +__aicore__ constexpr bool CalcAL1FullLoadTiling(int32_t l1Size, MatmulApiStaticTiling &tiling) +{ + if (tiling.singleCoreM > Impl::MIN_MN_SIZE || tiling.singleCoreN > Impl::MIN_MN_SIZE) { + return false; + } + l1Size = tiling.isBias ? (l1Size - Impl::BT_SIZE / Impl::BITS_PER_BYTE) : l1Size; + int32_t baseKaAlign = Align(tiling.baseK, + GetKAAlignValue() * GetReduceC0Size()); + int32_t baseKbAlign = Align(tiling.baseK, + GetKBAlignValue() * GetReduceC0Size()); + int32_t baseA = tiling.baseM * baseKaAlign * GetBitSize() / Impl::BITS_PER_BYTE; + int32_t baseB = tiling.baseN * baseKbAlign * GetBitSize() / Impl::BITS_PER_BYTE; + int32_t depthA1 = (l1Size - Impl::DB_ON * baseB) / baseA; + if (depthA1 * tiling.baseM * baseKaAlign < tiling.singleCoreM * tiling.singleCoreK) { + return false; + } + depthA1 = MaxValue(tiling.singleCoreM, tiling.baseM) * MaxValue(tiling.singleCoreK, baseKaAlign) / + tiling.baseM / baseKaAlign; + tiling.depthA1 = depthA1; + tiling.stepKa = depthA1; + + int32_t stepKb = (l1Size - depthA1 * baseA) / baseB / Impl::DB_ON; + if (stepKb * Impl::DB_ON * baseB > tiling.singleCoreK * tiling.singleCoreN) { + stepKb = MaxValue(tiling.singleCoreK, baseKbAlign) * MaxValue(tiling.singleCoreN, tiling.baseN) / + tiling.baseN / baseKbAlign / Impl::DB_ON; + } + if (stepKb < 1) { + tiling.depthB1 = 1; + tiling.stepKb = 1; + return true; + } + while (tiling.stepKa % stepKb != 0 && stepKb % tiling.stepKa != 0 && stepKb > 1) { + stepKb--; + } + tiling.depthB1 = stepKb * Impl::DB_ON; + tiling.stepKb = stepKb; + return true; +} } // namespace AscendC #endif // _MATMUL_CONSTANT_TILING_IMPL_ \ No newline at end of file diff --git a/impl/matmul/tiling/matmul_constant_tiling_utils.h b/impl/matmul/tiling/matmul_constant_tiling_utils.h index e9c67d01..01166a86 100644 --- a/impl/matmul/tiling/matmul_constant_tiling_utils.h +++ b/impl/matmul/tiling/matmul_constant_tiling_utils.h @@ -25,6 +25,9 @@ constexpr int32_t DB_ON = 2; constexpr int32_t DB_OFF = 1; constexpr int32_t MIN_MTE1_LOAD = 32; constexpr int32_t OUTER_STEP = 2; +constexpr int32_t BT_SIZE = 1024; +constexpr int32_t MIN_MN_SIZE = 16; +constexpr int32_t BITS_PER_BYTE = 8; #if __CCE_AICORE__ < 220 constexpr int32_t L1_SIZE = 1024 * 1024; #elif __CCE_AICORE__ == 300 diff --git a/lib/matmul/constant_tiling.h b/lib/matmul/constant_tiling.h index 7d589f7e..63eaa5df 100644 --- a/lib/matmul/constant_tiling.h +++ b/lib/matmul/constant_tiling.h @@ -46,16 +46,18 @@ __aicore__ constexpr MatmulApiStaticTiling GetMatmulApiTiling(const MatmulConfig tiling.baseM = mmCFG.basicM; tiling.baseN = mmCFG.basicN; tiling.baseK = mmCFG.basicK; + tiling.isBias = mmCFG.enableSetBias; tiling.stepM = l1Factor.mAL1; tiling.stepN = l1Factor.nBL1; int32_t reduceC0Size = GetReduceC0Size(); - int32_t kL0 = GetKL0(mmCFG); - tiling.stepKa = CeilNoLog(l1Factor.kAL1, kL0); - tiling.stepKb = CeilNoLog(l1Factor.kBL1, kL0); - tiling.depthA1 = CeilNoLog(l1Factor.kAL1, kL0) * l1Factor.mAL1 * l1Factor.dbAL1; - tiling.depthB1 = CeilNoLog(l1Factor.kBL1, kL0) * l1Factor.nBL1 * l1Factor.dbBL1; + if (!CalcAL1FullLoadTiling(l1Size, tiling)) { + int32_t kL0 = GetKL0(mmCFG); + tiling.stepKa = CeilNoLog(l1Factor.kAL1, kL0); + tiling.stepKb = CeilNoLog(l1Factor.kBL1, kL0); + tiling.depthA1 = CeilNoLog(l1Factor.kAL1, kL0) * l1Factor.mAL1 * l1Factor.dbAL1; + tiling.depthB1 = CeilNoLog(l1Factor.kBL1, kL0) * l1Factor.nBL1 * l1Factor.dbBL1; + } tiling.iterateOrder = GetIterateOrder(l1Factor, mmCFG); - tiling.isBias = mmCFG.enableSetBias; tiling.dbL0A = GetL0ADb(mmCFG, TOTAL_L0A_SIZE); tiling.dbL0B = GetL0BDb(mmCFG, TOTAL_L0B_SIZE); // keep the same with runtime tiling, fix l0c db -- Gitee