1 Star 0 Fork 2

何吕 / volantmq

forked from JUMEI_ARCH / volantmq 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
base.go 6.82 KB
一键复制 编辑 原始数据 按行查看 历史
hawklin 提交于 2018-06-14 18:24 . bugfix
package transport
import (
"errors"
"sync"
"time"
"github.com/VolantMQ/volantmq/auth"
"github.com/VolantMQ/volantmq/clients"
"github.com/VolantMQ/volantmq/packet"
"github.com/VolantMQ/volantmq/ratelimit/type"
"github.com/VolantMQ/volantmq/routines"
"github.com/VolantMQ/volantmq/systree"
"github.com/VolantMQ/volantmq/trace"
"go.uber.org/zap"
)
// Config is base configuration object used by all transports
type Config struct {
// AuthManager
AuthManager *auth.Manager
RateLimiter RateType.RateLimiter
// Port tcp port to listen on
Port string
// Host
Host string
}
// InternalConfig used by server implementation to configure internal specific needs
type InternalConfig struct {
// AllowedVersions what protocol version server will handle
// If not set than defaults to 0x3 and 0x04
AllowedVersions map[packet.ProtocolVersion]bool
Sessions *clients.Manager
Metric systree.Metric
// ConnectTimeout The number of seconds to wait for the CONNACK message before disconnecting.
// If not set then default to 2 seconds.
ConnectTimeout int
// KeepAlive The number of seconds to keep the connection live if there's no data.
// If not set then defaults to 5 minutes.
KeepAlive int
}
type baseConfig struct {
InternalConfig
config Config
onConnection sync.WaitGroup // nolint: structcheck
onceStop sync.Once // nolint: structcheck
quit chan struct{} // nolint: structcheck
log *zap.Logger
protocol string
}
// Provider is interface that all of transports must implement
type Provider interface {
Protocol() string
Serve() error
Close() error
Port() string
}
// Port return tcp port used by transport
func (c *baseConfig) Port() string {
return c.config.Port
}
// Protocol return protocol name used by transport
func (c *baseConfig) Protocol() string {
return c.protocol
}
// handleConnection is for the broker to handle an incoming connection from a client
func (c *baseConfig) handleConnection(conn conn) {
if c == nil {
c.log.Error("Invalid connection type")
return
}
var err error
defer func() {
if err != nil {
conn.Close() // nolint: errcheck, gas
}
}()
// To establish a connection, we must
// 1. Read and decode the message.ConnectMessage from the wire
// 2. If no decoding errors, then authenticate using username and password.
// Otherwise, write out to the wire message.ConnackMessage with
// appropriate error.
// 3. If authentication is successful, then either create a new session or
// retrieve existing session
// 4. Write out to the wire a successful message.ConnackMessage message
// Read the CONNECT message from the wire, if error, then check to see if it's
// a CONNACK error. If it's CONNACK error, send the proper CONNACK error back
// to client. Exit regardless of error type.
conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(c.ConnectTimeout))) // nolint: errcheck, gas
overflow := c.config.RateLimiter.WaitToken()
var req packet.Provider
var buf []byte
if buf, err = routines.GetMessageBuffer(conn); err != nil {
c.log.Error("Couldn't get CONNECT message", zap.String("remote", conn.RemoteAddr().String()), zap.Error(err))
return
}
var reason packet.ReasonCode
var hasErrReason bool
if req, _, err = packet.Decode(packet.ProtocolV50, buf); err != nil {
if reason, hasErrReason = err.(packet.ReasonCode); !hasErrReason {
c.log.Error("Couldn't decode message", zap.String("remote", conn.RemoteAddr().String()), zap.Error(err))
return
}
}
c.Metric.Packets().Received(req.Type())
if overflow == false {
reason = packet.CodeServerBusy
hasErrReason = true
if r, ok := req.(*packet.Connect); ok {
c.log.Error("rearch rate limit",
zap.Float64("speed", c.config.RateLimiter.Limit()),
zap.Int("burst", c.config.RateLimiter.Burst()),
zap.String("clientID", string(r.ClientID())),
zap.String("remote", conn.RemoteAddr().String()),
)
}
} else {
if r, ok := req.(*packet.Connect); ok {
c.log.Debug("pass rate limit",
zap.Float64("speed", c.config.RateLimiter.Limit()),
zap.Int("burst", c.config.RateLimiter.Burst()),
zap.String("clientID", string(r.ClientID())),
zap.String("remote", conn.RemoteAddr().String()),
)
}
}
if err == nil || hasErrReason {
// Disable read deadline. Will set it later if keep-alive interval is bigger than 0
conn.SetReadDeadline(time.Time{}) // nolint: errcheck
switch r := req.(type) {
case *packet.Connect:
clientID := string(r.ClientID())
m, _ := packet.New(req.Version(), packet.CONNACK)
resp, _ := m.(*packet.ConnAck)
if hasErrReason {
username, _ := r.Credentials()
systreeConnStatus := &systree.ClientConnectStatus{
Username: string(username),
Timestamp: time.Now().Unix(),
Address: conn.RemoteAddr().String(),
Protocol: r.Version(),
ConnAckCode: reason,
CleanSession: r.IsClean(),
}
c.Sessions.Systree.Clients().Connected(clientID, systreeConnStatus)
c.Sessions.Systree.Clients().Disconnected(clientID, reason, true)
c.Metric.Packets().Sent(resp.Type())
c.log.Error("client login fail", zap.String("clientID", clientID), zap.String("reason", reason.Error()), zap.String("remote", conn.RemoteAddr().String()))
resp.SetReturnCode(reason)
routines.WriteMessage(conn, resp)
return
}
if trace.GetInstance().Status(clientID, "") {
c.log.Info("client connect", zap.String("clientID", clientID), zap.String("remote", conn.RemoteAddr().String()))
} else {
c.log.Debug("client connect", zap.String("clientID", clientID), zap.String("remote", conn.RemoteAddr().String()))
}
// If protocol version is not in allowed list then give reject and pass control to session manager
// to handle response
if allowed, ok := c.AllowedVersions[r.Version()]; !ok || !allowed {
reason = packet.CodeRefusedUnacceptableProtocolVersion
if r.Version() == packet.ProtocolV50 {
reason = packet.CodeUnsupportedProtocol
}
} else {
user, pass := r.Credentials()
if status := c.config.AuthManager.Password(string(user), string(pass)); status == auth.StatusAllow {
reason = packet.CodeSuccess
} else {
reason = packet.CodeRefusedBadUsernameOrPassword
if req.Version() == packet.ProtocolV50 {
reason = packet.CodeBadUserOrPassword
}
c.log.Warn("bad username or password", zap.String("clientID", clientID))
}
}
resp.SetReturnCode(reason) // nolint: errcheck
err = c.Sessions.NewSession(
&clients.StartConfig{
Req: r,
Resp: resp,
Conn: conn,
Auth: c.config.AuthManager,
})
if err != nil {
c.log.Error("Failed to create session.", zap.ByteString("ClientID", r.ClientID()), zap.Error(err))
}
default:
c.log.Error("Unexpected message type",
zap.String("expected", "CONNECT"),
zap.String("received", r.Type().Name()))
err = errors.New("unexpected message type")
}
}
}
Go
1
https://gitee.com/kaifazhe/volantmq.git
git@gitee.com:kaifazhe/volantmq.git
kaifazhe
volantmq
volantmq
v0.0.4

搜索帮助

14c37bed 8189591 565d56ea 8189591