1 Star 0 Fork 0

simple/simple.io

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
relay.go 4.11 KB
一键复制 编辑 原始数据 按行查看 历史
simple 提交于 2024-11-24 23:50 +08:00 . perf(optimize):
package socks
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"gitee.com/simple-set/simple.io/src/process/codec"
"net"
)
// UdpRelay UDP中继对象
type UdpRelay struct {
rsv [2]byte
frag byte
AddType AddrType // 绑定地址类型
DstAddr []byte // 目标地址
DstPort uint16 // 目标端口
serBytes []byte // 序列化字节
payload []byte // 有效载荷
buf *codec.ByteBuf // io缓冲区
}
func NewUdpRelay(buf *codec.ByteBuf) *UdpRelay {
return &UdpRelay{rsv: [2]byte{0, 0}, frag: 0, buf: buf}
}
func (u *UdpRelay) SetAddr(addr *net.UDPAddr) {
if addr != nil {
u.DstAddr = addr.IP
u.DstPort = uint16(addr.Port)
}
}
func (u *UdpRelay) GetDstAddr() *net.UDPAddr {
if u.AddType != Domain {
return &net.UDPAddr{IP: u.DstAddr, Port: int(u.DstPort)}
}
return nil
}
func (u *UdpRelay) GetDstAddrString() string {
if u.AddType != Domain {
return u.GetDstAddr().String()
}
return fmt.Sprintf("%s:%d", u.DstAddr, u.DstPort)
}
// Buf 获取io缓冲区
func (u *UdpRelay) Buf() *codec.ByteBuf {
return u.buf
}
// SetPayload 设置载荷
func (u *UdpRelay) SetPayload(payload []byte) {
u.payload = payload
}
// Payload 获取中继对象的有效载荷
func (u *UdpRelay) Payload() ([]byte, error) {
if u.payload != nil {
return u.payload, nil
}
if _, err := u.Serialize(); err != nil {
return nil, err
}
return u.payload, nil
}
// Encode 编码, 序列化并写入io缓冲区
func (u *UdpRelay) Encode() error {
if serialize, err := u.Serialize(); err != nil {
return err
} else {
if _, err := u.buf.Write(serialize); err != nil {
return err
} else {
return u.buf.Flush()
}
}
}
// Decode 解码, 从io缓冲区读取并解析
func (u *UdpRelay) Decode() error {
return u.DecodeBuff(u.buf)
}
// DecodeBuff 解码, 从指定缓冲区解析
func (u *UdpRelay) DecodeBuff(buffer *codec.ByteBuf) error {
if rsv, err := buffer.ReadBytes(2); err != nil {
return err
} else {
if !bytes.Equal(rsv, []byte{0, 0}) {
return errors.New("invalid rsv: " + string(rsv))
}
u.rsv = [2]byte{rsv[0], rsv[1]}
}
// 解析分段标志位
if readByte, err := buffer.ReadByte(); err != nil {
return err
} else {
u.frag = readByte
}
if readByte, err := buffer.ReadByte(); err != nil {
return err
} else {
u.AddType = AddrType(readByte)
}
// 解析目标地址
if u.AddType != IpV4 && u.AddType != Ipv6 && u.AddType != Domain {
return errors.New(fmt.Sprint("AddrType target type: ", u.AddType))
}
if u.AddType == IpV4 {
u.DstAddr = make([]byte, 4)
if n, err := buffer.Read(u.DstAddr); err != nil {
return err
} else if n != 4 {
return errors.New(fmt.Sprint("DstAddr size error: ", n))
}
} else if u.AddType == Ipv6 {
u.DstAddr = make([]byte, 16)
if n, err := buffer.Read(u.DstAddr); err != nil {
return err
} else if n != 16 {
return errors.New(fmt.Sprint("DstAddr size error: ", n))
}
} else if u.AddType == Domain {
domainLen, err := buffer.ReadByte()
if err != nil {
return err
}
if u.DstAddr, err = buffer.ReadBytes(int(domainLen)); err != nil {
return err
}
}
// 解析端口
if port, err := u.buf.ReadBytes(2); err != nil {
return err
} else {
u.DstPort = binary.BigEndian.Uint16(port[:2])
}
return nil
}
// Serialize 序列化, 把中继对象和有效载荷转换为字节数组
func (u *UdpRelay) Serialize() ([]byte, error) {
if u.serBytes != nil {
return u.serBytes, nil
}
// 标志符
u.serBytes = append(u.serBytes, u.rsv[:]...)
// 分段符
u.serBytes = append(u.serBytes, u.frag)
// 地址类型
u.serBytes = append(u.serBytes, byte(u.AddType))
if u.AddType == Domain {
// 如果目标地址是域名类型,则第一个字节表示域名长度
u.serBytes = append(u.serBytes, byte(len(u.DstAddr)))
}
// 目标地址
u.serBytes = append(u.serBytes, u.DstAddr...)
// 目标端口
u.serBytes = append(u.serBytes, codec.Int16ToBytes(u.DstPort)...)
if u.payload == nil {
// 获取有效载荷
if readAll, err := u.buf.ReadAll(); err != nil {
return nil, err
} else {
u.payload = readAll
}
}
u.serBytes = append(u.serBytes, u.payload...)
return u.serBytes, nil
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/simple-set/simple.io.git
git@gitee.com:simple-set/simple.io.git
simple-set
simple.io
simple.io
v1.6.5

搜索帮助