8 Star 124 Fork 6

Ascend/cann-var-sequence-gemm

Create your Gitee Account
Explore and code with more than 13.5 million developers,Free private repositories !:)
Sign up
文件
include
test
BatchGemm
batch_test
simulator
single_test
batchgemm_device.cpp
batchgemm_host.h
batchgemm_main.cpp
batchgemm_make.sh
batchgemm_make_simulator.sh
batchgemm_profiling_main.cpp
batchgemm_profiling_make.sh
test_gen_data.py
LLMsGEMM_batch_QKT
LLMsGEMM_batch_QKTV
LLMsGEMM_batch_QKTVP
LLMsGEMM_whole_task
acl.json
data_utils.h
kernel_host.h
run_main.h
.gitignore
LICENSE
OAT.xml
OWNERS
README.en.md
README.md
Clone or Download
batchgemm_device.cpp 6.68 KB
Copy Edit Raw Blame History
/*
Copyright (c) 2025 Huawei Technologies Co., Ltd.
This file is a part of the CANN Open Software.
Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/
#include "kernel_operator.h"
#include "../../include/batch_cal_valid_len.h"
#include "../../include/batch_padding.h"
#include "../../include/batch_matmul.h"
#include "../../include/batch_epilogue.h"
/**
* @brief: 有效长度计算
* @param [in] layoutA: A矩阵排布格式
* @param [in] layoutB: B矩阵排布格式
* @param [in] zeroPaddingM: A、C矩阵零填充后的M维度
* @param [in] zeroPaddingN: B、C矩阵零填充后的N维度
* @param [in] zeroPaddingK: A、B矩阵零填充后的K维度
* @param [in] batchCount: 批量矩阵乘的batch数
* @param [in] d_maskA: 掩码矩阵
* @param [in] d_APointer: A矩阵指针数组
* @param [out] d_validM: 每批矩阵乘的A、C矩阵M维度有效长度数组
* @param [out] d_validN: 每批矩阵乘的B、C矩阵N维度有效长度数组
* @param [out] d_validK: 每批矩阵乘的A、B矩阵N维度有效长度数组
*/
[aicore] inline __attribute__((always_inline)) void CalValidLenHead(
layoutType layoutA,
layoutType layoutB,
uint32_t zeroPaddingM,
uint32_t zeroPaddingN,
uint32_t zeroPaddingK,
uint32_t batchCount,
__gm__ half* d_maskA,
__gm__ half **d_APointer,
__gm__ uint32_t *d_validM,
__gm__ uint32_t *d_validN,
__gm__ uint32_t *d_validK
) {
BatchCalValidMWithMask<half>(
zeroPaddingM,
batchCount,
d_validM,
d_maskA
);
AscendC::CrossCoreSetFlag<0, PIPE_MTE3>(0);
AscendC::CrossCoreWaitFlag(0);
PadValidLen(
batchCount,
d_validN,
zeroPaddingN
);
AscendC::CrossCoreSetFlag<0, PIPE_MTE3>(0);
AscendC::CrossCoreWaitFlag(0);
PadValidLen(
batchCount,
d_validK,
zeroPaddingK
);
}
/**
* @brief:
* @param [in] layoutA: A矩阵排布格式
* @param [in] layoutB: B矩阵排布格式
* @param [in] zeroPaddingM: A、C矩阵零填充后的M维度
* @param [in] zeroPaddingN: B、C矩阵零填充后的N维度
* @param [in] zeroPaddingK: A、B矩阵零填充后的K维度
* @param [in] batchCount: 批量矩阵乘的batch数
* @param [in] d_maskA: 掩码矩阵
* @param [out] d_validM: 每批矩阵乘的A、C矩阵M维度有效长度数组
* @param [out] d_validN: 每批矩阵乘的B、C矩阵N维度有效长度数组
* @param [out] d_validK: 每批矩阵乘的A、B矩阵N维度有效长度数组
* @param [in] alpha: alpha*AB+beta*C
* @param [in] d_APointer: 每批矩阵乘的零填充A矩阵首地址数组
* @param [in] d_BPointer: 每批矩阵乘的零填充B矩阵首地址数组
* @param [in] beta: alpha*AB+beta*C
* @param [out] d_CPointer: 每批矩阵乘的零填充B矩阵首地址数组
* @param[in] d_isAPadding: 各个batch A矩阵是否需要padding
* @param[in] d_isBPadding: 各个batch B矩阵是否需要padding
* @param[in] d_APointerPadding: 需padding的A矩阵新的device内存空间
* @param[in] d_BPointerPadding: 需padding的B矩阵新的device内存空间
* @param [in] d_AicAivWorkspacePointer: Aic Aiv 同步的GM空间首地址数组
* @param [in] fftsAddr: 跨核同步需要的地址
* @param [in] isAlpha1Beta0: 是否有 alpha==1.0 && beta==0.0
*/
extern "C" __global__ [aicore] void batchgemm_device (
layoutType layoutA,
layoutType layoutB,
uint32_t zeroPaddingM,
uint32_t zeroPaddingN,
uint32_t zeroPaddingK,
uint32_t batchCount,
__gm__ half* d_maskA,
__gm__ uint32_t* d_validM,
__gm__ uint32_t* d_validN,
__gm__ uint32_t* d_validK,
half alpha,
__gm__ half** d_APointer,
__gm__ half** d_BPointer,
half beta,
__gm__ half** d_CPointer,
__gm__ uint8_t *d_isAPadding,
__gm__ uint8_t *d_isBPadding,
__gm__ half** d_APointerPadding,
__gm__ half** d_BPointerPadding,
uint8_t paddingDirA,
uint8_t paddingDirB,
__gm__ half** d_AicAivWorkspacePointer,
uint64_t fftsAddr,
uint8_t isAlpha1Beta0
) {
#if __DAV_C220_CUBE__
AscendC::SetSyncBaseAddr(fftsAddr);
AscendC::SetAtomicNone();
AscendC::SetLoadDataPaddingValue<uint64_t>((uint64_t)0);
AscendC::CrossCoreWaitFlag(1);
BatchMatmul<L1M0, L1N0, L1K0, WORKSPACENUM>(
layoutA,
layoutB,
zeroPaddingM,
zeroPaddingN,
zeroPaddingK,
batchCount,
d_validM,
d_validN,
d_validK,
alpha,
d_APointer,
d_BPointer,
beta,
d_CPointer,
d_isAPadding,
d_isBPadding,
d_APointerPadding,
d_BPointerPadding,
paddingDirA,
paddingDirB,
d_AicAivWorkspacePointer,
isAlpha1Beta0
);
#elif __DAV_C220_VEC__
AscendC::SetSyncBaseAddr(fftsAddr);
AscendC::SetAtomicNone();
AscendC::SetMaskNorm();
AscendC::SetVectorMask<half, AscendC::MaskMode::NORMAL>( (uint64_t)-1, (uint64_t)-1 );
CalValidLenHead(
layoutA,
layoutB,
zeroPaddingM,
zeroPaddingN,
zeroPaddingK,
batchCount,
d_maskA,
d_APointer,
d_validM,
d_validN,
d_validK
);
AscendC::CrossCoreSetFlag<0, PIPE_MTE3>(0);
AscendC::CrossCoreWaitFlag(0);
BatchMatrixPadding<L1M0, L1K0>(
layoutA,
zeroPaddingM,
zeroPaddingK,
batchCount,
d_validM,
d_validK,
d_APointer,
d_isAPadding,
d_APointerPadding,
paddingDirA
);
AscendC::CrossCoreSetFlag<0, PIPE_MTE3>(0);
AscendC::CrossCoreWaitFlag(0);
BatchMatrixPadding<L1K0, L1N0>(
layoutB,
zeroPaddingK,
zeroPaddingN,
batchCount,
d_validK,
d_validN,
d_BPointer,
d_isBPadding,
d_BPointerPadding,
paddingDirB
);
AscendC::CrossCoreSetFlag<0, PIPE_MTE3>(0);
AscendC::CrossCoreWaitFlag(0);
AscendC::CrossCoreSetFlag<2, PIPE_MTE3>(1);
BatchMatmulEpilogue<L1M0, L1N0, L1K0, WORKSPACENUM>(
zeroPaddingM,
zeroPaddingN,
batchCount,
d_validM,
d_validN,
alpha,
beta,
d_CPointer,
d_AicAivWorkspacePointer,
isAlpha1Beta0
);
#endif
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ascend/cann-var-sequence-gemm.git
git@gitee.com:ascend/cann-var-sequence-gemm.git
ascend
cann-var-sequence-gemm
cann-var-sequence-gemm
master

Search