1 Star 0 Fork 0

zhoujin826/tidb

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
server.go 9.52 KB
一键复制 编辑 原始数据 按行查看 历史
// The MIT License (MIT)
//
// Copyright (c) 2014 wandoulabs
// Copyright (c) 2014 siddontang
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
// For pprof
_ "net/http/pprof"
log "github.com/Sirupsen/logrus"
"github.com/juju/errors"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/arena"
)
var (
baseConnID uint32
)
var (
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
)
// DefaultCapability is the capability of the server when it is created using the default configuration.
// When server is configured with SSL, the server will have extra capabilities compared to DefaultCapability.
const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag |
mysql.ClientConnectWithDB | mysql.ClientProtocol41 |
mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows |
mysql.ClientMultiStatements | mysql.ClientMultiResults | mysql.ClientLocalFiles |
mysql.ClientConnectAtts | mysql.ClientPluginAuth
// Server is the MySQL protocol server
type Server struct {
cfg *config.Config
tlsConfig *tls.Config
driver IDriver
listener net.Listener
rwlock *sync.RWMutex
concurrentLimiter *TokenLimiter
clients map[uint32]*clientConn
capability uint32
// When a critical error occurred, we don't want to exit the process, because there may be
// a supervisor automatically restart it, then new client connection will be created, but we can't server it.
// So we just stop the listener and store to force clients to chose other TiDB servers.
stopListenerCh chan struct{}
}
// ConnectionCount gets current connection count.
func (s *Server) ConnectionCount() int {
var cnt int
s.rwlock.RLock()
cnt = len(s.clients)
s.rwlock.RUnlock()
return cnt
}
func (s *Server) getToken() *Token {
return s.concurrentLimiter.Get()
}
func (s *Server) releaseToken(token *Token) {
s.concurrentLimiter.Put(token)
}
// newConn creates a new *clientConn from a net.Conn.
// It allocates a connection ID and random salt data for authentication.
func (s *Server) newConn(conn net.Conn) *clientConn {
cc := &clientConn{
server: s,
connectionID: atomic.AddUint32(&baseConnID, 1),
collation: mysql.DefaultCollationID,
alloc: arena.NewAllocator(32 * 1024),
}
log.Infof("[%d] new connection %s", cc.connectionID, conn.RemoteAddr().String())
if s.cfg.Performance.TCPKeepAlive {
if tcpConn, ok := conn.(*net.TCPConn); ok {
if err := tcpConn.SetKeepAlive(true); err != nil {
log.Error("failed to set tcp keep alive option:", err)
}
}
}
cc.setConn(conn)
cc.salt = util.RandomBuf(20)
return cc
}
func (s *Server) skipAuth() bool {
return s.cfg.Socket != ""
}
// NewServer creates a new Server.
func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
s := &Server{
cfg: cfg,
driver: driver,
concurrentLimiter: NewTokenLimiter(cfg.TokenLimit),
rwlock: &sync.RWMutex{},
clients: make(map[uint32]*clientConn),
stopListenerCh: make(chan struct{}, 1),
}
s.loadTLSCertificates()
s.capability = defaultCapability
if s.tlsConfig != nil {
s.capability |= mysql.ClientSSL
}
var err error
if cfg.Socket != "" {
if s.listener, err = net.Listen("unix", cfg.Socket); err == nil {
log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket)
}
} else {
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
if s.listener, err = net.Listen("tcp", addr); err == nil {
log.Infof("Server is running MySQL Protocol at [%s]", addr)
}
}
if err != nil {
return nil, errors.Trace(err)
}
// Init rand seed for randomBuf()
rand.Seed(time.Now().UTC().UnixNano())
return s, nil
}
func (s *Server) loadTLSCertificates() {
defer func() {
if s.tlsConfig != nil {
log.Infof("Secure connection is enabled (client verification enabled = %v)", len(variable.SysVars["ssl_ca"].Value) > 0)
variable.SysVars["have_openssl"].Value = "YES"
variable.SysVars["have_ssl"].Value = "YES"
variable.SysVars["ssl_cert"].Value = s.cfg.Security.SSLCert
variable.SysVars["ssl_key"].Value = s.cfg.Security.SSLKey
} else {
log.Warn("Secure connection is NOT ENABLED")
}
}()
if len(s.cfg.Security.SSLCert) == 0 || len(s.cfg.Security.SSLKey) == 0 {
s.tlsConfig = nil
return
}
tlsCert, err := tls.LoadX509KeyPair(s.cfg.Security.SSLCert, s.cfg.Security.SSLKey)
if err != nil {
log.Warn(errors.ErrorStack(err))
s.tlsConfig = nil
return
}
// Try loading CA cert.
clientAuthPolicy := tls.NoClientCert
var certPool *x509.CertPool
if len(s.cfg.Security.SSLCA) > 0 {
caCert, err := ioutil.ReadFile(s.cfg.Security.SSLCA)
if err != nil {
log.Warn(errors.ErrorStack(err))
} else {
certPool = x509.NewCertPool()
if certPool.AppendCertsFromPEM(caCert) {
clientAuthPolicy = tls.VerifyClientCertIfGiven
}
variable.SysVars["ssl_ca"].Value = s.cfg.Security.SSLCA
}
}
s.tlsConfig = &tls.Config{
Certificates: []tls.Certificate{tlsCert},
ClientCAs: certPool,
ClientAuth: clientAuthPolicy,
MinVersion: 0,
}
}
// Run runs the server.
func (s *Server) Run() error {
// Start HTTP API to report tidb info such as TPS.
if s.cfg.Status.ReportStatus {
s.startStatusHTTP()
}
for {
conn, err := s.listener.Accept()
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
if opErr.Err.Error() == "use of closed network connection" {
return nil
}
}
log.Errorf("accept error %s", err.Error())
return errors.Trace(err)
}
if s.shouldStopListener() {
err = conn.Close()
terror.Log(errors.Trace(err))
break
}
go s.onConn(conn)
}
err := s.listener.Close()
terror.Log(errors.Trace(err))
s.listener = nil
for {
log.Errorf("listener stopped, waiting for manual kill.")
time.Sleep(time.Minute)
}
}
func (s *Server) shouldStopListener() bool {
select {
case <-s.stopListenerCh:
return true
default:
return false
}
}
// Close closes the server.
func (s *Server) Close() {
s.rwlock.Lock()
defer s.rwlock.Unlock()
if s.listener != nil {
err := s.listener.Close()
terror.Log(errors.Trace(err))
s.listener = nil
}
}
// onConn runs in its own goroutine, handles queries from this connection.
func (s *Server) onConn(c net.Conn) {
conn := s.newConn(c)
defer func() {
log.Infof("[%d] close connection", conn.connectionID)
}()
if err := conn.handshake(); err != nil {
// Some keep alive services will send request to TiDB and disconnect immediately.
// So we use info log level.
log.Infof("handshake error %s", errors.ErrorStack(err))
err = c.Close()
terror.Log(errors.Trace(err))
return
}
s.rwlock.Lock()
s.clients[conn.connectionID] = conn
connections := len(s.clients)
s.rwlock.Unlock()
connGauge.Set(float64(connections))
conn.Run()
}
// ShowProcessList implements the SessionManager interface.
func (s *Server) ShowProcessList() []util.ProcessInfo {
var rs []util.ProcessInfo
s.rwlock.RLock()
for _, client := range s.clients {
if client.killed {
continue
}
rs = append(rs, client.ctx.ShowProcess())
}
s.rwlock.RUnlock()
return rs
}
// Kill implements the SessionManager interface.
func (s *Server) Kill(connectionID uint64, query bool) {
s.rwlock.Lock()
defer s.rwlock.Unlock()
conn, ok := s.clients[uint32(connectionID)]
if !ok {
return
}
conn.ctx.Cancel()
if !query {
conn.killed = true
}
}
// Server error codes.
const (
codeUnknownFieldType = 1
codeInvalidPayloadLen = 2
codeInvalidSequence = 3
codeInvalidType = 4
codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
)
func init() {
serverMySQLErrCodes := map[terror.ErrCode]uint16{
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
}
terror.ErrClassToMySQLCodes[terror.ClassServer] = serverMySQLErrCodes
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhoujin826/tidb.git
git@gitee.com:zhoujin826/tidb.git
zhoujin826
tidb
tidb
v1.0.6

搜索帮助