From d8ccaed24ee2afab3144f47c4c85de6485739923 Mon Sep 17 00:00:00 2001 From: XeonYZhang Date: Mon, 27 Oct 2025 09:33:16 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E5=A2=9E=E5=8A=A0int=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E5=8F=82=E6=95=B0=E9=9D=9E=E8=B4=9F=E6=A3=80?= =?UTF-8?q?=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/kutacc.h | 4 ++ src/CMakeLists.txt | 2 +- src/attention/gating_attention.cpp | 27 ++++++------ src/attention/global_attention.cpp | 24 ++++++----- src/attention/invariant_point.cpp | 23 +++++----- src/attention/outer_product_mean.cpp | 39 ++++++++++------- src/attention/rigid.cpp | 18 ++++---- src/attention/transition.cpp | 15 ++++--- src/attention/triangle_multiplication.cpp | 52 ++++++++++++++--------- 9 files changed, 122 insertions(+), 82 deletions(-) diff --git a/include/kutacc.h b/include/kutacc.h index 11c214b..ad047c6 100644 --- a/include/kutacc.h +++ b/include/kutacc.h @@ -18,6 +18,9 @@ #include #include +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + #ifdef __cplusplus extern "C" { #endif @@ -111,6 +114,7 @@ kutacc_export void kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h out * @param [in] GatingAttentionWeight * @param [out] out * @return Null + * constraint: nchannels = nheads * head_size */ kutacc_export void kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_tensor_h q, kutacc_tensor_h k, kutacc_tensor_h v, kutacc_tensor_h gate, kutacc_tensor_h weighted_avg, int64_t batch, int64_t seq_len, kutacc_tensor_h m_data, kutacc_tensor_h bias, kutacc_tensor_h nonbatched_bias, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index afcf98b..ec3cd01 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -30,7 +30,7 @@ elseif (CMAKE_BUILD_TYPE MATCHES "Release") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wcast-align -Wcast-qual -pipe -Wconversion") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wframe-larger-than=16384 -Wvla -fstack-protector-strong") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wstrict-prototypes") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstack-protector-strong") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstack-protector-strong -Wno-parentheses -fPIC -DF_INTERFACE_FLANG -mllvm -disable-lsr -Wno-unused-command-line-argument") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -s -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -Wl,-Bsymbolic -rdynamic -Wl,--no-undefined") add_compile_options(-O3) endif() diff --git a/src/attention/gating_attention.cpp b/src/attention/gating_attention.cpp index 78c6dcb..1b63c40 100644 --- a/src/attention/gating_attention.cpp +++ b/src/attention/gating_attention.cpp @@ -156,18 +156,21 @@ void kutacc_export kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_ten const kutacc_tensor_h value_w, const kutacc_tensor_h gating_w, const kutacc_tensor_h gating_b, const kutacc_tensor_h output_w, const kutacc_tensor_h output_b, kutacc_tensor_h out, int64_t block_size_, int64_t head_size, int64_t nheads, int64_t nchannels) { - KUTACC_CHECK(input != nullptr && q != nullptr && k != nullptr && v != nullptr && gate != nullptr && weighted_avg != nullptr && m_data != nullptr - && bias != nullptr && nonbatched_bias != nullptr && query_w != nullptr && key_w != nullptr && value_w != nullptr && gating_w != nullptr && gating_b != nullptr - && output_w != nullptr && output_b != nullptr && out != nullptr, - "kutacc_af2_gating_attention: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(input == nullptr || q == nullptr || k == nullptr || v == nullptr || gate == nullptr || weighted_avg == nullptr || m_data == nullptr + || bias == nullptr || nonbatched_bias == nullptr || query_w == nullptr || key_w == nullptr || value_w == nullptr || gating_w == nullptr || gating_b == nullptr + || output_w == nullptr || output_b == nullptr || out == nullptr)) + { + printf("kutacc_af2_gating_attention: input args nullptr error"); return; + } else if (unlikely(batch < 0 || seq_len < 0 || block_size_ < 0 || head_size < 0 || nheads < 0 || nchannels < 0)) { + printf("kutacc_af2_gating_attention: input args negative value error"); + return; + } else { + kutacc::gating_attention_kernel(*kutacc::convertKutaccTensor(input), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), + *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(weighted_avg), batch, seq_len, + *kutacc::convertKutaccTensor(bias), *kutacc::convertKutaccTensor(nonbatched_bias), *kutacc::convertKutaccTensor(query_w), + *kutacc::convertKutaccTensor(key_w), *kutacc::convertKutaccTensor(value_w), *kutacc::convertKutaccTensor(gating_w), + *kutacc::convertKutaccTensor(gating_b), *kutacc::convertKutaccTensor(output_w), *kutacc::convertKutaccTensor(output_b), + *kutacc::convertKutaccTensor(out), block_size_, head_size, nheads, nchannels); } - - kutacc::gating_attention_kernel(*kutacc::convertKutaccTensor(input), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), - *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(weighted_avg), batch, seq_len, - *kutacc::convertKutaccTensor(bias), *kutacc::convertKutaccTensor(nonbatched_bias), *kutacc::convertKutaccTensor(query_w), - *kutacc::convertKutaccTensor(key_w), *kutacc::convertKutaccTensor(value_w), *kutacc::convertKutaccTensor(gating_w), - *kutacc::convertKutaccTensor(gating_b), *kutacc::convertKutaccTensor(output_w), *kutacc::convertKutaccTensor(output_b), - *kutacc::convertKutaccTensor(out), block_size_, head_size, nheads, nchannels); } \ No newline at end of file diff --git a/src/attention/global_attention.cpp b/src/attention/global_attention.cpp index e5ce640..6bc075d 100644 --- a/src/attention/global_attention.cpp +++ b/src/attention/global_attention.cpp @@ -184,16 +184,20 @@ kutacc_export void kutacc_af2_global_attention(kutacc_tensor_h q_avg, kutacc_ten const kutacc_tensor_h value_w, const kutacc_tensor_h gating_w, const kutacc_tensor_h gating_b, const kutacc_tensor_h output_w, const kutacc_tensor_h output_b, kutacc_tensor_h out) { - KUTACC_CHECK(q_avg != nullptr && q != nullptr && k != nullptr && v != nullptr && gate != nullptr && q_data != nullptr && m_data != nullptr - && q_mask != nullptr && query_w != nullptr && key_w != nullptr && value_w != nullptr && gating_w != nullptr && gating_b != nullptr - && output_w != nullptr && output_b != nullptr && out != nullptr, - "kutacc_af2_global_attention: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(q_avg == nullptr || q == nullptr || k == nullptr || v == nullptr || gate == nullptr || q_data == nullptr || m_data == nullptr + || q_mask == nullptr || query_w == nullptr || key_w == nullptr || value_w == nullptr || gating_w == nullptr || gating_b == nullptr + || output_w == nullptr || output_b == nullptr || out == nullptr)) + { + printf("kutacc_af2_global_attention: input args nullptr error"); return; + } else if (unlikely(batch < 0 || seq_len < 0 || nchannels < 0 || nchannels >= INT64_MAX || nheads < 0 || head_size < 0)) { + printf("kutacc_af2_global_attention: input args int values error"); + return; + } else { + kutacc::global_attention_kernel(*kutacc::convertKutaccTensor(q_avg), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), + *kutacc::convertKutaccTensor(v), batch, seq_len, nchannels, nheads, head_size, *kutacc::convertKutaccTensor(gate),*kutacc::convertKutaccTensor(q_data), + *kutacc::convertKutaccTensor(q_mask), *kutacc::convertKutaccTensor(query_w), *kutacc::convertKutaccTensor(key_w), *kutacc::convertKutaccTensor(value_w), + *kutacc::convertKutaccTensor(gating_w), *kutacc::convertKutaccTensor(gating_b), *kutacc::convertKutaccTensor(output_w), + *kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(out)); } - kutacc::global_attention_kernel(*kutacc::convertKutaccTensor(q_avg), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), - *kutacc::convertKutaccTensor(v), batch, seq_len, nchannels, nheads, head_size, *kutacc::convertKutaccTensor(gate),*kutacc::convertKutaccTensor(q_data), - *kutacc::convertKutaccTensor(q_mask), *kutacc::convertKutaccTensor(query_w), *kutacc::convertKutaccTensor(key_w), *kutacc::convertKutaccTensor(value_w), - *kutacc::convertKutaccTensor(gating_w), *kutacc::convertKutaccTensor(gating_b), *kutacc::convertKutaccTensor(output_w), - *kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(out)); } \ No newline at end of file diff --git a/src/attention/invariant_point.cpp b/src/attention/invariant_point.cpp index f154575..fab244c 100644 --- a/src/attention/invariant_point.cpp +++ b/src/attention/invariant_point.cpp @@ -166,17 +166,20 @@ void kutacc_af2_invariant_point(kutacc_tensor_h q, kutacc_tensor_h k, kutacc_ten kutacc_tensor_h z, kutacc_tensor_h rigid_rot_mats, kutacc_tensor_h rigid_trans, kutacc_tensor_h mask, kutacc_tensor_h linear_b_w, kutacc_tensor_h linear_b_b, int64_t n_res, int64_t c_z, int64_t c_hidden, int64_t no_heads, int64_t no_qk_points, int64_t no_v_points) { - KUTACC_CHECK(q != nullptr && k != nullptr && v != nullptr && q_pts != nullptr && k_pts != nullptr && v_pts != nullptr - && b != nullptr && a != nullptr && head_weights != nullptr && weights_head_weights != nullptr && o != nullptr && o_pt != nullptr && o_pt_norm != nullptr - && o_pair != nullptr && z != nullptr && rigid_rot_mats != nullptr && rigid_trans != nullptr && mask != nullptr && linear_b_w != nullptr && linear_b_b != nullptr, - "kutacc_af2_invariant_point: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(q == nullptr || k == nullptr || v == nullptr || q_pts == nullptr || k_pts == nullptr || v_pts == nullptr + || b == nullptr || a == nullptr || head_weights == nullptr || weights_head_weights == nullptr || o == nullptr || o_pt == nullptr || o_pt_norm == nullptr + || o_pair == nullptr || z == nullptr || rigid_rot_mats == nullptr || rigid_trans == nullptr || mask == nullptr || linear_b_w == nullptr || linear_b_b == nullptr)) + { + printf("kutacc_af2_invariant_point: input args nullptr error"); return; + } else if (unlikely(n_res < 0 || c_z < 0 || c_hidden < 0 || no_heads < 0 || no_v_points < 0)) { + printf("kutacc_af2_invariant_point: input args negative value error"); + return; + } else { + kutacc::kutacc_af2_invariant_point_kernel(*kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(q_pts), *kutacc::convertKutaccTensor(k_pts), *kutacc::convertKutaccTensor(v_pts), + *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(head_weights), *kutacc::convertKutaccTensor(weights_head_weights), *kutacc::convertKutaccTensor(o), *kutacc::convertKutaccTensor(o_pt), *kutacc::convertKutaccTensor(o_pt_norm), *kutacc::convertKutaccTensor(o_pair), + *kutacc::convertKutaccTensor(z), *kutacc::convertKutaccTensor(rigid_rot_mats), *kutacc::convertKutaccTensor(rigid_trans), *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(linear_b_w), *kutacc::convertKutaccTensor(linear_b_b), + n_res, c_z, c_hidden, no_heads, no_qk_points, no_v_points); } - kutacc::kutacc_af2_invariant_point_kernel(*kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(q_pts), *kutacc::convertKutaccTensor(k_pts), *kutacc::convertKutaccTensor(v_pts), - *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(head_weights), *kutacc::convertKutaccTensor(weights_head_weights), *kutacc::convertKutaccTensor(o), *kutacc::convertKutaccTensor(o_pt), *kutacc::convertKutaccTensor(o_pt_norm), *kutacc::convertKutaccTensor(o_pair), - *kutacc::convertKutaccTensor(z), *kutacc::convertKutaccTensor(rigid_rot_mats), *kutacc::convertKutaccTensor(rigid_trans), *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(linear_b_w), *kutacc::convertKutaccTensor(linear_b_b), - n_res, c_z, c_hidden, no_heads, no_qk_points, no_v_points); - } diff --git a/src/attention/outer_product_mean.cpp b/src/attention/outer_product_mean.cpp index e9fdd5d..fbf7c41 100644 --- a/src/attention/outer_product_mean.cpp +++ b/src/attention/outer_product_mean.cpp @@ -167,18 +167,22 @@ void kutacc_export kutacc_af2_outer_product_mean_calc_left_and_right_mul( const kutacc_tensor_h left_proj_b, const kutacc_tensor_h right_proj_w, const kutacc_tensor_h right_proj_b, int64_t c_i, int64_t c_m, int64_t n_res, int64_t n_res_gather, int64_t n_seq, int64_t mask_bias) { - KUTACC_CHECK(left_proj != nullptr && right_proj != nullptr && left_proj_ != nullptr && right_proj_ != nullptr && input_act != nullptr && mask != nullptr - && norm != nullptr && left_proj_w != nullptr && left_proj_b != nullptr && right_proj_w != nullptr && right_proj_b != nullptr, - "kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(left_proj == nullptr || right_proj == nullptr || left_proj_ == nullptr || right_proj_ == nullptr || input_act == nullptr || mask == nullptr + || norm == nullptr || left_proj_w == nullptr || left_proj_b == nullptr || right_proj_w == nullptr || right_proj_b == nullptr)) + { + printf("kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args nullptr error"); return; + } else if (unlikely(c_i < 0 || c_m < 0 || n_res < 0 || n_res_gather < 0 || n_seq < 0 || mask_bias < 0)) { + printf("kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args negative value error"); + return; + } else { + outer_product_mean_calc_left_and_right_mul_kernel( + *kutacc::convertKutaccTensor(left_proj), *kutacc::convertKutaccTensor(right_proj), *kutacc::convertKutaccTensor(left_proj_), + *kutacc::convertKutaccTensor(right_proj_), *kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(mask), + *kutacc::convertKutaccTensor(norm), *kutacc::convertKutaccTensor(left_proj_w), *kutacc::convertKutaccTensor(left_proj_b), + *kutacc::convertKutaccTensor(right_proj_w), *kutacc::convertKutaccTensor(right_proj_b), c_i, c_m, n_res, n_res_gather, n_seq, + mask_bias); } - outer_product_mean_calc_left_and_right_mul_kernel( - *kutacc::convertKutaccTensor(left_proj), *kutacc::convertKutaccTensor(right_proj), *kutacc::convertKutaccTensor(left_proj_), - *kutacc::convertKutaccTensor(right_proj_), *kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(mask), - *kutacc::convertKutaccTensor(norm), *kutacc::convertKutaccTensor(left_proj_w), *kutacc::convertKutaccTensor(left_proj_b), - *kutacc::convertKutaccTensor(right_proj_w), *kutacc::convertKutaccTensor(right_proj_b), c_i, c_m, n_res, n_res_gather, n_seq, - mask_bias); } void kutacc_export kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h output_b, const kutacc_tensor_h output_w, @@ -186,12 +190,15 @@ void kutacc_export kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h out kutacc_tensor_h norm, int64_t left_block_size, int64_t right_block_size, int64_t c_i, int64_t c_z, int64_t n_res, int64_t n_res_gather, int64_t n_seq) { - KUTACC_CHECK(output_b != nullptr && output_w != nullptr && out != nullptr && left_proj_ != nullptr && right_proj_ != nullptr && norm != nullptr, - "kutacc_af2_outer_product_mean_chunk: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(output_b == nullptr || output_w == nullptr || out == nullptr || left_proj_ == nullptr || right_proj_ == nullptr || norm == nullptr)) { + printf("kutacc_af2_outer_product_mean_chunk: input args nullptr error"); + return; + } else if (unlikely(left_block_size < 0 || right_block_size < 0 || c_i < 0 || c_z < 0 || n_res < 0 || n_res_gather < 0 || n_seq < 0)) { + printf("kutacc_af2_outer_product_mean_chunk: input args negative value error"); return; + } else { + kutacc::outer_product_mean_chunk_kernel(*kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(output_w), + *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(left_proj_), *kutacc::convertKutaccTensor(right_proj_), + *kutacc::convertKutaccTensor(norm), left_block_size, right_block_size, c_i, c_z, n_res, n_res_gather, n_seq); } - kutacc::outer_product_mean_chunk_kernel(*kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(output_w), - *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(left_proj_), *kutacc::convertKutaccTensor(right_proj_), - *kutacc::convertKutaccTensor(norm), left_block_size, right_block_size, c_i, c_z, n_res, n_res_gather, n_seq); } \ No newline at end of file diff --git a/src/attention/rigid.cpp b/src/attention/rigid.cpp index 94ca998..f244f19 100644 --- a/src/attention/rigid.cpp +++ b/src/attention/rigid.cpp @@ -107,20 +107,22 @@ void rigid_rot_matmul(Tensor &a, Tensor &b, Tensor &out) void kutacc_af2_rigid_rot_vec_mul(kutacc_tensor_h pts, kutacc_tensor_h rot_mats, kutacc_tensor_h out, kutacc_tensor_h trans) { - KUTACC_CHECK(out != nullptr && pts != nullptr && rot_mats != nullptr, "kutacc_af2_rigid_rot_vec_mul: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(out == nullptr || pts == nullptr || rot_mats == nullptr)) { + printf("kutacc_af2_rigid_rot_vec_mul: input args nullptr error"); return; - } - if ((*kutacc::convertKutaccTensor(pts)).dtype() == kutacc::kBF16) { - kutacc::rigid_rot_vec_mul<__bf16>(*kutacc::convertKutaccTensor(pts), *kutacc::convertKutaccTensor(rot_mats), *kutacc::convertKutaccTensor(out), trans); + } else{ + if ((*kutacc::convertKutaccTensor(pts)).dtype() == kutacc::kBF16) { + kutacc::rigid_rot_vec_mul<__bf16>(*kutacc::convertKutaccTensor(pts), *kutacc::convertKutaccTensor(rot_mats), *kutacc::convertKutaccTensor(out), trans); + } } } void kutacc_af2_rigid_rot_matmul(kutacc_tensor_h a, kutacc_tensor_h b, kutacc_tensor_h out) { - KUTACC_CHECK(a != nullptr && b != nullptr && out != nullptr, "kutacc_af2_rigid_rot_matmul: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(a == nullptr || b == nullptr || out == nullptr)) { + printf("kutacc_af2_rigid_rot_matmul: input args nullptr error"); return; + } else { + kutacc::rigid_rot_matmul(*kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(out)); } - kutacc::rigid_rot_matmul(*kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(out)); } \ No newline at end of file diff --git a/src/attention/transition.cpp b/src/attention/transition.cpp index df81d9e..3bb2d47 100644 --- a/src/attention/transition.cpp +++ b/src/attention/transition.cpp @@ -50,12 +50,15 @@ void kutacc_af2_transition(kutacc_tensor_h input_act, const kutacc_tensor_h line kutacc_tensor_h linear2_w, kutacc_tensor_h linear2_b, kutacc_tensor_h intermediate_act, kutacc_tensor_h out, int64_t batch, int64_t n_res, int64_t c_o, int64_t c_i) { - KUTACC_CHECK(input_act != nullptr && linear1_w != nullptr && linear1_b != nullptr && linear2_w != nullptr && linear2_b != nullptr && intermediate_act != nullptr && out != nullptr, - "kutacc_af2_transition: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(input_act == nullptr || linear1_w == nullptr || linear1_b == nullptr || linear2_w == nullptr || linear2_b == nullptr || intermediate_act == nullptr || out == nullptr)) { + printf("kutacc_af2_transition: input args nullptr error"); return; + } else if (unlikely(batch < 0 || n_res < 0 || c_o < 0 || c_i < 0)) { + printf("kutacc_af2_transition: input args negative value error"); + return; + } else { + kutacc::transition_kernel(*kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(linear1_w), *kutacc::convertKutaccTensor(linear1_b), + *kutacc::convertKutaccTensor(linear2_w), *kutacc::convertKutaccTensor(linear2_b), *kutacc::convertKutaccTensor(intermediate_act), *kutacc::convertKutaccTensor(out), + batch, n_res, c_o, c_i); } - kutacc::transition_kernel(*kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(linear1_w), *kutacc::convertKutaccTensor(linear1_b), - *kutacc::convertKutaccTensor(linear2_w), *kutacc::convertKutaccTensor(linear2_b), *kutacc::convertKutaccTensor(intermediate_act), *kutacc::convertKutaccTensor(out), - batch, n_res, c_o, c_i); } \ No newline at end of file diff --git a/src/attention/triangle_multiplication.cpp b/src/attention/triangle_multiplication.cpp index 6d3978f..83c504e 100644 --- a/src/attention/triangle_multiplication.cpp +++ b/src/attention/triangle_multiplication.cpp @@ -173,46 +173,60 @@ kutacc_export void kutacc_af2_triangle_multiplication_calc_proj(kutacc_tensor_h const kutacc_tensor_h proj_w, const kutacc_tensor_h proj_b, const kutacc_tensor_h gate_w, const kutacc_tensor_h gate_b, int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack) { - KUTACC_CHECK(proj_act != nullptr && gate != nullptr && input_act != nullptr && mask != nullptr && proj_w != nullptr && proj_b != nullptr && gate_w != nullptr && gate_b != nullptr, - "kutacc_af2_triangle_multiplication_calc_proj: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(proj_act == nullptr || gate == nullptr || input_act == nullptr || mask == nullptr || proj_w == nullptr || proj_b == nullptr || gate_w == nullptr || gate_b == nullptr)) { + printf("kutacc_af2_triangle_multiplication_calc_proj: input args nullptr error"); return; + } else if (unlikely(n_res < 0 || n_res_gather < 0 || c_o < 0 || c_i < 0)) { + printf("kutacc_af2_triangle_multiplication_calc_proj: input args negative value error"); + return; + } else { + kutacc::calc_proj_act(*kutacc::convertKutaccTensor(proj_act), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(input_act), + *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(proj_w), *kutacc::convertKutaccTensor(proj_b), + *kutacc::convertKutaccTensor(gate_w), *kutacc::convertKutaccTensor(gate_b), n_res, n_res_gather, c_o, c_i, input_prepack); } - kutacc::calc_proj_act(*kutacc::convertKutaccTensor(proj_act), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(input_act), - *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(proj_w), *kutacc::convertKutaccTensor(proj_b), - *kutacc::convertKutaccTensor(gate_w), *kutacc::convertKutaccTensor(gate_b), n_res, n_res_gather, c_o, c_i, input_prepack); } kutacc_export void kutacc_af2_triangle_multiplication_equation(kutacc_tensor_h center_act, kutacc_tensor_h left_proj_act, kutacc_tensor_h right_proj_act, int64_t n_res_gather, bool is_incoming) { - KUTACC_CHECK(center_act != nullptr && left_proj_act != nullptr && right_proj_act != nullptr, "kutacc_af2_triangle_multiplication_equation: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(center_act == nullptr || left_proj_act == nullptr || right_proj_act == nullptr)) { + printf("kutacc_af2_triangle_multiplication_equation: input args nullptr error"); + return; + } else if (unlikely(n_res_gather < 0)) { + printf("kutacc_af2_triangle_multiplication_equation: input args negative value error"); return; + } else { + kutacc::equation(*kutacc::convertKutaccTensor(center_act), *kutacc::convertKutaccTensor(left_proj_act), *kutacc::convertKutaccTensor(right_proj_act), + n_res_gather, is_incoming); } - kutacc::equation(*kutacc::convertKutaccTensor(center_act), *kutacc::convertKutaccTensor(left_proj_act), *kutacc::convertKutaccTensor(right_proj_act), - n_res_gather, is_incoming); } kutacc_export void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc_tensor_h gate, kutacc_tensor_h out, kutacc_tensor_h input_act, kutacc_tensor_h center_act, const kutacc_tensor_h gating_w, const kutacc_tensor_h gating_b, const kutacc_tensor_h output_proj_w, const kutacc_tensor_h output_proj_b, int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack) { - KUTACC_CHECK(gate != nullptr && out != nullptr && input_act != nullptr && center_act != nullptr && gating_w != nullptr && gating_b != nullptr && output_proj_w != nullptr && output_proj_b != nullptr, - "kutacc_af2_triangle_multiplication_gate_and_out_linear: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(gate == nullptr || out == nullptr || input_act == nullptr || center_act == nullptr || gating_w == nullptr || gating_b == nullptr || output_proj_w == nullptr || output_proj_b == nullptr)) { + printf("kutacc_af2_triangle_multiplication_gate_and_out_linear: input args nullptr error"); return; + } else if (unlikely(n_res < 0 || n_res_gather < 0 || c_o < 0 || c_i < 0)) { + printf("kutacc_af2_triangle_multiplication_gate_and_out_linear: input args negative value error"); + return; + } else { + kutacc::gate_and_out_linear(*kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(input_act), + *kutacc::convertKutaccTensor(center_act), *kutacc::convertKutaccTensor(gating_w), *kutacc::convertKutaccTensor(gating_b), + *kutacc::convertKutaccTensor(output_proj_w), *kutacc::convertKutaccTensor(output_proj_b), n_res, n_res_gather, c_o, c_i, input_prepack); } - kutacc::gate_and_out_linear(*kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(input_act), - *kutacc::convertKutaccTensor(center_act), *kutacc::convertKutaccTensor(gating_w), *kutacc::convertKutaccTensor(gating_b), - *kutacc::convertKutaccTensor(output_proj_w), *kutacc::convertKutaccTensor(output_proj_b), n_res, n_res_gather, c_o, c_i, input_prepack); } kutacc_export void kutacc_af2_triangle_multiplication_last(kutacc_tensor_h out, kutacc_tensor_h gate, int64_t n_res, int64_t n_res_gather, int64_t c_o) { - KUTACC_CHECK(out != nullptr && gate != nullptr, "kutacc_af2_triangle_multiplication_last: input args nullptr error"); - if (kutacc::kutacc_check_err_set == true) { + if (unlikely(out == nullptr || gate == nullptr)){ + printf("kutacc_af2_triangle_multiplication_last: input args nullptr error"); + return; + } else if (unlikely(n_res < 0 || n_res_gather < 0 || c_o < 0)) { + printf("kutacc_af2_triangle_multiplication_last: input args negative value error"); return; + } else { + kutacc::last(*kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(gate), n_res, n_res_gather, c_o); } - kutacc::last(*kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(gate), n_res, n_res_gather, c_o); } \ No newline at end of file -- Gitee From 45c0742b2a8e0059dac65308e3423b6f3d8c0872 Mon Sep 17 00:00:00 2001 From: XeonYZhang Date: Mon, 27 Oct 2025 11:03:53 +0800 Subject: [PATCH 2/4] add unlikely check and non-positive int values check --- src/attention/gating_attention.cpp | 4 ++-- src/attention/global_attention.cpp | 6 +++--- src/attention/invariant_point.cpp | 4 ++-- src/attention/outer_product_mean.cpp | 4 ++-- src/attention/rigid.cpp | 4 ++-- src/attention/transition.cpp | 4 ++-- src/attention/triangle_multiplication.cpp | 16 ++++++++-------- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/attention/gating_attention.cpp b/src/attention/gating_attention.cpp index 1b63c40..bc9d66e 100644 --- a/src/attention/gating_attention.cpp +++ b/src/attention/gating_attention.cpp @@ -160,10 +160,10 @@ void kutacc_export kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_ten || bias == nullptr || nonbatched_bias == nullptr || query_w == nullptr || key_w == nullptr || value_w == nullptr || gating_w == nullptr || gating_b == nullptr || output_w == nullptr || output_b == nullptr || out == nullptr)) { - printf("kutacc_af2_gating_attention: input args nullptr error"); + printf("kutacc_af2_gating_attention: input args nullptr error\n"); return; } else if (unlikely(batch < 0 || seq_len < 0 || block_size_ < 0 || head_size < 0 || nheads < 0 || nchannels < 0)) { - printf("kutacc_af2_gating_attention: input args negative value error"); + printf("kutacc_af2_gating_attention: input args int values error, values are less than or equal to zero\n"); return; } else { kutacc::gating_attention_kernel(*kutacc::convertKutaccTensor(input), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), diff --git a/src/attention/global_attention.cpp b/src/attention/global_attention.cpp index 6bc075d..143571a 100644 --- a/src/attention/global_attention.cpp +++ b/src/attention/global_attention.cpp @@ -188,10 +188,10 @@ kutacc_export void kutacc_af2_global_attention(kutacc_tensor_h q_avg, kutacc_ten || q_mask == nullptr || query_w == nullptr || key_w == nullptr || value_w == nullptr || gating_w == nullptr || gating_b == nullptr || output_w == nullptr || output_b == nullptr || out == nullptr)) { - printf("kutacc_af2_global_attention: input args nullptr error"); + printf("kutacc_af2_global_attention: input args nullptr error\n"); return; - } else if (unlikely(batch < 0 || seq_len < 0 || nchannels < 0 || nchannels >= INT64_MAX || nheads < 0 || head_size < 0)) { - printf("kutacc_af2_global_attention: input args int values error"); + } else if (unlikely(batch < 0 || seq_len < 0 || nchannels < 0 || nheads < 0 || head_size < 0)) { + printf("kutacc_af2_global_attention: input args int values error, values are less than or equal to zero\n"); return; } else { kutacc::global_attention_kernel(*kutacc::convertKutaccTensor(q_avg), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), diff --git a/src/attention/invariant_point.cpp b/src/attention/invariant_point.cpp index fab244c..3394fad 100644 --- a/src/attention/invariant_point.cpp +++ b/src/attention/invariant_point.cpp @@ -170,10 +170,10 @@ void kutacc_af2_invariant_point(kutacc_tensor_h q, kutacc_tensor_h k, kutacc_ten || b == nullptr || a == nullptr || head_weights == nullptr || weights_head_weights == nullptr || o == nullptr || o_pt == nullptr || o_pt_norm == nullptr || o_pair == nullptr || z == nullptr || rigid_rot_mats == nullptr || rigid_trans == nullptr || mask == nullptr || linear_b_w == nullptr || linear_b_b == nullptr)) { - printf("kutacc_af2_invariant_point: input args nullptr error"); + printf("kutacc_af2_invariant_point: input args nullptr error\n"); return; } else if (unlikely(n_res < 0 || c_z < 0 || c_hidden < 0 || no_heads < 0 || no_v_points < 0)) { - printf("kutacc_af2_invariant_point: input args negative value error"); + printf("kutacc_af2_invariant_point: input args int values error, values are less than or equal to zero\n"); return; } else { kutacc::kutacc_af2_invariant_point_kernel(*kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(q_pts), *kutacc::convertKutaccTensor(k_pts), *kutacc::convertKutaccTensor(v_pts), diff --git a/src/attention/outer_product_mean.cpp b/src/attention/outer_product_mean.cpp index fbf7c41..e843755 100644 --- a/src/attention/outer_product_mean.cpp +++ b/src/attention/outer_product_mean.cpp @@ -191,10 +191,10 @@ void kutacc_export kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h out int64_t c_i, int64_t c_z, int64_t n_res, int64_t n_res_gather, int64_t n_seq) { if (unlikely(output_b == nullptr || output_w == nullptr || out == nullptr || left_proj_ == nullptr || right_proj_ == nullptr || norm == nullptr)) { - printf("kutacc_af2_outer_product_mean_chunk: input args nullptr error"); + printf("kutacc_af2_outer_product_mean_chunk: input args nullptr error\n"); return; } else if (unlikely(left_block_size < 0 || right_block_size < 0 || c_i < 0 || c_z < 0 || n_res < 0 || n_res_gather < 0 || n_seq < 0)) { - printf("kutacc_af2_outer_product_mean_chunk: input args negative value error"); + printf("kutacc_af2_outer_product_mean_chunk: input args int values error, values are less than or equal to zero\n"); return; } else { kutacc::outer_product_mean_chunk_kernel(*kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(output_w), diff --git a/src/attention/rigid.cpp b/src/attention/rigid.cpp index f244f19..5baf026 100644 --- a/src/attention/rigid.cpp +++ b/src/attention/rigid.cpp @@ -108,7 +108,7 @@ void rigid_rot_matmul(Tensor &a, Tensor &b, Tensor &out) void kutacc_af2_rigid_rot_vec_mul(kutacc_tensor_h pts, kutacc_tensor_h rot_mats, kutacc_tensor_h out, kutacc_tensor_h trans) { if (unlikely(out == nullptr || pts == nullptr || rot_mats == nullptr)) { - printf("kutacc_af2_rigid_rot_vec_mul: input args nullptr error"); + printf("kutacc_af2_rigid_rot_vec_mul: input args nullptr error\n"); return; } else{ if ((*kutacc::convertKutaccTensor(pts)).dtype() == kutacc::kBF16) { @@ -120,7 +120,7 @@ void kutacc_af2_rigid_rot_vec_mul(kutacc_tensor_h pts, kutacc_tensor_h rot_mats, void kutacc_af2_rigid_rot_matmul(kutacc_tensor_h a, kutacc_tensor_h b, kutacc_tensor_h out) { if (unlikely(a == nullptr || b == nullptr || out == nullptr)) { - printf("kutacc_af2_rigid_rot_matmul: input args nullptr error"); + printf("kutacc_af2_rigid_rot_matmul: input args nullptr error\n"); return; } else { kutacc::rigid_rot_matmul(*kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(out)); diff --git a/src/attention/transition.cpp b/src/attention/transition.cpp index 3bb2d47..0898779 100644 --- a/src/attention/transition.cpp +++ b/src/attention/transition.cpp @@ -51,10 +51,10 @@ void kutacc_af2_transition(kutacc_tensor_h input_act, const kutacc_tensor_h line int64_t batch, int64_t n_res, int64_t c_o, int64_t c_i) { if (unlikely(input_act == nullptr || linear1_w == nullptr || linear1_b == nullptr || linear2_w == nullptr || linear2_b == nullptr || intermediate_act == nullptr || out == nullptr)) { - printf("kutacc_af2_transition: input args nullptr error"); + printf("kutacc_af2_transition: input args nullptr error\n"); return; } else if (unlikely(batch < 0 || n_res < 0 || c_o < 0 || c_i < 0)) { - printf("kutacc_af2_transition: input args negative value error"); + printf("kutacc_af2_transition: input args int values error, values are less than or equal to zero\n"); return; } else { kutacc::transition_kernel(*kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(linear1_w), *kutacc::convertKutaccTensor(linear1_b), diff --git a/src/attention/triangle_multiplication.cpp b/src/attention/triangle_multiplication.cpp index 83c504e..79671d5 100644 --- a/src/attention/triangle_multiplication.cpp +++ b/src/attention/triangle_multiplication.cpp @@ -174,10 +174,10 @@ kutacc_export void kutacc_af2_triangle_multiplication_calc_proj(kutacc_tensor_h int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack) { if (unlikely(proj_act == nullptr || gate == nullptr || input_act == nullptr || mask == nullptr || proj_w == nullptr || proj_b == nullptr || gate_w == nullptr || gate_b == nullptr)) { - printf("kutacc_af2_triangle_multiplication_calc_proj: input args nullptr error"); + printf("kutacc_af2_triangle_multiplication_calc_proj: input args nullptr error\n"); return; } else if (unlikely(n_res < 0 || n_res_gather < 0 || c_o < 0 || c_i < 0)) { - printf("kutacc_af2_triangle_multiplication_calc_proj: input args negative value error"); + printf("kutacc_af2_triangle_multiplication_calc_proj: input args int values error, values are less than or equal to zero\n"); return; } else { kutacc::calc_proj_act(*kutacc::convertKutaccTensor(proj_act), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(input_act), @@ -190,10 +190,10 @@ kutacc_export void kutacc_af2_triangle_multiplication_equation(kutacc_tensor_h c int64_t n_res_gather, bool is_incoming) { if (unlikely(center_act == nullptr || left_proj_act == nullptr || right_proj_act == nullptr)) { - printf("kutacc_af2_triangle_multiplication_equation: input args nullptr error"); + printf("kutacc_af2_triangle_multiplication_equation: input args nullptr error\n"); return; } else if (unlikely(n_res_gather < 0)) { - printf("kutacc_af2_triangle_multiplication_equation: input args negative value error"); + printf("kutacc_af2_triangle_multiplication_equation: input args int values error, values are less than or equal to zero\n"); return; } else { kutacc::equation(*kutacc::convertKutaccTensor(center_act), *kutacc::convertKutaccTensor(left_proj_act), *kutacc::convertKutaccTensor(right_proj_act), @@ -206,10 +206,10 @@ kutacc_export void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack) { if (unlikely(gate == nullptr || out == nullptr || input_act == nullptr || center_act == nullptr || gating_w == nullptr || gating_b == nullptr || output_proj_w == nullptr || output_proj_b == nullptr)) { - printf("kutacc_af2_triangle_multiplication_gate_and_out_linear: input args nullptr error"); + printf("kutacc_af2_triangle_multiplication_gate_and_out_linear: input args nullptr error\n"); return; } else if (unlikely(n_res < 0 || n_res_gather < 0 || c_o < 0 || c_i < 0)) { - printf("kutacc_af2_triangle_multiplication_gate_and_out_linear: input args negative value error"); + printf("kutacc_af2_triangle_multiplication_gate_and_out_linear: input args int values error, values are less than or equal to zero\n"); return; } else { kutacc::gate_and_out_linear(*kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(input_act), @@ -221,10 +221,10 @@ kutacc_export void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc kutacc_export void kutacc_af2_triangle_multiplication_last(kutacc_tensor_h out, kutacc_tensor_h gate, int64_t n_res, int64_t n_res_gather, int64_t c_o) { if (unlikely(out == nullptr || gate == nullptr)){ - printf("kutacc_af2_triangle_multiplication_last: input args nullptr error"); + printf("kutacc_af2_triangle_multiplication_last: input args nullptr error\n"); return; } else if (unlikely(n_res < 0 || n_res_gather < 0 || c_o < 0)) { - printf("kutacc_af2_triangle_multiplication_last: input args negative value error"); + printf("kutacc_af2_triangle_multiplication_last: input args int values error, values are less than or equal to zero\n"); return; } else { kutacc::last(*kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(gate), n_res, n_res_gather, c_o); -- Gitee From e79a71add6e1de4d23bfe479590a0b545bb63eb8 Mon Sep 17 00:00:00 2001 From: XeonYZhang Date: Mon, 27 Oct 2025 17:28:27 +0800 Subject: [PATCH 3/4] change all unlikely to KUTACC_CHECK --- include/kutacc.h | 3 -- src/CMakeLists.txt | 2 +- src/attention/gating_attention.cpp | 28 +++++------ src/attention/global_attention.cpp | 26 +++++----- src/attention/invariant_point.cpp | 24 ++++----- src/attention/outer_product_mean.cpp | 41 +++++++-------- src/attention/rigid.cpp | 20 ++++---- src/attention/transition.cpp | 17 +++---- src/attention/triangle_multiplication.cpp | 61 +++++++++++------------ 9 files changed, 101 insertions(+), 121 deletions(-) diff --git a/include/kutacc.h b/include/kutacc.h index ad047c6..b01e2b5 100644 --- a/include/kutacc.h +++ b/include/kutacc.h @@ -18,9 +18,6 @@ #include #include -#define likely(x) __builtin_expect(!!(x), 1) -#define unlikely(x) __builtin_expect(!!(x), 0) - #ifdef __cplusplus extern "C" { #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ec3cd01..afcf98b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -30,7 +30,7 @@ elseif (CMAKE_BUILD_TYPE MATCHES "Release") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wcast-align -Wcast-qual -pipe -Wconversion") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wframe-larger-than=16384 -Wvla -fstack-protector-strong") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wstrict-prototypes") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstack-protector-strong -Wno-parentheses -fPIC -DF_INTERFACE_FLANG -mllvm -disable-lsr -Wno-unused-command-line-argument") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstack-protector-strong") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -s -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -Wl,-Bsymbolic -rdynamic -Wl,--no-undefined") add_compile_options(-O3) endif() diff --git a/src/attention/gating_attention.cpp b/src/attention/gating_attention.cpp index bc9d66e..61cc8c3 100644 --- a/src/attention/gating_attention.cpp +++ b/src/attention/gating_attention.cpp @@ -156,21 +156,19 @@ void kutacc_export kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_ten const kutacc_tensor_h value_w, const kutacc_tensor_h gating_w, const kutacc_tensor_h gating_b, const kutacc_tensor_h output_w, const kutacc_tensor_h output_b, kutacc_tensor_h out, int64_t block_size_, int64_t head_size, int64_t nheads, int64_t nchannels) { - if (unlikely(input == nullptr || q == nullptr || k == nullptr || v == nullptr || gate == nullptr || weighted_avg == nullptr || m_data == nullptr - || bias == nullptr || nonbatched_bias == nullptr || query_w == nullptr || key_w == nullptr || value_w == nullptr || gating_w == nullptr || gating_b == nullptr - || output_w == nullptr || output_b == nullptr || out == nullptr)) - { - printf("kutacc_af2_gating_attention: input args nullptr error\n"); + KUTACC_CHECK(input != nullptr && q != nullptr && k != nullptr && v != nullptr && gate != nullptr && weighted_avg != nullptr && m_data != nullptr + && bias != nullptr && nonbatched_bias != nullptr && query_w != nullptr && key_w != nullptr && value_w != nullptr && gating_w != nullptr && gating_b != nullptr + && output_w != nullptr && output_b != nullptr && out != nullptr, + "kutacc_af2_gating_attention: input args nullptr error\n"); + KUTACC_CHECK(batch > 0 && seq_len > 0 && block_size_ > 0 && head_size > 0 && nheads > 0 && nchannels > 0, + "kutacc_af2_gating_attention: input args int values error, values are less than or equal to zero\n"); + if (kutacc::kutacc_check_err_set == true) { return; - } else if (unlikely(batch < 0 || seq_len < 0 || block_size_ < 0 || head_size < 0 || nheads < 0 || nchannels < 0)) { - printf("kutacc_af2_gating_attention: input args int values error, values are less than or equal to zero\n"); - return; - } else { - kutacc::gating_attention_kernel(*kutacc::convertKutaccTensor(input), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), - *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(weighted_avg), batch, seq_len, - *kutacc::convertKutaccTensor(bias), *kutacc::convertKutaccTensor(nonbatched_bias), *kutacc::convertKutaccTensor(query_w), - *kutacc::convertKutaccTensor(key_w), *kutacc::convertKutaccTensor(value_w), *kutacc::convertKutaccTensor(gating_w), - *kutacc::convertKutaccTensor(gating_b), *kutacc::convertKutaccTensor(output_w), *kutacc::convertKutaccTensor(output_b), - *kutacc::convertKutaccTensor(out), block_size_, head_size, nheads, nchannels); } + kutacc::gating_attention_kernel(*kutacc::convertKutaccTensor(input), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), + *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(weighted_avg), batch, seq_len, + *kutacc::convertKutaccTensor(bias), *kutacc::convertKutaccTensor(nonbatched_bias), *kutacc::convertKutaccTensor(query_w), + *kutacc::convertKutaccTensor(key_w), *kutacc::convertKutaccTensor(value_w), *kutacc::convertKutaccTensor(gating_w), + *kutacc::convertKutaccTensor(gating_b), *kutacc::convertKutaccTensor(output_w), *kutacc::convertKutaccTensor(output_b), + *kutacc::convertKutaccTensor(out), block_size_, head_size, nheads, nchannels); } \ No newline at end of file diff --git a/src/attention/global_attention.cpp b/src/attention/global_attention.cpp index 143571a..28c04e1 100644 --- a/src/attention/global_attention.cpp +++ b/src/attention/global_attention.cpp @@ -184,20 +184,18 @@ kutacc_export void kutacc_af2_global_attention(kutacc_tensor_h q_avg, kutacc_ten const kutacc_tensor_h value_w, const kutacc_tensor_h gating_w, const kutacc_tensor_h gating_b, const kutacc_tensor_h output_w, const kutacc_tensor_h output_b, kutacc_tensor_h out) { - if (unlikely(q_avg == nullptr || q == nullptr || k == nullptr || v == nullptr || gate == nullptr || q_data == nullptr || m_data == nullptr - || q_mask == nullptr || query_w == nullptr || key_w == nullptr || value_w == nullptr || gating_w == nullptr || gating_b == nullptr - || output_w == nullptr || output_b == nullptr || out == nullptr)) - { - printf("kutacc_af2_global_attention: input args nullptr error\n"); + KUTACC_CHECK(q_avg != nullptr && q != nullptr && k != nullptr && v != nullptr && gate != nullptr && q_data != nullptr && m_data != nullptr + && q_mask != nullptr && query_w != nullptr && key_w != nullptr && value_w != nullptr && gating_w != nullptr && gating_b != nullptr + && output_w != nullptr && output_b != nullptr && out != nullptr, + "kutacc_af2_global_attention: input args nullptr error\n"); + KUTACC_CHECK(batch > 0 && seq_len > 0 && nchannels > 0 && nheads > 0 && head_size > 0, + "kutacc_af2_global_attention: input args int values error, values are less than or equal to zero\n"); + if (kutacc::kutacc_check_err_set == true) { return; - } else if (unlikely(batch < 0 || seq_len < 0 || nchannels < 0 || nheads < 0 || head_size < 0)) { - printf("kutacc_af2_global_attention: input args int values error, values are less than or equal to zero\n"); - return; - } else { - kutacc::global_attention_kernel(*kutacc::convertKutaccTensor(q_avg), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), - *kutacc::convertKutaccTensor(v), batch, seq_len, nchannels, nheads, head_size, *kutacc::convertKutaccTensor(gate),*kutacc::convertKutaccTensor(q_data), - *kutacc::convertKutaccTensor(q_mask), *kutacc::convertKutaccTensor(query_w), *kutacc::convertKutaccTensor(key_w), *kutacc::convertKutaccTensor(value_w), - *kutacc::convertKutaccTensor(gating_w), *kutacc::convertKutaccTensor(gating_b), *kutacc::convertKutaccTensor(output_w), - *kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(out)); } + kutacc::global_attention_kernel(*kutacc::convertKutaccTensor(q_avg), *kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), + *kutacc::convertKutaccTensor(v), batch, seq_len, nchannels, nheads, head_size, *kutacc::convertKutaccTensor(gate),*kutacc::convertKutaccTensor(q_data), + *kutacc::convertKutaccTensor(q_mask), *kutacc::convertKutaccTensor(query_w), *kutacc::convertKutaccTensor(key_w), *kutacc::convertKutaccTensor(value_w), + *kutacc::convertKutaccTensor(gating_w), *kutacc::convertKutaccTensor(gating_b), *kutacc::convertKutaccTensor(output_w), + *kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(out)); } \ No newline at end of file diff --git a/src/attention/invariant_point.cpp b/src/attention/invariant_point.cpp index 3394fad..3829249 100644 --- a/src/attention/invariant_point.cpp +++ b/src/attention/invariant_point.cpp @@ -166,20 +166,18 @@ void kutacc_af2_invariant_point(kutacc_tensor_h q, kutacc_tensor_h k, kutacc_ten kutacc_tensor_h z, kutacc_tensor_h rigid_rot_mats, kutacc_tensor_h rigid_trans, kutacc_tensor_h mask, kutacc_tensor_h linear_b_w, kutacc_tensor_h linear_b_b, int64_t n_res, int64_t c_z, int64_t c_hidden, int64_t no_heads, int64_t no_qk_points, int64_t no_v_points) { - if (unlikely(q == nullptr || k == nullptr || v == nullptr || q_pts == nullptr || k_pts == nullptr || v_pts == nullptr - || b == nullptr || a == nullptr || head_weights == nullptr || weights_head_weights == nullptr || o == nullptr || o_pt == nullptr || o_pt_norm == nullptr - || o_pair == nullptr || z == nullptr || rigid_rot_mats == nullptr || rigid_trans == nullptr || mask == nullptr || linear_b_w == nullptr || linear_b_b == nullptr)) - { - printf("kutacc_af2_invariant_point: input args nullptr error\n"); + KUTACC_CHECK(q != nullptr && k != nullptr && v != nullptr && q_pts != nullptr && k_pts != nullptr && v_pts != nullptr + && b != nullptr && a != nullptr && head_weights != nullptr && weights_head_weights != nullptr && o != nullptr && o_pt != nullptr && o_pt_norm != nullptr + && o_pair != nullptr && z != nullptr && rigid_rot_mats != nullptr && rigid_trans != nullptr && mask != nullptr && linear_b_w != nullptr && linear_b_b != nullptr, + "kutacc_af2_invariant_point: input args nullptr error\n"); + KUTACC_CHECK(n_res > 0 && c_z > 0 && c_hidden > 0 && no_heads > 0 && no_v_points > 0, + "kutacc_af2_invariant_point: input args int values error, values are less than or equal to zero\n"); + if (kutacc::kutacc_check_err_set == true) { return; - } else if (unlikely(n_res < 0 || c_z < 0 || c_hidden < 0 || no_heads < 0 || no_v_points < 0)) { - printf("kutacc_af2_invariant_point: input args int values error, values are less than or equal to zero\n"); - return; - } else { - kutacc::kutacc_af2_invariant_point_kernel(*kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(q_pts), *kutacc::convertKutaccTensor(k_pts), *kutacc::convertKutaccTensor(v_pts), - *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(head_weights), *kutacc::convertKutaccTensor(weights_head_weights), *kutacc::convertKutaccTensor(o), *kutacc::convertKutaccTensor(o_pt), *kutacc::convertKutaccTensor(o_pt_norm), *kutacc::convertKutaccTensor(o_pair), - *kutacc::convertKutaccTensor(z), *kutacc::convertKutaccTensor(rigid_rot_mats), *kutacc::convertKutaccTensor(rigid_trans), *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(linear_b_w), *kutacc::convertKutaccTensor(linear_b_b), - n_res, c_z, c_hidden, no_heads, no_qk_points, no_v_points); } + kutacc::kutacc_af2_invariant_point_kernel(*kutacc::convertKutaccTensor(q), *kutacc::convertKutaccTensor(k), *kutacc::convertKutaccTensor(v), *kutacc::convertKutaccTensor(q_pts), *kutacc::convertKutaccTensor(k_pts), *kutacc::convertKutaccTensor(v_pts), + *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(head_weights), *kutacc::convertKutaccTensor(weights_head_weights), *kutacc::convertKutaccTensor(o), *kutacc::convertKutaccTensor(o_pt), *kutacc::convertKutaccTensor(o_pt_norm), *kutacc::convertKutaccTensor(o_pair), + *kutacc::convertKutaccTensor(z), *kutacc::convertKutaccTensor(rigid_rot_mats), *kutacc::convertKutaccTensor(rigid_trans), *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(linear_b_w), *kutacc::convertKutaccTensor(linear_b_b), + n_res, c_z, c_hidden, no_heads, no_qk_points, no_v_points); } diff --git a/src/attention/outer_product_mean.cpp b/src/attention/outer_product_mean.cpp index e843755..bb85393 100644 --- a/src/attention/outer_product_mean.cpp +++ b/src/attention/outer_product_mean.cpp @@ -167,22 +167,20 @@ void kutacc_export kutacc_af2_outer_product_mean_calc_left_and_right_mul( const kutacc_tensor_h left_proj_b, const kutacc_tensor_h right_proj_w, const kutacc_tensor_h right_proj_b, int64_t c_i, int64_t c_m, int64_t n_res, int64_t n_res_gather, int64_t n_seq, int64_t mask_bias) { - if (unlikely(left_proj == nullptr || right_proj == nullptr || left_proj_ == nullptr || right_proj_ == nullptr || input_act == nullptr || mask == nullptr - || norm == nullptr || left_proj_w == nullptr || left_proj_b == nullptr || right_proj_w == nullptr || right_proj_b == nullptr)) - { - printf("kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args nullptr error"); + KUTACC_CHECK(left_proj != nullptr && right_proj != nullptr && left_proj_ != nullptr && right_proj_ != nullptr && input_act != nullptr && mask != nullptr + && norm != nullptr && left_proj_w != nullptr && left_proj_b != nullptr && right_proj_w != nullptr && right_proj_b != nullptr, + "kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args nullptr error\n"); + KUTACC_CHECK(c_i > 0 && c_m > 0 && n_res > 0 && n_res_gather > 0 && n_seq > 0 && mask_bias > 0, + "kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args negative value error\n"); + if (kutacc::kutacc_check_err_set == true) { return; - } else if (unlikely(c_i < 0 || c_m < 0 || n_res < 0 || n_res_gather < 0 || n_seq < 0 || mask_bias < 0)) { - printf("kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args negative value error"); - return; - } else { - outer_product_mean_calc_left_and_right_mul_kernel( - *kutacc::convertKutaccTensor(left_proj), *kutacc::convertKutaccTensor(right_proj), *kutacc::convertKutaccTensor(left_proj_), - *kutacc::convertKutaccTensor(right_proj_), *kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(mask), - *kutacc::convertKutaccTensor(norm), *kutacc::convertKutaccTensor(left_proj_w), *kutacc::convertKutaccTensor(left_proj_b), - *kutacc::convertKutaccTensor(right_proj_w), *kutacc::convertKutaccTensor(right_proj_b), c_i, c_m, n_res, n_res_gather, n_seq, - mask_bias); } + outer_product_mean_calc_left_and_right_mul_kernel( + *kutacc::convertKutaccTensor(left_proj), *kutacc::convertKutaccTensor(right_proj), *kutacc::convertKutaccTensor(left_proj_), + *kutacc::convertKutaccTensor(right_proj_), *kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(mask), + *kutacc::convertKutaccTensor(norm), *kutacc::convertKutaccTensor(left_proj_w), *kutacc::convertKutaccTensor(left_proj_b), + *kutacc::convertKutaccTensor(right_proj_w), *kutacc::convertKutaccTensor(right_proj_b), c_i, c_m, n_res, n_res_gather, n_seq, + mask_bias); } void kutacc_export kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h output_b, const kutacc_tensor_h output_w, @@ -190,15 +188,14 @@ void kutacc_export kutacc_af2_outer_product_mean_chunk(const kutacc_tensor_h out kutacc_tensor_h norm, int64_t left_block_size, int64_t right_block_size, int64_t c_i, int64_t c_z, int64_t n_res, int64_t n_res_gather, int64_t n_seq) { - if (unlikely(output_b == nullptr || output_w == nullptr || out == nullptr || left_proj_ == nullptr || right_proj_ == nullptr || norm == nullptr)) { - printf("kutacc_af2_outer_product_mean_chunk: input args nullptr error\n"); - return; - } else if (unlikely(left_block_size < 0 || right_block_size < 0 || c_i < 0 || c_z < 0 || n_res < 0 || n_res_gather < 0 || n_seq < 0)) { - printf("kutacc_af2_outer_product_mean_chunk: input args int values error, values are less than or equal to zero\n"); + KUTACC_CHECK(output_b != nullptr && output_w != nullptr && out != nullptr && left_proj_ != nullptr && right_proj_ != nullptr && norm != nullptr, + "kutacc_af2_outer_product_mean_chunk: input args nullptr error"); + KUTACC_CHECK(left_block_size > 0 && right_block_size > 0 && c_i > 0 && c_z > 0 && n_res > 0 && n_res_gather > 0 && n_seq > 0, + "kutacc_af2_outer_product_mean_chunk: input args int values error, values are less than or equal to zero\n"); + if (kutacc::kutacc_check_err_set == true) { return; - } else { - kutacc::outer_product_mean_chunk_kernel(*kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(output_w), + } + kutacc::outer_product_mean_chunk_kernel(*kutacc::convertKutaccTensor(output_b), *kutacc::convertKutaccTensor(output_w), *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(left_proj_), *kutacc::convertKutaccTensor(right_proj_), *kutacc::convertKutaccTensor(norm), left_block_size, right_block_size, c_i, c_z, n_res, n_res_gather, n_seq); - } } \ No newline at end of file diff --git a/src/attention/rigid.cpp b/src/attention/rigid.cpp index 5baf026..e2b205c 100644 --- a/src/attention/rigid.cpp +++ b/src/attention/rigid.cpp @@ -107,22 +107,20 @@ void rigid_rot_matmul(Tensor &a, Tensor &b, Tensor &out) void kutacc_af2_rigid_rot_vec_mul(kutacc_tensor_h pts, kutacc_tensor_h rot_mats, kutacc_tensor_h out, kutacc_tensor_h trans) { - if (unlikely(out == nullptr || pts == nullptr || rot_mats == nullptr)) { - printf("kutacc_af2_rigid_rot_vec_mul: input args nullptr error\n"); + KUTACC_CHECK(out != nullptr && pts != nullptr && rot_mats != nullptr, "kutacc_af2_rigid_rot_vec_mul: input args nullptr error"); + if (kutacc::kutacc_check_err_set == true) { return; - } else{ - if ((*kutacc::convertKutaccTensor(pts)).dtype() == kutacc::kBF16) { - kutacc::rigid_rot_vec_mul<__bf16>(*kutacc::convertKutaccTensor(pts), *kutacc::convertKutaccTensor(rot_mats), *kutacc::convertKutaccTensor(out), trans); - } + } + if ((*kutacc::convertKutaccTensor(pts)).dtype() == kutacc::kBF16) { + kutacc::rigid_rot_vec_mul<__bf16>(*kutacc::convertKutaccTensor(pts), *kutacc::convertKutaccTensor(rot_mats), *kutacc::convertKutaccTensor(out), trans); } } void kutacc_af2_rigid_rot_matmul(kutacc_tensor_h a, kutacc_tensor_h b, kutacc_tensor_h out) { - if (unlikely(a == nullptr || b == nullptr || out == nullptr)) { - printf("kutacc_af2_rigid_rot_matmul: input args nullptr error\n"); + KUTACC_CHECK(a != nullptr && b != nullptr && out != nullptr, "kutacc_af2_rigid_rot_matmul: input args nullptr error"); + if (kutacc::kutacc_check_err_set == true) { return; - } else { - kutacc::rigid_rot_matmul(*kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(out)); - } + } + kutacc::rigid_rot_matmul(*kutacc::convertKutaccTensor(a), *kutacc::convertKutaccTensor(b), *kutacc::convertKutaccTensor(out)); } \ No newline at end of file diff --git a/src/attention/transition.cpp b/src/attention/transition.cpp index 0898779..cdf28db 100644 --- a/src/attention/transition.cpp +++ b/src/attention/transition.cpp @@ -50,15 +50,14 @@ void kutacc_af2_transition(kutacc_tensor_h input_act, const kutacc_tensor_h line kutacc_tensor_h linear2_w, kutacc_tensor_h linear2_b, kutacc_tensor_h intermediate_act, kutacc_tensor_h out, int64_t batch, int64_t n_res, int64_t c_o, int64_t c_i) { - if (unlikely(input_act == nullptr || linear1_w == nullptr || linear1_b == nullptr || linear2_w == nullptr || linear2_b == nullptr || intermediate_act == nullptr || out == nullptr)) { - printf("kutacc_af2_transition: input args nullptr error\n"); + KUTACC_CHECK(input_act != nullptr && linear1_w != nullptr && linear1_b != nullptr && linear2_w != nullptr && linear2_b != nullptr && intermediate_act != nullptr && out != nullptr, + "kutacc_af2_transition: input args nullptr error\n"); + KUTACC_CHECK(batch > 0 && n_res > 0 && c_o > 0 && c_i > 0, + "kutacc_af2_transition: input args int values error, values are less than or equal to zero\n"); + if (kutacc::kutacc_check_err_set == true) { return; - } else if (unlikely(batch < 0 || n_res < 0 || c_o < 0 || c_i < 0)) { - printf("kutacc_af2_transition: input args int values error, values are less than or equal to zero\n"); - return; - } else { - kutacc::transition_kernel(*kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(linear1_w), *kutacc::convertKutaccTensor(linear1_b), - *kutacc::convertKutaccTensor(linear2_w), *kutacc::convertKutaccTensor(linear2_b), *kutacc::convertKutaccTensor(intermediate_act), *kutacc::convertKutaccTensor(out), - batch, n_res, c_o, c_i); } + kutacc::transition_kernel(*kutacc::convertKutaccTensor(input_act), *kutacc::convertKutaccTensor(linear1_w), *kutacc::convertKutaccTensor(linear1_b), + *kutacc::convertKutaccTensor(linear2_w), *kutacc::convertKutaccTensor(linear2_b), *kutacc::convertKutaccTensor(intermediate_act), *kutacc::convertKutaccTensor(out), + batch, n_res, c_o, c_i); } \ No newline at end of file diff --git a/src/attention/triangle_multiplication.cpp b/src/attention/triangle_multiplication.cpp index 79671d5..64e76f6 100644 --- a/src/attention/triangle_multiplication.cpp +++ b/src/attention/triangle_multiplication.cpp @@ -173,60 +173,55 @@ kutacc_export void kutacc_af2_triangle_multiplication_calc_proj(kutacc_tensor_h const kutacc_tensor_h proj_w, const kutacc_tensor_h proj_b, const kutacc_tensor_h gate_w, const kutacc_tensor_h gate_b, int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack) { - if (unlikely(proj_act == nullptr || gate == nullptr || input_act == nullptr || mask == nullptr || proj_w == nullptr || proj_b == nullptr || gate_w == nullptr || gate_b == nullptr)) { - printf("kutacc_af2_triangle_multiplication_calc_proj: input args nullptr error\n"); + KUTACC_CHECK(proj_act != nullptr && gate != nullptr && input_act != nullptr && mask != nullptr && proj_w != nullptr && proj_b != nullptr && gate_w != nullptr && gate_b != nullptr, + "kutacc_af2_triangle_multiplication_calc_proj: input args nullptr error\n"); + KUTACC_CHECK(n_res > 0 && n_res_gather > 0 && c_o > 0 && c_i > 0, + "kutacc_af2_triangle_multiplication_calc_proj: input args int values error, values are less than or equal to zero\n"); + if (kutacc::kutacc_check_err_set == true) { return; - } else if (unlikely(n_res < 0 || n_res_gather < 0 || c_o < 0 || c_i < 0)) { - printf("kutacc_af2_triangle_multiplication_calc_proj: input args int values error, values are less than or equal to zero\n"); - return; - } else { - kutacc::calc_proj_act(*kutacc::convertKutaccTensor(proj_act), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(input_act), - *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(proj_w), *kutacc::convertKutaccTensor(proj_b), - *kutacc::convertKutaccTensor(gate_w), *kutacc::convertKutaccTensor(gate_b), n_res, n_res_gather, c_o, c_i, input_prepack); } + kutacc::calc_proj_act(*kutacc::convertKutaccTensor(proj_act), *kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(input_act), + *kutacc::convertKutaccTensor(mask), *kutacc::convertKutaccTensor(proj_w), *kutacc::convertKutaccTensor(proj_b), + *kutacc::convertKutaccTensor(gate_w), *kutacc::convertKutaccTensor(gate_b), n_res, n_res_gather, c_o, c_i, input_prepack); } kutacc_export void kutacc_af2_triangle_multiplication_equation(kutacc_tensor_h center_act, kutacc_tensor_h left_proj_act, kutacc_tensor_h right_proj_act, int64_t n_res_gather, bool is_incoming) { - if (unlikely(center_act == nullptr || left_proj_act == nullptr || right_proj_act == nullptr)) { - printf("kutacc_af2_triangle_multiplication_equation: input args nullptr error\n"); - return; - } else if (unlikely(n_res_gather < 0)) { - printf("kutacc_af2_triangle_multiplication_equation: input args int values error, values are less than or equal to zero\n"); + KUTACC_CHECK(center_act != nullptr && left_proj_act != nullptr && right_proj_act != nullptr, + "kutacc_af2_triangle_multiplication_equation: input args nullptr error\n"); + KUTACC_CHECK(n_res_gather > 0, + "kutacc_af2_triangle_multiplication_equation: input args int values error, values are less than or equal to zero\n"); + if (kutacc::kutacc_check_err_set == true) { return; - } else { - kutacc::equation(*kutacc::convertKutaccTensor(center_act), *kutacc::convertKutaccTensor(left_proj_act), *kutacc::convertKutaccTensor(right_proj_act), - n_res_gather, is_incoming); } + kutacc::equation(*kutacc::convertKutaccTensor(center_act), *kutacc::convertKutaccTensor(left_proj_act), *kutacc::convertKutaccTensor(right_proj_act), + n_res_gather, is_incoming); } kutacc_export void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc_tensor_h gate, kutacc_tensor_h out, kutacc_tensor_h input_act, kutacc_tensor_h center_act, const kutacc_tensor_h gating_w, const kutacc_tensor_h gating_b, const kutacc_tensor_h output_proj_w, const kutacc_tensor_h output_proj_b, int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i, bool input_prepack) { - if (unlikely(gate == nullptr || out == nullptr || input_act == nullptr || center_act == nullptr || gating_w == nullptr || gating_b == nullptr || output_proj_w == nullptr || output_proj_b == nullptr)) { - printf("kutacc_af2_triangle_multiplication_gate_and_out_linear: input args nullptr error\n"); + KUTACC_CHECK(gate != nullptr && out != nullptr && input_act != nullptr && center_act != nullptr && gating_w != nullptr && gating_b != nullptr && output_proj_w != nullptr && output_proj_b != nullptr, + "kutacc_af2_triangle_multiplication_gate_and_out_linear: input args nullptr error\n"); + KUTACC_CHECK(n_res > 0 && n_res_gather > 0 && c_o > 0 && c_i > 0, + "kutacc_af2_triangle_multiplication_gate_and_out_linear: input args int values error, values are less than or equal to zero\n"); + if (kutacc::kutacc_check_err_set == true) { return; - } else if (unlikely(n_res < 0 || n_res_gather < 0 || c_o < 0 || c_i < 0)) { - printf("kutacc_af2_triangle_multiplication_gate_and_out_linear: input args int values error, values are less than or equal to zero\n"); - return; - } else { - kutacc::gate_and_out_linear(*kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(input_act), - *kutacc::convertKutaccTensor(center_act), *kutacc::convertKutaccTensor(gating_w), *kutacc::convertKutaccTensor(gating_b), - *kutacc::convertKutaccTensor(output_proj_w), *kutacc::convertKutaccTensor(output_proj_b), n_res, n_res_gather, c_o, c_i, input_prepack); } + kutacc::gate_and_out_linear(*kutacc::convertKutaccTensor(gate), *kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(input_act), + *kutacc::convertKutaccTensor(center_act), *kutacc::convertKutaccTensor(gating_w), *kutacc::convertKutaccTensor(gating_b), + *kutacc::convertKutaccTensor(output_proj_w), *kutacc::convertKutaccTensor(output_proj_b), n_res, n_res_gather, c_o, c_i, input_prepack); } kutacc_export void kutacc_af2_triangle_multiplication_last(kutacc_tensor_h out, kutacc_tensor_h gate, int64_t n_res, int64_t n_res_gather, int64_t c_o) { - if (unlikely(out == nullptr || gate == nullptr)){ - printf("kutacc_af2_triangle_multiplication_last: input args nullptr error\n"); - return; - } else if (unlikely(n_res < 0 || n_res_gather < 0 || c_o < 0)) { - printf("kutacc_af2_triangle_multiplication_last: input args int values error, values are less than or equal to zero\n"); + KUTACC_CHECK(out != nullptr && gate != nullptr, "kutacc_af2_triangle_multiplication_last: input args nullptr error\n"); + KUTACC_CHECK(n_res > 0 && n_res_gather > 0 && c_o > 0, + "kutacc_af2_triangle_multiplication_last: input args int values error, values are less than or equal to zero\n"); + if (kutacc::kutacc_check_err_set == true) { return; - } else { - kutacc::last(*kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(gate), n_res, n_res_gather, c_o); } + kutacc::last(*kutacc::convertKutaccTensor(out), *kutacc::convertKutaccTensor(gate), n_res, n_res_gather, c_o); } \ No newline at end of file -- Gitee From 54332b257e834c783686454f349f9efe90af6926 Mon Sep 17 00:00:00 2001 From: XeonYZhang Date: Mon, 27 Oct 2025 22:07:36 +0800 Subject: [PATCH 4/4] fix outer_product_mean check error --- src/attention/outer_product_mean.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/attention/outer_product_mean.cpp b/src/attention/outer_product_mean.cpp index bb85393..c0ff99e 100644 --- a/src/attention/outer_product_mean.cpp +++ b/src/attention/outer_product_mean.cpp @@ -170,7 +170,7 @@ void kutacc_export kutacc_af2_outer_product_mean_calc_left_and_right_mul( KUTACC_CHECK(left_proj != nullptr && right_proj != nullptr && left_proj_ != nullptr && right_proj_ != nullptr && input_act != nullptr && mask != nullptr && norm != nullptr && left_proj_w != nullptr && left_proj_b != nullptr && right_proj_w != nullptr && right_proj_b != nullptr, "kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args nullptr error\n"); - KUTACC_CHECK(c_i > 0 && c_m > 0 && n_res > 0 && n_res_gather > 0 && n_seq > 0 && mask_bias > 0, + KUTACC_CHECK(c_i > 0 && c_m > 0 && n_res > 0 && n_res_gather > 0 && n_seq > 0 && mask_bias >= 0, "kutacc_af2_outer_product_mean_calc_left_and_right_mul: input args negative value error\n"); if (kutacc::kutacc_check_err_set == true) { return; -- Gitee