diff --git a/impl/matmul/matmul_impl.h b/impl/matmul/matmul_impl.h index 5b9088255e02c11f93158cd191771e192804e9de..c85c053c4664582dead9769ecfd73b7c945fd9b1 100644 --- a/impl/matmul/matmul_impl.h +++ b/impl/matmul/matmul_impl.h @@ -4583,8 +4583,7 @@ __aicore__ inline void MatmulImpl(); // if var.baseUseN_ is not 32B align, use DataCopy Nd2Nz if ((var.baseUseN_ * sizeof(BiasT)) % ONE_BLK_SIZE != 0) { - Nd2NzParams intriParams{ 1, 1, (uint16_t)var.baseUseN_, 0, (uint16_t)var.baseUseN_, 1, 1, 1 }; - DataCopy(bias, biasGlobal[col * var.tiling_->baseN], intriParams); + DataCopy(bias, biasGlobal[col * var.tiling_->baseN], { 1, 1, (uint16_t)var.baseUseN_, 0, 1, 1, 1, 0 }); } else { auto blockLen = Ceil(var.baseUseN_ * sizeof(BiasT), ONE_BLK_SIZE); if constexpr (MM_CFG.scheduleType == ScheduleType::OUTER_PRODUCT && MM_CFG.iterateOrder == IterateOrder::ORDER_M) { diff --git a/impl/matmul/matmul_tiling_algorithm.cpp b/impl/matmul/matmul_tiling_algorithm.cpp index d62c49510118e3156b01b9f1d864caffe7af01c4..3e9611a89cecfaae94d3b59c777eba3b7aded2cc 100644 --- a/impl/matmul/matmul_tiling_algorithm.cpp +++ b/impl/matmul/matmul_tiling_algorithm.cpp @@ -1022,11 +1022,11 @@ void MatmulTilingAlgorithm::NeitherFullLoadKforND(const CoreStatusPack& coreStat } else if (tilingIns_->aType_.dataType == DataType::DT_INT4) { alignValue = INT4_ALIGN_SIZE; } - const int32_t reduceSize = C0_BYTE_SIZE / DTYPE_BYTE_TAB.at(tilingIns_->aType_.dataType); + const int32_t reduceSize = C0_BYTE_SIZE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) * BITS_PER_BYTE; const int32_t alignM = MathUtil::CeilDivision(l1Status.mAL1 * C0_SIZE, alignValue) * alignValue; const int32_t alignN = MathUtil::CeilDivision(l1Status.nBL1 * C0_SIZE, alignValue) * alignValue; const int32_t alignK = MathUtil::CeilDivision(l0Status.kL0 * reduceSize, alignValue) * alignValue * - DTYPE_BYTE_TAB.at(tilingIns_->aType_.dataType); + DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) / BITS_PER_BYTE; if (kMaxAxis == 1) { // first get k_al1, second get k_bl1 l1Status.kBL1 = l0Status.kL0; @@ -1211,7 +1211,7 @@ void MatmulTilingAlgorithm::GetL1Factors(const std::string& opType, const Matmul // choose the final factors int32_t* tmpFactors = res[IDX_THREE]; int32_t tmpLoadSize = tmpFactors[IDX_SIX]; - int32_t reduceSize = static_cast(C0_BYTE_SIZE / DTYPE_BYTE_TAB.at(tilingIns_->aType_.dataType)); + int32_t reduceSize = static_cast(C0_BYTE_SIZE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) * BITS_PER_BYTE); const int32_t kAl1FactorOne = res[IDX_ONE][IDX_ZERO] > 0 ? MathUtil::CeilDivision( MathUtil::CeilDivision(tilingIns_->singleK, reduceSize), (coreStatus.kDim * res[IDX_ONE][IDX_ZERO])) : 1; @@ -1259,13 +1259,13 @@ void MatmulTilingAlgorithm::GetL1Factors(const std::string& opType, const Matmul void MatmulTilingAlgorithm::GetUsedSize(int32_t& l1Size, int32_t& l0cSize, int32_t& ubSize, int32_t a1LengthCache, int32_t b1LengthCache) const { - const uint32_t aTypeSize = DTYPE_BYTE_TAB.at(tilingIns_->aType_.dataType); - const uint32_t bTypeSize = DTYPE_BYTE_TAB.at(tilingIns_->bType_.dataType); + const uint32_t aTypeSize = DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType); + const uint32_t bTypeSize = DTYPE_BIT_TAB.at(tilingIns_->bType_.dataType); const uint32_t cTypeSize = DTYPE_BYTE_TAB.at(tilingIns_->cType_.dataType); const uint32_t biasTypeSize = DTYPE_BYTE_TAB.at(tilingIns_->biasType_.dataType); - const int32_t a1Length = tilingIns_->tiling_.get_baseM() * tilingIns_->tiling_.get_baseK() * aTypeSize; - const int32_t b1Length = tilingIns_->tiling_.get_baseN() * tilingIns_->tiling_.get_baseK() * bTypeSize; + const int32_t a1Length = tilingIns_->tiling_.get_baseM() * tilingIns_->tiling_.get_baseK() * aTypeSize / BITS_PER_BYTE; + const int32_t b1Length = tilingIns_->tiling_.get_baseN() * tilingIns_->tiling_.get_baseK() * bTypeSize / BITS_PER_BYTE; const int32_t c1Length = tilingIns_->tiling_.get_baseN() * tilingIns_->tiling_.get_baseM() * FP32_BYTES; if (tilingIns_->aType_.pos != TPosition::TSCM) { @@ -1302,19 +1302,19 @@ void MatmulTilingAlgorithm::GetUsedSize(int32_t& l1Size, int32_t& l0cSize, int32 // (2) 输入VECCALC, format是ND, 需要在ub中对非对齐的尾块进行补零操作 int32_t aUbLength = 0; int32_t bUbLength = 0; - if (!tilingIns_->aType_.isTrans && ((tilingIns_->tiling_.get_singleCoreK() * aTypeSize) % C0_BYTE_SIZE != 0)) { + if (!tilingIns_->aType_.isTrans && ((tilingIns_->tiling_.get_singleCoreK() * aTypeSize / BITS_PER_BYTE) % C0_BYTE_SIZE != 0)) { aUbLength = tilingIns_->tiling_.get_baseM() * C0_BYTE_SIZE; } if (tilingIns_->aType_.isTrans && - ((tilingIns_->tiling_.get_singleCoreM() * aTypeSize) % C0_BYTE_SIZE != 0)) { + ((tilingIns_->tiling_.get_singleCoreM() * aTypeSize / BITS_PER_BYTE) % C0_BYTE_SIZE != 0)) { aUbLength = tilingIns_->tiling_.get_baseK() * C0_BYTE_SIZE; } - if (!tilingIns_->bType_.isTrans && ((tilingIns_->tiling_.get_singleCoreN() * bTypeSize) % C0_BYTE_SIZE != 0)) { + if (!tilingIns_->bType_.isTrans && ((tilingIns_->tiling_.get_singleCoreN() * bTypeSize / BITS_PER_BYTE) % C0_BYTE_SIZE != 0)) { bUbLength = tilingIns_->tiling_.get_baseK() * C0_BYTE_SIZE; } if (tilingIns_->bType_.isTrans && - ((tilingIns_->tiling_.get_singleCoreK() * bTypeSize) % C0_BYTE_SIZE != 0)) { + ((tilingIns_->tiling_.get_singleCoreK() * bTypeSize / BITS_PER_BYTE) % C0_BYTE_SIZE != 0)) { bUbLength = tilingIns_->tiling_.get_baseN() * C0_BYTE_SIZE; } if (tilingIns_->aType_.pos == TPosition::TSCM) { @@ -1369,7 +1369,7 @@ void MatmulTilingAlgorithm::GetBankConflictSize(int32_t& length, bool isAMatrix) true : false; bankConflictSize = tilingIns_->tiling_.get_baseK() * C0_SIZE * tilingIns_->tiling_.get_stepKa() * - DTYPE_BYTE_TAB.at(tilingIns_->aType_.dataType); + DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) / BITS_PER_BYTE; } else { isBankConflict = MathUtil::CeilDivision(tilingIns_->tiling_.get_stepKa() * tilingIns_->tiling_.get_baseK(), C0_SIZE) * @@ -1378,7 +1378,7 @@ void MatmulTilingAlgorithm::GetBankConflictSize(int32_t& length, bool isAMatrix) true : false; bankConflictSize = tilingIns_->tiling_.get_baseM() * C0_SIZE * tilingIns_->tiling_.get_stepM() * - DTYPE_BYTE_TAB.at(tilingIns_->aType_.dataType); + DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) / BITS_PER_BYTE; } } else { if (tilingIns_->bType_.isTrans) { @@ -1389,7 +1389,7 @@ void MatmulTilingAlgorithm::GetBankConflictSize(int32_t& length, bool isAMatrix) true : false; bankConflictSize = tilingIns_->tiling_.get_baseN() * C0_SIZE * tilingIns_->tiling_.get_stepN() * - DTYPE_BYTE_TAB.at(tilingIns_->bType_.dataType); + DTYPE_BIT_TAB.at(tilingIns_->bType_.dataType) / BITS_PER_BYTE; } else { isBankConflict = MathUtil::CeilDivision(tilingIns_->tiling_.get_stepN() * tilingIns_->tiling_.get_baseN(), C0_SIZE) * @@ -1398,7 +1398,7 @@ void MatmulTilingAlgorithm::GetBankConflictSize(int32_t& length, bool isAMatrix) true : false; bankConflictSize = tilingIns_->tiling_.get_baseK() * C0_SIZE * tilingIns_->tiling_.get_stepKb() * - DTYPE_BYTE_TAB.at(tilingIns_->bType_.dataType); + DTYPE_BIT_TAB.at(tilingIns_->bType_.dataType) / BITS_PER_BYTE; } } if (isBankConflict) { @@ -1418,7 +1418,7 @@ void MatmulTilingAlgorithm::GetTransLength(int32_t& transLength) const // A matrix ND2NZ if (tilingIns_->aType_.type == CubeFormat::ND) { a1Length = tilingIns_->tiling_.get_baseM() * tilingIns_->tiling_.get_baseK() * - DTYPE_BYTE_TAB.at(tilingIns_->aType_.dataType); + DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) / BITS_PER_BYTE; if (tilingIns_->mmConfigType == 1) { a1Length = a1Length * tilingIns_->tiling_.get_stepKa() * tilingIns_->tiling_.get_stepM(); } @@ -1428,7 +1428,7 @@ void MatmulTilingAlgorithm::GetTransLength(int32_t& transLength) const // B matrix ND2NZ if (tilingIns_->bType_.type == CubeFormat::ND) { b1Length = tilingIns_->tiling_.get_baseN() * tilingIns_->tiling_.get_baseK() * - DTYPE_BYTE_TAB.at(tilingIns_->bType_.dataType); + DTYPE_BIT_TAB.at(tilingIns_->bType_.dataType) / BITS_PER_BYTE; if (tilingIns_->mmConfigType == 1) { b1Length = b1Length * tilingIns_->tiling_.get_stepKb() * tilingIns_->tiling_.get_stepN(); } @@ -1484,7 +1484,7 @@ int32_t MatmulTilingAlgorithm::GetIteratorOrder(SingleCoreStatus& singleCoreStat if (tilingIns_->traverse_ != MatrixTraverse::NOSET) { return static_cast(tilingIns_->traverse_) - 1; } - const int32_t reduceSize = static_cast(C0_BYTE_SIZE / DTYPE_BYTE_TAB.at(tilingIns_->aType_.dataType)); + const int32_t reduceSize = static_cast(C0_BYTE_SIZE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) * BITS_PER_BYTE); const bool fullkAL1Load = (static_cast(singleCoreK) / (singleCoreStatus.l1Status.kAL1 * reduceSize)) > 1.0 ? false : true; bool fullkBL1Load = @@ -2009,7 +2009,7 @@ void MatmulTilingAlgorithm::FillParam(MatmulRunParas& param) numOfBlock_ = tilingIns_->blockDim; } - const int32_t reduceBlockSize = C0_BYTE_SIZE / DTYPE_BYTE_TAB.at(tilingIns_->aType_.dataType); + const int32_t reduceBlockSize = C0_BYTE_SIZE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) * BITS_PER_BYTE; param.k32 = MathUtil::CeilDivision(realK, reduceBlockSize); param.m32 = MathUtil::CeilDivision(realM, C0_SIZE); param.n32 = MathUtil::CeilDivision(realN, C0_SIZE); @@ -2046,17 +2046,12 @@ bool MatmulTilingAlgorithm::CheckFinaleParams(const CoreStatusPack& coreStatus) return false; } - int dateDtypeSize = tilingIns_->aType_.dataType == DataType::DT_FLOAT ? - FP32_BYTES : - (tilingIns_->aType_.dataType == DataType::DT_INT8 ? INT8_BYTES : FP16_BYTES); - if (tilingIns_->aType_.dataType == DataType::DT_INT4) { - dateDtypeSize = INT8_BYTES; - } + int dateDtypeSize = DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType); if (tilingIns_->tiling_.get_BatchNum() > 0 && ((tilingIns_->tiling_.get_singleCoreM() * tilingIns_->tiling_.get_singleCoreK() + tilingIns_->tiling_.get_singleCoreN() * tilingIns_->tiling_.get_singleCoreK()) * - tilingIns_->tiling_.get_BatchNum() * dateDtypeSize + - tilingIns_->tiling_.get_singleCoreN() * tilingIns_->tiling_.get_BatchNum() * dateDtypeSize > + tilingIns_->tiling_.get_BatchNum() * dateDtypeSize / BITS_PER_BYTE + + tilingIns_->tiling_.get_singleCoreN() * tilingIns_->tiling_.get_BatchNum() * dateDtypeSize / BITS_PER_BYTE > tilingIns_->bufferPool_.l1Size)) { TILING_LOG_WARNING("a/b matrix size of batch mm should less then L1Size"); return false; @@ -2098,9 +2093,9 @@ void MatmulTilingAlgorithm::SetDepthL1CacheUBParams(int32_t &a1LengthCache, int3 return; } int32_t a1Length = tilingIns_->tiling_.get_baseM() * tilingIns_->tiling_.get_baseK() * - DTYPE_BYTE_TAB.at(tilingIns_->aType_.dataType); + DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) / BITS_PER_BYTE; int32_t b1Length = tilingIns_->tiling_.get_baseN() * tilingIns_->tiling_.get_baseK() * - DTYPE_BYTE_TAB.at(tilingIns_->bType_.dataType); + DTYPE_BIT_TAB.at(tilingIns_->bType_.dataType) / BITS_PER_BYTE; a1LengthCache = a1Length * tilingIns_->tiling_.get_stepKa() * tilingIns_->tiling_.get_stepM(); b1LengthCache = b1Length * tilingIns_->tiling_.get_stepKb() * tilingIns_->tiling_.get_stepN(); int32_t freeL1Size = tilingIns_->bufferPool_.l1Size - tilingIns_->tiling_.get_depthA1() * a1Length - @@ -2278,13 +2273,8 @@ int64_t MatmulTilingAlgorithm::Process() tilingIns_->tiling_.set_baseM(singleCoreStatus.l0Status.mL0 * C0_SIZE); tilingIns_->tiling_.set_baseN(singleCoreStatus.l0Status.nL0 * C0_SIZE); - const int32_t reduceSize = C0_BYTE_SIZE / DTYPE_BYTE_TAB.at(tilingIns_->aType_.dataType); - // int4 baseK should be 64 align - if ((tilingIns_->aType_.dataType == DataType::DT_INT4) && (singleCoreStatus.l0Status.kL0 % NUM_TWO != 0)) { - tilingIns_->tiling_.set_baseK((singleCoreStatus.l0Status.kL0 + 1) * reduceSize); - } else { - tilingIns_->tiling_.set_baseK(singleCoreStatus.l0Status.kL0 * reduceSize); - } + const int32_t reduceSize = C0_BYTE_SIZE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) * BITS_PER_BYTE; + tilingIns_->tiling_.set_baseK(singleCoreStatus.l0Status.kL0 * reduceSize); tilingIns_->baseM = tilingIns_->tiling_.get_baseM(); tilingIns_->baseN = tilingIns_->tiling_.get_baseN(); tilingIns_->baseK = tilingIns_->tiling_.get_baseK(); diff --git a/impl/matmul/matmul_tiling_base.cpp b/impl/matmul/matmul_tiling_base.cpp index a6d17ac014dd5c86dfdefb02c0beb5939cc40c45..0e8e77f8134bfff46c19e90dc6d735d267d94620 100644 --- a/impl/matmul/matmul_tiling_base.cpp +++ b/impl/matmul/matmul_tiling_base.cpp @@ -456,7 +456,7 @@ int32_t MatmulApiTilingBase::SetFixSplit(int32_t baseMIn, int32_t baseNIn, int32 this->adjust_.maxBaseN = baseNIn; this->adjust_.minBaseN = baseNIn; } - const int32_t k0 = C0_BYTE_SIZE / DTYPE_BYTE_TAB.at(this->aType_.dataType); + const int32_t k0 = C0_BYTE_SIZE / DTYPE_BIT_TAB.at(this->aType_.dataType) * BITS_PER_BYTE; if (baseKIn != -1) { if (baseKIn % k0 > 0) { return -1; @@ -592,19 +592,19 @@ bool MatmulApiTilingBase::CheckSetParam() } } - int32_t dataBytes = DTYPE_BYTE_TAB.at(aType_.dataType); + int32_t dataBits = DTYPE_BIT_TAB.at(aType_.dataType); if (this->baseM != -1 && this->baseK != -1) { // 设置了baseM, baseK, 限制 L0A - if (this->baseM * this->baseK * dataBytes > this->bufferPool_.l0ASize) { + if (this->baseM * this->baseK * dataBits / BITS_PER_BYTE > this->bufferPool_.l0ASize) { TILING_LOG_INFO("baseM * baseK is larger then L0ASize"); return false; } } - dataBytes = DTYPE_BYTE_TAB.at(bType_.dataType); + dataBits = DTYPE_BIT_TAB.at(bType_.dataType); if (this->baseK != -1 && this->baseN != -1) { // 设置了baseM, baseK, 限制 L0B - if (this->baseK * this->baseN * dataBytes > this->bufferPool_.l0BSize) { + if (this->baseK * this->baseN * dataBits / BITS_PER_BYTE > this->bufferPool_.l0BSize) { TILING_LOG_INFO("baseN * baseK is larger then l0BSize"); return false; } @@ -774,9 +774,9 @@ void MatmulApiTilingBase::PrintTilingData() << std::endl; std::cout << "tiling.L0ARatio = " << (this->tiling_.get_baseM() * this->tiling_.get_baseK() + 0.0) * - DTYPE_BYTE_TAB.at(this->aType_.dataType) / this->oriBufferPool_.l0ASize << std::endl; + DTYPE_BIT_TAB.at(this->aType_.dataType) / BITS_PER_BYTE / this->oriBufferPool_.l0ASize << std::endl; std::cout << "tiling.L0BRatio = " << (this->tiling_.get_baseN() * this->tiling_.get_baseK() + 0.0) * - DTYPE_BYTE_TAB.at(this->bType_.dataType) / this->oriBufferPool_.l0BSize << std::endl; + DTYPE_BIT_TAB.at(this->bType_.dataType) / BITS_PER_BYTE / this->oriBufferPool_.l0BSize << std::endl; } } // namespace matmul_tiling \ No newline at end of file diff --git a/lib/matmul/matmul_tiling_base.h b/lib/matmul/matmul_tiling_base.h index ff9de1ead7e1752e02e5c53a73c10056dc5643ed..33319d47e208eb6e71070e7bdfca06f8f7b0c946 100644 --- a/lib/matmul/matmul_tiling_base.h +++ b/lib/matmul/matmul_tiling_base.h @@ -26,6 +26,7 @@ constexpr int32_t FP32_BYTES = 4; constexpr int32_t FP16_BYTES = 2; constexpr int32_t C0_SIZE = 16; constexpr int32_t C0_BYTE_SIZE = 32; +constexpr int32_t BITS_PER_BYTE = 8; enum class DataType : int32_t { DT_FLOAT = 0, // float type DT_FLOAT16 = 1, // fp16 type @@ -70,6 +71,13 @@ const std::map DTYPE_BYTE_TAB = { {DataType::DT_INT4, 1} }; +const std::map DTYPE_BIT_TAB = { + {DataType::DT_FLOAT, 32}, {DataType::DT_FLOAT16, 16}, {DataType::DT_INT8, 8}, {DataType::DT_INT16, 16}, + {DataType::DT_UINT16, 16}, {DataType::DT_UINT8, 8}, {DataType::DT_INT32, 32}, {DataType::DT_INT64, 64}, + {DataType::DT_UINT32, 32}, {DataType::DT_UINT64, 64}, {DataType::DT_BF16, 16}, {DataType::DT_BFLOAT16, 16}, + {DataType::DT_INT4, 4} +}; + enum class TPosition : int32_t { GM, A1, diff --git a/lib/matmul/tiling.h b/lib/matmul/tiling.h index 7e42d91fface1f14eb735f1d7f4d474e87292ba9..cc8ae9e28dbdf55458d96eb43071aa9243595786 100644 --- a/lib/matmul/tiling.h +++ b/lib/matmul/tiling.h @@ -157,17 +157,13 @@ struct MatmulBatchParams { struct MatmulFuncParams { bool intrinsicsCheck; bool enVecND2NZ; - uint32_t doMTE2Preload; - bool enableQuantVector = true; - bool enableSetDefineData = true; - uint8_t iterateMode = IterateMode::ITERATE_MODE_DEFAULT; - bool enableReuse = true; - bool enableUBReuse; + bool enableDoubleCache; bool enableL1CacheUB; - bool intraBlockPartSum = false; + uint32_t doMTE2Preload; IterateOrder iterateOrder; ScheduleType scheduleType; - bool enableDoubleCache; + bool enableReuse = true; + bool enableUBReuse; }; __aicore__ constexpr MatmulConfig GetNormalConfig(const bool intrinsicsLimit = false, const bool batchLoop = false, @@ -466,26 +462,6 @@ constexpr MatmulConfig CFG_MDL = GetMDLConfig(); constexpr MatmulConfig MM_CFG_BB = GetBasicConfig(128, 128, 128); constexpr MatmulConfig CFG_IBSHARE_NORM = GetIBShareNormConfig(); -template -__aicore__ inline constexpr MatmulConfig GetMMConfig(ArgTypes&&... args) { - MatmulConfig mmConfig = CFG_NORM; - if constexpr (configMode == MatmulConfigMode::CONFIG_MDL) { - mmConfig = CFG_MDL; - } else if constexpr (configMode == MatmulConfigMode::CONFIG_SPECIALMDL) { - mmConfig = GetSpecialMDLConfig(); - } else if constexpr (configMode == MatmulConfigMode::CONFIG_IBSHARE) { - mmConfig = CFG_IBSHARE_NORM; - } - GetMMConfigImpl(mmConfig, args...); - return mmConfig; -} - -template -__aicore__ inline constexpr void GetMMConfigImpl(MatmulConfig& cfg, T arg, ArgTypes&&... args) { - GetMMConfigImpl(cfg, arg); - GetMMConfigImpl(cfg, args...); -} - template __aicore__ inline constexpr void GetMMConfigImpl(MatmulConfig& cfg, ArgType arg) { if constexpr (AscendC::IsSameType::value) { @@ -504,18 +480,34 @@ __aicore__ inline constexpr void GetMMConfigImpl(MatmulConfig& cfg, ArgType arg) } else if constexpr (AscendC::IsSameType::value) { cfg.intrinsicsCheck = arg.intrinsicsCheck; cfg.enVecND2NZ = arg.enVecND2NZ; - cfg.doMTE2Preload = arg.doMTE2Preload; - cfg.enableQuantVector = arg.enableQuantVector; - cfg.enableSetDefineData = arg.enableSetDefineData; - cfg.iterateMode = arg.iterateMode; - cfg.enableReuse = arg.enableReuse; - cfg.enableUBReuse = arg.enableUBReuse; + cfg.enableDoubleCache = arg.enableDoubleCache; cfg.enableL1CacheUB = arg.enableL1CacheUB; - cfg.intraBlockPartSum = arg.intraBlockPartSum; + cfg.doMTE2Preload = arg.doMTE2Preload; cfg.iterateOrder = arg.iterateOrder; cfg.scheduleType = arg.scheduleType; - cfg.enableDoubleCache = arg.enableDoubleCache; + cfg.enableReuse = arg.enableReuse; + cfg.enableUBReuse = arg.enableUBReuse; + } +} + +template +__aicore__ inline constexpr void GetMMConfigImpl(MatmulConfig& cfg, T arg, ArgTypes&&... args) { + GetMMConfigImpl(cfg, arg); + GetMMConfigImpl(cfg, args...); +} + +template +__aicore__ inline constexpr MatmulConfig GetMMConfig(ArgTypes&&... args) { + MatmulConfig mmConfig = CFG_NORM; + if constexpr (configMode == MatmulConfigMode::CONFIG_MDL) { + mmConfig = CFG_MDL; + } else if constexpr (configMode == MatmulConfigMode::CONFIG_SPECIALMDL) { + mmConfig = GetSpecialMDLConfig(); + } else if constexpr (configMode == MatmulConfigMode::CONFIG_IBSHARE) { + mmConfig = CFG_IBSHARE_NORM; } + GetMMConfigImpl(mmConfig, args...); + return mmConfig; } struct MatrixOffset { diff --git a/tests/matmul/test_matmul_config.cpp b/tests/matmul/test_matmul_config.cpp index 33dce3b32ba2eb13fff05d3a12febd629611289d..5f5d416a564858afee94e8504d0439df3c82e883 100644 --- a/tests/matmul/test_matmul_config.cpp +++ b/tests/matmul/test_matmul_config.cpp @@ -30,32 +30,28 @@ TEST_F(TestMatmulConfig, TestParamsConfig) constexpr static MatmulShapeParams shapeParams{128, 128, 128, 64, 64, 64}; constexpr static MatmulQuantParams quantParams{1, 1}; constexpr static MatmulBatchParams batchParams{1, BatchMode::BATCH_LARGE_THAN_L1}; - constexpr static MatmulFuncParams funcParams{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - IterateOrder::ORDER_N, ScheduleType::OUTER_PRODUCT, 1}; + constexpr static MatmulFuncParams funcParams{1, 1, 1, 1, 1, IterateOrder::ORDER_N, ScheduleType::OUTER_PRODUCT, + 1, 1}; constexpr MatmulConfig mmConfig = GetMMConfig(shapeParams, quantParams, batchParams, funcParams); - EXPECT_EQ((uint32_t)mmConfig.doNorm, 1); - EXPECT_EQ((uint32_t)mmConfig.singleCoreM, 128); - EXPECT_EQ((uint32_t)mmConfig.singleCoreN, 128); - EXPECT_EQ((uint32_t)mmConfig.singleCoreK, 128); - EXPECT_EQ((uint32_t)mmConfig.basicM, 64); - EXPECT_EQ((uint32_t)mmConfig.basicN, 64); - EXPECT_EQ((uint32_t)mmConfig.basicK, 64); - EXPECT_EQ((uint32_t)mmConfig.isPerTensor, 1); - EXPECT_EQ((uint32_t)mmConfig.hasAntiQuantOffset, 1); - EXPECT_EQ((uint32_t)mmConfig.isNBatch, 1); - EXPECT_EQ((uint32_t)mmConfig.batchMode, 2); - EXPECT_EQ((uint32_t)mmConfig.intrinsicsCheck, 1); - EXPECT_EQ((uint32_t)mmConfig.enVecND2NZ, 1); - EXPECT_EQ((uint32_t)mmConfig.doMTE2Preload, 1); - EXPECT_EQ((uint32_t)mmConfig.enableQuantVector, 1); - EXPECT_EQ((uint32_t)mmConfig.enableSetDefineData, 1); - EXPECT_EQ((uint32_t)mmConfig.iterateMode, 1); - EXPECT_EQ((uint32_t)mmConfig.enableReuse, 1); - EXPECT_EQ((uint32_t)mmConfig.enableUBReuse, 1); - EXPECT_EQ((uint32_t)mmConfig.enableL1CacheUB, 1); - EXPECT_EQ((uint32_t)mmConfig.intraBlockPartSum, 1); - EXPECT_EQ((uint32_t)mmConfig.iterateOrder, 1); - EXPECT_EQ((uint32_t)mmConfig.scheduleType, 1); - EXPECT_EQ((uint32_t)mmConfig.enableDoubleCache, 1); + EXPECT_EQ(mmConfig.doNorm, true); + EXPECT_EQ(mmConfig.singleCoreM, 128); + EXPECT_EQ(mmConfig.singleCoreN, 128); + EXPECT_EQ(mmConfig.singleCoreK, 128); + EXPECT_EQ(mmConfig.basicM, 64); + EXPECT_EQ(mmConfig.basicN, 64); + EXPECT_EQ(mmConfig.basicK, 64); + EXPECT_EQ(mmConfig.isPerTensor, true); + EXPECT_EQ(mmConfig.hasAntiQuantOffset, true); + EXPECT_EQ(mmConfig.isNBatch, true); + EXPECT_EQ(mmConfig.batchMode, BatchMode::BATCH_LARGE_THAN_L1); + EXPECT_EQ(mmConfig.intrinsicsCheck, true); + EXPECT_EQ(mmConfig.enVecND2NZ, true); + EXPECT_EQ(mmConfig.enableDoubleCache, true); + EXPECT_EQ(mmConfig.enableL1CacheUB, true); + EXPECT_EQ(mmConfig.doMTE2Preload, 1); + EXPECT_EQ(mmConfig.iterateOrder, IterateOrder::ORDER_N); + EXPECT_EQ(mmConfig.scheduleType, ScheduleType::OUTER_PRODUCT); + EXPECT_EQ(mmConfig.enableReuse, true); + EXPECT_EQ(mmConfig.enableUBReuse, true); } \ No newline at end of file