1 Star 0 Fork 0

siliworks / common-package

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
tcp.go 6.75 KB
一键复制 编辑 原始数据 按行查看 历史
545403892 提交于 2023-06-06 00:08 . init
package server
import (
"context"
"encoding/json"
"fmt"
"gitee.com/siliworks/common-package/collector/module/define"
"gitee.com/micro-tools/wf/container/gmap"
"gitee.com/micro-tools/wf/errors/gerror"
"gitee.com/micro-tools/wf/net/gtcp"
"gitee.com/micro-tools/wf/os/gcron"
"gitee.com/micro-tools/wf/os/glog"
"gitee.com/micro-tools/wf/os/gmutex"
"gitee.com/micro-tools/wf/os/gtime"
"gitee.com/micro-tools/wf/util/gconv"
"gitee.com/micro-tools/wf/util/guid"
)
type tcp struct {
options Options
clientCacheMux *gmutex.Mutex
clientCache *gmap.HashMap
}
// RemoveClientFromBlacklist 移除黑名单
func (t *tcp) RemoveClientFromBlacklist(key string) error {
panic("implement me")
}
// Send 发送数据
func (t *tcp) Send(key string, data []byte) error {
panic("implement me")
}
// GetClient 获取客户端
func (t *tcp) GetClient(key string) *Client {
get := t.clientCache.Get(key)
cli := get.(*Client)
return cli
}
// CloseClient 关闭客户端
func (t *tcp) CloseClient(key string) error {
err := t.unbind(key)
if err != nil {
return err
}
return nil
}
// CloseAllClient 关闭所有客户端
func (t *tcp) CloseAllClient() error {
keys := t.clientCache.Keys()
for _, v := range keys {
get := t.clientCache.Get(v)
_ = get.(*Client).Conn.Close()
t.clientCache.Remove(v)
}
return nil
}
// ClientList 获取所有客户端
func (t *tcp) ClientList() []*Client {
values := t.clientCache.Values()
var datas []*Client
for _, v := range values {
res := v.(*Client)
datas = append(datas, res)
}
return datas
}
// bind 将连接的客户端和用户的key绑定
func (t *tcp) bind(key string, cli *Client) error {
t.clientCacheMux.Lock()
defer t.clientCacheMux.Unlock()
has := t.clientCache.Contains(key)
if has {
get := t.clientCache.Get(key)
if get != nil {
_ = get.(*Client).Conn.Close()
t.clientCache.Remove(key)
}
}
t.clientCache.Set(key, cli)
return nil
}
// unBind
func (t *tcp) unbind(key string) error {
t.clientCacheMux.Lock()
defer t.clientCacheMux.Unlock()
has := t.clientCache.Contains(key)
if has {
get := t.clientCache.Get(key)
if get != nil {
_ = get.(*Client).Conn.Close()
t.clientCache.Remove(key)
} else {
t.clientCache.Remove(key)
}
}
return nil
}
// updateHeartbeat
func (t *tcp) updateHeartbeat(key string) error {
t.clientCacheMux.Lock()
defer t.clientCacheMux.Unlock()
has := t.clientCache.Contains(key)
if has {
get := t.clientCache.Get(key)
if get != nil {
client := get.(*Client)
client.HeartBeatAt = gtime.Now()
t.clientCache.Set(key, client)
}
}
return nil
}
// Run 服务启动
func (t *tcp) Run() error {
glog.Infof("tcp server is running:[%s]", t.options.Address)
t.ConnectTimeout()
return gtcp.NewServer(t.options.Address, func(conn *gtcp.Conn) {
//todo 判断接入的客户端是否超过限制
if t.options.OnConnect != nil {
t.options.OnConnect(t, &conn.Conn)
}
ctx := context.Background()
registered := false
var key string
defer conn.Close()
for {
data, err := conn.RecvPkg()
if len(data) > 0 {
var res define.Agreement
res, err = define.ParseAgreement(data)
if err != nil {
glog.Warning(ctx, err)
continue
}
// 如果是第一次收到数据,则进行认证
if !registered {
registered, key, err = t.authentication(conn, res)
if err != nil {
break
}
} else {
// 如果是第二次收到数据,则解析数据
if err := t.parse(key, res); err != nil {
return
}
}
}
if err != nil {
if t.options.OnDisconnect != nil {
t.options.OnDisconnect(t, gerror.New(fmt.Sprintf("%s,%s %s", define.DisconnectByClient, conn.RemoteAddr(), err.Error())))
}
break
}
}
}).Run()
}
func (t *tcp) authentication(conn *gtcp.Conn, req define.Agreement) (registered bool, key string, err error) {
var data define.AgreementData
if !t.options.AuthenticationEnable {
registered = true
if err = t.parse(conn.RemoteAddr().String(), req); err != nil {
return
}
key = conn.RemoteAddr().String()
registered = true
_ = t.bind(data.Key, &Client{
Address: conn.RemoteAddr().String(),
ConnectAt: gtime.Now(),
HeartBeatAt: gtime.Now(),
Conn: conn,
})
//t.codeRegisterReplay(conn, data.Key)
} else {
data, err = define.Decrypt(req.Data)
if t.options.OnAuthentication != nil {
if t.options.OnAuthentication(t, data) {
if t.registerReplay(conn) == nil {
key = data.Key
registered = true
_ = t.bind(data.Key, &Client{
Address: conn.RemoteAddr().String(),
ConnectAt: gtime.Now(),
HeartBeatAt: gtime.Now(),
Conn: conn,
})
}
} else {
_ = conn.Close()
}
}
}
return
}
func (t *tcp) parse(key string, req define.Agreement) (err error) {
var data define.AgreementData
if req.Data != "" {
data, err = define.Decrypt(req.Data)
if err != nil {
return err
}
}
switch req.Code {
case define.CodeData:
if t.options.OnData != nil {
t.options.OnData(t, data.Payload, data.Key)
}
//更新心跳
case define.CodeHeartBeat:
_ = t.updateHeartbeat(key)
default:
return gerror.New("unknown code")
}
return
}
func (t *tcp) registerReplay(conn *gtcp.Conn) (err error) {
pkg := define.Agreement{
Code: define.CodeRegisterReplay,
Timestamp: gtime.Now().Timestamp(),
}
return t.send(conn, pkg)
}
func (t *tcp) send(conn *gtcp.Conn, data define.Agreement) error {
bs, _ := json.Marshal(data)
if conn != nil {
return conn.SendPkg(bs)
} else {
return gerror.New("conn is nil")
}
}
func (t *tcp) ConnectTimeout() {
//添加单例每30s执行一次,查询客户端连接是否超时
cron := fmt.Sprintf("*/%d * * * * *", gconv.Int64(t.options.KeepAlive.Seconds()))
_, err := gcron.AddSingleton(cron, func() {
keys := t.clientCache.Keys()
for _, v := range keys {
get := t.clientCache.Get(v)
if get == nil {
_ = t.unbind(gconv.String(v))
} else {
if gconv.Int64(gtime.Now().Sub(get.(*Client).HeartBeatAt).Seconds()) < gconv.Int64(t.options.KeepAlive.Seconds()) {
continue
//否则关闭客户端连接,从缓存中删除数据(即解绑)
} else {
err := t.unbind(gconv.String(v))
glog.Info(context.Background(), "unbind success")
if err != nil {
glog.Warning(context.Background(), err)
continue
}
}
}
}
})
if err != nil {
return
}
}
func (t *tcp) codeRegisterReplay(conn *gtcp.Conn, key string) {
pkg := define.Agreement{
Code: define.CodeRegisterReplay,
Timestamp: gtime.Now().Timestamp(),
Data: "",
}
data := define.AgreementData{
Key: key,
Payload: nil,
RandomStr: guid.S(),
}
rsp, err := define.Encrypt(gconv.String(data))
if err != nil {
glog.Warning(err)
} else {
pkg.Data = rsp
err = t.send(conn, pkg)
if err != nil {
glog.Warning(err)
}
}
}
1
https://gitee.com/siliworks/common-package.git
git@gitee.com:siliworks/common-package.git
siliworks
common-package
common-package
v1.0.3

搜索帮助