From eb42f34a08b5daa52a43e4ea40e40e4aa4dbb72e Mon Sep 17 00:00:00 2001 From: Nelson Lossing Date: Tue, 24 May 2022 12:59:03 +0000 Subject: [PATCH] CustomOp integrate usage of intrinsic directives for the scheduler for the moment need to activate manually MLSched so that AKG can take the intrisic directives into account for the scheduling export MS_DEV_POLY_SCHEDULER=mls or attrs['enable_mlsched']=True --- src/poly/mls.h | 2 +- src/poly/schedule_pass.cc | 57 +++++++++++++++++++ src/poly/schedule_pass.h | 4 ++ src/poly/schedule_pass/compute_schedule.cc | 49 +--------------- src/poly/schedule_pass/compute_schedule.h | 4 -- .../schedule_pass/scheduling_mind_trick.cc | 57 +++++++------------ .../schedule_pass/scheduling_mind_trick.h | 2 +- 7 files changed, 85 insertions(+), 90 deletions(-) diff --git a/src/poly/mls.h b/src/poly/mls.h index a4a365b4..b3b5df45 100644 --- a/src/poly/mls.h +++ b/src/poly/mls.h @@ -549,7 +549,7 @@ class Hints { /// \retval false otherwise [[gnu::pure]] bool HasStatementVectorials(const char *statement) const; - /// \brief Check whether the Hints has reduces directives for a given stateemnt + /// \brief Check whether the Hints has reduces directives for a given statement /// \param[in] statement Target statement /// \return A boolean value that indicates whether the Hints has reduces directives for \a statement /// \retval true if the Hints has reduces directives for \a statement diff --git a/src/poly/schedule_pass.cc b/src/poly/schedule_pass.cc index 6a368f34..6e5e76ce 100644 --- a/src/poly/schedule_pass.cc +++ b/src/poly/schedule_pass.cc @@ -564,6 +564,63 @@ mls::bin::Options MLSchedOptionsInit(const akg::ir::poly::PassInfo &pass_info, return result; } + +mls::bin::Hints ExtractDirectivesFromAKG(ScopInfo &scop_info) { + mls::bin::Hints hints; + + ForTypeMap directives = scop_info.analysis_result_.GetForTypeMap(); + std::map> serials_directive; + std::map> vectorials_directive; + std::map> parallels_directive; + std::map> reduces_directive; + for (const auto &[stmt, vloop_directive] : directives) { + std::string stmt_string = stmt.get_name(); + for (uint i = 0; i < vloop_directive.size(); ++i) { + switch (vloop_directive[i]) { + case ForType::Serial: + break; + case ForType::Invariant: + LOG(INFO) << stmt_string << " invariant_for"; + serials_directive[stmt_string].push_back(i); + break; + case ForType::Parallel: + LOG(INFO) << stmt_string << " parallel"; + parallels_directive[stmt_string].push_back(i); + break; + case ForType::Vectorized: + case ForType::Swizzled: // treat "Swizzled" like "Vectorized" for the moment + LOG(INFO) << stmt_string << " vectorized"; + vectorials_directive[stmt_string].push_back(i); + break; + case ForType::Reduce: + LOG(INFO) << stmt_string << " reduce"; + reduces_directive[stmt_string].push_back(i); + break; + case ForType::Unrolled: + LOG(WARNING) << stmt_string << " Do not treat ForType::Unrolled as a directives"; + break; + default: + LOG(WARNING) << stmt_string << " Unknow ForType loop"; + break; + } + } + } + + for (const auto &[key, directive] : serials_directive) { + hints.SetStatementSerials(key.c_str(), directive); + } + for (const auto &[key, directive] : vectorials_directive) { + hints.SetStatementVectorials(key.c_str(), directive); + } + for (const auto &[key, directive] : parallels_directive) { + hints.SetStatementParallels(key.c_str(), directive); + } + for (const auto &[key, directive] : reduces_directive) { + hints.SetStatementReduces(key.c_str(), directive); + } + + return hints; +} #endif } // namespace poly diff --git a/src/poly/schedule_pass.h b/src/poly/schedule_pass.h index 534ad2c4..8a38510c 100644 --- a/src/poly/schedule_pass.h +++ b/src/poly/schedule_pass.h @@ -130,6 +130,10 @@ bool MLSchedShouldBeUsed(akg::ir::poly::ScopInfo &scop_info); /// The options may be decided arbitrarily, from the environment or from \a pass_info and \a scop_info. mls::bin::Options MLSchedOptionsInit(const akg::ir::poly::PassInfo &pass_info, const akg::ir::poly::ScopInfo &scop_info); + +/// \brief Extract the directives informations from the information coming from AKG scop +/// \result return an hint object that can be used for MLSched scheduler. +mls::bin::Hints ExtractDirectivesFromAKG(ScopInfo &scop_info); #endif } // namespace poly diff --git a/src/poly/schedule_pass/compute_schedule.cc b/src/poly/schedule_pass/compute_schedule.cc index dad5cbd2..bc413a91 100644 --- a/src/poly/schedule_pass/compute_schedule.cc +++ b/src/poly/schedule_pass/compute_schedule.cc @@ -167,53 +167,6 @@ isl::union_pw_aff ComputeSchedule::GenerateNewAffine(const isl::union_pw_aff &sw return new_aff; }; -#ifdef AKG_USE_MLS -mls::bin::Hints ComputeSchedule::ExtractDirectivesFromAKG(void) { - mls::bin::Hints hints; - - ForTypeMap directives = scop_info_.analysis_result_.GetForTypeMap(); - std::map> serials_dir; - std::map> vectorials_dir; - std::map> parallels_dir; - for (const auto &[stmt, vloop_directive] : directives) { - std::string stmt_string = stmt.get_name(); - for (uint i = 0; i < vloop_directive.size(); ++i) { - switch (vloop_directive[i]) { - case ForType::Serial: - break; - case ForType::Invariant: - LOG(INFO) << stmt_string << "invariant_for"; - serials_dir[stmt_string].push_back(i); - break; - case ForType::Parallel: - LOG(INFO) << stmt_string << "parallel"; - parallels_dir[stmt_string].push_back(i); - break; - case ForType::Vectorized: - case ForType::Swizzled: // treat "Swizzled" like "Vectorized" for the moment - LOG(INFO) << stmt_string << "vectorized"; - vectorials_dir[stmt_string].push_back(i); - break; - case ForType::Unrolled: - LOG(WARNING) << stmt_string << "Do not treat ForType::Unrolled as a directives"; - break; - default: - break; - } - } - } - - for (const auto &[key, directive] : serials_dir) { - hints.SetStatementSerials(key.c_str(), directive); - } - for (const auto &[key, directive] : vectorials_dir) { - hints.SetStatementVectorials(key.c_str(), directive); - } - - return hints; -} -#endif - isl::schedule ComputeSchedule::Run(isl::schedule sch) { if (scop_info_.user_config_.GetModScheduleShift()) { pass_info_.dependences_ = ModDependences(pass_info_.dependences_); @@ -233,7 +186,7 @@ isl::schedule ComputeSchedule::Run(isl::schedule sch) { } const std::string &kernel_name = scop_info_.user_config_.GetKernelName(); - const mls::bin::Hints hints = ExtractDirectivesFromAKG(); + const mls::bin::Hints hints = ExtractDirectivesFromAKG(scop_info_); const isl::union_map reads = UnwrappedAccesses(scop_info_.analysis_result_.GetReads()); const isl::union_map writes = UnwrappedAccesses(scop_info_.analysis_result_.GetWrites()); isl_union_map *const dependences = pass_info_.dependences_.get(); diff --git a/src/poly/schedule_pass/compute_schedule.h b/src/poly/schedule_pass/compute_schedule.h index cc9f5f4f..ae7006f2 100644 --- a/src/poly/schedule_pass/compute_schedule.h +++ b/src/poly/schedule_pass/compute_schedule.h @@ -51,10 +51,6 @@ class ComputeSchedule : public SchedulePass { isl::union_pw_aff GenerateNewAffine(const isl::union_pw_aff &swap_out, const isl::union_pw_aff &swap_in, std::unordered_set swap_ids); -#ifdef AKG_USE_MLS - mls::bin::Hints ExtractDirectivesFromAKG(void); -#endif - private: PassInfo &pass_info_; diff --git a/src/poly/schedule_pass/scheduling_mind_trick.cc b/src/poly/schedule_pass/scheduling_mind_trick.cc index 447419c9..4653bb5e 100644 --- a/src/poly/schedule_pass/scheduling_mind_trick.cc +++ b/src/poly/schedule_pass/scheduling_mind_trick.cc @@ -1426,6 +1426,10 @@ bool SchedulingMindTrick::BuildInfluencedSchedule(const isl::schedule &schedule) } const std::string &kernel_name = scop_info_.user_config_.GetKernelName(); + if (!hints_.HaveDirectives()) { + mls::bin::Hints directive_hint = ExtractDirectivesFromAKG(scop_info_); + UpdateHints(directive_hint); + } mls::bin::Scop scop(initial_schedule, dependences, reads, writes, hints_, options, kernel_name.c_str()); const bool success = scop.ComputeSchedule(); @@ -1642,45 +1646,26 @@ isl::schedule SchedulingMindTrick::GpuPostProcessSchedule(const isl::schedule &s /////////////////////////////////////////////////////////////////////////// #ifdef AKG_USE_MLS -void SchedulingMindTrick::ExtractDirectivesFromAKG(void) { - ForTypeMap directives = scop_info_.analysis_result_.GetForTypeMap(); - std::map> serials_dir; - std::map> vectorials_dir; - std::map> parallels_dir; - for (const auto &[stmt, vloop_directive] : directives) { - std::string stmt_string = stmt.get_name(); - for (uint i = 0; i < vloop_directive.size(); ++i) { - switch (vloop_directive[i]) { - case ForType::Serial: - break; - case ForType::Invariant: - LOG(INFO) << "invariant_for"; - serials_dir[stmt_string].push_back(i); - break; - case ForType::Parallel: - LOG(INFO) << "parallel"; - parallels_dir[stmt_string].push_back(i); - break; - case ForType::Vectorized: - case ForType::Swizzled: // treat "Swizzled" like "Vectorized" for the moment - LOG(INFO) << "vectorized"; - vectorials_dir[stmt_string].push_back(i); - break; - case ForType::Unrolled: - LOG(WARNING) << "Do not treat ForType::Unrolled as a directives"; - break; - default: - break; +void SchedulingMindTrick::UpdateHints(mls::bin::Hints directive_hint) { + if (!hints_.HaveDirectives()) { + ForTypeMap directives = scop_info_.analysis_result_.GetForTypeMap(); + for (const auto &[stmt, vloop_directive] : directives) { + (void)vloop_directive; + const char *stmt_string = stmt.get_name().c_str(); + if (directive_hint.HasStatementSerials(stmt_string)) { + hints_.SetStatementSerials(stmt_string, directive_hint.GetStatementSerials(stmt_string)); + } + if (directive_hint.HasStatementVectorials(stmt_string)) { + hints_.SetStatementVectorials(stmt_string, directive_hint.GetStatementVectorials(stmt_string)); + } + if (directive_hint.HasStatementParallels(stmt_string)) { + hints_.SetStatementParallels(stmt_string, directive_hint.GetStatementParallels(stmt_string)); + } + if (directive_hint.HasStatementReduces(stmt_string)) { + hints_.SetStatementReduces(stmt_string, directive_hint.GetStatementReduces(stmt_string)); } } } - - for (const auto &[key, directive] : serials_dir) { - hints_.SetStatementSerials(key.c_str(), directive); - } - for (const auto &[key, directive] : vectorials_dir) { - hints_.SetStatementVectorials(key.c_str(), directive); - } } #endif diff --git a/src/poly/schedule_pass/scheduling_mind_trick.h b/src/poly/schedule_pass/scheduling_mind_trick.h index 2ecc84e6..93b3c66e 100644 --- a/src/poly/schedule_pass/scheduling_mind_trick.h +++ b/src/poly/schedule_pass/scheduling_mind_trick.h @@ -200,7 +200,7 @@ class SchedulingMindTrick { /////////////////////////////////////////////////////////////////////////// #ifdef AKG_USE_MLS - void ExtractDirectivesFromAKG(void); + void UpdateHints(mls::bin::Hints directive_hint); #endif /////////////////////////////////////////////////////////////////////////// -- Gitee