From 6066b1a502c523625695a84bd65995eb6fd3c106 Mon Sep 17 00:00:00 2001 From: blunce Date: Mon, 19 May 2025 15:31:46 +0800 Subject: [PATCH] [SVE] Add container restriction for std find with sve --- gcc/gimple-ssa-expand-sve.cc | 391 ++++++++++-------- gcc/testsuite/g++.dg/tree-ssa/find-with-sve.C | 38 ++ 2 files changed, 246 insertions(+), 183 deletions(-) diff --git a/gcc/gimple-ssa-expand-sve.cc b/gcc/gimple-ssa-expand-sve.cc index 9bac95212f2..e8c9e9abfe5 100644 --- a/gcc/gimple-ssa-expand-sve.cc +++ b/gcc/gimple-ssa-expand-sve.cc @@ -38,219 +38,244 @@ along with GCC; see the file COPYING3. If not see namespace { -#define TRACE_FUNCTION(fun) \ - if (dump_file) \ - { \ - fprintf (dump_file, "\nprocess function: \n"); \ - dump_function_to_file (fun, dump_file, TDF_NONE);\ - fprintf (dump_file, "\n"); \ - } - -#define TRACE_STMT(stmt) \ - if (dump_file) \ - { \ - fprintf (dump_file, "\nprocess stmt: \n"); \ - print_gimple_stmt (dump_file, stmt, 0, TDF_NONE);\ - fprintf (dump_file, "\n"); \ - } - -#define TRACE_REPLACE_STMT(stmt) \ - if (dump_file) \ - { \ - fprintf (dump_file, "\nprocess replace stmt: \n");\ - print_gimple_stmt (dump_file, stmt, 0, TDF_NONE);\ - fprintf (dump_file, "\n"); \ - } - -#define TRACE_ARG3_TYPE(type) \ - if (dump_file) \ - { \ - fprintf (dump_file, "\nprocess arg3 type: \n"); \ - dump_node (type, TDF_NONE, dump_file); \ - fprintf (dump_file, "\n"); \ - } +#define TRACE_FUNCTION(fun)\ + if (dump_file)\ + {\ + fprintf (dump_file, "\nprocess function: \n");\ + dump_function_to_file (fun, dump_file, TDF_NONE);\ + fprintf (dump_file, "\n");\ + } + +#define TRACE_STMT(stmt)\ + if (dump_file)\ + {\ + fprintf (dump_file, "\nprocess stmt: \n");\ + print_gimple_stmt (dump_file, stmt, 0, TDF_NONE);\ + fprintf (dump_file, "\n");\ + } + +#define TRACE_REPLACE_STMT(stmt)\ + if (dump_file)\ + {\ + fprintf (dump_file, "\nprocess replace stmt: \n");\ + print_gimple_stmt (dump_file, stmt, 0, TDF_NONE);\ + fprintf (dump_file, "\n");\ + } + +#define TRACE_ARG3_TYPE(type)\ + if (dump_file)\ + {\ + fprintf (dump_file, "\nprocess arg3 type: \n");\ + dump_node (type, TDF_NONE, dump_file);\ + fprintf (dump_file, "\n");\ + } const pass_data pass_data_find_with_sve = { - GIMPLE_PASS, /* type */ - "find_with_sve", /* name */ - OPTGROUP_NONE, /* optinfo_flags */ - TV_NONE, /* tv_id */ - 0, /* properties_required */ - 0, /* properties_provided */ - 0, /* properties_destroyed */ - 0, /* todo_flags_start */ - TODO_cleanup_cfg | TODO_update_ssa | TODO_update_address_taken - | TODO_rebuild_cgraph_edges, /* todo_flags_finish */ + GIMPLE_PASS, /* type. */ + "find_with_sve", /* name. */ + OPTGROUP_NONE, /* optinfo_flags. */ + TV_NONE, /* tv_id. */ + 0, /* properties_required. */ + 0, /* properties_provided. */ + 0, /* properties_destroyed. */ + 0, /* todo_flags_start. */ + TODO_cleanup_cfg | TODO_update_ssa | TODO_update_address_taken + | TODO_rebuild_cgraph_edges, /* todo_flags_finish. */ }; class pass_find_with_sve : public gimple_opt_pass { public: - pass_find_with_sve (gcc::context *ctx) : - gimple_opt_pass (pass_data_find_with_sve, ctx) - {} + pass_find_with_sve (gcc::context *ctx) : + gimple_opt_pass (pass_data_find_with_sve, ctx) + {} - virtual bool gate (function *fun) override - { - if (!flag_find_with_sve) - return false; + virtual bool gate (function *fun) override + { + if (!flag_find_with_sve) + return false; - if (!targetm.vector_mode_supported_p (V2DImode)) - return false; + if (!targetm.vector_mode_supported_p (V2DImode)) + return false; - return true; - } + return true; + } - virtual unsigned int execute (function *fun) override +virtual unsigned int execute (function *fun) override +{ + TRACE_FUNCTION (fun->decl); + basic_block bb; + FOR_EACH_BB_FN (bb, fun) + { + for (gimple_stmt_iterator gsi = gsi_start_bb (bb); + !gsi_end_p (gsi); gsi_next (&gsi)) { - TRACE_FUNCTION (fun->decl); - basic_block bb; - FOR_EACH_BB_FN (bb, fun) - { - for (gimple_stmt_iterator gsi = gsi_start_bb (bb); - !gsi_end_p (gsi); gsi_next (&gsi)) - { - gimple *stmt = gsi_stmt (gsi); - if (std_find_check (stmt)) - replace_std_find (gsi); - } - } - - return 0; + gimple *stmt = gsi_stmt (gsi); + if (std_find_check (stmt)) + replace_std_find (gsi); } + } -private: - uint8_t bit_width; - const char *null_name = ""; + return 0; +} - bool std_find_check (gimple *stmt) +private: + uint8_t bit_width; + const char *null_name = ""; + + bool std_find_check (gimple *stmt) + { + if (!is_gimple_call (stmt)) + return false; + + tree fndecl = gimple_call_fndecl (stmt); + if (fndecl == nullptr || DECL_NAME (fndecl) == nullptr) + return false; + + const char *fn_name = IDENTIFIER_POINTER (DECL_NAME (fndecl)); + if (fn_name == nullptr || strcmp (fn_name, "find") != 0) + return false; + + if (DECL_CONTEXT (fndecl) == nullptr + || TREE_CODE (DECL_CONTEXT (fndecl)) != NAMESPACE_DECL) + return false; + + const char *namespace_name + = IDENTIFIER_POINTER (DECL_NAME (DECL_CONTEXT (fndecl))); + if (namespace_name == nullptr || strcmp (namespace_name, "std") != 0) + return false; + + /* Exclude the scenarios : xxx::std::find. */ + if (DECL_CONTEXT (DECL_CONTEXT (fndecl)) + && TREE_CODE (DECL_CONTEXT (DECL_CONTEXT (fndecl))) + == NAMESPACE_DECL) + return false; + + if (gimple_call_num_args (stmt) != 3) + return false; + + tree arg1 = DECL_ARGUMENTS (fndecl); + tree arg2 = TREE_CHAIN (arg1); + tree arg3 = TREE_CHAIN (arg2); + + tree arg3_type = TREE_TYPE (arg3); + if (TREE_CODE (arg3_type) != REFERENCE_TYPE) + return false; + + tree main_type = TREE_TYPE (arg3_type); + TRACE_ARG3_TYPE (main_type); + if (TREE_CODE (main_type) == INTEGER_TYPE) { - if (!is_gimple_call (stmt)) - return false; - - tree fndecl = gimple_call_fndecl (stmt); - if (fndecl == nullptr || DECL_NAME (fndecl) == nullptr) - return false; - - const char *fn_name = IDENTIFIER_POINTER (DECL_NAME (fndecl)); - if (strcmp (fn_name, "find") != 0) - return false; - - if (DECL_CONTEXT (fndecl) == nullptr - || TREE_CODE (DECL_CONTEXT (fndecl)) != NAMESPACE_DECL) - return false; - - const char *namespace_name - = IDENTIFIER_POINTER (DECL_NAME (DECL_CONTEXT (fndecl))); - if (strcmp (namespace_name, "std") != 0) - return false; - - if (gimple_call_num_args (stmt) != 3) - return false; - - tree arg1 = DECL_ARGUMENTS (fndecl); - tree arg2 = TREE_CHAIN (arg1); - tree arg3 = TREE_CHAIN (arg2); - - tree arg3_type = TREE_TYPE (arg3); - if (TREE_CODE (arg3_type) != REFERENCE_TYPE) - return false; - - tree main_type = TREE_TYPE (arg3_type); - TRACE_ARG3_TYPE (main_type); - if (TREE_CODE (main_type) == INTEGER_TYPE) - { - if (TYPE_PRECISION (main_type) != 64) - return false; - - const char *type_name = get_type_name_arg3 (main_type); - if ((strcmp (type_name, "long unsigned int") != 0) - && (strcmp (type_name, "long int") != 0)) - return false; - - this->bit_width = 64; - } else if (TREE_CODE (main_type) == POINTER_TYPE) - this->bit_width = 64; - else - return false; - + if (TYPE_PRECISION (main_type) != 64) + return false; + + const char *type_name = get_type_name_arg (main_type); + if ((strcmp (type_name, "long unsigned int") != 0) + && (strcmp (type_name, "long int") != 0)) + return false; + + this->bit_width = 64; + } else if (TREE_CODE (main_type) == POINTER_TYPE) + this->bit_width = 64; + else + return false; + + tree arg1_type = TREE_TYPE (arg1); + if (TREE_CODE (arg1_type) == POINTER_TYPE) + return true; + else if (TREE_CODE (arg1_type) == RECORD_TYPE) + { + const char *type_name = get_type_name_arg (arg1_type); + if (strcmp (type_name, "__normal_iterator") == 0) return true; } - const char *get_type_name_arg3 (tree main_type) - { - enum tree_code code = TREE_CODE (main_type); - enum tree_code_class tclass = TREE_CODE_CLASS (code); + return false; + } - if (tclass == tcc_type) - { - if (TYPE_NAME (main_type)) - { - if (TREE_CODE (TYPE_NAME (main_type)) == IDENTIFIER_NODE) - return IDENTIFIER_POINTER (TYPE_NAME (main_type)); - else if (TREE_CODE (TYPE_NAME (main_type)) == TYPE_DECL - && DECL_NAME (TYPE_NAME (main_type))) - return IDENTIFIER_POINTER ( - DECL_NAME (TYPE_NAME (main_type))); - } - } + const char *get_type_name_arg (tree main_type) + { + enum tree_code code = TREE_CODE (main_type); + enum tree_code_class tclass = TREE_CODE_CLASS (code); - return null_name; - } - - void replace_std_find (gimple_stmt_iterator gsi) + if (tclass == tcc_type) { - switch (this->bit_width) + if (TYPE_NAME (main_type)) + { + if (TREE_CODE (TYPE_NAME (main_type)) == IDENTIFIER_NODE) + { + const char *type_name = IDENTIFIER_POINTER ( + TYPE_NAME (main_type)); + if (type_name) + return type_name; + } + else if (TREE_CODE (TYPE_NAME (main_type)) == TYPE_DECL + && DECL_NAME (TYPE_NAME (main_type))) { - case 64: - replace_std_find_u64 (gsi); - break; - case 32: - case 16: - case 8: - default:; + const char *type_name = IDENTIFIER_POINTER ( + DECL_NAME (TYPE_NAME (main_type))); + if (type_name) + return type_name; } + } } - void replace_std_find_u64 (gimple_stmt_iterator gsi) + return null_name; + } + + void replace_std_find (gimple_stmt_iterator gsi) + { + switch (this->bit_width) { - gimple *stmt = gsi_stmt (gsi); - tree old_fndecl = gimple_call_fndecl (stmt); - TRACE_STMT (stmt); - - // arguments list process: - auto_vec args; - for (unsigned i = 0; i < gimple_call_num_args (stmt); ++i) - args.safe_push (gimple_call_arg (stmt, i)); - tree new_arg = build_int_cst (unsigned_char_type_node, - sve_expand_std_find_threshold); - args.safe_push (new_arg); - - // functon declare process: - tree old_type = TREE_TYPE (old_fndecl); - tree ret_type = TREE_TYPE (old_type); - tree arg_types = NULL_TREE; - for (tree t = TYPE_ARG_TYPES (old_type); t; t = TREE_CHAIN (t)) - arg_types = tree_cons (NULL_TREE, TREE_VALUE (t), arg_types); - arg_types = tree_cons (NULL_TREE, unsigned_char_type_node, arg_types); - arg_types = nreverse (arg_types); - tree new_fndecl_type = build_function_type (ret_type, arg_types); - tree new_fndecl = build_fn_decl ("__sve_optimized_find_u64", - new_fndecl_type); - TREE_PUBLIC (new_fndecl) = 1; - DECL_EXTERNAL (new_fndecl) = 1; - - // call function process: - gcall *new_call = gimple_build_call_vec (new_fndecl, args); - if (gimple_has_lhs (stmt)) - gimple_call_set_lhs (new_call, gimple_call_lhs (stmt)); - gsi_replace (&gsi, new_call, true); - update_stmt (gsi_stmt (gsi)); - TRACE_REPLACE_STMT (gsi_stmt (gsi)); + case 64: + replace_std_find_u64 (gsi); + break; + case 32: + case 16: + case 8: + default:; } + } + + void replace_std_find_u64 (gimple_stmt_iterator gsi) + { + gimple *stmt = gsi_stmt (gsi); + tree old_fndecl = gimple_call_fndecl (stmt); + TRACE_STMT (stmt); + + // arguments list process: + auto_vec args; + for (unsigned i = 0; i < gimple_call_num_args (stmt); ++i) + args.safe_push (gimple_call_arg (stmt, i)); + tree new_arg = build_int_cst (unsigned_char_type_node, + sve_expand_std_find_threshold); + args.safe_push (new_arg); + + // functon declare process: + tree old_type = TREE_TYPE (old_fndecl); + tree ret_type = TREE_TYPE (old_type); + tree arg_types = NULL_TREE; + for (tree t = TYPE_ARG_TYPES (old_type); t; t = TREE_CHAIN (t)) + arg_types = tree_cons (NULL_TREE, TREE_VALUE (t), arg_types); + arg_types = tree_cons (NULL_TREE, unsigned_char_type_node, arg_types); + arg_types = nreverse (arg_types); + tree new_fndecl_type = build_function_type (ret_type, arg_types); + tree new_fndecl = build_fn_decl ("__sve_optimized_find_u64", + new_fndecl_type); + TREE_PUBLIC (new_fndecl) = 1; + DECL_EXTERNAL (new_fndecl) = 1; + + // call function process: + gcall *new_call = gimple_build_call_vec (new_fndecl, args); + if (gimple_has_lhs (stmt)) + gimple_call_set_lhs (new_call, gimple_call_lhs (stmt)); + gsi_replace (&gsi, new_call, true); + update_stmt (gsi_stmt (gsi)); + TRACE_REPLACE_STMT (gsi_stmt (gsi)); + } }; } // namespace gimple_opt_pass *make_pass_find_with_sve (gcc::context *ctx) { - return new pass_find_with_sve (ctx); + return new pass_find_with_sve (ctx); } diff --git a/gcc/testsuite/g++.dg/tree-ssa/find-with-sve.C b/gcc/testsuite/g++.dg/tree-ssa/find-with-sve.C index 66d03e2cfa9..e80fc91786c 100644 --- a/gcc/testsuite/g++.dg/tree-ssa/find-with-sve.C +++ b/gcc/testsuite/g++.dg/tree-ssa/find-with-sve.C @@ -7,6 +7,7 @@ #include #include #include +#include void test_u64() { @@ -122,6 +123,41 @@ void test_u16_point() std::cout << "fail!\n"; } +void test_set() +{ + std::set s = {1, 3, 5, 7, 9}; + std::uint64_t ask = 4; + + if (auto it = std::find (s.begin(), s.end(), ask); it != s.end()) // not matched + std::cout << "ok!\n"; + else + std::cout << "fail!\n"; +} + +namespace myspace +{ + namespace std + { + struct Basic + { + ::std::uint64_t id; + }; + + ::std::uint64_t find(Basic *, Basic *, ::std::uint64_t &x) + { + return x; + } + } +} + +void test_namespace() +{ + myspace::std::Basic b {1}; + std::uint64_t y = 1; + std::uint64_t x = find(nullptr, &b, y); + printf("x = %d\n", x); +} + int main() { test_u64(); @@ -131,6 +167,8 @@ int main() test_s32(); test_u16(); test_u16_point(); + test_set(); + test_namespace(); return 0; } -- Gitee