90 Star 492 Fork 151

平凯星辰(北京)科技有限公司/tidb

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
server.go 11.94 KB
一键复制 编辑 原始数据 按行查看 历史
// The MIT License (MIT)
//
// Copyright (c) 2014 wandoulabs
// Copyright (c) 2014 siddontang
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
// For pprof
_ "net/http/pprof"
"sync"
"sync/atomic"
"time"
"github.com/blacktear23/go-proxyprotocol"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
var (
baseConnID uint32
)
var (
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
)
// DefaultCapability is the capability of the server when it is created using the default configuration.
// When server is configured with SSL, the server will have extra capabilities compared to DefaultCapability.
const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag |
mysql.ClientConnectWithDB | mysql.ClientProtocol41 |
mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows |
mysql.ClientMultiStatements | mysql.ClientMultiResults | mysql.ClientLocalFiles |
mysql.ClientConnectAtts | mysql.ClientPluginAuth
// Server is the MySQL protocol server
type Server struct {
cfg *config.Config
tlsConfig *tls.Config
driver IDriver
listener net.Listener
rwlock *sync.RWMutex
concurrentLimiter *TokenLimiter
clients map[uint32]*clientConn
capability uint32
// stopListenerCh is used when a critical error occurred, we don't want to exit the process, because there may be
// a supervisor automatically restart it, then new client connection will be created, but we can't server it.
// So we just stop the listener and store to force clients to chose other TiDB servers.
stopListenerCh chan struct{}
statusServer *http.Server
}
// ConnectionCount gets current connection count.
func (s *Server) ConnectionCount() int {
var cnt int
s.rwlock.RLock()
cnt = len(s.clients)
s.rwlock.RUnlock()
return cnt
}
func (s *Server) getToken() *Token {
start := time.Now()
tok := s.concurrentLimiter.Get()
// Note that data smaller than one microsecond is ignored, because that case can be viewed as non-block.
metrics.GetTokenDurationHistogram.Observe(float64(time.Since(start).Nanoseconds() / 1e3))
return tok
}
func (s *Server) releaseToken(token *Token) {
s.concurrentLimiter.Put(token)
}
// newConn creates a new *clientConn from a net.Conn.
// It allocates a connection ID and random salt data for authentication.
func (s *Server) newConn(conn net.Conn) *clientConn {
cc := newClientConn(s)
if s.cfg.Performance.TCPKeepAlive {
if tcpConn, ok := conn.(*net.TCPConn); ok {
if err := tcpConn.SetKeepAlive(true); err != nil {
log.Error("failed to set tcp keep alive option:", err)
}
}
}
cc.setConn(conn)
cc.salt = util.RandomBuf(20)
return cc
}
func (s *Server) skipAuth() bool {
return s.cfg.Socket != ""
}
// NewServer creates a new Server.
func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
s := &Server{
cfg: cfg,
driver: driver,
concurrentLimiter: NewTokenLimiter(cfg.TokenLimit),
rwlock: &sync.RWMutex{},
clients: make(map[uint32]*clientConn),
stopListenerCh: make(chan struct{}, 1),
}
s.loadTLSCertificates()
s.capability = defaultCapability
if s.tlsConfig != nil {
s.capability |= mysql.ClientSSL
}
var err error
if cfg.Socket != "" {
if s.listener, err = net.Listen("unix", cfg.Socket); err == nil {
log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket)
}
} else {
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
if s.listener, err = net.Listen("tcp", addr); err == nil {
log.Infof("Server is running MySQL Protocol at [%s]", addr)
}
}
if cfg.ProxyProtocol.Networks != "" {
pplistener, errProxy := proxyprotocol.NewListener(s.listener, cfg.ProxyProtocol.Networks,
int(cfg.ProxyProtocol.HeaderTimeout))
if errProxy != nil {
log.Error("ProxyProtocol Networks parameter invalid")
return nil, errors.Trace(errProxy)
}
log.Infof("Server is running MySQL Protocol (through PROXY Protocol) at [%s]", s.cfg.Host)
s.listener = pplistener
}
if err != nil {
return nil, errors.Trace(err)
}
// Init rand seed for randomBuf()
rand.Seed(time.Now().UTC().UnixNano())
return s, nil
}
func (s *Server) loadTLSCertificates() {
defer func() {
if s.tlsConfig != nil {
log.Infof("Secure connection is enabled (client verification enabled = %v)", len(variable.SysVars["ssl_ca"].Value) > 0)
variable.SysVars["have_openssl"].Value = "YES"
variable.SysVars["have_ssl"].Value = "YES"
variable.SysVars["ssl_cert"].Value = s.cfg.Security.SSLCert
variable.SysVars["ssl_key"].Value = s.cfg.Security.SSLKey
} else {
log.Warn("Secure connection is NOT ENABLED")
}
}()
if len(s.cfg.Security.SSLCert) == 0 || len(s.cfg.Security.SSLKey) == 0 {
s.tlsConfig = nil
return
}
tlsCert, err := tls.LoadX509KeyPair(s.cfg.Security.SSLCert, s.cfg.Security.SSLKey)
if err != nil {
log.Warn(errors.ErrorStack(err))
s.tlsConfig = nil
return
}
// Try loading CA cert.
clientAuthPolicy := tls.NoClientCert
var certPool *x509.CertPool
if len(s.cfg.Security.SSLCA) > 0 {
caCert, err := ioutil.ReadFile(s.cfg.Security.SSLCA)
if err != nil {
log.Warn(errors.ErrorStack(err))
} else {
certPool = x509.NewCertPool()
if certPool.AppendCertsFromPEM(caCert) {
clientAuthPolicy = tls.VerifyClientCertIfGiven
}
variable.SysVars["ssl_ca"].Value = s.cfg.Security.SSLCA
}
}
s.tlsConfig = &tls.Config{
Certificates: []tls.Certificate{tlsCert},
ClientCAs: certPool,
ClientAuth: clientAuthPolicy,
MinVersion: 0,
}
}
// Run runs the server.
func (s *Server) Run() error {
metrics.ServerEventCounter.WithLabelValues(metrics.EventStart).Inc()
// Start HTTP API to report tidb info such as TPS.
if s.cfg.Status.ReportStatus {
s.startStatusHTTP()
}
for {
conn, err := s.listener.Accept()
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
if opErr.Err.Error() == "use of closed network connection" {
return nil
}
}
// If we got PROXY protocol error, we should continue accept.
if proxyprotocol.IsProxyProtocolError(err) {
log.Errorf("PROXY protocol error: %s", err.Error())
continue
}
log.Errorf("accept error %s", err.Error())
return errors.Trace(err)
}
if s.shouldStopListener() {
err = conn.Close()
terror.Log(errors.Trace(err))
break
}
go s.onConn(conn)
}
err := s.listener.Close()
terror.Log(errors.Trace(err))
s.listener = nil
for {
metrics.ServerEventCounter.WithLabelValues(metrics.EventHang).Inc()
log.Errorf("listener stopped, waiting for manual kill.")
time.Sleep(time.Minute)
}
}
func (s *Server) shouldStopListener() bool {
select {
case <-s.stopListenerCh:
return true
default:
return false
}
}
// Close closes the server.
func (s *Server) Close() {
s.rwlock.Lock()
defer s.rwlock.Unlock()
if s.listener != nil {
err := s.listener.Close()
terror.Log(errors.Trace(err))
s.listener = nil
}
if s.statusServer != nil {
err := s.statusServer.Close()
terror.Log(errors.Trace(err))
s.statusServer = nil
}
metrics.ServerEventCounter.WithLabelValues(metrics.EventClose).Inc()
}
// onConn runs in its own goroutine, handles queries from this connection.
func (s *Server) onConn(c net.Conn) {
conn := s.newConn(c)
if err := conn.handshake(); err != nil {
// Some keep alive services will send request to TiDB and disconnect immediately.
// So we only record metrics.
metrics.HandShakeErrorCounter.Inc()
err = c.Close()
terror.Log(errors.Trace(err))
return
}
log.Infof("con:%d new connection %s", conn.connectionID, c.RemoteAddr().String())
defer func() {
log.Infof("con:%d close connection", conn.connectionID)
}()
s.rwlock.Lock()
s.clients[conn.connectionID] = conn
connections := len(s.clients)
s.rwlock.Unlock()
metrics.ConnGauge.Set(float64(connections))
conn.Run()
}
// ShowProcessList implements the SessionManager interface.
func (s *Server) ShowProcessList() map[uint64]util.ProcessInfo {
s.rwlock.RLock()
rs := make(map[uint64]util.ProcessInfo, len(s.clients))
for _, client := range s.clients {
if atomic.LoadInt32(&client.status) == connStatusWaitShutdown {
continue
}
pi := client.ctx.ShowProcess()
rs[pi.ID] = pi
}
s.rwlock.RUnlock()
return rs
}
// Kill implements the SessionManager interface.
func (s *Server) Kill(connectionID uint64, query bool) {
s.rwlock.Lock()
defer s.rwlock.Unlock()
log.Infof("[server] Kill connectionID %d, query %t]", connectionID, query)
metrics.ServerEventCounter.WithLabelValues(metrics.EventKill).Inc()
conn, ok := s.clients[uint32(connectionID)]
if !ok {
return
}
conn.mu.RLock()
cancelFunc := conn.mu.cancelFunc
conn.mu.RUnlock()
if cancelFunc != nil {
cancelFunc()
}
if !query {
// Mark the client connection status as WaitShutdown, when the goroutine detect
// this, it will end the dispatch loop and exit.
atomic.StoreInt32(&conn.status, connStatusWaitShutdown)
}
}
// GracefulDown waits all clients to close.
func (s *Server) GracefulDown() {
log.Info("[server] graceful shutdown.")
metrics.ServerEventCounter.WithLabelValues(metrics.EventGracefulDown).Inc()
count := s.ConnectionCount()
for i := 0; count > 0; i++ {
time.Sleep(time.Second)
s.kickIdleConnection()
count = s.ConnectionCount()
// Print information for every 30s.
if i%30 == 0 {
log.Infof("graceful shutdown...connection count %d\n", count)
}
}
}
func (s *Server) kickIdleConnection() {
var conns []*clientConn
s.rwlock.RLock()
for _, cc := range s.clients {
if cc.ShutdownOrNotify() {
// Shutdowned conn will be closed by us, and notified conn will exist themselves.
conns = append(conns, cc)
}
}
s.rwlock.RUnlock()
for _, cc := range conns {
err := cc.Close()
if err != nil {
log.Error("close connection error:", err)
}
}
}
// Server error codes.
const (
codeUnknownFieldType = 1
codeInvalidPayloadLen = 2
codeInvalidSequence = 3
codeInvalidType = 4
codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
)
func init() {
serverMySQLErrCodes := map[terror.ErrCode]uint16{
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
}
terror.ErrClassToMySQLCodes[terror.ClassServer] = serverMySQLErrCodes
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/pingcap/tidb.git
git@gitee.com:pingcap/tidb.git
pingcap
tidb
tidb
v2.1.0-rc.4

搜索帮助