diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt index 0e89ad3e0344204794d14822000ed649fe630339..43ee62f823c028c85ae24155e3f9760c53716eb7 100644 --- a/.jenkins/check/config/filter_cppcheck.txt +++ b/.jenkins/check/config/filter_cppcheck.txt @@ -1 +1,2 @@ -"ms_custom_ops/ops/ascendc/grid_sample/op_kernel/grid_sample_ms.cpp" "" \ No newline at end of file +"ms_custom_ops/ops/ascendc/grid_sample/op_kernel/grid_sample_ms.cpp" "" +"ms_custom_ops/ops/ascendc/unpad_fa_npd/op_host/unpad_fa_npd.cpp" "unreadVariable" diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index 6589a7906cdfac92bf23a1f75b3f8d6f682646fb..7b878a2313102e49d36da30332b9502fc91a016c 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -1,2 +1,10 @@ # ms_custom_ops +"ms_custom_ops/ops/ascendc/apply_rotary_pos_emb_ms/op_kernel" "build/include_subdir" +"ms_custom_ops/ops/ascendc/apply_rotary_pos_emb_ms/op_host" "build/include_subdir" +"ms_custom_ops/ops/ascendc/unpad_fa_npd/op_kernel" "build/include_subdir" +"ms_custom_ops/ops/ascendc/unpad_fa_npd/op_host" "build/include_subdir" +"ms_custom_ops/ops/ascendc/reshape_and_cache_npd/op_kernel" "build/include_subdir" +"ms_custom_ops/ops/ascendc/reshape_and_cache_npd/op_host" "build/include_subdir" +"ms_custom_ops/ops/ascendc/kernel_common/op_kernel" "build/include_subdir" +# ms_custom_ops "ms_custom_ops/ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h" "build/include_subdir" \ No newline at end of file diff --git a/.jenkins/rules/codespell/codespell.allow b/.jenkins/rules/codespell/codespell.allow index d6c01ca1af2554d59408c242251a7b467f9a2f89..7ff55fc31dfacab2ecce55ac40a6a8007153e59c 100644 --- a/.jenkins/rules/codespell/codespell.allow +++ b/.jenkins/rules/codespell/codespell.allow @@ -1 +1,6 @@ EnQue +CANN +ND +nd +ArchType +CopyIn diff --git a/cmake/compile_ascendc_ops.cmake b/cmake/compile_ascendc_ops.cmake index c5a5b803aa0702e955fa8cba47c6ee7ff1011bf6..40137fe04d8c7df6160484222a0475b4f6ad0870 100644 --- a/cmake/compile_ascendc_ops.cmake +++ b/cmake/compile_ascendc_ops.cmake @@ -37,6 +37,7 @@ endif() add_custom_target( build_custom_op ALL COMMAND ${Python3_EXECUTABLE} ${OP_COMPILER_SCRIPT} + --common_dirs="${ASCENDC_OP_COMMON_DIRS}" --op_dirs="${ASCENDC_OP_DIRS}" --build_path=${CMAKE_BUILD_PATH} --build_type=${CMAKE_BUILD_TYPE} diff --git a/ops/ascendc/CMakeLists.txt b/ops/ascendc/CMakeLists.txt index 532eb13435e7294d554fd43958a9105164a21e53..d8eddfc607d2bc229fb995b9bffc51ffcc8a62e2 100644 --- a/ops/ascendc/CMakeLists.txt +++ b/ops/ascendc/CMakeLists.txt @@ -8,6 +8,7 @@ if(DEFINED ENV{OP_DIRS}) else() set(ASCENDC_OP_DIRS "") file(GLOB ITEMS "${CMAKE_CURRENT_SOURCE_DIR}/*") + list(FILTER ITEMS EXCLUDE REGEX ".*kernel_common/.*") foreach(ITEM ${ITEMS}) if(IS_DIRECTORY "${ITEM}" AND EXISTS "${ITEM}/op_host" AND EXISTS "${ITEM}/op_kernel") list(APPEND ASCENDC_OP_DIRS ${ITEM}) @@ -15,6 +16,8 @@ else() endforeach() endif() +set(ASCENDC_OP_COMMON_DIRS "${CMAKE_CURRENT_SOURCE_DIR}/kernel_common") + # AscendC src files file(GLOB_RECURSE SRC_FILES "${CMAKE_CURRENT_SOURCE_DIR}/*.cc") list(FILTER SRC_FILES EXCLUDE REGEX ".*op_host.*") diff --git a/ops/ascendc/apply_rotary_pos_emb_ms/op_host/apply_rotary_pos_emb_ms.cpp b/ops/ascendc/apply_rotary_pos_emb_ms/op_host/apply_rotary_pos_emb_ms.cpp index f5e7278816ba4a585678c11f7b0af62eca6f7644..9de772b4b21cb2b4d9c3a9e41c335389ddad72af 100644 --- a/ops/ascendc/apply_rotary_pos_emb_ms/op_host/apply_rotary_pos_emb_ms.cpp +++ b/ops/ascendc/apply_rotary_pos_emb_ms/op_host/apply_rotary_pos_emb_ms.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ #include -#include "apply_rotary_pos_emb_ms_tiling.h" // NOLINT(build/include_subdir) +#include "apply_rotary_pos_emb_ms_tiling.h" #include "register/op_def_registry.h" #include "graph/utils/type_utils.h" #include "utils/log/asc_cpu_log.h" diff --git a/ops/ascendc/apply_rotary_pos_emb_ms/op_kernel/apply_rotary_pos_emb_ms.cpp b/ops/ascendc/apply_rotary_pos_emb_ms/op_kernel/apply_rotary_pos_emb_ms.cpp index 05ff468a414e850a83396dd2553c1858a1508b19..9f61cf479b9c2dc9e694b205f7fd4b6fbef0ca1e 100644 --- a/ops/ascendc/apply_rotary_pos_emb_ms/op_kernel/apply_rotary_pos_emb_ms.cpp +++ b/ops/ascendc/apply_rotary_pos_emb_ms/op_kernel/apply_rotary_pos_emb_ms.cpp @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "kernel_operator.h" // NOLINT(build/include_subdir) +#include "kernel_operator.h" constexpr int32_t BUFFER_NUM = 1; template class KernelApplyRotaryPosEmbMS { diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/common.h b/ops/ascendc/kernel_common/op_kernel/asd/common/include/common.h new file mode 100644 index 0000000000000000000000000000000000000000..48ed2d1e563e0b5bc77e08d4c59f82297b5521ea --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/common.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_COMMON_H +#define INCLUDE_COMMON_H + +#define CONST_2 2 + +#define SET_FLAG(trigger, waiter, e) AscendC::SetFlag((e)) +#define WAIT_FLAG(trigger, waiter, e) AscendC::WaitFlag((e)) +#define PIPE_BARRIER(pipe) AscendC::PipeBarrier() + +#ifndef __force_inline__ +#define __force_inline__ inline __attribute__((always_inline)) +#endif + +#endif diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/common_func.h b/ops/ascendc/kernel_common/op_kernel/asd/common/include/common_func.h new file mode 100644 index 0000000000000000000000000000000000000000..a0665eded38e8a672aad9b6adb5e2f12897b416f --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/common_func.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef INCLUDE_COMMON_FUNC_H +#define INCLUDE_COMMON_FUNC_H + +#include +#include + +#ifdef __CCE_KT_TEST__ +#include "include/stub_def.h" +#include "include/stub_fun.h" +#else +#include "impl/kernel_macros.h" +#endif + +template +inline __aicore__ T RoundUp(const T val) { + static_assert(ALIGN != 0, "align must not be zero"); + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + T align = ALIGN; + if (val + align - 1 < val) { + return val; + } + return (val + align - 1) / align * align; +} + +template +inline __aicore__ T RoundUp(const T val, const T align) { + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + if (align == 0 || val + align - 1 < val) { + return val; + } + return (val + align - 1) / align * align; +} + +template +inline __aicore__ T CeilDiv(const T dividend) { + static_assert(DIVISOR != 0, "align must not be zero"); + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + T divisor = DIVISOR; + if (dividend + divisor - 1 < dividend) { + return dividend; + } + return (dividend + divisor - 1) / divisor; +} + +template +constexpr T T_MAX = std::numeric_limits::max(); + +template +inline __aicore__ T CeilDiv(const T dividend, const T divisor) { + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + if (divisor == 0 || dividend + divisor - 1 < dividend) { + return T_MAX; + } + return (dividend + divisor - 1) / divisor; +} + +template +__aicore__ inline T Min(const T lhs, const T rhs) { + return lhs < rhs ? lhs : rhs; +} + +template +__aicore__ __attribute__((always_inline)) inline uint32_t BlockSize() { + return 32 / sizeof(Dtype); +} + +template +__aicore__ __attribute__((always_inline)) inline uint32_t MatrixSize() { + return 512 / sizeof(Dtype); +} + +template +__aicore__ __attribute__((always_inline)) inline uint64_t BlockSizeRoundUp(uint64_t num) { + return (num + BlockSize() - 1) / BlockSize() * BlockSize(); +} + +template +__aicore__ __attribute__((always_inline)) inline uint64_t NumBlocksRoundUp(uint64_t num) { + return (num + BlockSize() - 1) / BlockSize(); +} + +template +__aicore__ __attribute__((always_inline)) inline uint64_t MatrixSizeRoundUp(uint64_t num) { + return (num + MatrixSize() - 1) / MatrixSize() * MatrixSize(); +} + +template +__aicore__ __attribute__((always_inline)) inline uint64_t NumMatrixsRoundUp(uint64_t num) { + return (num + MatrixSize() - 1) / MatrixSize(); +} + +template +__aicore__ __attribute__((always_inline)) inline uint64_t L0HalfSize() { + return 32 * 1024 / sizeof(Dtype); +} + +#endif diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/hardware.h b/ops/ascendc/kernel_common/op_kernel/asd/common/include/hardware.h new file mode 100644 index 0000000000000000000000000000000000000000..a9a22b5b9c8a1f53eb8e666cde42e706bf4ab518 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/hardware.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_HARDWARE_H +#define INCLUDE_HARDWARE_H + +enum class ArchType { ASCEND_V220, ASCEND_V200, ASCEND_M200 }; + +template +struct HardwareInfo { + static uint32_t const l2BW = 5; + static uint32_t const hbmBW = 1; + static uint32_t const supportMix = 0; + static uint32_t const l1Size = 512 * 1024; + static uint32_t const l0ASize = 64 * 1024; + static uint32_t const l0BSize = 64 * 1024; + static uint32_t const l0CSize = 128 * 1024; + static uint32_t const l2Size = 192 * 1024 * 1024; + static uint32_t const biasSize = 1024; + static uint32_t const fixBufSize = 7 * 1024; + static uint32_t const ubSize = 192 * 1024; + static uint32_t const fractalSize = 512; + static uint32_t const l1l0BlockSize = 32; + static uint32_t const btBlockSize = 64; + static uint32_t const fbBlockSize = 128; +}; + +#endif diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterator.h b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..813c0035c905a1cc6c5337cfa44b584d44c5b113 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterator.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_ITERTOR_H +#define INCLUDE_ITERTOR_H + +#include "common_func.h" +#include "hardware.h" +#include "tikcfw/kernel_operator.h" +#include "layout.h" +#include "mem.h" + +///////////////////////////////////////////////////// +// gm_to_l1 +///////////////////////////////////////////////////// +template +struct gm_to_l1 { + __aicore__ gm_to_l1(AscendC::LocalTensor l1Tensor, AscendC::GlobalTensor gmTensor, + uint32_t nTileActual, uint32_t nTileCeil, uint32_t nVal, uint32_t dTileActual, uint32_t dTileCeil, + uint32_t dVal) {} +}; + +template +struct m_gm_to_l1 { + __aicore__ m_gm_to_l1(AscendC::LocalTensor l1Tensor, AscendC::GlobalTensor gmTensor, + uint32_t ndNum, uint32_t nValue, uint32_t dValue, uint32_t srcNdMatrixStride, + uint32_t srcDValue, uint32_t dstNzC0Stride, uint32_t dstNzNStride, uint32_t dstNzMatrixStride) { + } +}; + +///////////////////////////////////////////////////// +// l1_to_l0_a +///////////////////////////////////////////////////// +template +struct l1_to_l0_a { + __aicore__ l1_to_l0_a(AscendC::LocalTensor l0Tensor, AscendC::LocalTensor l1Tensor, + uint32_t mTileCeil, uint32_t kPartCeil, uint32_t mSrcStride, uint32_t kSrcStride, + uint32_t mDstStride, uint32_t kDstStride) {} +}; + +///////////////////////////////////////////////////// +// l1_to_l0_b +///////////////////////////////////////////////////// +template +struct l1_to_l0_b { + __aicore__ l1_to_l0_b(AscendC::LocalTensor l0Tensor, AscendC::LocalTensor l1Tensor, + uint32_t nTileCeil, uint32_t kPartCeil, uint32_t nSrcStride, uint32_t kSrcStride, + uint32_t nDstStride, uint32_t kDstStride) {} +}; + +// l1_to_l0_a +///////////////////////////////////////////////////// +template +struct l1_to_l0_a_v1 { + __aicore__ l1_to_l0_a_v1(AscendC::LocalTensor l0_tensor, AscendC::LocalTensor l1_tensor, + uint32_t m_tile_ceil, uint32_t k_tile_ceil, uint32_t k_part, uint32_t k_part_ceil, + uint32_t k_part_idx) {} +}; + +///////////////////////////////////////////////////// +// l1_to_l0_b +///////////////////////////////////////////////////// +template +struct l1_to_l0_b_v1 { + __aicore__ l1_to_l0_b_v1(AscendC::LocalTensor l0_tensor, AscendC::LocalTensor l1_tensor, + int32_t n_tile_ceil, int32_t k_tile_ceil, int32_t k_part_ceil, int32_t k_part_idx) {} +}; + +///////////////////////////////////////////////////// +// l0c_to_gm +///////////////////////////////////////////////////// +template +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, uint32_t nTileActual, uint32_t mTileCeil, uint32_t nActual) {} +}; + +///////////////////////////////////////////////////// +// l0c_to_l1 +///////////////////////////////////////////////////// +template +struct l0c_to_l1 { + __aicore__ l0c_to_l1(AscendC::LocalTensor l1Tensor, AscendC::LocalTensor l0cTensor, + AscendC::LocalTensor deqTensor, uint32_t mTileActual, uint32_t nTileActual, + uint32_t mTileCeil, uint32_t nActual) {} +}; + +#include "iterators/gm_to_l1_iterator.inc" +#include "iterators/gm_to_ub_iterator.inc" +#include "iterators/l0c_to_gm_iterator.inc" +#include "iterators/l0c_to_l1_iterator.inc" +#include "iterators/l0c_to_ub_iterator.inc" +#include "iterators/l1_to_bt_iterator.inc" +#include "iterators/l1_to_fb_iterator.inc" +#include "iterators/l1_to_l0_iterator.inc" +#include "iterators/l1_to_ub_iterator.inc" +#endif diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/gm_to_l1_iterator.inc b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/gm_to_l1_iterator.inc new file mode 100644 index 0000000000000000000000000000000000000000..393698b4de9582353bdddf0c4e099905459dfdc4 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/gm_to_l1_iterator.inc @@ -0,0 +1,236 @@ +/* +* Copyright (c) 2024 Huawei Technologies Co., Ltd. +* This file is a part of the CANN Open Software. +* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +*/ +#include "../iterator.h" + +constexpr uint32_t STRIDE_LIMIT = 65536; + +// Partial specialization for V220, ND_in, ND_out +template +struct gm_to_l1 { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + + __aicore__ gm_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::GlobalTensor gmTensor, + uint32_t nTileActual, + uint32_t nTileCeil, + uint32_t nVal, + uint32_t dTileActual, + uint32_t dTileCeil, + uint32_t dVal) + { + AscendC::DataCopy(l1Tensor, + gmTensor, + AscendC::DataCopyParams(1, // nBurst + CeilDiv(nTileActual * dTileActual), // lenBurst + 0, // srcGap + 0)); // dstGap + }; +}; + +// Partial specialization for NZ_in, NZ_out +template +struct gm_to_l1 { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + + __aicore__ gm_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::GlobalTensor gmTensor, + uint32_t nTileActual, + uint32_t nTileCeil, + uint32_t nVal, + uint32_t dTileActual, + uint32_t dTileCeil, + uint32_t dVal) + { + uint64_t srcStride = nTileCeil - nTileActual; + if (srcStride < STRIDE_LIMIT) { + AscendC::DataCopy(l1Tensor, gmTensor, + AscendC::DataCopyParams(dTileActual / BLOCK_SIZE, // nBurst + nTileActual, // lenBurst + nTileCeil - nTileActual, // srcGap + 0)); // dstGap + } else { + for (uint64_t i = 0; i < dTileActual / BLOCK_SIZE; i++) { + uint64_t dstOffset = i * nTileActual * BLOCK_SIZE; + uint64_t srcOffset = i * nTileCeil * BLOCK_SIZE; + AscendC::DataCopy(l1Tensor[dstOffset], gmTensor[srcOffset], + AscendC::DataCopyParams(1, // nBurst + nTileActual, // lenBurst + 0, // srcGap + 0)); // dstGap + } + } + }; +}; + +// Partial specialization for V220, ND_in, ND_out +template +struct gm_to_l1 { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + + __aicore__ gm_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::GlobalTensor gmTensor, + uint32_t nTileActual, + uint32_t nTileCeil, + uint32_t nVal, + uint32_t dTileActual, + uint32_t dTileCeil, + uint32_t dVal) + { + if (dVal < STRIDE_LIMIT) { + AscendC::DataCopy(l1Tensor, + gmTensor, + AscendC::Nd2NzParams(1, // ndNum + nTileActual, // nValue + dTileActual, // dValue + 0, // srcNdMatrixStride, unused + dVal, // srcDValue + nTileCeil, // dstNzC0Stride + 1, // dstNzNStride + 0)); // dstNzMatrixStride, unused + } else { + for (uint32_t i = 0; i < nTileActual; i++) { + AscendC::DataCopy(l1Tensor[i * BLOCK_SIZE], + gmTensor[i * dVal], + AscendC::Nd2NzParams(1, // ndNum + 1, // nValue + dTileActual, // dValue + 0, // srcNdMatrixStride, unused + 0, // srcDValue + nTileCeil, // dstNzC0Stride + 0, // dstNzNStride + 0)); // dstNzMatrixStride, unused + } + } + }; +}; + +// Partial specialization for V220, ND_in, NZ_out +template +struct gm_to_l1 { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + + __aicore__ gm_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::GlobalTensor gmTensor, + uint32_t nTileActual, + uint32_t nTileCeil, + uint32_t nVal, + uint32_t dTileActual, + uint32_t dTileCeil, + uint32_t dVal) + { + if (dVal < STRIDE_LIMIT) { + AscendC::DataCopy(l1Tensor, + gmTensor, + AscendC::Nd2NzParams(1, // ndNum + nTileActual, // nValue + dTileActual, // dValue + 0, // srcNdMatrixStride, unused + dVal, // srcDValue + nTileCeil, // dstNzC0Stride + 1, // dstNzNStride + 0)); // dstNzMatrixStride, unused + } else { + for (uint32_t i = 0; i < nTileActual; ++i) { + AscendC::DataCopy(l1Tensor, + gmTensor, + AscendC::Nd2NzParams(1, // ndNum + 1, // nValue + dTileActual, // dValue + 0, // srcNdMatrixStride, unused + 0, // srcDValue + nTileCeil, // dstNzC0Stride + 0, // dstNzNStride + 0)); // dstNzMatrixStride, unused + } + } + }; +}; + +// Partial specialization for V220, ND_in, NZ_out +template +struct m_gm_to_l1 { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + + __aicore__ m_gm_to_l1(AscendC::LocalTensor l1Tensor, AscendC::GlobalTensor gmTensor, + uint32_t ndNum, uint32_t nValue, uint32_t dValue, uint32_t srcNdMatrixStride, + uint32_t srcDValue, uint32_t dstNzC0Stride, uint32_t dstNzNStride, uint32_t dstNzMatrixStride) { + if (srcNdMatrixStride < STRIDE_LIMIT) { + AscendC::DataCopy(l1Tensor, gmTensor, + AscendC::Nd2NzParams(ndNum, // ndNum + nValue, // nValue + dValue, // dValue + srcNdMatrixStride, // srcNdMatrixStride, unused + srcDValue, // srcDValue + dstNzC0Stride, // dstNzC0Stride + dstNzNStride, // dstNzNStride + dstNzMatrixStride)); // dstNzMatrixStride, unused + } else { + for (int i = 0; i < ndNum; i++) { + AscendC::DataCopy(l1Tensor[i * dstNzMatrixStride], gmTensor[i * srcNdMatrixStride], + AscendC::Nd2NzParams(1, // ndNum + nValue, // nValue + dValue, // dValue + 0, // unused + srcDValue, // srcDValue + dstNzC0Stride, // dstNzC0Stride + dstNzNStride, // dstNzNStride + 0)); // unused + } + } + }; +}; + +template + __aicore__ __attribute__((always_inline)) inline void CopyGmToL1Npd(AscendC::LocalTensor l1_dst, + AscendC::GlobalTensor gm_src, + size_t page_size, size_t kv_len, + size_t kv_len_round, size_t embed, + size_t embed_round, size_t kv_stride) { + + + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + uint32_t nd_num = kv_len / page_size; + if (nd_num > 0) { + m_gm_to_l1( + l1_dst, gm_src, + nd_num, // ndNum + page_size, // nValue + embed, // dValue + kv_stride, // srcNdMatrixStride + embed_round, // srcDValue + kv_len_round, // dstNzC0Stride + 1, // dstNzNStride + page_size * BLOCK_SIZE); // dstNzMatrixStride + } + + uint32_t reminder = kv_len % page_size; + if (reminder > 0) { + size_t l1_k_offset = (nd_num) * page_size * BLOCK_SIZE; + size_t gm_k_offset = (nd_num) * kv_stride; + + m_gm_to_l1( + l1_dst[l1_k_offset], gm_src[gm_k_offset], + 1, // ndNum + reminder, // nValue + embed, // dValue + kv_stride, // srcNdMatrixStride + embed_round, // srcDValues + kv_len_round, // dstNzC0Stride + 1, // dstNzNStride + page_size * BLOCK_SIZE); // dstNzMatrixStride + } + return; + }; \ No newline at end of file diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/gm_to_ub_iterator.inc b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/gm_to_ub_iterator.inc new file mode 100644 index 0000000000000000000000000000000000000000..33d4c2edfa6550a7380d04a1e1ff15c2ece2e6bd --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/gm_to_ub_iterator.inc @@ -0,0 +1,84 @@ +/* +* Copyright (c) 2024 Huawei Technologies Co., Ltd. +* This file is a part of the CANN Open Software. +* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +*/ +#include "../iterator.h" + +template struct gm_to_ub { + __aicore__ inline gm_to_ub(AscendC::LocalTensor dstTensor, AscendC::GlobalTensor srcTensor, + uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride) + { + AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride)); + }; +}; + +template struct gm_to_ub_align { + __aicore__ inline gm_to_ub_align(AscendC::LocalTensor dstTensor, AscendC::GlobalTensor srcTensor, + uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum, + uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap) + { + AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0), + AscendC::DataCopyPadExtParams(false, leftPaddingNum, rightPaddingNum, 0)); + }; +}; + +template struct ub_to_ub { + __aicore__ inline ub_to_ub(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride) + { + AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride)); + }; +}; + +template +struct ub_to_gm { + __aicore__ inline ub_to_gm(AscendC::GlobalTensor dstTensor, AscendC::LocalTensor srcTensor, + uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride) + { + AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride)); + }; +}; + +template struct ub_to_gm { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + + __aicore__ ub_to_gm(AscendC::GlobalTensor gmTensor, AscendC::LocalTensor l1Tensor, + uint32_t nTileActual, uint32_t nTileCeil, uint32_t nVal, uint32_t dTileActual, + uint32_t dTileCeil, uint32_t dVal) + { + uint64_t dstStride = nTileCeil - nTileActual; + if (dstStride < STRIDE_LIMIT) { + AscendC::DataCopy(gmTensor, l1Tensor, + AscendC::DataCopyParams(dTileActual / BLOCK_SIZE, // nBurst + nTileActual, // lenBurst + 0, // srcGap + dstStride)); // dstGap + } else { + for (uint64_t i = 0; i < dTileActual / BLOCK_SIZE; i++) { + uint64_t srcOffset = i * nTileActual * BLOCK_SIZE; + uint64_t dstOffset = i * nTileCeil * BLOCK_SIZE; + AscendC::DataCopy(gmTensor[dstOffset], l1Tensor[srcOffset], + AscendC::DataCopyParams(1, // nBurst + nTileActual, // lenBurst + 0, // srcGap + 0)); // dstGap + } + } + }; +}; + +template struct ub_to_gm_align { + __aicore__ inline ub_to_gm_align(AscendC::GlobalTensor dstTensor, AscendC::LocalTensor srcTensor, + uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum, + uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap) + { + AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0)); + }; +}; \ No newline at end of file diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l0c_to_gm_iterator.inc b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l0c_to_gm_iterator.inc new file mode 100644 index 0000000000000000000000000000000000000000..04599801e2f7f2c0f5302a903d502d9e326ef074 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l0c_to_gm_iterator.inc @@ -0,0 +1,200 @@ +/* +* Copyright (c) 2024 Huawei Technologies Co., Ltd. +* This file is a part of the CANN Open Software. +* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +*/ +#include "../iterator.h" + +constexpr uint32_t BLOCK_NUM = 16; +constexpr uint32_t BLOCK_SIZE_INT8 = 32; + +template <> +struct l0c_to_gm { + /** + * @brief Copy data from L0C buffer to global memory, partial specialized for + * + * @param gmTensor the destination tensor on global memory, which is stored in ND format. + * @param l0cTensor the source tensor on L0C buffer, which is stored in FRACTAL_NZ format. + * @param mTileActual the m-direction size of the matrix in L0C buffer. + * @param nTileActual the n-direction size of the matrix in L0C buffer. + * @param srcStride the source stride between the adjacent fractal matrices along n-direction in unit of C0_SIZE. + * @param dstStride the leading dimension of the destination matrix in unit of element. + */ + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::F322F16; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#else + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8), + 0, + dstStride); + intriParams.nz2ndParams = {true, 1, 0, 0, static_cast(nTileActual)}; + intriParams.quantParams = {QuantMode_t::F322F16}; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; + +template <> +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::VDEQF16; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#else + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8), + 0, + dstStride); + intriParams.nz2ndParams = {true, 1, 0, 0, static_cast(nTileActual)}; + intriParams.quantParams = {QuantMode_t::VDEQF16}; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; + +template <> +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor<__bf16> gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::F322BF16; + AscendC::Fixpipe<__bf16, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams); +#else + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8), + 0, + dstStride); + intriParams.nz2ndParams = {true, 1, 0, 0, static_cast(nTileActual)}; + intriParams.quantParams = {QuantMode_t::F322BF16}; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; + +// Partial specialization ND, float +template <> +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::NoQuant; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#else + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8), + 0, + dstStride); + intriParams.nz2ndParams = {true, 1, 0, 0, static_cast(nTileActual)}; + intriParams.quantParams = {QuantMode_t::NoQuant}; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; + +template <> +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::F322F16; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#else + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8), + 0, + dstStride - (nTileActual * sizeof(half) / sizeof(float))); + intriParams.quantParams = {QuantMode_t::F322F16}; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; + +template <> +struct l0c_to_gm { + __aicore__ l0c_to_gm(AscendC::GlobalTensor gmTensor, + AscendC::LocalTensor l0cTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t srcStride, + uint32_t dstStride) + { +#ifdef __DAV_C220_CUBE__ + auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize + mTileActual, // mSize + srcStride, // srcStride + dstStride, // dstStride + false); // enRelu + + intriParams.quantPre = QuantMode_t::NoQuant; + AscendC::Fixpipe(gmTensor, l0cTensor, intriParams); +#endif + }; +}; diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l0c_to_l1_iterator.inc b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l0c_to_l1_iterator.inc new file mode 100644 index 0000000000000000000000000000000000000000..463f2ade11c77095a9641bd1818ed54a797f44b0 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l0c_to_l1_iterator.inc @@ -0,0 +1,40 @@ +/* +* Copyright (c) 2024 Huawei Technologies Co., Ltd. +* This file is a part of the CANN Open Software. +* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +*/ +#include "../iterator.h" +///////////////////////////////////////////////////// +// l0c_to_l1 +///////////////////////////////////////////////////// + +// Partial specialization ZN, half, int32_t +template +struct l0c_to_l1 { + using ElementOut = half; + using ElementIn = int32_t; + __aicore__ l0c_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::LocalTensor l0cTensor, + AscendC::LocalTensor deqTensor, + uint32_t mTileActual, + uint32_t nTileActual, + uint32_t mTileCeil, + uint32_t nActual) + { + constexpr uint32_t BLOCK_NUM = 16; + constexpr uint32_t BLOCK_SIZE = 32; + AscendC::FixpipeParams intriParams( + (nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE, + static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE), + 0, + mTileCeil - static_cast(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE) * + sizeof(ElementOut) / sizeof(ElementIn)); + intriParams.nz2ndParams = {false, 1, 0, 0, static_cast(nTileActual)}; + intriParams.quantParams = {QuantMode_t::VDEQF16}; + AscendC::Fixpipe(l1Tensor, l0cTensor, deqTensor, intriParams); + }; +}; \ No newline at end of file diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l0c_to_ub_iterator.inc b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l0c_to_ub_iterator.inc new file mode 100644 index 0000000000000000000000000000000000000000..ef8cc958b13652d04eca9a33a71a98c3d2450db3 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l0c_to_ub_iterator.inc @@ -0,0 +1,62 @@ +/* +* Copyright (c) 2024 Huawei Technologies Co., Ltd. +* This file is a part of the CANN Open Software. +* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +*/ +#include "../iterator.h" + +///////////////////////////////////////////////////// +// l0c_to_ub +///////////////////////////////////////////////////// + +// Partial specialization ZN, half, int32_t +template struct l0c_to_ub { + __aicore__ l0c_to_ub(AscendC::LocalTensor ubTensor, AscendC::LocalTensor l0cTensor, + uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride) + { + constexpr auto mode = + MatrixMode ? AscendC::BlockMode::BLOCK_MODE_MATRIX : AscendC::BlockMode::BLOCK_MODE_VECTOR; + AscendC::DataCopy(ubTensor, l0cTensor, + AscendC::DataCopyParams(nBurst, // count + lenBurst, // len + srcStride, // srcStrideIn + dstStride), // dstStrideIn + AscendC::DataCopyEnhancedParams(mode, // blockModeIn + AscendC::DeqScale::DEQ_NONE, // deqScaleIn + 0, // deqValueIn + 0, // sidStoreModeIn + false, // isReluIn + pad_t::PAD_NONE, // padModeIn + 0) // padValueIn + ); + }; +}; + +template +struct l0c_to_ub { + __aicore__ l0c_to_ub(AscendC::LocalTensor ubTensor, + AscendC::LocalTensor l0cTensor, + uint16_t nBurst, + uint16_t lenBurst, + uint16_t srcStride, + uint16_t dstStride) + { + AscendC::DataCopy(ubTensor, l0cTensor, + AscendC::DataCopyParams(nBurst, // count + lenBurst, // len + srcStride, // srcStrideIn + dstStride), // dstStrideIn + AscendC::DataCopyEnhancedParams(AscendC::BlockMode::BLOCK_MODE_MATRIX, // blockModeIn + AscendC::DeqScale::VDEQ16, // deqScaleIn + 0, // deqValueIn + 0, // sidStoreModeIn + false, // isReluIn + pad_t::PAD_NONE, // padModeIn + 0) // padValueIn + ); + }; +}; \ No newline at end of file diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_bt_iterator.inc b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_bt_iterator.inc new file mode 100644 index 0000000000000000000000000000000000000000..7a09c31654561ad19b9e623141097ffdf0dfee2a --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_bt_iterator.inc @@ -0,0 +1,29 @@ +/* +* Copyright (c) 2024 Huawei Technologies Co., Ltd. +* This file is a part of the CANN Open Software. +* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +*/ +#include "../iterator.h" + +///////////////////////////////////////////////////// +// l1_to_bt +///////////////////////////////////////////////////// + +// Partial specialization for V220 +template +struct l1_to_bt { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::btBlockSize / sizeof(DataType); + + __aicore__ l1_to_bt(AscendC::LocalTensor biasTableTensor, + AscendC::LocalTensor biasL1Tensor, + uint32_t ntileActual) + { + AscendC::DataCopy( + biasTableTensor, biasL1Tensor, {1, static_cast(CeilDiv(ntileActual)), 0, 0}); + }; +}; \ No newline at end of file diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_fb_iterator.inc b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_fb_iterator.inc new file mode 100644 index 0000000000000000000000000000000000000000..c9c32d4fb8e8d08116acac91407b6b71c6641fc0 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_fb_iterator.inc @@ -0,0 +1,32 @@ +/* +* Copyright (c) 2024 Huawei Technologies Co., Ltd. +* This file is a part of the CANN Open Software. +* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +*/ +#include "../iterator.h" + +///////////////////////////////////////////////////// +// l1_to_fb +///////////////////////////////////////////////////// + +// Partial specialization for V220 +template +struct l1_to_fb { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::fbBlockSize / sizeof(DataType); + + __aicore__ + l1_to_fb(AscendC::LocalTensor fbTensor, AscendC::LocalTensor l1Tensor, uint32_t ntileActual) + { + copy_cbuf_to_fbuf((__fbuf__ DataType *)fbTensor.GetPhyAddr(), + (__cbuf__ DataType *)l1Tensor.GetPhyAddr(), + 1, + CeilDiv(ntileActual), + 0, + 0); + }; +}; \ No newline at end of file diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_l0_iterator.inc b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_l0_iterator.inc new file mode 100644 index 0000000000000000000000000000000000000000..199629978681bfcff2bf6142c7bf1b36b5ce2d69 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_l0_iterator.inc @@ -0,0 +1,248 @@ +/* +* Copyright (c) 2024 Huawei Technologies Co., Ltd. +* This file is a part of the CANN Open Software. +* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +*/ +#include "../iterator.h" + +///////////////////////////////////////////////////// +// l1_to_l0_a +///////////////////////////////////////////////////// + +// Partial specialization for vector +template +struct l1_to_l0_a { + using HardwareParams = HardwareInfo; + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + + __aicore__ l1_to_l0_a(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t mTileCeil, + uint32_t kPartCeil, + uint32_t mSrcStride, + uint32_t kSrcStride, + uint32_t mDstStride, + uint32_t kDstStride) + { + AscendC::LoadData(l0Tensor, + l1Tensor, + AscendC::LoadData2dParams(0, // baseIdx + kPartCeil, // repeat + kSrcStride, // srcStride + 0, // sid + kDstStride, // dstStride + IsTransPose, // transpose + 0)); // addrCalMode + }; +}; + +// Partial specialization for no transpose, not vector +template +struct l1_to_l0_a { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; + + __aicore__ l1_to_l0_a(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t mTileCeil, + uint32_t kPartCeil, + uint32_t mSrcStride, + uint32_t kSrcStride, + uint32_t mDstStride, + uint32_t kDstStride) + { + for (uint32_t i = 0; i < mTileCeil / BLOCK_NUM_PER_FRACTAL; i++) { + AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE], + l1Tensor[i * mSrcStride * FRACTAL_SIZE], + AscendC::LoadData2dParams(0, // baseIdx + static_cast(kPartCeil / BLOCK_SIZE), // repeat + kSrcStride, // srcStride + 0, // sid + kDstStride - 1, // dstStride + false, // transpose + 0)); // addrCalMode + } + }; +}; + +// Partial specialization for transpose, not vector +template +struct l1_to_l0_a { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; + + __aicore__ l1_to_l0_a(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t mTileCeil, + uint32_t kPartCeil, + uint32_t mSrcStride, + uint32_t kSrcStride, + uint32_t mDstStride, + uint32_t kDstStride) + { + for (uint32_t i = 0; i < mTileCeil / BLOCK_SIZE; i++) { + AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE], + l1Tensor[i * mSrcStride * FRACTAL_SIZE], + AscendC::LoadData2dParams(0, + static_cast(kPartCeil / BLOCK_NUM_PER_FRACTAL), + kSrcStride, + 0, + kDstStride - 1, + true, + 0)); + } + }; +}; + +template +struct l1_to_l0_a { + using HardwareParams = HardwareInfo; + // 16 * 32 + static constexpr uint32_t ROW_BLOCK_SIZE = 16; + static constexpr uint32_t COL_BLOCK_SIZE = 32 / sizeof(DataType); + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; + + __aicore__ l1_to_l0_a(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t mTileCeil, + uint32_t kPartCeil, + uint32_t mSrcStride, + uint32_t kSrcStride, + uint32_t mDstStride, + uint32_t kDstStride) + { + for (uint32_t i = 0; i < mTileCeil / ROW_BLOCK_SIZE; i++) { + AscendC::LoadData(l0Tensor[i * ROW_BLOCK_SIZE * kPartCeil], + l1Tensor[i * FRACTAL_SIZE], + AscendC::LoadData2dParams(0, + static_cast(kPartCeil / COL_BLOCK_SIZE), + mTileCeil / ROW_BLOCK_SIZE, + 0, + 0, + false, + 0)); + } + }; +}; + +///////////////////////////////////////////////////// +// l1_to_l0_b +///////////////////////////////////////////////////// + +// Partial specialization for vector +template +struct l1_to_l0_b { + using HardwareParams = HardwareInfo; + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + + __aicore__ l1_to_l0_b(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t nTileCeil, + uint32_t kPartCeil, + uint32_t nSrcStride, + uint32_t kSrcStride, + uint32_t nDstStride, + uint32_t kDstStride) + { + AscendC::LoadData( + l0Tensor, l1Tensor, AscendC::LoadData2dParams(0, kPartCeil, kSrcStride, 0, kDstStride, IsTransPose, 0)); + }; +}; + +template +struct l1_to_l0_b { + using HardwareParams = HardwareInfo; + using DataType = int8_t; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + + __aicore__ l1_to_l0_b(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t nTileCeil, + uint32_t kPartCeil, + uint32_t nSrcStride, + uint32_t kSrcStride, + uint32_t nDstStride, + uint32_t kDstStride) + { + for (uint32_t i = 0; i < nTileCeil / BLOCK_SIZE; i++) { + AscendC::LoadDataWithTranspose(l0Tensor[i * kPartCeil * BLOCK_SIZE], + l1Tensor[i * BLOCK_SIZE * BLOCK_SIZE], + AscendC::LoadData2dTransposeParams(0, // startIndexIn + kPartCeil / BLOCK_SIZE, // repeatTimesIn + nTileCeil / BLOCK_SIZE, // srcStrideIn + 1, // dstGapIn + 0, // dstfracGapIn + 0) // addrModeIn + ); + } + }; +}; + +// Partial specialization for no transpose, not vector +template +struct l1_to_l0_b { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; + + __aicore__ l1_to_l0_b(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t nTileCeil, + uint32_t kPartCeil, + uint32_t nSrcStride, + uint32_t kSrcStride, + uint32_t nDstStride, + uint32_t kDstStride) + { + for (uint32_t i = 0; i < kPartCeil / BLOCK_NUM_PER_FRACTAL; i++) { + AscendC::LoadData(l0Tensor[i * kDstStride * FRACTAL_SIZE], + l1Tensor[i * kSrcStride * FRACTAL_SIZE], + AscendC::LoadData2dParams(0, // baseIdx + static_cast(nTileCeil / BLOCK_SIZE), // repeat + nSrcStride, // srcStride + 0, // sid + nDstStride - 1, // dstStride + true, // transpose + 0)); // addrCalMode + } + }; +}; + +// Partial specialization for transpose, not vector +template +struct l1_to_l0_b { + using HardwareParams = HardwareInfo; + static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType); + static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType); + static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; + __aicore__ l1_to_l0_b(AscendC::LocalTensor l0Tensor, + AscendC::LocalTensor l1Tensor, + uint32_t nTileCeil, + uint32_t kPartCeil, + uint32_t nSrcStride, + uint32_t kSrcStride, + uint32_t nDstStride, + uint32_t kDstStride) + { + AscendC::LoadData( + l0Tensor, + l1Tensor, + AscendC::LoadData2dParams(0, // baseIdx + static_cast(kPartCeil * nTileCeil / FRACTAL_SIZE), // repeat + 1, // srcStride + 0, // sid + 0, // dstStride + false, // transpose + 0)); // addr_cal_mode_t + }; +}; \ No newline at end of file diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_ub_iterator.inc b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_ub_iterator.inc new file mode 100644 index 0000000000000000000000000000000000000000..75dd887f973c28fde64cfe9bf766473e63feb06e --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/iterators/l1_to_ub_iterator.inc @@ -0,0 +1,42 @@ +/* +* Copyright (c) 2024 Huawei Technologies Co., Ltd. +* This file is a part of the CANN Open Software. +* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +*/ +#include "../iterator.h" + +///////////////////////////////////////////////////// +// l1_to_ub +///////////////////////////////////////////////////// +template +struct l1_to_ub { + __aicore__ l1_to_ub(AscendC::LocalTensor ubTensor, + AscendC::LocalTensor l1Tensor, + uint16_t nBurst, + uint16_t lenBurst, + uint16_t srcStride, + uint16_t dstStride) + { + AscendC::DataCopy(ubTensor, l1Tensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride)); + }; +}; + +///////////////////////////////////////////////////// +// ub_to_l1 +///////////////////////////////////////////////////// +template +struct ub_to_l1 { + __aicore__ ub_to_l1(AscendC::LocalTensor l1Tensor, + AscendC::LocalTensor ubTensor, + uint16_t nBurst, + uint16_t lenBurst, + uint16_t srcStride, + uint16_t dstStride) + { + AscendC::DataCopy(l1Tensor, ubTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride)); + }; +}; \ No newline at end of file diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/layout.h b/ops/ascendc/kernel_common/op_kernel/asd/common/include/layout.h new file mode 100644 index 0000000000000000000000000000000000000000..71377d1f840780869f165a35843cd2b9b16dac03 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/layout.h @@ -0,0 +1,16 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef INCLUDE_LAYOUT_H +#define INCLUDE_LAYOUT_H + +enum class DataFormat { ND = 0, NZ, ZN, ZZ, NN, VECTOR }; + +#endif diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/mem.h b/ops/ascendc/kernel_common/op_kernel/asd/common/include/mem.h new file mode 100644 index 0000000000000000000000000000000000000000..87449fc70703db88793d3b0840618d5b914ae973 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/mem.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_MEM_H +#define INCLUDE_MEM_H + +#include "hardware.h" +#include "impl/kernel_event.h" +#include "kernel_tensor.h" + +enum class BufferType { ASCEND_UB, ASCEND_CB, ASCEND_L0A, ASCEND_L0B, ASCEND_L0C, ASCEND_MAX }; + +template +__aicore__ constexpr AscendC::TPosition GetPosition() { + if constexpr (BufferType_ == BufferType::ASCEND_UB) { + return AscendC::TPosition::VECIN; + } else if constexpr (BufferType_ == BufferType::ASCEND_CB) { + return AscendC::TPosition::A1; + } else if constexpr (BufferType_ == BufferType::ASCEND_L0A) { + return AscendC::TPosition::A2; + } else if constexpr (BufferType_ == BufferType::ASCEND_L0B) { + return AscendC::TPosition::B2; + } else if constexpr (BufferType_ == BufferType::ASCEND_L0C) { + return AscendC::TPosition::CO1; + } + return AscendC::TPosition::GM; +} + +template +struct AsdopsBuffer { + public: + __aicore__ AsdopsBuffer() { + constexpr uint32_t bufferSize[static_cast(BufferType::ASCEND_MAX)] = { + HardwareInfo::ubSize, HardwareInfo::l1Size, HardwareInfo::l0ASize, + HardwareInfo::l0BSize, HardwareInfo::l0CSize}; +#ifdef __DAV_C220_VEC__ + tensor[static_cast(BufferType::ASCEND_UB)].InitBuffer( + 0, bufferSize[static_cast(BufferType::ASCEND_UB)]); + tensor[static_cast(BufferType::ASCEND_UB)].address_.logicPos = + static_cast(AscendC::TPosition::VECIN); +#elif __DAV_C220_CUBE__ + tensor[static_cast(BufferType::ASCEND_CB)].InitBuffer( + 0, bufferSize[static_cast(BufferType::ASCEND_CB)]); + tensor[static_cast(BufferType::ASCEND_CB)].address_.logicPos = + static_cast(AscendC::TPosition::A1); + tensor[static_cast(BufferType::ASCEND_L0A)].InitBuffer( + 0, bufferSize[static_cast(BufferType::ASCEND_L0A)]); + tensor[static_cast(BufferType::ASCEND_L0A)].address_.logicPos = + static_cast(AscendC::TPosition::A2); + tensor[static_cast(BufferType::ASCEND_L0B)].InitBuffer( + 0, bufferSize[static_cast(BufferType::ASCEND_L0B)]); + tensor[static_cast(BufferType::ASCEND_L0B)].address_.logicPos = + static_cast(AscendC::TPosition::B2); + tensor[static_cast(BufferType::ASCEND_L0C)].InitBuffer( + 0, bufferSize[static_cast(BufferType::ASCEND_L0C)]); + tensor[static_cast(BufferType::ASCEND_L0C)].address_.logicPos = + static_cast(AscendC::TPosition::CO1); +#else + tensor[static_cast(BufferType::ASCEND_UB)].InitBuffer( + 0, bufferSize[static_cast(BufferType::ASCEND_UB)]); + tensor[static_cast(BufferType::ASCEND_UB)].address_.logicPos = + static_cast(AscendC::TPosition::VECIN); + tensor[static_cast(BufferType::ASCEND_CB)].InitBuffer( + 0, bufferSize[static_cast(BufferType::ASCEND_CB)]); + tensor[static_cast(BufferType::ASCEND_CB)].address_.logicPos = + static_cast(AscendC::TPosition::A1); + tensor[static_cast(BufferType::ASCEND_L0A)].InitBuffer( + 0, bufferSize[static_cast(BufferType::ASCEND_L0A)]); + tensor[static_cast(BufferType::ASCEND_L0A)].address_.logicPos = + static_cast(AscendC::TPosition::A2); + tensor[static_cast(BufferType::ASCEND_L0B)].InitBuffer( + 0, bufferSize[static_cast(BufferType::ASCEND_L0B)]); + tensor[static_cast(BufferType::ASCEND_L0B)].address_.logicPos = + static_cast(AscendC::TPosition::B2); + tensor[static_cast(BufferType::ASCEND_L0C)].InitBuffer( + 0, bufferSize[static_cast(BufferType::ASCEND_L0C)]); + tensor[static_cast(BufferType::ASCEND_L0C)].address_.logicPos = + static_cast(AscendC::TPosition::CO1); +#endif + } + + template + __aicore__ AscendC::LocalTensor GetBuffer(const uint32_t offset) const { + return tensor[static_cast(BufferType_)][offset].template ReinterpretCast(); + } + + public: + AscendC::LocalTensor tensor[static_cast(BufferType::ASCEND_MAX)]; +}; +#endif diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/mma.h b/ops/ascendc/kernel_common/op_kernel/asd/common/include/mma.h new file mode 100644 index 0000000000000000000000000000000000000000..d32e494286ab078d94c1c3e70393dee81690f3db --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/mma.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_MMA_H +#define INCLUDE_MMA_H + +#include "hardware.h" +#include "kernel_tensor.h" + +template +struct mmad { + __aicore__ mmad(AscendC::LocalTensor l0cTensor, AscendC::LocalTensor l0aTensor, + AscendC::LocalTensor l0bTensor, uint32_t mTileActual, uint32_t nTileActual, + uint32_t kPartActual, bool initC) {} + + __aicore__ mmad(AscendC::LocalTensor l0cTensor, AscendC::LocalTensor l0aTensor, + AscendC::LocalTensor l0bTensor, uint64_t biasBt, uint32_t mTileActual, uint32_t nTileActual, + uint32_t kPartActual, bool initC) {} +}; + +// Partial specialization for V220, int8_t, not_vector_A, not TransposeA +template +struct mmad { + __aicore__ mmad(AscendC::LocalTensor l0cTensor, AscendC::LocalTensor l0aTensor, + AscendC::LocalTensor l0bTensor, uint32_t mTileActual, uint32_t nTileActual, + uint32_t kPartActual, bool initC) { + AscendC::Mmad(l0cTensor, l0aTensor, l0bTensor, + AscendC::MmadParams(mTileActual, nTileActual, kPartActual, 0, false, initC)); + } + + __aicore__ mmad(AscendC::LocalTensor l0cTensor, AscendC::LocalTensor l0aTensor, + AscendC::LocalTensor l0bTensor, uint64_t biasBt, uint32_t mTileActual, uint32_t nTileActual, + uint32_t kPartActual, bool initC) { + AscendC::LocalTensor biasTensor; + biasTensor.InitBuffer(biasBt, mTileActual); + biasTensor.address_.logicPos = static_cast(AscendC::TPosition::C2); + AscendC::Mmad(l0cTensor, l0aTensor, l0bTensor, biasTensor, + AscendC::MmadParams(mTileActual, nTileActual, kPartActual, 0, false, initC)); + } +}; + +#endif diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/set_fpc.h b/ops/ascendc/kernel_common/op_kernel/asd/common/include/set_fpc.h new file mode 100644 index 0000000000000000000000000000000000000000..c93d45f28903ab018902a3f447e772d437269c4c --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/set_fpc.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_SET_FPC_H +#define INCLUDE_SET_FPC_H + +#include "hardware.h" +#include "kernel_tensor.h" + +///////////////////////////////////////////////////// +// SetQuantPreAddr +///////////////////////////////////////////////////// +template +struct SetQuantPreAddr { + __aicore__ SetQuantPreAddr(AscendC::LocalTensor quantPreTensor) {} +}; + +template +struct SetQuantPreAddr { + static constexpr uint32_t QUANT_PRE_ADDR_MASK = 0xffff; + static constexpr uint32_t USELESS_BIT_NUM = 7; + static constexpr uint32_t QUANT_PRE_BIT_POS_IN_FPC = 8; + + __aicore__ SetQuantPreAddr(AscendC::LocalTensor quantPreTensor) { + uint64_t quantPreAddr = static_cast((__fbuf__ uint64_t *)quantPreTensor.GetPhyAddr()); + AscendC::SetFixPipeConfigImpl(quantPreTensor); + } +}; +#endif diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/simd.h b/ops/ascendc/kernel_common/op_kernel/asd/common/include/simd.h new file mode 100644 index 0000000000000000000000000000000000000000..7d85535ca34ce4a3fcd11f1c0aa5ec7b107dcf15 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/simd.h @@ -0,0 +1,252 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_SIMD_H +#define INCLUDE_SIMD_H + +#include "hardware.h" +#include "kernel_operator.h" + +///////////////////////////////////////////////////// +// vadd +///////////////////////////////////////////////////// +template +__aicore__ inline void add_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, + AscendC::LocalTensor src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) { + AscendC::Add(dst, src0, src1, static_cast(0), repeat, + AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride, + dstRepeatStride, src0RepeatStride, src1RepeatStride)); +} + +///////////////////////////////////////////////////// +// vadds +///////////////////////////////////////////////////// +template +__aicore__ inline void adds_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, DType scalarValue, + uint8_t repeat, uint8_t dstBlockStride, uint8_t srcBlockStride, uint8_t dstRepeatStride, + uint8_t srcRepeatStride) { + AscendC::Adds( + dst, src, scalarValue, static_cast(0), repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vcadd +///////////////////////////////////////////////////// +template +__aicore__ inline void cadd_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstRepeatStride, uint16_t srcBlockStride, uint16_t srcRepeatStride) { + AscendC::RepeatReduceSum(dst, src, repeat, 0, 0, srcBlockStride, dstRepeatStride, srcRepeatStride); +} +///////////////////////////////////////////////////// +// vbrcb +///////////////////////////////////////////////////// +template +__aicore__ inline void brcb_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint16_t dstBlockStride, + uint16_t dstRepeatStride, uint8_t repeat) { + AscendC::Brcb(dst, src, repeat, AscendC::BrcbRepeatParams(dstBlockStride, dstRepeatStride)); +} + +///////////////////////////////////////////////////// +// vcmax +///////////////////////////////////////////////////// +template +__aicore__ inline void cmax_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstRepeatStride, uint16_t srcBlockStride, uint16_t srcRepeatStride) { +#if defined(__DAV_C220_VEC__) + AscendC::WholeReduceMax(dst, src, static_cast(0), repeat, dstRepeatStride, srcBlockStride, + srcRepeatStride, OrderType); +#else + AscendC::WholeReduceMax(dst, src, static_cast(0), repeat, dstRepeatStride, srcBlockStride, + srcRepeatStride); +#endif +} + +///////////////////////////////////////////////////// +// vconv +///////////////////////////////////////////////////// +template +__aicore__ inline void conv_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride, + uint16_t srcRepeatStride) { + if constexpr (std::is_same::value && std::is_same::value) { + AscendC::Cast( + dst, src, AscendC::RoundMode::CAST_RINT, static_cast(0), repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); + } else { + AscendC::Cast( + dst, src, AscendC::RoundMode::CAST_NONE, static_cast(0), repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); + } +} + +///////////////////////////////////////////////////// +// vconv_f322bf16r +///////////////////////////////////////////////////// +template +__aicore__ inline void convr_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride, + uint16_t srcRepeatStride) { + AscendC::Cast( + dst, src, AscendC::RoundMode::CAST_RINT, static_cast(0), repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vdiv +///////////////////////////////////////////////////// +template +__aicore__ inline void div_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, + AscendC::LocalTensor src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) { + AscendC::Div(dst, src0, src1, static_cast(0), repeat, + AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride, + dstRepeatStride, src0RepeatStride, src1RepeatStride)); +} + +///////////////////////////////////////////////////// +// vexp +///////////////////////////////////////////////////// +template +__aicore__ inline void exp_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride, + uint16_t srcRepeatStride) { + AscendC::Exp( + dst, src, static_cast(0), repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vmax +///////////////////////////////////////////////////// +template +__aicore__ inline void max_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, + AscendC::LocalTensor src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) { + AscendC::Max(dst, src0, src1, static_cast(0), repeat, + AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride, + dstRepeatStride, src0RepeatStride, src1RepeatStride)); +} + +///////////////////////////////////////////////////// +// vmul +///////////////////////////////////////////////////// +template +__aicore__ inline void mul_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, + AscendC::LocalTensor src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) { + AscendC::Mul(dst, src0, src1, static_cast(0), repeat, + AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride, + dstRepeatStride, src0RepeatStride, src1RepeatStride)); +} + +///////////////////////////////////////////////////// +// vmuls +///////////////////////////////////////////////////// +template +__aicore__ inline void muls_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, DType src1, + uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride, + uint16_t dstRepeatStride, uint16_t srcRepeatStride) { + AscendC::Muls( + dst, src0, src1, static_cast(0), repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vsub +///////////////////////////////////////////////////// +template +__aicore__ inline void sub_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, + AscendC::LocalTensor src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) { + AscendC::Sub(dst, src0, src1, static_cast(0), repeat, + AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride, + dstRepeatStride, src0RepeatStride, src1RepeatStride)); +} + +///////////////////////////////////////////////////// +// vmaxs +///////////////////////////////////////////////////// +template +__aicore__ inline void maxs_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, DType src1, + uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride, + uint16_t dstRepeatStride, uint16_t srcRepeatStride) { + AscendC::Maxs( + dst, src0, src1, static_cast(0), repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vmins +///////////////////////////////////////////////////// +template +__aicore__ inline void mins_v(AscendC::LocalTensor dst, AscendC::LocalTensor src0, DType src1, + uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride, + uint16_t dstRepeatStride, uint16_t srcRepeatStride) { + AscendC::Mins( + dst, src0, src1, static_cast(0), repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vsqrt +///////////////////////////////////////////////////// +template +__aicore__ inline void sqrt_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride, + uint16_t srcRepeatStride) { + AscendC::Sqrt( + dst, src, static_cast(0), repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vln +///////////////////////////////////////////////////// +template +__aicore__ inline void ln_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride, + uint16_t srcRepeatStride) { + AscendC::Ln( + dst, src, static_cast(0), repeat, + AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride)); +} + +///////////////////////////////////////////////////// +// vtranspose +///////////////////////////////////////////////////// +template +__aicore__ inline void tranpose_v(AscendC::LocalTensor dst, AscendC::LocalTensor src) { + AscendC::Transpose(dst, src); +} + +///////////////////////////////////////////////////// +// vcgmax +///////////////////////////////////////////////////// +template +__aicore__ inline void cgmax_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, const int32_t repeat, + const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride) { + AscendC::BlockReduceMax(dst, src, repeat, 0, dstRepStride, srcBlkStride, srcRepStride); +} + +///////////////////////////////////////////////////// +// vcgadd +///////////////////////////////////////////////////// +template +__aicore__ inline void cgadd_v(AscendC::LocalTensor dst, AscendC::LocalTensor src, const int32_t repeat, + const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride) { + AscendC::BlockReduceSum(dst, src, repeat, 0, dstRepStride, srcBlkStride, srcRepStride); +} +#endif diff --git a/ops/ascendc/kernel_common/op_kernel/asd/common/include/utils.h b/ops/ascendc/kernel_common/op_kernel/asd/common/include/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..cceeb61ae2ce4838d257b6b5ad749f978915cde1 --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/common/include/utils.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef INCLUDE_UTILS_H +#define INCLUDE_UTILS_H + +template +__aicore__ inline void CreateCaMatrix(const AscendC::LocalTensor &dst, const uint16_t repeats, + const uint16_t blockNum, const uint16_t dstGap, const IN_DTYPE initValue) { + AscendC::InitConstValue(dst, AscendC::InitConstValueParams(repeats, blockNum, dstGap, initValue)); +} + +__aicore__ inline void SetFftsBaseAddr(uint64_t config) { AscendC::SetSyncBaseAddr(config); } + +template +__aicore__ inline void SetPadding(IN_DTYPE padValue) { + AscendC::SetLoadDataPaddingValue(padValue); +} + +__aicore__ inline void SetAtomicnone() { AscendC::SetAtomicNone(); } + +__aicore__ inline void SetMasknorm() { +#if __CCE_AICORE__ == 100 + return; +#endif + AscendC::SetMaskNorm(); +} + +__aicore__ inline void SetNdpara(uint16_t ndNum, uint16_t srcNdStride, uint16_t dstNdStride) { + AscendC::SetFixpipeNz2ndFlag(ndNum, srcNdStride, dstNdStride); +} + +template +__aicore__ inline void SetVectorMask(const uint64_t maskHigh, const uint64_t maskLow) { + AscendC::SetVectorMask(maskHigh, maskLow); +} + +__aicore__ inline int64_t GetSubBlockidx() { return AscendC::GetSubBlockIdx(); } + +__aicore__ inline void WaitFlagDev(uint16_t flagId) { AscendC::WaitEvent(flagId); } + +template +__aicore__ inline void FftsCrossCoreSync(uint16_t flagId) { + AscendC::CrossCoreSetFlag(flagId); +} + +template +__aicore__ inline void SetFpc(const AscendC::LocalTensor &preTensor, bool isUnitFlag = false) { + AscendC::SetFixPipeConfig(preTensor, isUnitFlag); +} + +template +__aicore__ inline void CopyCbufToFbuf(AscendC::LocalTensor &dst, AscendC::LocalTensor &src, + uint16_t burstNum, uint16_t burstLen, uint16_t srcGapSize, uint16_t dstGapSize) { + dst.address_.logicPos = static_cast(AscendC::TPosition::C2PIPE2GM); + AscendC::DataCopy(dst, src, + AscendC::DataCopyParams(burstNum, // nBurst + burstLen, // lenBurst + srcGapSize, // srcGap + dstGapSize)); // dstGap); +} + +template +__aicore__ inline void CopyCbufToBt(uint64_t dst, const AscendC::LocalTensor &src, uint16_t convControl, + uint16_t nBurst, uint16_t lenBurst, uint16_t sourceGap, uint16_t dstGap) { + AscendC::LocalTensor dstTensor; + dstTensor.InitBuffer(dst, nBurst * lenBurst); + dstTensor.address_.logicPos = static_cast(AscendC::TPosition::C2); + AscendC::DataCopy(dstTensor, src, + AscendC::DataCopyParams(nBurst, // nBurst + lenBurst, // lenBurst + sourceGap, // srcGap + dstGap)); // dstGap); +} +#endif diff --git a/ops/ascendc/kernel_common/op_kernel/asd/unpad_fa/fa_common.cce b/ops/ascendc/kernel_common/op_kernel/asd/unpad_fa/fa_common.cce new file mode 100644 index 0000000000000000000000000000000000000000..2e263b91eb504bffa5662600233b7bae637f054c --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/unpad_fa/fa_common.cce @@ -0,0 +1,544 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "../common/include/common.h" +#include "../common/include/common_func.h" +#include "../common/include/simd.h" +#include "../common/include/iterator.h" +#include "../common/include/mma.h" +#include "../common/include/utils.h" +#include "kernel_operator.h" + +#ifdef __DAV_C220_VEC__ +constexpr int32_t ROW_OPS_SPEC_MASK_32 = 32; +constexpr int32_t ROW_OPS_SPEC_MASK_8 = 8; +constexpr int32_t ROW_OPS_SPEC_MASK_4 = 4; +constexpr int32_t REDUCE_UB_SIZE = 1024; +constexpr int32_t FLOAT_VECTOR_SIZE = 64; +constexpr int32_t VECTOR_SIZE = 128; +constexpr int32_t BLOCK_SIZE = 16; +constexpr int32_t FLOAT_BLOCK_SIZE = 8; +constexpr int32_t S_DB_SIZE = 8192; + +enum class RowCalcTile { TAIL_TILE = 0, SPEC_TILE_256, SPEC_TILE_512 }; + +enum ScaleType { SCALE_TOR = 0, SCALE_LOGN = 1, SCALE_LOGN_FP32 = 2 }; + +enum class MaskType { + MASK_TYPE_NONE = 0, + MASK_TYPE_TRIU = 1, + MASK_TYPE_ALIBI = 2, + MASK_TYPE_ALIBI_COMPRESS = 6, + MASK_TYPE_ALIBI_COMPRESS_SQRT = 7, + MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN = 8, + MASK_TYPE_ALIBI_COMPRESS_128 = 9 +}; + +__aicore__ __attribute__((always_inline)) inline void SetVecMask(int32_t len) { + uint64_t mask = 0; + uint64_t one = 1; + uint64_t temp = len % FLOAT_VECTOR_SIZE; + for (int64_t i = 0; i < temp; i++) { + mask |= one << i; + } + + if (len == VECTOR_SIZE || len == 0) { + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else if (len >= FLOAT_VECTOR_SIZE) { + SetVectorMask(mask, (uint64_t)-1); + } else { + SetVectorMask(0x0, mask); + } +} + +template +__aicore__ __attribute__((always_inline)) inline void SetBlockReduceMask(int32_t len); + +template +struct Rowsum { + __aicore__ + __attribute__((always_inline)) inline Rowsum(const AscendC::LocalTensor &src_ub, + const AscendC::LocalTensor &rowsum_ub, + const AscendC::LocalTensor &tmp_ub, uint32_t num_rows_round, + uint32_t num_elems, uint32_t num_elems_aligned); +}; + +template +struct Rowmax { + __aicore__ + __attribute__((always_inline)) inline Rowmax(const AscendC::LocalTensor &src_ub, + const AscendC::LocalTensor &rowmax_ub, + const AscendC::LocalTensor &tmp_ub, uint32_t num_rows_round, + uint32_t num_elems, uint32_t num_elems_aligned); +}; + +template +struct OnlineSoftmaxStage1 { + __aicore__ __attribute__((always_inline)) inline OnlineSoftmaxStage1( + const AscendC::LocalTensor &s_ub, const AscendC::LocalTensor &mask_orig_ub, + const AscendC::LocalTensor &mask_processed_ub, const AscendC::LocalTensor &local_rowmax_ub, + const AscendC::LocalTensor &hat_rowmax_ub, const AscendC::LocalTensor &global_rowmax_ub, + const AscendC::LocalTensor &diff_rowmax_ub, const AscendC::LocalTensor &s_exp_ub, + const AscendC::LocalTensor &local_rowsum_ub, const AscendC::LocalTensor &global_rowsum_ub, + const AscendC::LocalTensor &p_ub, const AscendC::LocalTensor &tmp_ub, + const AscendC::GlobalTensor &s_gm, const AscendC::GlobalTensor &p_gm, bool first_n_iter, + S_DTYPE tor, uint32_t m, uint32_t n_real, uint32_t n_stride, uint32_t pingpong_flag); +}; + +template <> +__aicore__ __attribute__((always_inline)) inline void SetBlockReduceMask(int32_t len) { + if (len > 8 || len < 1) { + SetVectorMask((uint64_t)-1, (uint64_t)-1); + return; + } + uint64_t subMask = ((uint64_t)1 << len) - 1; + uint64_t maskValue = (subMask << 48) + (subMask << 32) + (subMask << 16) + subMask + (subMask << 56) + + (subMask << 40) + (subMask << 24) + (subMask << 8); + SetVectorMask(maskValue, maskValue); +} + +template <> +__aicore__ __attribute__((always_inline)) inline void SetBlockReduceMask(int32_t len) { + if (len > 16 || len < 1) { + SetVectorMask((uint64_t)-1, (uint64_t)-1); + return; + } + uint64_t subMask = ((uint64_t)1 << len) - 1; + uint64_t maskValue = (subMask << 48) + (subMask << 32) + (subMask << 16) + subMask; + SetVectorMask(maskValue, maskValue); +} + +template <> +struct Rowsum { + __aicore__ + __attribute__((always_inline)) inline Rowsum(const AscendC::LocalTensor &src_ub, + const AscendC::LocalTensor &rowsum_ub, + const AscendC::LocalTensor &tmp_ub, uint32_t num_rows_round, + uint32_t num_elems, uint32_t num_elems_aligned) { + cgadd_v(tmp_ub, src_ub, num_rows_round * num_elems_aligned / FLOAT_VECTOR_SIZE, 1, 1, + 8); + PIPE_BARRIER(V); + cgadd_v(tmp_ub[REDUCE_UB_SIZE], tmp_ub, + num_rows_round * num_elems_aligned / FLOAT_BLOCK_SIZE / FLOAT_VECTOR_SIZE, 1, + 1, 8); + PIPE_BARRIER(V); + cgadd_v(rowsum_ub, tmp_ub[REDUCE_UB_SIZE], + num_rows_round * num_elems_aligned / FLOAT_VECTOR_SIZE / FLOAT_VECTOR_SIZE, 1, + 1, 8); + PIPE_BARRIER(V); + } +}; + +template <> +struct Rowsum { + __aicore__ + __attribute__((always_inline)) inline Rowsum(const AscendC::LocalTensor &src_ub, + const AscendC::LocalTensor &rowsum_ub, + const AscendC::LocalTensor &tmp_ub, uint32_t num_rows_round, + uint32_t num_elems, uint32_t num_elems_aligned) { + cgadd_v(tmp_ub, src_ub, num_rows_round * num_elems_aligned / FLOAT_VECTOR_SIZE, 1, 1, + 8); + PIPE_BARRIER(V); + SetVecMask(ROW_OPS_SPEC_MASK_32); + cgadd_v(tmp_ub[REDUCE_UB_SIZE], tmp_ub, num_rows_round, 1, 1, 4); + PIPE_BARRIER(V); + SetBlockReduceMask(ROW_OPS_SPEC_MASK_4); + cgadd_v( + rowsum_ub, tmp_ub[REDUCE_UB_SIZE], + (num_rows_round * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } +}; + +template <> +struct Rowsum { + __aicore__ + __attribute__((always_inline)) inline Rowsum(const AscendC::LocalTensor &src_ub, + const AscendC::LocalTensor &rowsum_ub, + const AscendC::LocalTensor &tmp_ub, uint32_t num_rows_round, + uint32_t num_elems, uint32_t num_elems_aligned) { + if (num_elems >= FLOAT_VECTOR_SIZE) { + cgadd_v(tmp_ub, src_ub, num_rows_round, 1, 1, num_elems_aligned / FLOAT_BLOCK_SIZE); + PIPE_BARRIER(V); + cgadd_v( + rowsum_ub, tmp_ub, (num_rows_round * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + for (uint64_t rowsum_idx = 1; rowsum_idx < (uint64_t)num_elems / FLOAT_VECTOR_SIZE; ++rowsum_idx) { + cgadd_v(tmp_ub, src_ub[rowsum_idx * FLOAT_VECTOR_SIZE], num_rows_round, 1, 1, + num_elems_aligned / FLOAT_BLOCK_SIZE); + PIPE_BARRIER(V); + cgadd_v( + tmp_ub[REDUCE_UB_SIZE], tmp_ub, + (num_rows_round * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + SetVecMask(num_rows_round); + add_v(rowsum_ub, rowsum_ub, tmp_ub[REDUCE_UB_SIZE], + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + if (num_elems % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(num_elems % FLOAT_VECTOR_SIZE); + cgadd_v(tmp_ub, src_ub[num_elems / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + num_rows_round, 1, 1, num_elems_aligned / FLOAT_BLOCK_SIZE); + PIPE_BARRIER(V); + SetBlockReduceMask((num_elems % FLOAT_VECTOR_SIZE + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE); + if (num_elems < FLOAT_VECTOR_SIZE) { + cgadd_v( + rowsum_ub, tmp_ub, (num_rows_round * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + } else { + cgadd_v( + tmp_ub[REDUCE_UB_SIZE], tmp_ub, + (num_rows_round * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + SetVecMask(num_rows_round); + add_v(rowsum_ub, rowsum_ub, tmp_ub[REDUCE_UB_SIZE], + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + } + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } +}; + +template <> +struct Rowmax { + __aicore__ + __attribute__((always_inline)) inline Rowmax(const AscendC::LocalTensor &src_ub, + const AscendC::LocalTensor &rowmax_ub, + const AscendC::LocalTensor &tmp_ub, uint32_t num_rows_round, + uint32_t num_elems, uint32_t num_elems_aligned) { + cgmax_v(tmp_ub, src_ub, num_rows_round * num_elems_aligned / FLOAT_VECTOR_SIZE, 1, 1, + 8); + PIPE_BARRIER(V); + cgmax_v(tmp_ub[REDUCE_UB_SIZE], tmp_ub, + num_rows_round * num_elems_aligned / FLOAT_BLOCK_SIZE / FLOAT_VECTOR_SIZE, 1, + 1, 8); + PIPE_BARRIER(V); + cgmax_v(rowmax_ub, tmp_ub[REDUCE_UB_SIZE], + num_rows_round * num_elems_aligned / FLOAT_VECTOR_SIZE / FLOAT_VECTOR_SIZE, 1, + 1, 8); + PIPE_BARRIER(V); + } +}; + +template <> +struct Rowmax { + __aicore__ + __attribute__((always_inline)) inline Rowmax(const AscendC::LocalTensor &src_ub, + const AscendC::LocalTensor &rowmax_ub, + const AscendC::LocalTensor &tmp_ub, uint32_t num_rows_round, + uint32_t num_elems, uint32_t num_elems_aligned) { + cgmax_v(tmp_ub, src_ub, num_rows_round * num_elems_aligned / FLOAT_VECTOR_SIZE, 1, 1, + 8); + PIPE_BARRIER(V); + SetVecMask(ROW_OPS_SPEC_MASK_32); + cgmax_v(tmp_ub[REDUCE_UB_SIZE], tmp_ub, num_rows_round, 1, 1, 4); + PIPE_BARRIER(V); + SetBlockReduceMask(ROW_OPS_SPEC_MASK_4); + cgmax_v( + rowmax_ub, tmp_ub[REDUCE_UB_SIZE], + (num_rows_round * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } +}; + +template <> +struct Rowmax { + __aicore__ + __attribute__((always_inline)) inline Rowmax(const AscendC::LocalTensor &src_ub, + const AscendC::LocalTensor &rowmax_ub, + const AscendC::LocalTensor &tmp_ub, uint32_t num_rows_round, + uint32_t num_elems, uint32_t num_elems_aligned) { + if (num_elems >= FLOAT_VECTOR_SIZE) { + cgmax_v(tmp_ub, src_ub, num_rows_round, 1, 1, num_elems_aligned / FLOAT_BLOCK_SIZE); + PIPE_BARRIER(V); + cgmax_v( + rowmax_ub, tmp_ub, (num_rows_round * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + for (uint64_t rowmax_idx = 1; rowmax_idx < (uint64_t)num_elems / FLOAT_VECTOR_SIZE; ++rowmax_idx) { + cgmax_v(tmp_ub, src_ub[rowmax_idx * FLOAT_VECTOR_SIZE], num_rows_round, 1, 1, + num_elems_aligned / FLOAT_BLOCK_SIZE); + PIPE_BARRIER(V); + cgmax_v( + tmp_ub[REDUCE_UB_SIZE], tmp_ub, + (num_rows_round * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + SetVecMask(num_rows_round); + max_v(rowmax_ub, rowmax_ub, tmp_ub[REDUCE_UB_SIZE], + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + if (num_elems % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(num_elems % FLOAT_VECTOR_SIZE); + cgmax_v(tmp_ub, src_ub[num_elems / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + num_rows_round, 1, 1, num_elems_aligned / FLOAT_BLOCK_SIZE); + PIPE_BARRIER(V); + SetBlockReduceMask((num_elems % FLOAT_VECTOR_SIZE + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE); + if (num_elems < FLOAT_VECTOR_SIZE) { + cgmax_v( + rowmax_ub, tmp_ub, (num_rows_round * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + } else { + cgmax_v( + tmp_ub[REDUCE_UB_SIZE], tmp_ub, + (num_rows_round * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + SetVecMask(num_rows_round); + max_v(rowmax_ub, rowmax_ub, tmp_ub[REDUCE_UB_SIZE], + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + } + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } +}; + +template +struct OnlineSoftmaxStage1 { + __aicore__ __attribute__((always_inline)) inline OnlineSoftmaxStage1( + const AscendC::LocalTensor &s_ub, const AscendC::LocalTensor &mask_orig_ub, + const AscendC::LocalTensor &mask_processed_ub, const AscendC::LocalTensor &local_rowmax_ub, + const AscendC::LocalTensor &hat_rowmax_ub, const AscendC::LocalTensor &global_rowmax_ub, + const AscendC::LocalTensor &diff_rowmax_ub, const AscendC::LocalTensor &s_exp_ub, + const AscendC::LocalTensor &local_rowsum_ub, const AscendC::LocalTensor &global_rowsum_ub, + const AscendC::LocalTensor &p_ub, const AscendC::LocalTensor &tmp_ub, + const AscendC::GlobalTensor &s_gm, const AscendC::GlobalTensor &p_gm, bool first_n_iter, float tor, + uint32_t m, uint32_t n_real, uint32_t n_stride, uint32_t pingpong_flag) { + uint32_t round_m = (m + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE * FLOAT_BLOCK_SIZE; + WAIT_FLAG(MTE3, MTE2, pingpong_flag); + // input QK + gm_to_ub(s_ub, s_gm, + 0, // sid + m, // nBurst + n_stride / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, pingpong_flag); + WAIT_FLAG(MTE2, V, pingpong_flag); + // *** ls = tor * ls + muls_v(s_ub, s_ub, tor, + (m * n_stride + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + if (n_real == 512) { + Rowmax(s_ub, local_rowmax_ub, tmp_ub, round_m, n_real, n_stride); + } else if (n_real == 256) { + Rowmax(s_ub, local_rowmax_ub, tmp_ub, round_m, n_real, n_stride); + } else { + Rowmax(s_ub, local_rowmax_ub, tmp_ub, round_m, n_real, n_stride); + } + + if (first_n_iter) { + // *** hm = lm + ub_to_ub(hat_rowmax_ub, local_rowmax_ub, + 0, // sid + 1, // nBurst + round_m / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + } else { + SetVecMask(m); + // *** hm = vmax(lm, gm) + max_v(hat_rowmax_ub, local_rowmax_ub, global_rowmax_ub, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** dm = gm - hm + sub_v(diff_rowmax_ub, global_rowmax_ub, hat_rowmax_ub, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** dm = exp(dm) + exp_v(diff_rowmax_ub, diff_rowmax_ub, + 1, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } + SetVectorMask((uint64_t)-1, (uint64_t)-1); + PIPE_BARRIER(V); + // *** gm = hm + ub_to_ub(global_rowmax_ub, hat_rowmax_ub, + 0, // sid + 1, // nBurst + round_m / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + // *** hm_block = expand_to_block(hm), 存放于 tv + brcb_v(tmp_ub.template ReinterpretCast(), + hat_rowmax_ub.template ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** ls = ls - hm_block + for (uint32_t vsub_idx = 0; vsub_idx < n_real / FLOAT_VECTOR_SIZE; ++vsub_idx) { + sub_v(s_ub[vsub_idx * FLOAT_VECTOR_SIZE], s_ub[vsub_idx * FLOAT_VECTOR_SIZE], + tmp_ub, + m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + n_stride / FLOAT_BLOCK_SIZE, // dstRepeatStride + n_stride / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (n_real % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(n_real % FLOAT_VECTOR_SIZE); + sub_v(s_ub[n_real / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + s_ub[n_real / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], tmp_ub, + m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + n_stride / FLOAT_BLOCK_SIZE, // dstRepeatStride + n_stride / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + + // *** ls = exp(ls) + exp_v(s_exp_ub, s_ub, + (m * n_stride + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + // *** ll = rowsum(ls32) + if (n_real == 512) { + Rowsum(s_exp_ub, local_rowsum_ub, tmp_ub, round_m, n_real, n_stride); + } else if (n_real == 256) { + Rowsum(s_exp_ub, local_rowsum_ub, tmp_ub, round_m, n_real, n_stride); + } else { + Rowsum(s_exp_ub, local_rowsum_ub, tmp_ub, round_m, n_real, n_stride); + } + + // *** lp = castfp32to16(ls) + conv_v(p_ub, s_exp_ub, + (m * n_stride + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + SET_FLAG(V, MTE3, pingpong_flag); + WAIT_FLAG(V, MTE3, pingpong_flag); + ub_to_gm(p_gm, p_ub, + 0, // sid + m, // nBurst + n_stride * 2 / BlockSize(), // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE3, MTE2, pingpong_flag); + if (first_n_iter) { + // *** gl = ll + ub_to_ub(global_rowsum_ub, local_rowsum_ub, + 0, // sid + 1, // nBurst + round_m / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + } else { + SetVecMask(m); + // *** gl = dm * gl + mul_v(global_rowsum_ub, diff_rowmax_ub, global_rowsum_ub, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** gl = ll + gl + add_v(global_rowsum_ub, global_rowsum_ub, local_rowsum_ub, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } +}; + +#endif diff --git a/ops/ascendc/kernel_common/op_kernel/asd/unpad_fa/unpad_flash_attention_mix.cce b/ops/ascendc/kernel_common/op_kernel/asd/unpad_fa/unpad_flash_attention_mix.cce new file mode 100644 index 0000000000000000000000000000000000000000..52e05529fda461ac936f9785d776e2c4a87c505d --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/unpad_fa/unpad_flash_attention_mix.cce @@ -0,0 +1,2335 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "../common/include/common.h" +#include "../common/include/common_func.h" +#include "../common/include/simd.h" +#include "../common/include/iterator.h" +#include "../common/include/mma.h" +#include "../common/include/utils.h" +#include "kernel_operator.h" + +namespace unpda_fa_npd_half { +#ifdef __CCE_KT_TEST__ +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +enum class ScaleType { SCALE_TOR = 0, SCALE_LOGN = 1, SCALE_LOGN_FP32 = 2 }; +// FFTS Flag +constexpr int32_t QK_READY = 1; +constexpr int32_t SOFTMAX_READY = 2; +constexpr int32_t UPDATE_READY = 3; +constexpr int32_t BIT_SHIFT = 8; +constexpr int32_t SOFTMAX_MAX_LENGTH = 256; +constexpr int32_t NO_STACK_S_BLOCK_LIMIT = 4; + +#ifdef __DAV_C220_CUBE__ +constexpr int32_t L0AB_HALF_BUF_SIZE = 16384; // 128 * 128 +constexpr int32_t BLOCK_SIZE = 16; +constexpr int32_t CUBE_MATRIX_SIZE = 256; // 16 * 16 +constexpr int32_t L0AB_UINT8_BLOCK_SIZE = 32768; // 128 * 128 * 2B +constexpr int32_t TMP_SIZE = 32768 * 4; // 128 * 256 * 2 + +template +class UnpadAttentionDecoderAic { + public: + template + __aicore__ __attribute__((always_inline)) inline uint32_t BlockSize() { + return 32 / sizeof(T); + } + + template + __aicore__ __attribute__((always_inline)) inline uint32_t MatrixSize() { + return 512 / sizeof(T); + } + + template + __aicore__ __attribute__((always_inline)) inline uint64_t BlockSizeRoundUp(uint64_t num) { + return (num + BlockSize() - 1) / BlockSize() * BlockSize(); + } + + template + __aicore__ __attribute__((always_inline)) inline uint64_t NumBlocksRoundUp(uint64_t num) { + return (num + BlockSize() - 1) / BlockSize(); + } + + template + __aicore__ __attribute__((always_inline)) inline uint64_t MatrixSizeRoundUp(uint64_t num) { + return (num + MatrixSize() - 1) / MatrixSize() * MatrixSize(); + } + + template + __aicore__ __attribute__((always_inline)) inline uint64_t NumMatrixsRoundUp(uint64_t num) { + return (num + MatrixSize() - 1) / MatrixSize(); + } + + template + __aicore__ __attribute__((always_inline)) inline uint64_t L0HalfSize() { + return 32 * 1024 / sizeof(T); + } + + __aicore__ __attribute__((always_inline)) inline void Run( + __gm__ uint8_t *__restrict__ sync, __gm__ uint8_t *__restrict__ q_gm, __gm__ uint8_t *__restrict__ k_gm, + __gm__ uint8_t *__restrict__ v_gm, __gm__ uint8_t *__restrict__ layerID_gm, __gm__ uint8_t *__restrict__ mask_gm, + __gm__ uint8_t *__restrict__ alibi_coeff_gm, __gm__ uint8_t *__restrict__ deq_qk_gm, + __gm__ uint8_t *__restrict__ off_qk_gm, __gm__ uint8_t *__restrict__ deq_pv_gm, + __gm__ uint8_t *__restrict__ off_pv_gm, __gm__ uint8_t *__restrict__ quant_p_gm, __gm__ uint8_t *__restrict__ o_gm, + __gm__ uint8_t *__restrict__ s_gm, __gm__ uint8_t *__restrict__ p_gm, __gm__ uint8_t *__restrict__ o_tmp_gm, + __gm__ uint8_t *__restrict__ upo_tmp_gm, __gm__ uint8_t *__restrict__ tiling_para_gm) { + SetFftsBaseAddr((unsigned long)sync); + SetPadding(0); + SetAtomicnone(); + SetNdpara(1, 0, 0); + SetMasknorm(); + + const uint32_t l1q_buf_addr_offset = 0; + const uint32_t l1k_buf_addr_offset = 2 * L0AB_UINT8_BLOCK_SIZE; + const uint32_t l1p_buf_addr_offset = 4 * L0AB_UINT8_BLOCK_SIZE; + const uint32_t l1v_buf_addr_offset = 6 * L0AB_UINT8_BLOCK_SIZE; + + AsdopsBuffer buf; + + AscendC::LocalTensor l1q_buf_addr_tensor = + buf.GetBuffer(l1q_buf_addr_offset); + AscendC::LocalTensor l1k_buf_addr_tensor = + buf.GetBuffer(l1k_buf_addr_offset); + AscendC::LocalTensor l1p_buf_addr_tensor = + buf.GetBuffer(l1p_buf_addr_offset); + AscendC::LocalTensor l1v_buf_addr_tensor = + buf.GetBuffer(l1v_buf_addr_offset); + + AscendC::LocalTensor l0a_buf_tensor = buf.GetBuffer(0); + AscendC::LocalTensor l0b_buf_tensor = buf.GetBuffer(0); + AscendC::LocalTensor l0c_buf_tensor = buf.GetBuffer(0); + + uint32_t batch_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm)); + uint32_t max_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1)); + uint32_t q_heads = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2)); + uint32_t embd = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3)); + uint32_t kv_heads = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 4)); + uint32_t is_triu_mask = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 8)); + uint32_t total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 9)); + uint32_t tiling_head_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 14)); + uint32_t tiling_para_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 15)); + uint32_t max_kv_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 18)); + uint32_t quantType = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 24)); + uint32_t data_shape_type = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 25)); + uint32_t window_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 27)); + uint32_t npd = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 37)); + npd = (npd == 2); + uint32_t page_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 38)); + + uint32_t group_num = q_heads / kv_heads; + uint64_t stride_qo = q_heads * embd; + uint64_t stride_kv = kv_heads * embd; + if (data_shape_type == 1) { + stride_qo = embd; + stride_kv = embd; + } + if (npd) { + stride_kv = kv_heads * embd * page_size; + } + + uint64_t batch_stride_kv = batch_size * max_kv_seqlen * kv_heads * embd * sizeof(mm1InputType); + if (layerID_gm != nullptr) { + uint32_t layer_id = *(__gm__ uint32_t *)layerID_gm; + k_gm = k_gm + layer_id * batch_stride_kv; + v_gm = v_gm + layer_id * batch_stride_kv; + } + + AscendC::GlobalTensor q_gm_tensor; + AscendC::GlobalTensor k_gm_tensor; + AscendC::GlobalTensor v_gm_tensor; + + AscendC::GlobalTensor s_gm_tensor; + AscendC::GlobalTensor p_gm_tensor; + AscendC::GlobalTensor o_tmp_gm_tensor; + AscendC::GlobalTensor o_gm_tensor; + AscendC::GlobalTensor deq_qk_gm_tensor; + + q_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mm1InputType *>(q_gm)); + k_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mm1InputType *>(k_gm)); + v_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mm2InputType *>(v_gm)); + + s_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(s_gm)); + p_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mm2InputType *>(p_gm)); + o_tmp_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mmOutputType *>(o_tmp_gm)); + o_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(o_gm)); + deq_qk_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(deq_qk_gm)); + + uint32_t __k = embd; + uint32_t round_k = (__k + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + SET_FLAG(MTE1, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID1); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + SET_FLAG(M, MTE1, EVENT_ID2); + SET_FLAG(M, MTE1, EVENT_ID3); + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(FIX, M, EVENT_ID1); + + uint64_t cur_batch = 0; + uint32_t pre_total_q_blk_num = 0; + uint32_t offset_tiling = tiling_head_size + tiling_para_size * cur_batch; + uint32_t cur_total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 13 + offset_tiling)); + uint32_t process_num = total_q_blk_num * q_heads; + uint32_t next_process = 0; + for (uint32_t process = block_idx; process < process_num; process = next_process) { + while (process >= cur_total_q_blk_num * q_heads) { + cur_batch++; + pre_total_q_blk_num = cur_total_q_blk_num; + offset_tiling += tiling_para_size; + cur_total_q_blk_num = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 13 + offset_tiling)); + } + next_process = process + block_num; + if (is_triu_mask) { + uint32_t curr_iter = process / block_num; + next_process = + curr_iter % 2 == 1 ? (curr_iter + 1) * block_num + block_idx : (curr_iter + 2) * block_num - 1 - block_idx; + } + + // get tiling args + uint32_t q_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + offset_tiling)); + uint32_t kv_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1 + offset_tiling)); + if (q_seqlen == 0 || kv_seqlen == 0) { + continue; + } + uint32_t pp_m_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2 + offset_tiling)); + uint32_t pp_n_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3 + offset_tiling)); + uint32_t addr_q_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 4 + offset_tiling)); + uint32_t addr_q_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 5 + offset_tiling)); + uint64_t addr_q_scalar = (uint64_t)(((uint64_t)addr_q_high32) << 32 | addr_q_loww32); + uint32_t addr_k_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 6 + offset_tiling)); + uint32_t addr_k_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 7 + offset_tiling)); + uint64_t addr_k_scalar = (uint64_t)(((uint64_t)addr_k_high32) << 32 | addr_k_loww32); + uint32_t addr_v_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 8 + offset_tiling)); + uint32_t addr_v_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 9 + offset_tiling)); + uint64_t addr_v_scalar = (uint64_t)(((uint64_t)addr_v_high32) << 32 | addr_v_loww32); + + uint32_t process_idx = process - pre_total_q_blk_num * q_heads; + uint32_t m_idx = process_idx / q_heads; + uint64_t head_idx = process_idx % q_heads; + + uint32_t m_loop = (q_seqlen + pp_m_scalar - 1) / pp_m_scalar; + uint32_t n_loop = (kv_seqlen + pp_n_scalar - 1) / pp_n_scalar; + + uint32_t qk_m = (m_idx == (m_loop - 1)) ? (q_seqlen - m_idx * pp_m_scalar) : pp_m_scalar; + uint32_t qk_round_m = (qk_m + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + /**************** pre_load *****************/ + uint32_t qk_n = n_loop == 1 ? kv_seqlen : pp_n_scalar; + uint32_t qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + uint32_t pingpong_flag = 0; + uint32_t offset = pingpong_flag * L0AB_HALF_BUF_SIZE; + + uint64_t q_offset = addr_q_scalar + head_idx * embd + m_idx * pp_m_scalar * stride_qo; + uint64_t k_base_offset = addr_k_scalar + (head_idx / group_num) * embd; + if (data_shape_type == 1) { + q_offset = addr_q_scalar + head_idx * embd * max_seqlen + m_idx * pp_m_scalar * stride_qo; + k_base_offset = addr_k_scalar + (head_idx / group_num) * embd * max_kv_seqlen; + } + if (npd) { + k_base_offset = addr_k_scalar + (head_idx / group_num) * embd * page_size; + } + uint64_t k_offset = 0; + // Only need load Q once + if (qk_m == 1) { + gm_to_l1( + l1q_buf_addr_tensor, q_gm_tensor[q_offset], 1, 0, 0, RoundUp(round_k, 32 / sizeof(mm1InputType)), 0, + 0); + } else { + gm_to_l1(l1q_buf_addr_tensor, + q_gm_tensor[q_offset], + qk_m, // nValue + qk_round_m, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + __k, // dValue + 0, // dstNzMatrixStride, unused + stride_qo // srcDValue + ); + } + SET_FLAG(MTE2, MTE1, pingpong_flag); + WAIT_FLAG(MTE2, MTE1, pingpong_flag); + + uint32_t sv_n = n_loop == 1 ? kv_seqlen : pp_n_scalar; + uint32_t sv_round_n = (sv_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + uint64_t v_base_offset = addr_v_scalar + (head_idx / group_num) * embd; + if (data_shape_type == 1) { + v_base_offset = addr_v_scalar + (head_idx / group_num) * embd * max_kv_seqlen; + } + if (npd) { + v_base_offset = addr_v_scalar + (head_idx / group_num) * embd * page_size; + } + uint64_t v_offset = 0; + uint32_t n_end = n_loop; + uint32_t window_start = (window_size + pp_n_scalar - 1) / pp_n_scalar; + uint32_t n_start = 0; + if ((is_triu_mask && pp_m_scalar >= pp_n_scalar) || window_size > 0) { + uint32_t n_offset = ((m_idx + 1) * pp_m_scalar + kv_seqlen - q_seqlen + pp_n_scalar - 1) / pp_n_scalar; + n_end = n_offset > n_end ? n_end : n_offset; + } + uint32_t k_token_id = 0; + uint32_t v_token_id = 0; + if constexpr (swa_flag) { + if (window_size > 0 && window_size < kv_seqlen) { + n_start = (m_idx < window_start) ? 0 : m_idx - window_start; + k_offset += n_start * stride_kv * pp_n_scalar; + v_offset += n_start * stride_kv * pp_n_scalar; + k_token_id = (n_start * pp_n_scalar); + v_token_id = (n_start * pp_n_scalar); + } + } + uint32_t s_block_stack = n_end > NO_STACK_S_BLOCK_LIMIT ? 2 : 1; // Currently not splitting K + uint32_t launch_delay = s_block_stack * 2; + uint32_t vect_mod = 2 * launch_delay; + for (uint32_t n_idx = n_start; n_idx < n_end + launch_delay; n_idx += s_block_stack) { + if (n_idx < n_end) { + for (uint32_t split_idx = 0; split_idx < s_block_stack && n_idx + split_idx < n_end; split_idx++) { + pingpong_flag = (n_idx + split_idx - n_start) % 2; + offset = pingpong_flag * L0AB_HALF_BUF_SIZE; + if (n_idx + split_idx == (n_loop - 1)) { + qk_n = (kv_seqlen - (n_idx + split_idx) * pp_n_scalar); + qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + } + bool last_split = split_idx == s_block_stack - 1 || n_idx + split_idx == n_end - 1; + WAIT_FLAG(M, MTE1, pingpong_flag); + uint32_t round_row_1 = RoundUp(round_k, 32 / sizeof(mm1InputType)); + if (qk_m == 1) { + l1_to_l0_a( + l0a_buf_tensor[offset], l1q_buf_addr_tensor, 0, + NumMatrixsRoundUp(round_row_1), // repeat + 0, + 1, // srcStride + 0, + 0 // dstStride + ); + } else { + l1_to_l0_a( + l0a_buf_tensor[offset], l1q_buf_addr_tensor, qk_round_m, round_row_1, 0, 0, 0, 0); + } + // *** Prepare K to L1 + SET_FLAG(MTE1, M, pingpong_flag); + WAIT_FLAG(MTE1, MTE2, pingpong_flag); + + if (npd) { + k_offset = (k_token_id / page_size) * stride_kv + (k_token_id % page_size) * embd; + CopyGmToL1Npd(l1k_buf_addr_tensor[offset], + k_gm_tensor[k_base_offset + k_offset], page_size, qk_n, + qk_round_n, __k, round_k, stride_kv); + } else { + gm_to_l1( + l1k_buf_addr_tensor[offset], k_gm_tensor[k_base_offset + k_offset], + qk_n, // nValue + qk_round_n, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + __k, // dValue + 0, // dstNzMatrixStride, unused + stride_kv // srcDValue + ); + k_offset += pp_n_scalar * stride_kv; + } + k_token_id += qk_n; + SET_FLAG(MTE2, MTE1, pingpong_flag); + WAIT_FLAG(M, MTE1, pingpong_flag + 2); + WAIT_FLAG(MTE2, MTE1, pingpong_flag); + l1_to_l0_b( + l0b_buf_tensor[offset], l1k_buf_addr_tensor[offset], 0, + NumMatrixsRoundUp(RoundUp(round_k, 32 / sizeof(mm1InputType)) * + qk_round_n), // repeat + 0, + 1, // srcStride + 0, + 0 // dstStride + ); + SET_FLAG(MTE1, MTE2, pingpong_flag); + SET_FLAG(MTE1, M, pingpong_flag + 2); + WAIT_FLAG(MTE1, M, pingpong_flag); + WAIT_FLAG(MTE1, M, pingpong_flag + 2); + if (split_idx == 0) { + WAIT_FLAG(FIX, M, EVENT_ID0); + WAIT_FLAG(FIX, M, EVENT_ID1); + } + mmad( + l0c_buf_tensor[split_idx * qk_round_m * pp_n_scalar], l0a_buf_tensor[offset], l0b_buf_tensor[offset], + qk_m, // m + qk_n, // n + __k, // k + 1 // cmatrixInitVal + ); + SET_FLAG(M, MTE1, pingpong_flag); + SET_FLAG(M, MTE1, pingpong_flag + 2); + } + SET_FLAG(M, FIX, EVENT_ID0); + WAIT_FLAG(M, FIX, EVENT_ID0); + uint32_t sv_n_triu = n_end * pp_n_scalar; + if (n_idx + s_block_stack > n_end - 1) { + sv_n = sv_n_triu > kv_seqlen ? kv_seqlen - n_idx * pp_n_scalar : sv_n_triu - n_idx * pp_n_scalar; + } else { + sv_n = pp_n_scalar * s_block_stack; + } + sv_round_n = (sv_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + if constexpr (int8_flag) { + float tmp = deq_qk_gm_tensor.GetValue(head_idx); + uint64_t deqScalar = static_cast(*reinterpret_cast(&tmp)); + AscendC::SetFixpipeNz2ndFlag(1, 0, 0); + AscendC::DataCopyCO12DstParams intriParams(sv_round_n, qk_m, sv_round_n, qk_round_m, QuantMode_t::DEQF16, 0, + false, true); + AscendC::SetFixpipePreQuantFlag(deqScalar); + AscendC::PipeBarrier(); + AscendC::DataCopy(s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod], + l0c_buf_tensor, intriParams); + } else { + l0c_to_gm( + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + (n_idx - n_start) % vect_mod * TMP_SIZE / vect_mod], + l0c_buf_tensor, + qk_m, // MSize + sv_round_n, // NSize + qk_round_m, // srcStride + sv_round_n // dstStride_dst_D + ); + } + + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(FIX, M, EVENT_ID1); + FftsCrossCoreSync(QK_READY); + } + if (n_idx >= launch_delay + n_start) { + uint32_t l0c_pingpong_flag = (n_idx - n_start) % 2; + uint32_t l0c_offset = l0c_pingpong_flag * L0AB_HALF_BUF_SIZE; + uint32_t sv_n_triu = n_end * pp_n_scalar; + if (n_idx + s_block_stack > n_end + launch_delay - 1) { + sv_n = sv_n_triu > kv_seqlen ? kv_seqlen - (n_idx - launch_delay) * pp_n_scalar + : sv_n_triu - (n_idx - launch_delay) * pp_n_scalar; + } else { + sv_n = pp_n_scalar * s_block_stack; + } + sv_round_n = (sv_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + if constexpr (int8_flag) { + gm_to_l1( + l1v_buf_addr_tensor, v_gm_tensor[v_base_offset + v_offset], + sv_n, // nValue + RoundUp(sv_round_n, 32), // dstNzC0Stride + 0, // dstNzMatrixStride, unused + __k, // dValue + 0, // dstNzMatrixStride, unused + stride_kv // srcDValue + ); + v_offset += sv_n * stride_kv; + } else { + if (npd) { + v_offset = ((v_token_id) / page_size) * stride_kv + (v_token_id % page_size) * embd; + CopyGmToL1Npd(l1v_buf_addr_tensor, + v_gm_tensor[v_base_offset + v_offset], page_size, sv_n, + sv_round_n, __k, round_k, stride_kv); + } else { + gm_to_l1( + l1v_buf_addr_tensor, v_gm_tensor[v_base_offset + v_offset], + sv_n, // nValue + sv_round_n, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + __k, // dValue + 0, // dstNzMatrixStride, unused + stride_kv // srcDValue + ); + v_offset += sv_n * stride_kv; + } + } + v_token_id += sv_n; + SET_FLAG(MTE2, MTE1, EVENT_ID2); + WAIT_FLAG(MTE2, MTE1, EVENT_ID2); + WAIT_FLAG(M, MTE1, EVENT_ID2); + WAIT_FLAG(M, MTE1, EVENT_ID3); + if constexpr (int8_flag) { + for (uint32_t l0b_load_idx = 0; l0b_load_idx < (sv_round_n + 31) / 32 * 32 / BlockSize(); + ++l0b_load_idx) { + AscendC::LoadDataWithTranspose( + l0b_buf_tensor[l0b_load_idx * round_k * BlockSize()], + l1v_buf_addr_tensor[l0b_load_idx * BlockSize() * BlockSize()], + AscendC::LoadData2dTransposeParams(0, // startIndexIn + (round_k + 31) / 32 * 32 / BlockSize(), // repeatTimesIn + (sv_round_n + 31) / 32 * 32 / BlockSize(), // srcStrideIn + 1, // dstGapIn + 0, // dstfracGapIn + 0) // addrModeIn + ); + } + } else { + for (uint32_t l0b_load_idx = 0; l0b_load_idx < sv_round_n / BLOCK_SIZE; ++l0b_load_idx) { + l1_to_l0_b( + l0b_buf_tensor[l0b_load_idx * round_k * BLOCK_SIZE], + l1v_buf_addr_tensor[l0b_load_idx * CUBE_MATRIX_SIZE], 0, + round_k / BLOCK_SIZE, // repeat + 0, + sv_round_n / BLOCK_SIZE, // srcStride + 0, + 0 // dstStride + ); + } + } + + SET_FLAG(MTE1, M, EVENT_ID6); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + WaitFlagDev(SOFTMAX_READY); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + if (qk_m == 1) { + gm_to_l1( + l1p_buf_addr_tensor, + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay - n_start) % vect_mod * TMP_SIZE / vect_mod) * + 2 / sizeof(mm2InputType)], + 1, 0, 0, RoundUp(sv_round_n, BlockSize()), // lenBurst + 0, 0); + } else { + gm_to_l1( + l1p_buf_addr_tensor, + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay - n_start) % vect_mod * TMP_SIZE / vect_mod) * + 2 / sizeof(mm2InputType)], + qk_m, // nValue + qk_round_m, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + sv_n, // dValue + 0, // dstNzMatrixStride, unused + sv_round_n * 2 / sizeof(mm2InputType) // srcDValue + ); + } + SET_FLAG(MTE2, MTE1, EVENT_ID3); + WAIT_FLAG(MTE2, MTE1, EVENT_ID3); + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + uint32_t round_row = + RoundUp(RoundUp(sv_round_n, BlockSize()), 32 / sizeof(mm2InputType)); + if (qk_m == 1) { + l1_to_l0_a( + l0a_buf_tensor, l1p_buf_addr_tensor, 0, + NumMatrixsRoundUp(round_row), // repeat + 0, + 1, // srcStride + 0, + 0 // dstStride + ); + } else { + l1_to_l0_a( + l0a_buf_tensor, l1p_buf_addr_tensor, qk_round_m, + round_row, // repeat + 0, + 0, // srcStride + 0, + 0 // dstStride + ); + } + SET_FLAG(MTE1, M, EVENT_ID5); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(MTE1, M, EVENT_ID5); + WAIT_FLAG(MTE1, M, EVENT_ID6); + WAIT_FLAG(FIX, M, l0c_pingpong_flag); + mmad(l0c_buf_tensor[l0c_offset], + l0a_buf_tensor, l0b_buf_tensor, + qk_m, // m + __k, // n + sv_n, // k + 1 // cmatrixInitVal + ); + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + SET_FLAG(M, MTE1, EVENT_ID2); + SET_FLAG(M, MTE1, EVENT_ID3); + SET_FLAG(M, FIX, l0c_pingpong_flag); + WAIT_FLAG(M, FIX, l0c_pingpong_flag); + // copy O to gm + l0c_to_gm( + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay - n_start) % vect_mod * TMP_SIZE / vect_mod], + l0c_buf_tensor[l0c_offset], + qk_m, // MSize + round_k, // NSize + qk_round_m, // srcStride + round_k // dstStride_dst_D + ); + SET_FLAG(FIX, M, l0c_pingpong_flag); + FftsCrossCoreSync(UPDATE_READY); + } + } + } + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(M, MTE1, EVENT_ID2); + WAIT_FLAG(M, MTE1, EVENT_ID3); + WAIT_FLAG(FIX, M, EVENT_ID0); + WAIT_FLAG(FIX, M, EVENT_ID1); + + PIPE_BARRIER(ALL); + } +}; +#elif __DAV_C220_VEC__ +constexpr int32_t BLOCK_SIZE = 16; +constexpr int32_t FLOAT_BLOCK_SIZE = 8; +constexpr int32_t LONG_SEQ_MASK_LEN = 128; +constexpr int32_t VECTOR_SIZE = 128; +constexpr int32_t COMPRESS_MASK_SIZE = 8192; +constexpr int32_t FLOAT_VECTOR_SIZE = 64; +constexpr int32_t UB_UINT8_BLOCK_SIZE = 16384; // 64 * 128 * 2B +constexpr int32_t UB_HALF_BUF_SIZE = 8192; // 64 * 128 +constexpr int32_t UB_UINT8_LINE_SIZE = 512; // 128 * 4B +constexpr int32_t UB_FLOAT_LINE_SIZE = 64; // 128 +constexpr int32_t UB_HALF_LINE_SIZE = 128; // UB_FLOAT_LINE_SIZE * 2 +constexpr int32_t TMP_SIZE = 32768 * 4; // 128 * 256 +constexpr int32_t TOTAL_UB_SIZE = 192 * 1024; +constexpr int32_t ROWMAX_TEMP_BUF_OFFSET = 1024; + +template +class UnpadAttentionDecoderAiv { + public: + template + __aicore__ __attribute__((always_inline)) inline uint32_t BlockSize() { + return 32 / sizeof(T); + } + + template + __aicore__ __attribute__((always_inline)) inline uint32_t MatrixSize() { + return 512 / sizeof(T); + } + + template + __aicore__ __attribute__((always_inline)) inline uint64_t BlockSizeRoundUp(uint64_t num) { + return (num + BlockSize() - 1) / BlockSize() * BlockSize(); + } + + template + __aicore__ __attribute__((always_inline)) inline uint64_t NumBlocksRoundUp(uint64_t num) { + return (num + BlockSize() - 1) / BlockSize(); + } + + __aicore__ __attribute__((always_inline)) inline void __set_mask(int32_t len) { + uint64_t mask = 0; + uint64_t one = 1; + uint64_t temp = len % FLOAT_VECTOR_SIZE; + for (int64_t i = 0; i < temp; i++) { + mask |= one << i; + } + + if (len == VECTOR_SIZE || len == 0) { + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else if (len >= FLOAT_VECTOR_SIZE) { + SetVectorMask(mask, (uint64_t)-1); + } else { + SetVectorMask(0x0, mask); + } + } + + __aicore__ __attribute__((always_inline)) inline void __set_vcg_mask(int32_t len) { + if (len > 16 || len < 1) { + SetVectorMask((uint64_t)-1, (uint64_t)-1); + return; + } + uint64_t subMask = ((uint64_t)1 << len) - 1; + uint64_t maskValue = (subMask << 48) + (subMask << 32) + (subMask << 16) + subMask; + SetVectorMask(maskValue, maskValue); + } + + template + __aicore__ __attribute__((always_inline)) inline uint32_t VectorSize() { + return 256 / sizeof(T); + } + + template + __aicore__ __attribute__((always_inline)) inline uint64_t NumVectorsRoundUp(uint64_t num) { + return (num + VectorSize() - 1) / VectorSize(); + } + + template + __aicore__ inline void RowMaxRepeatM(const AscendC::LocalTensor &dst, const AscendC::LocalTensor &src, + const AscendC::LocalTensor &tempTensor, const uint32_t &sub_m, + const uint32_t &qk_n, const uint32_t &qk_round_n) { + uint32_t T_BLOCK_SIZE = 32 / sizeof(T); + uint32_t T_VECTOR_SIZE = 256 / sizeof(T); + if (qk_n <= T_VECTOR_SIZE) { + __set_mask(qk_n); + cmax_v( + dst, src, + sub_m, // repeat + 1, // dstRepeatStride + 1, // srcBlockStride + qk_round_n / T_BLOCK_SIZE // srcRepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else { + ub_to_ub(tempTensor, src, + 0, // sid + sub_m, // nBurst + 8, // lenBurst + (qk_round_n - T_VECTOR_SIZE) / T_BLOCK_SIZE, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + for (uint32_t rowmax_idx = 1; rowmax_idx < qk_n / T_VECTOR_SIZE; ++rowmax_idx) { + max_v(tempTensor, tempTensor, src[rowmax_idx * T_VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + qk_round_n / T_BLOCK_SIZE // src1RepeatStride + ); + PIPE_BARRIER(V); + } + if (qk_n % T_VECTOR_SIZE > 0) { + __set_mask(qk_n % T_VECTOR_SIZE); + max_v(tempTensor, tempTensor, src[qk_n / T_VECTOR_SIZE * T_VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + qk_round_n / T_BLOCK_SIZE // src1RepeatStride + ); + } + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + cmax_v(dst, tempTensor, + sub_m, // repeat + 1, // dstRepeatStride + 1, // srcBlockStride + 8 // srcRepeatStride + ); + } + SetVectorMask((uint64_t)-1, (uint64_t)-1); + PIPE_BARRIER(V); + } + + template + __aicore__ inline void MulRepeatM(const AscendC::LocalTensor &dst, const AscendC::LocalTensor &src0, + const AscendC::LocalTensor &src1, const uint32_t &sub_m, const uint32_t &sub_n, + const uint32_t &sub_round_n) { + uint32_t T_BLOCK_SIZE = 32 / sizeof(T); + uint32_t T_VECTOR_SIZE = 256 / sizeof(T); + + for (uint32_t vmuls_idx = 0; vmuls_idx < sub_n / T_VECTOR_SIZE; ++vmuls_idx) { + mul_v(dst[vmuls_idx * T_VECTOR_SIZE], src0[vmuls_idx * T_VECTOR_SIZE], src1, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + sub_round_n / T_BLOCK_SIZE, // dstRepeatStride + sub_round_n / T_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (sub_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(sub_n % FLOAT_VECTOR_SIZE); + mul_v(dst[sub_n / T_VECTOR_SIZE * T_VECTOR_SIZE], + src0[sub_n / T_VECTOR_SIZE * T_VECTOR_SIZE], src1, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + sub_round_n / T_BLOCK_SIZE, // dstRepeatStride + sub_round_n / T_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + } + + template + __aicore__ inline void DivRepeatM(const AscendC::LocalTensor &dst, const AscendC::LocalTensor &src0, + const AscendC::LocalTensor &src1, const uint32_t &sub_m, const uint32_t &sub_n, + const uint32_t &sub_round_n) { + uint32_t temp = sizeof(T); + uint32_t T_BLOCK_SIZE = 32 / sizeof(T); + uint32_t T_VECTOR_SIZE = 256 / sizeof(T); + + for (uint32_t vdiv_idx = 0; vdiv_idx < sub_n / T_VECTOR_SIZE; ++vdiv_idx) { + div_v(dst[vdiv_idx * T_VECTOR_SIZE], src0[vdiv_idx * T_VECTOR_SIZE], src1, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + sub_round_n / T_BLOCK_SIZE, // dstRepeatStride + sub_round_n / T_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (sub_n % T_VECTOR_SIZE > 0) { + __set_mask(sub_n % T_VECTOR_SIZE); + div_v(dst[sub_n / T_VECTOR_SIZE * T_VECTOR_SIZE], + src0[sub_n / T_VECTOR_SIZE * T_VECTOR_SIZE], src1, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + sub_round_n / T_BLOCK_SIZE, // dstRepeatStride + sub_round_n / T_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + } + + template + __aicore__ inline void SymmetricQuant(const AscendC::LocalTensor &dst, const AscendC::LocalTensor &src, + const AscendC::LocalTensor &scale_ubuf_tensor, + const AscendC::LocalTensor &tempTensor, + const AscendC::GlobalTensor &quant_p_gm_tensor, + const AscendC::LocalTensor &lm_ubuf_tensor, + const AscendC::LocalTensor &hm_ubuf_tensor, const uint32_t &head_idx, + const uint32_t &quantType, const uint32_t &m_split, + const uint32_t &round_m_split, const uint32_t &qk_n, const uint32_t &qk_round_n, + const uint32_t &n_idx) { + if (quantType == 3) { + if (n_idx == 0) { + AscendC::Duplicate(scale_ubuf_tensor, (T)((T)1.0 / (T)127), round_m_split); + } else { + AscendC::Sub(tempTensor.template ReinterpretCast(), lm_ubuf_tensor, hm_ubuf_tensor, m_split); + PIPE_BARRIER(V); + AscendC::Cast(tempTensor, tempTensor.template ReinterpretCast(), AscendC::RoundMode::CAST_NONE, m_split); + PIPE_BARRIER(V); + AscendC::Exp(tempTensor, tempTensor, m_split); + PIPE_BARRIER(V); + AscendC::Muls(scale_ubuf_tensor, tempTensor, (T)((T)1 / (T)127), m_split); + } + + PIPE_BARRIER(V); + + brcb_v(tempTensor, scale_ubuf_tensor, + 1, // dstBlockStride + 8, // dstRepeatStride + round_m_split / 8 // repeat + ); + PIPE_BARRIER(V); + + DivRepeatM(dst, src, tempTensor, m_split, qk_n, qk_round_n); + } else { + float value = quant_p_gm_tensor.GetValue(head_idx); + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + AscendC::Muls(dst, src, (float)((float)1.0 / (float)value), round_m_split * qk_round_n); + PIPE_BARRIER(V); + } + } + + template + __aicore__ inline void SymmetricDeQuant(const AscendC::LocalTensor &lo_ubuf_tensor, + const AscendC::LocalTensor &scale_ubuf_tensor, + const AscendC::LocalTensor &tv_ubuf_tensor, + const AscendC::GlobalTensor &deq_pv_gm_tensor, + const AscendC::GlobalTensor &quant_p_gm_tensor, uint32_t head_idx, + uint32_t sub_m, uint32_t round_sub_m, uint32_t qk_n, uint32_t qk_round_n, + uint32_t quantType) { + conv_v(lo_ubuf_tensor, lo_ubuf_tensor.template ReinterpretCast(), + NumVectorsRoundUp(sub_m * qk_round_n), // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + float deq_pv = deq_pv_gm_tensor.GetValue(head_idx); + if (quantType == 3) { + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + AscendC::Muls(scale_ubuf_tensor, scale_ubuf_tensor, deq_pv, sub_m); + + PIPE_BARRIER(V); + + brcb_v(tv_ubuf_tensor, scale_ubuf_tensor, + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / 8 // repeat + ); + PIPE_BARRIER(V); + + MulRepeatM(lo_ubuf_tensor, lo_ubuf_tensor, tv_ubuf_tensor, sub_m, qk_n, qk_round_n); + + PIPE_BARRIER(V); + } else { + float quant_p = quant_p_gm_tensor.GetValue(head_idx); + float value = deq_pv * quant_p; + SET_FLAG(S, V, EVENT_ID0); + WAIT_FLAG(S, V, EVENT_ID0); + muls_v( + lo_ubuf_tensor, lo_ubuf_tensor, (float)value, + (sub_m * qk_round_n + VectorSize() - 1) / VectorSize(), // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + } + + __aicore__ __attribute__((always_inline)) inline void Run( + __gm__ uint8_t *__restrict__ sync, __gm__ uint8_t *__restrict__ q_gm, __gm__ uint8_t *__restrict__ k_gm, + __gm__ uint8_t *__restrict__ v_gm, __gm__ uint8_t *__restrict__ layerID_gm, __gm__ uint8_t *__restrict__ mask_gm, + __gm__ uint8_t *__restrict__ alibi_coeff_gm, __gm__ uint8_t *__restrict__ deq_qk_gm, + __gm__ uint8_t *__restrict__ off_qk_gm, __gm__ uint8_t *__restrict__ deq_pv_gm, + __gm__ uint8_t *__restrict__ off_pv_gm, __gm__ uint8_t *__restrict__ quant_p_gm, + __gm__ uint8_t *__restrict__ logN_gm, __gm__ uint8_t *__restrict__ o_gm, __gm__ uint8_t *__restrict__ s_gm, + __gm__ uint8_t *__restrict__ p_gm, __gm__ uint8_t *__restrict__ o_tmp_gm, __gm__ uint8_t *__restrict__ upo_tmp_gm, + __gm__ uint8_t *__restrict__ tiling_para_gm) { + SetFftsBaseAddr((unsigned long)sync); + int32_t sub_block_idx = GetSubBlockidx(); + SetAtomicnone(); + SetMasknorm(); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + + const uint32_t ls_ubuf_offset = 0; + const uint32_t lp_ubuf_offset = 0; + const uint32_t ls32_ubuf_offset = 2 * UB_UINT8_BLOCK_SIZE; + const uint32_t mask_ubuf_offset = 4 * UB_UINT8_BLOCK_SIZE; + const uint32_t lo_ubuf_offset = 6 * UB_UINT8_BLOCK_SIZE; + const uint32_t lm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE; + const uint32_t hm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 1 * UB_UINT8_LINE_SIZE; + const uint32_t gm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 2 * UB_UINT8_LINE_SIZE; + const uint32_t dm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 4 * UB_UINT8_LINE_SIZE; + const uint32_t ll_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 8 * UB_UINT8_LINE_SIZE; + const uint32_t gl_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 16 * UB_UINT8_LINE_SIZE; + const uint32_t scale_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 18 * UB_UINT8_LINE_SIZE; + const uint32_t log_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 30 * UB_UINT8_LINE_SIZE; + + const uint32_t tv_ubuf_offset = 11 * UB_UINT8_BLOCK_SIZE; + + const uint32_t go_ubuf_offset = 9 * UB_UINT8_BLOCK_SIZE; + const uint32_t ls32_quant_ubuf_offset = 6 * UB_UINT8_BLOCK_SIZE; + + AsdopsBuffer buf; + + AscendC::LocalTensor ls_ubuf_tensor = buf.GetBuffer(ls_ubuf_offset); + AscendC::LocalTensor lp_ubuf_tensor = + buf.GetBuffer(lp_ubuf_offset); + AscendC::LocalTensor lp_int8_ubuf_tensor = buf.GetBuffer(lp_ubuf_offset); + AscendC::LocalTensor lp32_ubuf_tensor = buf.GetBuffer(lp_ubuf_offset); + AscendC::LocalTensor ls32_ubuf_tensor = buf.GetBuffer(ls32_ubuf_offset); + AscendC::LocalTensor ls32_quant_ubuf_tensor = + buf.GetBuffer(ls32_quant_ubuf_offset); + AscendC::LocalTensor mask_ubuf_tensor = buf.GetBuffer(mask_ubuf_offset); + AscendC::LocalTensor mask_value_ubuf_tensor = + buf.GetBuffer(11 * UB_UINT8_BLOCK_SIZE); + AscendC::LocalTensor mask_u8_ubuf_tensor = + buf.GetBuffer(11 * UB_UINT8_BLOCK_SIZE); + AscendC::LocalTensor lo_ubuf_tensor = buf.GetBuffer(lo_ubuf_offset); + AscendC::LocalTensor lm_ubuf_tensor = buf.GetBuffer(lm_ubuf_offset); + AscendC::LocalTensor hm_ubuf_tensor = buf.GetBuffer(hm_ubuf_offset); + AscendC::LocalTensor gm_ubuf_tensor = buf.GetBuffer(gm_ubuf_offset); + AscendC::LocalTensor dm_ubuf_tensor = buf.GetBuffer(dm_ubuf_offset); + AscendC::LocalTensor ll_ubuf_tensor = buf.GetBuffer(ll_ubuf_offset); + AscendC::LocalTensor gl_ubuf_tensor = buf.GetBuffer(gl_ubuf_offset); + AscendC::LocalTensor tv_ubuf_tensor = buf.GetBuffer(tv_ubuf_offset); + AscendC::LocalTensor tv32_ubuf_tensor = buf.GetBuffer(tv_ubuf_offset); + AscendC::LocalTensor tv16_ubuf_tensor = buf.GetBuffer(tv_ubuf_offset); + AscendC::LocalTensor tv_u8_ubuf_tensor = buf.GetBuffer(tv_ubuf_offset); + AscendC::LocalTensor go_ubuf_tensor = buf.GetBuffer(go_ubuf_offset); + AscendC::LocalTensor scale_ubuf_tensor = buf.GetBuffer(scale_ubuf_offset); + AscendC::LocalTensor log_ubuf_tensor = buf.GetBuffer(log_ubuf_offset); + + AscendC::GlobalTensor mask_gm_tensor; + AscendC::GlobalTensor mask_u8_gm_tensor; + AscendC::GlobalTensor o_gm_tensor; + AscendC::GlobalTensor s_gm_tensor; + AscendC::GlobalTensor p_gm_tensor; + AscendC::GlobalTensor o_tmp_gm_tensor; + AscendC::GlobalTensor deq_qk_gm_tensor; + AscendC::GlobalTensor deq_pv_gm_tensor; + AscendC::GlobalTensor quant_p_gm_tensor; + AscendC::GlobalTensor logN_gm_tensor; + + mask_u8_gm_tensor.SetGlobalBuffer(mask_gm); + mask_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(mask_gm)); + o_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(o_gm)); + s_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(s_gm)); + p_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mm2InputType *>(p_gm)); + o_tmp_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mmOutputType *>(o_tmp_gm)); + deq_qk_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(deq_qk_gm)); + deq_pv_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(deq_pv_gm)); + quant_p_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(quant_p_gm)); + logN_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(logN_gm)); + + uint32_t go_flag_scalar = 1; + uint32_t batch_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm)); + uint32_t max_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1)); + uint32_t q_heads = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2)); + uint32_t embd = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3)); + half tor = (half)(*((__gm__ float *)tiling_para_gm + 5)); + uint32_t head_stride = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 6)); + uint32_t mask_stride = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 7)); + uint32_t is_triu_mask = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 8)); + uint32_t total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 9)); + uint32_t isClamp = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 10)); + half clampMin = (half)(*((__gm__ float *)tiling_para_gm + 11)); + half clampMax = (half)(*((__gm__ float *)tiling_para_gm + 12)); + uint32_t tiling_head_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 14)); + uint32_t tiling_para_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 15)); + uint32_t long_seq = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 17)); + uint32_t mask_type = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 20)); + uint32_t quantType = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 24)); + uint32_t data_shape_type = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 25)); + uint32_t window_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 27)); + + uint64_t stride_qo = q_heads * embd; + + if (data_shape_type == 1) { + stride_qo = embd; + } + + uint32_t __k = embd; + uint32_t round_k = (__k + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + SET_FLAG(MTE3, MTE2, EVENT_ID0); + SET_FLAG(MTE3, MTE2, EVENT_ID1); + SET_FLAG(MTE3, MTE2, EVENT_ID2); + SET_FLAG(V, MTE2, EVENT_ID0); + SET_FLAG(V, MTE2, EVENT_ID1); + SET_FLAG(V, MTE2, EVENT_ID2); + SET_FLAG(V, MTE2, EVENT_ID3); + SET_FLAG(V, MTE2, EVENT_ID4); + SET_FLAG(MTE3, V, EVENT_ID0); + SET_FLAG(MTE3, V, EVENT_ID1); + SET_FLAG(V, MTE2, EVENT_ID7); + + uint64_t cur_batch = 0; + uint32_t pre_total_q_blk_num = 0; + uint32_t offset_tiling = tiling_head_size + tiling_para_size * cur_batch; + uint32_t cur_total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 13 + offset_tiling)); + uint32_t process_num = total_q_blk_num * q_heads; + uint32_t next_process = 0; + for (uint32_t process = block_idx; process < process_num; process = next_process) { + while (process >= cur_total_q_blk_num * q_heads) { + cur_batch++; + pre_total_q_blk_num = cur_total_q_blk_num; + offset_tiling += tiling_para_size; + cur_total_q_blk_num = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 13 + offset_tiling)); + } + next_process = process + block_num; + if (is_triu_mask) { + uint32_t curr_iter = process / block_num; + next_process = + curr_iter % 2 == 1 ? (curr_iter + 1) * block_num + block_idx : (curr_iter + 2) * block_num - 1 - block_idx; + } + + // get tiling args + uint32_t q_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + offset_tiling)); + uint32_t kv_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1 + offset_tiling)); + if (q_seqlen == 0 || kv_seqlen == 0) { + continue; + } + uint32_t pp_m_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2 + offset_tiling)); + uint32_t pp_n_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3 + offset_tiling)); + uint32_t addr_o_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 10 + offset_tiling)); + uint32_t addr_o_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 11 + offset_tiling)); + uint64_t addr_o_scalar = (uint64_t)(((uint64_t)addr_o_high32) << 32 | addr_o_loww32); + ScaleType scaleType = (ScaleType)(*((__gm__ int32_t *)tiling_para_gm + 26)); + + uint32_t process_idx = process - pre_total_q_blk_num * q_heads; + uint32_t m_idx = process_idx / q_heads; + uint64_t head_idx = process_idx % q_heads; + + uint32_t m_loop = (q_seqlen + pp_m_scalar - 1) / pp_m_scalar; + uint32_t n_loop = (kv_seqlen + pp_n_scalar - 1) / pp_n_scalar; + + uint32_t qk_m = (m_idx == (m_loop - 1)) ? (q_seqlen - m_idx * pp_m_scalar) : pp_m_scalar; + uint32_t sub_m = (sub_block_idx == 1) ? (qk_m - qk_m / 2) : qk_m / 2; + uint32_t sub_m_d128 = (sub_m + VECTOR_SIZE - 1) / VECTOR_SIZE; // up aligned to 128 + uint32_t sub_m_d64 = (sub_m + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE; // up aligned to 64 + uint32_t round_sub_m = (sub_m + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + /******** pre_load *******/ + uint32_t qk_n = n_loop == 1 ? kv_seqlen : pp_n_scalar; + uint32_t qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + uint32_t pingpong_flag = 0; + uint32_t offset = pingpong_flag * UB_HALF_BUF_SIZE; + uint64_t mask_batch_offset = cur_batch * mask_stride * max_seqlen; + uint64_t mask_head_offset = head_idx * ((uint64_t)head_stride) * max_seqlen; + uint64_t mask_offset = mask_batch_offset + mask_head_offset; + if (long_seq == 0) { + mask_offset += m_idx * pp_m_scalar * max_seqlen; + } else { + gm_to_ub(mask_ubuf_tensor, + mask_gm_tensor[(uint64_t)sub_block_idx * qk_m / 2 * VECTOR_SIZE], + 0, // sid + sub_m, // nBurst + VECTOR_SIZE / BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + } + + uint64_t o_offset = addr_o_scalar + head_idx * embd + m_idx * pp_m_scalar * stride_qo; + if (data_shape_type == 1) { + o_offset = addr_o_scalar + head_idx * embd * max_seqlen + m_idx * pp_m_scalar * stride_qo; + } + + uint32_t n_end = n_loop; + if ((is_triu_mask && pp_m_scalar >= pp_n_scalar) || window_size > 0) { + uint32_t n_offset = ((m_idx + 1) * pp_m_scalar + kv_seqlen - q_seqlen + pp_n_scalar - 1) / pp_n_scalar; + n_end = n_offset > n_end ? n_end : n_offset; + } + uint32_t window_start = (window_size + pp_n_scalar - 1) / pp_n_scalar; + uint32_t n_start = 0; + if constexpr (swa_flag) { + if (window_size > 0 && window_size < kv_seqlen) { + n_start = (m_idx < window_start) ? 0 : m_idx - window_start; + if constexpr (!swa_compress) { + mask_offset += n_start * pp_n_scalar; + } + } + } + uint32_t qk_n_triu = n_end * pp_n_scalar; + uint32_t s_block_stack = n_end > NO_STACK_S_BLOCK_LIMIT ? 2 : 1; + uint32_t pv_stage = 3; + uint32_t launch_delay = s_block_stack * 2; + uint32_t vect_mod = 2 * launch_delay; + uint32_t m_slice = sub_m > 32 ? 32 : 0; // s_block_stack=2时,UB可以放下 + uint32_t m_end = sub_m > 32 ? 2 : 1; + + for (uint32_t n_idx = n_start; n_idx < n_end + launch_delay; n_idx += s_block_stack) { + if (n_idx < n_end) { + uint32_t p_scale_offset = + n_idx / s_block_stack % pv_stage * RoundUp(pp_m_scalar, FLOAT_VECTOR_SIZE); + if (n_idx + s_block_stack > n_end - 1) { + qk_n = qk_n_triu > kv_seqlen ? kv_seqlen - n_idx * pp_n_scalar : qk_n_triu - n_idx * pp_n_scalar; + } else { + qk_n = pp_n_scalar * s_block_stack; + } + uint32_t delta_idx = m_idx - n_idx; + bool skip_mask = window_start > 3 && delta_idx > 1 && delta_idx < window_start - 1; + if constexpr (swa_compress) { + mask_offset = 0; + if (window_start <= 3) { // window < 128*3 + if (m_idx < n_idx) { + mask_offset = pp_n_scalar; // swa with midx 0 && mask_type != 0 && long_seq == 0) { + if (qk_n <= pp_n_scalar) { + pingpong_flag = (n_idx - n_start) % 2; + offset = pingpong_flag * UB_HALF_BUF_SIZE; + WAIT_FLAG(V, MTE2, pingpong_flag + 2); + if constexpr (swa_compress) { + if (!skip_mask) { + gm_to_ub_align( + mask_ubuf_tensor[offset], + mask_gm_tensor[mask_offset + (uint64_t)sub_block_idx * qk_m / 2 * max_seqlen], + 0, // sid + sub_m, // nBurst + qk_n * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + (max_seqlen - qk_n) * 2, // srcGap + 0 // dstGap + ); + } + } else { + gm_to_ub_align( + mask_ubuf_tensor[offset], + mask_gm_tensor[mask_offset + (uint64_t)sub_block_idx * qk_m / 2 * max_seqlen], + 0, // sid + sub_m, // nBurst + qk_n * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + (max_seqlen - qk_n) * 2, // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE2, V, pingpong_flag + 2); + } else { + WAIT_FLAG(V, MTE2, EVENT_ID2); + if constexpr (swa_compress) { + if (!skip_mask) { + gm_to_ub_align( + mask_ubuf_tensor, mask_gm_tensor[mask_offset + (uint64_t)sub_block_idx * qk_m / 2 * max_seqlen], + 0, // sid + sub_m, // nBurst + qk_n * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + (max_seqlen - qk_n) * 2, // srcGap + 0 // dstGap + ); + } + } else { + gm_to_ub_align( + mask_ubuf_tensor, mask_gm_tensor[mask_offset + (uint64_t)sub_block_idx * qk_m / 2 * max_seqlen], + 0, // sid + sub_m, // nBurst + qk_n * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + (max_seqlen - qk_n) * 2, // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE2, V, EVENT_ID2); + } + if constexpr (!swa_compress) { + mask_offset += qk_n; + } + } + WaitFlagDev(QK_READY); + uint32_t qk_n_reduce_sum = qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE; + if (qk_n <= VECTOR_SIZE) { + pingpong_flag = (n_idx - n_start) % 2; + offset = pingpong_flag * UB_HALF_BUF_SIZE; + if (sub_m > 0) { + // int32_t + WAIT_FLAG(MTE3, MTE2, pingpong_flag); + if (s_block_stack == 2) { + WAIT_FLAG(MTE3, MTE2, 1 - pingpong_flag); + } + // input QK + gm_to_ub( + ls_ubuf_tensor[offset], + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + (n_idx - n_start) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n], + 0, // sid + sub_m, // nBurst + qk_round_n / BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID0); + if (scaleType == ScaleType::SCALE_LOGN) { + WAIT_FLAG(V, MTE2, EVENT_ID7); + gm_to_ub_align( + log_ubuf_tensor, logN_gm_tensor[m_idx * pp_m_scalar + (uint64_t)sub_block_idx * qk_m / 2], + 0, // sid + 1, // nBurst + sub_m * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap byte + (round_sub_m - sub_m) * 2 // dstGap block + ); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + brcb_v(tv_ubuf_tensor.ReinterpretCast()[VECTOR_SIZE], + log_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + (sub_m + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + WAIT_FLAG(MTE2, V, EVENT_ID0); + for (uint32_t vdiv_idx = 0; vdiv_idx < qk_n / VECTOR_SIZE; ++vdiv_idx) { + mul_v(ls_ubuf_tensor[offset + vdiv_idx * VECTOR_SIZE], + ls_ubuf_tensor[offset + vdiv_idx * VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast()[VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / BLOCK_SIZE, // dstRepeatStride + qk_round_n / BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + PIPE_BARRIER(V); + if (qk_n % VECTOR_SIZE > 0) { + __set_mask(qk_n % VECTOR_SIZE); + mul_v(ls_ubuf_tensor[offset + qk_n / VECTOR_SIZE * VECTOR_SIZE], + ls_ubuf_tensor[offset + qk_n / VECTOR_SIZE * VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast()[VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / BLOCK_SIZE, // dstRepeatStride + qk_round_n / BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + muls_v(ls_ubuf_tensor[offset], ls_ubuf_tensor[offset], tor, + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + SET_FLAG(V, MTE2, EVENT_ID7); + } else { + WAIT_FLAG(MTE2, V, EVENT_ID0); + // *** ls = tor * ls + muls_v(ls_ubuf_tensor[offset], ls_ubuf_tensor[offset], tor, + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } + PIPE_BARRIER(V); + if (isClamp == 1) { + // get min(clampMin,ls_ubuf) + maxs_v(ls_ubuf_tensor[offset], ls_ubuf_tensor[offset], clampMin, + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + + // get max(clampMin,ls_ubuf) + mins_v(ls_ubuf_tensor[offset], ls_ubuf_tensor[offset], clampMax, + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + // *** ls = ls + mask + if (mask_type != 0) { + if (long_seq == 0) { + WAIT_FLAG(MTE2, V, pingpong_flag + 2); + if constexpr (swa_compress) { + if (!skip_mask) { + add_v( + ls_ubuf_tensor[offset], ls_ubuf_tensor[offset], mask_ubuf_tensor[offset], + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + } + } else { + add_v(ls_ubuf_tensor[offset], ls_ubuf_tensor[offset], + mask_ubuf_tensor[offset], + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + } + SET_FLAG(V, MTE2, pingpong_flag + 2); + } else if (pp_n_scalar == FLOAT_VECTOR_SIZE && s_block_stack == 2 && n_idx == n_end - 2) { + __set_mask(qk_n - FLOAT_VECTOR_SIZE); + add_v(ls_ubuf_tensor[offset + FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[offset + FLOAT_VECTOR_SIZE], mask_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / BLOCK_SIZE, // dstRepeatStride + qk_round_n / BLOCK_SIZE, // src0RepeatStride + 8 // src1RepeatStride + ); + } else if (n_idx == n_end - 1) { + __set_mask(qk_n); + add_v(ls_ubuf_tensor[offset], ls_ubuf_tensor[offset], mask_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / BLOCK_SIZE, // dstRepeatStride + qk_round_n / BLOCK_SIZE, // src0RepeatStride + 8 // src1RepeatStride + ); + } + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + // *** lm = rowmax(ls) + if (qk_n <= VECTOR_SIZE) { + __set_mask(qk_n); + cgmax_v(tv_ubuf_tensor.ReinterpretCast(), ls_ubuf_tensor[offset], + sub_m, 2, 1, qk_round_n / BLOCK_SIZE); + PIPE_BARRIER(V); + __set_vcg_mask(qk_round_n / BLOCK_SIZE); + cgmax_v(lm_ubuf_tensor, tv_ubuf_tensor.ReinterpretCast(), + (sub_m * BLOCK_SIZE + VECTOR_SIZE - 1) / VECTOR_SIZE, 1, 1, 8); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else { + cgmax_v(tv_ubuf_tensor.ReinterpretCast(), ls_ubuf_tensor[offset], + sub_m, 2, 1, qk_round_n / BLOCK_SIZE); + PIPE_BARRIER(V); + __set_mask(qk_n - VECTOR_SIZE); + cgmax_v(tv_ubuf_tensor.ReinterpretCast()[ROWMAX_TEMP_BUF_OFFSET], + ls_ubuf_tensor[offset + VECTOR_SIZE], sub_m, 2, 1, + qk_round_n / BLOCK_SIZE); + PIPE_BARRIER(V); + __set_vcg_mask((qk_round_n - VECTOR_SIZE) / BLOCK_SIZE); + max_v(tv_ubuf_tensor.ReinterpretCast(), + tv_ubuf_tensor.ReinterpretCast(), + tv_ubuf_tensor.ReinterpretCast()[ROWMAX_TEMP_BUF_OFFSET], + (sub_m * BLOCK_SIZE + VECTOR_SIZE - 1) / VECTOR_SIZE, + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + PIPE_BARRIER(V); + __set_vcg_mask(VECTOR_SIZE / BLOCK_SIZE); + cgmax_v(lm_ubuf_tensor, tv_ubuf_tensor.ReinterpretCast(), + (sub_m * BLOCK_SIZE + VECTOR_SIZE - 1) / VECTOR_SIZE, 1, 1, 8); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + if (n_idx == n_start) { + // *** hm = lm + ub_to_ub(hm_ubuf_tensor, lm_ubuf_tensor, + 0, // sid + 1, // nBurst + round_sub_m / BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + } else { + // *** hm = vmax(lm, gm) + max_v(hm_ubuf_tensor, lm_ubuf_tensor, gm_ubuf_tensor, + sub_m_d128, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** dm = gm - hm + sub_v( + dm_ubuf_tensor[((n_idx - n_start) / s_block_stack) % 4 * UB_HALF_LINE_SIZE], gm_ubuf_tensor, + hm_ubuf_tensor, + sub_m_d128, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + } + // *** gm = hm + ub_to_ub(gm_ubuf_tensor, hm_ubuf_tensor, + 0, // sid + 1, // nBurst + round_sub_m / BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + // *** hm_block = expand_to_block(hm), 存放于 tv + brcb_v(tv_ubuf_tensor.ReinterpretCast(), + hm_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** ls = ls - hm_block + for (uint32_t vsub_idx = 0; vsub_idx < qk_n / VECTOR_SIZE; ++vsub_idx) { + sub_v(ls_ubuf_tensor[offset + vsub_idx * VECTOR_SIZE], + ls_ubuf_tensor[offset + vsub_idx * VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast(), + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / BLOCK_SIZE, // dstRepeatStride + qk_round_n / BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (qk_n % VECTOR_SIZE > 0) { + __set_mask(qk_n % VECTOR_SIZE); + sub_v(ls_ubuf_tensor[offset + qk_n / VECTOR_SIZE * VECTOR_SIZE], + ls_ubuf_tensor[offset + qk_n / VECTOR_SIZE * VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast(), + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / BLOCK_SIZE, // dstRepeatStride + qk_round_n / BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + // *** ls = castfp16to32(ls) + conv_v( + ls32_ubuf_tensor, ls_ubuf_tensor[offset], + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 4 // srcRepeatStride + ); + PIPE_BARRIER(V); + // *** ls = exp(ls) + exp_v( + ls32_ubuf_tensor, ls32_ubuf_tensor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + + PIPE_BARRIER(V); + if constexpr (int8_flag) { + WAIT_FLAG(V, MTE2, EVENT_ID4); + SymmetricQuant(ls32_quant_ubuf_tensor, ls32_ubuf_tensor, scale_ubuf_tensor[p_scale_offset], + tv32_ubuf_tensor, quant_p_gm_tensor, lm_ubuf_tensor, hm_ubuf_tensor, head_idx, quantType, + sub_m, round_sub_m, qk_n, qk_round_n, n_idx); + AscendC::Cast( + lp_ubuf_tensor.template ReinterpretCast()[offset], + ls32_quant_ubuf_tensor.template ReinterpretCast(), AscendC::RoundMode::CAST_RINT, (uint64_t)0, + (sub_m * qk_round_n + VectorSize() - 1) / VectorSize(), {1, 1, 4, 8}); + PIPE_BARRIER(V); + SET_FLAG(V, MTE2, EVENT_ID4); + for (uint32_t row_idx = 0; row_idx < qk_n / VectorSize(); ++row_idx) { + AscendC::Cast( + lp_ubuf_tensor.template ReinterpretCast()[offset * 2 + row_idx * VectorSize()], + lp_ubuf_tensor.template ReinterpretCast()[offset + row_idx * VectorSize()], + AscendC::RoundMode::CAST_RINT, (uint64_t)0, sub_m, + {1, 1, (uint8_t)(qk_round_n / BlockSize()), (uint8_t)(qk_round_n / BlockSize())}); + } + PIPE_BARRIER(V); + if (qk_n % VectorSize() > 0) { + __set_mask(qk_n % VectorSize()); + AscendC::Cast( + lp_ubuf_tensor + .template ReinterpretCast()[offset * 2 + qk_n / VectorSize() * VectorSize()], + lp_ubuf_tensor + .template ReinterpretCast()[offset + qk_n / VectorSize() * VectorSize()], + AscendC::RoundMode::CAST_RINT, (uint64_t)0, sub_m, + {1, 1, (uint8_t)(qk_round_n / BlockSize()), (uint8_t)(qk_round_n / BlockSize())}); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + } else { + // *** lp = castfp32to16(ls) + conv_v( + lp_ubuf_tensor[offset], ls32_ubuf_tensor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + } + + PIPE_BARRIER(V); + SET_FLAG(V, MTE3, EVENT_ID0); + // *** ll = rowsum(ls32) + if (qk_n <= FLOAT_VECTOR_SIZE) { + __set_mask(qk_n); + cadd_v( + ll_ubuf_tensor[((n_idx - n_start) / s_block_stack) % 4 * UB_FLOAT_LINE_SIZE], ls32_ubuf_tensor, + sub_m, // repeat + 1, // dstRepeatStride + 1, // srcBlockStride + qk_round_n / FLOAT_BLOCK_SIZE // srcRepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else { + for (uint32_t rowsum_idx = 1; rowsum_idx < qk_n / FLOAT_VECTOR_SIZE; ++rowsum_idx) { + add_v(ls32_ubuf_tensor, ls32_ubuf_tensor, + ls32_ubuf_tensor[rowsum_idx * FLOAT_VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + qk_round_n / FLOAT_BLOCK_SIZE // src1RepeatStride + ); + PIPE_BARRIER(V); + } + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + add_v(ls32_ubuf_tensor, ls32_ubuf_tensor, + ls32_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + qk_round_n / FLOAT_BLOCK_SIZE // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + cadd_v( + ll_ubuf_tensor[((n_idx - n_start) / s_block_stack) % 4 * UB_FLOAT_LINE_SIZE], ls32_ubuf_tensor, + sub_m, // repeat + 1, // dstRepeatStride + 1, // srcBlockStride + qk_round_n / FLOAT_BLOCK_SIZE // srcRepeatStride + ); + } + PIPE_BARRIER(V); + WAIT_FLAG(V, MTE3, EVENT_ID0); + if constexpr (int8_flag) { + ub_to_gm( + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n) * + 2 / sizeof(mm2InputType)], + lp_ubuf_tensor.template ReinterpretCast()[offset * 2], + 0, // sid + 1, // nBurst + sub_m * qk_round_n * 2 / BlockSize(), // lenBurst + 0, // srcGap + 0 // dstGap + ); + } else { + ub_to_gm( + p_gm_tensor[(uint64_t)block_idx * TMP_SIZE + (n_idx - n_start) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n], + lp_ubuf_tensor[offset], + 0, // sid + 1, // nBurst + sub_m * qk_round_n / BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + } + + SET_FLAG(MTE3, MTE2, pingpong_flag); + if (s_block_stack == 2) { + SET_FLAG(MTE3, MTE2, 1 - pingpong_flag); + } + } + } else { + bool last_n_loop = n_idx + s_block_stack > n_end - 1; + if (sub_m > 0) { + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + // input QK + gm_to_ub( + ls_ubuf_tensor, + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + (n_idx - n_start) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n], + 0, // sid + m_slice, // nBurst + qk_round_n / BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + if (sub_m > m_slice) { + if (m_end > 1) { + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + } + gm_to_ub( + ls_ubuf_tensor[m_slice * qk_round_n], + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + (n_idx - n_start) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n + m_slice * qk_round_n], + 0, // sid + sub_m - m_slice, // nBurst + qk_round_n / BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE2, V, EVENT_ID0); + if (scaleType == ScaleType::SCALE_LOGN) { + WAIT_FLAG(V, MTE2, EVENT_ID7); + gm_to_ub_align( + log_ubuf_tensor, logN_gm_tensor[m_idx * pp_m_scalar + (uint64_t)sub_block_idx * qk_m / 2], + 0, // sid + 1, // nBurst + sub_m * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap byte + (round_sub_m - sub_m) * 2 // dstGap block + ); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + brcb_v(tv_ubuf_tensor.ReinterpretCast()[VECTOR_SIZE], + log_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + (sub_m + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + WAIT_FLAG(MTE2, V, EVENT_ID0); + for (uint32_t vdiv_idx = 0; vdiv_idx < qk_n / VECTOR_SIZE; ++vdiv_idx) { + mul_v(ls_ubuf_tensor[vdiv_idx * VECTOR_SIZE], + ls_ubuf_tensor[vdiv_idx * VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast()[VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / BLOCK_SIZE, // dstRepeatStride + qk_round_n / BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + PIPE_BARRIER(V); + if (qk_n % VECTOR_SIZE > 0) { + __set_mask(qk_n % VECTOR_SIZE); + mul_v(ls_ubuf_tensor[qk_n / VECTOR_SIZE * VECTOR_SIZE], + ls_ubuf_tensor[qk_n / VECTOR_SIZE * VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast()[VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / BLOCK_SIZE, // dstRepeatStride + qk_round_n / BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + muls_v(ls_ubuf_tensor, ls_ubuf_tensor, tor, + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + SET_FLAG(V, MTE2, EVENT_ID7); + } else { + WAIT_FLAG(MTE2, V, EVENT_ID0); + // *** ls = tor * ls + muls_v(ls_ubuf_tensor, ls_ubuf_tensor, tor, + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } + PIPE_BARRIER(V); + if (mask_type != 0) { + if (long_seq == 0) { + WAIT_FLAG(MTE2, V, EVENT_ID2); + if constexpr (swa_compress) { + if (!skip_mask) { + add_v( + ls_ubuf_tensor, ls_ubuf_tensor, mask_ubuf_tensor, + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + } + } else { + add_v(ls_ubuf_tensor, ls_ubuf_tensor, mask_ubuf_tensor, + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + } + SET_FLAG(V, MTE2, EVENT_ID2); + } else if (n_idx == n_end - 2) { + __set_mask(qk_n - pp_n_scalar); + add_v(ls_ubuf_tensor[pp_n_scalar], ls_ubuf_tensor[pp_n_scalar], + mask_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / BLOCK_SIZE, // dstRepeatStride + qk_round_n / BLOCK_SIZE, // src0RepeatStride + 8 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + } + if (isClamp == 1) { + // get min(clampMin,ls_ubuf) + maxs_v(ls_ubuf_tensor, ls_ubuf_tensor, clampMin, + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + // get max(clampMin,ls_ubuf) + mins_v(ls_ubuf_tensor, ls_ubuf_tensor, clampMax, + (sub_m * qk_round_n + VECTOR_SIZE - 1) / VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + if (qk_n != SOFTMAX_MAX_LENGTH) { + ub_to_ub(ls32_ubuf_tensor.ReinterpretCast(), ls_ubuf_tensor, + 0, // sid + sub_m, // nBurst + VECTOR_SIZE / BLOCK_SIZE, // lenBurst + (qk_round_n - VECTOR_SIZE) / BLOCK_SIZE, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + __set_mask(qk_n - VECTOR_SIZE); + max_v(ls32_ubuf_tensor.ReinterpretCast(), + ls32_ubuf_tensor.ReinterpretCast(), + ls_ubuf_tensor[VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + qk_round_n / BLOCK_SIZE // src1RepeatStride + ); + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + cgmax_v(tv_ubuf_tensor.ReinterpretCast(), + ls32_ubuf_tensor.ReinterpretCast(), sub_m, 2, 1, 8); + PIPE_BARRIER(V); + __set_vcg_mask(VECTOR_SIZE / BLOCK_SIZE); + cgmax_v(lm_ubuf_tensor, tv_ubuf_tensor.ReinterpretCast(), + (sub_m * BLOCK_SIZE + VECTOR_SIZE - 1) / VECTOR_SIZE, 1, 1, 8); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else { + cgmax_v(ls32_ubuf_tensor.ReinterpretCast(), ls_ubuf_tensor, + sub_m * qk_round_n / VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + cgmax_v(lm_ubuf_tensor, ls32_ubuf_tensor.ReinterpretCast(), + (sub_m * BLOCK_SIZE + VECTOR_SIZE - 1) / VECTOR_SIZE, 1, 1, 8); + } + PIPE_BARRIER(V); + if (n_idx == n_start) { + // *** hm = lm + ub_to_ub(gm_ubuf_tensor, lm_ubuf_tensor, + 0, // sid + 1, // nBurst + round_sub_m / BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + brcb_v(tv_ubuf_tensor.ReinterpretCast(), + lm_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + } else { + // *** hm = vmax(lm, gm) + max_v(hm_ubuf_tensor, lm_ubuf_tensor, gm_ubuf_tensor, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** hm_block = expand_to_block(hm), 存放于 tv + brcb_v(tv_ubuf_tensor.ReinterpretCast(), + hm_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + // *** dm = gm - hm + sub_v( + dm_ubuf_tensor[((n_idx - n_start) / s_block_stack) % launch_delay * UB_HALF_LINE_SIZE], + gm_ubuf_tensor, hm_ubuf_tensor, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** gm = hm + ub_to_ub(gm_ubuf_tensor, hm_ubuf_tensor, + 0, // sid + 1, // nBurst + round_sub_m / BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + } + // *** ls = ls - hm_block + for (uint32_t vsub_idx = 0; vsub_idx < qk_n / VECTOR_SIZE; ++vsub_idx) { + sub_v(ls_ubuf_tensor[vsub_idx * VECTOR_SIZE], + ls_ubuf_tensor[vsub_idx * VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast(), + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / BLOCK_SIZE, // dstRepeatStride + qk_round_n / BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (qk_n % VECTOR_SIZE > 0) { + __set_mask(qk_n % VECTOR_SIZE); + sub_v(ls_ubuf_tensor[qk_n / VECTOR_SIZE * VECTOR_SIZE], + ls_ubuf_tensor[qk_n / VECTOR_SIZE * VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast(), + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / BLOCK_SIZE, // dstRepeatStride + qk_round_n / BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + for (uint32_t split_idx = 0; split_idx < m_end; split_idx++) { + bool last_m_loop = split_idx == m_end - 1; + uint32_t m_split = last_m_loop ? sub_m - split_idx * m_slice : m_slice; + uint32_t round_m_split = (m_split + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE * FLOAT_BLOCK_SIZE; + // *** ls = castfp16to32(ls) + conv_v( + ls32_ubuf_tensor, ls_ubuf_tensor[split_idx * m_slice * qk_round_n], + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 4 // srcRepeatStride + ); + PIPE_BARRIER(V); + // *** ls = exp(ls) + exp_v( + ls32_ubuf_tensor, ls32_ubuf_tensor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + if constexpr (int8_flag) { + WAIT_FLAG(V, MTE2, EVENT_ID4); + SymmetricQuant(ls32_quant_ubuf_tensor, ls32_ubuf_tensor, + scale_ubuf_tensor[p_scale_offset + split_idx * m_slice], tv32_ubuf_tensor, + quant_p_gm_tensor, lm_ubuf_tensor[split_idx * m_slice], + hm_ubuf_tensor[split_idx * m_slice], head_idx, quantType, m_split, round_m_split, qk_n, + qk_round_n, n_idx); + + AscendC::Cast( + lp_ubuf_tensor.template ReinterpretCast()[split_idx * m_slice * qk_round_n], + ls32_quant_ubuf_tensor, AscendC::RoundMode::CAST_RINT, (uint64_t)0, + (m_split * qk_round_n + VectorSize() - 1) / VectorSize(), {1, 1, 4, 8}); + PIPE_BARRIER(V); + SET_FLAG(V, MTE2, EVENT_ID4); + for (uint32_t row_idx = 0; row_idx < qk_n / VectorSize(); ++row_idx) { + AscendC::Cast( + lp_ubuf_tensor.template ReinterpretCast()[(split_idx * m_slice * qk_round_n) * 2 + + row_idx * VectorSize()], + lp_ubuf_tensor.template ReinterpretCast()[split_idx * m_slice * qk_round_n + + row_idx * VectorSize()], + AscendC::RoundMode::CAST_RINT, (uint64_t)0, m_split, + {1, 1, (uint8_t)(qk_round_n / BlockSize()), (uint8_t)(qk_round_n / BlockSize())}); + } + if (qk_n % VectorSize() > 0) { + __set_mask(qk_n % VectorSize()); + AscendC::Cast( + lp_ubuf_tensor.template ReinterpretCast()[(split_idx * m_slice * qk_round_n) * 2 + + qk_n / VectorSize() * VectorSize()], + lp_ubuf_tensor.template ReinterpretCast()[split_idx * m_slice * qk_round_n + + qk_n / VectorSize() * VectorSize()], + AscendC::RoundMode::CAST_RINT, (uint64_t)0, m_split, + {1, 1, (uint8_t)(qk_round_n / BlockSize()), (uint8_t)(qk_round_n / BlockSize())}); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } else { + // *** lp = castfp32to16(ls) + conv_v( + lp_ubuf_tensor[split_idx * m_slice * qk_round_n], ls32_ubuf_tensor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + } + + PIPE_BARRIER(V); + SET_FLAG(V, MTE3, EVENT_ID0); + // *** ll = rowsum(ls32) + for (uint32_t rowsum_idx = 1; rowsum_idx < qk_n / FLOAT_VECTOR_SIZE; ++rowsum_idx) { + add_v(ls32_ubuf_tensor, ls32_ubuf_tensor, + ls32_ubuf_tensor[rowsum_idx * FLOAT_VECTOR_SIZE], + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + qk_round_n / FLOAT_BLOCK_SIZE // src1RepeatStride + ); + PIPE_BARRIER(V); + } + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + add_v(ls32_ubuf_tensor, ls32_ubuf_tensor, + ls32_ubuf_tensor[qk_n_reduce_sum], + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + qk_round_n / FLOAT_BLOCK_SIZE // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + cadd_v( + ll_ubuf_tensor[((n_idx - n_start) / s_block_stack) % launch_delay * UB_FLOAT_LINE_SIZE + + split_idx * m_slice], + ls32_ubuf_tensor, + m_split, // repeat + 1, // dstRepeatStride + 1, // srcBlockStride + qk_round_n / FLOAT_BLOCK_SIZE // srcRepeatStride + ); + PIPE_BARRIER(V); + WAIT_FLAG(V, MTE3, EVENT_ID0); + if constexpr (int8_flag) { + ub_to_gm( + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + + ((uint64_t)sub_block_idx * qk_m / 2 + split_idx * m_slice) * qk_round_n) * + 2 / sizeof(mm2InputType)], + lp_ubuf_tensor.template ReinterpretCast()[(split_idx * m_slice * qk_round_n) * 2], + 0, // sid + m_split, // nBurst + qk_round_n * 2 / BlockSize(), // lenBurst + 0, // srcGap + 0 // dstGap + ); + } else { + ub_to_gm( + p_gm_tensor[(uint64_t)block_idx * TMP_SIZE + (n_idx - n_start) % vect_mod * TMP_SIZE / vect_mod + + ((uint64_t)sub_block_idx * qk_m / 2 + split_idx * m_slice) * qk_round_n], + lp_ubuf_tensor[split_idx * m_slice * qk_round_n], + 0, // sid + m_split, // nBurst + qk_round_n / BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + } + SET_FLAG(MTE3, MTE2, split_idx); + } + } + } + FftsCrossCoreSync(SOFTMAX_READY); + } + if (n_idx >= launch_delay + n_start) { + uint32_t p_scale_offset = + (n_idx - launch_delay) / s_block_stack % pv_stage * RoundUp(pp_m_scalar, FLOAT_VECTOR_SIZE); + WaitFlagDev(UPDATE_READY); // 4 + if (sub_m > 0) { + // *** 更新 L 和 O + if (n_idx != launch_delay + n_start) { + WAIT_FLAG(V, MTE2, EVENT_ID4); + + gm_to_ub( + lo_ubuf_tensor.template ReinterpretCast(), + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay - n_start) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * round_k], + 0, // sid + 1, // nBurst + sub_m * round_k / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID4); + // *** dm32 = castfp16to32(dm), 存放于 tv + conv_v( + tv_ubuf_tensor, + dm_ubuf_tensor[((n_idx - launch_delay - n_start) / s_block_stack % 4) * UB_HALF_LINE_SIZE], + sub_m_d64, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 4 // srcRepeatStride + ); + PIPE_BARRIER(V); + // *** dm = exp(dm) + exp_v(tv_ubuf_tensor, tv_ubuf_tensor, + sub_m_d64, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + // *** dm_block = expand_to_block(dm), 存放于 tv + brcb_v(tv_ubuf_tensor.ReinterpretCast()[VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** gl = dm * gl + mul_v(gl_ubuf_tensor, tv_ubuf_tensor, gl_ubuf_tensor, + sub_m_d64, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** gl = ll + gl + add_v( + gl_ubuf_tensor, gl_ubuf_tensor, + ll_ubuf_tensor[((n_idx - launch_delay - n_start) / s_block_stack % 4) * UB_FLOAT_LINE_SIZE], + sub_m_d64, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** go = go * dm_block + for (uint32_t vmul_idx = 0; vmul_idx < __k / FLOAT_VECTOR_SIZE; ++vmul_idx) { + mul_v(go_ubuf_tensor[vmul_idx * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[vmul_idx * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor[VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (__k % FLOAT_VECTOR_SIZE > 0) { + __set_mask(__k % FLOAT_VECTOR_SIZE); + mul_v(go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor[VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + WAIT_FLAG(MTE2, V, EVENT_ID4); + + if constexpr (int8_flag) { + SymmetricDeQuant(lo_ubuf_tensor, scale_ubuf_tensor[p_scale_offset], tv_ubuf_tensor, deq_pv_gm_tensor, + quant_p_gm_tensor, head_idx, sub_m, round_sub_m, __k, round_k, quantType); + } + // *** go = lo + go + add_v( + go_ubuf_tensor, go_ubuf_tensor, lo_ubuf_tensor, + (sub_m * round_k + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + SET_FLAG(V, MTE2, EVENT_ID4); + } else { + // *** gl = ll + ub_to_ub( + gl_ubuf_tensor, + ll_ubuf_tensor[((n_idx - launch_delay - n_start) / s_block_stack % 4) * UB_FLOAT_LINE_SIZE], + 0, // sid + 1, // nBurst + round_sub_m / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + WAIT_FLAG(MTE3, MTE2, EVENT_ID2); + gm_to_ub( + go_ubuf_tensor.template ReinterpretCast(), + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay - n_start) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * round_k], + 0, // sid + 1, // nBurst + sub_m * round_k / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID5); + WAIT_FLAG(MTE2, V, EVENT_ID5); + PIPE_BARRIER(V); + if constexpr (int8_flag) { + SymmetricDeQuant(go_ubuf_tensor, scale_ubuf_tensor[p_scale_offset], tv_ubuf_tensor, deq_pv_gm_tensor, + quant_p_gm_tensor, head_idx, sub_m, round_sub_m, __k, round_k, quantType); + } + PIPE_BARRIER(V); + } + if (n_idx + s_block_stack > n_end + launch_delay - 1) { + // *** gl = castfp32to16(gl) + conv_v(gl_ubuf_tensor.ReinterpretCast(), gl_ubuf_tensor, + sub_m_d64, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + // *** go = castfp32to16(go) + conv_v( + go_ubuf_tensor.ReinterpretCast(), go_ubuf_tensor, + (sub_m * round_k + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + // *** gl_block = expand_to_block(gl), 存放于 tv + brcb_v(tv_ubuf_tensor.ReinterpretCast(), + gl_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** go = go / gl_block + for (uint32_t vdiv_idx = 0; vdiv_idx < __k / VECTOR_SIZE; ++vdiv_idx) { + div_v(go_ubuf_tensor.ReinterpretCast()[vdiv_idx * VECTOR_SIZE], + go_ubuf_tensor.ReinterpretCast()[vdiv_idx * VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast(), + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / BLOCK_SIZE, // dstRepeatStride + round_k / BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (__k % VECTOR_SIZE > 0) { + __set_mask(__k % VECTOR_SIZE); + div_v( + go_ubuf_tensor.ReinterpretCast()[__k / VECTOR_SIZE * VECTOR_SIZE], + go_ubuf_tensor.ReinterpretCast()[__k / VECTOR_SIZE * VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast(), + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / BLOCK_SIZE, // dstRepeatStride + round_k / BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + // ********************* move O to GM ************************ + SET_FLAG(V, MTE3, EVENT_ID1); + WAIT_FLAG(V, MTE3, EVENT_ID1); + ub_to_gm_align( + o_gm_tensor[o_offset + (uint64_t)sub_block_idx * qk_m / 2 * stride_qo], + go_ubuf_tensor.ReinterpretCast(), + 0, // sid + sub_m, // nBurst + __k * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + (stride_qo - __k) * 2 // dstGap + ); + SET_FLAG(MTE3, MTE2, EVENT_ID2); + } + } + } + } + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + WAIT_FLAG(MTE3, MTE2, EVENT_ID2); + WAIT_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(V, MTE2, EVENT_ID1); + WAIT_FLAG(V, MTE2, EVENT_ID2); + WAIT_FLAG(V, MTE2, EVENT_ID3); + WAIT_FLAG(V, MTE2, EVENT_ID4); + WAIT_FLAG(MTE3, V, EVENT_ID0); + WAIT_FLAG(MTE3, V, EVENT_ID1); + WAIT_FLAG(V, MTE2, EVENT_ID7); + PIPE_BARRIER(ALL); + } +}; +#endif + +} // namespace unpda_fa_npd_half \ No newline at end of file diff --git a/ops/ascendc/kernel_common/op_kernel/asd/unpad_fa/unpad_flashattention_bf16_mix.cce b/ops/ascendc/kernel_common/op_kernel/asd/unpad_fa/unpad_flashattention_bf16_mix.cce new file mode 100644 index 0000000000000000000000000000000000000000..82681bb8863818769b014d5c0f81ff3dcd96c27e --- /dev/null +++ b/ops/ascendc/kernel_common/op_kernel/asd/unpad_fa/unpad_flashattention_bf16_mix.cce @@ -0,0 +1,4230 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "../common/include/common.h" +#include "../common/include/common_func.h" +#include "../common/include/simd.h" +#include "../common/include/iterator.h" +#include "../common/include/mma.h" +#include "../common/include/utils.h" +#include "kernel_operator.h" + +namespace unpda_fa_npd_bf16 { + +#ifdef __CCE_KT_TEST__ +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif + +// FFTS Flag +constexpr int32_t QK_READY = 1; +constexpr int32_t SOFTMAX_READY = 2; +constexpr int32_t UPDATE_READY = 3; +constexpr int32_t BIT_SHIFT = 8; +constexpr int32_t SOFTMAX_MAX_LENGTH = 256; +constexpr int32_t TMP_SIZE = 32768 * 8; // 128 * 256 + +#ifdef __DAV_C220_CUBE__ +constexpr int32_t L0AB_HALF_BUF_SIZE = 16384; // 128 * 128 +constexpr int32_t BLOCK_SIZE = 16; +constexpr int32_t CUBE_MATRIX_SIZE = 256; // 16 * 16 +constexpr int32_t L0AB_UINT8_BLOCK_SIZE = 32768; // 128 * 128 * 2B +constexpr int32_t KV_DB_SIZE = 65536; // 128 * 128 * 2B + +template +class FlashAttentionEncoderHighPrecision { + public: + __aicore__ __attribute__((always_inline)) inline FlashAttentionEncoderHighPrecision( + __gm__ uint8_t *__restrict__ sync, __gm__ uint8_t *__restrict__ q_gm, __gm__ uint8_t *__restrict__ k_gm, + __gm__ uint8_t *__restrict__ v_gm, __gm__ uint8_t *__restrict__ layerID_gm, __gm__ uint8_t *__restrict__ s_gm, + __gm__ uint8_t *__restrict__ p_gm, __gm__ uint8_t *__restrict__ o_tmp_gm, + __gm__ uint8_t *__restrict__ tiling_para_gm) + : q_gm(q_gm), s_gm(s_gm), p_gm(p_gm), o_tmp_gm(o_tmp_gm), tiling_para_gm(tiling_para_gm) { + SetFftsBaseAddr((unsigned long)sync); + SetPadding(0); + SetAtomicnone(); + SetNdpara(1, 0, 0); + SetMasknorm(); + + this->batch_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm)); + this->max_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1)); + this->q_heads = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2)); + this->embd = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3)); + this->kv_heads = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 4)); + this->is_triu_mask = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 8)); + this->total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 9)); + this->tiling_head_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 14)); + this->tiling_para_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 15)); + this->tilingKey = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 16)); + this->max_kv_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 18)); + this->window_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 27)); + this->data_shape_type = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 25)); + this->max_q_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 29)); + this->npd = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 37)); + this->npd = (this->npd == 2); + this->page_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 38)); + this->group_num = q_heads / kv_heads; + this->stride_qo = (uint64_t)q_heads * embd; + this->stride_kv = kv_heads * embd; + if (data_shape_type == 1) { + this->stride_qo = embd; + this->stride_kv = embd; + } + if (this->npd) { + this->stride_kv = kv_heads * embd * page_size; + } + this->batch_stride_kv = batch_size * max_kv_seqlen * kv_heads * embd * sizeof(QKV_DT); + this->__k = embd; + this->round_k = (__k + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + if (layerID_gm != nullptr) { + uint32_t layer_id = *(__gm__ uint32_t *)layerID_gm; + this->k_gm = k_gm + layer_id * batch_stride_kv; + this->v_gm = v_gm + layer_id * batch_stride_kv; + } else { + this->k_gm = k_gm; + this->v_gm = v_gm; + } + + q_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ QKV_DT *>(this->q_gm)); + k_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ QKV_DT *>(this->k_gm)); + v_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ QKV_DT *>(this->v_gm)); + s_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(this->s_gm)); + p_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ QKV_DT *>(this->p_gm)); + o_tmp_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(this->o_tmp_gm)); + + SET_FLAG(MTE1, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID1); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + SET_FLAG(M, MTE1, EVENT_ID2); + SET_FLAG(M, MTE1, EVENT_ID3); + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(FIX, M, EVENT_ID1); + } + + __aicore__ __attribute__((always_inline)) inline ~FlashAttentionEncoderHighPrecision() { + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(M, MTE1, EVENT_ID2); + WAIT_FLAG(M, MTE1, EVENT_ID3); + WAIT_FLAG(FIX, M, EVENT_ID0); + WAIT_FLAG(FIX, M, EVENT_ID1); + PIPE_BARRIER(ALL); + } + + __aicore__ __attribute__((always_inline)) inline uint32_t GetTilingKey() { return this->tilingKey; } + + __aicore__ __attribute__((always_inline)) inline void LoadDataToCa(AscendC::LocalTensor dst_tensor, + AscendC::LocalTensor src_tensor, + uint32_t round_k, uint32_t qk_round_m, + uint32_t qk_m) { + uint32_t round_row = RoundUp(round_k, 32 / sizeof(QKV_DT)); + if (qk_m == 1) { + l1_to_l0_a( + dst_tensor, src_tensor, 0, + NumMatrixsRoundUp(round_row), // repeat + 0, + 1, // srcStride + 0, + 0 // dstStride + ); + } else { + l1_to_l0_a( + dst_tensor, src_tensor, qk_round_m, round_row, 0, 0, 0, 0); + } + } + + template + __aicore__ __attribute__((always_inline)) inline void Run() { + uint64_t cur_batch = 0; + uint64_t pre_total_q_blk_num = 0; + uint32_t offset_tiling = tiling_head_size + tiling_para_size * cur_batch; + uint32_t cur_total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 13 + offset_tiling)); + uint64_t process_num = (uint64_t)total_q_blk_num * q_heads; + uint64_t next_process = 0; + for (uint64_t process = block_idx; process < process_num; process = next_process) { + while (process >= (uint64_t)cur_total_q_blk_num * q_heads) { + cur_batch++; + pre_total_q_blk_num = cur_total_q_blk_num; + offset_tiling += tiling_para_size; + cur_total_q_blk_num = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 13 + offset_tiling)); + } + next_process = process + block_num; + if (is_triu_mask) { + uint64_t curr_iter = process / block_num; + next_process = + curr_iter % 2 == 1 ? (curr_iter + 1) * block_num + block_idx : (curr_iter + 2) * block_num - 1 - block_idx; + } + // get tiling args + uint32_t q_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + offset_tiling)); + uint32_t kv_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1 + offset_tiling)); + if (q_seqlen == 0 || kv_seqlen == 0) { + continue; + } + uint32_t pp_m_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2 + offset_tiling)); + uint32_t pp_n_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3 + offset_tiling)); + uint32_t addr_q_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 4 + offset_tiling)); + uint32_t addr_q_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 5 + offset_tiling)); + uint64_t addr_q_scalar = (uint64_t)(((uint64_t)addr_q_high32) << 32 | addr_q_loww32); + uint32_t addr_k_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 6 + offset_tiling)); + uint32_t addr_k_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 7 + offset_tiling)); + uint64_t addr_k_scalar = (uint64_t)(((uint64_t)addr_k_high32) << 32 | addr_k_loww32); + uint32_t addr_v_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 8 + offset_tiling)); + uint32_t addr_v_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 9 + offset_tiling)); + uint64_t addr_v_scalar = (uint64_t)(((uint64_t)addr_v_high32) << 32 | addr_v_loww32); + uint64_t process_idx = process - pre_total_q_blk_num * q_heads; + uint32_t m_idx = process_idx / q_heads; + uint64_t head_idx = process_idx % q_heads; + + uint32_t m_loop = (q_seqlen + pp_m_scalar - 1) / pp_m_scalar; + uint32_t n_loop = (kv_seqlen + pp_n_scalar - 1) / pp_n_scalar; + + uint32_t qk_m = (m_idx == (m_loop - 1)) ? (q_seqlen - m_idx * pp_m_scalar) : pp_m_scalar; + uint32_t qk_round_m = (qk_m + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + /**************** pre_load *****************/ + uint32_t qk_n = n_loop == 1 ? kv_seqlen : pp_n_scalar; + uint32_t qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + uint32_t pingpong_flag = 0; + uint32_t offset = pingpong_flag * L0AB_HALF_BUF_SIZE; + + uint64_t q_offset = addr_q_scalar + head_idx * embd + m_idx * pp_m_scalar * stride_qo; + uint64_t k_base_offset = addr_k_scalar + (head_idx / group_num) * embd; + if (data_shape_type == 1) { + q_offset = addr_q_scalar + head_idx * embd * max_q_seqlen + m_idx * pp_m_scalar * stride_qo; + k_base_offset = addr_k_scalar + (head_idx / group_num) * embd * max_kv_seqlen; + } + if (this->npd) { + k_base_offset = addr_k_scalar + (head_idx / group_num) * embd * page_size; + } + uint64_t k_offset = 0; + // Only need load Q once + if (qk_m == 1) { + gm_to_l1( + l1q_buf_addr_tensor, q_gm_tensor[q_offset], 1, 0, 0, + RoundUp(round_k, 32 / sizeof(QKV_DT)), // lenBurst + 0, 0); + } else { + gm_to_l1(l1q_buf_addr_tensor, + q_gm_tensor[q_offset], + qk_m, // nValue + qk_round_m, // dstNzC0Stride + 0, + __k, // dValue + 0, + stride_qo // srcDValue + ); + } + SET_FLAG(MTE2, MTE1, pingpong_flag); + WAIT_FLAG(MTE2, MTE1, pingpong_flag); + + uint32_t sv_n = n_loop == 1 ? kv_seqlen : pp_n_scalar; + uint32_t sv_round_n = (sv_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + uint64_t v_base_offset = addr_v_scalar + (head_idx / group_num) * embd; + if (data_shape_type == 1) { + v_base_offset = addr_v_scalar + (head_idx / group_num) * embd * max_kv_seqlen; + } + if (this->npd) { + v_base_offset = addr_k_scalar + (head_idx / group_num) * embd * page_size; + } + uint64_t v_offset = 0; + uint32_t n_end = n_loop; + if ((is_triu_mask && pp_m_scalar >= pp_n_scalar) || window_size > 0) { + uint32_t n_offset = ((m_idx + 1) * pp_m_scalar + kv_seqlen - q_seqlen + pp_n_scalar - 1) / pp_n_scalar; + n_end = n_offset > n_end ? n_end : n_offset; + } + + uint32_t window_start = (window_size + pp_n_scalar - 1) / pp_n_scalar; + uint32_t n_start = 0; + uint32_t k_token_id = 0; + uint32_t v_token_id = 0; + if constexpr (swa_flag) { + if (window_size > 0 && window_size < kv_seqlen) { + n_start = (m_idx < window_start) ? 0 : m_idx - window_start; + k_offset += n_start * stride_kv * pp_n_scalar; + v_offset += n_start * stride_kv * pp_n_scalar; + k_token_id = n_start * pp_n_scalar; + v_token_id = n_start * pp_n_scalar; + } + } + uint32_t s_block_stack = n_end > 4 ? 2 : 1; // Currently not splitting K + uint32_t launch_delay = s_block_stack * 2; + uint32_t vect_mod = 2 * launch_delay; + for (uint32_t n_idx = n_start; n_idx < n_end + launch_delay; n_idx += s_block_stack) { + if (n_idx < n_end) { + for (uint32_t split_idx = 0; split_idx < s_block_stack && n_idx + split_idx < n_end; split_idx++) { + pingpong_flag = (n_idx + split_idx - n_start) % 2; + offset = pingpong_flag * L0AB_HALF_BUF_SIZE; + if (n_idx + split_idx == (n_loop - 1)) { + qk_n = (kv_seqlen - (n_idx + split_idx) * pp_n_scalar); + qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + } + bool last_split = split_idx == s_block_stack - 1 || n_idx + split_idx == n_end - 1; + WAIT_FLAG(M, MTE1, pingpong_flag); + LoadDataToCa(l0a_buf_tensor[offset], l1q_buf_addr_tensor, round_k, qk_round_m, qk_m); + // *** Prepare K to L1 + SET_FLAG(MTE1, M, pingpong_flag); + WAIT_FLAG(MTE1, MTE2, pingpong_flag); + if (this->npd) { + k_offset = (k_token_id / page_size) * stride_kv + (k_token_id % page_size) * embd; + CopyGmToL1Npd(l1k_buf_addr_tensor[offset], + k_gm_tensor[k_base_offset + k_offset], this->page_size, qk_n, + qk_round_n, __k, round_k, stride_kv); + } else { + gm_to_l1( + l1k_buf_addr_tensor[offset], k_gm_tensor[k_base_offset + k_offset], + qk_n, // nValue + qk_round_n, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + __k, // dValue + 0, // dstNzMatrixStride, unused + stride_kv // srcDValue + ); + k_offset += pp_n_scalar * stride_kv; + } + k_token_id += qk_n; + SET_FLAG(MTE2, MTE1, pingpong_flag); + WAIT_FLAG(M, MTE1, pingpong_flag + 2); + WAIT_FLAG(MTE2, MTE1, pingpong_flag); + l1_to_l0_b( + l0b_buf_tensor[offset], l1k_buf_addr_tensor[offset], 0, + NumMatrixsRoundUp(RoundUp(round_k, 32 / sizeof(QKV_DT)) * qk_round_n), // repeat + 0, + 1, // srcStride + 0, + 0 // dstStride + ); + SET_FLAG(MTE1, MTE2, pingpong_flag); + SET_FLAG(MTE1, M, pingpong_flag + 2); + WAIT_FLAG(MTE1, M, pingpong_flag); + WAIT_FLAG(MTE1, M, pingpong_flag + 2); + if (split_idx == 0) { + WAIT_FLAG(FIX, M, EVENT_ID0); + WAIT_FLAG(FIX, M, EVENT_ID1); + } + if constexpr (int8_flag) { + mmad( + l0c_buf_tensor.ReinterpretCast()[split_idx * qk_round_m * pp_n_scalar], l0a_buf_tensor[offset], + l0b_buf_tensor[offset], + qk_m, // m + qk_n, // n + __k, // k + 1 // cmatrixInitVal + ); + } else { + mmad( + l0c_buf_tensor[split_idx * qk_round_m * pp_n_scalar], l0a_buf_tensor[offset], l0b_buf_tensor[offset], + qk_m, // m + qk_n, // n + __k, // k + 1 // cmatrixInitVal + ); + } + SET_FLAG(M, MTE1, pingpong_flag); + SET_FLAG(M, MTE1, pingpong_flag + 2); + } + SET_FLAG(M, FIX, EVENT_ID0); + WAIT_FLAG(M, FIX, EVENT_ID0); + uint32_t sv_n_triu = n_end * pp_n_scalar; + if (n_idx + s_block_stack > n_end - 1) { + sv_n = sv_n_triu > kv_seqlen ? kv_seqlen - n_idx * pp_n_scalar : sv_n_triu - n_idx * pp_n_scalar; + } else { + sv_n = pp_n_scalar * s_block_stack; + } + uint32_t sv_round_n = (sv_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + // copy S to gm + l0c_to_gm( + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + (n_idx - n_start) % vect_mod * TMP_SIZE / vect_mod], + l0c_buf_tensor, + qk_m, // MSize + sv_round_n, // NSize + qk_round_m, // srcStride + sv_round_n // dstStride_dst_D + ); + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(FIX, M, EVENT_ID1); + FftsCrossCoreSync(QK_READY); + } + if (n_idx >= launch_delay + n_start) { + uint32_t l0c_pingpong_flag = (n_idx - n_start) % 2; + uint32_t l0c_offset = l0c_pingpong_flag * L0AB_HALF_BUF_SIZE; + uint32_t sv_n_triu = n_end * pp_n_scalar; + if (n_idx + s_block_stack > n_end + launch_delay - 1) { + sv_n = sv_n_triu > kv_seqlen ? kv_seqlen - (n_idx - launch_delay) * pp_n_scalar + : sv_n_triu - (n_idx - launch_delay) * pp_n_scalar; + } else { + sv_n = pp_n_scalar * s_block_stack; + } + uint32_t sv_round_n = (sv_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + + if constexpr (int8_flag) { + gm_to_l1( + l1v_buf_addr_tensor, v_gm_tensor[v_base_offset + v_offset], + sv_n, // nValue + RoundUp(sv_round_n, 32), // dstNzC0Stride + 0, // dstNzMatrixStride, unused + __k, // dValue + 0, // dstNzMatrixStride, unused + stride_kv // srcDValue + ); + v_offset += sv_n * stride_kv; + } else { + if (this->npd) { + v_offset = ((v_token_id) / page_size) * stride_kv + (v_token_id % page_size) * embd; + CopyGmToL1Npd(l1v_buf_addr_tensor, v_gm_tensor[v_base_offset + v_offset], + this->page_size, sv_n, sv_round_n, __k, round_k, stride_kv); + } else { + gm_to_l1( + l1v_buf_addr_tensor, v_gm_tensor[v_base_offset + v_offset], + sv_n, // nValue + sv_round_n, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + __k, // dValue + 0, // dstNzMatrixStride, unused + stride_kv // srcDValue + ); + } + v_offset += sv_n * stride_kv; + } + v_token_id += sv_n; + SET_FLAG(MTE2, MTE1, EVENT_ID2); + WAIT_FLAG(MTE2, MTE1, EVENT_ID2); + WAIT_FLAG(M, MTE1, EVENT_ID2); + WAIT_FLAG(M, MTE1, EVENT_ID3); + if constexpr (int8_flag) { + for (uint32_t l0b_load_idx = 0; l0b_load_idx < (sv_round_n + 31) / 32 * 32 / BlockSize(); + ++l0b_load_idx) { + AscendC::LoadDataWithTranspose( + l0b_buf_tensor[l0b_load_idx * round_k * BlockSize()], + l1v_buf_addr_tensor[l0b_load_idx * BlockSize() * BlockSize()], + AscendC::LoadData2dTransposeParams(0, // startIndexIn + (round_k + 31) / 32 * 32 / BlockSize(), // repeatTimesIn + (sv_round_n + 31) / 32 * 32 / BlockSize(), // srcStrideIn + 1, // dstGapIn + 0, // dstfracGapIn + 0) // addrModeIn + ); + } + } else { + for (uint32_t l0b_load_idx = 0; l0b_load_idx < sv_round_n / BLOCK_SIZE; ++l0b_load_idx) { + l1_to_l0_b( + l0b_buf_tensor[l0b_load_idx * round_k * BLOCK_SIZE], + l1v_buf_addr_tensor[l0b_load_idx * CUBE_MATRIX_SIZE], 0, + round_k / BLOCK_SIZE, // repeat + 0, + sv_round_n / BLOCK_SIZE, // srcStride + 0, + 0 // dstStride + ); + } + } + SET_FLAG(MTE1, M, EVENT_ID6); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + WaitFlagDev(SOFTMAX_READY); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + if (qk_m == 1) { + gm_to_l1( + l1p_buf_addr_tensor, + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay - n_start) % vect_mod * TMP_SIZE / vect_mod) * + 2 / sizeof(QKV_DT)], + 1, 0, 0, RoundUp(sv_round_n, BlockSize()), // lenBurst + 0, 0); + } else { + gm_to_l1( + l1p_buf_addr_tensor, + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay - n_start) % vect_mod * TMP_SIZE / vect_mod) * + 2 / sizeof(QKV_DT)], + qk_m, // nValue + qk_round_m, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + sv_n, // dValue + 0, // dstNzMatrixStride, unused + sv_round_n * 2 / sizeof(QKV_DT) // srcDValue + ); + } + SET_FLAG(MTE2, MTE1, EVENT_ID3); + WAIT_FLAG(MTE2, MTE1, EVENT_ID3); + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + LoadDataToCa(l0a_buf_tensor, l1p_buf_addr_tensor, RoundUp(sv_round_n, BlockSize()), + qk_round_m, qk_m); + SET_FLAG(MTE1, M, EVENT_ID5); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(MTE1, M, EVENT_ID5); + WAIT_FLAG(MTE1, M, EVENT_ID6); + WAIT_FLAG(FIX, M, l0c_pingpong_flag); + if constexpr (int8_flag) { + mmad( + l0c_buf_tensor.template ReinterpretCast()[l0c_offset], l0a_buf_tensor, l0b_buf_tensor, + qk_m, // m + __k, // n + sv_n, // k + 1 // cmatrixInitVal + ); + } else { + mmad(l0c_buf_tensor[l0c_offset], l0a_buf_tensor, + l0b_buf_tensor, + qk_m, // m + __k, // n + sv_n, // k + 1 // cmatrixInitVal + ); + } + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + SET_FLAG(M, MTE1, EVENT_ID2); + SET_FLAG(M, MTE1, EVENT_ID3); + SET_FLAG(M, FIX, l0c_pingpong_flag); + WAIT_FLAG(M, FIX, l0c_pingpong_flag); + // copy O to gm + l0c_to_gm( + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay - n_start) % vect_mod * TMP_SIZE / vect_mod], + l0c_buf_tensor[l0c_offset], + qk_m, // MSize + round_k, // NSize + qk_round_m, // srcStride + round_k // dstStride_dst_D + ); + SET_FLAG(FIX, M, l0c_pingpong_flag); + FftsCrossCoreSync(UPDATE_READY); + } + } + } + } + + private: + __gm__ uint8_t *__restrict__ q_gm{nullptr}; + __gm__ uint8_t *__restrict__ k_gm{nullptr}; + __gm__ uint8_t *__restrict__ v_gm{nullptr}; + __gm__ uint8_t *__restrict__ s_gm{nullptr}; + __gm__ uint8_t *__restrict__ p_gm{nullptr}; + __gm__ uint8_t *__restrict__ o_tmp_gm{nullptr}; + __gm__ uint8_t *__restrict__ tiling_para_gm{nullptr}; + + const uint32_t l1q_buf_addr_offset = 0; + const uint32_t l1k_buf_addr_offset = 2 * L0AB_UINT8_BLOCK_SIZE; + const uint32_t l1p_buf_addr_offset = 4 * L0AB_UINT8_BLOCK_SIZE; + const uint32_t l1v_buf_addr_offset = 6 * L0AB_UINT8_BLOCK_SIZE; + + AsdopsBuffer buf; + + AscendC::LocalTensor l1q_buf_addr_tensor = buf.GetBuffer(l1q_buf_addr_offset); + AscendC::LocalTensor l1k_buf_addr_tensor = buf.GetBuffer(l1k_buf_addr_offset); + AscendC::LocalTensor l1p_buf_addr_tensor = buf.GetBuffer(l1p_buf_addr_offset); + AscendC::LocalTensor l1v_buf_addr_tensor = buf.GetBuffer(l1v_buf_addr_offset); + + AscendC::GlobalTensor q_gm_tensor; + AscendC::GlobalTensor k_gm_tensor; + AscendC::GlobalTensor v_gm_tensor; + AscendC::GlobalTensor s_gm_tensor; + AscendC::GlobalTensor p_gm_tensor; + AscendC::GlobalTensor o_tmp_gm_tensor; + + AscendC::LocalTensor l0a_buf_tensor = buf.GetBuffer(0); + AscendC::LocalTensor l0b_buf_tensor = buf.GetBuffer(0); + AscendC::LocalTensor l0c_buf_tensor = buf.GetBuffer(0); + + uint32_t batch_size{0}; + uint32_t max_seqlen{0}; + uint32_t max_kv_seqlen{0}; + uint32_t max_q_seqlen{0}; + uint32_t q_heads{0}; + uint32_t embd{0}; + uint32_t kv_heads{0}; + uint32_t is_triu_mask{0}; + uint32_t total_q_blk_num{0}; + uint32_t group_num{0}; + uint64_t stride_qo{0}; + uint64_t stride_kv{0}; + uint64_t batch_stride_kv{0}; + uint32_t __k{0}; + uint32_t round_k{0}; + uint32_t data_shape_type{0}; + uint32_t window_size{0}; + uint32_t tilingKey{0}; + uint32_t tiling_head_size{0}; + uint32_t tiling_para_size{0}; + uint32_t npd{0}; + uint32_t page_size{0}; +}; + +template +class FlashAttentionEncoderHighPrecisionCubeOpt { + public: + __aicore__ __attribute__((always_inline)) inline FlashAttentionEncoderHighPrecisionCubeOpt( + __gm__ uint8_t *__restrict__ sync, __gm__ uint8_t *__restrict__ q_gm, __gm__ uint8_t *__restrict__ k_gm, + __gm__ uint8_t *__restrict__ v_gm, __gm__ uint8_t *__restrict__ layerID_gm, __gm__ uint8_t *__restrict__ s_gm, + __gm__ uint8_t *__restrict__ p_gm, __gm__ uint8_t *__restrict__ o_tmp_gm, + __gm__ uint8_t *__restrict__ tiling_para_gm) + : q_gm(q_gm), s_gm(s_gm), p_gm(p_gm), o_tmp_gm(o_tmp_gm), tiling_para_gm(tiling_para_gm) { + SetFftsBaseAddr((unsigned long)sync); + SetPadding(0); + SetAtomicnone(); + SetNdpara(1, 0, 0); + SetMasknorm(); + + this->batch_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm)); + this->max_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1)); + this->q_heads = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2)); + this->embd = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3)); + this->kv_heads = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 4)); + this->is_triu_mask = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 8)); + this->total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 9)); + this->tiling_head_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 14)); + this->tiling_para_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 15)); + this->tilingKey = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 16)); + this->max_kv_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 18)); + this->max_q_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 29)); + this->npd = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 37)); + this->npd = (this->npd == 2); + this->page_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 38)); + this->data_shape_type = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 25)); + this->group_num = q_heads / kv_heads; + this->stride_qo = (uint64_t)q_heads * embd; + this->stride_kv = kv_heads * embd; + if (data_shape_type == 1) { + this->stride_qo = embd; + this->stride_kv = embd; + } + if constexpr (splitm) { + this->tmp_times = 2; + } + if (this->npd) { + this->stride_kv = kv_heads * embd * page_size; + } + this->batch_stride_kv = batch_size * max_kv_seqlen * kv_heads * embd * sizeof(QKV_DT); + this->__k = embd; + this->round_k = (__k + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + if (layerID_gm != nullptr) { + uint32_t layer_id = *(__gm__ uint32_t *)layerID_gm; + this->k_gm = k_gm + layer_id * batch_stride_kv; + this->v_gm = v_gm + layer_id * batch_stride_kv; + } else { + this->k_gm = k_gm; + this->v_gm = v_gm; + } + + q_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ QKV_DT *>(this->q_gm)); + k_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ QKV_DT *>(this->k_gm)); + v_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ QKV_DT *>(this->v_gm)); + s_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(this->s_gm)); + p_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ QKV_DT *>(this->p_gm)); + o_tmp_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(this->o_tmp_gm)); + + SET_FLAG(MTE1, MTE2, EVENT_ID0); + SET_FLAG(MTE1, MTE2, EVENT_ID1); + SET_FLAG(MTE1, MTE2, EVENT_ID2); + SET_FLAG(MTE1, MTE2, EVENT_ID3); + SET_FLAG(MTE1, MTE2, EVENT_ID4); + SET_FLAG(MTE1, MTE2, EVENT_ID5); + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + SET_FLAG(M, MTE1, EVENT_ID2); + SET_FLAG(M, MTE1, EVENT_ID3); + SET_FLAG(FIX, M, EVENT_ID0); + SET_FLAG(FIX, M, EVENT_ID1); + } + + __aicore__ __attribute__((always_inline)) inline ~FlashAttentionEncoderHighPrecisionCubeOpt() { + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(MTE1, MTE2, EVENT_ID4); + WAIT_FLAG(MTE1, MTE2, EVENT_ID5); + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(M, MTE1, EVENT_ID2); + WAIT_FLAG(M, MTE1, EVENT_ID3); + WAIT_FLAG(FIX, M, EVENT_ID0); + WAIT_FLAG(FIX, M, EVENT_ID1); + PIPE_BARRIER(ALL); + } + + __aicore__ __attribute__((always_inline)) inline uint32_t GetTilingKey() { return this->tilingKey; } + + __aicore__ __attribute__((always_inline)) inline void LoadDataToCa(AscendC::LocalTensor dst_tensor, + AscendC::LocalTensor src_tensor, + uint32_t round_k, uint32_t qk_round_m, + uint32_t qk_m) { + uint32_t round_row = RoundUp(round_k, 32 / sizeof(QKV_DT)); + if (qk_m == 1) { + l1_to_l0_a( + dst_tensor, src_tensor, 0, + NumMatrixsRoundUp(round_row), // repeat + 0, + 1, // srcStride + 0, + 0 // dstStride + ); + } else { + l1_to_l0_a( + dst_tensor, src_tensor, qk_round_m, round_row, 0, 0, 0, 0); + } + } + + __aicore__ __attribute__((always_inline)) inline void Run() { + uint64_t cur_batch = 0; + uint64_t pre_total_q_blk_num = 0; + uint32_t offset_tiling = tiling_head_size + tiling_para_size * cur_batch; + uint32_t cur_total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 13 + offset_tiling)); + uint64_t process_num = (uint64_t)total_q_blk_num * q_heads; + uint64_t next_process = 0; + for (uint64_t process = block_idx; process < process_num; process = next_process) { + while (process >= (uint64_t)cur_total_q_blk_num * q_heads) { + cur_batch++; + pre_total_q_blk_num = cur_total_q_blk_num; + offset_tiling += tiling_para_size; + cur_total_q_blk_num = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 13 + offset_tiling)); + } + next_process = process + block_num; + if (is_triu_mask) { + uint64_t curr_iter = process / block_num; + next_process = + curr_iter % 2 == 1 ? (curr_iter + 1) * block_num + block_idx : (curr_iter + 2) * block_num - 1 - block_idx; + } + // get tiling args + uint32_t q_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + offset_tiling)); + uint32_t kv_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1 + offset_tiling)); + if (q_seqlen == 0 || kv_seqlen == 0) { + continue; + } + uint32_t pp_m_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2 + offset_tiling)); + uint32_t pp_n_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3 + offset_tiling)); + uint32_t addr_q_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 4 + offset_tiling)); + uint32_t addr_q_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 5 + offset_tiling)); + uint64_t addr_q_scalar = (uint64_t)(((uint64_t)addr_q_high32) << 32 | addr_q_loww32); + uint32_t addr_k_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 6 + offset_tiling)); + uint32_t addr_k_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 7 + offset_tiling)); + uint64_t addr_k_scalar = (uint64_t)(((uint64_t)addr_k_high32) << 32 | addr_k_loww32); + uint32_t addr_v_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 8 + offset_tiling)); + uint32_t addr_v_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 9 + offset_tiling)); + uint64_t addr_v_scalar = (uint64_t)(((uint64_t)addr_v_high32) << 32 | addr_v_loww32); + uint64_t process_idx = process - pre_total_q_blk_num * q_heads; + uint32_t m_idx = process_idx / q_heads; + uint64_t head_idx = process_idx % q_heads; + + uint32_t m_loop = (q_seqlen + pp_m_scalar - 1) / pp_m_scalar; + uint32_t n_loop = (kv_seqlen + pp_n_scalar - 1) / pp_n_scalar; + + uint32_t qk_m = (m_idx == (m_loop - 1)) ? (q_seqlen - m_idx * pp_m_scalar) : pp_m_scalar; + uint32_t qk_round_m = (qk_m + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + /**************** pre_load *****************/ + uint32_t qk_n = n_loop == 1 ? kv_seqlen : pp_n_scalar; + uint32_t qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + uint32_t pingpong_flag = 0; + uint32_t offset = pingpong_flag * L0AB_HALF_BUF_SIZE; + + uint64_t q_offset = + addr_q_scalar + head_idx * embd + m_idx * pp_m_scalar * stride_qo; // + 17592186044416 / sizeof(QKV_DT); + uint64_t k_base_offset = addr_k_scalar + (head_idx / group_num) * embd; // + 17592186044416 / sizeof(QKV_DT); + if (data_shape_type == 1) { + q_offset = addr_q_scalar + head_idx * embd * max_q_seqlen + + m_idx * pp_m_scalar * stride_qo; // + 17592186044416 / sizeof(QKV_DT); + k_base_offset = + addr_k_scalar + (head_idx / group_num) * embd * max_kv_seqlen; // + 17592186044416 / sizeof(QKV_DT); + } + if (this->npd) { + k_base_offset = addr_k_scalar + (head_idx / group_num) * embd * page_size; + } + uint64_t k_offset = 0; + // Only need load Q once + if (qk_m == 1) { + gm_to_l1( + l1q_buf_addr_tensor, q_gm_tensor[q_offset], 1, 0, 0, + RoundUp(round_k, 32 / sizeof(QKV_DT)), // lenBurst + 0, 0); + } else { + gm_to_l1(l1q_buf_addr_tensor, + q_gm_tensor[q_offset], + qk_m, // nValue + qk_round_m, // dstNzC0Stride + 0, + __k, // dValue + 0, + stride_qo // srcDValue + ); + } + SET_FLAG(MTE2, MTE1, pingpong_flag); + WAIT_FLAG(MTE2, MTE1, pingpong_flag); + + uint32_t sv_n = n_loop == 1 ? kv_seqlen : pp_n_scalar; + uint32_t sv_round_n = (sv_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + uint64_t v_base_offset = addr_v_scalar + (head_idx / group_num) * embd; // + 17592186044416 / sizeof(QKV_DT); + if (data_shape_type == 1) { + v_base_offset = + addr_v_scalar + (head_idx / group_num) * embd * max_kv_seqlen; // + 17592186044416 / sizeof(QKV_DT); + } + if (this->npd) { + v_base_offset = addr_k_scalar + (head_idx / group_num) * embd * page_size; + } + uint64_t v_offset = 0; + uint32_t n_end = n_loop; + if (is_triu_mask && pp_m_scalar >= pp_n_scalar) { + uint32_t n_offset = ((m_idx + 1) * pp_m_scalar + kv_seqlen - q_seqlen + pp_n_scalar - 1) / pp_n_scalar; + n_end = n_offset > n_end ? n_end : n_offset; + } + uint32_t k_token_id = 0; + uint32_t v_token_id = 0; + uint32_t s_block_stack = n_end > 8 ? 4 : (n_end > 4 ? 2 : 1); + uint32_t launch_delay = s_block_stack * 2; + uint32_t vect_mod = 2 * launch_delay; + uint32_t kv_pingpong_flag = 0; + uint64_t kv_pingpong_offset = kv_pingpong_flag * KV_DB_SIZE; + + if constexpr (splitm) { + uint32_t split_m = 128; + uint32_t m_inner_loop = (qk_m + 127) / 128; + uint32_t l0c_pingpong_flag = 0; + for (uint32_t n_idx = 0; n_idx < n_end + launch_delay; n_idx += s_block_stack) { + if (n_idx < n_end) { + uint32_t sv_n_triu = n_end * pp_n_scalar; + uint32_t l0c_offset = l0c_pingpong_flag * 128 * 128; + if (n_idx + s_block_stack > n_end - 1) { + sv_n = sv_n_triu > kv_seqlen ? kv_seqlen - n_idx * pp_n_scalar : sv_n_triu - n_idx * pp_n_scalar; + } else { + sv_n = pp_n_scalar * s_block_stack; + } + uint32_t sv_round_n = (sv_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + LoadDataToCa(l0a_buf_tensor, l1q_buf_addr_tensor, round_k, qk_round_m, qk_m); + // *** Prepare K to L1 + SET_FLAG(MTE1, M, EVENT_ID0); + WAIT_FLAG(MTE1, M, EVENT_ID0); + + for (uint32_t split_idx = 0; split_idx < s_block_stack && n_idx + split_idx < n_end; split_idx++) { + pingpong_flag = (n_idx + split_idx) % 2; + offset = pingpong_flag * L0AB_HALF_BUF_SIZE; + if (n_idx + split_idx == (n_loop - 1)) { + qk_n = (kv_seqlen - (n_idx + split_idx) * pp_n_scalar); + qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + } + bool last_split = split_idx == s_block_stack - 1 || n_idx + split_idx == n_end - 1; + WAIT_FLAG(MTE1, MTE2, pingpong_flag + 2 * kv_pingpong_flag); + + if (this->npd) { + k_offset = (k_token_id / page_size) * stride_kv + (k_token_id % page_size) * embd; + CopyGmToL1Npd(l1k_buf_addr_tensor[kv_pingpong_offset + offset], + k_gm_tensor[k_base_offset + k_offset], this->page_size, + qk_n, qk_round_n, __k, round_k, stride_kv); + } else { + gm_to_l1( + l1k_buf_addr_tensor[kv_pingpong_offset + offset], k_gm_tensor[k_base_offset + k_offset], + qk_n, // nValue + qk_round_n, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + __k, // dValue + 0, // dstNzMatrixStride, unused + stride_kv // srcDValue + ); + k_offset += pp_n_scalar * stride_kv; + } + k_token_id += qk_n; + SET_FLAG(MTE2, MTE1, pingpong_flag); + + WAIT_FLAG(M, MTE1, pingpong_flag + 2); + WAIT_FLAG(MTE2, MTE1, pingpong_flag); + l1_to_l0_b( + l0b_buf_tensor[offset], l1k_buf_addr_tensor[kv_pingpong_offset + offset], 0, + NumMatrixsRoundUp(RoundUp(round_k, 32 / sizeof(QKV_DT)) * qk_round_n), // repeat + 0, + 1, // srcStride + 0, + 0 // dstStride + ); + SET_FLAG(MTE1, MTE2, pingpong_flag + 2 * kv_pingpong_flag); + SET_FLAG(MTE1, M, pingpong_flag + 2); + WAIT_FLAG(MTE1, M, pingpong_flag + 2); + if (m_inner_loop == 1) { + WAIT_FLAG(FIX, M, l0c_pingpong_flag); + mmad(l0c_buf_tensor[l0c_offset], l0a_buf_tensor, + l0b_buf_tensor[offset], + qk_m, // m + qk_n, // n + __k, // k + 1 // cmatrixInitVal + ); + + SET_FLAG(M, FIX, l0c_pingpong_flag); + WAIT_FLAG(M, FIX, l0c_pingpong_flag); + // copy S to gm + l0c_to_gm( + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE * tmp_times + + n_idx % vect_mod * TMP_SIZE * tmp_times / vect_mod + split_idx * pp_n_scalar], + l0c_buf_tensor[l0c_offset], + qk_m, // MSize + qk_round_n, // NSize + qk_round_m, // srcStride + sv_round_n // dstStride_dst_D + ); + SET_FLAG(FIX, M, l0c_pingpong_flag); + l0c_pingpong_flag = 1 - l0c_pingpong_flag; + l0c_offset = l0c_pingpong_flag * L0AB_HALF_BUF_SIZE; + } else { + for (uint32_t m_inner_idx = 0; m_inner_idx < m_inner_loop; m_inner_idx++) { + uint32_t nowM = (m_inner_idx == m_inner_loop - 1) ? qk_m - m_inner_idx * split_m : split_m; + uint32_t nowMRound = (nowM + 15) / 16 * 16; + WAIT_FLAG(FIX, M, l0c_pingpong_flag); + mmad( + l0c_buf_tensor[l0c_offset], l0a_buf_tensor[l0c_pingpong_flag * split_m * round_k], + l0b_buf_tensor[offset], + nowM, // m + qk_n, // n + __k, // k + 1 // cmatrixInitVal + ); + + SET_FLAG(M, FIX, l0c_pingpong_flag); + WAIT_FLAG(M, FIX, l0c_pingpong_flag); + // copy S to gm + l0c_to_gm( + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE * tmp_times + + n_idx % vect_mod * TMP_SIZE * tmp_times / vect_mod + split_idx * pp_n_scalar + + m_inner_idx * split_m * sv_round_n], + l0c_buf_tensor[l0c_offset], + nowM, // MSize + qk_round_n, // NSize + nowMRound, // srcStride + sv_round_n // dstStride_dst_D + ); + SET_FLAG(FIX, M, l0c_pingpong_flag); + l0c_pingpong_flag = 1 - l0c_pingpong_flag; + l0c_offset = l0c_pingpong_flag * L0AB_HALF_BUF_SIZE; + } + } + SET_FLAG(M, MTE1, pingpong_flag + 2); + } + SET_FLAG(M, MTE1, EVENT_ID0); + SET_FLAG(M, MTE1, EVENT_ID1); + FftsCrossCoreSync(QK_READY); + kv_pingpong_flag = 1 - kv_pingpong_flag; + kv_pingpong_offset = kv_pingpong_flag * KV_DB_SIZE; + } + if (n_idx >= launch_delay) { + // uint32_t l0c_pingpong_flag = 0; + uint32_t l0c_offset = l0c_pingpong_flag * 128 * 128; + uint32_t sv_n_triu = n_end * pp_n_scalar; + bool last = false; + if (n_idx + s_block_stack > n_end + launch_delay - 1) { + sv_n = sv_n_triu > kv_seqlen ? kv_seqlen - (n_idx - launch_delay) * pp_n_scalar + : sv_n_triu - (n_idx - launch_delay) * pp_n_scalar; + last = true; + } else { + sv_n = pp_n_scalar * s_block_stack; + } + uint32_t sv_round_n = (sv_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + uint32_t n_slice = pp_n_scalar * ((s_block_stack + 1) / 2); + // n_slice = 256; + uint32_t l1_split_loop = (sv_n + n_slice - 1) / n_slice; + WAIT_FLAG(MTE1, MTE2, kv_pingpong_flag * 2); + WAIT_FLAG(MTE1, MTE2, kv_pingpong_flag * 2 + 1); + + if (this->npd) { + v_offset = ((v_token_id) / page_size) * stride_kv + (v_token_id % page_size) * embd; + CopyGmToL1Npd(l1v_buf_addr_tensor[kv_pingpong_offset], + v_gm_tensor[v_base_offset + v_offset], this->page_size, sv_n, + sv_round_n, __k, round_k, stride_kv); + } else { + gm_to_l1( + l1v_buf_addr_tensor[kv_pingpong_offset], v_gm_tensor[v_base_offset + v_offset], + sv_n, // nValue + sv_round_n, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + __k, // dValue + 0, // dstNzMatrixStride, unused + stride_kv // srcDValue + ); + v_offset += sv_n * stride_kv; + } + v_token_id += sv_n; + WaitFlagDev(SOFTMAX_READY); + for (uint32_t gm_split_idx = 0; gm_split_idx < m_inner_loop; gm_split_idx++) { + WAIT_FLAG(FIX, M, l0c_pingpong_flag); + bool m_last_split = gm_split_idx == m_inner_loop - 1; + uint64_t gm_p_offset = gm_split_idx * split_m * sv_round_n; + uint32_t nowM = m_last_split ? qk_m - gm_split_idx * split_m : split_m; + uint32_t nowMRound = (nowM + 15) / 16 * 16; + for (uint32_t l1_k_split_idx = 0; l1_k_split_idx < l1_split_loop; l1_k_split_idx++) { + uint32_t l1_pingpong_flag = l1_k_split_idx % 2; + uint32_t l1_offset = l1_pingpong_flag * 128 * 256; + bool l1_last_split = l1_k_split_idx == l1_split_loop - 1; + uint32_t d = l1_last_split ? sv_n - l1_k_split_idx * n_slice : n_slice; + WAIT_FLAG(MTE1, MTE2, l1_pingpong_flag + 4); + if (nowM == 1) { + gm_to_l1( + l1p_buf_addr_tensor[l1_offset], + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE * tmp_times + + (n_idx - launch_delay) % vect_mod * TMP_SIZE * tmp_times / vect_mod) * + 2 / sizeof(QKV_DT) + + gm_p_offset], + 1, 0, 0, RoundUp(sv_round_n, BlockSize()), // lenBurst + 0, 0); + } else { + gm_to_l1( + l1p_buf_addr_tensor[l1_offset], + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE * tmp_times + + (n_idx - launch_delay) % vect_mod * TMP_SIZE * tmp_times / vect_mod) * + 2 / sizeof(QKV_DT) + + l1_k_split_idx * n_slice + gm_p_offset], + nowM, // nValue + nowMRound, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + d, // dValue + 0, // dstNzMatrixStride, unused + sv_round_n * 2 / sizeof(QKV_DT) // srcDValue + ); + } + SET_FLAG(MTE2, MTE1, l1_pingpong_flag + 4); + WAIT_FLAG(MTE2, MTE1, l1_pingpong_flag + 4); + uint32_t d_split_loop = (d + 127) / 128; + for (uint32_t l0_k_split_idx = 0; l0_k_split_idx < d_split_loop; l0_k_split_idx++) { + uint32_t l0_pingpong_flag = l0_k_split_idx % 2; + uint32_t l0_offset = l0_pingpong_flag * 128 * 128; + bool l0_last_split = l0_k_split_idx == d_split_loop - 1; + int32_t l0_p_offset = nowM == 1 ? l0_k_split_idx * 128 : l0_k_split_idx * 128 * nowMRound; + bool initC = l0_k_split_idx == 0 && l1_k_split_idx == 0; + uint32_t reduce_d = l0_last_split ? d - l0_k_split_idx * 128 : 128; + uint32_t round_reduce_d = (reduce_d + 15) / 16 * 16; + WAIT_FLAG(M, MTE1, l0_pingpong_flag); + LoadDataToCa(l0a_buf_tensor[l0_offset], l1p_buf_addr_tensor[l1_offset + l0_p_offset], + RoundUp(round_reduce_d, BlockSize()), nowMRound, nowM); + if (l0_last_split) { + SET_FLAG(MTE1, MTE2, l1_pingpong_flag + 4); + } + WAIT_FLAG(M, MTE1, l0_pingpong_flag + 2); + for (uint32_t l0b_load_idx = 0; l0b_load_idx < 128 / BLOCK_SIZE; ++l0b_load_idx) { + l1_to_l0_b( + l0b_buf_tensor[l0_offset + l0b_load_idx * round_k * BLOCK_SIZE], + l1v_buf_addr_tensor[kv_pingpong_offset + l0b_load_idx * CUBE_MATRIX_SIZE + + l1_k_split_idx * n_slice * BLOCK_SIZE + l0_k_split_idx * 128 * BLOCK_SIZE], + 0, + round_k / BLOCK_SIZE, // repeat + 0, + sv_round_n / BLOCK_SIZE, // srcStride + 0, + 0 // dstStride + ); + } + if (l0_last_split && l1_last_split && m_last_split) { + SET_FLAG(MTE1, MTE2, kv_pingpong_flag * 2); + SET_FLAG(MTE1, MTE2, kv_pingpong_flag * 2 + 1); + } + SET_FLAG(MTE1, M, l0_pingpong_flag + 6); + WAIT_FLAG(MTE1, M, l0_pingpong_flag + 6); + mmad( + l0c_buf_tensor[l0c_offset], l0a_buf_tensor[l0_offset], l0b_buf_tensor[l0_offset], + nowM, // m + __k, // n + reduce_d, // k + initC // cmatrixInitVal + ); + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, l0_pingpong_flag); + SET_FLAG(M, MTE1, l0_pingpong_flag + 2); + } + } + SET_FLAG(M, FIX, l0c_pingpong_flag); + WAIT_FLAG(M, FIX, l0c_pingpong_flag); + // // copy O to gm + l0c_to_gm( + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE * tmp_times + + (n_idx - launch_delay) % vect_mod * TMP_SIZE * tmp_times / vect_mod + + gm_split_idx * split_m * round_k], + l0c_buf_tensor[l0c_offset], + nowM, // MSize + round_k, // NSize + nowMRound, // srcStride + round_k // dstStride_dst_D + ); + SET_FLAG(FIX, M, l0c_pingpong_flag); + l0c_pingpong_flag = 1 - l0c_pingpong_flag; + l0c_offset = l0c_pingpong_flag * 128 * 128; + } + FftsCrossCoreSync(UPDATE_READY); + kv_pingpong_flag = 1 - kv_pingpong_flag; + kv_pingpong_offset = kv_pingpong_flag * KV_DB_SIZE; + } + } + } else { + for (uint32_t n_idx = 0; n_idx < n_end + launch_delay; n_idx += s_block_stack) { + if (n_idx < n_end) { + uint32_t sv_n_triu = n_end * pp_n_scalar; + if (n_idx + s_block_stack > n_end - 1) { + sv_n = sv_n_triu > kv_seqlen ? kv_seqlen - n_idx * pp_n_scalar : sv_n_triu - n_idx * pp_n_scalar; + } else { + sv_n = pp_n_scalar * s_block_stack; + } + uint32_t sv_round_n = (sv_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + for (uint32_t split_idx = 0; split_idx < s_block_stack && n_idx + split_idx < n_end; split_idx++) { + pingpong_flag = (n_idx + split_idx) % 2; + offset = pingpong_flag * L0AB_HALF_BUF_SIZE; + if (n_idx + split_idx == (n_loop - 1)) { + qk_n = (kv_seqlen - (n_idx + split_idx) * pp_n_scalar); + qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + } + bool last_split = split_idx == s_block_stack - 1 || n_idx + split_idx == n_end - 1; + WAIT_FLAG(MTE1, MTE2, pingpong_flag + 2 * kv_pingpong_flag); + if (this->npd) { + k_offset = (k_token_id / page_size) * stride_kv + (k_token_id % page_size) * embd; + CopyGmToL1Npd(l1k_buf_addr_tensor[kv_pingpong_offset + offset], + k_gm_tensor[k_base_offset + k_offset], this->page_size, + qk_n, qk_round_n, __k, round_k, stride_kv); + } else { + gm_to_l1( + l1k_buf_addr_tensor[kv_pingpong_offset + offset], k_gm_tensor[k_base_offset + k_offset], + qk_n, // nValue + qk_round_n, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + __k, // dValue + 0, // dstNzMatrixStride, unused + stride_kv // srcDValue + ); + k_offset += pp_n_scalar * stride_kv; + } + k_token_id += qk_n; + SET_FLAG(MTE2, MTE1, pingpong_flag); + WAIT_FLAG(M, MTE1, pingpong_flag); + LoadDataToCa(l0a_buf_tensor[offset], l1q_buf_addr_tensor, round_k, qk_round_m, qk_m); + // *** Prepare K to L1 + SET_FLAG(MTE1, M, pingpong_flag); + WAIT_FLAG(M, MTE1, pingpong_flag + 2); + WAIT_FLAG(MTE2, MTE1, pingpong_flag); + l1_to_l0_b( + l0b_buf_tensor[offset], l1k_buf_addr_tensor[kv_pingpong_offset + offset], 0, + NumMatrixsRoundUp(RoundUp(round_k, 32 / sizeof(QKV_DT)) * qk_round_n), // repeat + 0, + 1, // srcStride + 0, + 0 // dstStride + ); + SET_FLAG(MTE1, MTE2, pingpong_flag + 2 * kv_pingpong_flag); + SET_FLAG(MTE1, M, pingpong_flag + 2); + WAIT_FLAG(MTE1, M, pingpong_flag); + WAIT_FLAG(MTE1, M, pingpong_flag + 2); + WAIT_FLAG(FIX, M, pingpong_flag); + if constexpr (int8_flag) { + mmad( + l0c_buf_tensor.ReinterpretCast()[pingpong_flag * L0AB_HALF_BUF_SIZE], l0a_buf_tensor[offset], + l0b_buf_tensor[offset], + qk_m, // m + qk_n, // n + __k, // k + 1 // cmatrixInitVal + ); + } else { + mmad( + l0c_buf_tensor[pingpong_flag * L0AB_HALF_BUF_SIZE], l0a_buf_tensor[offset], l0b_buf_tensor[offset], + qk_m, // m + qk_n, // n + __k, // k + 1 // cmatrixInitVal + ); + } + SET_FLAG(M, MTE1, pingpong_flag); + SET_FLAG(M, MTE1, pingpong_flag + 2); + + SET_FLAG(M, FIX, pingpong_flag); + WAIT_FLAG(M, FIX, pingpong_flag); + // copy S to gm + l0c_to_gm( + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE * tmp_times + + n_idx % vect_mod * TMP_SIZE * tmp_times / vect_mod + split_idx * pp_n_scalar], + l0c_buf_tensor[pingpong_flag * L0AB_HALF_BUF_SIZE], + qk_m, // MSize + qk_round_n, // NSize + qk_round_m, // srcStride + sv_round_n // dstStride_dst_D + ); + SET_FLAG(FIX, M, pingpong_flag); + } + FftsCrossCoreSync(QK_READY); + kv_pingpong_flag = 1 - kv_pingpong_flag; + kv_pingpong_offset = kv_pingpong_flag * KV_DB_SIZE; + } + if (n_idx >= launch_delay) { + uint32_t l0c_pingpong_flag = (n_idx + 1) % 2; + uint32_t l0c_offset = l0c_pingpong_flag * L0AB_HALF_BUF_SIZE; + uint32_t sv_n_triu = n_end * pp_n_scalar; + if (n_idx + s_block_stack > n_end + launch_delay - 1) { + sv_n = sv_n_triu > kv_seqlen ? kv_seqlen - (n_idx - launch_delay) * pp_n_scalar + : sv_n_triu - (n_idx - launch_delay) * pp_n_scalar; + } else { + sv_n = pp_n_scalar * s_block_stack; + } + uint32_t sv_round_n = (sv_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + uint32_t n_slice = pp_n_scalar * ((s_block_stack + 1) / 2); + // n_slice = 256; + uint32_t l1_split_loop = (sv_n + n_slice - 1) / n_slice; + WAIT_FLAG(MTE1, MTE2, kv_pingpong_flag * 2); + WAIT_FLAG(MTE1, MTE2, kv_pingpong_flag * 2 + 1); + + if (this->npd) { + v_offset = ((v_token_id) / page_size) * stride_kv + (v_token_id % page_size) * embd; + CopyGmToL1Npd(l1v_buf_addr_tensor[kv_pingpong_offset], + v_gm_tensor[v_base_offset + v_offset], this->page_size, sv_n, + sv_round_n, __k, round_k, stride_kv); + } else { + gm_to_l1( + l1v_buf_addr_tensor[kv_pingpong_offset], v_gm_tensor[v_base_offset + v_offset], + sv_n, // nValue + sv_round_n, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + __k, // dValue + 0, // dstNzMatrixStride, unused + stride_kv // srcDValue + ); + v_offset += sv_n * stride_kv; + } + v_token_id += sv_n; + WaitFlagDev(SOFTMAX_READY); + WAIT_FLAG(FIX, M, l0c_pingpong_flag); + for (uint32_t l1_k_split_idx = 0; l1_k_split_idx < l1_split_loop; l1_k_split_idx++) { + uint32_t l1_pingpong_flag = l1_k_split_idx % 2; + uint32_t l1_offset = l1_pingpong_flag * 128 * 256; + bool l1_last_split = l1_k_split_idx == l1_split_loop - 1; + uint32_t d = l1_last_split ? sv_n - l1_k_split_idx * n_slice : n_slice; + WAIT_FLAG(MTE1, MTE2, l1_pingpong_flag + 4); + if (qk_m == 1) { + gm_to_l1( + l1p_buf_addr_tensor[l1_offset], + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE * tmp_times + + (n_idx - launch_delay) % vect_mod * TMP_SIZE * tmp_times / vect_mod) * + 2 / sizeof(QKV_DT) + + l1_k_split_idx * n_slice], + 1, 0, 0, RoundUp(sv_round_n, BlockSize()), // lenBurst + 0, 0); + } else { + gm_to_l1( + l1p_buf_addr_tensor[l1_offset], + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE * tmp_times + + (n_idx - launch_delay) % vect_mod * TMP_SIZE * tmp_times / vect_mod) * + 2 / sizeof(QKV_DT) + + l1_k_split_idx * n_slice], + qk_m, // nValue + qk_round_m, // dstNzC0Stride + 0, // dstNzMatrixStride, unused + d, // dValue + 0, // dstNzMatrixStride, unused + sv_round_n * 2 / sizeof(QKV_DT) // srcDValue + ); + } + SET_FLAG(MTE2, MTE1, l1_pingpong_flag + 4); + WAIT_FLAG(MTE2, MTE1, l1_pingpong_flag + 4); + uint32_t d_split_loop = (d + 127) / 128; + for (uint32_t l0_k_split_idx = 0; l0_k_split_idx < d_split_loop; l0_k_split_idx++) { + uint32_t l0_pingpong_flag = l0_k_split_idx % 2; + uint32_t l0_offset = l0_pingpong_flag * 128 * 128; + bool l0_last_split = l0_k_split_idx == d_split_loop - 1; + int32_t l0_p_offset = qk_m == 1 ? l0_k_split_idx * 128 : l0_k_split_idx * 128 * qk_round_m; + bool initC = l0_k_split_idx == 0 && l1_k_split_idx == 0; + uint32_t reduce_d = l0_last_split ? d - l0_k_split_idx * 128 : 128; + uint32_t round_reduce_d = (reduce_d + 15) / 16 * 16; + WAIT_FLAG(M, MTE1, l0_pingpong_flag); + LoadDataToCa(l0a_buf_tensor[l0_offset], l1p_buf_addr_tensor[l1_offset + l0_p_offset], + RoundUp(round_reduce_d, BlockSize()), qk_round_m, qk_m); + if (l0_last_split) { + SET_FLAG(MTE1, MTE2, l1_pingpong_flag + 4); + } + WAIT_FLAG(M, MTE1, l0_pingpong_flag + 2); + for (uint32_t l0b_load_idx = 0; l0b_load_idx < 128 / BLOCK_SIZE; ++l0b_load_idx) { + l1_to_l0_b( + l0b_buf_tensor[l0_offset + l0b_load_idx * round_k * BLOCK_SIZE], + l1v_buf_addr_tensor[kv_pingpong_offset + l0b_load_idx * CUBE_MATRIX_SIZE + + l1_k_split_idx * n_slice * BLOCK_SIZE + l0_k_split_idx * 128 * BLOCK_SIZE], + 0, + round_k / BLOCK_SIZE, // repeat + 0, + sv_round_n / BLOCK_SIZE, // srcStride + 0, + 0 // dstStride + ); + } + if (l0_last_split && l1_last_split) { + SET_FLAG(MTE1, MTE2, kv_pingpong_flag * 2); + SET_FLAG(MTE1, MTE2, kv_pingpong_flag * 2 + 1); + } + SET_FLAG(MTE1, M, l0_pingpong_flag + 6); + WAIT_FLAG(MTE1, M, l0_pingpong_flag + 6); + if constexpr (int8_flag) { + mmad( + l0c_buf_tensor.template ReinterpretCast()[l0c_offset], l0a_buf_tensor[l0_offset], + l0b_buf_tensor[l0_offset], + qk_m, // m + __k, // n + sv_n, // k + 1 // cmatrixInitVal + ); + } else { + mmad( + l0c_buf_tensor[l0c_offset], l0a_buf_tensor[l0_offset], l0b_buf_tensor[l0_offset], + qk_m, // m + __k, // n + reduce_d, // k + initC // cmatrixInitVal + ); + } + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, l0_pingpong_flag); + SET_FLAG(M, MTE1, l0_pingpong_flag + 2); + } + } + SET_FLAG(M, FIX, l0c_pingpong_flag); + WAIT_FLAG(M, FIX, l0c_pingpong_flag); + // copy O to gm + l0c_to_gm( + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE * tmp_times + + (n_idx - launch_delay) % vect_mod * TMP_SIZE * tmp_times / vect_mod], + l0c_buf_tensor[l0c_offset], + qk_m, // MSize + round_k, // NSize + qk_round_m, // srcStride + round_k // dstStride_dst_D + ); + SET_FLAG(FIX, M, l0c_pingpong_flag); + FftsCrossCoreSync(UPDATE_READY); + kv_pingpong_flag = 1 - kv_pingpong_flag; + kv_pingpong_offset = kv_pingpong_flag * KV_DB_SIZE; + } + } + } + } + } + + private: + __gm__ uint8_t *__restrict__ q_gm{nullptr}; + __gm__ uint8_t *__restrict__ k_gm{nullptr}; + __gm__ uint8_t *__restrict__ v_gm{nullptr}; + __gm__ uint8_t *__restrict__ s_gm{nullptr}; + __gm__ uint8_t *__restrict__ p_gm{nullptr}; + __gm__ uint8_t *__restrict__ o_tmp_gm{nullptr}; + __gm__ uint8_t *__restrict__ tiling_para_gm{nullptr}; + + const uint32_t l1q_buf_addr_offset = 0; + const uint32_t l1k_buf_addr_offset = 4 * L0AB_UINT8_BLOCK_SIZE; + const uint32_t l1v_buf_addr_offset = 4 * L0AB_UINT8_BLOCK_SIZE; + const uint32_t l1p_buf_addr_offset = 12 * L0AB_UINT8_BLOCK_SIZE; + + AsdopsBuffer buf; + + AscendC::LocalTensor l1q_buf_addr_tensor = buf.GetBuffer(l1q_buf_addr_offset); + AscendC::LocalTensor l1k_buf_addr_tensor = buf.GetBuffer(l1k_buf_addr_offset); + AscendC::LocalTensor l1p_buf_addr_tensor = buf.GetBuffer(l1p_buf_addr_offset); + AscendC::LocalTensor l1v_buf_addr_tensor = buf.GetBuffer(l1v_buf_addr_offset); + + AscendC::GlobalTensor q_gm_tensor; + AscendC::GlobalTensor k_gm_tensor; + AscendC::GlobalTensor v_gm_tensor; + AscendC::GlobalTensor s_gm_tensor; + AscendC::GlobalTensor p_gm_tensor; + AscendC::GlobalTensor o_tmp_gm_tensor; + + AscendC::LocalTensor l0a_buf_tensor = buf.GetBuffer(0); + AscendC::LocalTensor l0b_buf_tensor = buf.GetBuffer(0); + AscendC::LocalTensor l0c_buf_tensor = buf.GetBuffer(0); + + uint32_t batch_size{0}; + uint32_t max_seqlen{0}; + uint32_t max_kv_seqlen{0}; + uint32_t max_q_seqlen{0}; + uint32_t q_heads{0}; + uint32_t embd{0}; + uint32_t kv_heads{0}; + uint32_t is_triu_mask{0}; + uint32_t total_q_blk_num{0}; + uint32_t group_num{0}; + uint64_t stride_qo{0}; + uint64_t stride_kv{0}; + uint64_t batch_stride_kv{0}; + uint32_t __k{0}; + uint32_t round_k{0}; + uint32_t data_shape_type{0}; + uint32_t tmp_times{1}; + uint32_t tilingKey{0}; + uint32_t tiling_head_size{0}; + uint32_t tiling_para_size{0}; + uint32_t npd{0}; + uint32_t page_size{0}; +}; +#elif __DAV_C220_VEC__ + +#include "fa_common.cce" + +constexpr int32_t UB_UINT8_BLOCK_SIZE = 16384; // 64 * 128 * 2B +constexpr int32_t UB_HALF_BUF_SIZE = 8192; +constexpr int32_t UB_bf16_BUF_SIZE = 8192; // 64 * 128 +constexpr int32_t UB_UINT8_LINE_SIZE = 512; // 128 * 4B +constexpr int32_t UB_FLOAT_LINE_SIZE = 64; // 128 +constexpr int32_t UB_HALF_LINE_SIZE = 128; // UB_FLOAT_LINE_SIZE * 2 +constexpr int32_t BASE_MASK_SIZE = 128; +constexpr int32_t COMPRESS_MASK_SIZE = 8192; // 64 * 128 +constexpr float BASE_Y = 128; + +__aicore__ __attribute__((always_inline)) inline void __set_mask(int32_t len) { + uint64_t mask = 0; + uint64_t one = 1; + uint64_t temp = len % FLOAT_VECTOR_SIZE; + for (int64_t i = 0; i < temp; i++) { + mask |= one << i; + } + + if (len == VECTOR_SIZE || len == 0) { + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else if (len >= FLOAT_VECTOR_SIZE) { + SetVectorMask(mask, (uint64_t)-1); + } else { + SetVectorMask(0x0, mask); + } +} + +__aicore__ __attribute__((always_inline)) inline void __set_vcg_mask(int32_t len) { + if (len > 16 || len < 1) { + SetVectorMask((uint64_t)-1, (uint64_t)-1); + return; + } + uint64_t subMask = ((uint64_t)1 << len) - 1; + uint64_t maskValue = (subMask << 48) + (subMask << 32) + (subMask << 16) + subMask + (subMask << 56) + + (subMask << 40) + (subMask << 24) + (subMask << 8); + SetVectorMask(maskValue, maskValue); +} + +template +class FlashAttentionEncoderHighPrecisionVec { + public: + __aicore__ __attribute__((always_inline)) inline FlashAttentionEncoderHighPrecisionVec( + __gm__ uint8_t *__restrict__ sync, __gm__ uint8_t *__restrict__ mask_gm, + __gm__ uint8_t *__restrict__ alibi_coeff_gm, __gm__ uint8_t *__restrict__ o_gm, __gm__ uint8_t *__restrict__ s_gm, + __gm__ uint8_t *__restrict__ p_gm, __gm__ uint8_t *__restrict__ o_tmp_gm, + __gm__ uint8_t *__restrict__ tiling_para_gm, __gm__ uint8_t *__restrict__ deq_qk_gm, + __gm__ uint8_t *__restrict__ off_qk_gm, __gm__ uint8_t *__restrict__ quant_p_gm, + __gm__ uint8_t *__restrict__ deq_pv_gm, __gm__ uint8_t *__restrict__ off_pv_gm, + __gm__ uint8_t *__restrict__ logN_gm) + : mask_gm(mask_gm), + o_gm(o_gm), + alibi_coeff_gm(alibi_coeff_gm), + s_gm(s_gm), + p_gm(p_gm), + o_tmp_gm(o_tmp_gm), + tiling_para_gm(tiling_para_gm), + deq_qk_gm(deq_qk_gm), + off_qk_gm(off_qk_gm), + quant_p_gm(quant_p_gm), + deq_pv_gm(deq_pv_gm), + off_pv_gm(off_pv_gm), + logN_gm(logN_gm) { + SetFftsBaseAddr((unsigned long)sync); + SetAtomicnone(); + SetMasknorm(); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + + this->sub_block_idx = GetSubBlockidx(); + this->batch_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm)); + this->max_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1)); + this->q_heads = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2)); + this->embd = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3)); + this->tor = (float)(*((__gm__ float *)tiling_para_gm + 5)); + this->head_stride = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 6)); + this->mask_stride = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 7)); + this->is_triu_mask = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 8)); + this->total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 9)); + this->isClamp = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 10)); + this->clampMin = (float)(*((__gm__ float *)tiling_para_gm + 11)); + this->clampMax = (float)(*((__gm__ float *)tiling_para_gm + 12)); + this->tiling_head_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 14)); + this->tiling_para_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 15)); + this->tilingKey = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 16)); + this->long_seq = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 17)); + this->is_sqrt = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 19)); + this->mask_type = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 20)); + this->alibi_compress_offset = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 21)); + this->alibi_left_align = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 22)); + this->quantType = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 24)); + this->data_shape_type = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 25)); + this->window_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 27)); + this->max_q_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 29)); + this->stride_qo = (uint64_t)q_heads * embd; + if (this->data_shape_type == 1) { + this->stride_qo = embd; + } + + this->__k = embd; + this->round_k = (__k + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + this->scaleType = (ScaleType)(*((__gm__ int32_t *)tiling_para_gm + 26)); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + SET_FLAG(MTE3, MTE2, EVENT_ID1); + SET_FLAG(MTE3, MTE2, EVENT_ID2); + SET_FLAG(V, MTE2, EVENT_ID0); + SET_FLAG(V, MTE2, EVENT_ID1); + SET_FLAG(V, MTE2, EVENT_ID2); + SET_FLAG(MTE3, V, EVENT_ID0); + SET_FLAG(V, MTE2, EVENT_ID7); + } + __aicore__ __attribute__((always_inline)) inline ~FlashAttentionEncoderHighPrecisionVec() { + WAIT_FLAG(V, MTE2, EVENT_ID7); + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + WAIT_FLAG(MTE3, MTE2, EVENT_ID2); + WAIT_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(V, MTE2, EVENT_ID1); + WAIT_FLAG(V, MTE2, EVENT_ID2); + WAIT_FLAG(MTE3, V, EVENT_ID0); + PIPE_BARRIER(ALL); + } + + __aicore__ __attribute__((always_inline)) inline uint32_t GetTilingKey() { return this->tilingKey; } + + template + __aicore__ __attribute__((always_inline)) inline uint32_t VectorSize() { + return 256 / sizeof(Dtype); + } + + template + __aicore__ __attribute__((always_inline)) inline uint64_t NumVectorsRoundUp(uint64_t num) { + return (num + VectorSize() - 1) / VectorSize(); + } + + __aicore__ __attribute__((always_inline)) inline void DeqPerHeadS322F32(AscendC::LocalTensor &s, + __gm__ uint8_t *deq_qk_gm, + __gm__ uint8_t *off_qk_gm, uint32_t head_idx, + uint32_t len) { + // dequant QK + // int32_t转成float类型 + conv_v(s, s.ReinterpretCast(), + NumVectorsRoundUp(len), // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + + // scale + float s_quant_scale = *((__gm__ float *)deq_qk_gm + head_idx); + muls_v(s, s, s_quant_scale, + NumVectorsRoundUp(len), // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + + template + __aicore__ inline void DivRepeatM(const AscendC::LocalTensor &dst, const AscendC::LocalTensor &src0, + const AscendC::LocalTensor &src1, const uint32_t sub_m, const uint32_t qk_n, + const uint32_t qk_round_n) { + uint32_t T_BLOCK_SIZE = BlockSize(); + uint32_t T_VECTOR_SIZE = VectorSize(); + for (uint32_t row_idx = 0; row_idx < qk_n / T_VECTOR_SIZE; ++row_idx) { + div_v(dst[row_idx * T_VECTOR_SIZE], src0[row_idx * T_VECTOR_SIZE], src1, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / T_BLOCK_SIZE, // dstRepeatStride + qk_round_n / T_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (qk_n % T_VECTOR_SIZE > 0) { + __set_mask(qk_n % T_VECTOR_SIZE); + div_v(dst[qk_n / T_VECTOR_SIZE * T_VECTOR_SIZE], + src0[qk_n / T_VECTOR_SIZE * T_VECTOR_SIZE], src1, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / T_BLOCK_SIZE, // dstRepeatStride + qk_round_n / T_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + } + + __aicore__ inline void SymmetricDeQuant(const AscendC::LocalTensor &lo_ubuf_tensor, + const AscendC::LocalTensor &p_scale_ubuf_tensor, uint32_t sub_m, + uint32_t round_sub_m, uint32_t qk_n, uint32_t qk_round_n, uint32_t head_idx) { + if (quantType == 3) { + brcb_v(tv_ubuf_tensor, p_scale_ubuf_tensor, + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / BlockSize() // repeat + ); + PIPE_BARRIER(V); + for (uint32_t row_idx = 0; row_idx < qk_n / VectorSize(); ++row_idx) { + mul_v(lo_ubuf_tensor[row_idx * VectorSize()], + lo_ubuf_tensor[row_idx * VectorSize()], tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / BlockSize(), // dstRepeatStride + qk_round_n / BlockSize(), // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (qk_n % VectorSize() > 0) { + __set_mask(qk_n % VectorSize()); + mul_v(lo_ubuf_tensor[qk_n / VectorSize() * VectorSize()], + lo_ubuf_tensor[qk_n / VectorSize() * VectorSize()], + tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / BlockSize(), // dstRepeatStride + qk_round_n / BlockSize(), // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + } else { + float p_scale = *((__gm__ float *)quant_p_gm + head_idx); + muls_v( + lo_ubuf_tensor, lo_ubuf_tensor, p_scale, + (sub_m * qk_round_n + VectorSize() - 1) / VectorSize(), // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + } + + __aicore__ inline void SymmetricQuant(const AscendC::LocalTensor &lp_ubuf_tensor, + const AscendC::LocalTensor &ls32_ubuf_tensor, + const AscendC::LocalTensor &lm_ubuf_tensor, + const AscendC::LocalTensor &hm_ubuf_tensor, + const AscendC::LocalTensor &p_scale_ubuf_tensor, const uint32_t sub_m, + const uint32_t round_sub_m, const uint32_t qk_n, const uint32_t qk_round_n, + uint32_t head_idx) { + // online quant + if (quantType == 3) { + // max(exp^(x-hm)) = max(exp^(x-lm)*exp^(lm-hm)) = 1 * exp^(lm-hm)) + sub_v(lm_ubuf_tensor, lm_ubuf_tensor, hm_ubuf_tensor, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + exp_v(lm_ubuf_tensor, lm_ubuf_tensor, + 1, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + muls_v(p_scale_ubuf_tensor, lm_ubuf_tensor, ((float)1 / (float)127), + 1, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + brcb_v(tv_ubuf_tensor, p_scale_ubuf_tensor, + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / BlockSize() // repeat + ); + PIPE_BARRIER(V); + DivRepeatM(lp_ubuf_tensor, ls32_ubuf_tensor, tv_ubuf_tensor, sub_m, qk_n, qk_round_n); + } else { // offline quant + float p_scale = (float)1.0 / *((__gm__ float *)quant_p_gm + head_idx); + muls_v( + lp_ubuf_tensor, ls32_ubuf_tensor, p_scale, + (sub_m * qk_round_n + VectorSize() - 1) / VectorSize(), // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + AscendC::Cast( + lp_ubuf_tensor.ReinterpretCast(), lp_ubuf_tensor, AscendC::RoundMode::CAST_RINT, (uint64_t)0, + (sub_m * qk_round_n + VectorSize() - 1) / VectorSize(), {1, 1, 4, 8}); + PIPE_BARRIER(V); + for (uint32_t row_idx = 0; row_idx < qk_n / VectorSize(); ++row_idx) { + AscendC::Cast( + lp_ubuf_tensor.ReinterpretCast()[row_idx * VectorSize()], + lp_ubuf_tensor.ReinterpretCast()[row_idx * VectorSize()], AscendC::RoundMode::CAST_RINT, + (uint64_t)0, sub_m, + {1, 1, (uint8_t)(qk_round_n / BlockSize()), (uint8_t)(qk_round_n / BlockSize())}); + } + if (qk_n % VectorSize() > 0) { + __set_mask(qk_n % VectorSize()); + AscendC::Cast( + lp_ubuf_tensor.ReinterpretCast()[qk_n / VectorSize() * VectorSize()], + lp_ubuf_tensor.ReinterpretCast()[qk_n / VectorSize() * VectorSize()], + AscendC::RoundMode::CAST_RINT, (uint64_t)0, sub_m, + {1, 1, (uint8_t)(qk_round_n / BlockSize()), (uint8_t)(qk_round_n / BlockSize())}); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + } + + template + __aicore__ __attribute__((always_inline)) inline void Run() { + uint64_t cur_batch = 0; + uint64_t pre_total_q_blk_num = 0; + uint32_t offset_tiling = tiling_head_size + tiling_para_size * cur_batch; + uint32_t cur_total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 13 + offset_tiling)); + uint64_t process_num = (uint64_t)total_q_blk_num * q_heads; + float alibi_coeff = 1; + + mask_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ IN_DATA_TYPE *>(mask_gm)); + o_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ O_DT *>(o_gm)); + s_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(s_gm)); + p_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ QKV_DT *>(p_gm)); + o_tmp_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(o_tmp_gm)); + logN_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ IN_DATA_TYPE *>(logN_gm)); + logN_float_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(logN_gm)); + uint64_t next_process = 0; + for (uint64_t process = block_idx; process < process_num; process = next_process) { + while (process >= (uint64_t)cur_total_q_blk_num * q_heads) { + cur_batch++; + pre_total_q_blk_num = cur_total_q_blk_num; + offset_tiling += tiling_para_size; + cur_total_q_blk_num = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 13 + offset_tiling)); + } + next_process = process + block_num; + if (is_triu_mask) { + uint64_t curr_iter = process / block_num; + next_process = + curr_iter % 2 == 1 ? (curr_iter + 1) * block_num + block_idx : (curr_iter + 2) * block_num - 1 - block_idx; + } + + // get tiling args + uint32_t q_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + offset_tiling)); + uint32_t kv_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1 + offset_tiling)); + if (q_seqlen == 0 || kv_seqlen == 0) { + continue; + } + uint32_t pp_m_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2 + offset_tiling)); + uint32_t pp_n_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3 + offset_tiling)); + uint32_t addr_o_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 10 + offset_tiling)); + uint32_t addr_o_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 11 + offset_tiling)); + uint64_t addr_o_scalar = (uint64_t)(((uint64_t)addr_o_high32) << 32 | addr_o_loww32); + + uint64_t process_idx = process - pre_total_q_blk_num * q_heads; + uint32_t m_idx = process_idx / q_heads; + uint64_t head_idx = process_idx % q_heads; + if (alibi_coeff_gm != nullptr) { + alibi_coeff = (float)(*((__gm__ float *)alibi_coeff_gm + head_idx)); + } + + uint32_t m_loop = (q_seqlen + pp_m_scalar - 1) / pp_m_scalar; + uint32_t n_loop = (kv_seqlen + pp_n_scalar - 1) / pp_n_scalar; + + uint32_t qk_m = (m_idx == (m_loop - 1)) ? (q_seqlen - m_idx * pp_m_scalar) : pp_m_scalar; + uint32_t sub_m = (sub_block_idx == 1) ? (qk_m - qk_m / 2) : qk_m / 2; + uint32_t sub_m_d128 = (sub_m + VECTOR_SIZE - 1) / VECTOR_SIZE; // up aligned to 128 + uint32_t sub_m_d64 = (sub_m + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE; // up aligned to 64 + uint32_t round_sub_m = (sub_m + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + /******** pre_load *******/ + uint32_t qk_n = n_loop == 1 ? kv_seqlen : pp_n_scalar; + uint32_t qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + uint64_t mask_batch_offset = cur_batch * mask_stride * max_seqlen; + uint64_t mask_head_offset = head_idx * head_stride * max_seqlen; + uint64_t mask_offset = mask_batch_offset + mask_head_offset; + uint32_t delta_uint = 0; + float base_y = -128; + float delta = 0; + + if (long_seq == 0) { + mask_offset += m_idx * pp_m_scalar * max_seqlen; + } else { + gm_to_ub(mask16_ubuf_tensor, + mask_gm_tensor[(uint64_t)sub_block_idx * qk_m / 2 * VECTOR_SIZE], + 0, // sid + sub_m, // nBurst + VECTOR_SIZE / BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + conv_v(mask_ubuf_tensor, mask16_ubuf_tensor, + sub_m * VECTOR_SIZE / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 4 // srcRepeatStride + ); + PIPE_BARRIER(V); + muls_v(mask_ubuf_tensor, mask_ubuf_tensor, (float)-3e38, + sub_m * VECTOR_SIZE / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + + uint64_t o_offset = addr_o_scalar + head_idx * embd + m_idx * pp_m_scalar * stride_qo; + if (data_shape_type == 1) { + o_offset = addr_o_scalar + head_idx * embd * max_q_seqlen + m_idx * pp_m_scalar * stride_qo; + } + + uint32_t n_end = n_loop; + if ((is_triu_mask && pp_m_scalar >= pp_n_scalar) || window_size > 0) { + uint32_t n_offset = ((m_idx + 1) * pp_m_scalar + kv_seqlen - q_seqlen + pp_n_scalar - 1) / pp_n_scalar; + n_end = n_offset > n_end ? n_end : n_offset; + } + uint32_t window_start = (window_size + pp_n_scalar - 1) / pp_n_scalar; + uint32_t n_start = 0; + if constexpr (swa_flag) { + if (window_size > 0 && window_size < kv_seqlen) { + n_start = (m_idx < window_start) ? 0 : m_idx - window_start; + if constexpr (!swa_compress) { + mask_offset += n_start * pp_n_scalar; + } + } + } + uint32_t qk_n_triu = n_end * pp_n_scalar; + uint32_t s_block_stack = n_end > 4 ? 2 : 1; + // PV is in stage 3, which means the 1st PV block corresponding to the 3th QK in our pipeline strategy + + uint32_t pv_stage = 3; + uint32_t launch_delay = s_block_stack * 2; + uint32_t vect_mod = 2 * launch_delay; + uint32_t m_slice = sub_m > 32 ? 32 : 0; // s_block_stack=2时,UB可以放下 + uint32_t m_end = sub_m > 32 ? 2 : 1; + for (uint32_t n_idx = n_start; n_idx < n_end + launch_delay; n_idx += s_block_stack) { + if (n_idx < n_end) { + uint32_t p_scale_offset = + n_idx / s_block_stack % pv_stage * RoundUp(pp_m_scalar, FLOAT_VECTOR_SIZE); + if (n_idx + s_block_stack > n_end - 1) { + qk_n = qk_n_triu > kv_seqlen ? kv_seqlen - n_idx * pp_n_scalar : qk_n_triu - n_idx * pp_n_scalar; + } else { + qk_n = pp_n_scalar * s_block_stack; + } + uint32_t delta_idx = m_idx - n_idx; + bool skip_mask = window_start > 3 && delta_idx > 1 && delta_idx < window_start - 1; + if constexpr (swa_compress) { + mask_offset = 0; + if (window_start <= 3) { // window < 128*3 最多跨4个基块 + if (m_idx < n_idx) { + mask_offset = pp_n_scalar; // 偏移128个数, midx=0, nidx=1 + } else { + mask_offset = delta_idx * max_seqlen * pp_m_scalar; + } + } else { + if (delta_idx == 0) { + mask_offset = 0; // m = n + } else if (delta_idx == window_start) { + mask_offset = 3 * max_seqlen * pp_m_scalar; + } else if (delta_idx == 1) { + mask_offset = max_seqlen * pp_m_scalar; + } else if (delta_idx == window_start - 1) { + mask_offset = 2 * max_seqlen * pp_m_scalar; + } // delta idx in [2, window-1) do not move and add mask + } + } + qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + if (qk_n <= VECTOR_SIZE) { + if (sub_m > 0 && mask_type != 0) { + if (alibi_coeff_gm != nullptr) { + if (alibi_left_align == 0) { + if (n_idx == n_end - 1) { + mask_offset = 0; + delta_uint = 0; + delta = 0; + } else { + mask_offset = BASE_MASK_SIZE * max_seqlen; + delta_uint = m_idx * pp_m_scalar - n_idx * pp_n_scalar; + delta = base_y + delta_uint; + } + } else { + if (n_idx == n_end - 1) { + mask_offset = 0; + } else { + mask_offset = BASE_MASK_SIZE * max_seqlen; + } + delta = -base_y * n_idx; + } + } else if (mask_type == 2 && alibi_compress_offset > 0) { + if (n_idx == n_end - 1) { + mask_offset = head_idx * alibi_compress_offset * BASE_MASK_SIZE; + } else { + delta_uint = m_idx * pp_m_scalar - n_idx * pp_n_scalar; + mask_offset = BASE_MASK_SIZE * delta_uint + head_idx * alibi_compress_offset * BASE_MASK_SIZE; + } + } + if (long_seq == 0) { + WAIT_FLAG(V, MTE2, EVENT_ID1); + if constexpr (!swa_compress) { + gm_to_ub_align( + mask16_ubuf_tensor, mask_gm_tensor[mask_offset + sub_block_idx * qk_m / 2 * max_seqlen], + 0, // sid + sub_m, // nBurst + qk_n * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + (max_seqlen - qk_n) * 2, // srcGap + 0 // dstGap + ); + } else { + if (!(skip_mask)) { + gm_to_ub_align( + mask16_ubuf_tensor, mask_gm_tensor[mask_offset + sub_block_idx * qk_m / 2 * max_seqlen], + 0, // sid + sub_m, // nBurst + qk_n * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + (max_seqlen - qk_n) * 2, // srcGap + 0 // dstGap + ); + } + } + SET_FLAG(MTE2, V, EVENT_ID1); + if constexpr (!swa_compress) { + mask_offset += qk_n; + } + WAIT_FLAG(MTE2, V, EVENT_ID1); + if constexpr (!swa_compress) { + conv_v( + mask_ubuf_tensor, mask16_ubuf_tensor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 4 // srcRepeatStride + ); + } else { + if (!(skip_mask)) { + conv_v( + mask_ubuf_tensor, mask16_ubuf_tensor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 4 // srcRepeatStride + ); + } + } + if (alibi_coeff_gm != nullptr) { + PIPE_BARRIER(V); + if (is_sqrt == 1 && m_idx != n_idx) { + mul_v( + mask_ubuf_tensor, mask_ubuf_tensor, mask_ubuf_tensor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + } + adds_v( + mask_ubuf_tensor, mask_ubuf_tensor, (float)delta, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + if (is_sqrt == 1 && m_idx != n_idx) { + sqrt_v( + mask_ubuf_tensor, mask_ubuf_tensor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + muls_v( + mask_ubuf_tensor, mask_ubuf_tensor, (float)alibi_coeff, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } + PIPE_BARRIER(V); + if (head_stride == 0 && mask_type != 2) { + if constexpr (!swa_compress) { + muls_v( + mask_ubuf_tensor, mask_ubuf_tensor, (float)-3e38, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } else { + if (!(skip_mask)) { + muls_v( + mask_ubuf_tensor, mask_ubuf_tensor, (float)-3e38, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } + } + PIPE_BARRIER(V); + } + } + } + WaitFlagDev(QK_READY); + if (sub_m > 0) { + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + // input QK + gm_to_ub( + ls_ubuf_tensor, + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + (n_idx - n_start) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n], + 0, // sid + sub_m, // nBurst + qk_round_n / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID0); + if (scaleType == ScaleType::SCALE_LOGN_FP32) { + WAIT_FLAG(V, MTE2, EVENT_ID7); + gm_to_ub_align( + log_ubuf_float_tensor, logN_float_gm_tensor[m_idx * pp_m_scalar + (uint64_t)sub_block_idx * qk_m / 2], + 0, // sid + 1, // nBurst + sub_m * 4, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap byte + (round_sub_m - sub_m) * 4 // dstGap block + ); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + brcb_v(tv_ubuf_tensor.ReinterpretCast()[VECTOR_SIZE], + log_ubuf_float_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + (sub_m + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + WAIT_FLAG(MTE2, V, EVENT_ID0); + SET_FLAG(V, MTE2, EVENT_ID7); + if constexpr (int8_flag) { + DeqPerHeadS322F32(ls_ubuf_tensor, deq_qk_gm, off_qk_gm, head_idx, sub_m * qk_round_n); + } + for (uint32_t vdiv_idx = 0; vdiv_idx < qk_n / FLOAT_VECTOR_SIZE; ++vdiv_idx) { + mul_v(ls_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor[VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + PIPE_BARRIER(V); + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + mul_v(ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor[VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + // *** ls = tor * ls + muls_v( + ls_ubuf_tensor, ls_ubuf_tensor, tor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } else if (scaleType == ScaleType::SCALE_LOGN) { + WAIT_FLAG(V, MTE2, EVENT_ID7); + gm_to_ub_align( + log_ubuf_tensor, logN_gm_tensor[m_idx * pp_m_scalar + (uint64_t)sub_block_idx * qk_m / 2], + 0, // sid + 1, // nBurst + sub_m * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap byte + (round_sub_m - sub_m) * 2 // dstGap block + ); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + conv_v(tv_ubuf_tensor, log_ubuf_tensor, + sub_m_d64, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 4 // srcRepeatStride + ); + PIPE_BARRIER(V); + brcb_v(tv_ubuf_tensor.ReinterpretCast()[VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + (sub_m + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + WAIT_FLAG(MTE2, V, EVENT_ID0); + SET_FLAG(V, MTE2, EVENT_ID7); + if constexpr (int8_flag) { + DeqPerHeadS322F32(ls_ubuf_tensor, deq_qk_gm, off_qk_gm, head_idx, sub_m * qk_round_n); + } + for (uint32_t vdiv_idx = 0; vdiv_idx < qk_n / FLOAT_VECTOR_SIZE; ++vdiv_idx) { + mul_v(ls_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor[VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + PIPE_BARRIER(V); + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + mul_v(ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor[VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + // *** ls = tor * ls + muls_v( + ls_ubuf_tensor, ls_ubuf_tensor, tor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } else { + WAIT_FLAG(MTE2, V, EVENT_ID0); + if constexpr (int8_flag) { + DeqPerHeadS322F32(ls_ubuf_tensor, deq_qk_gm, off_qk_gm, head_idx, sub_m * qk_round_n); + } + // *** ls = tor * ls + muls_v( + ls_ubuf_tensor, ls_ubuf_tensor, tor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } + PIPE_BARRIER(V); + + if (isClamp == 1) { + // get min(clampMin,ls_ubuf) + maxs_v( + ls_ubuf_tensor, ls_ubuf_tensor, clampMin, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + + // get max(clampMin,ls_ubuf) + mins_v( + ls_ubuf_tensor, ls_ubuf_tensor, clampMax, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + + // *** ls = ls + mask + if (mask_type != 0) { + if (long_seq == 0) { + if constexpr (!swa_compress) { + add_v( + ls_ubuf_tensor, ls_ubuf_tensor, mask_ubuf_tensor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + } else { + if (!(skip_mask)) { + add_v( + ls_ubuf_tensor, ls_ubuf_tensor, mask_ubuf_tensor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + } + } + SET_FLAG(V, MTE2, EVENT_ID1); + } else if (pp_n_scalar == FLOAT_VECTOR_SIZE && s_block_stack == 2 && n_idx == n_end - 2) { + __set_mask(qk_n - FLOAT_VECTOR_SIZE); + add_v(ls_ubuf_tensor[FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[FLOAT_VECTOR_SIZE], mask_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 16 // src1RepeatStride + ); + } else if (n_idx == n_end - 1) { + if (qk_n < FLOAT_VECTOR_SIZE) { + __set_mask(qk_n); + } else { + __set_mask(FLOAT_VECTOR_SIZE); + } + add_v(ls_ubuf_tensor, ls_ubuf_tensor, mask_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 16 // src1RepeatStride + ); + if (qk_n > FLOAT_VECTOR_SIZE) { + __set_mask(qk_n - FLOAT_VECTOR_SIZE); + add_v(ls_ubuf_tensor[FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[FLOAT_VECTOR_SIZE], + mask_ubuf_tensor[FLOAT_VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 16 // src1RepeatStride + ); + } + } + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + // *** lm = rowmax(ls) + if (qk_n <= FLOAT_VECTOR_SIZE) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + cgmax_v(tv_ubuf_tensor, ls_ubuf_tensor, sub_m, 1, 1, + qk_round_n / FLOAT_BLOCK_SIZE); + PIPE_BARRIER(V); + __set_vcg_mask((qk_n + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE); + cgmax_v( + lm_ubuf_tensor, tv_ubuf_tensor, + (sub_m * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else { + cgmax_v(tv_ubuf_tensor, ls_ubuf_tensor, sub_m, 1, 1, qk_round_n / 8); + PIPE_BARRIER(V); + cgmax_v(lm_ubuf_tensor, tv_ubuf_tensor, round_sub_m * 8 / 64, 1, 1, 8); + PIPE_BARRIER(V); + for (uint32_t rowMaxIdx = 1; rowMaxIdx < qk_n / FLOAT_VECTOR_SIZE; ++rowMaxIdx) { + cgmax_v(tv_ubuf_tensor, ls_ubuf_tensor[rowMaxIdx * FLOAT_VECTOR_SIZE], + sub_m, 1, 1, qk_round_n / 8); + PIPE_BARRIER(V); + cgmax_v(tv_ubuf_tensor, tv_ubuf_tensor, round_sub_m * 8 / 64, 1, 1, 8); + PIPE_BARRIER(V); + __set_mask(sub_m); + max_v(lm_ubuf_tensor, lm_ubuf_tensor, tv_ubuf_tensor, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + } + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + cgmax_v(tv_ubuf_tensor, + ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + sub_m, 1, 1, qk_round_n / 8); + PIPE_BARRIER(V); + __set_vcg_mask((qk_n % FLOAT_VECTOR_SIZE + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE); + cgmax_v(tv_ubuf_tensor, tv_ubuf_tensor, round_sub_m * 8 / 64, 1, 1, 8); + PIPE_BARRIER(V); + __set_mask(sub_m); + max_v(lm_ubuf_tensor, lm_ubuf_tensor, tv_ubuf_tensor, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + } + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + if (n_idx == n_start) { + // *** hm = lm + ub_to_ub(hm_ubuf_tensor, lm_ubuf_tensor, + 0, // sid + 1, // nBurst + round_sub_m / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + } else { + // *** hm = vmax(lm, gm) + max_v(hm_ubuf_tensor, lm_ubuf_tensor, gm_ubuf_tensor, + sub_m_d64, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** dm = gm - hm + sub_v( + dm_ubuf_tensor[((n_idx - n_start) / s_block_stack) % 4 * UB_FLOAT_LINE_SIZE], gm_ubuf_tensor, + hm_ubuf_tensor, + sub_m_d64, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** dm = exp(dm) + exp_v( + dm_ubuf_tensor[((n_idx - n_start) / s_block_stack) % 4 * UB_FLOAT_LINE_SIZE], + dm_ubuf_tensor[((n_idx - n_start) / s_block_stack) % 4 * UB_FLOAT_LINE_SIZE], + sub_m_d64, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } + // *** gm = hm + ub_to_ub(gm_ubuf_tensor, hm_ubuf_tensor, + 0, // sid + 1, // nBurst + round_sub_m / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + // *** hm_block = expand_to_block(hm), 存放于 tv + brcb_v(tv_ubuf_tensor.ReinterpretCast(), + hm_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** ls = ls - hm_block + for (uint32_t vsub_idx = 0; vsub_idx < qk_n / FLOAT_VECTOR_SIZE; ++vsub_idx) { + sub_v(ls_ubuf_tensor[vsub_idx * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[vsub_idx * FLOAT_VECTOR_SIZE], tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + sub_v(ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + // *** ls = exp(ls) + exp_v( + ls32_ubuf_tensor, ls_ubuf_tensor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + if constexpr (int8_flag) { + SymmetricQuant(lp_ubuf_tensor, ls32_ubuf_tensor, lm_ubuf_tensor, hm_ubuf_tensor, + p_scale_ubuf_tensor[p_scale_offset], sub_m, round_sub_m, qk_n, qk_round_n, head_idx); + } else { + // *** lp = castfp32to16(ls) + if (IS_BF16) { + convr_v( + lp_ubuf_tensor.ReinterpretCast(), ls32_ubuf_tensor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + } else { + conv_v( + lp_ubuf_tensor.ReinterpretCast(), ls32_ubuf_tensor, + (sub_m * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + } + PIPE_BARRIER(V); + } + + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + ub_to_gm( + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + (n_idx - n_start) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n) * + 2 / sizeof(QKV_DT)], + lp_ubuf_tensor.ReinterpretCast(), + 0, // sid + 1, // nBurst + sub_m * qk_round_n * 2 / BlockSize(), // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + // *** ll = rowsum(ls32) + if (qk_n <= FLOAT_VECTOR_SIZE) { + __set_mask(qk_n); + cadd_v(ll_ubuf_tensor, ls32_ubuf_tensor, + sub_m, // repeat + 1, // dstRepeatStride + 1, // srcBlockStride + qk_round_n / FLOAT_BLOCK_SIZE // srcRepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else { + for (uint32_t rowSumIdx = 1; rowSumIdx < qk_n / FLOAT_VECTOR_SIZE; ++rowSumIdx) { + add_v(ls32_ubuf_tensor, ls32_ubuf_tensor, + ls32_ubuf_tensor[rowSumIdx * FLOAT_VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + qk_round_n / FLOAT_BLOCK_SIZE // src1RepeatStride + ); + PIPE_BARRIER(V); + } + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + add_v(ls32_ubuf_tensor, ls32_ubuf_tensor, + ls32_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + qk_round_n / FLOAT_BLOCK_SIZE // src1RepeatStride + ); + } + SetVectorMask((uint64_t)-1, (uint64_t)-1); + PIPE_BARRIER(V); + cadd_v(ll_ubuf_tensor, ls32_ubuf_tensor, + sub_m, // repeat + 1, // dstRepeatStride + 1, // srcBlockStride + qk_round_n / FLOAT_BLOCK_SIZE // srcRepeatStride + ); + } + PIPE_BARRIER(V); + if (n_idx == n_start) { + // *** gl = ll + ub_to_ub(gl_ubuf_tensor, ll_ubuf_tensor, + 0, // sid + 1, // nBurst + round_sub_m / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + } else { + // *** gl = dm * gl + mul_v( + gl_ubuf_tensor, dm_ubuf_tensor[((n_idx - n_start) / s_block_stack) % 4 * UB_FLOAT_LINE_SIZE], + gl_ubuf_tensor, + sub_m_d64, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** gl = ll + gl + add_v(gl_ubuf_tensor, gl_ubuf_tensor, ll_ubuf_tensor, + sub_m_d64, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + } + } + } else { + bool last_n_loop = n_idx + s_block_stack > n_end - 1; + for (uint32_t split_idx = 0; split_idx < m_end; split_idx++) { + bool last_m_loop = split_idx == m_end - 1; + uint32_t m_split = last_m_loop ? sub_m - split_idx * m_slice : m_slice; + uint32_t round_m_split = (m_split + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE * FLOAT_BLOCK_SIZE; + if (sub_m > 0 && mask_type != 0 && long_seq == 0) { + WAIT_FLAG(V, MTE2, EVENT_ID1); + uint64_t mask_offset_tail = 0; + if (alibi_coeff_gm != nullptr) { + mask_offset = BASE_MASK_SIZE * SOFTMAX_MAX_LENGTH; + if (alibi_left_align == 0) { + delta = base_y * (n_idx + 1 - m_idx); + } else { + delta = -base_y * n_idx; + } + gm_to_ub( + mask16_ubuf_tensor, + mask_gm_tensor[mask_offset + (sub_block_idx * qk_m / 2 + split_idx * m_slice) * SOFTMAX_MAX_LENGTH], + 0, // sid + m_split, // nBurst + qk_round_n / BLOCK_SIZE, // lenBurst + (SOFTMAX_MAX_LENGTH - qk_round_n) / BLOCK_SIZE, // srcGap + 0 // dstGap + ); + } else if (mask_type == 2 && alibi_compress_offset > 0) { + delta_uint = m_idx * pp_m_scalar - n_idx * pp_n_scalar; + mask_offset = BASE_MASK_SIZE * delta_uint + head_idx * alibi_compress_offset * BASE_MASK_SIZE; + mask_offset_tail = mask_offset - BASE_MASK_SIZE * pp_n_scalar; + if (n_idx == n_end - 2) { + mask_offset_tail = head_idx * alibi_compress_offset * BASE_MASK_SIZE; + } + gm_to_ub( + mask16_ubuf_tensor, + mask_gm_tensor[mask_offset + (sub_block_idx * qk_m / 2 + split_idx * m_slice) * VECTOR_SIZE], + 0, // sid + m_split, // nBurst + 8, // lenBurst + 0, // srcGap + (qk_round_n - VECTOR_SIZE) / BLOCK_SIZE // dstGap + ); + gm_to_ub( + mask16_ubuf_tensor[VECTOR_SIZE], + mask_gm_tensor[mask_offset_tail + (sub_block_idx * qk_m / 2 + split_idx * m_slice) * VECTOR_SIZE], + 0, // sid + m_split, // nBurst + (qk_round_n - VECTOR_SIZE) / BLOCK_SIZE, // lenBurst + (SOFTMAX_MAX_LENGTH - qk_round_n) / BLOCK_SIZE, // srcGap + 8 // dstGap + ); + } else { + if constexpr (!swa_compress) { + gm_to_ub_align( + mask16_ubuf_tensor, + mask_gm_tensor[mask_offset + (sub_block_idx * qk_m / 2 + split_idx * m_slice) * max_seqlen], + 0, // sid + m_split, // nBurst + qk_n * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + (max_seqlen - qk_n) * 2, // srcGap + 0 // dstGap + ); + } else { + if (!skip_mask) { + gm_to_ub_align( + mask16_ubuf_tensor, + mask_gm_tensor[mask_offset + (sub_block_idx * qk_m / 2 + split_idx * m_slice) * max_seqlen], + 0, // sid + m_split, // nBurst + qk_n * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + (max_seqlen - qk_n) * 2, // srcGap + 0 // dstGap + ); + } + } + } + SET_FLAG(MTE2, V, EVENT_ID1); + } + if (split_idx == 0) { + WaitFlagDev(QK_READY); + } + if (sub_m > 0) { + if (m_split > 0) { + if (mask_type != 0 && long_seq == 0) { + WAIT_FLAG(MTE2, V, EVENT_ID1); + if constexpr (!swa_compress) { + conv_v( + mask_ubuf_tensor, mask16_ubuf_tensor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 4 // srcRepeatStride + ); + } else { + if (!skip_mask) { + conv_v( + mask_ubuf_tensor, mask16_ubuf_tensor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 4 // srcRepeatStride + ); + } + } + + SET_FLAG(V, MTE2, EVENT_ID1); + if (alibi_coeff_gm != nullptr) { + PIPE_BARRIER(V); + if (n_idx != n_end - 2) { + if (is_sqrt == 1) { + mul_v( + mask_ubuf_tensor, mask_ubuf_tensor, mask_ubuf_tensor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + } + adds_v( + mask_ubuf_tensor, mask_ubuf_tensor, (float)delta, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + if (alibi_left_align == 1) { + adds_v(mask_ubuf_tensor[128], mask_ubuf_tensor, (float)-base_y, + m_split, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + qk_round_n / 8, // dstRepeatStride + qk_round_n / 8 // srcRepeatStride + ); + adds_v(mask_ubuf_tensor[128 + FLOAT_VECTOR_SIZE], + mask_ubuf_tensor[FLOAT_VECTOR_SIZE], (float)-base_y, + m_split, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + qk_round_n / 8, // dstRepeatStride + qk_round_n / 8 // srcRepeatStride + ); + } else { + adds_v(mask_ubuf_tensor[128], mask_ubuf_tensor, base_y, + m_split, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + qk_round_n / 8, // dstRepeatStride + qk_round_n / 8 // srcRepeatStride + ); + adds_v(mask_ubuf_tensor[128 + FLOAT_VECTOR_SIZE], + mask_ubuf_tensor[FLOAT_VECTOR_SIZE], base_y, + m_split, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + qk_round_n / 8, // dstRepeatStride + qk_round_n / 8 // srcRepeatStride + ); + } + PIPE_BARRIER(V); + if (is_sqrt == 1) { + sqrt_v( + mask_ubuf_tensor, mask_ubuf_tensor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + } else if (alibi_left_align == 1) { + adds_v( + mask_ubuf_tensor, mask_ubuf_tensor, (float)delta, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + + muls_v( + mask_ubuf_tensor, mask_ubuf_tensor, (float)alibi_coeff, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } + PIPE_BARRIER(V); + if (head_stride == 0 && mask_type != 2) { + if constexpr (!swa_compress) { + muls_v( + mask_ubuf_tensor, mask_ubuf_tensor, (float)-3e38, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } else { + if (!skip_mask) { + muls_v( + mask_ubuf_tensor, mask_ubuf_tensor, (float)-3e38, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } + } + PIPE_BARRIER(V); + } + } + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + // input QK + gm_to_ub( + ls_ubuf_tensor, + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + (n_idx - n_start) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)(sub_block_idx * qk_m / 2 + split_idx * m_slice) * qk_round_n], + 0, // sid + m_split, // nBurst + qk_round_n / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID0); + WAIT_FLAG(MTE2, V, EVENT_ID0); + if constexpr (int8_flag) { + DeqPerHeadS322F32(ls_ubuf_tensor, deq_qk_gm, off_qk_gm, head_idx, m_split * qk_round_n); + } + SET_FLAG(MTE2, V, EVENT_ID0); + if (scaleType == ScaleType::SCALE_LOGN_FP32) { + WAIT_FLAG(V, MTE2, EVENT_ID7); + gm_to_ub_align( + log_ubuf_float_tensor, + logN_float_gm_tensor[m_idx * pp_m_scalar + + (uint64_t)(sub_block_idx * qk_m / 2 + split_idx * m_slice)], + 0, // sid + 1, // nBurst + m_split * 4, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap byte + (round_m_split - m_split) * 4 // dstGap block + ); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + brcb_v(tv_ubuf_tensor.ReinterpretCast()[VECTOR_SIZE], + log_ubuf_float_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_m_split / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + WAIT_FLAG(MTE2, V, EVENT_ID0); + SET_FLAG(V, MTE2, EVENT_ID7); + for (uint32_t vdiv_idx = 0; vdiv_idx < qk_n / FLOAT_VECTOR_SIZE; ++vdiv_idx) { + mul_v(ls_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor[VECTOR_SIZE], + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + PIPE_BARRIER(V); + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + mul_v(ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor[VECTOR_SIZE], + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + muls_v( + ls_ubuf_tensor, ls_ubuf_tensor, tor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } else if (scaleType == ScaleType::SCALE_LOGN) { + WAIT_FLAG(V, MTE2, EVENT_ID7); + gm_to_ub_align( + log_ubuf_tensor, + logN_gm_tensor[m_idx * pp_m_scalar + (uint64_t)(sub_block_idx * qk_m / 2 + split_idx * m_slice)], + 0, // sid + 1, // nBurst + m_split * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap byte + (round_m_split - m_split) * 2 // dstGap block + ); + SET_FLAG(MTE2, V, EVENT_ID1); + WAIT_FLAG(MTE2, V, EVENT_ID1); + conv_v( + tv_ubuf_tensor, log_ubuf_tensor, + (m_split + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 4 // srcRepeatStride + ); + PIPE_BARRIER(V); + brcb_v(tv_ubuf_tensor.ReinterpretCast()[VECTOR_SIZE], + tv_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_m_split / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + WAIT_FLAG(MTE2, V, EVENT_ID0); + SET_FLAG(V, MTE2, EVENT_ID7); + for (uint32_t vdiv_idx = 0; vdiv_idx < qk_n / FLOAT_VECTOR_SIZE; ++vdiv_idx) { + mul_v(ls_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor[VECTOR_SIZE], + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + PIPE_BARRIER(V); + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + mul_v(ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor[VECTOR_SIZE], + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + muls_v( + ls_ubuf_tensor, ls_ubuf_tensor, tor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } else { + WAIT_FLAG(MTE2, V, EVENT_ID0); + // *** ls = tor * ls + muls_v( + ls_ubuf_tensor, ls_ubuf_tensor, tor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } + PIPE_BARRIER(V); + if (isClamp == 1) { + // get min(clampMin,ls_ubuf) + maxs_v( + ls_ubuf_tensor, ls_ubuf_tensor, clampMin, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + // get max(clampMin,ls_ubuf) + mins_v( + ls_ubuf_tensor, ls_ubuf_tensor, clampMax, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + } + // *** ls = ls + mask + if (mask_type != 0) { + if (long_seq == 0) { + if constexpr (!swa_compress) { + add_v( + ls_ubuf_tensor, ls_ubuf_tensor, mask_ubuf_tensor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + } else { + if (!skip_mask) { + add_v( + ls_ubuf_tensor, ls_ubuf_tensor, mask_ubuf_tensor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + } + } + } else if (n_idx == n_end - 2) { + if (qk_n - pp_n_scalar < FLOAT_VECTOR_SIZE) { + __set_mask(qk_n - pp_n_scalar); + } else { + __set_mask(FLOAT_VECTOR_SIZE); + } + add_v(ls_ubuf_tensor[pp_n_scalar], ls_ubuf_tensor[pp_n_scalar], + mask_ubuf_tensor[split_idx * m_slice * VECTOR_SIZE], + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 16 // src1RepeatStride + ); + if (qk_n - pp_n_scalar > FLOAT_VECTOR_SIZE) { + __set_mask(qk_n - pp_n_scalar - FLOAT_VECTOR_SIZE); + add_v( + ls_ubuf_tensor[pp_n_scalar + FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[pp_n_scalar + FLOAT_VECTOR_SIZE], + mask_ubuf_tensor[FLOAT_VECTOR_SIZE + split_idx * m_slice * VECTOR_SIZE], + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 16 // src1RepeatStride + ); + } + } + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + if (qk_n == SOFTMAX_MAX_LENGTH) { + cgmax_v(tv_ubuf_tensor, ls_ubuf_tensor, + m_split * qk_n / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + __set_mask(32); + cgmax_v(tv_ubuf_tensor, tv_ubuf_tensor, m_split, 1, 1, 4); + PIPE_BARRIER(V); + __set_vcg_mask(4); + cgmax_v( + lm_ubuf_tensor, tv_ubuf_tensor, + (m_split * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + __set_mask(m_split); + } else { + cgmax_v(tv_ubuf_tensor, ls_ubuf_tensor, m_split, 1, 1, + qk_round_n / FLOAT_BLOCK_SIZE); + PIPE_BARRIER(V); + cgmax_v( + lm_ubuf_tensor, tv_ubuf_tensor, + (m_split * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + for (uint64_t rowmax_idx = 1; rowmax_idx < (uint64_t)qk_n / FLOAT_VECTOR_SIZE; ++rowmax_idx) { + cgmax_v(tv_ubuf_tensor, + ls_ubuf_tensor[rowmax_idx * FLOAT_VECTOR_SIZE], m_split, 1, + 1, qk_round_n / FLOAT_BLOCK_SIZE); + PIPE_BARRIER(V); + cgmax_v( + tv_ubuf_tensor, tv_ubuf_tensor, + (m_split * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + __set_mask(m_split); + max_v(lm_ubuf_tensor, lm_ubuf_tensor, tv_ubuf_tensor, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + cgmax_v( + tv_ubuf_tensor, ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], m_split, 1, 1, + qk_round_n / FLOAT_BLOCK_SIZE); + PIPE_BARRIER(V); + __set_vcg_mask((qk_n % FLOAT_VECTOR_SIZE + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE); + cgmax_v( + tv_ubuf_tensor, tv_ubuf_tensor, + (m_split * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 1, 1, 8); + PIPE_BARRIER(V); + __set_mask(m_split); + max_v(lm_ubuf_tensor, lm_ubuf_tensor, tv_ubuf_tensor, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + } + } + PIPE_BARRIER(V); + if (n_idx == n_start) { + // *** hm = lm + ub_to_ub(hm_ubuf_tensor[split_idx * m_slice], lm_ubuf_tensor, + 0, // sid + 1, // nBurst + round_m_split / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + } else { + // *** hm = vmax(lm, gm) + max_v(hm_ubuf_tensor[split_idx * m_slice], lm_ubuf_tensor, + gm_ubuf_tensor[split_idx * m_slice], + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** dm = gm - hm + sub_v( + dm_ubuf_tensor[((n_idx - n_start) / s_block_stack) % 4 * UB_FLOAT_LINE_SIZE + + split_idx * m_slice], + gm_ubuf_tensor[split_idx * m_slice], hm_ubuf_tensor[split_idx * m_slice], + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** dm = exp(dm) + exp_v( + dm_ubuf_tensor[((n_idx - n_start) / s_block_stack) % 4 * UB_FLOAT_LINE_SIZE + + split_idx * m_slice], + dm_ubuf_tensor[((n_idx - n_start) / s_block_stack) % 4 * UB_FLOAT_LINE_SIZE + + split_idx * m_slice], + 1, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + } + SetVectorMask((uint64_t)-1, (uint64_t)-1); + PIPE_BARRIER(V); + // *** gm = hm + ub_to_ub(gm_ubuf_tensor[split_idx * m_slice], + hm_ubuf_tensor[split_idx * m_slice], + 0, // sid + 1, // nBurst + round_m_split / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + // *** hm_block = expand_to_block(hm), 存放于 tv + brcb_v( + tv_ubuf_tensor.ReinterpretCast(), + hm_ubuf_tensor.ReinterpretCast()[split_idx * m_slice], + 1, // dstBlockStride + 8, // dstRepeatStride + round_m_split / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** ls = ls - hm_block + for (uint32_t vsub_idx = 0; vsub_idx < qk_n / FLOAT_VECTOR_SIZE; ++vsub_idx) { + sub_v(ls_ubuf_tensor[vsub_idx * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[vsub_idx * FLOAT_VECTOR_SIZE], tv_ubuf_tensor, + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + sub_v(ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + ls_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor, + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + // *** ls = exp(ls) + exp_v( + ls32_ubuf_tensor, ls_ubuf_tensor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + if constexpr (int8_flag) { + SymmetricQuant(lp_ubuf_tensor, ls32_ubuf_tensor, lm_ubuf_tensor, + hm_ubuf_tensor[split_idx * m_slice], + p_scale_ubuf_tensor[p_scale_offset + split_idx * m_slice], m_split, round_m_split, + qk_n, qk_round_n, head_idx); + } else { + // *** lp = castfp32to16(ls) + if (IS_BF16) { + convr_v( + lp_ubuf_tensor.ReinterpretCast(), ls32_ubuf_tensor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + } else { + conv_v( + lp_ubuf_tensor.ReinterpretCast(), ls32_ubuf_tensor, + (m_split * qk_round_n + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + } + PIPE_BARRIER(V); + } + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + ub_to_gm( + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + (n_idx - n_start) % vect_mod * TMP_SIZE / vect_mod + + ((uint64_t)sub_block_idx * qk_m / 2 + split_idx * m_slice) * qk_round_n) * + 2 / sizeof(QKV_DT)], + lp_ubuf_tensor.ReinterpretCast(), + 0, // sid + m_split, // nBurst + qk_round_n * 2 / BlockSize(), // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE3, MTE2, EVENT_ID0); + // *** ll = rowsum(ls32) + for (uint32_t rowsum_idx = 1; rowsum_idx < qk_n / FLOAT_VECTOR_SIZE; ++rowsum_idx) { + add_v(ls32_ubuf_tensor, ls32_ubuf_tensor, + ls32_ubuf_tensor[rowsum_idx * FLOAT_VECTOR_SIZE], + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + qk_round_n / FLOAT_BLOCK_SIZE // src1RepeatStride + ); + PIPE_BARRIER(V); + } + if (qk_n % FLOAT_VECTOR_SIZE > 0) { + __set_mask(qk_n % FLOAT_VECTOR_SIZE); + add_v(ls32_ubuf_tensor, ls32_ubuf_tensor, + ls32_ubuf_tensor[qk_n / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + m_split, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + qk_round_n / FLOAT_BLOCK_SIZE, // dstRepeatStride + qk_round_n / FLOAT_BLOCK_SIZE, // src0RepeatStride + qk_round_n / FLOAT_BLOCK_SIZE // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + cadd_v(ll_ubuf_tensor, ls32_ubuf_tensor, + m_split, // repeat + 1, // dstRepeatStride + 1, // srcBlockStride + qk_round_n / FLOAT_BLOCK_SIZE // srcRepeatStride + ); + PIPE_BARRIER(V); + if (n_idx == n_start) { + // *** gl = ll + ub_to_ub(gl_ubuf_tensor[split_idx * m_slice], ll_ubuf_tensor, + 0, // sid + 1, // nBurst + round_m_split / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + PIPE_BARRIER(V); + } else { + __set_mask(m_split); + // *** gl = dm * gl + mul_v( + gl_ubuf_tensor[split_idx * m_slice], + dm_ubuf_tensor[((n_idx - n_start) / s_block_stack) % 4 * UB_FLOAT_LINE_SIZE + + split_idx * m_slice], + gl_ubuf_tensor[split_idx * m_slice], + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + // *** gl = ll + gl + add_v(gl_ubuf_tensor[split_idx * m_slice], + gl_ubuf_tensor[split_idx * m_slice], ll_ubuf_tensor, + 1, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + } + } + if constexpr (!swa_compress) { + mask_offset += qk_n; + } + } + FftsCrossCoreSync(SOFTMAX_READY); + } + if (n_idx >= launch_delay + n_start) { + uint32_t p_scale_offset = + (n_idx - launch_delay) / s_block_stack % pv_stage * RoundUp(pp_m_scalar, FLOAT_VECTOR_SIZE); + WaitFlagDev(UPDATE_READY); // 4 + if (sub_m > 0) { + // *** 更新 L 和 O + if (n_idx != launch_delay + n_start) { + WAIT_FLAG(V, MTE2, EVENT_ID2); + gm_to_ub( + lo_ubuf_tensor, + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay - n_start) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * round_k], + 0, // sid + 1, // nBurst + sub_m * round_k / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID2); + // *** dm_block = expand_to_block(dm), 存放于 tv + brcb_v( + tv_ubuf_tensor.ReinterpretCast(), + dm_ubuf_tensor[((n_idx - launch_delay - n_start) / s_block_stack % 4) * UB_FLOAT_LINE_SIZE] + .ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** go = go * dm_block + for (uint32_t vmul_idx = 0; vmul_idx < __k / FLOAT_VECTOR_SIZE; ++vmul_idx) { + mul_v(go_ubuf_tensor[vmul_idx * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[vmul_idx * FLOAT_VECTOR_SIZE], tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (__k % FLOAT_VECTOR_SIZE > 0) { + __set_mask(__k % FLOAT_VECTOR_SIZE); + mul_v(go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + // *** go = lo + go + WAIT_FLAG(MTE2, V, EVENT_ID2); + if constexpr (int8_flag) { + DeqPerHeadS322F32(lo_ubuf_tensor, deq_pv_gm, off_pv_gm, head_idx, sub_m * round_k); + SymmetricDeQuant(lo_ubuf_tensor, p_scale_ubuf_tensor[p_scale_offset], sub_m, round_sub_m, __k, round_k, + head_idx); + } + add_v( + go_ubuf_tensor, go_ubuf_tensor, lo_ubuf_tensor, + (sub_m * round_k + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + SET_FLAG(V, MTE2, EVENT_ID2); + } else { + WAIT_FLAG(MTE3, MTE2, EVENT_ID2); + gm_to_ub( + go_ubuf_tensor, + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay - n_start) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * round_k], + 0, // sid + 1, // nBurst + sub_m * round_k / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID3); + WAIT_FLAG(MTE2, V, EVENT_ID3); + if constexpr (int8_flag) { + DeqPerHeadS322F32(go_ubuf_tensor, deq_pv_gm, off_pv_gm, head_idx, sub_m * round_k); + SymmetricDeQuant(go_ubuf_tensor, p_scale_ubuf_tensor[p_scale_offset], sub_m, round_sub_m, __k, round_k, + head_idx); + } + } + if (n_idx + s_block_stack > n_end + launch_delay - 1) { + // *** gl_block = expand_to_block(gl), 存放于 tv + brcb_v(tv_ubuf_tensor.ReinterpretCast(), + gl_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** go = go / gl_block + for (uint32_t vdiv_idx = 0; vdiv_idx < __k / FLOAT_VECTOR_SIZE; ++vdiv_idx) { + div_v(go_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (__k % FLOAT_VECTOR_SIZE > 0) { + __set_mask(__k % FLOAT_VECTOR_SIZE); + div_v(go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + if (IS_BF16) { + convr_v( + go_ubuf_tensor.ReinterpretCast(), go_ubuf_tensor, + (sub_m * round_k + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + } else { + conv_v( + go_ubuf_tensor.ReinterpretCast(), go_ubuf_tensor, + (sub_m * round_k + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + } + PIPE_BARRIER(V); + // ********************* move O to GM ************************ + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + ub_to_gm_align( + o_gm_tensor[o_offset + (uint64_t)sub_block_idx * qk_m / 2 * stride_qo], + go_ubuf_tensor.ReinterpretCast(), + 0, // sid + sub_m, // nBurst + __k * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + (stride_qo - __k) * 2 // dstGap + ); + SET_FLAG(MTE3, MTE2, EVENT_ID2); + } + } + } + } + } + } + + private: + __gm__ uint8_t *__restrict__ mask_gm{nullptr}; + __gm__ uint8_t *__restrict__ o_gm{nullptr}; + __gm__ uint8_t *__restrict__ s_gm{nullptr}; + __gm__ uint8_t *__restrict__ p_gm{nullptr}; + __gm__ uint8_t *__restrict__ o_tmp_gm{nullptr}; + __gm__ uint8_t *__restrict__ tiling_para_gm{nullptr}; + __gm__ uint8_t *__restrict__ alibi_coeff_gm{nullptr}; + __gm__ uint8_t *__restrict__ deq_qk_gm{nullptr}; + __gm__ uint8_t *__restrict__ off_qk_gm{nullptr}; + __gm__ uint8_t *__restrict__ quant_p_gm{nullptr}; + __gm__ uint8_t *__restrict__ deq_pv_gm{nullptr}; + __gm__ uint8_t *__restrict__ off_pv_gm{nullptr}; + __gm__ uint8_t *__restrict__ logN_gm{nullptr}; + + const uint32_t ls_ubuf_offset = 0; + const uint32_t lp_ubuf_offset = 0; + const uint32_t ls32_ubuf_offset = 2 * UB_UINT8_BLOCK_SIZE; + const uint32_t mask_ubuf_offset = 4 * UB_UINT8_BLOCK_SIZE; + const uint32_t lo_ubuf_offset = 6 * UB_UINT8_BLOCK_SIZE; + const uint32_t lm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE; + const uint32_t hm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 1 * UB_UINT8_LINE_SIZE; + const uint32_t gm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 2 * UB_UINT8_LINE_SIZE; + const uint32_t dm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 4 * UB_UINT8_LINE_SIZE; + const uint32_t ll_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 8 * UB_UINT8_LINE_SIZE; + const uint32_t gl_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 12 * UB_UINT8_LINE_SIZE; + const uint32_t tv_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 13 * UB_UINT8_LINE_SIZE; + const uint32_t p_scale_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 21 * UB_UINT8_LINE_SIZE; + const uint32_t log_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 30 * UB_UINT8_LINE_SIZE; + const uint32_t log_ubuf_float_offset = 8 * UB_UINT8_BLOCK_SIZE + 30 * UB_UINT8_LINE_SIZE; + const uint32_t go_ubuf_offset = 9 * UB_UINT8_BLOCK_SIZE; + const uint32_t mask16_ubuf_offset = 11 * UB_UINT8_BLOCK_SIZE; + + AsdopsBuffer buf; + AscendC::LocalTensor ls_ubuf_tensor = buf.GetBuffer(ls_ubuf_offset); + AscendC::LocalTensor lp_ubuf_tensor = buf.GetBuffer(lp_ubuf_offset); + AscendC::LocalTensor ls32_ubuf_tensor = buf.GetBuffer(ls32_ubuf_offset); + AscendC::LocalTensor mask_ubuf_tensor = buf.GetBuffer(mask_ubuf_offset); + AscendC::LocalTensor lo_ubuf_tensor = buf.GetBuffer(lo_ubuf_offset); + AscendC::LocalTensor lm_ubuf_tensor = buf.GetBuffer(lm_ubuf_offset); + AscendC::LocalTensor hm_ubuf_tensor = buf.GetBuffer(hm_ubuf_offset); + AscendC::LocalTensor gm_ubuf_tensor = buf.GetBuffer(gm_ubuf_offset); + AscendC::LocalTensor dm_ubuf_tensor = buf.GetBuffer(dm_ubuf_offset); + AscendC::LocalTensor ll_ubuf_tensor = buf.GetBuffer(ll_ubuf_offset); + AscendC::LocalTensor gl_ubuf_tensor = buf.GetBuffer(gl_ubuf_offset); + AscendC::LocalTensor tv_ubuf_tensor = buf.GetBuffer(tv_ubuf_offset); + AscendC::LocalTensor p_scale_ubuf_tensor = buf.GetBuffer(p_scale_ubuf_offset); + AscendC::LocalTensor go_ubuf_tensor = buf.GetBuffer(go_ubuf_offset); + AscendC::LocalTensor mask16_ubuf_tensor = + buf.GetBuffer(mask16_ubuf_offset); + + AscendC::LocalTensor log_ubuf_tensor = + buf.GetBuffer(log_ubuf_offset); + AscendC::LocalTensor log_ubuf_float_tensor = + buf.GetBuffer(log_ubuf_float_offset); + + AscendC::GlobalTensor mask_gm_tensor; + AscendC::GlobalTensor o_gm_tensor; + AscendC::GlobalTensor s_gm_tensor; + AscendC::GlobalTensor p_gm_tensor; + AscendC::GlobalTensor o_tmp_gm_tensor; + AscendC::GlobalTensor logN_gm_tensor; + AscendC::GlobalTensor logN_float_gm_tensor; + ScaleType scaleType = ScaleType::SCALE_TOR; + uint32_t batch_size{0}; + uint32_t max_seqlen{0}; + uint32_t max_q_seqlen{0}; + uint32_t q_heads{0}; + uint32_t embd{0}; + float tor{0}; + uint32_t head_stride{0}; + uint32_t mask_stride{0}; + uint32_t is_triu_mask{0}; + uint32_t total_q_blk_num{0}; + uint32_t isClamp{0}; + float clampMin; + float clampMax; + uint64_t stride_qo{0}; + uint32_t __k{0}; + uint32_t round_k{0}; + uint32_t go_flag_scalar{1}; + + int32_t sub_block_idx{-1}; + uint32_t tilingKey{0}; + uint32_t tiling_head_size{0}; + uint32_t tiling_para_size{0}; + uint32_t long_seq{0}; + uint32_t is_sqrt{0}; + uint32_t mask_type{0}; + uint32_t alibi_compress_offset{0}; + uint32_t alibi_left_align{0}; + uint32_t data_shape_type{0}; + uint32_t quantType{0}; + uint32_t window_size{0}; +}; +template +class FlashAttentionEncoderHighPrecisionVecOpt { + public: + __aicore__ __attribute__((always_inline)) inline FlashAttentionEncoderHighPrecisionVecOpt( + __gm__ uint8_t *__restrict__ sync, __gm__ uint8_t *__restrict__ mask_gm, + __gm__ uint8_t *__restrict__ alibi_coeff_gm, __gm__ uint8_t *__restrict__ o_gm, __gm__ uint8_t *__restrict__ s_gm, + __gm__ uint8_t *__restrict__ p_gm, __gm__ uint8_t *__restrict__ o_tmp_gm, + __gm__ uint8_t *__restrict__ tiling_para_gm, __gm__ uint8_t *__restrict__ deq_qk_gm, + __gm__ uint8_t *__restrict__ off_qk_gm, __gm__ uint8_t *__restrict__ quant_p_gm, + __gm__ uint8_t *__restrict__ deq_pv_gm, __gm__ uint8_t *__restrict__ off_pv_gm, + __gm__ uint8_t *__restrict__ logN_gm) + : mask_gm(mask_gm), + o_gm(o_gm), + alibi_coeff_gm(alibi_coeff_gm), + s_gm(s_gm), + p_gm(p_gm), + o_tmp_gm(o_tmp_gm), + tiling_para_gm(tiling_para_gm), + deq_qk_gm(deq_qk_gm), + off_qk_gm(off_qk_gm), + quant_p_gm(quant_p_gm), + deq_pv_gm(deq_pv_gm), + off_pv_gm(off_pv_gm), + logN_gm(logN_gm) { + SetFftsBaseAddr((unsigned long)sync); + SetAtomicnone(); + SetMasknorm(); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + + this->sub_block_idx = GetSubBlockidx(); + this->batch_size = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm)); + this->max_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1)); + this->q_heads = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2)); + this->embd = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3)); + this->tor = (float)(*((__gm__ float *)tiling_para_gm + 5)); + this->head_stride = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 6)); + this->mask_stride = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 7)); + this->is_triu_mask = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 8)); + this->total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 9)); + this->isClamp = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 10)); + this->clampMin = (float)(*((__gm__ float *)tiling_para_gm + 11)); + this->clampMax = (float)(*((__gm__ float *)tiling_para_gm + 12)); + this->tiling_head_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 14)); + this->tiling_para_size = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 15)); + this->tilingKey = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 16)); + this->long_seq = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 17)); + this->is_sqrt = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 19)); + this->mask_type = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 20)); + this->alibi_compress_offset = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 21)); + this->alibi_left_align = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 22)); + this->quantType = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 24)); + this->data_shape_type = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 25)); + this->max_q_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 29)); + this->stride_qo = (uint64_t)q_heads * embd; + if (this->data_shape_type == 1) { + this->stride_qo = embd; + } + if constexpr (splitm) { + this->tmp_times = 2; + } + this->__k = embd; + this->round_k = (__k + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + SET_FLAG(MTE3, MTE2, EVENT_ID0); + SET_FLAG(MTE3, MTE2, EVENT_ID1); + SET_FLAG(MTE3, MTE2, EVENT_ID2); + SET_FLAG(V, MTE2, EVENT_ID0); + SET_FLAG(V, MTE2, EVENT_ID1); + SET_FLAG(V, MTE2, EVENT_ID2); + SET_FLAG(V, MTE2, EVENT_ID6); + SET_FLAG(V, MTE2, EVENT_ID7); + SET_FLAG(MTE3, V, EVENT_ID0); + } + __aicore__ __attribute__((always_inline)) inline ~FlashAttentionEncoderHighPrecisionVecOpt() { + WAIT_FLAG(MTE3, MTE2, EVENT_ID0); + WAIT_FLAG(MTE3, MTE2, EVENT_ID1); + WAIT_FLAG(MTE3, MTE2, EVENT_ID2); + WAIT_FLAG(V, MTE2, EVENT_ID0); + WAIT_FLAG(V, MTE2, EVENT_ID1); + WAIT_FLAG(V, MTE2, EVENT_ID2); + WAIT_FLAG(V, MTE2, EVENT_ID6); + WAIT_FLAG(V, MTE2, EVENT_ID7); + WAIT_FLAG(MTE3, V, EVENT_ID0); + PIPE_BARRIER(ALL); + } + __aicore__ __attribute__((always_inline)) inline void Run() { + uint64_t cur_batch = 0; + uint64_t pre_total_q_blk_num = 0; + uint32_t offset_tiling = tiling_head_size + tiling_para_size * cur_batch; + uint32_t cur_total_q_blk_num = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 13 + offset_tiling)); + uint64_t process_num = (uint64_t)total_q_blk_num * q_heads; + mask_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ MASK_DTYPE *>(mask_gm)); + o_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ O_DTYPE *>(o_gm)); + s_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ S_DTYPE *>(s_gm)); + p_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ P_DTYPE *>(p_gm)); + o_tmp_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(o_tmp_gm)); + uint64_t next_process = 0; + for (uint64_t process = block_idx; process < process_num; process = next_process) { + while (process >= (uint64_t)cur_total_q_blk_num * q_heads) { + cur_batch++; + pre_total_q_blk_num = cur_total_q_blk_num; + offset_tiling += tiling_para_size; + cur_total_q_blk_num = (uint32_t)(*((__gm__ uint32_t *)tiling_para_gm + 13 + offset_tiling)); + } + next_process = process + block_num; + // get tiling args + uint32_t q_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + offset_tiling)); + uint32_t kv_seqlen = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 1 + offset_tiling)); + if (q_seqlen == 0 || kv_seqlen == 0) { + continue; + } + uint32_t pp_m_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 2 + offset_tiling)); + uint32_t pp_n_scalar = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 3 + offset_tiling)); + uint32_t addr_o_high32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 10 + offset_tiling)); + uint32_t addr_o_loww32 = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 11 + offset_tiling)); + uint64_t addr_o_scalar = (uint64_t)(((uint64_t)addr_o_high32) << 32 | addr_o_loww32); + + uint64_t process_idx = process - pre_total_q_blk_num * q_heads; + uint32_t m_idx = process_idx / q_heads; + uint64_t head_idx = process_idx % q_heads; + + uint32_t m_loop = (q_seqlen + pp_m_scalar - 1) / pp_m_scalar; + uint32_t n_loop = (kv_seqlen + pp_n_scalar - 1) / pp_n_scalar; + + uint32_t qk_m = (m_idx == (m_loop - 1)) ? (q_seqlen - m_idx * pp_m_scalar) : pp_m_scalar; + uint32_t sub_m = (sub_block_idx == 1) ? (qk_m - qk_m / 2) : qk_m / 2; + uint32_t sub_m_d128 = (sub_m + VECTOR_SIZE - 1) / VECTOR_SIZE; // up aligned to 128 + uint32_t sub_m_d64 = (sub_m + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE; // up aligned to 64 + uint32_t round_sub_m = (sub_m + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + /******** pre_load *******/ + uint32_t qk_n = n_loop == 1 ? kv_seqlen : pp_n_scalar; + uint32_t qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + + uint64_t o_offset = addr_o_scalar + head_idx * embd + m_idx * pp_m_scalar * stride_qo; + if (data_shape_type == 1) { + o_offset = addr_o_scalar + head_idx * embd * max_q_seqlen + m_idx * pp_m_scalar * stride_qo; + } + + uint32_t n_end = n_loop; + uint32_t qk_n_triu = n_end * pp_n_scalar; + uint32_t s_block_stack = n_end > 8 ? 4 : (n_end > 4 ? 2 : 1); + uint32_t launch_delay = s_block_stack * 2; + uint32_t vect_mod = 2 * launch_delay; + uint32_t m_slice = FLOAT_VECTOR_SIZE / s_block_stack; + uint32_t m_end = (sub_m + m_slice - 1) / m_slice; + for (uint32_t n_idx = 0; n_idx < n_end + launch_delay; n_idx += s_block_stack) { + if (n_idx < n_end) { + if (n_idx + s_block_stack > n_end - 1) { + qk_n = qk_n_triu > kv_seqlen ? kv_seqlen - n_idx * pp_n_scalar : qk_n_triu - n_idx * pp_n_scalar; + } else { + qk_n = pp_n_scalar * s_block_stack; + } + qk_round_n = (qk_n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + if (sub_m == 0) { + WaitFlagDev(QK_READY); + } + uint32_t pingpong_flag = 0; + for (uint32_t m_ind = 0; m_ind < m_end; m_ind++) { + uint32_t row_offset = m_ind * m_slice; + uint32_t curr_m = m_ind == m_end - 1 ? sub_m - row_offset : m_slice; + uint32_t s_ub_offset = pingpong_flag * S_DB_SIZE; + uint64_t sp_gm_offset = (uint64_t)block_idx * TMP_SIZE * tmp_times + + n_idx % vect_mod * TMP_SIZE * tmp_times / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n + row_offset * qk_round_n; + if (m_ind == 0) { + WaitFlagDev(QK_READY); + } + if (curr_m == 0) { + continue; + } + /* int32_t div_m = 6; + if constexpr (splitm) { + div_m = 4; + } */ + int32_t div_m = 4; + OnlineSoftmaxStage1( + ls_ubuf_tensor[s_ub_offset], mask16_ubuf_tensor, mask_ubuf_tensor, lm_ubuf_tensor[row_offset], + hm_ubuf_tensor[row_offset], gm_ubuf_tensor[row_offset], + dm_ubuf_tensor[((n_idx / s_block_stack) % div_m) * UB_FLOAT_LINE_SIZE * tmp_times + row_offset], + ls_ubuf_tensor[s_ub_offset], ll_ubuf_tensor[row_offset], gl_ubuf_tensor[row_offset], + lp_ubuf_tensor[s_ub_offset * 2], tv_ubuf_tensor, s_gm_tensor[sp_gm_offset], p_gm_tensor[sp_gm_offset], + n_idx == 0, this->tor, curr_m, qk_n, qk_round_n, pingpong_flag); + pingpong_flag = 1 - pingpong_flag; + } + FftsCrossCoreSync(SOFTMAX_READY); + } + if (n_idx >= launch_delay) { + WaitFlagDev(UPDATE_READY); // 4 + if (sub_m == 0) { + continue; + } + // // *** 更新 L 和 O + if constexpr (splitm) { + uint32_t sub_km = 64; + uint32_t loop_m = (sub_m + sub_km - 1) / sub_km; + if (n_idx != launch_delay) { + // *** dm_block = expand_to_block(dm), 存放于 tv + brcb_v( + tv_ubuf_tensor.ReinterpretCast(), + dm_ubuf_tensor[((n_idx - launch_delay) / s_block_stack % 4) * UB_FLOAT_LINE_SIZE * tmp_times] + .ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** go = go * dm_block + for (uint32_t vmul_idx = 0; vmul_idx < __k / FLOAT_VECTOR_SIZE; ++vmul_idx) { + mul_v(go_ubuf_tensor[vmul_idx * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[vmul_idx * FLOAT_VECTOR_SIZE], tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (__k % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(__k % FLOAT_VECTOR_SIZE); + mul_v(go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + // *** go = lo + go + for (uint32_t ms_idx = 0; ms_idx < loop_m; ms_idx++) { + uint32_t nowm = (ms_idx == (loop_m - 1)) ? sub_m - ms_idx * sub_km : sub_km; + uint32_t now_roundm = (nowm + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + uint32_t resm = ms_idx * sub_km; + WAIT_FLAG(V, MTE2, EVENT_ID2); + gm_to_ub( + lo_ubuf_tensor, + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE * tmp_times + + (n_idx - launch_delay) % vect_mod * TMP_SIZE * tmp_times / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * round_k + resm * round_k], + 0, // sid + 1, // nBurst + nowm * round_k / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID2); + // *** go = lo + go + WAIT_FLAG(MTE2, V, EVENT_ID2); + add_v( + go_ubuf_tensor[resm * round_k], go_ubuf_tensor[resm * round_k], lo_ubuf_tensor, + (nowm * round_k + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + SET_FLAG(V, MTE2, EVENT_ID2); + } + } else { + WAIT_FLAG(MTE3, MTE2, EVENT_ID2); + gm_to_ub( + go_ubuf_tensor, + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE * tmp_times + + (n_idx - launch_delay) % vect_mod * TMP_SIZE * tmp_times / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * round_k], + 0, // sid + 1, // nBurst + sub_m * round_k / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID3); + WAIT_FLAG(MTE2, V, EVENT_ID3); + } + if (n_idx + s_block_stack > n_end + launch_delay - 1) { + // *** gl_block = expand_to_block(gl), 存放于 tv + brcb_v(tv_ubuf_tensor.ReinterpretCast(), + gl_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** go = go / gl_block + for (uint32_t vdiv_idx = 0; vdiv_idx < __k / FLOAT_VECTOR_SIZE; ++vdiv_idx) { + div_v(go_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (__k % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(__k % FLOAT_VECTOR_SIZE); + div_v(go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + if (sub_m <= 16) { + PIPE_BARRIER(V); + } + for (uint32_t ms_idx = 0; ms_idx < loop_m; ms_idx++) { + uint32_t nowm = (ms_idx == (loop_m - 1)) ? sub_m - ms_idx * sub_km : sub_km; + uint32_t now_roundm = (nowm + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + uint32_t resm = ms_idx * sub_km; + conv_v( + go_ubuf_tensor.ReinterpretCast()[resm * round_k], go_ubuf_tensor[resm * round_k], + (nowm * round_k + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + } + PIPE_BARRIER(V); + // ********************* move O to GM ************************ + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + ub_to_gm_align( + o_gm_tensor[o_offset + (uint64_t)sub_block_idx * qk_m / 2 * stride_qo], + go_ubuf_tensor.ReinterpretCast(), + 0, // sid + sub_m, // nBurst + __k * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + (stride_qo - __k) * 2 // dstGap + ); + SET_FLAG(MTE3, MTE2, EVENT_ID2); + } + } else { + if (n_idx != launch_delay) { + WAIT_FLAG(V, MTE2, EVENT_ID2); + gm_to_ub( + lo_ubuf_tensor, + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * round_k], + 0, // sid + 1, // nBurst + sub_m * round_k / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID2); + // *** dm_block = expand_to_block(dm), 存放于 tv + brcb_v( + tv_ubuf_tensor.ReinterpretCast(), + dm_ubuf_tensor[((n_idx - launch_delay) / s_block_stack % 4) * UB_FLOAT_LINE_SIZE] + .ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** go = go * dm_block + for (uint32_t vmul_idx = 0; vmul_idx < __k / FLOAT_VECTOR_SIZE; ++vmul_idx) { + mul_v(go_ubuf_tensor[vmul_idx * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[vmul_idx * FLOAT_VECTOR_SIZE], tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (__k % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(__k % FLOAT_VECTOR_SIZE); + mul_v(go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + // *** go = lo + go + WAIT_FLAG(MTE2, V, EVENT_ID2); + add_v( + go_ubuf_tensor, go_ubuf_tensor, lo_ubuf_tensor, + (sub_m * round_k + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + SET_FLAG(V, MTE2, EVENT_ID2); + } else { + WAIT_FLAG(MTE3, MTE2, EVENT_ID2); + gm_to_ub( + go_ubuf_tensor, + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE + + (n_idx - launch_delay) % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * round_k], + 0, // sid + 1, // nBurst + sub_m * round_k / FLOAT_BLOCK_SIZE, // lenBurst + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE2, V, EVENT_ID3); + WAIT_FLAG(MTE2, V, EVENT_ID3); + } + if (n_idx + s_block_stack > n_end + launch_delay - 1) { + // *** gl_block = expand_to_block(gl), 存放于 tv + brcb_v(tv_ubuf_tensor.ReinterpretCast(), + gl_ubuf_tensor.ReinterpretCast(), + 1, // dstBlockStride + 8, // dstRepeatStride + round_sub_m / FLOAT_BLOCK_SIZE // repeat + ); + PIPE_BARRIER(V); + // *** go = go / gl_block + for (uint32_t vdiv_idx = 0; vdiv_idx < __k / FLOAT_VECTOR_SIZE; ++vdiv_idx) { + div_v(go_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[vdiv_idx * FLOAT_VECTOR_SIZE], tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + } + if (__k % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(__k % FLOAT_VECTOR_SIZE); + div_v(go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + go_ubuf_tensor[__k / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tv_ubuf_tensor, + sub_m, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 0, // src1BlockStride + round_k / FLOAT_BLOCK_SIZE, // dstRepeatStride + round_k / FLOAT_BLOCK_SIZE, // src0RepeatStride + 1 // src1RepeatStride + ); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + PIPE_BARRIER(V); + conv_v( + go_ubuf_tensor.ReinterpretCast(), go_ubuf_tensor, + (sub_m * round_k + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + // ********************* move O to GM ************************ + SET_FLAG(V, MTE3, EVENT_ID0); + WAIT_FLAG(V, MTE3, EVENT_ID0); + ub_to_gm_align( + o_gm_tensor[o_offset + (uint64_t)sub_block_idx * qk_m / 2 * stride_qo], + go_ubuf_tensor.ReinterpretCast(), + 0, // sid + sub_m, // nBurst + __k * 2, // lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + (stride_qo - __k) * 2 // dstGap + ); + SET_FLAG(MTE3, MTE2, EVENT_ID2); + } + } + } + } + } + } + + private: + __gm__ uint8_t *__restrict__ mask_gm{nullptr}; + __gm__ uint8_t *__restrict__ o_gm{nullptr}; + __gm__ uint8_t *__restrict__ s_gm{nullptr}; + __gm__ uint8_t *__restrict__ p_gm{nullptr}; + __gm__ uint8_t *__restrict__ o_tmp_gm{nullptr}; + __gm__ uint8_t *__restrict__ tiling_para_gm{nullptr}; + __gm__ uint8_t *__restrict__ alibi_coeff_gm{nullptr}; + __gm__ uint8_t *__restrict__ deq_qk_gm{nullptr}; + __gm__ uint8_t *__restrict__ off_qk_gm{nullptr}; + __gm__ uint8_t *__restrict__ quant_p_gm{nullptr}; + __gm__ uint8_t *__restrict__ deq_pv_gm{nullptr}; + __gm__ uint8_t *__restrict__ off_pv_gm{nullptr}; + __gm__ uint8_t *__restrict__ logN_gm{nullptr}; + + const uint32_t ls_ubuf_offset = 0; + const uint32_t lp_ubuf_offset = 0; + const uint32_t ls32_ubuf_offset = 2 * UB_UINT8_BLOCK_SIZE; + const uint32_t mask_ubuf_offset = 4 * UB_UINT8_BLOCK_SIZE; + const uint32_t go_ubuf_offset = 4 * UB_UINT8_BLOCK_SIZE; + const uint32_t lm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE; + const uint32_t hm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 1 * UB_UINT8_LINE_SIZE; + const uint32_t gm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 2 * UB_UINT8_LINE_SIZE; + const uint32_t dm_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 3 * UB_UINT8_LINE_SIZE; + const uint32_t ll_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 7 * UB_UINT8_LINE_SIZE; + const uint32_t gl_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 8 * UB_UINT8_LINE_SIZE; + const uint32_t tv_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 9 * UB_UINT8_LINE_SIZE; + const uint32_t p_scale_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 21 * UB_UINT8_LINE_SIZE; + const uint32_t log_ubuf_offset = 8 * UB_UINT8_BLOCK_SIZE + 30 * UB_UINT8_LINE_SIZE; + const uint32_t log_ubuf_float_offset = 8 * UB_UINT8_BLOCK_SIZE + 30 * UB_UINT8_LINE_SIZE; + const uint32_t lo_ubuf_offset = 10 * UB_UINT8_BLOCK_SIZE; + const uint32_t mask16_ubuf_offset = 11 * UB_UINT8_BLOCK_SIZE; + + AsdopsBuffer buf; + AscendC::LocalTensor ls_ubuf_tensor = buf.GetBuffer(ls_ubuf_offset); + AscendC::LocalTensor lp_ubuf_tensor = buf.GetBuffer(lp_ubuf_offset); + AscendC::LocalTensor ls32_ubuf_tensor = buf.GetBuffer(ls32_ubuf_offset); + AscendC::LocalTensor mask_ubuf_tensor = buf.GetBuffer(mask_ubuf_offset); + AscendC::LocalTensor lo_ubuf_tensor = buf.GetBuffer(lo_ubuf_offset); + AscendC::LocalTensor lm_ubuf_tensor = buf.GetBuffer(lm_ubuf_offset); + AscendC::LocalTensor hm_ubuf_tensor = buf.GetBuffer(hm_ubuf_offset); + AscendC::LocalTensor gm_ubuf_tensor = buf.GetBuffer(gm_ubuf_offset); + AscendC::LocalTensor dm_ubuf_tensor = buf.GetBuffer(dm_ubuf_offset); + AscendC::LocalTensor ll_ubuf_tensor = buf.GetBuffer(ll_ubuf_offset); + AscendC::LocalTensor gl_ubuf_tensor = buf.GetBuffer(gl_ubuf_offset); + AscendC::LocalTensor tv_ubuf_tensor = buf.GetBuffer(tv_ubuf_offset); + AscendC::LocalTensor p_scale_ubuf_tensor = buf.GetBuffer(p_scale_ubuf_offset); + AscendC::LocalTensor go_ubuf_tensor = buf.GetBuffer(go_ubuf_offset); + AscendC::LocalTensor mask16_ubuf_tensor = + buf.GetBuffer(mask16_ubuf_offset); + + AscendC::LocalTensor log_ubuf_tensor = buf.GetBuffer(log_ubuf_offset); + AscendC::LocalTensor log_ubuf_float_tensor = + buf.GetBuffer(log_ubuf_float_offset); + + AscendC::GlobalTensor mask_gm_tensor; + AscendC::GlobalTensor o_gm_tensor; + AscendC::GlobalTensor s_gm_tensor; + AscendC::GlobalTensor p_gm_tensor; + AscendC::GlobalTensor o_tmp_gm_tensor; + AscendC::GlobalTensor logN_gm_tensor; + AscendC::GlobalTensor logN_float_gm_tensor; + + uint32_t batch_size{0}; + uint32_t max_seqlen{0}; + uint32_t max_q_seqlen{0}; + uint32_t q_heads{0}; + uint32_t embd{0}; + float tor{0}; + uint32_t head_stride{0}; + uint32_t mask_stride{0}; + uint32_t is_triu_mask{0}; + uint32_t total_q_blk_num{0}; + uint32_t isClamp{0}; + float clampMin; + float clampMax; + uint64_t stride_qo{0}; + uint32_t __k{0}; + uint32_t round_k{0}; + uint32_t tmp_times{1}; + int32_t sub_block_idx{-1}; + uint32_t tilingKey{0}; + uint32_t tiling_head_size{0}; + uint32_t tiling_para_size{0}; + uint32_t long_seq{0}; + uint32_t is_sqrt{0}; + uint32_t mask_type{0}; + uint32_t alibi_compress_offset{0}; + uint32_t alibi_left_align{0}; + uint32_t data_shape_type{0}; + uint32_t quantType{0}; +}; +#endif + +} // namespace unpda_fa_npd_bf16 \ No newline at end of file diff --git a/ops/ascendc/reshape_and_cache_npd/op_host/reshape_and_cache_npd.cpp b/ops/ascendc/reshape_and_cache_npd/op_host/reshape_and_cache_npd.cpp index 43e92468193670e5aab93269a205e235db777964..ed6f6df1371b766262f7c3fbe872f1adf8cd11c1 100644 --- a/ops/ascendc/reshape_and_cache_npd/op_host/reshape_and_cache_npd.cpp +++ b/ops/ascendc/reshape_and_cache_npd/op_host/reshape_and_cache_npd.cpp @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../op_host/reshape_and_cache_npd_tiling.h" +#include +#include +#include "reshape_and_cache_npd.h" +#include "reshape_and_cache_npd_tiling.h" #include "register/op_def_registry.h" #include "graph/utils/type_utils.h" #include "tiling/platform/platform_ascendc.h" +#include "utils/log/asc_cpu_log.h" namespace optiling { -constexpr int32_t KCACHE_VCACHE = 11; -constexpr int32_t KCACHE = 10; -constexpr int32_t NPD_TILING = 1000; static constexpr auto kRank2 = 2; static constexpr auto kRank3 = 3; static constexpr auto kDim1 = 1; @@ -30,116 +31,206 @@ static constexpr auto kDim3 = 3; static constexpr auto kDim4 = 4; static constexpr auto kDim7 = 7; -#define GE_LOGE if (false) std::cout -int32_t ReshapeAndCacheNpdTilingId(ge::DataType dtype, int32_t cacheConfig, uint32_t npd) { - int32_t tilingId = 0; - return tilingId; +int32_t ReshapeAndCacheNpdTilingId(ge::DataType dtype, int32_t cacheConfig, int32_t KvConfig) { + int32_t tiling_id = 0; + if (cacheConfig) tiling_id |= 1; + if (KvConfig) tiling_id |= 1 << 1; + return tiling_id; } -static ge::graphStatus ReshapeAndCacheNpdTiling(gert::TilingContext *context) { - if (context->GetInputDesc(0)->GetDataType() == ge::DataType::DT_UNDEFINED || - context->GetInputDesc(kDim2)->GetDataType() == ge::DataType::DT_UNDEFINED || - context->GetInputDesc(kDim4)->GetDataType() == ge::DataType::DT_UNDEFINED) { - GE_LOGE << "Input 0, 2, 4(key, key_cache, slot_mapping) are required. Their dtypes cannot be None, but got " - << context->GetInputDesc(0)->GetDataType() << ", " << context->GetInputDesc(kDim2)->GetDataType() << ", " - << context->GetInputDesc(kDim4)->GetDataType(); +static ge::graphStatus ReshapeAndCacheNpdTiling(gert::TilingContext *context, int *kv_cache_mode, int *kv_mode) { + *kv_cache_mode = std::string(context->GetAttrs()->GetAttrPointer(0)) == std::string(NPD_LAYOUT); + *kv_mode = std::string(context->GetAttrs()->GetAttrPointer(1)) == std::string(NPD_LAYOUT); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetSeqLen(gert::TilingContext *context, int batch, size_t input_idx, + std::vector *seq_len) { + auto t = context->GetInputTensor(input_idx); + if (t == nullptr) { + ASC_CPU_LOG_ERROR("seq length is null"); return ge::GRAPH_FAILED; } - if ((context->GetInputDesc(kDim1)->GetDataType() == ge::DataType::DT_UNDEFINED) ^ - (context->GetInputDesc(kDim3)->GetDataType() == ge::DataType::DT_UNDEFINED)) { - GE_LOGE << "Input 1, 3 (value, value_cache) should either both be None or have the same dtype, but got " - << (context->GetInputDesc(kDim1)->GetDataType() == ge::DataType::DT_UNDEFINED) << ", " - << (context->GetInputDesc(0)->GetDataType() == ge::DataType::DT_UNDEFINED); + if (t->GetShapeSize() != batch) { + ASC_CPU_LOG_ERROR("seq length size is illegal, size = %d batch size = %d", t->GetShapeSize(), batch); return ge::GRAPH_FAILED; } - ReshapeAndCacheNpdTilingData tiling; - uint64_t ub_size; + auto p = t->GetData(); + seq_len->assign(p, p + batch); + return ge::GRAPH_SUCCESS; +} + +// Helper: validate inputs and extract shapes/parameters +static ge::graphStatus ValidateAndExtractNpdParams(gert::TilingContext *context, int32_t &cache_npd, int32_t &kv_npd, + uint32_t &num_tokens, uint32_t &hidden_size, uint32_t &key_head_num, + uint32_t &value_head_num, uint32_t &page_size, uint32_t &batch_size, + uint32_t &core_num, std::vector &q_seq_len, + std::vector &kv_seq_len) { + size_t k_idx = static_cast(ms_custom_ops::ReshapeAndCacheNpdInputIndex::kInputKeyIndex); + size_t k_cache_idx = static_cast(ms_custom_ops::ReshapeAndCacheNpdInputIndex::kInputKeyCacheIndex); + size_t sm_idx = static_cast(ms_custom_ops::ReshapeAndCacheNpdInputIndex::kInputSlotMappingIndex); + size_t v_idx = static_cast(ms_custom_ops::ReshapeAndCacheNpdInputIndex::kInputValueIndex); + size_t v_cache_idx = static_cast(ms_custom_ops::ReshapeAndCacheNpdInputIndex::kInputValueCacheIndex); + size_t block_tbl_idx = static_cast(ms_custom_ops::ReshapeAndCacheNpdInputIndex::kInputBlockTblIndex); + size_t q_seq_len_idx = static_cast(ms_custom_ops::ReshapeAndCacheNpdInputIndex::kInputQSeqIndex); + size_t kv_seq_len_idx = static_cast(ms_custom_ops::ReshapeAndCacheNpdInputIndex::kInputKVSeqIndex); + + ReshapeAndCacheNpdTiling(context, &cache_npd, &kv_npd); + + // dtype validation + auto k_data_data_type = context->GetInputDesc(k_idx)->GetDataType(); + auto k_cache_data_type = context->GetInputDesc(k_cache_idx)->GetDataType(); + auto sm_data_type = context->GetInputDesc(sm_idx)->GetDataType(); + if (k_data_data_type == ge::DataType::DT_UNDEFINED || k_cache_data_type == ge::DataType::DT_UNDEFINED || + sm_data_type == ge::DataType::DT_UNDEFINED) { + ASC_CPU_LOG_ERROR("Input key/key_cache/slot_mapping dtypes cannot be None"); + return ge::GRAPH_FAILED; + } + + auto v_data_data_type = context->GetInputDesc(v_idx)->GetDataType(); + auto v_cache_data_type = context->GetInputDesc(v_cache_idx)->GetDataType(); + if ((v_data_data_type == ge::DataType::DT_UNDEFINED) ^ (v_cache_data_type == ge::DataType::DT_UNDEFINED)) { + ASC_CPU_LOG_ERROR("value/value_cache must both be None or same dtype"); + return ge::GRAPH_FAILED; + } + + // platform info auto ascendc_platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint64_t ub_size; ascendc_platform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_size); - auto core_num = ascendc_platform.GetCoreNum(); - auto key_shape = context->GetInputShape(0)->GetOriginShape(); - auto key_cache_shape = context->GetInputShape(kDim2)->GetOriginShape(); - auto value_cache_shape = context->GetInputShape(kDim3)->GetOriginShape(); - auto slot_mapping_shape = context->GetInputShape(kDim4)->GetOriginShape(); - auto block_tbl_shape = context->GetInputShape(kDim7)->GetOriginShape(); - if (key_shape.GetDimNum() != kDim3 && key_shape.GetDimNum() != kDim2 || slot_mapping_shape.GetDimNum() != kDim1 || - key_cache_shape.GetDimNum() < kDim3) { - GE_LOGE << "The dim of input should be key=" << kDim3 << " or " << kDim2 << ", slot_mapping=" << kDim1 - << ", key_cache>=" << kDim3 << ", but got " << key_shape.GetDimNum() << ", " - << slot_mapping_shape.GetDimNum() << ", " << key_cache_shape.GetDimNum(); + core_num = ascendc_platform.GetCoreNumAiv(); + + // shape validation + auto key_shape = context->GetInputShape(k_idx)->GetOriginShape(); + auto key_cache_shape = context->GetInputShape(k_cache_idx)->GetOriginShape(); + auto value_cache_shape = context->GetInputShape(v_cache_idx)->GetOriginShape(); + auto slot_mapping_shape = context->GetInputShape(sm_idx)->GetOriginShape(); + auto block_tbl_shape = context->GetInputShape(block_tbl_idx)->GetOriginShape(); + + if (key_shape.GetDimNum() != kDim3 && key_shape.GetDimNum() != kDim2) { + ASC_CPU_LOG_ERROR("Key dim must be 2 or 3"); + return ge::GRAPH_FAILED; + } + if (slot_mapping_shape.GetDimNum() != kDim1) { + ASC_CPU_LOG_ERROR("Slot mapping dim must be 1"); + return ge::GRAPH_FAILED; + } + if (key_cache_shape.GetDimNum() != kDim4 || value_cache_shape.GetDimNum() != kDim4) { + ASC_CPU_LOG_ERROR("Key/value cache dim must be 4"); return ge::GRAPH_FAILED; } - uint32_t num_tokens = key_shape.GetDim(0); - uint32_t hidden_size = key_shape.GetDim(kDim1); - if (key_shape.GetDimNum() == kRank3) { + + // token/hidden size + num_tokens = key_shape.GetDim(0); + hidden_size = key_shape.GetDim(kDim1); + if (key_shape.GetDimNum() == kRank3) { // BSH hidden_size = key_shape.GetDim(kDim2); num_tokens *= key_shape.GetDim(kDim1); } + auto num_tokens_max = key_cache_shape.GetDim(0) * key_cache_shape.GetDim(kDim1); - const uint32_t *npd = context->GetAttrs()->GetAttrPointer(0); - if (*npd) { - // [BlockNum, N, BlockSize, D] + if (cache_npd) { num_tokens_max = key_cache_shape.GetDim(0) * key_cache_shape.GetDim(kDim2); } if (num_tokens_max != 0 && num_tokens_max < num_tokens) { - GE_LOGE << "The number of tokens should be less than or equal to block_num * block_size = " << num_tokens_max - << ", but got " << num_tokens; + ASC_CPU_LOG_ERROR("num_tokens exceeds max capacity"); return ge::GRAPH_FAILED; } if (num_tokens < core_num) { core_num = num_tokens; } - uint32_t key_head_num = key_cache_shape.GetDim(kDim1); - uint32_t page_size = key_cache_shape.GetDim(kDim2); - if (!*npd) { + key_head_num = key_cache_shape.GetDim(kDim1); + page_size = key_cache_shape.GetDim(kDim2); + if (!cache_npd) { key_head_num = key_cache_shape.GetDim(kDim2); page_size = key_cache_shape.GetDim(kDim1); } - uint32_t value_head_num = key_head_num; - uint32_t batch_size = block_tbl_shape.GetDim(0); + value_head_num = key_head_num; + batch_size = block_tbl_shape.GetDim(0); + + if (GetSeqLen(context, batch_size, q_seq_len_idx, &q_seq_len) != ge::GRAPH_SUCCESS || + GetSeqLen(context, batch_size, kv_seq_len_idx, &kv_seq_len) != ge::GRAPH_SUCCESS) { + ASC_CPU_LOG_ERROR("Failed to get sequence lengths"); + return ge::GRAPH_FAILED; + } + + if (batch_size > KMaxBatch) { + ASC_CPU_LOG_ERROR("Batch size too big"); + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} - int32_t cache_config = (context->GetInputDesc(kDim1)->GetDataType() == ge::DataType::DT_UNDEFINED && - context->GetInputTensor(kDim3)->GetDataType() == ge::DataType::DT_UNDEFINED) - ? kDim1 - : 0; +// Main function: uses helper above, then fills tiling buffer +static ge::graphStatus ReshapeAndCacheNpdTiling(gert::TilingContext *context) { + int32_t cache_npd, kv_npd; + uint32_t num_tokens, hidden_size, key_head_num, value_head_num, page_size, batch_size, core_num; + std::vector q_seq_len, kv_seq_len; - if (cache_config == 0) { - value_head_num = *npd ? value_cache_shape.GetDim(kDim1) : value_cache_shape.GetDim(kDim2); + auto status = ValidateAndExtractNpdParams(context, cache_npd, kv_npd, num_tokens, hidden_size, key_head_num, + value_head_num, page_size, batch_size, core_num, q_seq_len, kv_seq_len); + if (status != ge::GRAPH_SUCCESS) { + return status; } - auto tiling_key = optiling::ReshapeAndCacheNpdTilingId(context->GetInputDesc(0)->GetDataType(), cache_config, *npd); - tiling.set_numTokens(num_tokens); - tiling.set_hiddenSize(hidden_size); - tiling.set_kHeadNum(key_head_num); - tiling.set_vHeadNum(value_head_num); - tiling.set_pageSize(page_size); - tiling.set_batchSize(batch_size); - tiling.set_tilingId(tiling_key); + ReshapeAndCacheNpdTilingData tiling; + auto tiling_size = (KLastIdx + KSeqLenArrayNum * batch_size) * sizeof(uint32_t); + auto tiling_buf = tiling.get_buf(); + auto tiling_key = optiling::ReshapeAndCacheNpdTilingId(context->GetInputDesc(0)->GetDataType(), cache_npd, kv_npd); + + tiling_buf[KTilingId] = tiling_key; + tiling_buf[KUseCoreNumIdx] = core_num; + tiling_buf[KNumTokensIdx] = num_tokens; + tiling_buf[KHiddenSizeIdx] = hidden_size; + tiling_buf[KKeyHeadNumIdx] = key_head_num; + tiling_buf[KVHeadNumIdx] = value_head_num; + tiling_buf[KPageSizeIdx] = page_size; + tiling_buf[KBatchSizeIdx] = batch_size; + + for (uint32_t i = 0; i < batch_size; i++) { + tiling_buf[KLastIdx + i] = static_cast(kv_seq_len.at(i)); + tiling_buf[KLastIdx + i + batch_size] = static_cast(q_seq_len.at(i)); + } context->SetBlockDim(core_num); tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); - context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + context->GetRawTilingData()->SetDataSize(tiling_size); size_t *currentWorkspace = context->GetWorkspaceSizes(kDim1); currentWorkspace[0] = 0; + context->SetTilingKey(tiling_key); + + ASC_CPU_LOG_INFO("Reshape and cache npd tiling %d %d %d %d %d %d %d %d %d", num_tokens, hidden_size, key_head_num, + value_head_num, page_size, batch_size, tiling_key, core_num); + return ge::GRAPH_SUCCESS; } } // namespace optiling namespace ge { static ge::graphStatus ReshapeAndCacheNpdInferShape(gert::InferShapeContext *context) { + int32_t cache_kv_layout = std::string(context->GetAttrs()->GetAttrPointer(0)) == std::string(NPD_LAYOUT); + int32_t kv_layout = std::string(context->GetAttrs()->GetAttrPointer(1)) == std::string(NPD_LAYOUT); + auto k_idx = static_cast(ms_custom_ops::ReshapeAndCacheNpdInputIndex::kInputKeyIndex); + auto k_shape = context->GetInputShape(k_idx); auto t = context->GetInputShape(0)->GetDim(0); - auto ps = context->GetInputShape(optiling::kDim2)->GetDim(optiling::kDim2); + auto ps = context->GetInputShape(optiling::kDim2)->GetDim(cache_kv_layout ? optiling::kDim2 : optiling::kDim1); auto bs = context->GetInputShape(optiling::kDim7)->GetDim(0); int max_seq = t + (ps - optiling::kDim1) * bs; gert::Shape *out_key_shape = context->GetOutputShape(0); gert::Shape *out_value_shape = context->GetOutputShape(optiling::kDim1); + *out_key_shape = *(context->GetInputShape(0)); *out_value_shape = *(context->GetInputShape(0)); - out_key_shape->SetDim(0, max_seq); - out_value_shape->SetDim(0, max_seq); + if ((kv_layout == 1) && (k_shape->GetDim(0) >= 0)) { + out_key_shape->SetDim(0, max_seq); + out_value_shape->SetDim(0, max_seq); + } return GRAPH_SUCCESS; } + static graphStatus ReshapeAndCacheNpdInferDataType(gert::InferDataTypeContext *context) { const auto inputDataType = context->GetInputDataType(0); context->SetOutputDataType(0, inputDataType); @@ -182,17 +273,19 @@ class ReshapeAndCacheNpd : public OpDef { .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) .AutoContiguous(); - this->Input("q_seq") + this->Input("actual_seq_qlen") .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32}) + .DataType({ge::DT_INT64, ge::DT_INT64}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .ValueDepend(Option::OPTIONAL, DependScope::TILING) .AutoContiguous(); - this->Input("kv_seq") + this->Input("actual_seq_kvlen") .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32}) + .DataType({ge::DT_INT64, ge::DT_INT64}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .ValueDepend(Option::OPTIONAL, DependScope::TILING) .AutoContiguous(); this->Input("block_tbl") .ParamType(REQUIRED) @@ -200,13 +293,14 @@ class ReshapeAndCacheNpd : public OpDef { .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) .AutoContiguous(); - this->Attr("cache_mode").AttrType(OPTIONAL).Int(1); - this->Output("k_out") + this->Attr("kv_cache_layout").AttrType(OPTIONAL).String(NPD_LAYOUT); + this->Attr("key_value_layout").AttrType(OPTIONAL).String(NPD_LAYOUT); + this->Output("out_key") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16, ge::DT_BF16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("v_out") + this->Output("out_value") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16, ge::DT_BF16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) diff --git a/ops/ascendc/reshape_and_cache_npd/op_host/reshape_and_cache_npd.h b/ops/ascendc/reshape_and_cache_npd/op_host/reshape_and_cache_npd.h new file mode 100644 index 0000000000000000000000000000000000000000..374042feb10458ed0b4b0901600e984caee09e65 --- /dev/null +++ b/ops/ascendc/reshape_and_cache_npd/op_host/reshape_and_cache_npd.h @@ -0,0 +1,41 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_CUSTOM_OPS_ASCENDC_OP_HOST_RESHAPE_AND_CACHE_NPD_H +#define MS_CUSTOM_OPS_ASCENDC_OP_HOST_RESHAPE_AND_CACHE_NPD_H + +namespace ms_custom_ops { +#define TH_LAYOUT "TH" +#define ND_LAYOUT "ND" +#define NPD_LAYOUT "NPD" + +enum class ReshapeAndCacheNpdInputIndex { + kInputKeyIndex = 0, + kInputValueIndex = 1, + kInputKeyCacheIndex = 2, + kInputValueCacheIndex = 3, + kInputSlotMappingIndex = 4, + kInputQSeqIndex = 5, + kInputKVSeqIndex = 6, + kInputBlockTblIndex = 7, + kInputKVCacheModeIndex = 8, + kInputKeyValueModeIndex = 9 +}; + +enum class ReshapeAndCacheNpdOutputIndex { kOutputKeyIndex = 0, kOutputValueIndex = 1 }; + +} // namespace ms_custom_ops +#endif // MS_CUSTOM_OPS_ASCENDC_OP_HOST_RESHAPE_AND_CACHE_NPD_H diff --git a/ops/ascendc/reshape_and_cache_npd/op_host/reshape_and_cache_npd_tiling.h b/ops/ascendc/reshape_and_cache_npd/op_host/reshape_and_cache_npd_tiling.h index 276b72360316519df2fc5e02a2a8d5b3994294d6..3e5a63022df0969e01988f49b943cb3a0517cae5 100644 --- a/ops/ascendc/reshape_and_cache_npd/op_host/reshape_and_cache_npd_tiling.h +++ b/ops/ascendc/reshape_and_cache_npd/op_host/reshape_and_cache_npd_tiling.h @@ -18,17 +18,22 @@ #include "register/tilingdata_base.h" namespace optiling { + +#define KTilingId 0 +#define KUseCoreNumIdx 1 +#define KNumTokensIdx 2 +#define KHiddenSizeIdx 3 +#define KKeyHeadNumIdx 4 +#define KVHeadNumIdx 5 +#define KPageSizeIdx 6 +#define KBatchSizeIdx 7 +#define KLastIdx 32 +#define KSeqLenArrayNum 2 +#define KMaxBatch 512 + BEGIN_TILING_DATA_DEF(ReshapeAndCacheNpdTilingData) - TILING_DATA_FIELD_DEF(uint32_t, tilingId); - TILING_DATA_FIELD_DEF(uint32_t, useCoreNum); - TILING_DATA_FIELD_DEF(uint32_t, numTokens); - TILING_DATA_FIELD_DEF(uint32_t, hiddenSize); - TILING_DATA_FIELD_DEF(uint32_t, kHeadNum); - TILING_DATA_FIELD_DEF(uint32_t, vHeadNum); - TILING_DATA_FIELD_DEF(uint32_t, pageSize); - TILING_DATA_FIELD_DEF(uint32_t, batchSize); +TILING_DATA_FIELD_DEF_ARR(uint32_t, KLastIdx + KSeqLenArrayNum * KMaxBatch, buf); END_TILING_DATA_DEF; - -REGISTER_TILING_DATA_CLASS(ReshapeAndCacheNpd, ReshapeAndCacheNpdTilingData) +REGISTER_TILING_DATA_CLASS(ReshapeAndCacheNpd, ReshapeAndCacheNpdTilingData); } // namespace optiling #endif // MS_CUSTOM_OPS_OPS_ASCEND_C_RESHAPE_AND_CACHE_NPD_OP_HOST_RESHAPE_AND_CACHE_NPD_TILING_H diff --git a/ops/ascendc/reshape_and_cache_npd/op_kernel/reshape_and_cache_npd.cpp b/ops/ascendc/reshape_and_cache_npd/op_kernel/reshape_and_cache_npd.cpp index d1370af1a4f36aac749be6e26d5e9770987d5778..5e2a3357baa20f927e1b514a965d822041268002 100644 --- a/ops/ascendc/reshape_and_cache_npd/op_kernel/reshape_and_cache_npd.cpp +++ b/ops/ascendc/reshape_and_cache_npd/op_kernel/reshape_and_cache_npd.cpp @@ -17,134 +17,183 @@ #include "ascendc/basic_api/kernel_operator.h" static constexpr int32_t BLOCK_SIZE = 32; -void __aicore__ __inline__ reshape_and_cache_npd_kv_npd_offset(int token_idx, int block_size, int head_num, int embed, - __gm__ uint8_t *__restrict__ kv_seq, - uint32_t *token_offset) { - uint32_t offset = 0; - uint32_t i = 0; - int32_t limit = (int32_t)(*((__gm__ int32_t *)kv_seq + i)); - int32_t prev_seq_len = limit; - - uint32_t prev_token = 0; - while (token_idx >= limit) { - offset += ((prev_seq_len + block_size - 1) / block_size) * block_size; - prev_token += prev_seq_len; - i++; - prev_seq_len = (int32_t)(*((__gm__ int32_t *)kv_seq + i)); - limit += prev_seq_len; - } - uint32_t inner_token_idx = token_idx - prev_token; - uint32_t batch_offset = offset * head_num * embed; - uint32_t inner_batch_offset = (inner_token_idx / block_size) * head_num * block_size * embed + - (inner_token_idx % block_size) * embed; - *token_offset = batch_offset + inner_batch_offset; -} -template -__aicore__ __inline__ bool reshape_and_cache_npd_kv_copy(__gm__ uint8_t *__restrict__ input_gm, - __gm__ uint8_t *__restrict__ slot_mapping_input_gm, - __gm__ uint8_t *__restrict__ q_seq, - __gm__ uint8_t *__restrict__ kv_seq, - __gm__ uint8_t *__restrict__ block_tbl, - int32_t block_size, - int32_t head_num, - int32_t embed, - int64_t start_index, - __gm__ uint8_t *__restrict__ cache_output_gm, - __gm__ uint8_t *__restrict__ update_output_gm) { - int64_t hidden_size = head_num * embed; - int64_t start = start_index * hidden_size; - int32_t slot_value = (int32_t)(*((__gm__ int32_t *)slot_mapping_input_gm + start_index)); - if (slot_value < 0) return false; - int32_t num_hidden_blocks = hidden_size * sizeof(Dtype) / BLOCK_SIZE; - - __ubuf__ uint8_t *temp_ubuf = (__ubuf__ uint8_t *)get_imm(0); // 临时存放 token - copy_gm_to_ubuf((__ubuf__ Dtype *)temp_ubuf, (__gm__ Dtype *)input_gm + start, - 0, // sid - 1, // nBurst - num_hidden_blocks, // lenBurst - 0, // srcGap - 0); // dstGap - set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); - - // cache is stored [B_NUM, N, B_SIZE, D] - int64_t cache_start = (slot_value / block_size) * head_num * block_size * embed + (slot_value % block_size) * embed; - int32_t dst_stride = (block_size - 1) * embed * sizeof(Dtype) / BLOCK_SIZE; - int32_t num_embed_blocks = embed * sizeof(Dtype) / BLOCK_SIZE; - - copy_ubuf_to_gm((__gm__ Dtype *)cache_output_gm + cache_start, (__ubuf__ Dtype *)temp_ubuf, - 0, // sid - head_num, // nBurst - num_embed_blocks, // lenBurst - 0, // srcGap - dst_stride); // dstGap - - // update key\value - [T/B_SIZE, N, B_SIZE, D] - // Get Token base offset - uint32_t token_offset = 0; - reshape_and_cache_npd_kv_npd_offset(start_index, block_size, head_num, embed, kv_seq, &token_offset); - copy_ubuf_to_gm((__gm__ Dtype *)update_output_gm + token_offset, (__ubuf__ Dtype *)temp_ubuf, - 0, // sid - head_num, // nBurst - num_embed_blocks, // lenBurst - 0, // srcGap - dst_stride); // dstGap - - set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); - return true; -} +template +class ReshapeAndCache { + public: + void __aicore__ __inline__ Init(GM_ADDR key, GM_ADDR value, GM_ADDR key_cache, GM_ADDR value_cache, + GM_ADDR slot_mapping, GM_ADDR q_seq, GM_ADDR kv_seq, GM_ADDR block_tbl, GM_ADDR k_out, + GM_ADDR v_out, GM_ADDR workspace, GM_ADDR tiling) { + in_key_.SetGlobalBuffer(reinterpret_cast<__gm__ Dtype *>(key)); + in_value_.SetGlobalBuffer(reinterpret_cast<__gm__ Dtype *>(value)); + in_out_key_cache_.SetGlobalBuffer(reinterpret_cast<__gm__ Dtype *>(key_cache)); + in_out_value_cache_.SetGlobalBuffer(reinterpret_cast<__gm__ Dtype *>(value_cache)); + in_slot_mapping_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(slot_mapping)); + in_block_tbl_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(block_tbl)); + out_key_.SetGlobalBuffer(reinterpret_cast<__gm__ Dtype *>(k_out)); + out_value_.SetGlobalBuffer(reinterpret_cast<__gm__ Dtype *>(v_out)); + in_tiling.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(tiling)); -template -__aicore__ __inline__ void reshape_and_cache_npd_kernel( - __gm__ uint8_t *__restrict__ key_input_gm, __gm__ uint8_t *__restrict__ value_input_gm, - __gm__ uint8_t *__restrict__ key_cache_gm, __gm__ uint8_t *__restrict__ value_cache_gm, - __gm__ uint8_t *__restrict__ slot_mapping_input_gm, __gm__ uint8_t *__restrict__ q_seq, - __gm__ uint8_t *__restrict__ kv_seq, __gm__ uint8_t *__restrict__ block_tbl, __gm__ uint8_t *__restrict__ k_out, - __gm__ uint8_t *__restrict__ v_out, __gm__ uint8_t *__restrict__ workspace, - __gm__ uint8_t *__restrict__ tiling_para_gm) { - int32_t num_tokens = (int32_t)(*((__gm__ int32_t *)tiling_para_gm + 2)); - int32_t hidden_size = (int32_t)(*((__gm__ int32_t *)tiling_para_gm + 3)); - int32_t block_size = (int32_t)(*((__gm__ int32_t *)tiling_para_gm + 6)); - int32_t k_head_num = (int32_t)(*((__gm__ int32_t *)tiling_para_gm + 4)); - int32_t v_head_num = (int32_t)(*((__gm__ int32_t *)tiling_para_gm + 5)); - - int32_t core_num = get_block_num(); - int64_t per_core_task_num = num_tokens * 2 / core_num; - int32_t tail_task_num = num_tokens * 2 % core_num; - - int32_t block_id = get_block_idx(); - int64_t start_task_id = block_id * per_core_task_num; - - if (block_id < tail_task_num) { - per_core_task_num++; - start_task_id += block_id; - } else { - start_task_id += tail_task_num; + num_tokens_ = in_tiling.GetValue(2); + hidden_size_ = in_tiling.GetValue(3); + k_head_num_ = in_tiling.GetValue(4); + v_head_num_ = in_tiling.GetValue(5); + page_size_ = in_tiling.GetValue(6); + batch_size_ = in_tiling.GetValue(7); + uint32_t allc_size = (k_head_num_ > v_head_num_ ? k_head_num_ : v_head_num_) * hidden_size_; + in_local_ = ub_allocator_.Alloc(allc_size); + in_kv_seq_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(tiling) + 32); + in_q_seq_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(tiling) + 32 + batch_size_); } - for (int64_t i = 0; i < per_core_task_num; i++) { - if (i + start_task_id < num_tokens) { - reshape_and_cache_npd_kv_copy(key_input_gm, slot_mapping_input_gm, q_seq, kv_seq, block_tbl, block_size, - k_head_num, hidden_size / k_head_num, (i + start_task_id), key_cache_gm, - k_out); + void __aicore__ __inline__ Process() { + int32_t core_num = get_block_num(); + int64_t per_core_task_num = num_tokens_ * 2 / core_num; + int32_t tail_task_num = num_tokens_ * 2 % core_num; + + int32_t block_id = get_block_idx(); + int64_t start_task_id = block_id * per_core_task_num; + + if (block_id < tail_task_num) { + per_core_task_num++; + start_task_id += block_id; } else { - reshape_and_cache_npd_kv_copy(value_input_gm, slot_mapping_input_gm, q_seq, kv_seq, block_tbl, block_size, - v_head_num, hidden_size / v_head_num, (i + start_task_id - num_tokens), - value_cache_gm, v_out); + start_task_id += tail_task_num; + } + + for (int64_t i = 0; i < per_core_task_num; i++) { + if (i + start_task_id < num_tokens_) { + ReshapeAndCacheDo(in_key_, k_head_num_, hidden_size_ / k_head_num_, (i + start_task_id), in_out_key_cache_, + out_key_); + } else { + ReshapeAndCacheDo(in_value_, v_head_num_, hidden_size_ / v_head_num_, (i + start_task_id - num_tokens_), + in_out_value_cache_, out_value_); + } } } -} + + private: + void __aicore__ __inline__ CopyIn(AscendC::GlobalTensor &in_kv, uint64_t src_offset, + AscendC::DataCopyParams cpy_param) { + DataCopy(in_local_, in_kv[src_offset], cpy_param); + } + + void __aicore__ __inline__ CopyOut(AscendC::GlobalTensor &out, uint64_t dst_offset, + AscendC::DataCopyParams cpy_param) { + DataCopy(out[dst_offset], in_local_, cpy_param); + } + + void __aicore__ __inline__ ComputeKvNpdOffset(int token_idx, int head_num, int embed, uint64_t *token_offset) { + uint32_t offset = 0; + uint32_t i = 0; + int32_t limit = in_kv_seq_.GetValue(i); + int32_t prev_seq_len = limit; + uint32_t prev_token = 0; + while (token_idx >= limit) { + offset += ((prev_seq_len + page_size_ - 1) / page_size_) * page_size_; + prev_token += prev_seq_len; + i++; + prev_seq_len = in_kv_seq_.GetValue(i); + limit += prev_seq_len; + } + uint32_t inner_token_idx = token_idx - prev_token; + uint32_t batch_offset = offset * head_num * embed; + uint32_t inner_batch_offset = + (inner_token_idx / page_size_) * head_num * page_size_ * embed + (inner_token_idx % page_size_) * embed; + *token_offset = batch_offset + inner_batch_offset; + } + + void __aicore__ __inline__ ReshapeAndCacheDo(AscendC::GlobalTensor &in_kv, int32_t kv_head_num, int32_t embed, + int32_t token_idx, AscendC::GlobalTensor &out_cache, + AscendC::GlobalTensor &out_kv) { + uint64_t start = token_idx * hidden_size_; + int32_t slot_value = in_slot_mapping_.GetValue(token_idx); + if (slot_value < 0) return; + uint16_t num_hidden_blocks = hidden_size_ * sizeof(Dtype) / BLOCK_SIZE; + uint16_t num_embed_blocks = embed * sizeof(Dtype) / BLOCK_SIZE; + + // step I: copy token into local memory + AscendC::DataCopyParams cpy_in_param(1, num_hidden_blocks, 0, 0); + CopyIn(in_kv, start, cpy_in_param); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + // step II: copy token into kv_cache + // cache is stored [B_NUM, B_SIZE, N, D] + uint64_t cache_start = slot_value * hidden_size_; + uint16_t dst_stride = 0; + uint16_t n_burst = 1; + uint16_t n_blks = num_hidden_blocks; + if constexpr (kv_cache_npd) { + // cache is stored [B_NUM, N, B_SIZE, D] + cache_start = (slot_value / page_size_) * kv_head_num * page_size_ * embed + (slot_value % page_size_) * embed; + dst_stride = (page_size_ - 1) * embed * sizeof(Dtype) / BLOCK_SIZE; + n_burst = kv_head_num; + n_blks = num_embed_blocks; + } + AscendC::DataCopyParams cpy_out_cache_param(n_burst, n_blks, 0, dst_stride); + CopyOut(out_cache, cache_start, cpy_out_cache_param); + + // step III: copy token into kv_out + // update key\value - [TND] + uint64_t token_start = start; + dst_stride = 0; + n_burst = 1; + n_blks = num_hidden_blocks; + if constexpr (kv_npd) { + // update key\value - [T/B_SIZE, N, B_SIZE, D] + dst_stride = (page_size_ - 1) * embed * sizeof(Dtype) / BLOCK_SIZE; + n_burst = kv_head_num; + n_blks = num_embed_blocks; + ComputeKvNpdOffset(token_idx, kv_head_num, embed, &token_start); + } + AscendC::DataCopyParams cpy_out_kv_param(n_burst, n_blks, 0, dst_stride); + CopyOut(out_kv, token_start, cpy_out_kv_param); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + return; + } + AscendC::LocalMemAllocator ub_allocator_; + AscendC::GlobalTensor in_key_; + AscendC::GlobalTensor in_value_; + AscendC::GlobalTensor in_out_key_cache_; + AscendC::GlobalTensor in_out_value_cache_; + AscendC::GlobalTensor in_slot_mapping_; + AscendC::GlobalTensor in_q_seq_; + AscendC::GlobalTensor in_kv_seq_; + AscendC::GlobalTensor in_block_tbl_; + AscendC::GlobalTensor in_tiling; + AscendC::GlobalTensor out_key_; + AscendC::GlobalTensor out_value_; + AscendC::LocalTensor in_local_; + + int32_t num_tokens_ = 0; + int32_t hidden_size_ = 0; + int32_t k_head_num_ = 0; + int32_t v_head_num_ = 0; + int32_t page_size_ = 0; + int32_t batch_size_ = 0; +}; extern "C" __global__ __aicore__ void reshape_and_cache_npd(GM_ADDR key, GM_ADDR value, GM_ADDR key_cache, GM_ADDR value_cache, GM_ADDR slot_mapping, GM_ADDR q_seq, GM_ADDR kv_seq, GM_ADDR block_tbl, GM_ADDR k_out, GM_ADDR v_out, GM_ADDR workspace, GM_ADDR tiling) { - int32_t tiling_id = (int32_t)(*((__gm__ int32_t *)tiling + 0)); - if (tiling_id == 0) { - reshape_and_cache_npd_kernel(key, value, key_cache, value_cache, slot_mapping, q_seq, kv_seq, block_tbl, - k_out, v_out, workspace, tiling); + if (TILING_KEY_IS(0)) { + ReshapeAndCache r; + r.Init(key, value, key_cache, value_cache, slot_mapping, q_seq, kv_seq, block_tbl, k_out, v_out, workspace, tiling); + r.Process(); + } else if (TILING_KEY_IS(1)) { + ReshapeAndCache r; + r.Init(key, value, key_cache, value_cache, slot_mapping, q_seq, kv_seq, block_tbl, k_out, v_out, workspace, tiling); + r.Process(); + } else if (TILING_KEY_IS(2)) { + ReshapeAndCache r; + r.Init(key, value, key_cache, value_cache, slot_mapping, q_seq, kv_seq, block_tbl, k_out, v_out, workspace, tiling); + r.Process(); + } else if (TILING_KEY_IS(3)) { + ReshapeAndCache r; + r.Init(key, value, key_cache, value_cache, slot_mapping, q_seq, kv_seq, block_tbl, k_out, v_out, workspace, tiling); + r.Process(); } } diff --git a/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd.cc b/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd.cc index 03167d74453bbdc9493cda5cb53fb61513e10e55..8da49c002869033fd9732a2123c2cb03472662ff 100644 --- a/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd.cc +++ b/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd.cc @@ -18,41 +18,55 @@ // GRAPH MODE IMPLEMENTATION // ============================================================================= +#include #include +#include +#include #include #include +#include +#include +#include "op_host/reshape_and_cache_npd.h" #include "ops/framework/aclnn/graphmode/aclnn_kernel_mod.h" #include "ops/framework/utils.h" #include "mindspore/include/custom_op_api.h" namespace ms_custom_ops { -enum class CacheMode : int32_t { - ND = 0, - NPD = 1, -}; +const char *op_name = "ReshapeAndCacheNpd"; -enum class ReshapeAndCacheNpdInputIndex : size_t { - kInputKeyIndex = 0, - kInputValueIndex = 1, - kInputKeyCacheIndex = 2, - kInputValueCacheIndex = 3, - kInputSlotMappingIndex = 4, - kInputQSeqIndex = 5, - kInputKVSeqIndex = 6, - kInputBlockTblIndex = 7, - kInputCacheModeIndex = 8 -}; -enum class ReshapeAndCacheNpdOutputIndex : size_t { kOutputKeyIndex = 0, kOutputValueIndex = 1}; - -static void ReshapeAndCacheNpdCheckInputsShape(const std::string &op_name, const std::vector &key_shape, - const std::vector &value_shape, - const std::vector &key_cache_shape, - const std::vector &value_cache_shape, - const std::vector &slot_mapping_shape, - const std::vector &q_seq_shape, - const std::vector &kv_seq_shape, - const std::vector &block_tbl_shape) { +template +static inline std::vector GetSeqLenFromTensor(const mindspore::tensor::TensorPtr &seq_length_tensor) { + if (seq_length_tensor != nullptr) { + auto seq_length_values = static_cast(seq_length_tensor->data_c()); + auto seq_length_values_num = seq_length_tensor->DataSize(); + std::vector seq_len; + seq_len.reserve(seq_length_values_num); + std::transform(seq_length_values, seq_length_values + seq_length_values_num, std::back_inserter(seq_len), + [](T1 val) { return static_cast(val); }); + return seq_len; + } + return {}; +} + +template +static inline std::vector CastVector(const std::vector &src) { + auto elem_num = src.size(); + if (elem_num > 0) { + std::vector dst; + auto src_data = src.data(); + dst.reserve(elem_num); + std::transform(src_data, src_data + elem_num, std::back_inserter(dst), [](T1 val) { return static_cast(val); }); + return dst; + } + return {}; +} + +static void ReshapeAndCacheNpdCheckInputsShape( + const std::string &op_name, const std::vector &key_shape, const std::vector &value_shape, + const std::vector &key_cache_shape, const std::vector &value_cache_shape, + const std::vector &slot_mapping_shape, const std::vector &q_seq_shape, + const std::vector &kv_seq_shape, const std::vector &block_tbl_shape) { if (key_cache_shape.size() != kDim4 || key_shape.size() != kDim2 || value_cache_shape.size() != kDim4 || value_shape.size() != kDim2) { MS_LOG(EXCEPTION) << op_name << ", the dim of inputs should be value.dim=key.dim=2, " @@ -65,27 +79,24 @@ static void ReshapeAndCacheNpdCheckInputsShape(const std::string &op_name, const } if (q_seq_shape.size() != kDim1 || kv_seq_shape.size() != kDim1) { MS_LOG(EXCEPTION) << op_name << ", the dim of inputs should be q_seq.dim=kv_seq.dim=1, but got q_seq.dim=" - << "q_seq_shape.dim=" << q_seq_shape.size() - << "kv_seq_shape.dim=" << kv_seq_shape.size(); + << "q_seq_shape.dim=" << q_seq_shape.size() << "kv_seq_shape.dim=" << kv_seq_shape.size(); } if (block_tbl_shape.size() != kDim2) { MS_LOG(EXCEPTION) << op_name << ", the dim of block_tbl_shape is illegal, " << "block_tbl_shape.dim=" << block_tbl_shape.size(); } - MS_CHECK_VALUE(key_shape == value_shape && key_cache_shape == value_cache_shape, - CheckAndConvertUtils::FormatCommMsg( - op_name, ", key_shape should be equal value_shape, key_cache_shape should be equal value_cache_shape,", - " but got, key.shape=", key_shape, ", value_shape.shape=", value_shape, ", key_cache_shape=", - key_cache_shape, ", value_cache_shape=", value_cache_shape)); + MS_CHECK_VALUE( + key_shape == value_shape && key_cache_shape == value_cache_shape, + CheckAndConvertUtils::FormatCommMsg( + op_name, ", key_shape should be equal value_shape, key_cache_shape should be equal value_cache_shape,", + " but got, key.shape=", key_shape, ", value_shape.shape=", value_shape, ", key_cache_shape=", key_cache_shape, + ", value_cache_shape=", value_cache_shape)); } static void ReshapeAndCacheNpdCheckInputsType(const std::string &op_name, const TypeId &key_dtype, - const TypeId &value_dtype, - const TypeId &key_cache_dtype, - const TypeId &value_cache_dtype, - const TypeId &slot_mapping_dtype, - const TypeId &q_seq_dtype, - const TypeId &kv_seq_dtype, - const TypeId &block_tbl_dtype) { + const TypeId &value_dtype, const TypeId &key_cache_dtype, + const TypeId &value_cache_dtype, const TypeId &slot_mapping_dtype, + const TypeId &q_seq_dtype, const TypeId &kv_seq_dtype, + const TypeId &block_tbl_dtype) { const std::unordered_set valid_types = {kNumberTypeFloat16, kNumberTypeBFloat16}; std::unordered_set input_types = {key_dtype, value_dtype, key_cache_dtype, value_cache_dtype}; if (input_types.size() > 1) { @@ -108,38 +119,74 @@ static void ReshapeAndCacheNpdCheckInputsType(const std::string &op_name, const } if (valid_int_types.find(slot_mapping_dtype) == valid_int_types.end()) { MS_LOG(EXCEPTION) << op_name << ", the dtype of 'slot_mapping, q_seq, kv_seq, block_tbl' should be " - << TypeIdToString(kNumberTypeInt32) << " or " << TypeIdToString(kNumberTypeInt) - << ", but got '" << TypeIdToString(slot_mapping_dtype) << ", " << TypeIdToString(q_seq_dtype) - << ", " << TypeIdToString(kv_seq_dtype) << ", " << TypeIdToString(block_tbl_dtype) << "'"; + << TypeIdToString(kNumberTypeInt32) << " or " << TypeIdToString(kNumberTypeInt) << ", but got '" + << TypeIdToString(slot_mapping_dtype) << ", " << TypeIdToString(q_seq_dtype) << ", " + << TypeIdToString(kv_seq_dtype) << ", " << TypeIdToString(block_tbl_dtype) << "'"; } } static mindspore::ShapeVector ReshapeAndCacheNpdDoInferShape(const mindspore::ShapeVector &shape, - const uint32_t &block_size, const uint32_t &batch) { - int max_seq = shape.at(0) + (block_size - 1) * batch; + const mindspore::ShapeVector &cache_shape, + const uint32_t &batch, + const uint32_t &cache_key_value_layout, + const bool &key_value_layout_npd) { ShapeVector out_shape = shape; - out_shape.at(0) = max_seq; + if (key_value_layout_npd) { + if (cache_shape.size() != kDim4) { + MS_LOG(EXCEPTION) << "The dim cache should be 4 but got," << cache_shape.size(); + } + if (out_shape.at(0) >= 0) { + auto block_size = cache_key_value_layout ? cache_shape.at(kDim2) : cache_shape.at(kDim1); + auto max_seq = shape.at(0) + (block_size - 1) * batch; + out_shape.at(0) = max_seq; + } + } return out_shape; } + +static void ReshapeAndCacheNpdCheckLayout(const std::string &op_name, const std::string &out_kv_cache_layout, + const std::string &out_kv_layout) { + if (out_kv_cache_layout != ND_LAYOUT && out_kv_cache_layout != NPD_LAYOUT) { + MS_LOG(EXCEPTION) << op_name << ", out kv cache layout must be ND or NPD, but got" << out_kv_cache_layout; + } + if (out_kv_layout != TH_LAYOUT && out_kv_layout != NPD_LAYOUT) { + MS_LOG(EXCEPTION) << op_name << ", out kv layout must be TH or NPD, but got" << out_kv_cache_layout; + } +} class OPS_API ReshapeAndCacheNpdOpFuncImpl : public OpFuncImpl { public: ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { auto key_shape = input_infos[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyIndex)]->GetShape(); auto k_cache_shape = - input_infos[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyCacheIndex)]->GetShape(); - int32_t batch = - input_infos[static_cast(ReshapeAndCacheNpdInputIndex::kInputBlockTblIndex)]->GetShape().at(0); - int32_t block_size = k_cache_shape.at(kDim2); - ShapeVector out_shape = ReshapeAndCacheNpdDoInferShape(key_shape, block_size, batch); + input_infos[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyCacheIndex)]->GetShape(); + auto blk_tbl_shape = + input_infos[static_cast(ReshapeAndCacheNpdInputIndex::kInputBlockTblIndex)]->GetShape(); + if (blk_tbl_shape.size() != kDim2) { + MS_LOG(EXCEPTION) << "The dim of block table shape is illegal, " + << "block_tbl_shape.dim=" << blk_tbl_shape.size(); + } + int32_t batch = blk_tbl_shape.at(0); + + auto cache_kv_layout = input_infos[static_cast(ReshapeAndCacheNpdInputIndex::kInputKVCacheModeIndex)] + ->GetScalarValueWithCheck() == std::string(NPD_LAYOUT); + + auto kv_layout = input_infos[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyValueModeIndex)] + ->GetScalarValueWithCheck() == std::string(NPD_LAYOUT); + + ShapeVector out_shape = ReshapeAndCacheNpdDoInferShape(key_shape, k_cache_shape, batch, cache_kv_layout, kv_layout); return {out_shape, out_shape}; } std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { auto out_type = input_infos[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyIndex)]->GetType(); return {out_type, out_type}; } + std::set GetValueDependArgIndices() const override { + return {static_cast(ReshapeAndCacheNpdInputIndex::kInputQSeqIndex), + static_cast(ReshapeAndCacheNpdInputIndex::kInputKVSeqIndex)}; + } + bool GeneralInferRegistered() const override { return true; } }; - class ReshapeAndCacheNpdAscend : public AclnnCustomKernelMod { public: ReshapeAndCacheNpdAscend() : AclnnCustomKernelMod(std::move("aclnnReshapeAndCacheNpd")) {} @@ -148,33 +195,40 @@ class ReshapeAndCacheNpdAscend : public AclnnCustomKernelMod { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { MS_EXCEPTION_IF_NULL(stream_ptr); - RunOp( - stream_ptr, workspace, inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyIndex)], - inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputValueIndex)], - inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyCacheIndex)], - inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputValueCacheIndex)], - inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputSlotMappingIndex)], - inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputQSeqIndex)], - inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKVSeqIndex)], - inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputBlockTblIndex)], - cache_mode_, - outputs[static_cast(ReshapeAndCacheNpdOutputIndex::kOutputKeyIndex)], - outputs[static_cast(ReshapeAndCacheNpdOutputIndex::kOutputValueIndex)]); + RunOp(stream_ptr, workspace, inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyIndex)], + inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputValueIndex)], + inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyCacheIndex)], + inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputValueCacheIndex)], + inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputSlotMappingIndex)], + inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputQSeqIndex)], + inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKVSeqIndex)], + inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputBlockTblIndex)], kv_cache_mode_, kv_mode_, + outputs[static_cast(ReshapeAndCacheNpdOutputIndex::kOutputKeyIndex)], + outputs[static_cast(ReshapeAndCacheNpdOutputIndex::kOutputValueIndex)]); return true; } void GetWorkSpaceInfo(const std::vector &inputs, const std::vector &outputs) override { - cache_mode_ = inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputCacheModeIndex)] - ->GetValueWithCheck(); + kv_cache_mode_ = inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKVCacheModeIndex)] + ->GetValueWithCheck(); + kv_mode_ = inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyValueModeIndex)] + ->GetValueWithCheck(); + + ReshapeAndCacheNpdCheckLayout(op_name, kv_cache_mode_, kv_mode_); + std::vector q_cpu = + CastVector(inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputQSeqIndex)] + ->GetValueWithCheck>()); + std::vector kv_cpu = + CastVector(inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKVSeqIndex)] + ->GetValueWithCheck>()); + GetWorkspaceForResize(inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyIndex)], inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputValueIndex)], inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKeyCacheIndex)], inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputValueCacheIndex)], - inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputSlotMappingIndex)], - inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputQSeqIndex)], - inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputKVSeqIndex)], - inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputBlockTblIndex)], - cache_mode_, + inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputSlotMappingIndex)], q_cpu, + kv_cpu, inputs[static_cast(ReshapeAndCacheNpdInputIndex::kInputBlockTblIndex)], + kv_cache_mode_, kv_mode_, outputs[static_cast(ReshapeAndCacheNpdOutputIndex::kOutputKeyIndex)], outputs[static_cast(ReshapeAndCacheNpdOutputIndex::kOutputValueIndex)]); return; @@ -182,7 +236,8 @@ class ReshapeAndCacheNpdAscend : public AclnnCustomKernelMod { private: DEFINE_GET_WORKSPACE_FOR_RESIZE(); - int64_t cache_mode_{1}; + std::string kv_cache_mode_ = NPD_LAYOUT; + std::string kv_mode_ = NPD_LAYOUT; }; } // namespace ms_custom_ops @@ -198,20 +253,34 @@ std::vector reshape_and_cache_npd_custom(const ms::Tensor &key, cons const ms::Tensor &key_cache, const ms::Tensor &value_cache, const ms::Tensor &slot_mapping, const ms::Tensor &q_seq, const ms::Tensor &kv_seq, const ms::Tensor &block_tbl, - const int32_t &cache_mode) { - std::string op_name = "ReshapeAndCacheNpd"; + const std::string &kv_cache_layout, + const std::string &key_value_layout) { + auto q_cpu = GetSeqLenFromTensor(q_seq.tensor()->cpu()); + if (q_cpu.empty()) { + MS_LOG(EXCEPTION) << "Get q_cpu seq len failed "; + } + auto kv_cpu = GetSeqLenFromTensor(kv_seq.tensor()->cpu()); + if (kv_cpu.empty()) { + MS_LOG(EXCEPTION) << "Get kv_cpu seq len failed "; + } + auto runner = std::make_shared(op_name); ReshapeAndCacheNpdCheckInputsShape(op_name, key.shape(), value.shape(), key_cache.shape(), value_cache.shape(), slot_mapping.shape(), q_seq.shape(), kv_seq.shape(), block_tbl.shape()); ReshapeAndCacheNpdCheckInputsType(op_name, key.data_type(), value.data_type(), key_cache.data_type(), value_cache.data_type(), slot_mapping.data_type(), q_seq.data_type(), kv_seq.data_type(), block_tbl.data_type()); - ShapeVector out_shape = ReshapeAndCacheNpdDoInferShape(key.shape(), key_cache.shape().at(kDim2), - block_tbl.shape().at(0)); + ReshapeAndCacheNpdCheckLayout(op_name, kv_cache_layout, key_value_layout); + + auto cache_layout = kv_cache_layout == std::string(NPD_LAYOUT); + auto kv_layout = key_value_layout == std::string(NPD_LAYOUT); + + ShapeVector out_shape = + ReshapeAndCacheNpdDoInferShape(key.shape(), key_cache.shape(), block_tbl.shape().at(0), cache_layout, kv_layout); auto out_k = ms::Tensor(key.data_type(), out_shape); auto out_v = ms::Tensor(value.data_type(), out_shape); runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnReshapeAndCacheNpd, key, value, key_cache, value_cache, slot_mapping, - q_seq, kv_seq, block_tbl, cache_mode, out_k, out_v)); + q_cpu, kv_cpu, block_tbl, kv_cache_layout, key_value_layout, out_k, out_v)); runner->Run({key, value, key_cache, value_cache, slot_mapping, q_seq, kv_seq, block_tbl}, {out_k, out_v}); return {out_k, out_v}; } @@ -220,6 +289,6 @@ std::vector reshape_and_cache_npd_custom(const ms::Tensor &key, cons MS_CUSTOM_OPS_EXTENSION_MODULE(m) { m.def("reshape_and_cache_npd", &ms_custom_ops::reshape_and_cache_npd_custom, pybind11::arg("key"), pybind11::arg("value"), pybind11::arg("key_cache"), pybind11::arg("value_cache"), pybind11::arg("slot_mapping"), - pybind11::arg("q_seq"), pybind11::arg("kv_seq"), pybind11::arg("block_tbl"), pybind11::arg("cache_mode") = 1); + pybind11::arg("actual_seq_qlen"), pybind11::arg("actual_seq_kvlen"), pybind11::arg("block_tbl"), + pybind11::arg("kv_cache_layout") = NPD_LAYOUT, pybind11::arg("key_value_layout") = NPD_LAYOUT); } - diff --git a/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd.md b/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd.md index 060bb1458ac970e568efe7c886862359420c21aa..ca2362f8593f544ddcc5e51c106152d3faee6469 100644 --- a/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd.md +++ b/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd.md @@ -2,25 +2,45 @@ ## 描述 -reshape_and_cache_npd - store KV cache in NPD format, N-KV head number, P - block Size, D - embedding size +reshape_and_cache_npd - the operator responsible for storing key/value pairs in a KV cache in ND or NPD format. It also produces two outputs: the current key tensor and the value tensor, along with a history if required. -## 输入参数 +### 名称 -| Name | DType | Shape | Optional | Inplace | Format | Description | -|------------------------|-----------------------------------|-------------------------------------------------------------------------------------------------------------------------|----------|---------|--------|--------------------------------| -| key | Tensor[float16/bfloat16] | (num_tokens, num_head*head_dim) | No | No | ND | key 张量 | -| value | Tensor[float16/bfloat16] | (num_tokens, num_head*head_dim) | No | No | ND | value 张量 | -| key_cache | Tensor[float16/bfloat16] | NPD: (num_blocks, num_head, block_size, head_dim) | No | No | NPD | key_cache 张量 | -| value_cache | Tensor[float16/bfloat16] | NPD: (num_blocks, num_head, block_size, head_dim) | No | No | NPD | value_cache 张量 | -| slot_mapping | Tensor[int32] | (num_tokens,) | No | No | ND | slot_mapping 张量 | -| cache_mode | int | - | No | - | - | 缓存模式:0 表示 ND 格式,1 表示 NPD 格式 | +- 算子名:`reshape_and_cache_npd` + +### 输入参数 + +T - token number +BN - block number +N - head number +P - block size +D - embedding +B - Batch + +| Name | DType | Shape | Optional | Format | Description | +|------------------|---------------------------|--------------------------------------|----------|--------|-------------| +| key | Tensor[float16/bfloat16] | TH/NPD | No | ND | Key tensor | +| value | Tensor[float16/bfloat16] | TH/NPD | No | ND | Value tensor | +| key_cache | Tensor[float16/bfloat16] | NPD: (BN, N, P, D); ND: (BN, P, N, D)| No | ND | key cache | +| value_cache | Tensor[float16/bfloat16] | NPD: (BN, N, P, D); ND: (BN, P, N, D)| No | ND | value cache | +| actual_seq_qlen | Tensor[int32] | (B,) | No | ND | Number of query tokens in each B element | +| actual_seq_kvlen | Tensor[int32] | (B,) | No | ND | Number of key/value tokens in each batch element | +| slot_mapping | Tensor[int32] | (T,) | No | ND | Mapping of token into Key value cache | +| block_tbl | Tensor[int32] | (B, max_query_len) | No | ND | Mapping of page into Key value cache | +| kv_cache_layout | string | - | No | - | Key value cache in\out layout: supported values ND or NPD,默认 NPD | +| key_value_layout | string | - | No | - | Key/Value output layout: TH or NPD,默认 NPD | + +#### 参数补充说明 + +- q_seq_len / kv_seq_len + - Both must be provided and should be located on CPU memory. ## 输出参数 -| Name | DType | Shape | Description | -|--------|-----------------|--------------------------------------|-------------| -| key_out | Tensor[float16/bfloat16] | (num_tokens/block_size, num_head, block_size, head_dim) | key out in NPD format | -| value_out | Tensor[float16/bfloat16] | (num_tokens/block_size, num_head, block_size, head_dim) | key out in NPD format | +| Name | DType | Shape | Description | +|-----------|--------------------------|-------------------------------------------|---------------| +| key_out | Tensor[float16/bfloat16] | NPD:(T+History, N * P * D) TH: (T, N * D) | key output | +| value_out | Tensor[float16/bfloat16] | NPD:(T+History, N * P * D) TH: (T, N * D) | value output | ## 使用示例 @@ -28,16 +48,27 @@ reshape_and_cache_npd - store KV cache in NPD format, N-KV head number, P - bloc import mindspore as ms import ms_custom_ops import numpy as np +from mindspore import Tensor, context + + +np.random.seed(0) +context.set_context(device_target="Ascend", mode=context.GRAPH_MODE) + +T = 16 +BN = 1024 +P = 32 +N = 8 +D = 128 +B = 1 -# 创建输入张量 -key = ms.Tensor(np.random.rand(16, 4096), ms.float16) -value = ms.Tensor(np.random.rand(16, 4096), ms.float16) -key_cache = ms.Tensor(np.random.rand(1024, 32, 16, 128), ms.float16) -value_cache = ms.Tensor(np.random.rand(1024, 32, 16, 128), ms.float16) -slot_mapping = ms.Tensor(np.arange(128), ms.int32) -q_seq = ms.Tensor(np.arange(1), ms.int32) -kv_seq = ms.Tensor(np.arange(1), ms.int32) -block_tbl = ms.Tensor(np.arange(128).reshape(1,128), ms.int32) +key = ms.Tensor(np.random.rand(T, N * D), ms.float16) +value = ms.Tensor(np.random.rand(T, N * D), ms.float16) +key_cache = ms.Tensor(np.random.rand(BN, N, P, D), ms.float16) +value_cache = ms.Tensor(np.random.rand(BN, N, P, D), ms.float16) +slot_mapping = ms.Tensor(np.arange(T), ms.int32) +q_seq = ms.Tensor(np.full(B, T).astype(np.int32)).move_to("CPU") +kv_seq = ms.Tensor(np.full(B, T).astype(np.int32)).move_to("CPU") +block_tbl = ms.Tensor(np.arange(BN).reshape(B,BN), ms.int32) # 调用算子 k_out, v_out = ms_custom_ops.reshape_and_cache_npd( key=key, @@ -45,9 +76,11 @@ k_out, v_out = ms_custom_ops.reshape_and_cache_npd( key_cache=key_cache, value_cache=value_cache, slot_mapping=slot_mapping, - q_seq=q_seq, - kv_seq=kv_seq, + actual_seq_qlen=q_seq, + actual_seq_kvlen=kv_seq, block_tbl=block_tbl, - cache_mode=1 #ND0,NPD1 + kv_cache_layout="NPD", + key_value_layout="NPD" ) -``` \ No newline at end of file +print("shape ", k_out.shape) +``` diff --git a/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd_op.yaml b/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd_op.yaml index d38f7271e79e4974b3e80fcc282fd6188fb8aa6b..baf7280a5c039563e781e5cbf71a4ccb2ad135e2 100644 --- a/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd_op.yaml +++ b/ops/ascendc/reshape_and_cache_npd/reshape_and_cache_npd_op.yaml @@ -11,15 +11,18 @@ reshape_and_cache_npd: dtype: tensor slot_mapping: dtype: tensor - q_seq: + actual_seq_qlen: dtype: tensor - kv_seq: + actual_seq_kvlen: dtype: tensor block_tbl: dtype: tensor - cache_mode: - dtype: int - default: 1 + kv_cache_layout: + dtype: str + default: "'NPD'" + key_value_layout: + dtype: str + default: "'NPD'" args_signature: rw_write: key_cache, value_cache dtype_group: (key, key_cache), (value, value_cache) diff --git a/ops/ascendc/unpad_fa_npd/op_host/unpad_fa_npd.cpp b/ops/ascendc/unpad_fa_npd/op_host/unpad_fa_npd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0969fd55620f2da96f876dda105289c05e1c2f09 --- /dev/null +++ b/ops/ascendc/unpad_fa_npd/op_host/unpad_fa_npd.cpp @@ -0,0 +1,520 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "unpad_fa_npd.h" +#include +#include +#include +#include +#include "unpad_fa_npd_tiling.h" +#include "register/op_def_registry.h" +#include "graph/utils/type_utils.h" +#include "tiling/platform/platform_ascendc.h" +#include "utils/log/asc_cpu_log.h" + +namespace optiling { +static constexpr auto kDim1 = 1; +static constexpr auto kDim2 = 2; +static constexpr auto kDim3 = 3; +static constexpr auto kDim4 = 4; +static constexpr auto kDim9 = 9; +static constexpr auto kHigh32Bit = 32; + +class tiling_info { + public: + int32_t batch; + int32_t head_num = 0; + int32_t d = 0; + int32_t kv_head_num = 0; + int32_t head_stride = 0; + std::vector q_seq_len; + std::vector kv_seq_len; + int32_t max_q_seq = 0; + int32_t max_kv_seq = 0; + bool is_triu = false; + bool is_long_seq = false; + int32_t mask_type = 0; + int32_t mask_stride = 0; + int32_t mask_seq_len = 0; + int32_t splitm = 0; + // Friend function to overload the << operator + friend std::ostream &operator<<(std::ostream &os, const tiling_info &obj) { + os << "\nbatch " << obj.batch << "\n" + << "head_num " << obj.head_num << "\n" + << "d " << obj.d << "\n" + << "kv_head_num " << obj.kv_head_num << "\n" + << "head_stride " << obj.head_stride << "\n"; + os << "q_seq_len "; + for (auto &s : obj.q_seq_len) os << " " << s; + os << "\n"; + os << "kv_seq_len "; + for (auto &s : obj.kv_seq_len) os << " " << s; + os << "\n"; + os << "max_q_seq " << obj.max_q_seq << "\n" + << "is_triu " << obj.is_triu << "\n" + << "is_long_seq " << obj.is_long_seq << "\n" + << "mask_type " << obj.mask_type << "\n" + << "mask_stride " << obj.mask_stride << "\n" + << "mask_seq_len " << obj.mask_seq_len << "\n" + << "splitm " << obj.splitm; + return os; + } + void print_to_log() { + std::ostringstream oss; + oss << this; + ASC_CPU_LOG_INFO("tiling info: %s\n", oss.str().c_str()); + } +}; + +constexpr int32_t triu_dim = 128; +constexpr int32_t block_size = 16; +constexpr int32_t double_ping_pong_size = 32768 * 8; +constexpr int32_t splitm_double_ping_pong_size = 32768 * 16; + +#define FA_ASSERT_RET(context, condition, logMessage, ret) \ + do { \ + if (!(condition)) { \ + const char *name = ((context)->GetNodeName() == nullptr) ? "nil" : (context)->GetNodeName(); \ + ASC_CPU_LOG_ERROR("%s: %s", name, logMessage); \ + return ret; \ + } \ + } while (0) + +inline static int32_t UpDiv(int32_t val, int32_t divider = block_size) { return (val + divider - 1) / divider; } + +inline static int32_t UpRound(int32_t val, int32_t round = block_size) { return UpDiv(val, round) * round; } + +inline static uint32_t FloatToBits(float x) { + uint32_t u; + std::memcpy(&u, &x, sizeof(u)); + return u; +} + +static ge::graphStatus GetD(gert::TilingContext *context, int32_t head_num, int32_t *d) { + auto input_idx = static_cast(UnpadFaNpdInputIndex::kInputQIndex); + auto q_shape = context->GetInputShape(input_idx); + FA_ASSERT_RET(context, q_shape != nullptr, "q sahpe is null", ge::GRAPH_FAILED); + auto q_dims_num = q_shape->GetOriginShape().GetDimNum(); + if (q_dims_num == kDim2) { // TH + *d = static_cast(q_shape->GetOriginShape().GetDim(kDim1)) / head_num; + } else if (q_dims_num == kDim3) { // BSH + *d = static_cast(q_shape->GetOriginShape().GetDim(kDim2)) / head_num; + } else { + FA_ASSERT_RET(context, 0, "query format not supported", ge::GRAPH_FAILED); + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetKvHeadNum(gert::TilingContext *context, int32_t d, int32_t *kv_head_num) { + auto input_idx = static_cast(UnpadFaNpdInputIndex::kInputKIndex); + auto k_shape = context->GetInputShape(input_idx); + FA_ASSERT_RET(context, k_shape != nullptr, "k sahpe is null", ge::GRAPH_FAILED); + auto k_dims_num = k_shape->GetOriginShape().GetDimNum(); + if (k_dims_num == kDim2) { // TH + *kv_head_num = static_cast(k_shape->GetOriginShape().GetDim(kDim1)) / d; + } else if (k_dims_num == kDim3) { // BSH + *kv_head_num = static_cast(k_shape->GetOriginShape().GetDim(kDim2)) / d; + } else { + FA_ASSERT_RET(context, 0, "kv format not supported", ge::GRAPH_FAILED); + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetBatch(gert::TilingContext *context, int32_t *b) { + auto input_idx = static_cast(UnpadFaNpdInputIndex::kInputActualQSeqIndex); + auto q_seq_len = context->GetInputShape(input_idx); + FA_ASSERT_RET(context, q_seq_len != nullptr, "q sequence length sahpe is null", ge::GRAPH_FAILED); + auto q_seq_dims_num = q_seq_len->GetOriginShape().GetDimNum(); + if (q_seq_dims_num == 1) { + *b = q_seq_len->GetOriginShape().GetDim(0); + } else if (q_seq_dims_num == kDim2) { + *b = q_seq_len->GetOriginShape().GetDim(kDim1); + } else { + FA_ASSERT_RET(context, 0, "query sequence length format not supported", ge::GRAPH_FAILED); + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetQSeqLens(gert::TilingContext *context, int batch, std::vector *q_seq_len) { + auto input_idx = static_cast(UnpadFaNpdInputIndex::kInputActualQSeqIndex); + auto t = context->GetInputTensor(input_idx); + FA_ASSERT_RET(context, t != nullptr, "q sequence length is null", ge::GRAPH_FAILED); + FA_ASSERT_RET(context, t->GetShapeSize() == batch, "q sequence length size is not valid", ge::GRAPH_FAILED); + auto p = t->GetData(); + + q_seq_len->assign(p, p + batch); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetKvSeqLens(gert::TilingContext *context, int batch, std::vector *kv_seq_len) { + auto input_idx = static_cast(UnpadFaNpdInputIndex::kInputActualKVSeqIndex); + auto t = context->GetInputTensor(input_idx); + FA_ASSERT_RET(context, t != nullptr, "kv sequence length is null", ge::GRAPH_FAILED); + FA_ASSERT_RET(context, t->GetShapeSize() == batch, "kv sequence length size is not valid", ge::GRAPH_FAILED); + auto p = t->GetData(); + kv_seq_len->assign(p, p + batch); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetTriuMask(gert::TilingContext *context, bool *is_triu) { + auto input_idx = static_cast(UnpadFaNpdInputIndex::kInputAttnMaskIndex); + *is_triu = false; + if (context->GetOptionalInputTensor(input_idx) == nullptr) { + return ge::GRAPH_SUCCESS; + } + auto mask_shape = context->GetInputShape(input_idx); + if ((mask_shape && mask_shape->GetOriginShape().GetDimNum() == kDim2) && + (mask_shape->GetOriginShape().GetDim(0) == triu_dim) && (mask_shape->GetOriginShape().GetDim(1) == triu_dim)) { + *is_triu = true; + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetLongMask(gert::TilingContext *context, bool is_triu, bool *is_long_seq) { + auto input_idx = static_cast(UnpadFaNpdInputIndex::kInputAttnMaskIndex); + *is_long_seq = false; + if (context->GetOptionalInputTensor(input_idx) == nullptr) { + return ge::GRAPH_SUCCESS; + } + auto mask_shape = context->GetInputShape(input_idx); + if (mask_shape) { + auto dim_num = mask_shape->GetOriginShape().GetDimNum(); + auto max_seq = mask_shape->GetOriginShape().GetDim(dim_num - 1); + if ((max_seq == triu_dim) && is_triu) { + *is_long_seq = true; + } + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetMaskType(gert::TilingContext *context, int32_t *mask_type) { + auto input_idx = static_cast(UnpadFaNpdInputIndex::kInputAttnMaskIndex); + *mask_type = 0; + if (context->GetOptionalInputTensor(input_idx) == nullptr) { + return ge::GRAPH_SUCCESS; + } + auto mask_shape = context->GetInputShape(input_idx); + + if (mask_shape) { + *mask_type = 1; + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetMaskStride(gert::TilingContext *context, int32_t *mask_stride) { + auto input_idx = static_cast(UnpadFaNpdInputIndex::kInputAttnMaskIndex); + *mask_stride = 0; + if (context->GetOptionalInputTensor(input_idx) == nullptr) { + return ge::GRAPH_SUCCESS; + } + auto mask_shape = context->GetInputShape(input_idx); + + if (mask_shape) { + auto dim_num = mask_shape->GetOriginShape().GetDimNum(); + if (dim_num == kDim2) { + *mask_stride = 0; + } else if (dim_num == kDim3) { + *mask_stride = mask_shape->GetOriginShape().GetDim(kDim1); + } else { + FA_ASSERT_RET(context, 0, "mask format not supported", ge::GRAPH_FAILED); + } + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetMaskSeq(gert::TilingContext *context, int32_t *mask_seq_len) { + auto input_idx = static_cast(UnpadFaNpdInputIndex::kInputAttnMaskIndex); + *mask_seq_len = 0; + if (context->GetOptionalInputTensor(input_idx) == nullptr) { + return ge::GRAPH_SUCCESS; + } + auto mask_shape = context->GetInputShape(input_idx); + + if (mask_shape) { + auto dim_num = mask_shape->GetOriginShape().GetDimNum(); + *mask_seq_len = mask_shape->GetOriginShape().GetDim(dim_num - 1); + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetTilingKey(gert::TilingContext *context, tiling_info *ti, int32_t *tiling_key) { + auto input_idx = static_cast(UnpadFaNpdInputIndex::kInputQIndex); + auto t = context->GetInputTensor(input_idx); + FA_ASSERT_RET(context, t != nullptr, "q is null", ge::GRAPH_FAILED); + *tiling_key = 0; + bool is_bf_16 = t->GetDataType() == ge::DataType::DT_BF16; + *tiling_key = is_bf_16 ? 1 : 0; + if (is_bf_16 && ti->d <= 128 && ti->mask_type == 0) { + *tiling_key |= (1 << 1); + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InitUnpadFaNpdTilingInfo(gert::TilingContext *context, tiling_info &ti, float &tor, bool &npd, + int64_t &page_size) { + ti.head_num = *context->GetAttrs()->GetAttrPointer(0); + + FA_ASSERT_RET(context, (GetD(context, ti.head_num, &ti.d) == ge::GRAPH_SUCCESS), "fail to get embed size", + ge::GRAPH_FAILED); + FA_ASSERT_RET(context, (GetKvHeadNum(context, ti.d, &ti.kv_head_num) == ge::GRAPH_SUCCESS), "fail to get embed size", + ge::GRAPH_FAILED); + FA_ASSERT_RET(context, (GetBatch(context, &ti.batch) == ge::GRAPH_SUCCESS), "fail to get batch size", + ge::GRAPH_FAILED); + FA_ASSERT_RET(context, (GetQSeqLens(context, ti.batch, &ti.q_seq_len) == ge::GRAPH_SUCCESS), + "fail to get query sequnch length", ge::GRAPH_FAILED); + FA_ASSERT_RET(context, (GetKvSeqLens(context, ti.batch, &ti.kv_seq_len) == ge::GRAPH_SUCCESS), + "fail to key&value sequnch length", ge::GRAPH_FAILED); + ti.head_stride = 0; + ti.max_q_seq = *std::max_element(ti.q_seq_len.begin(), ti.q_seq_len.end()); + ti.max_kv_seq = *std::max_element(ti.kv_seq_len.begin(), ti.kv_seq_len.end()); + FA_ASSERT_RET(context, (GetTriuMask(context, &ti.is_triu) == ge::GRAPH_SUCCESS), "fail to setup is triu mask", + ge::GRAPH_FAILED); + FA_ASSERT_RET(context, (GetLongMask(context, ti.is_triu, &ti.is_long_seq) == ge::GRAPH_SUCCESS), + "fail to setup is long mask", ge::GRAPH_FAILED); + FA_ASSERT_RET(context, (GetMaskStride(context, &ti.mask_stride) == ge::GRAPH_SUCCESS), "fail to setup mask stride", + ge::GRAPH_FAILED); + FA_ASSERT_RET(context, (GetMaskType(context, &ti.mask_type) == ge::GRAPH_SUCCESS), "fail to setup mask type", + ge::GRAPH_FAILED); + FA_ASSERT_RET(context, (GetMaskSeq(context, &ti.mask_seq_len) == ge::GRAPH_SUCCESS), "fail to setup mask type", + ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +static int32_t ComputeBatchTilingParams(gert::TilingContext *context, tiling_info &ti, UnpadFaNpdTilingData &tiling, + bool npd, int64_t page_size, int d_round, int &q_blocks_total, + uint64_t &in_out_offset, uint64_t &kv_offset) { + constexpr int pp_buff_size = 128 * 128; + constexpr int pp_mm[] = {16, 32, 48, 64, 80, 96, 112, 128}; + constexpr int pp_mm_len = sizeof(pp_mm) / sizeof(pp_mm[0]); + constexpr int pp_mm_max = pp_mm[pp_mm_len - 1]; + constexpr int pp_nn[] = {16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256}; + constexpr int pp_nn_len = sizeof(pp_nn) / sizeof(pp_nn[0]); + constexpr int pp_nn_max = pp_nn[pp_nn_len - 1]; + + for (int i = 0; i < ti.batch; i++) { + auto tiling_param = &(tiling.get_buf()[i * KParaLastIdx + KHeadLastIdx]); + auto cur_q_seq = ti.q_seq_len.at(i); + auto cur_q_seq_round = UpRound(cur_q_seq); + auto cur_kv_seq = ti.kv_seq_len.at(i); + auto cur_kv_seq_round = UpRound(cur_kv_seq); + auto tiling_k = d_round < triu_dim ? triu_dim : d_round; + + uint32_t nUbd = std::min((pp_buff_size / tiling_k / block_size) * block_size, cur_kv_seq_round); + uint32_t nIbd = (nUbd > pp_nn_max) ? pp_nn_len - 1 : (nUbd / block_size - 1); + uint32_t mUbd = + std::min((pp_buff_size / std::max(d_round, pp_nn[nIbd]) / block_size) * block_size, cur_q_seq_round); + uint32_t mIbd = (mUbd > pp_mm_max) ? pp_mm_len - 1 : (mUbd / block_size - 1); + mUbd = (ti.splitm) ? pp_nn_max : pp_mm[mIbd]; + + int32_t q_blocks_num = (cur_q_seq != 0 && cur_kv_seq != 0) ? (UpDiv(cur_q_seq, mUbd)) : 0; + q_blocks_total += q_blocks_num; + + // sets M,N and K + tiling_param[KQSeqLenIdx] = cur_q_seq; + tiling_param[KKvSeqLenIdx] = cur_kv_seq; + tiling_param[KMMIdx] = mUbd; + tiling_param[KNNIdx] = pp_nn[nIbd]; + // sets IO offsets + tiling_param[KAddrQSeqOffsetHighIdx] = (in_out_offset >> kHigh32Bit); + tiling_param[KAddrQSeqOffsetLowIdx] = (static_cast(in_out_offset)); + tiling_param[KAddrKSeqOffsetHighIdx] = (kv_offset >> kHigh32Bit); + tiling_param[KAddrKSeqOffsetLowIdx] = (static_cast(kv_offset)); + tiling_param[KAddrVSeqOffsetHighIdx] = (kv_offset >> kHigh32Bit); + tiling_param[KAddrVSeqOffsetLowIdx] = (static_cast(kv_offset)); + tiling_param[KAddrOSeqOffsetHighIdx] = (in_out_offset >> kHigh32Bit); + tiling_param[KAddrOSeqOffsetLowIdx] = (static_cast(in_out_offset)); + tiling_param[KTotalQBlockIdx] = q_blocks_total; + tiling_param[KSplitIdx] = ti.splitm; + tiling_param[KStateIdx] = (static_cast(cur_q_seq) > 0 && static_cast(cur_kv_seq) > 0); + in_out_offset += static_cast(cur_q_seq * ti.head_num * ti.d); + if (npd) { + kv_offset += static_cast(UpRound(cur_kv_seq, page_size) * ti.kv_head_num * ti.d); + } else { + kv_offset += static_cast(cur_kv_seq * ti.kv_head_num * ti.d); + } + } + return q_blocks_total; +} + +static ge::graphStatus UnpadFaNpdTiling(gert::TilingContext *context) { + tiling_info ti; + UnpadFaNpdTilingData tiling; + + auto tor = static_cast(*context->GetAttrs()->GetAttrPointer(1)); + bool npd = std::string(context->GetAttrs()->GetAttrPointer(3)) == std::string("NPD"); + auto page_size = *context->GetAttrs()->GetAttrPointer(4); + auto status = InitUnpadFaNpdTilingInfo(context, ti, tor, npd, page_size); + if (status != ge::GRAPH_SUCCESS) { + return status; + } + + auto d_round = UpRound(ti.d); + int32_t q_blocks_total = 0; + uint64_t in_out_offset = 0; + uint64_t kv_offset = 0; + int tiling_key = 0; + ti.splitm = false; + + auto tiling_size = (KHeadLastIdx + KParaLastIdx * ti.batch) * sizeof(uint32_t); + FA_ASSERT_RET(context, (KMaxBatch >= ti.batch), "tiling size is too big", ge::GRAPH_FAILED); + ti.print_to_log(); + + // Call helper + q_blocks_total = + ComputeBatchTilingParams(context, ti, tiling, npd, page_size, d_round, q_blocks_total, in_out_offset, kv_offset); + + auto require_cores = static_cast(ti.head_num * q_blocks_total); + + auto tiling_head = tiling.get_buf(); + tiling_head[KBatchIdx] = ti.batch; + tiling_head[KMaxSeqIdx] = ti.mask_seq_len; + tiling_head[KInnerBatchSizeIdx] = ti.head_num; + tiling_head[KEmbeddingSizeIdx] = ti.d; + tiling_head[KBvHeadNumIdx] = ti.kv_head_num; + tiling_head[KTorIdx] = FloatToBits(tor); + tiling_head[KHeadStrideIdx] = ti.head_stride; + tiling_head[KMaskStrideIdx] = ti.mask_stride; + tiling_head[KIsTriuMaskIdx] = (static_cast(ti.is_triu)); + tiling_head[KTotalQBlkNumIdx] = q_blocks_total; + tiling_head[KIsClampIdx] = 0; + float clamp = 0.0f; + auto clampInt = FloatToBits(clamp); + tiling_head[KClampMinUptrIdx] = clampInt; + tiling_head[KClampMaxUptrIdx] = clampInt; + tiling_head[KNoneIdx] = 0; + tiling_head[KTilingHeadSizeIdx] = KHeadLastIdx; + tiling_head[KTilingParaSizeIdx] = KParaLastIdx; + FA_ASSERT_RET(context, (GetTilingKey(context, &ti, &tiling_key) == ge::GRAPH_SUCCESS), "fail to compute tiling key", + ge::GRAPH_FAILED); + tiling_head[KKeyIdx] = tiling_key; // TBD + tiling_head[KIsLongSeqIdx] = static_cast(ti.is_long_seq); + tiling_head[KMaxKVSeqIdx] = ti.max_kv_seq; + tiling_head[KIsAlibiMaskIdx] = 0; + tiling_head[KMaskTypeIdx] = static_cast(ti.mask_type); + tiling_head[KAlibiCompresOffsetIdx] = 0; + tiling_head[KAlibiLeftAlignIdx] = 0; + tiling_head[KEmbeddingSizeVIdx] = ti.d; + tiling_head[KQuantTypeIdx] = 0; + tiling_head[KDataShapeTypeIdx] = 0; + tiling_head[KScaleTypeIdx] = 0; + tiling_head[KWindowSizeIdx] = 0; + tiling_head[KMaxNumBlocksIdx] = 0; + tiling_head[KQMaxSeqLenIdx] = ti.max_q_seq; + tiling_head[KPreTokensIdx] = 0; + tiling_head[KNextTokensIdx] = 0; + tiling_head[KRazorLenIdx] = 0; + tiling_head[KTileQIdx] = 0; + tiling_head[KTileKvIdx] = 0; + tiling_head[KTextQLenIdx] = 0; + tiling_head[KTextKvLenIdx] = 0; + tiling_head[KKvLayoutIdx] = npd ? kDim2 : 0; + tiling_head[KPageSizeIdx] = page_size; + + auto ascendc_platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + auto core_num = ascendc_platform.GetCoreNumAic(); + core_num = (core_num > require_cores) ? require_cores : core_num; + uint64_t work_size = static_cast(core_num) * static_cast(double_ping_pong_size) * sizeof(float); + if (ti.splitm) { + work_size = static_cast(core_num) * static_cast(splitm_double_ping_pong_size) * sizeof(float); + } + tiling_head[KWsLowIdx] = (static_cast(work_size)); + + size_t *currentWorkspace = context->GetWorkspaceSizes(kDim1); + currentWorkspace[0] = work_size * kDim9; + + context->SetBlockDim(core_num); + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling_size); + context->SetTilingKey(tiling_key); + + return ge::GRAPH_SUCCESS; +} +} // namespace optiling + +namespace ge { +static graphStatus UnpadFaNpdInferShape(gert::InferShapeContext *context) { + auto q_shape = context->GetInputShape(0); + gert::Shape *attn_out_shape = context->GetOutputShape(0); + *attn_out_shape = *q_shape; + return ge::GRAPH_SUCCESS; +} + +static graphStatus UnpadFaNpdInferDataType(gert::InferDataTypeContext *context) { + const auto inputDataType = context->GetInputDataType(0); + context->SetOutputDataType(0, inputDataType); + return ge::GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class UnpadFaNpd : public OpDef { + public: + explicit UnpadFaNpd(const char *name) : OpDef(name) { + this->Input("query") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("key") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("value") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("attn_mask") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("actual_seq_qlen") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .ValueDepend(Option::OPTIONAL, DependScope::TILING) + .AutoContiguous(); + this->Input("actual_seq_kvlen") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .ValueDepend(Option::OPTIONAL, DependScope::TILING) + .AutoContiguous(); + + this->Attr("head_num").AttrType(REQUIRED).Int(1); + this->Attr("scale_value").AttrType(OPTIONAL).Float(1.0f); + this->Attr("q_input_layout").AttrType(OPTIONAL).String(LAYOUT_TH); + this->Attr("kv_input_layout").AttrType(OPTIONAL).String(LAYOUT_NPD); + this->Attr("block_size").AttrType(OPTIONAL).Int(16); + this->Output("attention_out") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->SetInferShape(ge::UnpadFaNpdInferShape).SetInferDataType(ge::UnpadFaNpdInferDataType); + this->AICore().SetTiling(optiling::UnpadFaNpdTiling).AddConfig("ascend910b"); + } +}; +OP_ADD(UnpadFaNpd); +} // namespace ops diff --git a/ops/ascendc/unpad_fa_npd/op_host/unpad_fa_npd.h b/ops/ascendc/unpad_fa_npd/op_host/unpad_fa_npd.h new file mode 100644 index 0000000000000000000000000000000000000000..3f0b030beaf10ed4468cd9fed0844f525de3acaf --- /dev/null +++ b/ops/ascendc/unpad_fa_npd/op_host/unpad_fa_npd.h @@ -0,0 +1,46 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_CUSTOM_OPS_ASCENDC_UNPAD_FA_NPD_H +#define MS_CUSTOM_OPS_ASCENDC_UNPAD_FA_NPD_H + +namespace optiling { + +enum class InputLayout { BSH = 0, TH = 1, NPD = 2 }; + +enum class UnpadFaNpdInputIndex { + kInputQIndex = 0, + kInputKIndex = 1, + kInputVIndex = 2, + kInputAttnMaskIndex = 3, + kInputActualQSeqIndex = 4, + kInputActualKVSeqIndex = 5, + kInputHeadNumIndex = 6, + kInputScaleValueIndex = 7, + kInputInputQLayoutIndex = 8, + kInputInputKVLayoutIndex = 9, + kInputBlockSizeIndex = 10 +}; + +enum class UnpadFaNpdOutputIndex { kOutputAttnOutIndex = 0 }; + +#define GRAPH_INPUT_KV_SEQ_NAME "actual_seq_kvlen", "batch_valid_length" +#define GRAPH_INPUT_Q_SEQ_NAME "actual_seq_qlen", "q_seq_lens" +#define LAYOUT_BSH "BSH" +#define LAYOUT_TH "TH" +#define LAYOUT_NPD "NPD" +} // namespace optiling +#endif // MS_CUSTOM_OPS_ASCENDC_UNPAD_FA_NPD_H diff --git a/ops/ascendc/unpad_fa_npd/op_host/unpad_fa_npd_tiling.h b/ops/ascendc/unpad_fa_npd/op_host/unpad_fa_npd_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..91bdc31e60de73792750115897dd7a7a3c9d4a8b --- /dev/null +++ b/ops/ascendc/unpad_fa_npd/op_host/unpad_fa_npd_tiling.h @@ -0,0 +1,89 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_CUSTOM_OPS_OPS_ASCENDC_UNPAD_FA_NPD_OP_HOST_UNPAD_FA_NPD_TILING_H +#define MS_CUSTOM_OPS_OPS_ASCENDC_UNPAD_FA_NPD_OP_HOST_UNPAD_FA_NPD_TILING_H +#include "register/tilingdata_base.h" + +namespace optiling { +#define KBatchIdx 0 +#define KMaxSeqIdx 1 +#define KInnerBatchSizeIdx 2 +#define KEmbeddingSizeIdx 3 +#define KBvHeadNumIdx 4 +#define KTorIdx 5 +#define KHeadStrideIdx 6 +#define KMaskStrideIdx 7 +#define KIsTriuMaskIdx 8 +#define KTotalQBlkNumIdx 9 +#define KIsClampIdx 10 +#define KClampMinUptrIdx 11 +#define KClampMaxUptrIdx 12 +#define KNoneIdx 13 +#define KTilingHeadSizeIdx 14 +#define KTilingParaSizeIdx 15 +#define KKeyIdx 16 +#define KIsLongSeqIdx 17 +#define KMaxKVSeqIdx 18 +#define KIsAlibiMaskIdx 19 +#define KMaskTypeIdx 20 +#define KAlibiCompresOffsetIdx 21 +#define KAlibiLeftAlignIdx 22 +#define KEmbeddingSizeVIdx 23 +#define KQuantTypeIdx 24 +#define KDataShapeTypeIdx 25 +#define KScaleTypeIdx 26 +#define KWindowSizeIdx 27 +#define KMaxNumBlocksIdx 28 +#define KQMaxSeqLenIdx 29 +#define KPreTokensIdx 30 +#define KNextTokensIdx 31 +#define KRazorLenIdx 32 +#define KTileQIdx 33 +#define KTileKvIdx 34 +#define KTextQLenIdx 35 +#define KTextKvLenIdx 36 +#define KKvLayoutIdx 37 +#define KPageSizeIdx 38 +#define KWsLowIdx 39 +#define KHeadLastIdx 40 + +#define KQSeqLenIdx 0 +#define KKvSeqLenIdx 1 +#define KMMIdx 2 +#define KNNIdx 3 +#define KAddrQSeqOffsetHighIdx 4 +#define KAddrQSeqOffsetLowIdx 5 +#define KAddrKSeqOffsetHighIdx 6 +#define KAddrKSeqOffsetLowIdx 7 +#define KAddrVSeqOffsetHighIdx 8 +#define KAddrVSeqOffsetLowIdx 9 +#define KAddrOSeqOffsetHighIdx 10 +#define KAddrOSeqOffsetLowIdx 11 +#define KSplitIdx 12 +#define KTotalQBlockIdx 13 +#define KStateIdx 14 +#define KParaLastIdx 15 + +#define KMaxBatch 512 + +BEGIN_TILING_DATA_DEF(UnpadFaNpdTilingData) +TILING_DATA_FIELD_DEF_ARR(uint32_t, KHeadLastIdx + KParaLastIdx * KMaxBatch, buf); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(UnpadFaNpd, UnpadFaNpdTilingData) + +} // namespace optiling +#endif // MS_CUSTOM_OPS_OPS_ASCENDC_UNPAD_FA_NPD_OP_HOST_UNPAD_FA_NPD_TILING_H diff --git a/ops/ascendc/unpad_fa_npd/op_kernel/unpad_fa_npd.cpp b/ops/ascendc/unpad_fa_npd/op_kernel/unpad_fa_npd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..17688df197b8292540ed1a21ca041b4ba5486f8d --- /dev/null +++ b/ops/ascendc/unpad_fa_npd/op_kernel/unpad_fa_npd.cpp @@ -0,0 +1,83 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ascendc/basic_api/kernel_operator.h" +#include "lib/matmul_intf.h" +#include "asd/unpad_fa/unpad_flash_attention_mix.cce" +#include "asd/unpad_fa/unpad_flashattention_bf16_mix.cce" +__aicore__ inline int64_t GetFftsBaseAddr() { return get_ffts_base_addr(); } + +extern "C" __global__ __aicore__ void unpad_fa_npd(GM_ADDR query, GM_ADDR key, GM_ADDR value, GM_ADDR attn_mask, + GM_ADDR q_seq, GM_ADDR kv_seq, GM_ADDR attn_out, GM_ADDR workspace, + GM_ADDR tiling) { +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) + GM_ADDR sync = (__gm__ uint8_t *)GetFftsBaseAddr(); + GM_ADDR q_gm = query; + GM_ADDR k_gm = key; + GM_ADDR v_gm = value; + GM_ADDR layerID_gm = nullptr; + GM_ADDR mask_gm = attn_mask; + GM_ADDR alibi_coeff_gm = nullptr; + GM_ADDR deq_qk_gm = nullptr; + GM_ADDR off_qk_gm = nullptr; + GM_ADDR deq_pv_gm = nullptr; + GM_ADDR off_pv_gm = nullptr; + GM_ADDR quant_p_gm = nullptr; + GM_ADDR o_gm = attn_out; + GM_ADDR tiling_para_gm = tiling; + uint32_t ws_size = static_cast(*((__gm__ uint32_t *)tiling_para_gm + 39)); + GM_ADDR s_gm = workspace; + GM_ADDR p_gm = workspace + ws_size; + GM_ADDR o_tmp_gm = workspace + 2 * ws_size; + GM_ADDR upo_tmp_gm = nullptr; + GM_ADDR logN_gm = nullptr; +#endif + if (TILING_KEY_IS(0)) { +#ifdef __DAV_C220_CUBE__ + unpda_fa_npd_half::UnpadAttentionDecoderAic fa_aic_fp16; + fa_aic_fp16.Run(sync, q_gm, k_gm, v_gm, layerID_gm, mask_gm, alibi_coeff_gm, deq_qk_gm, off_qk_gm, deq_pv_gm, + off_pv_gm, quant_p_gm, o_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); +#elif __DAV_C220_VEC__ + unpda_fa_npd_half::UnpadAttentionDecoderAiv fa_aiv_fp16; + fa_aiv_fp16.Run(sync, q_gm, k_gm, v_gm, layerID_gm, mask_gm, alibi_coeff_gm, deq_qk_gm, off_qk_gm, deq_pv_gm, + off_pv_gm, quant_p_gm, logN_gm, o_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); +#endif + } else if (TILING_KEY_IS(1)) { +#ifdef __DAV_C220_CUBE__ + unpda_fa_npd_bf16::FlashAttentionEncoderHighPrecision<__bf16> fa_cube(sync, q_gm, k_gm, v_gm, layerID_gm, s_gm, + p_gm, o_tmp_gm, tiling_para_gm); + fa_cube.Run(); +#elif __DAV_C220_VEC__ + unpda_fa_npd_bf16::FlashAttentionEncoderHighPrecisionVec<__bf16> fa_vec( + sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, + deq_pv_gm, off_pv_gm, logN_gm); + fa_vec.Run(); +#endif + } else if (TILING_KEY_IS(3)) { +#ifdef __DAV_C220_CUBE__ + unpda_fa_npd_bf16::FlashAttentionEncoderHighPrecisionCubeOpt<__bf16> fa_cube(sync, q_gm, k_gm, v_gm, layerID_gm, + s_gm, p_gm, o_tmp_gm, tiling_para_gm); + fa_cube.Run(); +#elif __DAV_C220_VEC__ + unpda_fa_npd_bf16::FlashAttentionEncoderHighPrecisionVecOpt + fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, tiling_para_gm, deq_qk_gm, off_qk_gm, + quant_p_gm, deq_pv_gm, off_pv_gm, logN_gm); + fa_vec.Run(); +#endif + } +} diff --git a/ops/ascendc/unpad_fa_npd/unpad_fa_npd.cc b/ops/ascendc/unpad_fa_npd/unpad_fa_npd.cc new file mode 100644 index 0000000000000000000000000000000000000000..80edea958d51e149146f002cae21989a585502e4 --- /dev/null +++ b/ops/ascendc/unpad_fa_npd/unpad_fa_npd.cc @@ -0,0 +1,259 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include +#include +#include +#include +#include +#include +#include +#include +#include "op_host/unpad_fa_npd.h" +// #include "kernel_common/op_host/npd/npd_common.h" +#include "ops/framework/aclnn/graphmode/aclnn_kernel_mod.h" +#include "ops/framework/utils.h" +#include "mindspore/include/custom_op_api.h" + +using optiling::UnpadFaNpdInputIndex; +using optiling::UnpadFaNpdOutputIndex; + +namespace ms_custom_ops { + +template +static inline std::vector GetSeqLenFromTensor(const mindspore::tensor::TensorPtr &seq_length_tensor) { + if (seq_length_tensor != nullptr) { + auto seq_length_values = static_cast(seq_length_tensor->data_c()); + auto seq_length_values_num = seq_length_tensor->DataSize(); + std::vector seq_len; + seq_len.reserve(seq_length_values_num); + std::transform(seq_length_values, seq_length_values + seq_length_values_num, std::back_inserter(seq_len), + [](T1 val) { return static_cast(val); }); + return seq_len; + } + return {}; +} + +template +static inline std::vector CastVector(const std::vector &src) { + auto elem_num = src.size(); + if (elem_num > 0) { + std::vector dst; + auto src_data = src.data(); + dst.reserve(elem_num); + std::transform(src_data, src_data + elem_num, std::back_inserter(dst), [](T1 val) { return static_cast(val); }); + return dst; + } + return {}; +} + +static void UnpadFaNpdCheckInputsShape(const std::string &op_name, const std::vector &q_shape, + const std::vector &k_shape, const std::vector &v_shape, + const std::vector &act_q_seq_shape, + const std::vector &act_kv_seq_shape) { + if (q_shape.size() != kDim2 && q_shape.size() != kDim3) { + MS_LOG(EXCEPTION) << op_name << ", q should be in TH or BSH layout"; + } + + if (k_shape.size() != v_shape.size()) { + MS_LOG(EXCEPTION) << op_name << ", key dim number should be equal to value dim number, " + << "key.dim=" << k_shape.size() << "value.dim=" << v_shape.size(); + } + + if (k_shape.size() != kDim4 && k_shape.size() != kDim2) { + MS_LOG(EXCEPTION) << op_name << ", key dim number should be 4 or 2, but got " + << "key.dim=" << k_shape.size(); + } + + if (v_shape.size() != kDim4 && v_shape.size() != kDim2) { + MS_LOG(EXCEPTION) << op_name << ", value dim number should be 4 or 2, but got " + << "value.dim=" << k_shape.size(); + } + + if (act_q_seq_shape.size() != kDim1 && act_kv_seq_shape.size() != kDim1) { + MS_LOG(EXCEPTION) << op_name + << ", the dim of inputs should be act_q_seq_shape.dim==act_kv_seq_shape.dim=1, but got " + << "act_q_seq_shape.dim=" << act_q_seq_shape.size() + << "act_kv_seq_shape.dim=" << act_kv_seq_shape.size(); + } +} + +static void UnpadFaNpdCheckInputsType(const std::string &op_name, const TypeId &query_dtype, const TypeId &key_dtype, + const TypeId &value_dtype, const TypeId &act_q_seq_dtype, + const TypeId &act_kv_seq_dtype) { + const std::unordered_set valid_types = {kNumberTypeFloat16, kNumberTypeBFloat16}; + std::unordered_set input_types = {query_dtype, key_dtype, value_dtype}; + if (input_types.size() > 1) { + MS_LOG(EXCEPTION) << op_name << ", the dtype of 'query_dtype, key_dtype, value_dtype' should be same, but got '" + << TypeIdToString(query_dtype) << ", " << TypeIdToString(key_dtype) << ", " + << TypeIdToString(value_dtype) << "'"; + } + if (valid_types.find(key_dtype) == valid_types.end()) { + MS_LOG(EXCEPTION) << op_name << ", the dtype of 'query_dtype, key_dtype, value_dtype' should be " + << TypeIdToString(kNumberTypeFloat16) << " or " << TypeIdToString(kNumberTypeBFloat16) + << ", but got '" << TypeIdToString(query_dtype) << ", " << TypeIdToString(key_dtype) << ", " + << TypeIdToString(value_dtype) << ", " << "'"; + } + const std::unordered_set valid_int_types = {kNumberTypeInt32, kNumberTypeInt}; + std::unordered_set input_int_types = {act_q_seq_dtype, act_kv_seq_dtype}; + if (input_int_types.size() > 1) { + MS_LOG(EXCEPTION) << op_name << ", the dtype of 'act_q_seq_dtype, act_kv_seq_dtype' should be same, but got '" + << TypeIdToString(act_q_seq_dtype) << ", " << TypeIdToString(act_kv_seq_dtype) << "'"; + } + if (valid_int_types.find(act_q_seq_dtype) == valid_int_types.end()) { + MS_LOG(EXCEPTION) << op_name << ", the dtype of 'act_q_seq_dtype, act_kv_seq_dtype' should be " + << TypeIdToString(kNumberTypeInt32) << " or " << TypeIdToString(kNumberTypeInt) << ", but got '" + << TypeIdToString(act_q_seq_dtype) << ", " << TypeIdToString(act_kv_seq_dtype); + } +} + +static mindspore::ShapeVector UnpadFaNpdDoInferShape(mindspore::ShapeVector shape) { + ShapeVector out_shape = shape; + return out_shape; +} + +class OPS_API UnpadFaNpdOpFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + auto q_shape = input_infos[static_cast(UnpadFaNpdInputIndex::kInputQIndex)]->GetShape(); + auto q_layout = input_infos[static_cast(UnpadFaNpdInputIndex::kInputInputQLayoutIndex)] + ->GetScalarValueWithCheck(); + auto kv_layout = input_infos[static_cast(UnpadFaNpdInputIndex::kInputInputKVLayoutIndex)] + ->GetScalarValueWithCheck(); + auto op_name = primitive->name(); + MS_CHECK_VALUE( + q_layout == LAYOUT_BSH || q_layout == LAYOUT_TH, + CheckAndConvertUtils::FormatCommMsg(op_name, " q_layout should be 'BSH' or 'TH', but got ", q_layout)); + + MS_CHECK_VALUE( + kv_layout == LAYOUT_BSH || kv_layout == LAYOUT_TH || kv_layout == LAYOUT_NPD, + CheckAndConvertUtils::FormatCommMsg(op_name, " kv_layout should be 'BSH' or 'TH' or 'NPD', but got ", kv_layout)); + + ShapeVector out_shape = UnpadFaNpdDoInferShape(q_shape); + return {out_shape}; + } + std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + auto out_type = input_infos[static_cast(UnpadFaNpdInputIndex::kInputQIndex)]->GetType(); + return {out_type}; + } + std::set GetValueDependArgIndices() const override { + return {static_cast(UnpadFaNpdInputIndex::kInputActualQSeqIndex), + static_cast(UnpadFaNpdInputIndex::kInputActualKVSeqIndex)}; + } + + bool GeneralInferRegistered() const override { return true; } +}; + +class UnpadFaNpdAscend : public AclnnCustomKernelMod { + public: + UnpadFaNpdAscend() : AclnnCustomKernelMod(std::move("aclnnUnpadFaNpd")) {} + ~UnpadFaNpdAscend() = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + MS_EXCEPTION_IF_NULL(stream_ptr); + RunOp(stream_ptr, workspace, inputs[static_cast(UnpadFaNpdInputIndex::kInputQIndex)], + inputs[static_cast(UnpadFaNpdInputIndex::kInputKIndex)], + inputs[static_cast(UnpadFaNpdInputIndex::kInputVIndex)], + inputs[static_cast(UnpadFaNpdInputIndex::kInputAttnMaskIndex)], + inputs[static_cast(UnpadFaNpdInputIndex::kInputActualQSeqIndex)], + inputs[static_cast(UnpadFaNpdInputIndex::kInputActualKVSeqIndex)], head_num_, scale_, q_layout_, + kv_layout_, block_size_, outputs[static_cast(UnpadFaNpdOutputIndex::kOutputAttnOutIndex)]); + return true; + } + + void GetWorkSpaceInfo(const std::vector &inputs, + const std::vector &outputs) override { + head_num_ = inputs[static_cast(UnpadFaNpdInputIndex::kInputHeadNumIndex)]->GetValueWithCheck(); + scale_ = static_cast( + inputs[static_cast(UnpadFaNpdInputIndex::kInputScaleValueIndex)]->GetValueWithCheck()); + q_layout_ = + inputs[static_cast(UnpadFaNpdInputIndex::kInputInputQLayoutIndex)]->GetValueWithCheck(); + kv_layout_ = + inputs[static_cast(UnpadFaNpdInputIndex::kInputInputKVLayoutIndex)]->GetValueWithCheck(); + block_size_ = inputs[static_cast(UnpadFaNpdInputIndex::kInputBlockSizeIndex)]->GetValueWithCheck(); + + auto q_cpu = CastVector(inputs[static_cast(UnpadFaNpdInputIndex::kInputActualQSeqIndex)] + ->GetValueWithCheck>()); + auto kv_cpu = CastVector(inputs[static_cast(UnpadFaNpdInputIndex::kInputActualKVSeqIndex)] + ->GetValueWithCheck>()); + + GetWorkspaceForResize(inputs[static_cast(UnpadFaNpdInputIndex::kInputQIndex)], + inputs[static_cast(UnpadFaNpdInputIndex::kInputKIndex)], + inputs[static_cast(UnpadFaNpdInputIndex::kInputVIndex)], + inputs[static_cast(UnpadFaNpdInputIndex::kInputAttnMaskIndex)], q_cpu, kv_cpu, + head_num_, scale_, q_layout_, kv_layout_, block_size_, + outputs[static_cast(UnpadFaNpdOutputIndex::kOutputAttnOutIndex)]); + return; + } + + private: + DEFINE_GET_WORKSPACE_FOR_RESIZE(); + std::string q_layout_{""}; + std::string kv_layout_{""}; + int64_t head_num_{0}; + int64_t block_size_{0}; + double scale_{1.0f}; +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(unpad_fa_npd, ms_custom_ops::UnpadFaNpdOpFuncImpl, ms_custom_ops::UnpadFaNpdAscend); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +namespace ms_custom_ops { +constexpr size_t kUnpadFaNpdOutputNum = 1; + +ms::Tensor unpad_fa_npd_custom(const ms::Tensor &query, const ms::Tensor &key, const ms::Tensor &value, + const std::optional &attn_mask, const ms::Tensor &q_seq, + const ms::Tensor &kv_seq, const int32_t &head_num, const float &scale, + const std::string &q_layout, const std::string &kv_layout, const int32_t &block_size) { + std::string op_name = "UnpadFaNpd"; + auto q_cpu = GetSeqLenFromTensor(q_seq.tensor()->cpu()); + if (q_cpu.empty()) { + MS_LOG(EXCEPTION) << "Get q_cpu seq len failed "; + } + auto kv_cpu = GetSeqLenFromTensor(kv_seq.tensor()->cpu()); + if (kv_cpu.empty()) { + MS_LOG(EXCEPTION) << "Get kv_cpu seq len failed "; + } + auto runner = std::make_shared(op_name); + UnpadFaNpdCheckInputsShape(op_name, query.shape(), key.shape(), value.shape(), q_seq.shape(), kv_seq.shape()); + UnpadFaNpdCheckInputsType(op_name, query.data_type(), key.data_type(), value.data_type(), q_seq.data_type(), + kv_seq.data_type()); + ShapeVector out_shape = UnpadFaNpdDoInferShape(query.shape()); + auto attn_out = ms::Tensor(query.data_type(), out_shape); + + runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnUnpadFaNpd, query, key, value, attn_mask, q_cpu, kv_cpu, head_num, + static_cast(scale), q_layout, kv_layout, block_size, attn_out)); + runner->Run({query, key, value, GetTensorOrEmpty(attn_mask), q_seq, kv_seq}, {attn_out}); + return attn_out; +} +} // namespace ms_custom_ops + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("unpad_fa_npd", &ms_custom_ops::unpad_fa_npd_custom, pybind11::arg("query"), pybind11::arg("key"), + pybind11::arg("value"), pybind11::arg("attn_mask") = nullptr, pybind11::arg("actual_seq_qlen"), + pybind11::arg("actual_seq_kvlen"), pybind11::arg("head_num") = 32, pybind11::arg("scale") = 1.0f, + pybind11::arg("q_layout") = LAYOUT_TH, pybind11::arg("kv_layout") = LAYOUT_NPD, + pybind11::arg("block_size") = 32); +} diff --git a/ops/ascendc/unpad_fa_npd/unpad_fa_npd.md b/ops/ascendc/unpad_fa_npd/unpad_fa_npd.md new file mode 100644 index 0000000000000000000000000000000000000000..58f3496b4d59d62e07ae509923c8459564c5ffdb --- /dev/null +++ b/ops/ascendc/unpad_fa_npd/unpad_fa_npd.md @@ -0,0 +1,106 @@ +# FlashAttention Encoder 算子 + +## 描述 + +FlashAttention is a high-performance implementation of self-attention, mainly reducing memory usage and improving throughput through blocking/recomputation, mask compression, and better memory access via NPD shape. +where N is head number, P is page size and D is head size + +## 接口与输入输出 + +### 名称 + +- 算子名:`unpad_fa_npd` + +### 输入参数 + +| Name | DType | Shape | Optional | Format | Description | +|------------------|---------------------------|---------------------------|----------|--------|-------------| +| query | Tensor[float16/bfloat16] | TH | No | ND | Query tensor | +| key | Tensor[float16/bfloat16] | TH/NPD | No | ND | Key tensor | +| value | Tensor[float16/bfloat16] | TH/NPD | No | ND | Value tensor | +| attn_mask | Tensor[float16/bfloat16] | (128,128) | No | ND | Upper triangular mask or None | +| actual_seq_qlen | Tensor[int32] | (batch,) | No | ND | Number of query tokens in each batch element | +| actual_seq_kvlen | Tensor[int32] | (batch,) | No | ND | Number of key/value tokens in each batch element | +| head_num | int | - | No | - | 注意力头数(H),默认 32(需显式设置 >0) | +| scale_value | float | - | No | - | QK 缩放系数 `qk_scale`(通常为 1/sqrt(head_dim)),默认 1.0 | +| q_input_layout | string | - | No | - | Query input layout: TH only TH is supported ,默认 TH | +| kv_input_layout | string | - | No | - | Key/Value input layout: TH or NPD,默认 NPD | +| block_size | int | - | No | - | Page table, block size,默认 32 | + +#### 参数补充说明 + +- attn_mask + - The mask should be an upper triangular matrix containing values of 0 and 1 for bfloat16, and values of 0 and -10000 for float16. + +- q_seq_len / kv_seq_len + - Both must be provided and should be located on CPU memory. + +### 输出参数 + +| Name | DType | Shape | Description | +|---------------|--------------------------|-----------------|-------------| +| attention_out | Tensor[float16/bfloat16] | 与 query 对齐 | 注意力输出 | + +### miscellaneous 杂项 + +- Typically used with reshape_and_cache_npd to efficiently manage all key/value permutations within an operator. + +### Python 使用示例 + +```python +import mindspore as ms +from mindspore import Tensor, context +import ms_custom_ops +import numpy as np + +# 创建输入张量 +np.random.seed(0) +context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE) + +head_num = 8 +kv_head_num = 4 +head_dim = 128 +scale_value = float(1.0 / np.sqrt(head_dim)) +block_size = 32 + +q_seq = np.array([block_size, block_size*2], dtype=np.int32) +kv_seq = np.array([block_size, block_size*2], dtype=np.int32) + +q_tokens = int(q_seq.sum()) +kv_tokens = int(kv_seq.sum()) + +query = Tensor(np.random.uniform(-1, 1, size=(q_tokens, head_num * head_dim)).astype(np.float16)) +key = np.random.uniform(-1, 1, size=(kv_tokens, kv_head_num * head_dim)) +value = np.random.uniform(-1, 1, size=(kv_tokens, kv_head_num * head_dim)) + +q_seq_len = Tensor(q_seq).move_to("CPU") +kv_seq_len = Tensor(kv_seq).move_to("CPU") + +# Permute to NPD format +npd_key = key.reshape(kv_tokens // block_size, block_size, kv_head_num, head_dim) +npd_value = value.reshape(kv_tokens // block_size, block_size, kv_head_num, head_dim) +npd_key = npd_key.transpose(0,2,1,3) +npd_value = npd_value.transpose(0,2,1,3) +npd_key = npd_key.reshape(kv_tokens, kv_head_num * head_dim) +npd_value = npd_value.reshape(kv_tokens, kv_head_num * head_dim) + +attn_mask = Tensor(np.ones(shape=(128, 128)).astype(np.float32), ms.float16) +npd_key = Tensor(npd_key, ms.float16) +npd_value = Tensor(npd_value, ms.float16) + +# 调用算子 +attention_out = ms_custom_ops.unpad_fa_npd( + query=query, + key=npd_key, + value=npd_value, + attn_mask=attn_mask, + actual_seq_qlen=q_seq_len, + actual_seq_kvlen=kv_seq_len, + head_num=head_num, + scale_value=scale_value, + q_input_layout="TH", + kv_input_layout="NPD", + block_size=block_size +) +print(attention_out.shape) # 期望: (q_tokens, head_num * head_dim) +``` diff --git a/ops/ascendc/unpad_fa_npd/unpad_fa_npd_op.yaml b/ops/ascendc/unpad_fa_npd/unpad_fa_npd_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5495aab2af1eeef2147d5d07bd4fc4937d01a9ab --- /dev/null +++ b/ops/ascendc/unpad_fa_npd/unpad_fa_npd_op.yaml @@ -0,0 +1,36 @@ +#operator unpad_fa_npd +unpad_fa_npd: + args: + query: + dtype: tensor + key: + dtype: tensor + value: + dtype: tensor + attn_mask: + dtype: tensor + default: None + actual_seq_qlen: + dtype: tensor + default: None + actual_seq_kvlen: + dtype: tensor + default: None + head_num: + dtype: int + default: 32 + scale_value: + dtype: float + default: 1.0 + q_input_layout: + dtype: str + default: "'TH'" + kv_input_layout: + dtype: str + default: "'NPD'" + block_size: + dtype: int + default: 32 + returns: + attention_out: + dtype: tensor diff --git a/scripts/op_compiler.py b/scripts/op_compiler.py index dbb8392e34fabedc37421b0e072cb0edd1ba9bcc..5b7c3b429bc1069f3b3525d76eefa7b95ffd0206 100644 --- a/scripts/op_compiler.py +++ b/scripts/op_compiler.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) OP_HOST = "op_host" OP_KERNEL = "op_kernel" -code_suffix = {"cpp", "h"} +code_suffix = {"cpp", "cc", "h"} SOC_VERSION_MAP = { @@ -43,7 +43,7 @@ SOC_VERSION_MAP = { "ascend910b3": "ascend910b", "ascend910b4": "ascend910b", "ascend910b4-1": "ascend910b", - "ascend910c": "ascend910_93", + "ascend910_93": "ascend910_93", "ascend910_9391": "ascend910_93", "ascend910_9392": "ascend910_93", "ascend910_9381": "ascend910_93", @@ -70,6 +70,7 @@ SOC_VERSION_MAP = { def get_config(): """get config from user""" parser = argparse.ArgumentParser() + parser.add_argument("--common_dirs", type=str, required=True) parser.add_argument("--op_dirs", type=str, required=True) parser.add_argument("--build_type", type=str, default="Release") parser.add_argument("--build_path", type=str, default="") @@ -100,13 +101,13 @@ class CustomOPCompiler(): for op_dir in self.op_dirs: if not os.path.isdir(op_dir): raise ValueError( - f"Config error! op directpry [{op_dir}] is not exist, " + f"Config error! op directory [{op_dir}] is not exist, " f"please check your set --op_dirs") if self.args.soc_version != "": soc_version_list = re.split(r"[;,]", self.args.soc_version) for soc_version in soc_version_list: - if soc_version.lower() not in SOC_VERSION_MAP.keys(): + if soc_version.lower() not in SOC_VERSION_MAP: raise ValueError( f"Config error! Unsupported soc version(s): {soc_version}! " f"Please check your set --soc_version and use ';' or ',' to separate multiple soc_versions. " @@ -115,7 +116,7 @@ class CustomOPCompiler(): if self.args.ascend_cann_package_path != "": if not os.path.isdir(self.args.ascend_cann_package_path): raise ValueError( - f"Config error! ascend cann package path [{self.args.ascend_cann_package_path}] is not valid path, " + f"Config error! ascend CANN package path [{self.args.ascend_cann_package_path}] is not valid path, " f"please check your set --ascend_cann_package_path") if self.args.install or self.args.install_path != "": @@ -129,31 +130,64 @@ class CustomOPCompiler(): os.makedirs(self.args.install_path, exist_ok=True) def exec_shell_command(self, command, stdout=None): + """run exec shell""" try: - result = subprocess.run(command, stdout=stdout, stderr=subprocess.STDOUT, shell=False, text=True, check=True) + result = subprocess.run(command, stdout=stdout, + stderr=subprocess.STDOUT, + shell=False, + text=True, + check=True) except FileNotFoundError as e: - logger.error(f"Command not found: {e}") - raise RuntimeError(f"Command not found: {e}") + logger.error("Command not found: %s", e) + raise RuntimeError(f'Command not found: {e}') from e except subprocess.CalledProcessError as e: - logger.error(f"Run {command} Command failed with return code {e.returncode}: {e.output}") - raise RuntimeError(f"Run {command} Command failed with return code {e.returncode}: {e.output}") + logger.error("Run %s Command failed with return code %s: %s", command, e.returncode, e.output) + raise RuntimeError( + f"Run {command} Command failed with return code {e.returncode}: {e.output}" + ) from e return result def init_config(self): """initialize config""" if self.args.ascend_cann_package_path == "": - self.args.ascend_cann_package_path = os.environ.get('ASCEND_HOME_PATH', "/usr/local/Ascend/ascend-toolkit/latest") + self.args.ascend_cann_package_path = os.environ.get('ASCEND_HOME_PATH', + "/usr/local/Ascend/ascend-toolkit/latest") if self.args.soc_version == "": self.args.soc_version = "ascend910b1,ascend310p1" def copy_code_file(self): """copy code file to custom project""" + + def _copy_dir(source_dir, dest_dir): + for item in os.listdir(source_dir): + source_item_path = os.path.join(source_dir, item) + dest_item_path = os.path.join(dest_dir, item) + shutil.copytree(source_item_path, dest_item_path) + + common_host = os.path.join(self.args.common_dirs, OP_HOST) + common_kernel = os.path.join(self.args.common_dirs, OP_KERNEL) + + if os.path.exists(common_host): + logger.info("Copy common files for host: %s", common_host) + target_path = os.path.join(self.custom_project, OP_HOST) + _copy_dir(common_host, target_path) + + if os.path.exists(common_kernel): + logger.info("Copy common files for kernel: %s", common_kernel) + target_path = os.path.join(self.custom_project, OP_KERNEL) + _copy_dir(common_kernel, target_path) + for op_dir in self.op_dirs: op_host_dir = os.path.join(op_dir, OP_HOST) op_kernel_dir = os.path.join(op_dir, OP_KERNEL) if not os.path.exists(op_host_dir) or not os.path.exists(op_kernel_dir): - logger.warning(f"The {op_dir} dose not contain {op_host_dir} or {op_kernel_dir}, skipped!") + logger.warning( + "The %s dose not contain %s or %s, skipped!", + op_dir, + op_host_dir, + op_kernel_dir, + ) continue for item in os.listdir(op_host_dir): @@ -177,13 +211,14 @@ class CustomOPCompiler(): os.chmod(os.path.join(root, f), 0o700) def trans_soc_version(self, soc_version_args): + """Convert soc_version string into the ai_core compute unit format""" soc_version_list = re.split(r"[;,]", soc_version_args) if len(soc_version_list) == 1: version_map = {"ascend910": "ascend910a", "ascend910b": "ascend910b1", "ascend310p": "ascend310p1", "ascned310b": "ascend310b1", - "ascend910c": "ascend910_9391"} + "ascend910_93": "ascend910_9391"} soc = soc_version_list[0].lower() return f"ai_core-{version_map.get(soc, soc)}" @@ -206,12 +241,13 @@ class CustomOPCompiler(): json.dump(json_data, f, indent=4) os.chmod(custom_json, stat.S_IRUSR | stat.S_IWUSR) - gen_command = ["msopgen", "gen", "-i", custom_json, "-c", compute_unit, "-lan", "cpp", "-out", self.custom_project] + gen_command = ["msopgen", "gen", "-i", custom_json, "-c", + compute_unit, "-lan", "cpp", "-out", self.custom_project] self.exec_shell_command(gen_command) if os.getenv("GCC_TOOLCHAIN"): gcc_path = os.getenv("GCC_TOOLCHAIN") - bisheng_gcc = ['sed', '-i', + bisheng_gcc = ['sed', '-i', f'/options.append("-I" + tikcpp_path)/i\\ options.append("--gcc-toolchain={gcc_path}")', f'{self.custom_project}/cmake/util/ascendc_impl_build.py'] self.exec_shell_command(bisheng_gcc) @@ -260,7 +296,7 @@ class CustomOPCompiler(): def install_custom_op(self): """install custom run""" if self.args.install or self.args.install_path != "": - logger.info("Install custom opp run in {}".format(self.args.install_path)) + logger.info("Install custom opp run in %s", self.args.install_path) os.environ['ASCEND_CUSTOM_OPP_PATH'] = self.args.install_path run_path = [] build_out_path = os.path.join(self.custom_project, "build_out") @@ -268,13 +304,15 @@ class CustomOPCompiler(): if item.split('.')[-1] == "run": run_path.append(os.path.join(build_out_path, item)) if not run_path: - raise RuntimeError("There is no custom run in {}".format(build_out_path)) + raise RuntimeError(f"There is no custom run in {build_out_path}") self.exec_shell_command(['bash', run_path[0]]) logger.info("Install custom run opp successfully!") logger.info( - "Please set [source ASCEND_CUSTOM_OPP_PATH={}/vendors/{}:$ASCEND_CUSTOM_OPP_PATH] to " - "make the custom operator effective in the current path.".format( - self.args.install_path, self.args.vendor_name)) + "Please set [source ASCEND_CUSTOM_OPP_PATH=%s/vendors/%s:$ASCEND_CUSTOM_OPP_PATH] to " + "make the custom operator effective in the current path.", + self.args.install_path, + self.args.vendor_name, + ) def clear_compile_project(self): """clear log and build out""" diff --git a/tests/st/test_custom_reshape_and_cache_npd.py b/tests/st/test_custom_reshape_and_cache_npd.py index b216af77fbdf8ff5823c5207252227dd15baad0f..fb181c2082cdf948bdbbac3061048a97bd0e1aa1 100644 --- a/tests/st/test_custom_reshape_and_cache_npd.py +++ b/tests/st/test_custom_reshape_and_cache_npd.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" tests_custom_pyboost_ascend """ +"""tests_custom_pyboost_ascend""" # Standard library imports from enum import Enum @@ -32,18 +32,22 @@ from mindspore.common.np_dtype import bfloat16 # Local imports import ms_custom_ops + def jit_for_graph_mode(fn): """ A decorator that conditionally applies jit to a function at runtime based on the context mode. """ jitted_fn = jit(fn) + @wraps(fn) def wrapper(*args, **kwargs): if context.get_context("mode") == context.GRAPH_MODE: return jitted_fn(*args, **kwargs) return fn(*args, **kwargs) + return wrapper + # Global constants NUM_SLOTS = 20 SLOT_SIZE = 16 @@ -53,14 +57,17 @@ NUM_HEADS = 32 K_HEAD_DIM = 128 V_HEAD_DIM = 128 + class CacheFormat(Enum): """Cache format enumeration""" + ND = "nd" NPD = "npd" class DataType(Enum): """Data type enumeration""" + FLOAT16 = np.float16 BFLOAT16 = bfloat16 @@ -68,21 +75,50 @@ class DataType(Enum): class ReshapeAndCacheNpdAll(nn.Cell): """Reshape and cache operation for ND/NPD format with all parameters""" + def __init__(self, cache_layout="NPD", kv_layout="NPD"): + super().__init__() + self.cache_layout = cache_layout + self.kv_layout = kv_layout + @jit_for_graph_mode - def construct(self, key, value, key_cache, value_cache, slot_map, q_seq, kv_seq, block_tbl, cache_mode=1): - k_out, value_out = ms_custom_ops.reshape_and_cache_npd( - key, value, key_cache, value_cache, slot_map, q_seq, kv_seq, block_tbl, cache_mode) + def construct(self, + key=None, + value=None, + key_cache=None, + value_cache=None, + slot_map=None, + actual_seq_qlen=None, + actual_seq_kvlen=None, + block_tbl=None): + """ReshapeAndCacheNpdAll construct""" + k_out, value_out = ms_custom_ops.reshape_and_cache_npd( + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + slot_mapping=slot_map, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen, + block_tbl=block_tbl, + kv_cache_layout=self.cache_layout, + key_value_layout=self.kv_layout + ) return k_out, value_out - class MindSporeInputFactory: """Factory for creating MindSpore inputs""" @staticmethod - def create_inputs(np_k: np.ndarray, np_v: np.ndarray, - np_k_cache: np.ndarray, np_v_cache: np.ndarray, - np_slot_map: np.ndarray, np_q_seq: np.ndarray, np_kv_seq: np.ndarray, - np_block_tbl: np.ndarray) -> Tuple[Tensor, ...]: + def create_inputs( + np_k: np.ndarray, + np_v: np.ndarray, + np_k_cache: np.ndarray, + np_v_cache: np.ndarray, + np_slot_map: np.ndarray, + np_q_seq: np.ndarray, + np_kv_seq: np.ndarray, + np_block_tbl: np.ndarray, + ) -> Tuple[Tensor, ...]: """Create MindSpore inputs""" ms_key = Tensor(np_k) ms_value = Tensor(np_v) @@ -97,16 +133,23 @@ class MindSporeInputFactory: def create_ms_inputs(np_k, np_v, np_k_cache, np_v_cache, np_slot_map, q_seq, kv_seq, block_tbl): """Legacy function for backward compatibility""" - return MindSporeInputFactory.create_inputs(np_k, np_v, np_k_cache, np_v_cache, np_slot_map, q_seq, kv_seq, - block_tbl) + return MindSporeInputFactory.create_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, q_seq, kv_seq, block_tbl + ) class TestResultVerifier: """Verify test results""" @staticmethod - def verify_results(ms_cache: Tensor, np_cache: np.ndarray, - dtype: np.dtype, truncate: bool = False, rtol: float = 0.001, atol: float = 0.001) -> None: + def verify_results( + ms_cache: Tensor, + np_cache: np.ndarray, + dtype: np.dtype, + truncate: bool = False, + rtol: float = 0.001, + atol: float = 0.001, + ) -> None: """Verify results with appropriate dtype handling""" if dtype == bfloat16: ms_cache_np = ms_cache.float().asnumpy() @@ -115,11 +158,11 @@ class TestResultVerifier: ms_cache_np = ms_cache.asnumpy() if truncate is False: - ms_cache_np = ms_cache_np.flatten() - np_cache = np_cache.flatten() + ms_cache_np = ms_cache_np.flatten() + np_cache = np_cache.flatten() else: - ms_cache_np = ms_cache_np[:np_cache.shape[0]].flatten() - np_cache = np_cache.flatten() + ms_cache_np = ms_cache_np[: np_cache.shape[0]].flatten() + np_cache = np_cache.flatten() assert np.allclose(ms_cache_np, np_cache, rtol=rtol, atol=atol) @@ -127,8 +170,12 @@ class TestResultVerifier: class TestConfig: """Test configuration""" - def __init__(self, device_target: str = "Ascend", mode: context = context.GRAPH_MODE, - jit_config: Optional[Dict[str, Any]] = None): + def __init__( + self, + device_target: str = "Ascend", + mode: context = context.GRAPH_MODE, + jit_config: Optional[Dict[str, Any]] = None, + ): self.device_target = device_target self.mode = mode self.jit_config = jit_config or {} @@ -163,30 +210,28 @@ class DimensionTestHelper: # =============================== # RESHAPE AND CACHE NPD TEST ARCHITECTURE # =============================== -""" -Test Structure Overview: - -1. ND FORMAT TESTS (cache_mode=0): - - Direct ND format testing without format conversion - - Data flow: Input(ND) → ReshapeAndCacheNpd → Output(ND) → Verify - - Tests: test_reshape_and_cache_nd_* - -2. NZ FORMAT TESTS (cache_mode=1): - - Tests NPD format with format conversion using trans_data - - Data flow: Input(ND) → TransData(ND→NZ) → ReshapeAndCacheNpd → TransData(NZ→ND) → Verify - - Tests: test_reshape_and_cache_npd_* - -3. KEY COMPONENTS: - - create_nd_inputs(): Generate ND format test data - - create_npd_inputs(): Generate NZ-compatible test data (different layout) - - get_nd_cached_slots(): Extract verification data from ND format cache - - get_npd_cached_slots(): Extract verification data from NZ format cache (legacy) - - npd_inference(): Generate golden reference results - -4. VERIFICATION STRATEGY: - - ND tests: Both actual and golden use ND format → direct comparison - - NPD tests: Convert actual results back to ND format → compare with ND golden -""" +# Test Structure Overview: +# +# 1. ND FORMAT TESTS (cache_mode=0): +# - Direct ND format testing without format conversion +# - Data flow: Input(ND) → ReshapeAndCacheNpd → Output(ND) → Verify +# - Tests: test_reshape_and_cache_nd_* +# +# 2. NZ FORMAT TESTS (cache_mode=1): +# - Tests NPD format with format conversion using trans_data +# - Data flow: Input(ND) → TransData(ND→NZ) → ReshapeAndCacheNpd → TransData(NZ→ND) → Verify +# - Tests: test_reshape_and_cache_npd_* +# +# 3. KEY COMPONENTS: +# - create_inputs(): Generate ND format test data +# - get_nd_cached_slots(): Extract verification data from ND format cache +# - get_npd_cached_slots(): Extract verification data from NZ format cache (legacy) +# - npd_inference(): Generate golden reference results +# +# 4. VERIFICATION STRATEGY: +# - ND tests: Both actual and golden use ND format → direct comparison +# - NPD tests: Convert actual results back to ND format → compare with ND golden + # =============================== # NPD FORMAT TESTS @@ -207,12 +252,12 @@ class TestDataGenerator: return np.random.choice(np.arange(num_tokens), num_tokens, replace=False).astype(np.int32) @staticmethod - def create_q_seq(batch: int, seq:int) -> np.ndarray: + def create_q_seq(batch: int, seq: int) -> np.ndarray: """Create q seq lens""" return np.full(batch, seq).astype(np.int32) @staticmethod - def create_kv_seq(batch: int, seq:int) -> np.ndarray: + def create_kv_seq(batch: int, seq: int) -> np.ndarray: """Create kv seq lens""" return np.full(batch, seq).astype(np.int32) @@ -221,7 +266,8 @@ class TestDataGenerator: return np.arange(page_num, dtype=np.int32).reshape(batch, (page_num + batch - 1) // batch) @staticmethod - def get_update_shapes(kv_dim: int, k_head_dim=None, v_head_dim=None, batch=None, head_num = None + def get_update_shapes( + kv_dim: int, k_head_dim=None, v_head_dim=None, batch=None, head_num=None ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int, int]: """Get update shapes for key and value, and number of tokens based on dimension""" # Use provided dimensions or fall back to global constants @@ -244,19 +290,22 @@ class TestDataGenerator: return key_update_shape, value_update_shape, num_tokens, page_num @staticmethod - def get_update_shape(kv_dim: int, is_key: bool = True, k_head_dim=None, v_head_dim=None, batch=None, head_num=None + def get_update_shape( + kv_dim: int, is_key: bool = True, k_head_dim=None, v_head_dim=None, batch=None ) -> Tuple[Tuple[int, ...], int, int]: """Legacy method for backward compatibility""" - key_shape, value_shape, num_tokens, page_num = TestDataGenerator.get_update_shapes(kv_dim, k_head_dim, - v_head_dim, batch) + key_shape, value_shape, num_tokens, page_num = TestDataGenerator.get_update_shapes( + kv_dim, k_head_dim, v_head_dim, batch + ) return (key_shape if is_key else value_shape), num_tokens, page_num -class NPDDataGenerator(TestDataGenerator): +class DataGenerator(TestDataGenerator): """Data generator for NPD format""" @staticmethod - def create_npd_inputs(dtype: np.dtype, kv_dim: int, k_head_dim=None, v_head_dim=None, batch=None, head_num=None + def create_inputs( + dtype: np.dtype, kv_dim: int, k_head_dim=None, v_head_dim=None, batch=None, head_num=None, npd=False ) -> Tuple[np.ndarray, ...]: """Create NPD format inputs""" # Use provided dimensions or fall back to global constants @@ -265,11 +314,14 @@ class NPDDataGenerator(TestDataGenerator): actual_batch = batch if batch is not None else BATCH_SIZE actual_head_num = head_num if head_num is not None else NUM_HEADS - key_new_shape, value_new_shape, num_tokens, page_num = TestDataGenerator.get_update_shapes(kv_dim, k_head_dim, - v_head_dim, batch, - head_num) - key_cache_shape = (page_num, actual_head_num, SLOT_SIZE, actual_k_head_dim) - value_cache_shape = (page_num, actual_head_num, SLOT_SIZE, actual_v_head_dim) + key_new_shape, value_new_shape, num_tokens, page_num = TestDataGenerator.get_update_shapes( + kv_dim, k_head_dim, v_head_dim, batch, head_num + ) + key_cache_shape = (page_num, SLOT_SIZE, actual_head_num, actual_k_head_dim) + value_cache_shape = (page_num, SLOT_SIZE, actual_head_num, actual_v_head_dim) + if npd: + key_cache_shape = (page_num, actual_head_num, SLOT_SIZE, actual_k_head_dim) + value_cache_shape = (page_num, actual_head_num, SLOT_SIZE, actual_v_head_dim) key_update = TestDataGenerator.create_random_data(key_new_shape, dtype) value_update = TestDataGenerator.create_random_data(value_new_shape, dtype) @@ -283,19 +335,21 @@ class NPDDataGenerator(TestDataGenerator): return key_update, value_update, key_cache, value_cache, slot_map, q_seq, kv_seq, block_tbl -def create_npd_inputs(dtype=np.float16, kv_dim=3, k_head_dim=None, v_head_dim=None, batch=None, head_num=None): +def create_inputs(dtype=np.float16, kv_dim=3, k_head_dim=None, v_head_dim=None, batch=None, head_num=None, npd=False): """Legacy function for backward compatibility""" - return NPDDataGenerator.create_npd_inputs(dtype, kv_dim, k_head_dim, v_head_dim, batch, head_num) + return DataGenerator.create_inputs(dtype, kv_dim, k_head_dim, v_head_dim, batch, head_num, npd) class InferenceEngine: """Inference engine for different formats""" + @staticmethod def up_div(x, y): return (x + y - 1) // y @staticmethod - def npd_permute(np_arr: np.ndarray, heads:int, embed:int, batch:int, ps:int): + def npd_permute(np_arr: np.ndarray, heads: int, embed: int, batch: int, ps: int): + """npd permute function""" if np_arr.ndim == 2: t_total, hd = np_arr.shape else: @@ -318,15 +372,19 @@ class InferenceEngine: # Step 3: reshape to [batch, blocks, ps, heads, embed] → permute to [batch*blocks, heads, ps, embed] blocks = InferenceEngine.up_div(t, ps) - np_arr = np_arr.reshape(batch*blocks, ps, heads, embed).transpose(0, 2, 1, 3) + np_arr = np_arr.reshape(batch * blocks, ps, heads, embed).transpose(0, 2, 1, 3) np_arr = np_arr.reshape(-1, heads, embed) - return np_arr.reshape(-1, heads,embed) + return np_arr.reshape(-1, heads, embed) @staticmethod - def npd_inference(key: np.ndarray, value: np.ndarray, key_cache: np.ndarray, value_cache: np.ndarray, - slot_map: np.ndarray, block_tbl:np.array - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def npd_cache_inference( + key: np.ndarray, + value: np.ndarray, + key_cache: np.ndarray, + value_cache: np.ndarray, + slot_map: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: """NPD format inference""" key_tmp = key.copy() value_tmp = value.copy() @@ -339,7 +397,6 @@ class InferenceEngine: value_head = value_cache.shape[1] value_head_dim = value_cache.shape[3] slot_size = key_cache.shape[2] - batch = block_tbl.shape[0] key_tmp = key_tmp.reshape(-1, key_head, key_head_dim) value_tmp = value_tmp.reshape(-1, value_head, value_head_dim) @@ -349,44 +406,329 @@ class InferenceEngine: for h in range(key_head): key_cache_ans[slot_idx][h][slot_offset] = key_tmp[i][h] value_cache_ans[slot_idx][h][slot_offset] = value_tmp[i][h] - return (InferenceEngine.npd_permute(key, key_head, key_head_dim, batch, slot_size), - InferenceEngine.npd_permute(value, value_head, value_head_dim, batch, slot_size), key_cache_ans, - value_cache_ans) + return (key_cache_ans, value_cache_ans) + + @staticmethod + def npd_kv_inference( + key: np.ndarray, + value: np.ndarray, + key_cache: np.ndarray, + value_cache: np.ndarray, + block_tbl: np.array, + cache_layout: bool, + ) -> Tuple[np.ndarray, np.ndarray]: + """NPD format inference""" + if cache_layout: + key_head = key_cache.shape[1] + key_head_dim = key_cache.shape[3] + value_head = value_cache.shape[1] + value_head_dim = value_cache.shape[3] + slot_size = key_cache.shape[2] + else: + key_head = key_cache.shape[2] + key_head_dim = key_cache.shape[3] + value_head = value_cache.shape[2] + value_head_dim = value_cache.shape[3] + slot_size = key_cache.shape[1] + batch = block_tbl.shape[0] + return ( + InferenceEngine.npd_permute(key, key_head, key_head_dim, batch, slot_size), + InferenceEngine.npd_permute(value, value_head, value_head_dim, batch, slot_size), + ) + + @staticmethod + def nd_cache_inference( + key: np.ndarray, + value: np.ndarray, + key_cache: np.ndarray, + value_cache: np.ndarray, + slot_map: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """ND format inference""" + key_tmp = key.copy() + value_tmp = value.copy() + key_cache_ans = key_cache.copy() + value_cache_ans = value_cache.copy() + + # Use different dimensions for key and value + key_head = key_cache.shape[2] + key_head_dim = key_cache.shape[3] + value_head = value_cache.shape[2] + value_head_dim = value_cache.shape[3] -def npd_inference(key, value, key_cache, value_cache, slot_map, block_tbl): + key_tmp = key_tmp.reshape(-1, key_head, key_head_dim) + value_tmp = value_tmp.reshape(-1, value_head, value_head_dim) + + for i, slot in enumerate(slot_map): + slot_idx = slot // key_cache.shape[1] + slot_offset = slot % key_cache.shape[1] + key_cache_ans[slot_idx][slot_offset] = key_tmp[i] + value_cache_ans[slot_idx][slot_offset] = value_tmp[i] + + return (key_cache_ans, value_cache_ans) + + @staticmethod + def nd_kv_inference( + key: np.ndarray, + value: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + return (key, value) + + +def nd_kv_inference(key, value): + """Legacy function for backward compatibility""" + return InferenceEngine.nd_kv_inference(key, value) + + +def nd_cache_inference(key, value, key_cache, value_cache, slot_map): """Legacy function for backward compatibility""" - return InferenceEngine.npd_inference(key, value, key_cache, value_cache, slot_map, block_tbl) + return InferenceEngine.nd_cache_inference(key, value, key_cache, value_cache, slot_map) + + +def npd_cache_inference(key, value, key_cache, value_cache, slot_map): + return InferenceEngine.npd_cache_inference(key, value, key_cache, value_cache, slot_map) + + +def npd_kv_inference(key, value, key_cache, value_cache, block_tbl, cache_layout): + return InferenceEngine.npd_kv_inference(key, value, key_cache, value_cache, block_tbl, cache_layout) + @pytest.mark.level0 @pytest.mark.platform_ascend910b @pytest.mark.env_onecard -@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('np_dtype', [np.float16, bfloat16]) -@pytest.mark.parametrize('kv_embed', [32, 128, 256]) -@pytest.mark.parametrize('batch', [1, 2, 8, 13, 16]) -@pytest.mark.parametrize('head_num', [4, 8, 32]) -def test_reshape_and_cache_nd_key_value(np_dtype, kv_embed, batch, head_num, run_mode): +@pytest.mark.parametrize("run_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("np_dtype", [np.float16, bfloat16]) +@pytest.mark.parametrize("kv_embed", [32, 128, 256]) +@pytest.mark.parametrize("batch", [1, 2, 8, 13, 16]) +@pytest.mark.parametrize("head_num", [4, 8, 32]) +def test_reshape_and_cache_npd_kv_npd_ckv(np_dtype, kv_embed, batch, head_num, run_mode): """ Feature: Test ReshapeAndCacheNpd. Description: Test ND format with key and value. Expectation: Assert that results are consistent with numpy. """ kv_dim = 2 + kv_cache_mode = "NPD" + kv_mode = "NPD" + test_config = TestConfig(device_target="Ascend", mode=run_mode) test_config.apply() - net = ReshapeAndCacheNpdAll() + net = ReshapeAndCacheNpdAll(kv_cache_mode, kv_mode) + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_q_seq, np_kv_seq, np_block_tbl = create_inputs( + np_dtype, kv_dim, kv_embed, kv_embed, batch, head_num, npd = "NPD" + ) + + if kv_mode == "NPD": + np_k_out, np_v_out = npd_kv_inference( + np_k, np_v, np_k_cache, np_v_cache, np_block_tbl, cache_layout= "NPD" + ) + else: + np_k_out, np_v_out = nd_kv_inference( + np_k, np_v + ) + + if kv_cache_mode == "NPD": + np_k_cache_out, np_v_cache_out = npd_cache_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map + ) + else: + np_k_cache_out, np_v_cache_out = nd_cache_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map + ) + + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, ms_q_seq, ms_kv_seq, ms_block_tbl = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_q_seq, np_kv_seq, np_block_tbl + ) + ms_q_seq = ms_q_seq.move_to("CPU") + ms_kv_seq = ms_kv_seq.move_to("CPU") + + # Run test + ms_k_out, ms_v_out = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, ms_q_seq, ms_kv_seq, ms_block_tbl) + + # Verify Output + TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np_dtype) + TestResultVerifier.verify_results(ms_v_cache, np_v_cache_out, np_dtype) + TestResultVerifier.verify_results(ms_k_out, np_k_out, np_dtype, True) + TestResultVerifier.verify_results(ms_v_out, np_v_out, np_dtype, True) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize("run_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("np_dtype", [np.float16, bfloat16]) +@pytest.mark.parametrize("kv_embed", [32, 128, 256]) +@pytest.mark.parametrize("batch", [1, 2, 8, 13, 16]) +@pytest.mark.parametrize("head_num", [4, 8, 32]) +def test_reshape_and_cache_npd_kv_nd_ckv(np_dtype, kv_embed, batch, head_num, run_mode): + """ + Feature: Test ReshapeAndCacheNpd. + Description: Test ND format with key and value. + Expectation: Assert that results are consistent with numpy. + """ + kv_dim = 2 + kv_cache_mode = "ND" + kv_mode = "NPD" + + test_config = TestConfig(device_target="Ascend", mode=run_mode) + test_config.apply() + + net = ReshapeAndCacheNpdAll(kv_cache_mode, kv_mode) + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_q_seq, np_kv_seq, np_block_tbl = create_inputs( + np_dtype, kv_dim, kv_embed, kv_embed, batch, head_num, kv_cache_mode == "NPD" + ) + + if kv_mode == "NPD": + np_k_out, np_v_out = npd_kv_inference( + np_k, np_v, np_k_cache, np_v_cache, np_block_tbl, kv_cache_mode == "NPD" + ) + else: + np_k_out, np_v_out = nd_kv_inference( + np_k, np_v + ) + + if kv_cache_mode == "NPD": + np_k_cache_out, np_v_cache_out = npd_cache_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map + ) + else: + np_k_cache_out, np_v_cache_out = nd_cache_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map + ) + + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, ms_q_seq, ms_kv_seq, ms_block_tbl = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_q_seq, np_kv_seq, np_block_tbl + ) + ms_q_seq = ms_q_seq.move_to("CPU") + ms_kv_seq = ms_kv_seq.move_to("CPU") + + # Run test + ms_k_out, ms_v_out = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, ms_q_seq, ms_kv_seq, ms_block_tbl) + + # Verify Output + TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np_dtype) + TestResultVerifier.verify_results(ms_v_cache, np_v_cache_out, np_dtype) + TestResultVerifier.verify_results(ms_k_out, np_k_out, np_dtype, True) + TestResultVerifier.verify_results(ms_v_out, np_v_out, np_dtype, True) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize("run_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("np_dtype", [np.float16, bfloat16]) +@pytest.mark.parametrize("kv_embed", [32, 128, 256]) +@pytest.mark.parametrize("batch", [1, 2, 8, 13, 16]) +@pytest.mark.parametrize("head_num", [4, 8, 32]) +def test_reshape_and_cache_th_kv_npd_ckv(np_dtype, kv_embed, batch, head_num, run_mode): + """ + Feature: Test ReshapeAndCacheNpd. + Description: Test ND format with key and value. + Expectation: Assert that results are consistent with numpy. + """ + kv_dim = 2 + kv_cache_mode = "NPD" + kv_mode = "TH" + + test_config = TestConfig(device_target="Ascend", mode=run_mode) + test_config.apply() + + net = ReshapeAndCacheNpdAll(kv_cache_mode, kv_mode) + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_q_seq, np_kv_seq, np_block_tbl = create_inputs( + np_dtype, kv_dim, kv_embed, kv_embed, batch, head_num, kv_cache_mode == "NPD" + ) + + if kv_mode == "NPD": + np_k_out, np_v_out = npd_kv_inference( + np_k, np_v, np_k_cache, np_v_cache, np_block_tbl, kv_cache_mode == "NPD" + ) + else: + np_k_out, np_v_out = nd_kv_inference( + np_k, np_v + ) + + if kv_cache_mode == "NPD": + np_k_cache_out, np_v_cache_out = npd_cache_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map + ) + else: + np_k_cache_out, np_v_cache_out = nd_cache_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map + ) + + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, ms_q_seq, ms_kv_seq, ms_block_tbl = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_q_seq, np_kv_seq, np_block_tbl + ) + ms_q_seq = ms_q_seq.move_to("CPU") + ms_kv_seq = ms_kv_seq.move_to("CPU") + + # Run test + ms_k_out, ms_v_out = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, ms_q_seq, ms_kv_seq, ms_block_tbl) + + # Verify Output + TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np_dtype) + TestResultVerifier.verify_results(ms_v_cache, np_v_cache_out, np_dtype) + TestResultVerifier.verify_results(ms_k_out, np_k_out, np_dtype, True) + TestResultVerifier.verify_results(ms_v_out, np_v_out, np_dtype, True) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize("run_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("np_dtype", [np.float16, bfloat16]) +@pytest.mark.parametrize("kv_embed", [32, 128, 256]) +@pytest.mark.parametrize("batch", [1, 2, 8, 13, 16]) +@pytest.mark.parametrize("head_num", [4, 8, 32]) +def test_reshape_and_cache_th_kv_nd_ckv(np_dtype, kv_embed, batch, head_num, run_mode): + """ + Feature: Test ReshapeAndCacheNpd. + Description: Test ND format with key and value. + Expectation: Assert that results are consistent with numpy. + """ + kv_dim = 2 + kv_cache_mode = "ND" + kv_mode = "TH" + + test_config = TestConfig(device_target="Ascend", mode=run_mode) + test_config.apply() - np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_q_seq, np_kv_seq, np_block_tbl = create_npd_inputs( - np_dtype, kv_dim, kv_embed, kv_embed, batch, head_num) - np_k_out, np_v_out, np_k_cache_out, np_v_cache_out = npd_inference( - np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_block_tbl) + net = ReshapeAndCacheNpdAll(kv_cache_mode, kv_mode) + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_q_seq, np_kv_seq, np_block_tbl = create_inputs( + np_dtype, kv_dim, kv_embed, kv_embed, batch, head_num, kv_cache_mode == "NPD" + ) + + if kv_mode == "NPD": + np_k_out, np_v_out = npd_kv_inference( + np_k, np_v, np_k_cache, np_v_cache, np_block_tbl, kv_cache_mode == "NPD" + ) + else: + np_k_out, np_v_out = nd_kv_inference( + np_k, np_v + ) + + if kv_cache_mode == "NPD": + np_k_cache_out, np_v_cache_out = npd_cache_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map + ) + else: + np_k_cache_out, np_v_cache_out = nd_cache_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map + ) ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, ms_q_seq, ms_kv_seq, ms_block_tbl = create_ms_inputs( - np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_q_seq, np_kv_seq, np_block_tbl) + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_q_seq, np_kv_seq, np_block_tbl + ) + ms_q_seq = ms_q_seq.move_to("CPU") + ms_kv_seq = ms_kv_seq.move_to("CPU") # Run test - ms_k_out, ms_v_out = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, ms_q_seq, ms_kv_seq, ms_block_tbl, 1) + ms_k_out, ms_v_out = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, ms_q_seq, ms_kv_seq, ms_block_tbl) # Verify Output TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np_dtype) diff --git a/tests/st/test_custom_unpad_fa_npd.py b/tests/st/test_custom_unpad_fa_npd.py new file mode 100644 index 0000000000000000000000000000000000000000..38e2a00608d28afc86ba5f15e7936293e89ff8a0 --- /dev/null +++ b/tests/st/test_custom_unpad_fa_npd.py @@ -0,0 +1,532 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""tests_custom_pyboost_ascend""" + +# Standard library imports +from enum import Enum +from functools import wraps +from typing import Tuple, Optional, Dict, Any + +# Third-party imports +import numpy as np +import pytest + +# MindSpore imports +import mindspore as ms +from mindspore import Tensor, context +from mindspore.common.api import jit +from mindspore.common.np_dtype import bfloat16 +from mindspore import ops +from mindspore.nn.cell import Cell +from mindspore.ops.operations.nn_ops import FlashAttentionScore + +# Local imports +import ms_custom_ops + +ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) +# ms.set_device("Ascend", 6) + + +def jit_for_graph_mode(fn): + """ + A decorator that conditionally applies jit to a function at runtime based on the context mode. + """ + jitted_fn = jit(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if context.get_context("mode") == context.GRAPH_MODE: + return jitted_fn(*args, **kwargs) + return fn(*args, **kwargs) + + return wrapper + + +class DataType(Enum): + """Data type enumeration""" + + FLOAT16 = np.float16 + BFLOAT16 = bfloat16 + + +class FlashAttention(Cell): + """Flash Attention Layer.""" + + def __init__(self, head_num, scale_value): + super().__init__() + self.head_num = head_num + self.scale_value = scale_value + self.reshape_and_cache = ops.auto_generate.ReshapeAndCache() + self.flash_attention = FlashAttentionScore(head_num=head_num, scale_value=scale_value, input_layout="TH") + + @jit_for_graph_mode + def construct( + self=None, + query=None, + key=None, + value=None, + key_cache=None, + value_cache=None, + slot_mapping=None, + attn_mask=None, + actual_seq_qlen=None, + actual_seq_kvlen=None + ): + """Forward process of the FlashAttention.""" + key = key.contiguous() + value = value.contiguous() + self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) + _, _, _, context_ = self.flash_attention( + query, key, value, None, None, None, attn_mask, None, actual_seq_qlen, actual_seq_kvlen + ) + return context_ + + +class FlashAttentionNpd(Cell): + """Flash Attention NPD Layer.""" + + def __init__(self, head_num, scale_value, block_size): + super().__init__() + self.head_num = head_num + self.scale_value = scale_value + self.cache_layout = "ND" + self.kv_layout = "NPD" + self.block_size = block_size + + @jit_for_graph_mode + def construct( + self, + query=None, + key=None, + value=None, + key_cache=None, + value_cache=None, + slot_mapping=None, + attn_mask=None, + actual_seq_qlen=None, + actual_seq_kvlen=None, + block_table=None, + ): + """ReshapeAndCacheNpd construct""" + key = key.contiguous() + value = value.contiguous() + k_out, value_out = ms_custom_ops.reshape_and_cache_npd( + key, + value, + key_cache, + value_cache, + slot_mapping, + actual_seq_qlen, + actual_seq_kvlen, + block_table, + self.cache_layout, + self.kv_layout, + ) + out = ms_custom_ops.unpad_fa_npd( + query=query, + key=k_out, + value=value_out, + attn_mask=attn_mask, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen, + head_num=self.head_num, + scale_value=self.scale_value, + q_input_layout="TH", + kv_input_layout=self.kv_layout, + block_size=self.block_size, + ) + return out + + +class MindSporeInputFactory: + """Factory for creating MindSpore inputs""" + + @staticmethod + def create_inputs( + np_q: np.ndarray, + np_k: np.ndarray, + np_v: np.ndarray, + np_k_cache: np.ndarray, + np_v_cache: np.ndarray, + np_slot_map: np.ndarray, + np_attn_mask: np.ndarray, + np_q_seq: np.ndarray, + np_kv_seq: np.ndarray, + np_block_tbl: np.ndarray, + ) -> Tuple[Tensor, ...]: + """Create MindSpore inputs""" + ms_query = Tensor(np_q) + ms_key = Tensor(np_k) + ms_value = Tensor(np_v) + ms_key_cache = Tensor(np_k_cache) + ms_value_cache = Tensor(np_v_cache) + ms_slot_map = Tensor(np_slot_map) + ms_attn_mask = Tensor(np_attn_mask) + ms_q_seq = Tensor(np_q_seq) + ms_kv_seq = Tensor(np_kv_seq) + ms_block_tbl = Tensor(np_block_tbl) + return ( + ms_query, + ms_key, + ms_value, + ms_key_cache, + ms_value_cache, + ms_slot_map, + ms_attn_mask, + ms_q_seq, + ms_kv_seq, + ms_block_tbl, + ) + + +def create_ms_inputs(np_q, np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_attn_mask, q_seq, kv_seq, block_tbl): + return MindSporeInputFactory.create_inputs( + np_q, np_k, np_v, np_k_cache, np_v_cache, np_slot_map, np_attn_mask, q_seq, kv_seq, block_tbl + ) + + +class TestResultVerifier: + """Verify test results""" + + @staticmethod + def verify_results(golden: Tensor, out: Tensor, dtype: np.dtype, rtol: float = 0.001, atol: float = 0.001) -> None: + """Verify results with appropriate dtype handling""" + + golden_np = golden.float().asnumpy().astype(np.float32).flatten() + out_np = out.float().asnumpy().astype(np.float32).flatten() + if dtype == bfloat16: + rtol = 0.01 + assert np.allclose(golden_np, out_np, rtol=rtol, atol=atol) + + +class TestConfig: + """Test configuration""" + + def __init__( + self, + device_target: str = "Ascend", + mode: context = context.GRAPH_MODE, + jit_config: Optional[Dict[str, Any]] = None, + ): + self.device_target = device_target + self.mode = mode + self.jit_config = jit_config or {} + + def apply(self): + """Apply test configuration""" + ms.set_device(self.device_target) + context.set_context(mode=self.mode) + if self.jit_config: + context.set_context(jit_config=self.jit_config) + + +# =============================== +# UNPAD FA NPD TEST ARCHITECTURE +# =============================== +#""" +#Test Structure Overview: +# +#1. ND FORMAT TESTS (cache_mode=0): +# - Direct ND format testing without format conversion +# - Data flow: Input(ND) → flash attention → Output(ND) → Verify +# - Tests: test_reshape_and_cache_nd_* +# +#3. KEY COMPONENTS: +# - create_inputs(): Generate ND format test data +# - FlashAttention - generate golden output +# - FlashAttentionNpd - new attention output +# +#4. VERIFICATION STRATEGY: +# - compare results with FlashAttention + + +# =============================== +# NPD FORMAT TESTS +# =============================== +class TestDataGenerator: + """Data generator for test inputs""" + + @staticmethod + def create_random_data(shape: Tuple[int, ...], dtype: np.dtype) -> np.ndarray: + """Create random data with specified shape and dtype""" + if dtype == np.int8: + return np.random.randint(low=-128, high=127, size=shape, dtype=np.int8) + return np.random.rand(*shape).astype(dtype) + + @staticmethod + def create_slot_map(num_tokens: int) -> np.ndarray: + """Create slot mapping""" + return np.arange(num_tokens, dtype=np.int32) + + @staticmethod + def create_q_seq(batch: int, seq: int) -> np.ndarray: + """Create q seq lens""" + return np.full(batch, seq).astype(np.int32) + + @staticmethod + def create_kv_seq(batch: int, seq: int) -> np.ndarray: + """Create kv seq lens""" + return np.full(batch, seq).astype(np.int32) + + @staticmethod + def create_blk_tbl(page_num: int, batch: int) -> np.ndarray: + return np.arange(page_num, dtype=np.int32).reshape(batch, (page_num + batch - 1) // batch) + + @staticmethod + def create_attn_mask(dtype: np.dtype) -> np.ndarray: + amask = np.ones(shape=(128, 128)) + amask = np.triu(amask, 1) + if dtype == bfloat16: + return amask.astype(dtype) + return (amask * -10000).astype(dtype) + + @staticmethod + def get_update_shapes( + batch_dim: int, seq: int, q_head_num: int, kv_head_num: int, head_dim: int, block_size: int + ) -> Tuple[Tuple[int, ...]]: + """Get update shapes for key and value, and number of tokens based on dimension""" + + query = (batch_dim * seq, q_head_num * head_dim) + key = (batch_dim * seq, kv_head_num * head_dim) + value = (batch_dim * seq, kv_head_num * head_dim) + num_tokens = key[0] + + block_num = ((seq + seq - 1) // block_size) * batch_dim + k_cache = (block_num, block_size, kv_head_num, head_dim) + v_cache = (block_num, block_size, kv_head_num, head_dim) + slot_map = num_tokens + q_seq = num_tokens + k_seq = num_tokens + block_tbl = (block_num, batch_dim) + + return query, key, value, k_cache, v_cache, slot_map, q_seq, k_seq, block_tbl + + +class DataGenerator(TestDataGenerator): + """Data generator for NPD format""" + + @staticmethod + def create_inputs( + dtype: np.dtype, batch_dim: int, seq: int, q_head_num: int, kv_head_num: int, head_dim: int, block_size: int + ) -> Tuple[np.ndarray, ...]: + """create inputs""" + ( + query_shape, + key_shape, + value_shape, + k_cache_shape, + v_cache_shape, + slot_map_shape, + _, + _, + block_tbl_shape, + ) = TestDataGenerator.get_update_shapes(batch_dim, seq, q_head_num, kv_head_num, head_dim, block_size) + query = TestDataGenerator.create_random_data(query_shape, dtype) + key = TestDataGenerator.create_random_data(key_shape, dtype) + value = TestDataGenerator.create_random_data(value_shape, dtype) + key_cache = TestDataGenerator.create_random_data(k_cache_shape, dtype) + value_cache = TestDataGenerator.create_random_data(v_cache_shape, dtype) + slot_map = TestDataGenerator.create_slot_map(slot_map_shape) + q_seq = TestDataGenerator.create_q_seq(batch_dim, seq) + kv_seq = TestDataGenerator.create_kv_seq(batch_dim, seq) + block_tbl = TestDataGenerator.create_blk_tbl(block_tbl_shape[0], block_tbl_shape[1]) + attn_mask = TestDataGenerator.create_attn_mask(dtype) + + return query, key, value, key_cache, value_cache, slot_map, attn_mask, q_seq, kv_seq, block_tbl + + +def create_inputs(dtype, batch_dim: int, seq: int, q_head_num: int, kv_head_num: int, head_dim: int, block_size: int): + """Legacy function for backward compatibility""" + return DataGenerator.create_inputs(dtype, batch_dim, seq, q_head_num, kv_head_num, head_dim, block_size) + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize("np_dtype", [np.float16]) +@pytest.mark.parametrize("run_mode", [context.PYNATIVE_MODE, context.GRAPH_MODE]) +@pytest.mark.parametrize("batch", [1, 3]) +@pytest.mark.parametrize("seq", [128, 300, 8192]) +@pytest.mark.parametrize("head_num", [32]) +@pytest.mark.parametrize("kv_head_num", [8]) +@pytest.mark.parametrize("embed", [128]) +def test_unpad_fa_npd(np_dtype, run_mode, batch, seq, head_num, kv_head_num, embed): + """ + Feature: Test unpad flash attention. + Description: Test ND format query, key, value and mask. + Expectation: Assert that results are consistent with numpy. + """ + test_config = TestConfig(device_target="Ascend", mode=run_mode) + test_config.apply() + + block_size = 32 + net0 = FlashAttention(head_num, 1.0) + net1 = FlashAttentionNpd(head_num, 1.0, block_size) + + ( + np_query, + np_key, + np_value, + np_key_cache, + np_value_cache, + np_slot_map, + np_attn_mask, + np_q_seq, + np_kv_seq, + np_block_tbl, + ) = create_inputs(np_dtype, batch, seq, head_num, kv_head_num, embed, block_size) + + ( + ms_query, + ms_key, + ms_value, + ms_key_cache, + ms_value_cache, + ms_slot_map, + ms_attn_mask, + ms_q_seq, + ms_kv_seq, + ms_block_tbl, + ) = create_ms_inputs( + np_query, + np_key, + np_value, + np_key_cache, + np_value_cache, + np_slot_map, + np_attn_mask, + np_q_seq, + np_kv_seq, + np_block_tbl, + ) + + # Run test + golden_out = net0( + ms_query, + ms_key, + ms_value, + ms_key_cache, + ms_value_cache, + ms_slot_map, + ms_attn_mask, + ms_q_seq, + ms_kv_seq + ) + ms_q_seq = ms_q_seq.move_to("CPU") + ms_kv_seq = ms_kv_seq.move_to("CPU") + out = net1( + ms_query, + ms_key, + ms_value, + ms_key_cache, + ms_value_cache, + ms_slot_map, + ms_attn_mask, + ms_q_seq, + ms_kv_seq, + ms_block_tbl, + ) + # Verify Output + TestResultVerifier.verify_results(golden_out, out, np_dtype) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize("np_dtype", [np.float16, bfloat16]) +@pytest.mark.parametrize("run_mode", [context.PYNATIVE_MODE, context.GRAPH_MODE]) +@pytest.mark.parametrize("batch", [1, 3]) +@pytest.mark.parametrize("seq", [128, 300, 8192]) +@pytest.mark.parametrize("head_num", [16]) +@pytest.mark.parametrize("kv_head_num", [16]) +@pytest.mark.parametrize("embed", [128]) +def test_unpad_fa_npd_none_mask(np_dtype, run_mode, batch, seq, head_num, kv_head_num, embed): + """ + Feature: Test unpad flash attention. + Description: Test ND format query, key, value and mask. + Expectation: Assert that results are consistent with numpy. + """ + test_config = TestConfig(device_target="Ascend", mode=run_mode) + test_config.apply() + + block_size = 32 + net0 = FlashAttention(head_num, 1.0) + net1 = FlashAttentionNpd(head_num, 1.0, block_size) + + ( + np_query, + np_key, + np_value, + np_key_cache, + np_value_cache, + np_slot_map, + np_attn_mask, + np_q_seq, + np_kv_seq, + np_block_tbl, + ) = create_inputs(np_dtype, batch, seq, head_num, kv_head_num, embed, block_size) + + ( + ms_query, + ms_key, + ms_value, + ms_key_cache, + ms_value_cache, + ms_slot_map, + ms_attn_mask, + ms_q_seq, + ms_kv_seq, + ms_block_tbl, + ) = create_ms_inputs( + np_query, + np_key, + np_value, + np_key_cache, + np_value_cache, + np_slot_map, + np_attn_mask, + np_q_seq, + np_kv_seq, + np_block_tbl, + ) + + # Run test + ms_attn_mask = None + golden_out = net0( + ms_query, + ms_key, + ms_value, + ms_key_cache, + ms_value_cache, + ms_slot_map, + ms_attn_mask, + ms_q_seq, + ms_kv_seq + ) + ms_q_seq = ms_q_seq.move_to("CPU") + ms_kv_seq = ms_kv_seq.move_to("CPU") + out = net1( + ms_query, + ms_key, + ms_value, + ms_key_cache, + ms_value_cache, + ms_slot_map, + ms_attn_mask, + ms_q_seq, + ms_kv_seq, + ms_block_tbl, + ) + # Verify Output + TestResultVerifier.verify_results(golden_out, out, np_dtype)