Ai
1 Star 1 Fork 0

carlzyhuang/framework

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
session.go 10.50 KB
一键复制 编辑 原始数据 按行查看 历史
huangzhiyong 提交于 2025-10-24 15:00 +08:00 . 仓库地址修改
package websocket
import (
"context"
"fmt"
"sync"
"time"
"gitee.com/carlzyhuang/framework/log"
"gitee.com/carlzyhuang/framework/rpc/websocket/protocol"
"github.com/gorilla/websocket"
"google.golang.org/protobuf/proto"
)
type ServiceDispatcher interface {
Dispatch(sessionCtx SessionContext, msg *protocol.RequestPacket) (*protocol.ResponsePacket, error)
}
// Session 表示一个 WebSocket 会话
type Session struct {
ID string
conn *websocket.Conn
send chan []byte
manager *SessionManager
mu sync.Mutex
isClosed bool
lastActivity time.Time
disp ServiceDispatcher
context SessionContext
}
// SessionManager 管理所有活跃的会话
type SessionManager struct {
sessions map[string]*Session
playerSessions map[int64]*Session
broadcast chan []byte
register chan *Session
unregister chan *Session
mu sync.RWMutex
}
// NewSessionManager 创建新的会话管理器
func NewSessionManager() *SessionManager {
return &SessionManager{
sessions: make(map[string]*Session),
broadcast: make(chan []byte),
register: make(chan *Session, 256),
unregister: make(chan *Session, 256),
}
}
// Start 启动会话管理器
func (sm *SessionManager) Start() {
go sm.run()
}
// run 处理会话的注册、注销和广播
func (sm *SessionManager) run() {
for {
select {
case session := <-sm.register:
sm.mu.Lock()
sm.sessions[session.ID] = session
// playerSessions 在这里不进行关联,等玩家绑定后再关联
sm.mu.Unlock()
log.Infof("Session %s registered. Total sessions: %d", session.ID, len(sm.sessions))
case session := <-sm.unregister:
sm.mu.Lock()
if _, ok := sm.sessions[session.ID]; ok {
delete(sm.sessions, session.ID)
// 同时删除 playerSessions 中的关联
delete(sm.playerSessions, session.context.GetPlayer().GetPlayerID())
close(session.send)
log.Infof("Session %s unregistered. Total sessions: %d players: %d", session.ID, len(sm.sessions), len(sm.playerSessions))
}
sm.mu.Unlock()
case message := <-sm.broadcast:
sm.mu.RLock()
for _, session := range sm.sessions {
select {
case session.send <- message:
default:
// 如果发送通道已满,关闭会话
go session.Close()
}
}
sm.mu.RUnlock()
}
}
}
// RegisterSession 注册新会话
func (sm *SessionManager) RegisterSession(session *Session) {
sm.register <- session
}
// UnregisterSession 注销会话
func (sm *SessionManager) UnregisterSession(session *Session) {
sm.unregister <- session
}
// Broadcast 广播消息给所有会话
func (sm *SessionManager) Broadcast(message []byte) {
sm.broadcast <- message
}
// GetSession 获取指定ID的会话
func (sm *SessionManager) GetSession(id string) *Session {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.sessions[id]
}
func (sm *SessionManager) GetSessionByPlayerID(playerID int64) *Session {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.playerSessions[playerID]
}
// GetSessionCount 获取当前会话数量
func (sm *SessionManager) GetSessionCount() int {
sm.mu.RLock()
defer sm.mu.RUnlock()
return len(sm.sessions)
}
// CloseAll 关闭所有会话
func (sm *SessionManager) CloseAll() {
sm.mu.Lock()
defer sm.mu.Unlock()
for _, session := range sm.sessions {
session.Close()
}
sm.sessions = make(map[string]*Session)
}
// NewSession 创建新的会话
func NewSession(id string, conn *websocket.Conn, manager *SessionManager, disp ServiceDispatcher) *Session {
s := &Session{
ID: id,
conn: conn,
send: make(chan []byte, 256),
manager: manager,
lastActivity: time.Now(),
disp: disp,
}
s.context = NewSessionContext(s)
return s
}
// ReadPump 处理来自客户端的消息读取
func (s *Session) ReadPump() {
defer func() {
s.Close()
}()
s.conn.SetReadLimit(4 * 1024 * 1024) // 限制消息大小
s.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
s.conn.SetPongHandler(func(string) error {
s.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
s.updateActivity()
return nil
})
for {
_, message, err := s.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Errorf("Read error for session %s: %v", s.ID, err)
}
break
}
s.updateActivity()
log.Debugf("Received message from session %s", s.ID)
request := &protocol.RequestPacket{}
if err := proto.Unmarshal(message, request); err != nil {
newErrorResponse(request.Header, protocol.StatusCode_WS_UNMARSHAL_FAIL, err.Error())
log.Errorf("Failed to unmarshal message for session %s: %v", s.ID, err)
continue
}
// TODO 优化:可以考虑使用消息队列来处理,增加缓冲区,防止单个慢请求阻塞整个读取流程
resp, err := s.disp.Dispatch(s.context, request)
if err != nil {
newErrorResponse(request.Header, protocol.StatusCode_WS_SERVER_ERROR, err.Error())
log.Errorf("Dispatch error for session %s: %v", s.ID, err)
continue
}
if resp != nil {
respData, err := proto.Marshal(resp)
if err != nil {
newErrorResponse(request.Header, protocol.StatusCode_WS_SERVER_UNMARSHAL_FAIL, err.Error())
log.Errorf("Failed to marshal response for session %s: %v", s.ID, err)
continue
}
if !s.Send(respData) {
log.Errorf("Failed to send response to session %s", s.ID)
}
}
}
}
// WritePump 处理向客户端发送消息
func (s *Session) WritePump() {
ticker := time.NewTicker(54 * time.Second) // 心跳间隔
defer func() {
ticker.Stop()
s.Close()
}()
for {
select {
case message, ok := <-s.send:
s.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
// 通道关闭
s.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := s.conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return
}
log.Debugf("session %s send msg %d", s.ID, len(message))
w.Write(message)
case <-ticker.C:
s.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
// Send 发送消息到会话
func (s *Session) Send(message []byte) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.isClosed {
return false
}
select {
case s.send <- message:
return true
case <-time.After(10 * time.Second):
// 发送通道已满,关闭会话
id := int64(0)
if player := s.context.GetPlayer(); player != nil {
id = player.GetPlayerID()
}
log.Errorf("session %s player %d send channel timeout, close!", s.ID, id)
go s.Close()
return false
}
}
// Close 关闭会话
func (s *Session) Close() {
s.mu.Lock()
defer s.mu.Unlock()
if s.isClosed {
return
}
s.isClosed = true
s.manager.UnregisterSession(s)
s.conn.Close()
log.Infof("Session %s closed", s.ID)
}
// updateActivity 更新会话的最后活动时间
func (s *Session) updateActivity() {
s.mu.Lock()
defer s.mu.Unlock()
s.lastActivity = time.Now()
}
type BasePlayer interface {
GetPlayerID() int64
}
type SessionContext interface {
// 只读数据
Context() context.Context
GetPlayer() BasePlayer
// 发送消息
Send(msgID uint32, data proto.Message) bool
// 发送原始数据
SendRaw(msgID uint32, data []byte) bool
// 给指定玩家发送数据
SendByPlayer(playerID int64, msgID uint32, data proto.Message) (bool, error)
// 给指定玩家发送原始数据
SendRawByPlayer(playerID int64, msgID uint32, data []byte) (bool, error)
// 绑定session后关联的玩家信息
SetPlayer(player BasePlayer)
WithContext(ctx context.Context) SessionContext
}
// sessionContext 会话上下文,包含会话信息和玩家信息
type sessionContext struct {
ctx context.Context
session *Session
Player BasePlayer
}
// WithContext implements SessionContext.
func (sc *sessionContext) WithContext(ctx context.Context) SessionContext {
sc.ctx = ctx
return sc
}
// Send implements SessionContext.
func (sc *sessionContext) Send(msgID uint32, data proto.Message) bool {
payload, err := proto.Marshal(data)
if err != nil {
log.Errorf("failed to marshal message %d: %v", msgID, err)
return false
}
return sc.session.Send(newResponse(msgID, protocol.StatusCode_WS_OK, payload))
}
// SendByPlayer implements SessionContext.
func (sc *sessionContext) SendByPlayer(playerID int64, msgID uint32, data proto.Message) (bool, error) {
payload, err := proto.Marshal(data)
if err != nil {
log.Errorf("failed to marshal message %d: %v", msgID, err)
return false, err
}
ss := sc.session.manager.GetSessionByPlayerID(playerID)
if ss == nil {
return false, fmt.Errorf("session not found for player %d", playerID)
}
return ss.Send(newResponse(msgID, protocol.StatusCode_WS_OK, payload)), nil
}
// SendRaw implements SessionContext.
func (sc *sessionContext) SendRaw(msgID uint32, data []byte) bool {
return sc.session.Send(newResponse(msgID, protocol.StatusCode_WS_OK, data))
}
// SendRawByPlayer implements SessionContext.
func (sc *sessionContext) SendRawByPlayer(playerID int64, msgID uint32, data []byte) (bool, error) {
ss := sc.session.manager.GetSessionByPlayerID(playerID)
if ss == nil {
return false, fmt.Errorf("session not found for player %d", playerID)
}
return ss.Send(newResponse(msgID, protocol.StatusCode_WS_OK, data)), nil
}
// Context implements SessionContext.
func (sc *sessionContext) Context() context.Context {
return sc.ctx
}
type emptyPlayerStruct struct{}
func (e *emptyPlayerStruct) GetPlayerID() int64 {
return 0
}
func emptyPlayer() BasePlayer {
return &emptyPlayerStruct{}
}
func NewSessionContext(session *Session) SessionContext {
return &sessionContext{
session: session,
Player: emptyPlayer(),
ctx: context.Background(),
}
}
func (sc *sessionContext) GetPlayer() BasePlayer {
return sc.Player
}
// SetPlayer 设置会话关联的玩家信息
func (sc *sessionContext) SetPlayer(player BasePlayer) {
sc.Player = player
// 将玩家ID与会话关联
sc.session.manager.mu.Lock()
defer sc.session.manager.mu.Unlock()
sc.session.manager.playerSessions[player.GetPlayerID()] = sc.session
}
func newResponse(msgID uint32, status protocol.StatusCode, payload []byte) []byte {
header := &protocol.Header{
MessageId: msgID,
Sequence: 0,
Timestamp: uint64(time.Now().UnixMilli()),
Version: 0,
}
response := &protocol.ResponsePacket{
Header: header,
Code: int32(status),
Payload: payload,
}
respData, err := proto.Marshal(response)
if err != nil {
log.Errorf("failed to marshal response message for msg %d: %v", msgID, err)
return nil
}
return respData
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/carlzyhuang/framework.git
git@gitee.com:carlzyhuang/framework.git
carlzyhuang
framework
framework
v0.0.18

搜索帮助