diff --git a/src/poly/mls.h b/src/poly/mls.h index a4a365b487275ae6ea329fb30a8e4de9a8c2dcb9..b3b5df458107904cdc751fa2e4bf058b1f01a513 100644 --- a/src/poly/mls.h +++ b/src/poly/mls.h @@ -549,7 +549,7 @@ class Hints { /// \retval false otherwise [[gnu::pure]] bool HasStatementVectorials(const char *statement) const; - /// \brief Check whether the Hints has reduces directives for a given stateemnt + /// \brief Check whether the Hints has reduces directives for a given statement /// \param[in] statement Target statement /// \return A boolean value that indicates whether the Hints has reduces directives for \a statement /// \retval true if the Hints has reduces directives for \a statement diff --git a/src/poly/schedule_pass.cc b/src/poly/schedule_pass.cc index 6a368f345d3df53aa4ea702eb3098c2319e53fdb..6e5e76cea615c93968db04f9ebb43ef25d85f035 100644 --- a/src/poly/schedule_pass.cc +++ b/src/poly/schedule_pass.cc @@ -564,6 +564,63 @@ mls::bin::Options MLSchedOptionsInit(const akg::ir::poly::PassInfo &pass_info, return result; } + +mls::bin::Hints ExtractDirectivesFromAKG(ScopInfo &scop_info) { + mls::bin::Hints hints; + + ForTypeMap directives = scop_info.analysis_result_.GetForTypeMap(); + std::map> serials_directive; + std::map> vectorials_directive; + std::map> parallels_directive; + std::map> reduces_directive; + for (const auto &[stmt, vloop_directive] : directives) { + std::string stmt_string = stmt.get_name(); + for (uint i = 0; i < vloop_directive.size(); ++i) { + switch (vloop_directive[i]) { + case ForType::Serial: + break; + case ForType::Invariant: + LOG(INFO) << stmt_string << " invariant_for"; + serials_directive[stmt_string].push_back(i); + break; + case ForType::Parallel: + LOG(INFO) << stmt_string << " parallel"; + parallels_directive[stmt_string].push_back(i); + break; + case ForType::Vectorized: + case ForType::Swizzled: // treat "Swizzled" like "Vectorized" for the moment + LOG(INFO) << stmt_string << " vectorized"; + vectorials_directive[stmt_string].push_back(i); + break; + case ForType::Reduce: + LOG(INFO) << stmt_string << " reduce"; + reduces_directive[stmt_string].push_back(i); + break; + case ForType::Unrolled: + LOG(WARNING) << stmt_string << " Do not treat ForType::Unrolled as a directives"; + break; + default: + LOG(WARNING) << stmt_string << " Unknow ForType loop"; + break; + } + } + } + + for (const auto &[key, directive] : serials_directive) { + hints.SetStatementSerials(key.c_str(), directive); + } + for (const auto &[key, directive] : vectorials_directive) { + hints.SetStatementVectorials(key.c_str(), directive); + } + for (const auto &[key, directive] : parallels_directive) { + hints.SetStatementParallels(key.c_str(), directive); + } + for (const auto &[key, directive] : reduces_directive) { + hints.SetStatementReduces(key.c_str(), directive); + } + + return hints; +} #endif } // namespace poly diff --git a/src/poly/schedule_pass.h b/src/poly/schedule_pass.h index 534ad2c4d9c2fc56b16d8fbe7a88af96c0560242..8a38510c85e79e6696f28196777c1c8b691ac0eb 100644 --- a/src/poly/schedule_pass.h +++ b/src/poly/schedule_pass.h @@ -130,6 +130,10 @@ bool MLSchedShouldBeUsed(akg::ir::poly::ScopInfo &scop_info); /// The options may be decided arbitrarily, from the environment or from \a pass_info and \a scop_info. mls::bin::Options MLSchedOptionsInit(const akg::ir::poly::PassInfo &pass_info, const akg::ir::poly::ScopInfo &scop_info); + +/// \brief Extract the directives informations from the information coming from AKG scop +/// \result return an hint object that can be used for MLSched scheduler. +mls::bin::Hints ExtractDirectivesFromAKG(ScopInfo &scop_info); #endif } // namespace poly diff --git a/src/poly/schedule_pass/compute_schedule.cc b/src/poly/schedule_pass/compute_schedule.cc index dad5cbd2f2ecee9f3bc3de1601682a5d19cd62d5..bc413a91f5f02b70c02d034ea6ccbcfe74fcf5d2 100644 --- a/src/poly/schedule_pass/compute_schedule.cc +++ b/src/poly/schedule_pass/compute_schedule.cc @@ -167,53 +167,6 @@ isl::union_pw_aff ComputeSchedule::GenerateNewAffine(const isl::union_pw_aff &sw return new_aff; }; -#ifdef AKG_USE_MLS -mls::bin::Hints ComputeSchedule::ExtractDirectivesFromAKG(void) { - mls::bin::Hints hints; - - ForTypeMap directives = scop_info_.analysis_result_.GetForTypeMap(); - std::map> serials_dir; - std::map> vectorials_dir; - std::map> parallels_dir; - for (const auto &[stmt, vloop_directive] : directives) { - std::string stmt_string = stmt.get_name(); - for (uint i = 0; i < vloop_directive.size(); ++i) { - switch (vloop_directive[i]) { - case ForType::Serial: - break; - case ForType::Invariant: - LOG(INFO) << stmt_string << "invariant_for"; - serials_dir[stmt_string].push_back(i); - break; - case ForType::Parallel: - LOG(INFO) << stmt_string << "parallel"; - parallels_dir[stmt_string].push_back(i); - break; - case ForType::Vectorized: - case ForType::Swizzled: // treat "Swizzled" like "Vectorized" for the moment - LOG(INFO) << stmt_string << "vectorized"; - vectorials_dir[stmt_string].push_back(i); - break; - case ForType::Unrolled: - LOG(WARNING) << stmt_string << "Do not treat ForType::Unrolled as a directives"; - break; - default: - break; - } - } - } - - for (const auto &[key, directive] : serials_dir) { - hints.SetStatementSerials(key.c_str(), directive); - } - for (const auto &[key, directive] : vectorials_dir) { - hints.SetStatementVectorials(key.c_str(), directive); - } - - return hints; -} -#endif - isl::schedule ComputeSchedule::Run(isl::schedule sch) { if (scop_info_.user_config_.GetModScheduleShift()) { pass_info_.dependences_ = ModDependences(pass_info_.dependences_); @@ -233,7 +186,7 @@ isl::schedule ComputeSchedule::Run(isl::schedule sch) { } const std::string &kernel_name = scop_info_.user_config_.GetKernelName(); - const mls::bin::Hints hints = ExtractDirectivesFromAKG(); + const mls::bin::Hints hints = ExtractDirectivesFromAKG(scop_info_); const isl::union_map reads = UnwrappedAccesses(scop_info_.analysis_result_.GetReads()); const isl::union_map writes = UnwrappedAccesses(scop_info_.analysis_result_.GetWrites()); isl_union_map *const dependences = pass_info_.dependences_.get(); diff --git a/src/poly/schedule_pass/compute_schedule.h b/src/poly/schedule_pass/compute_schedule.h index cc9f5f4f8f19cc95bc22d95cdf5ccb4144b4e32e..ae7006f2d8299405865f99fcfb35cc7da866e7b7 100644 --- a/src/poly/schedule_pass/compute_schedule.h +++ b/src/poly/schedule_pass/compute_schedule.h @@ -51,10 +51,6 @@ class ComputeSchedule : public SchedulePass { isl::union_pw_aff GenerateNewAffine(const isl::union_pw_aff &swap_out, const isl::union_pw_aff &swap_in, std::unordered_set swap_ids); -#ifdef AKG_USE_MLS - mls::bin::Hints ExtractDirectivesFromAKG(void); -#endif - private: PassInfo &pass_info_; diff --git a/src/poly/schedule_pass/scheduling_mind_trick.cc b/src/poly/schedule_pass/scheduling_mind_trick.cc index 447419c930aa3adb4facb7e6a47ed2045262e1dc..4653bb5ee91ecd6100958765290215eedcc913f2 100644 --- a/src/poly/schedule_pass/scheduling_mind_trick.cc +++ b/src/poly/schedule_pass/scheduling_mind_trick.cc @@ -1426,6 +1426,10 @@ bool SchedulingMindTrick::BuildInfluencedSchedule(const isl::schedule &schedule) } const std::string &kernel_name = scop_info_.user_config_.GetKernelName(); + if (!hints_.HaveDirectives()) { + mls::bin::Hints directive_hint = ExtractDirectivesFromAKG(scop_info_); + UpdateHints(directive_hint); + } mls::bin::Scop scop(initial_schedule, dependences, reads, writes, hints_, options, kernel_name.c_str()); const bool success = scop.ComputeSchedule(); @@ -1642,45 +1646,26 @@ isl::schedule SchedulingMindTrick::GpuPostProcessSchedule(const isl::schedule &s /////////////////////////////////////////////////////////////////////////// #ifdef AKG_USE_MLS -void SchedulingMindTrick::ExtractDirectivesFromAKG(void) { - ForTypeMap directives = scop_info_.analysis_result_.GetForTypeMap(); - std::map> serials_dir; - std::map> vectorials_dir; - std::map> parallels_dir; - for (const auto &[stmt, vloop_directive] : directives) { - std::string stmt_string = stmt.get_name(); - for (uint i = 0; i < vloop_directive.size(); ++i) { - switch (vloop_directive[i]) { - case ForType::Serial: - break; - case ForType::Invariant: - LOG(INFO) << "invariant_for"; - serials_dir[stmt_string].push_back(i); - break; - case ForType::Parallel: - LOG(INFO) << "parallel"; - parallels_dir[stmt_string].push_back(i); - break; - case ForType::Vectorized: - case ForType::Swizzled: // treat "Swizzled" like "Vectorized" for the moment - LOG(INFO) << "vectorized"; - vectorials_dir[stmt_string].push_back(i); - break; - case ForType::Unrolled: - LOG(WARNING) << "Do not treat ForType::Unrolled as a directives"; - break; - default: - break; +void SchedulingMindTrick::UpdateHints(mls::bin::Hints directive_hint) { + if (!hints_.HaveDirectives()) { + ForTypeMap directives = scop_info_.analysis_result_.GetForTypeMap(); + for (const auto &[stmt, vloop_directive] : directives) { + (void)vloop_directive; + const char *stmt_string = stmt.get_name().c_str(); + if (directive_hint.HasStatementSerials(stmt_string)) { + hints_.SetStatementSerials(stmt_string, directive_hint.GetStatementSerials(stmt_string)); + } + if (directive_hint.HasStatementVectorials(stmt_string)) { + hints_.SetStatementVectorials(stmt_string, directive_hint.GetStatementVectorials(stmt_string)); + } + if (directive_hint.HasStatementParallels(stmt_string)) { + hints_.SetStatementParallels(stmt_string, directive_hint.GetStatementParallels(stmt_string)); + } + if (directive_hint.HasStatementReduces(stmt_string)) { + hints_.SetStatementReduces(stmt_string, directive_hint.GetStatementReduces(stmt_string)); } } } - - for (const auto &[key, directive] : serials_dir) { - hints_.SetStatementSerials(key.c_str(), directive); - } - for (const auto &[key, directive] : vectorials_dir) { - hints_.SetStatementVectorials(key.c_str(), directive); - } } #endif diff --git a/src/poly/schedule_pass/scheduling_mind_trick.h b/src/poly/schedule_pass/scheduling_mind_trick.h index 2ecc84e67bb1739a60ced8e1b807bbba2933ad6d..93b3c66e05741d8081a08dd3c2ddde10bf04e329 100644 --- a/src/poly/schedule_pass/scheduling_mind_trick.h +++ b/src/poly/schedule_pass/scheduling_mind_trick.h @@ -200,7 +200,7 @@ class SchedulingMindTrick { /////////////////////////////////////////////////////////////////////////// #ifdef AKG_USE_MLS - void ExtractDirectivesFromAKG(void); + void UpdateHints(mls::bin::Hints directive_hint); #endif ///////////////////////////////////////////////////////////////////////////