From f4202da1ebf043eae379f497b57f6b28261fa48c Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Tue, 26 Apr 2022 11:45:40 +0800 Subject: [PATCH] optimize high-dimentional csr --- src/pass/fuse_axis.cc | 26 +++++++---- src/pass/utils.h | 9 +++- src/poly/poly_util.h | 4 +- .../schedule_analysis/band_node_analysis.cc | 8 +++- src/poly/scop_info.h | 6 +++ src/poly/scop_make_schedule_tree.cc | 20 ++++++-- src/poly/tiling/tiling_analyzer.h | 3 ++ .../tiling/tiling_strategy_manager_gpu.cc | 46 +++++++++++++++---- 8 files changed, 96 insertions(+), 26 deletions(-) diff --git a/src/pass/fuse_axis.cc b/src/pass/fuse_axis.cc index 1dacb1ec..f0de9546 100644 --- a/src/pass/fuse_axis.cc +++ b/src/pass/fuse_axis.cc @@ -61,9 +61,10 @@ namespace ir { class FuseAxisExtern : public IRMutator { public: - std::unordered_map> var_name_with_range; - explicit FuseAxisExtern(std::unordered_map> &var_and_range) - : var_name_with_range{var_and_range} {} + std::unordered_map> var_name_with_range_; + Expr feat_len_{1}; + explicit FuseAxisExtern(std::unordered_map> &var_and_range, Expr feat_len) + : var_name_with_range_{var_and_range}, feat_len_{feat_len} {} Stmt Mutate_(const For *op, const Stmt &s) final { auto next_for_op = op->body.as(); if (next_for_op) { @@ -88,7 +89,7 @@ class FuseAxisExtern : public IRMutator { bool CheckFusible(std::string name_hint) { auto idx = name_hint.find("fused"); - return var_name_with_range.count(name_hint) || idx != std::string::npos; + return var_name_with_range_.count(name_hint) || idx != std::string::npos; } Expr Mutate_(const FloorDiv *op, const Expr &e) final { return air::ir::CanonicalSimplify(e); } @@ -98,14 +99,18 @@ class FuseAxisExtern : public IRMutator { class FusionVarCollector : public IRVisitor { public: - std::unordered_map> var_name_with_range; + std::unordered_map> var_name_with_range_; std::vector not_fused_var_name; void Visit_(const For *op) override { std::string var_name = op->loop_var->name_hint; if (op->min.as() && op->extent.as()) { auto range = Range::make_by_min_extent(op->min, op->extent); - var_name_with_range[var_name] = std::make_pair(op->loop_var, range); + var_name_with_range_[var_name] = std::make_pair(op->loop_var, range); + if (is_feature_dim_) { + feat_len_ *= op->extent; + } + is_feature_dim_ = true; } IRVisitor::Visit_(op); } @@ -121,6 +126,9 @@ class FusionVarCollector : public IRVisitor { } IRVisitor::Visit_(op); } + + bool is_feature_dim_{false}; + Expr feat_len_{1}; }; Stmt FuseAxisExternOp(Stmt stmt, air::Schedule sch) { @@ -128,11 +136,12 @@ Stmt FuseAxisExternOp(Stmt stmt, air::Schedule sch) { auto bounds = air::schedule::InferBound(sch); auto fusion_var_vollector = FusionVarCollector(); fusion_var_vollector.Visit(stmt); - auto var_name_with_range{fusion_var_vollector.var_name_with_range}; + auto var_name_with_range{fusion_var_vollector.var_name_with_range_}; + auto feat_len{fusion_var_vollector.feat_len_}; for (auto var_name : fusion_var_vollector.not_fused_var_name) { var_name_with_range.erase(var_name); } - auto fuse_axis_extern = FuseAxisExtern(var_name_with_range); + auto fuse_axis_extern = FuseAxisExtern(var_name_with_range, feat_len); // prevent infinite loop for (size_t i{0}; i < MAX_FUSE_TIMES; ++i) { auto fused_stmt = fuse_axis_extern.Mutate(stmt); @@ -141,6 +150,7 @@ Stmt FuseAxisExternOp(Stmt stmt, air::Schedule sch) { } stmt = fused_stmt; } + stmt = AttrStmt::make(Expr("INFO"), "csr_feature_length", fuse_axis_extern.feat_len_, stmt); return stmt; } } // namespace ir diff --git a/src/pass/utils.h b/src/pass/utils.h index 03de9946..c31d4d2d 100644 --- a/src/pass/utils.h +++ b/src/pass/utils.h @@ -34,7 +34,7 @@ using air::ir::substitute; static const float HALF_MIN = 5.960464e-08; // minimum number of float16 static const float HALF_MAX = 65504.0; // maximum number of float16 -struct PairHash { +struct PairHash { template size_t operator()(const std::pair &a) const { return dmlc::HashCombine(std::hash()(a.first), std::hash()(a.second)); @@ -448,12 +448,17 @@ constexpr auto AKG_INNER_TENSOR = "INNER_TENSOR"; constexpr auto AKG_TENSOR_OF_TENSOR = "TENSOR_OF_TENSOR"; constexpr auto AKG_ATOMIC_TOT = "atomic_tot"; constexpr auto AKG_REMOVE_SELF_DEPENDENCE = "REMOVE_SELF_DEPENDENCE"; +constexpr auto CSR_FEATURE_LENGTH = "csr_feature_length"; constexpr auto CSR_AVG_ROW = "csr_avg_row"; constexpr auto CSR_MAP_THREAD = "csr_map_thread"; static constexpr auto ATTR_PREFETCH_MODE = "prefetch_mode"; enum class PrefetchMode { - DEFAULT = 0, TRANSFERBUFFER, DOUBLEBUFFER, TRANSFERBUFFER_THREADGROUP, DOUBLEBUFFER_THREADGROUP + DEFAULT = 0, + TRANSFERBUFFER, + DOUBLEBUFFER, + TRANSFERBUFFER_THREADGROUP, + DOUBLEBUFFER_THREADGROUP }; constexpr auto REDUCE_LIB_TYPE_ORIGIN = "origin"; diff --git a/src/poly/poly_util.h b/src/poly/poly_util.h index 94a119a5..5ac9456f 100644 --- a/src/poly/poly_util.h +++ b/src/poly/poly_util.h @@ -500,7 +500,9 @@ const std::unordered_set AkgSupportedReduceOp = {AKG_REDUCE_SUM, AK AKG_REDUCE_AND, AKG_REDUCE_OR, AKG_REDUCE_PROD}; const std::unordered_set AkgSupportedTotOp = {AKG_ATOMIC_TOT, AKG_TENSOR_OF_TENSOR, AKG_TENSOR_NOT_PROMOTE, - AKG_INNER_TENSOR, AKG_REMOVE_SELF_DEPENDENCE}; + AKG_INNER_TENSOR}; + +const std::unordered_set AkgSupportedCsrOp = {AKG_REMOVE_SELF_DEPENDENCE, CSR_AVG_ROW, CSR_FEATURE_LENGTH}; const std::vector ConvATTRList = {ATTR_CONV_FEATURE_W, ATTR_CONV_KERNEL_H, ATTR_CONV_KERNEL_W, ATTR_CONV_STRIDE_H, ATTR_CONV_STRIDE_W, ATTR_CONV_DILATION_H, diff --git a/src/poly/schedule_analysis/band_node_analysis.cc b/src/poly/schedule_analysis/band_node_analysis.cc index 8868bdd7..f6cabf24 100644 --- a/src/poly/schedule_analysis/band_node_analysis.cc +++ b/src/poly/schedule_analysis/band_node_analysis.cc @@ -177,7 +177,13 @@ class OperatorInfoCollector { ReduceDirection reduce_direction = ReduceDirection::UNKNOWN; if (scop_info_.analysis_result_.GetCsr()) { - reduce_direction = ReduceDirection::X; + // If CSR has feature dimension, threadIdx.x should map to the feature + // dimension instead, to keep continuous memory load/save + if (scop_info_.analysis_result_.GetCsrFeatLen() > 1) { + reduce_direction = ReduceDirection::Y; + } else { + reduce_direction = ReduceDirection::X; + } } else { PostOrderVisit(op->value, [&reduce_direction, &reduce_attrs, op](const NodeRef &node) -> void { if (reduce_direction == ReduceDirection::Y) { diff --git a/src/poly/scop_info.h b/src/poly/scop_info.h index 8bb62cb6..c4b10705 100644 --- a/src/poly/scop_info.h +++ b/src/poly/scop_info.h @@ -1238,6 +1238,9 @@ class AnalysisResult { void SetCsrAvgRow(int csr_avg_row) { csr_avg_row_ = csr_avg_row; } int GetCsrAvgRow() { return csr_avg_row_; } + void SetCsrFeatLen(int csr_feat_len) { csr_feat_len_ = csr_feat_len; } + int GetCsrFeatLen() { return csr_feat_len_; } + void ResetOuterBandNode() { outer_band_nodes_.clear(); } void ResetActivateBufferFootprints() { active_buffer_footprints_.clear(); } void ResetBufferDefInfos() { buffer_def_infos_.clear(); } @@ -1323,9 +1326,12 @@ class AnalysisResult { std::unordered_set tensors_not_promote_; std::unordered_set inner_tensor_; bool is_tensor_of_tensor_{false}; + + // csr bool is_csr_{false}; bool remove_self_dependence_{false}; int csr_avg_row_{0}; + int csr_feat_len_{1}; RestartPassName restart_pass_name_{RestartPassName::NOT_RESTART}; std::unordered_map pass_schedule_map_; diff --git a/src/poly/scop_make_schedule_tree.cc b/src/poly/scop_make_schedule_tree.cc index 5f0b2fa9..b51819f4 100644 --- a/src/poly/scop_make_schedule_tree.cc +++ b/src/poly/scop_make_schedule_tree.cc @@ -447,8 +447,20 @@ class ScopMakeScheduleTree final : protected IRVisitor { } else if (op->attr_key == AKG_INNER_TENSOR) { CHECK(op->value.as()); scop_info_.analysis_result_.RecordInnerTensor(op->value.as()->value); - } else if (op->attr_key == AKG_REMOVE_SELF_DEPENDENCE) { + } + } + + void SetCsrInfo(const AttrStmt *op) { + if (op->attr_key == AKG_REMOVE_SELF_DEPENDENCE) { scop_info_.analysis_result_.SetRemoveSelfDependence(true); + } else if (op->attr_key == CSR_FEATURE_LENGTH) { + auto csr_feature_length = op->value.as(); + CHECK(csr_feature_length); + scop_info_.analysis_result_.SetCsrFeatLen(csr_feature_length->value); + } else if (op->attr_key == CSR_AVG_ROW) { + auto csr_avg_row = op->value.as(); + CHECK(csr_avg_row); + scop_info_.analysis_result_.SetCsrAvgRow(csr_avg_row->value); } } @@ -479,10 +491,8 @@ class ScopMakeScheduleTree final : protected IRVisitor { Op_buffer_bind_scope(op); } else if (op->attr_key == ATTR_IM2COL_KEY) { scop_info_.analysis_result_.RecordAttrStmt(op); - } else if (op->attr_key == CSR_AVG_ROW) { - auto csr_avg_row = op->value.as(); - CHECK(csr_avg_row); - scop_info_.analysis_result_.SetCsrAvgRow(csr_avg_row->value); + } else if (AkgSupportedCsrOp.count(op->attr_key) != 0) { + SetCsrInfo(op); } sch = MakeScheduleTreeHelper(op->body, scop_info_, set, outer, macro_stmt); diff --git a/src/poly/tiling/tiling_analyzer.h b/src/poly/tiling/tiling_analyzer.h index 89bab23e..ec1f475a 100755 --- a/src/poly/tiling/tiling_analyzer.h +++ b/src/poly/tiling/tiling_analyzer.h @@ -62,6 +62,9 @@ constexpr auto TOT_BEST_NUM_PER_BLOCK = 1024; constexpr auto OUTERMOST_AXIS = 0; constexpr auto CPU_CSR_TILING_FACTOR = 1; constexpr auto CPU_CSR_PARALLEL_CUTOFF = 4096; +constexpr auto GPU_CSR_FUSION_AXES_SIZE = 2; +constexpr auto GPU_CSR_BEST_NUM_NODES_PER_BLOCK = 8; +constexpr auto GPU_CSR_NO_TILE = 1; // Controlled by custom tiling. constexpr auto ALLOCATION_PERCENTAGE = 0.5; // reserved for double buffer in default diff --git a/src/poly/tiling/tiling_strategy_manager_gpu.cc b/src/poly/tiling/tiling_strategy_manager_gpu.cc index 64a53159..c8ebc92a 100644 --- a/src/poly/tiling/tiling_strategy_manager_gpu.cc +++ b/src/poly/tiling/tiling_strategy_manager_gpu.cc @@ -356,7 +356,8 @@ void ReduceStrategy::AddGpuConstraint() { if (auto ext = axis->range_extent.as()) { reduce_length_ *= ext->value; } else if (analyzer_->scop_info_.analysis_result_.IsCsrDynamicExtent(axis->range_extent)) { - reduce_length_ = analyzer_->scop_info_.user_config_.GetCsrThreadNum(); + int rest_thread = total_available_thread_ / analyzer_->scop_info_.analysis_result_.GetCsrFeatLen(); + reduce_length_ = std::min(analyzer_->scop_info_.user_config_.GetCsrThreadNum(), rest_thread); } if (std::count(reduce_axes_.begin(), reduce_axes_.end(), axis)) { return; @@ -1690,9 +1691,8 @@ void GpuStrategy::InjectiveSpeedup() { auto parallel_size = GetProposalParallelSize(problem_size); auto proposal_blocks = parallel_size.first; auto proposal_threads = parallel_size.second; - auto proposal_elem_per_thread = coaleasced_size < warp_sizes_ ? 1 - : total_blocks < proposal_blocks * 8 ? min_elem_for_io_bound_ - : 8; + auto proposal_elem_per_thread = + coaleasced_size < warp_sizes_ ? 1 : total_blocks < proposal_blocks * 8 ? min_elem_for_io_bound_ : 8; proposal_elem_per_thread = proposal_elem_per_thread / SafeDivisor(curr_elem_size); CHECK(proposal_threads != 0 && total_blocks != 0); int64_t shrinked_threads = @@ -2277,19 +2277,47 @@ void CsrStrategy::AddGpuConstraint() { } axes.push_back(axis); }); + auto available_threads = total_available_thread_; + int csr_thread_num = -1; + auto feat_len = analyzer_->scop_info_.analysis_result_.GetCsrFeatLen(); + if (feat_len > 1) { + // CSR schedule with feature dimension (csr.values > 1d), axis has already been + // fused to outer axis (static boundary), and inner axis (dynamic boundary). + // Feature dimension will be mapped first + CHECK(axes.size() == GPU_CSR_FUSION_AXES_SIZE); + auto outer_axis = axes[OUTERMOST_AXIS]; + auto inner_axis = axes[OUTERMOST_AXIS + 1]; + bool use_reduce_lib = analyzer_->scop_info_.analysis_result_.GetUseGpuReduceLib(); + csr_thread_num = use_reduce_lib ? analyzer_->scop_info_.user_config_.GetCsrThreadNum() : GPU_CSR_NO_TILE; + // For outer axis + CHECK(!analyzer_->scop_info_.analysis_result_.IsCsrDynamicExtent(outer_axis->range_extent)); + int max_nodes_per_block = available_threads / feat_len / csr_thread_num; + auto nodes_per_block = std::min(max_nodes_per_block, GPU_CSR_BEST_NUM_NODES_PER_BLOCK); + nodes_per_block = SafeDivisor(nodes_per_block); + outer_axis->block_constraints.map_extent_ = + std::ceil(static_cast(outer_axis->extent_val) / feat_len / nodes_per_block); + outer_axis->thread_constraints.map_extent_ = feat_len * nodes_per_block; + // For inner axis + CHECK(analyzer_->scop_info_.analysis_result_.IsCsrDynamicExtent(inner_axis->range_extent)); + inner_axis->block_constraints.map_extent_ = GPU_CSR_NO_TILE; + inner_axis->thread_constraints.map_extent_ = csr_thread_num; + inner_axis->c1_constraints.tile_min_ = GPU_CSR_NO_TILE; + inner_axis->c1_constraints.tile_extent_ = csr_thread_num; + analyzer_->scop_info_.user_config_.SetCsrThreadNum(csr_thread_num); + return; + } std::sort(axes.begin(), axes.end(), [](TileAxis *a, TileAxis *b) { if (a->dim_axis == b->dim_axis) { return a->index < b->index; } return a->dim_axis > b->dim_axis; }); - auto available_threads = total_available_thread_; - int csr_thread_num = -1; for (size_t i = 0; i < axes.size(); ++i) { + // CSR schedule without dimension (csr.values = 1d) auto axis = axes[i]; if (analyzer_->scop_info_.analysis_result_.IsCsrDynamicExtent(axis->range_extent)) { if (csr_thread_num != -1) { - axis->block_constraints.map_extent_ = 1; + axis->block_constraints.map_extent_ = GPU_CSR_NO_TILE; axis->thread_constraints.map_extent_ = csr_thread_num; axis->c1_constraints.tile_extent_ = csr_thread_num; } else { @@ -2307,14 +2335,14 @@ void CsrStrategy::AddGpuConstraint() { csr_thread_num = static_cast(std::exp2(warp_base)) * WARP_SIZE; } csr_thread_num = std::min(static_cast(csr_thread_num), available_threads); - axis->block_constraints.map_extent_ = 1; + axis->block_constraints.map_extent_ = GPU_CSR_NO_TILE; axis->thread_constraints.map_extent_ = csr_thread_num; axis->c1_constraints.tile_extent_ = csr_thread_num; analyzer_->scop_info_.user_config_.SetCsrThreadNum(csr_thread_num); available_threads /= SafeDivisor(csr_thread_num); } } else if (axis->dim_axis == 0) { - axis->thread_constraints.map_extent_ = 1; + axis->thread_constraints.map_extent_ = GPU_CSR_NO_TILE; } else { available_threads /= SafeDivisor(std::min(axis->extent_val, available_threads)); } -- Gitee