diff --git a/include/kutacc.h b/include/kutacc.h index 11c214b5b3e88b5397bfecf5e7c03884bf6d2d71..b01e2b53a55e0486983e674bac4a299a317ea069 100644 --- a/include/kutacc.h +++ b/include/kutacc.h @@ -111,6 +111,7 @@ kutacc_export void kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h out * @param [in] GatingAttentionWeight * @param [out] out * @return Null + * constraint: nchannels = nheads * head_size */ kutacc_export void kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_tensor_h q, kutacc_tensor_h k, kutacc_tensor_h v, kutacc_tensor_h gate, kutacc_tensor_h weighted_avg, int64_t batch, int64_t seq_len, kutacc_tensor_h m_data, kutacc_tensor_h bias, kutacc_tensor_h nonbatched_bias, diff --git a/src/attention/gating_attention.cpp b/src/attention/gating_attention.cpp index 78c6dcbf5cb588fdf667f03b2a7ed1fc144e5b22..61cc8c3738f2569211a7a033a52acb39c34c975c 100644 --- a/src/attention/gating_attention.cpp +++ b/src/attention/gating_attention.cpp @@ -159,11 +159,12 @@ void kutacc_export kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_ten KUTACC_CHECK(input != nullptr && q != nullptr && k != nullptr && v != nullptr && gate != nullptr && weighted_avg != nullptr && m_data != nullptr && bias != nullptr && nonbatched_bias != nullptr && query_w != nullptr && key_w != nullptr && value_w != nullptr && gating_w != nullptr && gating_b != nullptr && output_w != nullptr && output_b != nullptr && out != nullptr, - "kutacc_af2_gating_attention: input args nullptr error"); + "kutacc_af2_gating_attention: input args nullptr error\n"); + KUTACC_CHECK(batch > 0 && seq_len > 0 && block_size_ > 0 && head_size > 0 && nheads > 0 && nchannels > 0, + "kutacc_af2_gating_attention: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; } - kutacc::gating_attention_kernel(*kutacc::convertKutaccTensor(input), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(weighted_avg), batch, seq_len, *kutacc::convertKutaccTensor(bias), *kutacc::convertKutaccTensor(nonbatched_bias), *kutacc::convertKutaccTensor(query_w), diff --git a/src/attention/global_attention.cpp b/src/attention/global_attention.cpp index e5ce6404fa4ca80ebdb9c923fcd43be4f5320fd7..28c04e197bd2cfd7b38c1ed63bd1fc28a2b90abc 100644 --- a/src/attention/global_attention.cpp +++ b/src/attention/global_attention.cpp @@ -187,7 +187,9 @@ kutacc_export void kutacc_af2_global_attention(kutacc_tensor_h q_avg, kutacc_ten KUTACC_CHECK(q_avg != nullptr && q != nullptr && k != nullptr && v != nullptr && gate != nullptr && q_data != nullptr && m_data != nullptr && q_mask != nullptr && query_w != nullptr && key_w != nullptr && value_w != nullptr && gating_w != nullptr && gating_b != nullptr && output_w != nullptr && output_b != nullptr && out != nullptr, - "kutacc_af2_global_attention: input args nullptr error"); + "kutacc_af2_global_attention: input args nullptr error\n"); + KUTACC_CHECK(batch > 0 && seq_len > 0 && nchannels > 0 && nheads > 0 && head_size > 0, + "kutacc_af2_global_attention: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; } diff --git a/src/attention/invariant_point.cpp b/src/attention/invariant_point.cpp index f1545754f99ac96dc3d419122493ceb3bde9fd49..38292492a801d06c01c1d9d6c1c1a6a89c6a0d48 100644 --- a/src/attention/invariant_point.cpp +++ b/src/attention/invariant_point.cpp @@ -169,7 +169,9 @@ void kutacc_af2_invariant_point(kutacc_tensor_h q, kutacc_tensor_h k, kutacc_ten KUTACC_CHECK(q != nullptr && k != nullptr && v != nullptr && q_pts != nullptr && k_pts != nullptr && v_pts != nullptr && b != nullptr && a != nullptr && head_weights != nullptr && weights_head_weights != nullptr && o != nullptr && o_pt != nullptr && o_pt_norm != nullptr && o_pair != nullptr && z != nullptr && rigid_rot_mats != nullptr && rigid_trans != nullptr && mask != nullptr && linear_b_w != nullptr && linear_b_b != nullptr, - "kutacc_af2_invariant_point: input args nullptr error"); + "kutacc_af2_invariant_point: input args nullptr error\n"); + KUTACC_CHECK(n_res > 0 && c_z > 0 && c_hidden > 0 && no_heads > 0 && no_v_points > 0, + "kutacc_af2_invariant_point: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; } @@ -177,6 +179,5 @@ void kutacc_af2_invariant_point(kutacc_tensor_h q, kutacc_tensor_h k, kutacc_ten *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(head_weights), *kutacc::convertKutaccTensor(weights_head_weights), *kutacc::convertKutaccTensor(o), *kutacc::convertKutaccTensor(o_pt), *kutacc::convertKutaccTensor(o_pt_norm), *kutacc::convertKutaccTensor(o_pair), *kutacc::convertKutaccTensor(z), *kutacc::convertKutaccTensor(rigid_rot_mats), *kutacc::convertKutaccTensor(rigid_trans), *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(linear_b_w), *kutacc::convertKutaccTensor(linear_b_b), n_res, c_z, c_hidden, no_heads, no_qk_points, no_v_points); - } diff --git a/src/attention/outer_product_mean.cpp b/src/attention/outer_product_mean.cpp index e9fdd5df9aa6c41b1f862a6d205e2d212e2de349..c0ff99e52d33d30f58c33c6a72b3e927630722bb 100644 --- a/src/attention/outer_product_mean.cpp +++ b/src/attention/outer_product_mean.cpp @@ -169,7 +169,9 @@ void kutacc_export kutacc_af2_outer_product_mean_calc_left_and_right_mul( { KUTACC_CHECK(left_proj != nullptr && right_proj != nullptr && left_proj_ != nullptr && right_proj_ != nullptr && input_act != nullptr && mask != nullptr && norm != nullptr && left_proj_w != nullptr && left_proj_b != nullptr && right_proj_w != nullptr && right_proj_b != nullptr, - "kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args nullptr error"); + "kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args nullptr error\n"); + KUTACC_CHECK(c_i > 0 && c_m > 0 && n_res > 0 && n_res_gather > 0 && n_seq > 0 && mask_bias >= 0, + "kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args negative value error\n"); if (kutacc::kutacc_check_err_set == true) { return; } @@ -188,10 +190,12 @@ void kutacc_export kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h out { KUTACC_CHECK(output_b != nullptr && output_w != nullptr && out != nullptr && left_proj_ != nullptr && right_proj_ != nullptr && norm != nullptr, "kutacc_af2_outer_product_mean_chunk: input args nullptr error"); + KUTACC_CHECK(left_block_size > 0 && right_block_size > 0 && c_i > 0 && c_z > 0 && n_res > 0 && n_res_gather > 0 && n_seq > 0, + "kutacc_af2_outer_product_mean_chunk: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; - } + } kutacc::outer_product_mean_chunk_kernel(*kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(output_w), - *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(left_proj_), *kutacc::convertKutaccTensor(right_proj_), - *kutacc::convertKutaccTensor(norm), left_block_size, right_block_size, c_i, c_z, n_res, n_res_gather, n_seq); + *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(left_proj_), *kutacc::convertKutaccTensor(right_proj_), + *kutacc::convertKutaccTensor(norm), left_block_size, right_block_size, c_i, c_z, n_res, n_res_gather, n_seq); } \ No newline at end of file diff --git a/src/attention/rigid.cpp b/src/attention/rigid.cpp index 94ca9980d129ceb88774bfcdcf9bd6c9e38a4ee0..e2b205c866c13cedbb93c1b738e2a11d2d8fadae 100644 --- a/src/attention/rigid.cpp +++ b/src/attention/rigid.cpp @@ -121,6 +121,6 @@ void kutacc_af2_rigid_rot_matmul(kutacc_tensor_h a, kutacc_tensor_h b, kutacc_te KUTACC_CHECK(a != nullptr && b != nullptr && out != nullptr, "kutacc_af2_rigid_rot_matmul: input args nullptr error"); if (kutacc::kutacc_check_err_set == true) { return; - } + } kutacc::rigid_rot_matmul(*kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(out)); } \ No newline at end of file diff --git a/src/attention/transition.cpp b/src/attention/transition.cpp index df81d9e9d151a4ac5150b732c4e898df83a60a25..cdf28db158e5112aff2884fa4d600df257ac825d 100644 --- a/src/attention/transition.cpp +++ b/src/attention/transition.cpp @@ -51,7 +51,9 @@ void kutacc_af2_transition(kutacc_tensor_h input_act, const kutacc_tensor_h line int64_t batch, int64_t n_res, int64_t c_o, int64_t c_i) { KUTACC_CHECK(input_act != nullptr && linear1_w != nullptr && linear1_b != nullptr && linear2_w != nullptr && linear2_b != nullptr && intermediate_act != nullptr && out != nullptr, - "kutacc_af2_transition: input args nullptr error"); + "kutacc_af2_transition: input args nullptr error\n"); + KUTACC_CHECK(batch > 0 && n_res > 0 && c_o > 0 && c_i > 0, + "kutacc_af2_transition: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; } diff --git a/src/attention/triangle_multiplication.cpp b/src/attention/triangle_multiplication.cpp index 6d3978fc8e2fd5f2d9260cea94e5145251b6cbaa..64e76f60282ef001a722326173660e2de7471c0b 100644 --- a/src/attention/triangle_multiplication.cpp +++ b/src/attention/triangle_multiplication.cpp @@ -174,7 +174,9 @@ kutacc_export void kutacc_af2_triangle_multiplication_calc_proj(kutacc_tensor_h int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack) { KUTACC_CHECK(proj_act != nullptr && gate != nullptr && input_act != nullptr && mask != nullptr && proj_w != nullptr && proj_b != nullptr && gate_w != nullptr && gate_b != nullptr, - "kutacc_af2_triangle_multiplication_calc_proj: input args nullptr error"); + "kutacc_af2_triangle_multiplication_calc_proj: input args nullptr error\n"); + KUTACC_CHECK(n_res > 0 && n_res_gather > 0 && c_o > 0 && c_i > 0, + "kutacc_af2_triangle_multiplication_calc_proj: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; } @@ -186,7 +188,10 @@ kutacc_export void kutacc_af2_triangle_multiplication_calc_proj(kutacc_tensor_h kutacc_export void kutacc_af2_triangle_multiplication_equation(kutacc_tensor_h center_act, kutacc_tensor_h left_proj_act, kutacc_tensor_h right_proj_act, int64_t n_res_gather, bool is_incoming) { - KUTACC_CHECK(center_act != nullptr && left_proj_act != nullptr && right_proj_act != nullptr, "kutacc_af2_triangle_multiplication_equation: input args nullptr error"); + KUTACC_CHECK(center_act != nullptr && left_proj_act != nullptr && right_proj_act != nullptr, + "kutacc_af2_triangle_multiplication_equation: input args nullptr error\n"); + KUTACC_CHECK(n_res_gather > 0, + "kutacc_af2_triangle_multiplication_equation: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; } @@ -199,7 +204,9 @@ kutacc_export void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack) { KUTACC_CHECK(gate != nullptr && out != nullptr && input_act != nullptr && center_act != nullptr && gating_w != nullptr && gating_b != nullptr && output_proj_w != nullptr && output_proj_b != nullptr, - "kutacc_af2_triangle_multiplication_gate_and_out_linear: input args nullptr error"); + "kutacc_af2_triangle_multiplication_gate_and_out_linear: input args nullptr error\n"); + KUTACC_CHECK(n_res > 0 && n_res_gather > 0 && c_o > 0 && c_i > 0, + "kutacc_af2_triangle_multiplication_gate_and_out_linear: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; } @@ -210,7 +217,9 @@ kutacc_export void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc kutacc_export void kutacc_af2_triangle_multiplication_last(kutacc_tensor_h out, kutacc_tensor_h gate, int64_t n_res, int64_t n_res_gather, int64_t c_o) { - KUTACC_CHECK(out != nullptr && gate != nullptr, "kutacc_af2_triangle_multiplication_last: input args nullptr error"); + KUTACC_CHECK(out != nullptr && gate != nullptr, "kutacc_af2_triangle_multiplication_last: input args nullptr error\n"); + KUTACC_CHECK(n_res > 0 && n_res_gather > 0 && c_o > 0, + "kutacc_af2_triangle_multiplication_last: input args int values error, values are less than or equal to zero\n"); if (kutacc::kutacc_check_err_set == true) { return; }