diff --git a/impl/matmul/param/matmul_shape_info.h b/impl/matmul/param/matmul_shape_info.h index 683e906621fcf8aa7e4d4b2b0fb6d68cbb7e0663..2ac8196deeaf5bb64dafdf5adb9a837baafc7adc 100644 --- a/impl/matmul/param/matmul_shape_info.h +++ b/impl/matmul/param/matmul_shape_info.h @@ -104,13 +104,13 @@ public: if constexpr (A_TYPE::format != CubeFormat::VECTOR) { ASCENDC_ASSERT((GetSingleCoreM() % ToMatmulConfig(MM_CFG).basicM == 0), { KERNEL_LOG(KERNEL_ERROR, - "singleCoreM is %d, basicM is %d, singleCoreM sould be a multiple of basicM in Basic Block mode.", + "singleCoreM is %d, basicM is %d, singleCoreM should be a multiple of basicM in Basic Block mode.", GetSingleCoreM(), ToMatmulConfig(MM_CFG).basicM); }); } ASCENDC_ASSERT((GetSingleCoreN() % ToMatmulConfig(MM_CFG).basicN == 0), { KERNEL_LOG(KERNEL_ERROR, - "singleCoreN is %d, basicN is %d, singleCoreN sould be a multiple of basicN in Basic Block mode.", + "singleCoreN is %d, basicN is %d, singleCoreN should be a multiple of basicN in Basic Block mode.", GetSingleCoreN(), ToMatmulConfig(MM_CFG).basicN); }); } diff --git a/impl/matmul/param/matmul_shape_tiling.h b/impl/matmul/param/matmul_shape_tiling.h index c7244e8d1820fb98b11bf1344f86c1e8b01f8a36..4af37db0b4c210d3f562f5dccf0c2f239613780d 100644 --- a/impl/matmul/param/matmul_shape_tiling.h +++ b/impl/matmul/param/matmul_shape_tiling.h @@ -44,12 +44,12 @@ public: NumericalValidCheck(); ShareInfoCheck(); if constexpr (!HasScalePosition::value && !HasScalePosition::value) { - ShapeVaildCheck(); + ShapeValidCheck(); DepthCheck(); ConfigCommonCheck(); ConfigSpecificCheck(); } else { - MxShapeVaildCheck(); + MxShapeValidCheck(); DepthCheck(); } #endif @@ -112,7 +112,7 @@ private: } template - __aicore__ inline void ShapeVaildCheck() + __aicore__ inline void ShapeValidCheck() { const auto L0ABUseSizeFactor = (tiling_.GetDbL0A() - 1) & (tiling_.GetDbL0B() - 1) ? Impl::DB_FACTOR : 1; const auto L0CUseSizeFactor = (tiling_.GetDbL0C() == Impl::DB_FACTOR) ? Impl::DB_FACTOR : 1; @@ -160,7 +160,7 @@ private: } template - __aicore__ inline void MxShapeVaildCheck() + __aicore__ inline void MxShapeValidCheck() { #if defined(__DAV_C310__) || defined(__DAV_310R6__) const auto L0ABUseSizeFactor = (tiling_.GetDbL0A() - 1) & (tiling_.GetDbL0B() - 1) ? Impl::DB_FACTOR : 1; diff --git a/impl/matmul/tiling/matmul_tiling_algorithm.cpp b/impl/matmul/tiling/matmul_tiling_algorithm.cpp index 0e9a5f05ac2a91fce06c9c792a4408595739f18f..a88e0cee67aff3b764f7b62c0a8cc8030bf35e84 100644 --- a/impl/matmul/tiling/matmul_tiling_algorithm.cpp +++ b/impl/matmul/tiling/matmul_tiling_algorithm.cpp @@ -2540,7 +2540,7 @@ void MatmulTilingAlgorithm::AdjustFloatL1Factor(const SingleCoreStatus& singleCo int64_t MatmulTilingAlgorithm::UpdateTiling(const MatmulRunParas& param, const CoreStatusPack &coreStatus, SingleCoreStatus& singleCoreStatus) const { - int32_t coreUse = singelBlockDim_ ? tilingIns_->blockDim : coreStatus.batchDim * coreStatus.mDim * coreStatus.kDim * coreStatus.nDim; + int32_t coreUse = singleBlockDim_ ? tilingIns_->blockDim : coreStatus.batchDim * coreStatus.mDim * coreStatus.kDim * coreStatus.nDim; int32_t singleCoreM; int32_t singleCoreN; int32_t singleCoreK; @@ -2971,13 +2971,13 @@ void MatmulTilingAlgorithm::FillParam(MatmulRunParas& param) realM = tilingIns_->singleCoreM != -1 ? tilingIns_->singleCoreM : tilingIns_->singleM; realK = tilingIns_->singleCoreK != -1 ? tilingIns_->singleCoreK : tilingIns_->singleK; realN = tilingIns_->singleCoreN != -1 ? tilingIns_->singleCoreN : tilingIns_->singleN; - singelBlockDim_ = true; + singleBlockDim_ = true; numOfBlock_ = 1; } else { realM = GetSingleM(); realK = GetSingleK(); realN = GetSingleN(); - singelBlockDim_ = false; + singleBlockDim_ = false; numOfBlock_ = tilingIns_->blockDim; } @@ -3373,7 +3373,7 @@ void MatmulTilingAlgorithm::GetSingleShape(const CoreStatusPack &coreStatus, con singleCoreN = MathUtil::CeilDivision(singleCoreN, coreStatus.nDim); singleCoreK = GetSingleK(); singleCoreK = MathUtil::CeilDivision(singleCoreK, coreStatus.kDim); - if (singelBlockDim_) { + if (singleBlockDim_) { singleCoreM = tilingIns_->singleCoreM != -1 ? tilingIns_->singleCoreM : tilingIns_->singleM; singleCoreN = tilingIns_->singleCoreN != -1 ? tilingIns_->singleCoreN : tilingIns_->singleN; singleCoreK = tilingIns_->singleCoreK != -1 ? tilingIns_->singleCoreK : tilingIns_->singleK; @@ -3736,7 +3736,7 @@ int64_t MatmulTilingAlgorithm::Process() TILING_LOG_WARNING("check baseM/baseN not pass"); return -1; } - singelBlockDim_ = false; + singleBlockDim_ = false; splitCoreFlag_ = false; CoreStatusPack coreStatus; SingleCoreStatus singleCoreStatus; diff --git a/impl/matmul/tiling/matmul_tiling_algorithm.h b/impl/matmul/tiling/matmul_tiling_algorithm.h index 200fb0d1dc27274aaf62508b7658019e41254414..562b00c7300fe36f7b57f484662b6abfe8b695f6 100644 --- a/impl/matmul/tiling/matmul_tiling_algorithm.h +++ b/impl/matmul/tiling/matmul_tiling_algorithm.h @@ -523,7 +523,7 @@ private: int64_t UpdateTiling(const MatmulRunParas& param, const CoreStatusPack &coreStatus, SingleCoreStatus& singleCoreStatus) const; private: MatmulApiTilingBase* tilingIns_ = nullptr; - bool singelBlockDim_ = false; + bool singleBlockDim_ = false; bool splitCoreFlag_ = false; int32_t dbL0A_ = DB_ON; int32_t dbL0B_ = DB_ON; diff --git a/impl/matmul/tiling/matmul_tiling_base.cpp b/impl/matmul/tiling/matmul_tiling_base.cpp index 4efb3b2e61db2fc60e5a63fab186eb38265019fc..c4557dccee5ff4416e9a6f01f8b957a5d60ac67c 100644 --- a/impl/matmul/tiling/matmul_tiling_base.cpp +++ b/impl/matmul/tiling/matmul_tiling_base.cpp @@ -655,7 +655,7 @@ bool MatmulApiTilingBase::CheckSetParam() if (this->baseM != -1 && this->baseK != -1) { // set baseM, baseK, L0A limited if (this->baseM * this->baseK * dataBits / BITS_PER_BYTE > this->bufferPool_.l0ASize) { - TILING_LOG_INFO("baseM * baseK is larger then L0ASize"); + TILING_LOG_INFO("baseM * baseK is larger than L0ASize"); return false; } } @@ -664,7 +664,7 @@ bool MatmulApiTilingBase::CheckSetParam() if (this->baseK != -1 && this->baseN != -1) { // set baseM, baseK, L0B limited if (this->baseK * this->baseN * dataBits / BITS_PER_BYTE > this->bufferPool_.l0BSize) { - TILING_LOG_INFO("baseN * baseK is larger then l0BSize"); + TILING_LOG_INFO("baseN * baseK is larger than l0BSize"); return false; } } @@ -672,7 +672,7 @@ bool MatmulApiTilingBase::CheckSetParam() if (this->baseM != -1 && this->baseN != -1) { // set baseM, baseN, L0C limited if (this->baseM * this->baseN * FP32_BYTES > this->bufferPool_.l0CSize) { - TILING_LOG_INFO("baseM * baseN is larger then L0CSize"); + TILING_LOG_INFO("baseM * baseN is larger than L0CSize"); return false; } } diff --git a/lib/matmul/matmul_tiling_base.h b/lib/matmul/matmul_tiling_base.h index b0c6666a24b4c6b5e6bd1ee5a60cd0006a2b10df..51d9ba51b3379f7c8a20d8512f0996d3e9971753 100644 --- a/lib/matmul/matmul_tiling_base.h +++ b/lib/matmul/matmul_tiling_base.h @@ -91,125 +91,217 @@ const std::map DTYPE_BIT_TAB = { {DataType::DT_FLOAT8_E4M3FN, 8}, {DataType::DT_FLOAT8_E5M2, 8}, {DataType::DT_FLOAT8_E8M0, 8} }; #endif // __ASCC_DEVICE__ +/** +* @enum class TPosition +* @brief TPosition inherits from int32_t and includes a set of storage positions +*/ enum class TPosition : int32_t { - GM, - A1, - A2, - B1, - B2, - C1, - C2, - CO1, - CO2, - VECIN, - VECOUT, - VECCALC, - LCM = VECCALC, - SPM, - SHM = SPM, - TSCM, - MAX, + GM, ///< GM position + A1, ///< A1 position + A2, ///< A2 position + B1, ///< B1 position + B2, ///< B2 position + C1, ///< C1 position + C2, ///< C2 position + CO1, ///< CO1 position + CO2, ///< CO2 position + VECIN, ///< Vector input position + VECOUT, ///< Vector output position + VECCALC, ///< Vector calculation position + LCM = VECCALC, ///< LCM position (equivalent to VECCALC) + SPM, ///< SPM position + SHM = SPM, ///< SHM position (equivalent to SPM) + TSCM, ///< TSCM position + MAX, ///< Maximum position }; - +/** +* @enum class TilingPolicy +* @brief TilingPolicy inherits from int32_t and includes a set of policys +*/ enum class TilingPolicy : int32_t { - FIXED_A_TSCM, - FIXED_B_TSCM, - FIXED_A_B_TSCM, - NO_POLICY + FIXED_A_TSCM, ///< Fixed A TSCM policy + FIXED_B_TSCM, ///< Fixed B TSCM policy + FIXED_A_B_TSCM, ///< Fixed A and B TSCM policy + NO_POLICY ///< No policy }; - +/** +* @enum class CubeFormat +* @brief CubeFormat inherits from int32_t and includes a set of cube formats +*/ enum class CubeFormat : int32_t { - ND = 0, - NZ, - ZN, - ZZ, - NN, - ND_ALIGN, - SCALAR, - VECTOR, - ROW_MAJOR = ND, // ND - COLUMN_MAJOR = 8, // DN + ND = 0, ///< Undefined format + NZ, ///< NZ format + ZN, ///< ZN format + ZZ, ///< ZZ format + NN, ///< NN format + ND_ALIGN, ///< ND alignment format + SCALAR, ///< Scalar format + VECTOR, ///< Vector format + ROW_MAJOR = ND, // ND ///< Row-major format, equivalent to ND + COLUMN_MAJOR = 8, // DN ///< Column-major format, equivalent to DN }; - +/** +* @enum class MatrixTraverse +* @brief MatrixTraverse inherits from int32_t and includes a set of traverse methods +*/ enum class MatrixTraverse : int32_t { - NOSET = 0, - FIRSTM = 1, - FIRSTN = 2, + NOSET = 0, ///< Traverse method not set + FIRSTM = 1, ///< Traverse by rows first + FIRSTN = 2, ///< Traverse by columns first }; - +/** +* @enum class MatrixMadType +* @brief MatrixMadType inherits from int32_t and includes a set of matrix operation modes +*/ enum class MatrixMadType : int32_t { - NORMAL = 0, - HF32 = 1, // V220 HF32 - MXMODE = 2, // v310 MxMatmulFlag + NORMAL = 0, ///< Normal matrix operation mode + HF32 = 1, // V220 HF32 ///< High-performance 32-bit floating-point operation mode (V220 HF32) + MXMODE = 2, // v310 MxMatmulFlag ///< Matrix multiplication flag mode (V310 MXMODE) }; - +/** +* @enum class DequantType +* @brief DequantType inherits from int32_t and includes a set of quantification modes +*/ enum class DequantType : int32_t { - SCALAR = 0, - TENSOR = 1, + SCALAR = 0, ///< Scalar type, value is 0 + TENSOR = 1, ///< Tensor type, value is 1 }; - +/** +* @enum class ScheduleType +* @brief ScheduleType inherits from int32_t and includes a set of operation types +*/ enum class ScheduleType : int32_t { - INNER_PRODUCT = 0, - OUTER_PRODUCT = 1, - N_BUFFER_33 = 2, + INNER_PRODUCT = 0, ///< Inner product operation type, value is 0 + OUTER_PRODUCT = 1, ///< Outer product operation type, value is 1 + N_BUFFER_33 = 2, ///< Buffer type, value is 2 }; - +/** +* @struct SysTilingTempBufSize +* @brief System tiling temporary buffer size structure +* +* This structure stores the temporary buffer size information required during system tiling. +*/ struct SysTilingTempBufSize { int32_t ubSize = 0; int32_t l1Size = 0; int32_t l0cSize = 0; }; - +/** +* @struct MatTilingType +* @brief Structure for matrix tiling type configuration +*/ struct MatTilingType { + /** + * @brief Matrix position, default is global memory (GM) + */ TPosition pos = TPosition::GM; + /** + * @brief Matrix format, default is ND format + */ CubeFormat type = CubeFormat::ND; + /** + * @brief Matrix data type, default is float + */ DataType dataType = DataType::DT_FLOAT; + /** + * @brief Whether the matrix is transposed, default is false + */ bool isTrans = false; + /** + * @brief Whether the matrix uses double buffer, default is false + */ bool isDB = false; + /** + * @brief Whether scale type has been set, default is false + */ bool hasSetScaleType = false; + /** + * @brief Scale position, default is global memory (GM) + */ TPosition scalePos = TPosition::GM; + /** + * @brief Scale format, default is ND format + */ CubeFormat scaleType = CubeFormat::ND; + /** + * @brief Whether scale is transposed, default is false + */ bool isScaleTrans = false; }; - +/** +* @struct BufferPool +* @brief Buffer pool structure for managing buffers of different sizes +*/ struct BufferPool { - int32_t l1Size; - int32_t l0CSize; - int32_t ubSize; - int32_t l0ASize; - int32_t l0BSize; - int32_t btSize; - - int32_t l1AlignSize; - int32_t l0CAlignSize; - int32_t l0AAlignSize; - int32_t l0BAlignSize; - int32_t ubAlignSize; + int32_t l1Size; ///< Size of the L1 buffer + int32_t l0CSize; ///< Size of the L0C buffer + int32_t ubSize; ///< Size of the UB buffer + int32_t l0ASize; ///< Size of the L0A buffer + int32_t l0BSize; ///< Size of the L0B buffer + int32_t btSize; ///< Size of the BT buffer + + int32_t l1AlignSize; ///< Aligned size of the L1 buffer + int32_t l0CAlignSize; ///< Aligned size of the L0C buffer + int32_t l0AAlignSize; ///< Aligned size of the L0A buffer + int32_t l0BAlignSize; ///< Aligned size of the L0B buffer + int32_t ubAlignSize; ///< Aligned size of the UB buffer }; - +/** +* @struct PlatformInfo +* @brief A structure that stores platform information. +*/ struct PlatformInfo { + /** + * @brief Soc version information. + */ platform_ascendc::SocVersion socVersion; - uint64_t l1Size = 0; - uint64_t l0CSize = 0; - uint64_t ubSize = 0; - uint64_t l0ASize = 0; - uint64_t l0BSize = 0; + uint64_t l1Size = 0; ///< Size of L1 cache, in bytes + uint64_t l0CSize = 0; ///< Size of L0C cache, in bytes + uint64_t ubSize = 0; ///< Size of UB cache, in bytes + uint64_t l0ASize = 0; ///< Size of L0A cache, in bytes + uint64_t l0BSize = 0; ///< Size of L0B cache, in bytes }; - +/** +* @struct MatmulConfigParams +* @brief Matrix multiplication configuration parameters structure +*/ struct MatmulConfigParams { + /** + * @brief Matrix multiplication configuration type + */ int32_t mmConfigType; + /** + * @brief Whether to enable L1 cache + */ bool enableL1CacheUB; + /** + * @brief Schedule type + */ ScheduleType scheduleType; + /** + * @brief Matrix traversal method + */ MatrixTraverse traverse; + /** + * @brief Whether to enable vector ND2NZ + */ bool enVecND2NZ; + /** + * @brief Constructor + * @param [in] mmConfigTypeIn Matrix multiplication configuration type, default is 1 + * @param [in] enableL1CacheUBIn Whether to enable L1 cache, default is false + * @param [in] scheduleTypeIn Schedule type, default is ScheduleType::INNER_PRODUCT + * @param [in] traverseIn Matrix traversal method, default is MatrixTraverse::NOSET + * @param [in] enVecND2NZIn Whether to enable vector ND2NZ, default is false + */ MatmulConfigParams(int32_t mmConfigTypeIn = 1, bool enableL1CacheUBIn = false, ScheduleType scheduleTypeIn = ScheduleType::INNER_PRODUCT, MatrixTraverse traverseIn = MatrixTraverse::NOSET, bool enVecND2NZIn = false) { - mmConfigType = mmConfigTypeIn; - enableL1CacheUB = enableL1CacheUBIn; - scheduleType = scheduleTypeIn; - traverse = traverseIn; - enVecND2NZ = enVecND2NZIn; + mmConfigType = mmConfigTypeIn; ///< Set matrix multiplication configuration type + enableL1CacheUB = enableL1CacheUBIn; ///< Set whether to enable L1 cache + scheduleType = scheduleTypeIn; ///< Set schedule type + traverse = traverseIn; ///< Set matrix traversal method + enVecND2NZ = enVecND2NZIn; ///< Set whether to enable vector ND2NZ } }; @@ -219,59 +311,246 @@ public: explicit MatmulApiTilingBase(const platform_ascendc::PlatformAscendC& ascendcPlatform); explicit MatmulApiTilingBase(const PlatformInfo& platform); virtual ~MatmulApiTilingBase(); + /** + * @brief Set the A type + * @param [in] pos: the position, type TPosition + * @param [in] type: the cube format, type CubeFormat + * @param [in] dataType: the data type, type DataType + * @param [in] isTrans: whether to transpose, default is false + */ int32_t SetAType(TPosition pos, CubeFormat type, DataType dataType, bool isTrans = false); + /** + * @brief Set the B type + * @param [in] pos: the position, type TPosition + * @param [in] type: the cube format, type CubeFormat + * @param [in] dataType: the data type, type DataType + * @param [in] isTrans: whether to transpose, default is false + */ int32_t SetBType(TPosition pos, CubeFormat type, DataType dataType, bool isTrans = false); + /** + * @brief Set the scale A type + * @param [in] scalePos: scale position, type TPosition + * @param [in] scaleType: scale type, type CubeFormat + * @param [in] isScaleTrans: whether to perform scale transformation, default is false + */ int32_t SetScaleAType(TPosition scalePos, CubeFormat scaleType, bool isScaleTrans = false); + /** + * @brief Set the scale B type + * @param [in] scalePos: scale position, type TPosition + * @param [in] scaleType: scale type, type CubeFormat + * @param [in] isScaleTrans: whether to perform scale transformation, default is true + */ int32_t SetScaleBType(TPosition scalePos, CubeFormat scaleType, bool isScaleTrans = true); + /** + * @brief Set the type and data type of a cube + * @param [in] pos: the position, type TPosition + * @param [in] type: the cube format, type CubeFormat + * @param [in] dataType: the data type, type DataType + */ int32_t SetCType(TPosition pos, CubeFormat type, DataType dataType); + /** + * @brief Set bias type + * @param [in] pos: the position, type TPosition + * @param [in] type: the cube format, type CubeFormat + * @param [in] dataType: the data type, type DataType + */ int32_t SetBiasType(TPosition pos, CubeFormat type, DataType dataType); + /** + * @brief Set the dequantization type + * @param [in] dequantType: the dequantization type enumeration value + * @return Return 0 to indicate successful setting + */ int32_t SetDequantType(DequantType dequantType) { this->deqType = dequantType; return 0; } - + /** + * @brief Set the shape of the object + * @param [in] m: first dimension of the shape + * @param [in] n: second dimension of the shape + * @param [in] k: third dimension of the shape + */ virtual int32_t SetShape(int32_t m, int32_t n, int32_t k); + /** + * @brief Set the original shape dimensions + * @param [in] orgMIn: the M dimension size of the original shape + * @param [in] orgNIn: the N dimension size of the original shape + * @param [in] orgKIn: the K dimension size of the original shape + */ int32_t SetOrgShape(int32_t orgMIn, int32_t orgNIn, int32_t orgKIn); + /** + * @brief Set the original shape dimensions + * @param [in] orgMIn: the M dimension size of the original shape + * @param [in] orgNIn: the N dimension size of the original shape + * @param [in] orgKaIn: the Ka dimension size of the original shape + * @param [in] orgKbIn: the Kb dimension size of the original shape + */ int32_t SetOrgShape(int32_t orgMIn, int32_t orgNIn, int32_t orgKaIn, int32_t orgKbIn); + /** + * @brief Set the layout axis information for matrix A, including B, S, N, G, and D axis + * @param [in] b: batch dimension (B-axis) size, representing the number of batches + * @param [in] s: spatial dimension (S-axis) size, representing the number of spatial dimensions + * @param [in] n: channel dimension (N-axis) size, representing the number of channels + * @param [in] g: group dimension (G-axis) size, representing the number of groups + * @param [in] d: dimension (D-axis) size, representing the number of dimensions + */ int32_t SetALayout(int32_t b, int32_t s, int32_t n, int32_t g, int32_t d); + /** + * @brief Set the layout axis information for matrix B, including B, S, N, G, and D axis + * @param [in] b: batch dimension (B-axis) size, representing the number of batches + * @param [in] s: spatial dimension (S-axis) size, representing the number of spatial dimensions + * @param [in] n: channel dimension (N-axis) size, representing the number of channels + * @param [in] g: group dimension (G-axis) size, representing the number of groups + * @param [in] d: dimension (D-axis) size, representing the number of dimensions + */ int32_t SetBLayout(int32_t b, int32_t s, int32_t n, int32_t g, int32_t d); + /** + * @brief Set the layout axis information for matrix C, including B, S, N, G, and D axis + * @param [in] b: batch dimension (B-axis) size, representing the number of batches + * @param [in] s: spatial dimension (S-axis) size, representing the number of spatial dimensions + * @param [in] n: channel dimension (N-axis) size, representing the number of channels + * @param [in] g: group dimension (G-axis) size, representing the number of groups + * @param [in] d: dimension (D-axis) size, representing the number of dimensions + */ int32_t SetCLayout(int32_t b, int32_t s, int32_t n, int32_t g, int32_t d); + /** + * @brief Set the batch information for normal processing + * @param [in] batchA: the value for batch A + * @param [in] batchB: the value for batch B + * @param [in] m: the value for parameter m + * @param [in] n: the value for parameter n + * @param [in] k: the value for parameter k + */ int32_t SetBatchInfoForNormal(int32_t batchA, int32_t batchB, int32_t m, int32_t n, int32_t k); + /** + * @brief Set the batch number + * @param [in] batch: the batch number to set + */ int32_t SetBatchNum(int32_t batch); + /** + * @brief Enable the bias + * @param [in] isBiasIn: if true, enable the bias; if false, disable the bias, default is false + */ int32_t EnableBias(bool isBiasIn = false); + /** + * @brief Set the bias parameter + * @param [in] isBiasIn: whether to use bias, default is false + */ int32_t SetBias(bool isBiasIn = false); + /** + * @brief Set fixed split parameters + * @param [in] baseMIn: initial value for parameter M, default is -1 + * @param [in] baseNIn: initial value for parameter N, default is -1 + * @param [in] baseKIn: initial value for parameter K, default is -1 + * @return Return the result of the setting operation + */ int32_t SetFixSplit(int32_t baseMIn = -1, int32_t baseNIn = -1, int32_t baseKIn = -1); + /** + * @brief Set the size of buffer spaces + * @param [in] l1Size: size of L1 buffer in bytes; -1 leaves the current setting unchanged + * @param [in] l0CSize: size of L0C buffer in bytes; -1 leaves the current setting unchanged + * @param [in] ubSize: size of UB buffer in bytes; -1 leaves the current setting unchanged + * @param [in] btSize: size of BT buffer in bytes; -1 leaves the current setting unchanged + * @return Return 0 if success, -1 if failure + */ int32_t SetBufferSpace(int32_t l1Size = -1, int32_t l0CSize = -1, int32_t ubSize = -1, int32_t btSize = -1); + /** + * @brief Set the traversal method for the matrix + * @param [in] traverse: the traversal method to be set + * @return Return 0 if success + */ int32_t SetTraverse(MatrixTraverse traverse); // Set the N direction first for the upper left corner matrix + /** + * @brief Set the MAD of the matrix + * @param [in] madType: the MAD type to set + */ int32_t SetMadType(MatrixMadType madType); // Set hf32 mode // L0C: BaseM * baseN = GetTensorC() // L1 : BaseM * BaseK + BaseK*BaseN, --> [disable temporarily] BaseK/k(1)=k1, BaseM/m(1)=m1, BaseN/n(1) = n1 + + /** + * @brief Set the split range + * @param [in] maxBaseM: maximum M value, default is -1 + * @param [in] maxBaseN: maximum N value, default is -1 + * @param [in] maxBaseK: maximum K value, default is -1 + * @param [in] minBaseM: minimum M value, default is -1 + * @param [in] minBaseN: minimum N value, default is -1 + * @param [in] minBaseK: minimum K value, default is -1 + * @return Return the result of the setting + */ int32_t SetSplitRange(int32_t maxBaseM = -1, int32_t maxBaseN = -1, int32_t maxBaseK = -1, int32_t minBaseM = -1, int32_t minBaseN = -1, int32_t minBaseK = -1); - + /** + * @brief Set the double buffer mode + * @param [in] a: enable double buffer mode for matrix A + * @param [in] b: enable double buffer mode for matrix B + * @param [in] c: enable double buffer mode for matrix C + * @param [in] bias: enable double buffer mode for bias + * @param [in] transND2NZ: enable transpose from ND to NZ, default is true + * @param [in] transNZ2ND: enable transpose from NZ to ND, default is true + * @return Return 0 if success + */ int32_t SetDoubleBuffer(bool a, bool b, bool c, bool bias, bool transND2NZ = true, bool transNZ2ND = true); - + /** + * @brief Set matrix multiplication configuration parameters + * @param [in] mmConfigTypeIn: matrix multiplication configuration type, default is 1 + * @param [in] enableL1CacheUBIn: enable L1 cache, default is false + * @param [in] scheduleTypeIn: schedule type, default is INNER_PRODUCT + * @param [in] traverseIn: matrix traversal method, default is NOSET + * @param [in] enVecND2NZIn: enable vector ND2NZ, default is false + * @note this function is used to set matrix multiplication configuration parameters, + * including configuration type, cache enablement, schedule type, traversal method, and vector conversion + */ void SetMatmulConfigParams(int32_t mmConfigTypeIn = 1, bool enableL1CacheUBIn = false, ScheduleType scheduleTypeIn = ScheduleType::INNER_PRODUCT, MatrixTraverse traverseIn = MatrixTraverse::NOSET, bool enVecND2NZIn = false); + /** + * @brief Set matrix multiplication configuration parameters + * @param [in] configParams: matrix multiplication configuration parameters object + * @note this function sets matrix multiplication configuration parameters by passing a MatmulConfigParams object + */ void SetMatmulConfigParams(const MatmulConfigParams& configParams); + /** + * @brief Set the sparse matrix flag + * @param [in] isSparseIn: input flag for sparse matrix, the matrix is sparse if true + */ int32_t SetSparse(bool isSparseIn = false); - + /** + * @brief Get the base M value + * @return Return the base M value + */ int32_t GetBaseM() const { return baseM; } + /** + * @brief Get the base N value + * @return Return the base N value + */ int32_t GetBaseN() const { return baseN; } + /** + * @brief Get the base K value + * @return Return the base K value + */ int32_t GetBaseK() const { return baseK; } - + /** + * @brief Interface to get tiling information + * @param [in] tiling: reference to store the tiling information + * @note the tiling of this function is in namespace optiling + */ virtual int64_t GetTiling(optiling::TCubeTiling& tiling) = 0; + /** + * @brief Interface to get tiling information + * @param [in] tiling: reference to store the tiling information + * @note the tiling of this function is in global namespace + */ virtual int64_t GetTiling(TCubeTiling& tiling) = 0; public: