18 Star 46 Fork 691

Ascend/op-plugin

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
FlashAttentionKernelNpuOpApi.cpp 87.80 KB
一键复制 编辑 原始数据 按行查看 历史
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833
// Copyright (c) 2023 Huawei Technologies Co., Ltd
// All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h"
#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
const int DIM_0 = 0;
const int DIM_1 = 1;
const int DIM_2 = 2;
const int DIM_3 = 3;
const int THIRD_ELEMENT = 2;
const int FORTH_ELEMENT = 3;
const int DIMENSION_3D = 3;
const int DIMENSION_4D = 4;
const int LAYOUT_MAX_LENGTH = 20;
const double EPSILON = 1e-9;
const int64_t LENGTH_BIAS = 32;
const static int FLASH_THRESHOLD = 512;
const static int64_t SOFTMAXMAX_LAST_DIMSHAPE = 8;
const static int64_t PFA_SPARSE_HIGH_PRECISION_NO_MASK = 10;
const static int64_t PFA_SPARSE_HIGH_PRECISION_BAND = 14;
const static int64_t MAX_SEQUENCE_LENGTH = 1000000;
using namespace at_npu::native;
using npu_preparation = at_npu::native::OpPreparation;
enum class DropOutStatus {
DROPOUT_NORMAL = 0,
DROPOUT_NONE,
DROPOUT_ALL
};
enum class SparseMode {
NO_MASK = 0,
ALL_MASK,
LEFT_UP_CAUSAL,
RIGHT_DOWN_CAUSAL,
BAND,
PREFIX,
PREFIX_COMPRESS,
RIGHT_DOWN_CAUSAL_BAND,
BAND_LEFT_UP_CAUSAL
};
namespace {
DropOutStatus get_dropout_status(double keep_prob)
{
if (std::abs(keep_prob - 0.0) < EPSILON) {
return DropOutStatus::DROPOUT_ALL;
}
if (std::abs(keep_prob - 1.0) < EPSILON) {
return DropOutStatus::DROPOUT_NONE;
}
return DropOutStatus::DROPOUT_NORMAL;
}
at::Tensor format_trans(const at::Tensor &at_tensor)
{
if (at_tensor.defined()) {
TORCH_CHECK(torch_npu::utils::is_npu(at_tensor),
"Expected all tensors to be on the same device. "
"Expected NPU tensor, please check whether the input tensor device is correct.",
OPS_ERROR(ErrCode::TYPE));
return custom_ops::npu_format_cast(at_tensor, ACL_FORMAT_ND);
}
return at_tensor;
}
at::Tensor& stateless_dropout_gen_mask_aclop(const at::Tensor &query, double keep_prob, int64_t seed,
const int64_t offset, const int64_t numels, at::Tensor& mask)
{
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
c10::TensorOptions options = query.options();
at::SmallVector<int64_t, ::N> offsetList = {0, offset};
const int64_t seed1 = 0;
at_npu::native::OpCommand cmd;
cmd.Name("StatelessDropOutGenMask")
.Input(at::IntArrayRef{numels})
.Input(at::Scalar(keep_prob), query.scalar_type(), at_npu::native::CompileType::MEMORY_HOST_COMPILE_DEPENDENT)
.Input(at::Scalar(seed), at::ScalarType::Int)
.Input(at::Scalar(seed1), at::ScalarType::Int)
.Input(offsetList, at::kLong, at_npu::native::CompileType::MEMORY_HOST_COMPILE_INDEPENDENT)
.Output(mask)
.Run();
return mask;
}
at::Tensor dropout_gen_mask_impl(const at::Tensor &query, double keep_prob, int64_t seed,
const int64_t offset, const int64_t numels)
{
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
c10::TensorOptions options = query.options();
at::Tensor mask = OpPreparation::apply_tensor_without_format(at::IntArrayRef{length}, options.dtype(at::kByte));
c10::SmallVector<int64_t, SIZE> shapeSize = {numels};
at::IntArrayRef shapeArray = at::IntArrayRef(shapeSize);
double prob;
at::Scalar probScalar;
if (query.scalar_type() == at::kHalf) {
probScalar = at::Scalar(at::Half(1.0)- at::Half(keep_prob));
} else if (query.scalar_type() == at::kBFloat16) {
probScalar = at::Scalar(at::BFloat16(1.0)- at::BFloat16(keep_prob));
} else {
probScalar = at::Scalar(float(1.0) - float(keep_prob));
}
prob = probScalar.toDouble();
aclDataType probDataType = at_npu::native::OpPreparation::convert_to_acl_data_type(query.scalar_type());
DO_COMPATIBILITY(aclnnDropoutGenMaskV2,
stateless_dropout_gen_mask_aclop(query, keep_prob, seed, offset, numels, mask));
EXEC_NPU_CMD(aclnnDropoutGenMaskV2, shapeArray, prob, seed, offset, probDataType, mask);
return mask;
}
at::Tensor dropout_gen_mask_dispatch(const at::Tensor &query, double keep_prob, int64_t seed,
const int64_t offset, const int64_t numels, const bool gen_mask_parallel, const bool sync)
{
at::Tensor mask;
if (gen_mask_parallel) {
auto original_stream = c10_npu::getCurrentNPUStream();
{
// During the life cycle of this raii instance, the calcu stream is set as the
// secondary stream, and tasks are distributed to the secondary stream. At the
// same time, according to the one-stream-one-pool principle, memory is also
// alloced from the pool of the secondary stream.
c10_npu::SecondaryStreamGuard guard(c10_npu::getCurrentSecondaryStream());
mask = dropout_gen_mask_impl(query, keep_prob, seed, offset, numels);
if (sync) {
OPS_CHECK_ERROR(c10_npu::acl::AclrtSynchronizeStreamWithTimeout(original_stream));
}
}
} else {
mask = dropout_gen_mask_impl(query, keep_prob, seed, offset, numels);
}
return mask;
}
} // namespace _
#if VERSION_BETWEEN(V1R11, V1R11)
at::Tensor dropout_gen_mask(const at::Tensor &query, const at::Tensor &key, double keep_prob, int64_t head_num,
std::string input_layout, bool gen_mask_parallel, bool sync, int64_t &seed, int64_t &offset, int64_t &numels)
{
at::Tensor drop_mask;
if (input_layout == "BSH") {
numels = query.size(0) * head_num * query.size(1) * key.size(1); // [B,N,S,S]
} else if (input_layout == "SBH") {
numels = query.size(1) * head_num * query.size(0) * key.size(0); // [B,N,S,S]
} else if (input_layout == "BNSD") {
numels = query.size(0) * query.size(1) * query.size(THIRD_ELEMENT) * key.size(THIRD_ELEMENT); // [B,N,S,S]
} else if (input_layout == "BSND") {
numels = query.size(0) * query.size(THIRD_ELEMENT) * query.size(1) * key.size(1); // [B,N,S,S]
}
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
length += LENGTH_BIAS;
if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
const auto gen = at_npu::detail::getDefaultNPUGenerator();
auto pair = at::check_generator<at_npu::NPUGeneratorImpl>(gen)->philox_engine_inputs(10);
seed = static_cast<int64_t>(pair.first);
offset = static_cast<int64_t>(pair.second);
drop_mask = dropout_gen_mask_dispatch(query, keep_prob, seed, offset, numels, gen_mask_parallel, sync);
} else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
}
return drop_mask;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_backward(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &dy,
int64_t head_num,
const std::string input_layout,
const c10::optional<at::Tensor> &pse,
const c10::optional<at::Tensor> &drop_mask,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &softmax_max,
const c10::optional<at::Tensor> &softmax_sum,
const c10::optional<at::Tensor> &softmax_in,
const c10::optional<at::Tensor> &attention_in,
double scale_value,
double keep_prob,
int64_t pre_tockens,
int64_t next_tockens,
int64_t inner_precise,
c10::optional<at::IntArrayRef> prefix,
c10::optional<at::IntArrayRef> actual_seq_qlen,
c10::optional<at::IntArrayRef> actual_seq_kvlen,
int64_t sparse_mode)
{
double scale = scale_value;
const at::Tensor &pse_const = pse.value_or(at::Tensor());
const at::Tensor &drop_mask_const = drop_mask.value_or(at::Tensor());
const at::Tensor &padding_mask_const = padding_mask.value_or(at::Tensor());
const at::Tensor &atten_mask_const = atten_mask.value_or(at::Tensor());
const at::Tensor &softmax_max_const = softmax_max.value_or(at::Tensor());
const at::Tensor &softmax_sum_const = softmax_sum.value_or(at::Tensor());
const at::Tensor &softmax_const = softmax_in.value_or(at::Tensor());
const at::Tensor &attention_const = attention_in.value_or(at::Tensor());
auto prefixN = prefix.value_or(at::IntArrayRef{});
auto ac_seq_qlen = actual_seq_qlen.value_or(at::IntArrayRef{});
auto ac_seq_kvlen = actual_seq_kvlen.value_or(at::IntArrayRef{});
at::Tensor format_query = format_trans(query);
at::Tensor format_key = format_trans(key);
at::Tensor format_value = format_trans(value);
at::Tensor format_dy = format_trans(dy);
at::Tensor format_pse = format_trans(pse_const);
at::Tensor format_drop_mask = format_trans(drop_mask_const);
at::Tensor format_padding_mask = format_trans(padding_mask_const);
at::Tensor format_atten_mask = format_trans(atten_mask_const);
at::Tensor format_softmax_max = format_trans(softmax_max_const);
at::Tensor format_softmax_sum = format_trans(softmax_sum_const);
at::Tensor format_softmax = format_trans(softmax_const);
at::Tensor format_attention = format_trans(attention_const);
at::Tensor dq = OpPreparation::apply_tensor_without_format(format_query);
at::Tensor dk = OpPreparation::apply_tensor_without_format(format_key);
at::Tensor dv = OpPreparation::apply_tensor_without_format(format_value);
at::Tensor dpse;
if (format_pse.defined()) {
dpse = OpPreparation::apply_tensor_without_format(format_pse);
} else {
dpse = at::empty({0}, query.options());
}
char input_layout_char[LAYOUT_MAX_LENGTH];
strncpy(input_layout_char, input_layout.c_str(), LAYOUT_MAX_LENGTH - 1);
if (!ac_seq_qlen.empty() && !ac_seq_kvlen.empty()) {
EXEC_NPU_CMD(
aclnnFlashAttentionUnpaddingScoreGrad, format_query, format_key, format_value, format_dy,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
format_softmax_sum, format_softmax, format_attention, prefixN, ac_seq_qlen, ac_seq_kvlen,
scale_value, keep_prob, pre_tockens, next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode, dq, dk, dv, dpse);
} else {
EXEC_NPU_CMD(
aclnnFlashAttentionScoreGrad, format_query, format_key, format_value, format_dy,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
format_softmax_sum, format_softmax, format_attention, prefixN, scale_value, keep_prob, pre_tockens,
next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode, dq, dk, dv, dpse);
}
FLOP_COUNT(FlopCounter::flash_attention_backward_flop, query, key, value, dy, head_num, input_layout,
c10::OptionalIntArrayRef(actual_seq_qlen), c10::OptionalIntArrayRef(actual_seq_kvlen));
if (!format_pse.defined()) {
at::Tensor dpse_required;
dpse = dpse_required;
}
return std::make_tuple(dq, dk, dv, dpse);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_grad(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &dy,
int64_t head_num,
c10::string_view input_layout,
const c10::optional<at::Tensor> &pse,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &softmax_max,
const c10::optional<at::Tensor> &softmax_sum,
const c10::optional<at::Tensor> &softmax_in,
const c10::optional<at::Tensor> &attention_in,
double scale_value,
double keep_prob,
int64_t pre_tockens,
int64_t next_tockens,
int64_t inner_precise,
int64_t seed,
int64_t offset,
int64_t numels,
c10::optional<at::IntArrayRef> prefix,
c10::optional<at::IntArrayRef> actual_seq_qlen,
c10::optional<at::IntArrayRef> actual_seq_kvlen,
int64_t sparse_mode,
bool gen_mask_parallel,
bool sync)
{
TORCH_CHECK(query.dim() == DIMENSION_3D || query.dim() == DIMENSION_4D,
"The shapes of the input query should be 3 or 4 dimensional, but got ",
query.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(key.dim() == DIMENSION_3D || key.dim() == DIMENSION_4D,
"The shapes of the input key should be 3 or 4 dimensional, but got ",
key.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(value.dim() == DIMENSION_3D || value.dim() == DIMENSION_4D,
"The shapes of the input value should be 3 or 4 dimensional, but got ",
value.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(dy.dim() == DIMENSION_3D || dy.dim() == DIMENSION_4D,
"The shapes of the input dy should be 3 or 4 dimensional, but got ",
dy.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(keep_prob > 0 && keep_prob <= 1,
"The keep_prob value must be in range of (0, 1], but got ",
keep_prob, OPS_ERROR(ErrCode::PARAM));
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "TND") {
TORCH_CHECK((sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode < static_cast<int64_t>(SparseMode::PREFIX)) ||
(sparse_mode > static_cast<int64_t>(SparseMode::PREFIX) &&
sparse_mode <= static_cast<int64_t>(SparseMode::BAND_LEFT_UP_CAUSAL)),
"The sparse_mode value must be in range of [0,5) or (5,8], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
} else {
TORCH_CHECK(sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode <= static_cast<int64_t>(SparseMode::PREFIX_COMPRESS),
"The sparse_mode value must be in range of [0,6], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
}
for (auto &c : input_layout_str) {
c = toupper(c);
}
TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" ||
input_layout_str == "BNSD" || input_layout_str == "BSND" || input_layout_str == "TND",
"The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ",
input_layout, OPS_ERROR(ErrCode::PARAM));
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
length += LENGTH_BIAS;
at::Tensor drop_mask;
if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
drop_mask = dropout_gen_mask_dispatch(query, keep_prob, seed, offset, numels, gen_mask_parallel, sync);
} else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
}
auto result = npu_flash_attention_backward(query,
key, value, dy, head_num, input_layout_str, pse, drop_mask, padding_mask, atten_mask,
softmax_max, softmax_sum, softmax_in, attention_in, scale_value, keep_prob, pre_tockens,
next_tockens, inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, sparse_mode);
if (!sync) {
c10_npu::NPUEvent npu_event;
npu_event.record(c10_npu::getCurrentNPUStream());
npu_event.block(c10_npu::getCurrentSecondaryStream());
}
return result;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, int64_t, int64_t, int64_t> npu_flash_attention(
const at::Tensor &query, const at::Tensor &key,
const at::Tensor &value, int64_t head_num, c10::string_view input_layout,
const c10::optional<at::Tensor> &pse, const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask, double scale, double keep_prob,
int64_t pre_tockens, int64_t next_tockens, int64_t inner_precise,
c10::optional<at::IntArrayRef> prefix, c10::optional<at::IntArrayRef> actual_seq_qlen,
c10::optional<at::IntArrayRef> actual_seq_kvlen, int64_t sparse_mode, bool gen_mask_parallel, bool sync)
{
const at::Tensor &pse_const = pse.value_or(at::Tensor());
const at::Tensor &padding_mask_const = padding_mask.value_or(at::Tensor());
const at::Tensor &atten_mask_const = atten_mask.value_or(at::Tensor());
auto prefixN = prefix.value_or(at::IntArrayRef{});
auto ac_seq_qlen = actual_seq_qlen.value_or(at::IntArrayRef{});
auto ac_seq_kvlen = actual_seq_kvlen.value_or(at::IntArrayRef{});
TORCH_CHECK(head_num > 0, "head_num must > 0, but got ", head_num, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(query.dim() == DIMENSION_3D || query.dim() == DIMENSION_4D,
"The shapes of the input query should be 3 or 4 dimensional, but got ",
query.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(key.dim() == DIMENSION_3D || key.dim() == DIMENSION_4D,
"The shapes of the input key should be 3 or 4 dimensional, but got ",
key.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(value.dim() == DIMENSION_3D || value.dim() == DIMENSION_4D,
"The shapes of the input value should be 3 or 4 dimensional, but got ",
value.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(keep_prob > 0 && keep_prob <= 1,
"The keep_prob value must be in range of (0, 1], but got ",
keep_prob, OPS_ERROR(ErrCode::PARAM));
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "TND") {
TORCH_CHECK((sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode < static_cast<int64_t>(SparseMode::PREFIX)) ||
(sparse_mode > static_cast<int64_t>(SparseMode::PREFIX) &&
sparse_mode <= static_cast<int64_t>(SparseMode::BAND_LEFT_UP_CAUSAL)),
"The sparse_mode value must be in range of [0,5) or (5,8], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(ac_seq_qlen.size() != 0 && ac_seq_kvlen.size() != 0 && ac_seq_qlen.size() == ac_seq_kvlen.size(),
"the size of actual_seq_qlen and actual_seq_kvlen must be the same and cannot be empty." +
OPS_ERROR(ErrCode::PARAM));
} else {
TORCH_CHECK(sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode <= static_cast<int64_t>(SparseMode::PREFIX_COMPRESS),
"The sparse_mode value must be in range of [0,6], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
}
for (auto &c : input_layout_str) {
c = toupper(c);
}
TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" ||
input_layout_str == "BNSD" || input_layout_str == "BSND" || input_layout_str == "TND",
"The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ",
input_layout, OPS_ERROR(ErrCode::PARAM));
int64_t B = 0;
int64_t S0 = 0; // S for query
int64_t S1 = 0; // S for key & value
int64_t N_local = 0; // N for npu_fusion_attention
int64_t D = 0;
int64_t H = 0;
int64_t T = 0;
int64_t D2 = 0; // D2 for value head-dim
c10::SmallVector<int64_t> atten_score_shape;
if (input_layout_str == "BSH") {
B = query.size(0);
S0 = query.size(1);
S1 = key.size(1);
H = query.size(THIRD_ELEMENT);
D = H / head_num;
D2 = (D == 0 || !key.size(THIRD_ELEMENT)) ? 0 : value.size(THIRD_ELEMENT) / (key.size(THIRD_ELEMENT) / D);
atten_score_shape = {B, S0, head_num * D2};
} else if (input_layout_str == "SBH") {
B = query.size(1);
S0 = query.size(0);
S1 = key.size(0);
H = query.size(THIRD_ELEMENT);
D = H / head_num;
D2 = (D == 0 || !key.size(THIRD_ELEMENT)) ? 0 : value.size(THIRD_ELEMENT) / (key.size(THIRD_ELEMENT) / D);
atten_score_shape = {S0, B, head_num * D2};
} else if (input_layout_str == "BNSD") {
B = query.size(0);
N_local = query.size(1);
S0 = query.size(THIRD_ELEMENT);
S1 = key.size(THIRD_ELEMENT);
D = query.size(FORTH_ELEMENT);
D2 = value.size(FORTH_ELEMENT);
atten_score_shape = {B, N_local, S0, D2};
} else if (input_layout_str == "BSND") {
B = query.size(0);
N_local = query.size(THIRD_ELEMENT);
S0 = query.size(1);
S1 = key.size(1);
D = query.size(FORTH_ELEMENT);
D2 = value.size(FORTH_ELEMENT);
atten_score_shape = {B, S0, N_local, D2};
} else if (input_layout_str == "TND") {
T = query.size(0);
N_local = query.size(1);
D = query.size(THIRD_ELEMENT);
D2 = value.size(THIRD_ELEMENT);
atten_score_shape = {T, N_local, D2};
}
double scale_value = scale;
at::Tensor format_query = format_trans(query);
at::Tensor attention_score = npu_preparation::apply_tensor_without_format(atten_score_shape, query.options());
at::Tensor format_key = format_trans(key);
at::Tensor format_value = format_trans(value);
at::Tensor format_pse = format_trans(pse_const);
at::Tensor format_padding_mask = format_trans(padding_mask_const);
at::Tensor format_atten_mask = format_trans(atten_mask_const);
int64_t seed;
int64_t offset;
int64_t numels;
if (input_layout_str == "TND") {
numels = N_local;
int64_t accum = ac_seq_qlen[0] * ac_seq_kvlen[0];
for (size_t i = 1; i < ac_seq_qlen.size(); i++) {
accum += ((ac_seq_qlen[i] - ac_seq_qlen[i - 1]) * (ac_seq_kvlen[i] - ac_seq_kvlen[i - 1]));
}
numels *= accum;
}
at::Tensor format_drop_mask = dropout_gen_mask(format_query, format_key, keep_prob, head_num, input_layout_str,
gen_mask_parallel, sync, seed, offset, numels);
at::Tensor softmax_max;
at::Tensor softmax_sum;
at::Tensor softmax_out;
if (input_layout_str != "TND") {
softmax_max = OpPreparation::apply_tensor_without_format({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [B, N, S0, 8]
softmax_sum = OpPreparation::apply_tensor_without_format({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [B, N, S0, 8]
} else {
softmax_max = OpPreparation::apply_tensor_without_format({T, N_local, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [T, N, 8]
softmax_sum = OpPreparation::apply_tensor_without_format({T, N_local, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [T, N, 8]
}
softmax_out = at::empty({0}, query.options());
char input_layout_char[LAYOUT_MAX_LENGTH];
strncpy(input_layout_char, input_layout_str.c_str(), LAYOUT_MAX_LENGTH - 1);
if (!ac_seq_qlen.empty() && !ac_seq_kvlen.empty()) {
EXEC_NPU_CMD(
aclnnFlashAttentionVarLenScore, format_query, format_key, format_value,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN,
ac_seq_qlen, ac_seq_kvlen, scale, keep_prob, pre_tockens, next_tockens, head_num,
input_layout_char, inner_precise, sparse_mode, softmax_max, softmax_sum,
softmax_out, attention_score);
} else {
EXEC_NPU_CMD(
aclnnFlashAttentionScore, format_query, format_key, format_value,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN,
scale, keep_prob, pre_tockens, next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode, softmax_max, softmax_sum, softmax_out, attention_score);
}
FLOP_COUNT(FlopCounter::flash_attention_forward_flop, query, key, value, head_num, input_layout_str,
c10::OptionalIntArrayRef(actual_seq_qlen), c10::OptionalIntArrayRef(actual_seq_kvlen));
if (!sync) {
c10_npu::NPUEvent npu_event;
npu_event.record(c10_npu::getCurrentNPUStream());
npu_event.block(c10_npu::getCurrentSecondaryStream());
}
return std::make_tuple(attention_score, softmax_max, softmax_sum, softmax_out,
seed, offset, numels);
}
#endif
#if VERSION_BETWEEN(V2R0, V2R0)
at::Tensor dropout_gen_mask(const at::Tensor &query, const at::Tensor &key, double keep_prob, int64_t head_num,
std::string input_layout, bool gen_mask_parallel, bool sync, int64_t &seed, int64_t &offset, int64_t &numels)
{
at::Tensor drop_mask;
if (input_layout == "BSH") {
numels = query.size(0) * head_num * query.size(1) * key.size(1); // [B,N,S,S]
} else if (input_layout == "SBH") {
numels = query.size(1) * head_num * query.size(0) * key.size(0); // [B,N,S,S]
} else if (input_layout == "BNSD") {
numels = query.size(0) * query.size(1) * query.size(THIRD_ELEMENT) * key.size(THIRD_ELEMENT); // [B,N,S,S]
} else if (input_layout == "BSND") {
numels = query.size(0) * query.size(THIRD_ELEMENT) * query.size(1) * key.size(1); // [B,N,S,S]
}
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
length += LENGTH_BIAS;
if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
const auto gen = at_npu::detail::getDefaultNPUGenerator();
auto pair = at::check_generator<at_npu::NPUGeneratorImpl>(gen)->philox_engine_inputs(10);
seed = static_cast<int64_t>(pair.first);
offset = static_cast<int64_t>(pair.second);
drop_mask = dropout_gen_mask_dispatch(query, keep_prob, seed, offset, numels, gen_mask_parallel, sync);
} else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
}
return drop_mask;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_fusion_attention_backward(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &dy,
int64_t head_num,
const std::string input_layout,
const c10::optional<at::Tensor> &pse,
const c10::optional<at::Tensor> &drop_mask,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &softmax_max,
const c10::optional<at::Tensor> &softmax_sum,
const c10::optional<at::Tensor> &softmax_in,
const c10::optional<at::Tensor> &attention_in,
double scale_value,
double keep_prob,
int64_t pre_tockens,
int64_t next_tockens,
int64_t inner_precise,
at::OptionalIntArrayRef prefix,
at::OptionalIntArrayRef actual_seq_qlen,
at::OptionalIntArrayRef actual_seq_kvlen,
int64_t sparse_mode)
{
double scale = scale_value;
const at::Tensor &pse_const = pse.value_or(at::Tensor());
const at::Tensor &drop_mask_const = drop_mask.value_or(at::Tensor());
const at::Tensor &padding_mask_const = padding_mask.value_or(at::Tensor());
const at::Tensor &atten_mask_const = atten_mask.value_or(at::Tensor());
const at::Tensor &softmax_max_const = softmax_max.value_or(at::Tensor());
const at::Tensor &softmax_sum_const = softmax_sum.value_or(at::Tensor());
const at::Tensor &softmax_const = softmax_in.value_or(at::Tensor());
const at::Tensor &attention_const = attention_in.value_or(at::Tensor());
auto prefixN = prefix.value_or(at::IntArrayRef{});
auto ac_seq_qlen = actual_seq_qlen.value_or(at::IntArrayRef{});
auto ac_seq_kvlen = actual_seq_kvlen.value_or(at::IntArrayRef{});
at::Tensor format_query = format_trans(query);
at::Tensor format_key = format_trans(key);
at::Tensor format_value = format_trans(value);
at::Tensor format_dy = format_trans(dy);
at::Tensor format_pse = format_trans(pse_const);
at::Tensor format_drop_mask = format_trans(drop_mask_const);
at::Tensor format_padding_mask = format_trans(padding_mask_const);
at::Tensor format_atten_mask = format_trans(atten_mask_const);
at::Tensor format_softmax_max = format_trans(softmax_max_const);
at::Tensor format_softmax_sum = format_trans(softmax_sum_const);
at::Tensor format_softmax = format_trans(softmax_const);
at::Tensor format_attention = format_trans(attention_const);
at::Tensor dq = OpPreparation::apply_tensor_without_format(format_query);
at::Tensor dk = OpPreparation::apply_tensor_without_format(format_key);
at::Tensor dv = OpPreparation::apply_tensor_without_format(format_value);
at::Tensor dpse;
if (format_pse.defined()) {
dpse = OpPreparation::apply_tensor_without_format(format_pse);
} else {
dpse = at::empty({0}, query.options());
}
char input_layout_char[LAYOUT_MAX_LENGTH];
strncpy(input_layout_char, input_layout.c_str(), LAYOUT_MAX_LENGTH - 1);
if (!ac_seq_qlen.empty() && !ac_seq_kvlen.empty()) {
EXEC_NPU_CMD(
aclnnFlashAttentionUnpaddingScoreGrad, format_query, format_key, format_value, format_dy,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
format_softmax_sum, format_softmax, format_attention, prefixN, ac_seq_qlen, ac_seq_kvlen,
scale_value, keep_prob, pre_tockens, next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode, dq, dk, dv, dpse);
} else {
EXEC_NPU_CMD(
aclnnFlashAttentionScoreGrad, format_query, format_key, format_value, format_dy,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
format_softmax_sum, format_softmax, format_attention, prefixN, scale_value, keep_prob,
pre_tockens, next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode, dq, dk, dv, dpse);
}
FLOP_COUNT(FlopCounter::flash_attention_backward_flop, query, key, value, dy,
head_num, input_layout, actual_seq_qlen, actual_seq_kvlen);
if (!format_pse.defined()) {
at::Tensor dpse_required;
dpse = dpse_required;
}
return std::make_tuple(dq, dk, dv, dpse);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_fusion_attention_grad(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &dy,
int64_t head_num,
c10::string_view input_layout,
const c10::optional<at::Tensor> &pse,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &softmax_max,
const c10::optional<at::Tensor> &softmax_sum,
const c10::optional<at::Tensor> &softmax_in,
const c10::optional<at::Tensor> &attention_in,
double scale_value,
double keep_prob,
int64_t pre_tockens,
int64_t next_tockens,
int64_t inner_precise,
int64_t seed,
int64_t offset,
int64_t numels,
at::OptionalIntArrayRef prefix,
at::OptionalIntArrayRef actual_seq_qlen,
at::OptionalIntArrayRef actual_seq_kvlen,
int64_t sparse_mode,
bool gen_mask_parallel,
bool sync)
{
TORCH_CHECK(query.dim() == DIMENSION_3D || query.dim() == DIMENSION_4D,
"The shapes of the input query should be 3 or 4 dimensional, but got ",
query.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(key.dim() == DIMENSION_3D || key.dim() == DIMENSION_4D,
"The shapes of the input key should be 3 or 4 dimensional, but got ",
key.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(value.dim() == DIMENSION_3D || value.dim() == DIMENSION_4D,
"The shapes of the input value should be 3 or 4 dimensional, but got ",
value.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(dy.dim() == DIMENSION_3D || dy.dim() == DIMENSION_4D,
"The shapes of the input dy should be 3 or 4 dimensional, but got ",
dy.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(keep_prob > 0 && keep_prob <= 1,
"The keep_prob value must be in range of (0, 1], but got ",
keep_prob, OPS_ERROR(ErrCode::VALUE));
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "TND") {
TORCH_CHECK((sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode < static_cast<int64_t>(SparseMode::PREFIX)) ||
(sparse_mode > static_cast<int64_t>(SparseMode::PREFIX) &&
sparse_mode <= static_cast<int64_t>(SparseMode::BAND_LEFT_UP_CAUSAL)),
"The sparse_mode value must be in range of [0,5) or (5,8], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
} else {
TORCH_CHECK(sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode <= static_cast<int64_t>(SparseMode::PREFIX_COMPRESS),
"The sparse_mode value must be in range of [0,6], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
}
for (auto &c : input_layout_str) {
c = toupper(c);
}
TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" || input_layout_str == "BNSD" ||
input_layout_str == "BSND" || input_layout_str == "TND",
"The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ",
input_layout, OPS_ERROR(ErrCode::PARAM));
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
length += LENGTH_BIAS;
at::Tensor drop_mask;
if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
drop_mask = dropout_gen_mask_dispatch(query, keep_prob, seed, offset, numels, gen_mask_parallel, sync);
} else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
}
auto result = npu_fusion_attention_backward(query,
key, value, dy, head_num, input_layout_str, pse, drop_mask, padding_mask, atten_mask,
softmax_max, softmax_sum, softmax_in, attention_in, scale_value, keep_prob, pre_tockens,
next_tockens, inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, sparse_mode);
if (!sync) {
c10_npu::NPUEvent npu_event;
npu_event.record(c10_npu::getCurrentNPUStream());
npu_event.block(c10_npu::getCurrentSecondaryStream());
}
return result;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, int64_t, int64_t, int64_t> npu_fusion_attention(
const at::Tensor &query, const at::Tensor &key,
const at::Tensor &value, int64_t head_num, c10::string_view input_layout,
const c10::optional<at::Tensor> &pse, const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask, double scale, double keep_prob,
int64_t pre_tockens, int64_t next_tockens, int64_t inner_precise,
at::OptionalIntArrayRef prefix, at::OptionalIntArrayRef actual_seq_qlen,
at::OptionalIntArrayRef actual_seq_kvlen, int64_t sparse_mode, bool gen_mask_parallel, bool sync)
{
const at::Tensor &pse_const = pse.value_or(at::Tensor());
const at::Tensor &padding_mask_const = padding_mask.value_or(at::Tensor());
const at::Tensor &atten_mask_const = atten_mask.value_or(at::Tensor());
auto prefixN = prefix.value_or(at::IntArrayRef{});
auto ac_seq_qlen = actual_seq_qlen.value_or(at::IntArrayRef{});
auto ac_seq_kvlen = actual_seq_kvlen.value_or(at::IntArrayRef{});
TORCH_CHECK(head_num > 0, "head_num must > 0, but got ", head_num, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(query.dim() == DIMENSION_3D || query.dim() == DIMENSION_4D,
"The shapes of the input query should be 3 or 4 dimensional, but got ",
query.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(key.dim() == DIMENSION_3D || key.dim() == DIMENSION_4D,
"The shapes of the input key should be 3 or 4 dimensional, but got ",
key.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(value.dim() == DIMENSION_3D || value.dim() == DIMENSION_4D,
"The shapes of the input value should be 3 or 4 dimensional, but got ",
value.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(keep_prob > 0 && keep_prob <= 1,
"The keep_prob value must be in range of (0, 1], but got ",
keep_prob, OPS_ERROR(ErrCode::PARAM));
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "TND") {
TORCH_CHECK((sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode < static_cast<int64_t>(SparseMode::PREFIX)) ||
(sparse_mode > static_cast<int64_t>(SparseMode::PREFIX) &&
sparse_mode <= static_cast<int64_t>(SparseMode::BAND_LEFT_UP_CAUSAL)),
"The sparse_mode value must be in range of [0,5) or (5,8], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(ac_seq_qlen.size() != 0 && ac_seq_kvlen.size() != 0 && ac_seq_qlen.size() == ac_seq_kvlen.size(),
"the size of actual_seq_qlen and actual_seq_kvlen must be the same and cannot be empty." +
OPS_ERROR(ErrCode::PARAM));
} else {
TORCH_CHECK(sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode <= static_cast<int64_t>(SparseMode::PREFIX_COMPRESS),
"The sparse_mode value must be in range of [0,6], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
}
for (auto &c : input_layout_str) {
c = toupper(c);
}
TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" ||
input_layout_str == "BNSD" || input_layout_str == "BSND" || input_layout_str == "TND",
"The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ",
input_layout, OPS_ERROR(ErrCode::PARAM));
int64_t B = 0;
int64_t S0 = 0; // S for query
int64_t S1 = 0; // S for key & value
int64_t N_local = 0; // N for npu_fusion_attention
int64_t D = 0;
int64_t H = 0;
int64_t T = 0;
int64_t D2 = 0; // D2 for value head-dim
c10::SmallVector<int64_t> atten_score_shape;
if (input_layout_str == "BSH") {
B = query.size(0);
S0 = query.size(1);
S1 = key.size(1);
H = query.size(THIRD_ELEMENT);
D = H / head_num;
D2 = (D == 0 || !key.size(THIRD_ELEMENT)) ? 0 : value.size(THIRD_ELEMENT) / (key.size(THIRD_ELEMENT) / D);
atten_score_shape = {B, S0, head_num * D2};
} else if (input_layout_str == "SBH") {
B = query.size(1);
S0 = query.size(0);
S1 = key.size(0);
H = query.size(THIRD_ELEMENT);
D = H / head_num;
D2 = (D == 0 || !key.size(THIRD_ELEMENT)) ? 0 : value.size(THIRD_ELEMENT) / (key.size(THIRD_ELEMENT) / D);
atten_score_shape = {S0, B, head_num * D2};
} else if (input_layout_str == "BNSD") {
B = query.size(0);
N_local = query.size(1);
S0 = query.size(THIRD_ELEMENT);
S1 = key.size(THIRD_ELEMENT);
D = query.size(FORTH_ELEMENT);
D2 = value.size(FORTH_ELEMENT);
atten_score_shape = {B, N_local, S0, D2};
} else if (input_layout_str == "BSND") {
B = query.size(0);
N_local = query.size(THIRD_ELEMENT);
S0 = query.size(1);
S1 = key.size(1);
D = query.size(FORTH_ELEMENT);
D2 = value.size(FORTH_ELEMENT);
atten_score_shape = {B, S0, N_local, D2};
} else if (input_layout_str == "TND") {
T = query.size(0);
N_local = query.size(1);
D = query.size(THIRD_ELEMENT);
D2 = value.size(THIRD_ELEMENT);
atten_score_shape = {T, N_local, D2};
}
double scale_value = scale;
at::Tensor format_query = format_trans(query);
at::Tensor attention_score = npu_preparation::apply_tensor_without_format(atten_score_shape, query.options());
at::Tensor format_key = format_trans(key);
at::Tensor format_value = format_trans(value);
at::Tensor format_pse = format_trans(pse_const);
at::Tensor format_padding_mask = format_trans(padding_mask_const);
at::Tensor format_atten_mask = format_trans(atten_mask_const);
int64_t seed;
int64_t offset;
int64_t numels;
if (input_layout_str == "TND") {
numels = N_local;
int64_t accum = ac_seq_qlen[0] * ac_seq_kvlen[0];
for (size_t i = 1; i < ac_seq_qlen.size(); i++) {
accum += ((ac_seq_qlen[i] - ac_seq_qlen[i - 1]) * (ac_seq_kvlen[i] - ac_seq_kvlen[i - 1]));
}
numels *= accum;
}
at::Tensor format_drop_mask = dropout_gen_mask(format_query, format_key, keep_prob, head_num, input_layout_str,
gen_mask_parallel, sync, seed, offset, numels);
at::Tensor softmax_max;
at::Tensor softmax_sum;
at::Tensor softmax_out;
if (input_layout_str != "TND") {
softmax_max = OpPreparation::apply_tensor_without_format({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [B, N, S0, 8]
softmax_sum = OpPreparation::apply_tensor_without_format({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [B, N, S0, 8]
} else {
softmax_max = OpPreparation::apply_tensor_without_format({T, N_local, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [T, N, 8]
softmax_sum = OpPreparation::apply_tensor_without_format({T, N_local, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [T, N, 8]
}
softmax_out = at::empty({0}, query.options());
char input_layout_char[LAYOUT_MAX_LENGTH];
strncpy(input_layout_char, input_layout_str.c_str(), LAYOUT_MAX_LENGTH - 1);
if (!ac_seq_qlen.empty() && !ac_seq_kvlen.empty()) {
EXEC_NPU_CMD(
aclnnFlashAttentionVarLenScore, format_query, format_key, format_value,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN,
ac_seq_qlen, ac_seq_kvlen, scale, keep_prob, pre_tockens, next_tockens, head_num,
input_layout_char, inner_precise, sparse_mode, softmax_max, softmax_sum,
softmax_out, attention_score);
} else {
EXEC_NPU_CMD(
aclnnFlashAttentionScore, format_query, format_key, format_value,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN,
scale, keep_prob, pre_tockens, next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode, softmax_max, softmax_sum, softmax_out, attention_score);
}
FLOP_COUNT(FlopCounter::flash_attention_forward_flop, query, key, value,
head_num, input_layout_str, actual_seq_qlen, actual_seq_kvlen);
if (!sync) {
c10_npu::NPUEvent npu_event;
npu_event.record(c10_npu::getCurrentNPUStream());
npu_event.block(c10_npu::getCurrentSecondaryStream());
}
return std::make_tuple(attention_score, softmax_max, softmax_sum, softmax_out,
seed, offset, numels);
}
#endif
#if VERSION_BETWEEN(V2R1, V2R1)
at::Tensor dropout_gen_mask(const at::Tensor &query, const at::Tensor &key, double keep_prob, int64_t head_num,
std::string input_layout, bool gen_mask_parallel, bool sync, int64_t &seed, int64_t &offset, int64_t &numels)
{
at::Tensor drop_mask;
if (input_layout == "BSH") {
numels = query.size(0) * head_num * query.size(1) * key.size(1); // [B,N,S,S]
} else if (input_layout == "SBH") {
numels = query.size(1) * head_num * query.size(0) * key.size(0); // [B,N,S,S]
} else if (input_layout == "BNSD") {
numels = query.size(0) * query.size(1) * query.size(THIRD_ELEMENT) * key.size(THIRD_ELEMENT); // [B,N,S,S]
} else if (input_layout == "BSND") {
numels = query.size(0) * query.size(THIRD_ELEMENT) * query.size(1) * key.size(1); // [B,N,S,S]
}
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
length += LENGTH_BIAS;
if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
const auto gen = at_npu::detail::getDefaultNPUGenerator();
auto pair = at::check_generator<at_npu::NPUGeneratorImpl>(gen)->philox_engine_inputs(10);
seed = static_cast<int64_t>(pair.first);
offset = static_cast<int64_t>(pair.second);
drop_mask = dropout_gen_mask_dispatch(query, keep_prob, seed, offset, numels, gen_mask_parallel, sync);
} else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
}
return drop_mask;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_fusion_attention_backward(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &dy,
int64_t head_num,
const std::string input_layout,
const c10::optional<at::Tensor> &pse,
const c10::optional<at::Tensor> &drop_mask,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &softmax_max,
const c10::optional<at::Tensor> &softmax_sum,
const c10::optional<at::Tensor> &softmax_in,
const c10::optional<at::Tensor> &attention_in,
double scale_value,
double keep_prob,
int64_t pre_tockens,
int64_t next_tockens,
int64_t inner_precise,
c10::OptionalIntArrayRef prefix,
c10::OptionalIntArrayRef actual_seq_qlen,
c10::OptionalIntArrayRef actual_seq_kvlen,
int64_t sparse_mode)
{
double scale = scale_value;
const at::Tensor &pse_const = pse.value_or(at::Tensor());
const at::Tensor &drop_mask_const = drop_mask.value_or(at::Tensor());
const at::Tensor &padding_mask_const = padding_mask.value_or(at::Tensor());
const at::Tensor &atten_mask_const = atten_mask.value_or(at::Tensor());
const at::Tensor &softmax_max_const = softmax_max.value_or(at::Tensor());
const at::Tensor &softmax_sum_const = softmax_sum.value_or(at::Tensor());
const at::Tensor &softmax_const = softmax_in.value_or(at::Tensor());
const at::Tensor &attention_const = attention_in.value_or(at::Tensor());
auto prefixN = prefix.value_or(at::IntArrayRef{});
auto ac_seq_qlen = actual_seq_qlen.value_or(at::IntArrayRef{});
auto ac_seq_kvlen = actual_seq_kvlen.value_or(at::IntArrayRef{});
at::Tensor format_query = format_trans(query);
at::Tensor format_key = format_trans(key);
at::Tensor format_value = format_trans(value);
at::Tensor format_dy = format_trans(dy);
at::Tensor format_pse = format_trans(pse_const);
at::Tensor format_drop_mask = format_trans(drop_mask_const);
at::Tensor format_padding_mask = format_trans(padding_mask_const);
at::Tensor format_atten_mask = format_trans(atten_mask_const);
at::Tensor format_softmax_max = format_trans(softmax_max_const);
at::Tensor format_softmax_sum = format_trans(softmax_sum_const);
at::Tensor format_softmax = format_trans(softmax_const);
at::Tensor format_attention = format_trans(attention_const);
at::Tensor dq = OpPreparation::apply_tensor_without_format(format_query);
at::Tensor dk = OpPreparation::apply_tensor_without_format(format_key);
at::Tensor dv = OpPreparation::apply_tensor_without_format(format_value);
at::Tensor dpse;
if (format_pse.defined()) {
dpse = OpPreparation::apply_tensor_without_format(format_pse);
} else {
dpse = at::empty({0}, query.options());
}
char input_layout_char[LAYOUT_MAX_LENGTH];
strncpy(input_layout_char, input_layout.c_str(), LAYOUT_MAX_LENGTH - 1);
if (!ac_seq_qlen.empty() && !ac_seq_kvlen.empty()) {
EXEC_NPU_CMD(
aclnnFlashAttentionUnpaddingScoreGrad, format_query, format_key, format_value, format_dy,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
format_softmax_sum, format_softmax, format_attention, prefixN, ac_seq_qlen, ac_seq_kvlen,
scale_value, keep_prob, pre_tockens, next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode, dq, dk, dv, dpse);
} else {
EXEC_NPU_CMD(
aclnnFlashAttentionScoreGrad, format_query, format_key, format_value, format_dy,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
format_softmax_sum, format_softmax, format_attention, prefixN, scale_value, keep_prob,
pre_tockens, next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode, dq, dk, dv, dpse);
}
FLOP_COUNT(FlopCounter::flash_attention_backward_flop, query, key, value, dy,
head_num, input_layout, actual_seq_qlen, actual_seq_kvlen);
if (!format_pse.defined()) {
at::Tensor dpse_required;
dpse = dpse_required;
}
return std::make_tuple(dq, dk, dv, dpse);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_fusion_attention_grad(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &dy,
int64_t head_num,
c10::string_view input_layout,
const c10::optional<at::Tensor> &pse,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &softmax_max,
const c10::optional<at::Tensor> &softmax_sum,
const c10::optional<at::Tensor> &softmax_in,
const c10::optional<at::Tensor> &attention_in,
double scale_value,
double keep_prob,
int64_t pre_tockens,
int64_t next_tockens,
int64_t inner_precise,
int64_t seed,
int64_t offset,
int64_t numels,
c10::OptionalIntArrayRef prefix,
c10::OptionalIntArrayRef actual_seq_qlen,
c10::OptionalIntArrayRef actual_seq_kvlen,
int64_t sparse_mode,
bool gen_mask_parallel,
bool sync)
{
TORCH_CHECK(query.dim() == DIMENSION_3D || query.dim() == DIMENSION_4D,
"The shapes of the input query should be 3 or 4 dimensional, but got ",
query.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(key.dim() == DIMENSION_3D || key.dim() == DIMENSION_4D,
"The shapes of the input key should be 3 or 4 dimensional, but got ",
key.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(value.dim() == DIMENSION_3D || value.dim() == DIMENSION_4D,
"The shapes of the input value should be 3 or 4 dimensional, but got ",
value.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(dy.dim() == DIMENSION_3D || dy.dim() == DIMENSION_4D,
"The shapes of the input dy should be 3 or 4 dimensional, but got ",
dy.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(keep_prob > 0 && keep_prob <= 1,
"The keep_prob value must be in range of (0, 1], but got ", keep_prob, OPS_ERROR(ErrCode::PARAM));
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "TND") {
TORCH_CHECK((sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode < static_cast<int64_t>(SparseMode::PREFIX)) ||
(sparse_mode > static_cast<int64_t>(SparseMode::PREFIX) &&
sparse_mode <= static_cast<int64_t>(SparseMode::BAND_LEFT_UP_CAUSAL)),
"The sparse_mode value must be in range of [0,5) or (5,8], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
} else {
TORCH_CHECK(sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode <= static_cast<int64_t>(SparseMode::PREFIX_COMPRESS),
"The sparse_mode value must be in range of [0,6], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
}
for (auto &c : input_layout_str) {
c = toupper(c);
}
TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" || input_layout_str == "BNSD" ||
input_layout_str == "BSND" || input_layout_str == "TND",
"The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ",
input_layout, OPS_ERROR(ErrCode::PARAM));
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
length += LENGTH_BIAS;
at::Tensor drop_mask;
if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
drop_mask = dropout_gen_mask_dispatch(query, keep_prob, seed, offset, numels, gen_mask_parallel, sync);
} else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
}
auto result = npu_fusion_attention_backward(query,
key, value, dy, head_num, input_layout_str, pse, drop_mask, padding_mask, atten_mask,
softmax_max, softmax_sum, softmax_in, attention_in, scale_value, keep_prob, pre_tockens,
next_tockens, inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, sparse_mode);
if (!sync) {
c10_npu::NPUEvent npu_event;
npu_event.record(c10_npu::getCurrentNPUStream());
npu_event.block(c10_npu::getCurrentSecondaryStream());
}
return result;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, int64_t, int64_t, int64_t> npu_fusion_attention(
const at::Tensor &query, const at::Tensor &key,
const at::Tensor &value, int64_t head_num, c10::string_view input_layout,
const c10::optional<at::Tensor> &pse, const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
double scale, double keep_prob, int64_t pre_tockens, int64_t next_tockens, int64_t inner_precise,
c10::OptionalIntArrayRef prefix, c10::OptionalIntArrayRef actual_seq_qlen,
c10::OptionalIntArrayRef actual_seq_kvlen, int64_t sparse_mode, bool gen_mask_parallel, bool sync)
{
const at::Tensor &pse_const = pse.value_or(at::Tensor());
const at::Tensor &padding_mask_const = padding_mask.value_or(at::Tensor());
const at::Tensor &atten_mask_const = atten_mask.value_or(at::Tensor());
auto prefixN = prefix.value_or(at::IntArrayRef{});
auto ac_seq_qlen = actual_seq_qlen.value_or(at::IntArrayRef{});
auto ac_seq_kvlen = actual_seq_kvlen.value_or(at::IntArrayRef{});
TORCH_CHECK(head_num > 0, "head_num must > 0, but got ", head_num, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(query.dim() == DIMENSION_3D || query.dim() == DIMENSION_4D,
"The shapes of the input query should be 3 or 4 dimensional, but got ",
query.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(key.dim() == DIMENSION_3D || key.dim() == DIMENSION_4D,
"The shapes of the input key should be 3 or 4 dimensional, but got ",
key.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(value.dim() == DIMENSION_3D || value.dim() == DIMENSION_4D,
"The shapes of the input value should be 3 or 4 dimensional, but got ",
value.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(keep_prob > 0 && keep_prob <= 1,
"The keep_prob value must be in range of (0, 1], but got ", keep_prob, OPS_ERROR(ErrCode::PARAM));
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "TND") {
TORCH_CHECK((sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode < static_cast<int64_t>(SparseMode::PREFIX)) ||
(sparse_mode > static_cast<int64_t>(SparseMode::PREFIX) &&
sparse_mode <= static_cast<int64_t>(SparseMode::BAND_LEFT_UP_CAUSAL)),
"The sparse_mode value must be in range of [0,5) or (5,8], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(ac_seq_qlen.size() != 0 && ac_seq_kvlen.size() != 0 && ac_seq_qlen.size() == ac_seq_kvlen.size(),
"the size of actual_seq_qlen and actual_seq_kvlen must be the same and cannot be empty." +
OPS_ERROR(ErrCode::PARAM));
} else {
TORCH_CHECK(sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode <= static_cast<int64_t>(SparseMode::PREFIX_COMPRESS),
"The sparse_mode value must be in range of [0,6], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
}
for (auto &c : input_layout_str) {
c = toupper(c);
}
TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" ||
input_layout_str == "BNSD" || input_layout_str == "BSND" || input_layout_str == "TND",
"The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ",
input_layout, OPS_ERROR(ErrCode::PARAM));
int64_t B = 0;
int64_t S0 = 0; // S for query
int64_t S1 = 0; // S for key & value
int64_t N_local = 0; // N for npu_fusion_attention
int64_t D = 0;
int64_t H = 0;
int64_t T = 0;
int64_t D2 = 0; // D2 for value head-dim
c10::SmallVector<int64_t> atten_score_shape;
if (input_layout_str == "BSH") {
B = query.size(0);
S0 = query.size(1);
S1 = key.size(1);
H = query.size(THIRD_ELEMENT);
D = H / head_num;
D2 = (D == 0 || key.size(THIRD_ELEMENT) == 0) ? 0 : value.size(THIRD_ELEMENT) / (key.size(THIRD_ELEMENT) / D);
atten_score_shape = {B, S0, head_num * D2};
} else if (input_layout_str == "SBH") {
B = query.size(1);
S0 = query.size(0);
S1 = key.size(0);
H = query.size(THIRD_ELEMENT);
D = H / head_num;
D2 = (D == 0 || key.size(THIRD_ELEMENT) == 0) ? 0 : value.size(THIRD_ELEMENT) / (key.size(THIRD_ELEMENT) / D);
atten_score_shape = {S0, B, head_num * D2};
} else if (input_layout_str == "BNSD") {
B = query.size(0);
N_local = query.size(1);
S0 = query.size(THIRD_ELEMENT);
S1 = key.size(THIRD_ELEMENT);
D = query.size(FORTH_ELEMENT);
D2 = value.size(FORTH_ELEMENT);
atten_score_shape = {B, N_local, S0, D2};
} else if (input_layout_str == "BSND") {
B = query.size(0);
N_local = query.size(THIRD_ELEMENT);
S0 = query.size(1);
S1 = key.size(1);
D = query.size(FORTH_ELEMENT);
D2 = value.size(FORTH_ELEMENT);
atten_score_shape = {B, S0, N_local, D2};
} else if (input_layout_str == "TND") {
T = query.size(0);
N_local = query.size(1);
D = query.size(THIRD_ELEMENT);
D2 = value.size(THIRD_ELEMENT);
atten_score_shape = {T, N_local, D2};
}
double scale_value = scale;
at::Tensor format_query = format_trans(query);
at::Tensor attention_score = npu_preparation::apply_tensor_without_format(atten_score_shape, query.options());
at::Tensor format_key = format_trans(key);
at::Tensor format_value = format_trans(value);
at::Tensor format_pse = format_trans(pse_const);
at::Tensor format_padding_mask = format_trans(padding_mask_const);
at::Tensor format_atten_mask = format_trans(atten_mask_const);
int64_t seed;
int64_t offset;
int64_t numels;
if (input_layout_str == "TND") {
numels = N_local;
int64_t accum = ac_seq_qlen[0] * ac_seq_kvlen[0];
for (uint64_t i = 1; i < ac_seq_qlen.size(); i++) {
accum += ((ac_seq_qlen[i] - ac_seq_qlen[i - 1]) * (ac_seq_kvlen[i] - ac_seq_kvlen[i - 1]));
}
numels *= accum;
}
at::Tensor format_drop_mask = dropout_gen_mask(format_query, format_key, keep_prob, head_num, input_layout_str,
gen_mask_parallel, sync, seed, offset, numels);
at::Tensor softmax_max;
at::Tensor softmax_sum;
at::Tensor softmax_out;
if (input_layout_str != "TND") {
softmax_max = OpPreparation::apply_tensor_without_format({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [B, N, S0, 8]
softmax_sum = OpPreparation::apply_tensor_without_format({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [B, N, S0, 8]
} else {
softmax_max = OpPreparation::apply_tensor_without_format({T, N_local, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [T, N, 8]
softmax_sum = OpPreparation::apply_tensor_without_format({T, N_local, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [T, N, 8]
}
softmax_out = at::empty({0}, query.options());
char input_layout_char[LAYOUT_MAX_LENGTH];
strncpy(input_layout_char, input_layout_str.c_str(), LAYOUT_MAX_LENGTH - 1);
if (!ac_seq_qlen.empty() && !ac_seq_kvlen.empty()) {
EXEC_NPU_CMD(
aclnnFlashAttentionVarLenScore, format_query, format_key, format_value,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN,
ac_seq_qlen, ac_seq_kvlen, scale, keep_prob, pre_tockens, next_tockens, head_num,
input_layout_char, inner_precise, sparse_mode, softmax_max, softmax_sum,
softmax_out, attention_score);
} else {
EXEC_NPU_CMD(
aclnnFlashAttentionScore, format_query, format_key, format_value,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN,
scale, keep_prob, pre_tockens, next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode, softmax_max, softmax_sum, softmax_out, attention_score);
}
FLOP_COUNT(FlopCounter::flash_attention_forward_flop, query, key, value, head_num,
input_layout_str, actual_seq_qlen, actual_seq_kvlen);
if (!sync) {
c10_npu::NPUEvent npu_event;
npu_event.record(c10_npu::getCurrentNPUStream());
npu_event.block(c10_npu::getCurrentSecondaryStream());
}
return std::make_tuple(attention_score, softmax_max, softmax_sum, softmax_out,
seed, offset, numels);
}
at::Tensor npu_prompt_flash_attention(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &pse_shift,
c10::OptionalIntArrayRef actual_seq_lengths,
const c10::optional<at::Tensor> &deq_scale1,
const c10::optional<at::Tensor> &quant_scale1,
const c10::optional<at::Tensor> &deq_scale2,
const c10::optional<at::Tensor> &quant_scale2,
const c10::optional<at::Tensor> &quant_offset2,
int64_t num_heads, double scale_value,
int64_t pre_tokens, int64_t next_tokens,
c10::string_view input_layout, int64_t num_key_value_heads,
c10::OptionalIntArrayRef actual_seq_lengths_kv,
int64_t sparse_mode)
{
// construct the output tensor of the NPU
at::Tensor output;
at::Tensor tmp_output = npu_preparation::apply_tensor_without_format(query);
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "BNSD_BSND") {
tmp_output = OpPreparation::apply_tensor_without_format(
{query.size(DIM_0), query.size(DIM_2), query.size(DIM_1), query.size(DIM_3)},
query.options().dtype(query.dtype()));
} else if (input_layout_str == "TND") {
tmp_output = OpPreparation::apply_tensor_without_format(
{query.size(DIM_0), query.size(DIM_1), value.size(DIM_2)},
query.options().dtype(query.dtype()));
}
if (quant_scale2.has_value()) {
output = npu_preparation::apply_tensor_without_format(tmp_output.sizes(), c10::dtype(c10::ScalarType::Char));
} else if (query.dtype() == at::kChar) {
output = npu_preparation::apply_tensor_without_format(tmp_output.sizes(), c10::dtype(c10::ScalarType::Half));
} else {
output = npu_preparation::apply_tensor_without_format(tmp_output);
}
auto actSeqLen = actual_seq_lengths.value_or(at::IntArrayRef{});
auto actSeqLenKv = actual_seq_lengths_kv.value_or(at::IntArrayRef{});
int64_t inner_precise = 1;
if (sparse_mode >= PFA_SPARSE_HIGH_PRECISION_NO_MASK && sparse_mode <= PFA_SPARSE_HIGH_PRECISION_BAND) {
// for sparse in range [10,14], set inner calculate mode to high-precision
inner_precise = 0;
sparse_mode -= PFA_SPARSE_HIGH_PRECISION_NO_MASK;
}
char input_layout_char[LAYOUT_MAX_LENGTH];
strncpy(input_layout_char, input_layout_str.c_str(), LAYOUT_MAX_LENGTH - 1);
// dispatch hostAPI
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnPromptFlashAttentionV3, query, key, value, pse_shift, atten_mask, actSeqLen,
actSeqLenKv, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value,
pre_tokens, next_tokens, input_layout_char, num_key_value_heads, sparse_mode,
inner_precise, output);
return output;
}
#endif
#if VERSION_BETWEEN(V2R2, VERSION_NEWEST)
at::Tensor dropout_gen_mask(const at::Tensor &query, const at::Tensor &key, double keep_prob, int64_t head_num,
std::string input_layout, bool gen_mask_parallel, bool sync, int64_t &seed, int64_t &offset, int64_t &numels)
{
at::Tensor drop_mask;
if (input_layout == "BSH") {
numels = query.size(0) * head_num * query.size(1) * key.size(1); // [B,N,S,S]
} else if (input_layout == "SBH") {
numels = query.size(1) * head_num * query.size(0) * key.size(0); // [B,N,S,S]
} else if (input_layout == "BNSD") {
numels = query.size(0) * query.size(1) * query.size(THIRD_ELEMENT) * key.size(THIRD_ELEMENT); // [B,N,S,S]
} else if (input_layout == "BSND") {
numels = query.size(0) * query.size(THIRD_ELEMENT) * query.size(1) * key.size(1); // [B,N,S,S]
}
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
length += LENGTH_BIAS;
if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
const auto gen = at_npu::detail::getDefaultNPUGenerator();
auto pair = at::check_generator<at_npu::NPUGeneratorImpl>(gen)->philox_engine_inputs(10);
seed = static_cast<int64_t>(pair.first);
offset = static_cast<int64_t>(pair.second);
drop_mask = dropout_gen_mask_dispatch(query, keep_prob, seed, offset, numels, gen_mask_parallel, sync);
} else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
}
return drop_mask;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_fusion_attention_backward(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &dy,
int64_t head_num,
const std::string input_layout,
const c10::optional<at::Tensor> &pse,
const c10::optional<at::Tensor> &drop_mask,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &softmax_max,
const c10::optional<at::Tensor> &softmax_sum,
const c10::optional<at::Tensor> &softmax_in,
const c10::optional<at::Tensor> &attention_in,
double scale_value,
double keep_prob,
int64_t pre_tockens,
int64_t next_tockens,
int64_t inner_precise,
c10::OptionalIntArrayRef prefix,
c10::OptionalIntArrayRef actual_seq_qlen,
c10::OptionalIntArrayRef actual_seq_kvlen,
int64_t sparse_mode)
{
double scale = scale_value;
const at::Tensor &pse_const = pse.value_or(at::Tensor());
const at::Tensor &drop_mask_const = drop_mask.value_or(at::Tensor());
const at::Tensor &padding_mask_const = padding_mask.value_or(at::Tensor());
const at::Tensor &atten_mask_const = atten_mask.value_or(at::Tensor());
const at::Tensor &softmax_max_const = softmax_max.value_or(at::Tensor());
const at::Tensor &softmax_sum_const = softmax_sum.value_or(at::Tensor());
const at::Tensor &softmax_const = softmax_in.value_or(at::Tensor());
const at::Tensor &attention_const = attention_in.value_or(at::Tensor());
auto prefixN = prefix.value_or(at::IntArrayRef{});
auto ac_seq_qlen = actual_seq_qlen.value_or(at::IntArrayRef{});
auto ac_seq_kvlen = actual_seq_kvlen.value_or(at::IntArrayRef{});
at::Tensor format_query = format_trans(query);
at::Tensor format_key = format_trans(key);
at::Tensor format_value = format_trans(value);
at::Tensor format_dy = format_trans(dy);
at::Tensor format_pse = format_trans(pse_const);
at::Tensor format_drop_mask = format_trans(drop_mask_const);
at::Tensor format_padding_mask = format_trans(padding_mask_const);
at::Tensor format_atten_mask = format_trans(atten_mask_const);
at::Tensor format_softmax_max = format_trans(softmax_max_const);
at::Tensor format_softmax_sum = format_trans(softmax_sum_const);
at::Tensor format_softmax = format_trans(softmax_const);
at::Tensor format_attention = format_trans(attention_const);
at::Tensor dq = OpPreparation::apply_tensor_without_format(format_query);
at::Tensor dk = OpPreparation::apply_tensor_without_format(format_key);
at::Tensor dv = OpPreparation::apply_tensor_without_format(format_value);
at::Tensor dpse;
if (format_pse.defined()) {
dpse = OpPreparation::apply_tensor_without_format(format_pse);
} else {
dpse = at::empty({0}, query.options());
}
char input_layout_char[LAYOUT_MAX_LENGTH];
strncpy(input_layout_char, input_layout.c_str(), LAYOUT_MAX_LENGTH - 1);
if (!ac_seq_qlen.empty() && !ac_seq_kvlen.empty()) {
EXEC_NPU_CMD(
aclnnFlashAttentionUnpaddingScoreGrad, format_query, format_key, format_value, format_dy,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
format_softmax_sum, format_softmax, format_attention, prefixN, ac_seq_qlen, ac_seq_kvlen,
scale_value, keep_prob, pre_tockens, next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode,
dq, dk, dv, dpse);
} else {
EXEC_NPU_CMD(
aclnnFlashAttentionScoreGrad, format_query, format_key, format_value, format_dy,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
format_softmax_sum, format_softmax, format_attention, prefixN, scale_value, keep_prob,
pre_tockens, next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode, dq, dk, dv, dpse);
}
FLOP_COUNT(FlopCounter::flash_attention_backward_flop, query, key, value,
dy, head_num, input_layout, actual_seq_qlen, actual_seq_kvlen);
if (!format_pse.defined()) {
at::Tensor dpse_required;
dpse = dpse_required;
}
return std::make_tuple(dq, dk, dv, dpse);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_fusion_attention_grad(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &dy,
int64_t head_num,
c10::string_view input_layout,
const c10::optional<at::Tensor> &pse,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &softmax_max,
const c10::optional<at::Tensor> &softmax_sum,
const c10::optional<at::Tensor> &softmax_in,
const c10::optional<at::Tensor> &attention_in,
double scale_value,
double keep_prob,
int64_t pre_tockens,
int64_t next_tockens,
int64_t inner_precise,
int64_t seed,
int64_t offset,
int64_t numels,
c10::OptionalIntArrayRef prefix,
c10::OptionalIntArrayRef actual_seq_qlen,
c10::OptionalIntArrayRef actual_seq_kvlen,
int64_t sparse_mode,
bool gen_mask_parallel,
bool sync)
{
TORCH_CHECK(query.dim() == DIMENSION_3D || query.dim() == DIMENSION_4D,
"The shapes of the input query should be 3 or 4 dimensional, but got ",
query.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(key.dim() == DIMENSION_3D || key.dim() == DIMENSION_4D,
"The shapes of the input key should be 3 or 4 dimensional, but got ",
key.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(value.dim() == DIMENSION_3D || value.dim() == DIMENSION_4D,
"The shapes of the input value should be 3 or 4 dimensional, but got ",
value.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(dy.dim() == DIMENSION_3D || dy.dim() == DIMENSION_4D,
"The shapes of the input dy should be 3 or 4 dimensional, but got ", dy.dim(), "-dimensional",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(keep_prob > 0 && keep_prob <= 1,
"The keep_prob value must be in range of (0, 1], but got ", keep_prob,
OPS_ERROR(ErrCode::PARAM));
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "TND") {
TORCH_CHECK((sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode < static_cast<int64_t>(SparseMode::PREFIX)) ||
(sparse_mode > static_cast<int64_t>(SparseMode::PREFIX) &&
sparse_mode <= static_cast<int64_t>(SparseMode::BAND_LEFT_UP_CAUSAL)),
"The sparse_mode value must be in range of [0,5) or (5,8], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
} else {
TORCH_CHECK(sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode <= static_cast<int64_t>(SparseMode::PREFIX_COMPRESS),
"The sparse_mode value must be in range of [0,6], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
}
for (auto &c : input_layout_str) {
c = toupper(c);
}
TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" || input_layout_str == "BNSD" ||
input_layout_str == "BSND" || input_layout_str == "TND",
"The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ",
input_layout, OPS_ERROR(ErrCode::PARAM));
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
length += LENGTH_BIAS;
at::Tensor drop_mask;
if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
drop_mask = dropout_gen_mask_dispatch(query, keep_prob, seed, offset, numels, gen_mask_parallel, sync);
} else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
}
auto result = npu_fusion_attention_backward(query,
key, value, dy, head_num, input_layout_str, pse, drop_mask, padding_mask, atten_mask,
softmax_max, softmax_sum, softmax_in, attention_in, scale_value, keep_prob, pre_tockens,
next_tockens, inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, sparse_mode);
if (!sync) {
c10_npu::NPUEvent npu_event;
npu_event.record(c10_npu::getCurrentNPUStream());
npu_event.block(c10_npu::getCurrentSecondaryStream());
}
return result;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, int64_t, int64_t, int64_t> npu_fusion_attention(
const at::Tensor &query, const at::Tensor &key,
const at::Tensor &value, int64_t head_num, c10::string_view input_layout,
const c10::optional<at::Tensor> &pse, const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
double scale, double keep_prob, int64_t pre_tockens, int64_t next_tockens, int64_t inner_precise,
c10::OptionalIntArrayRef prefix, c10::OptionalIntArrayRef actual_seq_qlen,
c10::OptionalIntArrayRef actual_seq_kvlen, int64_t sparse_mode, bool gen_mask_parallel, bool sync)
{
const at::Tensor &pse_const = pse.value_or(at::Tensor());
const at::Tensor &padding_mask_const = padding_mask.value_or(at::Tensor());
const at::Tensor &atten_mask_const = atten_mask.value_or(at::Tensor());
auto prefixN = prefix.value_or(at::IntArrayRef{});
auto ac_seq_qlen = actual_seq_qlen.value_or(at::IntArrayRef{});
auto ac_seq_kvlen = actual_seq_kvlen.value_or(at::IntArrayRef{});
TORCH_CHECK(head_num > 0, "head_num must > 0, but got ", head_num, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(query.dim() == DIMENSION_3D || query.dim() == DIMENSION_4D,
"The shapes of the input query should be 3 or 4 dimensional, but got ",
query.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(key.dim() == DIMENSION_3D || key.dim() == DIMENSION_4D,
"The shapes of the input key should be 3 or 4 dimensional, but got ", key.dim(),
"-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(value.dim() == DIMENSION_3D || value.dim() == DIMENSION_4D,
"The shapes of the input value should be 3 or 4 dimensional, but got ", value.dim(),
"-dimensional", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(keep_prob > 0 && keep_prob <= 1,
"The keep_prob value must be in range of (0, 1], but got ", keep_prob, OPS_ERROR(ErrCode::PARAM));
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "TND") {
TORCH_CHECK((sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode < static_cast<int64_t>(SparseMode::PREFIX)) ||
(sparse_mode > static_cast<int64_t>(SparseMode::PREFIX) &&
sparse_mode <= static_cast<int64_t>(SparseMode::BAND_LEFT_UP_CAUSAL)),
"The sparse_mode value must be in range of [0,5) or (5,8], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(ac_seq_qlen.size() != 0 && ac_seq_kvlen.size() != 0 && ac_seq_qlen.size() == ac_seq_kvlen.size(),
"the size of actual_seq_qlen and actual_seq_kvlen must be the same and cannot be empty." +
OPS_ERROR(ErrCode::PARAM));
} else {
TORCH_CHECK(sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode <= static_cast<int64_t>(SparseMode::PREFIX_COMPRESS),
"The sparse_mode value must be in range of [0,6], but got ",
sparse_mode, OPS_ERROR(ErrCode::PARAM));
}
for (auto &c : input_layout_str) {
c = toupper(c);
}
TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" ||
input_layout_str == "BNSD" || input_layout_str == "BSND" || input_layout_str == "TND",
"The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ",
input_layout, OPS_ERROR(ErrCode::PARAM));
int64_t B = 0;
int64_t S0 = 0; // S for query
int64_t S1 = 0; // S for key & value
int64_t N_local = 0; // N for npu_fusion_attention
int64_t D = 0;
int64_t H = 0;
int64_t T = 0;
int64_t D2 = 0; // D2 for value head-dim
c10::SmallVector<int64_t> atten_score_shape;
if (input_layout_str == "BSH") {
B = query.size(0);
S0 = query.size(1);
S1 = key.size(1);
H = query.size(THIRD_ELEMENT);
D = H / head_num;
D2 = (D == 0 || !key.size(THIRD_ELEMENT)) ? 0 : value.size(THIRD_ELEMENT) / (key.size(THIRD_ELEMENT) / D);
atten_score_shape = {B, S0, head_num * D2};
} else if (input_layout_str == "SBH") {
B = query.size(1);
S0 = query.size(0);
S1 = key.size(0);
H = query.size(THIRD_ELEMENT);
D = H / head_num;
D2 = (D == 0 || !key.size(THIRD_ELEMENT)) ? 0 : value.size(THIRD_ELEMENT) / (key.size(THIRD_ELEMENT) / D);
atten_score_shape = {S0, B, head_num * D2};
} else if (input_layout_str == "BNSD") {
B = query.size(0);
N_local = query.size(1);
S0 = query.size(THIRD_ELEMENT);
S1 = key.size(THIRD_ELEMENT);
D = query.size(FORTH_ELEMENT);
D2 = value.size(FORTH_ELEMENT);
atten_score_shape = {B, N_local, S0, D2};
} else if (input_layout_str == "BSND") {
B = query.size(0);
N_local = query.size(THIRD_ELEMENT);
S0 = query.size(1);
S1 = key.size(1);
D = query.size(FORTH_ELEMENT);
D2 = value.size(FORTH_ELEMENT);
atten_score_shape = {B, S0, N_local, D2};
} else if (input_layout_str == "TND") {
T = query.size(0);
N_local = query.size(1);
D = query.size(THIRD_ELEMENT);
D2 = value.size(THIRD_ELEMENT);
atten_score_shape = {T, N_local, D2};
}
double scale_value = scale;
at::Tensor format_query = format_trans(query);
at::Tensor attention_score = npu_preparation::apply_tensor_without_format(atten_score_shape, query.options());
at::Tensor format_key = format_trans(key);
at::Tensor format_value = format_trans(value);
at::Tensor format_pse = format_trans(pse_const);
at::Tensor format_padding_mask = format_trans(padding_mask_const);
at::Tensor format_atten_mask = format_trans(atten_mask_const);
int64_t seed;
int64_t offset;
int64_t numels;
if (input_layout_str == "TND") {
numels = N_local;
int64_t accum = ac_seq_qlen[0] * ac_seq_kvlen[0];
for (size_t i = 1; i < ac_seq_qlen.size(); i++) {
accum += ((ac_seq_qlen[i] - ac_seq_qlen[i - 1]) * (ac_seq_kvlen[i] - ac_seq_kvlen[i - 1]));
}
numels *= accum;
}
at::Tensor format_drop_mask = dropout_gen_mask(format_query, format_key, keep_prob, head_num, input_layout_str,
gen_mask_parallel, sync, seed, offset, numels);
at::Tensor softmax_max;
at::Tensor softmax_sum;
at::Tensor softmax_out;
if (input_layout_str != "TND") {
softmax_max = OpPreparation::apply_tensor_without_format({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [B, N, S0, 8]
softmax_sum = OpPreparation::apply_tensor_without_format({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [B, N, S0, 8]
} else {
softmax_max = OpPreparation::apply_tensor_without_format({T, N_local, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [T, N, 8]
softmax_sum = OpPreparation::apply_tensor_without_format({T, N_local, SOFTMAXMAX_LAST_DIMSHAPE},
query.options().dtype(at::kFloat)); // [T, N, 8]
}
softmax_out = at::empty({0}, query.options());
char input_layout_char[LAYOUT_MAX_LENGTH];
strncpy(input_layout_char, input_layout_str.c_str(), LAYOUT_MAX_LENGTH - 1);
if (!ac_seq_qlen.empty() && !ac_seq_kvlen.empty()) {
EXEC_NPU_CMD(
aclnnFlashAttentionVarLenScore, format_query, format_key, format_value,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN,
ac_seq_qlen, ac_seq_kvlen, scale, keep_prob, pre_tockens, next_tockens, head_num,
input_layout_char, inner_precise, sparse_mode, softmax_max, softmax_sum,
softmax_out, attention_score);
} else {
EXEC_NPU_CMD(
aclnnFlashAttentionScore, format_query, format_key, format_value,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN,
scale, keep_prob, pre_tockens, next_tockens, head_num, input_layout_char,
inner_precise, sparse_mode, softmax_max, softmax_sum, softmax_out, attention_score);
}
FLOP_COUNT(FlopCounter::flash_attention_forward_flop, query, key, value, head_num,
input_layout_str, actual_seq_qlen, actual_seq_kvlen);
if (!sync) {
c10_npu::NPUEvent npu_event;
npu_event.record(c10_npu::getCurrentNPUStream());
npu_event.block(c10_npu::getCurrentSecondaryStream());
}
return std::make_tuple(attention_score, softmax_max, softmax_sum, softmax_out,
seed, offset, numels);
}
at::Tensor npu_prompt_flash_attention(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &pse_shift,
c10::OptionalIntArrayRef actual_seq_lengths,
const c10::optional<at::Tensor> &deq_scale1,
const c10::optional<at::Tensor> &quant_scale1,
const c10::optional<at::Tensor> &deq_scale2,
const c10::optional<at::Tensor> &quant_scale2,
const c10::optional<at::Tensor> &quant_offset2,
int64_t num_heads, double scale_value,
int64_t pre_tokens, int64_t next_tokens,
c10::string_view input_layout, int64_t num_key_value_heads,
c10::OptionalIntArrayRef actual_seq_lengths_kv,
int64_t sparse_mode)
{
// construct the output tensor of the NPU
at::Tensor output;
at::Tensor tmp_output = npu_preparation::apply_tensor_without_format(query);
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "BNSD_BSND") {
tmp_output = OpPreparation::apply_tensor_without_format(
{query.size(DIM_0), query.size(DIM_2), query.size(DIM_1), query.size(DIM_3)},
query.options().dtype(query.dtype()));
} else if (input_layout_str == "TND") {
tmp_output = OpPreparation::apply_tensor_without_format(
{query.size(DIM_0), query.size(DIM_1), value.size(DIM_2)},
query.options().dtype(query.dtype()));
}
if (quant_scale2.has_value()) {
output = npu_preparation::apply_tensor_without_format(tmp_output.sizes(), c10::dtype(c10::ScalarType::Char));
} else if (query.dtype() == at::kChar) {
output = npu_preparation::apply_tensor_without_format(tmp_output.sizes(), c10::dtype(c10::ScalarType::Half));
} else {
output = npu_preparation::apply_tensor_without_format(tmp_output);
}
auto actSeqLen = actual_seq_lengths.value_or(at::IntArrayRef{});
auto actSeqLenKv = actual_seq_lengths_kv.value_or(at::IntArrayRef{});
int64_t inner_precise = 1;
if (sparse_mode >= PFA_SPARSE_HIGH_PRECISION_NO_MASK && sparse_mode <= PFA_SPARSE_HIGH_PRECISION_BAND) {
// for sparse in range [10,14], set inner calculate mode to high-precision
inner_precise = 0;
sparse_mode -= PFA_SPARSE_HIGH_PRECISION_NO_MASK;
}
char input_layout_char[LAYOUT_MAX_LENGTH];
strncpy(input_layout_char, input_layout_str.c_str(), LAYOUT_MAX_LENGTH - 1);
// dispatch hostAPI
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnPromptFlashAttentionV3, query, key, value, pse_shift, atten_mask, actSeqLen,
actSeqLenKv, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value,
pre_tokens, next_tokens, input_layout_char,
num_key_value_heads, sparse_mode, inner_precise, output);
return output;
}
#endif
#if VERSION_BETWEEN(V2R1, VERSION_NEWEST)
at::Tensor npu_incre_flash_attention_symint(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &padding_mask, const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &pse_shift,
c10::OptionalArrayRef<c10::SymInt> actual_seq_lengths, const c10::optional<at::Tensor> &antiquant_scale,
const c10::optional<at::Tensor> &antiquant_offset, const c10::optional<at::Tensor> &block_table,
const c10::optional<at::Tensor> &dequant_scale1, const c10::optional<at::Tensor> &quant_scale1,
const c10::optional<at::Tensor> &dequant_scale2, const c10::optional<at::Tensor> &quant_scale2,
const c10::optional<at::Tensor> &quant_offset2, const c10::optional<at::Tensor> &kv_padding_size,
int64_t num_heads, double scale_value, c10::string_view input_layout, int64_t num_key_value_heads,
int64_t block_size, int64_t inner_precise)
{
// construct the output tensor of the NPU
at::Tensor output;
if (quant_scale2.has_value()) {
output = npu_preparation::apply_tensor_without_format(query.sizes(), c10::dtype(c10::ScalarType::Char));
} else if (query.dtype() == at::kChar) {
output = npu_preparation::apply_tensor_without_format(query.sizes(), c10::dtype(c10::ScalarType::Half));
} else {
output = npu_preparation::apply_tensor_without_format(query);
}
at::TensorList keyTensors = key;
at::TensorList valueTensors = value;
std::string input_layout_str = std::string(input_layout);
char input_layout_char[LAYOUT_MAX_LENGTH];
strncpy(input_layout_char, input_layout_str.c_str(), LAYOUT_MAX_LENGTH - 1);
// dispatch hostAPI
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnIncreFlashAttentionV4, query, keyTensors, valueTensors, pse_shift, atten_mask,
actual_seq_lengths, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale,
antiquant_offset, block_table, kv_padding_size, num_heads, scale_value, input_layout_char,
num_key_value_heads, block_size, inner_precise, output);
return output;
}
#endif
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ascend/op-plugin.git
git@gitee.com:ascend/op-plugin.git
ascend
op-plugin
op-plugin
master

搜索帮助