From da79d7e418edf4d3be598c4cdbb7a93af970284f Mon Sep 17 00:00:00 2001 From: hujiahui8 Date: Sat, 30 Jul 2022 15:30:06 +0800 Subject: [PATCH] adapt transpose operator --- src/api/api_pass.cc | 1 + src/codegen/lower_llvm.cc | 1 + src/include/ir_pass.h | 2 + src/pass/adjust_parallel_loop.cc | 4 +- src/pass/matrix_transpose.cc | 105 +++++++++++ src/pass/reconstruct_layout.cc | 2 + src/poly/create_cluster.cc | 15 +- src/poly/create_cluster.h | 1 + .../schedule_analysis/band_node_analysis.cc | 2 +- src/poly/schedule_pass/tile_outer_band.cc | 29 +++- src/poly/schedule_pass/tile_outer_band.h | 1 + .../schedule_pass_cpu/cpu_memory_manager.cc | 36 +++- .../register_memory_manager.cc | 13 +- .../shared_memory_manager.cc | 6 +- src/poly/schedule_tree_util.cc | 48 ++++-- src/poly/schedule_tree_util.h | 9 +- src/poly/scop_info.h | 7 +- src/poly/tiling/space_analyzer.cc | 38 +++- src/poly/tiling/space_analyzer.h | 3 +- src/poly/tiling/tiling_analyzer.h | 3 + src/poly/tiling/tiling_strategy_manager.h | 12 +- .../tiling/tiling_strategy_manager_cpu.cc | 163 +++++++++++++----- tests/st/ops/test_transpose.py | 1 + .../src/codegen/llvm/codegen_llvm.cc | 79 +++++---- .../src/codegen/llvm/codegen_llvm.h | 9 +- 25 files changed, 468 insertions(+), 122 deletions(-) create mode 100644 src/pass/matrix_transpose.cc diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 80a6d529..13ae4f54 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -144,6 +144,7 @@ REGISTER_PASS(StrideKernelOp); REGISTER_PASS(UnifyLoopVars); REGISTER_PASS(TileCoverCorrect); REGISTER_PASS(ReconstructLayout); +REGISTER_PASS(MatrixTranspose); REGISTER_PASS(AdjustParallelLoop); REGISTER_PASS(ReductionFactor); } // namespace ir diff --git a/src/codegen/lower_llvm.cc b/src/codegen/lower_llvm.cc index a3efdbd1..5ed774e0 100644 --- a/src/codegen/lower_llvm.cc +++ b/src/codegen/lower_llvm.cc @@ -82,6 +82,7 @@ StageResult LLVMLowerBeforeFlattern(Stmt &stmt, LowerData &data) { } stmt = NEXT_PASS(RealizeCompress, stmt); stmt = NEXT_PASS(ReconstructLayout, stmt); + stmt = NEXT_PASS(MatrixTranspose, stmt); stmt = NEXT_PASS(AdjustParallelLoop, stmt); stmt = NEXT_PASS(ReductionFactor, stmt, data->binds_0); } diff --git a/src/include/ir_pass.h b/src/include/ir_pass.h index 48d4dc85..737d49fa 100644 --- a/src/include/ir_pass.h +++ b/src/include/ir_pass.h @@ -92,6 +92,8 @@ Stmt InjectTransferBufferScope(Stmt stmt); */ Stmt ReconstructLayout(const Stmt &stmt); +Stmt MatrixTranspose(const Stmt &stmt); + Stmt AdjustParallelLoop(const Stmt &stmt); Stmt ReductionFactor(const Stmt &stmt, const Map &extern_buffer); diff --git a/src/pass/adjust_parallel_loop.cc b/src/pass/adjust_parallel_loop.cc index e1b4b770..e17ed224 100644 --- a/src/pass/adjust_parallel_loop.cc +++ b/src/pass/adjust_parallel_loop.cc @@ -99,16 +99,14 @@ class FuseParallelLoop : public IRMutator { // Calculate extent after merged parallel loop. int init_i = 0; - Expr div_extend = 1; Expr extent_sum = 1; std::unordered_map vmap; for (auto item : value_map_) { if (init_i == 0) { vmap[item.first] = Mod::make(op->loop_var, item.second); - div_extend = Mul::make(div_extend, item.second); ++init_i; } else { - auto tmp_div = Div::make(op->loop_var, div_extend); + auto tmp_div = Div::make(op->loop_var, extent_sum); vmap[item.first] = Mod::make(tmp_div, item.second); } extent_sum = Mul::make(extent_sum, item.second); diff --git a/src/pass/matrix_transpose.cc b/src/pass/matrix_transpose.cc new file mode 100644 index 00000000..749b6a12 --- /dev/null +++ b/src/pass/matrix_transpose.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "ir_pass.h" + +namespace akg { +namespace ir { +static constexpr auto PROMOTE_TRANSPOSE = "promoted_transpose"; +static constexpr auto MATRIX_TRANSPOSE = "MatrixTranspose"; +static constexpr auto INT32 = 32; +static constexpr auto PARAMETER_NUM = 4; + +class MatrixTransposeMutator : public IRMutator { + public: + explicit MatrixTransposeMutator() {} + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + if (op->attr_key == PROMOTE_TRANSPOSE) { + return CallTransposeInterface(op, s); + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Provide *op, const Stmt &s) final { + auto provide = IRMutator::Mutate_(op, s); + provide_ = provide; + return provide; + } + + Stmt Mutate_(const For *op, const Stmt &s) final { + extents_.push_back(op->extent); + return IRMutator::Mutate_(op, s); + } + + private: + Stmt CallTransposeInterface(const AttrStmt *op, const Stmt &s) { + extents_.clear(); + auto stmt = IRMutator::Mutate(op->body); + for (auto extent : extents_) { + auto extent_value = extent.as()->value; + if (extent_value & (extent_value - 1)) { + return s; + } + } + + Array shapes; + std::reverse(extents_.begin(), extents_.end()); + shapes.assign(extents_.begin(), extents_.end()); + extents_.clear(); + + auto provide = provide_.as(); + CHECK(provide->value.as()); + auto pro_value = provide->value.as(); + auto pro_func = provide->func; + + Array indices; + for (auto arg : pro_value->args) { + indices.push_back(make_zero(Int(INT32))); + } + + // Array indices; + Array args; + args.push_back(make_zero(Int(INT32))); + args.push_back(make_zero(Int(INT32))); + for (auto shape : shapes) { + args.push_back(shape); + } + CHECK(args.size() == PARAMETER_NUM) << "The number of input parameters of the transpose interface must be 4."; + Expr dst_call = Call::make(pro_value->type, pro_func->func_name(), indices, pro_value->call_type, pro_func, + pro_value->value_index); + Expr src_call = Call::make(pro_value->type, pro_value->name, indices, pro_value->call_type, pro_value->func, + pro_value->value_index); + + Expr dst_addr = Call::make(Handle(), air::ir::intrinsic::tvm_address_of, {dst_call}, Call::PureIntrinsic); + Expr src_addr = Call::make(Handle(), air::ir::intrinsic::tvm_address_of, {src_call}, Call::PureIntrinsic); + args.Set(0, dst_addr); + args.Set(1, src_addr); + return Evaluate::make(Call::make(Handle(), MATRIX_TRANSPOSE, args, Call::Intrinsic)); + } + + private: + std::vector extents_; + Stmt provide_; +}; + +Stmt MatrixTranspose(const Stmt &stmt) { return MatrixTransposeMutator().Mutate(stmt); } +} // namespace ir +} // namespace akg diff --git a/src/pass/reconstruct_layout.cc b/src/pass/reconstruct_layout.cc index 9fe9baa3..dfaf289b 100644 --- a/src/pass/reconstruct_layout.cc +++ b/src/pass/reconstruct_layout.cc @@ -547,12 +547,14 @@ class CPULocalReconstruction : public IRMutator { Array indices; Array args; args.push_back(make_zero(Int(INT32))); + args.push_back(make_zero(Int(INT32))); for (size_t i = 0; i < tensor.ndim(); i++) { indices.push_back(make_zero(Int(INT32))); args.push_back(tensor->shape[i]); } Expr addr = Call::make(Handle(), air::ir::intrinsic::tvm_address_of, {tensor(indices)}, Call::PureIntrinsic); args.Set(0, addr); + args.Set(1, addr); Expr matrix_trans = Call::make(Handle(), MATRIX_TRANSPOSE, args, Call::Intrinsic); auto block = Block::make({read, Evaluate::make(matrix_trans), write}); for (auto j : tensor->shape) { diff --git a/src/poly/create_cluster.cc b/src/poly/create_cluster.cc index f0a2c906..abcdae00 100644 --- a/src/poly/create_cluster.cc +++ b/src/poly/create_cluster.cc @@ -629,7 +629,9 @@ bool CpuCreateCluster::CheckPromotion(const isl::schedule_node ¤t_node, co const TensorFootprintCluster &cluster, const std::pair &tensor_info) { auto template_type = scop_info_.analysis_result_.GetOuterBandNode(band_index_)->template_type; - return template_type == Template::MATMUL || template_type == Template::CONV; + bool need_promotion = + template_type == Template::MATMUL || template_type == Template::CONV || template_type == Template::TRANSPOSE_OP; + return need_promotion; } void CpuCreateCluster::CreateClusterListForGemm(const isl::schedule_node &node, @@ -674,6 +676,17 @@ void CpuCreateCluster::CreateClusterListForConv(const isl::schedule_node &node, RecordPromotedTensorInfo(node, mark_name, current_tensors); } } + +void CpuCreateCluster::CreateClusterListForTranspose(const isl::schedule_node &node, + const std::unordered_set &mark_names) { + auto configed_tensors = scop_info_.user_config_.GetRegisterTensors(); + // Initialize the promoted types of all tensors. + RecordInitPromotedTensorType(configed_tensors); + + for (auto mark_name : mark_names) { + RecordPromotedTensorInfo(node, mark_name, all_tensors_); + } +} } // namespace poly } // namespace ir } // namespace akg \ No newline at end of file diff --git a/src/poly/create_cluster.h b/src/poly/create_cluster.h index bb2e1f62..b1c6fc0f 100644 --- a/src/poly/create_cluster.h +++ b/src/poly/create_cluster.h @@ -130,6 +130,7 @@ class CpuCreateCluster : public CreateCluster { // Promoted tensors needed to create different types of operators. void CreateClusterListForGemm(const isl::schedule_node &orig_node, const std::unordered_set &mark_names); void CreateClusterListForConv(const isl::schedule_node &node, const std::unordered_set &mark_names); + void CreateClusterListForTranspose(const isl::schedule_node &node, const std::unordered_set &mark_names); private: // Common functions required by shared, register in gpu and cpu. diff --git a/src/poly/schedule_analysis/band_node_analysis.cc b/src/poly/schedule_analysis/band_node_analysis.cc index 82f377fe..153806a2 100644 --- a/src/poly/schedule_analysis/band_node_analysis.cc +++ b/src/poly/schedule_analysis/band_node_analysis.cc @@ -643,12 +643,12 @@ void AnalyzeBandNode::Run() { AnalyzeOuterBandAccessInfo(bn); if (target_ == TARGET_CPU || target_ == TARGET_CUDA) { AnalyzeAxisPosition(bn); + bn->use_register_memory = scop_info_.user_config_.GetUseRegisterMemory(); } if (target_ == TARGET_CUDA) { CheckVectorization(bn); bn->use_shared_memory = scop_info_.user_config_.GetUseSharedMemory(); - bn->use_register_memory = scop_info_.user_config_.GetUseRegisterMemory(); } } ShowBandInfo(); diff --git a/src/poly/schedule_pass/tile_outer_band.cc b/src/poly/schedule_pass/tile_outer_band.cc index 8a5f111a..d607f8a0 100644 --- a/src/poly/schedule_pass/tile_outer_band.cc +++ b/src/poly/schedule_pass/tile_outer_band.cc @@ -1215,6 +1215,10 @@ isl::schedule_node TileOuterBand::MarkOuterPermutableCpu(isl::schedule_node node return TileConvForCpu(node); } + if (template_type == Template::TRANSPOSE_OP && current_outer_bn->enable_transpose) { + return TileTransposeForCpu(node); + } + return TileElementWiseForCpu(node); } @@ -1315,7 +1319,7 @@ isl::schedule_node TileOuterBand::TileCsrForCpu(const isl::schedule_node &orig_n auto band_node = node.as(); node = band_node.n_member() <= 1 ? band_node : band_node.split(band_node.n_member() - 1); node = TileAccordingToTileType(node, TileType::C1); - node = InsertMarkerForLoop(node, FOR_PARALLEL); + node = InsertMarkerForLoop(node, FOR_PARALLEL).child(0); auto template_type = scop_info_.analysis_result_.GetOuterBandNode(cur_band_index_)->template_type; if (template_type == Template::REDUCTION) { node = SplitReduceStatements(node.child(0)); @@ -1535,6 +1539,29 @@ isl::schedule_node TileOuterBand::InsertMarkerForReduceY(const isl::schedule_nod return node; } +isl::schedule_node TileOuterBand::TileTransposeForCpu(const isl::schedule_node &orig_node) { + if (!orig_node.isa()) { + return orig_node; + } + + auto node = orig_node; + size_t start_depth = node.get_tree_depth(); + + // first tiling: parallel + node = TileAccordingToTileType(node, TileType::C1); + node = TileAccordingToTileType(node, TileType::C0); + + auto band_node = node.as(); + node = band_node.split(band_node.n_member() - 1).child(0); + node = node.insert_mark(FOR_VECTORIZED).parent(); + node = InsertMultiMarker(node, FOR_UNROLLED, true); + + node = node.insert_mark(PROMOTE_TRANSPOSE).parent(); + node = InsertMultiMarker(node.parent(), FOR_PARALLEL); + node = node.ancestor(node.get_tree_depth() - start_depth); + return node; +} + isl::schedule_node TileOuterBand::TileElementWiseForCpu(const isl::schedule_node &orig_node, const bool is_all_reduce) { if (!orig_node.isa()) { return orig_node; diff --git a/src/poly/schedule_pass/tile_outer_band.h b/src/poly/schedule_pass/tile_outer_band.h index 1999ccc0..91acb67f 100644 --- a/src/poly/schedule_pass/tile_outer_band.h +++ b/src/poly/schedule_pass/tile_outer_band.h @@ -112,6 +112,7 @@ class TileOuterBand : public SchedulePass { isl::schedule_node TileGemmOperatorForCpu(const isl::schedule_node &orig_node); isl::schedule_node TileElementWiseForCpu(const isl::schedule_node &orig_node, const bool is_all_reduce = false); isl::schedule_node TileConvForCpu(const isl::schedule_node &orig_node); + isl::schedule_node TileTransposeForCpu(const isl::schedule_node &orig_node); bool IsContainReduceStatement(const isl::schedule_node &orig_node); isl::schedule_node SplitReduceStatements(const isl::schedule_node &orig_node); diff --git a/src/poly/schedule_pass_cpu/cpu_memory_manager.cc b/src/poly/schedule_pass_cpu/cpu_memory_manager.cc index 4e14cd81..4d4509e9 100644 --- a/src/poly/schedule_pass_cpu/cpu_memory_manager.cc +++ b/src/poly/schedule_pass_cpu/cpu_memory_manager.cc @@ -28,8 +28,7 @@ namespace ir { namespace poly { isl::schedule CpuMemoryManager::Run(isl::schedule sch) { - if (!scop_info_.user_config_.GetUseSharedMemory() || - (!scop_info_.user_config_.GetEnableMatmul() && !scop_info_.user_config_.GetEnableConv2dDirect())) { + if (!scop_info_.user_config_.GetUseRegisterMemory()) { return sch; } @@ -39,7 +38,7 @@ isl::schedule CpuMemoryManager::Run(isl::schedule sch) { isl::schedule_node CpuMemoryManager::HoistCpuMemoryOnMark(const isl::schedule_node &orig_node) { current_outer_bn_ = scop_info_.analysis_result_.GetOuterBandNode(band_index_); - if (!current_outer_bn_->use_shared_memory) { + if (!current_outer_bn_->use_register_memory) { return orig_node; } @@ -56,7 +55,10 @@ isl::schedule_node CpuMemoryManager::HoistCpuMemoryOnMark(const isl::schedule_no return node; } - node = node.del().parent(); + if (current_outer_bn_->template_type != Template::TRANSPOSE_OP) { + node = node.del(); + } + node = node.parent(); return HoistClusters(node).child(0); }; auto node = orig_node; @@ -103,19 +105,39 @@ void CpuMemoryManager::CreateClusterForOperator(const isl::schedule_node &orig_n mark_names_.emplace(PROMOTE_GLOBAL_TO_REGISTER_AB); mark_names_.emplace(PROMOTE_GLOBAL_TO_REGISTER_C); create_cluster.CreateClusterListForConv(orig_node, mark_names_); + } else if (current_outer_bn_->template_type == Template::TRANSPOSE_OP) { + // transpose operator + mark_names_.emplace(PROMOTE_TRANSPOSE); + create_cluster.CreateClusterListForTranspose(orig_node, mark_names_); } } isl::schedule_node CpuMemoryManager::InsertMarkerForEmit(const isl::schedule_node &orig_node) { isl::schedule_node node = orig_node; + std::unordered_map filter_marker_map; if (current_outer_bn_->template_type == Template::MATMUL) { // matmul operator node = InsertMarkerForGemm(node); } else if (current_outer_bn_->template_type == Template::CONV) { // conv operator - node = InsertMarkerForPromotedNode(node, WRITE_ID_NAME, FOR_VECTORIZED, -1); - node = InsertMarkerForPromotedNode(node, WRITE_ID_NAME, FOR_UNROLLED, -1); - node = InsertMarkerForPromotedNode(node, READ_ID_NAME, FOR_VECTORIZED, -1); + PromoteMarkerInfo write_info; + write_info.markers = {FOR_VECTORIZED, FOR_UNROLLED}; + write_info.axis_pos = -1; + filter_marker_map[WRITE_ID_NAME] = write_info; + + PromoteMarkerInfo read_info; + read_info.markers = {FOR_VECTORIZED}; + read_info.axis_pos = -1; + filter_marker_map[READ_ID_NAME] = read_info; + node = InsertMarkerForPromotedNode(node, filter_marker_map); + } else if (current_outer_bn_->template_type == Template::TRANSPOSE_OP) { + // transpose operator + PromoteMarkerInfo read_write_info; + read_write_info.markers = {FOR_VECTORIZED, FOR_UNROLLED}; + read_write_info.axis_pos = -1; + filter_marker_map[WRITE_ID_NAME] = read_write_info; + filter_marker_map[READ_ID_NAME] = read_write_info; + node = InsertMarkerForPromotedNode(node, filter_marker_map); } return node; } diff --git a/src/poly/schedule_pass_gpu/register_memory_manager.cc b/src/poly/schedule_pass_gpu/register_memory_manager.cc index 41e9eae7..7e3fc965 100644 --- a/src/poly/schedule_pass_gpu/register_memory_manager.cc +++ b/src/poly/schedule_pass_gpu/register_memory_manager.cc @@ -136,6 +136,7 @@ void RegisterMemoryManager::CreateClusterForOperator(const isl::schedule_node &n isl::schedule_node RegisterMemoryManager::InsertMarkerForEmit(const isl::schedule_node &orig_node) { auto node = orig_node; + std::unordered_map filter_marker_map; if (scop_info_.user_config_.GetEnableMatmul()) { if (scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { node = TileTensorAccordingInterfaceValue(orig_node); @@ -144,10 +145,16 @@ isl::schedule_node RegisterMemoryManager::InsertMarkerForEmit(const isl::schedul if (write_name_ == SHARED_WRITE_ID_NAME) { marker_name = PROMOTE_REGISTER_TO_SHARED; } - node = InsertMarkerForPromotedNode(node, write_name_, marker_name); + PromoteMarkerInfo write_info; + write_info.markers = {marker_name}; + filter_marker_map[write_name_] = write_info; + node = InsertMarkerForPromotedNode(node, filter_marker_map); } else if (current_outer_bn_->enable_vectorization) { - node = InsertMarkerForPromotedNode(node, GML_READ_ID_NAME, FOR_VECTORIZED); - node = InsertMarkerForPromotedNode(node, GML_WRITE_ID_NAME, FOR_VECTORIZED); + PromoteMarkerInfo read_write_info; + read_write_info.markers = {FOR_VECTORIZED}; + filter_marker_map[GML_WRITE_ID_NAME] = read_write_info; + filter_marker_map[GML_READ_ID_NAME] = read_write_info; + node = InsertMarkerForPromotedNode(node, filter_marker_map); } return node; } diff --git a/src/poly/schedule_pass_gpu/shared_memory_manager.cc b/src/poly/schedule_pass_gpu/shared_memory_manager.cc index 8d7f226e..57c0e414 100644 --- a/src/poly/schedule_pass_gpu/shared_memory_manager.cc +++ b/src/poly/schedule_pass_gpu/shared_memory_manager.cc @@ -151,7 +151,11 @@ isl::schedule_node SharedMemoryManager::InsertMarkerForRegisterPromotion(const i if (mark_names_.find(PROMOTE_GLOBAL_TO_SHARED_C) != mark_names_.end()) { hoist_register_node = orig_node.child(0).insert_mark(PROMOTE_SHARED_TO_REGISTER_C); } - hoist_register_node = InsertMarkerForPromotedNode(hoist_register_node, WRITE_ID_NAME, PROMOTE_SHARED_TO_GLOBAL); + std::unordered_map filter_marker_map; + PromoteMarkerInfo write_info; + write_info.markers = {PROMOTE_SHARED_TO_GLOBAL}; + filter_marker_map[WRITE_ID_NAME] = write_info; + hoist_register_node = InsertMarkerForPromotedNode(hoist_register_node, filter_marker_map); return ReplaceMarker(hoist_register_node, PROMOTE_GLOBAL_TO_SHARED_AB, SHARED_MEM_PROMOTED_COMPLETE); } diff --git a/src/poly/schedule_tree_util.cc b/src/poly/schedule_tree_util.cc index e4b9e4d1..8a99a04f 100644 --- a/src/poly/schedule_tree_util.cc +++ b/src/poly/schedule_tree_util.cc @@ -224,22 +224,21 @@ std::vector BandsSplitAfterDepth(const std::vector 0) << "The position of the inserted axis must be greater than 0."; - auto GetPromotedFilter = [filter_name, marker_name, aixs_pos](isl::schedule_node node) -> isl::schedule_node { +isl::schedule_node InsertMarkerForPromotedNode( + const isl::schedule_node &orig_node, const std::unordered_map &filter_marker_map) { + auto GetPromotedFilter = [filter_marker_map](isl::schedule_node node) -> isl::schedule_node { if (!node.isa()) { return node; } isl::union_set uset = node.as().get_filter(); - bool is_gm_filter = false; - uset.foreach_set([&is_gm_filter, filter_name](isl::set s) { - if (s.get_tuple_name() == filter_name) { - is_gm_filter = true; + std::string filter_name = ""; + uset.foreach_set([&filter_name, filter_marker_map](isl::set s) { + std::string set_name = s.get_tuple_name(); + if (filter_marker_map.count(set_name) != 0) { + filter_name = set_name; } }); - if (!is_gm_filter) { + if (filter_name.empty()) { return node; } auto child_node = node.child(0); @@ -252,17 +251,36 @@ isl::schedule_node InsertMarkerForPromotedNode(const isl::schedule_node &orig_no if (n_member == 0) { return node; } - CHECK(std::abs(aixs_pos) <= n_member) + + size_t start_depth = node.get_tree_depth(); + PromoteMarkerInfo marker_info = filter_marker_map.at(filter_name); + int aixs_pos = marker_info.axis_pos; + auto marker_names = marker_info.markers; + + // aixs_pos: Indicates that the marker is inserted before the i-th axis, starting from 1. + CHECK(std::abs(aixs_pos) <= n_member && aixs_pos != 0) << "The position of the inserted axis: " << std::abs(aixs_pos) << " cannot be greater than the total number of axes of the current band node: " << n_member << "."; int current_aixs_pos = aixs_pos - 1; if (aixs_pos < 0) { current_aixs_pos = n_member + aixs_pos; } - bool need_split = current_aixs_pos != 0; - node = need_split ? band_node.split(current_aixs_pos).child(0) : child_node; - node = node.insert_mark(marker_name).parent(); - node = need_split ? node.parent() : node; + current_aixs_pos -= (static_cast(marker_names.size()) - 1); + + node = (current_aixs_pos > 0) ? band_node.split(current_aixs_pos).child(0) : child_node; + for (auto marker_name : marker_names) { + auto cur_node = node.as(); + int band_number = cur_node.n_member(); + if (band_number > 1) { + node = cur_node.split(band_number - 1).child(0); + } + node = node.insert_mark(marker_name).parent(); + + if (band_number == 1) { + break; + } + } + node = node.ancestor(node.get_tree_depth() - start_depth); return node; }; return orig_node.map_descendant_bottom_up(GetPromotedFilter); diff --git a/src/poly/schedule_tree_util.h b/src/poly/schedule_tree_util.h index 9199d93a..ace0fbbc 100644 --- a/src/poly/schedule_tree_util.h +++ b/src/poly/schedule_tree_util.h @@ -25,6 +25,11 @@ namespace akg { namespace ir { namespace poly { +struct PromoteMarkerInfo { + std::vector markers; // Insert markers from back to front + int axis_pos{1}; +}; + isl::union_set CollectDomain(const isl::schedule_node &node); isl::schedule_node MapDescendantTopDown(isl::schedule_node node, @@ -114,8 +119,8 @@ isl::schedule_node UnrollByMarkOptions(isl::schedule_node &node, uint64_t unroll isl::map GetExtensionSpace(const isl::schedule_node &node, const isl::id &id); isl::schedule_node InsertExtensionNodeBeforeOrAfter(const isl::schedule_node &node, const isl::id &id, bool before); -isl::schedule_node InsertMarkerForPromotedNode(const isl::schedule_node &orig_node, const std::string &filter_name, - const std::string &marker_name, const int aixs_pos = 1); +isl::schedule_node InsertMarkerForPromotedNode( + const isl::schedule_node &orig_node, const std::unordered_map &filter_marker_map); std::string GetMarkerName(const isl::schedule_node &node, std::string find_name); isl::union_set GetMappingFilterInfo(const isl::schedule_node node, MappingCfg *mapping_cfg, diff --git a/src/poly/scop_info.h b/src/poly/scop_info.h index cf597a26..0d0ff2de 100644 --- a/src/poly/scop_info.h +++ b/src/poly/scop_info.h @@ -305,6 +305,7 @@ class UserConfig { ParseBoolAttr(attrs, "enable_vectorization", &enable_vectorization_); ParseBoolAttr(attrs, "pragma_enable_matmul", &enable_matmul_); ParseIntAttr(attrs, "vector_length", &vector_length_); + ParseBoolAttr(attrs, "use_register_memory", &use_register_memory_); if (GetTarget() == TARGET_CUDA) { ParseStringAttr(attrs, "device_type", &device_type_); @@ -317,7 +318,6 @@ class UserConfig { ParseBoolAttr(attrs, "enable_tensor_core_use_poly", &enable_tensor_core_use_poly_); ParseBoolAttr(attrs, "enable_akg_reduce_lib", &enable_akg_reduce_lib_); ParseBoolAttr(attrs, "has_tot_ops", &has_tot_ops_); - ParseBoolAttr(attrs, "use_register_memory", &use_register_memory_); ParseBoolAttr(attrs, "use_shared_memory", &use_shared_memory_); ParseBoolAttr(attrs, "enable_bank_conflict_opt", &enable_bank_conflict_); ParseBoolAttr(attrs, "enable_one_dim_thread", &enable_one_dim_thread_); @@ -582,8 +582,8 @@ class UserConfig { bool GetUseRegisterMemory() const { return use_register_memory_; } bool GetUseSharedMemory() const { return use_shared_memory_; } - void SetGetUseSharedMemory(bool use_shared_memory) { use_shared_memory_ = use_shared_memory; } - void SetGetUseRegisterMemory(bool use_register_memory) { use_register_memory_ = use_register_memory; } + void SetUseSharedMemory(bool use_shared_memory) { use_shared_memory_ = use_shared_memory; } + void SetUseRegisterMemory(bool use_register_memory) { use_register_memory_ = use_register_memory; } std::unordered_set GetSplitTensors(const std::string &tensor_name); void RecordSharedTensors(const std::string &tensor_name) { shared_tensors_ += (SPACE_PATTERN + tensor_name); } @@ -1009,6 +1009,7 @@ class AnalysisResult { isl::union_map reads; isl::union_map writes; std::unordered_map mnk_pos; + bool enable_transpose{false}; // user config bool use_shared_memory{true}; bool use_register_memory{true}; diff --git a/src/poly/tiling/space_analyzer.cc b/src/poly/tiling/space_analyzer.cc index f5b447f1..65364006 100644 --- a/src/poly/tiling/space_analyzer.cc +++ b/src/poly/tiling/space_analyzer.cc @@ -98,6 +98,7 @@ void SpaceAnalyzer::MarkCaredType(ProvideEntry pe) { } MarkInnerMostAxis({target}, AT_BROADCAST_INNERMOST_AXIS); } else if (ct == AT_TRANSPOSE) { + MarkTransposeAxes(pe); std::vector tensors = {pe.dst}; for (auto src : pe.src) { tensors.emplace_back(src); @@ -292,7 +293,42 @@ void SpaceAnalyzer::MarkBroadcastAxes(const ProvideEntry &pe) { } for (auto axis : broadcasted) { - axis->MarkWithAttr(AttrInfo{AT_OP_TYPE, AT_BROADCAST}); + axis->MarkWithAttr(AttrInfo{AT_AXIS_TYPE, AT_BROADCAST_AXIS}); + } +} + +void SpaceAnalyzer::MarkTransposeAxes(const ProvideEntry &pe) { + std::unordered_set dst_transpose; + for (auto dst_it : pe.dst.loops) { + for (auto l : dst_it.second) { + dst_transpose.insert(l); + } + } + + std::unordered_set src_transpose; + for (auto src : pe.src) { + if (src.loops.size() == 0 || src.loops.size() != pe.dst.loops.size()) { + continue; + } + for (auto src_it : src.loops) { + for (auto l : src_it.second) { + src_transpose.insert(l); + } + } + } + + if (dst_transpose.size() != src_transpose.size()) { + return; + } + + for (auto dst_it = dst_transpose.begin(), src_it = src_transpose.begin(); dst_it != dst_transpose.end(); + dst_it++, src_it++) { + auto dst_for = *dst_it; + auto src_for = *src_it; + if (dst_for->loop_var->name_hint != src_for->loop_var->name_hint) { + auto axis = analyzer_->Axis(dst_for); + axis->MarkWithAttr(AttrInfo{AT_AXIS_TYPE, AT_TRANSPOSE_AXIS}); + } } } diff --git a/src/poly/tiling/space_analyzer.h b/src/poly/tiling/space_analyzer.h index 65350571..cab3352c 100644 --- a/src/poly/tiling/space_analyzer.h +++ b/src/poly/tiling/space_analyzer.h @@ -94,6 +94,7 @@ class SpaceAnalyzer { Band loops); void MarkBroadcastAxes(const ProvideEntry &pe); + void MarkTransposeAxes(const ProvideEntry &pe); std::vector FindModConstraint(const Expr &arg, std::vector constraints); const For *GetBufferInnerAxis(const TensorEntry &t, int offset = 1); void SetAttrForAxis(int tile_band, int tile_axis, const std::string &attr_key, const std::string &attr_value); @@ -103,7 +104,7 @@ class SpaceAnalyzer { const std::string &attr_key, const std::string &attr_value, TileAxis *target); std::string ParseAllTypeExpr(const Expr constraint); std::string ParseArrayExpr(const Array constraint); - + Template GetOpTemplate(const Node *op); }; } // namespace poly diff --git a/src/poly/tiling/tiling_analyzer.h b/src/poly/tiling/tiling_analyzer.h index d79fe70a..479e2493 100755 --- a/src/poly/tiling/tiling_analyzer.h +++ b/src/poly/tiling/tiling_analyzer.h @@ -140,6 +140,7 @@ constexpr auto AT_DMA = "DMA"; constexpr auto AT_DMA2 = "DMA2"; constexpr auto AT_DMA3 = "DMA3"; constexpr auto AT_OP_TYPE = "OP_TYPE"; +constexpr auto AT_AXIS_TYPE = "AXIS_TYPE"; constexpr auto AT_REDUCE_DST_LAST = "REDUCE_DST_LAST"; constexpr auto AT_REDUCE_SRC_LAST = "REDUCE_SRC_LAST"; @@ -147,6 +148,8 @@ constexpr auto AT_TRANSPOSE_INNERMOST_AXIS = "TRANSPOSE_INNERMOST_AXIS"; constexpr auto AT_BROADCAST_INNERMOST_AXIS = "BROADCAST_INNERMOST_AXIS"; constexpr auto AT_REDUCE_FLOW = "REDUCE_FLOW"; constexpr auto AT_REDUCE_AXIS = "REDUCE_AXIS"; +constexpr auto AT_TRANSPOSE_AXIS = "TRANSPOSE_AXIS"; +constexpr auto AT_BROADCAST_AXIS = "BROADCAST_AXIS"; constexpr auto AT_COUNT_AXIS = "COUNT_AXIS"; constexpr auto AT_POST_FUSION_REDUCE_TENSOR = "POST_FUSION_REDUCE_TENSOR"; constexpr auto AT_CONV = "CONV"; diff --git a/src/poly/tiling/tiling_strategy_manager.h b/src/poly/tiling/tiling_strategy_manager.h index 25f77538..a1151f37 100644 --- a/src/poly/tiling/tiling_strategy_manager.h +++ b/src/poly/tiling/tiling_strategy_manager.h @@ -579,15 +579,17 @@ class CpuStrategy : public TilingStrategy { private: void BuildAxesQueue(); void RecordTileValue(); - void GenConv2dTileByAxis(const int index, int64_t &p, int64_t tile1, int64_t tile0); - void SetConv2dTileValue(int index); - void SetMatMulTileValue(int index); - bool SetReduceYTileValue(int index); + void GenConv2dTileByAxis(int64_t &p, int64_t tile1, int64_t tile0); + void SetConv2dTileValue(); + void SetMatMulTileValue(); + bool SetReduceYTileValue(); + void SetCsrTileValue(); + void SetElementWiseTileValue(); + void SetTransposeTileValue(); void SetMultiLevelTileValue(); void SetUnrollTileValue(TileAxis *axis, const int64_t axis_size, int64_t &tile_left); void SetParallelTileValue(TileAxis *axis, const int64_t axis_size, const int64_t data_size, bool is_unroll_axis = false, int64_t tile_left = 1); - void SetCsrTileValue(); std::vector>> pending_axes_; int min_exec_num_per_thread_{MIN_EXEC_NUM_PER_THREAD}; int best_parallel_num_{BEST_PARALLEL_NUM}; diff --git a/src/poly/tiling/tiling_strategy_manager_cpu.cc b/src/poly/tiling/tiling_strategy_manager_cpu.cc index 95288c6d..b238b013 100644 --- a/src/poly/tiling/tiling_strategy_manager_cpu.cc +++ b/src/poly/tiling/tiling_strategy_manager_cpu.cc @@ -119,32 +119,32 @@ void CpuStrategy::SetParallelTileValue(TileAxis *axis, const int64_t axis_size, axis->TileRestrainToSingleValue(Expr(c0_tile_value), TileLevel::CACHE0); } -void CpuStrategy::GenConv2dTileByAxis(const int index, int64_t &p, int64_t tile1, int64_t tile0) { +void CpuStrategy::GenConv2dTileByAxis(int64_t &p, int64_t tile1, int64_t tile0) { TileAxis *axis = nullptr; int64_t shape; - std::tie(axis, shape) = pending_axes_[index][p]; + std::tie(axis, shape) = pending_axes_[current_band_][p]; CHECK(axis != nullptr); axis->TileRestrainToSingleValue(Expr((int64_t)tile1), TileLevel::CACHE1); axis->TileRestrainToSingleValue(Expr((int64_t)tile0), TileLevel::CACHE0); p += 1; } -void CpuStrategy::SetConv2dTileValue(int index) { +void CpuStrategy::SetConv2dTileValue() { // format of conv2d tile should be: batch, oc_out, oh, ow, oc_in, ic_out. // all of them can be 1, so we use axes_names to check the exist of each axis. auto axes_names = analyzer_->scop_info_.analysis_result_.GetCpuConvolutionAxes(); int64_t p = 0; if (axes_names.find(CONV_BATCH) != std::string::npos) { // batch - GenConv2dTileByAxis(index, p, 1, 1); + GenConv2dTileByAxis(p, 1, 1); } if (axes_names.find(CONV_OC_OUT) != std::string::npos) { // oc_out - GenConv2dTileByAxis(index, p, 1, 1); + GenConv2dTileByAxis(p, 1, 1); } if (axes_names.find(CONV_OH) != std::string::npos) { // oh - GenConv2dTileByAxis(index, p, 1, 1); + GenConv2dTileByAxis(p, 1, 1); } if (axes_names.find(CONV_OW) != std::string::npos) { // ow - int64_t ow_shape = pending_axes_[index][p].second; + int64_t ow_shape = pending_axes_[current_band_][p].second; /* ow_inner should follow some strategy: 1. ow_shape % ow_tile == 0 2. ow_tile is smaller than simd length */ @@ -156,39 +156,39 @@ void CpuStrategy::SetConv2dTileValue(int index) { break; } } - GenConv2dTileByAxis(index, p, ow_tile, ow_tile); + GenConv2dTileByAxis(p, ow_tile, ow_tile); } if (axes_names.find(CONV_OC_IN) != std::string::npos) { // oc_in - int64_t oc_in_shape = pending_axes_[index][p].second; - GenConv2dTileByAxis(index, p, oc_in_shape, oc_in_shape); + int64_t oc_in_shape = pending_axes_[current_band_][p].second; + GenConv2dTileByAxis(p, oc_in_shape, oc_in_shape); } if (axes_names.find(CONV_IC_OUT) != std::string::npos) { // ic_out: reduction axis - int64_t ic_out_shape = pending_axes_[index][p].second; - GenConv2dTileByAxis(index, p, ic_out_shape, 1); + int64_t ic_out_shape = pending_axes_[current_band_][p].second; + GenConv2dTileByAxis(p, ic_out_shape, 1); } if (axes_names.find(CONV_IC_OUT) == std::string::npos && axes_names.find(CONV_KH) != std::string::npos) { // kh: reduction axis - GenConv2dTileByAxis(index, p, 1, 1); + GenConv2dTileByAxis(p, 1, 1); } if (axes_names.find(CONV_IC_OUT) == std::string::npos && axes_names.find(CONV_KH) == std::string::npos && axes_names.find(CONV_KW) != std::string::npos) { // kw: reduction axis - GenConv2dTileByAxis(index, p, 1, 1); + GenConv2dTileByAxis(p, 1, 1); } if (axes_names.find(CONV_IC_OUT) == std::string::npos && axes_names.find(CONV_KH) == std::string::npos && axes_names.find(CONV_KW) == std::string::npos && axes_names.find(CONV_IC_IN) != std::string::npos) { // ic_in: reduction axis - int64_t ic_in_shape = pending_axes_[index][p].second; - GenConv2dTileByAxis(index, p, ic_in_shape, ic_in_shape); + int64_t ic_in_shape = pending_axes_[current_band_][p].second; + GenConv2dTileByAxis(p, ic_in_shape, ic_in_shape); } } -void CpuStrategy::SetMatMulTileValue(int index) { +void CpuStrategy::SetMatMulTileValue() { auto pack_size = analyzer_->scop_info_.analysis_result_.GetPackBlockSize(); - for (int i = 0; i < static_cast(pending_axes_[index].size()); ++i) { + for (int i = 0; i < static_cast(pending_axes_[current_band_].size()); ++i) { TileAxis *axis; int64_t pack = 1; int64_t shape; - std::tie(axis, shape) = pending_axes_[index][i]; + std::tie(axis, shape) = pending_axes_[current_band_][i]; int64_t value = shape; for (const auto &attr : axis->attrs) { if (attr.attr_key != AT_GEMM) { @@ -212,14 +212,14 @@ void CpuStrategy::SetMatMulTileValue(int index) { } } -bool CpuStrategy::SetReduceYTileValue(int index) { - auto axes_num = pending_axes_[index].size(); +bool CpuStrategy::SetReduceYTileValue() { + auto axes_num = pending_axes_[current_band_].size(); CHECK(axes_num >= REDUCE_Y_LEAST_AXES_NUM) << "axes_num is less than 2"; bool is_tiled = false; TileAxis *axis1, *axis0; int64_t shape1, shape0; - std::tie(axis0, shape0) = pending_axes_[index][0]; - std::tie(axis1, shape1) = pending_axes_[index][1]; + std::tie(axis0, shape0) = pending_axes_[current_band_][0]; + std::tie(axis1, shape1) = pending_axes_[current_band_][1]; int64_t value1 = shape1; if (shape1 >= REDUCE_Y_LEAST_BLOCK_SIZE && shape0 <= REDUCE_Y_LEAST_X_SIZE) { int64_t value0 = shape0; @@ -263,6 +263,94 @@ void CpuStrategy::SetCsrTileValue() { } } +void CpuStrategy::SetElementWiseTileValue() { + size_t ori_size = pending_axes_[current_band_].size(); + int64_t data_size = 1; + for (int i = static_cast(ori_size - 1); i >= 0; i--) { + TileAxis *axis; + int64_t shape; + std::tie(axis, shape) = pending_axes_[current_band_][i]; + data_size *= shape; + int64_t tile_outer_left = 1; + + int vectorize_axis = analyzer_->scop_info_.analysis_result_.GetOuterBandNode(current_band_)->last_axis; + if (vectorize_axis == i) { + SetUnrollTileValue(axis, shape, tile_outer_left); + } + + /* Set parallel tile size on the outermost axis */ + if (i == 0) { + bool is_unroll_axis = vectorize_axis == 0 ? true : false; + SetParallelTileValue(axis, shape, data_size, is_unroll_axis, tile_outer_left); + } + } +} + +void CpuStrategy::SetTransposeTileValue() { + std::unordered_set write_tensor_name; + auto current_outer_bn = analyzer_->scop_info_.analysis_result_.GetOuterBandNode(current_band_); + isl::map_list access_list = current_outer_bn->writes.get_map_list(); + for (auto access : access_list) { + auto access_id = access.domain_factor_domain().get_tuple_id(isl_dim_out); + write_tensor_name.insert(access_id.get_name()); + } + + int transpose_write_axis_pos = -1; + std::unordered_set transpose_axis_pos; + size_t ori_size = pending_axes_[current_band_].size(); + for (int i = static_cast(ori_size - 1); i >= 0; i--) { + TileAxis *axis; + int64_t shape; + std::tie(axis, shape) = pending_axes_[current_band_][i]; + + bool is_inner_axis = false; + bool is_transpose_axis = false; + std::string tensor_name; + for (const auto &attr : axis->attrs) { + if (attr.attr_value == AT_TRANSPOSE_AXIS) { + is_transpose_axis = true; + } + + if (attr.attr_key == AT_TRANSPOSE_INNERMOST_AXIS) { + is_inner_axis = true; + tensor_name = attr.attr_value; + } + } + + if (is_transpose_axis && is_inner_axis) { + transpose_axis_pos.emplace(i); + if (write_tensor_name.count(tensor_name) != 0) { + transpose_write_axis_pos = i; + } + } + } + + const size_t transpose_size = 2; + const int64_t transpose_row = 8; + const int64_t transpose_col = 4; + if (transpose_axis_pos.size() == transpose_size) { + current_outer_bn->enable_transpose = true; + for (int i = static_cast(ori_size - 1); i >= 0; i--) { + TileAxis *axis; + int64_t shape; + std::tie(axis, shape) = pending_axes_[current_band_][i]; + int64_t tile_val = 1; + if (transpose_axis_pos.count(i) != 0) { + if (transpose_write_axis_pos == i) { + tile_val = shape < transpose_row ? shape : transpose_row; + } else { + tile_val = shape < transpose_col ? shape : transpose_col; + } + } + + axis->TileRestrainToSingleValue(Expr(tile_val), TileLevel::CACHE1); + axis->TileRestrainToSingleValue(Expr(tile_val), TileLevel::CACHE0); + } + } else { + SetElementWiseTileValue(); + } +} + void CpuStrategy::SetMultiLevelTileValue() { if (analyzer_->scop_info_.analysis_result_.GetCsr()) { SetCsrTileValue(); @@ -272,40 +360,27 @@ void CpuStrategy::SetMultiLevelTileValue() { current_band_ = idx; auto op_type = analyzer_->scop_info_.analysis_result_.GetOuterBandNode(idx)->template_type; if (op_type == Template::CONV) { - SetConv2dTileValue(idx); + SetConv2dTileValue(); continue; } if (op_type == Template::MATMUL) { - SetMatMulTileValue(idx); + SetMatMulTileValue(); continue; } auto reduce_direction = analyzer_->scop_info_.analysis_result_.GetReduceDirection(); if (op_type == Template::REDUCTION && reduce_direction == ReduceDirection::Y) { - bool is_tiled = SetReduceYTileValue(idx); + bool is_tiled = SetReduceYTileValue(); if (is_tiled) { continue; } } - size_t ori_size = pending_axes_[idx].size(); - int64_t data_size = 1; - for (int i = static_cast(ori_size - 1); i >= 0; i--) { - TileAxis *axis; - int64_t shape; - std::tie(axis, shape) = pending_axes_[idx][i]; - data_size *= shape; - int64_t tile_outer_left = 1; - int vectorize_axis = analyzer_->scop_info_.analysis_result_.GetOuterBandNode(idx)->last_axis; - if (vectorize_axis == i) { - SetUnrollTileValue(axis, shape, tile_outer_left); - } - - /* Set parallel tile size on the outermost axis */ - if (i == 0) { - bool is_unroll_axis = vectorize_axis == 0 ? true : false; - SetParallelTileValue(axis, shape, data_size, is_unroll_axis, tile_outer_left); - } + if (op_type == Template::TRANSPOSE_OP) { + SetTransposeTileValue(); + continue; } + + SetElementWiseTileValue(); } } diff --git a/tests/st/ops/test_transpose.py b/tests/st/ops/test_transpose.py index b2ffdbf4..12555e7f 100644 --- a/tests/st/ops/test_transpose.py +++ b/tests/st/ops/test_transpose.py @@ -122,6 +122,7 @@ class TestCase(TestBase): self.args_outhers = [ ("000_case", transpose_run, ((8, 24, 38, 38), (0, 2, 1, 3), 'float32'), ["level0"]), ("001_case", transpose_run, ((8, 24, 38, 38), (0, 2, 1, 3), 'float16'), ["level0"]), + ("002_case", transpose_run, ((8, 24, 32, 16), (0, 1, 3, 2), 'float32'), ["level0"]), ] return True diff --git a/third_party/incubator-tvm/src/codegen/llvm/codegen_llvm.cc b/third_party/incubator-tvm/src/codegen/llvm/codegen_llvm.cc index 16af5752..94473ee2 100644 --- a/third_party/incubator-tvm/src/codegen/llvm/codegen_llvm.cc +++ b/third_party/incubator-tvm/src/codegen/llvm/codegen_llvm.cc @@ -1658,7 +1658,9 @@ llvm::Value* CodeGenLLVM::EmitSgemmKernel(const Call* op) { return sgemm_ret; } -llvm::Value* CodeGenLLVM::CreateMatrixTransposeBase(llvm::Value *buffer, size_t row, size_t col, size_t bits) { +llvm::Value* CodeGenLLVM::CreateMatrixTransposeBase(llvm::Value* dst_buffer, + llvm::Value* src_buffer, size_t row, size_t col, + size_t bits) { #if TVM_LLVM_VERSION >= 110 auto align = llvm::Align(row * col); std::vector indices; @@ -1666,18 +1668,23 @@ llvm::Value* CodeGenLLVM::CreateMatrixTransposeBase(llvm::Value *buffer, size_t auto align = row * col; std::vector indices; #endif - llvm::Value* ptr = CreateBufferVecPtr(DataType(kDLUInt, bits, row * col), buffer, ConstInt32(0)); - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, align, true); + llvm::Value* src_ptr = + CreateBufferVecPtr(DataType(kDLUInt, bits, row * col), src_buffer, ConstInt32(0)); + llvm::LoadInst* load = builder_->CreateAlignedLoad(src_ptr, align, true); for (unsigned i = 0; i < col; i++) { for (unsigned j = 0; j < row; j++) { indices.push_back(j * col + i); } } + auto dst = builder_->CreateShuffleVector(load, indices); - return builder_->CreateAlignedStore(dst, ptr, align, true); + llvm::Value* dst_ptr = + CreateBufferVecPtr(DataType(kDLUInt, bits, row * col), dst_buffer, ConstInt32(0)); + return builder_->CreateAlignedStore(dst, dst_ptr, align, true); } -llvm::Value* CodeGenLLVM::CreateMatrixTranspose4x4(llvm::Value *buffer, size_t row, size_t col, size_t bits) { +llvm::Value* CodeGenLLVM::CreateMatrixTranspose4x4(llvm::Value* dst_buffer, llvm::Value* src_buffer, + size_t row, size_t col, size_t bits) { #if TVM_LLVM_VERSION >= 110 auto align = llvm::Align(col); std::vector low = {0, 4, 1, 5}; @@ -1693,10 +1700,10 @@ llvm::Value* CodeGenLLVM::CreateMatrixTranspose4x4(llvm::Value *buffer, size_t r #endif llvm::StoreInst* store; - llvm::Value* ptr0 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(0)); - llvm::Value* ptr1 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(1)); - llvm::Value* ptr2 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(2)); - llvm::Value* ptr3 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(3)); + llvm::Value* ptr0 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(0)); + llvm::Value* ptr1 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(1)); + llvm::Value* ptr2 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(2)); + llvm::Value* ptr3 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(3)); llvm::LoadInst* xmm0 = builder_->CreateAlignedLoad(ptr0, align, true); llvm::LoadInst* xmm1 = builder_->CreateAlignedLoad(ptr1, align, true); @@ -1712,6 +1719,11 @@ llvm::Value* CodeGenLLVM::CreateMatrixTranspose4x4(llvm::Value *buffer, size_t r tmp2 = builder_->CreateBitCast(tmp2, LLVMType(DataType(kDLUInt, bits * 2, col / 2))); tmp3 = builder_->CreateBitCast(tmp3, LLVMType(DataType(kDLUInt, bits * 2, col / 2))); + ptr0 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), dst_buffer, ConstInt32(0)); + ptr1 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), dst_buffer, ConstInt32(1)); + ptr2 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), dst_buffer, ConstInt32(2)); + ptr3 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), dst_buffer, ConstInt32(3)); + auto row0 = builder_->CreateShuffleVector(tmp0, tmp2, low_h); store = builder_->CreateAlignedStore(row0, ptr0, llvm::Align(row), true); auto row1 = builder_->CreateShuffleVector(tmp0, tmp2, high_h); @@ -1723,7 +1735,8 @@ llvm::Value* CodeGenLLVM::CreateMatrixTranspose4x4(llvm::Value *buffer, size_t r return store; } -llvm::Value* CodeGenLLVM::CreateMatrixTranspose8x4(llvm::Value *buffer, size_t row, size_t col, size_t bits) { +llvm::Value* CodeGenLLVM::CreateMatrixTranspose8x4(llvm::Value* dst_buffer, llvm::Value* src_buffer, + size_t row, size_t col, size_t bits) { #if TVM_LLVM_VERSION >= 110 auto align = llvm::Align(col); std::vector concat = {0, 1, 2, 3, 4, 5, 6, 7}; @@ -1741,14 +1754,14 @@ llvm::Value* CodeGenLLVM::CreateMatrixTranspose8x4(llvm::Value *buffer, size_t r #endif llvm::StoreInst* store; - llvm::Value* ptr0 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(0)); - llvm::Value* ptr1 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(1)); - llvm::Value* ptr2 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(2)); - llvm::Value* ptr3 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(3)); - llvm::Value* ptr4 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(4)); - llvm::Value* ptr5 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(5)); - llvm::Value* ptr6 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(6)); - llvm::Value* ptr7 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), buffer, ConstInt32(7)); + llvm::Value* ptr0 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(0)); + llvm::Value* ptr1 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(1)); + llvm::Value* ptr2 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(2)); + llvm::Value* ptr3 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(3)); + llvm::Value* ptr4 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(4)); + llvm::Value* ptr5 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(5)); + llvm::Value* ptr6 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(6)); + llvm::Value* ptr7 = CreateBufferVecPtr(DataType(kDLUInt, bits, col), src_buffer, ConstInt32(7)); llvm::LoadInst* xmm0 = builder_->CreateAlignedLoad(ptr0, align, true); llvm::LoadInst* xmm4 = builder_->CreateAlignedLoad(ptr4, align, true); @@ -1773,10 +1786,10 @@ llvm::Value* CodeGenLLVM::CreateMatrixTranspose8x4(llvm::Value *buffer, size_t r tmp2 = builder_->CreateBitCast(tmp2, LLVMType(DataType(kDLUInt, bits * 2, row / 2))); tmp3 = builder_->CreateBitCast(tmp3, LLVMType(DataType(kDLUInt, bits * 2, row / 2))); - ptr0 = CreateBufferVecPtr(DataType(kDLUInt, bits, row), buffer, ConstInt32(0)); - ptr1 = CreateBufferVecPtr(DataType(kDLUInt, bits, row), buffer, ConstInt32(1)); - ptr2 = CreateBufferVecPtr(DataType(kDLUInt, bits, row), buffer, ConstInt32(2)); - ptr3 = CreateBufferVecPtr(DataType(kDLUInt, bits, row), buffer, ConstInt32(3)); + ptr0 = CreateBufferVecPtr(DataType(kDLUInt, bits, row), dst_buffer, ConstInt32(0)); + ptr1 = CreateBufferVecPtr(DataType(kDLUInt, bits, row), dst_buffer, ConstInt32(1)); + ptr2 = CreateBufferVecPtr(DataType(kDLUInt, bits, row), dst_buffer, ConstInt32(2)); + ptr3 = CreateBufferVecPtr(DataType(kDLUInt, bits, row), dst_buffer, ConstInt32(3)); auto row0 = builder_->CreateShuffleVector(tmp0, tmp2, low_h); store = builder_->CreateAlignedStore(row0, ptr0, llvm::Align(row), true); @@ -1790,22 +1803,26 @@ llvm::Value* CodeGenLLVM::CreateMatrixTranspose8x4(llvm::Value *buffer, size_t r } llvm::Value* CodeGenLLVM::CreateMatrixTranspose(const Call* op) { - llvm::Value *buffer = MakeValue(op->args[0]); - if (op->args[1].as() == nullptr || op->args[2].as() == nullptr) { + const int row_pos = 2; + const int col_pos = 3; + const int bit_pos = 4; + llvm::Value* dst_buffer = MakeValue(op->args[0]); + llvm::Value* src_buffer = MakeValue(op->args[1]); + if (op->args[row_pos].as() == nullptr || op->args[col_pos].as() == nullptr) { LOG(FATAL) << "Fail to call MatrixTranspose function, The row or col must be IntImm!"; } - unsigned row = op->args[1].as()->value; - unsigned col = op->args[2].as()->value; + unsigned row = op->args[row_pos].as()->value; + unsigned col = op->args[col_pos].as()->value; unsigned bits = 32; - if (op->args.size() > 3) { - bits = op->args[3].as()->value; + if (op->args.size() > bit_pos) { + bits = op->args[bit_pos].as()->value; } if (row == 4 && col == 4) { - return CreateMatrixTranspose4x4(buffer, row, col, bits); + return CreateMatrixTranspose4x4(dst_buffer, src_buffer, row, col, bits); } else if (row == 8 && col == 4) { - return CreateMatrixTranspose8x4(buffer, row, col, bits); + return CreateMatrixTranspose8x4(dst_buffer, src_buffer, row, col, bits); } - return CreateMatrixTransposeBase(buffer, row, col, bits); + return CreateMatrixTransposeBase(dst_buffer, src_buffer, row, col, bits); } } // namespace codegen diff --git a/third_party/incubator-tvm/src/codegen/llvm/codegen_llvm.h b/third_party/incubator-tvm/src/codegen/llvm/codegen_llvm.h index a1822356..3594cbe9 100644 --- a/third_party/incubator-tvm/src/codegen/llvm/codegen_llvm.h +++ b/third_party/incubator-tvm/src/codegen/llvm/codegen_llvm.h @@ -321,9 +321,12 @@ class CodeGenLLVM : llvm::Value* CreateLog(const Call* op); llvm::Value* CreateExp(const Call* op); llvm::Value* CreateMatrixTranspose(const Call* op); - llvm::Value* CreateMatrixTransposeBase(llvm::Value *buffer, size_t row, size_t col, size_t bits=32); - llvm::Value* CreateMatrixTranspose4x4(llvm::Value *buffer, size_t row, size_t col, size_t bits=32); - llvm::Value* CreateMatrixTranspose8x4(llvm::Value *buffer, size_t row, size_t col, size_t bits=32); + llvm::Value* CreateMatrixTransposeBase(llvm::Value* dst_buffer, llvm::Value* src_buffer, + size_t row, size_t col, size_t bits = 32); + llvm::Value* CreateMatrixTranspose4x4(llvm::Value* dst_buffer, llvm::Value* src_buffer, + size_t row, size_t col, size_t bits = 32); + llvm::Value* CreateMatrixTranspose8x4(llvm::Value* dst_buffer, llvm::Value* src_buffer, + size_t row, size_t col, size_t bits = 32); }; } // namespace codegen } // namespace air -- Gitee