代码拉取完成,页面将自动刷新
package websocket
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
"gitee.com/carlzyhuang/framework/log"
"gitee.com/carlzyhuang/framework/rpc/websocket/protocol"
"github.com/gorilla/websocket"
)
type Server struct {
httpSrv *http.Server
network string
address string
timeout time.Duration
lis net.Listener
endpoint *url.URL
sessionManager *SessionManager
mu sync.Mutex // guards following
service *serviceInfo
}
func serveHome(w http.ResponseWriter, r *http.Request) {
log.Info(r.URL.String())
if r.URL.Path != "/" {
http.Error(w, "Not found", http.StatusNotFound)
return
}
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
http.ResponseWriter.Write(w, []byte("WebSocket Server is running. Connect to /ws for WebSocket endpoint."))
}
func (s *Server) serveWs(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("Upgrade error: %v", err)
return
}
// 生成会话ID(在实际应用中可以使用更复杂的ID生成策略)
sessionID := generateSessionID()
// 创建新会话
session := NewSession(sessionID, conn, s.sessionManager, s)
// 注册会话
s.sessionManager.RegisterSession(session)
// 启动读写的 goroutine
go session.WritePump()
go session.ReadPump()
}
func (s *Server) Start(ctx context.Context) error {
log.Infof("[ws] server listening on: %s", s.httpSrv.Addr)
if err := s.listenAndEndpoint(); err != nil {
return err
}
if err := s.httpSrv.Serve(s.lis); err != nil {
log.Fatalf("Serve: %v", err)
}
return nil
}
func (s *Server) Stop(ctx context.Context) error {
log.Infof("[ws] server %s stopping", s.httpSrv.Addr)
err := s.httpSrv.Shutdown(ctx)
if err != nil {
log.Errorf("server close error: %v", err)
return err
}
log.Infof("[ws] server %s stopped", s.httpSrv.Addr)
return nil
}
func (s *Server) listenAndEndpoint() error {
if s.lis == nil {
ln, err := net.Listen(s.network, s.address)
if err != nil {
return err
}
s.lis = ln
}
if s.endpoint == nil {
addr, err := Extract(s.address, s.lis)
if err != nil {
return err
}
s.endpoint = &url.URL{Scheme: "ws", Host: addr, Path: "/ws"}
}
return nil
}
func (s *Server) Endpoint() (*url.URL, error) {
if err := s.listenAndEndpoint(); err != nil {
return nil, err
}
return s.endpoint, nil
}
func NewServer(opts ...ServerOption) *Server {
srv := &Server{
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
sessionManager: NewSessionManager(),
httpSrv: &http.Server{},
service: newServiceInfo(),
}
for _, o := range opts {
o(srv)
}
mux := http.NewServeMux()
mux.HandleFunc("/", serveHome)
mux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
srv.serveWs(w, r)
})
srv.httpSrv.Handler = mux
srv.httpSrv.Addr = srv.address
return srv
}
var (
upgrader = websocket.Upgrader{
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 2048,
WriteBufferSize: 2048,
Error: nil,
CheckOrigin: func(r *http.Request) bool {
return true
},
EnableCompression: true,
}
)
// generateSessionID 生成会话ID(简化版本)
func generateSessionID() string {
return time.Now().Format("20060102150405") + "-" + randomString(8)
}
// randomString 生成随机字符串(简化版本)
func randomString(length int) string {
// 在实际应用中应该使用更安全的随机数生成器
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
result[i] = chars[time.Now().UnixNano()%int64(len(chars))]
}
return string(result)
}
// ServiceRegistrar wraps a single method that supports service registration. It
// enables users to pass concrete types other than grpc.Server to the service
// registration methods exported by the IDL generated code.
type ServiceRegistrar interface {
// RegisterService registers a service and its implementation to the
// concrete type implementing this interface. It may not be called
// once the server has started serving.
// desc describes the service and its methods and handlers. impl is the
// service implementation which is passed to the method handlers.
RegisterService(desc *ServiceDesc, impl any)
}
// RegisterService registers a service and its implementation to the gRPC
// server. It is called from the IDL generated code. This must be called before
// invoking Serve. If ss is non-nil (for legacy code), its type is checked to
// ensure it implements sd.HandlerType.
func (s *Server) RegisterService(sd *ServiceDesc, ss any) {
// if ss != nil {
// ht := reflect.TypeOf(sd.HandlerType).Elem()
// st := reflect.TypeOf(ss)
// if !st.Implements(ht) {
// log.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
// }
// }
s.register(sd, ss)
}
func (s *Server) register(sd *ServiceDesc, ss any) {
s.mu.Lock()
defer s.mu.Unlock()
log.Infof("RegisterService(%q)", sd.ServiceName)
s.service.registerServiceInfo(sd, ss)
}
func (s *Server) Dispatch(sessionCtx SessionContext, msg *protocol.RequestPacket) (*protocol.ResponsePacket, error) {
return s.service.Dispatch(sessionCtx, msg)
}
// Extract returns a private addr and port.
func Extract(hostPort string, lis net.Listener) (string, error) {
addr, port, err := net.SplitHostPort(hostPort)
if err != nil && lis == nil {
return "", err
}
if lis != nil {
p, ok := Port(lis)
if !ok {
return "", fmt.Errorf("failed to extract port: %v", lis.Addr())
}
port = strconv.Itoa(p)
}
if len(addr) > 0 && (addr != "0.0.0.0" && addr != "[::]" && addr != "::") {
return net.JoinHostPort(addr, port), nil
}
ifaces, err := net.Interfaces()
if err != nil {
return "", err
}
var (
minIndex = 0
ips = make([]net.IP, 0, 1)
)
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.Index >= minIndex && len(ips) != 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, rawAddr := range addrs {
var ip net.IP
switch addr := rawAddr.(type) {
case *net.IPAddr:
ip = addr.IP
case *net.IPNet:
ip = addr.IP
default:
continue
}
if isValidIP(ip.String()) {
minIndex = iface.Index
ips = append(ips, ip)
if ip.To4() != nil {
break
}
}
}
}
if len(ips) != 0 {
return net.JoinHostPort(ips[len(ips)-1].String(), port), nil
}
return "", nil
}
func isValidIP(addr string) bool {
ip := net.ParseIP(addr)
return ip.IsGlobalUnicast() && !ip.IsInterfaceLocalMulticast()
}
// Port return a real port.
func Port(lis net.Listener) (int, bool) {
if addr, ok := lis.Addr().(*net.TCPAddr); ok {
return addr.Port, true
}
return 0, false
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。