diff --git a/README.md b/README.md index be00f60601365e449685186dcdf70347720e8797..7d2a059ef635eddfb08a66d4bb37cea160c0ea86 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ kunpeng-extension-for-pytorch is a package for extending the official Pytorch th **一、环境设置** -1.安装pytorch +1.在鲲鹏平台编译安装pytorch([pytorch使能KML操作指南](https://gitee.com/kunpeng-hpc/kunpeng-extension-for-pytorch/wikis/PyTorch%20v2.5.0%20%E5%AF%B9%E6%8E%A5KML)) 2.加载kutacc环境 diff --git a/csrc/tpp/alphafold/bind.h b/csrc/tpp/alphafold/bind.h index e1e079572407fe0b6ee85e8e4b1b462648cc1915..7ffe4f010a6e8b3b6bd3d2e2b6d4dd16a42d332a 100644 --- a/csrc/tpp/alphafold/bind.h +++ b/csrc/tpp/alphafold/bind.h @@ -17,10 +17,12 @@ #include #include "gating_attention.h" +#include "outer_product_mean.h" #include "transition.h" #include "global_attention.h" #include "triangle_multiplication.h" // #include "utils/layernorm.h" +#include "kupl.h" namespace alphafold { inline void bind(pybind11::module &m) @@ -34,13 +36,14 @@ inline void bind(pybind11::module &m) submodule.def("gating_attention", &gating_attention, py::arg("q_data"), py::arg("m_data"), py::arg("bias"), py::arg("nonbatched_bias"), py::arg("weights"), py::arg("block_size") = std::nullopt); - py::class_(submodule, "GlobalAttentionWeight") - .def(py::init(submodule, "OuterProductMeanWeight") + .def(py::init(), - py::arg("query_w"), py::arg("key_w"), py::arg("value_w"), py::arg("gate_w"), py::arg("gate_b"), - py::arg("output_w"), py::arg("output_b")); - submodule.def("global_attention", &global_attention, py::arg("q_data"), py::arg("m_data"), py::arg("q_mask"), - py::arg("weights")); + py::arg("input_ln_w"), py::arg("input_ln_b"), py::arg("left_proj_w"), py::arg("left_proj_b"), + py::arg("right_proj_w"), py::arg("right_proj_b"), py::arg("output_w"), py::arg("output_b")); + submodule.def("outer_product_mean", &outer_product_mean, py::arg("act"), py::arg("mask"), py::arg("weights"), + py::arg("left_block_size") = std::nullopt, py::arg("right_block_size") = std::nullopt, + py::arg("no_mpi") = false); py::class_(submodule, "TransitionWeight") .def(py::init(), @@ -59,6 +62,12 @@ inline void bind(pybind11::module &m) py::arg("output_proj_b")); submodule.def("triangle_multiplication", &triangle_multiplication, py::arg("act"), py::arg("mask"), py::arg("weights")); // submodule.def("layernorm", &layernorm, py::arg("act"), py::arg("weight_"), py::arg("bias_")); + + auto mpimodule = m.def_submodule("mpi"); + mpimodule.def("initialize", &initialize, py::arg("world_size"), py::arg("rank"), py::arg("buffer_size")); + mpimodule.def("finalize", &finalize); + mpimodule.def("all_gather", &af2_all_gather, py::arg("data"), py::arg("m"), py::arg("n")); + mpimodule.def("all2all", &af2_transpose, py::arg("data"), py::arg("m"), py::arg("n")); } } diff --git a/csrc/tpp/alphafold/kupl.h b/csrc/tpp/alphafold/kupl.h new file mode 100644 index 0000000000000000000000000000000000000000..8c610a8f92aab2787d8c82e0386b51587fca3e8b --- /dev/null +++ b/csrc/tpp/alphafold/kupl.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#ifndef KPEX_KUPL_H +#define KPEX_KUPL_H + +#include + +void initialize(int64_t _world_size, int64_t _rank, int64_t _buffer_size) +{ + kutacc_initialize(_world_size, _rank, _buffer_size); +} + +void finalize() +{ + kutacc_finalize(); +} + +#endif \ No newline at end of file diff --git a/csrc/tpp/alphafold/outer_product_mean.cpp b/csrc/tpp/alphafold/outer_product_mean.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ddf998f50c811bcc6344f7762e5773fcdf0b6c1e --- /dev/null +++ b/csrc/tpp/alphafold/outer_product_mean.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#include "outer_product_mean.h" +#include +#include +#include +#include +#include +#include + +namespace alphafold { + +namespace { +void default_block_size(int64_t n_seq, int64_t n_res_gather, int64_t &left_block_size, int64_t &right_block_size) +{ + if (n_seq == 512) { + if (n_res_gather < 200) { + left_block_size = 4; + } else if (n_res_gather < 400) { + left_block_size = 8; + } else { + left_block_size = 16; + } + if (n_res_gather < 200) { + right_block_size = 16; + } else if (n_res_gather < 300) { + right_block_size = 32; + } else if (n_res_gather < 700) { + right_block_size = 80; + } else { + right_block_size = 112; + } + } else { + left_block_size = std::clamp(n_res_gather / 24, (int64_t)8, (int64_t)80); + right_block_size = std::clamp(n_res_gather / 8, (int64_t)16, (int64_t)192); + } +} +} + +OuterProductMeanWeight::OuterProductMeanWeight(at::Tensor &input_ln_w, at::Tensor &input_ln_b, at::Tensor &left_proj_w, + at::Tensor &left_proj_b, at::Tensor &right_proj_w, + at::Tensor &right_proj_b, at::Tensor &output_w, at::Tensor &output_b) + : c_m(input_ln_w.sizes()[0]), + c_i(left_proj_w.sizes()[0]), + c_z(output_w.sizes()[0]) +{ + KPEX_CHECK_TENSOR_SHAPE(input_ln_w, c_m); + KPEX_CHECK_TENSOR_SHAPE(input_ln_b, c_m); + KPEX_CHECK_TENSOR_SHAPE(left_proj_w, c_i, c_m); + KPEX_CHECK_TENSOR_SHAPE(left_proj_b, c_i); + KPEX_CHECK_TENSOR_SHAPE(right_proj_w, c_i, c_m); + KPEX_CHECK_TENSOR_SHAPE(right_proj_b, c_i); + KPEX_CHECK_TENSOR_SHAPE(output_w, c_z, c_i, c_i); + KPEX_CHECK_TENSOR_SHAPE(output_b, c_z); + + auto float_opt = left_proj_w.options().device(kpex::device()).dtype(c10::kFloat); + auto bf16_opt = left_proj_w.options().device(kpex::device()).dtype(c10::kBFloat16); + this->left_proj_w = left_proj_w.to(bf16_opt).contiguous(); + this->right_proj_w = right_proj_w.to(bf16_opt).contiguous(); + this->outer_w = output_w.to(bf16_opt).contiguous().view({c_z, c_i * c_i}); + + auto left_proj_w_res = linear_weight_prepack(this->left_proj_w); + auto right_proj_w_res = linear_weight_prepack(this->right_proj_w); + auto outer_w_res = linear_weight_prepack(this->outer_w); + this->left_proj_w = left_proj_w_res; + this->right_proj_w = right_proj_w_res; + this->outer_w = outer_w_res; + + this->input_ln_w = input_ln_w.to(float_opt).contiguous(); + this->input_ln_b = input_ln_b.to(float_opt).contiguous(); + this->left_proj_b = left_proj_b.to(float_opt).contiguous(); + this->right_proj_b = right_proj_b.to(float_opt).contiguous(); + this->outer_b = output_b.to(float_opt).contiguous(); +} + +at::Tensor outer_product_mean(at::Tensor &act, at::Tensor &mask, const OuterProductMeanWeight &weights, + std::optional left_block_size_, std::optional right_block_size_, + bool no_mpi) +{ + at::Tensor out = act.new_empty({act.sizes()[1], mask.sizes()[1], weights.c_z}); + int64_t n_seq = act.sizes()[0]; + int64_t n_res = act.sizes()[1]; + int64_t n_res_gather = mask.sizes()[1]; + int64_t c_m = weights.c_m; + int64_t c_i = weights.c_i; + int64_t c_z = weights.c_z; + int64_t left_block_size; + int64_t right_block_size; + default_block_size(n_seq, n_res_gather, left_block_size, right_block_size); + left_block_size = left_block_size_.value_or(left_block_size); + right_block_size = right_block_size_.value_or(right_block_size); + + KPEX_CHECK(act.dtype() == c10::kBFloat16, act.dtype()); + KPEX_CHECK(mask.dtype() == c10::kBFloat16, mask.dtype()); + KPEX_CHECK_TENSOR_SHAPE(act, n_seq, n_res, c_m); + KPEX_CHECK_TENSOR_SHAPE(mask, n_seq, n_res_gather); + act = act.contiguous(); + mask = mask.transpose(0, 1).contiguous(); + + at::Tensor left_proj = act.new_empty({c_i, n_res, n_seq}); + at::Tensor right_proj = act.new_empty({c_i, n_res, n_seq}); + at::Tensor left_proj_ = act.new_empty({n_res, c_i, n_seq}); + at::Tensor right_proj_ = act.new_empty({n_res, c_i, n_seq}); + at::Tensor norm = mask.new_empty({n_res, n_res_gather}); + int64_t mask_bias = 0; + + if (n_res_gather > n_res) { + mask_bias = rank * ((n_res_gather + world_size - 1) / world_size) * mask.strides()[0]; + } + + // convert weights to kutacc::weights + // weights :: at::Tensor + // kutacc :: TencsorWrapper + + // 1、只kutacc 不行 + // 2、Kutacc + kpex + // 3、只kutacc (需要改OuterProductMean入参改变) + + at::Tensor input_act = layernorm(act.transpose(0, 1), weights.input_ln_w, weights.input_ln_b); + + auto input_act_tw = convert_to_tensor_wrapper(input_act); + auto mask_tw = convert_to_tensor_wrapper(mask); + auto left_proj_w_tw = convert_to_tensor_wrapper(weights.left_proj_w); + auto left_proj_b_tw = convert_to_tensor_wrapper(weights.left_proj_b); + auto right_proj_w_tw = convert_to_tensor_wrapper(weights.right_proj_w); + auto right_proj_b_tw = convert_to_tensor_wrapper(weights.right_proj_b); + auto left_proj_tw = convert_to_tensor_wrapper(left_proj); + auto right_proj_tw = convert_to_tensor_wrapper(right_proj); + auto left_proj_tw_ = convert_to_tensor_wrapper(left_proj_); + auto right_proj_tw_ = convert_to_tensor_wrapper(right_proj_); + auto norm_tw = convert_to_tensor_wrapper(norm); + + kutacc_outer_product_mean_calc_left_and_right_mul( + left_proj_tw.get_tensor(), right_proj_tw.get_tensor(), left_proj_tw_.get_tensor(), right_proj_tw_.get_tensor(), + input_act_tw.get_tensor(), mask_tw.get_tensor(), norm_tw.get_tensor(), left_proj_w_tw.get_tensor(), + left_proj_b_tw.get_tensor(), right_proj_w_tw.get_tensor(), right_proj_b_tw.get_tensor(), c_i, c_m, n_res, + n_res_gather, n_seq, mask_bias); + + if (n_res_gather > n_res) { + if (!no_mpi) { + right_proj_ = af2_all_gather(right_proj_, n_res_gather, c_i); + } else { + right_proj_ = at::empty({n_res_gather, c_i, n_seq}, right_proj_.options()); + } + } + + auto output_w_tw = convert_to_tensor_wrapper(weights.outer_w); + auto output_b_tw = convert_to_tensor_wrapper(weights.outer_b); + auto out_tw = convert_to_tensor_wrapper(out); + + kutacc_outer_product_mean_chunk(output_b_tw.get_tensor(), output_w_tw.get_tensor(), out_tw.get_tensor(), + left_proj_tw_.get_tensor(), right_proj_tw_.get_tensor(), norm_tw.get_tensor(), + left_block_size, right_block_size, c_i, c_z, n_res, n_res_gather, n_seq); + + return out; +} + +} \ No newline at end of file diff --git a/csrc/tpp/alphafold/outer_product_mean.h b/csrc/tpp/alphafold/outer_product_mean.h new file mode 100644 index 0000000000000000000000000000000000000000..d7a64de6a383853d400e64df43a903e9eec3ab42 --- /dev/null +++ b/csrc/tpp/alphafold/outer_product_mean.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#ifndef KPEX_OUTER_PRODUCT_MEAN_H +#define KPEX_OUTER_PRODUCT_MEAN_H + +#include + +namespace alphafold { + + struct OuterProductMeanWeight { + int64_t c_m; + int64_t c_i; + int64_t c_z; + + at::Tensor input_ln_w; + at::Tensor input_ln_b; + at::Tensor left_proj_w; + at::Tensor left_proj_b; + at::Tensor right_proj_w; + at::Tensor right_proj_b; + at::Tensor outer_w; + at::Tensor outer_b; + + OuterProductMeanWeight(at::Tensor &input_ln_w, at::Tensor &input_ln_b, at::Tensor &left_proj_w, + at::Tensor &left_proj_b, at::Tensor &right_proj_w, at::Tensor &right_proj_b, at::Tensor &output_w, + at::Tensor &output_b); + }; + + at::Tensor outer_product_mean(at::Tensor &act, at::Tensor &mask, const OuterProductMeanWeight &weights, + std::optional left_block_size, std::optional right_block_size, bool no_mpi); + +} + +#endif \ No newline at end of file diff --git a/csrc/utils/TensorWrapper.h b/csrc/utils/TensorWrapper.h index 6126833f3e918e040f91e0627c829ad53cab0433..08585b95b92a26d5a7fdff465cfee3d3fd245403 100644 --- a/csrc/utils/TensorWrapper.h +++ b/csrc/utils/TensorWrapper.h @@ -50,6 +50,39 @@ inline const kutacc::TensorWrapper convert_to_tensor_wrapper(const at::Tensor &t ); } +inline kutacc::TensorWrapper convert_to_tensor_wrapper_comm(at::Tensor &tensor) { + int64_t scalar_size = c10::elementSize(data.scalar_type()); + kutacc::DType dtype; + auto scalar_type = tensor.scalar_type(); + if (scalar_type == at::kBFloat16) { + dtype = kutacc::kBF16; + } + return kutacc::TensorWrapper( + tensor.data_ptr(), + {tensor.size()[0], tensor.size()[1], tensor.size()[2] * scalar_size}, + {tensor.strides()[0] * scalar_size, tensor.size()[1] * scalar_size, 1}, + tensor.dim(), + dtype + ); +} + +inline const kutacc::TensorWrapper convert_to_tensor_wrapper_comm(const at::Tensor &tensor) +{ + int64_t scalar_size = c10::elementSize(data.scalar_type()); + kutacc::DType dtype; + auto scalar_type = tensor.scalar_type(); + if (scalar_type == at::kBFloat16) { + dtype = kutacc::kBF16; + } + return kutacc::TensorWrapper( + tensor.data_ptr(), + {tensor.size()[0], tensor.size()[1], tensor.size()[2] * scalar_size}, + {tensor.strides()[0] * scalar_size, tensor.size()[1] * scalar_size, 1}, + tensor.dim(), + dtype + ); +} + inline at::Tensor linear_weight_prepack(const at::Tensor &weight, int64_t num_threads = 0) { int64_t n = weight.sizes()[0]; diff --git a/csrc/utils/all_gather.h b/csrc/utils/all_gather.h new file mode 100644 index 0000000000000000000000000000000000000000..218dc2763ae23e03c40599c38d66fa49b37708c8 --- /dev/null +++ b/csrc/utils/all_gather.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#ifndef KPEX_ALL_GATHER_H +#define KPEX_ALL_GATHER_H + +#include +#include +#include +#include +#include +#include "check.h" + +inline at::Tensor af2_all_gather(at::Tensor &data, int64_t m, int64_t n) +{ + int64_t len = data.sizes()[2]; + + KPEX_CHECK(data.strides()[2] == 1, data.strides()[2]); + + at::Tensor out = at::empty({m, n, len}, data.options()); + + auto data_tw = convert_to_tensor_wrapper_comm(data); + auto out_tw = convert_to_tensor_wrapper_comm(out); + + kutacc_af2_all_gather(data_tw.get_tensor(), out_tw.get_tensor()); + return out; +} + +#endif \ No newline at end of file diff --git a/csrc/utils/transpose.h b/csrc/utils/transpose.h index 41bd30ae62b131fe3ea354b9aabcbd9cc70b12b8..e1a38dd562e4be33c2b6d33e54d5968f6fd0418d 100644 --- a/csrc/utils/transpose.h +++ b/csrc/utils/transpose.h @@ -1,45 +1,47 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. - * - * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. -*/ - -#ifndef KPEX_TRANSPOSE_H -#define KPEX_TRANSPOSE_H -#include -#include -#include "TensorWrapper.h" -#include "check.h" -#include -#include -#include - -inline at::Tensor transpose(at::Tensor &data, int64_t m, int64_t n) -{ - int64_t block_m = (m + world_size - 1) / world_size; - int64_t block_n = (n + world_size - 1) / world_size; - int64_t len = data.sizes()[2]; - int64_t scalar_size = c10::elementSize(data.scalar_type()); - - KPEX_CHECK(data.strides()[2] == 1, data.strides()[2]); - at::Tensor out; - if (data.sizes()[0] < m) { - out = at::empty({m, std::min(block_n, n - rank * block_n), len}, data.options()); - } else { - out = at::empty({std::min(block_m, m - rank * block_m), n, len}, data.options()); - } - auto data_tw = convert_to_tensor_wrapper(data); - auto out_tw = convert_to_tensor_wrapper(out); - kutacc_af2_transpose(data_tw.get_tensor(), out_tw.get_tensor()); - return out; -} - +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#ifndef KPEX_TRANSPOSE_H +#define KPEX_TRANSPOSE_H + +#include +#include +#include +#include +#include +#include "check.h" + +inline at::Tensor af2_transpose(at::Tensor &data, int64_t m, int64_t n) +{ + int64_t block_m = (m + world_size - 1) / world_size; + int64_t block_n = (n + world_size - 1) / world_size; + int64_t len = data.sizes()[2]; + + KPEX_CHECK(data.strides()[2] == 1, data.strides()[2]); + + at::Tensor out; + if (data.sizes()[0] < m) { + out = at::empty({m, std::min(block_n, n - rank * block_n), len}, data.options()); + } else { + out = at::empty({std::min(block_m, m - rank * block_m), n, len}, data.options()); + } + + auto data_tw = convert_to_tensor_wrapper_comm(data); + auto out_tw = convert_to_tensor_wrapper_comm(out); + + kutacc_af2_transpose(data_tw.get_tensor(), out_tw.get_tensor()); + return out; +} + #endif \ No newline at end of file diff --git a/kpex/tpp/alphafold/alphafold.py b/kpex/tpp/alphafold/alphafold.py index 81a8d13668c52b301499dbf583c94347f0f80b64..53f8324565c6408f79deec838748f81769038a7d 100644 --- a/kpex/tpp/alphafold/alphafold.py +++ b/kpex/tpp/alphafold/alphafold.py @@ -95,26 +95,32 @@ def gating_attention_forward(self, q_data, m_data, bias, nonbatched_bias=torch.T ) return out -def global_attention_forward(self, q_data, m_data, q_mask, bias): + +def outer_product_mean_forward( + self, act, mask, left_block_size=None, right_block_size=None, no_mpi=False +): if not hasattr(self, "kpex_weights"): - self.kpex_weights = kernel.alphafold.GlobalAttentionWeight( - self.query_w.permute(1, 2, 0), - self.key_w.permute(1, 0), - self.value_w.permute(1, 0), - self.gating_w.permute(1, 2, 0), - self.gating_b, + self.kpex_weights = kernel.alphafold.OuterProductMeanWeight( + self.layer_norm_input.weight, + self.layer_norm_input.bias, + self.left_projection.weight, + self.left_projection.bias, + self.right_projection.weight, + self.right_projection.bias, self.output_w.permute(2, 0, 1), - self.output_b + self.output_b, ) - act = q_data.to(torch.bfloat16) - out = kernel.alphafold.global_attention( - act, - act, - q_mask.to(torch.bfloat16), + out = kernel.alphafold.outer_product_mean( + act.to(torch.bfloat16), + mask.to(torch.bfloat16), self.kpex_weights, + left_block_size, + right_block_size, + no_mpi, ) return out + def transition_forward(self, act, mask): if not hasattr(self, "kpex_weights"): self.kpex_weights = kernel.alphafold.TransitionWeight( @@ -178,6 +184,10 @@ def kpex_alphafold(model, model_config, dtype=torch.float): gating_attention_forward, block.triangle_attention_ending_node.attention ) + block.outer_product_mean.forward = types.MethodType( + outer_product_mean_forward, + block.outer_product_mean + ) block.msa_transition.forward = types.MethodType( transition_forward, block.msa_transition @@ -212,6 +222,10 @@ def kpex_alphafold(model, model_config, dtype=torch.float): gating_attention_forward, block.triangle_attention_ending_node.attention ) + block.outer_product_mean.forward = types.MethodType( + outer_product_mean_forward, + block.outer_product_mean + ) block.msa_transition.forward = types.MethodType( transition_forward, block.msa_transition