From 7e776781bda7c7906a8b9b83943152888c5f5266 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E6=B0=91=E5=AE=89?= Date: Tue, 22 Jul 2025 22:03:52 +0800 Subject: [PATCH] fix tiling base align && optimize bmm bias copy --- .../copy_cube_in/bias/copy_bias_in_batch.h | 20 ++++---- .../matmul/tiling/matmul_tiling_algorithm.cpp | 18 ++++++-- impl/matmul/tiling/matmul_tiling_algorithm.h | 1 + tests/tiling/test_matmul_api_tiling.cpp | 46 +++++++++++++++++++ 4 files changed, 74 insertions(+), 11 deletions(-) diff --git a/impl/matmul/stage/copy_cube_in/bias/copy_bias_in_batch.h b/impl/matmul/stage/copy_cube_in/bias/copy_bias_in_batch.h index 000a3fd4..7146a59e 100644 --- a/impl/matmul/stage/copy_cube_in/bias/copy_bias_in_batch.h +++ b/impl/matmul/stage/copy_cube_in/bias/copy_bias_in_batch.h @@ -55,20 +55,24 @@ private: __aicore__ inline void BiasCopy(LocalTensor& bias, TensorT& srcTensor, int32_t dataLen, int32_t dataNum, int32_t srcOffset) { - Nd2NzParams intriParams {1, 1, static_cast(dataLen), 0, static_cast(dataLen), 1, 1, 0}; // Check if the bias is batched or not if constexpr (!ToMatmulConfig(MM_CFG).isBiasBatch) { + Nd2NzParams intriParams {1, 1, static_cast(dataLen), 0, static_cast(dataLen), 1, 1, 0}; // Not batched, only copy the data once DataCopy(bias, srcTensor, intriParams); } else { - // Batched, copy the data one by one - int32_t dstOffset = 0; + // Batched, copy dataNumm data by one instr auto dstStride = CeilAlign(dataLen, c0Size_); - for (int32_t i = 0; i < dataNum; ++i) { - DataCopy(bias[dstOffset], srcTensor[srcOffset], intriParams); - srcOffset += dataLen; - dstOffset += dstStride; - } + Nd2NzParams intriParams; + intriParams.ndNum = dataNum; + intriParams.nValue = 1; + intriParams.dValue = static_cast(dataLen); + intriParams.srcNdMatrixStride = static_cast(dataLen); + intriParams.srcDValue = static_cast(dataLen); + intriParams.dstNzC0Stride = 1; + intriParams.dstNzNStride = 1; + intriParams.dstNzMatrixStride = dstStride; + DataCopy(bias, srcTensor[srcOffset], intriParams); } } }; diff --git a/impl/matmul/tiling/matmul_tiling_algorithm.cpp b/impl/matmul/tiling/matmul_tiling_algorithm.cpp index e4ce59e5..178053d8 100644 --- a/impl/matmul/tiling/matmul_tiling_algorithm.cpp +++ b/impl/matmul/tiling/matmul_tiling_algorithm.cpp @@ -115,6 +115,18 @@ MatmulTilingAlgorithm::MatmulTilingAlgorithm(MatmulApiTilingBase* tilingIns) tilingIns_ = tilingIns; } +int32_t MatmulTilingAlgorithm::GetC0Size() const +{ + if (tilingIns_->aType_.dataType == DataType::DT_FLOAT) { + return FLOAT32_REDUCE_BLOCK_SIZE; + } else if (tilingIns_->aType_.dataType == DataType::DT_INT8) { + return INT8_REDUCE_BLOCK_SIZE; + } else if (tilingIns_->aType_.dataType == DataType::DT_INT4) { + return INT4_REDUCE_BLOCK_SIZE; + } + return REDUCE_BLOCK_SIZE; +} + int32_t MatmulTilingAlgorithm::GetBestValue(int32_t base) const { for (uint32_t i = 0; i < BEST_VALUE_LENGTH; ++i) { @@ -2691,10 +2703,10 @@ void MatmulTilingAlgorithm::CalcMultiCoreBlockDimsSmallShape(const MatmulRunPara baseBlock.baseK = results[0].baseBlock.baseK; } - int32_t tmpSize = L0_SIZE / DB_ON / (DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) / BITS_PER_BYTE); + int32_t tmpSize = L0_SIZE / DB_ON * BITS_PER_BYTE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType); baseBlock.baseK = min(tmpSize / baseBlock.baseM, tmpSize / baseBlock.baseN); - baseBlock.baseK = min(MathUtil::AlignDown(baseBlock.baseK, BASIC_SIZE_32), - static_cast(MathUtil::Align(params.oriShapeKa, static_cast(BASIC_SIZE_32)))); + baseBlock.baseK = min(MathUtil::AlignDown(baseBlock.baseK, GetC0Size()), + static_cast(MathUtil::Align(params.oriShapeKa, static_cast(GetC0Size())))); (void)CalcMultiCoreBlockDimsPost(params, coreStatus, blockDimRes); return; } diff --git a/impl/matmul/tiling/matmul_tiling_algorithm.h b/impl/matmul/tiling/matmul_tiling_algorithm.h index cd84473f..67b0a291 100644 --- a/impl/matmul/tiling/matmul_tiling_algorithm.h +++ b/impl/matmul/tiling/matmul_tiling_algorithm.h @@ -381,6 +381,7 @@ private: bool CheckFinaleParams(const CoreStatusPack& coreStatus) const; bool CheckBaseMN() const; int32_t GetBestValue(int32_t base) const; + int32_t GetC0Size() const; int32_t GetIteratorOrder(const SingleCoreStatus& singleCoreStatus, const int32_t singleCoreM, const int32_t singleCoreN, const int32_t singleCoreK) const; void GetL0StatusFromParasCombo(L0StatusPack& l0Status, int32_t* parasCombo) const; diff --git a/tests/tiling/test_matmul_api_tiling.cpp b/tests/tiling/test_matmul_api_tiling.cpp index 2d57fab0..d3da1fcc 100644 --- a/tests/tiling/test_matmul_api_tiling.cpp +++ b/tests/tiling/test_matmul_api_tiling.cpp @@ -211,6 +211,50 @@ TEST_F(TestMatmulAPITiling, SmallShapeAlign4) int64_t res = tilingApi.GetTiling(tilingData); tilingApi.PrintTilingData(); EXPECT_EQ(res, 0); + + tilingApi.SetDim(20); + + tilingApi.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT, false); + tilingApi.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT, false); + tilingApi.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); + tilingApi.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); + + tilingApi.SetOrgShape(16, 256, 256); + tilingApi.SetShape(-1, -1, -1); + tilingApi.EnableBias(false); + + tilingApi.SetBufferSpace(-1, -1, -1); + res = tilingApi.GetTiling(tilingData); + tilingApi.PrintTilingData(); + EXPECT_EQ(res, 0); + + tilingApi.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT8, false); + tilingApi.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT8, false); + tilingApi.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT32); + tilingApi.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT32); + + tilingApi.SetOrgShape(16, 256, 256); + tilingApi.SetShape(-1, -1, -1); + tilingApi.EnableBias(false); + + tilingApi.SetBufferSpace(-1, -1, -1); + res = tilingApi.GetTiling(tilingData); + tilingApi.PrintTilingData(); + EXPECT_EQ(res, 0); + + tilingApi.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT4, false); + tilingApi.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT4, false); + tilingApi.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT32); + tilingApi.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT32); + + tilingApi.SetOrgShape(16, 256, 256); + tilingApi.SetShape(-1, -1, -1); + tilingApi.EnableBias(false); + + tilingApi.SetBufferSpace(-1, -1, -1); + res = tilingApi.GetTiling(tilingData); + tilingApi.PrintTilingData(); + EXPECT_EQ(res, 0); } TEST_F(TestMatmulAPITiling, SmallShapeNotAlign) @@ -255,3 +299,5 @@ TEST_F(TestMatmulAPITiling, SmallShapeNotAlign2) EXPECT_EQ(res, 0); } + + -- Gitee