From ff0c7d31ce49de3e49677f62dea03e51a499ee61 Mon Sep 17 00:00:00 2001 From: r1chardf1d0 Date: Mon, 17 May 2021 11:56:46 +0800 Subject: [PATCH] fix stitch bug when store is shared but load is not shared --- src/composite/stitch_fusion.cc | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/src/composite/stitch_fusion.cc b/src/composite/stitch_fusion.cc index 8e7f398f..45638112 100644 --- a/src/composite/stitch_fusion.cc +++ b/src/composite/stitch_fusion.cc @@ -137,7 +137,6 @@ class StitchMutate : public IRMutator { Stmt Mutate_(const Store *op, const Stmt &s) final { Var var = op->buffer_var; auto name = var->name_hint; - auto index = op->index; if (stitch_buffer_map_.count(name) && !IsOutput(name)) { auto info = stitch_buffer_map_[name]; if (info.type == StorageType::Shared) { @@ -146,9 +145,10 @@ class StitchMutate : public IRMutator { Var shared = new_buffer ? Var(shared_name) : vars_[shared_name]; vars_[shared_name] = shared; stitch_buffer_map_[shared_name] = info; - fix_producer_ = true; - auto stmt = Store::make(shared, this->Mutate(op->value), this->Mutate(index), op->predicate); - fix_producer_ = false; + rm_block_ = true; + auto index = this->Mutate(op->index); + rm_block_ = false; + auto stmt = Store::make(shared, this->Mutate(op->value), index, op->predicate); if (new_buffer) new_allocate_.insert(stmt.as()); return stmt; } else { @@ -156,7 +156,7 @@ class StitchMutate : public IRMutator { } } if (stitch_type_ == StitchOpType::Broadcast) - return Store::make(op->buffer_var, this->Mutate(op->value), this->Mutate(index), op->predicate); + return Store::make(op->buffer_var, this->Mutate(op->value), this->Mutate(op->index), op->predicate); return IRMutator::Mutate_(op, s); } @@ -168,10 +168,10 @@ class StitchMutate : public IRMutator { if (kv.second.name == name || kv.first == name) { auto info = kv.second; Var replace = GetReplaceVar(var, vars_, kv.first, info); - if (info.type == StorageType::Shared || info.type == StorageType::Global) { - fix_consumer_ = true; + if (info.type == StorageType::Shared) { + rm_block_ = true; index = this->Mutate(index); - fix_consumer_ = false; + rm_block_ = false; return Load::make(op->type, replace, index, op->predicate); } } @@ -183,13 +183,10 @@ class StitchMutate : public IRMutator { Expr Mutate_(const Variable *op, const Expr &e) final { auto name = op->name_hint; - if (fix_producer_ || fix_consumer_) { - if (IsBlockIdx(name)) return 0; - } // substitute idx - if (IsBlockIdxX(name)) return block_idx_x; - if (IsBlockIdxY(name)) return block_idx_y; - if (IsBlockIdxZ(name)) return block_idx_z; + if (IsBlockIdxX(name)) return rm_block_ ? Expr(0) : block_idx_x; + if (IsBlockIdxY(name)) return rm_block_ ? Expr(0) : block_idx_y; + if (IsBlockIdxZ(name)) return rm_block_ ? Expr(0) : block_idx_z; if (IsThreadIdxX(name)) return thread_idx_x; if (IsThreadIdxY(name)) return thread_idx_y; if (IsThreadIdxZ(name)) return thread_idx_z; @@ -204,8 +201,7 @@ class StitchMutate : public IRMutator { size_t phase{0}; private: - bool fix_producer_{false}; - bool fix_consumer_{false}; + bool rm_block_{false}; std::unordered_map &stitch_buffer_map_; std::unordered_map &buf_within_op_map_; std::vector &allocate_revoke_; -- Gitee