diff --git a/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/BufferAnalysis.h b/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/BufferAnalysis.h new file mode 100644 index 0000000000000000000000000000000000000000..8803b569b803a54675346f0eee9f9b0b507c951e --- /dev/null +++ b/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/BufferAnalysis.h @@ -0,0 +1,258 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AKG_DIALECT_AFFINE_ANALYSIS_BUFFERANALYSIS_H +#define AKG_DIALECT_AFFINE_ANALYSIS_BUFFERANALYSIS_H + +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +namespace mlir { +namespace akg { + +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// + +/// Check if the operation is a memref reshaping operation. +inline bool isReshapingOp(Operation *op) { + return isa(op); +} + +/// Check if the operation is a memref slicing operation. +inline bool isSlicingOp(Operation *op) { return isa(op); } + +/// Check if the operation is a memref aliasing operation. +inline bool isMemRefAliasingOp(Operation *op) { return isReshapingOp(op) || isSlicingOp(op); } + +/// Get the source memref for an aliasing operation. +inline Value getAliasSource(Operation *op) { + return llvm::TypeSwitch(op) + .Case([](memref::ExpandShapeOp expand) { return expand.getSrc(); }) + .Case([](memref::CollapseShapeOp collapse) { return collapse.getSrc(); }) + .Case([](memref::SubViewOp subview) { return subview.getSource(); }) + .Case([](memref::ReshapeOp reshape) { return reshape.getSource(); }) + .Case([](memref::ReinterpretCastOp cast) { return cast.getSource(); }) + .Default([](Operation *op) { + llvm_unreachable("Unsupported aliasing op"); + return Value(); + }); +} + +//===----------------------------------------------------------------------===// +// Value Comparator +//===----------------------------------------------------------------------===// + +/// Value comparator for std::map. +inline bool isLessValue(const Value &a, const Value &b) { return a.getImpl() < b.getImpl(); } + +struct ValueComparator { + bool operator()(const Value &a, const Value &b) const { return isLessValue(a, b); } +}; + +//===----------------------------------------------------------------------===// +// UnionFind +//===----------------------------------------------------------------------===// + +/// Union-Find data structure for alias analysis. +/// Tracks connected components and maintains the minimum index in each set. +class UnionFind { + public: + explicit UnionFind(size_t n = 0) : minIndex(n), parent(n, -1) { std::iota(minIndex.begin(), minIndex.end(), 0); } + + /// Find the representative of the set containing x. + int find(int x); + + /// Join the sets containing a and b. + bool join(int a, int b); + + /// Minimum index in each connected component. + std::vector minIndex; + + private: + /// Ensure capacity for index n. + void ensureCapacity(size_t n); + + /// Parent array for union-find. Negative values indicate root with size. + std::vector parent; +}; + +//===----------------------------------------------------------------------===// +// Type Aliases +//===----------------------------------------------------------------------===// + +using IdxToValMap = std::map; +using IdxToOpMap = std::map; +using DataTypeWeightMap = llvm::DenseMap; +using ValToIdxMap = llvm::DenseMap; +using OpToIdxMap = llvm::DenseMap; + +//===----------------------------------------------------------------------===// +// WeightedLiveRange +//===----------------------------------------------------------------------===// + +/// Start, End, Weighted live range of operations. +struct WeightedLiveRange { + uint32_t start; + uint32_t end; + int64_t weight; + + explicit WeightedLiveRange(uint32_t s = 0, uint32_t e = 0, int64_t w = 1) : start(s), end(e), weight(w) {} + + bool operator<(const WeightedLiveRange &other) const { + return std::tie(start, end, weight) < std::tie(other.start, other.end, other.weight); + } +}; + +using LiveRanges = llvm::SmallVector; +using WeightedEndPair = std::pair; + +//===----------------------------------------------------------------------===// +// BufferAnalysisOptions +//===----------------------------------------------------------------------===// + +struct BufferAnalysisOptions { + using MultiBufferMap = std::map; + + /// Mapping from `value` to the multi-buffer count. + MultiBufferMap multiBufferCount; + + /// If enabled, the buffer used by DMA operations will not be reused by Vector + /// operations. + bool enableDmaOpt = false; + + /// If enabled, print live range information for debugging. + bool printLiveRange = false; +}; + +//===----------------------------------------------------------------------===// +// ValOperationIndexer +//===----------------------------------------------------------------------===// + +/// Class to index values and operations with sequential indices. +class ValOperationIndexer { + public: + ValToIdxMap valToIdx; + OpToIdxMap opToIdx; + IdxToValMap idxToVal; + IdxToOpMap idxToOp; + + static constexpr uint32_t kOpNotFoundLiveRange = static_cast(-1); + + mlir::FailureOr getVal(uint32_t idx) const; + mlir::FailureOr getOp(uint32_t idx) const; + uint32_t getClosestOpIdx(uint32_t idx) const; + uint32_t getIndex(Value val) const { return valToIdx.at(val); } + uint32_t getIndex(Operation *op) const { return opToIdx.at(op); } + uint32_t getCurrentCount() const { return opCount; } + bool insert(Value val); + bool insert(Operation *op); + + private: + uint32_t opCount = 0; +}; + +//===----------------------------------------------------------------------===// +// BufferAnalysis +//===----------------------------------------------------------------------===// + +/// Main buffer analysis class for affine dialect. +/// Computes the maximum buffer requirement using live range analysis. +class BufferAnalysis { + public: + BufferAnalysis(Block &block, const BufferAnalysisOptions &options, mlir::func::FuncOp op) + : block(block), options(options), liveness(op) {} + + /// Count the maximum number of buffers needed simultaneously. + int64_t countMaxBuffer(); + + private: + Block █ + BufferAnalysisOptions options; + mlir::Liveness liveness; + + DataTypeWeightMap dataTypeWeightMap; + llvm::DenseMap valToLiveRangeIdx; + LiveRanges liveRanges; + llvm::DenseMap> opToEndValIdx; + llvm::DenseMap aliasFurthest; + + /// Alias information using union-find. + UnionFind aliasSet; + ValOperationIndexer indexer; + + /// Check if a value is a buffer value (memref type). + static bool isUsingBuffer(const Value &value) { return isa(value.getType()); } + + /// Skip operations that are ignorable for buffer analysis. + static bool skippableOperation(Operation *op) { return isa(op); } + + /// Check if an operation is an affine memory read operation. + static bool isAffineReadOp(Operation *op) { return isa(op); } + + /// Check if an operation is an affine memory write operation. + static bool isAffineWriteOp(Operation *op) { return isa(op); } + + /// Check if an operation is a control flow operation (for, if, etc.) + static bool isControlFlowOp(Operation *op) { return isa(op); } + + /// Get the memref from a load/store operation. + static Value getMemRefFromOp(Operation *op); + + void adjustInplaceReuseOp(Operation *op); + void adjustCopyInCopyOut(Operation *op); + uint32_t insertValue(const Value &value, uint32_t pos, uint32_t weight = 1); + void recordDataTypeWeight(const Value &value, uint32_t *smallestTypeBits); + int64_t getExtraBufferSizeByFactor(Operation *op) const; + llvm::SmallVector getOperands(Operation &op) const; + uint32_t getValMultiBuffer(const Value &value, uint32_t def = 1) const; + uint32_t getValDataTypeWeight(const Value &value, uint32_t def = 1) const; + void gatherLiveRanges(const mlir::LivenessBlockInfo *blockInfo); + void processOperationForLiveRange(Operation *op, const mlir::LivenessBlockInfo *blockInfo); + void processOperationForPostProcess(Operation *op); + uint32_t updateAliasIntoFurthest(const Value &value, Operation *endOp); + void gatherDataTypeWeights(); + void processOperationForDataTypeWeight(Operation *op, uint32_t *smallestTypeBits); + void gatherIndexingAndAlias(); + void processOperationForIndexing(Operation *op); + void printLiveRanges() const; + void printAliasInfo(); + int64_t lineSweepRanges(); +}; + +//===----------------------------------------------------------------------===// +// Public API +//===----------------------------------------------------------------------===// + +/// Count the maximum number of buffers needed simultaneously for a function. +/// Returns -1 if the function has more than one block. +int64_t countMaxBuffer(mlir::func::FuncOp func, const BufferAnalysisOptions &options = {}); + +} // namespace akg +} // namespace mlir + +#endif // AKG_DIALECT_AFFINE_ANALYSIS_BUFFERANALYSIS_H diff --git a/akg-mlir/compiler/lib/Dialect/Affine/Analysis/BufferAnalysis.cpp b/akg-mlir/compiler/lib/Dialect/Affine/Analysis/BufferAnalysis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f10c0647efd9c45c84a0466ce68ef0e34211a00e --- /dev/null +++ b/akg-mlir/compiler/lib/Dialect/Affine/Analysis/BufferAnalysis.cpp @@ -0,0 +1,564 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "akg/Dialect/Affine/Analysis/BufferAnalysis.h" + +#include +#include +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "llvm/ADT/PriorityQueue.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace akg { + +//===----------------------------------------------------------------------===// +// UnionFind Implementation +//===----------------------------------------------------------------------===// + +void UnionFind::ensureCapacity(size_t n) { + if (n + 1 > parent.size()) { + size_t oldSize = parent.size(); + parent.resize(n + 1, -1); + minIndex.resize(n + 1); + for (size_t i = oldSize; i < n + 1; ++i) { + minIndex[i] = static_cast(i); + } + } +} + +int UnionFind::find(int x) { + ensureCapacity(x); + if (parent[x] < 0) { + return x; + } + return parent[x] = find(parent[x]); +} + +bool UnionFind::join(int a, int b) { + ensureCapacity(std::max(a, b)); + a = find(a); + b = find(b); + if (a != b) { + // Union by rank: attach smaller tree under root of larger tree. + if (parent[a] > parent[b]) { + std::swap(a, b); + } + parent[a] += parent[b]; + parent[b] = a; + minIndex[a] = std::min(minIndex[b], minIndex[a]); + } + return true; +} + +//===----------------------------------------------------------------------===// +// ValOperationIndexer Implementation +//===----------------------------------------------------------------------===// + +mlir::FailureOr ValOperationIndexer::getVal(uint32_t idx) const { + auto it = idxToVal.find(idx); + if (it != idxToVal.end()) { + return it->second; + } + return failure(); +} + +mlir::FailureOr ValOperationIndexer::getOp(uint32_t idx) const { + auto it = idxToOp.find(idx); + if (it != idxToOp.end()) { + return it->second; + } + return failure(); +} + +uint32_t ValOperationIndexer::getClosestOpIdx(uint32_t idx) const { + auto it = idxToOp.lower_bound(idx); + if (it == idxToOp.end()) { + return kOpNotFoundLiveRange; + } + return it->first; +} + +bool ValOperationIndexer::insert(Value val) { + llvm::outs() << val << " " << opCount << "\n"; + if (valToIdx.count(val)) { + return false; + } + valToIdx[val] = opCount; + idxToVal[opCount] = val; + opCount++; + return true; +} + +bool ValOperationIndexer::insert(Operation *op) { + if (opToIdx.count(op)) { + return false; + } + opToIdx[op] = opCount; + idxToOp[opCount] = op; + opCount++; + return true; +} + +//===----------------------------------------------------------------------===// +// BufferAnalysis Implementation +//===----------------------------------------------------------------------===// + +Value BufferAnalysis::getMemRefFromOp(Operation *op) { + if (auto loadOp = dyn_cast(op)) { + return loadOp.getMemRef(); + } + if (auto storeOp = dyn_cast(op)) { + return storeOp.getMemRef(); + } + if (auto loadOp = dyn_cast(op)) { + return loadOp.getMemRef(); + } + if (auto storeOp = dyn_cast(op)) { + return storeOp.getMemRef(); + } + return Value(); +} + +void BufferAnalysis::adjustInplaceReuseOp(Operation *op) { + if (!isa(op)) { + return; + } + + auto forOp = cast(op); + llvm::SmallVector readMemRefs, writeMemRefs; + + forOp.walk([&](mlir::affine::AffineLoadOp loadOp) { readMemRefs.push_back(loadOp.getMemRef()); }); + + forOp.walk([&](mlir::affine::AffineStoreOp storeOp) { writeMemRefs.push_back(storeOp.getMemRef()); }); + + for (auto readMem : readMemRefs) { + for (auto writeMem : writeMemRefs) { + if (readMem == writeMem) { + llvm::outs() << "In-place operation detected for memref: " << readMem << "\n"; + } + } + } +} + +void BufferAnalysis::adjustCopyInCopyOut(Operation *op) { + if (!options.enableDmaOpt) { + return; + } + + if (auto copyOp = dyn_cast(op)) { + auto srcMemref = copyOp.getSource(); + auto dstMemref = copyOp.getTarget(); + if (valToLiveRangeIdx.count(srcMemref)) { + auto rangesIndex = valToLiveRangeIdx.at(srcMemref); + liveRanges[rangesIndex].end = indexer.getCurrentCount(); + llvm::outs() << "Extended live range of source " << srcMemref << " to " << indexer.getCurrentCount() << "\n"; + } + if (valToLiveRangeIdx.count(dstMemref)) { + auto rangesIndex = valToLiveRangeIdx.at(dstMemref); + liveRanges[rangesIndex].end = indexer.getCurrentCount(); + llvm::outs() << "Extended live range of target " << dstMemref << " to " << indexer.getCurrentCount() << "\n"; + } + } +} + +uint32_t BufferAnalysis::insertValue(const Value &value, uint32_t pos, uint32_t weight) { + llvm::outs() << "--- Inserting value " << value << " " << pos << " " << weight << "\n"; + assert(!valToLiveRangeIdx.count(value)); + liveRanges.emplace_back(pos, pos, weight); + return valToLiveRangeIdx[value] = liveRanges.size() - 1; +} + +int64_t BufferAnalysis::getExtraBufferSizeByFactor(Operation *op) const { + if (auto forOp = dyn_cast(op)) { + bool hasReduction = false; + forOp.walk([&](mlir::affine::AffineStoreOp storeOp) { + Value storedMemref = storeOp.getMemRef(); + for (auto user : storedMemref.getUsers()) { + if (auto loadOp = dyn_cast(user)) { + if (loadOp->getBlock() == storeOp->getBlock()) { + hasReduction = true; + return; + } + } + } + }); + + if (hasReduction) { + return 0; + } + } + return 0; +} + +llvm::SmallVector BufferAnalysis::getOperands(Operation &op) const { + if (auto forOp = dyn_cast(op)) { + return llvm::SmallVector(forOp.getInits().begin(), forOp.getInits().end()); + } + if (auto returnOp = dyn_cast(op)) { + return llvm::SmallVector(returnOp.getOperands().begin(), returnOp.getOperands().end()); + } + return llvm::SmallVector(op.getOperands().begin(), op.getOperands().end()); +} + +uint32_t BufferAnalysis::getValMultiBuffer(const Value &value, uint32_t def) const { + auto it = options.multiBufferCount.find(value); + if (it != options.multiBufferCount.end()) { + return it->second; + } + return def; +} + +uint32_t BufferAnalysis::getValDataTypeWeight(const Value &value, uint32_t def) const { + auto it = dataTypeWeightMap.find(value); + if (it != dataTypeWeightMap.end()) { + return it->second; + } + return def; +} + +uint32_t BufferAnalysis::updateAliasIntoFurthest(const Value &val, Operation *endOp) { + auto valIdx = indexer.getIndex(val); + llvm::outs() << "found valIdx " << valIdx << "\n"; + + auto aliasParent = aliasSet.minIndex[aliasSet.find(valIdx)]; + llvm::outs() << "found alias parent " << aliasParent << "\n"; + llvm::outs() << "Ok found endIdx " << *endOp << "\n"; + + auto endIdx = indexer.getIndex(endOp); + llvm::outs() << "Ok found endIdx " << endIdx << "\n"; + + if (!aliasFurthest.count(aliasParent)) { + aliasFurthest[aliasParent] = -1; + } + + auto &furthestPtr = aliasFurthest[aliasParent]; + llvm::outs() << endIdx << " " << furthestPtr << " end -- " << *endOp << "\n"; + + if (endIdx > furthestPtr) { + llvm::outs() << "Updating furthest " << endIdx << " " << furthestPtr << "\n"; + opToEndValIdx[furthestPtr].erase(aliasParent); + furthestPtr = endIdx; + opToEndValIdx[endIdx].insert(aliasParent); + } + return aliasParent; +} + +void BufferAnalysis::processOperationForLiveRange(Operation *op, const mlir::LivenessBlockInfo *blockInfo) { + if (isControlFlowOp(op)) { + for (Region ®ion : op->getRegions()) { + for (Block &nestedBlock : region) { + for (Operation &nestedOp : nestedBlock) { + processOperationForLiveRange(&nestedOp, blockInfo); + } + } + } + return; + } + + // if (skippableOperation(op)) { + // return; + // } + + uint32_t currentOpIndex = indexer.getIndex(op); + + for (const auto &[idx, res] : llvm::enumerate(op->getResults())) { + Operation *endOp = blockInfo->getEndOperation(res, op); + auto aliasParent = updateAliasIntoFurthest(res, endOp); + + auto currentWeight = getValMultiBuffer(res) * getValDataTypeWeight(res); + llvm::outs() << "inserting " << res << "\n"; + + // Aliased values don't contribute additional weight. + if (aliasParent != indexer.getIndex(res)) { + currentWeight = 0; + } + + // bufferization.to_memref and bufferization.to_tensor results have zero weight + // as they represent external buffers that are copied into local buffers. + bool isToMemrefOp = isa(op); + bool isToTensorOp = isa(op); + + if (isToMemrefOp || isToTensorOp) { + // Check if any operand of the operation is a Block argument + bool hasBlockArgOperand = false; + for (Value operand : op->getOperands()) { + if (llvm::is_contained(block.getArguments(), operand)) { + hasBlockArgOperand = true; + break; + } + } + + if (hasBlockArgOperand) { + llvm::outs() << (isToMemrefOp ? "ToMemrefOp" : "ToTensorOp") + << " with block argument detected, setting weight to 0 for " << res << "\n"; + currentWeight = 0; + } + } + + insertValue(res, aliasParent, currentWeight); + } + + // Update live range end points for values that die at this operation. + llvm::outs() << "Printing dead val at " << currentOpIndex << " " << *op << "\n"; + for (auto deadVal : opToEndValIdx[currentOpIndex]) { + llvm::outs() << "Here is " << deadVal << "\n"; + Value curVal = indexer.getVal(deadVal).value(); + auto indexPos = valToLiveRangeIdx[curVal]; + liveRanges[indexPos].end = currentOpIndex; + } + + // Add extra buffer for operations that need temporary buffers. + if (auto extraWeight = getExtraBufferSizeByFactor(op)) { + extraWeight *= std::max((uint32_t)1, getValMultiBuffer(op->getResult(0), 0)); + llvm::outs() << "Appending " << *op << " with " << extraWeight << "\n"; + liveRanges.emplace_back(currentOpIndex, currentOpIndex, extraWeight); + } +} + +void BufferAnalysis::processOperationForPostProcess(Operation *op) { + if (isControlFlowOp(op)) { + for (Region ®ion : op->getRegions()) { + for (Block &nestedBlock : region) { + for (Operation &nestedOp : nestedBlock) { + processOperationForPostProcess(&nestedOp); + } + } + } + return; + } + + if (skippableOperation(op)) { + return; + } + + adjustInplaceReuseOp(op); + if (options.enableDmaOpt) { + adjustCopyInCopyOut(op); + } +} + +void BufferAnalysis::gatherLiveRanges(const mlir::LivenessBlockInfo *blockInfo) { + llvm::outs() << "Gathering live range information...\n"; + + // Process operations in the block. + for (auto &op : block) { + processOperationForLiveRange(&op, blockInfo); + } + + // Post-process for in-place reuse and DMA optimizations. + for (auto &op : block) { + processOperationForPostProcess(&op); + } +} + +void BufferAnalysis::recordDataTypeWeight(const Value &value, uint32_t *smallestTypeBits) { + Type type = value.getType(); + + uint32_t currentTypeBits = 0; + if (auto memrefType = dyn_cast(type)) { + auto elementType = memrefType.getElementType(); + assert(elementType.isIntOrFloat() && "Can only handle int or float element type!"); + currentTypeBits = static_cast(elementType.getIntOrFloatBitWidth()); + } else if (auto tensorType = dyn_cast(type)) { + auto elementType = tensorType.getElementType(); + assert(elementType.isIntOrFloat() && "Can only handle int or float element type!"); + currentTypeBits = static_cast(elementType.getIntOrFloatBitWidth()); + } else if (auto intType = dyn_cast(type)) { + currentTypeBits = static_cast(intType.getWidth()); + } else if (auto floatType = dyn_cast(type)) { + currentTypeBits = static_cast(floatType.getWidth()); + } else { + return; + } + + dataTypeWeightMap[value] = currentTypeBits; + *smallestTypeBits = std::min(*smallestTypeBits, currentTypeBits); +} + +void BufferAnalysis::processOperationForDataTypeWeight(Operation *op, uint32_t *smallestTypeBits) { + if (isControlFlowOp(op)) { + for (Region ®ion : op->getRegions()) { + for (Block &nestedBlock : region) { + for (Operation &nestedOp : nestedBlock) { + processOperationForDataTypeWeight(&nestedOp, smallestTypeBits); + } + } + } + return; + } + + for (const auto &[idx, res] : llvm::enumerate(op->getResults())) { + recordDataTypeWeight(res, smallestTypeBits); + } + + for (const auto &[idx, operand] : llvm::enumerate(op->getOperands())) { + recordDataTypeWeight(operand, smallestTypeBits); + } +} + +void BufferAnalysis::gatherDataTypeWeights() { + llvm::outs() << "Gathering data type information...\n"; + uint32_t smallestTypeBits = std::numeric_limits::max(); + + // Process operations in the block. + // Note: Block arguments are now tensors, which are converted to memrefs via + // bufferization.to_memref operations. The memref results will be processed here. + for (auto &op : block) { + processOperationForDataTypeWeight(&op, &smallestTypeBits); + } + + llvm::outs() << "Smallest type bits is " << smallestTypeBits << ", normalizing weights...\n"; + + // Normalize weights based on smallest type. + for (auto &[value, bits] : dataTypeWeightMap) { + auto normalizedTypeBits = bits / smallestTypeBits; + if (bits % smallestTypeBits != 0) { + llvm::outs() << "WARN: Current type bits " << bits + << " is not divisible by the smallest type bits! Rounding up...\n"; + normalizedTypeBits = (bits + smallestTypeBits - 1) / smallestTypeBits; + } + bits = normalizedTypeBits; + } +} + +void BufferAnalysis::printLiveRanges() const { + llvm::outs() << "Considering " << valToLiveRangeIdx.size() << " and " << liveRanges.size() - valToLiveRangeIdx.size() + << " extra Live Range:\n"; + + for (size_t i = 0; i < liveRanges.size(); i++) { + llvm::outs() << "Live Range #" << i << ": \n"; + if (i == 0 || liveRanges[i].start != liveRanges[i - 1].start) { + auto currentVal = indexer.getVal(liveRanges[i].start); + if (succeeded(currentVal)) { + llvm::outs() << currentVal.value() << ": \n"; + } else { + llvm::outs() << *indexer.getOp(liveRanges[i].start).value() << ": \n"; + } + } + + llvm::outs() << liveRanges[i].start << " " << liveRanges[i].end << " " << liveRanges[i].weight << "\n"; + llvm::outs() << "Done Live Range\n"; + } +} + +int64_t BufferAnalysis::lineSweepRanges() { + // Min-heap sorted by end time. + llvm::PriorityQueue, std::greater> earlyDone; + + int64_t maxBuffer = 0; + int64_t currentBuffer = 0; + + for (const auto &liveRange : liveRanges) { + if (liveRange.start == liveRange.end) { + llvm::outs() << "WARN: dead operation or temporary buffer exists at position " << liveRange.start << "\n"; + } + + // Remove buffers that have ended before the current start. + while (!earlyDone.empty() && earlyDone.top().first < liveRange.start) { + currentBuffer -= earlyDone.top().second; + earlyDone.pop(); + } + + earlyDone.push({liveRange.end, liveRange.weight}); + currentBuffer += liveRange.weight; + maxBuffer = std::max(maxBuffer, currentBuffer); + } + return maxBuffer; +} + +void BufferAnalysis::printAliasInfo() { + for (const auto &[idx, val] : indexer.idxToVal) { + auto aliasParent = indexer.getVal(aliasSet.minIndex[aliasSet.find(idx)]); + if (aliasParent != val) { + llvm::outs() << "value: " << val << " alias parent is: " << aliasParent << "\n"; + } + } +} + +void BufferAnalysis::processOperationForIndexing(Operation *op) { + if (isControlFlowOp(op)) { + indexer.insert(op); + for (Region ®ion : op->getRegions()) { + for (Block &nestedBlock : region) { + for (Operation &nestedOp : nestedBlock) { + processOperationForIndexing(&nestedOp); + } + } + } + return; + } + + llvm::outs() << "Processing op: " << *op << "\n"; + + for (auto res : op->getResults()) { + indexer.insert(res); + + if (!isUsingBuffer(res)) { + continue; + } + + // Handle aliasing operations (subview, reshape, etc.). + if (isMemRefAliasingOp(op)) { + auto src = getAliasSource(op); + if (indexer.valToIdx.count(src)) { + auto aliasSrcPar = indexer.getIndex(src); + aliasSet.join(aliasSrcPar, indexer.getIndex(res)); + } + } + } + + llvm::outs() << "Inserting op " << *op << "\n"; + indexer.insert(op); +} + +void BufferAnalysis::gatherIndexingAndAlias() { + llvm::outs() << "Gathering alias information...\n"; + + // Process operations. + for (auto &op : block) { + processOperationForIndexing(&op); + } + printAliasInfo(); +} + +int64_t BufferAnalysis::countMaxBuffer() { + const mlir::LivenessBlockInfo *blockInfo = liveness.getLiveness(&block); + gatherIndexingAndAlias(); + gatherDataTypeWeights(); + gatherLiveRanges(blockInfo); + llvm::sort(liveRanges); + if (options.printLiveRange) { + printLiveRanges(); + } + return lineSweepRanges(); +} + +//===----------------------------------------------------------------------===// +// Public API Implementation +//===----------------------------------------------------------------------===// + +int64_t countMaxBuffer(mlir::func::FuncOp func, const BufferAnalysisOptions &options) { + if (func.getBody().getBlocks().size() != 1) { + return -1; + } + + BufferAnalysis analysis(*func.getBody().begin(), options, func); + return analysis.countMaxBuffer(); +} + +} // namespace akg +} // namespace mlir diff --git a/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AKGLoopTiling.cpp b/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AKGLoopTiling.cpp index 073e95790e9dc4d88ae80474af0b75cf77edb8a3..b66e9e771430f01209593f2cd18a2000c019594c 100644 --- a/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AKGLoopTiling.cpp +++ b/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AKGLoopTiling.cpp @@ -21,6 +21,7 @@ #include #include #include "akg/Dialect/Affine/Analysis/AutoTiling.h" +#include "akg/Dialect/Affine/Analysis/BufferAnalysis.h" #include "akg/Utils/AKGGlobalVars.hpp" #include "akg/Utils/AnalysisCommon.hpp" #include "akg/Utils/AnalysisForGpu.hpp" @@ -148,23 +149,27 @@ std::unique_ptr> mlir::createAKGLoopTili } /// Creates a pass to perform loop tiling using auto-tiling strategy std::unique_ptr> mlir::createAKGLoopTilingPass(const std::string &target, - bool useAutoTiling) { + bool useAutoTiling) { return std::make_unique(target, useAutoTiling); } /// Creates a pass to perform loop tiling using auto-tiling strategy for dynamic shape -std::unique_ptr> mlir::createAKGLoopTilingPass( - const std::string &target, bool useAutoTiling, const std::string &tilingMode) { +std::unique_ptr> mlir::createAKGLoopTilingPass(const std::string &target, + bool useAutoTiling, + const std::string &tilingMode) { return std::make_unique(target, useAutoTiling, tilingMode); } -std::unique_ptr> mlir::createAKGLoopTilingPass( - const std::string &target, const std::string &feature, bool useAutoTiling) { +std::unique_ptr> mlir::createAKGLoopTilingPass(const std::string &target, + const std::string &feature, + bool useAutoTiling) { return std::make_unique(target, feature, useAutoTiling); } -std::unique_ptr> mlir::createAKGLoopTilingPass( - const std::string &target, bool useAutoTiling, const std::string &arch, const std::string &feature) { +std::unique_ptr> mlir::createAKGLoopTilingPass(const std::string &target, + bool useAutoTiling, + const std::string &arch, + const std::string &feature) { return std::make_unique(target, useAutoTiling, arch, feature); } @@ -401,8 +406,8 @@ void AKGLoopTiling::setNewUpperBound(mlir::MutableArrayRef intraTileLoops, - SmallVectorImpl &fullTileLoops) { +mlir::LogicalResult AKGLoopTiling::createFullBlock(mlir::MutableArrayRef intraTileLoops, + SmallVectorImpl &fullTileLoops) { if (intraTileLoops.size() == 0) { return mlir::success(); } @@ -506,8 +510,7 @@ mlir::LogicalResult AKGLoopTiling::createFullBlock( } // Add the body for the full tile loop nest. for (const auto &loopEn : llvm::enumerate(intraTileLoops)) { - mlir::replaceAllUsesInRegionWith(loopEn.value().getInductionVar(), - fullTileLoops[loopEn.index()].getInductionVar(), + mlir::replaceAllUsesInRegionWith(loopEn.value().getInductionVar(), fullTileLoops[loopEn.index()].getInductionVar(), fullTileLoops[loopEn.index()].getRegion()); } @@ -599,8 +602,7 @@ mlir::LogicalResult AKGLoopTiling::createTailBlock(mlir::affine::AffineForOp for return mlir::success(); } -mlir::LogicalResult AKGLoopTiling::createTailBlockStatic(mlir::affine::AffineForOp forOp, - int64_t differenceUbAndLb) { +mlir::LogicalResult AKGLoopTiling::createTailBlockStatic(mlir::affine::AffineForOp forOp, int64_t differenceUbAndLb) { auto origUbMap = forOp.getUpperBoundMap(); auto origLbMap = forOp.getLowerBoundMap(); int64_t origStep = forOp.getStepAsInt(); @@ -663,8 +665,7 @@ mlir::LogicalResult AKGLoopTiling::createTailBlockStatic(mlir::affine::AffineFor tailForOp.setLowerBoundMap(ubMap); tailForOp.setUpperBoundMap(origUbMap); tailForOp.setStep(tailSize); - mlir::replaceAllUsesInRegionWith(forOp.getInductionVar(), tailForOp.getInductionVar(), - tailForOp.getRegion()); + mlir::replaceAllUsesInRegionWith(forOp.getInductionVar(), tailForOp.getInductionVar(), tailForOp.getRegion()); updateForOpUsers(tailForOp, tailSize); // Recursively processes the tailForOp body. @@ -1039,9 +1040,7 @@ void AKGLoopTiling::runCpuOperation() { } } -bool AKGLoopTiling::isDynamicShape() const { - return akgglobal::ShapeAlignTool::getInstance().getFuncArgSizes() > 0; -} +bool AKGLoopTiling::isDynamicShape() const { return akgglobal::ShapeAlignTool::getInstance().getFuncArgSizes() > 0; } void AKGLoopTiling::runNpuOperation() { if (band.empty()) { @@ -1144,6 +1143,11 @@ void AKGLoopTiling::runOnOperation() { } if (useAutoTiling) { + mlir::akg::BufferAnalysisOptions options; + options.enableDmaOpt = false; + auto maxBuffer = countMaxBuffer(funcOp, options); + llvm::outs() << "maxBuffer: " << maxBuffer << "\n"; + auto initGraph = mlir::akg::autotiling::parseIr(funcOp, bands); initGraph->setHardware(target); initGraph->setFeature(feature); @@ -1154,8 +1158,7 @@ void AKGLoopTiling::runOnOperation() { mlir::OpBuilder builder(funcOp); SmallVector tileSizeAttrs; tileSizeAttrs.reserve(this->multiTileSizes.size()); - std::transform(this->multiTileSizes.begin(), this->multiTileSizes.end(), - std::back_inserter(tileSizeAttrs), + std::transform(this->multiTileSizes.begin(), this->multiTileSizes.end(), std::back_inserter(tileSizeAttrs), [&builder](unsigned size) { return builder.getI32IntegerAttr(size); }); funcOp->setAttr("npu.multiTileSizes", builder.getArrayAttr(tileSizeAttrs)); }