15 Star 69 Fork 15

He3DB / He3Proxy

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
conn.go 11.82 KB
一键复制 编辑 原始数据 按行查看 历史
wangyao 提交于 2022-12-02 07:02 . improve: write data to backend
// Copyright 2016 The kingshard Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package server
import (
"bytes"
"crypto/tls"
"encoding/binary"
"fmt"
"net"
"runtime"
"sync"
"gitee.com/he3db/he3proxy/backend"
"gitee.com/he3db/he3proxy/config"
"gitee.com/he3db/he3proxy/core/golog"
"gitee.com/he3db/he3proxy/core/hack"
"gitee.com/he3db/he3proxy/mysql"
)
//client <-> proxy
type ClientConn struct {
sync.Mutex
pkg *mysql.PacketIO
c net.Conn
proxy *Server
capability uint32
connectionId uint32
status uint16
collation mysql.CollationId
charset string
user string
db string
table string
salt []byte
nodes map[string]*backend.Node
schema *Schema
txConns map[*backend.Node]*backend.BackendConn
closed bool
lastInsertId int64
affectedRows int64
stmtId uint32
// Prepare related, stmt from client to proxy
stmts map[uint32]*Stmt
configVer uint32 //check config version for reload online
// new params for PG
tlsConn *tls.Conn // TLS connection, nil if not TLS.
backendConn *backend.BackendConn
beginFlag uint8 // 'begin' statement status
//TODO Environment variables should be placed in the connection cache, and the variables should be set first each time the statement is executed
alwaysCurNode bool // Flag for always exec sql in same node
// save receive data from backend, will send to client
dataRecv []byte
// flag for use extended query protocol
parseFlag bool
Parse sync.Map //parse name and parse info
isInTrxPg bool // use for pg, true means connection in transaction
}
const (
BEGIN_UNSTART = iota // Not exec 'begin' statement
BEGIN_PRESTART // 'begin' already exec, but not exec insert/create/update/delete... statement.
// 'begin' already exec, but not exec insert/create/update/delete... statement., and auto exec commit for backend connection.
// cancel transaction state.
BEGIN_PRESTART_COMMIT
// after exec 'begin', and will exec insert/create/update/delete... statement.
BEGIN_RELSTART
// after exec 'begin' and insert/create/update/delete... statement.
// Add 'begin' for first insert/create/update/delete... statement.
BEGIN_RELSTART_BEGIN
// after exec 'commit' for client
BEGIN_COMMIT
)
var DEFAULT_CAPABILITY = mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_LONG_FLAG |
mysql.CLIENT_CONNECT_WITH_DB | mysql.CLIENT_PROTOCOL_41 |
mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_SECURE_CONNECTION
var baseConnId uint32 = 10000
func (c *ClientConn) IsAllowConnect() bool {
clientHost, _, err := net.SplitHostPort(c.c.RemoteAddr().String())
if err != nil {
fmt.Println(err)
}
clientIP := net.ParseIP(clientHost)
current, _, _ := c.proxy.allowipsIndex.Get()
ipVec := c.proxy.allowips[current]
if ipVecLen := len(ipVec); ipVecLen == 0 {
return true
}
for _, ip := range ipVec {
if ip.Match(clientIP) {
return true
}
}
golog.Error("server", "IsAllowConnect", "error", mysql.ER_ACCESS_DENIED_ERROR,
"ip address", c.c.RemoteAddr().String(), " access denied by He3Proxy.")
return false
}
func (c *ClientConn) Handshake() error {
if err := c.writeInitialHandshake(); err != nil {
golog.Error("server", "Handshake", err.Error(),
c.connectionId, "msg", "send initial handshake error")
return err
}
if err := c.readHandshakeResponse(); err != nil {
golog.Error("server", "readHandshakeResponse",
err.Error(), c.connectionId,
"msg", "read Handshake Response error")
return err
}
if err := c.writeOK(nil); err != nil {
golog.Error("server", "readHandshakeResponse",
"write ok fail",
c.connectionId, "error", err.Error())
return err
}
c.pkg.Sequence = 0
return nil
}
func (c *ClientConn) Close() error {
c.status &= ^mysql.SERVER_STATUS_IN_TRANS
if c.closed {
return nil
}
c.c.Close()
c.backendConn.Close()
c.closed = true
return nil
}
func (c *ClientConn) writeInitialHandshake() error {
data := make([]byte, 4, 128)
//min version 10
data = append(data, 10)
//server version[00]
data = append(data, mysql.ServerVersion...)
data = append(data, 0)
//connection id
data = append(data, byte(c.connectionId), byte(c.connectionId>>8), byte(c.connectionId>>16), byte(c.connectionId>>24))
//auth-plugin-data-part-1
data = append(data, c.salt[0:8]...)
//filter [00]
data = append(data, 0)
//capability flag lower 2 bytes, using default capability here
data = append(data, byte(DEFAULT_CAPABILITY), byte(DEFAULT_CAPABILITY>>8))
//charset, utf-8 default
data = append(data, uint8(mysql.DEFAULT_COLLATION_ID))
//status
data = append(data, byte(c.status), byte(c.status>>8))
//below 13 byte may not be used
//capability flag upper 2 bytes, using default capability here
data = append(data, byte(DEFAULT_CAPABILITY>>16), byte(DEFAULT_CAPABILITY>>24))
//filter [0x15], for wireshark dump, value is 0x15
data = append(data, 0x15)
//reserved 10 [00]
data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
//auth-plugin-data-part-2
data = append(data, c.salt[8:]...)
//filter [00]
data = append(data, 0)
return c.writePacket(data)
}
func (c *ClientConn) readPacket() ([]byte, error) {
return c.pkg.ReadPacket()
}
func (c *ClientConn) writePacket(data []byte) error {
return c.pkg.WritePacket(data)
}
func (c *ClientConn) writePacketBatch(total, data []byte, direct bool) ([]byte, error) {
return c.pkg.WritePacketBatch(total, data, direct)
}
func (c *ClientConn) readHandshakeResponse() error {
data, err := c.readPacket()
if err != nil {
return err
}
pos := 0
//capability
c.capability = binary.LittleEndian.Uint32(data[:4])
pos += 4
//skip max packet size
pos += 4
//charset, skip, if you want to use another charset, use set names
//c.collation = CollationId(data[pos])
pos++
//skip reserved 23[00]
pos += 23
//user name
c.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)])
pos += len(c.user) + 1
//auth length and auth
authLen := int(data[pos])
pos++
auth := data[pos : pos+authLen]
//check user
if _, ok := c.proxy.users[c.user]; !ok {
golog.Error("ClientConn", "readHandshakeResponse", "error", 0,
"auth", auth,
"client_user", c.user,
"config_set_user", c.user,
"password", c.proxy.users[c.user])
return mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR, c.user, c.c.RemoteAddr().String(), "Yes")
}
//check password
checkAuth := mysql.CalcPassword(c.salt, []byte(c.proxy.users[c.user]))
if !bytes.Equal(auth, checkAuth) {
golog.Error("ClientConn", "readHandshakeResponse", "error", 0,
"auth", auth,
"checkAuth", checkAuth,
"client_user", c.user,
"config_set_user", c.user,
"password", c.proxy.users[c.user])
return mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR, c.user, c.c.RemoteAddr().String(), "Yes")
}
pos += authLen
var db string
if c.capability&mysql.CLIENT_CONNECT_WITH_DB > 0 {
if len(data[pos:]) == 0 {
return nil
}
db = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)])
pos += len(c.db) + 1
}
c.db = db
return nil
}
func (c *ClientConn) clean() {
if c.txConns != nil && len(c.txConns) > 0 {
for _, co := range c.txConns {
co.Close()
}
}
}
func (c *ClientConn) Run() {
defer func() {
r := recover()
if err, ok := r.(error); ok {
const size = 4096
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
golog.Error("ClientConn", "Run",
err.Error(), 0,
"stack", string(buf))
}
c.Close()
}()
defer c.clean()
for {
// Get the client input command, but the SQL read here does not contain hint eg:select /*master*/ * from test001;
// Will get select * from test001. Cause conn_preshard file is invalid when it is judged that the primary usage is mandatory.
// if strings.ToLower(tokens[1]) == mysql.TK_STR_MASTER_HINT
data, err := c.readPacket()
if err != nil {
return
}
if c.configVer != c.proxy.configVer {
err := c.reloadConfig()
if nil != err {
golog.Error("ClientConn", "Run",
err.Error(), c.connectionId,
)
c.writeError(err)
return
}
c.configVer = c.proxy.configVer
golog.Debug("ClientConn", "Run",
fmt.Sprintf("config reload ok, ver:%d", c.configVer), c.connectionId,
)
}
if err := c.dispatch(data); err != nil {
c.proxy.counter.IncrErrLogTotal()
golog.Error("ClientConn", "Run",
err.Error(), c.connectionId,
)
c.writeError(err)
if err == mysql.ErrBadConn {
c.Close()
}
}
if c.closed {
return
}
c.pkg.Sequence = 0
}
}
func (c *ClientConn) dispatch(data []byte) error {
c.proxy.counter.IncrClientQPS()
cmd := data[0]
data = data[1:]
switch cmd {
case mysql.COM_QUIT:
c.handleRollback()
c.Close()
return nil
case mysql.COM_QUERY:
return c.handleQuery(hack.String(data))
case mysql.COM_PING:
return c.writeOK(nil)
case mysql.COM_INIT_DB:
return c.handleUseDB(hack.String(data))
case mysql.COM_FIELD_LIST:
return c.handleFieldList(data)
case mysql.COM_STMT_PREPARE:
return c.handleStmtPrepare(hack.String(data))
case mysql.COM_STMT_EXECUTE:
return c.handleStmtExecute(data)
case mysql.COM_STMT_CLOSE:
return c.handleStmtClose(data)
case mysql.COM_STMT_SEND_LONG_DATA:
return c.handleStmtSendLongData(data)
case mysql.COM_STMT_RESET:
return c.handleStmtReset(data)
case mysql.COM_SET_OPTION:
return c.writeEOF(0)
default:
msg := fmt.Sprintf("command %d not supported now", cmd)
golog.Error("ClientConn", "dispatch", msg, 0)
return mysql.NewError(mysql.ER_UNKNOWN_ERROR, msg)
}
return nil
}
func (c *ClientConn) writeOK(r *mysql.Result) error {
if r == nil {
r = &mysql.Result{Status: c.status}
}
data := make([]byte, 4, 32)
data = append(data, mysql.OK_HEADER)
data = append(data, mysql.PutLengthEncodedInt(r.AffectedRows)...)
data = append(data, mysql.PutLengthEncodedInt(r.InsertId)...)
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
data = append(data, byte(r.Status), byte(r.Status>>8))
data = append(data, 0, 0)
}
return c.writePacket(data)
}
func (c *ClientConn) writeError(e error) error {
var m *mysql.SqlError
var ok bool
if m, ok = e.(*mysql.SqlError); !ok {
m = mysql.NewError(mysql.ER_UNKNOWN_ERROR, e.Error())
}
data := make([]byte, 4, 16+len(m.Message))
data = append(data, mysql.ERR_HEADER)
data = append(data, byte(m.Code), byte(m.Code>>8))
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
data = append(data, '#')
data = append(data, m.State...)
}
data = append(data, m.Message...)
return c.writePacket(data)
}
func (c *ClientConn) writeEOF(status uint16) error {
data := make([]byte, 4, 9)
data = append(data, mysql.EOF_HEADER)
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
data = append(data, 0, 0)
data = append(data, byte(status), byte(status>>8))
}
return c.writePacket(data)
}
func (c *ClientConn) writeEOFBatch(total []byte, status uint16, direct bool) ([]byte, error) {
data := make([]byte, 4, 9)
data = append(data, mysql.EOF_HEADER)
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
data = append(data, 0, 0)
data = append(data, byte(status), byte(status>>8))
}
return c.writePacketBatch(total, data, direct)
}
func (c *ClientConn) reloadConfig() error {
c.proxy.configUpdateMutex.RLock()
defer c.proxy.configUpdateMutex.RUnlock()
c.schema = c.proxy.GetSchema(c.user)
if nil == c.schema {
// TODO 22.05.30 The user obtains the partition rule from the database, so the default configuration is used first.
c.schema = c.proxy.GetSchema(config.DefaultHe3User)
if nil == c.schema {
return fmt.Errorf("schema of user [%s] is null or user is deleted", c.user)
}
}
c.nodes = c.proxy.nodes
return nil
}
1
https://gitee.com/he3db/he3proxy.git
git@gitee.com:he3db/he3proxy.git
he3db
he3proxy
He3Proxy
v1.0.1

搜索帮助

53164aa7 5694891 3bd8fe86 5694891