diff --git a/src/pass/to_three_address.cc b/src/pass/to_three_address.cc index d6bf815dcbdec58451574b86058341241004f01e..4c82ccbb66d7f61da89e537998da1041c3b97366 100644 --- a/src/pass/to_three_address.cc +++ b/src/pass/to_three_address.cc @@ -1711,6 +1711,39 @@ class ThreeAddressStmtMutator : public IRMutator { bool cross_stmt_simplify_; }; +class ExprArgsExtract : public IRVisitor { + public: + explicit ExprArgsExtract(Array args) : args_(args) {} + ~ExprArgsExtract() override = default; + + Array GetArgs(const Expr &e) { + Visit(e); + return args_; + } + + void Visit_(const Call *op) override { + if (op->call_type == Call::CallType::Halide) { + for (Expr arg : op->args) { + if (!Contain(args_, arg)) { + args_.push_back(arg); + } + } + } + } + + private: + bool Contain(const Array &args, const Expr &arg) { + for (Expr e : args) { + if (e.same_as(arg)) { + return true; + } + } + return false; + } + + Array args_; +}; + class LoopMutator : public IRMutator { public: Stmt Mutate_(const For *op, const Stmt &s) final { @@ -1720,9 +1753,14 @@ class LoopMutator : public IRMutator { loop_vars_.push_back(op); Stmt stmt = IRMutator::Mutate(op->body); if (!provides_.empty()) { - provides_.sort([](const Provide *s1, const Provide *s2) -> bool { return s1->args.size() < s2->args.size(); }); + // This sort can generate a wrong schedule order, + // sometimes it the statement wit small number of iterator must be at the beginning + // sometimes it the statement wit small number of iterator must be at the end + // sometimes it the statement wit small number of iterator must be mixe at the begin and at the end + // need a dependence analyze to be sure that they can be move + // provides_.sort([](const Provide *s1, const Provide *s2) -> bool { return s1->args.size() < s2->args.size(); }); while (!provides_.empty()) { - SplitProides(); + SplitProvides(); } } for (size_t index = 0; index < stmts_.size(); ++index) { @@ -1750,7 +1788,7 @@ class LoopMutator : public IRMutator { } private: - void SplitProides() { + void SplitProvides() { const Provide *provide = provides_.back(); Stmt stmt = Provide::make(provide->func, provide->value_index, provide->value, provide->args); provides_.pop_back(); @@ -1763,7 +1801,8 @@ class LoopMutator : public IRMutator { provides_.pop_back(); } stmts_.insert(stmts_.begin(), stmt); - args_.insert(args_.begin(), provide->args); + Array all_args = ExprArgsExtract(provide->args).GetArgs(provide->value); + args_.insert(args_.begin(), all_args); } bool IsContain(const Array &args, const Var &var) {