From 957fe88edac279e4d2180f514949a64442b40cd1 Mon Sep 17 00:00:00 2001 From: tedeasonwang Date: Fri, 19 Sep 2025 16:49:46 +0800 Subject: [PATCH 1/3] add invariant_point linear rigid --- csrc/tpp/alphafold/bind.h | 22 +++- csrc/tpp/alphafold/gating_attention.cpp | 2 +- csrc/tpp/alphafold/gating_attention.h | 4 +- csrc/tpp/alphafold/invariant_point.cpp | 153 ++++++++++++++++++++++++ csrc/tpp/alphafold/invariant_point.h | 60 ++++++++++ csrc/tpp/alphafold/rigid.cpp | 83 +++++++++++++ csrc/tpp/alphafold/rigid.h | 25 ++++ csrc/tpp/alphafold/transition.cpp | 2 +- csrc/tpp/alphafold/transition.h | 4 +- csrc/utils/TensorWrapper.h | 4 +- csrc/utils/bf16.h | 4 +- csrc/utils/check.h | 4 +- csrc/utils/layernorm.h | 6 +- csrc/utils/linear.h | 70 +++++++++++ csrc/utils/memory.h | 4 +- csrc/utils/parallel.h | 4 +- kpex/tpp/alphafold/alphafold.py | 78 ++++++++++++ 17 files changed, 509 insertions(+), 20 deletions(-) create mode 100644 csrc/tpp/alphafold/invariant_point.cpp create mode 100644 csrc/tpp/alphafold/invariant_point.h create mode 100644 csrc/tpp/alphafold/rigid.cpp create mode 100644 csrc/tpp/alphafold/rigid.h create mode 100644 csrc/utils/linear.h diff --git a/csrc/tpp/alphafold/bind.h b/csrc/tpp/alphafold/bind.h index 6477947..23a87d8 100644 --- a/csrc/tpp/alphafold/bind.h +++ b/csrc/tpp/alphafold/bind.h @@ -18,6 +18,8 @@ #include "gating_attention.h" #include "transition.h" +#include "invariant_point.h" +#include "rigid.h" // #include "utils/layernorm.h" namespace alphafold { @@ -38,7 +40,25 @@ inline void bind(pybind11::module &m) py::arg("linear2_b")); submodule.def("transition", &transition, py::arg("act"), py::arg("weights")); - // submodule.def("layernorm", &layernorm, py::arg("act"), py::arg("weight_"), py::arg("bias_")); + submodule.def("rigid_rot_vec_mul", &rigid_rot_vec_mul, py::arg("pts"), py::arg("rot_mats"), + py::arg("trans") = std::nullopt); + submodule.def("rigid_rot_matmul", &rigid_rot_matmul, py::arg("a"), py::arg("b")); + + py::class_(submodule, "InvariantPointAttentionWeight") + .def(py::init(), + py::arg("c_s"), py::arg("c_z"), py::arg("c_hidden"), py::arg("no_heads"), py::arg("no_qk_points"), + py::arg("no_v_points"), py::arg("is_multimer"), py::arg("linear_q_w"), py::arg("linear_q_b"), py::arg("linear_kv_w"), py::arg("linear_kv_b"), + py::arg("linear_q_points_w"), py::arg("linear_q_points_b"), py::arg("linear_kv_points_w"), py::arg("linear_kv_points_b"), py::arg("linear_b_w"), py::arg("linear_b_b"), py::arg("head_weights"), + py::arg("linear_out_w"), py::arg("linear_out_b")); + submodule.def("invariant_point_attention", &invariant_point_attention, py::arg("s"), py::arg("z"), py::arg("rigid_trans"), py::arg("rigid_rot_mats"), py::arg("mask"), + py::arg("weights")); + + + submodule.def("rigid_rot_vec_mul", &rigid_rot_vec_mul, py::arg("pts"), py::arg("rot_mats"), + py::arg("trans") = std::nullopt); + + submodule.def("rigid_rot_matmul", &rigid_rot_matmul, py::arg("a"), py::arg("b")); } } diff --git a/csrc/tpp/alphafold/gating_attention.cpp b/csrc/tpp/alphafold/gating_attention.cpp index 28c05e0..3069f71 100644 --- a/csrc/tpp/alphafold/gating_attention.cpp +++ b/csrc/tpp/alphafold/gating_attention.cpp @@ -133,7 +133,7 @@ at::Tensor gating_attention(at::Tensor &q_data, at::Tensor &m_data, at::Tensor & auto out_tw = convert_to_tensor_wrapper(out); - kutacc_gating_attention(input_tw.get_tensor(), q_tw.get_tensor(), k_tw.get_tensor(), v_tw.get_tensor(), + kutacc_af2_gating_attention(input_tw.get_tensor(), q_tw.get_tensor(), k_tw.get_tensor(), v_tw.get_tensor(), gate_tw.get_tensor(), weighted_avg_tw.get_tensor(), batch, seq_len, m_data_tw.get_tensor(), bias_tw.get_tensor(), nonbatched_bias_tw.get_tensor(), query_w_tw.get_tensor(), key_w_tw.get_tensor(), value_w_tw.get_tensor(), gating_w_tw.get_tensor(), gating_b_tw.get_tensor(), output_w_tw.get_tensor(), diff --git a/csrc/tpp/alphafold/gating_attention.h b/csrc/tpp/alphafold/gating_attention.h index b934344..dfb2dcf 100644 --- a/csrc/tpp/alphafold/gating_attention.h +++ b/csrc/tpp/alphafold/gating_attention.h @@ -11,8 +11,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef KUTACC_TPP_ALPHAFOLD_GATING_ATTENTION_H -#define KUTACC_TPP_ALPHAFOLD_GATING_ATTENTION_H +#ifndef KPEX_TPP_ALPHAFOLD_GATING_ATTENTION_H +#define KPEX_TPP_ALPHAFOLD_GATING_ATTENTION_H #include "utils/check.h" #include diff --git a/csrc/tpp/alphafold/invariant_point.cpp b/csrc/tpp/alphafold/invariant_point.cpp new file mode 100644 index 0000000..28f798b --- /dev/null +++ b/csrc/tpp/alphafold/invariant_point.cpp @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ +#include "invariant_point.h" +#include "rigid.h" + +#include "kutacc.h" +#include "utils/linear.h" +#include +#include +#include "utils/memory.h" + +namespace alphafold { + + +at::Tensor invariant_point_attention(at::Tensor &s, at::Tensor &z, at::Tensor &rigid_trans, at::Tensor &rigid_rot_mats, + at::Tensor &mask, const InvariantPointAttentionWeight &weights) +{ + at::Tensor out = at::empty(s.sizes(), s.options()); + int64_t n_res = s.sizes()[0]; + int64_t c_s = weights.c_s; + int64_t c_z = weights.c_z; + int64_t c_hidden = weights.c_hidden; + int64_t no_heads = weights.no_heads; + int64_t no_qk_points = weights.no_qk_points; + int64_t no_v_points = weights.no_v_points; + bool is_multimer = weights.is_multimer; + + KPEX_CHECK(s.scalar_type() == c10::kBFloat16, s.scalar_type()); + KPEX_CHECK(z.scalar_type() == c10::kBFloat16, z.scalar_type()); + KPEX_CHECK(rigid_trans.scalar_type() == c10::kFloat, rigid_trans.scalar_type()); + KPEX_CHECK(rigid_rot_mats.scalar_type() == c10::kFloat, rigid_rot_mats.scalar_type()); + KPEX_CHECK(mask.scalar_type() == c10::kBFloat16, mask.scalar_type()); + + KPEX_CHECK_TENSOR_SHAPE(s, n_res, c_s); + KPEX_CHECK_TENSOR_SHAPE(z, n_res, n_res, c_z); + KPEX_CHECK_TENSOR_SHAPE(rigid_trans, n_res, 3); + KPEX_CHECK_TENSOR_SHAPE(rigid_rot_mats, n_res, 3, 3); + KPEX_CHECK_TENSOR_SHAPE(mask, n_res); + KPEX_CHECK(!is_multimer, "not implemented"); + + rigid_trans = rigid_trans.view({n_res, 1, 1, 3}); + rigid_rot_mats = rigid_rot_mats.view({n_res, 1, 1, 3, 3}); + + auto q = linear(s, weights.linear_q_w, weights.linear_q_b); + auto k = linear(s, weights.linear_k_w, weights.linear_k_b); + auto v = linear(s, weights.linear_v_w, weights.linear_v_b); + + auto q_pts = linear(s, weights.linear_q_points_w, weights.linear_q_points_b); + q_pts = rigid_rot_vec_mul(q_pts, rigid_rot_mats, rigid_trans); + + auto k_pts = linear(s, weights.linear_k_points_w, weights.linear_k_points_b); + k_pts = rigid_rot_vec_mul(k_pts, rigid_rot_mats, rigid_trans); + auto v_pts = linear(s, weights.linear_v_points_w, weights.linear_v_points_b); + v_pts = rigid_rot_vec_mul(v_pts, rigid_rot_mats, rigid_trans); + v_pts = v_pts.permute({1, 2, 3, 0}).contiguous(); + + auto b = at::empty({no_heads, n_res, n_res}, s.options()); + auto a = at::empty({no_heads, n_res, n_res}, q.options()); + auto head_weights = at::empty(weights.head_weights.sizes(), weights.head_weights.options()); + auto collect = at::empty({n_res, no_heads * (c_hidden + no_v_points * 4 + c_z)}, s.options()); + auto o = collect.narrow(1, 0, no_heads * c_hidden).view({n_res, no_heads, c_hidden}); + auto o_pt = collect.narrow(1, no_heads * c_hidden, no_heads * no_v_points * 3).view({n_res, 3, no_heads, no_v_points}); + auto o_pt_norm = collect.narrow(1, no_heads * (c_hidden + no_v_points * 3), no_heads * no_v_points).view({n_res, no_heads, no_v_points}); + auto o_pair = collect.narrow(1, no_heads * (c_hidden + no_v_points * 4), no_heads * c_z).view({n_res, no_heads, c_z}); + + auto q_tw = convert_to_tensor_wrapper(q); + auto k_tw = convert_to_tensor_wrapper(k); + auto v_tw = convert_to_tensor_wrapper(v); + auto q_pts_tw = convert_to_tensor_wrapper(q_pts); + auto k_pts_tw = convert_to_tensor_wrapper(k_pts); + auto v_pts_tw = convert_to_tensor_wrapper(v_pts); + auto b_tw = convert_to_tensor_wrapper(b); + auto a_tw = convert_to_tensor_wrapper(a); + auto head_weights_tw = convert_to_tensor_wrapper(head_weights); + auto weights_head_weights_tw = convert_to_tensor_wrapper(weights.head_weights); + // auto collect_tw = convert_to_tensor_wrapper(collect); + auto o_tw = convert_to_tensor_wrapper(o); + auto o_pt_tw = convert_to_tensor_wrapper(o_pt); + auto o_pt_norm_tw = convert_to_tensor_wrapper(o_pt_norm); + auto o_pair_tw = convert_to_tensor_wrapper(o_pair); + auto z_tw = convert_to_tensor_wrapper(z); + auto rigid_rot_mats_tw = convert_to_tensor_wrapper(rigid_rot_mats); + auto rigid_trans_tw = convert_to_tensor_wrapper(rigid_trans); + auto mask_tw = convert_to_tensor_wrapper(mask); + auto linear_b_w_tw = convert_to_tensor_wrapper(weights.linear_b_w); + auto linear_b_b_tw = convert_to_tensor_wrapper(weights.linear_b_b); + + kutacc_af2_invariant_point(q_tw.get_tensor(), k_tw.get_tensor(), v_tw.get_tensor(), q_pts_tw.get_tensor(), k_pts_tw.get_tensor(), v_pts_tw.get_tensor(), + b_tw.get_tensor(), a_tw.get_tensor(), head_weights_tw.get_tensor(), weights_head_weights_tw.get_tensor(), o_tw.get_tensor(), o_pt_tw.get_tensor(), o_pt_norm_tw.get_tensor(), o_pair_tw.get_tensor(), + z_tw.get_tensor(), rigid_rot_mats_tw.get_tensor(), rigid_trans_tw.get_tensor(), mask_tw.get_tensor(), linear_b_w_tw.get_tensor(), linear_b_b_tw.get_tensor(), + n_res, c_z, c_hidden, no_heads, no_qk_points, no_v_points); + + out = linear(collect, weights.linear_out_w, weights.linear_out_b); + + return out; +} + + +InvariantPointAttentionWeight::InvariantPointAttentionWeight(int64_t c_s, int64_t c_z, int64_t c_hidden, + int64_t no_heads, int64_t no_qk_points, int64_t no_v_points, bool is_multimer, at::Tensor &linear_q_w, + at::Tensor &linear_q_b, at::Tensor &linear_kv_w, at::Tensor linear_kv_b, at::Tensor &linear_q_points_w, + at::Tensor &linear_q_points_b, at::Tensor &linear_kv_points_w, at::Tensor &linear_kv_points_b, + at::Tensor &linear_b_w, at::Tensor &linear_b_b, at::Tensor &head_weights, at::Tensor &linear_out_w, + at::Tensor &linear_out_b) + : c_s(c_s), c_z(c_z), c_hidden(c_hidden), no_heads(no_heads), no_qk_points(no_qk_points), no_v_points(no_v_points), + is_multimer(is_multimer) +{ + // TO DO CHECK + linear_q_w = linear_q_w.view({no_heads, c_hidden, c_s}); + linear_q_b = linear_q_b.view({no_heads, c_hidden}); + linear_kv_w = linear_kv_w.view({no_heads, 2 * c_hidden, c_s}); + linear_kv_b = linear_kv_b.view({no_heads, 2 * c_hidden}); + linear_q_points_w = linear_q_points_w.view({3, no_heads, no_qk_points, c_s}).permute({1, 2, 0, 3}); + linear_q_points_b = linear_q_points_b.view({3, no_heads, no_qk_points}).permute({1, 2, 0}); + linear_kv_points_w = linear_kv_points_w.view({3, no_heads, (no_qk_points + no_v_points), c_s}).permute({1, 2, 0, 3}); + linear_kv_points_b = linear_kv_points_b.view({3, no_heads, (no_qk_points + no_v_points)}).permute({1, 2, 0}); + + auto float_opt = linear_q_w.options().device(kpex::device()).dtype(c10::kFloat); + auto bf16_opt = linear_q_w.options().device(kpex::device()).dtype(c10::kBFloat16); + this->linear_q_w = linear_q_w.to(bf16_opt).contiguous(); + this->linear_q_b = linear_q_b.to(float_opt).contiguous(); + this->linear_k_w = linear_kv_w.narrow(1, 0, c_hidden).to(bf16_opt).contiguous(); + this->linear_v_w = linear_kv_w.narrow(1, c_hidden, c_hidden).to(bf16_opt).contiguous(); + this->linear_k_b = linear_kv_b.narrow(1, 0, c_hidden).to(float_opt).contiguous(); + this->linear_v_b = linear_kv_b.narrow(1, c_hidden, c_hidden).to(float_opt).contiguous(); + this->linear_q_points_w = linear_q_points_w.to(bf16_opt).contiguous(); + this->linear_q_points_b = linear_q_points_b.to(float_opt).contiguous(); + this->linear_k_points_w = linear_kv_points_w.narrow(1, 0, no_qk_points).to(bf16_opt).contiguous(); + this->linear_k_points_b = linear_kv_points_b.narrow(1, 0, no_qk_points).to(float_opt).contiguous(); + this->linear_v_points_w = linear_kv_points_w.narrow(1, no_qk_points, no_v_points).to(bf16_opt).contiguous(); + this->linear_v_points_b = linear_kv_points_b.narrow(1, no_qk_points, no_v_points).to(float_opt).contiguous(); + this->linear_b_w = linear_b_w.to(bf16_opt).contiguous(); + this->linear_b_b = linear_b_b.to(float_opt).contiguous(); + this->head_weights = head_weights.to(float_opt).contiguous(); + this->linear_out_w = linear_out_w.to(bf16_opt).contiguous(); + this->linear_out_b = linear_out_b.to(float_opt).contiguous(); + +} + + + +} \ No newline at end of file diff --git a/csrc/tpp/alphafold/invariant_point.h b/csrc/tpp/alphafold/invariant_point.h new file mode 100644 index 0000000..69ea56b --- /dev/null +++ b/csrc/tpp/alphafold/invariant_point.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ +#ifndef KPEX_TPP_ALPHAFOLD_INVARIANT_POINT_H +#define KPEX_TPP_ALPHAFOLD_INVARIANT_POINT_H + +#include "utils/check.h" +#include + +namespace alphafold { +struct InvariantPointAttentionWeight { + int64_t c_s; + int64_t c_z; + int64_t c_hidden; + int64_t no_heads; + int64_t no_qk_points; + int64_t no_v_points; + bool is_multimer; + + at::Tensor linear_q_w; + at::Tensor linear_q_b; + at::Tensor linear_k_w; + at::Tensor linear_k_b; + at::Tensor linear_v_w; + at::Tensor linear_v_b; + at::Tensor linear_q_points_w; + at::Tensor linear_q_points_b; + at::Tensor linear_k_points_w; + at::Tensor linear_k_points_b; + at::Tensor linear_v_points_w; + at::Tensor linear_v_points_b; + at::Tensor linear_b_w; + at::Tensor linear_b_b; + at::Tensor head_weights; + at::Tensor linear_out_w; + at::Tensor linear_out_b; + + InvariantPointAttentionWeight(int64_t c_s, int64_t c_z, int64_t c_hidden, + int64_t no_heads, int64_t no_qk_points, int64_t no_v_points, bool is_multimer, at::Tensor &linear_q_w, + at::Tensor &linear_q_b, at::Tensor &linear_kv_w, at::Tensor linear_kv_b, at::Tensor &linear_q_points_w, + at::Tensor &linear_q_points_b, at::Tensor &linear_kv_points_w, at::Tensor &linear_kv_points_b, + at::Tensor &linear_b_w, at::Tensor &linear_b_b, at::Tensor &head_weights, at::Tensor &linear_out_w, + at::Tensor &linear_out_b); +}; + + at::Tensor invariant_point_attention(at::Tensor &s, at::Tensor &z, at::Tensor &rigid_trans, at::Tensor &rigid_rot_mats, + at::Tensor &mask, const InvariantPointAttentionWeight &weights); +} + +#endif \ No newline at end of file diff --git a/csrc/tpp/alphafold/rigid.cpp b/csrc/tpp/alphafold/rigid.cpp new file mode 100644 index 0000000..4147547 --- /dev/null +++ b/csrc/tpp/alphafold/rigid.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ +#include "rigid.h" +#include "utils/TensorWrapper.h" + +#include +#include +#include + +#include "kutacc.h" + +namespace alphafold { + at::Tensor rigid_rot_vec_mul(at::Tensor &pts, at::Tensor &rot_mats, std::optional trans) + { + int64_t dim = pts.dim() - 1; + KPEX_CHECK(pts.strides()[dim] == 1, pts.strides()); + KPEX_CHECK(rot_mats.dim() == dim + 2, rot_mats.dim(), dim); + KPEX_CHECK(rot_mats.strides()[dim] == 3 && rot_mats.strides()[dim + 1] == 1, rot_mats.strides()); + kutacc_tensor_h trans_ = nullptr; + if (trans.has_value()) { + KPEX_CHECK(trans->dim() == dim + 1, trans->dim(), dim); + KPEX_CHECK(trans->strides()[dim] == 1, trans->strides()); + auto trans_tensor = trans.value(); + + auto trans_tw = convert_to_tensor_wrapper(trans_tensor); + trans_ = convert_to_tensor_wrapper(trans.value()).get_tensor(); + auto out = at::empty(pts.sizes(), pts.options()); + + auto pts_tw = convert_to_tensor_wrapper(pts); + auto rot_mats_tw = convert_to_tensor_wrapper(rot_mats); + auto out_tw = convert_to_tensor_wrapper(out); + kutacc_af2_rigid_rot_vec_mul(pts_tw.get_tensor(), rot_mats_tw.get_tensor(), out_tw.get_tensor(), trans_tw.get_tensor()); + return out; + } else { + auto out = at::empty(pts.sizes(), pts.options()); + + auto pts_tw = convert_to_tensor_wrapper(pts); + auto rot_mats_tw = convert_to_tensor_wrapper(rot_mats); + auto out_tw = convert_to_tensor_wrapper(out); + kutacc_af2_rigid_rot_vec_mul(pts_tw.get_tensor(), rot_mats_tw.get_tensor(), out_tw.get_tensor(), nullptr); + return out; + } + + auto out = at::empty(pts.sizes(), pts.options()); + + auto pts_tw = convert_to_tensor_wrapper(pts); + auto rot_mats_tw = convert_to_tensor_wrapper(rot_mats); + auto out_tw = convert_to_tensor_wrapper(out); + kutacc_af2_rigid_rot_vec_mul(pts_tw.get_tensor(), rot_mats_tw.get_tensor(), out_tw.get_tensor(), trans_); + return out; + } + + at::Tensor rigid_rot_matmul(at::Tensor &a, at::Tensor &b) + { + int64_t dim = a.dim() - 2; + KPEX_CHECK(b.dim() == dim + 2, b.dim(), dim); + KPEX_CHECK(a.strides()[dim] == 3 && a.strides()[dim + 1] == 1, a.strides()); + KPEX_CHECK(b.strides()[dim] == 3 && b.strides()[dim + 1] == 1, b.strides()); + KPEX_CHECK(a.scalar_type() == c10::kFloat, a.scalar_type()); + KPEX_CHECK(b.scalar_type() == c10::kFloat, b.scalar_type()); + + auto out = at::empty(b.sizes(), b.options()); + + auto a_tw = convert_to_tensor_wrapper(a); + auto b_tw = convert_to_tensor_wrapper(b); + auto out_tw = convert_to_tensor_wrapper(out); + kutacc_af2_rigid_rot_matmul(a_tw.get_tensor(), b_tw.get_tensor(), out_tw.get_tensor()); + return out; + } + +} + diff --git a/csrc/tpp/alphafold/rigid.h b/csrc/tpp/alphafold/rigid.h new file mode 100644 index 0000000..6dc8b31 --- /dev/null +++ b/csrc/tpp/alphafold/rigid.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ +#pragma once + +#include "utils/check.h" +#include "kutacc.h" + +#include +#include + +namespace alphafold { + at::Tensor rigid_rot_vec_mul(at::Tensor &pts, at::Tensor &rot_mats, std::optional trans); + at::Tensor rigid_rot_matmul(at::Tensor &a, at::Tensor &b); +} diff --git a/csrc/tpp/alphafold/transition.cpp b/csrc/tpp/alphafold/transition.cpp index b413622..639ccbb 100644 --- a/csrc/tpp/alphafold/transition.cpp +++ b/csrc/tpp/alphafold/transition.cpp @@ -43,7 +43,7 @@ at::Tensor transition(at::Tensor &act, const TransitionWeight &weights) auto intermediate_act_tw = convert_to_tensor_wrapper(intermediate_act); auto out_tw = convert_to_tensor_wrapper(out); - kutacc_transition(input_act_tw.get_tensor(), linear1_w_tw.get_tensor(), linear1_b_tw.get_tensor(), linear2_w_tw.get_tensor(), linear2_b_tw.get_tensor(), + kutacc_af2_transition(input_act_tw.get_tensor(), linear1_w_tw.get_tensor(), linear1_b_tw.get_tensor(), linear2_w_tw.get_tensor(), linear2_b_tw.get_tensor(), intermediate_act_tw.get_tensor(), out_tw.get_tensor(), batch, n_res, c_o, c_i); return out; } diff --git a/csrc/tpp/alphafold/transition.h b/csrc/tpp/alphafold/transition.h index 2fd09aa..8fef415 100644 --- a/csrc/tpp/alphafold/transition.h +++ b/csrc/tpp/alphafold/transition.h @@ -11,8 +11,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef KUTACC_TPP_ALPHAFOLD_TRANSITION_H -#define KUTACC_TPP_ALPHAFOLD_TRANSITION_H +#ifndef KPEX_TPP_ALPHAFOLD_TRANSITION_H +#define KPEX_TPP_ALPHAFOLD_TRANSITION_H #include "utils/check.h" #include diff --git a/csrc/utils/TensorWrapper.h b/csrc/utils/TensorWrapper.h index 6126833..d5b5fcf 100644 --- a/csrc/utils/TensorWrapper.h +++ b/csrc/utils/TensorWrapper.h @@ -55,10 +55,10 @@ inline at::Tensor linear_weight_prepack(const at::Tensor &weight, int64_t num_th int64_t n = weight.sizes()[0]; int64_t k = weight.sizes()[1]; int64_t ldb = weight.strides()[0]; - int64_t pack_size = kutacc_gemm_pack_get_size('A', 'T', 'N', n, 0, k); + int64_t pack_size = kutacc_af2_gemm_pack_get_size('A', 'T', 'N', n, 0, k); at::Tensor result = weight.new_empty({pack_size}); - kutacc_linear_weight_prepack((__bf16 *)weight.data_ptr(), (__bf16 *)result.data_ptr(), n, k, ldb); + kutacc_af2_linear_weight_prepack((__bf16 *)weight.data_ptr(), (__bf16 *)result.data_ptr(), n, k, ldb); return result; } diff --git a/csrc/utils/bf16.h b/csrc/utils/bf16.h index 2f0d0fc..1945369 100644 --- a/csrc/utils/bf16.h +++ b/csrc/utils/bf16.h @@ -11,8 +11,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef KUTACC_UTILS_BF16_H -#define KUTACC_UTILS_BF16_H +#ifndef KPEX_UTILS_BF16_H +#define KPEX_UTILS_BF16_H #include #include diff --git a/csrc/utils/check.h b/csrc/utils/check.h index b9ddad3..893472c 100644 --- a/csrc/utils/check.h +++ b/csrc/utils/check.h @@ -11,8 +11,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef KUTACC_UTILS_CHECK_H -#define KUTACC_UTILS_CHECK_H +#ifndef KPEX_UTILS_CHECK_H +#define KPEX_UTILS_CHECK_H #include #include diff --git a/csrc/utils/layernorm.h b/csrc/utils/layernorm.h index 4d79a90..1aa5fce 100644 --- a/csrc/utils/layernorm.h +++ b/csrc/utils/layernorm.h @@ -11,8 +11,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef KUTACC_UTILS_LAYERNORM_H -#define KUTACC_UTILS_LAYERNORM_H +#ifndef KPEX_UTILS_LAYERNORM_H +#define KPEX_UTILS_LAYERNORM_H #include #include @@ -37,7 +37,7 @@ inline at::Tensor layernorm(const at::Tensor &act, const at::Tensor &weight_, co int64_t mi, ni; at::native::data_index_init(start, mi, m, ni, n); for ([[maybe_unused]] int64_t _ : c10::irange(start, end)) { - kutacc_layernorm( + kutacc_af2_layernorm( (__bf16 *)act.data_ptr() + mi * act.strides()[0] + ni * act.strides()[1], (float *)weight.data_ptr(), (float *)bias.data_ptr(), len, 1e-5, (__bf16 *)out.data_ptr() + mi * out.strides()[0] + ni * out.strides()[1]); diff --git a/csrc/utils/linear.h b/csrc/utils/linear.h new file mode 100644 index 0000000..834a3c7 --- /dev/null +++ b/csrc/utils/linear.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ +#ifndef KPEX_LINEAR_H +#define KPEX_LINEAR_H + +#include +#include +#include +#include +#include "utils/check.h" +#include "utils/TensorWrapper.h" +#include "kutacc.h" + +namespace alphafold { +inline at::Tensor linear(const at::Tensor &act, const at::Tensor &weight, std::optional bias, + std::optional result_ = std::nullopt) +{ + KPEX_CHECK(act.dim() >= 2, act.dim()); + KPEX_CHECK(weight.dim() >= 2, weight.dim()); + if (bias.has_value()) { + KPEX_CHECK(bias.value().dim() == weight.dim() - 1, bias.value().dim()); + } + int64_t beta; + at::Tensor result; + if (!result_.has_value()) { + auto result_sizes = act.sizes().vec(); + result_sizes.pop_back(); + result_sizes.insert(result_sizes.end(), weight.sizes().begin(), weight.sizes().end() - 1); + result = at::empty(result_sizes, act.options()); + beta = 0; + } else { + result = result_.value(); + beta = 1; + } + if (bias.has_value()) { + auto bias_ = bias.value(); + auto bias_data = (float *)bias_.data_ptr(); + auto act_tw = convert_to_tensor_wrapper(act); + auto weight_tw = convert_to_tensor_wrapper(weight); + auto result_tw = convert_to_tensor_wrapper(result); + kutacc_af2_linear(act_tw.get_tensor(), weight_tw.get_tensor(), bias_data, result_tw.get_tensor(), beta); + return result; + } else { + auto act_tw = convert_to_tensor_wrapper(act); + auto weight_tw = convert_to_tensor_wrapper(weight); + auto result_tw = convert_to_tensor_wrapper(result); + kutacc_af2_linear(act_tw.get_tensor(), weight_tw.get_tensor(), nullptr, result_tw.get_tensor(), beta); + return result; + } + // auto bias_data = bias.has_value() ? (float *)bias.value().data_ptr() : nullptr; + // auto act_tw = convert_to_tensor_wrapper(act); + // auto weight_tw = convert_to_tensor_wrapper(weight); + // auto result_tw = convert_to_tensor_wrapper(result); + // kutacc_af2_linear(act_tw.get_tensor(), weight_tw.get_tensor(), bias_data, result_tw.get_tensor(), beta); + // return result; +} +} + +#endif \ No newline at end of file diff --git a/csrc/utils/memory.h b/csrc/utils/memory.h index ed8f517..d611b56 100644 --- a/csrc/utils/memory.h +++ b/csrc/utils/memory.h @@ -11,8 +11,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef KUTACC_UTILS_MEMORY_H -#define KUTACC_UTILS_MEMORY_H +#ifndef KPEX_UTILS_MEMORY_H +#define KPEX_UTILS_MEMORY_H #include #include diff --git a/csrc/utils/parallel.h b/csrc/utils/parallel.h index 45139fa..890c5ec 100644 --- a/csrc/utils/parallel.h +++ b/csrc/utils/parallel.h @@ -11,8 +11,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef KUTACC_UTILS_PARALLEL_H -#define KUTACC_UTILS_PARALLEL_H +#ifndef KPEX_UTILS_PARALLEL_H +#define KPEX_UTILS_PARALLEL_H #include #include diff --git a/kpex/tpp/alphafold/alphafold.py b/kpex/tpp/alphafold/alphafold.py index d9d0d1c..5ef17ff 100644 --- a/kpex/tpp/alphafold/alphafold.py +++ b/kpex/tpp/alphafold/alphafold.py @@ -108,6 +108,64 @@ def transition_forward(self, act, mask): out = kernel.alphafold.transition(act.to(torch.bfloat16), self.kpex_weights) return out +def invariant_point_forward(self, s, z, r, mask): + if not hasattr(self, "kpex_weights"): + self.kpex_weights = kernel.alphafold.InvariantPointAttentionWeight( + self.c_s, + self.c_z, + self.c_hidden, + self.no_heads, + self.no_qk_points, + self.no_v_points, + self.is_multimer, + self.linear_q.weight, + self.linear_q.bias, + self.linear_kv.weight, + self.linear_kv.bias, + self.linear_q_points.weight, + self.linear_q_points.bias, + self.linear_kv_points.weight, + self.linear_kv_points.bias, + self.linear_b.weight, + self.linear_b.bias, + self.head_weights, + self.linear_out.weight, + self.linear_out.bias, + ) + out = kernel.alphafold.invariant_point_attention( + s.to(torch.bfloat16), + z.to(torch.bfloat16), + r._trans, + r._rots.get_rot_mats(), + mask.to(torch.bfloat16), + self.kpex_weights, + ) + return out + +def rot_vec_mul(r, t): + return kernel.alphafold.rigid_rot_vec_mul(t, r) + +def rot_to_quat( + rot: torch.Tensor +): + if (rot.shape[-2:] != (3,3)): + return ValueError("Input rotation is incorrectly shaped") + + rot = [[rot[..., i, j] for j in range(3)] for i in range(3)] + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + + k = [ + [xx + yy + zz, zy - yz, xz - zx, yx - xy], + [zy - yz, xx - yy - zz, xy + yx, xz + zx], + [xz - zx, xy + yx, yy - xx - zz, yz + zy], + [yx - xy, xz + zx, yz + zy, zz - xx - yy], + ] + + k = (1. / 3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2) + _, vectors = np.linalg.eigh(k.numpy()) + return torch.from_numpy(vectors[..., -1]) + + def kpex_alphafold(model, model_config, dtype=torch.float): new_model = copy.deepcopy(model) evoformer = new_model.model.impl.evoformer @@ -175,6 +233,26 @@ def kpex_alphafold(model, model_config, dtype=torch.float): transition_forward, block.pair_transition ) + + # print("new model is:") + # print(new_model.model.impl.structure_module) + # print("new model print end") + + structure_model = new_model.model.impl.structure_module.model + # print("new model is:") + # print(new_model.model.impl.structure_module.model.ipa) + # print("new model print end") + structure_model.ipa.forward = types.MethodType( + invariant_point_forward, + structure_model.ipa + ) + + # rigid_utils = new_model.model.impl.structure_module.model.utils.rigid_utils + # rigid_utils.rot_vec_mul = rot_vec_mul + # rigid_utils.rot_matmul = kernel.alphafold.rigid_rot_matmul + # rigid_utils.rot_to_quat = rot_to_quat + + return new_model -- Gitee From 466fd790eac339792df985f0f6d6b61dbfd73a51 Mon Sep 17 00:00:00 2001 From: tedeasonwang Date: Fri, 19 Sep 2025 17:29:19 +0800 Subject: [PATCH 2/3] delete redundant code --- csrc/tpp/alphafold/gating_attention.cpp | 4 ---- csrc/utils/linear.h | 6 ------ 2 files changed, 10 deletions(-) diff --git a/csrc/tpp/alphafold/gating_attention.cpp b/csrc/tpp/alphafold/gating_attention.cpp index a9c7ce2..3069f71 100644 --- a/csrc/tpp/alphafold/gating_attention.cpp +++ b/csrc/tpp/alphafold/gating_attention.cpp @@ -133,11 +133,7 @@ at::Tensor gating_attention(at::Tensor &q_data, at::Tensor &m_data, at::Tensor & auto out_tw = convert_to_tensor_wrapper(out); -<<<<<<< HEAD kutacc_af2_gating_attention(input_tw.get_tensor(), q_tw.get_tensor(), k_tw.get_tensor(), v_tw.get_tensor(), -======= - kutacc_af2_gating_attention(input_tw.get_tensor(), q_tw.get_tensor(), k_tw.get_tensor(), v_tw.get_tensor(), ->>>>>>> 50190d6ef2464092e1011a08f515d3e03f42f4c9 gate_tw.get_tensor(), weighted_avg_tw.get_tensor(), batch, seq_len, m_data_tw.get_tensor(), bias_tw.get_tensor(), nonbatched_bias_tw.get_tensor(), query_w_tw.get_tensor(), key_w_tw.get_tensor(), value_w_tw.get_tensor(), gating_w_tw.get_tensor(), gating_b_tw.get_tensor(), output_w_tw.get_tensor(), diff --git a/csrc/utils/linear.h b/csrc/utils/linear.h index 834a3c7..1ae8687 100644 --- a/csrc/utils/linear.h +++ b/csrc/utils/linear.h @@ -58,12 +58,6 @@ inline at::Tensor linear(const at::Tensor &act, const at::Tensor &weight, std::o kutacc_af2_linear(act_tw.get_tensor(), weight_tw.get_tensor(), nullptr, result_tw.get_tensor(), beta); return result; } - // auto bias_data = bias.has_value() ? (float *)bias.value().data_ptr() : nullptr; - // auto act_tw = convert_to_tensor_wrapper(act); - // auto weight_tw = convert_to_tensor_wrapper(weight); - // auto result_tw = convert_to_tensor_wrapper(result); - // kutacc_af2_linear(act_tw.get_tensor(), weight_tw.get_tensor(), bias_data, result_tw.get_tensor(), beta); - // return result; } } -- Gitee From eb18f8280ff6dc29b64c3780012c596747762249 Mon Sep 17 00:00:00 2001 From: tedeasonwang Date: Sat, 20 Sep 2025 00:03:35 +0800 Subject: [PATCH 3/3] clean kpex_alphafold unnecessary comment --- kpex/tpp/alphafold/alphafold.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/kpex/tpp/alphafold/alphafold.py b/kpex/tpp/alphafold/alphafold.py index 5ef17ff..bc8c81b 100644 --- a/kpex/tpp/alphafold/alphafold.py +++ b/kpex/tpp/alphafold/alphafold.py @@ -233,26 +233,13 @@ def kpex_alphafold(model, model_config, dtype=torch.float): transition_forward, block.pair_transition ) - - # print("new model is:") - # print(new_model.model.impl.structure_module) - # print("new model print end") structure_model = new_model.model.impl.structure_module.model - # print("new model is:") - # print(new_model.model.impl.structure_module.model.ipa) - # print("new model print end") structure_model.ipa.forward = types.MethodType( invariant_point_forward, structure_model.ipa ) - # rigid_utils = new_model.model.impl.structure_module.model.utils.rigid_utils - # rigid_utils.rot_vec_mul = rot_vec_mul - # rigid_utils.rot_matmul = kernel.alphafold.rigid_rot_matmul - # rigid_utils.rot_to_quat = rot_to_quat - - return new_model -- Gitee