diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h index c390fe5b221fa7a83394176bcb5ea15d0925d626..3806ce0526b19b256654e79dc2fa4bd7549ff96f 100644 --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -420,6 +420,7 @@ void initializeWasmEHPreparePass(PassRegistry&); void initializeWinEHPreparePass(PassRegistry&); void initializeWriteBitcodePassPass(PassRegistry&); void initializeXRayInstrumentationPass(PassRegistry&); +void initializeSVEExpandLibCallPass(PassRegistry &); } // end namespace llvm diff --git a/llvm/include/llvm/Target/SVEExpandLibCall.h b/llvm/include/llvm/Target/SVEExpandLibCall.h new file mode 100644 index 0000000000000000000000000000000000000000..7e40b0889f0337db92c6a8e854a5947fa16727ee --- /dev/null +++ b/llvm/include/llvm/Target/SVEExpandLibCall.h @@ -0,0 +1,15 @@ +#ifndef LLVM_LIB_TARGET_AARCH64_SVEEXPANDLIBCALL_H +#define LLVM_LIB_TARGET_AARCH64_SVEEXPANDLIBCALL_H + + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +struct SVEExpandLibCallPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + +} // namespace llvm + +#endif \ No newline at end of file diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h index 46e105e76a52775bbd98bb90279411569c0524d2..fb1b253f9caaeb6aa79505d4816547ec68b7820b 100644 --- a/llvm/lib/Target/AArch64/AArch64.h +++ b/llvm/lib/Target/AArch64/AArch64.h @@ -59,6 +59,7 @@ FunctionPass *createWeakConsistencyPass(); FunctionPass *createAArch64CleanupLocalDynamicTLSPass(); FunctionPass *createAArch64CollectLOHPass(); +FunctionPass *createSVEExpandLibCallPass(); FunctionPass *createSMEABIPass(); ModulePass *createSVEIntrinsicOptsPass(); InstructionSelector * diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp index 1c0032be035c4c10face18858621c8a12cc5681d..6ceda34fbfd8dd57ac58ff1f9011fc9063c7164b 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -161,6 +161,11 @@ static cl::opt EnableGlobalISelAtO( cl::desc("Enable GlobalISel at or below an opt level (-1 to disable)"), cl::init(0)); +static cl::opt + EnableSVELibCallOpts("aarch64-enable-sve-libcall-opts", cl::Hidden, + cl::desc("Enable SVE libcall opts"), + cl::init(true)); + static cl::opt EnableSVEIntrinsicOpts("aarch64-enable-sve-intrinsic-opts", cl::Hidden, cl::desc("Enable SVE intrinsic opts"), @@ -235,6 +240,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAArch64Target() { initializeLDTLSCleanupPass(*PR); initializeSMEABIPass(*PR); initializeSVEIntrinsicOptsPass(*PR); + initializeSVEExpandLibCallPass(*PR); initializeAArch64SpeculationHardeningPass(*PR); initializeAArch64SLSHardeningPass(*PR); initializeAArch64StackTaggingPass(*PR); @@ -557,6 +563,9 @@ void AArch64PassConfig::addIRPasses() { if (EnableSVEIntrinsicOpts && TM->getOptLevel() == CodeGenOpt::Aggressive) addPass(createSVEIntrinsicOptsPass()); + if (EnableSVELibCallOpts && TM->getOptLevel() == CodeGenOpt::Aggressive) + addPass(createSVEExpandLibCallPass()); + // Cmpxchg instructions are often used with a subsequent comparison to // determine whether it succeeded. We can exploit existing control-flow in // ldrex/strex loops to simplify this, but it needs tidying up. diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt index 97a03e1c268b924f286fec6681a24d5bed059058..11fb27121b841ffb6411aba0cca601737435fa4a 100644 --- a/llvm/lib/Target/AArch64/CMakeLists.txt +++ b/llvm/lib/Target/AArch64/CMakeLists.txt @@ -86,6 +86,7 @@ add_llvm_target(AArch64CodeGen AArch64TargetTransformInfo.cpp SMEABIPass.cpp SVEIntrinsicOpts.cpp + SVEExpandLibCall.cpp AArch64SIMDInstrOpt.cpp WeakConsistencyPass.cpp WeakConsistencyAllowlist.cpp diff --git a/llvm/lib/Target/AArch64/SVEExpandLibCall.cpp b/llvm/lib/Target/AArch64/SVEExpandLibCall.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e3a7a378f8673d4720aeeeb53afc626e472d98d --- /dev/null +++ b/llvm/lib/Target/AArch64/SVEExpandLibCall.cpp @@ -0,0 +1,294 @@ +//===----- SVEExpandLibCall - SVE Lib Call Expansion ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass performs the optimization: +// +// 1. Expands memset and memcpy intrinsics to SVE loops. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Target/SVEExpandLibCall.h" +#include "Utils/AArch64BaseInfo.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAArch64.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/InitializePasses.h" +#include "llvm/Support/CommandLine.h" + +using namespace llvm; + +#define DEBUG_TYPE "aarch64-sve-expandlibcall" + +STATISTIC(NumExpandedCalls, "Number of SVE libcalls expanded"); + +static cl::opt EnableMemCallInlining( + "sve-enable-mem-call-inlining", cl::init(false), cl::Hidden, + cl::desc("Replace calls to memsets/memcpy with an inline SVE loop")); + +static cl::opt ExpandMemCallThreshold( + "sve-expand-mem-call-threshold", cl::init(128), cl::Hidden, + cl::desc("Size threshold for expanding memset/memcpy calls to SVE loops")); + +static cl::opt EnableMemCallRuntimeCheck( + "sve-enable-mem-call-rtcheck", cl::init(true), cl::Hidden, + cl::desc("Enable runtime check for small memsets/memcpy calls")); + +namespace llvm { +void initializeSVEExpandLibCallPass(PassRegistry &); +} + +namespace { +class SVEExpandLibCallBase { +public: + SVEExpandLibCallBase() {} + bool ExpandMemCallToLoop(MemIntrinsic *II, Function &F); + Instruction *CreateWhile(Intrinsic::ID ID, Type *Ty, Value *Op1, Value *Op2); + +private: + Function *F; +}; +} // namespace + +namespace { +struct SVEExpandLibCall : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + SVEExpandLibCall() : FunctionPass(ID) { + initializeSVEExpandLibCallPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; +}; +} // namespace + +char SVEExpandLibCall::ID = 0; +static const char *name = "SVE Vector Lib Call Expansion"; +INITIALIZE_PASS_BEGIN(SVEExpandLibCall, DEBUG_TYPE, name, false, false) +INITIALIZE_PASS_END(SVEExpandLibCall, DEBUG_TYPE, name, false, false) + +namespace llvm { +FunctionPass *createSVEExpandLibCallPass() { return new SVEExpandLibCall(); } +} // namespace llvm + +// Check if the given instruction is a memory lib call that needs expanding. +static bool isMemLibCall(Instruction *I) { + IntrinsicInst *II = dyn_cast(I); + if (!II) + return false; + + switch (II->getIntrinsicID()) { + case Intrinsic::memset: + case Intrinsic::memcpy: + if (cast(I)->getDestAlignment() == 0) + return false; + return true; + default: + break; + } + + return false; +} + +/// Create call to specified WHILE intrinsic. +Instruction *SVEExpandLibCallBase::CreateWhile(Intrinsic::ID IntID, Type *Ty, + Value *Op1, Value *Op2) { + SmallVector Types = {Ty, Op1->getType()}; + SmallVector Args{Op1, Op2}; + + Function *Intrinsic = Intrinsic::getDeclaration(F->getParent(), IntID, Types); + return CallInst::Create(Intrinsic->getFunctionType(), Intrinsic, Args); +} + +// Expand a call to memcpy or memset into an optimized SVE loop +bool SVEExpandLibCallBase::ExpandMemCallToLoop(MemIntrinsic *II, Function &F) { + if (II->isVolatile()) + return false; + + this->F = &F; + + if (auto *CI = dyn_cast(II->getLength())) { + if (CI->getZExtValue() < ExpandMemCallThreshold) + return false; + } + + LLVM_DEBUG(dbgs() << "SVEExpandLib: Expanding call to: " + << II->getCalledFunction()->getName() << "\n"); + + // Splitting basic block into expand (loop body) block and resume block. + auto *ParentBlock = II->getParent(); + auto *PHBlock = ParentBlock->splitBasicBlock(II, "mem.ph"); + auto *LoopBlock = PHBlock->splitBasicBlock(II, "mem.exploop"); + auto *LoopEpilog = LoopBlock->splitBasicBlock(II, "mem.epilog"); + auto *MemIntrBlock = LoopEpilog->splitBasicBlock(II, "mem.intrinsic"); + auto *ResumeBlock = MemIntrBlock->splitBasicBlock(II, "mem.resume"); + + // Fill in the preheader + BasicBlock::iterator InsertPt(ParentBlock->getTerminator()); + IRBuilder<> Builder(&*InsertPt); + + // Always use a 64bit iteration counter + Type *IdxTy = Builder.getInt64Ty(); + + // Possibly zero-extend the length + Value *Length = II->getLength(); + if (!Length->getType()->isIntegerTy(64)) { + Length = Builder.CreateZExt(Length, IdxTy); + } + + // Create runtime check based on runtime MemCallThreshold length + if (EnableMemCallRuntimeCheck) { + auto MinElts = Builder.getInt64(ExpandMemCallThreshold); + auto ScalarCompare = Builder.CreateICmpULE(Length, MinElts); + Builder.CreateCondBr(ScalarCompare, PHBlock, MemIntrBlock); + ParentBlock->getTerminator()->eraseFromParent(); + } + + // Create the splat (memset) + Builder.SetInsertPoint(PHBlock->getTerminator()); + auto *ValTy = VectorType::get(Builder.getInt8Ty(), 16, true); + Value *SetVal = nullptr; + auto *DestTy = II->getRawDest()->getType()->getScalarType(); + auto *NumElts = + ConstantInt::get(IdxTy, ValTy->getElementCount().getKnownMinValue()); + auto *VScale = Builder.CreateVScale(NumElts); + + Value *MinusOne = ConstantInt::get(IdxTy, -1, false); + auto *VScaleMask = Builder.CreateAdd(VScale, MinusOne); + auto *RemainLength = Builder.CreateAnd(VScaleMask, Length); + + if (auto *MS = dyn_cast(II)) + SetVal = + Builder.CreateVectorSplat(ValTy->getElementCount(), MS->getValue()); + + auto *LengthAtLatch = Builder.CreateSub(Length, VScale); + auto *EndDest = + Builder.CreateGEP(Builder.getInt8Ty(), II->getRawDest(), LengthAtLatch); + + auto EPiCompare = Builder.CreateICmpUGE(Length, VScale); + Builder.CreateCondBr(EPiCompare, LoopBlock, LoopEpilog); + PHBlock->getTerminator()->eraseFromParent(); + + // Set Insert point to loop body + Builder.SetInsertPoint(LoopBlock->getTerminator()); + + auto *DestPHI = Builder.CreatePHI(DestTy, 2); + DestPHI->addIncoming(II->getRawDest(), PHBlock); + + Value *NexSource = nullptr; + + // Create the load (in case of memcpy) + if (auto *MC = dyn_cast(II)) { + auto *SourceTy = MC->getRawSource()->getType()->getScalarType(); + auto *SourcePHI = Builder.CreatePHI(SourceTy, 2); + SourcePHI->addIncoming(MC->getRawSource(), PHBlock); + SetVal = Builder.CreateLoad(ValTy, SourcePHI); + NexSource = Builder.CreateGEP(Builder.getInt8Ty(), SourcePHI, VScale); + SourcePHI->addIncoming(NexSource, LoopBlock); + } + + assert(SetVal && "No Value to store"); + + Builder.CreateStore(SetVal, DestPHI); + auto *NextDest = Builder.CreateGEP(Builder.getInt8Ty(), DestPHI, VScale); + DestPHI->addIncoming(NextDest, LoopBlock); + + auto ScalarCompare = Builder.CreateICmpULE(NextDest, EndDest); + Builder.CreateCondBr(ScalarCompare, LoopBlock, LoopEpilog); + LoopBlock->getTerminator()->eraseFromParent(); + + Builder.SetInsertPoint(LoopEpilog->getTerminator()); + auto *EpiDestPHI = Builder.CreatePHI(DestTy, 2); + EpiDestPHI->addIncoming(II->getRawDest(), PHBlock); + EpiDestPHI->addIncoming(NextDest, LoopBlock); + + PHINode *EpiSourcePHI = nullptr; + if (auto *MC = dyn_cast(II)) { + auto *SourceTy = MC->getRawSource()->getType()->getScalarType(); + EpiSourcePHI = Builder.CreatePHI(SourceTy, 2); + EpiSourcePHI->addIncoming(MC->getRawSource(), PHBlock); + EpiSourcePHI->addIncoming(NexSource, LoopBlock); + } + + Value *Zero = ConstantInt::get(IdxTy, 0, false); + auto *PredTy = VectorType::get(Builder.getInt1Ty(), 16, true); + auto *Pred = + CreateWhile(Intrinsic::aarch64_sve_whilelo, PredTy, Zero, RemainLength); + Builder.Insert(Pred); + if (EpiSourcePHI) { + SetVal = Builder.CreateMaskedLoad(ValTy, EpiSourcePHI, + Align(II->getDestAlignment()), Pred); + } + + Builder.CreateMaskedStore(SetVal, EpiDestPHI, Align(II->getDestAlignment()), + Pred); + LoopEpilog->getTerminator()->eraseFromParent(); + BranchInst::Create(ResumeBlock, LoopEpilog); + + // Remove the original memset + II->moveBefore(MemIntrBlock->getTerminator()); + + NumExpandedCalls++; + return true; +} + +bool SVEExpandLibCall::runOnFunction(Function &F) { + bool Changed = false; + SmallVector MemWorkList; + + for (auto I = inst_begin(F), E = inst_end(F); I != E; ++I) { + if (EnableMemCallInlining && isMemLibCall(&*I)) + MemWorkList.push_back(&*I); + } + + // If the target-feature for SVE is not set, we can't generate + // explicit SVE intrinsics to optimize memsets. + bool HasSVEAttribute = F.getAttributes() + .getFnAttrs() + .getAttribute("target-features") + .getValueAsString() + .contains("+sve"); + SVEExpandLibCallBase Impl; + if (HasSVEAttribute) { + for (auto I : MemWorkList) + Changed |= Impl.ExpandMemCallToLoop(cast(I), F); + } + + return Changed; +} + +PreservedAnalyses SVEExpandLibCallPass::run(Function &F, + FunctionAnalysisManager &AM) { + bool Changed = false; + SmallVector MemWorkList; + + for (auto I = inst_begin(F), E = inst_end(F); I != E; ++I) { + if (EnableMemCallInlining && isMemLibCall(&*I)) + MemWorkList.push_back(&*I); + } + + // If the target-feature for SVE is not set, we can't generate + // explicit SVE intrinsics to optimize memsets. + bool HasSVEAttribute = F.getAttributes() + .getFnAttrs() + .getAttribute("target-features") + .getValueAsString() + .contains("+sve"); + + SVEExpandLibCallBase Impl; + if (HasSVEAttribute) { + for (auto I : MemWorkList) + Changed |= Impl.ExpandMemCallToLoop(cast(I), F); + } + + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll index fad55a823f3c617feefd644f926ee66100d3fe13..3b1e336c4b9e482a784f2a7e89a4e3ddee4929a7 100644 --- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll @@ -26,6 +26,7 @@ ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: FunctionPass Manager +; CHECK-NEXT: SVE Vector Lib Call Expansion ; CHECK-NEXT: Simplify the CFG ; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: Natural Loop Information diff --git a/llvm/test/CodeGen/AArch64/sve-expandmemlibcall-memcpy.ll b/llvm/test/CodeGen/AArch64/sve-expandmemlibcall-memcpy.ll new file mode 100644 index 0000000000000000000000000000000000000000..f828a26dc0454a46e0f6a3e4e6a331a8a5e640c4 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-expandmemlibcall-memcpy.ll @@ -0,0 +1,22 @@ +; RUN: opt -S -aarch64-sve-expandlibcall -sve-enable-mem-call-inlining=true -sve-enable-mem-call-rtcheck=true -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s + +@a = dso_local global [8000 x float] zeroinitializer, align 64 +@b = dso_local global [8000 x float] zeroinitializer, align 64 + +define void @test_memcpy() { +; CHECK-LABEL: @test_memcpy( +; CHECK: label %mem.ph, label %mem.intrinsic +; CHECK: mem.ph: +; CHECK-LABEL: mem.exploop +; CHECK: call @llvm.aarch64.sve.whilelo.nxv16i1.i64 +; CHECK: call @llvm.masked.load.nxv16i8.p0 +; CHECK: call void @llvm.masked.store.nxv16i8.p0 +; CHECK-LABEL: mem.resume +; CHECK-LABEL: mem.intrinsic +; +entry: + call void @llvm.memcpy.p0.p0.i64(ptr align 64 @b, ptr align 64 @a, i64 32000, i1 false) + ret void +} + +declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg) diff --git a/llvm/test/CodeGen/AArch64/sve-expandmemlibcall-memset.ll b/llvm/test/CodeGen/AArch64/sve-expandmemlibcall-memset.ll new file mode 100644 index 0000000000000000000000000000000000000000..104a31fa20126ed6394a0c6653ce85b11e53d846 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-expandmemlibcall-memset.ll @@ -0,0 +1,19 @@ +; RUN: opt -S -aarch64-sve-expandlibcall -sve-enable-mem-call-inlining=true -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s + +@A = dso_local local_unnamed_addr global [1000000 x i32] zeroinitializer, align 4 + +; Function Attrs: nofree norecurse nosync nounwind uwtable writeonly vscale_range(1,16) +define dso_local void @foo(){ +; CHECK-LABEL: @foo( +; CHECK-LABEL: mem.exploop +; CHECK: call @llvm.aarch64.sve.whilelo.nxv16i1.i64 +; CHECK: call void @llvm.masked.store.nxv16i8.p0 +; CHECK-LABEL: mem.resume +; +entry: + call void @llvm.memset.p0.i64(ptr noundef nonnull align 4 dereferenceable(4000000) @A, i8 0, i64 4000000, i1 false) + ret void +} + +; Function Attrs: argmemonly nofree nounwind willreturn writeonly +declare void @llvm.memset.p0.i64(ptr nocapture writeonly, i8, i64, i1 immarg) \ No newline at end of file