8 Star 123 Fork 5

Ascend/cann-var-sequence-gemm

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
batch_cal_valid_len.h 19.94 KB
一键复制 编辑 原始数据 按行查看 历史
万仁棋-华工 提交于 5个月前 . 4.18
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
/*
Copyright (c) 2025 Huawei Technologies Co., Ltd.
This file is a part of the CANN Open Software.
Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "./custom_type.h"
#include "./data_transfer.h"
#include "./kernel_const.h"
#include "./kernel_utils.h"
#include "kernel_operator.h"
/**
* @brief: 借助掩码矩阵计算有效长度
* @param [in] maskType: 掩码矩阵数据类型
* @param [in] zeroPaddingM: 零填充后文本长度
* @param [in] batchCount: 批量数
* @param [out] d_validM: 各批次文本有效长度
* @param [in] d_maskA: 掩码矩阵
*/
template<typename maskType>
[aicore] inline __attribute__((always_inline)) void BatchCalValidMWithMask(
uint32_t zeroPaddingM,
uint32_t batchCount,
__gm__ uint32_t * d_validM,
__gm__ maskType * d_maskA
){
AscendC::TBuf<AscendC::TPosition::VECIN> ub_buf;
AscendC::TPipe ub_pipe;
ub_pipe.InitBuffer(ub_buf, UB_BYTES);
AscendC::LocalTensor<uint8_t> ub_tensor = ub_buf.Get<uint8_t>();
ub_pipe.Destroy();
static constexpr uint64_t ub_reduceSumNum = 1;
uint32_t ub_typeRate = sizeof(float) / sizeof(maskType);
uint32_t ub_devideBuf = 0;
uint32_t ub_inputMaskBufMulRate = 0;
uint32_t ub_calcMaskBufMulRate = 0;
uint32_t ub_workLocalMulRate = 0;
uint32_t ub_srcResultDivRate = 0;
uint32_t ub_dstResultDivRate = 0;
if(ub_typeRate == 4){
ub_devideBuf = 8;
ub_inputMaskBufMulRate = 1;
ub_calcMaskBufMulRate = ub_typeRate;
ub_workLocalMulRate = ub_typeRate/2;
ub_srcResultDivRate = 2;
ub_dstResultDivRate = 2;
}else if(ub_typeRate == 2){
ub_devideBuf = 6;
ub_inputMaskBufMulRate = 1;
ub_calcMaskBufMulRate = ub_typeRate;
ub_workLocalMulRate = ub_typeRate/2;
ub_srcResultDivRate = 1;
ub_dstResultDivRate = 1;
}else if(ub_typeRate == 1){
ub_devideBuf = 4;
ub_inputMaskBufMulRate = 1;
ub_calcMaskBufMulRate = ub_typeRate;
ub_workLocalMulRate = ub_typeRate;
ub_srcResultDivRate = 2;
ub_dstResultDivRate = 2;
}
uint64_t ub_reduceSumBufBytes = UB_BYTES / ub_reduceSumNum;
uint64_t ub_reduceSumPartBytes = ub_reduceSumBufBytes / ub_devideBuf;
uint64_t ub_reduceSumInputMaskBytes = ub_reduceSumPartBytes * ub_inputMaskBufMulRate;
uint32_t ub_reduceSumInputMaskSize = ub_reduceSumInputMaskBytes / sizeof(maskType);
uint64_t ub_reduceSumMaskBytes = ub_reduceSumPartBytes * ub_calcMaskBufMulRate;
uint32_t ub_reduceSumMaskSize = ub_reduceSumMaskBytes / sizeof(float);
uint64_t ub_reduceSumWorkLocalBytes = ub_reduceSumPartBytes * ub_workLocalMulRate;
uint32_t ub_reduceSumWorkLocalSize = ub_reduceSumWorkLocalBytes / sizeof(float);
uint64_t ub_reduceSumResultSrcBytes = ub_reduceSumPartBytes / ub_srcResultDivRate;
uint32_t ub_reduceSumResultSrcSize = ub_reduceSumResultSrcBytes / sizeof(float);
uint64_t ub_reduceSumResultDstBytes = ub_reduceSumPartBytes / ub_dstResultDivRate;
uint32_t ub_reduceSumResultDstSize = ub_reduceSumResultDstBytes / sizeof(int32_t);
AscendC::LocalTensor<maskType> ub_reduceSumInputMaskBuf[ub_reduceSumNum];
AscendC::LocalTensor<float> ub_reduceSumMaskBuf[ub_reduceSumNum];
AscendC::LocalTensor<float> ub_reduceSumWorkLocalBuf[ub_reduceSumNum];
AscendC::LocalTensor<float> ub_reduceSumResultSrcBuf[ub_reduceSumNum];
AscendC::LocalTensor<int32_t> ub_reduceSumResultDstBuf[ub_reduceSumNum];
for(uint32_t i = 0; i < ub_reduceSumNum; i++){
ub_reduceSumInputMaskBuf[i] = ub_tensor[ i * ub_reduceSumBufBytes ].template ReinterpretCast<maskType>();
ub_reduceSumMaskBuf[i] = ub_tensor[ i * ub_reduceSumBufBytes
+ ub_reduceSumInputMaskBytes ].template ReinterpretCast<float>();
ub_reduceSumWorkLocalBuf[i] = ub_tensor[ i * ub_reduceSumBufBytes
+ ub_reduceSumInputMaskBytes
+ ub_reduceSumMaskBytes ].template ReinterpretCast<float>();
ub_reduceSumResultSrcBuf[i] = ub_tensor[ i * ub_reduceSumBufBytes
+ ub_reduceSumInputMaskBytes
+ ub_reduceSumMaskBytes
+ ub_reduceSumWorkLocalBytes ].template ReinterpretCast<float>();
ub_reduceSumResultDstBuf[i] = ub_tensor[ i * ub_reduceSumBufBytes
+ ub_reduceSumInputMaskBytes
+ ub_reduceSumMaskBytes
+ ub_reduceSumWorkLocalBytes
+ ub_reduceSumResultSrcBytes ].template ReinterpretCast<int32_t>();
}
for(uint32_t i = 0; i < ub_reduceSumNum; i++){
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>((event_t)(i));
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>((event_t)(i));
}
uint32_t loopSumInner = CeilDiv<uint32_t>(zeroPaddingM, ub_reduceSumMaskSize);
uint32_t loopSumTotal = batchCount;
uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum();
uint32_t aivIdx = AscendC::GetBlockIdx();
AscendC::GlobalTensor<maskType> gm_maskA;
AscendC::GlobalTensor<int32_t> gm_validM;
uint32_t curAicoreTaskIdx = 0;
uint32_t batchIdx = -1;
uint32_t batchTaskIdx = -1;
uint32_t batchTaskLen = 0;
uint8_t flagId = -1;
for(uint32_t loopIdx = 0; loopIdx < loopSumTotal; loopIdx++ ){
batchIdx++;
gm_maskA.SetGlobalBuffer( ( __gm__ maskType * )(d_maskA + batchIdx * zeroPaddingM) );
gm_validM.SetGlobalBuffer( ( __gm__ int32_t * )(d_validM + batchIdx) );
if(loopIdx % aivNum != aivIdx){
continue;
}
for(uint32_t loopIdxInner = 0; loopIdxInner < loopSumInner; loopIdxInner++){
batchTaskIdx = loopIdxInner;
batchTaskLen = (batchTaskIdx == loopSumInner - 1 ? zeroPaddingM - batchTaskIdx * ub_reduceSumMaskSize : ub_reduceSumMaskSize);
flagId = curAicoreTaskIdx % ub_reduceSumNum;
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>((event_t)(flagId));
Gm2Ub<maskType>(
ub_reduceSumInputMaskBuf[flagId],
gm_maskA[batchTaskIdx * ub_reduceSumMaskSize],
1,
batchTaskLen,
0,
batchTaskLen,
0,
0
);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>((event_t)(flagId));
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>((event_t)(flagId));
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast<float, maskType>(
ub_reduceSumMaskBuf[flagId],
ub_reduceSumInputMaskBuf[flagId],
AscendC::RoundMode::CAST_NONE,
batchTaskLen
);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>((event_t)(flagId));
AscendC::PipeBarrier<PIPE_V>();
AscendC::ReduceSum<float>(
ub_reduceSumResultSrcBuf[flagId],
ub_reduceSumMaskBuf[flagId],
ub_reduceSumWorkLocalBuf[flagId],
batchTaskLen
);
AscendC::PipeBarrier<PIPE_V>();
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>((event_t)(flagId));
AscendC::Cast<int32_t, float>(
ub_reduceSumResultDstBuf[flagId],
ub_reduceSumResultSrcBuf[flagId],
AscendC::RoundMode::CAST_RINT,
1
);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>((event_t)(flagId));
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>((event_t)(flagId));
if(batchTaskIdx != 0) AscendC::SetAtomicAdd<int32_t>();
Ub2Gm<int32_t>(
gm_validM,
ub_reduceSumResultDstBuf[flagId],
1,
1,
1,
0
);
if(batchTaskIdx != 0) AscendC::SetAtomicNone();
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>((event_t)(flagId));
curAicoreTaskIdx++;
}
}
for(uint32_t i = 0; i < ub_reduceSumNum; i++){
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>((event_t)(i));
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>((event_t)(i));
}
}
/**
* @brief: 无掩码矩阵计算有效长度
* @param [in] layout: 输入矩阵行/列优先存储类型
* @param [in] zeroPaddingM: 零填充后文本长度
* @param [in] zeroPaddingK: 矩阵K维度
* @param [in] batchCount: 批量数
* @param [out] d_validM: 各批次文本有效长度
* @param [in] d_A_pointer: 矩阵指针数组
*/
[aicore] inline __attribute__((always_inline)) void BatchCalValidMNoMask(
layoutType layout,
uint32_t zeroPaddingM,
uint32_t zeroPaddingK,
uint32_t batchCount,
__gm__ uint32_t * d_validM,
__gm__ half ** d_A_pointer
){
AscendC::TBuf<AscendC::TPosition::VECIN> ub_buf;
AscendC::TPipe ub_pipe;
ub_pipe.InitBuffer(ub_buf, UB_BYTES);
AscendC::LocalTensor<uint8_t> ub_tensor = ub_buf.Get<uint8_t>();
ub_pipe.Destroy();
static constexpr uint64_t ub_bufNum = 1;
uint64_t ub_midTensorBytes = 256;
AscendC::LocalTensor<half> ub_midTensorBuf[ub_bufNum];
for(uint64_t i = 0; i < ub_bufNum; i++){
ub_midTensorBuf[i] = ub_tensor[ i * ub_midTensorBytes ].template ReinterpretCast<half>();
}
uint64_t ub_cmpResultTensorBytes = ub_midTensorBytes / sizeof(half) / 8;
AscendC::LocalTensor<uint8_t> ub_cmpResultTensorBuf[ub_bufNum];
for(uint64_t i = 0; i < ub_bufNum; i++){
ub_cmpResultTensorBuf[i] = ub_tensor[ ub_bufNum * ub_midTensorBytes
+ i * ub_cmpResultTensorBytes].template ReinterpretCast<uint8_t>();
}
uint64_t ub_validMTensorBytes = sizeof(uint32_t);
AscendC::LocalTensor<uint32_t> ub_validMTensorBuf[ub_bufNum];
for(uint64_t i = 0; i < ub_bufNum; i++){
ub_validMTensorBuf[i] = ub_tensor[ub_bufNum * ub_midTensorBytes
+ ub_bufNum * ub_cmpResultTensorBytes
+ i * ub_validMTensorBytes].template ReinterpretCast<uint32_t>();
}
AscendC::GlobalTensor<half> gm_matrix;
AscendC::GlobalTensor<uint32_t> gm_validM;
for(uint32_t i = 0; i < ub_bufNum; i++){
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>((event_t)(i % ub_bufNum));
}
for(uint32_t i = 0; i < ub_bufNum; i++){
AscendC::SetFlag<AscendC::HardEvent::S_V>((event_t)(ub_bufNum + i % ub_bufNum));
}
for(uint32_t i = 0; i < ub_bufNum; i++){
AscendC::SetFlag<AscendC::HardEvent::MTE3_S>((event_t)(ub_bufNum + ub_bufNum + i % ub_bufNum));
}
uint32_t curAicoreBatchIdx = 0;
for(uint32_t batchIdx = 0; batchIdx < batchCount; batchIdx++){
if(batchIdx % (AscendC::GetBlockNum() * AscendC::GetSubBlockNum()) != AscendC::GetBlockIdx()){
continue;
}
gm_matrix.SetGlobalBuffer( (__gm__ half * )d_A_pointer[batchIdx] );
gm_validM.SetGlobalBuffer( (__gm__ uint32_t * )(d_validM) );
int64_t left = 0, right = zeroPaddingM;
while(right - left > 1){
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>((event_t)(curAicoreBatchIdx % ub_bufNum));
AscendC::WaitFlag<AscendC::HardEvent::S_V>((event_t)(ub_bufNum + curAicoreBatchIdx % ub_bufNum));
int64_t mid = left + (right - left) / 2;
Gm2Ub<half>(
ub_midTensorBuf[curAicoreBatchIdx % ub_bufNum],
gm_matrix[mid * ( layout == RowMajor? zeroPaddingK : 1)],
1,
1,
0,
1,
0,
0
);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>((event_t)(curAicoreBatchIdx % ub_bufNum));
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>((event_t)(curAicoreBatchIdx % ub_bufNum));
AscendC::CompareScalar<half, uint8_t>(
ub_cmpResultTensorBuf[curAicoreBatchIdx % ub_bufNum],
ub_midTensorBuf[curAicoreBatchIdx % ub_bufNum],
0.0f,
AscendC::CMPMODE::NE,
ub_midTensorBytes / sizeof(half)
);
AscendC::SetFlag<AscendC::HardEvent::V_S>((event_t)(ub_bufNum + curAicoreBatchIdx % ub_bufNum));
AscendC::WaitFlag<AscendC::HardEvent::V_S>((event_t)(ub_bufNum + curAicoreBatchIdx % ub_bufNum));
auto temp = ub_cmpResultTensorBuf[curAicoreBatchIdx % ub_bufNum].GetValue(0);
AscendC::PipeBarrier<PIPE_S>();
int64_t firstOneIdx = AscendC::ScalarGetSFFValue<1>(temp);
bool midNE0 = (firstOneIdx == 0);
if(midNE0){
left = mid;
}else{
right = mid;
}
AscendC::PipeBarrier<PIPE_S>();
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>((event_t)(curAicoreBatchIdx % ub_bufNum));
AscendC::SetFlag<AscendC::HardEvent::S_V>((event_t)(ub_bufNum + curAicoreBatchIdx % ub_bufNum));
}
int64_t validM = left + 1;
ub_validMTensorBuf[curAicoreBatchIdx % ub_bufNum].SetValue(0, validM);
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(ub_bufNum + ub_bufNum + curAicoreBatchIdx % ub_bufNum);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_S>((event_t)(ub_bufNum + ub_bufNum + curAicoreBatchIdx % ub_bufNum));
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(ub_bufNum + ub_bufNum + curAicoreBatchIdx % ub_bufNum);
Ub2Gm<uint32_t>(
gm_validM[batchIdx],
ub_validMTensorBuf[curAicoreBatchIdx % ub_bufNum],
1,
1,
1,
0
);
AscendC::SetFlag<AscendC::HardEvent::MTE3_S>((event_t)(ub_bufNum + ub_bufNum + curAicoreBatchIdx % ub_bufNum));
curAicoreBatchIdx++;
}
for(uint32_t i = 0; i < ub_bufNum; i++){
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>((event_t)(i % ub_bufNum));
}
for(uint32_t i = 0; i < ub_bufNum; i++){
AscendC::WaitFlag<AscendC::HardEvent::S_V>((event_t)(ub_bufNum + i % ub_bufNum));
}
for(uint32_t i = 0; i < ub_bufNum; i++){
AscendC::WaitFlag<AscendC::HardEvent::MTE3_S>((event_t)(ub_bufNum + ub_bufNum + i % ub_bufNum));
}
}
/**
* @brief: 有效长度数组拷贝
* @param [in] batchCount: 数组长度
* @param [in] d_dstValidLen: 目的数组
* @param [in] d_srcValidLen: 源数组
*/
[aicore] inline __attribute__((always_inline)) void CopyValidLen(
uint32_t batchCount,
__gm__ uint32_t * d_dstValidLen,
__gm__ uint32_t * d_srcValidLen
){
AscendC::TBuf<AscendC::TPosition::VECIN> ub_buf;
AscendC::TPipe ub_pipe;
ub_pipe.InitBuffer(ub_buf, UB_BYTES);
AscendC::LocalTensor<uint8_t> ub_tensor = ub_buf.Get<uint8_t>();
ub_pipe.Destroy();
static constexpr uint8_t ub_bufNum = 2;
uint64_t ub_bufBytes = UB_BYTES / ub_bufNum;
uint32_t ub_bufSize = ub_bufBytes / sizeof(uint32_t);
AscendC::LocalTensor<uint32_t> ub_bufArr[ub_bufNum];
for(uint32_t i = 0; i < ub_bufNum; i++){
ub_bufArr[i] = ub_tensor[ i * ub_bufBytes].template ReinterpretCast<uint32_t>();
}
for(uint32_t i = 0; i < ub_bufNum; i++){
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>((event_t)(i));
}
uint32_t loopSum = CeilDiv<uint32_t>(batchCount, ub_bufSize);
uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum();
uint32_t aivIdx = AscendC::GetBlockIdx();
AscendC::GlobalTensor<uint32_t> gm_srcValidLen;
AscendC::GlobalTensor<uint32_t> gm_dstValidLen;
uint32_t curLen = 0;
uint8_t flagId = -1;
uint32_t curAicoreTaskIdx = 0;
for(uint32_t loopIdx = 0; loopIdx < loopSum; loopIdx++){
if(loopIdx % aivNum != aivIdx){
continue;
}
gm_srcValidLen.SetGlobalBuffer( (__gm__ uint32_t *)(d_srcValidLen + loopIdx * ub_bufSize) );
gm_dstValidLen.SetGlobalBuffer( (__gm__ uint32_t *)(d_dstValidLen + loopIdx * ub_bufSize) );
curLen = (loopIdx == loopSum - 1 ? batchCount - loopIdx * ub_bufSize : ub_bufSize);
flagId = curAicoreTaskIdx % ub_bufNum;
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>((event_t)(flagId));
Gm2Ub<uint32_t>(
ub_bufArr[flagId],
gm_srcValidLen,
1,
curLen,
0,
curLen,
0,
0
);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>((event_t)(flagId));
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>((event_t)(flagId));
Ub2Gm<uint32_t>(
gm_dstValidLen,
ub_bufArr[flagId],
1,
curLen,
curLen,
0
);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>((event_t)(flagId));
curAicoreTaskIdx++;
}
for(uint32_t i = 0; i < ub_bufNum; i++){
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>((event_t)(i));
}
}
/**
* @brief: 有效长度数组填充
* @param [out] d_validLen: 目的数组
* @param [in] validLen: 填充值
*/
[aicore] inline __attribute__((always_inline)) void PadValidLen(
uint32_t batchCount,
__gm__ uint32_t * d_validLen,
uint32_t validLen
){
AscendC::TBuf<AscendC::TPosition::VECIN> ub_buf;
AscendC::TPipe ub_pipe;
ub_pipe.InitBuffer(ub_buf, UB_BYTES);
AscendC::LocalTensor<uint8_t> ub_tensor = ub_buf.Get<uint8_t>();
ub_pipe.Destroy();
static constexpr uint8_t ub_bufNum = 2;
uint64_t ub_bufBytes = UB_BYTES / ub_bufNum;
uint32_t ub_bufSize = ub_bufBytes / sizeof(uint32_t);
AscendC::LocalTensor<uint32_t> ub_bufArr[ub_bufNum];
for(uint32_t i = 0; i < ub_bufNum; i++){
ub_bufArr[i] = ub_tensor[ i * ub_bufBytes].template ReinterpretCast<uint32_t>();
}
for(uint32_t i = 0; i < ub_bufNum; i++){
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>((event_t)(i));
}
uint32_t loopSum = CeilDiv<uint32_t>(batchCount, ub_bufSize);
uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum();
uint32_t aivIdx = AscendC::GetBlockIdx();
AscendC::GlobalTensor<uint32_t> gm_validLen;
uint32_t curLen = 0;
uint8_t flagId = -1;
uint32_t curAicoreTaskIdx = 0;
for(uint32_t loopIdx = 0; loopIdx < loopSum; loopIdx++){
if(loopIdx % aivNum != aivIdx){
continue;
}
gm_validLen.SetGlobalBuffer( (__gm__ uint32_t *)(d_validLen + loopIdx * ub_bufSize) );
curLen = (loopIdx == loopSum - 1 ? batchCount - loopIdx * ub_bufSize : ub_bufSize);
flagId = curAicoreTaskIdx % ub_bufNum;
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>((event_t)(flagId));
AscendC::Duplicate<uint32_t>(ub_bufArr[flagId], validLen, curLen);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>((event_t)(flagId));
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>((event_t)(flagId));
Ub2Gm<uint32_t>(
gm_validLen,
ub_bufArr[flagId],
1,
curLen,
curLen,
0
);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>((event_t)(flagId));
curAicoreTaskIdx++;
}
for(uint32_t i = 0; i < ub_bufNum; i++){
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>((event_t)(i));
}
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ascend/cann-var-sequence-gemm.git
git@gitee.com:ascend/cann-var-sequence-gemm.git
ascend
cann-var-sequence-gemm
cann-var-sequence-gemm
master

搜索帮助