From 6b6f29e67c03f957be5b14c5b49dfdd22e2a1cd0 Mon Sep 17 00:00:00 2001 From: hujiahui8 Date: Thu, 3 Mar 2022 15:00:23 +0800 Subject: [PATCH] fix the bug of gemm operator tiling error --- src/poly/schedule_pass/tile_outer_band.cc | 3 ++- src/poly/schedule_tree_util.cc | 23 +++++++++++-------- .../tiling/tiling_strategy_manager_gpu.cc | 9 +++++++- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/poly/schedule_pass/tile_outer_band.cc b/src/poly/schedule_pass/tile_outer_band.cc index 259664a2..45a1465b 100644 --- a/src/poly/schedule_pass/tile_outer_band.cc +++ b/src/poly/schedule_pass/tile_outer_band.cc @@ -1246,6 +1246,8 @@ isl::schedule_node TileOuterBand::TileThreadAndBlockConfig(const isl::schedule_n isl::schedule_node TileOuterBand::TileMatmulOperatorForCuda(const isl::schedule_node &node) { auto tile_node = node; size_t start_depth = tile_node.get_tree_depth(); + + tile_node = TileThreadAndBlockConfig(tile_node, true); tile_node = TileBand(tile_node, GetLevelTileSize(tile_node, TILE_WITH_C1)); isl::schedule_node_band band_node = tile_node.as(); @@ -1259,7 +1261,6 @@ isl::schedule_node TileOuterBand::TileMatmulOperatorForCuda(const isl::schedule_ // split the k axis tile_node = band_node.split(count_coincident); - tile_node = TileThreadAndBlockConfig(tile_node, true); tile_node = InsertPromoteMarker(tile_node); if (scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { diff --git a/src/poly/schedule_tree_util.cc b/src/poly/schedule_tree_util.cc index 1060c1ed..8477389a 100644 --- a/src/poly/schedule_tree_util.cc +++ b/src/poly/schedule_tree_util.cc @@ -906,27 +906,32 @@ isl::multi_val CheckAndGetMapSize(const isl::schedule_node &mapping_root, const int aff_size = static_cast(aff_list.size()); for (int i = aff_size - 1; i >= 0; --i) { auto aff = aff_list.get_at(i).floor(); - auto extent = aff.max_val().get_num_si() + 1; - auto map_size = extent; + int extent = static_cast(aff.max_val().get_num_si()) + 1; + int map_size = extent; if (required_mapping_strategy.count(static_cast(i)) != 0) { std::string mapping_idx = required_mapping_strategy[static_cast(i)].mapping_idx; + if (static_cast(additional_tile_size.size()) > i) { + mapping_cfg_map[mapping_idx] *= additional_tile_size[i]; + } + auto current_mapping_size = mapping_cfg_map[mapping_idx] - required_mapping_strategy[static_cast(i)].offset; if (non_repeated_idx.find(mapping_idx) != non_repeated_idx.end()) { - map_size = mapping_cfg_map[mapping_idx] - required_mapping_strategy[static_cast(i)].offset; + map_size = current_mapping_size; } else { - map_size = mapping_cfg_map[mapping_idx] - required_mapping_strategy[static_cast(i)].offset; + map_size = std::min(map_size, current_mapping_size); + CHECK(map_size != 0); mapping_cfg_map[mapping_idx] = mapping_cfg_map[mapping_idx] / map_size; } - + } else { + map_size = 1; if (static_cast(additional_tile_size.size()) > i) { map_size *= additional_tile_size[i]; } - CHECK(map_size != 0); - mapping_sizes.emplace_back(map_size); - } else { - mapping_sizes.emplace_back(1); } + CHECK(map_size != 0); + mapping_sizes.emplace_back(map_size); + if (mapping_cfg->type == MappingType::THREADS || mapping_cfg->type == MappingType::REPLACE_THREADS) { need_tile = need_tile || extent > map_size; } else { diff --git a/src/poly/tiling/tiling_strategy_manager_gpu.cc b/src/poly/tiling/tiling_strategy_manager_gpu.cc index a90ecaa5..f471c00a 100644 --- a/src/poly/tiling/tiling_strategy_manager_gpu.cc +++ b/src/poly/tiling/tiling_strategy_manager_gpu.cc @@ -1020,7 +1020,14 @@ bool GpuStrategy::IsVectorized() { for (auto i : tensors_shape) { CHECK(i[i.size() - 1].as()); - if (i[i.size() - 1].as()->value % SafeDivisor(vectorized_bytes_) != 0) { + auto i_value = i[i.size() - 1].as()->value; + if (i_value % SafeDivisor(vectorized_bytes_) != 0) { + return false; + } + + auto first_tensor = tensors_shape[0]; + auto first_value = first_tensor[first_tensor.size() - 1].as()->value; + if(template_ == Template::PAD_OP && i_value != first_value) { return false; } } -- Gitee