diff --git a/src/pass/reduction_factor.cc b/src/pass/reduction_factor.cc index 54476a76b15bbe320b70d8c2688da6533bc4c334..b8356bb81fc9bdef85e2431b02517dcbdc53eba0 100644 --- a/src/pass/reduction_factor.cc +++ b/src/pass/reduction_factor.cc @@ -138,7 +138,13 @@ class IdentifyReduceChance : public IRVisitor { reduce_index_++; } } - IRVisitor::Visit_(op); + const Block *block_ptr = op->body.as(); + if (block_ptr && is_reduce_y_) { + IRVisitor::Visit(block_ptr->first); + } else { + IRVisitor::Visit_(op); + } + if (cur_reduce_data_->outter_reduction_data) { auto inner_reduction_data = cur_reduce_data_; cur_reduce_data_ = cur_reduce_data_->outter_reduction_data; @@ -187,7 +193,11 @@ class IdentifyReduceChance : public IRVisitor { break; case air::ir::ForType::Parallel: is_parallel_area_ = true; - cur_reduce_data_->vector_parallel_for = op; + if (is_reduce_y_) { + cur_reduce_data_->parallel_for = op; + } else { + cur_reduce_data_->vector_parallel_for = op; + } IRVisitor::Visit_(op); is_parallel_area_ = false; return; @@ -260,6 +270,26 @@ class IdentifyReduceChance : public IRVisitor { Array output_args = op->args; Expr input_expr = GetTheInputExpr(op); auto call_ptr = input_expr.as(); + while(!call_ptr) { + if (auto min = input_expr.as()) { + input_expr = min->b; + } else if (auto max = input_expr.as()) { + input_expr = max->b; + } else if (auto and_op = input_expr.as()) { + input_expr = and_op->b; + } else if (auto or_op = input_expr.as()) { + input_expr = or_op->b; + } else if (auto add_op = input_expr.as()) { + input_expr = add_op->b; + } else if (auto mul_op = input_expr.as()) { + input_expr = mul_op->b; + } else if (auto div_op = input_expr.as
()) { + input_expr = div_op->b; + } else { + CHECK(false) << "reduce type is invalid"; + } + call_ptr = input_expr.as(); + } CHECK(call_ptr) << "call_ptr is nullptr"; Array input_args = call_ptr->args; while(call_ptr) { @@ -445,7 +475,8 @@ class MutateReduceBody : public IRMutator { cur_reduce_data_->reduce_temp.as()->args); } else if (cur_reduce_data_->isolate_reduce_provide.count(op) && (cur_reduce_data_->outter_reduction_data != nullptr || - cur_reduce_data_->pre_reduction_data != nullptr)) { + cur_reduce_data_->pre_reduction_data != nullptr || + cur_reduce_data_->is_reduce_y)) { auto value = op->value; auto temp_call = MakeUniqueTempCallForIsolate(cur_reduce_data_->reduce_temp, op); auto input_call = GetTheInputExpr(op); @@ -540,10 +571,6 @@ class ReduceVectorizeEnable : public IRMutator { return body; } - if (cur_reduce_data_->is_reduce_y && reduce_index_ == INNER_REDUCTION_DATA_INDEX) { - return MutateReduceBody(cur_reduce_data_).Mutate(body); - } - CHECK(cur_reduce_data_->body.defined()) << "cur_reduce_data_ body is not defined"; Stmt stmt = MakeReduceStmt(); stmt = VectorizedForAmend(cur_reduce_data_).Mutate(stmt); @@ -565,13 +592,6 @@ class ReduceVectorizeEnable : public IRMutator { } else if (current_level_entry->pre_reduction_data) { cur_reduce_data_ = current_level_entry->pre_reduction_data; } - if (cur_reduce_data_->is_reduce_y) { - auto inner_reduce_data = reduce_datas_[reduce_index_]; - inner_reduce_data->reduce_temp = cur_reduce_data_->reduce_temp; - inner_reduce_data->temp_tensor = cur_reduce_data_->temp_tensor; - inner_reduce_data->temp_buffer = cur_reduce_data_->temp_buffer; - return IRMutator::Mutate(stmt); - } return stmt; } else if (op->attr_key == REDUCE_Y_FLAG) { return IRMutator::Mutate(op->body); @@ -592,6 +612,7 @@ class ReduceVectorizeEnable : public IRMutator { } else { CHECK(false) << "can not get output_shape"; } + CHECK(reduce_data->parallel_for) << "parallel_for is nullptr"; shapes.push_back(reduce_data->parallel_for->extent); for (size_t i = 0; i < output_shape.size(); i++) { if (reduce_data->reduce_axis_hidden || i != reduce_data->reduce_axis) { @@ -780,9 +801,8 @@ class ReduceVectorizeEnable : public IRMutator { } // step 3: vectorize or parallel area Stmt reduce_body = cur_reduce_data_->body; - if (!cur_reduce_data_->is_reduce_y || reduce_index_ == 0) { - reduce_body = MutateReduceBody(cur_reduce_data_).Mutate(cur_reduce_data_->body); - } + reduce_body = MutateReduceBody(cur_reduce_data_).Mutate(reduce_body); + // step 4: init area const For *for_ptr = cur_reduce_data_->is_reduce_y ? cur_reduce_data_->parallel_for : vector_parallel_for; loop_var = Variable::make(for_ptr->loop_var->type, diff --git a/src/poly/schedule_pass/tile_outer_band.cc b/src/poly/schedule_pass/tile_outer_band.cc index 34b7d1d74c1b7594712112c4f1f76468c16f769a..27b3e57798b004552c38a84337da05292cfcdfee 100644 --- a/src/poly/schedule_pass/tile_outer_band.cc +++ b/src/poly/schedule_pass/tile_outer_band.cc @@ -1355,8 +1355,6 @@ isl::schedule_node TileOuterBand::InsertMarkerForReduceY(const isl::schedule_nod node = InsertMarkerForLoop(node, FOR_PARALLEL); bool is_parallel = !GetMarkerName(node, FOR_PARALLEL).empty(); if (is_parallel) { - node = node.child(0).child(0).child(0).child(0); - node = node.insert_mark(REDUCE_AREA_FLAG); node = node.ancestor(node.get_tree_depth() - start_depth); node = node.insert_mark(REDUCE_AREA_FLAG); node = node.insert_mark(REDUCE_Y_FLAG);