From 5ae7f4637ed441b0fb7b0c4f78aacd65351296d8 Mon Sep 17 00:00:00 2001 From: oniond Date: Thu, 11 Sep 2025 11:44:46 +0800 Subject: [PATCH 1/3] add outer_product_mean --- include/kutacc.h | 25 ++++ src/attention/outer_product_mean.cpp | 185 +++++++++++++++++++++++++++ 2 files changed, 210 insertions(+) create mode 100644 src/attention/outer_product_mean.cpp diff --git a/include/kutacc.h b/include/kutacc.h index 8f4af7e..da5a4c7 100644 --- a/include/kutacc.h +++ b/include/kutacc.h @@ -38,6 +38,31 @@ kutacc_export int kutacc_get_version(kutacc_version_t *version); typedef void* kutacc_tensor_h; +/** + * @brief outer_product_mean_calc_left_and_right_mu algorithm + * @param [out] left_proj, right_proj, left_proj_, right_proj_, mask + * @param [in] left_proj_w, left_proj_b, right_proj_w, right_proj_b + * @param [in] c_i, c_m, n_res, n_res_gather, n_seq, mask_bias + * @return Null + */ +kutacc_export void kutacc_af2_outer_product_mean_calc_left_and_right_mul( + kutacc_tensor_h left_proj, kutacc_tensor_h right_proj, kutacc_tensor_h left_proj_, kutacc_tensor_h right_proj_, + kutacc_tensor_h input_act, kutacc_tensor_h mask, kutacc_tensor_h norm, const kutacc_tensor_h left_proj_w, + const kutacc_tensor_h left_proj_b, const kutacc_tensor_h right_proj_w, const kutacc_tensor_h right_proj_b, + int64_t c_i, int64_t c_m, int64_t n_res, int64_t n_res_gather, int64_t n_seq, int64_t mask_bias); + +/** + * @brief outer_product_mean_chunk algorithm + * @param [out] output_b, output_w out + * @param [in] left_proj_, right_proj_, norm, left_block_size, right_block_size, + * @param [in] c_i, c_z, n_res, n_res_gather, n_seq + * @return Null + */ +kutacc_export void kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h output_b, const kutacc_tensor_h output_w, + kutacc_tensor_h out, kutacc_tensor_h left_proj_, kutacc_tensor_h right_proj_, kutacc_tensor_h norm, + int64_t left_block_size, int64_t right_block_size, int64_t c_i, int64_t c_z, int64_t n_res, + int64_t n_res_gather, int64_t n_seq); + /** * @brief gating_attention algorithm * @param input prepacked q_data, shape[batch * seq_len, nchannels] diff --git a/src/attention/outer_product_mean.cpp b/src/attention/outer_product_mean.cpp new file mode 100644 index 0000000..feb93d8 --- /dev/null +++ b/src/attention/outer_product_mean.cpp @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#include "../core/mm.h" +#include "../wrapper/wrapper.h" +#include "../utils/parallel.h" +#include "../utils/check.h" +#include "../utils/memory.h" + +namespace kutacc { + +/* +* [OUT] left_proj, right_proj, left_proj_, right_proj_, mask +* [IN] left_proj_w, left_proj_b, right_proj_w, right_proj_b +* [IN]c_i, c_m, n_res, n_res_gather, n_seq, mask_bias +*/ + +void outer_product_mean_calc_left_and_right_mul_kernel(Tensor &left_proj, Tensor &right_proj, Tensor &left_proj_, + Tensor &right_proj_, Tensor &input_act, Tensor &mask, + Tensor &norm, const Tensor &left_proj_w, + const Tensor &left_proj_b, const Tensor &right_proj_w, + const Tensor &right_proj_b, int64_t c_i, int64_t c_m, + int64_t n_res, int64_t n_res_gather, int64_t n_seq, + int64_t mask_bias) +{ + // 矩阵乘对结果转置 + // input_act [n_seq, n_res, n_cm] -> left_proj [c_i, n_res, n_seq] + addmm(to_bf16(1), Tensor((__bf16 *)left_proj_w.data_ptr(), {c_i, c_m}, {c_m, 1}, 2, kBF16), + Tensor((__bf16 *)input_act.data_ptr(), {c_m, n_res * n_seq}, {1, c_m}, 2, kBF16), to_bf16(0), + Tensor((__bf16 *)left_proj.data_ptr(), {c_i, n_res * n_seq}, {n_res * n_seq, 1}, 2, kBF16), + BlasExtendParams{.prepack_a = true, .prepack_b = false, .col_bias = true, .bias = left_proj_b.data_ptr()}); + + // 矩阵乘对结果转置 + // input_act [n_seq, n_res, n_cm] -> right_proj [c_i, n_res, n_seq] + addmm(to_bf16(1), Tensor((__bf16 *)right_proj_w.data_ptr(), {c_i, c_m}, {c_m, 1}, 2, kBF16), + Tensor((__bf16 *)input_act.data_ptr(), {c_m, n_res * n_seq}, {1, c_m}, 2, kBF16), to_bf16(0), + Tensor((__bf16 *)right_proj.data_ptr(), {c_i, n_res * n_seq}, {n_res * n_seq, 1}, 2, kBF16), + BlasExtendParams{.prepack_a = true, .prepack_b = false, .col_bias = true, .bias = right_proj_b.data_ptr()}); + + // left_proj_ = (left_proj + left_proj_b) * mask + // right_proj_ = (right_proj + right_proj_b) * mask + // left_proj / right_proj [c_i, n_res, n_seq] -> left_proj_ / right_proj_ [n_res, c_i, n_seq] + parallel_for(0, c_i * n_res, 1, [&](int64_t start, int64_t end) { + int64_t ci, ri; + data_index_init(start, ci, c_i, ri, n_res); + for (int64_t _i = start; _i < end; _i++) { + auto left_proj_data = + (__bf16 *)left_proj.data_ptr() + ci * left_proj.strides()[0] + ri * left_proj.strides()[1]; + auto right_proj_data = + (__bf16 *)right_proj.data_ptr() + ci * right_proj.strides()[0] + ri * right_proj.strides()[1]; + auto mask_data = (__bf16 *)mask.data_ptr() + mask_bias + ri * mask.strides()[0]; + auto left_proj_data_ = + (__bf16 *)left_proj_.data_ptr() + ri * left_proj_.strides()[0] + ci * left_proj_.strides()[1]; + auto right_proj_data_ = + (__bf16 *)right_proj_.data_ptr() + ri * right_proj_.strides()[0] + ci * right_proj_.strides()[1]; + int64_t vl = (int64_t)svcntw(); + for (int64_t i = 0; i < n_seq; i += vl) { + svbool_t pg = svwhilelt_b32(i, n_seq); + auto left_values = svld1(pg, &left_proj_data[i]); + auto right_values = svld1(pg, &right_proj_data[i]); + auto mask_values = svld1(pg, &mask_data[i]); + left_values = svmul_x(pg, left_values, mask_values); + right_values = svmul_x(pg, right_values, mask_values); + svst1<__bf16, float>(pg, &left_proj_data_[i], left_values); + svst1<__bf16, float>(pg, &right_proj_data_[i], right_values); + } + data_index_step(ci, c_i, ri, n_res); + } + }); + + // norm = mask @ mask.transpose(0, 1) + addmm(to_bf16(1), Tensor((__bf16 *)mask.data_ptr() + mask_bias, {n_res, n_seq}, {n_seq, 1}, 2, kBF16), + Tensor((__bf16 *)mask.data_ptr(), {n_seq, n_res_gather}, {1, n_seq}, 2, kBF16), to_bf16(0), + Tensor((__bf16 *)norm.data_ptr(), {n_res, n_res_gather}, {n_res_gather, 1}, 2, kBF16)); +} + +/* + * [out] output_b, output_w out + * [IN] left_proj_, right_proj_, norm, left_block_size, right_block_size, + * [IN] c_i, c_z, n_res, n_res_gather, n_seq +*/ +void outer_product_mean_chunk_kernel(const Tensor &output_b, const Tensor &output_w, Tensor &out, Tensor &left_proj_, + Tensor &right_proj_, Tensor &norm, int64_t left_block_size, + int64_t right_block_size, int64_t c_i, int64_t c_z, int64_t n_res, + int64_t n_res_gather, int64_t n_seq) +{ + int64_t left_nblocks = (n_res + left_block_size - 1) / left_block_size; + int64_t right_nblocks = (n_res_gather + right_block_size - 1) / right_block_size; + parallel_for(0, left_nblocks * right_nblocks, 1, [&](int64_t start, int64_t end) { + int64_t left_block_i, right_block_i; + data_index_init(start, left_block_i, left_nblocks, right_block_i, right_nblocks); + auto chunk_buf = alloc<__bf16>(left_block_size * c_i * right_block_size * c_i); + auto chunk_buf_ = alloc<__bf16>(left_block_size * right_block_size * c_i * c_i); + auto out_buf = alloc<__bf16>(left_block_size * right_block_size * c_z); + for (int64_t _i = start; _i < end; _i++) { + int64_t left_start = left_block_i * left_block_size; + int64_t left_end = std::min(left_start + left_block_size, n_res); + int64_t right_start = right_block_i * right_block_size; + int64_t right_end = std::min(right_start + right_block_size, n_res_gather); + addmm(to_bf16(1), + Tensor((__bf16 *)left_proj_.data_ptr() + left_start * left_proj_.strides()[0], + {(left_end - left_start) * c_i, n_seq}, {n_seq, 1}, 2, kBF16), + Tensor((__bf16 *)right_proj_.data_ptr() + right_start * right_proj_.strides()[0], + {n_seq, (right_end - right_start) * c_i}, {1, n_seq}, 2, kBF16), + to_bf16(0), + Tensor(chunk_buf.get(), {(left_end - left_start) * c_i, (right_end - right_start) * c_i}, + {(right_end - right_start) * c_i, 1}, 2, kBF16), + BlasExtendParams{.num_threads = 1}); + for (int64_t left_ri = left_start; left_ri < left_end; left_ri++) { + for (int64_t right_ri = right_start; right_ri < right_end; right_ri++) { + for (int64_t left_ci = 0; left_ci < c_i; left_ci++) { + std::memcpy(chunk_buf_.get() + (left_ri - left_start) * (right_end - right_start) * c_i * c_i + + (right_ri - right_start) * c_i * c_i + left_ci * c_i, + chunk_buf.get() + (left_ri - left_start) * (right_end - right_start) * c_i * c_i + + left_ci * (right_end - right_start) * c_i + (right_ri - right_start) * c_i, + size_t(c_i) * sizeof(__bf16)); + } + } + } + addmm(to_bf16(1), + Tensor(chunk_buf_.get(), {(left_end - left_start) * (right_end - right_start), c_i * c_i}, + {c_i * c_i, 1}, 2, kBF16), + Tensor((__bf16 *)output_w.data_ptr(), {c_i * c_i, c_z}, {1, c_i * c_i}, 2, kBF16), to_bf16(0), + Tensor(out_buf.get(), {(left_end - left_start) * (right_end - right_start), c_z}, {c_z, 1}, 2, kBF16), + BlasExtendParams{.num_threads = 1, + .prepack_a = false, + .prepack_b = true, + .row_bias = true, + .bias = output_b.data_ptr()}); + int64_t vl = (int64_t)svcntw(); + for (int64_t left_ri = left_start; left_ri < left_end; left_ri++) { + for (int64_t right_ri = right_start; right_ri < right_end; right_ri++) { + auto out_buf_data = out_buf.get() + (left_ri - left_start) * (right_end - right_start) * c_z + + (right_ri - right_start) * c_z; + auto out_data = (__bf16 *)out.data_ptr() + left_ri * out.strides()[0] + right_ri * out.strides()[1]; + float norm_value = to_float(((__bf16 *)norm.data_ptr())[left_ri * n_res_gather + right_ri]); + norm_value = 1 / (norm_value + 1e-3f); + for (int64_t i = 0; i < c_z; i += vl) { + svbool_t pg = svwhilelt_b32(i, c_z); + auto values = svld1(pg, &out_buf_data[i]); + values = svmul_x(pg, values, norm_value); + svst1<__bf16, float>(pg, &out_data[i], values); + } + } + } + data_index_step(left_block_i, left_nblocks, right_block_i, right_nblocks); + } + }); +} +} + +void kutacc_export kutacc_af2_outer_product_mean_calc_left_and_right_mul( + kutacc_tensor_h left_proj, kutacc_tensor_h right_proj, kutacc_tensor_h left_proj_, kutacc_tensor_h right_proj_, + kutacc_tensor_h input_act, kutacc_tensor_h mask, kutacc_tensor_h norm, const kutacc_tensor_h left_proj_w, + const kutacc_tensor_h left_proj_b, const kutacc_tensor_h right_proj_w, const kutacc_tensor_h right_proj_b, + int64_t c_i, int64_t c_m, int64_t n_res, int64_t n_res_gather, int64_t n_seq, int64_t mask_bias) +{ + outer_product_mean_calc_left_and_right_mul_kernel( + *kutacc::convertKutaccTensor(left_proj), *kutacc::convertKutaccTensor(right_proj), *kutacc::convertKutaccTensor(left_proj_), + *kutacc::convertKutaccTensor(right_proj_), *kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(mask), + *kutacc::convertKutaccTensor(norm), *kutacc::convertKutaccTensor(left_proj_w), *kutacc::convertKutaccTensor(left_proj_b), + *kutacc::convertKutaccTensor(right_proj_w), *kutacc::convertKutaccTensor(right_proj_b), c_i, c_m, n_res, n_res_gather, n_seq, + mask_bias); +} + +void kutacc_export kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h output_b, const kutacc_tensor_h output_w, + kutacc_tensor_h out, kutacc_tensor_h left_proj_, kutacc_tensor_h right_proj_, + kutacc_tensor_h norm, int64_t left_block_size, int64_t right_block_size, + int64_t c_i, int64_t c_z, int64_t n_res, int64_t n_res_gather, int64_t n_seq) +{ + kutacc::outer_product_mean_chunk_kernel(*kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(output_w), + *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(left_proj_), *kutacc::convertKutaccTensor(right_proj_), + *kutacc::convertKutaccTensor(norm), left_block_size, right_block_size, c_i, c_z, n_res, n_res_gather, n_seq); +} \ No newline at end of file -- Gitee From 016d50cc96c2167af041da304ce1affba7703100 Mon Sep 17 00:00:00 2001 From: oniond Date: Wed, 17 Sep 2025 15:18:06 +0800 Subject: [PATCH 2/3] update README --- README.md | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 819abb8..d31323e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,70 @@ -# KUTensor ACCelerating + + + # Kunpeng Unifined Transformer Accelerated Library -基于KUPL优化的AI算子加速库 \ No newline at end of file +## 1.简介 +鲲鹏芯片支持向量、矩阵计算,带来算力提升的同时,辅以高速RDMA网络,带来超大带宽、微秒级延迟的极致性能。该芯片强浮点算力和高速带宽天然亲和AI推理计算。基于此,我们提出一种鲲鹏平台上Transformer模型融合算子库(简称"KuTACC"),高效实现Transformer模型推理在鲲鹏处理器的执行。 + +## 2.本地运行 + +### 2.1 依赖软件安装 + +#### 2.1.1 HPCKit安装 +该方案是用HPCKit组件中的毕昇编译器进行编译,HPCKit安装流程参考[官方指导文档](https://www.hikunpeng.com/developer/hpc/hpckit-download)。 + +KuTACC的安装需要使用HPCKit环境中的毕昇编译器、KUPL,配置流程参考[HPCKit介绍](https://www.hikunpeng.com/document/detail/zh/kunpenghpcs/hpckit/devg/KunpengHPCKit_developer_002.html)。 + +### 2.2 源码编译与安装 +可以使用build.sh将KuTACC安装在任意指定的路径下,同时支持release/debug模式的库安装。 +```shell +sh build.sh --install_path=/path/to/your/kutacc-path --build_type=Release/Debug +``` + +### 2.3 环境变量配置 +将KuTACC的LIB和Include设置后即可调用kutacc中的相应接口。 +```shell +export KUTACC_LIB=/path_to_kutacc/install/lib +export KUTACC_INCLUDE=/path_to_kutacc/install/include +``` +要使用KuTACC提供的接口,需要在项目文件中增加对这两个参数的引用,即在CXX的编译参数上增加以下内容 +```shell +export CXXFLAGS="-I${KUTACC_INCLUDE}" +export LDFLAGS="-L${KUTACC_LIB}" +export LDLIBS="-lkutacc" +``` +设置环境变量后,编译某个程序的编译脚本为 +```shell +g++/clang xxx.o $(LDFLAGS) $(LDLIBS) -o $@ +g++/clang $(CXXFLAGS) -c xxx.cpp -o xxx.o +``` + +若项目使用CMAKE进行管理,推荐使用target接口进行链接 +```shell +# include路径 +target_include_directories(yourapp PRIVATE ${KUTACC_INCLUDE}) + +# lib路径 +target_link_libraries(yourapp PRIVATE ${KUTACC_LIB}) +``` + +## 3. 支持的应用 + +| 支持的应用 | 应用版本| +| :------------ | -----------: | +| Alphafold2 | v1.0 | + +## License +此代码遵循[OpenSoftware License 1.0](LICENSE),继承自MIT。 + +## 联系方式 +如果您有任何疑问,请欢迎提issue共同讨论。 -- Gitee From 5795a926f7fa343593407c7cf6399279ca1bb4dc Mon Sep 17 00:00:00 2001 From: oniond Date: Sat, 20 Sep 2025 15:47:33 +0800 Subject: [PATCH 3/3] fix comm swap --- src/comm/all_gather.cpp | 90 +++++++++++++++++++++++++++++++++++++++ src/comm/transpose.cpp | 94 +++++++++++++++++++++++++++++++++++++++++ src/tensor/tensor.cpp | 10 +++++ src/tensor/tensor.h | 2 + 4 files changed, 196 insertions(+) create mode 100644 src/comm/all_gather.cpp create mode 100644 src/comm/transpose.cpp diff --git a/src/comm/all_gather.cpp b/src/comm/all_gather.cpp new file mode 100644 index 0000000..d7fc6e7 --- /dev/null +++ b/src/comm/all_gather.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#include "../utils/parallel.h" +#include "../tensor/tensor.h" +#include "kutacc.h" +#include "kupl.h" + +namespace kutacc { +void af2_all_gather_kernel(Tensor &data, Tensor &out) +{ + int64_t m = out.sizes()[0]; + int64_t n = out.sizes()[1]; + int64_t len = data.sizes()[2]; + + if (data.sizes()[0] == m) { + std::swap(m, n); + std::swap(data.sizes_ref()[0], data.sizes_ref()[1]); + std::swap(data.strides_ref()[0], data.strides_ref()[1]); + std::swap(out.sizes_ref()[0], out.sizes_ref()[1]); + std::swap(out.strides_ref()[0], out.strides_ref()[1]); + } + + int64_t block_m = (m + world_size - 1) / world_size; + + KUTACC_CHECK(data.strides()[2] == 1, data.strides()[2]); + KUTACC_CHECK(out.strides()[2] == 1, out.strides()[2]); + + int64_t subblock_m = buffer_size / (n * len); + KUTACC_CHECK(subblock_m > 0, buffer_size, " ", n, " ", len); + for (int64_t block_mi_start = 0; block_mi_start < block_m; block_mi_start += subblock_m) { + int64_t size_m = std::min(subblock_m, block_m - block_mi_start); + + parallel_for(0, 2 * size_m * n, 1, [&](int64_t start, int64_t end) { + collapse_for(start, end, 2, size_m, n, [&](int64_t direct, int64_t sbmi, int64_t ni) { + int64_t bmi = sbmi + block_mi_start; + if (bmi < m - rank * block_m) { + if (!direct) { + memcpy((uint8_t *)(kupl_recvbuf) + (sbmi * n + ni) * len, + (uint8_t *)(data.data_ptr()) + bmi * data.strides()[0] + ni * data.strides()[1], + (size_t)len); + } else { + memcpy((uint8_t *)(out.data_ptr()) + (bmi + rank * block_m) * out.strides()[0] + + ni * out.strides()[1], + (uint8_t *)(data.data_ptr()) + bmi * data.strides()[0] + ni * data.strides()[1], + (size_t)len); + } + } + }); + }); + + int64_t par_size = world_size * size_m * n; + parallel_for(0, par_size, 1, [&](int64_t start, int64_t end) { + collapse_for(start, end, world_size, size_m, n, [&](int64_t ri, int64_t sbmi, int64_t ni) { + int64_t bmi = sbmi + block_mi_start; + if (bmi < m - ri * block_m) { + if (ri != rank) { + void *remote_buffer; + kupl_shm_win_query(kupl_recvbuf_win, (int)ri, &remote_buffer); + memcpy((uint8_t *)(out.data_ptr()) + (bmi + ri * block_m) * out.strides()[0] + + ni * out.strides()[1], + (uint8_t *)(remote_buffer) + (sbmi * n + ni) * len, (size_t)len); + } + } + }); + }); + } +} +} + +void kutacc_af2_all_gather(kutacc_tensor_h data, kutacc_tensor_h out) +{ + if (data == nullptr || out = nullptr) { + KUTACC_CHECK(data != nullptr, data); + KUTACC_CHECK(out != nullptr, out); + return; + } + af2_all_gather_kernel(*kutacc::convertKutaccTensor(data), *kutacc::convertKutaccTensor(out)); +} \ No newline at end of file diff --git a/src/comm/transpose.cpp b/src/comm/transpose.cpp new file mode 100644 index 0000000..1529c83 --- /dev/null +++ b/src/comm/transpose.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#include "../utils/parallel.h" +#include "../tensor/tensor.h" +#include "kutacc.h" +#include "kupl.h" + +namespace kutacc { +void af2_transpose_kernel(Tensor &data, Tensor &out) +{ + int64_t m = std::max(data.sizes()[0], out.sizes()[0]); + int64_t n = std::max(data.sizes()[1], out.sizes()[1]); + int64_t len = data.sizes()[2]; + + if (data.sizes()[0] == m) { + std::swap(m, n); + std::swap(data.sizes_ref()[0], data.sizes_ref()[1]); + std::swap(data.strides_ref()[0], data.strides_ref()[1]); + std::swap(out.sizes_ref()[0], out.sizes_ref()[1]); + std::swap(out.strides_ref()[0], out.strides_ref()[1]); + } + + int64_t block_m = (m + world_size - 1) / world_size; + int64_t block_n = (n + world_size - 1) / world_size; + + KUTACC_CHECK(data.strides()[2] == 1, data.strides()[2]); + KUTACC_CHECK(out.strides()[2] == 1, out.strides()[2]); + + int64_t subblock_m = buffer_size / (world_size * block_n * len); + KUTACC_CHECK(subblock_m > 0, buffer_size, " ", world_size, " ", block_n, " ", len); + for (int64_t block_mi_start = 0; block_mi_start < block_m; block_mi_start += subblock_m) { + int64_t size_m = std::min(subblock_m, block_m - block_mi_start); + int64_t par_size = world_size * size_m * block_n; + + parallel_for(0, par_size, 1, [&](int64_t start, int64_t end) { + collapse_for(start, end, world_size, size_m, block_n, [&](int64_t ri, int64_t sbmi, int64_t bni) { + int64_t bmi = sbmi + block_mi_start; + if (bmi < m - rank * block_m && bni < n - ri * block_n) { + if (ri != rank) { + memcpy((uint8_t *)(kupl_recvbuf) + (ri * subblock_m * block_n + sbmi * block_n + bni) * len, + (uint8_t *)(data.data_ptr()) + bmi * data.strides()[0] + + (bni + ri * block_n) * data.strides()[1], + (size_t)len); + } else { + memcpy((uint8_t *)(out.data_ptr()) + (bmi + rank * block_m) * out.strides()[0] + + bni * out.strides()[1], + (uint8_t *)(data.data_ptr()) + bmi * data.strides()[0] + + (bni + rank * block_n) * data.strides()[1], + (size_t)len); + } + } + }); + }); + + parallel_for(0, par_size, 1, [&](int64_t start, int64_t end) { + collapse_for(start, end, world_size, size_m, block_n, [&](int64_t ri, int64_t sbmi, int64_t bni) { + int64_t bmi = sbmi + block_mi_start; + if (bmi < m - ri * block_m && bni < n - rank * block_n) { + if (ri != rank) { + void *remote_buffer; + kupl_shm_win_query(kupl_recvbuf_win, (int)ri, &remote_buffer); + memcpy((uint8_t *)(out.data_ptr()) + (bmi + ri * block_m) * out.strides()[0] + + bni * out.strides()[1], + (uint8_t *)(remote_buffer) + (rank * subblock_m * block_n + sbmi * block_n + bni) * len, + (size_t)len); + } + } + }); + }); + } +} +} + +void kutacc_af2_transpose(kutacc_tensor_h data, kutacc_tensor_h out) +{ + if (data == nullptr || out = nullptr) { + KUTACC_CHECK(data != nullptr, data); + KUTACC_CHECK(out != nullptr, out); + return; + } + af2_transpose_kernel(*kutacc::convertKutaccTensor(data), *kutacc::convertKutaccTensor(out)); +} \ No newline at end of file diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index c0cb728..109753c 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -30,6 +30,16 @@ std::vector Tensor::strides() const return data.strides; } +std::vector& Tensor::sizes_ref() +{ + return data.sizes; +} + +std::vector& Tensor::strides_ref() +{ + return data.strides; +} + DType Tensor::dtype() const { return data.dtype; diff --git a/src/tensor/tensor.h b/src/tensor/tensor.h index 433bca0..9e1f5ac 100644 --- a/src/tensor/tensor.h +++ b/src/tensor/tensor.h @@ -33,6 +33,8 @@ struct Tensor { void* data_ptr() const; std::vector sizes() const; std::vector strides() const; + std::vector& sizes_ref(); + std::vector& strides_ref(); DType dtype() const; int64_t dim() const; }; -- Gitee