From 113e939fec1c5add8c30e473d17b93848357d34e Mon Sep 17 00:00:00 2001 From: liuchao Date: Wed, 3 Dec 2025 11:36:29 +0800 Subject: [PATCH] add affine load&store remove pass --- .../Affine/Analysis/DependenceAnalysis.h | 4 +- .../include/akg/Utils/AnalysisCommon.hpp | 1 + .../Affine/Transforms/AKGLoopTiling.cpp | 4 +- .../Transforms/AffineIteratorConversion.cpp | 38 ++++++++----------- .../Transforms/AffineReductionAnnotation.cpp | 4 +- .../Dialect/GPU/Transforms/AKGGPUMapping.cpp | 14 +++---- .../Transforms/MatchAndMarkReductionOps.cpp | 14 +++---- .../RewriteReduceInMultiLevelMemory.cpp | 14 +++---- .../Pipelines/AscendPipelines/AscendOpt.cpp | 4 +- .../Affine/affine_reduction_op_matcher.mlir | 2 +- .../Affine/remove-redundant-loops.mlir | 4 +- .../match_and_mark_reduction_ops_affine.mlir | 4 +- .../ut/Transforms/affine_load_removal.mlir | 33 ++++++++++++++++ 13 files changed, 77 insertions(+), 63 deletions(-) create mode 100644 akg-mlir/tests/ut/Transforms/affine_load_removal.mlir diff --git a/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/DependenceAnalysis.h b/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/DependenceAnalysis.h index fb6fb031..2ac240f2 100644 --- a/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/DependenceAnalysis.h +++ b/akg-mlir/compiler/include/akg/Dialect/Affine/Analysis/DependenceAnalysis.h @@ -64,8 +64,6 @@ struct Edge { // MemRefDependenceGraph is a graph data structure where graph nodes are // top-level operations in a `Block` which contain load/store ops, and edges // are memref dependences between the nodes. -// TODO: Add a more flexible dependence graph representation. -// TODO: Add a depth parameter to dependence graph construction. struct MemRefDependenceGraph { public: // Map from node id to Node. @@ -84,7 +82,7 @@ struct MemRefDependenceGraph { // Whether to use AKG-specific analysis. bool useAKGAnalysis; - explicit MemRefDependenceGraph(Block *block, bool useAKGAnalysis = true) + explicit MemRefDependenceGraph(Block *block, bool useAKGAnalysis = false) : block(block), useAKGAnalysis(useAKGAnalysis) {} void createInitNode(DenseMap> &memrefAccesses); diff --git a/akg-mlir/compiler/include/akg/Utils/AnalysisCommon.hpp b/akg-mlir/compiler/include/akg/Utils/AnalysisCommon.hpp index 19a813cb..45b1a2ff 100644 --- a/akg-mlir/compiler/include/akg/Utils/AnalysisCommon.hpp +++ b/akg-mlir/compiler/include/akg/Utils/AnalysisCommon.hpp @@ -57,6 +57,7 @@ constexpr auto kOperatorTypeStr = "OperatorType"; constexpr auto kReduceStr = "Reduce"; constexpr auto kReductionAxesStr = "reduction_axes"; constexpr auto kReductionTypeStr = "reduction_type"; +constexpr auto kReductionLoopAttr = "reduction_loop"; constexpr auto kVectorize128Bit = 128; constexpr auto kVectorize256Bit = 256; constexpr auto kVectorize512Bit = 512; diff --git a/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AKGLoopTiling.cpp b/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AKGLoopTiling.cpp index 6d29ffaa..073e9579 100644 --- a/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AKGLoopTiling.cpp +++ b/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AKGLoopTiling.cpp @@ -125,9 +125,9 @@ class AKGLoopTiling : public impl::AKGAffineLoopTilingBase { bool avoidMaxMinBounds{true}; // hardware information std::string target{mlir::kTargetCpu}; - std::string feature{mlir::kNEONInstructionSet}; std::string tilingMode{"auto"}; [[maybe_unused]] std::string arch{}; + std::string feature{mlir::kNEONInstructionSet}; mlir::akg::autotiling::TilingSolverPtr solver{nullptr}; size_t levelToTile{1}; @@ -1010,7 +1010,7 @@ void AKGLoopTiling::runCpuOperation() { if (opType == mlir::OperatorTemplate::Reduce) { SmallVector reduceLoops = mlir::CommonUtils::collectReductionAxes(funcOp); for (auto reduceLoop : reduceLoops) { - reduceLoop->setAttr("reduceLoop", b.getUnitAttr()); + reduceLoop->setAttr(kReductionLoopAttr, b.getUnitAttr()); } } else if (opType == mlir::OperatorTemplate::Broadcast) { llvm::SmallSet allBroadcastFor; diff --git a/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AffineIteratorConversion.cpp b/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AffineIteratorConversion.cpp index c3a1346f..7501ca21 100644 --- a/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AffineIteratorConversion.cpp +++ b/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AffineIteratorConversion.cpp @@ -1,5 +1,5 @@ /** - * Copyright 2023 Huawei Technologies Co., Ltd + * Copyright 2023-2025 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,9 @@ */ #include "akg/Dialect/Affine/Transforms/AffineIteratorConversion.h" -#include "akg/Utils/AnalysisCommon.hpp" +#include +#include "akg/Utils/AnalysisCommon.hpp" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" @@ -33,6 +34,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" + namespace mlir { #define GEN_PASS_DECL_AFFINEITERATORCONVERSION #define GEN_PASS_DEF_AFFINEITERATORCONVERSION @@ -129,15 +131,7 @@ static affine::AffineLoadOp getLoadOp(affine::AffineForOp reduceForOp, Operation if (op != lhs.getDefiningOp() && op != rhs.getDefiningOp()) { return; } - auto indices = op.getIndices(); - bool flag = false; - for (auto index : indices) { - if (index == iv) { - flag = true; - break; - } - } - if (!flag) { + if (llvm::find(op.getIndices(), iv) == op.getIndices().end()) { loadOp = op; } }); @@ -147,7 +141,7 @@ static affine::AffineLoadOp getLoadOp(affine::AffineForOp reduceForOp, Operation static Operation *getInnermostReduceOp(Operation *curOp) { Operation *innermostReduceOp = nullptr; curOp->walk([&](affine::AffineForOp op) -> WalkResult { - if (op->getAttr("reduceLoop")) { + if (op->getAttr(kReductionLoopAttr)) { innermostReduceOp = op.getOperation(); return WalkResult::interrupt(); } @@ -176,7 +170,7 @@ void AffineIteratorConversion::loadRemoveEachBand(Operation *curOp) { return; } affine::AffineStoreOp initStoreOp = nullptr; - while (isa(reduceLoopOp) && reduceLoopOp->getAttr("reduceLoop")) { + while (isa(reduceLoopOp) && reduceLoopOp->getAttr(kReductionLoopAttr)) { affine::AffineForOp reduceLoop = cast(reduceLoopOp); // init load statement affine::AffineLoadOp loadOp = getLoadOp(reduceLoop, reduceArithOp); @@ -212,16 +206,15 @@ void AffineIteratorConversion::loadRemoveEachBand(Operation *curOp) { Operation *user = use.getOwner(); return newLoop->isProperAncestor(user); }); - b.setInsertionPoint(reduceLoop); + b.setInsertionPointAfter(newLoop.getOperation()); auto parentOp = newLoop.getOperation()->getParentOp(); - if (isa(parentOp) && parentOp->getAttr("reduceLoop")) { - CreateArithOp rewriter(b, newLoop, loadOp, storeOp, reduceArithOp); - reduceArithOp = identifyAndCreateArithOp(rewriter); + if (isa(parentOp) && parentOp->getAttr(kReductionLoopAttr)) { + CreateArithOp arithOpCreater(b, newLoop, loadOp, storeOp, reduceArithOp); + reduceArithOp = identifyAndCreateArithOp(arithOpCreater); } else { b.create(storeOp.getLoc(), newLoop.getResults().back(), storeOp.getMemRef(), storeOp.getAffineMapAttr().getValue(), storeOp.getIndices()); } - reduceLoop.erase(); loadOp.erase(); storeOp.erase(); reduceLoopOp = newLoop.getOperation()->getParentOp(); @@ -245,16 +238,15 @@ void AffineIteratorConversion::runOnOperation() { } removeInitMemoryCopy(func); - // todo(yanzhi): bugfix this function + SmallVector reduceLoops = CommonUtils::collectReductionAxes(func); for (auto reduceLoop : reduceLoops) { - reduceLoop->setAttr("reduceLoop", b.getUnitAttr()); + reduceLoop->setAttr(kReductionLoopAttr, b.getUnitAttr()); } SmallVector bands; - for (auto band : func.getOps()) { - bands.push_back(band); - } + (void)std::copy(func.getOps().begin(), func.getOps().end(), + std::back_inserter(bands)); for (auto band : bands) { loadRemoveEachBand(band); } diff --git a/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AffineReductionAnnotation.cpp b/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AffineReductionAnnotation.cpp index 8eb2742d..8f935500 100644 --- a/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AffineReductionAnnotation.cpp +++ b/akg-mlir/compiler/lib/Dialect/Affine/Transforms/AffineReductionAnnotation.cpp @@ -168,7 +168,7 @@ void AffineReductionAnnotation::annotateReductionOps(Operation *funcOp) { // collect reduction axes and update attributes SmallVector reduceLoops = CommonUtils::collectReductionAxes(funcOp); for (auto reduceLoop : reduceLoops) { - reduceLoop->setAttr("reduceLoop", builder.getUnitAttr()); + reduceLoop->setAttr(kReductionLoopAttr, builder.getUnitAttr()); } (void)funcOp->walk([&](Operation *redOp) { @@ -188,7 +188,7 @@ void AffineReductionAnnotation::annotateReductionOps(Operation *funcOp) { } } } - if (curOp->hasAttr("reduceLoop")) { + if (curOp->hasAttr(kReductionLoopAttr)) { redFlags.push_back(true); } else { redFlags.push_back(false); diff --git a/akg-mlir/compiler/lib/Dialect/GPU/Transforms/AKGGPUMapping.cpp b/akg-mlir/compiler/lib/Dialect/GPU/Transforms/AKGGPUMapping.cpp index 71b493fd..76c951d8 100644 --- a/akg-mlir/compiler/lib/Dialect/GPU/Transforms/AKGGPUMapping.cpp +++ b/akg-mlir/compiler/lib/Dialect/GPU/Transforms/AKGGPUMapping.cpp @@ -18,9 +18,9 @@ #include #include -#include #include #include +#include #include "akg/Dialect/MindSpore/IR/MindSporeOps.h" #include "akg/Utils/AKGGlobalVars.hpp" #include "akg/Utils/AnalysisCommon.hpp" @@ -274,7 +274,7 @@ bool isPostFusionSingleStmt(Operation *op) { } bool isPostFusionMultiStmt(Operation *op) { - if (auto andi = dyn_cast(op)) { + if (dyn_cast(op)) { for (auto operand : op->getOperands()) { if (isPostFusionMultiStmt(operand.getDefiningOp())) { return true; @@ -324,10 +324,8 @@ static bool canMoveOpOutOfTarget(Operation *op, Operation *targetOp) { for (auto operand : op->getOperands()) { SmallVector axesA; CommonUtils::collectRelatedAxes(operand, axesA); - for (auto a : axesA) { - if (targetOp == a) { - return false; - } + if (llvm::any_of(axesA, [targetOp](Operation *op) { return op == targetOp; })) { + return false; } } @@ -687,7 +685,7 @@ static void SetRedutionMarkToParallelOp(Operation *funcOp) { std::reverse(parallelOps.begin(), parallelOps.end()); for (auto attr : attrs) { auto idx = dyn_cast(attr).getInt(); - parallelOps[idx]->setAttr("reduceLoop", builder.getUnitAttr()); + parallelOps[idx]->setAttr(kReductionLoopAttr, builder.getUnitAttr()); } if (!redOp->hasAttr(kEnableParallelReduce)) { (void)redOp->emitWarning("This reduction op does not have a \"gpu_parallel_reduce\" mark, set to false."); @@ -790,7 +788,7 @@ void AKGGPUMappingLoops::createMappingTask(ParallelOp parallelOp) { for (auto [loopVar, lowerBoundVar, upperBoundVar, stepVar] : llvm::zip( parallelOp.getInductionVars(), parallelOp.getLowerBound(), parallelOp.getUpperBound(), parallelOp.getStep())) { size_t dim = getNestedNum(parallelOp.getOperation()); - bool isReduceAxis = (parallelOp.getOperation()->hasAttr("reduceLoop")) ? true : false; + bool isReduceAxis = (parallelOp.getOperation()->hasAttr(kReductionLoopAttr)) ? true : false; int reductionDim = isReduceAxis ? static_cast(dim) : -1; auto lbConst = getMaxIntConst(lowerBoundVar); auto ubConst = getMaxIntConst(upperBoundVar); diff --git a/akg-mlir/compiler/lib/Dialect/Linalg/Transforms/MatchAndMarkReductionOps.cpp b/akg-mlir/compiler/lib/Dialect/Linalg/Transforms/MatchAndMarkReductionOps.cpp index 950b64fc..9a7a6377 100644 --- a/akg-mlir/compiler/lib/Dialect/Linalg/Transforms/MatchAndMarkReductionOps.cpp +++ b/akg-mlir/compiler/lib/Dialect/Linalg/Transforms/MatchAndMarkReductionOps.cpp @@ -42,9 +42,6 @@ namespace mlir { #include "akg/Dialect/Linalg/Passes.h.inc" } // namespace mlir -using namespace mlir; -using namespace akgglobal; - namespace mlir { namespace linalg { namespace { @@ -84,7 +81,7 @@ static void MatchAndMarkRedOpInLinalg(Operation *funcOp) { } auto strAttr = builder.getStringAttr(reduceDirectionMap.at(reduceDirection)); op->setAttr(kReductionTypeStr, strAttr); - GpuScheduleTool::getInstance().setReduceDirection((size_t)reduceDirection); + akgglobal::GpuScheduleTool::getInstance().setReduceDirection((size_t)reduceDirection); } }); } @@ -94,7 +91,7 @@ static void MatchAndMarkRedOpInAffine(Operation *funcOp) { OpBuilder builder(funcOp); SmallVector reduceLoops = CommonUtils::collectReductionAxes(funcOp); for (auto reduceLoop : reduceLoops) { - reduceLoop->setAttr("reduceLoop", builder.getUnitAttr()); + reduceLoop->setAttr(kReductionLoopAttr, builder.getUnitAttr()); } (void)funcOp->walk([&](Operation *redOp) { if (!isa(redOp) && redOp->hasAttr(kReductionAxesStr)) { @@ -102,7 +99,7 @@ static void MatchAndMarkRedOpInAffine(Operation *funcOp) { auto curOp = redOp; while (curOp) { if (isa(curOp)) { - if (curOp->hasAttr("reduceLoop")) { + if (curOp->hasAttr(kReductionLoopAttr)) { redFlags.push_back(true); } else { redFlags.push_back(false); @@ -128,12 +125,11 @@ static void MatchAndMarkRedOpInAffine(Operation *funcOp) { struct MatchAndMarkReductionOps : public impl::MatchAndMarkReductionOpsBase { MatchAndMarkReductionOps() = default; - explicit MatchAndMarkReductionOps(const std::string dialect) { this->dialect = dialect; } + explicit MatchAndMarkReductionOps(const std::string &dialect) { this->dialect = dialect; } void runOnOperation() override { Operation *funcOp = getOperation(); - if (!(funcOp->hasAttr(kOperatorTypeStr) && - dyn_cast(funcOp->getAttr(kOperatorTypeStr)) == kReduceStr)) { + if (!(funcOp->hasAttr(kOperatorTypeStr) && dyn_cast(funcOp->getAttr(kOperatorTypeStr)) == kReduceStr)) { return; } if (this->dialect == "linalg") { diff --git a/akg-mlir/compiler/lib/Dialect/SCF/Transforms/RewriteReduceInMultiLevelMemory.cpp b/akg-mlir/compiler/lib/Dialect/SCF/Transforms/RewriteReduceInMultiLevelMemory.cpp index f2499226..f9e253c8 100644 --- a/akg-mlir/compiler/lib/Dialect/SCF/Transforms/RewriteReduceInMultiLevelMemory.cpp +++ b/akg-mlir/compiler/lib/Dialect/SCF/Transforms/RewriteReduceInMultiLevelMemory.cpp @@ -35,21 +35,17 @@ #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/DialectConversion.h" -using namespace akgglobal; - namespace mlir { #define GEN_PASS_DECL_REWRITEREDUCEINMULTILEVELMEMORY #define GEN_PASS_DEF_REWRITEREDUCEINMULTILEVELMEMORY #include "akg/Dialect/SCF/Passes.h.inc" } // namespace mlir -using namespace mlir; - namespace mlir { namespace scf { namespace { -void cloneAndReplaceOps(SmallVector &ops, OpBuilder &builder) { +void cloneAndReplaceOps(const SmallVector &ops, OpBuilder &builder) { for (Operation *op : ops) { if (op) { Operation *clonedOp = builder.clone(*op); @@ -86,8 +82,8 @@ Value createInitialValue(Operation *op, mlir::Location loc, OpBuilder &builder) initialValue = builder.create(loc, std::numeric_limits::max(), cast(elementType)); } else if (isa(op)) { - initialValue = builder.create(loc, std::numeric_limits::lowest(), - cast(elementType)); + initialValue = + builder.create(loc, std::numeric_limits::lowest(), cast(elementType)); } else if (isa(op)) { initialValue = builder.create(loc, std::numeric_limits::max(), cast(elementType)); @@ -136,7 +132,7 @@ static Operation *getOutermostSeqLoop(Operation *redOp) { Operation *outerSeqReduceLoop = nullptr; auto curOp = redOp; while (curOp) { - if (isa(curOp) && curOp->getAttr("reduceLoop")) { + if (isa(curOp) && curOp->getAttr(kReductionLoopAttr)) { if (gpu::GpuAttrUtils::getProcessorFromParallelOp(curOp) == gpu::Processor::Sequential) { outerSeqReduceLoop = curOp; } else { @@ -172,7 +168,7 @@ struct RewriteReduceInMultiLevelMemory void runOnOperation() override { auto funcOp = getOperation(); SmallVector redOps; - bool isReduceY = GpuScheduleTool::getInstance().getReduceDirection() == (unsigned long)ReduceDirection::Y; + bool isReduceY = akgglobal::GpuScheduleTool::getInstance().getReduceDirection() == (size_t)ReduceDirection::Y; funcOp.walk([&](Operation *op) { if (!isa(op)) { bool parallelReduce = (op->hasAttr(akg::utils::kEnableParallelReduce) && diff --git a/akg-mlir/compiler/lib/Pipelines/AscendPipelines/AscendOpt.cpp b/akg-mlir/compiler/lib/Pipelines/AscendPipelines/AscendOpt.cpp index 464364cf..a2b20a25 100644 --- a/akg-mlir/compiler/lib/Pipelines/AscendPipelines/AscendOpt.cpp +++ b/akg-mlir/compiler/lib/Pipelines/AscendPipelines/AscendOpt.cpp @@ -1,5 +1,5 @@ /** - * Copyright 2024 Huawei Technologies Co., Ltd + * Copyright 2024-2025 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,7 +99,7 @@ void createAscendOptPipelineImpl(OpPassManager &pm, const mlir::AscendOptPipelin nestedFusionPM.addPass(mlir::createCanonicalizerPass()); // vector - // nestedFusionPM.addPass(mlir::createAffineIteratorConversionPass()); + nestedFusionPM.addPass(mlir::createAffineIteratorConversionPass()); // nestedFusionPM.addPass(mlir::createExtractIfOpPass(options.target)); nestedFusionPM.addPass(mlir::affine::createAffineForVectPass()); diff --git a/akg-mlir/tests/ut/Dialect/Affine/affine_reduction_op_matcher.mlir b/akg-mlir/tests/ut/Dialect/Affine/affine_reduction_op_matcher.mlir index 4b32f244..7af1d085 100644 --- a/akg-mlir/tests/ut/Dialect/Affine/affine_reduction_op_matcher.mlir +++ b/akg-mlir/tests/ut/Dialect/Affine/affine_reduction_op_matcher.mlir @@ -27,7 +27,7 @@ // CHECK-NEXT: %2 = arith.addf %0, %1 {reduction_axes = [1 : index], reduction_type = "y"} : f32 // CHECK-NEXT: affine.store %2, %alloc_0[%arg1, %arg3] : memref<1x3072xf32> // CHECK-NEXT: } -// CHECK-NEXT: } {reduceLoop} +// CHECK-NEXT: } {reduction_loop} // CHECK-NEXT: } // CHECK-NEXT: %expand_shape = memref.expand_shape %alloc_0 {{\[\[0\], \[1, 2\]\]}} output_shape [1, 1, 3072] : memref<1x3072xf32> into memref<1x1x3072xf32> // CHECK-NEXT: return %expand_shape : memref<1x1x3072xf32> diff --git a/akg-mlir/tests/ut/Dialect/Affine/remove-redundant-loops.mlir b/akg-mlir/tests/ut/Dialect/Affine/remove-redundant-loops.mlir index aa3b2eb1..50e9a6e7 100644 --- a/akg-mlir/tests/ut/Dialect/Affine/remove-redundant-loops.mlir +++ b/akg-mlir/tests/ut/Dialect/Affine/remove-redundant-loops.mlir @@ -16,7 +16,7 @@ // CHECK-NEXT: %1 = vector.transfer_read %arg1[%arg2, %arg3], %cst_2 : memref<4954x3xf32>, vector<4xf32> // CHECK-NEXT: %2 = arith.addf %0, %1 {reduction_axes = [2 : index], reduction_type = "y"} : vector<4xf32> // CHECK-NEXT: vector.transfer_write %2, %arg1[%arg2, %arg3] : vector<4xf32>, memref<4954x3xf32> -// CHECK-NEXT: } {reduceLoop} +// CHECK-NEXT: } {reduction_loop} // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -46,7 +46,7 @@ func.func @Fused_ReduceSum_split_2679919397770605199(%arg0: memref<4954x4953x3xf } } } - } {reduceLoop} + } {reduction_loop} } } return diff --git a/akg-mlir/tests/ut/Dialect/Linalg/match_and_mark_reduction_ops_affine.mlir b/akg-mlir/tests/ut/Dialect/Linalg/match_and_mark_reduction_ops_affine.mlir index a02663fd..a0f6d5a7 100644 --- a/akg-mlir/tests/ut/Dialect/Linalg/match_and_mark_reduction_ops_affine.mlir +++ b/akg-mlir/tests/ut/Dialect/Linalg/match_and_mark_reduction_ops_affine.mlir @@ -30,9 +30,9 @@ // CHECK-NEXT: %8 = math.log %7 : f32 // CHECK-NEXT: affine.store %8, %arg3[%arg6] : memref<233008xf32> // CHECK-NEXT: } -// CHECK-NEXT: } {reduceLoop} +// CHECK-NEXT: } {reduction_loop} // CHECK-NEXT: } -// CHECK-NEXT: } {reduceLoop} +// CHECK-NEXT: } {reduction_loop} // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } diff --git a/akg-mlir/tests/ut/Transforms/affine_load_removal.mlir b/akg-mlir/tests/ut/Transforms/affine_load_removal.mlir new file mode 100644 index 00000000..f90fb3a7 --- /dev/null +++ b/akg-mlir/tests/ut/Transforms/affine_load_removal.mlir @@ -0,0 +1,33 @@ +// RUN: akg-opt %s --affine-load-removal | FileCheck %s + +// CHECK-LABEL: func.func @reduction(%arg0: memref<256x512x1024xf32>, %arg1: memref<256x512xf32>) attributes {OperatorType = "Reduce"} { +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: affine.for %arg2 = 0 to 256 { +// CHECK-NEXT: affine.for %arg3 = 0 to 512 { +// CHECK-NEXT: %0 = affine.for %arg4 = 0 to 1024 iter_args(%arg5 = %cst) -> (f32) { +// CHECK-NEXT: %1 = affine.load %arg0[%arg2, %arg3, %arg4] : memref<256x512x1024xf32> +// CHECK-NEXT: %2 = arith.addf %arg5, %1 {reduction_axes = [2 : index], reduction_type = "x"} : f32 +// CHECK-NEXT: affine.yield %2 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: affine.store %0, %arg1[%arg2, %arg3] : memref<256x512xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + + +func.func @reduction(%in: memref<256x512x1024xf32>, %out: memref<256x512xf32>) attributes {OperatorType = "Reduce"} { + %cst = arith.constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + affine.for %j = 0 to 512 { + affine.store %cst, %out[%i, %j] : memref<256x512xf32> + affine.for %k = 0 to 1024 { + %ld = affine.load %in[%i, %j, %k] : memref<256x512x1024xf32> + %sum = affine.load %out[%i, %j] : memref<256x512xf32> + %add = arith.addf %sum, %ld {reduction_axes = [2:index], reduction_type = "x"}: f32 + affine.store %add, %out[%i, %j] : memref<256x512xf32> + } {reduction_loop} + } + } + return +} -- Gitee