90 Star 491 Fork 151

平凯星辰(北京)科技有限公司/tidb

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
scalar_function.go 7.48 KB
一键复制 编辑 原始数据 按行查看 历史
// Copyright 2016 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"bytes"
"fmt"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/hack"
"github.com/pkg/errors"
)
// ScalarFunction is the function that returns a value.
type ScalarFunction struct {
FuncName model.CIStr
// RetType is the type that ScalarFunction returns.
// TODO: Implement type inference here, now we use ast's return type temporarily.
RetType *types.FieldType
Function builtinFunc
hashcode []byte
}
// GetArgs gets arguments of function.
func (sf *ScalarFunction) GetArgs() []Expression {
return sf.Function.getArgs()
}
// GetCtx gets the context of function.
func (sf *ScalarFunction) GetCtx() sessionctx.Context {
return sf.Function.getCtx()
}
// String implements fmt.Stringer interface.
func (sf *ScalarFunction) String() string {
var buffer bytes.Buffer
fmt.Fprintf(&buffer, "%s(", sf.FuncName.L)
for i, arg := range sf.GetArgs() {
buffer.WriteString(arg.String())
if i+1 != len(sf.GetArgs()) {
buffer.WriteString(", ")
}
}
buffer.WriteString(")")
return buffer.String()
}
// MarshalJSON implements json.Marshaler interface.
func (sf *ScalarFunction) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("\"%s\"", sf)), nil
}
// NewFunction creates a new scalar function or constant.
func NewFunction(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
if retType == nil {
return nil, errors.Errorf("RetType cannot be nil for ScalarFunction.")
}
if funcName == ast.Cast {
return BuildCastFunction(ctx, args[0], retType), nil
}
fc, ok := funcs[funcName]
if !ok {
return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", funcName)
}
funcArgs := make([]Expression, len(args))
copy(funcArgs, args)
f, err := fc.getFunction(ctx, funcArgs)
if err != nil {
return nil, errors.Trace(err)
}
if builtinRetTp := f.getRetTp(); builtinRetTp.Tp != mysql.TypeUnspecified || retType.Tp == mysql.TypeUnspecified {
retType = builtinRetTp
}
sf := &ScalarFunction{
FuncName: model.NewCIStr(funcName),
RetType: retType,
Function: f,
}
return FoldConstant(sf), nil
}
// NewFunctionInternal is similar to NewFunction, but do not returns error, should only be used internally.
func NewFunctionInternal(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) Expression {
expr, err := NewFunction(ctx, funcName, retType, args...)
terror.Log(errors.Trace(err))
return expr
}
// ScalarFuncs2Exprs converts []*ScalarFunction to []Expression.
func ScalarFuncs2Exprs(funcs []*ScalarFunction) []Expression {
result := make([]Expression, 0, len(funcs))
for _, col := range funcs {
result = append(result, col)
}
return result
}
// Clone implements Expression interface.
func (sf *ScalarFunction) Clone() Expression {
return &ScalarFunction{
FuncName: sf.FuncName,
RetType: sf.RetType,
Function: sf.Function.Clone(),
hashcode: sf.hashcode,
}
}
// GetType implements Expression interface.
func (sf *ScalarFunction) GetType() *types.FieldType {
return sf.RetType
}
// Equal implements Expression interface.
func (sf *ScalarFunction) Equal(ctx sessionctx.Context, e Expression) bool {
fun, ok := e.(*ScalarFunction)
if !ok {
return false
}
if sf.FuncName.L != fun.FuncName.L {
return false
}
return sf.Function.equal(fun.Function)
}
// IsCorrelated implements Expression interface.
func (sf *ScalarFunction) IsCorrelated() bool {
for _, arg := range sf.GetArgs() {
if arg.IsCorrelated() {
return true
}
}
return false
}
// Decorrelate implements Expression interface.
func (sf *ScalarFunction) Decorrelate(schema *Schema) Expression {
for i, arg := range sf.GetArgs() {
sf.GetArgs()[i] = arg.Decorrelate(schema)
}
return sf
}
// Eval implements Expression interface.
func (sf *ScalarFunction) Eval(row chunk.Row) (d types.Datum, err error) {
var (
res interface{}
isNull bool
)
switch tp, evalType := sf.GetType(), sf.GetType().EvalType(); evalType {
case types.ETInt:
var intRes int64
intRes, isNull, err = sf.EvalInt(sf.GetCtx(), row)
if mysql.HasUnsignedFlag(tp.Flag) {
res = uint64(intRes)
} else {
res = intRes
}
case types.ETReal:
res, isNull, err = sf.EvalReal(sf.GetCtx(), row)
case types.ETDecimal:
res, isNull, err = sf.EvalDecimal(sf.GetCtx(), row)
case types.ETDatetime, types.ETTimestamp:
res, isNull, err = sf.EvalTime(sf.GetCtx(), row)
case types.ETDuration:
res, isNull, err = sf.EvalDuration(sf.GetCtx(), row)
case types.ETJson:
res, isNull, err = sf.EvalJSON(sf.GetCtx(), row)
case types.ETString:
res, isNull, err = sf.EvalString(sf.GetCtx(), row)
}
if isNull || err != nil {
d.SetValue(nil)
return d, errors.Trace(err)
}
d.SetValue(res)
return
}
// EvalInt implements Expression interface.
func (sf *ScalarFunction) EvalInt(ctx sessionctx.Context, row chunk.Row) (int64, bool, error) {
return sf.Function.evalInt(row)
}
// EvalReal implements Expression interface.
func (sf *ScalarFunction) EvalReal(ctx sessionctx.Context, row chunk.Row) (float64, bool, error) {
return sf.Function.evalReal(row)
}
// EvalDecimal implements Expression interface.
func (sf *ScalarFunction) EvalDecimal(ctx sessionctx.Context, row chunk.Row) (*types.MyDecimal, bool, error) {
return sf.Function.evalDecimal(row)
}
// EvalString implements Expression interface.
func (sf *ScalarFunction) EvalString(ctx sessionctx.Context, row chunk.Row) (string, bool, error) {
return sf.Function.evalString(row)
}
// EvalTime implements Expression interface.
func (sf *ScalarFunction) EvalTime(ctx sessionctx.Context, row chunk.Row) (types.Time, bool, error) {
return sf.Function.evalTime(row)
}
// EvalDuration implements Expression interface.
func (sf *ScalarFunction) EvalDuration(ctx sessionctx.Context, row chunk.Row) (types.Duration, bool, error) {
return sf.Function.evalDuration(row)
}
// EvalJSON implements Expression interface.
func (sf *ScalarFunction) EvalJSON(ctx sessionctx.Context, row chunk.Row) (json.BinaryJSON, bool, error) {
return sf.Function.evalJSON(row)
}
// HashCode implements Expression interface.
func (sf *ScalarFunction) HashCode(sc *stmtctx.StatementContext) []byte {
if len(sf.hashcode) > 0 {
return sf.hashcode
}
sf.hashcode = append(sf.hashcode, scalarFunctionFlag)
sf.hashcode = codec.EncodeCompactBytes(sf.hashcode, hack.Slice(sf.FuncName.L))
for _, arg := range sf.GetArgs() {
sf.hashcode = append(sf.hashcode, arg.HashCode(sc)...)
}
return sf.hashcode
}
// ResolveIndices implements Expression interface.
func (sf *ScalarFunction) ResolveIndices(schema *Schema) Expression {
newSf := sf.Clone()
newSf.resolveIndices(schema)
return newSf
}
func (sf *ScalarFunction) resolveIndices(schema *Schema) {
for _, arg := range sf.GetArgs() {
arg.resolveIndices(schema)
}
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/pingcap/tidb.git
git@gitee.com:pingcap/tidb.git
pingcap
tidb
tidb
v2.1.0-rc.4

搜索帮助

0d507c66 1850385 C8b1a773 1850385