From 8a58f954e131f5fdf729f70df20c425d7afd78c5 Mon Sep 17 00:00:00 2001 From: softwarezhen Date: Wed, 16 Aug 2023 01:23:47 +0000 Subject: [PATCH 1/2] =?UTF-8?q?update=20llvm/lib/Transforms/Utils/Simplify?= =?UTF-8?q?LibCalls.cpp.=20=E5=A2=9E=E5=8A=A0=E4=BB=A5=E4=B8=8B=E5=BA=93?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E5=90=88=E5=B9=B6=E4=BC=98=E5=8C=96=EF=BC=9A?= =?UTF-8?q?=20pow=20(sqrt(x),=20y)=20->=20pow=20(x,=20y*0.5)=20pow=20(pow?= =?UTF-8?q?=20(x,=20y),=20z)=20->=20pow=20(x,=20y*z)=20sqrt=20(Nroot(x))?= =?UTF-8?q?=20->=20pow(x,1/(2*N))=20sqrt=20(pow=20(x,=20y))=20->=20pow=20(?= =?UTF-8?q?|x|,=20y*0.5)=20cbrt(exp(X))=20->=20exp(x/3)=20cbrt(exp2(X))=20?= =?UTF-8?q?->=20exp2(x/3)=20cbrt(sqrt(x))=20->=20pow(x,1/6)=20cbrt(cbrt(x)?= =?UTF-8?q?)=20->=20pow(x,1/9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: softwarezhen --- .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 337 +++++++++++++++++- 1 file changed, 334 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 245f2d4e442a..4ef213d88532 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1970,7 +1970,319 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) { return Sqrt; } - +//pow(sqrt(x),y) -> pow(x,y*0.5) +Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, IRBuilderBase &B){ + Value *Base=nullptr,*y=nullptr,*NewPow=nullptr; + Base = Pow->getArgOperand(0); + y = Pow->getArgOperand(1); + Module *Mod = Pow->getModule(); + Type *Ty = Pow->getType();//DoubleTyID、FloatTyID等 + CallInst *BaseFn = dyn_cast(Base); + //确定base是一个函数调用,并且开启了快速数学标志 + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()){ + Function *CalleeFn = BaseFn->getCalledFunction(); + if (!TargetLibraryInfoImpl::isCallingConvCCompatible(CalleeFn)) + return nullptr; + LibFunc LibFn; + + //判断Base是不是一个Intrinsic函数调用、并且是一个Intrinsic为sqrt的调用 + if (CalleeFn && CalleeFn->getIntrinsicID()==Intrinsic::sqrt){ + //提取sqrt内部的参数x + Value *x=BaseFn->getOperand(0); + //创建新节点y*0.5 + Value *y_05=B.CreateFMul(y,ConstantFP::get(Ty, 0.5)); + //创建新节点NewPow的定义。代表pow(x,y*0.5) + NewPow=B.CreateCall(Intrinsic::getDeclaration(Mod,Pow->getIntrinsicID(),Ty),{x,y_05}); + } + //判断是不是库函数调用 + else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(Mod, TLI, LibFn)){ + LibFunc floatFn,doubleFn,longDFn; + switch (LibFn) + { + case LibFunc_sqrtf: + case LibFunc_sqrt: + case LibFunc_sqrtl: + floatFn = LibFunc_powf; + doubleFn = LibFunc_pow; + longDFn = LibFunc_powl; + break; + case LibFunc_sqrtf_finite: + case LibFunc_sqrt_finite: + case LibFunc_sqrtl_finite: + floatFn = LibFunc_powf_finite; + doubleFn = LibFunc_pow_finite; + longDFn = LibFunc_powl_finite; + break; + default: + return nullptr; + } + //提取sqrt内部的参数x + Value *x=BaseFn->getOperand(0); + NewPow = emitBinaryFloatFnCall(x, B.CreateFMul(y,ConstantFP::get(Ty, 0.5)), + TLI, doubleFn, floatFn, longDFn, + B, CalleeFn->getAttributes()); + } + //使用新节点NewPow替换旧节点Pow + if (NewPow) { + Pow->replaceAllUsesWith(NewPow); + return NewPow; + } + } + return nullptr; +} +//pow(pow(x,y),z)-> pow(x,y*z) +//GCC未做该优化 +Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilderBase &B){ + Value *Base=nullptr,*z=nullptr,*NewPow=nullptr; + Base = Pow->getArgOperand(0);//pow(x,y) + z = Pow->getArgOperand(1);//z + Module *Mod = Pow->getModule(); + Type *Ty = Pow->getType();//DoubleTyID、FloatTyID等 + CallInst *BaseFn = dyn_cast(Base); + //确定base是一个函数调用,并且开启了快速数学标志 + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()){ + Function *CalleeFn = BaseFn->getCalledFunction(); + if (!TargetLibraryInfoImpl::isCallingConvCCompatible(CalleeFn)) + return nullptr; + LibFunc LibFn; + //判断Base是不是一个Intrinsic函数调用、并且是一个Intrinsic为 pow 的调用 + if (CalleeFn && CalleeFn->getIntrinsicID()==Intrinsic::pow){ + //提取 pow 内部的参数x,y + Value *x=BaseFn->getOperand(0); + Value *y=BaseFn->getOperand(1); + //创建新节点y*z + Value *yz=B.CreateFMul(y,z); + //创建新节点NewPow的定义。代表pow(x,y*0.5) + NewPow=B.CreateCall(Intrinsic::getDeclaration(Mod,Pow->getIntrinsicID(),Ty),{x,yz}); + } + //判断是不是库函数调用 + else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(Mod, TLI, LibFn)){ + LibFunc floatFn,doubleFn,longDFn; + switch (LibFn) + { + case LibFunc_powf: + case LibFunc_pow: + case LibFunc_powl: + floatFn = LibFunc_powf; + doubleFn = LibFunc_pow; + longDFn = LibFunc_powl; + break; + case LibFunc_powf_finite: + case LibFunc_pow_finite: + case LibFunc_powl_finite: + floatFn = LibFunc_powf_finite; + doubleFn = LibFunc_pow_finite; + longDFn = LibFunc_powl_finite; + break; + default: + return nullptr; + } + //提取 pow 内部的参数x,y + Value *x=BaseFn->getOperand(0); + Value *y=BaseFn->getOperand(1); + //创建新节点y*z + Value *yz=B.CreateFMul(y,z); + NewPow = emitBinaryFloatFnCall(x, yz, TLI, doubleFn, floatFn, longDFn, + B, CalleeFn->getAttributes()); + } + if (NewPow) { + Pow->replaceAllUsesWith(NewPow); + return NewPow; + } + } + return nullptr; +} +//sqrt(pow(x,y)) -> pow(|x|,y*0.5) +Value *LibCallSimplifier::replaceNestedSqrtAndPowWithPow(CallInst *Sqrt,IRBuilderBase &B){ + Value *OldPow=nullptr,*NewPow=nullptr; + OldPow=Sqrt->getArgOperand(0); + Module *Mod = Sqrt->getModule(); + Type *Ty = Sqrt->getType(); + CallInst *Pow=dyn_cast(OldPow); + if(Pow && Pow->hasOneUse() && Pow->isFast() && Sqrt->isFast()){ + Function *CalleeFn = Pow->getCalledFunction(); + if (!TargetLibraryInfoImpl::isCallingConvCCompatible(CalleeFn)) + return nullptr; + IRBuilderBase::FastMathFlagGuard Guard(B); + B.setFastMathFlags(Sqrt->getFastMathFlags()); + LibFunc LibFn; + if (CalleeFn && CalleeFn->getIntrinsicID()==Intrinsic::pow){ + Value *x=Pow->getOperand(0); + Value *y=Pow->getOperand(1); + //不转为abs(x) + Value *y_05=B.CreateFMul(y,ConstantFP::get(Ty, 0.5)); + NewPow=B.CreateCall(Intrinsic::getDeclaration(Mod,Pow->getIntrinsicID(),Ty),{x,y_05}); + } + //判断是不是库函数调用 + else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(Mod, TLI, LibFn)){ + LibFunc floatFn,doubleFn,longDFn; + switch (LibFn) + { + case LibFunc_powf: + case LibFunc_pow: + case LibFunc_powl: + floatFn = LibFunc_powf; + doubleFn = LibFunc_pow; + longDFn = LibFunc_powl; + break; + case LibFunc_powf_finite: + case LibFunc_pow_finite: + case LibFunc_powl_finite: + floatFn = LibFunc_powf_finite; + doubleFn = LibFunc_pow_finite; + longDFn = LibFunc_powl_finite; + break; + default: + return nullptr; + } + //不转为abs(x) + Value *x=Pow->getOperand(0); + Value *y=Pow->getOperand(1); + Value *y_05=B.CreateFMul(y,ConstantFP::get(Ty, 0.5)); + NewPow = emitBinaryFloatFnCall(x, y_05, TLI, doubleFn, floatFn, longDFn, + B, CalleeFn->getAttributes()); + } + if (NewPow) { + Sqrt->replaceAllUsesWith(NewPow); + return NewPow; + } + } + return nullptr; +} +/* +* cbrt(expN(X)) -> expN(x/3) +* cbrt(sqrt(x)) -> pow(x,1/6) +* cbrt(cbrt(x)) -> pow(x,1/9) +*/ +Value *LibCallSimplifier::optimizeCbrt(CallInst *CI, IRBuilderBase &B){ + Module *M = CI->getModule(); + Value *Base = CI->getArgOperand(0); + CallInst *BaseFn = dyn_cast(Base); + Type *Ty = CI->getType(); + Value *Ret=nullptr, *tempRet1=nullptr, *tempRet2=nullptr; + if (!TargetLibraryInfoImpl::isCallingConvCCompatible(CI)) + return nullptr; + IRBuilderBase::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + //确定cbrt内部传输的是一个函数调用 + //并且启用了fast-math标志 + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && CI->isFast()){ + //检查内部是否是库函数调用 + LibFunc LibFn; + Function *CalleeFn = BaseFn->getCalledFunction(); + Value *x; + //检查内部是否是intrinsic调用 + if(IntrinsicInst *II = dyn_cast(BaseFn)){ + Intrinsic::ID IntrinsicID = II->getIntrinsicID(); + switch (IntrinsicID) + { + //cbrt(exp(X)) -> exp(x/3) + //cbrt(exp2(X)) -> exp2(x/3) + case Intrinsic::exp: + case Intrinsic::exp2: + x = BaseFn->getOperand(0); + Ret = B.CreateCall(Intrinsic::getDeclaration(M,IntrinsicID,Ty), + B.CreateFDiv(x,ConstantFP::get(Ty, 3.0))); + break; + //cbrt(sqrt(x)) -> pow(x,1/6) + case Intrinsic::sqrt: + x = BaseFn->getOperand(0); + Ret = B.CreateCall(Intrinsic::getDeclaration(M,Intrinsic::pow,Ty), + {x,B.CreateFDiv(ConstantFP::get(Ty, 1.0),ConstantFP::get(Ty, 6.0))} + ); + break; + default: + return nullptr; + } + } + else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(M, TLI, LibFn)){ + LibFunc floatFn,doubleFn,longDFn; + //yong + switch(LibFn){ + //cbrt(exp(X)) -> exp(x/3) + case LibFunc_exp: + case LibFunc_expf: + case LibFunc_expl: + x=BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), TLI, + LibFunc_exp, LibFunc_expf, LibFunc_expl, B, CalleeFn->getAttributes()); + break; + case LibFunc_exp_finite: + case LibFunc_expf_finite: + case LibFunc_expl_finite: + x=BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), TLI, + LibFunc_exp_finite, LibFunc_expf_finite, LibFunc_expl_finite, + B, CalleeFn->getAttributes()); + break; + //cbrt(exp2(X)) -> exp2(x/3) + case LibFunc_exp2: + case LibFunc_exp2f: + case LibFunc_exp2l: + x=BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), TLI, + LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l, B, CalleeFn->getAttributes()); + break; + case LibFunc_exp2_finite: + case LibFunc_exp2f_finite: + case LibFunc_exp2l_finite: + x=BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), TLI, + LibFunc_exp2_finite, LibFunc_exp2f_finite, LibFunc_exp2l_finite, + B, CalleeFn->getAttributes()); + break; + //cbrt(sqrt(x)) -> pow(x,1/6) + case LibFunc_sqrt: + case LibFunc_sqrtf: + case LibFunc_sqrtl: + x=BaseFn->getOperand(0); + Ret = emitBinaryFloatFnCall(x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), + ConstantFP::get(Ty, 6.0)), TLI, LibFunc_pow, LibFunc_powf, + LibFunc_powl, B, BaseFn->getAttributes()); + break; + //cbrt(sqrt(x)) -> pow(x,1/6) + case LibFunc_sqrt_finite: + case LibFunc_sqrtf_finite: + case LibFunc_sqrtl_finite: + x=BaseFn->getOperand(0); + Ret = emitBinaryFloatFnCall(x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), + ConstantFP::get(Ty, 6.0)), TLI, LibFunc_pow_finite, LibFunc_powf_finite, + LibFunc_powl_finite, B, BaseFn->getAttributes()); + break; + //cbrt(cbrt(x)) -> pow(x,1/9) + case LibFunc_cbrt: + case LibFunc_cbrtf: + case LibFunc_cbrtl: + x=BaseFn->getOperand(0); + //x>=0时pow(x,1/9) + tempRet1 = emitBinaryFloatFnCall(x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), + ConstantFP::get(Ty, 9.0)), TLI, LibFunc_pow, LibFunc_powf, + LibFunc_powl, B, BaseFn->getAttributes()); + //x<0时-pow(-x,1/9) + tempRet2 = B.CreateFNeg( + emitBinaryFloatFnCall(B.CreateFNeg(x), B.CreateFDiv(ConstantFP::get(Ty, 1.0), + ConstantFP::get(Ty, 9.0)), TLI, LibFunc_pow, LibFunc_powf, + LibFunc_powl, B, BaseFn->getAttributes())); + Ret = B.CreateSelect( + B.CreateFCmpOGE(x,ConstantFP::get(Ty, 0.0)), + tempRet1, + tempRet2); + break; + default: + return nullptr; + } + } + if (Ret){ + CI->replaceAllUsesWith(Ret); + return Ret; + } + } + return nullptr; +} static Value *createPowWithIntegerExponent(Value *Base, Value *Expo, Module *M, IRBuilderBase &B) { Value *Args[] = {Base, Expo}; @@ -2021,6 +2333,12 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilderBase &B) { if (Value *Sqrt = replacePowWithSqrt(Pow, B)) return Sqrt; + //pow(sqrt(x),y) -> pow(x,y*0.5) + if (Value *V = replaceNestedPowAndSqrtWithPow(Pow, B)) + return V; + //pow(pow(x,y),z)-> pow(x,y*z) + if (Value *V = replaceNestedPowAndPowWithPow(Pow, B)) + return V; // If we can approximate pow: // pow(x, n) -> powi(x, n) * sqrt(x) if n has exactly a 0.5 fraction @@ -2313,7 +2631,10 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { if (!CI->isFast()) return Ret; - + //sqrt(pow(x,y)) -> pow(|x|,y*0.5) + if(Value *V=replaceNestedSqrtAndPowWithPow(CI,B)) + return V; + Instruction *I = dyn_cast(CI->getArgOperand(0)); if (!I || I->getOpcode() != Instruction::FMul || !I->isFast()) return Ret; @@ -3274,6 +3595,9 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_powf: case LibFunc_pow: case LibFunc_powl: + case LibFunc_pow_finite: + case LibFunc_powf_finite: + case LibFunc_powl_finite: return optimizePow(CI, Builder); case LibFunc_exp2l: case LibFunc_exp2: @@ -3286,6 +3610,9 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_sqrtf: case LibFunc_sqrt: case LibFunc_sqrtl: + case LibFunc_sqrtf_finite: + case LibFunc_sqrt_finite: + case LibFunc_sqrtl_finite: return optimizeSqrt(CI, Builder); case LibFunc_logf: case LibFunc_log: @@ -3327,7 +3654,7 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_asinh: case LibFunc_atan: case LibFunc_atanh: - case LibFunc_cbrt: + //case LibFunc_cbrt: case LibFunc_cosh: case LibFunc_exp: case LibFunc_exp10: @@ -3354,6 +3681,10 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_cabsf: case LibFunc_cabsl: return optimizeCAbs(CI, Builder); + case LibFunc_cbrtf: + case LibFunc_cbrt: + case LibFunc_cbrtl: + return optimizeCbrt(CI, Builder); default: return nullptr; } -- Gitee From 6a947776536c8145b1a2c4a0e93272a1a027bcf4 Mon Sep 17 00:00:00 2001 From: softwarezhen Date: Wed, 16 Aug 2023 01:27:04 +0000 Subject: [PATCH 2/2] =?UTF-8?q?update=20llvm/include/llvm/Transforms/Utils?= =?UTF-8?q?/SimplifyLibCalls.h.=20=E5=A2=9E=E5=8A=A0=E5=A4=84=E7=90=86?= =?UTF-8?q?=E4=BB=A5=E4=B8=8B=E5=BA=93=E5=87=BD=E6=95=B0=E5=90=88=E5=B9=B6?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E7=9A=84=E5=87=BD=E6=95=B0=E5=A3=B0=E6=98=8E?= =?UTF-8?q?=EF=BC=9A=20pow=20(sqrt(x),=20y)=20->=20pow=20(x,=20y*0.5)=20po?= =?UTF-8?q?w=20(pow=20(x,=20y),=20z)=20->=20pow=20(x,=20y*z)=20sqrt=20(Nro?= =?UTF-8?q?ot(x))=20->=20pow(x,1/(2*N))=20sqrt=20(pow=20(x,=20y))=20->=20p?= =?UTF-8?q?ow=20(|x|,=20y*0.5)=20cbrt(exp(X))=20->=20exp(x/3)=20cbrt(exp2(?= =?UTF-8?q?X))=20->=20exp2(x/3)=20cbrt(sqrt(x))=20->=20pow(x,1/6)=20cbrt(c?= =?UTF-8?q?brt(x))=20->=20pow(x,1/9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: softwarezhen --- llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h index 1b2482a2363d..dedf8482aab8 100644 --- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -190,6 +190,10 @@ private: Value *optimizePow(CallInst *CI, IRBuilderBase &B); Value *replacePowWithExp(CallInst *Pow, IRBuilderBase &B); Value *replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B); + Value *replaceNestedPowAndSqrtWithPow(CallInst *Pow, IRBuilderBase &B); + Value *replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilderBase &B); + Value *replaceNestedSqrtAndPowWithPow(CallInst *Sqrt,IRBuilderBase &B); + Value *optimizeCbrt(CallInst *CI, IRBuilderBase &B); Value *optimizeExp2(CallInst *CI, IRBuilderBase &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilderBase &B); Value *optimizeLog(CallInst *CI, IRBuilderBase &B); -- Gitee