From 6b8aa79325dc7339e6371c54923604310c28971a Mon Sep 17 00:00:00 2001 From: Nelson Lossing Date: Fri, 3 Jun 2022 15:54:55 +0000 Subject: [PATCH] FIX to_three_address pass LoopMutator function modify the schedule order of the generated ThreeAdressStatements that was sometimes wrong. - Remove this wrong rescheduling - Generate the right iterator loop --- src/pass/to_three_address.cc | 47 +++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/src/pass/to_three_address.cc b/src/pass/to_three_address.cc index d6bf815d..4c82ccbb 100644 --- a/src/pass/to_three_address.cc +++ b/src/pass/to_three_address.cc @@ -1711,6 +1711,39 @@ class ThreeAddressStmtMutator : public IRMutator { bool cross_stmt_simplify_; }; +class ExprArgsExtract : public IRVisitor { + public: + explicit ExprArgsExtract(Array args) : args_(args) {} + ~ExprArgsExtract() override = default; + + Array GetArgs(const Expr &e) { + Visit(e); + return args_; + } + + void Visit_(const Call *op) override { + if (op->call_type == Call::CallType::Halide) { + for (Expr arg : op->args) { + if (!Contain(args_, arg)) { + args_.push_back(arg); + } + } + } + } + + private: + bool Contain(const Array &args, const Expr &arg) { + for (Expr e : args) { + if (e.same_as(arg)) { + return true; + } + } + return false; + } + + Array args_; +}; + class LoopMutator : public IRMutator { public: Stmt Mutate_(const For *op, const Stmt &s) final { @@ -1720,9 +1753,14 @@ class LoopMutator : public IRMutator { loop_vars_.push_back(op); Stmt stmt = IRMutator::Mutate(op->body); if (!provides_.empty()) { - provides_.sort([](const Provide *s1, const Provide *s2) -> bool { return s1->args.size() < s2->args.size(); }); + // This sort can generate a wrong schedule order, + // sometimes it the statement wit small number of iterator must be at the beginning + // sometimes it the statement wit small number of iterator must be at the end + // sometimes it the statement wit small number of iterator must be mixe at the begin and at the end + // need a dependence analyze to be sure that they can be move + // provides_.sort([](const Provide *s1, const Provide *s2) -> bool { return s1->args.size() < s2->args.size(); }); while (!provides_.empty()) { - SplitProides(); + SplitProvides(); } } for (size_t index = 0; index < stmts_.size(); ++index) { @@ -1750,7 +1788,7 @@ class LoopMutator : public IRMutator { } private: - void SplitProides() { + void SplitProvides() { const Provide *provide = provides_.back(); Stmt stmt = Provide::make(provide->func, provide->value_index, provide->value, provide->args); provides_.pop_back(); @@ -1763,7 +1801,8 @@ class LoopMutator : public IRMutator { provides_.pop_back(); } stmts_.insert(stmts_.begin(), stmt); - args_.insert(args_.begin(), provide->args); + Array all_args = ExprArgsExtract(provide->args).GetArgs(provide->value); + args_.insert(args_.begin(), all_args); } bool IsContain(const Array &args, const Var &var) { -- Gitee