From 6b65b776d4409614afe68dab3e0bbf6e3abd3059 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8A=B1=E6=97=A0=E6=87=BF?= Date: Sat, 29 Nov 2025 17:24:57 +0800 Subject: [PATCH] Application scenarios of generalized vectorization --- .../Transforms/VectorTransferTensorize.cpp | 768 +++++++++++++++--- .../ut/Dialect/Affine/vector_tensor.mlir | 70 +- 2 files changed, 675 insertions(+), 163 deletions(-) diff --git a/akg-mlir/compiler/lib/Dialect/Affine/Transforms/VectorTransferTensorize.cpp b/akg-mlir/compiler/lib/Dialect/Affine/Transforms/VectorTransferTensorize.cpp index 5ee6f9b3..c1b824d5 100644 --- a/akg-mlir/compiler/lib/Dialect/Affine/Transforms/VectorTransferTensorize.cpp +++ b/akg-mlir/compiler/lib/Dialect/Affine/Transforms/VectorTransferTensorize.cpp @@ -27,6 +27,7 @@ #include #include +#include #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" @@ -44,8 +45,10 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "vector-transfer-tensorize" @@ -59,11 +62,8 @@ namespace memref = mlir::memref; namespace tensor = mlir::tensor; namespace func = mlir::func; namespace affine = mlir::affine; -namespace func = mlir::func; -//------------------------------------------------------------------------------ // Utilities -//------------------------------------------------------------------------------ static RankedTensorType convertMemrefToTensorType(MemRefType memrefType) { if (!memrefType || !memrefType.hasStaticShape()) return {}; @@ -71,55 +71,74 @@ static RankedTensorType convertMemrefToTensorType(MemRefType memrefType) { } static bool isNumericTypeLike(Type type) { - if (auto rankedTensorType = type.dyn_cast()) + if (auto rankedTensorType = mlir::dyn_cast(type)) return rankedTensorType.getElementType().isIntOrFloat(); - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = mlir::dyn_cast(type)) return vectorType.getElementType().isIntOrFloat(); return type.isIntOrFloat(); } -//------------------------------------------------------------------------------ // Tensorization state -//------------------------------------------------------------------------------ + +struct PendingWriteInfo { + vector::TransferWriteOp writeOp; + tensor::InsertSliceOp insertSliceOp; + Value logicalBase; +}; struct TensorizationState { explicit TensorizationState(MLIRContext *context) : builder(context) {} OpBuilder builder; - IRMapping valueMapping; // original value → tensor value + IRMapping valueMapping; llvm::DenseSet operationsToErase; - llvm::DenseSet invalidatedOps; + llvm::DenseSet invalidatedOperations; + + llvm::SmallVector pendingWrites; void map(Value original, Value tensorized) { valueMapping.map(original, tensorized); } + void replaceMapping(Value original, Value tensorized) { valueMapping.erase(original); map(original, tensorized); } + Value lookup(Value value) const { return valueMapping.lookupOrNull(value); } + void markForErasure(Operation *operation) { operationsToErase.insert(operation); } - void trackInvalid(Operation *operation) { invalidatedOps.insert(operation); } - bool isInvalid(Operation *operation) const { return invalidatedOps.contains(operation); } + + void trackInvalid(Operation *operation) { invalidatedOperations.insert(operation); } + + bool isInvalid(Operation *operation) const { return invalidatedOperations.contains(operation); } }; -//------------------------------------------------------------------------------ +static void markInvalidRecursively(Operation *op, TensorizationState &state) { + if (!op) return; + state.trackInvalid(op); + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : block) { + markInvalidRecursively(&nestedOp, state); + } + } + } +} + // memref.alloc → tensor.empty -//------------------------------------------------------------------------------ static FailureOr convertAllocOp(memref::AllocOp allocOp, TensorizationState &state) { auto tensorType = convertMemrefToTensorType(allocOp.getType()); if (!tensorType) return failure(); - auto emptyTensor = + auto emptyTensorOp = state.builder.create(allocOp.getLoc(), tensorType.getShape(), tensorType.getElementType()); - state.map(allocOp.getResult(), emptyTensor.getResult()); + state.map(allocOp.getResult(), emptyTensorOp.getResult()); state.markForErasure(allocOp); - return emptyTensor; + return emptyTensorOp; } -//------------------------------------------------------------------------------ // bufferization.to_{memref|tensor} -//------------------------------------------------------------------------------ static void convertToMemrefOp(buffer::ToMemrefOp toMemrefOp, TensorizationState &state) { state.map(toMemrefOp.getResult(), toMemrefOp.getTensor()); @@ -134,96 +153,207 @@ static LogicalResult convertToTensorOp(buffer::ToTensorOp toTensorOp, Tensorizat return success(); } -//------------------------------------------------------------------------------ // tensor.extract_slice helpers -//------------------------------------------------------------------------------ static FailureOr buildOneDimensionalExtractSlice(OpBuilder &builder, Location loc, Value sourceTensor, ValueRange indices, unsigned sliceLength) { - auto tensorType = sourceTensor.getType().dyn_cast(); - if (!tensorType || !tensorType.hasStaticShape()) return failure(); + auto sourceTensorType = mlir::dyn_cast(sourceTensor.getType()); + if (!sourceTensorType || !sourceTensorType.hasStaticShape()) return failure(); - unsigned rank = tensorType.getRank(); + unsigned rank = sourceTensorType.getRank(); SmallVector strides(rank, builder.getIndexAttr(1)); SmallVector offsets(indices.begin(), indices.end()); SmallVector sizes(rank, builder.getIndexAttr(1)); - int64_t fullLength = tensorType.getShape().back(); + int64_t fullLength = sourceTensorType.getShape().back(); int64_t length = sliceLength ? static_cast(sliceLength) : fullLength; sizes.back() = builder.getIndexAttr(length); - auto resultTensorType = RankedTensorType::get({length}, tensorType.getElementType()); + auto resultTensorType = RankedTensorType::get({length}, sourceTensorType.getElementType()); return builder.create(loc, resultTensorType, sourceTensor, offsets, sizes, strides); } -//------------------------------------------------------------------------------ // const → tensor -//------------------------------------------------------------------------------ - -static FailureOr createTensorConstantFromScalar(arith::ConstantOp scalarConstant, - RankedTensorType targetTensorType, - OpBuilder &builder) { - Attribute valueAttr = scalarConstant.getValue(); - - if (auto denseAttr = valueAttr.dyn_cast()) { - if (denseAttr.getType() == targetTensorType) - return builder.create(scalarConstant.getLoc(), denseAttr); - if (denseAttr.getNumElements() == targetTensorType.getNumElements()) - return builder.create(scalarConstant.getLoc(), denseAttr.reshape(targetTensorType)); + +static FailureOr createTensorConstantFromScalar( + arith::ConstantOp scalarConstantOp, + RankedTensorType targetTensorType, + OpBuilder &builder) { + if (!targetTensorType) return failure(); + + Attribute valueAttribute = scalarConstantOp.getValue(); + + if (auto denseElements = valueAttribute.dyn_cast()) { + if (mlir::isa(denseElements.getType())) { + if (denseElements.getType() == targetTensorType) { + return builder.create(scalarConstantOp.getLoc(), denseElements); + } + if (denseElements.getNumElements() == targetTensorType.getNumElements()) { + return builder.create(scalarConstantOp.getLoc(), + denseElements.reshape(targetTensorType)); + } + return failure(); + } + + if (auto vectorType = mlir::dyn_cast(denseElements.getType())) { + if (denseElements.getNumElements() == targetTensorType.getNumElements() && + vectorType.getElementType() == targetTensorType.getElementType()) { + return builder.create(scalarConstantOp.getLoc(), + denseElements.reshape(targetTensorType)); + } + + Type elementType = targetTensorType.getElementType(); + SmallVector elementAttributes; + elementAttributes.reserve(denseElements.getNumElements()); + + if (elementType.isa()) { + auto values = denseElements.getValues(); + elementAttributes.reserve(elementAttributes.size() + std::distance(values.begin(), values.end())); + std::transform(values.begin(), values.end(), std::back_inserter(elementAttributes), + [&](const APFloat &v) { + return FloatAttr::get(elementType, v); + }); + } else if (elementType.isa()) { + auto values = denseElements.getValues(); + elementAttributes.reserve(elementAttributes.size() + std::distance(values.begin(), values.end())); + std::transform(values.begin(), values.end(), std::back_inserter(elementAttributes), + [&](const APInt &v) { + return IntegerAttr::get(elementType, v); + }); + } else { + auto attrs = denseElements.getValues(); + elementAttributes.reserve(elementAttributes.size() + std::distance(attrs.begin(), attrs.end())); + std::copy(attrs.begin(), attrs.end(), std::back_inserter(elementAttributes)); + } + + auto tensorDense = DenseElementsAttr::get(targetTensorType, elementAttributes); + auto constantOp = builder.create(scalarConstantOp.getLoc(), tensorDense); + auto newDense = dyn_cast(constantOp.getValue()); + if (!newDense || newDense.getType() != targetTensorType) { + return failure(); + } + return constantOp; + } + + if (denseElements.getNumElements() == targetTensorType.getNumElements()) { + return builder.create(scalarConstantOp.getLoc(), + denseElements.reshape(targetTensorType)); + } return failure(); } - Attribute elementAttr = valueAttr; + Attribute elementAttribute = valueAttribute; Type elementType = targetTensorType.getElementType(); - if (auto intAttr = valueAttr.dyn_cast()) - elementAttr = (intAttr.getType() == elementType) ? elementAttr : IntegerAttr::get(elementType, intAttr.getInt()); - else if (auto floatAttr = valueAttr.dyn_cast()) - elementAttr = - (floatAttr.getType() == elementType) ? elementAttr : FloatAttr::get(elementType, floatAttr.getValue()); - else + if (auto intAttr = valueAttribute.dyn_cast()) { + elementAttribute = (intAttr.getType() == elementType) + ? elementAttribute + : IntegerAttr::get(elementType, intAttr.getInt()); + } else if (auto floatAttr = valueAttribute.dyn_cast()) { + elementAttribute = (floatAttr.getType() == elementType) + ? elementAttribute + : FloatAttr::get(elementType, floatAttr.getValue()); + } else { return failure(); + } - SmallVector repeatedValues(targetTensorType.getNumElements(), elementAttr); - auto denseAttr = DenseElementsAttr::get(targetTensorType, repeatedValues); - return builder.create(scalarConstant.getLoc(), denseAttr); + SmallVector repeatedValues(targetTensorType.getNumElements(), elementAttribute); + auto tensorDense = DenseElementsAttr::get(targetTensorType, repeatedValues); + auto constantOp = builder.create(scalarConstantOp.getLoc(), tensorDense); + auto newDense = dyn_cast(constantOp.getValue()); + if (!newDense || newDense.getType() != targetTensorType) { + return failure(); + } + return constantOp; } static LogicalResult upgradeConstantToTensor(arith::ConstantOp constantOp, RankedTensorType targetTensorType, TensorizationState &state) { if (state.lookup(constantOp.getResult())) return success(); auto tensorConstantOr = createTensorConstantFromScalar(constantOp, targetTensorType, state.builder); - if (failed(tensorConstantOr)) return failure(); + if (failed(tensorConstantOr)) { + return failure(); + } state.map(constantOp.getResult(), (*tensorConstantOr).getResult()); return success(); } -//------------------------------------------------------------------------------ -// transfer_read / transfer_write -//------------------------------------------------------------------------------ +static LogicalResult replaceVectorConstantWithTensor(arith::ConstantOp constantOp, OpBuilder &builder) { + auto vectorType = mlir::dyn_cast(constantOp.getType()); + if (!vectorType) return failure(); + auto denseElements = constantOp.getValue().dyn_cast(); + if (!denseElements || !mlir::isa(denseElements.getType())) return failure(); + + auto targetTensorType = RankedTensorType::get(vectorType.getShape(), vectorType.getElementType()); + builder.setInsertionPoint(constantOp); + + if (denseElements.getNumElements() == targetTensorType.getNumElements() && + vectorType.getElementType() == targetTensorType.getElementType()) { + auto reshaped = denseElements.reshape(targetTensorType); + auto newConstant = builder.create(constantOp.getLoc(), reshaped); + constantOp.getResult().replaceAllUsesWith(newConstant.getResult()); + constantOp.erase(); + return success(); + } + + SmallVector elementAttributes; + elementAttributes.reserve(denseElements.getNumElements()); + Type elemTy = vectorType.getElementType(); + + if (elemTy.isa()) { + auto values = denseElements.getValues(); + std::transform(values.begin(), values.end(), + std::back_inserter(elementAttributes), + [&](const APFloat &v) { + return FloatAttr::get(elemTy, v); + }); + } else if (elemTy.isa()) { + auto values = denseElements.getValues(); + std::transform(values.begin(), values.end(), + std::back_inserter(elementAttributes), + [&](const APInt &v) { + return IntegerAttr::get(elemTy, v); + }); + } else { + auto attrs = denseElements.getValues(); + std::copy(attrs.begin(), attrs.end(), + std::back_inserter(elementAttributes)); + } + auto tensorDense = DenseElementsAttr::get(targetTensorType, elementAttributes); + auto newConstant = builder.create(constantOp.getLoc(), tensorDense); + + constantOp.getResult().replaceAllUsesWith(newConstant.getResult()); + constantOp.erase(); + return success(); +} + +// transfer_read static FailureOr convertTransferReadOp(vector::TransferReadOp readOp, TensorizationState &state) { Value sourceTensor = state.lookup(readOp.getSource()); if (!sourceTensor) return failure(); - arith::ConstantOp paddingConstant = - readOp.getPadding() ? readOp.getPadding().getDefiningOp() : nullptr; - unsigned sliceLength = readOp.getVectorType().getNumElements(); - auto extractSliceOr = - buildOneDimensionalExtractSlice(state.builder, readOp.getLoc(), sourceTensor, readOp.getIndices(), sliceLength); + auto extractSliceOr = buildOneDimensionalExtractSlice(state.builder, readOp.getLoc(), sourceTensor, + readOp.getIndices(), sliceLength); if (failed(extractSliceOr)) return failure(); state.map(readOp.getResult(), (*extractSliceOr).getResult()); state.markForErasure(readOp); - if (paddingConstant) state.markForErasure(paddingConstant.getOperation()); + if (readOp.getPadding()) { + if (auto paddingConstantOp = readOp.getPadding().getDefiningOp()) { + state.markForErasure(paddingConstantOp.getOperation()); + } + } return extractSliceOr->getOperation(); } +// transfer_write (phase 1) + static FailureOr convertTransferWriteOp(vector::TransferWriteOp writeOp, TensorizationState &state) { - Value destinationTensor = state.lookup(writeOp.getSource()); - if (!destinationTensor) return failure(); + Value logicalBase = state.lookup(writeOp.getSource()); + if (!logicalBase) return failure(); unsigned sliceLength = writeOp.getVectorType().getNumElements(); auto sliceTensorType = RankedTensorType::get({sliceLength}, writeOp.getVectorType().getElementType()); @@ -238,68 +368,407 @@ static FailureOr convertTransferWriteOp(vector::TransferWriteOp wri } if (!sourceSlice || sourceSlice.getType() != sliceTensorType) return failure(); - auto destinationTensorType = destinationTensor.getType().cast(); + auto destinationTensorType = logicalBase.getType().cast(); unsigned rank = destinationTensorType.getRank(); SmallVector strides(rank, state.builder.getIndexAttr(1)); SmallVector sizes(rank, state.builder.getIndexAttr(1)); sizes.back() = state.builder.getIndexAttr(sliceLength); SmallVector offsets(writeOp.getIndices().begin(), writeOp.getIndices().end()); - auto insertSliceOp = state.builder.create(writeOp.getLoc(), sourceSlice, destinationTensor, - offsets, sizes, strides); + auto insertSliceOp = + state.builder.create(writeOp.getLoc(), sourceSlice, logicalBase, offsets, sizes, strides); + + state.pendingWrites.push_back(PendingWriteInfo{ + writeOp, + insertSliceOp, + logicalBase, + }); + + state.markForErasure(writeOp); + return insertSliceOp.getOperation(); +} + +// Phase 2 helpers + +using PendingGroupMap = llvm::DenseMap>; + +struct LiveOutCandidate { + Value origBase; + Value currentValue; + affine::AffineForOp loop; +}; + +struct PropagateRequest { + Value origBase; + Value currentVal; +}; + +static LogicalResult propagateLiveOuts(SmallVectorImpl &liveOutCandidates, + TensorizationState &state) { + SmallVector currentLevel(liveOutCandidates.begin(), liveOutCandidates.end()); + SmallVector nextLevel; + + llvm::DenseSet> seen; + + while (!currentLevel.empty()) { + llvm::DenseMap> requestsPerOuterLoop; + + for (const auto &candidate : currentLevel) { + affine::AffineForOp innerLoop = candidate.loop; + if (!innerLoop) continue; + + affine::AffineForOp outerLoop = innerLoop->getParentOfType(); + if (!outerLoop) continue; + + Value currentValue = candidate.currentValue; + if (!currentValue) continue; + + bool usedOutsideOuter = false; + for (OpOperand &use : currentValue.getUses()) { + Operation *user = use.getOwner(); + if (!outerLoop->isProperAncestor(user) && user != outerLoop.getOperation()) { + usedOutsideOuter = true; + break; + } + } + if (!usedOutsideOuter) continue; + + auto key = std::make_pair(outerLoop.getOperation(), candidate.origBase); + if (!seen.insert(key).second) continue; + + auto &byBase = requestsPerOuterLoop[outerLoop]; + auto it = byBase.find(candidate.origBase); + if (it == byBase.end()) { + byBase.insert({candidate.origBase, PropagateRequest{candidate.origBase, currentValue}}); + } else { + it->second.currentVal = currentValue; + } + } + + if (requestsPerOuterLoop.empty()) break; + + nextLevel.clear(); + + for (auto &kv : requestsPerOuterLoop) { + affine::AffineForOp outerLoop = kv.first; + auto &byBase = kv.second; + if (!outerLoop || byBase.empty()) continue; + + IRRewriter rewriter(outerLoop.getContext()); + rewriter.setInsertionPoint(outerLoop); + + unsigned oldNumResults = outerLoop.getNumResults(); + Block *oldBody = outerLoop.getBody(); + unsigned oldNumBodyArgs = oldBody->getNumArguments(); + + SmallVector initOperands; + SmallVector currentValues; + SmallVector originalBases; + + for (auto &it2 : byBase) { + Value originalBase = it2.getFirst(); + PropagateRequest &request = it2.getSecond(); + Value initValue = originalBase; + initOperands.push_back(initValue); + currentValues.push_back(request.currentVal); + originalBases.push_back(originalBase); + } + + auto newOuterLoop = cast( + *outerLoop.replaceWithAdditionalYields( + rewriter, + /*initOperands=*/initOperands, + /*replaceInitOperandUsesInLoop=*/false, + [&](const OpBuilder &yieldBuilder, Location, ArrayRef) { + (void)yieldBuilder; + SmallVector yieldValues(currentValues.begin(), currentValues.end()); + return yieldValues; + })); + + unsigned numAdded = initOperands.size(); + unsigned initStartIndex = newOuterLoop->getNumOperands() - numAdded; + + Block *newBody = newOuterLoop.getBody(); + unsigned newNumBodyArgs = newBody->getNumArguments(); + assert(newNumBodyArgs == oldNumBodyArgs + numAdded && + "unexpected number of body arguments after hoisting live-outs"); + + for (size_t i = 0; i < originalBases.size(); ++i) { + Value originalBase = originalBases[i]; + Value currentValue = currentValues[i]; + Value outerResult = newOuterLoop.getResults()[oldNumResults + i]; + BlockArgument outerIterArg = newBody->getArgument(oldNumBodyArgs + i); + + auto replaceOutsideUses = [&](OpOperand &use) { + Operation *user = use.getOwner(); + if (user == newOuterLoop.getOperation()) { + unsigned operandNumber = use.getOperandNumber(); + if (operandNumber >= initStartIndex && operandNumber < initStartIndex + numAdded) + return false; + } + if (newOuterLoop->isProperAncestor(user)) return false; + return true; + }; + + currentValue.replaceUsesWithIf(outerResult, replaceOutsideUses); + originalBase.replaceUsesWithIf(outerResult, replaceOutsideUses); + + auto replaceInsideBody = [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newOuterLoop->isProperAncestor(user); + }; + originalBase.replaceUsesWithIf(outerIterArg, replaceInsideBody); + + state.replaceMapping(originalBase, outerResult); + + LiveOutCandidate nextCandidate; + nextCandidate.origBase = originalBase; + nextCandidate.currentValue = outerResult; + nextCandidate.loop = newOuterLoop; + nextLevel.push_back(nextCandidate); + } + } - affine::AffineForOp currentLoop = writeOp->getParentOfType(); - if (!currentLoop) { - state.markForErasure(writeOp); - return insertSliceOp.getOperation(); + currentLevel.swap(nextLevel); } - Value currentInsideValue = insertSliceOp.getResult(); - Value currentDestination = destinationTensor; - Value originalInitTensor = destinationTensor; - Operation *outermostChangedLoop = nullptr; + return success(); +} + +static LogicalResult processLoopAllBaseTensors(affine::AffineForOp loopOp, + PendingGroupMap &groupsByBaseTensor, + TensorizationState &state, + SmallVectorImpl &liveOutCandidates) { + if (groupsByBaseTensor.empty()) return success(); + + auto valueBefore = [](Value lhs, Value rhs) -> bool { + if (lhs == rhs) return false; + bool lhsIsArg = lhs.isa(); + bool rhsIsArg = rhs.isa(); + if (lhsIsArg != rhsIsArg) return lhsIsArg; + if (lhsIsArg && rhsIsArg) { + BlockArgument lhsArg = lhs.cast(); + BlockArgument rhsArg = rhs.cast(); + if (lhsArg.getOwner() != rhsArg.getOwner()) + return lhsArg.getOwner() < rhsArg.getOwner(); + return lhsArg.getArgNumber() < rhsArg.getArgNumber(); + } + Operation *lhsOp = lhs ? lhs.getDefiningOp() : nullptr; + Operation *rhsOp = rhs ? rhs.getDefiningOp() : nullptr; + if (lhsOp == rhsOp) { + return lhs.cast().getResultNumber() < rhs.cast().getResultNumber(); + } + Block *lhsBlock = lhsOp ? lhsOp->getBlock() : nullptr; + Block *rhsBlock = rhsOp ? rhsOp->getBlock() : nullptr; + if (lhsBlock == rhsBlock) { + return lhsOp->isBeforeInBlock(rhsOp); + } + return lhsBlock < rhsBlock; + }; + + auto opBefore = [](Operation *lhs, Operation *rhs) -> bool { + if (lhs == rhs) return false; + Block *lhsBlock = lhs ? lhs->getBlock() : nullptr; + Block *rhsBlock = rhs ? rhs->getBlock() : nullptr; + if (lhsBlock == rhsBlock) return lhs->isBeforeInBlock(rhs); + return lhsBlock < rhsBlock; + }; + + SmallVector logicalBases; + SmallVector currentBases; + SmallVector> allInsertSlicesPerBase; + + SmallVector>, 8> groupedEntries; + groupedEntries.reserve(groupsByBaseTensor.size()); + + std::transform(groupsByBaseTensor.begin(), groupsByBaseTensor.end(), + std::back_inserter(groupedEntries), + [](auto &entry) -> std::pair> { + return {entry.first, entry.second}; + }); + + llvm::sort(groupedEntries, [&](const auto &lhs, const auto &rhs) { return valueBefore(lhs.first, rhs.first); }); + + for (auto &pair : groupedEntries) { + Value logicalBase = pair.first; + Value currentBase = state.lookup(logicalBase); + if (!currentBase) currentBase = logicalBase; + + SmallVector insertSlices; + insertSlices.reserve(pair.second.size()); + std::transform(pair.second.begin(), pair.second.end(), + std::back_inserter(insertSlices), + [](PendingWriteInfo* info) { return info->insertSliceOp; }); + if (insertSlices.empty()) continue; + + llvm::sort(insertSlices, [&](tensor::InsertSliceOp a, tensor::InsertSliceOp b) { + return opBefore(a.getOperation(), b.getOperation()); + }); + + logicalBases.push_back(logicalBase); + currentBases.push_back(currentBase); + allInsertSlicesPerBase.push_back(std::move(insertSlices)); + } + + if (logicalBases.empty()) return success(); + + IRRewriter rewriter(loopOp.getContext()); + rewriter.setInsertionPoint(loopOp); + + unsigned oldNumBodyArgs = loopOp.getBody()->getNumArguments(); + unsigned oldNumResults = loopOp.getNumResults(); + + SmallVector initOperands(currentBases.begin(), currentBases.end()); + + auto newLoopOp = cast(*loopOp.replaceWithAdditionalYields( + rewriter, + /*initOperands=*/initOperands, + /*replaceInitOperandUsesInLoop=*/false, + [&](const OpBuilder &yieldBuilder, Location, ArrayRef) { + SmallVector yieldValues; + yieldValues.reserve(allInsertSlicesPerBase.size()); + std::transform(allInsertSlicesPerBase.begin(), allInsertSlicesPerBase.end(), + std::back_inserter(yieldValues), + [](auto &vec) -> Value { return vec.back().getResult(); }); + return yieldValues; + })); + + Block *body = newLoopOp.getBody(); + unsigned newNumBodyArgs = body->getNumArguments(); + assert(newNumBodyArgs == oldNumBodyArgs + logicalBases.size() && + "unexpected number of body arguments after replaceWithAdditionalYields"); + + SmallVector iterArgs; + iterArgs.reserve(logicalBases.size()); + for (size_t i = 0; i < logicalBases.size(); ++i) + iterArgs.push_back(body->getArgument(oldNumBodyArgs + i)); + + for (size_t i = 0; i < logicalBases.size(); ++i) { + Value runningBase = iterArgs[i]; + auto &insertSlices = allInsertSlicesPerBase[i]; + for (auto insertOp : insertSlices) { + insertOp.getDestMutable().assign(runningBase); + runningBase = insertOp.getResult(); + } + } - while (currentLoop) { - IRRewriter rewriter(currentLoop.getContext()); - state.trackInvalid(currentLoop.getOperation()); + unsigned newNumResults = newLoopOp.getNumResults(); + assert(newNumResults == oldNumResults + logicalBases.size() && + "unexpected number of loop results after replaceWithAdditionalYields"); - auto newLoop = cast(*currentLoop.replaceWithAdditionalYields( - rewriter, /*initOperand=*/originalInitTensor, /*replaceInitOperandUsesInLoop=*/false, - [&](OpBuilder &b, Location, ArrayRef) { return SmallVector{currentInsideValue}; })); + SmallVector loopResults; + loopResults.reserve(logicalBases.size()); + for (size_t i = 0; i < logicalBases.size(); ++i) + loopResults.push_back(newLoopOp.getResults()[oldNumResults + i]); - outermostChangedLoop = newLoop.getOperation(); + unsigned initOperandStart = newLoopOp->getNumOperands() - logicalBases.size(); - Value yieldedInsideValue = newLoop.getBody()->getArguments().back(); - Value yieldedOutsideValue = newLoop.getResults().back(); + for (size_t i = 0; i < logicalBases.size(); ++i) { + Value logicalBase = logicalBases[i]; + Value previousBase = currentBases[i]; + BlockArgument iterArg = iterArgs[i]; + Value loopResult = loopResults[i]; - auto replaceInsideUses = [&](OpOperand &use) -> bool { return newLoop->isProperAncestor(use.getOwner()); }; - currentDestination.replaceUsesWithIf(yieldedInsideValue, - mlir::function_ref(replaceInsideUses)); + auto replaceInsideUses = [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoopOp->isProperAncestor(user); + }; + previousBase.replaceUsesWithIf(iterArg, replaceInsideUses); - unsigned initOperandPosition = newLoop->getNumOperands() - 1; - auto replaceOutsideUses = [&](OpOperand &use) -> bool { + auto replaceOutsideUses = [&](OpOperand &use) { Operation *user = use.getOwner(); - if (user == newLoop.getOperation() && use.getOperandNumber() == static_cast(initOperandPosition)) - return false; - if (newLoop->isProperAncestor(user)) return false; + if (user == newLoopOp.getOperation()) { + unsigned operandNumber = use.getOperandNumber(); + if (operandNumber >= initOperandStart && operandNumber < initOperandStart + logicalBases.size()) + return false; + } + if (newLoopOp->isProperAncestor(user)) return false; return true; }; - currentDestination.replaceUsesWithIf(yieldedOutsideValue, - mlir::function_ref(replaceOutsideUses)); + previousBase.replaceUsesWithIf(loopResult, replaceOutsideUses); - currentInsideValue = yieldedOutsideValue; - currentLoop = newLoop->getParentOfType(); + state.replaceMapping(logicalBase, loopResult); } - state.replaceMapping(destinationTensor, currentInsideValue); - state.markForErasure(writeOp); - state.markForErasure(insertSliceOp.getOperation()); - return outermostChangedLoop ? outermostChangedLoop : insertSliceOp.getOperation(); + for (size_t i = 0; i < logicalBases.size(); ++i) { + Value originalBase = logicalBases[i]; + Value loopResult = loopResults[i]; + + bool usedOutsideThisLoop = false; + for (OpOperand &use : loopResult.getUses()) { + Operation *user = use.getOwner(); + if (!newLoopOp->isProperAncestor(user) && user != newLoopOp.getOperation()) { + usedOutsideThisLoop = true; + break; + } + } + + if (usedOutsideThisLoop) { + LiveOutCandidate candidate; + candidate.origBase = originalBase; + candidate.currentValue = loopResult; + candidate.loop = newLoopOp; + liveOutCandidates.push_back(candidate); + } + } + + return success(); +} + +static LogicalResult runSecondPhase(func::FuncOp funcOp, TensorizationState &state) { + llvm::DenseMap groupsByLoop; + + for (auto &info : state.pendingWrites) { + auto writeOp = info.writeOp; + auto parentLoopOp = writeOp->getParentOfType(); + if (!parentLoopOp) continue; + Value logicalBase = info.logicalBase; + groupsByLoop[parentLoopOp][logicalBase].push_back(&info); + } + + SmallVector loopOrder; + funcOp.walk([&](affine::AffineForOp loopOp) { loopOrder.push_back(loopOp); }); + std::reverse(loopOrder.begin(), loopOrder.end()); + + SmallVector liveOutCandidates; + + for (auto loopOp : loopOrder) { + auto it = groupsByLoop.find(loopOp); + if (it == groupsByLoop.end()) continue; + + PendingGroupMap &groupsByBaseTensor = it->second; + if (failed(processLoopAllBaseTensors(loopOp, groupsByBaseTensor, state, liveOutCandidates))) + return failure(); + } + + for (auto &info : state.pendingWrites) { + if (info.writeOp->getParentOfType()) continue; + + tensor::InsertSliceOp insertOp = info.insertSliceOp; + Value logicalBase = info.logicalBase; + + Value baseTensor = state.lookup(logicalBase); + if (!baseTensor) baseTensor = logicalBase; + + auto notInsertDest = [&](OpOperand &use) { + Operation *user = use.getOwner(); + if (user == insertOp.getOperation()) { + if (use.getOperandNumber() == insertOp.getDestMutable().getOperandNumber()) + return false; + } + return true; + }; + baseTensor.replaceUsesWithIf(insertOp.getResult(), notInsertDest); + state.replaceMapping(logicalBase, insertOp.getResult()); + } + + if (failed(propagateLiveOuts(liveOutCandidates, state))) + return failure(); + + return success(); } -//------------------------------------------------------------------------------ // Element-wise tensorization -//------------------------------------------------------------------------------ static FailureOr cloneElementWiseOp(Operation *originalOp, ArrayRef newOperands, TensorizationState &state) { @@ -309,11 +778,10 @@ static FailureOr cloneElementWiseOp(Operation *originalOp, ArrayRef SmallVector resultTensorTypes; for (Type originalResultType : originalOp->getResultTypes()) { - if (auto rankedTensorType = originalResultType.dyn_cast()) + if (auto rankedTensorType = mlir::dyn_cast(originalResultType)) resultTensorTypes.push_back(rankedTensorType); - else if (auto vectorType = originalResultType.dyn_cast()) - resultTensorTypes.push_back( - RankedTensorType::get(vectorType.getShape(), vectorType.getElementType())); + else if (auto vectorType = mlir::dyn_cast(originalResultType)) + resultTensorTypes.push_back(RankedTensorType::get(vectorType.getShape(), vectorType.getElementType())); else return failure(); } @@ -330,52 +798,85 @@ static FailureOr cloneElementWiseOp(Operation *originalOp, ArrayRef static FailureOr convertElementWiseOp(Operation *op, TensorizationState &state) { if (op->getNumRegions() != 0) return failure(); - if (llvm::any_of(op->getOperandTypes(), [](Type t) { return !isNumericTypeLike(t); })) { + auto isNumericLike = [](Type type) { return isNumericTypeLike(type); }; + if (llvm::any_of(op->getOperandTypes(), [&](Type type) { return !isNumericLike(type); })) { return failure(); } - - if (llvm::any_of(op->getResultTypes(), [](Type t) { return !isNumericTypeLike(t); })) { + if (llvm::any_of(op->getResultTypes(), [&](Type type) { return !isNumericLike(type); })) { return failure(); } RankedTensorType referenceTensorType = nullptr; + for (Value operand : op->getOperands()) { - if (auto mapped = state.lookup(operand)) { - if (auto rankedType = mapped.getType().dyn_cast()) { + if (auto mappedValue = state.lookup(operand)) { + if (auto rankedType = mlir::dyn_cast(mappedValue.getType())) { + referenceTensorType = rankedType; + break; + } + } + } + if (!referenceTensorType) { + for (Value operand : op->getOperands()) { + if (auto rankedType = mlir::dyn_cast(operand.getType())) { referenceTensorType = rankedType; break; } } - if (auto rankedType = operand.getType().dyn_cast()) { - referenceTensorType = rankedType; - break; + } + if (!referenceTensorType) { + for (Value operand : op->getOperands()) { + if (auto vecType = mlir::dyn_cast(operand.getType())) { + referenceTensorType = RankedTensorType::get(vecType.getShape(), vecType.getElementType()); + break; + } } } + if (!referenceTensorType) { + for (Type resultType : op->getResultTypes()) { + if (auto rankedType = mlir::dyn_cast(resultType)) { + referenceTensorType = rankedType; + break; + } + if (auto vecType = mlir::dyn_cast(resultType)) { + referenceTensorType = RankedTensorType::get(vecType.getShape(), vecType.getElementType()); + break; + } + } + } + if (!referenceTensorType) return failure(); SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + for (Value operand : op->getOperands()) { - if (auto mapped = state.lookup(operand)) { - newOperands.push_back(mapped); + if (auto mappedValue = state.lookup(operand)) { + if (mappedValue.getType() != referenceTensorType) { + return failure(); + } + newOperands.push_back(mappedValue); continue; } + if (operand.getType() == referenceTensorType) { newOperands.push_back(operand); continue; } + if (auto constantOp = operand.getDefiningOp()) { if (failed(upgradeConstantToTensor(constantOp, referenceTensorType, state))) return failure(); newOperands.push_back(state.lookup(constantOp.getResult())); continue; } + return failure(); } + return cloneElementWiseOp(op, newOperands, state); } -//------------------------------------------------------------------------------ // Dispatch -//------------------------------------------------------------------------------ static FailureOr tensorizeOperation(Operation *op, TensorizationState &state) { if (auto allocOp = dyn_cast(op)) return convertAllocOp(allocOp, state); @@ -399,8 +900,7 @@ namespace { #define GEN_PASS_DEF_VECTORTRANSFERTENSORIZE #include "akg/Dialect/Affine/Passes.h.inc" -struct VectorTransferTensorizePass - : public impl::VectorTransferTensorizeBase { +struct VectorTransferTensorizePass : public impl::VectorTransferTensorizeBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -410,19 +910,35 @@ struct VectorTransferTensorizePass func::FuncOp funcOp = getOperation(); TensorizationState state(funcOp.getContext()); + { + SmallVector vectorConstants; + funcOp.walk([&](arith::ConstantOp constantOp) { + if (mlir::isa(constantOp.getType())) vectorConstants.push_back(constantOp); + }); + for (auto constantOp : vectorConstants) { + (void)replaceVectorConstantWithTensor(constantOp, state.builder); + } + } + SmallVector workList; funcOp.walk([&](Operation *operation) { workList.push_back(operation); }); - for (size_t idx = 0; idx < workList.size(); ++idx) { - Operation *op = workList[idx]; - if (!op) continue; - if (state.operationsToErase.contains(op)) continue; - if (state.isInvalid(op)) continue; - if (op == funcOp.getOperation()) continue; - if (!op->getBlock()) continue; - if (isa(op)) continue; - state.builder.setInsertionPoint(op); - (void)tensorizeOperation(op, state); + for (size_t index = 0; index < workList.size(); ++index) { + Operation *operation = workList[index]; + if (!operation) continue; + if (state.operationsToErase.contains(operation)) continue; + if (state.isInvalid(operation)) continue; + if (operation == funcOp.getOperation()) continue; + if (!operation->getBlock()) continue; + if (isa(operation)) continue; + state.builder.setInsertionPoint(operation); + auto result = tensorizeOperation(operation, state); + if (failed(result)) continue; + } + + if (failed(runSecondPhase(funcOp, state))) { + signalPassFailure(); + return; } auto eraseDeadMarkedOps = [&]() { diff --git a/akg-mlir/tests/ut/Dialect/Affine/vector_tensor.mlir b/akg-mlir/tests/ut/Dialect/Affine/vector_tensor.mlir index 47d84aec..276bc132 100644 --- a/akg-mlir/tests/ut/Dialect/Affine/vector_tensor.mlir +++ b/akg-mlir/tests/ut/Dialect/Affine/vector_tensor.mlir @@ -1,45 +1,41 @@ // RUN: akg-opt %s --vector-transfer-tensorize -allow-unregistered-dialect | FileCheck %s -// CHECK-LABEL: module { -// CHECK-NEXT: func.func @Fused_BiasAdd_10033593016906428850(%arg0: tensor<28x3072xbf16>, %arg1: tensor<3072xbf16>) -> tensor<28x3072xbf16> { -// CHECK-NEXT: %0 = tensor.empty() : tensor<28x3072xbf16> -// CHECK-NEXT: %1 = tensor.empty() : tensor<28x3072xbf16> -// CHECK-NEXT: %2:2 = affine.for %arg2 = 0 to 28 iter_args(%arg3 = %0, %arg4 = %1) -> (tensor<28x3072xbf16>, tensor<28x3072xbf16>) { -// CHECK-NEXT: %3:2 = affine.for %arg5 = 0 to 3072 step 3072 iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (tensor<28x3072xbf16>, tensor<28x3072xbf16>) { -// CHECK-NEXT: %extracted_slice = tensor.extract_slice %arg1[%arg5] [3072] [1] : tensor<3072xbf16> to tensor<3072xbf16> -// CHECK-NEXT: %inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg2, %arg5] [1, 3072] [1, 1] : tensor<3072xbf16> into tensor<28x3072xbf16> -// CHECK-NEXT: %extracted_slice_0 = tensor.extract_slice %arg0[%arg2, %arg5] [1, 3072] [1, 1] : tensor<28x3072xbf16> to tensor<3072xbf16> -// CHECK-NEXT: %extracted_slice_1 = tensor.extract_slice %0[%arg2, %arg5] [1, 3072] [1, 1] : tensor<28x3072xbf16> to tensor<3072xbf16> -// CHECK-NEXT: %4 = arith.addf %extracted_slice_0, %extracted_slice_1 : tensor<3072xbf16> -// CHECK-NEXT: %inserted_slice_2 = tensor.insert_slice %4 into %arg7[%arg2, %arg5] [1, 3072] [1, 1] : tensor<3072xbf16> into tensor<28x3072xbf16> -// CHECK-NEXT: affine.yield %inserted_slice, %inserted_slice_2 : tensor<28x3072xbf16>, tensor<28x3072xbf16> -// CHECK-NEXT: } -// CHECK-NEXT: affine.yield %3#0, %3#1 : tensor<28x3072xbf16>, tensor<28x3072xbf16> +// CHECK-LABEL: func.func @Fused_BiasAdd_10033593016906428850(%arg0: tensor<28x3072xbf16>, %arg1: tensor<3072xbf16>) -> tensor<28x3072xbf16> { +// CHECK-NEXT: %0 = tensor.empty() : tensor<28x3072xbf16> +// CHECK-NEXT: %1 = tensor.empty() : tensor<28x3072xbf16> +// CHECK-NEXT: %2 = affine.for %arg2 = 0 to 28 iter_args(%arg3 = %1) -> (tensor<28x3072xbf16>) { +// CHECK-NEXT: %3:2 = affine.for %arg4 = 0 to 3072 step 3072 iter_args(%arg5 = %0, %arg6 = %arg3) -> (tensor<28x3072xbf16>, tensor<28x3072xbf16>) { +// CHECK-NEXT: %extracted_slice = tensor.extract_slice %arg1[%arg4] [3072] [1] : tensor<3072xbf16> to tensor<3072xbf16> +// CHECK-NEXT: %inserted_slice = tensor.insert_slice %extracted_slice into %arg5[%arg2, %arg4] [1, 3072] [1, 1] : tensor<3072xbf16> into tensor<28x3072xbf16> +// CHECK-NEXT: %extracted_slice_0 = tensor.extract_slice %arg0[%arg2, %arg4] [1, 3072] [1, 1] : tensor<28x3072xbf16> to tensor<3072xbf16> +// CHECK-NEXT: %extracted_slice_1 = tensor.extract_slice %arg5[%arg2, %arg4] [1, 3072] [1, 1] : tensor<28x3072xbf16> to tensor<3072xbf16> +// CHECK-NEXT: %4 = arith.addf %extracted_slice_0, %extracted_slice_1 : tensor<3072xbf16> +// CHECK-NEXT: %inserted_slice_2 = tensor.insert_slice %4 into %arg6[%arg2, %arg4] [1, 3072] [1, 1] : tensor<3072xbf16> into tensor<28x3072xbf16> +// CHECK-NEXT: affine.yield %inserted_slice, %inserted_slice_2 : tensor<28x3072xbf16>, tensor<28x3072xbf16> // CHECK-NEXT: } -// CHECK-NEXT: return %2#1 : tensor<28x3072xbf16> +// CHECK-NEXT: affine.yield %3#1 : tensor<28x3072xbf16> // CHECK-NEXT: } +// CHECK-NEXT: return %2 : tensor<28x3072xbf16> // CHECK-NEXT: } -module { - func.func @Fused_BiasAdd_10033593016906428850(%arg0: tensor<28x3072xbf16>, %arg1: tensor<3072xbf16>) -> tensor<28x3072xbf16> { - %0 = bufferization.to_memref %arg0 : memref<28x3072xbf16> - %1 = bufferization.to_memref %arg1 : memref<3072xbf16> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<28x3072xbf16> - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<28x3072xbf16> - affine.for %arg2 = 0 to 28 { - affine.for %arg3 = 0 to 3072 step 3072 { - %cst = arith.constant 0.000000e+00 : bf16 - %3 = vector.transfer_read %1[%arg3], %cst : memref<3072xbf16>, vector<3072xbf16> - vector.transfer_write %3, %alloc[%arg2, %arg3] : vector<3072xbf16>, memref<28x3072xbf16> - %cst_1 = arith.constant 0.000000e+00 : bf16 - %4 = vector.transfer_read %0[%arg2, %arg3], %cst_1 : memref<28x3072xbf16>, vector<3072xbf16> - %cst_2 = arith.constant 0.000000e+00 : bf16 - %5 = vector.transfer_read %alloc[%arg2, %arg3], %cst_2 : memref<28x3072xbf16>, vector<3072xbf16> - %6 = arith.addf %4, %5 : vector<3072xbf16> - vector.transfer_write %6, %alloc_0[%arg2, %arg3] : vector<3072xbf16>, memref<28x3072xbf16> - } - } - %2 = bufferization.to_tensor %alloc_0 : memref<28x3072xbf16> - return %2 : tensor<28x3072xbf16> +func.func @Fused_BiasAdd_10033593016906428850(%arg0: tensor<28x3072xbf16>, %arg1: tensor<3072xbf16>) -> tensor<28x3072xbf16> { + %0 = bufferization.to_memref %arg0 : memref<28x3072xbf16> + %1 = bufferization.to_memref %arg1 : memref<3072xbf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<28x3072xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<28x3072xbf16> + affine.for %arg2 = 0 to 28 { + affine.for %arg3 = 0 to 3072 step 3072 { + %cst = arith.constant 0.000000e+00 : bf16 + %3 = vector.transfer_read %1[%arg3], %cst : memref<3072xbf16>, vector<3072xbf16> + vector.transfer_write %3, %alloc[%arg2, %arg3] : vector<3072xbf16>, memref<28x3072xbf16> + %cst_1 = arith.constant 0.000000e+00 : bf16 + %4 = vector.transfer_read %0[%arg2, %arg3], %cst_1 : memref<28x3072xbf16>, vector<3072xbf16> + %cst_2 = arith.constant 0.000000e+00 : bf16 + %5 = vector.transfer_read %alloc[%arg2, %arg3], %cst_2 : memref<28x3072xbf16>, vector<3072xbf16> + %6 = arith.addf %4, %5 : vector<3072xbf16> + vector.transfer_write %6, %alloc_0[%arg2, %arg3] : vector<3072xbf16>, memref<28x3072xbf16> } + } + %2 = bufferization.to_tensor %alloc_0 : memref<28x3072xbf16> + return %2 : tensor<28x3072xbf16> } \ No newline at end of file -- Gitee