From 9f910cd23cea94c3c868205c4ee297adb8a6c6fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=9D=A8?= Date: Thu, 25 Sep 2025 16:43:30 +0800 Subject: [PATCH] allgathermm optimize --- .../AclNNInvocation/README.md | 1 - .../AclNNInvocation/src/main.cpp | 6 - .../AclNNInvocation/src/op_runner.cpp | 3 +- .../op_host/all_gather_matmul_custom.cpp | 5 - .../op_kernel/all_gather_matmul_custom.cpp | 123 ++++++++----- .../op_kernel/gather_mm.h | 99 ----------- .../op_kernel/mc2_matmul_block.h | 167 ------------------ .../op_kernel/mc2_matmul_compute.h | 98 ---------- .../21_all_gather_matmul_custom/README.md | 2 +- .../all_gather_matmul_custom.json | 12 +- 10 files changed, 78 insertions(+), 438 deletions(-) delete mode 100644 operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/gather_mm.h delete mode 100644 operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_block.h delete mode 100644 operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_compute.h diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/README.md b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/README.md index 40cfc9d50..4d5b71060 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/README.md +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/README.md @@ -28,7 +28,6 @@ aclnnStatus aclnnAllGatherMatmulCustomGetWorkspaceSize( const aclTensor *a, const aclTensor *b, - const aclTensor *biasOptional, char *group, const aclTensor *cOut, const aclTensor *gatherOutOut, diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/main.cpp b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/main.cpp index 86ff36642..bc9eac908 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/main.cpp +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/main.cpp @@ -24,18 +24,15 @@ bool g_isDevice = false; namespace { -constexpr int32_t INPUT_BUFFER_BIAS = 2; OperatorDesc CreateOpDesc() { // define operator std::vector shapeA { RANK_M, RANK_K }; std::vector shapeB { RANK_K, RANK_N }; - std::vector shapeBias {}; std::vector shapeC { RANK_M * RANK_DIM, RANK_N }; std::vector shapeGatherOut { RANK_M * RANK_DIM, RANK_K }; aclDataType dataTypeA = ACL_FLOAT16; aclDataType dataTypeB = ACL_FLOAT16; - aclDataType dataTypeBias = ACL_FLOAT16; aclDataType dataTypeC = ACL_FLOAT16; aclDataType dataTypeGatherOut = ACL_FLOAT16; aclFormat format = ACL_FORMAT_ND; @@ -43,7 +40,6 @@ OperatorDesc CreateOpDesc() OperatorDesc opDesc; opDesc.AddInputTensorDesc(dataTypeA, shapeA.size(), shapeA.data(), format); opDesc.AddInputTensorDesc(dataTypeB, shapeB.size(), shapeB.data(), format); - opDesc.AddInputTensorDesc(dataTypeBias, shapeBias.size(), shapeBias.data(), format); opDesc.AddOutputTensorDesc(dataTypeC, shapeC.size(), shapeC.data(), format); opDesc.AddOutputTensorDesc(dataTypeGatherOut, shapeGatherOut.size(), shapeGatherOut.data(), format); return opDesc; @@ -56,8 +52,6 @@ bool SetInputData(OpRunner &runner, uint32_t rankId) runner.GetInputBuffer(0), runner.GetInputSize(0)); // Read input_a file ReadFile("../input/input_b_" + std::to_string(rankId) + ".bin", fileSize, runner.GetInputBuffer(1), runner.GetInputSize(1)); // Read input_b file - ReadFile("../input/input_bias_" + std::to_string(rankId) + ".bin", fileSize, - runner.GetInputBuffer(INPUT_BUFFER_BIAS), runner.GetInputSize(INPUT_BUFFER_BIAS)); INFO_LOG("Set input success"); return true; } diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/op_runner.cpp b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/op_runner.cpp index 5aa62934f..c47e2ddee 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/op_runner.cpp +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/op_runner.cpp @@ -298,10 +298,9 @@ bool OpRunner::RunOp(std::string group, aclrtStream stream) INFO_LOG("Copy input[%zu] success", i); } - aclTensor *bias = nullptr; size_t workspaceSize = 0; aclOpExecutor *handle = nullptr; - auto ret = aclnnAllGatherMatmulCustomGetWorkspaceSize(inputTensor_[0], inputTensor_[1], bias, (char*)group.c_str(), + auto ret = aclnnAllGatherMatmulCustomGetWorkspaceSize(inputTensor_[0], inputTensor_[1], (char*)group.c_str(), outputTensor_[0], outputTensor_[1], &workspaceSize, &handle); if (ret != ACL_SUCCESS) { ERROR_LOG("Get Operator Workspace failed. error code is %d", static_cast(ret)); diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_host/all_gather_matmul_custom.cpp b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_host/all_gather_matmul_custom.cpp index 9916b7b9d..fedb14ba8 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_host/all_gather_matmul_custom.cpp +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_host/all_gather_matmul_custom.cpp @@ -131,11 +131,6 @@ public: .Format({ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND}) .IgnoreContiguous(); - this->Input("bias") - .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}); this->Output("c") .ParamType(REQUIRED) diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom.cpp b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom.cpp index bcdae45b8..ad77b44a8 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom.cpp +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom.cpp @@ -7,76 +7,103 @@ * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. */ - #include "kernel_operator.h" -#include "kernel_operator_intf.h" #include "lib/matmul_intf.h" -#include "gather_mm.h" #include "all_gather_matmul_custom_tiling.h" - using namespace AscendC; +using MATMUL_TYPE = MatmulType; -extern "C" __global__ __aicore__ void all_gather_matmul_custom(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR biasGM, GM_ADDR cGM, - GM_ADDR gatherOutGM, GM_ADDR workspaceGM, GM_ADDR tilingGM) +__aicore__ inline void MatmulKernel(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR cGM, TCubeTiling &tiling, + MatmulImpl &mm) { - if ASCEND_IS_AIV { + if (GetBlockIdx() >= tiling.usedCoreNum) { return; } - REGISTER_TILING_DEFAULT(AllGatherMatmulCustomTilingData); - auto tiling = (__gm__ AllGatherMatmulCustomTilingData*)tilingGM; - __gm__ void *mc2InitTiling = (__gm__ void *)(&(tiling->mc2InitTiling)); - __gm__ void *mc2CcTiling = (__gm__ void *)(&(tiling->mc2CcTiling)); + GlobalTensor aGlobal, bGlobal, cGlobal; + aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(aGM), tiling.M * tiling.Ka); + bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(bGM), tiling.Ka * tiling.N); + cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(cGM), tiling.M * tiling.N); - GET_TILING_DATA(tilingData, tilingGM); - auto &&cfg = tilingData.cfg; - auto &&localTiling = tilingData.localTiling; - auto &&tileTiling = tilingData.tileTiling; - auto &&tailTiling = tilingData.tailTiling; - const auto tileNum = cfg.tileNum; - const auto tailNum = cfg.tailNum; + int mSingleBlocks = (tiling.M + tiling.singleCoreM - 1) / tiling.singleCoreM; + int mCoreIndex = GetBlockIdx() % mSingleBlocks; + int nCoreIndex = GetBlockIdx() / mSingleBlocks; + int offsetA = mCoreIndex * tiling.Ka * tiling.singleCoreM; + int offsetB = nCoreIndex * tiling.singleCoreN; + int offsetC = mCoreIndex * tiling.N * tiling.singleCoreM + nCoreIndex * tiling.singleCoreN; + int tailM = Std::min(tiling.M - mCoreIndex * tiling.singleCoreM, tiling.singleCoreM); + int tailN = Std::min(tiling.N - nCoreIndex * tiling.singleCoreN, tiling.singleCoreN); - const auto aTileCnt = tileTiling.M * tileTiling.Ka; - const auto aTileOffset = tileTiling.M * tileTiling.Ka * sizeof(A_DTYPE); - const auto cTileOffset = tileTiling.M * tileTiling.N * sizeof(C_DTYPE); - const auto aTailCnt = tailTiling.M * tailTiling.Ka; - const auto aRankCnt = cfg.rankM * cfg.rankK; + mm.SetOrgShape(tiling.M, tiling.N, tiling.Ka, tiling.Kb); + mm.SetTensorA(aGlobal[offsetA]); + mm.SetTensorB(bGlobal[offsetB]); + mm.SetTail(tailM, tailN); + mm.IterateAll(cGlobal[offsetC]); +} +extern "C" __global__ __aicore__ void all_gather_matmul_custom(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR cGM, + GM_ADDR gatherOutGM, GM_ADDR workspaceGM, + GM_ADDR tilingGM) +{ + if ASCEND_IS_AIV { + return; + } + REGISTER_TILING_DEFAULT(AllGatherMatmulCustomTilingData); + GET_TILING_DATA(tilingData, tilingGM); TPipe pipe; + auto &&localTiling = tilingData.localTiling; + auto &&tileTiling = tilingData.tileTiling; + auto &&tailTiling = tilingData.tailTiling; + const auto tileNum = tilingData.cfg.tileNum; + const auto tailNum = tilingData.cfg.tailNum; + const auto aTileEleCnt = tileTiling.M * tileTiling.Ka; + const auto aTileSize = tileTiling.M * tileTiling.Ka * sizeof(half); + const auto cTileSize = tileTiling.M * tileTiling.N * sizeof(half); + const auto aTailEleCnt = tailTiling.M * tailTiling.Ka; + const auto aRankEleCnt = localTiling.M * localTiling.Ka; + const auto aRankSize = localTiling.M * localTiling.Ka * sizeof(half); + const auto cRankSize = localTiling.M * localTiling.N * sizeof(half); + Hccl hccl; GM_ADDR contextGM = GetHcclContext(); - hccl.Init(contextGM, mc2InitTiling); - hccl.SetCcTiling(mc2CcTiling); - - // 下发allgather任务 - // 首块 - auto handleId = hccl.AllGather(aGM, gatherOutGM, aTileCnt, HcclDataType::HCCL_DATA_TYPE_FP16, aRankCnt, tileNum); - // 尾块 - auto tailHandleId = hccl.AllGather(aGM + tileNum * aTileOffset, gatherOutGM + tileNum * aTileOffset, aTailCnt, - HcclDataType::HCCL_DATA_TYPE_FP16, aRankCnt, tailNum); - - using A_TYPE = MatmulType; - using B_TYPE = MatmulType; - using C_TYPE = MatmulType; + hccl.InitV2(contextGM, &tilingData); + hccl.SetCcTilingV2(offsetof(AllGatherMatmulCustomTilingData, mc2CcTiling)); + auto handleId = + hccl.AllGather(aGM, gatherOutGM, aTileEleCnt, HcclDataType::HCCL_DATA_TYPE_FP16, aRankEleCnt, tileNum); + auto tailHandleId = hccl.AllGather(aGM + tileNum * aTileSize, gatherOutGM + tileNum * aTileSize, aTailEleCnt, + HcclDataType::HCCL_DATA_TYPE_FP16, aRankEleCnt, tailNum); - // 本卡数据计算 - MatmulKernelLocal(aGM, bGM, cGM, cfg, localTiling, hccl); + MatmulImpl mm; + mm.SetSubBlockIdx(0); + mm.Init(&localTiling); + MatmulKernel(aGM, bGM, cGM + hccl.GetRankId() * cRankSize, localTiling, mm); - // tile首块计算 - auto aAddr = gatherOutGM; // gatherOut 作为 mm A矩阵地址 + auto aAddr = gatherOutGM; auto cAddr = cGM; - if (tileNum > 0) { - MatmulKernel(aAddr, bGM, cAddr, cfg, tileTiling, hccl, handleId, - tileNum); + mm.Init(&tileTiling); + for (uint32_t i = 0; i < tileNum; i++) { + hccl.Wait(handleId); + for (uint32_t rankId = 0; rankId < hccl.GetRankDim(); rankId++) { + if (rankId == hccl.GetRankId()) + continue; + MatmulKernel(aAddr + rankId * aRankSize, bGM, cAddr + rankId * cRankSize, tileTiling, mm); + } + aAddr += aTileSize; + cAddr += cTileSize; } - // tail尾块计算 - aAddr = gatherOutGM + tileNum * aTileOffset; - cAddr = cGM + tileNum * cTileOffset; + aAddr = gatherOutGM + tileNum * aTileSize; + cAddr = cGM + tileNum * cTileSize; if (tailNum > 0) { - MatmulKernel(aAddr, bGM, cAddr, cfg, tailTiling, hccl, tailHandleId, - tailNum); + mm.Init(&tailTiling); + hccl.Wait(tailHandleId); + for (uint32_t rankId = 0; rankId < hccl.GetRankDim(); rankId++) { + if (rankId == hccl.GetRankId()) + continue; + MatmulKernel(aAddr + rankId * aRankSize, bGM, cAddr + rankId * cRankSize, tailTiling, mm); + } } + mm.End(); hccl.Finalize(); } \ No newline at end of file diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/gather_mm.h b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/gather_mm.h deleted file mode 100644 index b363d8ce2..000000000 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/gather_mm.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * @file gather_mm.h - * - * Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - */ - -#ifndef MC2_GATHER_MM_H -#define MC2_GATHER_MM_H - -#if defined ASCENDC_CPU_DEBUG -#define SET_G_CORE_TYPE_IS_AIV thread_local int g_coreType = 2 -#define SET_G_CORE_TYPE_IS_AIC thread_local int g_coreType = 1 -#define DTYPE_A half -#define DTYPE_C half -#else -#define SET_G_CORE_TYPE_IS_AIV -#define SET_G_CORE_TYPE_IS_AIC -#endif - -#include "kernel_operator_intf.h" -#include "mc2_matmul_compute.h" -#include "all_gather_matmul_custom_tiling.h" - -namespace AscendC { -using A_DTYPE = DTYPE_A; -using B_DTYPE = DTYPE_B; -using C_DTYPE = DTYPE_C; - -template -__aicore__ inline void MatmulKernelLocal(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR cGM, AllGatherMatmulTiling &cfg, - TCubeTiling &tiling, Hccl &hccl) -{ - if ASCEND_IS_AIC { - if (GetBlockIdx() >= tiling.usedCoreNum) { - return; - } - using C_T = typename C_TYPE::T; - const auto aRankDataCnt = cfg.rankM * cfg.rankK; - const auto cRankDataCnt = cfg.rankM * cfg.rankN; - - MatmulCompute mmLocal; - mmLocal.Init(cfg, tiling); - mmLocal.UpdateWeight(bGM); - mmLocal.UpdateAddress(aGM, aRankDataCnt, cGM + hccl.GetRankId() * cRankDataCnt * sizeof(C_T), cRankDataCnt); - mmLocal.Process(); - mmLocal.End(); - } -} - -template -__aicore__ inline void MatmulKernel(GM_ADDR aAddr, GM_ADDR bGM, GM_ADDR cAddr, AllGatherMatmulTiling &cfg, - TCubeTiling &tiling, Hccl &hccl, HcclHandle &handleId, uint32_t tileCnt) -{ - if ASCEND_IS_AIC { - if (GetBlockIdx() >= tiling.usedCoreNum) { - for (uint32_t i = 0; i < tileCnt; i++) { - CrossCoreSetFlag<0x0, PIPE_FIX>(0x8); - CrossCoreWaitFlag(0x8); - } - return; - } - using A_T = typename A_TYPE::T; - using C_T = typename C_TYPE::T; - const auto aDataCnt = tiling.M * tiling.Ka; - const auto aOffset = aDataCnt * sizeof(A_T); - const auto cDataCnt = tiling.M * tiling.N; - const auto cOffset = cDataCnt * sizeof(C_T); - const auto aRankOffset = cfg.rankM * cfg.rankK * sizeof(A_T); - const auto cRankOffset = cfg.rankM * cfg.rankN * sizeof(C_T); - - MatmulCompute mm; - mm.Init(cfg, tiling); - mm.UpdateWeight(bGM); - for (uint32_t i = 0; i < tileCnt; i++) { - // wait current handle allgather - hccl.Wait(handleId); - CrossCoreSetFlag<0x0, PIPE_FIX>(0x8); - CrossCoreWaitFlag(0x8); - // calculate all ranks data - for (uint32_t rankId = 0; rankId < hccl.GetRankDim(); rankId++) { - // skip local rank - if (rankId == hccl.GetRankId()) { - continue; - } - mm.UpdateAddress(aAddr + rankId * aRankOffset, aDataCnt, cAddr + rankId * cRankOffset, cDataCnt); - mm.Process(); - } - aAddr += aOffset; - cAddr += cOffset; - } - mm.End(); - } -} -} -#endif // MC2_GATHER_MM_H diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_block.h b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_block.h deleted file mode 100644 index 00d4322b5..000000000 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_block.h +++ /dev/null @@ -1,167 +0,0 @@ -/** - * @file mc2_matmul_block.h - * - * Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - */ - -#ifndef MC2_MATMUL_BLOCK_H -#define MC2_MATMUL_BLOCK_H - -namespace AscendC { - -constexpr uint32_t C0_SIZE = 16; - -struct BaseBlockOffset { - uint64_t offsetA; - uint64_t offsetB; - uint64_t offsetC; -}; - -struct BaseBlockArguments -{ - bool isRowOrder; - uint32_t singleCoreM; - uint32_t singleCoreN; - uint32_t mBlockCnt; // M方向的基本块个数 - uint32_t nBlockCnt; // N方向的基本块个数 - uint32_t nBaseTail; // N方向的尾块大小 - uint32_t mBaseTail; // M方向的尾块大小 - uint32_t totalBlockCnt; // C矩阵的全部基本块个数 - uint32_t blockCnt; // 单核需要计算的基本块个数 - uint32_t blockStartIdx; // 当前核需计算的基本块的起始位置索引 - uint32_t blockCurrIdx; // 当前核需计算的基本块的当前位置索引 - uint32_t preCoreNum; // 满核分配后剩余基本块数/需要预分配1个block的核数 - uint32_t preCoreStartIdx; // 多分配一个基本块的核起始位置 - uint64_t mBlockOffset; - uint64_t nBlockOffset; - uint64_t mCWorkOffset; -}; - -class MatmulBaseBlock { -public: - __aicore__ inline MatmulBaseBlock() {} - __aicore__ inline void Init(TCubeTiling& tiling); - __aicore__ inline void InitBlockWithoutIndex(); - __aicore__ inline void UpdateBlockIndex(uint32_t currPos); - __aicore__ inline void UpdateBlockParams(int32_t mTileIndex=0, int32_t nTileIndex=0); - __aicore__ inline void CalcGMOffset(); - __aicore__ inline void GetBlockStartIdx(uint32_t startIdx, uint32_t endIdx); - -public: - BaseBlockOffset offset_; - BaseBlockArguments args_; - TCubeTiling tiling_; -}; - -__aicore__ inline void MatmulBaseBlock::Init(TCubeTiling& tiling) -{ - tiling_ = tiling; - args_.preCoreStartIdx = 0; - args_.mBlockCnt = DivCeil(tiling.M, tiling.baseM); //M方向分Base块个数 - args_.nBlockCnt = DivCeil(tiling.N, tiling.baseN); //N方向分Base块个数 - args_.nBaseTail = tiling.N - (args_.nBlockCnt - 1) * tiling.baseN; - args_.mBaseTail = tiling.M - (args_.mBlockCnt - 1) * tiling.baseM; - args_.totalBlockCnt = args_.mBlockCnt * args_.nBlockCnt; - args_.isRowOrder = true; - if (tiling_.N > 5 * tiling_.M) { // 5: ratio of rowOrder - args_.isRowOrder = false; - } -} - -__aicore__ inline void MatmulBaseBlock::InitBlockWithoutIndex() -{ - args_.totalBlockCnt = args_.mBlockCnt * args_.nBlockCnt; - args_.blockCnt = args_.totalBlockCnt / tiling_.usedCoreNum; - args_.preCoreNum = args_.totalBlockCnt % tiling_.usedCoreNum; - - // 多分配1个基本块的核索引, 从上一次结束位置开始 - auto startIdx = args_.preCoreStartIdx; - auto endIdx = (startIdx + args_.preCoreNum) % tiling_.usedCoreNum; - args_.preCoreStartIdx = endIdx; - GetBlockStartIdx(startIdx, endIdx); -} - -__aicore__ inline void MatmulBaseBlock::GetBlockStartIdx(uint32_t startIdx, uint32_t endIdx) -{ - if (startIdx > endIdx) { - if (block_idx < endIdx) { - args_.blockCnt += 1; - args_.blockStartIdx = block_idx * args_.blockCnt; - } else if (block_idx >= startIdx) { - args_.blockCnt += 1; - args_.blockStartIdx = block_idx * args_.blockCnt - (tiling_.usedCoreNum - args_.preCoreNum); - } else { - args_.blockStartIdx = block_idx * args_.blockCnt + endIdx; - } - } else { - if (block_idx < startIdx) { - args_.blockStartIdx = block_idx * args_.blockCnt; - } else if (block_idx >= endIdx) { - args_.blockStartIdx = block_idx * args_.blockCnt + args_.preCoreNum; - } else { - args_.blockCnt += 1; - args_.blockStartIdx = block_idx * args_.blockCnt - startIdx; - } - } - - if (!args_.isRowOrder) { - auto blockStart = args_.blockStartIdx; - args_.blockStartIdx = blockStart / args_.mBlockCnt + blockStart % args_.mBlockCnt * args_.nBlockCnt; - } -} - -__aicore__ inline void MatmulBaseBlock::UpdateBlockIndex(uint32_t currPos) -{ - // 按行取,计算第i个基本块的index - if (args_.isRowOrder) { - args_.blockCurrIdx = args_.blockStartIdx + currPos % args_.blockCnt; - return; - } - - args_.blockCurrIdx = args_.blockStartIdx + (currPos % args_.blockCnt) * args_.nBlockCnt; - // 按列取,如果block超行,需计算下一列的位置 - if (args_.blockCurrIdx >= args_.totalBlockCnt) { - args_.blockCurrIdx = args_.blockCurrIdx % args_.totalBlockCnt + args_.blockCurrIdx / args_.totalBlockCnt; - } - return; -} - -__aicore__ inline void MatmulBaseBlock::UpdateBlockParams(int32_t mTileIndex, int32_t nTileIndex) -{ - (void)mTileIndex; - (void)nTileIndex; - if (args_.blockCurrIdx == (args_.totalBlockCnt - 1)) { - // 当前矩阵最后一块 - args_.singleCoreM = args_.mBaseTail; - args_.singleCoreN = args_.nBaseTail; - } else if (args_.blockCurrIdx >= (args_.mBlockCnt - 1) * args_.nBlockCnt) { - // 当前矩阵最后一行 - args_.singleCoreM = args_.mBaseTail; - args_.singleCoreN = tiling_.baseN; - } else if ((args_.blockCurrIdx + 1) % args_.nBlockCnt == 0) { - // 当前矩阵最后一列 - args_.singleCoreM = tiling_.baseM; - args_.singleCoreN = args_.nBaseTail; - } else { - args_.singleCoreM = tiling_.baseM; - args_.singleCoreN = tiling_.baseN; - } - - // 更新基本块的地址偏移 - args_.mBlockOffset = args_.blockCurrIdx / args_.nBlockCnt * tiling_.baseM; // 基本块所在的行偏移 - args_.nBlockOffset = args_.blockCurrIdx % args_.nBlockCnt * tiling_.baseN; // 基本块所在的列偏移 - args_.mCWorkOffset = args_.mBlockOffset; -} - -__aicore__ inline void MatmulBaseBlock::CalcGMOffset() -{ - offset_.offsetA = args_.mBlockOffset * tiling_.Ka; - offset_.offsetB = args_.nBlockOffset; - offset_.offsetC = args_.nBlockOffset + args_.mCWorkOffset * tiling_.N; -} -} // namespace ASCENDC -#endif // MC2_MATMUL_BLOCK_H diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_compute.h b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_compute.h deleted file mode 100644 index 0bac09100..000000000 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_compute.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * @file mc2_matmul_compute.h - * - * Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - */ - -#ifndef MC2_MATMUL_COMPUTE_H -#define MC2_MATMUL_COMPUTE_H -#include "mc2_matmul_block.h" -#include "all_gather_matmul_custom_tiling.h" - -namespace AscendC { -using namespace matmul; - -template -class MatmulCompute { - using A_T = typename A_TYPE::T; - using B_T = typename B_TYPE::T; - using C_T = typename C_TYPE::T; - -public: - __aicore__ inline MatmulCompute() {} - __aicore__ inline void Init(AllGatherMatmulTiling& cfg, TCubeTiling& tiling); - __aicore__ inline void UpdateWeight(GM_ADDR bGM); - __aicore__ inline void UpdateAddress(GM_ADDR aGM, uint32_t aSize, GM_ADDR cGM, uint32_t cSize); - __aicore__ inline void Process(); - __aicore__ inline void End(); - -private: - MatmulImpl mm_; - GlobalTensor aGlobal; - GlobalTensor bGlobal; - GlobalTensor cGlobal; - MatmulBaseBlock block_; - TCubeTiling tiling_; - AllGatherMatmulTiling cfg_; -}; - -template -__aicore__ inline void MatmulCompute::UpdateWeight(GM_ADDR bGM) -{ - // MC2的计算流中默认B矩阵不变,GM地址无需偏移 - bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ B_T *>(bGM), tiling_.Kb * tiling_.N); -} - -template -__aicore__ inline void MatmulCompute::UpdateAddress( - GM_ADDR aGM, uint32_t aSize, GM_ADDR cGM, uint32_t cSize) -{ - aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ A_T *>(aGM), aSize); - cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ C_T *>(cGM), cSize); -} - -template -__aicore__ inline void MatmulCompute::Init(AllGatherMatmulTiling& cfg, TCubeTiling& tiling) -{ - mm_.SetSubBlockIdx(0); - mm_.Init(&tiling, GetTPipePtr()); - tiling_ = tiling; - cfg_ = cfg; - block_.Init(tiling); -} - -template -__aicore__ inline void MatmulCompute::Process() -{ - // 每次block循环开始前需要计算初始blockIndex - block_.InitBlockWithoutIndex(); - for (uint32_t i = 0; i < block_.args_.blockCnt; i++) { - // calculate blockCurrIndex - block_.UpdateBlockIndex(i); - if (block_.args_.blockCurrIdx < block_.args_.totalBlockCnt) { - block_.UpdateBlockParams(); - block_.CalcGMOffset(); - mm_.SetSingleShape(block_.args_.singleCoreM, block_.args_.singleCoreN, tiling_.singleCoreK); - mm_.SetTensorA(aGlobal[block_.offset_.offsetA]); - mm_.SetTensorB(bGlobal[block_.offset_.offsetB]); - mm_.Iterate(); - mm_.GetTensorC(cGlobal[block_.offset_.offsetC]); - // 增加M等FIX同步 - event_t eventIDFixToM = static_cast(GetTPipePtr()->FetchEventID(HardEvent::FIX_M)); - SetFlag(eventIDFixToM); - WaitFlag(eventIDFixToM); - } - } -} - -template -__aicore__ inline void MatmulCompute::End() -{ - mm_.End(); -} -} -#endif // MC2_MATMUL_COMPUTE_H diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/README.md b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/README.md index 03c7430e5..0ba75d35a 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/README.md +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/README.md @@ -32,7 +32,7 @@ $$ 算子输入nameshapedata typeformat a512 * 5120float16ND b5120 * 640float16ND -bias/// + 算子输出c4096 * 640float16ND diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_custom.json b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_custom.json index 6ab16e763..96aa36210 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_custom.json +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_custom.json @@ -13,7 +13,7 @@ ] }, { - "name": "x2", + "name": "b", "param_type": "required", "format": [ "ND" @@ -21,16 +21,6 @@ "type": [ "float16" ] - }, - { - "name": "bias", - "param_type": "optional", - "format": [ - "ND" - ], - "type": [ - "float16" - ] } ], "output_desc":[ -- Gitee