91 Star 502 Fork 154

平凯星辰(北京)科技有限公司/tidb

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
util.go 11.28 KB
一键复制 编辑 原始数据 按行查看 历史
// Copyright 2016 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"strconv"
"strings"
"time"
"unicode"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/hack"
)
// Filter the input expressions, append the results to result.
func Filter(result []Expression, input []Expression, filter func(Expression) bool) []Expression {
for _, e := range input {
if filter(e) {
result = append(result, e)
}
}
return result
}
// ExtractColumns extracts all columns from an expression.
func ExtractColumns(expr Expression) (cols []*Column) {
// Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning.
result := make([]*Column, 0, 8)
return extractColumns(result, expr, nil)
}
// ExtractColumnsFromExpressions is a more efficient version of ExtractColumns for batch operation.
// filter can be nil, or a function to filter the result column.
// It's often observed that the pattern of the caller like this:
//
// cols := ExtractColumns(...)
// for _, col := range cols {
// if xxx(col) {...}
// }
//
// Provide an additional filter argument, this can be done in one step.
// To avoid allocation for cols that not need.
func ExtractColumnsFromExpressions(result []*Column, exprs []Expression, filter func(*Column) bool) []*Column {
for _, expr := range exprs {
result = extractColumns(result, expr, filter)
}
return result
}
func extractColumns(result []*Column, expr Expression, filter func(*Column) bool) []*Column {
switch v := expr.(type) {
case *Column:
if filter == nil || filter(v) {
result = append(result, v)
}
case *ScalarFunction:
for _, arg := range v.GetArgs() {
result = extractColumns(result, arg, filter)
}
}
return result
}
// ColumnSubstitute substitutes the columns in filter to expressions in select fields.
// e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k.
func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Expression {
switch v := expr.(type) {
case *Column:
id := schema.ColumnIndex(v)
if id == -1 {
return v
}
return newExprs[id].Clone()
case *ScalarFunction:
if v.FuncName.L == ast.Cast {
newFunc := v.Clone().(*ScalarFunction)
newFunc.GetArgs()[0] = ColumnSubstitute(newFunc.GetArgs()[0], schema, newExprs)
return newFunc
}
newArgs := make([]Expression, 0, len(v.GetArgs()))
for _, arg := range v.GetArgs() {
newArgs = append(newArgs, ColumnSubstitute(arg, schema, newExprs))
}
return NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, newArgs...)
}
return expr
}
// getValidPrefix gets a prefix of string which can parsed to a number with base. the minimum base is 2 and the maximum is 36.
func getValidPrefix(s string, base int64) string {
var (
validLen int
upper rune
)
switch {
case base >= 2 && base <= 9:
upper = rune('0' + base)
case base <= 36:
upper = rune('A' + base - 10)
default:
return ""
}
Loop:
for i := 0; i < len(s); i++ {
c := rune(s[i])
switch {
case unicode.IsDigit(c) || unicode.IsLower(c) || unicode.IsUpper(c):
c = unicode.ToUpper(c)
if c < upper {
validLen = i + 1
} else {
break Loop
}
case c == '+' || c == '-':
if i != 0 {
break Loop
}
default:
break Loop
}
}
if validLen > 1 && s[0] == '+' {
return s[1:validLen]
}
return s[:validLen]
}
// SubstituteCorCol2Constant will substitute correlated column to constant value which it contains.
// If the args of one scalar function are all constant, we will substitute it to constant.
func SubstituteCorCol2Constant(expr Expression) (Expression, error) {
switch x := expr.(type) {
case *ScalarFunction:
allConstant := true
newArgs := make([]Expression, 0, len(x.GetArgs()))
for _, arg := range x.GetArgs() {
newArg, err := SubstituteCorCol2Constant(arg)
if err != nil {
return nil, errors.Trace(err)
}
_, ok := newArg.(*Constant)
newArgs = append(newArgs, newArg)
allConstant = allConstant && ok
}
if allConstant {
val, err := x.Eval(nil)
if err != nil {
return nil, errors.Trace(err)
}
return &Constant{Value: val, RetType: x.GetType()}, nil
}
var newSf Expression
if x.FuncName.L == ast.Cast {
newSf = BuildCastFunction(x.GetCtx(), newArgs[0], x.RetType)
} else {
newSf = NewFunctionInternal(x.GetCtx(), x.FuncName.L, x.GetType(), newArgs...)
}
return newSf, nil
case *CorrelatedColumn:
return &Constant{Value: *x.Data, RetType: x.GetType()}, nil
case *Constant:
if x.DeferredExpr != nil {
newExpr := FoldConstant(x)
return &Constant{Value: newExpr.(*Constant).Value, RetType: x.GetType()}, nil
}
}
return expr.Clone(), nil
}
// timeZone2Duration converts timezone whose format should satisfy the regular condition
// `(^(+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^+13:00$)` to time.Duration.
func timeZone2Duration(tz string) time.Duration {
sign := 1
if strings.HasPrefix(tz, "-") {
sign = -1
}
i := strings.Index(tz, ":")
h, err := strconv.Atoi(tz[1:i])
terror.Log(errors.Trace(err))
m, err := strconv.Atoi(tz[i+1:])
terror.Log(errors.Trace(err))
return time.Duration(sign) * (time.Duration(h)*time.Hour + time.Duration(m)*time.Minute)
}
var oppositeOp = map[string]string{
ast.LT: ast.GE,
ast.GE: ast.LT,
ast.GT: ast.LE,
ast.LE: ast.GT,
ast.EQ: ast.NE,
ast.NE: ast.EQ,
}
// a op b is equal to b symmetricOp a
var symmetricOp = map[opcode.Op]opcode.Op{
opcode.LT: opcode.GT,
opcode.GE: opcode.LE,
opcode.GT: opcode.LT,
opcode.LE: opcode.GE,
opcode.EQ: opcode.EQ,
opcode.NE: opcode.NE,
opcode.NullEQ: opcode.NullEQ,
}
// PushDownNot pushes the `not` function down to the expression's arguments.
func PushDownNot(ctx sessionctx.Context, expr Expression, not bool) Expression {
if f, ok := expr.(*ScalarFunction); ok {
switch f.FuncName.L {
case ast.UnaryNot:
return PushDownNot(f.GetCtx(), f.GetArgs()[0], !not)
case ast.LT, ast.GE, ast.GT, ast.LE, ast.EQ, ast.NE:
if not {
return NewFunctionInternal(f.GetCtx(), oppositeOp[f.FuncName.L], f.GetType(), f.GetArgs()...)
}
for i, arg := range f.GetArgs() {
f.GetArgs()[i] = PushDownNot(f.GetCtx(), arg, false)
}
return f
case ast.LogicAnd:
if not {
args := f.GetArgs()
for i, a := range args {
args[i] = PushDownNot(f.GetCtx(), a, true)
}
return NewFunctionInternal(f.GetCtx(), ast.LogicOr, f.GetType(), args...)
}
for i, arg := range f.GetArgs() {
f.GetArgs()[i] = PushDownNot(f.GetCtx(), arg, false)
}
return f
case ast.LogicOr:
if not {
args := f.GetArgs()
for i, a := range args {
args[i] = PushDownNot(f.GetCtx(), a, true)
}
return NewFunctionInternal(f.GetCtx(), ast.LogicAnd, f.GetType(), args...)
}
for i, arg := range f.GetArgs() {
f.GetArgs()[i] = PushDownNot(f.GetCtx(), arg, false)
}
return f
}
}
if not {
expr = NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), expr)
}
return expr
}
// Contains tests if `exprs` contains `e`.
func Contains(exprs []Expression, e Expression) bool {
for _, expr := range exprs {
if e == expr {
return true
}
}
return false
}
// ExtractFiltersFromDNFs checks whether the cond is DNF. If so, it will get the extracted part and the remained part.
// The original DNF will be replaced by the remained part or just be deleted if remained part is nil.
// And the extracted part will be appended to the end of the orignal slice.
func ExtractFiltersFromDNFs(ctx sessionctx.Context, conditions []Expression) []Expression {
var allExtracted []Expression
for i := len(conditions) - 1; i >= 0; i-- {
if sf, ok := conditions[i].(*ScalarFunction); ok && sf.FuncName.L == ast.LogicOr {
extracted, remained := extractFiltersFromDNF(ctx, sf)
allExtracted = append(allExtracted, extracted...)
if remained == nil {
conditions = append(conditions[:i], conditions[i+1:]...)
} else {
conditions[i] = remained
}
}
}
return append(conditions, allExtracted...)
}
// extractFiltersFromDNF extracts the same condition that occurs in every DNF item and remove them from dnf leaves.
func extractFiltersFromDNF(ctx sessionctx.Context, dnfFunc *ScalarFunction) ([]Expression, Expression) {
dnfItems := FlattenDNFConditions(dnfFunc)
sc := ctx.GetSessionVars().StmtCtx
codeMap := make(map[string]int)
hashcode2Expr := make(map[string]Expression)
for i, dnfItem := range dnfItems {
innerMap := make(map[string]struct{})
cnfItems := SplitCNFItems(dnfItem)
for _, cnfItem := range cnfItems {
code := cnfItem.HashCode(sc)
if i == 0 {
codeMap[hack.String(code)] = 1
hashcode2Expr[hack.String(code)] = cnfItem
} else if _, ok := codeMap[hack.String(code)]; ok {
// We need this check because there may be the case like `select * from t, t1 where (t.a=t1.a and t.a=t1.a) or (something).
// We should make sure that the two `t.a=t1.a` contributes only once.
// TODO: do this out of this function.
if _, ok = innerMap[hack.String(code)]; !ok {
codeMap[hack.String(code)]++
innerMap[hack.String(code)] = struct{}{}
}
}
}
}
// We should make sure that this item occurs in every DNF item.
for hashcode, cnt := range codeMap {
if cnt < len(dnfItems) {
delete(hashcode2Expr, hashcode)
}
}
if len(hashcode2Expr) == 0 {
return nil, dnfFunc
}
newDNFItems := make([]Expression, 0, len(dnfItems))
onlyNeedExtracted := false
for _, dnfItem := range dnfItems {
cnfItems := SplitCNFItems(dnfItem)
newCNFItems := make([]Expression, 0, len(cnfItems))
for _, cnfItem := range cnfItems {
code := cnfItem.HashCode(sc)
_, ok := hashcode2Expr[hack.String(code)]
if !ok {
newCNFItems = append(newCNFItems, cnfItem)
}
}
// If the extracted part is just one leaf of the DNF expression. Then the value of the total DNF expression is
// always the same with the value of the extracted part.
if len(newCNFItems) == 0 {
onlyNeedExtracted = true
break
}
newDNFItems = append(newDNFItems, ComposeCNFCondition(ctx, newCNFItems...))
}
extractedExpr := make([]Expression, 0, len(hashcode2Expr))
for _, expr := range hashcode2Expr {
extractedExpr = append(extractedExpr, expr)
}
if onlyNeedExtracted {
return extractedExpr, nil
}
return extractedExpr, ComposeDNFCondition(ctx, newDNFItems...)
}
// DisableParseJSONFlag4Expr disables ParseToJSONFlag for `expr` except Column.
// We should not *PARSE* a string as JSON under some scenarios. ParseToJSONFlag
// is 0 for JSON column yet, so we can skip it. Moreover, Column.RetType refers
// to the infoschema, if we modify it, data race may happen if another goroutine
// read from the infoschema at the same time.
func DisableParseJSONFlag4Expr(expr Expression) {
if _, isColumn := expr.(*Column); isColumn {
return
}
expr.GetType().Flag &= ^mysql.ParseToJSONFlag
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/pingcap/tidb.git
git@gitee.com:pingcap/tidb.git
pingcap
tidb
tidb
v2.0.11

搜索帮助