From 04344b7c2639e6cb2a1b519487ab827ab9b571af Mon Sep 17 00:00:00 2001 From: XeonYZhang Date: Mon, 22 Sep 2025 10:02:56 +0800 Subject: [PATCH 1/3] repair some bugs when compiling outer_product_mean_ops --- csrc/tpp/alphafold/bind.h | 20 ++- csrc/tpp/alphafold/kupl.h | 30 ++++ csrc/tpp/alphafold/outer_product_mean.cpp | 163 ++++++++++++++++++ csrc/tpp/alphafold/outer_product_mean.h | 46 +++++ .../tpp/alphafold/triangle_multiplication.cpp | 6 +- csrc/utils/TensorWrapper.h | 33 ++++ csrc/utils/all_gather.h | 40 +++++ csrc/utils/transpose.h | 19 +- kpex/tpp/alphafold/alphafold.py | 38 +++- 9 files changed, 380 insertions(+), 15 deletions(-) create mode 100644 csrc/tpp/alphafold/kupl.h create mode 100644 csrc/tpp/alphafold/outer_product_mean.cpp create mode 100644 csrc/tpp/alphafold/outer_product_mean.h create mode 100644 csrc/utils/all_gather.h diff --git a/csrc/tpp/alphafold/bind.h b/csrc/tpp/alphafold/bind.h index 3ac8a35..8d17385 100644 --- a/csrc/tpp/alphafold/bind.h +++ b/csrc/tpp/alphafold/bind.h @@ -22,7 +22,10 @@ #include "rigid.h" #include "global_attention.h" #include "triangle_multiplication.h" -// #include "utils/layernorm.h" +#include "outer_product_mean.h" +#include "kupl.h" +#include +#include namespace alphafold { inline void bind(pybind11::module &m) @@ -36,6 +39,15 @@ inline void bind(pybind11::module &m) submodule.def("gating_attention", &gating_attention, py::arg("q_data"), py::arg("m_data"), py::arg("bias"), py::arg("nonbatched_bias"), py::arg("weights"), py::arg("block_size") = std::nullopt); + py::class_(submodule, "OuterProductMeanWeight") + .def(py::init(), + py::arg("input_ln_w"), py::arg("input_ln_b"), py::arg("left_proj_w"), py::arg("left_proj_b"), + py::arg("right_proj_w"), py::arg("right_proj_b"), py::arg("output_w"), py::arg("output_b")); + submodule.def("outer_product_mean", &outer_product_mean, py::arg("act"), py::arg("mask"), py::arg("weights"), + py::arg("left_block_size") = std::nullopt, py::arg("right_block_size") = std::nullopt, + py::arg("no_mpi") = false); + py::class_(submodule, "GlobalAttentionWeight") .def(py::init(), @@ -80,6 +92,12 @@ inline void bind(pybind11::module &m) py::arg("output_proj_b")); submodule.def("triangle_multiplication", &triangle_multiplication, py::arg("act"), py::arg("mask"), py::arg("weights")); // submodule.def("layernorm", &layernorm, py::arg("act"), py::arg("weight_"), py::arg("bias_")); + + auto mpimodule = m.def_submodule("mpi"); + mpimodule.def("initialize", &initialize, py::arg("world_size"), py::arg("rank"), py::arg("buffer_size")); + mpimodule.def("finalize", &finalize); + mpimodule.def("all_gather", &af2_all_gather, py::arg("data"), py::arg("m"), py::arg("n")); + mpimodule.def("all2all", &af2_transpose, py::arg("data"), py::arg("m"), py::arg("n")); } } diff --git a/csrc/tpp/alphafold/kupl.h b/csrc/tpp/alphafold/kupl.h new file mode 100644 index 0000000..5dbc0ce --- /dev/null +++ b/csrc/tpp/alphafold/kupl.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#ifndef KPEX_KUPL_H +#define KPEX_KUPL_H + +#include + +void initialize(int64_t _world_size, int64_t _rank, int64_t _buffer_size) +{ + kutacc_initialize(_world_size, _rank, _buffer_size); +} + +void finalize() +{ + kutacc_finalize(); +} + +#endif \ No newline at end of file diff --git a/csrc/tpp/alphafold/outer_product_mean.cpp b/csrc/tpp/alphafold/outer_product_mean.cpp new file mode 100644 index 0000000..fddd878 --- /dev/null +++ b/csrc/tpp/alphafold/outer_product_mean.cpp @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#include "outer_product_mean.h" +#include +#include +#include +#include +#include +#include +#include + +namespace alphafold { + +namespace { +void default_block_size(int64_t n_seq, int64_t n_res_gather, int64_t &left_block_size, int64_t &right_block_size) +{ + if (n_seq == 512) { + if (n_res_gather < 200) { + left_block_size = 4; + } else if (n_res_gather < 400) { + left_block_size = 8; + } else { + left_block_size = 16; + } + if (n_res_gather < 200) { + right_block_size = 16; + } else if (n_res_gather < 300) { + right_block_size = 32; + } else if (n_res_gather < 700) { + right_block_size = 80; + } else { + right_block_size = 112; + } + } else { + left_block_size = std::clamp(n_res_gather / 24, (int64_t)8, (int64_t)80); + right_block_size = std::clamp(n_res_gather / 8, (int64_t)16, (int64_t)192); + } +} +} + +OuterProductMeanWeight::OuterProductMeanWeight(at::Tensor &input_ln_w, at::Tensor &input_ln_b, at::Tensor &left_proj_w, + at::Tensor &left_proj_b, at::Tensor &right_proj_w, + at::Tensor &right_proj_b, at::Tensor &output_w, at::Tensor &output_b) + : c_m(input_ln_w.sizes()[0]), + c_i(left_proj_w.sizes()[0]), + c_z(output_w.sizes()[0]) +{ + KPEX_CHECK_TENSOR_SHAPE(input_ln_w, c_m); + KPEX_CHECK_TENSOR_SHAPE(input_ln_b, c_m); + KPEX_CHECK_TENSOR_SHAPE(left_proj_w, c_i, c_m); + KPEX_CHECK_TENSOR_SHAPE(left_proj_b, c_i); + KPEX_CHECK_TENSOR_SHAPE(right_proj_w, c_i, c_m); + KPEX_CHECK_TENSOR_SHAPE(right_proj_b, c_i); + KPEX_CHECK_TENSOR_SHAPE(output_w, c_z, c_i, c_i); + KPEX_CHECK_TENSOR_SHAPE(output_b, c_z); + + auto float_opt = left_proj_w.options().device(kpex::device()).dtype(c10::kFloat); + auto bf16_opt = left_proj_w.options().device(kpex::device()).dtype(c10::kBFloat16); + this->left_proj_w = left_proj_w.to(bf16_opt).contiguous(); + this->right_proj_w = right_proj_w.to(bf16_opt).contiguous(); + this->outer_w = output_w.to(bf16_opt).contiguous().view({c_z, c_i * c_i}); + + auto left_proj_w_res = linear_weight_prepack(this->left_proj_w); + auto right_proj_w_res = linear_weight_prepack(this->right_proj_w); + auto outer_w_res = linear_weight_prepack(this->outer_w); + this->left_proj_w = left_proj_w_res; + this->right_proj_w = right_proj_w_res; + this->outer_w = outer_w_res; + + this->input_ln_w = input_ln_w.to(float_opt).contiguous(); + this->input_ln_b = input_ln_b.to(float_opt).contiguous(); + this->left_proj_b = left_proj_b.to(float_opt).contiguous(); + this->right_proj_b = right_proj_b.to(float_opt).contiguous(); + this->outer_b = output_b.to(float_opt).contiguous(); +} + +at::Tensor outer_product_mean(at::Tensor &act, at::Tensor &mask, const OuterProductMeanWeight &weights, + std::optional left_block_size_, std::optional right_block_size_, + bool no_mpi) +{ + at::Tensor out = act.new_empty({act.sizes()[1], mask.sizes()[1], weights.c_z}); + int64_t n_seq = act.sizes()[0]; + int64_t n_res = act.sizes()[1]; + int64_t n_res_gather = mask.sizes()[1]; + int64_t c_m = weights.c_m; + int64_t c_i = weights.c_i; + int64_t c_z = weights.c_z; + int64_t left_block_size; + int64_t right_block_size; + default_block_size(n_seq, n_res_gather, left_block_size, right_block_size); + left_block_size = left_block_size_.value_or(left_block_size); + right_block_size = right_block_size_.value_or(right_block_size); + + KPEX_CHECK(act.dtype() == c10::kBFloat16, act.dtype()); + KPEX_CHECK(mask.dtype() == c10::kBFloat16, mask.dtype()); + KPEX_CHECK_TENSOR_SHAPE(act, n_seq, n_res, c_m); + KPEX_CHECK_TENSOR_SHAPE(mask, n_seq, n_res_gather); + act = act.contiguous(); + mask = mask.transpose(0, 1).contiguous(); + + at::Tensor left_proj = act.new_empty({c_i, n_res, n_seq}); + at::Tensor right_proj = act.new_empty({c_i, n_res, n_seq}); + at::Tensor left_proj_ = act.new_empty({n_res, c_i, n_seq}); + at::Tensor right_proj_ = act.new_empty({n_res, c_i, n_seq}); + at::Tensor norm = mask.new_empty({n_res, n_res_gather}); + int64_t mask_bias = 0; + + if (n_res_gather > n_res) { + mask_bias = rank * ((n_res_gather + world_size - 1) / world_size) * mask.strides()[0]; + } + + at::Tensor input_act = layernorm(act.transpose(0, 1), weights.input_ln_w, weights.input_ln_b); + + auto input_act_tw = convert_to_tensor_wrapper(input_act); + auto mask_tw = convert_to_tensor_wrapper(mask); + auto left_proj_w_tw = convert_to_tensor_wrapper(weights.left_proj_w); + auto left_proj_b_tw = convert_to_tensor_wrapper(weights.left_proj_b); + auto right_proj_w_tw = convert_to_tensor_wrapper(weights.right_proj_w); + auto right_proj_b_tw = convert_to_tensor_wrapper(weights.right_proj_b); + auto left_proj_tw = convert_to_tensor_wrapper(left_proj); + auto right_proj_tw = convert_to_tensor_wrapper(right_proj); + auto left_proj_tw_ = convert_to_tensor_wrapper(left_proj_); + auto right_proj_tw_ = convert_to_tensor_wrapper(right_proj_); + auto norm_tw = convert_to_tensor_wrapper(norm); + + kutacc_af2_outer_product_mean_calc_left_and_right_mul( + left_proj_tw.get_tensor(), right_proj_tw.get_tensor(), left_proj_tw_.get_tensor(), right_proj_tw_.get_tensor(), + input_act_tw.get_tensor(), mask_tw.get_tensor(), norm_tw.get_tensor(), left_proj_w_tw.get_tensor(), + left_proj_b_tw.get_tensor(), right_proj_w_tw.get_tensor(), right_proj_b_tw.get_tensor(), c_i, c_m, n_res, + n_res_gather, n_seq, mask_bias); + + if (n_res_gather > n_res) { + if (!no_mpi) { + right_proj_ = af2_all_gather(right_proj_, n_res_gather, c_i); + } else { + right_proj_ = at::empty({n_res_gather, c_i, n_seq}, right_proj_.options()); + } + } + + auto output_w_tw = convert_to_tensor_wrapper(weights.outer_w); + auto output_b_tw = convert_to_tensor_wrapper(weights.outer_b); + auto out_tw = convert_to_tensor_wrapper(out); + + kutacc_af2_outer_product_mean_chunk(output_b_tw.get_tensor(), output_w_tw.get_tensor(), out_tw.get_tensor(), + left_proj_tw_.get_tensor(), right_proj_tw_.get_tensor(), norm_tw.get_tensor(), + left_block_size, right_block_size, c_i, c_z, n_res, n_res_gather, n_seq); + + return out; +} + +} \ No newline at end of file diff --git a/csrc/tpp/alphafold/outer_product_mean.h b/csrc/tpp/alphafold/outer_product_mean.h new file mode 100644 index 0000000..d7a64de --- /dev/null +++ b/csrc/tpp/alphafold/outer_product_mean.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#ifndef KPEX_OUTER_PRODUCT_MEAN_H +#define KPEX_OUTER_PRODUCT_MEAN_H + +#include + +namespace alphafold { + + struct OuterProductMeanWeight { + int64_t c_m; + int64_t c_i; + int64_t c_z; + + at::Tensor input_ln_w; + at::Tensor input_ln_b; + at::Tensor left_proj_w; + at::Tensor left_proj_b; + at::Tensor right_proj_w; + at::Tensor right_proj_b; + at::Tensor outer_w; + at::Tensor outer_b; + + OuterProductMeanWeight(at::Tensor &input_ln_w, at::Tensor &input_ln_b, at::Tensor &left_proj_w, + at::Tensor &left_proj_b, at::Tensor &right_proj_w, at::Tensor &right_proj_b, at::Tensor &output_w, + at::Tensor &output_b); + }; + + at::Tensor outer_product_mean(at::Tensor &act, at::Tensor &mask, const OuterProductMeanWeight &weights, + std::optional left_block_size, std::optional right_block_size, bool no_mpi); + +} + +#endif \ No newline at end of file diff --git a/csrc/tpp/alphafold/triangle_multiplication.cpp b/csrc/tpp/alphafold/triangle_multiplication.cpp index 61253aa..4fc15ea 100644 --- a/csrc/tpp/alphafold/triangle_multiplication.cpp +++ b/csrc/tpp/alphafold/triangle_multiplication.cpp @@ -70,8 +70,8 @@ namespace alphafold { right_gate_b_tw.get_tensor(), n_res, n_res_gather, c_o, c_i, input_prepack); if (n_res < n_res_gather) { - left_proj_act = transpose(left_proj_act, c_i, n_res_gather); - right_proj_act = transpose(right_proj_act, c_i, n_res_gather); + left_proj_act = af2_transpose(left_proj_act, c_i, n_res_gather); + right_proj_act = af2_transpose(right_proj_act, c_i, n_res_gather); } center_act = act.new_empty({left_proj_act.sizes()[0], n_res_gather, n_res_gather}); @@ -81,7 +81,7 @@ namespace alphafold { auto right_proj_act_new_tw = convert_to_tensor_wrapper(right_proj_act); kutacc_af2_triangle_multiplication_equation(center_act_tw.get_tensor(), left_proj_act_new_tw.get_tensor(), right_proj_act_new_tw.get_tensor(), n_res_gather, weights.is_incoming); - center_act = transpose(center_act, c_i, n_res_gather); + center_act = af2_transpose(center_act, c_i, n_res_gather); } else { kutacc_af2_triangle_multiplication_equation(center_act_tw.get_tensor(), left_proj_act_tw.get_tensor(), right_proj_act_tw.get_tensor(), n_res_gather, weights.is_incoming); diff --git a/csrc/utils/TensorWrapper.h b/csrc/utils/TensorWrapper.h index d5b5fcf..1b4cd53 100644 --- a/csrc/utils/TensorWrapper.h +++ b/csrc/utils/TensorWrapper.h @@ -50,6 +50,39 @@ inline const kutacc::TensorWrapper convert_to_tensor_wrapper(const at::Tensor &t ); } +inline kutacc::TensorWrapper convert_to_tensor_wrapper_comm(at::Tensor &tensor) { + int64_t scalar_size = c10::elementSize(tensor.scalar_type()); + kutacc::DType dtype; + auto scalar_type = tensor.scalar_type(); + if (scalar_type == at::kBFloat16) { + dtype = kutacc::kBF16; + } + return kutacc::TensorWrapper( + tensor.data_ptr(), + {tensor.sizes()[0], tensor.sizes()[1], tensor.sizes()[2] * scalar_size}, + {tensor.strides()[0] * scalar_size, tensor.strides()[1] * scalar_size, 1}, + tensor.dim(), + dtype + ); +} + +inline const kutacc::TensorWrapper convert_to_tensor_wrapper_comm(const at::Tensor &tensor) +{ + int64_t scalar_size = c10::elementSize(tensor.scalar_type()); + kutacc::DType dtype; + auto scalar_type = tensor.scalar_type(); + if (scalar_type == at::kBFloat16) { + dtype = kutacc::kBF16; + } + return kutacc::TensorWrapper( + tensor.data_ptr(), + {tensor.sizes()[0], tensor.sizes()[1], tensor.sizes()[2] * scalar_size}, + {tensor.strides()[0] * scalar_size, tensor.strides()[1] * scalar_size, 1}, + tensor.dim(), + dtype + ); +} + inline at::Tensor linear_weight_prepack(const at::Tensor &weight, int64_t num_threads = 0) { int64_t n = weight.sizes()[0]; diff --git a/csrc/utils/all_gather.h b/csrc/utils/all_gather.h new file mode 100644 index 0000000..218dc27 --- /dev/null +++ b/csrc/utils/all_gather.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#ifndef KPEX_ALL_GATHER_H +#define KPEX_ALL_GATHER_H + +#include +#include +#include +#include +#include +#include "check.h" + +inline at::Tensor af2_all_gather(at::Tensor &data, int64_t m, int64_t n) +{ + int64_t len = data.sizes()[2]; + + KPEX_CHECK(data.strides()[2] == 1, data.strides()[2]); + + at::Tensor out = at::empty({m, n, len}, data.options()); + + auto data_tw = convert_to_tensor_wrapper_comm(data); + auto out_tw = convert_to_tensor_wrapper_comm(out); + + kutacc_af2_all_gather(data_tw.get_tensor(), out_tw.get_tensor()); + return out; +} + +#endif \ No newline at end of file diff --git a/csrc/utils/transpose.h b/csrc/utils/transpose.h index 41bd30a..286788a 100644 --- a/csrc/utils/transpose.h +++ b/csrc/utils/transpose.h @@ -14,30 +14,33 @@ #ifndef KPEX_TRANSPOSE_H #define KPEX_TRANSPOSE_H + +#include #include -#include -#include "TensorWrapper.h" -#include "check.h" #include -#include #include +#include +#include +#include "check.h" -inline at::Tensor transpose(at::Tensor &data, int64_t m, int64_t n) +inline at::Tensor af2_transpose(at::Tensor &data, int64_t m, int64_t n) { int64_t block_m = (m + world_size - 1) / world_size; int64_t block_n = (n + world_size - 1) / world_size; int64_t len = data.sizes()[2]; - int64_t scalar_size = c10::elementSize(data.scalar_type()); KPEX_CHECK(data.strides()[2] == 1, data.strides()[2]); + at::Tensor out; if (data.sizes()[0] < m) { out = at::empty({m, std::min(block_n, n - rank * block_n), len}, data.options()); } else { out = at::empty({std::min(block_m, m - rank * block_m), n, len}, data.options()); } - auto data_tw = convert_to_tensor_wrapper(data); - auto out_tw = convert_to_tensor_wrapper(out); + + auto data_tw = convert_to_tensor_wrapper_comm(data); + auto out_tw = convert_to_tensor_wrapper_comm(out); + kutacc_af2_transpose(data_tw.get_tensor(), out_tw.get_tensor()); return out; } diff --git a/kpex/tpp/alphafold/alphafold.py b/kpex/tpp/alphafold/alphafold.py index d015da9..5586b05 100644 --- a/kpex/tpp/alphafold/alphafold.py +++ b/kpex/tpp/alphafold/alphafold.py @@ -115,6 +115,30 @@ def global_attention_forward(self, q_data, m_data, q_mask, bias): ) return out +def outer_product_mean_forward( + self, act, mask, left_block_size=None, right_block_size=None, no_mpi=False +): + if not hasattr(self, "kpex_weights"): + self.kpex_weights = kernel.alphafold.OuterProductMeanWeight( + self.layer_norm_input.weight, + self.layer_norm_input.bias, + self.left_projection.weight, + self.left_projection.bias, + self.right_projection.weight, + self.right_projection.bias, + self.output_w.permute(2, 0, 1), + self.output_b, + ) + out = kernel.alphafold.outer_product_mean( + act.to(torch.bfloat16), + mask.to(torch.bfloat16), + self.kpex_weights, + left_block_size, + right_block_size, + no_mpi, + ) + return out + def transition_forward(self, act, mask): if not hasattr(self, "kpex_weights"): self.kpex_weights = kernel.alphafold.TransitionWeight( @@ -236,6 +260,10 @@ def kpex_alphafold(model, model_config, dtype=torch.float): gating_attention_forward, block.triangle_attention_ending_node.attention ) + block.outer_product_mean.forward = types.MethodType( + outer_product_mean_forward, + block.outer_product_mean + ) block.msa_transition.forward = types.MethodType( transition_forward, block.msa_transition @@ -270,6 +298,10 @@ def kpex_alphafold(model, model_config, dtype=torch.float): gating_attention_forward, block.triangle_attention_ending_node.attention ) + block.outer_product_mean.forward = types.MethodType( + outer_product_mean_forward, + block.outer_product_mean + ) block.msa_transition.forward = types.MethodType( transition_forward, block.msa_transition @@ -292,15 +324,15 @@ def kpex_alphafold(model, model_config, dtype=torch.float): block.pair_transition.forward = types.MethodType( transition_forward, block.pair_transition - ) + ) block.triangle_multiplication_outgoing.forward = types.MethodType( triangleMultiplication_forward, block.triangle_multiplication_outgoing - ) + ) block.triangle_multiplication_incoming.forward = types.MethodType( triangleMultiplication_forward, block.triangle_multiplication_incoming - ) + ) structure_model = new_model.model.impl.structure_module.model structure_model.ipa.forward = types.MethodType( -- Gitee From 682394a75e90782ef9370a3a39ae1b2e05e1cd6d Mon Sep 17 00:00:00 2001 From: XeonYZhang Date: Mon, 22 Sep 2025 10:50:47 +0800 Subject: [PATCH 2/3] fix merge conflict issuse --- csrc/tpp/alphafold/bind.h | 12 +++++------- kpex/tpp/alphafold/alphafold.py | 31 +++++++++++++------------------ 2 files changed, 18 insertions(+), 25 deletions(-) diff --git a/csrc/tpp/alphafold/bind.h b/csrc/tpp/alphafold/bind.h index e0d4686..6cc4bb3 100644 --- a/csrc/tpp/alphafold/bind.h +++ b/csrc/tpp/alphafold/bind.h @@ -52,11 +52,10 @@ inline void bind(pybind11::module &m) py::class_(submodule, "GlobalAttentionWeight") .def(py::init(), - py::arg("input_ln_w"), py::arg("input_ln_b"), py::arg("left_proj_w"), py::arg("left_proj_b"), - py::arg("right_proj_w"), py::arg("right_proj_b"), py::arg("output_w"), py::arg("output_b")); - submodule.def("outer_product_mean", &outer_product_mean, py::arg("act"), py::arg("mask"), py::arg("weights"), - py::arg("left_block_size") = std::nullopt, py::arg("right_block_size") = std::nullopt, - py::arg("no_mpi") = false); + py::arg("query_w"), py::arg("key_w"), py::arg("value_w"), py::arg("gate_w"), py::arg("gate_b"), + py::arg("output_w"), py::arg("output_b")); + submodule.def("global_attention", &global_attention, py::arg("q_data"), py::arg("m_data"), py::arg("q_mask"), + py::arg("weights")); py::class_(submodule, "TransitionWeight") .def(py::init(), @@ -78,11 +77,10 @@ inline void bind(pybind11::module &m) submodule.def("invariant_point_attention", &invariant_point_attention, py::arg("s"), py::arg("z"), py::arg("rigid_trans"), py::arg("rigid_rot_mats"), py::arg("mask"), py::arg("weights")); - submodule.def("rigid_rot_vec_mul", &rigid_rot_vec_mul, py::arg("pts"), py::arg("rot_mats"), py::arg("trans") = std::nullopt); - submodule.def("rigid_rot_matmul", &rigid_rot_matmul, py::arg("a"), py::arg("b")); + py::class_(submodule, "TriangleMultiplicationWeight") .def(py::init Date: Mon, 22 Sep 2025 10:53:31 +0800 Subject: [PATCH 3/3] bind.h cleancode --- csrc/tpp/alphafold/bind.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/tpp/alphafold/bind.h b/csrc/tpp/alphafold/bind.h index 6cc4bb3..702f5ae 100644 --- a/csrc/tpp/alphafold/bind.h +++ b/csrc/tpp/alphafold/bind.h @@ -42,7 +42,7 @@ inline void bind(pybind11::module &m) py::class_(submodule, "OuterProductMeanWeight") .def(py::init(), + at::Tensor &>(), py::arg("input_ln_w"), py::arg("input_ln_b"), py::arg("left_proj_w"), py::arg("left_proj_b"), py::arg("right_proj_w"), py::arg("right_proj_b"), py::arg("output_w"), py::arg("output_b")); submodule.def("outer_product_mean", &outer_product_mean, py::arg("act"), py::arg("mask"), py::arg("weights"), -- Gitee