15 Star 69 Fork 15

He3DB / He3Proxy

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
conn_pgsql.go 48.21 KB
一键复制 编辑 原始数据 按行查看 历史
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481
// Copyright (c) 2022. China Mobile(SuZhou)Software Technology Co.,Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package server
import (
"bufio"
"context"
"crypto/md5"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"runtime"
"strings"
"sync"
"time"
"unsafe"
"github.com/cloudwego/netpoll"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
_ "go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
//timecost "github.com/dablelv/go-huge-util"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgx/v4"
"gitee.com/he3db/he3proxy/config"
"gitee.com/he3db/he3proxy/core/errors"
"gitee.com/he3db/he3proxy/core/golog"
_ "gitee.com/he3db/he3proxy/core/hack"
"gitee.com/he3db/he3proxy/mysql"
"gitee.com/he3db/he3proxy/postgresql/hba"
)
const protocolVersionNumber = 196608 // 3.0
const sslRequestNumber = 80877103
const cancelRequestCode = 80877102
const gssEncReqNumber = 80877104
const protocolSSL = false
const defaultWriterSize = 16 * 1024
const moduleName = "CONN_PGSQL"
var clientConnMap sync.Map
//-----------------------------------------------------------------
// Handshake of PG
// 1. PG SQL will send 'SSLRequest' message firstly, and ask for if request SSL
// 2. Server will send a message 'S' if need SSL, or 'N' means no SSL
// 3. Notice that we don't support SSL now
// 4. When finish SSL handshard or receive 'N', client will send 'StartupMessage' to server
// 5. Than server handle 'StartupMessage' and whether request auth or not. If need auth will send 'AuthenticationRequest' to client
// 6. PG support muti-kind auth type. Now we support md5 first, when finish auth the server will return 'AuthenticationOk' or 'ErrorResponse'
// 7. If auth successfully, server also sent some params to client by 'ParameterStatus', contains: server_version, client_encoding and so on
// 8. At last server will send 'ReadyForQuery' to client, means connection successfully and server can handle sql requests
// 9. When client receive msg 'ReadyForQuery', it means we can run SQL operation
// for more detail information, please check this article: https://zhuanlan.zhihu.com/p/493045524
func (cc *ClientConn) handshake(ctx context.Context) error {
if config.Opentracing {
// Use the global TracerProvider.
tr := otel.Tracer("component-handshake")
var span trace.Span
ctx, span = tr.Start(ctx, "handshake " + string(cc.connectionId))
//span.SetAttributes(attribute.Key("testset").String("value"))
defer span.End()
}
m, err := cc.ReceiveStartupMessage()
if err != nil {
return err
}
switch m.(type) {
// handle for cancel request (ctrl+c)
case *pgproto3.CancelRequest:
val, flag := clientConnMap.Load(m.(*pgproto3.CancelRequest).ProcessID)
if flag {
c := val.(*ClientConn)
if c != nil && c.backendConn != nil {
cancelRequest := &pgproto3.CancelRequest{ProcessID: c.backendConn.ProcessID,
SecretKey: c.backendConn.SecretKey}
wr := c.backendConn.ConnPg.Writer()
_, err = wr.WriteBinary(cancelRequest.Encode(nil))
if err != nil {
golog.Error(moduleName, "CancelRequest", "write msg err: "+err.Error(), cc.connectionId)
return err
}
wr.Flush()
}
cc.Close()
}
return err
case *pgproto3.SSLRequest:
return cc.handleSSLRequest(ctx)
case *pgproto3.StartupMessage:
return cc.handleStartupMessage(ctx, m.(*pgproto3.StartupMessage))
default:
err = errors.ErrFormatStr("received is not a expected packet")
return err
}
}
// handle client request through goroutine
func (c *ClientConn) RunPg(ctx context.Context) {
if config.Opentracing {
// Use the global TracerProvider.
tr := otel.Tracer("component-Runsql")
var span trace.Span
ctx, span = tr.Start(ctx, "RunSQL " + string(c.connectionId))
//span.SetAttributes(attribute.Key("testset").String("value"))
defer span.End()
}
defer func() {
r := recover()
if err, ok := r.(error); ok {
const size = 4096
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
golog.Error(moduleName, "RunPg",
err.Error(), 0,
"stack", string(buf))
}
//set transaction status for backend conn
if c.backendConn != nil {
if c.isInTrxPg {
c.backendConn.IsInTransaction = true
} else {
c.backendConn.IsInTransaction = false
}
}
c.clean()
c.Close()
clientConnMap.Delete(c.connectionId)
}()
// flag for just use master node, just use for some special cases
// treat it simple and crude, set in transaction
if config.SingleSession {
c.isInTrxPg = true
c.alwaysCurNode = true
}
//zero copy
//var (
// buf = newBuffer(c.c, mysql.MaxPayloadLen)
// headBuf []byte
// contentSize int
// msg []byte
//)
for {
// old method
msg, err := c.readPacketPg(ctx)
if err != nil {
return
}
//zero copy
//_, err := buf.readFromReader()
//if err != nil {
// return
//}
//headBuf, err = buf.seek(5)
//if err != nil {
// break
//}
//contentSize = int(binary.BigEndian.Uint32(headBuf[1:]) +1)
//if (buf.Len() >= contentSize) {
// msg = buf.read(0, contentSize)
//}else {
// return
//}
//zero copy end
// reload configuration
if c.configVer != c.proxy.configVer {
err := c.reloadConfig()
if nil != err {
golog.Error(moduleName, "RunPg",
err.Error(), c.connectionId,
)
c.writeError(err)
return
}
c.configVer = c.proxy.configVer
if golog.GetLevel() <= golog.LevelDebug {
golog.Debug(moduleName, "RunPg",
fmt.Sprintf("config reload ok, ver: %d", c.configVer), c.connectionId)
}
}
// handle receive msg
if err = c.dispatchPg(ctx, msg); err != nil {
c.proxy.counter.IncrErrLogTotal()
if err == io.EOF {
continue
}
golog.Error(moduleName, "RunPg",
err.Error(), c.connectionId,
)
c.writePgErr(ctx, "22000", err.Error())
if err == mysql.ErrBadConn {
c.Close()
}
c.backendConn.Close()
}
if c.closed {
return
}
c.pkg.Sequence = 0
}
}
// dispatch handles client request based on command which is the first byte of the data.
// It also gets a token from server which is used to limit the concurrently handling clients.
// The most frequently used command is ComQuery.
// PostgreSQL Modified
func (cc *ClientConn) dispatchPg(ctx context.Context, data []byte) error {
if config.Opentracing {
// Use the global TracerProvider.
tr := otel.Tracer("component-dispatchPg")
var span trace.Span
ctx, span = tr.Start(ctx, "dispatchPg " + string(cc.connectionId))
//span.SetAttributes(attribute.Key("testset").String("value"))
defer span.End()
}
cc.proxy.counter.IncrClientQPS()
cmd := data[0]
var err error
//golog.Trace(moduleName, "dispatchPg", "cmd str:"+string(cmd), cc.connectionId)
switch cmd {
case 'Q': /* simple query */
simpleQuery := pgproto3.Query{}
if err = simpleQuery.Decode(data[5:]); err != nil {
return err
}
err = cc.handleQueryPg(ctx, simpleQuery.String, data)
return err
/*
extend query protocol, msg send in sequence
first phase:
Parse
---------- --> return:ParseComplete
Describe
---------- --> return:ParameterDescription RowDescription
Sync
---------- --> return:ReadyForQuery
second phase:
Bind
---------- --> return:BindComplete
Describe
---------- --> return:RowDescription
Execute
---------- --> return:DataRow CommandComplete
Sync
---------- --> return:ReadyForQuery
X
end with io.EOF
*/
case 'P': /* parse */
parse := pgproto3.Parse{}
if err := parse.Decode(data[5:]); err != nil {
return err
}
// save parse name, use for delete parse when retrieve to connection pool
if !config.ReadOnly{
cc.Parse.Store(parse.Name, parse)
}
sql := parse.Query
if cc.backendConn == nil || cc.backendConn.Conn == nil {
// parse phase will reuse connect session, if exec select first than exec insert will get an error.
// we use temporary scheme to fix it, set session in transaction, so session will choose master node.
// TODO but the scheme makes load balance unuseful, no suitable scheme by now. Will keep thinking.
//cc.status &= ^mysql.SERVER_STATUS_AUTOCOMMIT
//cc.status |= mysql.SERVER_STATUS_IN_TRANS
// In benchmarkSQL, the extended-query protocol and the simple-query protocol will be used at the same time.
// After the query statement use extended protocol change to the simple protocol,
// when a ReadyForQuery in a non transaction is received, the backend connection will be released, and the previous parse statement will be lost.
// So set alwaysCurNode to true, that continue to use the current connection.
// TODO In future will save parse statement in clientConnection, and init first when get backend connection.
//cc.alwaysCurNode = true
cc.parseFlag = true
cc.alwaysCurNode = true
cc.backendConn, err = cc.preHandlePg(sql, ctx)
if err == nil {
if golog.GetLevel() <= golog.LevelDebug {
golog.Debug(moduleName, "parse",
fmt.Sprintf("exec sql [%s] by node [%s]", sql, cc.backendConn.GetAddr()), cc.connectionId, "dbname", cc.db)
}
}
// save conn info to map, use for cancel request.
if cc != nil && cc.backendConn != nil {
if config.CancelReq {
clientConnMap.Store(cc.connectionId, cc)
}
}
}
if cc == nil || cc.backendConn == nil || err != nil {
golog.Error(moduleName, "RunPg", err.Error(), cc.connectionId)
return nil
}
// packaging send msg
wr := cc.backendConn.Conn.ConnPg.Writer()
n, err := wr.WriteBinary(data)
if n != len(data) || err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "handleParsePrepare", fmt.Sprintf("write parse to connection err: %s", err.Error()), cc.connectionId)
}
}
return nil
case 'B', 'D', 'E', 'd': /* bind */ /* describe */ /* execute */ /* copy data */
if cc.backendConn == nil || cc.backendConn.Conn == nil {
golog.Warn(moduleName, "handle B D E d", "backend connection is null, current send data is: "+string(data), cc.connectionId)
var err error
// TODO
cc.backendConn, err = cc.preHandlePg("begin;", ctx)
if err != nil {
golog.Error(moduleName, "handleParsePrepare", "reconnect backend err: "+err.Error(), cc.connectionId)
return err
}
if cc != nil && cc.backendConn != nil {
clientConnMap.Store(cc.connectionId, cc)
cc.handleParsePrepare(ctx)
}
}
// packaging send msg
wr := cc.backendConn.Conn.ConnPg.Writer()
n, err := wr.WriteBinary(data)
if n != len(data) || err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "handleParsePrepare", fmt.Sprintf("write bind to connection err: %s", err.Error()), cc.connectionId)
}
}
return nil
case 'F': /* fastpath function call */
case 'C': /* close */
c := pgproto3.Close{}
if err := c.Decode(data[5:]); err != nil {
return err
}
err = cc.handleStmtClosePg(ctx, c)
return err
case 'H': /* flush */
// return cc.flush(ctx)
case 'S': /* sync */
err = cc.handleStmtSyncPg(ctx, data)
return err
case 'X': /*Client Terminate*/
return io.EOF
case 'c': /* copy done */
err = cc.handleCopy(ctx, data)
return err
case 'f': /* copy fail */
err = cc.handleCopy(ctx, data)
return err
default:
return errors.ErrFormat("command %d not supported now", cmd)
}
return errors.ErrFormat("command %d not supported now", cmd)
}
func (cc *ClientConn) handleParsePrepare(ctx context.Context) error {
if config.Opentracing {
// Use the global TracerProvider.
tr := otel.Tracer("component-handleParsePrepare")
var span trace.Span
ctx, span = tr.Start(ctx, "handleParsePrepare " + string(cc.connectionId))
//span.SetAttributes(attribute.Key("testset").String("value"))
defer span.End()
}
var parseData []byte
cc.Parse.Range(func(key, value interface{}) bool {
parse := (value).(pgproto3.Parse)
parseData = parse.Encode(parseData)
return true
})
if parseData == nil || len(parseData) == 0 {
return nil
}
if golog.GetLevel() <= golog.LevelDebug {
golog.Debug(moduleName, "handleParsePrepare", fmt.Sprintf("write cached parse data is: %s", string(parseData)), cc.connectionId)
}
wr := cc.backendConn.Conn.ConnPg.Writer()
n, err := wr.WriteBinary(parseData)
if n != len(parseData) || err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "handleParsePrepare", fmt.Sprintf("write parse to connection err: %s", err.Error()), cc.connectionId)
}
}
wr.Flush()
return err
}
// handle simple query protocol
func (cc *ClientConn) handleQueryPg(ctx context.Context, sql string, data []byte) error {
if config.Opentracing {
// Use the global TracerProvider.
tr := otel.Tracer("component-handleQueryPg")
var span trace.Span
ctx, span = tr.Start(ctx, "handleQueryPg " + string(cc.connectionId))
span.SetAttributes(attribute.Key("sql").String(sql))
defer span.End()
}
var err error
if cc.backendConn == nil || cc.backendConn.Conn == nil {
cc.backendConn, err = cc.preHandlePg(sql, ctx)
if err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "handleQueryPg", err.Error(), 0, "sql", sql)
}
return err
}
if cc != nil && cc.backendConn != nil {
if config.CancelReq {
clientConnMap.Store(cc.connectionId, cc)
}
cc.handleParsePrepare(ctx)
}
} else {
// change status, use for load balance in begin statement.(select first and than insert/update/delete)
// we can do load balance for 'first insert', will improve machine throughput.
if cc.isInTrxPg || !config.ReadOnly {
if cc.beginFlag == BEGIN_PRESTART_COMMIT {
cc.beginFlag = BEGIN_RELSTART
}
}
}
if cc.backendConn == nil || cc.backendConn.Conn == nil || cc.backendConn.Conn.ConnPg == nil {
return errors.ErrConnIsNil
}
defer cc.closeConn(cc.backendConn, false)
if golog.GetLevel() <= golog.LevelDebug {
golog.Debug(moduleName, "handleQueryPg",
fmt.Sprintf("exec sql [%s] by node [%s]", sql, cc.backendConn.GetAddr()), cc.connectionId, "dbname", cc.db)
}
// deal with duplicate "begin", will return 'WARNING: there is already a transaction in progress'.
if cc.beginFlag == BEGIN_PRESTART_COMMIT && "BEGIN" == strings.ToUpper(strings.ReplaceAll(sql, ";", "")) {
//errRes := pgproto3.ErrorResponse{
// Severity: "WARNING",
// SeverityUnlocalized: "WARNING",
// Code: "25001",
// Message: "there is already a transaction in progress",
// File: "xact.c",
// Line: 3689,
// Routine: "BeginTransactionBlock",
//}
//var nRes pgproto3.NoticeResponse
//nRes = pgproto3.NoticeResponse(errRes)
//cmdComplete := &pgproto3.CommandComplete{CommandTag: stringTobyteSlice("BEGIN")}
//cc.WriteData((&pgproto3.ReadyForQuery{TxStatus: 'T'}).Encode(cmdComplete.Encode((&nRes).Encode(nil))))
cc.WriteData([]byte{78, 0, 0, 0, 111, 83, 87, 65, 82, 78, 73, 78, 71, 0, 86, 87, 65, 82, 78, 73, 78, 71, 0, 67,
50, 53, 48, 48, 49, 0, 77, 116, 104, 101, 114, 101, 32, 105, 115, 32, 97, 108, 114, 101, 97, 100, 121, 32,
97, 32, 116, 114, 97, 110, 115, 97, 99, 116, 105, 111, 110, 32, 105, 110, 32, 112, 114, 111, 103, 114, 101,
115, 115, 0, 70, 120, 97, 99, 116, 46, 99, 0, 76, 51, 54, 56, 57, 0, 82, 66, 101, 103, 105, 110, 84, 114,
97, 110, 115, 97, 99, 116, 105, 111, 110, 66, 108, 111, 99, 107, 0, 0, 67, 0, 0, 0, 10, 66, 69, 71, 73, 78,
0, 90, 0, 0, 0, 5, 84})
return nil
}
var reader, writer = cc.backendConn.Conn.ConnPg.Reader(), cc.backendConn.Conn.ConnPg.Writer()
// handle for 'begin' statement, and when exec insert/update/delete statement add 'begin'
// 1.begin with commit ('begin' to 'begin;...;commit')
// 2.if select statement, will do load balance
// 3.if insert/delete/update and so on, will add 'begin' in first statement.
// And set flag to BEGIN_RELSTART_BEGIN means in transaction
// 4.deal with statement as in the transaction until commit
if cc.beginFlag == BEGIN_PRESTART {
// write msg to client directly
//cmdComplete := &pgproto3.CommandComplete{CommandTag: stringTobyteSlice("BEGIN")}
//cc.WriteData((&pgproto3.ReadyForQuery{TxStatus: 'T'}).Encode(cmdComplete.Encode(nil)))
cc.WriteData([]byte{67, 0, 0, 0, 10, 66, 69, 71, 73, 78, 0, 90, 0, 0, 0, 5, 84})
cc.beginFlag = BEGIN_PRESTART_COMMIT
return nil
} else if cc.beginFlag == BEGIN_RELSTART {
// when first exec write ops after begin, will prior exec 'begin' statement.
//sqlStr := "BEGIN;"
//_, err = writer.WriteBinary((&pgproto3.Query{String: sqlStr}).Encode(nil))
_, err = writer.WriteBinary([]byte{81, 0, 0, 0, 11, 66, 69, 71, 73, 78, 59, 0})
if err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "handleQueryPg", fmt.Sprintf("write msg err: %s", err.Error()), cc.connectionId)
}
return err
}
// exec current statement
_, err = writer.WriteBinary(data)
if err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "handleQueryPg", fmt.Sprintf("write msg err: %s", err.Error()), cc.connectionId)
}
return err
}
writer.Flush()
// consume msg from backend, but not return to client
for {
d, e := cc.backendConn.Conn.ReadPgPacket(reader)
if e != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "receiveBackendMsg", fmt.Sprintf("read packet from backend err: %s", e.Error()), cc.connectionId)
}
return e
}
if d[0] == 'Z' {
cc.isInTrxPg = true
break
}
}
// reset flag
cc.beginFlag = BEGIN_RELSTART_BEGIN
} else {
n, er := writer.WriteBinary(data)
if er != nil || n != len(data) {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "handleQueryPg", fmt.Sprintf("write msg err: %s", err.Error()), cc.connectionId)
}
return err
}
writer.Flush()
}
// mock
//by := []byte{84,0,0,0,26,0,1,99,0,0,0,70,57,0,3,0,0,4,18,255,255,0,0,0,124,0,0,68,0,0,0,130,0,1,0,0,0,120,50,52,51,54,55,51,53,50,55,53,57,45,57,50,56,53,56,54,51,48,54,53,57,45,48,49,48,51,52,57,56,48,53,52,57,45,54,55,54,53,56,56,56,49,55,52,52,45,51,56,52,57,49,50,56,48,50,55,53,45,49,55,49,48,52,55,52,55,51,49,56,45,53,49,52,53,56,53,50,57,54,55,55,45,57,54,53,52,54,51,53,48,54,49,48,45,49,52,56,53,53,55,51,51,51,55,51,45,55,54,56,57,57,55,50,56,54,54,57,32,67,0,0,0,13,83,69,76,69,67,84,32,49,0,90,0,0,0,5,73}
//cc.WriteData(by)
//return err
err = cc.receiveBackendMsg(ctx, reader)
if err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "handleQueryPg", fmt.Sprintf("receiveBackend msg err: %s", err.Error()), cc.connectionId)
}
return err
}
return nil
}
//func slicePgMsg(msg []byte) (res [][]byte) {
// res = make([][]byte, 0)
// if len(msg) == 0 {
// return res
// }
// for len(msg) > 0 {
// msgLen := binary.BigEndian.Uint32(msg[1:5])
// res = append(res, msg[:1+msgLen])
// msg = msg[1+msgLen:]
// }
// return res
//}
//
// receive server connection msg, add deal with it
//
//func (cc *ClientConn) receiveBackendMsg(ctx context.Context) error {
// msg, err := cc.backendConn.Conn.ReadPgAllPacket()
// if err != nil {
// golog.Error(moduleName, "receiveBackendMsg", fmt.Sprintf("read packet from backend err: %s", err.Error()), cc.connectionId)
// return err
// }
// golog.Trace(moduleName, "receiveBackendMsg", fmt.Sprintf("recv packet from backend msg type: %s", string(msg)), cc.connectionId)
// dataList := slicePgMsg(msg)
//readloop:
// for _, data := range dataList {
// // deal with copy msg
// if data[0] == 'G' || data[0] == 'W' {
// // in transaction
// cc.status &= ^mysql.SERVER_STATUS_AUTOCOMMIT
// cc.status |= mysql.SERVER_STATUS_IN_TRANS
// cc.dataRecv = append(cc.dataRecv, data...)
// cc.WriteData(cc.dataRecv)
// cc.dataRecv = make([]byte, 0)
// break readloop
// }
// if data[0] == 'H' {
// cc.status &= ^mysql.SERVER_STATUS_AUTOCOMMIT
// cc.status |= mysql.SERVER_STATUS_IN_TRANS
// cc.dataRecv = append(cc.dataRecv, data...)
// cc.WriteData(cc.dataRecv)
// cc.dataRecv = make([]byte, 0)
// continue
// }
// // add new protocol 'L' for read consistency
// if data[0] == 'L' {
// lsn := pgproto3.LsnResponse{}
// lsn.Decode(data[5:])
// addr := cc.backendConn.ConnPg.PgConn().Conn().RemoteAddr().String()
// golog.Debug("pg conn", "receiveBackendMsg", fmt.Sprintf("addr: %s, lsn: %d", addr, lsn.LSN), cc.connectionId)
// // set LSN to node
// if addr != "" {
// cc.nodes["node1"].NodeLsn.Store(strings.Split(addr, ":")[0], lsn.LSN)
// }
// // set LSN to db_table
// if cc.table != "" && cc.db != "" {
// cc.nodes["node1"].NodeLsn.Store(cc.db+"_"+cc.table, lsn.LSN)
// }
// continue
// }
//
// // deal with msg for readForQuery. return msg
// if data[0] == 'Z' {
// q := pgproto3.ReadyForQuery{}
// q.Decode(data[5:])
// // deal with 'begin-commit' statement, if begin-select will return 'T' for front,
// // means in transaction, actually backend not in transaction. Do sql with load balance
// if cc.beginFlag == BEGIN_PRESTART_COMMIT {
// data = (&pgproto3.ReadyForQuery{TxStatus: 'T'}).Encode(nil)
// } else if cc.beginFlag == BEGIN_RELSTART_BEGIN {
// if q.TxStatus == 'I' && !cc.alwaysCurNode {
// cc.status = mysql.SERVER_STATUS_AUTOCOMMIT
// } else {
// cc.status &= ^mysql.SERVER_STATUS_AUTOCOMMIT
// cc.status |= mysql.SERVER_STATUS_IN_TRANS
// }
// cc.beginFlag = BEGIN_COMMIT
// } else {
// if q.TxStatus == 'T' && !cc.isInTransaction() {
// cc.status &= ^mysql.SERVER_STATUS_AUTOCOMMIT
// cc.status |= mysql.SERVER_STATUS_IN_TRANS
// } else if q.TxStatus == 'I' && !cc.alwaysCurNode {
// //cc.status |= mysql.SERVER_STATUS_AUTOCOMMIT
// //cc.status &= ^mysql.SERVER_STATUS_IN_TRANS
// if cc.isInTransaction() {
// cc.status = mysql.SERVER_STATUS_AUTOCOMMIT
// }
// if cc.beginFlag != BEGIN_UNSTART {
// cc.beginFlag = BEGIN_UNSTART
// }
// }
// }
//
// cc.dataRecv = append(cc.dataRecv, data...)
// err = cc.WriteData(cc.dataRecv)
// if err != nil {
// golog.Error(moduleName, "receiveBackendMsg", fmt.Sprintf("write data to backend err: %v", err), cc.connectionId)
// }
// cc.dataRecv = make([]byte, 0)
// break readloop
// }
// if data[0] == 'E' {
// golog.Error(moduleName, "receiveBackendMsg", fmt.Sprintf("read err packet from backend: %s", string(data)), cc.connectionId)
// }
// cc.dataRecv = append(cc.dataRecv, data...)
// }
// return nil
//}
// receive server connection msg, add deal with it
func (cc *ClientConn) receiveBackendMsg(ctx context.Context, reader netpoll.Reader) error {
if config.Opentracing {
// Use the global TracerProvider.
tr := otel.Tracer("component-receiveBackendMsg")
var span trace.Span
ctx, span = tr.Start(ctx, "receiveBackendMsg " + string(cc.connectionId))
//span.SetAttributes(attribute.Key("testset").String("value"))
defer span.End()
}
cc.dataRecv = cc.dataRecv[:0]
if config.ReadOnly {
time.Sleep(time.Microsecond * 1)
mlen := reader.Len()
for mlen == 0 {
time.Sleep(time.Microsecond * 2)
mlen = reader.Len()
}
data, err := reader.Next(mlen)
//if err != nil {
// if golog.GetLevel() <= golog.LevelError {
// golog.Error(moduleName, "receiveBackendMsg", fmt.Sprintf("read packet from backend err: %s", err.Error()), cc.connectionId)
// }
// return err
//}
//if cc.beginFlag == BEGIN_UNSTART{
// cc.WriteData(data)
// reader.Release()
// return nil
//}
//if data[mlen-6]=='Z'{
// if cc.beginFlag == BEGIN_PRESTART_COMMIT {
// cc.dataRecv = BytesCombine(cc.dataRecv, data)
// cc.dataRecv[mlen-1] = 'T'
// cc.WriteData(cc.dataRecv)
// reader.Release()
// return nil
// } else if cc.beginFlag == BEGIN_RELSTART_BEGIN {
// if data[5] == 'I' && !cc.alwaysCurNode {
// cc.isInTrxPg = false
// } else {
// cc.isInTrxPg = true
// }
// cc.beginFlag = BEGIN_COMMIT
// } else {
// if data[5] == 'T' && !cc.isInTrxPg {
// cc.isInTrxPg = true
// } else if data[5] == 'I' && !cc.alwaysCurNode {
// //cc.status |= mysql.SERVER_STATUS_AUTOCOMMIT
// //cc.status &= ^mysql.SERVER_STATUS_IN_TRANS
// if cc.isInTrxPg {
// cc.isInTrxPg = false
// }
// if cc.beginFlag != BEGIN_UNSTART {
// cc.beginFlag = BEGIN_UNSTART
// }
// }
// }
//}
cc.WriteData(data)
reader.Release()
return err
}
readloop:
for {
data, err := cc.backendConn.Conn.ReadPgPacket(reader)
if err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "receiveBackendMsg", fmt.Sprintf("read packet from backend err: %s", err.Error()), cc.connectionId)
}
return err
}
if golog.GetLevel() <= golog.LevelTrace {
golog.Trace(moduleName, "receiveBackendMsg", fmt.Sprintf("recv packet from backend msg type: %s", string(data[0])), cc.connectionId)
}
switch data[0] {
// deal with msg for readForQuery. return msg
case 'Z':
// deal with 'begin-commit' statement, if begin-select will return 'T' for front,
// means in transaction, actually backend not in transaction. Do sql with load balance
if cc.beginFlag == BEGIN_PRESTART_COMMIT {
// (&pgproto3.ReadyForQuery{TxStatus: 'T'}).Encode(nil)
data = []byte{90, 0, 0, 0, 5, 84}
} else if cc.beginFlag == BEGIN_RELSTART_BEGIN {
if data[5] == 'I' && !cc.alwaysCurNode {
cc.isInTrxPg = false
} else {
cc.isInTrxPg = true
}
cc.beginFlag = BEGIN_COMMIT
} else {
if data[5] == 'T' && !cc.isInTrxPg {
cc.isInTrxPg = true
} else if data[5] == 'I' && !cc.alwaysCurNode {
//cc.status |= mysql.SERVER_STATUS_AUTOCOMMIT
//cc.status &= ^mysql.SERVER_STATUS_IN_TRANS
if cc.isInTrxPg {
cc.isInTrxPg = false
}
if cc.beginFlag != BEGIN_UNSTART {
cc.beginFlag = BEGIN_UNSTART
}
}
}
cc.dataRecv = BytesCombine(cc.dataRecv, data)
err = cc.WriteData(cc.dataRecv)
if err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "receiveBackendMsg", fmt.Sprintf("write data to backend err: %v", err), cc.connectionId)
}
}
cc.dataRecv = cc.dataRecv[:0]
break readloop
case 'L':
// add new protocol 'L' for read consistency
lsn := pgproto3.LsnResponse{}
lsn.Decode(data[5:])
addr := cc.backendConn.ConnPg.RemoteAddr().String()
if golog.GetLevel() <= golog.LevelDebug {
golog.Debug("pg conn", "receiveBackendMsg", fmt.Sprintf("addr: %s, lsn: %d", addr, lsn.LSN), cc.connectionId)
}
// set LSN to node
if addr != "" {
cc.nodes["node1"].NodeLsn.Store(strings.Split(addr, ":")[0], lsn.LSN)
}
// set LSN to db_table
if cc.table != "" && cc.db != "" {
cc.nodes["node1"].NodeLsn.Store(cc.db+"_"+cc.table, lsn.LSN)
}
continue
case 'E':
if golog.GetLevel() <= golog.LevelWarn {
golog.Warn(moduleName, "receiveBackendMsg", fmt.Sprintf("read err packet from backend: %s", string(data)), cc.connectionId)
}
case 'G', 'W':
// deal with copy msg
// in transaction
cc.isInTrxPg = true
cc.dataRecv = BytesCombine(cc.dataRecv, data)
cc.WriteData(cc.dataRecv)
cc.dataRecv = cc.dataRecv[:0]
break readloop
case 'H':
cc.isInTrxPg = true
cc.dataRecv = BytesCombine(cc.dataRecv, data)
cc.WriteData(cc.dataRecv)
cc.dataRecv = cc.dataRecv[:0]
continue
}
cc.dataRecv = BytesCombine(cc.dataRecv, data)
// TODO At present, all data are returned.
// In the future, we need to consider the situation of multiple data.
// We need to set a threshold and return in batches
//if (len(cc.dataRecv) + 1<<12) > mysql.MaxPayloadLen {
// cc.WriteData(cc.dataRecv)
// cc.dataRecv = cc.dataRecv[:0]
//}
}
reader.Release()
return nil
}
// handleStmtClose handle close messages in pgsql's extended query.
func (cc *ClientConn) handleStmtClosePg(ctx context.Context, close pgproto3.Close) error {
if config.Opentracing {
// Use the global TracerProvider.
tr := otel.Tracer("component-handleStmtClosePg")
var span trace.Span
ctx, span = tr.Start(ctx, "handleStmtClosePg " + string(cc.connectionId))
//span.SetAttributes(attribute.Key("testset").String("value"))
defer span.End()
}
if cc.backendConn == nil {
return nil
}
// Delete Parse record if closed.
cc.Parse.Load(close.Name)
if _, flag := cc.Parse.Load(close.Name); flag {
cc.Parse.Delete(close.Name)
}
data := close.Encode(nil)
wr := cc.backendConn.Conn.ConnPg.Writer()
n, err := wr.WriteBinary(data)
if n != len(data) || err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "handleParsePrepare", fmt.Sprintf("write parse to connection err: %s", err.Error()), cc.connectionId)
}
}
wr.Flush()
return cc.writeCloseComplete()
}
// writeCloseComplete
func (cc *ClientConn) writeCloseComplete() error {
closeComplete := pgproto3.CloseComplete{}
return cc.WriteData(closeComplete.Encode(nil))
}
func (cc *ClientConn) handleStmtSyncPg(ctx context.Context, data []byte) error {
if config.Opentracing {
// Use the global TracerProvider.
tr := otel.Tracer("component-handleStmtSyncPg")
var span trace.Span
ctx, span = tr.Start(ctx, "handleStmtSyncPg " + string(cc.connectionId))
//span.SetAttributes(attribute.Key("testset").String("value"))
defer span.End()
}
//defer cc.closeConn(cc.backendConn, false)
if cc.backendConn == nil || cc.backendConn.Conn == nil {
golog.Warn(moduleName, "handleStmtSyncPg", "backend connection is null, current send data is: "+string(data), cc.connectionId)
var err error
// TODO
cc.backendConn, err = cc.preHandlePg("begin;", ctx)
if err != nil {
golog.Error(moduleName, "handleStmtSyncPg", "reconnect backend err: "+err.Error(), cc.connectionId)
return err
}
if cc != nil && cc.backendConn != nil {
clientConnMap.Store(cc.connectionId, cc)
cc.handleParsePrepare(ctx)
}
//return nil
}
wr := cc.backendConn.Conn.ConnPg.Writer()
n, err := wr.WriteBinary(data)
if n != len(data) || err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "handleParsePrepare", fmt.Sprintf("write parse to connection err: %s", err.Error()), cc.connectionId)
}
}
wr.Flush()
err = cc.receiveBackendMsg(ctx, cc.backendConn.Conn.ConnPg.Reader())
if err != nil {
golog.Error(moduleName, "handleStmtSyncPg", "recv backend msg err: "+err.Error(), cc.connectionId)
return err
}
return nil
}
func (cc *ClientConn) handleCopy(ctx context.Context, data []byte) error {
if config.Opentracing {
// Use the global TracerProvider.
tr := otel.Tracer("component-handleCopy")
var span trace.Span
ctx, span = tr.Start(ctx, "handleCopy " + string(cc.connectionId))
//span.SetAttributes(attribute.Key("testset").String("value"))
defer span.End()
}
if cc.backendConn == nil {
return nil
}
writer := cc.backendConn.Conn.ConnPg.Writer()
n, err := writer.WriteBinary(data)
if n != len(data) || err != nil {
golog.Error(moduleName, "handleCopy", "write msg err: "+err.Error(), cc.connectionId)
return err
}
writer.Flush()
err = cc.receiveBackendMsg(ctx, cc.backendConn.Conn.ConnPg.Reader())
if err != nil {
return err
}
return nil
}
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
// because the initial connection message is "special" and does not include the message type as the first byte. This
// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
func (cc *ClientConn) ReceiveStartupMessage() (pgproto3.FrontendMessage, error) {
header := make([]byte, 4)
if _, err := io.ReadFull(cc.pkg.Rb, header); err != nil {
return nil, err
}
msgLen := int(binary.BigEndian.Uint32(header) - 4)
msg := make([]byte, msgLen)
if _, err := io.ReadFull(cc.pkg.Rb, msg); err != nil {
return nil, err
}
code := binary.BigEndian.Uint32(msg)
switch code {
case protocolVersionNumber:
startMessage := &pgproto3.StartupMessage{}
if err := startMessage.Decode(msg); err != nil {
return nil, err
}
return startMessage, nil
case sslRequestNumber:
sslRequest := &pgproto3.SSLRequest{}
if err := sslRequest.Decode(msg); err != nil {
return nil, err
}
return sslRequest, nil
case cancelRequestCode:
cancelRequest := &pgproto3.CancelRequest{}
if err := cancelRequest.Decode(msg); err != nil {
return nil, err
}
return cancelRequest, nil
case gssEncReqNumber:
gssEncRequest := &pgproto3.GSSEncRequest{}
if err := gssEncRequest.Decode(msg); err != nil {
return nil, err
}
return gssEncRequest, nil
default:
return nil, errors.ErrFormatStr("unknown startup message code: %s", fmt.Sprint(code))
}
}
// loadSSLCertificates
func loadSSLCertificates() (tlsConfig *tls.Config, err error) {
tlsCert, err := tls.LoadX509KeyPair("server/certs/pgserver.crt", "server/certs/pgserver.key")
if err != nil {
println("Load X509 failed" + err.Error())
}
clientAutoPolicy := tls.RequireAndVerifyClientCert
caCert, err := ioutil.ReadFile("server/certs/pgroot.crt")
if err != nil {
println("read ca file filed" + err.Error())
return nil, nil
}
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(caCert)
tlsConfig = &tls.Config{
ClientAuth: clientAutoPolicy,
ClientCAs: certPool,
Certificates: []tls.Certificate{tlsCert},
}
return tlsConfig, nil
}
func (cc *ClientConn) handleSSLRequest(ctx context.Context) error {
if protocolSSL {
tlsConfig, err := loadSSLCertificates()
if err != nil {
return err
}
// 写回 'S' 表示使用 SSL 连接
if err := cc.writeSSLRequest(ctx, 'S'); err != nil {
return err
}
//将现有的连接升级为SSL连接
if err := cc.upgradeToTLS(tlsConfig); err != nil {
return err
}
} else {
// 写回 'N' 表示不使用 SSL 连接
if err := cc.writeSSLRequest(ctx, 'N'); err != nil {
return err
}
}
// 完成 SSL 确认后需要正式接收 StartupMessage
m, err := cc.ReceiveStartupMessage()
if err != nil {
return err
}
msg, ok := m.(*pgproto3.StartupMessage)
// 如果接收到的包不为启动包则报错
if !ok {
err := errors.ErrFormatStr("received is not a StartupMessage")
return err
}
// 接收完 SSLRequest 包后接收 StartupMessage
if err := cc.handleStartupMessage(ctx, msg); err != nil {
return err
}
return nil
}
// handleStartupMessage
// Receive the client StartupMessage, obtain the client information, initialize the session and perform user authentication.
// During user authentication, the customer service terminal may disconnect and wait for the user to input the password
// before reestablishing the connection. Please note that.
// Finally, send AuthenticationOK or ErrorResponse to indicate the success or failure of authentication
func (cc *ClientConn) handleStartupMessage(ctx context.Context, startupMessage *pgproto3.StartupMessage) error {
if config.Opentracing {
// Use the global TracerProvider.
tr := otel.Tracer("component-startupmsg")
var span trace.Span
ctx, span = tr.Start(ctx, "handleStartupMessage " + string(cc.connectionId))
//span.SetAttributes(attribute.Key("testset").String("value"))
defer span.End()
}
// get db username and connection database name
cc.user = startupMessage.Parameters["user"]
cc.db = startupMessage.Parameters["database"]
datestyle := "ISO, MDY"
timezone := "PRC"
// set these envs, need pg driver pgconn support
// github.com/jackc/pgconn/config.go
if key, ok := startupMessage.Parameters["datestyle"]; ok {
datestyle = key
os.Setenv("PGDATESTYLE", key)
}
if key, ok := startupMessage.Parameters["timezone"]; ok {
timezone = key
os.Setenv("PGTZ", key)
}
if key, ok := startupMessage.Parameters["options"]; ok {
os.Setenv("PGOPTIONS", key)
}
// set flag, means connection request by he3proxy
// the flag default is on, set env 'HE3PROXY_FLAG' to false if you want close it.
// connection env 'HE3PROXY' need support by db engine
// if you use original postgres, please close this flag or will encounter an error like:
// FATAL: unrecognized configuration parameter "he3proxy" (SQLSTATE 42704)
if config.He3Proxy {
os.Setenv("HE3PROXY", "true")
}
if cc.db == "" {
cc.db = cc.user
}
// init session and do user auth.
if err := cc.PgOpenSessionAndDoAuth(ctx); err != nil {
return err
}
// According to pgpool. default server version 9.0.0
// Or 14 Navicat will occurs error as below.
/*
SELECT d.oid, d.datname AS databasename, pg_get_userbyid(d.datdba) AS databaseowner, des.description, d.datpath,
d.encoding, pg_encoding_to_char(d.encoding) AS encodingname FROM pg_database d
LEFT JOIN pg_description des ON des.objoid = d.oid;
ERROR: column d.datpath does not exist
LINE 1: ...byid(d.datdba) AS databaseowner, des.description, d.datpath,...
^
HINT: Perhaps you meant to reference the column "d.datname" or the column "d.datacl".
*/
serverVersion := config.ServerVersion
if serverVersion == "" {
serverVersion = "9.0.0"
}
parameters := map[string]string{
"client_encoding": "UTF8",
"DateStyle": datestyle,
"integer_datetimes": "on",
"is_superuser": "on",
"server_encoding": "UTF8",
"server_version": serverVersion,
"TimeZone": timezone,
"standard_conforming_strings": "on",
"default_transaction_read_only": "off",
"in_hot_standby": "off",
"IntervalStyle": "postgres",
"session_authorization": cc.user,
"application_name": startupMessage.Parameters["application_name"],
}
// send ParameterStatus
if err := cc.writeParameterStatus(parameters); err != nil {
return err
}
// send BackendKeyData
if err := cc.writeBackendKeyData(ctx, cc.connectionId); err != nil {
return err
}
// send ReadyForQuery "I"
if err := cc.writeReadyForQuery(ctx, 'I'); err != nil {
return err
}
if golog.GetLevel() <= golog.LevelInfo {
golog.Info(moduleName, "handleStartupMessage",
fmt.Sprintf("%s connection succeeded", cc.c.RemoteAddr().String()), cc.connectionId)
}
cc.pkg.Sequence = 0
return nil
}
// writeReadyForQuery
// Param 'status' is the status code of the current backend transaction.
// "I" means idle (not in the transaction), "T" means in the transaction;
// "E" means in a failed transaction (queries will be rejected before the end of the transaction block)
// After calling this method, the cache will be cleared and all messages in the cache will be sent.
func (cc *ClientConn) writeReadyForQuery(ctx context.Context, status byte) error {
readyForQuery := &pgproto3.ReadyForQuery{TxStatus: status}
if err := cc.WriteData(readyForQuery.Encode(nil)); err != nil {
return err
}
return nil
}
func (cc *ClientConn) writeBackendKeyData(ctx context.Context, pid uint32) error {
backendKeyData := &pgproto3.BackendKeyData{ProcessID: pid, SecretKey: 2204724030}
if err := cc.WriteData(backendKeyData.Encode(nil)); err != nil {
return err
}
return nil
}
// writeParameterStatus
// pgAdmin require param 'client_encoding'
func (cc *ClientConn) writeParameterStatus(parameters map[string]string) error {
for k, v := range parameters {
parameterStatus := &pgproto3.ParameterStatus{Name: k, Value: v}
if err := cc.WriteData(parameterStatus.Encode(nil)); err != nil {
return errors.ErrFormat("write ParameterStatus to client failed: ", err.Error())
}
}
return nil
}
// PgOpenSessionAndDoAuth
// Initialize session and perform user authentication
// There is a difference between PgSQL and mysql.
// The PgSQL client will send the password only after receiving the auth request from the server.
// At the beginning, if the MySQL client has a password, it will send the password directly to the server.
func (cc *ClientConn) PgOpenSessionAndDoAuth(ctx context.Context) error {
var err error
// hba check
authConf := GetAuthenticationConfiguration()
// Look up the method from the HBA configuration.
//var mi methodInfo
var hbaEntry *hba.Entry
_, hbaEntry, err = cc.lookupAuthenticationMethodUsingRules(hba.ConnHostNoSSL, authConf)
if err != nil {
return err
}
if hbaEntry.Method.String() == "reject" {
return errors.ErrFormat("user %s is not allowed to login.", cc.user)
} else if hbaEntry.Method.String() == "trust" {
if err = cc.isAceessDB(hbaEntry); err != nil {
return err
}
if err = cc.writeAuthenticationOK(ctx); err != nil {
return err
}
return nil
} else if hbaEntry.Method.String() == "md5" {
if err = cc.isAceessDB(hbaEntry); err != nil {
return err
}
authData := make([]byte, 0)
err = cc.DoAuth(ctx, authData)
if err != nil {
return err
}
return nil
} else {
return errors.ErrFormat("Unsupport auth method %s ", hbaEntry.Method.String())
}
}
// DoAuth PostgreSQL
// During authentication, the client will not actively send the authentication password.
// The server needs to send different types of password authentication requests before the client returns the corresponding authentication information.
func (cc *ClientConn) DoAuth(ctx context.Context, auth []byte) error {
var err error
// Send an authrequest, where the MD5 encryption request is used.
// The front end must return an MD5 encrypted password for verification.
// Salt is randomly generated 4 bytes. Here, the first four bits of salt generated in cc are directly taken.
salt := [4]byte{cc.salt[0], cc.salt[1], cc.salt[2], cc.salt[3]}
authRequest := pgproto3.AuthenticationMD5Password{Salt: salt}
//AuthenticationCleartextPassword means receive pwd without encryption.
//authRequest := pgproto3.AuthenticationCleartextPassword{}
if err = cc.WriteData(authRequest.Encode(nil)); err != nil {
return errors.ErrFormat("write AuthenticationMD5Password to client failed: %s", err.Error())
}
// After sending the authrequest, the client (SQL shell) will exit the connection.
// Then wait for the user to re-enter the password and then establish the connection again.
// Currently, it only runs on the SQL shell, at least in the SQL shell.
// Therefore, it needs to be judged when reading a byte here. When the error is EOF, it can be ended normally.
// Read the password sent by the client.
// format: 'p' + len + 'password' + '0'
// length = len + password + 1
msg, err := cc.readPacketPg(ctx)
if err != nil {
return err
}
if msg[0] != 'p' {
return errors.ErrFormatStr("received is not a password packet" + string(auth[0]))
}
// remove first byte 'p' and last end EOF. The rest is authentication information
auth = msg[:len(msg)-1]
//get node
rolpassword, err := getRolPwdFromDB(cc.proxy.cfg.Nodes[0].Master, cc.user)
if err != nil || !strings.HasPrefix(rolpassword, "md5") {
return errors.ErrFormat("%s@%s login failed, password not correct or user not exist! time: [%s]",
cc.user, cc.c.RemoteAddr().String(), time.Now())
}
// The client first performs MD5 encryption on the password entered by the user, where the user name is salt.
// Then, the 4-bit random number md5salt sent from the server is used as salt for another MD5 encryption,
// And sends the result to the server again as authentication information.
res := "md5" + fmt.Sprintf("%x", md5.Sum([]byte(strings.TrimPrefix(rolpassword, "md5")+
string([]byte{cc.salt[0], cc.salt[1], cc.salt[2], cc.salt[3]}))))
if res != string(auth) {
golog.Error("server", "IsAllowConnect", "error", mysql.ER_ACCESS_DENIED_ERROR,
"ip address", cc.c.RemoteAddr().String(), " access denied by He3-Proxy.")
return errors.ErrFormat("%s@%s login failed, password incorrect! time: [%s]", cc.user, cc.c.RemoteAddr().String(), time.Now())
}
if err = cc.writeAuthenticationOK(ctx); err != nil {
return err
}
return nil
}
// Add the deployment node address of the he3proxy to the hba-trust, so the password field is not required.
func getRolPwdFromDB(addr string, rolname string) (string, error) {
urlExample := "postgres://postgres:vjC2T7r!6Amf54QZ@" + addr + "/postgres"
rolpassword := ""
conn, err := pgx.Connect(context.Background(), urlExample)
defer conn.Close(context.Background())
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err)
return rolpassword, err
}
err = conn.QueryRow(context.Background(), "select rolpassword from pg_authid where rolname=$1", rolname).Scan(&rolpassword)
return rolpassword, err
}
// readPacket Read general messages of postgresql protocol
func (cc *ClientConn) readPacketPg(ctx context.Context) ([]byte, error) {
if config.Opentracing {
// Use the global TracerProvider.
tr := otel.Tracer("component-readPacketPg")
var span trace.Span
ctx, span = tr.Start(ctx, "readPacketPg " + string(cc.connectionId))
//span.SetAttributes(attribute.Key("testset").String("value"))
defer span.End()
}
//header := make([]byte, 5)
//if _, err := io.ReadFull(cc.pkg.Rb, header); err != nil {
// return nil, nil, err
//}
header, err := cc.pkg.Rb.Peek(5)
if err != nil {
return nil, err
}
for len(header)<5{
time.Sleep(2*time.Microsecond)
header, err = cc.pkg.Rb.Peek(5)
if err != nil {
return nil, err
}
}
msgLen := int(binary.BigEndian.Uint32(header[1:]) +1)
if msgLen > cap(cc.dataRecv) {
cc.dataRecv = make([]byte, msgLen)
} else {
cc.dataRecv = cc.dataRecv[:msgLen]
}
if _, err = io.ReadAtLeast(cc.pkg.Rb, cc.dataRecv, msgLen); err != nil {
return nil, err
}
return cc.dataRecv, nil
}
// writeAuthenticationOK
func (cc *ClientConn) writeAuthenticationOK(ctx context.Context) error {
authOK := &pgproto3.AuthenticationOk{}
if err := cc.WriteData(authOK.Encode(nil)); err != nil {
return err
}
return nil
}
// writePgErr
func (cc *ClientConn) writePgErr(ctx context.Context, code string, errmsg string) error {
errorResponse := &pgproto3.ErrorResponse{
Severity: "ERROR",
SeverityUnlocalized: "",
//TODO The error needs to be returned according to the error code.
Code: code,
Message: errmsg,
Detail: "",
Hint: "",
}
if err := cc.WriteData(errorResponse.Encode(nil)); err != nil {
return err
}
return cc.writeReadyForQuery(ctx, 'I')
}
// writeSSLRequest
// 'S' means agree handshake for SSL.
// 'N' means do not use SSL.
func (cc *ClientConn) writeSSLRequest(ctx context.Context, pgRequestSSL byte) error {
if err := cc.WriteData([]byte{pgRequestSSL}); err != nil {
return err
}
return nil
}
func (cc *ClientConn) WriteData(data []byte) error {
if n, err := cc.pkg.Wb.Write(data); err != nil {
if golog.GetLevel() <= golog.LevelError {
golog.Error(moduleName, "WriteData", fmt.Sprintf("write data to backend err: %v", err), cc.connectionId, "data", string(data))
}
cc.Close()
return mysql.ErrBadConn
} else if n != len(data) {
return mysql.ErrBadConn
} else {
return nil
}
}
func (cc *ClientConn) upgradeToTLS(tlsConfig *tls.Config) error {
// Important: read from buffered reader instead of the original net.Conn because it may contain data we need.
tlsConn := tls.Server(cc.c, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return err
}
cc.c = tlsConn
cc.pkg.Wb = bufio.NewWriterSize(cc.pkg.Wb, defaultWriterSize)
cc.tlsConn = tlsConn
return nil
}
func (c *ClientConn) lookupAuthenticationMethodUsingRules(
connType hba.ConnType, auth *hba.Conf,
) (mi methodInfo, entry *hba.Entry, err error) {
var ip net.IP
if connType != hba.ConnLocal {
// Extract the IP address of the client.
tcpAddr, ok := c.c.RemoteAddr().(*net.TCPAddr)
if !ok {
err = errors.ErrFormat("client address type %T unsupported", c.c.RemoteAddr())
return
}
ip = tcpAddr.IP
}
// Look up the method.
for i := range auth.Entries {
entry = &auth.Entries[i]
var connMatch bool
connMatch, err = entry.ConnMatches(connType, ip)
if err != nil {
// TODO(knz): Determine if an error should be reported
// upon unknown address formats.
// See: https://github.com/cockroachdb/cockroach/issues/43716
return
}
if !connMatch {
// The address does not match.
continue
}
if !entry.UserMatches(hba.MakeSQLUsernameFromPreNormalizedString(c.user)) {
// The user does not match.
continue
}
return entry.MethodFn.(methodInfo), entry, nil
}
// No match.
err = errors.ErrFormat("no %s entry for host %q, user %q", *HbaConfigFile, ip, c.user)
return
}
// Determine whether there is a configuration for accessing dB
func (cc *ClientConn) isAceessDB(hbaEntry *hba.Entry) error {
for _, s := range hbaEntry.Database {
if s.Value == "all" || s.Value == cc.db {
return nil
}
}
return errors.ErrFormat("user %s is not allowed to access db %s.", cc.user, cc.db)
}
func stringTobyteSlice(s string) []byte {
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}
1
https://gitee.com/he3db/he3proxy.git
git@gitee.com:he3db/he3proxy.git
he3db
he3proxy
He3Proxy
v1.0.1

搜索帮助

53164aa7 5694891 3bd8fe86 5694891