15 Star 69 Fork 15

He3DB / He3Proxy

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
conn_select.go 17.07 KB
一键复制 编辑 原始数据 按行查看 历史
wangyao 提交于 2022-08-30 16:43 . improve: change repo addr
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780
// Copyright 2016 The kingshard Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package server
import (
"bytes"
"fmt"
"strconv"
"strings"
"gitee.com/he3db/he3proxy/core/errors"
"gitee.com/he3db/he3proxy/core/golog"
"gitee.com/he3db/he3proxy/core/hack"
"gitee.com/he3db/he3proxy/mysql"
"gitee.com/he3db/he3proxy/sqlparser"
)
const (
MasterComment = "/*master*/"
SumFunc = "sum"
CountFunc = "count"
MaxFunc = "max"
MinFunc = "min"
LastInsertIdFunc = "last_insert_id"
FUNC_EXIST = 1
)
var funcNameMap = map[string]int{
"sum": FUNC_EXIST,
"count": FUNC_EXIST,
"max": FUNC_EXIST,
"min": FUNC_EXIST,
"last_insert_id": FUNC_EXIST,
}
func (c *ClientConn) handleFieldList(data []byte) error {
index := bytes.IndexByte(data, 0x00)
table := string(data[0:index])
wildcard := string(data[index+1:])
if c.schema == nil {
return mysql.NewDefaultError(mysql.ER_NO_DB_ERROR)
}
nodeName := c.schema.rule.GetRule(c.db, table).Nodes[0]
n := c.proxy.GetNode(nodeName)
co, err := c.getBackendConn(n, false)
defer c.closeConn(co, false)
if err != nil {
return err
}
if err = co.UseDB(c.db); err != nil {
//reset the database to null
c.db = ""
return err
}
if fs, err := co.FieldList(table, wildcard); err != nil {
return err
} else {
return c.writeFieldList(c.status, fs)
}
}
func (c *ClientConn) writeFieldList(status uint16, fs []*mysql.Field) error {
c.affectedRows = int64(-1)
var err error
total := make([]byte, 0, 1024)
data := make([]byte, 4, 512)
for _, v := range fs {
data = data[0:4]
data = append(data, v.Dump()...)
total, err = c.writePacketBatch(total, data, false)
if err != nil {
return err
}
}
_, err = c.writeEOFBatch(total, status, true)
return err
}
//处理select语句
func (c *ClientConn) handleSelect(stmt *sqlparser.Select, args []interface{}) error {
var fromSlave = true
plan, err := c.schema.rule.BuildPlan(c.db, stmt)
if err != nil {
return err
}
if 0 < len(stmt.Comments) {
comment := string(stmt.Comments[0])
if 0 < len(comment) && strings.ToLower(comment) == MasterComment {
fromSlave = false
}
}
conns, err := c.getShardConns(fromSlave, plan)
if err != nil {
golog.Error("ClientConn", "handleSelect", err.Error(), c.connectionId)
return err
}
if conns == nil {
r := c.newEmptyResultset(stmt)
return c.writeResultset(c.status, r)
}
var rs []*mysql.Result
rs, err = c.executeInMultiNodes(conns, plan.RewrittenSqls, args)
c.closeShardConns(conns, false)
if err != nil {
golog.Error("ClientConn", "handleSelect", err.Error(), c.connectionId)
return err
}
err = c.mergeSelectResult(rs, stmt)
if err != nil {
golog.Error("ClientConn", "handleSelect", err.Error(), c.connectionId)
}
return err
}
func (c *ClientConn) mergeSelectResult(rs []*mysql.Result, stmt *sqlparser.Select) error {
var r *mysql.Result
var err error
if len(stmt.GroupBy) == 0 {
r, err = c.buildSelectOnlyResult(rs, stmt)
} else {
//group by
r, err = c.buildSelectGroupByResult(rs, stmt)
}
if err != nil {
return err
}
c.sortSelectResult(r.Resultset, stmt)
//to do, add log here, sort may error because order by key not exist in resultset fields
if err := c.limitSelectResult(r.Resultset, stmt); err != nil {
return err
}
return c.writeResultset(r.Status, r.Resultset)
}
//only process last_inser_id
func (c *ClientConn) handleSimpleSelect(stmt *sqlparser.SimpleSelect) error {
nonStarExpr, _ := stmt.SelectExprs[0].(*sqlparser.NonStarExpr)
var name = hack.String(nonStarExpr.As)
if name == "" {
name = "last_insert_id()"
}
var column = 1
var rows [][]string
var names = []string{
name,
}
var t = fmt.Sprintf("%d", c.lastInsertId)
rows = append(rows, []string{t})
r := new(mysql.Resultset)
var values = make([][]interface{}, len(rows))
for i := range rows {
values[i] = make([]interface{}, column)
for j := range rows[i] {
values[i][j] = rows[i][j]
}
}
r, _ = c.buildResultset(nil, names, values)
return c.writeResultset(c.status, r)
}
//build select result with group by opt
func (c *ClientConn) buildSelectGroupByResult(rs []*mysql.Result,
stmt *sqlparser.Select) (*mysql.Result, error) {
var err error
var r *mysql.Result
var groupByIndexs []int
fieldLen := len(rs[0].Fields)
startIndex := fieldLen - len(stmt.GroupBy)
for startIndex < fieldLen {
groupByIndexs = append(groupByIndexs, startIndex)
startIndex++
}
funcExprs := c.getFuncExprs(stmt)
if len(funcExprs) == 0 {
r, err = c.mergeGroupByWithoutFunc(rs, groupByIndexs)
} else {
r, err = c.mergeGroupByWithFunc(rs, groupByIndexs, funcExprs)
}
if err != nil {
return nil, err
}
//build result
names := make([]string, 0, 2)
if 0 < len(r.Values) {
r.Fields = r.Fields[:groupByIndexs[0]]
for i := 0; i < len(r.Fields) && i < groupByIndexs[0]; i++ {
names = append(names, string(r.Fields[i].Name))
}
//delete group by columns in Values
for i := 0; i < len(r.Values); i++ {
r.Values[i] = r.Values[i][:groupByIndexs[0]]
}
r.Resultset, err = c.buildResultset(r.Fields, names, r.Values)
if err != nil {
return nil, err
}
} else {
r.Resultset = c.newEmptyResultset(stmt)
}
return r, nil
}
//only merge result with aggregate function in group by opt
func (c *ClientConn) mergeGroupByWithFunc(rs []*mysql.Result, groupByIndexs []int,
funcExprs map[int]string) (*mysql.Result, error) {
r := rs[0]
//load rs into a map, in order to make group
resultMap, err := c.loadResultWithFuncIntoMap(rs, groupByIndexs, funcExprs)
if err != nil {
return nil, err
}
//set status
status := c.status
for i := 0; i < len(rs); i++ {
status = status | rs[i].Status
}
//change map into Resultset
r.Values = nil
r.RowDatas = nil
for _, v := range resultMap {
r.Values = append(r.Values, v.Value)
r.RowDatas = append(r.RowDatas, v.RowData)
}
r.Status = status
return r, nil
}
//only merge result without aggregate function in group by opt
func (c *ClientConn) mergeGroupByWithoutFunc(rs []*mysql.Result,
groupByIndexs []int) (*mysql.Result, error) {
r := rs[0]
//load rs into a map
resultMap, err := c.loadResultIntoMap(rs, groupByIndexs)
if err != nil {
return nil, err
}
//set status
status := c.status
for i := 0; i < len(rs); i++ {
status = status | rs[i].Status
}
//load map into Resultset
r.Values = nil
r.RowDatas = nil
for _, v := range resultMap {
r.Values = append(r.Values, v.Value)
r.RowDatas = append(r.RowDatas, v.RowData)
}
r.Status = status
return r, nil
}
type ResultRow struct {
Value []interface{}
RowData mysql.RowData
}
func (c *ClientConn) generateMapKey(groupColumns []interface{}) (string, error) {
bk := make([]byte, 0, 8)
separatorBuf, err := formatValue("+")
if err != nil {
return "", err
}
for _, v := range groupColumns {
b, err := formatValue(v)
if err != nil {
return "", err
}
bk = append(bk, b...)
bk = append(bk, separatorBuf...)
}
return string(bk), nil
}
func (c *ClientConn) loadResultIntoMap(rs []*mysql.Result,
groupByIndexs []int) (map[string]*ResultRow, error) {
//load Result into map
resultMap := make(map[string]*ResultRow)
for _, r := range rs {
for i := 0; i < len(r.Values); i++ {
keySlice := r.Values[i][groupByIndexs[0]:]
mk, err := c.generateMapKey(keySlice)
if err != nil {
return nil, err
}
resultMap[mk] = &ResultRow{
Value: r.Values[i],
RowData: r.RowDatas[i],
}
}
}
return resultMap, nil
}
func (c *ClientConn) loadResultWithFuncIntoMap(rs []*mysql.Result,
groupByIndexs []int, funcExprs map[int]string) (map[string]*ResultRow, error) {
resultMap := make(map[string]*ResultRow)
rt := new(mysql.Result)
rt.Resultset = new(mysql.Resultset)
rt.Fields = rs[0].Fields
//change Result into map
for _, r := range rs {
for i := 0; i < len(r.Values); i++ {
keySlice := r.Values[i][groupByIndexs[0]:]
mk, err := c.generateMapKey(keySlice)
if err != nil {
return nil, err
}
if v, ok := resultMap[mk]; ok {
//init rt
rt.Values = nil
rt.RowDatas = nil
//append v and r into rt, and calculate the function value
rt.Values = append(rt.Values, r.Values[i], v.Value)
rt.RowDatas = append(rt.RowDatas, r.RowDatas[i], v.RowData)
resultTmp := []*mysql.Result{rt}
for funcIndex, funcName := range funcExprs {
funcValue, err := c.calFuncExprValue(funcName, resultTmp, funcIndex)
if err != nil {
return nil, err
}
//set the function value in group by
resultMap[mk].Value[funcIndex] = funcValue
}
} else { //key is not exist
resultMap[mk] = &ResultRow{
Value: r.Values[i],
RowData: r.RowDatas[i],
}
}
}
}
return resultMap, nil
}
//build select result without group by opt
func (c *ClientConn) buildSelectOnlyResult(rs []*mysql.Result,
stmt *sqlparser.Select) (*mysql.Result, error) {
var err error
r := rs[0].Resultset
status := c.status | rs[0].Status
funcExprs := c.getFuncExprs(stmt)
if len(funcExprs) == 0 {
for i := 1; i < len(rs); i++ {
status |= rs[i].Status
for j := range rs[i].Values {
r.Values = append(r.Values, rs[i].Values[j])
r.RowDatas = append(r.RowDatas, rs[i].RowDatas[j])
}
}
} else {
//result only one row, status doesn't need set
r, err = c.buildFuncExprResult(stmt, rs, funcExprs)
if err != nil {
return nil, err
}
}
return &mysql.Result{
Status: status,
Resultset: r,
}, nil
}
func (c *ClientConn) sortSelectResult(r *mysql.Resultset, stmt *sqlparser.Select) error {
if stmt.OrderBy == nil {
return nil
}
sk := make([]mysql.SortKey, len(stmt.OrderBy))
for i, o := range stmt.OrderBy {
sk[i].Name = nstring(o.Expr)
sk[i].Direction = o.Direction
}
return r.Sort(sk)
}
func (c *ClientConn) limitSelectResult(r *mysql.Resultset, stmt *sqlparser.Select) error {
if stmt.Limit == nil {
return nil
}
var offset, count int64
var err error
if stmt.Limit.Offset == nil {
offset = 0
} else {
if o, ok := stmt.Limit.Offset.(sqlparser.NumVal); !ok {
return fmt.Errorf("invalid select limit %s", nstring(stmt.Limit))
} else {
if offset, err = strconv.ParseInt(hack.String([]byte(o)), 10, 64); err != nil {
return err
}
}
}
if o, ok := stmt.Limit.Rowcount.(sqlparser.NumVal); !ok {
return fmt.Errorf("invalid limit %s", nstring(stmt.Limit))
} else {
if count, err = strconv.ParseInt(hack.String([]byte(o)), 10, 64); err != nil {
return err
} else if count < 0 {
return fmt.Errorf("invalid limit %s", nstring(stmt.Limit))
}
}
if offset > int64(len(r.Values)) {
r.Values = nil
r.RowDatas = nil
return nil
}
if offset+count > int64(len(r.Values)) {
count = int64(len(r.Values)) - offset
}
r.Values = r.Values[offset : offset+count]
r.RowDatas = r.RowDatas[offset : offset+count]
return nil
}
func (c *ClientConn) buildFuncExprResult(stmt *sqlparser.Select,
rs []*mysql.Result, funcExprs map[int]string) (*mysql.Resultset, error) {
var names []string
var err error
r := rs[0].Resultset
funcExprValues := make(map[int]interface{})
for index, funcName := range funcExprs {
funcExprValue, err := c.calFuncExprValue(
funcName,
rs,
index,
)
if err != nil {
return nil, err
}
funcExprValues[index] = funcExprValue
}
r.Values, err = c.buildFuncExprValues(rs, funcExprValues)
if 0 < len(r.Values) {
for _, field := range rs[0].Fields {
names = append(names, string(field.Name))
}
r, err = c.buildResultset(rs[0].Fields, names, r.Values)
if err != nil {
return nil, err
}
} else {
r = c.newEmptyResultset(stmt)
}
return r, nil
}
//get the index of funcExpr, the value is function name
func (c *ClientConn) getFuncExprs(stmt *sqlparser.Select) map[int]string {
var f *sqlparser.FuncExpr
funcExprs := make(map[int]string)
for i, expr := range stmt.SelectExprs {
nonStarExpr, ok := expr.(*sqlparser.NonStarExpr)
if !ok {
continue
}
f, ok = nonStarExpr.Expr.(*sqlparser.FuncExpr)
if !ok {
continue
} else {
f = nonStarExpr.Expr.(*sqlparser.FuncExpr)
funcName := strings.ToLower(string(f.Name))
switch funcNameMap[funcName] {
case FUNC_EXIST:
funcExprs[i] = funcName
}
}
}
return funcExprs
}
func (c *ClientConn) getSumFuncExprValue(rs []*mysql.Result,
index int) (interface{}, error) {
var sumf float64
var sumi int64
var IsInt bool
var err error
var result interface{}
for _, r := range rs {
for k := range r.Values {
result, err = r.GetValue(k, index)
if err != nil {
return nil, err
}
if result == nil {
continue
}
switch v := result.(type) {
case int:
sumi = sumi + int64(v)
IsInt = true
case int32:
sumi = sumi + int64(v)
IsInt = true
case int64:
sumi = sumi + v
IsInt = true
case float32:
sumf = sumf + float64(v)
case float64:
sumf = sumf + v
case []byte:
tmp, err := strconv.ParseFloat(string(v), 64)
if err != nil {
return nil, err
}
sumf = sumf + tmp
default:
return nil, errors.ErrSumColumnType
}
}
}
if IsInt {
return sumi, nil
} else {
return sumf, nil
}
}
func (c *ClientConn) getMaxFuncExprValue(rs []*mysql.Result,
index int) (interface{}, error) {
var max interface{}
var findNotNull bool
if len(rs) == 0 {
return nil, nil
} else {
for _, r := range rs {
for k := range r.Values {
result, err := r.GetValue(k, index)
if err != nil {
return nil, err
}
if result != nil {
max = result
findNotNull = true
break
}
}
if findNotNull {
break
}
}
}
for _, r := range rs {
for k := range r.Values {
result, err := r.GetValue(k, index)
if err != nil {
return nil, err
}
if result == nil {
continue
}
switch result.(type) {
case int64:
if max.(int64) < result.(int64) {
max = result
}
case uint64:
if max.(uint64) < result.(uint64) {
max = result
}
case float64:
if max.(float64) < result.(float64) {
max = result
}
case string:
if max.(string) < result.(string) {
max = result
}
}
}
}
return max, nil
}
func (c *ClientConn) getMinFuncExprValue(
rs []*mysql.Result, index int) (interface{}, error) {
var min interface{}
var findNotNull bool
if len(rs) == 0 {
return nil, nil
} else {
for _, r := range rs {
for k := range r.Values {
result, err := r.GetValue(k, index)
if err != nil {
return nil, err
}
if result != nil {
min = result
findNotNull = true
break
}
}
if findNotNull {
break
}
}
}
for _, r := range rs {
for k := range r.Values {
result, err := r.GetValue(k, index)
if err != nil {
return nil, err
}
if result == nil {
continue
}
switch result.(type) {
case int64:
if min.(int64) > result.(int64) {
min = result
}
case uint64:
if min.(uint64) > result.(uint64) {
min = result
}
case float64:
if min.(float64) > result.(float64) {
min = result
}
case string:
if min.(string) > result.(string) {
min = result
}
}
}
}
return min, nil
}
//calculate the the value funcExpr(sum or count)
func (c *ClientConn) calFuncExprValue(funcName string,
rs []*mysql.Result, index int) (interface{}, error) {
var num int64
switch strings.ToLower(funcName) {
case CountFunc:
if len(rs) == 0 {
return nil, nil
}
for _, r := range rs {
if r != nil {
for k := range r.Values {
result, err := r.GetInt(k, index)
if err != nil {
return nil, err
}
num += result
}
}
}
return num, nil
case SumFunc:
return c.getSumFuncExprValue(rs, index)
case MaxFunc:
return c.getMaxFuncExprValue(rs, index)
case MinFunc:
return c.getMinFuncExprValue(rs, index)
case LastInsertIdFunc:
return c.lastInsertId, nil
default:
if len(rs) == 0 {
return nil, nil
}
//get a non-null value of funcExpr
for _, r := range rs {
for k := range r.Values {
result, err := r.GetValue(k, index)
if err != nil {
return nil, err
}
if result != nil {
return result, nil
}
}
}
}
return nil, nil
}
//build values of resultset,only build one row
func (c *ClientConn) buildFuncExprValues(rs []*mysql.Result,
funcExprValues map[int]interface{}) ([][]interface{}, error) {
values := make([][]interface{}, 0, 1)
//build a row in one result
for i := range rs {
for j := range rs[i].Values {
for k := range funcExprValues {
rs[i].Values[j][k] = funcExprValues[k]
}
values = append(values, rs[i].Values[j])
if len(values) == 1 {
break
}
}
break
}
//generate one row just for sum or count
if len(values) == 0 {
value := make([]interface{}, len(rs[0].Fields))
for k := range funcExprValues {
value[k] = funcExprValues[k]
}
values = append(values, value)
}
return values, nil
}
1
https://gitee.com/he3db/he3proxy.git
git@gitee.com:he3db/he3proxy.git
he3db
he3proxy
He3Proxy
v1.0.1

搜索帮助

53164aa7 5694891 3bd8fe86 5694891