diff --git a/csrc/tpp/alphafold/bind.h b/csrc/tpp/alphafold/bind.h index 64779474c8eedaf08a166a11b64aabf50fc15b7b..e1e079572407fe0b6ee85e8e4b1b462648cc1915 100644 --- a/csrc/tpp/alphafold/bind.h +++ b/csrc/tpp/alphafold/bind.h @@ -18,6 +18,8 @@ #include "gating_attention.h" #include "transition.h" +#include "global_attention.h" +#include "triangle_multiplication.h" // #include "utils/layernorm.h" namespace alphafold { @@ -32,12 +34,30 @@ inline void bind(pybind11::module &m) submodule.def("gating_attention", &gating_attention, py::arg("q_data"), py::arg("m_data"), py::arg("bias"), py::arg("nonbatched_bias"), py::arg("weights"), py::arg("block_size") = std::nullopt); + py::class_(submodule, "GlobalAttentionWeight") + .def(py::init(), + py::arg("query_w"), py::arg("key_w"), py::arg("value_w"), py::arg("gate_w"), py::arg("gate_b"), + py::arg("output_w"), py::arg("output_b")); + submodule.def("global_attention", &global_attention, py::arg("q_data"), py::arg("m_data"), py::arg("q_mask"), + py::arg("weights")); + py::class_(submodule, "TransitionWeight") .def(py::init(), py::arg("input_ln_w"), py::arg("input_ln_b"), py::arg("linear1_w"), py::arg("linear1_b"), py::arg("linear2_w"), py::arg("linear2_b")); submodule.def("transition", &transition, py::arg("act"), py::arg("weights")); + py::class_(submodule, "TriangleMultiplicationWeight") + .def(py::init(), + py::arg("is_incoming"), py::arg("input_ln_w"), py::arg("input_ln_b"), py::arg("left_proj_w"), + py::arg("left_proj_b"), py::arg("right_proj_w"), py::arg("right_proj_b"), py::arg("left_gate_w"), + py::arg("left_gate_b"), py::arg("right_gate_w"), py::arg("right_gate_b"), py::arg("gating_w"), + py::arg("gating_b"), py::arg("center_ln_w"), py::arg("center_ln_b"), py::arg("output_proj_w"), + py::arg("output_proj_b")); + submodule.def("triangle_multiplication", &triangle_multiplication, py::arg("act"), py::arg("mask"), py::arg("weights")); // submodule.def("layernorm", &layernorm, py::arg("act"), py::arg("weight_"), py::arg("bias_")); } } diff --git a/csrc/tpp/alphafold/global_attention.cpp b/csrc/tpp/alphafold/global_attention.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89d9e3bd7ce252aeed74efe22bf6ccca8d629427 --- /dev/null +++ b/csrc/tpp/alphafold/global_attention.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#include "global_attention.h" +#include +#include "ATen/native/cpu/utils.h" +#include "ATen/ops/empty.h" +#include "ATen/record_function.h" +#include +#include + +namespace alphafold { + + GlobalAttentionWeight::GlobalAttentionWeight(at::Tensor &query_w, at::Tensor &key_w, at::Tensor &value_w, at::Tensor &gating_w, + at::Tensor &gating_b, at::Tensor &output_w, at::Tensor &output_b) + : nchannels(query_w.sizes()[2]), nheads(query_w.sizes()[0]), head_size(query_w.sizes()[1]) + { + KPEX_CHECK(nchannels == nheads * head_size, "invalid query_w shape [", nchannels, ", ", nheads, ", ", head_size, "]"); + KPEX_CHECK_TENSOR_SHAPE(query_w, nheads, head_size, nchannels); + KPEX_CHECK_TENSOR_SHAPE(key_w, head_size, nchannels); + KPEX_CHECK_TENSOR_SHAPE(value_w, head_size, nchannels); + KPEX_CHECK_TENSOR_SHAPE(gating_w, nheads, head_size, nchannels); + KPEX_CHECK_TENSOR_SHAPE(gating_b, nheads, head_size); + KPEX_CHECK_TENSOR_SHAPE(output_w, nchannels, nheads, head_size); + KPEX_CHECK_TENSOR_SHAPE(output_b, nchannels); + + auto float_opt = query_w.options().device(kpex::device()).dtype(c10::kFloat); + auto bf16_opt = query_w.options().device(kpex::device()).dtype(c10::kBFloat16); + + query_w = query_w.to(bf16_opt).contiguous().view({nchannels, nchannels}); + key_w = key_w.to(bf16_opt).contiguous().view({head_size, nchannels}); + value_w = value_w.to(bf16_opt).contiguous().view({head_size, nchannels}); + gating_w = gating_w.to(bf16_opt).contiguous().view({nchannels, nchannels}); + output_w = output_w.to(bf16_opt).contiguous().view({nchannels, nchannels}); + + this->query_w = linear_weight_prepack(query_w); + this->key_w = linear_weight_prepack(key_w); + this->value_w = linear_weight_prepack(value_w); + this->gating_w = linear_weight_prepack(gating_w); + this->output_w = linear_weight_prepack(output_w); + + this->gating_b = gating_b.to(float_opt).contiguous(); + this->output_b = output_b.to(float_opt).contiguous(); + } + + at::Tensor global_attention(at::Tensor &q_data, at::Tensor &m_data, at::Tensor &q_mask, const GlobalAttentionWeight &weights) + { + at::Tensor out = at::empty(q_data.sizes(), q_data.options()); + int64_t batch = q_data.sizes()[0]; + int64_t seq_len = q_data.sizes()[1]; + int64_t nchannels = weights.nchannels; + int64_t nheads = weights.nheads; + int64_t head_size = weights.head_size; + + KPEX_CHECK(q_data.dtype() == c10::kBFloat16, q_data.dtype()); + KPEX_CHECK(m_data.dtype() == c10::kBFloat16, m_data.dtype()); + KPEX_CHECK(q_mask.dtype() == c10::kBFloat16, q_mask.dtype()); + KPEX_CHECK_TENSOR_SHAPE(q_data, batch, seq_len, nchannels); + KPEX_CHECK_TENSOR_SHAPE(q_mask, batch, seq_len, 1); + + q_mask = q_mask.contiguous(); + + auto q_avg = q_data.new_empty({batch, nchannels}); + auto q = q_data.new_empty({batch, nheads, head_size}); + auto k = q_data.new_empty({batch, seq_len, head_size}); + auto v = q_data.new_empty({head_size, batch, seq_len}); + auto gate = q_data.new_empty({batch, seq_len, nheads, head_size}); + + auto q_avg_tw = convert_to_tensor_wrapper(q_avg); + auto q_tw = convert_to_tensor_wrapper(q); + auto k_tw = convert_to_tensor_wrapper(k); + auto v_tw = convert_to_tensor_wrapper(v); + auto gate_tw = convert_to_tensor_wrapper(gate); + auto q_data_tw = convert_to_tensor_wrapper(q_data); + auto m_data_tw = convert_to_tensor_wrapper(m_data); + auto q_mask_tw = convert_to_tensor_wrapper(q_mask); + auto out_tw = convert_to_tensor_wrapper(out); + auto query_w_tw = convert_to_tensor_wrapper(weights.query_w); + auto key_w_tw = convert_to_tensor_wrapper(weights.key_w); + auto value_w_tw = convert_to_tensor_wrapper(weights.value_w); + auto gating_w_tw = convert_to_tensor_wrapper(weights.gating_w); + auto gating_b_tw = convert_to_tensor_wrapper(weights.gating_b); + auto output_w_tw = convert_to_tensor_wrapper(weights.output_w); + auto output_b_tw = convert_to_tensor_wrapper(weights.output_b); + + kutacc_af2_global_attention(q_avg_tw.get_tensor(), q_tw.get_tensor(), k_tw.get_tensor(), v_tw.get_tensor(), batch, seq_len, + nchannels, nheads, head_size, gate_tw.get_tensor(), q_data_tw.get_tensor(), m_data_tw.get_tensor(), q_mask_tw.get_tensor(), + query_w_tw.get_tensor(), key_w_tw.get_tensor(), value_w_tw.get_tensor(), gating_w_tw.get_tensor(), gating_b_tw.get_tensor(), + output_w_tw.get_tensor(), output_b_tw.get_tensor(), out_tw.get_tensor()); + return out; + } +} \ No newline at end of file diff --git a/csrc/tpp/alphafold/global_attention.h b/csrc/tpp/alphafold/global_attention.h new file mode 100644 index 0000000000000000000000000000000000000000..6d718f8cc978aee56e461e5f561c95ddf3ac6aae --- /dev/null +++ b/csrc/tpp/alphafold/global_attention.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#ifndef KPEX_TPP_ALPHAFOLD_GLOBAL_ATTENTION_H +#define KPEX_TPP_ALPHAFOLD_GLOBAL_ATTENTION_H + +#include "utils/check.h" +#include + +namespace alphafold { + + struct GlobalAttentionWeight { + int64_t nchannels; + int64_t nheads; + int64_t head_size; + + at::Tensor query_w; + at::Tensor key_w; + at::Tensor value_w; + at::Tensor gating_w; + at::Tensor gating_b; + at::Tensor output_w; + at::Tensor output_b; + + GlobalAttentionWeight(at::Tensor &query_w, at::Tensor &key_w, at::Tensor &value_w, at::Tensor &gating_w, + at::Tensor &gating_b, at::Tensor &output_w, at::Tensor &output_b); + + }; + + at::Tensor global_attention(at::Tensor &q_data, at::Tensor &m_data, at::Tensor &q_mask, const GlobalAttentionWeight &weights); + +} + +#endif \ No newline at end of file diff --git a/csrc/tpp/alphafold/transition.h b/csrc/tpp/alphafold/transition.h index 8fef41522fedded735d45ad1eb65393d017da983..6035913ca1cb2926768dab926c161b54b1249be6 100644 --- a/csrc/tpp/alphafold/transition.h +++ b/csrc/tpp/alphafold/transition.h @@ -36,4 +36,4 @@ struct TransitionWeight { at::Tensor transition(at::Tensor &act, const TransitionWeight &weights); } -#endif \ No newline at end of file +#endif diff --git a/csrc/tpp/alphafold/triangle_multiplication.cpp b/csrc/tpp/alphafold/triangle_multiplication.cpp new file mode 100644 index 0000000000000000000000000000000000000000..61253aac79afa8854a3f4fd0ded365fda83ffd8f --- /dev/null +++ b/csrc/tpp/alphafold/triangle_multiplication.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#include "triangle_multiplication.h" +#include "kutacc.h" +#include +#include "ATen/ops/empty.h" +#include "ATen/native/cpu/utils.h" +#include "ATen/record_function.h" + +namespace alphafold { + + at::Tensor triangle_multiplication(at::Tensor &act, at::Tensor &mask, const TriangleMultiplicationWeight &weights) + { + at::Tensor out = at::empty(act.sizes(), act.options()); + int64_t n_res = act.sizes()[0]; + int64_t n_res_gather = act.sizes()[1]; + int64_t c_o = weights.c_o; + int64_t c_i = weights.c_i; + + KPEX_CHECK(act.dtype() == c10::kBFloat16, act.dtype()); + KPEX_CHECK(mask.dtype() == c10::kBFloat16, act.dtype()); + + at::Tensor input_act = layernorm(act, weights.input_ln_w, weights.input_ln_b); + at::Tensor center_act; + at::Tensor left_proj_act = input_act.new_empty({c_i, n_res, n_res_gather}); + at::Tensor right_proj_act = input_act.new_empty({c_i, n_res, n_res_gather}); + at::Tensor gate = act.new_empty({n_res, n_res_gather, c_o}); + bool input_prepack = false; + + if (input_prepack) { + input_act = linear_weight_prepack(input_act.view({n_res * n_res_gather, c_o})); + } + + auto left_proj_act_tw = convert_to_tensor_wrapper(left_proj_act); + auto right_proj_act_tw = convert_to_tensor_wrapper(right_proj_act); + auto gate_tw = convert_to_tensor_wrapper(gate); + auto input_act_tw = convert_to_tensor_wrapper(input_act); + auto mask_tw = convert_to_tensor_wrapper(mask); + auto left_proj_w_tw = convert_to_tensor_wrapper(weights.left_proj_w); + auto left_proj_b_tw = convert_to_tensor_wrapper(weights.left_proj_b); + auto left_gate_w_tw = convert_to_tensor_wrapper(weights.left_gate_w); + auto left_gate_b_tw = convert_to_tensor_wrapper(weights.left_gate_b); + auto right_proj_w_tw = convert_to_tensor_wrapper(weights.right_proj_w); + auto right_proj_b_tw = convert_to_tensor_wrapper(weights.right_proj_b); + auto right_gate_w_tw= convert_to_tensor_wrapper(weights.right_gate_w); + auto right_gate_b_tw= convert_to_tensor_wrapper(weights.right_gate_b); + + at::Tensor gate_left = input_act.new_empty({c_i, n_res, n_res_gather}); + auto gate_left_tw = convert_to_tensor_wrapper(gate_left); + kutacc_af2_triangle_multiplication_calc_proj(left_proj_act_tw.get_tensor(), gate_left_tw.get_tensor(), input_act_tw.get_tensor(), + mask_tw.get_tensor(), left_proj_w_tw.get_tensor(), left_proj_b_tw.get_tensor(), left_gate_w_tw.get_tensor(), + left_gate_b_tw.get_tensor(), n_res, n_res_gather, c_o, c_i, input_prepack); + + at::Tensor gate_right = input_act.new_empty({c_i, n_res, n_res_gather}); + auto gate_right_tw = convert_to_tensor_wrapper(gate_right); + kutacc_af2_triangle_multiplication_calc_proj(right_proj_act_tw.get_tensor(), gate_right_tw.get_tensor(), input_act_tw.get_tensor(), + mask_tw.get_tensor(), right_proj_w_tw.get_tensor(), right_proj_b_tw.get_tensor(), right_gate_w_tw.get_tensor(), + right_gate_b_tw.get_tensor(), n_res, n_res_gather, c_o, c_i, input_prepack); + + if (n_res < n_res_gather) { + left_proj_act = transpose(left_proj_act, c_i, n_res_gather); + right_proj_act = transpose(right_proj_act, c_i, n_res_gather); + } + center_act = act.new_empty({left_proj_act.sizes()[0], n_res_gather, n_res_gather}); + + auto center_act_tw = convert_to_tensor_wrapper(center_act); + if (n_res < n_res_gather) { + auto left_proj_act_new_tw = convert_to_tensor_wrapper(left_proj_act); // transpose后需要重新包装,规避内存重复释放问题 + auto right_proj_act_new_tw = convert_to_tensor_wrapper(right_proj_act); + kutacc_af2_triangle_multiplication_equation(center_act_tw.get_tensor(), left_proj_act_new_tw.get_tensor(), right_proj_act_new_tw.get_tensor(), n_res_gather, + weights.is_incoming); + center_act = transpose(center_act, c_i, n_res_gather); + } else { + kutacc_af2_triangle_multiplication_equation(center_act_tw.get_tensor(), left_proj_act_tw.get_tensor(), right_proj_act_tw.get_tensor(), n_res_gather, + weights.is_incoming); + } + center_act = center_act.permute({1, 2, 0}).contiguous(); + center_act = layernorm(center_act, weights.center_ln_w, weights.center_ln_b); + auto center_act_new_tw = convert_to_tensor_wrapper(center_act); // permute & layernorm之后重新包装center_act,规避内存重复释放问题 + + auto out_tw = convert_to_tensor_wrapper(out); + auto gating_w_tw = convert_to_tensor_wrapper(weights.gating_w); + auto gating_b_tw = convert_to_tensor_wrapper(weights.gating_b); + auto output_proj_w_tw = convert_to_tensor_wrapper(weights.output_proj_w); + auto output_proj_b_tw = convert_to_tensor_wrapper(weights.output_proj_b); + + kutacc_af2_triangle_multiplication_gate_and_out_linear(gate_tw.get_tensor(), out_tw.get_tensor(), input_act_tw.get_tensor(), center_act_new_tw.get_tensor(), + gating_w_tw.get_tensor(), gating_b_tw.get_tensor(), output_proj_w_tw.get_tensor(), output_proj_b_tw.get_tensor(), + n_res, n_res_gather, c_o, c_i, input_prepack); + kutacc_af2_triangle_multiplication_last(out_tw.get_tensor(), gate_tw.get_tensor(), n_res, n_res_gather, c_o); + return out; + } + + TriangleMultiplicationWeight::TriangleMultiplicationWeight(bool is_incoming, at::Tensor &input_ln_w, + at::Tensor &input_ln_b, at::Tensor &left_proj_w, at::Tensor &left_proj_b, at::Tensor &right_proj_w, + at::Tensor &right_proj_b, at::Tensor &left_gate_w, at::Tensor &left_gate_b, at::Tensor &right_gate_w, + at::Tensor &right_gate_b, at::Tensor &gating_w, at::Tensor &gating_b, at::Tensor ¢er_ln_w, + at::Tensor ¢er_ln_b, at::Tensor &output_proj_w, at::Tensor &output_proj_b) + : c_o(input_ln_w.sizes()[0]), c_i(left_proj_w.sizes()[0]), is_incoming(is_incoming) + { + KPEX_CHECK_TENSOR_SHAPE(input_ln_w, c_o); + KPEX_CHECK_TENSOR_SHAPE(input_ln_b, c_o); + KPEX_CHECK_TENSOR_SHAPE(left_proj_w, c_i, c_o); + KPEX_CHECK_TENSOR_SHAPE(left_proj_b, c_i); + KPEX_CHECK_TENSOR_SHAPE(right_proj_w, c_i, c_o); + KPEX_CHECK_TENSOR_SHAPE(right_proj_b, c_i); + KPEX_CHECK_TENSOR_SHAPE(left_gate_w, c_i, c_o); + KPEX_CHECK_TENSOR_SHAPE(left_gate_b, c_i); + KPEX_CHECK_TENSOR_SHAPE(right_gate_w, c_i, c_o); + KPEX_CHECK_TENSOR_SHAPE(right_gate_b, c_i); + KPEX_CHECK_TENSOR_SHAPE(gating_w, c_i, c_o); + KPEX_CHECK_TENSOR_SHAPE(gating_b, c_i); + KPEX_CHECK_TENSOR_SHAPE(center_ln_w, c_i); + KPEX_CHECK_TENSOR_SHAPE(center_ln_b, c_i); + KPEX_CHECK_TENSOR_SHAPE(output_proj_w, c_o, c_i); + KPEX_CHECK_TENSOR_SHAPE(output_proj_b, c_o); + + auto float_opt = left_proj_w.options().device(kpex::device()).dtype(c10::kFloat); + auto bf16_opt = left_proj_w.options().device(kpex::device()).dtype(c10::kBFloat16); + left_proj_w = left_proj_w.to(bf16_opt).contiguous(); + right_proj_w = right_proj_w.to(bf16_opt).contiguous(); + left_gate_w = left_gate_w.to(bf16_opt).contiguous(); + right_gate_w = right_gate_w.to(bf16_opt).contiguous(); + gating_w = gating_w.to(bf16_opt).contiguous(); + output_proj_w = output_proj_w.to(bf16_opt).contiguous(); + + this->left_proj_w = linear_weight_prepack(left_proj_w); + this->right_proj_w = linear_weight_prepack(right_proj_w); + this->left_gate_w = linear_weight_prepack(left_gate_w); + this->right_gate_w = linear_weight_prepack(right_gate_w); + this->gating_w = linear_weight_prepack(gating_w); + this->output_proj_w = linear_weight_prepack(output_proj_w); + + this->input_ln_w = input_ln_w.to(float_opt).contiguous(); + this->input_ln_b = input_ln_b.to(float_opt).contiguous(); + this->left_proj_b = left_proj_b.to(float_opt).contiguous(); + this->right_proj_b = right_proj_b.to(float_opt).contiguous(); + this->left_gate_b = left_gate_b.to(float_opt).contiguous(); + this->right_gate_b = right_gate_b.to(float_opt).contiguous(); + this->gating_b = gating_b.to(float_opt).contiguous(); + this->center_ln_w = center_ln_w.to(float_opt).contiguous(); + this->center_ln_b = center_ln_b.to(float_opt).contiguous(); + this->output_proj_b = output_proj_b.to(float_opt).contiguous(); + } +} \ No newline at end of file diff --git a/csrc/tpp/alphafold/triangle_multiplication.h b/csrc/tpp/alphafold/triangle_multiplication.h new file mode 100644 index 0000000000000000000000000000000000000000..87e97ebbfa511c8c419aa8e9c7a7477654597f46 --- /dev/null +++ b/csrc/tpp/alphafold/triangle_multiplication.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#ifndef KPEX_TPP_ALPHAFOLD_TRIANGLE_MULTIPLICATION_H +#define KPEX_TPP_ALPHAFOLD_TRIANGLE_MULTIPLICATION_H + +#include +#include +#include +#include +#include + +namespace alphafold { + + struct TriangleMultiplicationWeight { + int64_t c_o; + int64_t c_i; + + bool is_incoming; + at::Tensor input_ln_w; + at::Tensor input_ln_b; + at::Tensor left_proj_w; + at::Tensor left_proj_b; + at::Tensor right_proj_w; + at::Tensor right_proj_b; + at::Tensor left_gate_w; + at::Tensor left_gate_b; + at::Tensor right_gate_w; + at::Tensor right_gate_b; + at::Tensor gating_w; + at::Tensor gating_b; + at::Tensor center_ln_w; + at::Tensor center_ln_b; + at::Tensor output_proj_w; + at::Tensor output_proj_b; + + TriangleMultiplicationWeight(bool is_incoming, at::Tensor &input_ln_w, at::Tensor &input_ln_b, + at::Tensor &left_proj_w, at::Tensor &left_proj_b, at::Tensor &right_proj_w, at::Tensor &right_proj_b, + at::Tensor &left_gate_w, at::Tensor &left_gate_b, at::Tensor &right_gate_w, at::Tensor &right_gate_b, + at::Tensor &gating_w, at::Tensor &gating_b, at::Tensor ¢er_ln_w, at::Tensor ¢er_ln_b, + at::Tensor &output_proj_w, at::Tensor &output_proj_b); + + }; + + at::Tensor triangle_multiplication(at::Tensor &act, at::Tensor &mask, const TriangleMultiplicationWeight &weights); + +} + +#endif \ No newline at end of file diff --git a/csrc/utils/check.h b/csrc/utils/check.h index 0d57155c216d0fa99a01e15c8170f750a411761d..893472c354810223b7fd306394513b6afac9f412 100644 --- a/csrc/utils/check.h +++ b/csrc/utils/check.h @@ -11,7 +11,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef KPEXC_UTILS_CHECK_H +#ifndef KPEX_UTILS_CHECK_H #define KPEX_UTILS_CHECK_H #include diff --git a/csrc/utils/layernorm.h b/csrc/utils/layernorm.h index 7eb86a5e3b8794a6443089e0eb52611c0b5b9de9..21430d0d6b75eda6480ac9bf02a4179186d048fe 100644 --- a/csrc/utils/layernorm.h +++ b/csrc/utils/layernorm.h @@ -37,7 +37,7 @@ inline at::Tensor layernorm(const at::Tensor &act, const at::Tensor &weight_, co int64_t mi, ni; at::native::data_index_init(start, mi, m, ni, n); for ([[maybe_unused]] int64_t _ : c10::irange(start, end)) { - kutacc_layernorm( + kutacc_af2_layernorm( (__bf16 *)act.data_ptr() + mi * act.strides()[0] + ni * act.strides()[1], (float *)weight.data_ptr(), (float *)bias.data_ptr(), len, 1e-5, (__bf16 *)out.data_ptr() + mi * out.strides()[0] + ni * out.strides()[1]); @@ -48,4 +48,4 @@ inline at::Tensor layernorm(const at::Tensor &act, const at::Tensor &weight_, co } -#endif \ No newline at end of file +#endif diff --git a/csrc/utils/parallel.h b/csrc/utils/parallel.h index 890c5ecb0143f2bc8280b7185c1c09054151cc4c..c2acc57fbf2e9e8ca736350c7d674e804586e414 100644 --- a/csrc/utils/parallel.h +++ b/csrc/utils/parallel.h @@ -93,4 +93,4 @@ inline void parallel(int num_threads, const F& f) { } // kpex -#endif \ No newline at end of file +#endif diff --git a/csrc/utils/transpose.h b/csrc/utils/transpose.h new file mode 100644 index 0000000000000000000000000000000000000000..41bd30ae62b131fe3ea354b9aabcbd9cc70b12b8 --- /dev/null +++ b/csrc/utils/transpose.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + * + * Licensed under a modified version of the MIT license. See LICENSE in the project root for license information. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. +*/ + +#ifndef KPEX_TRANSPOSE_H +#define KPEX_TRANSPOSE_H +#include +#include +#include "TensorWrapper.h" +#include "check.h" +#include +#include +#include + +inline at::Tensor transpose(at::Tensor &data, int64_t m, int64_t n) +{ + int64_t block_m = (m + world_size - 1) / world_size; + int64_t block_n = (n + world_size - 1) / world_size; + int64_t len = data.sizes()[2]; + int64_t scalar_size = c10::elementSize(data.scalar_type()); + + KPEX_CHECK(data.strides()[2] == 1, data.strides()[2]); + at::Tensor out; + if (data.sizes()[0] < m) { + out = at::empty({m, std::min(block_n, n - rank * block_n), len}, data.options()); + } else { + out = at::empty({std::min(block_m, m - rank * block_m), n, len}, data.options()); + } + auto data_tw = convert_to_tensor_wrapper(data); + auto out_tw = convert_to_tensor_wrapper(out); + kutacc_af2_transpose(data_tw.get_tensor(), out_tw.get_tensor()); + return out; +} + +#endif \ No newline at end of file diff --git a/kpex/tpp/alphafold/alphafold.py b/kpex/tpp/alphafold/alphafold.py index d9d0d1c8a932f48270f85cc17e274634c3f4d0a5..81a8d13668c52b301499dbf583c94347f0f80b64 100644 --- a/kpex/tpp/alphafold/alphafold.py +++ b/kpex/tpp/alphafold/alphafold.py @@ -95,6 +95,26 @@ def gating_attention_forward(self, q_data, m_data, bias, nonbatched_bias=torch.T ) return out +def global_attention_forward(self, q_data, m_data, q_mask, bias): + if not hasattr(self, "kpex_weights"): + self.kpex_weights = kernel.alphafold.GlobalAttentionWeight( + self.query_w.permute(1, 2, 0), + self.key_w.permute(1, 0), + self.value_w.permute(1, 0), + self.gating_w.permute(1, 2, 0), + self.gating_b, + self.output_w.permute(2, 0, 1), + self.output_b + ) + act = q_data.to(torch.bfloat16) + out = kernel.alphafold.global_attention( + act, + act, + q_mask.to(torch.bfloat16), + self.kpex_weights, + ) + return out + def transition_forward(self, act, mask): if not hasattr(self, "kpex_weights"): self.kpex_weights = kernel.alphafold.TransitionWeight( @@ -108,6 +128,34 @@ def transition_forward(self, act, mask): out = kernel.alphafold.transition(act.to(torch.bfloat16), self.kpex_weights) return out +def triangleMultiplication_forward(self, act, mask): + if not hasattr(self, "kpex_weights"): + self.kpex_weights = kernel.alphafold.TriangleMultiplicationWeight( + self.c_equation == "kjc, kic->ijc", + self.layer_norm_input.weight, + self.layer_norm_input.bias, + self.left_projection.weight, + self.left_projection.bias, + self.right_projection.weight, + self.right_projection.bias, + self.left_gate.weight, + self.left_gate.bias, + self.right_gate.weight, + self.right_gate.bias, + self.gating_linear.weight, + self.gating_linear.bias, + self.center_layer_norm.weight, + self.center_layer_norm.bias, + self.output_projection.weight, + self.output_projection.bias, + ) + out = kernel.alphafold.triangle_multiplication( + act.to(torch.bfloat16), + mask.to(torch.bfloat16), + self.kpex_weights, + ) + return out + def kpex_alphafold(model, model_config, dtype=torch.float): new_model = copy.deepcopy(model) evoformer = new_model.model.impl.evoformer @@ -118,6 +166,10 @@ def kpex_alphafold(model, model_config, dtype=torch.float): gating_attention_forward, block.msa_row_attention_with_pair_bias.attention ) + block.msa_column_global_attention.attention.forward = types.MethodType( + global_attention_forward, + block.msa_column_global_attention.attention + ) block.triangle_attention_starting_node.attention.forward = types.MethodType( gating_attention_forward, block.triangle_attention_starting_node.attention @@ -134,6 +186,14 @@ def kpex_alphafold(model, model_config, dtype=torch.float): transition_forward, block.pair_transition ) + block.triangle_multiplication_outgoing.forward = types.MethodType( + triangleMultiplication_forward, + block.triangle_multiplication_outgoing + ) + block.triangle_multiplication_incoming.forward = types.MethodType( + triangleMultiplication_forward, + block.triangle_multiplication_incoming + ) if hasattr(evoformer, "evoformer_iteration"): for block in evoformer.evoformer_iteration: block.msa_row_attention_with_pair_bias.attention.forward = types.MethodType( @@ -175,6 +235,14 @@ def kpex_alphafold(model, model_config, dtype=torch.float): transition_forward, block.pair_transition ) + block.triangle_multiplication_outgoing.forward = types.MethodType( + triangleMultiplication_forward, + block.triangle_multiplication_outgoing + ) + block.triangle_multiplication_incoming.forward = types.MethodType( + triangleMultiplication_forward, + block.triangle_multiplication_incoming + ) return new_model diff --git a/kpex_example/csrc/kupl_example/bind.h b/kpex_example/csrc/kupl_example/bind.h index e6658c0112b13d41061f1c129d8b772dd2a4c87b..1d1fa92b67fbf6e4b5ebddc0e1c992d8505b1d96 100644 --- a/kpex_example/csrc/kupl_example/bind.h +++ b/kpex_example/csrc/kupl_example/bind.h @@ -29,4 +29,4 @@ inline void bind(pybind11::module &m) } } -#endif \ No newline at end of file +#endif diff --git a/kpex_example/csrc/kupl_example/kupl_example.h b/kpex_example/csrc/kupl_example/kupl_example.h index 534f1964f2b63470142aefb5fd612a383ea65874..8073f01ab634e027fc687d4215db5563a27358ca 100644 --- a/kpex_example/csrc/kupl_example/kupl_example.h +++ b/kpex_example/csrc/kupl_example/kupl_example.h @@ -25,4 +25,4 @@ void test_kupl_parallel_for_error(); } // namespace kupl_example -#endif \ No newline at end of file +#endif diff --git a/kpex_example/frontend.py b/kpex_example/frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391