From 0457de9a9458ee387795c144deae4b9f6106f32a Mon Sep 17 00:00:00 2001 From: hujiahui8 Date: Thu, 16 Oct 2025 14:22:06 +0800 Subject: [PATCH] MLIR: fix bug && add ut case --- .../Affine/Analysis/AKGLoopFusionBuilder.h | 6 ++- .../Affine/Analysis/DependenceAnalysis.h | 9 ++-- .../Affine/Analysis/AKGLoopFusionAnalyzer.cpp | 44 +++------------- .../Affine/Analysis/AKGLoopFusionBuilder.cpp | 50 +++++++++++++++++-- .../Affine/Analysis/DependenceAnalysis.cpp | 32 ++++++++---- .../Linalg/Transforms/MemrefCopyToLoops.cpp | 31 +++--------- .../Pipelines/AscendPipelines/AscendOpt.cpp | 5 +- .../ut/Dialect/Affine/akg_loop_fusion.mlir | 49 ++++++++++++++++++ 8 files changed, 145 insertions(+), 81 deletions(-) create mode 100644 akg-mlir/tests/ut/Dialect/Affine/akg_loop_fusion.mlir diff --git a/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/AKGLoopFusionBuilder.h b/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/AKGLoopFusionBuilder.h index 4689356d..e89badac 100644 --- a/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/AKGLoopFusionBuilder.h +++ b/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/AKGLoopFusionBuilder.h @@ -57,6 +57,7 @@ public: void setIsGlobalOut(bool isOut) { isGlobalOut = isOut; } void dump() const { print(llvm::errs()); } void print(llvm::raw_ostream &os) const; + std::string getGroupTemplateString() const; private: std::unordered_map loopTransformToStr{ @@ -84,13 +85,14 @@ struct LoopNestStateCollector { // Inherits from MemRefDependenceGraph and adds fusion-specific functionality struct MemRefDependenceGraphForFusion : public MemRefDependenceGraph { public: - explicit MemRefDependenceGraphForFusion(Block *block) : MemRefDependenceGraph(block) {} + explicit MemRefDependenceGraphForFusion(Block *block) : MemRefDependenceGraph(block, false) {} GroupPtr getGroup(unsigned groupId); GroupPtr getGroupByNode(unsigned nodeId); std::unordered_set getGroupsByNode(llvm::DenseSet nodeIds); bool init(); - void print(llvm::raw_ostream &os) const; + void print(llvm::raw_ostream &os) const override; + void dump() const override { print(llvm::errs()); } void createInitNode(llvm::DenseMap> &memrefAccesses); OperatorTemplate getGroupType(const std::vector &nodes); bool elementwiseMatch(Operation *op); diff --git a/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/DependenceAnalysis.h b/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/DependenceAnalysis.h index e9e4ca6c..fb6fb031 100644 --- a/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/DependenceAnalysis.h +++ b/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/DependenceAnalysis.h @@ -81,8 +81,11 @@ struct MemRefDependenceGraph { unsigned nextNodeId = 0; // The block for which this graph is created to perform fusion. Block *block; + // Whether to use AKG-specific analysis. + bool useAKGAnalysis; - explicit MemRefDependenceGraph(Block *block) : block(block) {} + explicit MemRefDependenceGraph(Block *block, bool useAKGAnalysis = true) + : block(block), useAKGAnalysis(useAKGAnalysis) {} void createInitNode(DenseMap> &memrefAccesses); bool createEdges(const DenseMap> &memrefAccesses); @@ -103,8 +106,8 @@ struct MemRefDependenceGraph { void getSuccessorNodes(unsigned id, DenseSet &dependentNodes); void getSuccessorNodes(unsigned id, std::vector &dependentNodes); - void print(raw_ostream &os) const; - void dump() const { print(llvm::errs()); } + virtual void print(raw_ostream &os) const; + virtual void dump() const { print(llvm::errs()); } }; } // namespace akg diff --git a/akg-mlir/compiler/lib/Dialect/Affine/Analysis/AKGLoopFusionAnalyzer.cpp b/akg-mlir/compiler/lib/Dialect/Affine/Analysis/AKGLoopFusionAnalyzer.cpp index 27f6709d..c887583c 100644 --- a/akg-mlir/compiler/lib/Dialect/Affine/Analysis/AKGLoopFusionAnalyzer.cpp +++ b/akg-mlir/compiler/lib/Dialect/Affine/Analysis/AKGLoopFusionAnalyzer.cpp @@ -127,8 +127,10 @@ void FusionAnalyzer::checkAndFixMultiOut(FusionPlan &fusePlan) { auto orderGroups = [&]() -> std::pair { auto oldGroup = depGraph.getGroup(oldPlan.fusedGroup.to); auto newGroup = depGraph.getGroup(fusePlan.fusedGroup.to); - llvm::outs() << "oldGroup " << oldPlan.fusedGroup.to << " template " << oldGroup->groupTemplate - << ", new group " << fusePlan.fusedGroup.to << " template " << newGroup->groupTemplate << "\n"; + std::string oldTemplateString = oldGroup->getGroupTemplateString(); + std::string newTemplateString = newGroup->getGroupTemplateString(); + llvm::outs() << "oldGroup " << oldPlan.fusedGroup.to << " groupTemplate " << oldTemplateString + << ", new group " << fusePlan.fusedGroup.to << " groupTemplate " << newTemplateString << "\n"; if (oldGroup->groupTemplate == OperatorTemplate::Reduce) { return std::make_pair(newGroup, oldGroup); } @@ -286,41 +288,6 @@ void FusionAnalyzer::plan() { fusionPlans = sortedPlan; } -void Group::print(raw_ostream &os) const { - std::string indent = " "; - os << "[Group " << groupId << "]\n"; - auto it3 = operatorTemplateMap.find(static_cast(groupTemplate)); - if (it3 != operatorTemplateMap.end()) { - os << indent << ">> GroupTemplate: " << it3->second << "\n"; - } else { - os << indent << ">> GroupTemplate: " << static_cast(groupTemplate) << "\n"; - } - os << indent << ">> FusedGroups: ["; - for (auto gid : fusedGroupId) { - os << gid << ", "; - } - os << "]\n"; - os << indent << ">> Nodes: ["; - for (auto nid : nodesId) { - os << nid << ", "; - } - os << "]\n"; - os << indent << ">> LoopTransforms: [\n"; - for (auto it : nodeTransformRecords) { - os << indent << indent << ">> Node " << it.first << ": ["; - for (LoopTransform lt : it.second) { - int ltI = static_cast(lt); - auto it2 = loopTransformToStr.find(ltI); - if (it2 != loopTransformToStr.end()) { - os << it2->second << " -> "; - } else { - os << ltI << " -> "; - } - } - os << "]\n"; - } - os << indent << indent << "]\n"; -} void FusionAnalyzer::initGroups() { groups = depGraph.groups; } @@ -332,7 +299,8 @@ void FusionAnalyzer::topoSort() { } std::sort(allGroups.begin(), allGroups.end(), GroupCmp); for (auto g : allGroups) { - llvm::outs() << "group " << g->groupId << " temp " << g->groupTemplate << "\n"; + std::string groupTemplateString = g->getGroupTemplateString(); + llvm::outs() << "group " << g->groupId << " groupTemplate " << groupTemplateString << "\n"; for (auto node : g->nodesId) { topoSortNodeIds.push_back(node); } diff --git a/akg-mlir/compiler/lib/Dialect/Affine/Analysis/AKGLoopFusionBuilder.cpp b/akg-mlir/compiler/lib/Dialect/Affine/Analysis/AKGLoopFusionBuilder.cpp index 076afe31..d187d0aa 100644 --- a/akg-mlir/compiler/lib/Dialect/Affine/Analysis/AKGLoopFusionBuilder.cpp +++ b/akg-mlir/compiler/lib/Dialect/Affine/Analysis/AKGLoopFusionBuilder.cpp @@ -14,8 +14,9 @@ * limitations under the License. */ -#include "akg/Dialect/Affine/Analysis/AKGLoopFusionAnalyzer.h" +#include "akg/Dialect/Affine/Analysis/AKGLoopFusionBuilder.h" #include "akg/Utils/AKGGlobalVars.hpp" +#include "akg/Utils/AnalysisCommon.hpp" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" @@ -141,7 +142,7 @@ bool MemRefDependenceGraphForFusion::init() { DenseMap> memrefAccesses; createInitNode(memrefAccesses); // Create Node Edges - MemRefDependenceGraph::createEdges(memrefAccesses); + createEdges(memrefAccesses); return true; } @@ -220,7 +221,8 @@ void MemRefDependenceGraphForFusion::print(raw_ostream &os) const { MemRefDependenceGraph::print(os); for (auto it : groups) { auto g = it.second; - os << "Group " << g->groupId << " (Type " << g->groupTemplate << ") IsGlobalOut (" << g->isGlobalOut << ") root is " + std::string groupTemplateString = g->getGroupTemplateString(); + os << "Group " << g->groupId << " (GroupTemplate " << groupTemplateString << ") IsGlobalOut (" << g->isGlobalOut << ") root is " << g->rootId << " has " << g->nodesId.size() << " nodes inside: ["; for (auto nid : g->nodesId) { os << nid << ", "; @@ -487,5 +489,47 @@ void FusionCodeGenHelper::doHFuse(unsigned srcId, unsigned dstId, affine::Affine // srcNode = nullptr; nodeAlias[srcId] = dstId; } + +std::string Group::getGroupTemplateString() const { + auto it = operatorTemplateMap.find(static_cast(groupTemplate)); + if (it != operatorTemplateMap.end()) { + return it->second; + } else { + return std::to_string(static_cast(groupTemplate)); + } +} + +void Group::print(raw_ostream &os) const { + std::string indent = " "; + os << "[Group " << groupId << "]\n"; + std::string groupTemplateString = getGroupTemplateString(); + os << indent << ">> GroupTemplate: " << groupTemplateString << "\n"; + os << indent << ">> FusedGroups: ["; + for (auto gid : fusedGroupId) { + os << gid << ", "; + } + os << "]\n"; + os << indent << ">> Nodes: ["; + for (auto nid : nodesId) { + os << nid << ", "; + } + os << "]\n"; + os << indent << ">> LoopTransforms: [\n"; + for (auto it : nodeTransformRecords) { + os << indent << indent << ">> Node " << it.first << ": ["; + for (LoopTransform lt : it.second) { + int ltI = static_cast(lt); + auto it2 = loopTransformToStr.find(ltI); + if (it2 != loopTransformToStr.end()) { + os << it2->second << " -> "; + } else { + os << ltI << " -> "; + } + } + os << "]\n"; + } + os << indent << indent << "]\n"; +} + } // namespace akg } // namespace mlir diff --git a/akg-mlir/compiler/lib/Dialect/Affine/Analysis/DependenceAnalysis.cpp b/akg-mlir/compiler/lib/Dialect/Affine/Analysis/DependenceAnalysis.cpp index 8eea951f..7b121cb6 100644 --- a/akg-mlir/compiler/lib/Dialect/Affine/Analysis/DependenceAnalysis.cpp +++ b/akg-mlir/compiler/lib/Dialect/Affine/Analysis/DependenceAnalysis.cpp @@ -192,15 +192,29 @@ bool MemRefDependenceGraph::hasMemrefAccessDependence(unsigned srcId, unsigned d Operation *srcOp = getNode(srcId)->op; Operation *dstOp = getNode(dstId)->op; unsigned numCommonLoops = affine::getNumCommonSurroundingLoops(*srcOp, *dstOp); - affine::AKGMemRefAccess srcAccess(srcOp); - affine::AKGMemRefAccess dstAccess(dstOp); - for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { - affine::FlatAffineValueConstraints dependenceConstraints; - // todo: Cache dependence analysis results, check cache here. - affine::DependenceResult result = - mlir::affine::checkMemrefAccessDependenceAKG(srcAccess, dstAccess, d, &dependenceConstraints, nullptr); - if (result.value == affine::DependenceResult::HasDependence) { - return true; + + if (useAKGAnalysis) { + affine::AKGMemRefAccess srcAccess(srcOp); + affine::AKGMemRefAccess dstAccess(dstOp); + for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { + affine::FlatAffineValueConstraints dependenceConstraints; + // todo: Cache dependence analysis results, check cache here. + affine::DependenceResult result = + affine::checkMemrefAccessDependenceAKG(srcAccess, dstAccess, d, &dependenceConstraints, nullptr); + if (result.value == affine::DependenceResult::HasDependence) { + return true; + } + } + } else { + affine::MemRefAccess srcAccess(srcOp); + affine::MemRefAccess dstAccess(dstOp); + for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { + affine::FlatAffineValueConstraints dependenceConstraints; + affine::DependenceResult result = + affine::checkMemrefAccessDependence(srcAccess, dstAccess, d, &dependenceConstraints, nullptr); + if (result.value == affine::DependenceResult::HasDependence) { + return true; + } } } return false; diff --git a/akg-mlir/compiler/lib/Dialect/Linalg/Transforms/MemrefCopyToLoops.cpp b/akg-mlir/compiler/lib/Dialect/Linalg/Transforms/MemrefCopyToLoops.cpp index 653f4be4..b3b057bd 100644 --- a/akg-mlir/compiler/lib/Dialect/Linalg/Transforms/MemrefCopyToLoops.cpp +++ b/akg-mlir/compiler/lib/Dialect/Linalg/Transforms/MemrefCopyToLoops.cpp @@ -43,9 +43,7 @@ struct MemrefCopyToLoops : public impl::MemrefCopyToLoopsBase MemrefCopyToLoops() = default; void runOnOperation() override{ - llvm::outs() << "createMemrefCopyToLoopsPass\n"; func::FuncOp funcOp = getOperation(); - // auto isReduceOp = CommonUtils::getOperatorType(funcOp) == OperatorTemplate::Reduce; SmallVector needConvert; (void)funcOp->walk([&](memref::CopyOp copyOp) { @@ -53,47 +51,30 @@ struct MemrefCopyToLoops : public impl::MemrefCopyToLoopsBase if (!srcOp){ return; } - llvm::outs() << "srcOp\n"; - srcOp->dump(); + if (auto tomem = dyn_cast(srcOp)){ - llvm::outs() << "YES\n"; auto mem = tomem.getMemref(); - llvm::outs() << "mem\n"; - mem.dump(); - llvm::outs() << "getTensor\n"; - tomem.getTensor().dump(); - llvm::outs() << "getTensor memref\n"; auto totensor = tomem.getTensor().getDefiningOp(); - if (totensor){ - totensor->dump(); + if (!totensor){ + return; } if (auto tt = dyn_cast(totensor)) { - llvm::outs() << "IS TENSOR\n"; - tt.dump(); auto tensormem = tt.getMemref().getDefiningOp(); - llvm::outs() << "tensormem\n"; - tensormem->dump(); + if (!tensormem){ + return; + } if (isa(tensormem)) { needConvert.emplace_back(copyOp); } } } - // if (!isa(srcOp) && !copyOp.getTarget().getDefiningOp()) { - // return; - // } - // if (!copyOp.getSource().getDefiningOp() && !copyOp.getTarget().getDefiningOp() && isReduceOp) { - // llvm::outs() << "Copy output, don't change\n"; - // return; - // } }); OpBuilder builder(funcOp); for (auto copyOp : needConvert) { - copyOp.dump(); builder.setInsertionPoint(copyOp); auto newCopyOp = makeMemRefCopyOp(builder, copyOp->getLoc(), copyOp.getSource(), copyOp.getTarget()); - newCopyOp.dump(); copyOp.getOperation()->replaceAllUsesWith(newCopyOp.getOperation()); copyOp.erase(); } diff --git a/akg-mlir/compiler/lib/Pipelines/AscendPipelines/AscendOpt.cpp b/akg-mlir/compiler/lib/Pipelines/AscendPipelines/AscendOpt.cpp index 5ac81598..17a36dd9 100644 --- a/akg-mlir/compiler/lib/Pipelines/AscendPipelines/AscendOpt.cpp +++ b/akg-mlir/compiler/lib/Pipelines/AscendPipelines/AscendOpt.cpp @@ -60,8 +60,11 @@ void createAscendOptPipelineImpl(OpPassManager &pm, const AscendOptPipelineOptio nestedFunctionPM.addPass(tosa::createTosaToLinalg()); if (options.enableAKGLoopFusion) { - pm.addPass(bufferization::createBufferResultsToOutParamsPass()); + bool keepFakeOuts = true; + nestedFunctionPM.addPass(createLinalgCopyBufferizePass(keepFakeOuts)); + pm.addPass(bufferization::createEmptyTensorToAllocTensorPass()); pm.addPass(bufferization::createOneShotBufferizePass()); + pm.addPass(createCanonicalizerPass()); pm.addPass(createMemrefCopyToLoopsPass()); pm.addPass(createMatchAndMarkReductionOpsPass()); diff --git a/akg-mlir/tests/ut/Dialect/Affine/akg_loop_fusion.mlir b/akg-mlir/tests/ut/Dialect/Affine/akg_loop_fusion.mlir new file mode 100644 index 00000000..b42dac9a --- /dev/null +++ b/akg-mlir/tests/ut/Dialect/Affine/akg_loop_fusion.mlir @@ -0,0 +1,49 @@ +// RUN: akg-opt %s -akg-loop-fusion | FileCheck %s + +// CHECK-LABEL: func.func @Fused_Mul_ReduceSum(%arg0: tensor<2x3072xbf16>) -> tensor<2x1xbf16> attributes {OperatorType = "Reduce", compute_capability = "", hacc.function_kind = #hacc.function_kind, mindspore_kernel, process = "aicore"} { +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : bf16 +// CHECK-NEXT: %0 = bufferization.to_memref %arg0 : memref<2x3072xbf16, strided<[?, ?], offset: ?>> +// CHECK-NEXT: %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2xbf16> +// CHECK-NEXT: affine.for %arg1 = 0 to 2 { +// CHECK-NEXT: affine.store %cst, %alloc_0[%arg1] : memref<2xbf16> +// CHECK-NEXT: affine.for %arg2 = 0 to 3072 { +// CHECK-NEXT: %2 = affine.load %0[%arg1, %arg2] : memref<2x3072xbf16, strided<[?, ?], offset: ?>> +// CHECK-NEXT: %3 = arith.mulf %2, %2 : bf16 +// CHECK-NEXT: %4 = affine.load %alloc_0[%arg1] : memref<2xbf16> +// CHECK-NEXT: %5 = arith.addf %3, %4 : bf16 +// CHECK-NEXT: affine.store %5, %alloc_0[%arg1] : memref<2xbf16> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: %expand_shape = memref.expand_shape %alloc_0 [[0, 1]] output_shape [2, 1] : memref<2xbf16> into memref<2x1xbf16> +// CHECK-NEXT: %1 = bufferization.to_tensor %expand_shape : memref<2x1xbf16> +// CHECK-NEXT: return %1 : tensor<2x1xbf16> +// CHECK-NEXT: } + + +func.func @Fused_Mul_ReduceSum_split(%arg0: tensor<2x3072xbf16>) -> tensor<2x1xbf16> attributes {OperatorType = "Reduce", compute_capability = "", hacc.function_kind = #hacc.function_kind, mindspore_kernel, process = "aicore"} { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = bufferization.to_memref %arg0 : memref<2x3072xbf16, strided<[?, ?], offset: ?>> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x3072xbf16> + affine.for %arg1 = 0 to 2 { + affine.for %arg2 = 0 to 3072 { + %2 = affine.load %0[%arg1, %arg2] : memref<2x3072xbf16, strided<[?, ?], offset: ?>> + %3 = arith.mulf %2, %2 : bf16 + affine.store %3, %alloc[%arg1, %arg2] : memref<2x3072xbf16> + } + } + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2xbf16> + affine.for %arg1 = 0 to 2 { + affine.store %cst, %alloc_0[%arg1] : memref<2xbf16> + } + affine.for %arg1 = 0 to 2 { + affine.for %arg2 = 0 to 3072 { + %2 = affine.load %alloc[%arg1, %arg2] : memref<2x3072xbf16> + %3 = affine.load %alloc_0[%arg1] : memref<2xbf16> + %4 = arith.addf %2, %3 : bf16 + affine.store %4, %alloc_0[%arg1] : memref<2xbf16> + } + } + %expand_shape = memref.expand_shape %alloc_0 [[0, 1]] output_shape [2, 1] : memref<2xbf16> into memref<2x1xbf16> + %1 = bufferization.to_tensor %expand_shape : memref<2x1xbf16> + return %1 : tensor<2x1xbf16> +} \ No newline at end of file -- Gitee