代码拉取完成,页面将自动刷新
package websocket
import (
"bytes"
"compress/flate"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strconv"
"sync"
"time"
"gitee.com/ccconnor/exchange-api/exchange/okex/future"
"gitee.com/ccconnor/exchange-api/exchange/okex/internal"
"github.com/gorilla/websocket"
)
type MessageHandler func(message *TableMessage)
type Client struct {
endpoint string
apiKey string
secretKey string
password string
conn *websocket.Conn
handler MessageHandler
topics map[string]bool
loginCh chan error
closeCh chan struct{}
writeMutex sync.Mutex
connected bool
helloTimer *time.Timer
deadTimer *time.Timer
wgReader sync.WaitGroup
wgReconnect sync.WaitGroup
}
const (
helloInterval = 5 * time.Second
deadInterval = 10 * time.Second
connectTimeout = 20 * time.Second
loginTimeout = 5 * time.Second
writeTimeout = 5 * time.Second
)
// NewClient json: 返回json数据还是struct数据
func NewClient(endpoint, apiKey, secretKey, password string, topics []string, handler MessageHandler) *Client {
h := &Client{
endpoint: endpoint,
apiKey: apiKey,
secretKey: secretKey,
password: password,
topics: make(map[string]bool),
handler: handler,
loginCh: make(chan error),
}
for _, v := range topics {
h.topics[v] = true
}
return h
}
// Connect 连接, 登录(如果有密码), 订阅
func (c *Client) Connect() error {
if c.connected {
return nil
}
dialer := &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: connectTimeout,
}
conn, resp, err := dialer.Dial(c.endpoint, nil)
if err != nil {
switch {
case resp == nil:
log.Println("dial:", err)
case resp.Body == nil:
log.Printf("dial:%v, status:%v", err, resp.StatusCode)
default:
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
log.Printf("dial:%v, status:%v, info:%v", err, resp.StatusCode, string(respBody))
}
return err
}
c.conn = conn
go c.reader()
err = c.login(loginTimeout)
if err != nil {
log.Println("login failed", err)
_ = conn.Close()
return err
}
c.connected = true
c.closeCh = make(chan struct{})
if len(c.topics) > 0 {
_ = c.sendCommand(EventTypeSubscribe, c.getTopics())
}
return nil
}
func (c *Client) Close() {
select {
case <-c.closeCh:
return
default:
close(c.closeCh)
c.wgReconnect.Wait()
c.stop()
log.Println("websocket is closed")
}
}
func (c *Client) isClosed() bool {
select {
case <-c.closeCh:
return true
default:
return false
}
}
func (c *Client) Connected() bool {
return c.connected
}
func (c *Client) Subscribe(topics []string) error {
if !c.connected {
return fmt.Errorf("not connected")
}
var args []string
for _, v := range topics {
if _, exist := c.topics[v]; !exist {
c.topics[v] = true
args = append(args, v)
}
}
if len(args) == 0 {
return nil
}
return c.sendCommand(EventTypeSubscribe, args)
}
func (c *Client) Unsubscribe(topics []string) error {
if !c.connected {
return fmt.Errorf("not connected")
}
var args []string
for _, v := range topics {
if _, exist := c.topics[v]; exist {
delete(c.topics, v)
args = append(args, v)
}
}
if len(args) == 0 {
return nil
}
return c.sendCommand(EventTypeUnsubscribe, args)
}
func (c *Client) login(timeout time.Duration) error {
if c.apiKey == "" || c.secretKey == "" || c.password == "" {
return nil
}
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
signature := internal.Sign(c.secretKey, "GET", "/users/self/verify", timestamp, "")
err := c.sendCommand(EventTypeLogin, []string{c.apiKey, c.password, timestamp, signature})
if err != nil {
log.Println("write error", err)
return err
}
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case <-timer.C:
return fmt.Errorf("login timeout")
case err := <-c.loginCh:
return err
}
}
func (c *Client) ping() {
if c.connected {
_ = c.sendMessage([]byte("ping"))
}
}
func (c *Client) reader() {
c.wgReader.Add(1)
defer c.wgReader.Done()
for {
c.keepAlive()
msgType, message, err := c.conn.ReadMessage()
if err != nil {
log.Println("read error:", err)
break
}
messageJSON := message
switch msgType {
case websocket.TextMessage:
case websocket.BinaryMessage:
messageJSON, err = c.decompress(message)
if err != nil {
log.Println("decompress message failed", err)
}
default:
log.Println("received message type", msgType)
continue
}
var messageMap map[string]interface{}
_ = json.Unmarshal(messageJSON, &messageMap)
if event, found := messageMap["event"]; found {
c.handleEventMessage(event.(string), messageJSON)
continue
}
if _, found := messageMap["table"]; found {
c.handleTableMessage(messageMap)
continue
}
log.Println("received message:", string(messageJSON))
}
log.Println("exit websocket read loop")
}
func (c *Client) keepAlive() {
if !c.connected {
return
}
if c.helloTimer == nil {
c.helloTimer = time.AfterFunc(helloInterval, c.ping)
} else {
c.helloTimer.Reset(helloInterval)
}
if c.deadTimer == nil {
c.deadTimer = time.AfterFunc(deadInterval, c.reconnect)
} else {
c.deadTimer.Reset(deadInterval)
}
}
func (c *Client) reconnect() {
c.wgReconnect.Add(1)
defer c.wgReconnect.Done()
c.stop()
log.Println("websocket is disconnected, try reconnect...")
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if c.Connect() == nil {
return
}
case <-c.closeCh:
log.Println("stop reconnecting")
return
}
}
}
func (c *Client) stop() {
if c.helloTimer != nil {
c.helloTimer.Stop()
}
if c.deadTimer != nil {
c.deadTimer.Stop()
}
if c.conn != nil {
_ = c.conn.Close()
}
c.connected = false
c.wgReader.Wait()
}
func (c *Client) handleEventMessage(eventName string, messageJSON []byte) {
switch eventName {
case EventTypeLogin:
var event EventLogin
_ = json.Unmarshal(messageJSON, &event)
log.Println("login return", event.Success)
if event.Success {
c.loginCh <- nil
} else {
err := fmt.Errorf("login failed")
c.loginCh <- err
}
case EventTypeSubscribe:
var event EventSubscribe
_ = json.Unmarshal(messageJSON, &event)
log.Println("subscribed:", event.Channel)
case EventTypeUnsubscribe:
var event EventUnsubscribe
_ = json.Unmarshal(messageJSON, &event)
log.Println("unsubscribed:", event.Channel)
case EventTypeError:
var event EventError
_ = json.Unmarshal(messageJSON, &event)
log.Println("received error:", event.ErrorCode, event.Message)
if event.ErrorCode == ErrorInvalidSign {
err := fmt.Errorf("%v", event.Message)
c.loginCh <- err
}
default:
log.Println("received event:", string(messageJSON))
}
}
func (c *Client) handleTableMessage(messageMap map[string]interface{}) {
if c.handler == nil {
return
}
table := messageMap["table"].(string)
var action string
if v, found := messageMap["action"]; found {
action = v.(string)
}
data, _ := json.Marshal(messageMap["data"])
tableMessage := &TableMessage{
Action: action,
Table: table,
}
switch table {
case ChnlFuturesPriceRange:
var priceList []*future.TablePriceRange
_ = json.Unmarshal(data, &priceList)
tableMessage.Data = priceList
case ChnlFuturesDepth5:
var depthList []*future.TableDepth5
_ = json.Unmarshal(data, &depthList)
tableMessage.Data = depthList
case ChnlFuturesPosition:
var positionList []*future.TablePosition
_ = json.Unmarshal(data, &positionList)
tableMessage.Data = positionList
case ChnlFuturesOrder:
var orderList []*future.TableOrder
_ = json.Unmarshal(data, &orderList)
tableMessage.Data = orderList
default:
log.Printf("table %v not handled", table)
return
}
c.handler(tableMessage)
}
func (c *Client) sendCommand(command string, args []string) error {
message := map[string]interface{}{
"op": command,
"args": args,
}
jsonMsg, _ := json.Marshal(message)
return c.sendMessage(jsonMsg)
}
func (c *Client) sendMessage(message []byte) error {
c.writeMutex.Lock()
defer c.writeMutex.Unlock()
_ = c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
return c.conn.WriteMessage(websocket.TextMessage, message)
}
func (c *Client) decompress(message []byte) ([]byte, error) {
reader := flate.NewReader(bytes.NewReader(message))
defer reader.Close()
return io.ReadAll(reader)
}
func (c *Client) getTopics() []string {
topics := make([]string, len(c.topics))
for k := range c.topics {
topics = append(topics, k)
}
return topics
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。