diff --git a/impl/matmul/matmul_tiling_algorithm.cpp b/impl/matmul/matmul_tiling_algorithm.cpp index a921221219062f3243307b843dddf2aac824148b..f6d3d5d8fa6a452f72ce216d353ceb26a26b22b2 100644 --- a/impl/matmul/matmul_tiling_algorithm.cpp +++ b/impl/matmul/matmul_tiling_algorithm.cpp @@ -673,7 +673,8 @@ void MatmulTilingAlgorithm::L1StatusAl1FullLoad(const CoreStatusPack& coreStatus GetABL1KAlignValue(kaAlignValue, kbAlignValue); l1Status.kAL1 = MathUtil::CeilDivision(l1Status.kAL1, l0Status.kL0) * l0Status.kL0; const int32_t curL1Size = GetL1Size(l1Status, l0Status); - if (curL1Size > 0 && curL1Size <= tilingIns_->bufferPool_.l1Size) { + const int32_t a1Length = GetAL1UbSize(l1Status, l0Status); + if (curL1Size > 0 && curL1Size <= tilingIns_->bufferPool_.l1Size && a1Length < tilingIns_->bufferPool_.ubSize) { l1Status.aL1FullLoad = true; l1Status.aL1Size = max(MathUtil::Align(coreStatus.k, kaAlignValue), MathUtil::Align(l1Status.kAL1, kaAlignValue)) * @@ -698,6 +699,12 @@ void MatmulTilingAlgorithm::L1StatusAl1FullLoad(const CoreStatusPack& coreStatus l1Status.kBL1 = min(CalL1MaxLen((l1Status.bL1Size - biasSize - dequantSize), l1Status, l0Status, kbAlignValue, L1TilingType::KBL1_16), coreStatus.k); + if (IsUbNd2Nz()) { + l1Status.dbBL1 = DB_OFF; + const int32_t b1Length = tilingIns_->bufferPool_.ubSize - a1Length; + l1Status.kBL1 = min(CalL1MaxLen(min(l1Status.bL1Size - biasSize - dequantSize, b1Length), + l1Status, l0Status, kbAlignValue, L1TilingType::KBL1_16), coreStatus.k); + } l1Status.bL1Times = min(l1Status.kBL1 / l0Status.kL0, l1Status.maxKBL1); GetNearestFactor(l1Status.allTimes, l1Status.bL1Times); // tik-mm support no factor---ncheck l1Status.kBL1 = l1Status.bL1Times * l0Status.kL0; @@ -740,7 +747,8 @@ void MatmulTilingAlgorithm::L1StatusBl1FullLoad(const CoreStatusPack& coreStatus GetABL1KAlignValue(kaAlignValue, kbAlignValue); l1Status.kBL1 = MathUtil::CeilDivision(l1Status.kBL1, l0Status.kL0) * l0Status.kL0; const int32_t curL1Size = GetL1Size(l1Status, l0Status); - if (curL1Size > 0 && curL1Size <= tilingIns_->bufferPool_.l1Size) { + const int32_t b1Length = GetBL1UbSize(l1Status, l0Status); + if (curL1Size > 0 && curL1Size <= tilingIns_->bufferPool_.l1Size && b1Length < tilingIns_->bufferPool_.ubSize) { l1Status.bL1FullLoad = true; l1Status.bL1Size = max(MathUtil::Align(coreStatus.k, kbAlignValue), MathUtil::Align(l1Status.kBL1, kbAlignValue)) * @@ -765,6 +773,12 @@ void MatmulTilingAlgorithm::L1StatusBl1FullLoad(const CoreStatusPack& coreStatus l1Status.kAL1 = min(CalL1MaxLen((l1Status.aL1Size - biasSize - dequantSize), l1Status, l0Status, kaAlignValue, L1TilingType::KAL1_16), coreStatus.k); + if (IsUbNd2Nz()) { + l1Status.dbAL1 = DB_OFF; + const int32_t a1Length = tilingIns_->bufferPool_.ubSize - b1Length; + l1Status.kAL1 = min(CalL1MaxLen(min(l1Status.aL1Size - biasSize - dequantSize, a1Length), + l1Status, l0Status, kaAlignValue, L1TilingType::KAL1_16), coreStatus.k); + } l1Status.aL1Times = min(l1Status.kAL1 / l0Status.kL0, l1Status.maxKAL1); GetNearestFactor(l1Status.allTimes, l1Status.aL1Times); // tik-mm support no factor---ncheck l1Status.kAL1 = l1Status.aL1Times * l0Status.kL0; @@ -799,11 +813,14 @@ void MatmulTilingAlgorithm::L1StatusBothFullLoad(const CoreStatusPack& coreStatu l1Status.kAL1 = MathUtil::CeilDivision(l1Status.kAL1, l0Status.kL0) * l0Status.kL0; l1Status.kBL1 = MathUtil::CeilDivision(l1Status.kBL1, l0Status.kL0) * l0Status.kL0; const int32_t curL1Size = GetL1Size(l1Status, l0Status); + const int32_t a1Length = GetAL1UbSize(l1Status, l0Status); + const int32_t b1Length = GetBL1UbSize(l1Status, l0Status); if (tilingIns_->aType_.pos == TPosition::TSCM && tilingIns_->bType_.pos == TPosition::TSCM) { l1Status.mAL1 = 1; l1Status.nBL1 = 1; } - if ((curL1Size > 0 && curL1Size <= tilingIns_->bufferPool_.l1Size) || + if ((curL1Size > 0 && curL1Size <= tilingIns_->bufferPool_.l1Size) && + a1Length + b1Length <= tilingIns_->bufferPool_.ubSize || (tilingIns_->aType_.pos == TPosition::TSCM && tilingIns_->bType_.pos == TPosition::TSCM)) { l1Status.bothFullLoad = true; l1Status.loadSize = (tilingIns_->aType_.pos == TPosition::TSCM ? 0 : coreStatus.m) + @@ -822,22 +839,27 @@ void MatmulTilingAlgorithm::NeitherFullLoadDb(const CoreStatusPack& coreStatus, { const int32_t tmpKbl116 = l1Status.kBL1; l1Status.kBL1 = kbl1Db; - if (GetL1Size(l1Status, l0Status) > tilingIns_->bufferPool_.l1Size) { + if (GetL1Size(l1Status, l0Status) > tilingIns_->bufferPool_.l1Size || + GetAL1UbSize(l1Status, l0Status) + GetBL1UbSize(l1Status, l0Status) > tilingIns_->bufferPool_.ubSize) { l1Status.dbBL1 = DB_OFF; - if (GetL1Size(l1Status, l0Status) > tilingIns_->bufferPool_.l1Size) { + if (GetL1Size(l1Status, l0Status) > tilingIns_->bufferPool_.l1Size || + GetAL1UbSize(l1Status, l0Status) + GetBL1UbSize(l1Status, l0Status) > tilingIns_->bufferPool_.ubSize) { l1Status.dbAL1 = DB_OFF; } } l1Status.kBL1 = coreStatus.k; const bool bothDoubleBuffer = coreStatus.m != l0Status.mL0 && coreStatus.k > l0Status.kL0 && - GetL1Size(l1Status, l0Status) > tilingIns_->bufferPool_.l1Size; + (GetL1Size(l1Status, l0Status) > tilingIns_->bufferPool_.l1Size || + GetAL1UbSize(l1Status, l0Status) + GetBL1UbSize(l1Status, l0Status) > tilingIns_->bufferPool_.ubSize); l1Status.kBL1 = tmpKbl116; if (bothDoubleBuffer) { l1Status.dbAL1 = DB_ON; l1Status.dbBL1 = DB_ON; - if (GetL1Size(l1Status, l0Status) > tilingIns_->bufferPool_.l1Size) { + if (GetL1Size(l1Status, l0Status) > tilingIns_->bufferPool_.l1Size || + GetAL1UbSize(l1Status, l0Status) + GetBL1UbSize(l1Status, l0Status) > tilingIns_->bufferPool_.ubSize) { l1Status.dbBL1 = DB_OFF; - if (GetL1Size(l1Status, l0Status) > tilingIns_->bufferPool_.l1Size) { + if (GetL1Size(l1Status, l0Status) > tilingIns_->bufferPool_.l1Size || + GetAL1UbSize(l1Status, l0Status) + GetBL1UbSize(l1Status, l0Status) > tilingIns_->bufferPool_.ubSize) { l1Status.dbAL1 = DB_OFF; } } @@ -878,20 +900,36 @@ void MatmulTilingAlgorithm::NeitherFullLoadMN(const CoreStatusPack& coreStatus, l1Mfirst.bL1Size = MathUtil::Align(l1Mfirst.kBL1, kbAlignValue) * l0Status.nL0 * C0_SIZE * C0_BYTE_SIZE * l1Mfirst.dbBL1; l1Mfirst.aL1Size = tilingIns_->bufferPool_.l1Size - l1Mfirst.bL1Size; + int32_t a1Length = tilingIns_->bufferPool_.ubSize - GetBL1UbSize(l1Mfirst, l0Status); l1Mfirst.mAL1 = max(min(min( CalL1MaxLen(l1Mfirst.aL1Size - biasSize - dequantSize, l1Mfirst, l0Status, kaAlignValue, L1TilingType::M_AL1), l1Mfirst.maxMAL1), mRepeat), 1); + if (IsUbNd2Nz()) { + l1Mfirst.mAL1 = max(min(min( + CalL1MaxLen(min(l1Mfirst.aL1Size - biasSize - dequantSize, a1Length), l1Mfirst, l0Status, kaAlignValue, L1TilingType::M_AL1), + l1Mfirst.maxMAL1), + mRepeat), + 1); + } GetNearestFactor(mRepeat, l1Mfirst.mAL1); // tik-mm support no factor ----ncheck l1Mfirst.aL1Size = MathUtil::Align(l1Mfirst.kAL1, kaAlignValue) * l1Mfirst.mAL1 * l0Status.mL0 * C0_SIZE * C0_BYTE_SIZE * l1Mfirst.dbAL1; l1Mfirst.bL1Size = tilingIns_->bufferPool_.l1Size - l1Mfirst.aL1Size; + int32_t b1Length = tilingIns_->bufferPool_.ubSize - GetAL1UbSize(l1Mfirst, l0Status); l1Mfirst.nBL1 = max(min(min( CalL1MaxLen(l1Mfirst.bL1Size - biasSize - dequantSize, l1Mfirst, l0Status, kbAlignValue, L1TilingType::N_BL1), l1Mfirst.maxNBL1), nRepeat), 1); + if (IsUbNd2Nz()) { + l1Mfirst.nBL1 = max(min(min( + CalL1MaxLen(min(l1Mfirst.bL1Size - biasSize - dequantSize, b1Length), l1Mfirst, l0Status, kbAlignValue, L1TilingType::N_BL1), + l1Mfirst.maxNBL1), + nRepeat), + 1); + } GetNearestFactor(nRepeat, l1Mfirst.nBL1); l1Mfirst.loadSize = coreStatus.m + coreStatus.n * MathUtil::CeilDivision(coreStatus.m, l1Mfirst.mAL1 * l0Status.mL0); @@ -900,21 +938,37 @@ void MatmulTilingAlgorithm::NeitherFullLoadMN(const CoreStatusPack& coreStatus, l1Nfirst.aL1Size = MathUtil::Align(l1Nfirst.kAL1, kaAlignValue) * l0Status.mL0 * C0_SIZE * C0_BYTE_SIZE * l1Nfirst.dbAL1; l1Nfirst.bL1Size = tilingIns_->bufferPool_.l1Size - l1Nfirst.aL1Size; + b1Length = tilingIns_->bufferPool_.ubSize - GetAL1UbSize(l1Nfirst, l0Status); l1Nfirst.nBL1 = max(min(min( CalL1MaxLen(l1Nfirst.bL1Size - biasSize - dequantSize, l1Nfirst, l0Status, kbAlignValue, L1TilingType::N_BL1), l1Nfirst.maxNBL1), nRepeat), 1); + if (IsUbNd2Nz()) { + l1Nfirst.nBL1 = max(min(min( + CalL1MaxLen(min(l1Nfirst.bL1Size - biasSize - dequantSize, b1Length), l1Nfirst, l0Status, kbAlignValue, L1TilingType::N_BL1), + l1Nfirst.maxNBL1), + nRepeat), + 1); + } GetNearestFactor(nRepeat, l1Nfirst.nBL1); l1Nfirst.bL1Size = MathUtil::Align(coreStatus.k, kbAlignValue) * l1Nfirst.nBL1 * l0Status.nL0 * C0_SIZE * C0_BYTE_SIZE * l1Nfirst.dbBL1; l1Nfirst.aL1Size = tilingIns_->bufferPool_.l1Size - l1Nfirst.bL1Size; biasSize = biasSize * l1Nfirst.nBL1; + a1Length = tilingIns_->bufferPool_.ubSize - GetBL1UbSize(l1Nfirst, l0Status); l1Nfirst.mAL1 = max(min(min( CalL1MaxLen(l1Nfirst.aL1Size - biasSize - dequantSize, l1Nfirst, l0Status, kaAlignValue, L1TilingType::M_AL1), l1Nfirst.maxMAL1), mRepeat), 1); + if (IsUbNd2Nz()) { + l1Nfirst.mAL1 = max(min(min( + CalL1MaxLen(min(l1Nfirst.aL1Size - biasSize - dequantSize, a1Length), l1Nfirst, l0Status, kaAlignValue, L1TilingType::M_AL1), + l1Nfirst.maxMAL1), + mRepeat), + 1); + } GetNearestFactor(mRepeat, l1Nfirst.mAL1); l1Nfirst.loadSize = coreStatus.m * MathUtil::CeilDivision(coreStatus.n, l1Nfirst.nBL1 * l0Status.nL0) + coreStatus.n; @@ -976,9 +1030,15 @@ void MatmulTilingAlgorithm::NeitherFullLoadKforNZ(const CoreStatusPack& coreStat l1Status.bL1Size = MathUtil::Align(coreStatus.k, kbAlignValue) * l1Status.nBL1 * l0Status.nL0 * C0_SIZE * C0_BYTE_SIZE * l1Status.dbBL1; l1Status.aL1Size = tilingIns_->bufferPool_.l1Size - l1Status.bL1Size; + int32_t a1Length = tilingIns_->bufferPool_.ubSize - GetBL1UbSize(l1Status, l0Status); l1Status.kAL1 = min(CalL1MaxLen(l1Status.aL1Size - biasSize - dequantSize, l1Status, l0Status, kaAlignValue, L1TilingType::KAL1_16), coreStatus.k); + if (IsUbNd2Nz()) { + l1Status.kAL1 = min(CalL1MaxLen(min(l1Status.aL1Size - biasSize - dequantSize, a1Length), l1Status, l0Status, kaAlignValue, + L1TilingType::KAL1_16), + coreStatus.k); + } l1Status.aL1Times = max(min(l1Status.kAL1 / l0Status.kL0, l1Status.maxKAL1), 1); GetNearestFactor(l1Status.allTimes, l1Status.aL1Times); l1Status.kAL1 = l1Status.aL1Times * l0Status.kL0; @@ -989,6 +1049,13 @@ void MatmulTilingAlgorithm::NeitherFullLoadKforNZ(const CoreStatusPack& coreStat C0_SIZE * l0Status.nL0 * C0_BYTE_SIZE * l1Status.dbBL1) / l0Status.kL0 * l0Status.kL0, coreStatus.k); + if (IsUbNd2Nz()) { + perK = min(min(tilingIns_->bufferPool_.l1Size - biasSize - dequantSize, tilingIns_->bufferPool_.ubSize) / + (l0Status.mL0 * C0_SIZE * C0_BYTE_SIZE * l1Status.dbAL1 + + C0_SIZE * l0Status.nL0 * C0_BYTE_SIZE * l1Status.dbBL1) / + l0Status.kL0 * l0Status.kL0, + coreStatus.k); + } const int32_t biasFactor = tilingIns_->isBias ? l1Status.nBL1 * l0Status.nL0 : 0; const int32_t aAlignedPerK = MathUtil::Align(perK, kaAlignValue); const int32_t bAlignedPerK = MathUtil::Align(perK, kbAlignValue); @@ -1047,28 +1114,49 @@ void MatmulTilingAlgorithm::NeitherFullLoadKforND(const CoreStatusPack& coreStat l1Status.bL1Size = l1Status.kBL1 * l1Status.nBL1 * l0Status.nL0 * C0_SIZE * C0_BYTE_SIZE * l1Status.dbBL1; } l1Status.aL1Size = tilingIns_->bufferPool_.l1Size - l1Status.bL1Size; + int32_t a1Length = tilingIns_->bufferPool_.ubSize - GetBL1UbSize(l1Status, l0Status); auto factor = l1Status.mAL1 * l0Status.mL0 * C0_SIZE * l1Status.dbAL1 * C0_BYTE_SIZE; l1Status.kAL1 = (factor == 0) ? coreStatus.k : min((l1Status.aL1Size - biasSize - dequantSize) / factor, coreStatus.k); + if (IsUbNd2Nz()) { + l1Status.kAL1 = (factor == 0) ? coreStatus.k : min(min(l1Status.aL1Size - biasSize - dequantSize, a1Length) / factor, + coreStatus.k); + } l1Status.aL1Times = max(l1Status.kAL1 / l0Status.kL0, 1); GetNearestFactor(l1Status.allTimes, l1Status.aL1Times); // tik-mm support no factor ----ncheck l1Status.kAL1 = l1Status.aL1Times * l0Status.kL0; l1Status.aL1Size = l1Status.kAL1 * l1Status.mAL1 * l0Status.mL0 * C0_SIZE * C0_BYTE_SIZE * l1Status.dbAL1; l1Status.bL1Size = tilingIns_->bufferPool_.l1Size - l1Status.aL1Size; + int32_t b1Length = tilingIns_->bufferPool_.ubSize - GetAL1UbSize(l1Status, l0Status); if ((tilingIns_->bType_.dataType == DataType::DT_FLOAT) || (tilingIns_->aType_.isTrans && tilingIns_->aType_.dataType == DataType::DT_INT8)) { l1Status.kBL1 = min((l1Status.bL1Size - biasSize - dequantSize) / (l1Status.nBL1 * l0Status.nL0 * C0_SIZE * l1Status.dbBL1 * alignK), coreStatus.k); + if (IsUbNd2Nz()) { + l1Status.kBL1 = min(min(l1Status.bL1Size - biasSize - dequantSize, b1Length) / + (l1Status.nBL1 * l0Status.nL0 * C0_SIZE * l1Status.dbBL1 * alignK), + coreStatus.k); + } } else if (!tilingIns_->bType_.isTrans && (tilingIns_->bType_.dataType == DataType::DT_INT8 || tilingIns_->bType_.dataType == DataType::DT_INT4)) { l1Status.kBL1 = min((l1Status.bL1Size - biasSize - dequantSize) / (alignN * l0Status.nL0 * l1Status.dbBL1 * alignK), coreStatus.k); + if (IsUbNd2Nz()) { + l1Status.kBL1 = min(min(l1Status.bL1Size - biasSize - dequantSize, b1Length) / + (alignN * l0Status.nL0 * l1Status.dbBL1 * alignK), + coreStatus.k); + } } else { - l1Status.kBL1 = min((l1Status.bL1Size - biasSize - dequantSize)/ + l1Status.kBL1 = min((l1Status.bL1Size - biasSize - dequantSize) / (l1Status.nBL1 * l0Status.nL0 * C0_SIZE * l1Status.dbBL1 * C0_BYTE_SIZE), coreStatus.k); + if (IsUbNd2Nz()) { + l1Status.kBL1 = min(min(l1Status.bL1Size - biasSize - dequantSize, b1Length) / + (l1Status.nBL1 * l0Status.nL0 * C0_SIZE * l1Status.dbBL1 * C0_BYTE_SIZE), + coreStatus.k); + } } l1Status.bL1Times = max(min(l1Status.kBL1 / l0Status.kL0, l1Status.maxKBL1), 1); GetNearestFactor(l1Status.allTimes, l1Status.bL1Times); @@ -1087,28 +1175,47 @@ void MatmulTilingAlgorithm::NeitherFullLoadKforND(const CoreStatusPack& coreStat } l1Status.bL1Size = tilingIns_->bufferPool_.l1Size - l1Status.aL1Size; + int32_t b1Length = tilingIns_->bufferPool_.ubSize - GetAL1UbSize(l1Status, l0Status); l1Status.kBL1 = min((l1Status.bL1Size - biasSize - dequantSize) / (l1Status.nBL1 * l0Status.nL0 * C0_SIZE * l1Status.dbBL1 * C0_BYTE_SIZE), coreStatus.k); + if (IsUbNd2Nz()) { + l1Status.kBL1 = min(min(l1Status.bL1Size - biasSize - dequantSize, b1Length) / + (l1Status.nBL1 * l0Status.nL0 * C0_SIZE * l1Status.dbBL1 * C0_BYTE_SIZE), + coreStatus.k); + } l1Status.bL1Times = max(l1Status.kBL1 / l0Status.kL0, 1); GetNearestFactor(l1Status.allTimes, l1Status.bL1Times); l1Status.kBL1 = l1Status.bL1Times * l0Status.kL0; l1Status.bL1Size = l1Status.kBL1 * l1Status.nBL1 * l0Status.nL0 * C0_SIZE * C0_BYTE_SIZE * l1Status.dbBL1; l1Status.aL1Size = tilingIns_->bufferPool_.l1Size - l1Status.bL1Size; + int32_t a1Length = tilingIns_->bufferPool_.ubSize - GetBL1UbSize(l1Status, l0Status); if ((tilingIns_->aType_.isTrans && tilingIns_->aType_.dataType == DataType::DT_FLOAT) || (!tilingIns_->aType_.isTrans && (tilingIns_->aType_.dataType == DataType::DT_INT8 || tilingIns_->aType_.dataType == DataType::DT_INT4))) { auto factor = l1Status.mAL1 * l0Status.mL0 * C0_SIZE * l1Status.dbAL1 * alignK; - l1Status.kAL1 = (factor == 0) ? coreStatus.k : min((l1Status.aL1Size - biasSize - dequantSize) / factor, + l1Status.kAL1 = (factor == 0) ? coreStatus.k : min((l1Status.aL1Size - biasSize - dequantSize) / factor, coreStatus.k); + if (IsUbNd2Nz()) { + l1Status.kAL1 = (factor == 0) ? coreStatus.k : min(min(l1Status.aL1Size - biasSize - dequantSize, a1Length) / factor, + coreStatus.k); + } } else if (tilingIns_->aType_.isTrans && tilingIns_->aType_.dataType == DataType::DT_INT8) { l1Status.kAL1 = min((l1Status.aL1Size - biasSize - dequantSize) / (alignM * l0Status.mL0 * l1Status.dbAL1 * alignK), coreStatus.k); + if (IsUbNd2Nz()) { + l1Status.kAL1 = min(min(l1Status.aL1Size - biasSize - dequantSize, a1Length) / + (alignM * l0Status.mL0 * l1Status.dbAL1 * alignK), coreStatus.k); + } l1Status.aL1Size = l1Status.kAL1 * alignM * l0Status.mL0 * alignK * l1Status.dbAL1; } else { auto factor = l1Status.mAL1 * l0Status.mL0 * C0_SIZE * l1Status.dbAL1 * C0_BYTE_SIZE; l1Status.kAL1 = (factor == 0) ? coreStatus.k : min((l1Status.aL1Size - biasSize - dequantSize) / factor, coreStatus.k); + if (IsUbNd2Nz()) { + l1Status.kAL1 = (factor == 0) ? coreStatus.k : min(min(l1Status.aL1Size - biasSize - dequantSize, a1Length) / factor, + coreStatus.k); + } } l1Status.aL1Times = max(min(l1Status.kAL1 / l0Status.kL0, l1Status.maxKAL1), 1); GetNearestFactor(l1Status.allTimes, l1Status.aL1Times); @@ -1361,6 +1468,62 @@ void MatmulTilingAlgorithm::GetUsedSize(int32_t& l1Size, int32_t& l0cSize, int32 return; } +void MatmulTilingAlgorithm::GetBankConflictSize(const L1StatusPack& l1Status, const L0StatusPack& l0Status, + int32_t& length, bool isAMatrix) const +{ + constexpr int blockSize = 32; + constexpr int bankLen = 512; + bool isBankConflict = false; + int bankConflictSize = 0; + const int32_t reduceSize = static_cast(C0_BYTE_SIZE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) * BITS_PER_BYTE); + if (isAMatrix) { + if (tilingIns_->aType_.isTrans) { + isBankConflict = + MathUtil::CeilDivision(l1Status.mAL1 * l0Status.mL0 * C0_SIZE, C0_SIZE) * + blockSize % bankLen == + 0 ? + true : + false; + bankConflictSize = l0Status.kL0 * reduceSize * C0_SIZE * + MathUtil::CeilDivision(l1Status.kAL1, l0Status.kL0) * + DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) / BITS_PER_BYTE; + } else { + isBankConflict = + MathUtil::CeilDivision(MathUtil::CeilDivision(l1Status.kAL1, l0Status.kL0) * l0Status.kL0 * reduceSize, + C0_SIZE) * blockSize % bankLen == + 0 ? + true : + false; + bankConflictSize = l0Status.mL0 * C0_SIZE * C0_SIZE * l1Status.mAL1 * + DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) / BITS_PER_BYTE; + } + } else { + if (tilingIns_->bType_.isTrans) { + isBankConflict = + MathUtil::CeilDivision(MathUtil::CeilDivision(l1Status.kBL1, l0Status.kL0) * l0Status.kL0 * reduceSize, + C0_SIZE) * blockSize % bankLen == + 0 ? + true : + false; + bankConflictSize = l0Status.nL0 * C0_SIZE * C0_SIZE * l1Status.nBL1 * + DTYPE_BIT_TAB.at(tilingIns_->bType_.dataType) / BITS_PER_BYTE; + } else { + isBankConflict = + MathUtil::CeilDivision(l1Status.nBL1 * l0Status.nL0 * C0_SIZE, C0_SIZE) * + blockSize % bankLen == + 0 ? + true : + false; + bankConflictSize = l0Status.kL0 * reduceSize * C0_SIZE * + MathUtil::CeilDivision(l1Status.kBL1, l0Status.kL0) * + DTYPE_BIT_TAB.at(tilingIns_->bType_.dataType) / BITS_PER_BYTE; + } + } + if (isBankConflict) { + length = length + bankConflictSize; + } +} + void MatmulTilingAlgorithm::GetBankConflictSize(int32_t& length, bool isAMatrix) const { constexpr int blockSize = 32; @@ -1413,6 +1576,55 @@ void MatmulTilingAlgorithm::GetBankConflictSize(int32_t& length, bool isAMatrix) } } +int32_t MatmulTilingAlgorithm::GetAL1UbSize(const L1StatusPack& l1Status, const L0StatusPack& l0Status) const +{ + int32_t a1Length = 0; + const int32_t reduceSize = static_cast(C0_BYTE_SIZE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) * BITS_PER_BYTE); + if (IsUbNd2Nz()) { + // A matrix ND2NZ + if (tilingIns_->aType_.type == CubeFormat::ND) { + a1Length = l0Status.mL0 * C0_SIZE * l0Status.kL0 * reduceSize * + DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) / BITS_PER_BYTE; + if (tilingIns_->mmConfigType == 1) { + a1Length = a1Length * MathUtil::CeilDivision(l1Status.kAL1, l0Status.kL0) * l1Status.mAL1; + } + // bank conflict + GetBankConflictSize(l1Status, l0Status, a1Length, true); + } + } + return a1Length; +} + +int32_t MatmulTilingAlgorithm::GetBL1UbSize(const L1StatusPack& l1Status, const L0StatusPack& l0Status) const +{ + int32_t b1Length = 0; + const int32_t reduceSize = static_cast(C0_BYTE_SIZE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) * BITS_PER_BYTE); + if (IsUbNd2Nz()) { + // B matrix ND2NZ + if (tilingIns_->bType_.type == CubeFormat::ND) { + b1Length = l0Status.nL0 * C0_SIZE * l0Status.kL0 * reduceSize * + DTYPE_BIT_TAB.at(tilingIns_->bType_.dataType) / BITS_PER_BYTE; + if (tilingIns_->mmConfigType == 1) { + b1Length = b1Length * MathUtil::CeilDivision(l1Status.kBL1, l0Status.kL0) * l1Status.nBL1; + } + // bank conflict + GetBankConflictSize(l1Status, l0Status, b1Length, false); + } + } + return b1Length; +} + +bool MatmulTilingAlgorithm::IsUbNd2Nz() const +{ + if (tilingIns_->enVecND2NZ && tilingIns_->mmConfigType == 1 && + (tilingIns_->socVersion == platform_ascendc::SocVersion::ASCEND910 || + tilingIns_->socVersion == platform_ascendc::SocVersion::ASCEND310P || + tilingIns_->socVersion == platform_ascendc::SocVersion::ASCEND310B)) { + return true; + } + return false; +} + void MatmulTilingAlgorithm::GetTransLength(int32_t& transLength) const { int32_t a1Length = 0; diff --git a/impl/matmul/matmul_tiling_algorithm.h b/impl/matmul/matmul_tiling_algorithm.h index 5937c096ceb60be08648b606ae3b8ec3ca97ef72..9ef135064c391c57ab0e47c6cbab0b6fe749deef 100644 --- a/impl/matmul/matmul_tiling_algorithm.h +++ b/impl/matmul/matmul_tiling_algorithm.h @@ -354,6 +354,10 @@ private: void GetUsedSize(int32_t& l1Size, int32_t& l0cSize, int32_t& ubSize, int32_t a1LengthCache, int32_t b1LengthCache) const; void GetBankConflictSize(int32_t& length, bool isAMatrix) const; + void GetBankConflictSize(const L1StatusPack& l1Status, const L0StatusPack& l0Status, int32_t& length, bool isAMatrix) const; + int32_t GetAL1UbSize(const L1StatusPack& l1Status, const L0StatusPack& l0Status) const; + int32_t GetBL1UbSize(const L1StatusPack& l1Status, const L0StatusPack& l0Status) const; + bool IsUbNd2Nz() const; void GetTransLength(int32_t& transLength) const; void SetDepthL1CacheUBParams(int32_t &a1LengthCache, int32_t &b1LengthCache) const; void GetABL1KAlignValue(int32_t& kaAlignValue, int32_t& kbAlignValue) const; diff --git a/lib/matmul/matmul_tiling_base.h b/lib/matmul/matmul_tiling_base.h index 12c5f1f1d6a8a83c3d3017ad8dfe52e883e63dd2..50f3c1a9566fd0e1d273237d694cbca040446b39 100644 --- a/lib/matmul/matmul_tiling_base.h +++ b/lib/matmul/matmul_tiling_base.h @@ -176,10 +176,20 @@ struct PlatformInfo { }; struct MatmulConfigParams { - int32_t mmConfigType = 1; - bool enableL1CacheUB = false; - ScheduleType scheduleType = ScheduleType::INNER_PRODUCT; - MatrixTraverse traverse = MatrixTraverse::NOSET; + int32_t mmConfigType; + bool enableL1CacheUB; + ScheduleType scheduleType; + MatrixTraverse traverse; + bool enVecND2NZ; + MatmulConfigParams(int32_t mmConfigTypeIn = 1, bool enableL1CacheUBIn = false, + ScheduleType scheduleTypeIn = ScheduleType::INNER_PRODUCT, MatrixTraverse traverseIn = MatrixTraverse::NOSET, + bool enVecND2NZIn = false) { + mmConfigType = mmConfigTypeIn; + enableL1CacheUB = enableL1CacheUBIn; + scheduleType = scheduleTypeIn; + traverse = traverseIn; + enVecND2NZ = enVecND2NZIn; + } }; class MatmulApiTilingBase { @@ -317,6 +327,7 @@ public: platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B; int32_t mmConfigType = 1; // 0: Norm; 1: MDL bool enableL1CacheUB = false; + bool enVecND2NZ = false; protected: virtual int64_t Compute() = 0;