diff --git a/README.md b/README.md index 29aa9b3c214fc3849b7289a1bda1527b3e8f5721..b01408beacf0614020c7cd690c59f0058190e96b 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,58 @@ - - -# kunpeng Unifined Transformer ACCelerating - -鲲鹏Transformer加速库 \ No newline at end of file +# Kunpeng Unifined Transformer Accelerated Library + +## 1.简介 +鲲鹏芯片支持向量、矩阵计算,带来算力提升的同时,辅以高速RDMA网络,带来超大带宽、微秒级延迟的极致性能。该芯片强浮点算力和高速带宽天然亲和AI推理计算。基于此,我们提出一种鲲鹏平台上Transformer模型融合算子库(简称"KuTACC"),高效实现Transformer模型推理在鲲鹏处理器的执行。 + +## 2.本地运行 + +### 2.1 依赖软件安装 + +#### 2.1.1 HPCKit安装 +该方案是用HPCKit组件中的毕昇编译器进行编译,HPCKit安装流程参考[官方指导文档](https://www.hikunpeng.com/developer/hpc/hpckit-download)。 + +KuTACC的安装需要使用HPCKit环境中的毕昇编译器、KUPL,配置流程参考[HPCKit介绍](https://www.hikunpeng.com/document/detail/zh/kunpenghpcs/hpckit/devg/KunpengHPCKit_developer_002.html)。 + +### 2.2 源码编译与安装 +可以使用build.sh将KuTACC安装在任意指定的路径下,同时支持release/debug模式的库安装。 +```shell +sh build.sh --install_path=/path/to/your/kutacc-path --build_type=Release/Debug +``` + +### 2.3 环境变量配置 +将KuTACC的LIB和Include设置后即可调用kutacc中的相应接口。 +```shell +export KUTACC_LIB=/path_to_kutacc/install/lib +export KUTACC_INCLUDE=/path_to_kutacc/install/include +``` +要使用KuTACC提供的接口,需要在项目文件中增加对这两个参数的引用,即在CXX的编译参数上增加以下内容 +```shell +export CXXFLAGS="-I${KUTACC_INCLUDE}" +export LDFLAGS="-L${KUTACC_LIB}" +export LDLIBS="-lkutacc" +``` +设置环境变量后,编译某个程序的编译脚本为 +```shell +g++/clang xxx.o $(LDFLAGS) $(LDLIBS) -o $@ +g++/clang $(CXXFLAGS) -c xxx.cpp -o xxx.o +``` + +若项目使用CMAKE进行管理,推荐使用target接口进行链接 +```shell +# include路径 +target_include_directories(yourapp PRIVATE ${KUTACC_INCLUDE}) + +# lib路径 +target_link_libraries(yourapp PRIVATE ${KUTACC_LIB}) +``` + +## 3. 支持的应用 + +| 支持的应用 | 应用版本| +| :------------ | -----------: | +| Alphafold2 | v1.0 | + +## License +此代码遵循[OpenSoftware License 1.0](LICENSE),继承自MIT。 + +## 联系方式 +如果您有任何疑问,请欢迎提issue共同讨论。 diff --git a/include/kutacc.h b/include/kutacc.h index cd11cd2042c318db3c77cc41d5427f6438ccfc54..15ae56f6d71ba4ddde0673610e47a331a282fbbc 100644 --- a/include/kutacc.h +++ b/include/kutacc.h @@ -120,12 +120,77 @@ kutacc_export void kutacc_af2_transition(kutacc_tensor_h input_act, const kutacc * @return Null */ kutacc_export void kutacc_af2_layernorm(__bf16 *data, float *gamma, float *beta, int64_t size, float eps, __bf16 *out); + +/** + * @brief af2_invariant_point algorithm: invariant_point interface for af2 model + * @param [in] q + * @param [in] k + * @param [in] v + * @param [in] q_pts + * @param [in] k_pts + * @param [in] v_pts + * @param [in] b + * @param [in] a + * @param [in] head_weights + * @param [in] weights.head_weights + * @param [in] z + * @param [in] rigid_rot_mats + * @param [in] rigid_trans + * @param [in] mask + * @param [in] linear_b_w + * @param [in] linear_b_b + * @param [in] n_res + * @param [in] c_z + * @param [in] c_hidden + * @param [in] no_heads + * @param [in] no_qk_points + * @param [in] no_v_points + * @param [out] o + * @param [out] o_pt + * @param [out] o_pt_norm + * @param [out] o_pair + * @param [out] out + * @return Null + */ +kutacc_export void kutacc_af2_invariant_point(kutacc_tensor_h q, kutacc_tensor_h k, kutacc_tensor_h v, kutacc_tensor_h q_pts, kutacc_tensor_h k_pts, kutacc_tensor_h v_pts, + kutacc_tensor_h b, kutacc_tensor_h a, kutacc_tensor_h head_weights, kutacc_tensor_h weights_head_weights, kutacc_tensor_h o, kutacc_tensor_h o_pt, kutacc_tensor_h o_pt_norm, kutacc_tensor_h o_pair, + kutacc_tensor_h z, kutacc_tensor_h rigid_rot_mats, kutacc_tensor_h rigid_trans, kutacc_tensor_h mask, kutacc_tensor_h linear_b_w, kutacc_tensor_h linear_b_b, + int64_t n_res, int64_t c_z, int64_t c_hidden, int64_t no_heads, int64_t no_qk_points, int64_t no_v_points); + +/** + * @brief impl of rot_vec_mul + * @param rot_mats shape [..., 3, 3], fp32 + * @param pts shape [..., 3], bf16 / fp32 + * @param trans shape [..., 3], bf16 / fp32 + * @return shape [..., 3], bf16 / fp32 + */ +kutacc_export void kutacc_af2_rigid_rot_vec_mul(kutacc_tensor_h pts, kutacc_tensor_h rot_mats, kutacc_tensor_h out, kutacc_tensor_h trans); + +/** + * @brief impl of rot_mat_mul + * @param a shape [..., 3, 3], fp32 + * @param b shape [..., 3, 3], fp32 + * @return shape [..., 3, 3], fp32 + */ +kutacc_export void kutacc_af2_rigid_rot_matmul(kutacc_tensor_h a, kutacc_tensor_h b, kutacc_tensor_h out); + +/** + * @brief af2_linear algorithm: linear interface for af2 model + * @param [in] act + * @param [in] weight + * @param [in] bias_data + * @param [in] beta + * @param [in/out] result + * @return Null + */ +kutacc_export void kutacc_af2_linear(kutacc_tensor_h act, kutacc_tensor_h weight, float* bias_data, kutacc_tensor_h result, int64_t beta); + /** * @brief gemm prepack for linear layer * @param weight shape [n, k] * @return result shape [len] */ -kutacc_export void kutacc_linear_weight_prepack(const __bf16 *weight, __bf16 *result, int64_t n, int64_t k, int64_t ldb, int64_t num_threads = 0); +kutacc_export void kutacc_af2_linear_weight_prepack(const __bf16 *weight, __bf16 *result, int64_t n, int64_t k, int64_t ldb, int64_t num_threads = 0); /** * @brief get pack size of A or B for linear layer @@ -135,7 +200,7 @@ kutacc_export void kutacc_linear_weight_prepack(const __bf16 *weight, __bf16 *re * @param m,n,k A shape[m, k] B shape[k, n] * @return size: m * k or k * n */ -kutacc_export size_t kutacc_gemm_pack_get_size(char identifier, char transa, char transb, int m ,int n, int k); +kutacc_export size_t kutacc_af2_gemm_pack_get_size(char identifier, char transa, char transb, int m ,int n, int k); #ifdef __cplusplus } @@ -169,6 +234,7 @@ private: public: TensorWrapper(void *data_ptr, std::vector sizes, std::vector strides, int64_t dim, DType dtype); + TensorWrapper(); kutacc_tensor_h get_tensor(){return tensor_;} ~TensorWrapper(); }; diff --git a/src/attention/invariant_point.cpp b/src/attention/invariant_point.cpp new file mode 100644 index 0000000000000000000000000000000000000000..386b1eb0454d339af68c2aca60cb0f2209684123 --- /dev/null +++ b/src/attention/invariant_point.cpp @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ +#include "kutacc.h" +#include +#include "tensor/tensor.h" +#include "linear/mm.h" +#include "wrapper/wrapper.h" +#include "utils/collapse.h" +#include "utils/parallel.h" +#include "utils/blas.h" +#include "utils/memory.h" +#include "softmax/softmax.h" +#include +#include +#include +#include "rigid.h" + +namespace kutacc { +void kutacc_af2_invariant_point_kernel(Tensor q, Tensor k, Tensor v, Tensor q_pts, Tensor k_pts, Tensor v_pts, + Tensor b, Tensor a, Tensor head_weights, Tensor weights_head_weights, Tensor o, Tensor o_pt, Tensor o_pt_norm, Tensor o_pair, + Tensor z, Tensor rigid_rot_mats, Tensor rigid_trans, Tensor mask, Tensor linear_b_w, Tensor linear_b_b, + int64_t n_res, int64_t c_z, int64_t c_hidden, int64_t no_heads, int64_t no_qk_points, int64_t no_v_points) +{ + addmm(to_bf16(1), + Tensor((__bf16 *)linear_b_w.data_ptr(), {no_heads, c_z}, {linear_b_w.strides()[0], 1}, 2, kBF16), + Tensor((__bf16 *)z.data_ptr(), {c_z, n_res * n_res}, {1, z.strides()[1]}, 2, kBF16), + to_bf16(0), + Tensor((__bf16 *)b.data_ptr(), {no_heads, n_res * n_res}, {b.strides()[0], 1}, 2, kBF16), + {.col_bias = true, .bias = linear_b_b.data_ptr()}); + + for (int64_t hi = 0; hi < no_heads; hi++) { + addmm(to_bf16(1), + Tensor((__bf16 *)q.data_ptr() + hi * q.strides()[1], {n_res, c_hidden}, {q.strides()[0], 1}, 2, kBF16), + Tensor((__bf16 *)k.data_ptr() + hi * k.strides()[1], {c_hidden, n_res}, {1, k.strides()[0]}, 2, kBF16), + to_bf16(0), + Tensor((__bf16 *)a.data_ptr() + hi * a.strides()[0], {n_res, n_res}, {a.strides()[1],1}, 2, kBF16)); + } + + for (int64_t i = 0; i < no_heads; i++) { + float x = ((float *)weights_head_weights.data_ptr())[i]; + x = std::log(std::exp(x) + 1); + x *= std::sqrt(1.0f / (3 * ((float)no_qk_points * 9.0f / 2))); + ((float *)head_weights.data_ptr())[i] = x; + } + auto pg_pt_att = svwhilelt_b32((int64_t)0, no_qk_points * 3); + float scale_a = std::sqrt(1.0f / (3 * (float)c_hidden)); + float scale_b = std::sqrt(1.0f / 3); + parallel_for(0, no_heads * n_res, 1, [&](int64_t start, int64_t end) { + auto pt_att = kutacc::alloc(n_res); + auto softmax_buf = kutacc::alloc(n_res); + kutacc::collapse_for(start, end, no_heads, n_res, [&](int64_t hi, int64_t qi){ + for (int64_t ki = 0; ki < n_res; ki++) { + auto q_values = svld1(pg_pt_att, + (__bf16 *)q_pts.data_ptr() + qi * q_pts.strides()[0] + hi * q_pts.strides()[1]); + auto k_values = svld1(pg_pt_att, + (__bf16 *)k_pts.data_ptr() + ki * k_pts.strides()[0] + hi * k_pts.strides()[1]); + auto values = svsub_x(pg_pt_att, q_values, k_values); + values = svmul_x(pg_pt_att, values, values); + float sum = svaddv(pg_pt_att, values); + sum *= ((float *)head_weights.data_ptr())[hi] * (-0.5f); + pt_att[(unsigned long)ki] = sum; + } + int64_t vl = (int64_t)svcntw(); + svfloat32_t q_mask = svdup_f32(kutacc::to_float(((__bf16 *)mask.data_ptr())[qi])); + svfloat32_t reduce_max = svdup_f32(-INFINITY); + for (int64_t ki = 0; ki < n_res; ki += vl) { + svbool_t pg =svwhilelt_b32(ki, n_res); + auto values = svld1(pg, + (__bf16 *)a.data_ptr() + hi * a.strides()[0] + qi * a.strides()[1] + ki); + { + auto b_values = svld1(pg, + (__bf16 *)b.data_ptr() + hi * b.strides()[0] + qi * b.strides()[1] + ki); + b_values = svmul_x(pg, b_values, scale_b); + values = svmla_x(pg, b_values, values, scale_a); + } + { + auto pt_att_values = svld1(pg, pt_att.get() + ki); + values = svadd_x(pg, values, pt_att_values); + } + { + auto mask_values = svld1(pg, (__bf16 *)mask.data_ptr() + ki); + mask_values = svmad_x(pg, mask_values, q_mask, -1.f); + values = svmla_x(pg, values, mask_values, 1e5f); + } + reduce_max = svmax_m(pg, reduce_max, values); + svst1(pg, softmax_buf.get() + ki, values); + } + kutacc::softmax_with_max(softmax_buf.get(), + (__bf16 *)a.data_ptr() + hi * a.strides()[0] + qi * a.strides()[1], n_res, + svmaxv(svptrue_b32(), reduce_max)); + }); + }); + for (int64_t hi = 0; hi < no_heads; hi++) { + addmm(to_bf16(1), + Tensor((__bf16 *)a.data_ptr() + hi * a.strides()[0], {n_res, n_res}, {a.strides()[1], 1}, 2, kBF16), + Tensor((__bf16 *)v.data_ptr() + hi * v.strides()[1], {n_res, c_hidden}, {v.strides()[0], 1}, 2, kBF16), + to_bf16(0), + Tensor((__bf16 *)o.data_ptr() + hi * o.strides()[1], {n_res, c_hidden}, {o.strides()[0], 1}, 2 ,kBF16)); + } + parallel_for(0, n_res, 1, [&](int64_t start, int64_t end) { + for (int64_t qi = start; qi < end; qi++) { + for (int64_t hi = 0; hi < no_heads; hi ++) { + for (int64_t vpi = 0; vpi < no_v_points; vpi++) { + float o_pt_buf[3]; + for (int64_t di = 0; di < 3; di++) { + auto a_data = (__bf16 *)a.data_ptr() + hi * a.strides()[0] + qi * a.strides()[1]; + auto v_pts_data = (__bf16 *)v_pts.data_ptr() + hi * v_pts.strides()[0] + + vpi * v_pts.strides()[1] + di * v_pts.strides()[2]; + int64_t vl = (int64_t)svcntw(); + svfloat32_t reduce_sum = svdup_f32(0); + for (int64_t ki = 0; ki < n_res; ki +=vl) { + svbool_t pg = svwhilelt_b32(ki, n_res); + auto values = svld1(pg, a_data + ki); + auto v_pts_values = svld1(pg, v_pts_data + ki); + values = svmul_x(pg, values, v_pts_values); + reduce_sum = svadd_m(pg, reduce_sum, values); + } + float sum = svaddv(svptrue_b32(), reduce_sum); + o_pt_buf[di] = sum; + } + auto o_pt_data = (__bf16 *)o_pt.data_ptr() + qi * o_pt.strides()[0] + hi * o_pt.strides()[2] + + vpi * o_pt.strides()[3]; + auto o_pt_norm_data = (__bf16 *)o_pt_norm.data_ptr() + qi * o_pt_norm.strides()[0] + + hi * o_pt_norm.strides()[1] + vpi; + rigid_rot_vec_mul_kernel(o_pt_buf, (float *)rigid_rot_mats.data_ptr() + qi * 9, o_pt_buf, + (float *)rigid_trans.data_ptr() + qi * 3, true); + float sqrsum = 0; + for (int64_t i = 0; i < 3; i ++) { + o_pt_data[i * o_pt.strides()[1]] = to_bf16(o_pt_buf[i]); + sqrsum += o_pt_buf[i] * o_pt_buf[i]; + } + *o_pt_norm_data = to_bf16(std::sqrt(sqrsum) + 1e-8f); + } + } + } + }); + parallel_for(0, n_res, 1, [&](int64_t start, int64_t end) { + for (int64_t ri = start; ri < end; ri ++) { + addmm(to_bf16(1), + Tensor((__bf16 *)a.data_ptr() + ri * a.strides()[1], {no_heads, n_res}, {a.strides()[0], 1}, 2, kBF16), + Tensor((__bf16 *)z.data_ptr() + ri * z.strides()[0], {n_res, c_z}, {z.strides()[1], 1}, 2, kBF16), + to_bf16(0), + Tensor((__bf16 *)o_pair.data_ptr() + ri * o_pair.strides()[0], {no_heads, c_z}, {o_pair.strides()[1], 1}, 2, kBF16), + BlasExtendParams{.num_threads = 1}); + } + }); + // out = linear(collect, linear_out_w, linenar_out_b); +} + +} // namespace kutacc + + +void kutacc_af2_invariant_point(kutacc_tensor_h q, kutacc_tensor_h k, kutacc_tensor_h v, kutacc_tensor_h q_pts, kutacc_tensor_h k_pts, kutacc_tensor_h v_pts, + kutacc_tensor_h b, kutacc_tensor_h a, kutacc_tensor_h head_weights, kutacc_tensor_h weights_head_weights, kutacc_tensor_h o, kutacc_tensor_h o_pt, kutacc_tensor_h o_pt_norm, kutacc_tensor_h o_pair, + kutacc_tensor_h z, kutacc_tensor_h rigid_rot_mats, kutacc_tensor_h rigid_trans, kutacc_tensor_h mask, kutacc_tensor_h linear_b_w, kutacc_tensor_h linear_b_b, + int64_t n_res, int64_t c_z, int64_t c_hidden, int64_t no_heads, int64_t no_qk_points, int64_t no_v_points) +{ + kutacc::kutacc_af2_invariant_point_kernel(*kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(q_pts), *kutacc::convertKutaccTensor(k_pts), *kutacc::convertKutaccTensor(v_pts), + *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(head_weights), *kutacc::convertKutaccTensor(weights_head_weights), *kutacc::convertKutaccTensor(o), *kutacc::convertKutaccTensor(o_pt), *kutacc::convertKutaccTensor(o_pt_norm), *kutacc::convertKutaccTensor(o_pair), + *kutacc::convertKutaccTensor(z), *kutacc::convertKutaccTensor(rigid_rot_mats), *kutacc::convertKutaccTensor(rigid_trans), *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(linear_b_w), *kutacc::convertKutaccTensor(linear_b_b), + n_res, c_z, c_hidden, no_heads, no_qk_points, no_v_points); + +} + diff --git a/src/attention/rigid.cpp b/src/attention/rigid.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2669951b2f7050443bae1494e4bd7329488d2bf7 --- /dev/null +++ b/src/attention/rigid.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ +#include "rigid.h" +#include "tensor/tensor.h" +#include "utils/parallel.h" +#include "kutacc.h" +#include "utils/collapse.h" + +namespace kutacc { +template +void rigid_rot_vec_mul(Tensor &pts, Tensor &rot_mats, Tensor &out, kutacc_tensor_h trans) +{ + parallel_for(0, pts.numel() / 3, 1024, [&](int64_t start, int64_t end) { + if (pts.dim() == 2) { + collapse_for(start, end, pts.sizes()[0], [&](int64_t i0) { + auto pts_data = (scalar_t *)pts.data_ptr() + pts.strides()[0] * i0; + auto rot_mats_data = + (float *)rot_mats.data_ptr() + rot_mats.strides()[0] * (rot_mats.sizes()[0] != 1 ? i0 : 0); + auto out_data = (scalar_t *)out.data_ptr() + out.strides()[0] * i0; + float* trans_data = nullptr; + if (trans != nullptr) { + auto trans_t = *convertKutaccTensor(trans); + trans_data = (float *)trans_t.data_ptr() + trans_t.strides()[0] * + (trans_t.sizes()[0] != 1 ? i0 : 0); + } + rigid_rot_vec_mul_kernel(pts_data, rot_mats_data, out_data, trans_data); + }); + } else if (pts.dim() == 3) { + collapse_for(start, end, pts.sizes()[0], pts.sizes()[1], [&](int64_t i0, int64_t i1) { + auto pts_data = (scalar_t *)pts.data_ptr() + pts.strides()[0] * i0 + pts.strides()[1] * i1; + auto rot_mats_data = (float *)rot_mats.data_ptr() + + rot_mats.strides()[0] * (rot_mats.sizes()[0] != 1 ? i0 : 0) + + rot_mats.strides()[1] * (rot_mats.sizes()[1] != 1 ? i1 : 0); + auto out_data = (scalar_t *)out.data_ptr() + out.strides()[0] * i0 + + out.strides()[1] * i1; + float* trans_data = nullptr; + if (trans != nullptr) { + auto trans_t = *convertKutaccTensor(trans); + trans_data = (float *)trans_t.data_ptr() + trans_t.strides()[0] * + (trans_t.sizes()[0] != 1 ? i0 : 0) + + trans_t.strides()[1] * (trans_t.sizes()[1] != 1 ? i1 : 0); + } + rigid_rot_vec_mul_kernel(pts_data, rot_mats_data, out_data, trans_data); + }); + } else if (pts.dim() == 4) { + collapse_for(start, end, pts.sizes()[0], pts.sizes()[1], pts.sizes()[2], + [&](int64_t i0, int64_t i1, int64_t i2) { + auto pts_data = (scalar_t *)pts.data_ptr() + pts.strides()[0] * i0 + pts.strides()[1] * i1 + + pts.strides()[2] * i2; + auto rot_mats_data = (float *)rot_mats.data_ptr() + + rot_mats.strides()[0] * (rot_mats.sizes()[0] != 1 ? i0 : 0) + + rot_mats.strides()[1] * (rot_mats.sizes()[1] != 1 ? i1 : 0) + + rot_mats.strides()[2] * (rot_mats.sizes()[2] != 1 ? i2 : 0); + auto out_data = (scalar_t *)out.data_ptr() + out.strides()[0] * i0 + + out.strides()[1] * i1 + out.strides()[2] * i2; + float* trans_data = nullptr; + if (trans != nullptr) { + auto trans_t = *convertKutaccTensor(trans); + trans_data = (float *)trans_t.data_ptr() + trans_t.strides()[0] * + (trans_t.sizes()[0] != 1 ? i0 : 0) + + trans_t.strides()[1] * (trans_t.sizes()[1] != 1 ? i1 : 0) + + trans_t.strides()[2] * (trans_t.sizes()[2] != 1 ? i2 : 0); + } + rigid_rot_vec_mul_kernel(pts_data, rot_mats_data, out_data, trans_data); + }); + } else { + KUTACC_CHECK(false, "not implemented"); + } + }); +} + +void rigid_rot_matmul(Tensor &a, Tensor &b, Tensor &out) +{ + parallel_for(0, b.numel() / 9, 1024, [&](int64_t start, int64_t end) { + if (a.dim() == 3) { + collapse_for(start, end, [&](int64_t i0){ + auto a_data = (float *)a.data_ptr() + a.strides()[0] * (a.sizes()[0] != 1 ? i0 : 0); + auto b_data = (float *)b.data_ptr() + b.strides()[0] * i0; + auto out_data = (float *)out.data_ptr() + out.strides()[0] * i0; + rigid_rot_matmul_kernel(a_data, b_data, out_data); + }); + } else if (a.dim() == 4) { + collapse_for(start, end, a.sizes()[0], a.sizes()[1], [&](int64_t i0, int64_t i1) { + auto a_data = (float *)a.data_ptr() + a.strides()[0] * (a.sizes()[0] != 1 ? i0 : 0) + + a.strides()[1] * (a.sizes()[1] != 1 ? i1 : 0); + auto b_data = (float *)b.data_ptr() + b.strides()[0] * i0 + b.strides()[1] * i1; + auto out_data = (float *)out.data_ptr() + out.strides()[0] * i0 + out.strides()[1] * i1; + rigid_rot_matmul_kernel(a_data, b_data, out_data); + }); + } else { + KUTACC_CHECK(false, "not implemented"); + } + }); +} +} + +void kutacc_af2_rigid_rot_vec_mul(kutacc_tensor_h pts, kutacc_tensor_h rot_mats, kutacc_tensor_h out, kutacc_tensor_h trans) +{ + if ((*kutacc::convertKutaccTensor(pts)).dtype() == kutacc::kBF16) { + kutacc::rigid_rot_vec_mul<__bf16>(*kutacc::convertKutaccTensor(pts), *kutacc::convertKutaccTensor(rot_mats), *kutacc::convertKutaccTensor(out), trans); + } +} + +void kutacc_af2_rigid_rot_matmul(kutacc_tensor_h a, kutacc_tensor_h b, kutacc_tensor_h out) +{ + kutacc::rigid_rot_matmul(*kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(out)); +} \ No newline at end of file diff --git a/src/attention/rigid.h b/src/attention/rigid.h new file mode 100644 index 0000000000000000000000000000000000000000..4fc89681c87e71314c70b2dfb1ab07502a08948c --- /dev/null +++ b/src/attention/rigid.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ +#ifndef KUTACC_RIGID_H +#define KUTACC_RIGID_H + +#include "utils/check.h" +#include "wrapper/wrapper.h" + +namespace kutacc{ +template +inline void rigid_rot_vec_mul_kernel(scalar_t *pt, const float *rot_mat, scalar_t *out, const float *trans = nullptr, + bool invert = false) +{ + float output[3]; + float input[3]; + { + auto pg = svwhilelt_b32(0, 3); + auto pt_values = svld1(pg, pt); + svst1(pg, input, pt_values); + } + if (!invert) { + for (int64_t i = 0; i < 3; i++) { + float sum = 0; + for (int64_t j = 0; j < 3; j++) { + sum += rot_mat[i * 3 + j] * input[j]; + } + output[i] = sum + (trans != nullptr ? trans[i] : 0); + } + } else { + for (int64_t i = 0; i < 3; i++) { + float sum = 0; + for (int64_t j = 0; j < 3; j++) { + sum += rot_mat[j * 3 + i] * (input[j] - (trans != nullptr ? trans[j] : 0)); + } + output[i] = sum; + } + } + { + auto pg = svwhilelt_b32(0, 3); + auto pt_values = svld1(pg, output); + svst1(pg, out, pt_values); + } +} + +inline void rigid_rot_matmul_kernel(const float *a, const float *b, float *out) +{ + for (int64_t i = 0; i < 3; i++) { + for (int64_t j = 0; j < 3; j++) { + float sum = 0; + for (int64_t k = 0; k < 3; k++) { + sum += a[i * 3 + k] * b[k * 3 + j]; + } + out[i * 3 + j] = sum; + } + } +} +} + +#endif \ No newline at end of file diff --git a/src/linear/linear.cpp b/src/linear/linear.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c30e45492e7f55142b2b136cc8b6804d8be0259c --- /dev/null +++ b/src/linear/linear.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ +#include +#include "kutacc.h" +#include "tensor/tensor.h" +#include "linear/mm.h" +#include "utils/check.h" +#include "utils/collapse.h" + +namespace kutacc{ +void linear_kernel(Tensor &act, Tensor &weight, float* bias_data, Tensor& result, int64_t beta) +{ + int64_t m = act.numel() / act.sizes().back(); + int64_t n = weight.numel() / weight.sizes().back(); + int64_t k = act.sizes().back(); + bool bias_flag = bias_data != nullptr ? true : false; + if (bias_flag) { + addmm(to_bf16(1), + Tensor((__bf16 *)act.data_ptr(), {m, k}, {act.strides()[(unsigned long)act.dim() - 2], 1}, 2, kBF16), + Tensor((__bf16 *)weight.data_ptr(), {k, n}, {1, weight.strides()[(unsigned long)weight.dim() - 2]}, 2, kBF16), + to_bf16((float)beta), + Tensor((__bf16 *)result.data_ptr(), {m, n}, {n, 1}, 2, kBF16), + BlasExtendParams{.row_bias = bias_flag, .bias = bias_data}); + } else { + addmm(to_bf16(1), + Tensor((__bf16 *)act.data_ptr(), {m, k}, {act.strides()[(unsigned long)act.dim() - 2], 1}, 2, kBF16), + Tensor((__bf16 *)weight.data_ptr(), {k, n}, {1, weight.strides()[(unsigned long)weight.dim() - 2]}, 2, kBF16), + to_bf16((float)beta), + Tensor((__bf16 *)result.data_ptr(), {m, n}, {n, 1}, 2, kBF16), + BlasExtendParams{.row_bias = bias_flag, .bias = bias_data}); + } +} +} + +void kutacc_af2_linear(kutacc_tensor_h act, kutacc_tensor_h weight, float* bias_data, kutacc_tensor_h result, int64_t beta) +{ + kutacc::linear_kernel(*kutacc::convertKutaccTensor(act), *kutacc::convertKutaccTensor(weight), bias_data, + *kutacc::convertKutaccTensor(result), beta); +} \ No newline at end of file diff --git a/src/linear/mm.cpp b/src/linear/mm.cpp index 1aebcb4f2a1ee423e93597a882de68ed30e062c5..7d2b46f009b7afe129b8942d250a675397c8bf56 100644 --- a/src/linear/mm.cpp +++ b/src/linear/mm.cpp @@ -12,9 +12,9 @@ * SOFTWARE. */ #include "mm.h" -#include "../utils/blas.h" -#include "../utils/check.h" -#include "../utils/parallel.h" +#include "utils/blas.h" +#include "utils/check.h" +#include "utils/parallel.h" #define DIM_2 2 @@ -93,7 +93,7 @@ namespace kutacc { * @brief weight shape [n, k], dtype = __bf16 * @brief shape [len], dtype = __bf16 */ -kutacc_export void kutacc_linear_weight_prepack(const __bf16 *weight, __bf16 *result, int64_t n, int64_t k, int64_t ldb, int64_t num_threads) +kutacc_export void kutacc_af2_linear_weight_prepack(const __bf16 *weight, __bf16 *result, int64_t n, int64_t k, int64_t ldb, int64_t num_threads) { if (num_threads != 1) { int64_t nblocks = (n + 32 - 1) / 32; diff --git a/src/linear/mm.h b/src/linear/mm.h index 686e4956707984a18464a432c2b20e64ae44a4b8..bab961426b9928116de19b67b2f7e34bf3ce84af 100644 --- a/src/linear/mm.h +++ b/src/linear/mm.h @@ -15,7 +15,7 @@ #define KUTACC_MM_H #include "kutacc.h" -#include "../tensor/tensor.h" +#include "tensor/tensor.h" #include #include #include @@ -23,7 +23,7 @@ namespace kutacc { void addmm(__bf16 alpha, const kutacc::Tensor &a, const kutacc::Tensor &b, __bf16 beta, const kutacc::Tensor &c, - kutacc::BlasExtendParams param); + kutacc::BlasExtendParams param = {}); } #endif \ No newline at end of file diff --git a/src/normalization/layernorm_kernel.h b/src/normalization/layernorm_kernel.h index 4bf3d2931735a8b225108b01ee53fa424fc0b2e3..74400d496b587906081a152feb32b5c019e8ed1a 100644 --- a/src/normalization/layernorm_kernel.h +++ b/src/normalization/layernorm_kernel.h @@ -14,7 +14,7 @@ #ifndef KUTACC_LAYERNORM_KERNEL_H #define KUTACC_LAYERNORM_KERNEL_H -#include "../wrapper/wrapper.h" +#include "wrapper/wrapper.h" #include #include diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index c9cf467beb2ea1248132023308a418e20a826ecb..4064e6a7e25a4156957b2766c823e11c49056e54 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -14,7 +14,7 @@ #include "tensor.h" #include #include "kutacc.h" -#include "../utils/check.h" +#include "utils/check.h" #include namespace kutacc{ @@ -25,6 +25,11 @@ TensorWrapper::TensorWrapper(void *data_ptr, std::vector sizes, tensor_ = t; } +TensorWrapper::TensorWrapper() +{ + tensor_ = nullptr; +} + void* Tensor::data_ptr() const { return data.data_ptr; @@ -50,6 +55,18 @@ int64_t Tensor::dim() const return data.dim; } +int64_t Tensor::numel() const +{ + if (data.dim == 0) { + return 0; + } + int64_t res = 1; + for (int64_t i = 0; i < data.dim; i++) { + res *= data.sizes[(uint64_t)i]; + } + return res; +} + // void Tensor::SetTensorStrides(std::vector &strides) // { // 预留接口优化addmm装包解包过程 // data.strides = strides; diff --git a/src/tensor/tensor.h b/src/tensor/tensor.h index 12ac5a3d12e4ed5314996b959a2ec6a635e6ef7c..cedaac24b2ca2403123783540a714e5c3de753e6 100644 --- a/src/tensor/tensor.h +++ b/src/tensor/tensor.h @@ -16,7 +16,7 @@ #include #include "kutacc.h" -#include "../utils/check.h" +#include "utils/check.h" namespace kutacc { struct SimpleTensor { @@ -44,6 +44,7 @@ struct Tensor { std::vector strides() const; DType dtype() const; int64_t dim() const; + int64_t numel() const; }; Tensor* convertKutaccTensor(void* tensor_); diff --git a/src/utils/blas.cpp b/src/utils/blas.cpp index dbaf0f675b3b2f38ab2dacac21e70b8fa70b56bc..befa4ea7bc25d681cb0e35db97c30ee731739acb 100644 --- a/src/utils/blas.cpp +++ b/src/utils/blas.cpp @@ -91,7 +91,7 @@ typedef struct BlasExtendBiasExtra_ { DEF_GEMM(kutacc_core_b, __bf16, __bf16); -kutacc_export size_t kutacc_gemm_pack_get_size(char identifier, char transa, char transb, int m ,int n, int k) +kutacc_export size_t kutacc_af2_gemm_pack_get_size(char identifier, char transa, char transb, int m ,int n, int k) { (void)transa; (void)transb; diff --git a/src/utils/collapse.h b/src/utils/collapse.h index 772559a7cd6b7440f012a7d2ee29ad90c9513ae0..78968f43f99abf43c3edeec2b89201a79925da24 100644 --- a/src/utils/collapse.h +++ b/src/utils/collapse.h @@ -67,6 +67,15 @@ inline void collapse_for(int64_t start, int64_t end, const F &f) } } +template +inline void collapse_for(int64_t start, int64_t end, int64_t n0, const F &f) +{ + (void)n0; + for (int64_t i = start; i < end; i ++) { + f(i); + } +} + template inline void collapse_for(int64_t start, int64_t end, int64_t n0, int64_t n1, const F &f) {