diff --git a/gcc/Makefile.in b/gcc/Makefile.in index f21bc5f9a97333e6d74644026685383da3a90aa4..cca3027e67eaa3753e5b469f931c1b7a4fc10109 100644 --- a/gcc/Makefile.in +++ b/gcc/Makefile.in @@ -1358,6 +1358,7 @@ OBJS = \ gimple-ssa-backprop.o \ gimple-ssa-evrp.o \ gimple-ssa-evrp-analyze.o \ + gimple-ssa-expand-sve.o \ gimple-ssa-isolate-paths.o \ gimple-ssa-nonnull-compare.o \ gimple-ssa-split-paths.o \ diff --git a/gcc/common.opt b/gcc/common.opt index c0c1de754dd9abb9f39e87fc38b0787196def2da..a2621fad78942baaffe58b41211fc1d755e59bf8 100644 --- a/gcc/common.opt +++ b/gcc/common.opt @@ -3593,4 +3593,12 @@ fifcvt-allow-complicated-cmps Common Report Var(flag_ifcvt_allow_complicated_cmps) Optimization Allow RTL if-conversion pass to deal with complicated cmps (can increase compilation time). +ffind-with-sve +Common Var(flag_find_with_sve) Init(0) Optimization +Enable replace std::find with sve + +fsve-expand-std-find-threshold +Common Var(sve_expand_std_find_threshold) Init(8) Optimization +Minimal length of the array to search + ; This comment is to ensure we retain the blank line above. diff --git a/gcc/gimple-ssa-expand-sve.cc b/gcc/gimple-ssa-expand-sve.cc new file mode 100644 index 0000000000000000000000000000000000000000..3aadedff82c3f1ec567f087e10547f57f0e44e38 --- /dev/null +++ b/gcc/gimple-ssa-expand-sve.cc @@ -0,0 +1,281 @@ +/* replace the std::find with sve. + Copyright (C) 2005-2022 Free Software Foundation, Inc. + +This file is part of GCC. + +GCC is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 3, or (at your option) +any later version. + +GCC is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with GCC; see the file COPYING3. If not see +. */ + +#include "config.h" +#include "system.h" +#include "coretypes.h" +#include "tree.h" +#include "function.h" +#include "basic-block.h" +#include "gimple.h" +#include "gimple-iterator.h" +#include "tree-pass.h" +#include "context.h" +#include "target.h" +#include "toplev.h" +#include "cfghooks.h" +#include "cfg.h" +#include "tree-cfg.h" +#include "cfgloop.h" +#include "gimple-ssa.h" +#include "gimple-pretty-print.h" + +namespace { + +#define TRACE_FUNCTION(fun)\ + if (dump_file)\ + {\ + fprintf (dump_file, "\nprocess function: \n");\ + dump_function_to_file (fun, dump_file, TDF_NONE);\ + fprintf (dump_file, "\n");\ + } + +#define TRACE_STMT(stmt)\ + if (dump_file)\ + {\ + fprintf (dump_file, "\nprocess stmt: \n");\ + print_gimple_stmt (dump_file, stmt, 0, TDF_NONE);\ + fprintf (dump_file, "\n");\ + } + +#define TRACE_REPLACE_STMT(stmt)\ + if (dump_file)\ + {\ + fprintf (dump_file, "\nprocess replace stmt: \n");\ + print_gimple_stmt (dump_file, stmt, 0, TDF_NONE);\ + fprintf (dump_file, "\n");\ + } + +#define TRACE_ARG3_TYPE(type)\ + if (dump_file)\ + {\ + fprintf (dump_file, "\nprocess arg3 type: \n");\ + dump_node (type, TDF_NONE, dump_file);\ + fprintf (dump_file, "\n");\ + } + +const pass_data pass_data_find_with_sve = { + GIMPLE_PASS, /* type. */ + "find_with_sve", /* name. */ + OPTGROUP_NONE, /* optinfo_flags. */ + TV_NONE, /* tv_id. */ + 0, /* properties_required. */ + 0, /* properties_provided. */ + 0, /* properties_destroyed. */ + 0, /* todo_flags_start. */ + TODO_cleanup_cfg | TODO_update_ssa | TODO_update_address_taken + | TODO_rebuild_cgraph_edges, /* todo_flags_finish. */ +}; + +class pass_find_with_sve : public gimple_opt_pass { +public: + pass_find_with_sve (gcc::context *ctx) : + gimple_opt_pass (pass_data_find_with_sve, ctx) + {} + + virtual bool gate (function *fun) override + { + if (!flag_find_with_sve) + return false; + + if (!targetm.vector_mode_supported_p (V2DImode)) + return false; + + return true; + } + +virtual unsigned int execute (function *fun) override +{ + TRACE_FUNCTION (fun->decl); + basic_block bb; + FOR_EACH_BB_FN (bb, fun) + { + for (gimple_stmt_iterator gsi = gsi_start_bb (bb); + !gsi_end_p (gsi); gsi_next (&gsi)) + { + gimple *stmt = gsi_stmt (gsi); + if (std_find_check (stmt)) + replace_std_find (gsi); + } + } + + return 0; +} + +private: + uint8_t bit_width; + const char *null_name = ""; + + bool std_find_check (gimple *stmt) + { + if (!is_gimple_call (stmt)) + return false; + + tree fndecl = gimple_call_fndecl (stmt); + if (fndecl == nullptr || DECL_NAME (fndecl) == nullptr) + return false; + + const char *fn_name = IDENTIFIER_POINTER (DECL_NAME (fndecl)); + if (fn_name == nullptr || strcmp (fn_name, "find") != 0) + return false; + + if (DECL_CONTEXT (fndecl) == nullptr + || TREE_CODE (DECL_CONTEXT (fndecl)) != NAMESPACE_DECL) + return false; + + const char *namespace_name + = IDENTIFIER_POINTER (DECL_NAME (DECL_CONTEXT (fndecl))); + if (namespace_name == nullptr || strcmp (namespace_name, "std") != 0) + return false; + + /* Exclude the scenarios : xxx::std::find. */ + if (DECL_CONTEXT (DECL_CONTEXT (fndecl)) + && TREE_CODE (DECL_CONTEXT (DECL_CONTEXT (fndecl))) + == NAMESPACE_DECL) + return false; + + if (gimple_call_num_args (stmt) != 3) + return false; + + tree arg1 = DECL_ARGUMENTS (fndecl); + tree arg2 = TREE_CHAIN (arg1); + tree arg3 = TREE_CHAIN (arg2); + + tree arg3_type = TREE_TYPE (arg3); + if (TREE_CODE (arg3_type) != REFERENCE_TYPE) + return false; + + tree main_type = TREE_TYPE (arg3_type); + TRACE_ARG3_TYPE (main_type); + if (TREE_CODE (main_type) == INTEGER_TYPE) + { + if (TYPE_PRECISION (main_type) != 64) + return false; + + const char *type_name = get_type_name_arg (main_type); + if ((strcmp (type_name, "long unsigned int") != 0) + && (strcmp (type_name, "long int") != 0)) + return false; + + this->bit_width = 64; + } else if (TREE_CODE (main_type) == POINTER_TYPE) + this->bit_width = 64; + else + return false; + + tree arg1_type = TREE_TYPE (arg1); + if (TREE_CODE (arg1_type) == POINTER_TYPE) + return true; + else if (TREE_CODE (arg1_type) == RECORD_TYPE) + { + const char *type_name = get_type_name_arg (arg1_type); + if (strcmp (type_name, "__normal_iterator") == 0) + return true; + } + + return false; + } + + const char *get_type_name_arg (tree main_type) + { + enum tree_code code = TREE_CODE (main_type); + enum tree_code_class tclass = TREE_CODE_CLASS (code); + + if (tclass == tcc_type) + { + if (TYPE_NAME (main_type)) + { + if (TREE_CODE (TYPE_NAME (main_type)) == IDENTIFIER_NODE) + { + const char *type_name = IDENTIFIER_POINTER ( + TYPE_NAME (main_type)); + if (type_name) + return type_name; + } + else if (TREE_CODE (TYPE_NAME (main_type)) == TYPE_DECL + && DECL_NAME (TYPE_NAME (main_type))) + { + const char *type_name = IDENTIFIER_POINTER ( + DECL_NAME (TYPE_NAME (main_type))); + if (type_name) + return type_name; + } + } + } + + return null_name; + } + + void replace_std_find (gimple_stmt_iterator gsi) + { + switch (this->bit_width) + { + case 64: + replace_std_find_u64 (gsi); + break; + case 32: + case 16: + case 8: + default:; + } + } + + void replace_std_find_u64 (gimple_stmt_iterator gsi) + { + gimple *stmt = gsi_stmt (gsi); + tree old_fndecl = gimple_call_fndecl (stmt); + TRACE_STMT (stmt); + + // arguments list process: + auto_vec args; + for (unsigned i = 0; i < gimple_call_num_args (stmt); ++i) + args.safe_push (gimple_call_arg (stmt, i)); + tree new_arg = build_int_cst (unsigned_char_type_node, + sve_expand_std_find_threshold); + args.safe_push (new_arg); + + // functon declare process: + tree old_type = TREE_TYPE (old_fndecl); + tree ret_type = TREE_TYPE (old_type); + tree arg_types = NULL_TREE; + for (tree t = TYPE_ARG_TYPES (old_type); t; t = TREE_CHAIN (t)) + arg_types = tree_cons (NULL_TREE, TREE_VALUE (t), arg_types); + arg_types = tree_cons (NULL_TREE, unsigned_char_type_node, arg_types); + arg_types = nreverse (arg_types); + tree new_fndecl_type = build_function_type (ret_type, arg_types); + tree new_fndecl = build_fn_decl ("__sve_optimized_find_u64", + new_fndecl_type); + TREE_PUBLIC (new_fndecl) = 1; + DECL_EXTERNAL (new_fndecl) = 1; + + // call function process: + gcall *new_call = gimple_build_call_vec (new_fndecl, args); + if (gimple_has_lhs (stmt)) + gimple_call_set_lhs (new_call, gimple_call_lhs (stmt)); + gsi_replace (&gsi, new_call, true); + update_stmt (gsi_stmt (gsi)); + TRACE_REPLACE_STMT (gsi_stmt (gsi)); + } +}; +} // namespace + +gimple_opt_pass *make_pass_find_with_sve (gcc::context *ctx) +{ + return new pass_find_with_sve (ctx); +} diff --git a/gcc/passes.def b/gcc/passes.def index 8898b72fcdc679feb088e9f8440bbfd75e3442dd..bcf477cd02192ec8b62e894d8eeb5c41e808a0bf 100644 --- a/gcc/passes.def +++ b/gcc/passes.def @@ -68,6 +68,7 @@ along with GCC; see the file COPYING3. If not see PUSH_INSERT_PASSES_WITHIN (pass_local_optimization_passes) NEXT_PASS (pass_fixup_cfg); NEXT_PASS (pass_rebuild_cgraph_edges); + NEXT_PASS (pass_find_with_sve); NEXT_PASS (pass_local_fn_summary); NEXT_PASS (pass_early_inline); NEXT_PASS (pass_all_early_optimizations); diff --git a/gcc/testsuite/g++.dg/tree-ssa/find-with-sve.C b/gcc/testsuite/g++.dg/tree-ssa/find-with-sve.C new file mode 100644 index 0000000000000000000000000000000000000000..42c51668f5f8b86fb3eaf89154108e344c1e63d1 --- /dev/null +++ b/gcc/testsuite/g++.dg/tree-ssa/find-with-sve.C @@ -0,0 +1,175 @@ +/* { dg-do compile } */ +/* { dg-options "-std=c++17 -O3 -ffind-with-sve -march=armv8-a+sve -fdump-tree-optimized" } */ + +#include +#include +#include +#include +#include +#include +#include + +void test_u64() +{ + std::vector v = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + + std::uint64_t x = 100; + std::cin >> x; + + auto it = std::find(v.begin(), v.end(), x); // matched : No.1 + + if (it != v.end()) + std::cout << "ok!\n"; + else + std::cout << "fail!\n"; +} + +void test_s64() +{ + std::vector v = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + + std::int64_t x = 100; + std::cin >> x; + + auto it = std::find(v.begin(), v.end(), x); // matched : No.2 + + if (it != v.end()) + std::cout << "ok!\n"; + else + std::cout << "fail!\n"; +} + +void test_array() +{ + const unsigned N = 1024 * 1024 * 16; + long *arr = new long[N]; + long *p; + for (unsigned i = 0; i < N; ++i) + arr[i] = i; + for (unsigned i = N - 1000; i < N - 1; ++i) { + p = std::find(arr, arr + N, arr[i]); // matched : No.3 + assert(p == arr + i); + unsigned j = i - 10; + p = std::find(arr + j, arr + j + 1, arr[j]); // matched : No.4 + assert(p == arr + j); + p = std::find(arr + j + 1, arr + j + 1, arr[j + 2]); // matched : No.5 + assert(p == arr + j + 1); + p = std::find(arr + j + 2, arr + j + 1, arr[j + 2]); // matched : No.6 + assert(p == arr + j + 1); + } + p = std::find(arr, arr + N, (long)-1); // matched : No.7 + assert(p == arr + N); +} + +void test_string() +{ + std::vector v; + for (int i = 0; i < 5; i++) + v.push_back(std::to_string(123 + i)); + + for (int i = 0; i < 5; i++) { + auto it = std::find(v.begin(), v.end(), std::to_string(124 + i)); // not matched + + if (it != v.end()) + std::cout << "ok!\n"; + else + std::cout << "failed!\n"; + } +} + +void test_s32() +{ + std::vector v = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + + std::int32_t x = 100; + + auto it = std::find(v.begin(), v.end(), x); // not matched + + if (it != v.end()) + std::cout << "ok!\n"; + else + std::cout << "failed!\n"; +} + +void test_u16() +{ + std::vector v = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + + std::uint16_t x = 100; + + auto it = std::find(v.begin(), v.end(), x); // not matched + + if (it != v.end()) + std::cout << "ok!\n"; + else + std::cout << "failed!\n"; +} + +void test_u16_point() +{ + std::vector v = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector v_ptr; + + for (auto &item : v) + v_ptr.push_back(&item); + + std::uint16_t x = 100; + + auto it = std::find(v_ptr.begin(), v_ptr.end(), &x); // matched : No.8 + + if (it != v_ptr.end()) + std::cout << "ok!\n"; + else + std::cout << "fail!\n"; +} + +void test_set() +{ + std::set s = {1, 3, 5, 7, 9}; + std::uint64_t ask = 4; + + if (auto it = std::find (s.begin(), s.end(), ask); it != s.end()) // not matched + std::cout << "ok!\n"; + else + std::cout << "fail!\n"; +} + +namespace myspace +{ + namespace std + { + struct Basic + { + ::std::uint64_t id; + }; + + ::std::uint64_t find(Basic *, Basic *, ::std::uint64_t &x) + { + return x; + } + } +} + +void test_namespace() +{ + myspace::std::Basic b {1}; + std::uint64_t y = 1; + std::uint64_t x = find(nullptr, &b, y); + printf("x = %d\n", x); +} + +int main() +{ + test_u64(); + test_s64(); + test_array(); + test_string(); + test_s32(); + test_u16(); + test_u16_point(); + test_set(); + test_namespace(); + return 0; +} + +/* { dg-final { scan-tree-dump-times "__sve_optimized_find_u64" 8 "optimized" } } */ diff --git a/gcc/tree-pass.h b/gcc/tree-pass.h index d3a41d0d51e82bf2179c070d3fa383bbb0bb1e66..3045fbe2fc54ac5e84a06a87213f4446fd5c09b3 100644 --- a/gcc/tree-pass.h +++ b/gcc/tree-pass.h @@ -483,6 +483,7 @@ extern gimple_opt_pass *make_pass_sprintf_length (gcc::context *ctxt); extern gimple_opt_pass *make_pass_walloca (gcc::context *ctxt); extern gimple_opt_pass *make_pass_coroutine_lower_builtins (gcc::context *ctxt); extern gimple_opt_pass *make_pass_coroutine_early_expand_ifns (gcc::context *ctxt); +extern gimple_opt_pass *make_pass_find_with_sve (gcc::context *ctx); /* IPA Passes */ extern simple_ipa_opt_pass *make_pass_ipa_lower_emutls (gcc::context *ctxt); diff --git a/libgcc/config/aarch64/sve_std_find.c b/libgcc/config/aarch64/sve_std_find.c new file mode 100644 index 0000000000000000000000000000000000000000..86ff4cb5a64a6871924d16ff0a31ee16c60d59cb --- /dev/null +++ b/libgcc/config/aarch64/sve_std_find.c @@ -0,0 +1,38 @@ +#include +#include + +#pragma GCC target ("+sve") + +uint64_t *__sve_optimized_find_u64 (uint64_t *first, uint64_t *last, + uint64_t const *value, uint8_t threshold) +{ + if (first + threshold > last) + { + goto Tail; + } + + uint64_t m = svcntd (); + uint64_t n = (last - first) / m; + svbool_t TRUE = svptrue_b64 (); + for (; n-- > 0;) + { + svuint64_t v3 = svld1_u64 (TRUE, (uint64_t *)first); + svbool_t v4 = svcmpeq_n_u64 (TRUE, v3, (uint64_t)*value); + if (svptest_any (TRUE, v4)) + { + break; + } + first += m; + } + +Tail: + while (first < last) + { + if (*first == *value) + { + return first; + } + ++first; + } + return last; +} diff --git a/libgcc/config/aarch64/t-aarch64 b/libgcc/config/aarch64/t-aarch64 index fce36be7480b3c6885ec6f6d069ce9a71c5b754c..9ad607adfc6e62d58ea5d23a4200ca3a2843d0c2 100644 --- a/libgcc/config/aarch64/t-aarch64 +++ b/libgcc/config/aarch64/t-aarch64 @@ -19,3 +19,4 @@ # . LIB2ADD += $(srcdir)/config/aarch64/sync-cache.c +LIB2ADD += $(srcdir)/config/aarch64/sve_std_find.c