1 Star 0 Fork 0

zhoujin826/tidb

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
builtin_control.go 21.71 KB
一键复制 编辑 原始数据 按行查看 历史
coocood 提交于 2017-10-22 22:36 . *: Add Row interface (#4859)
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"github.com/cznic/mathutil"
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/types"
"github.com/pingcap/tidb/util/types/json"
"github.com/pingcap/tipb/go-tipb"
)
var (
_ functionClass = &caseWhenFunctionClass{}
_ functionClass = &ifFunctionClass{}
_ functionClass = &ifNullFunctionClass{}
)
var (
_ builtinFunc = &builtinCaseWhenIntSig{}
_ builtinFunc = &builtinCaseWhenRealSig{}
_ builtinFunc = &builtinCaseWhenDecimalSig{}
_ builtinFunc = &builtinCaseWhenStringSig{}
_ builtinFunc = &builtinCaseWhenTimeSig{}
_ builtinFunc = &builtinCaseWhenDurationSig{}
_ builtinFunc = &builtinIfNullIntSig{}
_ builtinFunc = &builtinIfNullRealSig{}
_ builtinFunc = &builtinIfNullDecimalSig{}
_ builtinFunc = &builtinIfNullStringSig{}
_ builtinFunc = &builtinIfNullTimeSig{}
_ builtinFunc = &builtinIfNullDurationSig{}
_ builtinFunc = &builtinIfNullJSONSig{}
_ builtinFunc = &builtinIfIntSig{}
_ builtinFunc = &builtinIfRealSig{}
_ builtinFunc = &builtinIfDecimalSig{}
_ builtinFunc = &builtinIfStringSig{}
_ builtinFunc = &builtinIfTimeSig{}
_ builtinFunc = &builtinIfDurationSig{}
_ builtinFunc = &builtinIfJSONSig{}
)
type caseWhenFunctionClass struct {
baseFunctionClass
}
// Infer result type for builtin IF, IFNULL && NULLIF.
func inferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType {
resultFieldType := &types.FieldType{}
if lhs.Tp == mysql.TypeNull {
*resultFieldType = *rhs
// If both arguments are NULL, make resulting type BINARY(0).
if rhs.Tp == mysql.TypeNull {
resultFieldType.Tp = mysql.TypeString
resultFieldType.Flen, resultFieldType.Decimal = 0, 0
types.SetBinChsClnFlag(resultFieldType)
}
} else if rhs.Tp == mysql.TypeNull {
*resultFieldType = *lhs
} else {
var unsignedFlag uint
evalType := types.AggregateEvalType([]*types.FieldType{lhs, rhs}, &unsignedFlag)
resultFieldType = types.AggFieldType([]*types.FieldType{lhs, rhs})
if evalType == types.ETInt {
resultFieldType.Decimal = 0
} else {
if lhs.Decimal == types.UnspecifiedLength || rhs.Decimal == types.UnspecifiedLength {
resultFieldType.Decimal = types.UnspecifiedLength
} else {
resultFieldType.Decimal = mathutil.Max(lhs.Decimal, rhs.Decimal)
}
}
if types.IsNonBinaryStr(lhs) && !types.IsBinaryStr(rhs) {
resultFieldType.Charset, resultFieldType.Collate, resultFieldType.Flag = charset.CharsetUTF8, charset.CollationUTF8, 0
if mysql.HasBinaryFlag(lhs.Flag) {
resultFieldType.Flag |= mysql.BinaryFlag
}
} else if types.IsNonBinaryStr(rhs) && !types.IsBinaryStr(lhs) {
resultFieldType.Charset, resultFieldType.Collate, resultFieldType.Flag = charset.CharsetUTF8, charset.CollationUTF8, 0
if mysql.HasBinaryFlag(rhs.Flag) {
resultFieldType.Flag |= mysql.BinaryFlag
}
} else if types.IsBinaryStr(lhs) || types.IsBinaryStr(rhs) || !evalType.IsStringKind() {
types.SetBinChsClnFlag(resultFieldType)
} else {
resultFieldType.Charset, resultFieldType.Collate, resultFieldType.Flag = charset.CharsetUTF8, charset.CollationUTF8, 0
}
if evalType == types.ETDecimal || evalType == types.ETInt {
lhsUnsignedFlag, rhsUnsignedFlag := mysql.HasUnsignedFlag(lhs.Flag), mysql.HasUnsignedFlag(rhs.Flag)
lhsFlagLen, rhsFlagLen := 0, 0
if !lhsUnsignedFlag {
lhsFlagLen = 1
}
if !rhsUnsignedFlag {
rhsFlagLen = 1
}
lhsFlen := lhs.Flen - lhsFlagLen
rhsFlen := rhs.Flen - rhsFlagLen
if lhs.Decimal != types.UnspecifiedLength {
lhsFlen -= lhs.Decimal
}
if lhs.Decimal != types.UnspecifiedLength {
rhsFlen -= rhs.Decimal
}
resultFieldType.Flen = mathutil.Max(lhsFlen, rhsFlen) + resultFieldType.Decimal + 1
} else {
resultFieldType.Flen = mathutil.Max(lhs.Flen, rhs.Flen)
}
}
// Fix decimal for int and string.
resultEvalType := resultFieldType.EvalType()
if resultEvalType == types.ETInt {
resultFieldType.Decimal = 0
} else if resultEvalType == types.ETString {
if lhs.Tp != mysql.TypeNull || rhs.Tp != mysql.TypeNull {
resultFieldType.Decimal = types.UnspecifiedLength
}
}
return resultFieldType
}
func (c *caseWhenFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) {
if err = c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
l := len(args)
// Fill in each 'THEN' clause parameter type.
fieldTps := make([]*types.FieldType, 0, (l+1)/2)
decimal, flen, isBinaryStr := args[1].GetType().Decimal, 0, false
for i := 1; i < l; i += 2 {
fieldTps = append(fieldTps, args[i].GetType())
decimal = mathutil.Max(decimal, args[i].GetType().Decimal)
flen = mathutil.Max(flen, args[i].GetType().Flen)
isBinaryStr = isBinaryStr || types.IsBinaryStr(args[i].GetType())
}
if l%2 == 1 {
fieldTps = append(fieldTps, args[l-1].GetType())
decimal = mathutil.Max(decimal, args[l-1].GetType().Decimal)
flen = mathutil.Max(flen, args[l-1].GetType().Flen)
isBinaryStr = isBinaryStr || types.IsBinaryStr(args[l-1].GetType())
}
fieldTp := types.AggFieldType(fieldTps)
tp := fieldTp.EvalType()
if tp == types.ETInt {
decimal = 0
}
fieldTp.Decimal, fieldTp.Flen = decimal, flen
if fieldTp.EvalType().IsStringKind() && !isBinaryStr {
fieldTp.Charset, fieldTp.Collate = mysql.DefaultCharset, mysql.DefaultCollationName
}
// Set retType to BINARY(0) if all arguments are of type NULL.
if fieldTp.Tp == mysql.TypeNull {
fieldTp.Flen, fieldTp.Decimal = 0, -1
types.SetBinChsClnFlag(fieldTp)
}
argTps := make([]types.EvalType, 0, l)
for i := 0; i < l-1; i += 2 {
argTps = append(argTps, types.ETInt, tp)
}
if l%2 == 1 {
argTps = append(argTps, tp)
}
bf := newBaseBuiltinFuncWithTp(ctx, args, tp, argTps...)
bf.tp = fieldTp
switch tp {
case types.ETInt:
bf.tp.Decimal = 0
sig = &builtinCaseWhenIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CaseWhenInt)
case types.ETReal:
sig = &builtinCaseWhenRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CaseWhenReal)
case types.ETDecimal:
sig = &builtinCaseWhenDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CaseWhenDecimal)
case types.ETString:
bf.tp.Decimal = types.UnspecifiedLength
sig = &builtinCaseWhenStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CaseWhenString)
case types.ETDatetime, types.ETTimestamp:
sig = &builtinCaseWhenTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CaseWhenTime)
case types.ETDuration:
sig = &builtinCaseWhenDurationSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CaseWhenDuration)
}
return sig, nil
}
type builtinCaseWhenIntSig struct {
baseBuiltinFunc
}
// evalInt evals a builtinCaseWhenIntSig.
// See https://dev.mysql.com/doc/refman/5.7/en/case.html
func (b *builtinCaseWhenIntSig) evalInt(row types.Row) (ret int64, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
var condition int64
args, l := b.getArgs(), len(b.getArgs())
for i := 0; i < l-1; i += 2 {
condition, isNull, err = args[i].EvalInt(row, sc)
if err != nil {
return 0, isNull, errors.Trace(err)
}
if isNull || condition == 0 {
continue
}
ret, isNull, err = args[i+1].EvalInt(row, sc)
return ret, isNull, errors.Trace(err)
}
// when clause(condition, result) -> args[i], args[i+1]; (i >= 0 && i+1 < l-1)
// else clause -> args[l-1]
// If case clause has else clause, l%2 == 1.
if l%2 == 1 {
ret, isNull, err = args[l-1].EvalInt(row, sc)
return ret, isNull, errors.Trace(err)
}
return ret, true, nil
}
type builtinCaseWhenRealSig struct {
baseBuiltinFunc
}
// evalReal evals a builtinCaseWhenRealSig.
// See https://dev.mysql.com/doc/refman/5.7/en/case.html
func (b *builtinCaseWhenRealSig) evalReal(row types.Row) (ret float64, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
var condition int64
args, l := b.getArgs(), len(b.getArgs())
for i := 0; i < l-1; i += 2 {
condition, isNull, err = args[i].EvalInt(row, sc)
if err != nil {
return 0, isNull, errors.Trace(err)
}
if isNull || condition == 0 {
continue
}
ret, isNull, err = args[i+1].EvalReal(row, sc)
return ret, isNull, errors.Trace(err)
}
// when clause(condition, result) -> args[i], args[i+1]; (i >= 0 && i+1 < l-1)
// else clause -> args[l-1]
// If case clause has else clause, l%2 == 1.
if l%2 == 1 {
ret, isNull, err = args[l-1].EvalReal(row, sc)
return ret, isNull, errors.Trace(err)
}
return ret, true, nil
}
type builtinCaseWhenDecimalSig struct {
baseBuiltinFunc
}
// evalDecimal evals a builtinCaseWhenDecimalSig.
// See https://dev.mysql.com/doc/refman/5.7/en/case.html
func (b *builtinCaseWhenDecimalSig) evalDecimal(row types.Row) (ret *types.MyDecimal, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
var condition int64
args, l := b.getArgs(), len(b.getArgs())
for i := 0; i < l-1; i += 2 {
condition, isNull, err = args[i].EvalInt(row, sc)
if err != nil {
return nil, isNull, errors.Trace(err)
}
if isNull || condition == 0 {
continue
}
ret, isNull, err = args[i+1].EvalDecimal(row, sc)
return ret, isNull, errors.Trace(err)
}
// when clause(condition, result) -> args[i], args[i+1]; (i >= 0 && i+1 < l-1)
// else clause -> args[l-1]
// If case clause has else clause, l%2 == 1.
if l%2 == 1 {
ret, isNull, err = args[l-1].EvalDecimal(row, sc)
return ret, isNull, errors.Trace(err)
}
return ret, true, nil
}
type builtinCaseWhenStringSig struct {
baseBuiltinFunc
}
// evalString evals a builtinCaseWhenStringSig.
// See https://dev.mysql.com/doc/refman/5.7/en/case.html
func (b *builtinCaseWhenStringSig) evalString(row types.Row) (ret string, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
var condition int64
args, l := b.getArgs(), len(b.getArgs())
for i := 0; i < l-1; i += 2 {
condition, isNull, err = args[i].EvalInt(row, sc)
if err != nil {
return "", isNull, errors.Trace(err)
}
if isNull || condition == 0 {
continue
}
ret, isNull, err = args[i+1].EvalString(row, sc)
return ret, isNull, errors.Trace(err)
}
// when clause(condition, result) -> args[i], args[i+1]; (i >= 0 && i+1 < l-1)
// else clause -> args[l-1]
// If case clause has else clause, l%2 == 1.
if l%2 == 1 {
ret, isNull, err = args[l-1].EvalString(row, sc)
return ret, isNull, errors.Trace(err)
}
return ret, true, nil
}
type builtinCaseWhenTimeSig struct {
baseBuiltinFunc
}
// evalTime evals a builtinCaseWhenTimeSig.
// See https://dev.mysql.com/doc/refman/5.7/en/case.html
func (b *builtinCaseWhenTimeSig) evalTime(row types.Row) (ret types.Time, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
var condition int64
args, l := b.getArgs(), len(b.getArgs())
for i := 0; i < l-1; i += 2 {
condition, isNull, err = args[i].EvalInt(row, sc)
if err != nil {
return ret, isNull, errors.Trace(err)
}
if isNull || condition == 0 {
continue
}
ret, isNull, err = args[i+1].EvalTime(row, sc)
return ret, isNull, errors.Trace(err)
}
// when clause(condition, result) -> args[i], args[i+1]; (i >= 0 && i+1 < l-1)
// else clause -> args[l-1]
// If case clause has else clause, l%2 == 1.
if l%2 == 1 {
ret, isNull, err = args[l-1].EvalTime(row, sc)
return ret, isNull, errors.Trace(err)
}
return ret, true, nil
}
type builtinCaseWhenDurationSig struct {
baseBuiltinFunc
}
// evalDuration evals a builtinCaseWhenDurationSig.
// See https://dev.mysql.com/doc/refman/5.7/en/case.html
func (b *builtinCaseWhenDurationSig) evalDuration(row types.Row) (ret types.Duration, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
var condition int64
args, l := b.getArgs(), len(b.getArgs())
for i := 0; i < l-1; i += 2 {
condition, isNull, err = args[i].EvalInt(row, sc)
if err != nil {
return ret, true, errors.Trace(err)
}
if isNull || condition == 0 {
continue
}
ret, isNull, err = args[i+1].EvalDuration(row, sc)
return ret, isNull, errors.Trace(err)
}
// when clause(condition, result) -> args[i], args[i+1]; (i >= 0 && i+1 < l-1)
// else clause -> args[l-1]
// If case clause has else clause, l%2 == 1.
if l%2 == 1 {
ret, isNull, err = args[l-1].EvalDuration(row, sc)
return ret, isNull, errors.Trace(err)
}
return ret, true, nil
}
type ifFunctionClass struct {
baseFunctionClass
}
// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#function_if
func (c *ifFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) {
if err = c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
retTp := inferType4ControlFuncs(args[1].GetType(), args[2].GetType())
evalTps := retTp.EvalType()
bf := newBaseBuiltinFuncWithTp(ctx, args, evalTps, types.ETInt, evalTps, evalTps)
bf.tp = retTp
switch evalTps {
case types.ETInt:
sig = &builtinIfIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfInt)
case types.ETReal:
sig = &builtinIfRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfReal)
case types.ETDecimal:
sig = &builtinIfDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfDecimal)
case types.ETString:
sig = &builtinIfStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfString)
case types.ETDatetime, types.ETTimestamp:
sig = &builtinIfTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfTime)
case types.ETDuration:
sig = &builtinIfDurationSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfDuration)
case types.ETJson:
sig = &builtinIfJSONSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfJson)
}
return sig, nil
}
type builtinIfIntSig struct {
baseBuiltinFunc
}
func (b *builtinIfIntSig) evalInt(row types.Row) (ret int64, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull0, err := b.args[0].EvalInt(row, sc)
if err != nil {
return 0, true, errors.Trace(err)
}
arg1, isNull1, err := b.args[1].EvalInt(row, sc)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, errors.Trace(err)
}
arg2, isNull2, err := b.args[2].EvalInt(row, sc)
return arg2, isNull2, errors.Trace(err)
}
type builtinIfRealSig struct {
baseBuiltinFunc
}
func (b *builtinIfRealSig) evalReal(row types.Row) (ret float64, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull0, err := b.args[0].EvalInt(row, sc)
if err != nil {
return 0, true, errors.Trace(err)
}
arg1, isNull1, err := b.args[1].EvalReal(row, sc)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, errors.Trace(err)
}
arg2, isNull2, err := b.args[2].EvalReal(row, sc)
return arg2, isNull2, errors.Trace(err)
}
type builtinIfDecimalSig struct {
baseBuiltinFunc
}
func (b *builtinIfDecimalSig) evalDecimal(row types.Row) (ret *types.MyDecimal, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull0, err := b.args[0].EvalInt(row, sc)
if err != nil {
return nil, true, errors.Trace(err)
}
arg1, isNull1, err := b.args[1].EvalDecimal(row, sc)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, errors.Trace(err)
}
arg2, isNull2, err := b.args[2].EvalDecimal(row, sc)
return arg2, isNull2, errors.Trace(err)
}
type builtinIfStringSig struct {
baseBuiltinFunc
}
func (b *builtinIfStringSig) evalString(row types.Row) (ret string, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull0, err := b.args[0].EvalInt(row, sc)
if err != nil {
return "", true, errors.Trace(err)
}
arg1, isNull1, err := b.args[1].EvalString(row, sc)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, errors.Trace(err)
}
arg2, isNull2, err := b.args[2].EvalString(row, sc)
return arg2, isNull2, errors.Trace(err)
}
type builtinIfTimeSig struct {
baseBuiltinFunc
}
func (b *builtinIfTimeSig) evalTime(row types.Row) (ret types.Time, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull0, err := b.args[0].EvalInt(row, sc)
if err != nil {
return ret, true, errors.Trace(err)
}
arg1, isNull1, err := b.args[1].EvalTime(row, sc)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, errors.Trace(err)
}
arg2, isNull2, err := b.args[2].EvalTime(row, sc)
return arg2, isNull2, errors.Trace(err)
}
type builtinIfDurationSig struct {
baseBuiltinFunc
}
func (b *builtinIfDurationSig) evalDuration(row types.Row) (ret types.Duration, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull0, err := b.args[0].EvalInt(row, sc)
if err != nil {
return ret, true, errors.Trace(err)
}
arg1, isNull1, err := b.args[1].EvalDuration(row, sc)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, errors.Trace(err)
}
arg2, isNull2, err := b.args[2].EvalDuration(row, sc)
return arg2, isNull2, errors.Trace(err)
}
type builtinIfJSONSig struct {
baseBuiltinFunc
}
func (b *builtinIfJSONSig) evalJSON(row types.Row) (ret json.JSON, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull0, err := b.args[0].EvalInt(row, sc)
if err != nil {
return ret, true, errors.Trace(err)
}
arg1, isNull1, err := b.args[1].EvalJSON(row, sc)
if err != nil {
return ret, true, errors.Trace(err)
}
arg2, isNull2, err := b.args[2].EvalJSON(row, sc)
if err != nil {
return ret, true, errors.Trace(err)
}
switch {
case isNull0 || arg0 == 0:
ret, isNull = arg2, isNull2
case arg0 != 0:
ret, isNull = arg1, isNull1
}
return
}
type ifNullFunctionClass struct {
baseFunctionClass
}
func (c *ifNullFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) {
if err = errors.Trace(c.verifyArgs(args)); err != nil {
return nil, errors.Trace(err)
}
lhs, rhs := args[0].GetType(), args[1].GetType()
retTp := inferType4ControlFuncs(lhs, rhs)
retTp.Flag |= (lhs.Flag & mysql.NotNullFlag) | (rhs.Flag & mysql.NotNullFlag)
if lhs.Tp == mysql.TypeNull && rhs.Tp == mysql.TypeNull {
retTp.Tp = mysql.TypeNull
retTp.Flen, retTp.Decimal = 0, -1
types.SetBinChsClnFlag(retTp)
}
evalTps := retTp.EvalType()
bf := newBaseBuiltinFuncWithTp(ctx, args, evalTps, evalTps, evalTps)
bf.tp = retTp
switch evalTps {
case types.ETInt:
sig = &builtinIfNullIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfNullInt)
case types.ETReal:
sig = &builtinIfNullRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfNullReal)
case types.ETDecimal:
sig = &builtinIfNullDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfNullDecimal)
case types.ETString:
sig = &builtinIfNullStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfNullString)
case types.ETDatetime, types.ETTimestamp:
sig = &builtinIfNullTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfNullTime)
case types.ETDuration:
sig = &builtinIfNullDurationSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfNullDuration)
case types.ETJson:
sig = &builtinIfNullJSONSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IfNullJson)
}
return sig, nil
}
type builtinIfNullIntSig struct {
baseBuiltinFunc
}
func (b *builtinIfNullIntSig) evalInt(row types.Row) (int64, bool, error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull, err := b.args[0].EvalInt(row, sc)
if !isNull || err != nil {
return arg0, err != nil, errors.Trace(err)
}
arg1, isNull, err := b.args[1].EvalInt(row, sc)
return arg1, isNull || err != nil, errors.Trace(err)
}
type builtinIfNullRealSig struct {
baseBuiltinFunc
}
func (b *builtinIfNullRealSig) evalReal(row types.Row) (float64, bool, error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull, err := b.args[0].EvalReal(row, sc)
if !isNull || err != nil {
return arg0, err != nil, errors.Trace(err)
}
arg1, isNull, err := b.args[1].EvalReal(row, sc)
return arg1, isNull || err != nil, errors.Trace(err)
}
type builtinIfNullDecimalSig struct {
baseBuiltinFunc
}
func (b *builtinIfNullDecimalSig) evalDecimal(row types.Row) (*types.MyDecimal, bool, error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull, err := b.args[0].EvalDecimal(row, sc)
if !isNull || err != nil {
return arg0, err != nil, errors.Trace(err)
}
arg1, isNull, err := b.args[1].EvalDecimal(row, sc)
return arg1, isNull || err != nil, errors.Trace(err)
}
type builtinIfNullStringSig struct {
baseBuiltinFunc
}
func (b *builtinIfNullStringSig) evalString(row types.Row) (string, bool, error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull, err := b.args[0].EvalString(row, sc)
if !isNull || err != nil {
return arg0, err != nil, errors.Trace(err)
}
arg1, isNull, err := b.args[1].EvalString(row, sc)
return arg1, isNull || err != nil, errors.Trace(err)
}
type builtinIfNullTimeSig struct {
baseBuiltinFunc
}
func (b *builtinIfNullTimeSig) evalTime(row types.Row) (types.Time, bool, error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull, err := b.args[0].EvalTime(row, sc)
if !isNull || err != nil {
return arg0, err != nil, errors.Trace(err)
}
arg1, isNull, err := b.args[1].EvalTime(row, sc)
return arg1, isNull || err != nil, errors.Trace(err)
}
type builtinIfNullDurationSig struct {
baseBuiltinFunc
}
func (b *builtinIfNullDurationSig) evalDuration(row types.Row) (types.Duration, bool, error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull, err := b.args[0].EvalDuration(row, sc)
if !isNull || err != nil {
return arg0, err != nil, errors.Trace(err)
}
arg1, isNull, err := b.args[1].EvalDuration(row, sc)
return arg1, isNull || err != nil, errors.Trace(err)
}
type builtinIfNullJSONSig struct {
baseBuiltinFunc
}
func (b *builtinIfNullJSONSig) evalJSON(row types.Row) (json.JSON, bool, error) {
sc := b.ctx.GetSessionVars().StmtCtx
arg0, isNull, err := b.args[0].EvalJSON(row, sc)
if !isNull {
return arg0, err != nil, errors.Trace(err)
}
arg1, isNull, err := b.args[1].EvalJSON(row, sc)
return arg1, isNull || err != nil, errors.Trace(err)
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhoujin826/tidb.git
git@gitee.com:zhoujin826/tidb.git
zhoujin826
tidb
tidb
v1.1.0-alpha

搜索帮助