diff --git a/include/kutacc.h b/include/kutacc.h index b01e2b53a55e0486983e674bac4a299a317ea069..2f180a625fc1cb971045fa6bd93787fbbe3d6aa4 100644 --- a/include/kutacc.h +++ b/include/kutacc.h @@ -73,6 +73,189 @@ kutacc_export void kutacc_af2_transpose(kutacc_tensor_h data, kutacc_tensor_h ou */ kutacc_export void kutacc_af2_all_gather(kutacc_tensor_h data, kutacc_tensor_h out); +/** @brief used for gating_attention and global_attention op, it's a union of weights params */ +typedef struct kutacc_af2_attention_weights { + int64_t nchannels; + int64_t nheads; + int64_t head_size; + + kutacc_tensor_h query_w; + kutacc_tensor_h key_w; + kutacc_tensor_h value_w; + kutacc_tensor_h gating_w; + kutacc_tensor_h gating_b; + kutacc_tensor_h output_w; + kutacc_tensor_h output_b; +} kutacc_af2_attention_weights_t; + +/** @brief used for gating_attention and global_attention op, + it's a union of tensors which shapes like act param in alphafold.py + and be used as input variables of the intermediate process + */ +typedef struct kutacc_af2_attention_inputs { + int64_t batch; + int64_t seq_len; + + kutacc_tensor_h q; + kutacc_tensor_h k; + kutacc_tensor_h v; + kutacc_tensor_h gate; + kutacc_tensor_h avg; +} kutacc_af2_attention_inputs_t; + +/** @brief used for invariant point attention op, + it's a union of weights params needed by ipa + */ +typedef struct kutacc_af2_ipa_weights { + int64_t c_z; + int64_t c_hidden; + int64_t no_heads; + int64_t no_qk_points; + int64_t no_v_points; + + kutacc_tensor_h head_weights; + kutacc_tensor_h weights_head_weights; + kutacc_tensor_h linear_b_w; + kutacc_tensor_h linear_b_b; +} kutacc_af2_ipa_weights_t; + +/** @brief used for ipa op, + it's a union of tensors which shapes like s param in alphafold.py + and be used as input variables of the intermediate process + */ +typedef struct kutacc_af2_ipa_s_inputs { + int64_t n_res; + + kutacc_tensor_h a; + kutacc_tensor_h b; + kutacc_tensor_h q; + kutacc_tensor_h k; + kutacc_tensor_h v; + kutacc_tensor_h q_pts; + kutacc_tensor_h k_pts; + kutacc_tensor_h v_pts; +} kutacc_af2_ipa_s_inputs_t; + +/** @brief used for ipa op, + it's a union of tensors which shapes like o param from kpex + and be used as input variables of the intermediate process + */ +typedef struct kutacc_af2_ipa_o_inputs { + kutacc_tensor_h o; + kutacc_tensor_h o_pt; + kutacc_tensor_h o_pt_norm; + kutacc_tensor_h o_pair; +} kutacc_af2_ipa_o_inputs_t; + +/** @brief used for triangle multiplication op + it's a union of tensors and be used as weights + for op's intermediate process: calc left and right projection + */ +typedef struct kutacc_af2_tm_proj_weights { + int64_t c_o; + int64_t c_i; + + kutacc_tensor_h proj_w; + kutacc_tensor_h proj_b; + kutacc_tensor_h gate_w; + kutacc_tensor_h gate_b; +} kutacc_af2_tm_proj_weights_t; + +/** @brief used for triangle multiplication op + it's a union of tensors which shapes like act from alphafold.py + proj act is the input and also the output of calc left and right projection + */ +typedef struct kutacc_af2_tm_act_inputs { + int64_t n_res; + int64_t n_res_gather; + + kutacc_tensor_h proj_act; + kutacc_tensor_h input_act; + kutacc_tensor_h proj_act_gate; +} kutacc_af2_tm_act_inputs_t; + +/** @brief used for triangle multiplication op + it's a union of tensors and be used as weights + for op's intermediate process: calc gate value and out value + */ +typedef struct kutacc_af2_tm_linear_weights { + int64_t c_o; + int64_t c_i; + + kutacc_tensor_h gating_w; + kutacc_tensor_h gating_b; + kutacc_tensor_h output_proj_w; + kutacc_tensor_h output_proj_b; +} kutacc_af2_tm_linear_weights_t; + +/** @brief used for transition op + it's a union of tensors and be used as weights of transition op + */ +typedef struct kutacc_af2_trans_weights { + int64_t c_o; + int64_t c_i; + + kutacc_tensor_h linear1_w; + kutacc_tensor_h linear1_b; + kutacc_tensor_h linear2_w; + kutacc_tensor_h linear2_b; +} kutacc_af2_trans_weights_t; + +/** @brief used for transition op + it's a union of tensors which shape like act from alphafold.py + and be used as inputs of transition op + */ +typedef struct kutacc_af2_trans_act_inputs { + int64_t batch; + int64_t n_res; + + kutacc_tensor_h input_act; + kutacc_tensor_h intermediate_act; +} kutacc_af2_trans_act_inputs_t; + +/** @brief used for outer product mean op + it's a union of tensors and be used as weights of outer product mean op + */ +typedef struct kutacc_af2_opm_weights { + int64_t c_m; + int64_t c_i; + int64_t c_z; + + kutacc_tensor_h left_proj_w; + kutacc_tensor_h left_proj_b; + kutacc_tensor_h right_proj_w; + kutacc_tensor_h right_proj_b; + kutacc_tensor_h outer_w; + kutacc_tensor_h outer_b; +} kutacc_af2_opm_weights_t; + +/** @brief used for outer product mean op + it's a union of tensors which shape like act from alphafold.py + and be used as inputs of outer product mean op + */ +typedef struct kutacc_af2_opm_act_inputs { + int64_t n_seq; + int64_t n_res; + + kutacc_tensor_h input_act; + kutacc_tensor_h left_proj; + kutacc_tensor_h right_proj; + kutacc_tensor_h left_proj_; + kutacc_tensor_h right_proj_; +} kutacc_af2_opm_act_inputs_t; + +/** @brief used for outer product mean op + it's a union of tensors which shape like mask from alphafold.py + and be used as mask inputs of outer product mean op + */ +typedef struct kutacc_af2_opm_mask_inputs { + int64_t n_res_gather; + int64_t mask_bias; + + kutacc_tensor_h mask; + kutacc_tensor_h norm; +} kutacc_af2_opm_mask_inputs_t; + /** * @brief outer_product_mean_calc_left_and_right_mu algorithm * @param [out] left_proj, right_proj, left_proj_, right_proj_, mask @@ -80,11 +263,7 @@ kutacc_export void kutacc_af2_all_gather(kutacc_tensor_h data, kutacc_tensor_h o * @param [in] c_i, c_m, n_res, n_res_gather, n_seq, mask_bias * @return Null */ -kutacc_export void kutacc_af2_outer_product_mean_calc_left_and_right_mul( - kutacc_tensor_h left_proj, kutacc_tensor_h right_proj, kutacc_tensor_h left_proj_, kutacc_tensor_h right_proj_, - kutacc_tensor_h input_act, kutacc_tensor_h mask, kutacc_tensor_h norm, const kutacc_tensor_h left_proj_w, - const kutacc_tensor_h left_proj_b, const kutacc_tensor_h right_proj_w, const kutacc_tensor_h right_proj_b, - int64_t c_i, int64_t c_m, int64_t n_res, int64_t n_res_gather, int64_t n_seq, int64_t mask_bias); +kutacc_export void kutacc_af2_outer_product_mean_calc_left_and_right_mul(kutacc_af2_opm_act_inputs_t *opm_acts_ptr, kutacc_af2_opm_mask_inputs_t *opm_masks_ptr, kutacc_af2_opm_weights_t *opm_weights_ptr); /** * @brief outer_product_mean_chunk algorithm @@ -93,10 +272,8 @@ kutacc_export void kutacc_af2_outer_product_mean_calc_left_and_right_mul( * @param [in] c_i, c_z, n_res, n_res_gather, n_seq * @return Null */ -kutacc_export void kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h output_b, const kutacc_tensor_h output_w, - kutacc_tensor_h out, kutacc_tensor_h left_proj_, kutacc_tensor_h right_proj_, kutacc_tensor_h norm, - int64_t left_block_size, int64_t right_block_size, int64_t c_i, int64_t c_z, int64_t n_res, - int64_t n_res_gather, int64_t n_seq); +kutacc_export void kutacc_af2_outer_product_mean_chunk(kutacc_af2_opm_act_inputs_t *opm_acts_ptr, kutacc_af2_opm_mask_inputs_t *opm_masks_ptr, kutacc_af2_opm_weights_t *opm_weights_ptr,kutacc_tensor_h out, + int64_t left_block_size, int64_t right_block_size); /** * @brief gating_attention algorithm @@ -113,22 +290,15 @@ kutacc_export void kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h out * @return Null * constraint: nchannels = nheads * head_size */ -kutacc_export void kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_tensor_h q, kutacc_tensor_h k, kutacc_tensor_h v, kutacc_tensor_h gate, kutacc_tensor_h weighted_avg, - int64_t batch, int64_t seq_len, kutacc_tensor_h m_data, kutacc_tensor_h bias, kutacc_tensor_h nonbatched_bias, - const kutacc_tensor_h query_w, const kutacc_tensor_h key_w, const kutacc_tensor_h value_w, const kutacc_tensor_h gating_w, - const kutacc_tensor_h gating_b, const kutacc_tensor_h output_w, const kutacc_tensor_h output_b, kutacc_tensor_h out, - int64_t block_size_, int64_t head_size, int64_t nheads, int64_t nchannels); - +kutacc_export void kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_af2_attention_inputs_t *q_based_ptr, kutacc_tensor_h m_data, kutacc_tensor_h bias, kutacc_tensor_h nonbatched_bias, + kutacc_af2_attention_weights_t *weight_ptr, kutacc_tensor_h out, int64_t block_size_); /** * @param q_data shape [batch, seq_len, nchannels], bf16 * @param m_data shape [batch, seq_len, nchannels], bf16 * @param q_mask shape [batch, seq_len, 1], bf16 */ -kutacc_export void kutacc_af2_global_attention(kutacc_tensor_h q_avg, kutacc_tensor_h q, kutacc_tensor_h k, kutacc_tensor_h v, - int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size, kutacc_tensor_h gate, kutacc_tensor_h q_data, - kutacc_tensor_h m_data, kutacc_tensor_h q_mask, const kutacc_tensor_h query_w, const kutacc_tensor_h key_w, - const kutacc_tensor_h value_w, const kutacc_tensor_h gating_w, const kutacc_tensor_h gating_b, const kutacc_tensor_h output_w, - const kutacc_tensor_h output_b, kutacc_tensor_h out); +kutacc_export void kutacc_af2_global_attention(kutacc_af2_attention_inputs_t *q_based_ptr, kutacc_tensor_h q_data, + kutacc_tensor_h m_data, kutacc_tensor_h q_mask, kutacc_af2_attention_weights_t *weight_ptr, kutacc_tensor_h out); /** * @brief transition algorithm @@ -142,9 +312,7 @@ kutacc_export void kutacc_af2_global_attention(kutacc_tensor_h q_avg, kutacc_ten * @param [out] out * @return Null */ -kutacc_export void kutacc_af2_transition(kutacc_tensor_h input_act, const kutacc_tensor_h linear1_w, kutacc_tensor_h linear1_b, - kutacc_tensor_h linear2_w, kutacc_tensor_h linear2_b, kutacc_tensor_h intermediate_act, kutacc_tensor_h out, - int64_t batch, int64_t n_res, int64_t c_o, int64_t c_i); +kutacc_export void kutacc_af2_transition(kutacc_af2_trans_act_inputs_t *trans_inputs_ptr, kutacc_af2_trans_weights_t *trans_weights_ptr, kutacc_tensor_h out); /** * @brief af2_layernorm algorithm: layernorm interface for af2 model @@ -189,10 +357,8 @@ kutacc_export void kutacc_af2_layernorm(__bf16 *data, float *gamma, float *beta, * @param [out] out * @return Null */ -kutacc_export void kutacc_af2_invariant_point(kutacc_tensor_h q, kutacc_tensor_h k, kutacc_tensor_h v, kutacc_tensor_h q_pts, kutacc_tensor_h k_pts, kutacc_tensor_h v_pts, - kutacc_tensor_h b, kutacc_tensor_h a, kutacc_tensor_h head_weights, kutacc_tensor_h weights_head_weights, kutacc_tensor_h o, kutacc_tensor_h o_pt, kutacc_tensor_h o_pt_norm, kutacc_tensor_h o_pair, - kutacc_tensor_h z, kutacc_tensor_h rigid_rot_mats, kutacc_tensor_h rigid_trans, kutacc_tensor_h mask, kutacc_tensor_h linear_b_w, kutacc_tensor_h linear_b_b, - int64_t n_res, int64_t c_z, int64_t c_hidden, int64_t no_heads, int64_t no_qk_points, int64_t no_v_points); +kutacc_export void kutacc_af2_invariant_point(kutacc_af2_ipa_s_inputs_t *ipa_s_ptrs, kutacc_af2_ipa_o_inputs_t *ipa_o_ptrs, kutacc_tensor_h z, kutacc_tensor_h rigid_rot_mats, + kutacc_tensor_h rigid_trans, kutacc_tensor_h mask, kutacc_af2_ipa_weights_t *ipa_weight_ptrs); /** * @brief impl of rot_vec_mul @@ -245,9 +411,7 @@ kutacc_export size_t kutacc_af2_gemm_pack_get_size(char identifier, char transa, * @param mask shape [n_res, n_res_gather] * @param [out] proj_act */ -kutacc_export void kutacc_af2_triangle_multiplication_calc_proj(kutacc_tensor_h proj_act, kutacc_tensor_h gate, kutacc_tensor_h input_act, kutacc_tensor_h mask, - const kutacc_tensor_h proj_w, const kutacc_tensor_h proj_b, const kutacc_tensor_h gate_w, const kutacc_tensor_h gate_b, int64_t n_res, - int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack); +kutacc_export void kutacc_af2_triangle_multiplication_calc_proj(kutacc_af2_tm_act_inputs_t *tm_acts_ptr, kutacc_tensor_h mask, kutacc_af2_tm_proj_weights_t *tm_weights_ptr, bool input_prepack); /** * @brief center_act = left_proj * right_proj @@ -262,9 +426,8 @@ kutacc_export void kutacc_af2_triangle_multiplication_equation(kutacc_tensor_h c * @param [in] input_act, center_act, gating_w, gating_b, output_proj_w, output_proj_b * @param [out] gate, out */ -kutacc_export void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc_tensor_h gate, kutacc_tensor_h out, kutacc_tensor_h input_act, kutacc_tensor_h center_act, - const kutacc_tensor_h gating_w, const kutacc_tensor_h gating_b, const kutacc_tensor_h output_proj_w, const kutacc_tensor_h output_proj_b, - int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack); +kutacc_export void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc_tensor_h gate, kutacc_tensor_h out, kutacc_af2_tm_act_inputs_t *tm_acts_ptr, kutacc_tensor_h center_act, + kutacc_af2_tm_linear_weights_t *tm_weights_ptr, bool input_prepack); /** * @brief out = (out + out_proj_b) * sigmoid(gate + gating_b) diff --git a/src/attention/gating_attention.cpp b/src/attention/gating_attention.cpp index 61cc8c3738f2569211a7a033a52acb39c34c975c..cbd80d85a5758088001577b6c321e192d69eacde 100644 --- a/src/attention/gating_attention.cpp +++ b/src/attention/gating_attention.cpp @@ -149,24 +149,36 @@ void gating_attention_kernel(Tensor &input, Tensor &q, Tensor &k, Tensor &v, Ten BlasExtendParams{.prepack_a = false, .prepack_b = true, .row_bias = true, .bias = output_b.data_ptr()}); } } - -void kutacc_export kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_tensor_h q, kutacc_tensor_h k, kutacc_tensor_h v, - kutacc_tensor_h gate, kutacc_tensor_h weighted_avg, int64_t batch, int64_t seq_len, kutacc_tensor_h m_data, - kutacc_tensor_h bias, kutacc_tensor_h nonbatched_bias, const kutacc_tensor_h query_w, const kutacc_tensor_h key_w, - const kutacc_tensor_h value_w, const kutacc_tensor_h gating_w, const kutacc_tensor_h gating_b, const kutacc_tensor_h output_w, - const kutacc_tensor_h output_b, kutacc_tensor_h out, int64_t block_size_, int64_t head_size, int64_t nheads, int64_t nchannels) +kutacc_export void kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_af2_attention_inputs_t *q_based_ptr, kutacc_tensor_h m_data, kutacc_tensor_h bias, kutacc_tensor_h nonbatched_bias, + kutacc_af2_attention_weights_t *weight_ptr, kutacc_tensor_h out, int64_t block_size_) { - KUTACC_CHECK(input != nullptr && q != nullptr && k != nullptr && v != nullptr && gate != nullptr && weighted_avg != nullptr && m_data != nullptr - && bias != nullptr && nonbatched_bias != nullptr && query_w != nullptr && key_w != nullptr && value_w != nullptr && gating_w != nullptr && gating_b != nullptr - && output_w != nullptr && output_b != nullptr && out != nullptr, - "kutacc_af2_gating_attention: input args nullptr error\n"); - KUTACC_CHECK(batch > 0 && seq_len > 0 && block_size_ > 0 && head_size > 0 && nheads > 0 && nchannels > 0, - "kutacc_af2_gating_attention: input args int values error, values are less than or equal to zero\n"); + KUTACC_CHECK(input != nullptr && q_based_ptr != nullptr && m_data != nullptr && bias != nullptr && nonbatched_bias != nullptr && weight_ptr != nullptr + && out != nullptr && block_size_ > 0 && block_size_ < INT64_MAX, "kutacc_af2_gating_attention: input args nullptr or invalid int value error"); if (kutacc::kutacc_check_err_set == true) { return; } + + kutacc_tensor_h q = q_based_ptr->q; + kutacc_tensor_h k = q_based_ptr->k; + kutacc_tensor_h v = q_based_ptr->v; + kutacc_tensor_h gate = q_based_ptr->gate; + kutacc_tensor_h weight_avg = q_based_ptr->avg; + int64_t batch = q_based_ptr->batch; + int64_t seq_len = q_based_ptr->seq_len; + + kutacc_tensor_h query_w = weight_ptr->query_w; + kutacc_tensor_h key_w = weight_ptr->key_w; + kutacc_tensor_h value_w = weight_ptr->value_w; + kutacc_tensor_h gating_w = weight_ptr->gating_w; + kutacc_tensor_h gating_b = weight_ptr->gating_b; + kutacc_tensor_h output_w = weight_ptr->output_w; + kutacc_tensor_h output_b = weight_ptr->output_b; + int64_t head_size = weight_ptr->head_size; + int64_t nheads = weight_ptr->nheads; + int64_t nchannels = weight_ptr->nchannels; + kutacc::gating_attention_kernel(*kutacc::convertKutaccTensor(input), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), - *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(weighted_avg), batch, seq_len, + *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(weight_avg), batch, seq_len, *kutacc::convertKutaccTensor(bias), *kutacc::convertKutaccTensor(nonbatched_bias), *kutacc::convertKutaccTensor(query_w), *kutacc::convertKutaccTensor(key_w), *kutacc::convertKutaccTensor(value_w), *kutacc::convertKutaccTensor(gating_w), *kutacc::convertKutaccTensor(gating_b), *kutacc::convertKutaccTensor(output_w), *kutacc::convertKutaccTensor(output_b), diff --git a/src/attention/global_attention.cpp b/src/attention/global_attention.cpp index 28c04e197bd2cfd7b38c1ed63bd1fc28a2b90abc..931c3c83f1dd47f9260f0ee7335747014cf83817 100644 --- a/src/attention/global_attention.cpp +++ b/src/attention/global_attention.cpp @@ -178,21 +178,34 @@ void global_attention_kernel(Tensor &q_avg, Tensor &q, Tensor &k, Tensor &v, int } } -kutacc_export void kutacc_af2_global_attention(kutacc_tensor_h q_avg, kutacc_tensor_h q, kutacc_tensor_h k, kutacc_tensor_h v, - int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size, kutacc_tensor_h gate, kutacc_tensor_h q_data, - kutacc_tensor_h m_data, kutacc_tensor_h q_mask, const kutacc_tensor_h query_w, const kutacc_tensor_h key_w, - const kutacc_tensor_h value_w, const kutacc_tensor_h gating_w, const kutacc_tensor_h gating_b, const kutacc_tensor_h output_w, - const kutacc_tensor_h output_b, kutacc_tensor_h out) +kutacc_export void kutacc_af2_global_attention(kutacc_af2_attention_inputs_t *q_based_ptr, kutacc_tensor_h q_data, + kutacc_tensor_h m_data, kutacc_tensor_h q_mask, kutacc_af2_attention_weights_t *weight_ptr, kutacc_tensor_h out) { - KUTACC_CHECK(q_avg != nullptr && q != nullptr && k != nullptr && v != nullptr && gate != nullptr && q_data != nullptr && m_data != nullptr - && q_mask != nullptr && query_w != nullptr && key_w != nullptr && value_w != nullptr && gating_w != nullptr && gating_b != nullptr - && output_w != nullptr && output_b != nullptr && out != nullptr, - "kutacc_af2_global_attention: input args nullptr error\n"); - KUTACC_CHECK(batch > 0 && seq_len > 0 && nchannels > 0 && nheads > 0 && head_size > 0, - "kutacc_af2_global_attention: input args int values error, values are less than or equal to zero\n"); + KUTACC_CHECK(q_based_ptr != nullptr && q_data != nullptr && m_data != nullptr && q_mask != nullptr && weight_ptr != nullptr + && weight_ptr != nullptr && out != nullptr, "kutacc_af2_gating_attention: input args nullptr error"); if (kutacc::kutacc_check_err_set == true) { return; } + + kutacc_tensor_h q = q_based_ptr->q; + kutacc_tensor_h k = q_based_ptr->k; + kutacc_tensor_h v = q_based_ptr->v; + kutacc_tensor_h gate = q_based_ptr->gate; + kutacc_tensor_h q_avg = q_based_ptr->avg; + int64_t batch = q_based_ptr->batch; + int64_t seq_len = q_based_ptr->seq_len; + + kutacc_tensor_h query_w = weight_ptr->query_w; + kutacc_tensor_h key_w = weight_ptr->key_w; + kutacc_tensor_h value_w = weight_ptr->value_w; + kutacc_tensor_h gating_w = weight_ptr->gating_w; + kutacc_tensor_h gating_b = weight_ptr->gating_b; + kutacc_tensor_h output_w = weight_ptr->output_w; + kutacc_tensor_h output_b = weight_ptr->output_b; + int64_t head_size = weight_ptr->head_size; + int64_t nheads = weight_ptr->nheads; + int64_t nchannels = weight_ptr->nchannels; + kutacc::global_attention_kernel(*kutacc::convertKutaccTensor(q_avg), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), *kutacc::convertKutaccTensor(v), batch, seq_len, nchannels, nheads, head_size, *kutacc::convertKutaccTensor(gate),*kutacc::convertKutaccTensor(q_data), *kutacc::convertKutaccTensor(q_mask), *kutacc::convertKutaccTensor(query_w), *kutacc::convertKutaccTensor(key_w), *kutacc::convertKutaccTensor(value_w), diff --git a/src/attention/invariant_point.cpp b/src/attention/invariant_point.cpp index 38292492a801d06c01c1d9d6c1c1a6a89c6a0d48..8504ee627c3e9b1f2ac061c8cbef82070891ec94 100644 --- a/src/attention/invariant_point.cpp +++ b/src/attention/invariant_point.cpp @@ -161,20 +161,40 @@ void kutacc_af2_invariant_point_kernel(Tensor q, Tensor k, Tensor v, Tensor q_pt } // namespace kutacc -void kutacc_af2_invariant_point(kutacc_tensor_h q, kutacc_tensor_h k, kutacc_tensor_h v, kutacc_tensor_h q_pts, kutacc_tensor_h k_pts, kutacc_tensor_h v_pts, - kutacc_tensor_h b, kutacc_tensor_h a, kutacc_tensor_h head_weights, kutacc_tensor_h weights_head_weights, kutacc_tensor_h o, kutacc_tensor_h o_pt, kutacc_tensor_h o_pt_norm, kutacc_tensor_h o_pair, - kutacc_tensor_h z, kutacc_tensor_h rigid_rot_mats, kutacc_tensor_h rigid_trans, kutacc_tensor_h mask, kutacc_tensor_h linear_b_w, kutacc_tensor_h linear_b_b, - int64_t n_res, int64_t c_z, int64_t c_hidden, int64_t no_heads, int64_t no_qk_points, int64_t no_v_points) +kutacc_export void kutacc_af2_invariant_point(kutacc_af2_ipa_s_inputs_t *ipa_s_ptrs, kutacc_af2_ipa_o_inputs_t *ipa_o_ptrs, kutacc_tensor_h z, kutacc_tensor_h rigid_rot_mats, + kutacc_tensor_h rigid_trans, kutacc_tensor_h mask, kutacc_af2_ipa_weights_t *ipa_weight_ptrs) { - KUTACC_CHECK(q != nullptr && k != nullptr && v != nullptr && q_pts != nullptr && k_pts != nullptr && v_pts != nullptr - && b != nullptr && a != nullptr && head_weights != nullptr && weights_head_weights != nullptr && o != nullptr && o_pt != nullptr && o_pt_norm != nullptr - && o_pair != nullptr && z != nullptr && rigid_rot_mats != nullptr && rigid_trans != nullptr && mask != nullptr && linear_b_w != nullptr && linear_b_b != nullptr, + KUTACC_CHECK(ipa_s_ptrs != nullptr && ipa_o_ptrs != nullptr && z != nullptr && rigid_rot_mats != nullptr && rigid_trans != nullptr && mask != nullptr && ipa_weight_ptrs != nullptr, "kutacc_af2_invariant_point: input args nullptr error\n"); - KUTACC_CHECK(n_res > 0 && c_z > 0 && c_hidden > 0 && no_heads > 0 && no_v_points > 0, - "kutacc_af2_invariant_point: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; } + + kutacc_tensor_h q = ipa_s_ptrs->q; + kutacc_tensor_h k = ipa_s_ptrs->k; + kutacc_tensor_h v = ipa_s_ptrs->v; + kutacc_tensor_h q_pts = ipa_s_ptrs->q_pts; + kutacc_tensor_h k_pts = ipa_s_ptrs->k_pts; + kutacc_tensor_h v_pts = ipa_s_ptrs->v_pts; + kutacc_tensor_h a = ipa_s_ptrs->a; + kutacc_tensor_h b = ipa_s_ptrs->b; + int64_t n_res = ipa_s_ptrs->n_res; + + kutacc_tensor_h o = ipa_o_ptrs->o; + kutacc_tensor_h o_pt = ipa_o_ptrs->o_pt; + kutacc_tensor_h o_pt_norm = ipa_o_ptrs->o_pt_norm; + kutacc_tensor_h o_pair = ipa_o_ptrs->o_pair; + + kutacc_tensor_h head_weights = ipa_weight_ptrs->head_weights; + kutacc_tensor_h weights_head_weights = ipa_weight_ptrs->weights_head_weights; + kutacc_tensor_h linear_b_w = ipa_weight_ptrs->linear_b_w; + kutacc_tensor_h linear_b_b = ipa_weight_ptrs->linear_b_b; + int64_t c_z = ipa_weight_ptrs->c_z; + int64_t c_hidden = ipa_weight_ptrs->c_hidden; + int64_t no_heads = ipa_weight_ptrs->no_heads; + int64_t no_qk_points = ipa_weight_ptrs->no_qk_points; + int64_t no_v_points = ipa_weight_ptrs->no_v_points; + kutacc::kutacc_af2_invariant_point_kernel(*kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(q_pts), *kutacc::convertKutaccTensor(k_pts), *kutacc::convertKutaccTensor(v_pts), *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(head_weights), *kutacc::convertKutaccTensor(weights_head_weights), *kutacc::convertKutaccTensor(o), *kutacc::convertKutaccTensor(o_pt), *kutacc::convertKutaccTensor(o_pt_norm), *kutacc::convertKutaccTensor(o_pair), *kutacc::convertKutaccTensor(z), *kutacc::convertKutaccTensor(rigid_rot_mats), *kutacc::convertKutaccTensor(rigid_trans), *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(linear_b_w), *kutacc::convertKutaccTensor(linear_b_b), diff --git a/src/attention/outer_product_mean.cpp b/src/attention/outer_product_mean.cpp index c0ff99e52d33d30f58c33c6a72b3e927630722bb..8ea7225687209bacfd2f0a8e17ddd341fedb04e1 100644 --- a/src/attention/outer_product_mean.cpp +++ b/src/attention/outer_product_mean.cpp @@ -20,7 +20,6 @@ #include "utils/memory.h" namespace kutacc { - /* * [OUT] left_proj, right_proj, left_proj_, right_proj_, mask * [IN] left_proj_w, left_proj_b, right_proj_w, right_proj_b @@ -161,17 +160,30 @@ void outer_product_mean_chunk_kernel(const Tensor &output_b, const Tensor &outpu } } -void kutacc_export kutacc_af2_outer_product_mean_calc_left_and_right_mul( - kutacc_tensor_h left_proj, kutacc_tensor_h right_proj, kutacc_tensor_h left_proj_, kutacc_tensor_h right_proj_, - kutacc_tensor_h input_act, kutacc_tensor_h mask, kutacc_tensor_h norm, const kutacc_tensor_h left_proj_w, - const kutacc_tensor_h left_proj_b, const kutacc_tensor_h right_proj_w, const kutacc_tensor_h right_proj_b, - int64_t c_i, int64_t c_m, int64_t n_res, int64_t n_res_gather, int64_t n_seq, int64_t mask_bias) +kutacc_export void kutacc_af2_outer_product_mean_calc_left_and_right_mul(kutacc_af2_opm_act_inputs_t *opm_acts_ptr, kutacc_af2_opm_mask_inputs_t *opm_masks_ptr, kutacc_af2_opm_weights_t *opm_weights_ptr) { - KUTACC_CHECK(left_proj != nullptr && right_proj != nullptr && left_proj_ != nullptr && right_proj_ != nullptr && input_act != nullptr && mask != nullptr - && norm != nullptr && left_proj_w != nullptr && left_proj_b != nullptr && right_proj_w != nullptr && right_proj_b != nullptr, - "kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args nullptr error\n"); - KUTACC_CHECK(c_i > 0 && c_m > 0 && n_res > 0 && n_res_gather > 0 && n_seq > 0 && mask_bias >= 0, - "kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args negative value error\n"); + KUTACC_CHECK(opm_acts_ptr != nullptr && opm_masks_ptr != nullptr && opm_weights_ptr != nullptr, "kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args nullptr error"); + + kutacc_tensor_h input_act = opm_acts_ptr->input_act; + kutacc_tensor_h left_proj = opm_acts_ptr->left_proj; + kutacc_tensor_h right_proj = opm_acts_ptr->right_proj; + kutacc_tensor_h left_proj_ = opm_acts_ptr->left_proj_; + kutacc_tensor_h right_proj_ = opm_acts_ptr->right_proj_; + int64_t n_seq = opm_acts_ptr->n_seq; + int64_t n_res = opm_acts_ptr->n_res; + + kutacc_tensor_h mask = opm_masks_ptr->mask; + kutacc_tensor_h norm = opm_masks_ptr->norm; + int64_t n_res_gather = opm_masks_ptr->n_res_gather; + int64_t mask_bias = opm_masks_ptr->mask_bias; + + kutacc_tensor_h left_proj_w = opm_weights_ptr->left_proj_w; + kutacc_tensor_h right_proj_w = opm_weights_ptr->right_proj_w; + kutacc_tensor_h left_proj_b = opm_weights_ptr->left_proj_b; + kutacc_tensor_h right_proj_b = opm_weights_ptr->right_proj_b; + int64_t c_i = opm_weights_ptr->c_i; + int64_t c_m = opm_weights_ptr->c_m; + if (kutacc::kutacc_check_err_set == true) { return; } @@ -183,18 +195,30 @@ void kutacc_export kutacc_af2_outer_product_mean_calc_left_and_right_mul( mask_bias); } -void kutacc_export kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h output_b, const kutacc_tensor_h output_w, - kutacc_tensor_h out, kutacc_tensor_h left_proj_, kutacc_tensor_h right_proj_, - kutacc_tensor_h norm, int64_t left_block_size, int64_t right_block_size, - int64_t c_i, int64_t c_z, int64_t n_res, int64_t n_res_gather, int64_t n_seq) +kutacc_export void kutacc_af2_outer_product_mean_chunk(kutacc_af2_opm_act_inputs_t *opm_acts_ptr, kutacc_af2_opm_mask_inputs_t *opm_masks_ptr, kutacc_af2_opm_weights_t *opm_weights_ptr,kutacc_tensor_h out, + int64_t left_block_size, int64_t right_block_size) { - KUTACC_CHECK(output_b != nullptr && output_w != nullptr && out != nullptr && left_proj_ != nullptr && right_proj_ != nullptr && norm != nullptr, + KUTACC_CHECK(opm_acts_ptr != nullptr && opm_masks_ptr != nullptr && opm_weights_ptr != nullptr && out != nullptr, "kutacc_af2_outer_product_mean_chunk: input args nullptr error"); - KUTACC_CHECK(left_block_size > 0 && right_block_size > 0 && c_i > 0 && c_z > 0 && n_res > 0 && n_res_gather > 0 && n_seq > 0, + KUTACC_CHECK(left_block_size > 0 && right_block_size > 0 && left_block_size < INT64_MAX && right_block_size < INT64_MAX, "kutacc_af2_outer_product_mean_chunk: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; - } + } + + kutacc_tensor_h output_w = opm_weights_ptr->outer_w; + kutacc_tensor_h output_b = opm_weights_ptr->outer_b; + int64_t c_i = opm_weights_ptr->c_i; + int64_t c_z = opm_weights_ptr->c_z; + + kutacc_tensor_h left_proj_ = opm_acts_ptr->left_proj_; + kutacc_tensor_h right_proj_ = opm_acts_ptr->right_proj_; + int64_t n_seq = opm_acts_ptr->n_seq; + int64_t n_res = opm_acts_ptr->n_res; + + kutacc_tensor_h norm = opm_masks_ptr->norm; + int64_t n_res_gather = opm_masks_ptr->n_res_gather; + kutacc::outer_product_mean_chunk_kernel(*kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(output_w), *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(left_proj_), *kutacc::convertKutaccTensor(right_proj_), *kutacc::convertKutaccTensor(norm), left_block_size, right_block_size, c_i, c_z, n_res, n_res_gather, n_seq); diff --git a/src/attention/transition.cpp b/src/attention/transition.cpp index cdf28db158e5112aff2884fa4d600df257ac825d..fc7ef0623f2ae30edcc6e29ed7e67cebbd2732ff 100644 --- a/src/attention/transition.cpp +++ b/src/attention/transition.cpp @@ -46,17 +46,25 @@ void transition_kernel(Tensor &input_act, const Tensor &linear1_w, Tensor &linea } } -void kutacc_af2_transition(kutacc_tensor_h input_act, const kutacc_tensor_h linear1_w, kutacc_tensor_h linear1_b, - kutacc_tensor_h linear2_w, kutacc_tensor_h linear2_b, kutacc_tensor_h intermediate_act, kutacc_tensor_h out, - int64_t batch, int64_t n_res, int64_t c_o, int64_t c_i) +kutacc_export void kutacc_af2_transition(kutacc_af2_trans_act_inputs_t *trans_inputs_ptr, kutacc_af2_trans_weights_t *trans_weights_ptr, kutacc_tensor_h out) { - KUTACC_CHECK(input_act != nullptr && linear1_w != nullptr && linear1_b != nullptr && linear2_w != nullptr && linear2_b != nullptr && intermediate_act != nullptr && out != nullptr, - "kutacc_af2_transition: input args nullptr error\n"); - KUTACC_CHECK(batch > 0 && n_res > 0 && c_o > 0 && c_i > 0, - "kutacc_af2_transition: input args int values error, values are less than or equal to zero\n"); + KUTACC_CHECK(trans_inputs_ptr != nullptr && trans_weights_ptr != nullptr && out != nullptr, "af2_transition: input args nullptr error\n"); if (kutacc::kutacc_check_err_set == true) { return; } + + kutacc_tensor_h input_act = trans_inputs_ptr->input_act; + kutacc_tensor_h intermediate_act = trans_inputs_ptr->intermediate_act; + kutacc_tensor_h linear1_w = trans_weights_ptr->linear1_w; + kutacc_tensor_h linear1_b = trans_weights_ptr->linear1_b; + kutacc_tensor_h linear2_w = trans_weights_ptr->linear2_w; + kutacc_tensor_h linear2_b = trans_weights_ptr->linear2_b; + + int64_t batch = trans_inputs_ptr->batch; + int64_t n_res = trans_inputs_ptr->n_res; + int64_t c_o = trans_weights_ptr->c_o; + int64_t c_i = trans_weights_ptr->c_i; + kutacc::transition_kernel(*kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(linear1_w), *kutacc::convertKutaccTensor(linear1_b), *kutacc::convertKutaccTensor(linear2_w), *kutacc::convertKutaccTensor(linear2_b), *kutacc::convertKutaccTensor(intermediate_act), *kutacc::convertKutaccTensor(out), batch, n_res, c_o, c_i); diff --git a/src/attention/triangle_multiplication.cpp b/src/attention/triangle_multiplication.cpp index 64e76f60282ef001a722326173660e2de7471c0b..5392aff9f46027f93c5062e0abee59cdb8eb80f6 100644 --- a/src/attention/triangle_multiplication.cpp +++ b/src/attention/triangle_multiplication.cpp @@ -169,17 +169,26 @@ void last(Tensor &out, Tensor &gate, int64_t n_res, int64_t n_res_gather, int64_ } } -kutacc_export void kutacc_af2_triangle_multiplication_calc_proj(kutacc_tensor_h proj_act, kutacc_tensor_h gate, kutacc_tensor_h input_act, kutacc_tensor_h mask, - const kutacc_tensor_h proj_w, const kutacc_tensor_h proj_b, const kutacc_tensor_h gate_w, const kutacc_tensor_h gate_b, int64_t n_res, - int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack) +kutacc_export void kutacc_af2_triangle_multiplication_calc_proj(kutacc_af2_tm_act_inputs_t *tm_acts_ptr, kutacc_tensor_h mask, kutacc_af2_tm_proj_weights_t *tm_weights_ptr, bool input_prepack) { - KUTACC_CHECK(proj_act != nullptr && gate != nullptr && input_act != nullptr && mask != nullptr && proj_w != nullptr && proj_b != nullptr && gate_w != nullptr && gate_b != nullptr, + KUTACC_CHECK(tm_acts_ptr != nullptr && mask != nullptr && tm_weights_ptr != nullptr, "kutacc_af2_triangle_multiplication_calc_proj: input args nullptr error\n"); - KUTACC_CHECK(n_res > 0 && n_res_gather > 0 && c_o > 0 && c_i > 0, - "kutacc_af2_triangle_multiplication_calc_proj: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; } + + kutacc_tensor_h proj_act = tm_acts_ptr->proj_act; + kutacc_tensor_h input_act = tm_acts_ptr->input_act; + kutacc_tensor_h gate = tm_acts_ptr->proj_act_gate; + int64_t n_res = tm_acts_ptr->n_res; + int64_t n_res_gather = tm_acts_ptr->n_res_gather; + + kutacc_tensor_h proj_w = tm_weights_ptr->proj_w; + kutacc_tensor_h proj_b = tm_weights_ptr->proj_b; + kutacc_tensor_h gate_w = tm_weights_ptr->gate_w; + kutacc_tensor_h gate_b = tm_weights_ptr->gate_b; + int64_t c_o = tm_weights_ptr->c_o; + int64_t c_i = tm_weights_ptr->c_i; kutacc::calc_proj_act(*kutacc::convertKutaccTensor(proj_act), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(proj_w), *kutacc::convertKutaccTensor(proj_b), *kutacc::convertKutaccTensor(gate_w), *kutacc::convertKutaccTensor(gate_b), n_res, n_res_gather, c_o, c_i, input_prepack); @@ -190,7 +199,7 @@ kutacc_export void kutacc_af2_triangle_multiplication_equation(kutacc_tensor_h c { KUTACC_CHECK(center_act != nullptr && left_proj_act != nullptr && right_proj_act != nullptr, "kutacc_af2_triangle_multiplication_equation: input args nullptr error\n"); - KUTACC_CHECK(n_res_gather > 0, + KUTACC_CHECK(n_res_gather > 0 && n_res_gather < INT64_MAX, "kutacc_af2_triangle_multiplication_equation: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; @@ -199,17 +208,26 @@ kutacc_export void kutacc_af2_triangle_multiplication_equation(kutacc_tensor_h c n_res_gather, is_incoming); } -kutacc_export void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc_tensor_h gate, kutacc_tensor_h out, kutacc_tensor_h input_act, kutacc_tensor_h center_act, - const kutacc_tensor_h gating_w, const kutacc_tensor_h gating_b, const kutacc_tensor_h output_proj_w, const kutacc_tensor_h output_proj_b, - int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack) +kutacc_export void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc_tensor_h gate, kutacc_tensor_h out, kutacc_af2_tm_act_inputs_t *tm_acts_ptr, kutacc_tensor_h center_act, + kutacc_af2_tm_linear_weights_t *tm_weights_ptr, bool input_prepack) { - KUTACC_CHECK(gate != nullptr && out != nullptr && input_act != nullptr && center_act != nullptr && gating_w != nullptr && gating_b != nullptr && output_proj_w != nullptr && output_proj_b != nullptr, + KUTACC_CHECK(gate != nullptr && out != nullptr && tm_acts_ptr != nullptr && center_act != nullptr && tm_weights_ptr != nullptr, "kutacc_af2_triangle_multiplication_gate_and_out_linear: input args nullptr error\n"); - KUTACC_CHECK(n_res > 0 && n_res_gather > 0 && c_o > 0 && c_i > 0, - "kutacc_af2_triangle_multiplication_gate_and_out_linear: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; } + + kutacc_tensor_h input_act = tm_acts_ptr->input_act; + int64_t n_res = tm_acts_ptr->n_res; + int64_t n_res_gather = tm_acts_ptr->n_res_gather; + + kutacc_tensor_h gating_w = tm_weights_ptr->gating_w; + kutacc_tensor_h gating_b = tm_weights_ptr->gating_b; + kutacc_tensor_h output_proj_w = tm_weights_ptr->output_proj_w; + kutacc_tensor_h output_proj_b = tm_weights_ptr->output_proj_b; + int64_t c_o = tm_weights_ptr->c_o; + int64_t c_i = tm_weights_ptr->c_i; + kutacc::gate_and_out_linear(*kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(center_act), *kutacc::convertKutaccTensor(gating_w), *kutacc::convertKutaccTensor(gating_b), *kutacc::convertKutaccTensor(output_proj_w), *kutacc::convertKutaccTensor(output_proj_b), n_res, n_res_gather, c_o, c_i, input_prepack); @@ -218,7 +236,7 @@ kutacc_export void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc kutacc_export void kutacc_af2_triangle_multiplication_last(kutacc_tensor_h out, kutacc_tensor_h gate, int64_t n_res, int64_t n_res_gather, int64_t c_o) { KUTACC_CHECK(out != nullptr && gate != nullptr, "kutacc_af2_triangle_multiplication_last: input args nullptr error\n"); - KUTACC_CHECK(n_res > 0 && n_res_gather > 0 && c_o > 0, + KUTACC_CHECK(n_res > 0 && n_res_gather > 0 && c_o > 0 && n_res < INT64_MAX && n_res_gather < INT64_MAX && c_o < INT64_MAX, "kutacc_af2_triangle_multiplication_last: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return;