Ai
1 Star 0 Fork 0

hexug/goChainRestfulClient

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
request.go 16.05 KB
一键复制 编辑 原始数据 按行查看 历史
hexug 提交于 2024-10-14 17:39 +08:00 . 为http客户端添加自定义DNS解析功能
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
package rest
import (
"archive/tar"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"mime/multipart"
"net"
"net/http"
"net/http/cookiejar"
"net/url"
"os"
"path"
"reflect"
"strconv"
"strings"
"time"
"gitee.com/hexug/go-chain-restful-client/negotiator"
"github.com/bytedance/sonic"
)
// 设置请求方法
func (r *Request) Method(method Method) *Request {
r.method = method
r.log.L().Debugf("设置请求方法为:%s", method)
return r
}
// 设置前缀路径,但是会接在前面的前缀路径后面
// 使用AbsPath会直接清空前缀路径,而使用AbsPath中的路径
func (r *Request) Prefix(segments ...string) *Request {
if r.err != nil {
return r
}
r.prePath = path.Join(r.prePath, path.Join(segments...))
return r
}
// 设置后缀部分,最后会接在前缀的后面
func (r *Request) Suffix(segments ...string) *Request {
if r.err != nil {
return r
}
r.subPath = path.Join(r.subPath, path.Join(segments...))
return r
}
// 设置url为reqPath
func (r *Request) URL(p string) *Request {
u, err := url.Parse(p)
if err != nil {
r.err = err
r.log.L().Errorw(err.Error())
return r
}
r.reqPath = u.String()
return r
}
// 设置请求的URL的绝对路径
func (r *Request) AbsURL(p string) *Request {
// 如果设置绝对路径,则reqPath就直接设置为p
r.URL(p)
r.isAbs = true
return r
}
// 生成完整的url路径
func (r *Request) url() string {
// 先初始化路径
u, err := url.Parse(r.reqPath)
if err != nil {
r.err = err
return ""
}
// 如果不是绝对路径,就凑出绝对路径
if !r.isAbs {
// 如果有前缀,就先将前缀加到请求路径前面
if r.prePath != "" {
u.Path = path.Join(r.prePath, u.Path)
}
// 如果有
if r.subPath != "" {
u.Path = path.Join(u.Path, r.subPath)
}
}
u1, err := url.JoinPath(r.basePath, u.Path)
if err != nil {
r.err = err
return ""
}
if r.params != nil {
u2, err := url.Parse(u1)
if err != nil {
r.err = err
return ""
}
u2.RawQuery = r.params.Encode()
return u2.String()
}
return u1
}
// 请求阶段也可以设置超时时长,如果客户端也设置了,就会覆盖客户端设置的超时时长
func (r *Request) Timeout(d time.Duration) *Request {
if r.err != nil {
return r
}
r.timeout = d
return r
}
// 请求阶段也可以设置
func (r *Request) Header(key string, values ...string) *Request {
if key == "" {
r.log.L().Errorw("设置的header不合法")
return r
}
if r.headers == nil {
r.headers = http.Header{}
}
old := r.headers.Get("Content-Type")
if old != "" {
r.log.L().Debugf("原有 %s 为 %s,已替换为新的Content-Type", key, old)
}
r.headers.Del(key)
for _, value := range values {
r.headers.Add(key, value)
r.log.L().Debugf("添加header:%s = %s", key, value)
}
return r
}
// 设置cookie
func (r *Request) Cookie(cs ...*http.Cookie) *Request {
if r.err != nil {
return r
}
if r.cookies == nil {
r.cookies = make([]*http.Cookie, 0)
}
r.cookies = append(r.cookies, cs...)
r.log.L().Debugf("已经设置了cookies %v", cs)
return r
}
// 添加查询参数
func (r *Request) Param(paramName, value string) *Request {
if r.err != nil {
return r
}
if r.params == nil {
r.params = make(url.Values)
}
r.params[paramName] = append(r.params[paramName], value)
return r
}
// 提交多种内容的表单 当上传文件的同时提交表单需要用这个方法
func (r *Request) MultiFormData(key string, values ...string) *Request {
if key == "" {
r.log.L().Errorw("提交的表单不能为空")
return r
}
if r.multiFormData == nil {
r.multiFormData = make(url.Values)
}
// r.multiFormData.Del(key)
for _, value := range values {
r.multiFormData.Add(key, value)
r.log.L().Debugf("添加header:%s = %s", key, value)
}
return r
}
// 上传文件
func (r *Request) UploadFiles(fileKey string, filePath string) *Request {
if fileKey == "" || filePath == "" {
r.log.L().Errorw("文件的key不能为空")
return r
}
if r.files == nil {
r.files = make(map[string]string)
}
r.files[fileKey] = filePath
r.log.L().Debugf("已经设置文件表单数据:%s = %s", fileKey, filePath)
return r
}
// 设置单纯表单数据
func (r *Request) FormData(key, value string) *Request {
if key == "" {
r.log.L().Errorw("提交的表单不能为空")
return r
}
if r.formData == nil {
r.formData = make(url.Values)
}
r.formData.Add(key, value)
r.log.L().Debugf("已经设置表单数据:%s=%s", key, value)
return r
}
// 设置json数据
func (r *Request) DataJs(data any) *Request {
b, err := sonic.Marshal(data)
if err != nil {
r.err = err
r.log.L().Errorf("数据序列化失败,%s", err.Error())
return r
}
r.dataJs = b
r.Header(CONTENT_TYPE_HEADER, "application/json")
r.body = bytes.NewReader(b)
r.log.L().Debugf("已经提交JSON数据:%+v", data)
return r
}
// 设置文本数据
func (r *Request) RowText(t string) *Request {
if t == "" {
r.log.L().Errorw("提交的文本数据无效")
return r
}
r.rowText = t
r.Header(CONTENT_TYPE_HEADER, "text/plain")
r.body = strings.NewReader(t)
return r
}
// 为docker远程传输数据实现
func (r *Request) TarFiles(
folderPath string, // base目录 docker中就是Dockerfile的目录所在,dockerfile必须在最上层
fileWhitelist, // 文件白名单
dirWhitelist, // 文件夹白名单
fileBlacklist, // 文件黑名单
dirBlacklist []string, // 文件夹黑名单
) *Request {
buffer := new(bytes.Buffer)
tarWriter := tar.NewWriter(buffer)
err := tarFolderToReader(
tarWriter,
folderPath,
false,
fileWhitelist,
dirWhitelist,
fileBlacklist,
dirBlacklist,
)
if err != nil {
r.log.L().Errorln("创建tar归档失败:", err)
r.err = err
return r
}
r.Header(CONTENT_TYPE_HEADER, "application/x-tar")
r.body = buffer
return r
}
// TODO 待改进
func (r *Request) Body(v any) *Request {
if r.err != nil {
return r
}
ct := HeaderFilterFlags(r.headers.Get(CONTENT_TYPE_HEADER))
nt := negotiator.GetNegotiator(ct)
b, err := nt.Encode(v)
if err != nil {
r.err = err
return r
}
r.body = bytes.NewReader(b)
return r
}
func (r *Request) Do(ctx context.Context) *Response {
// 请求响应对象
resp := NewResponse(r.c)
if r.formData != nil {
r.Header(CONTENT_TYPE_HEADER, "application/x-www-form-urlencoded")
r.body = strings.NewReader(r.formData.Encode())
} else if r.multiFormData != nil || r.files != nil {
body := new(bytes.Buffer)
writer := multipart.NewWriter(body)
if r.multiFormData != nil {
for key, values := range r.multiFormData {
for _, value := range values {
_ = writer.WriteField(key, value)
}
}
}
if r.files != nil {
for key, file := range r.files {
file1, err := os.Open(file)
if err != nil {
r.log.L().Errorf(fmt.Sprintf("文件 %s 无法打开:%s", file, err.Error()))
resp.err = err
// os.Exit(1)
return resp
}
defer file1.Close()
part, err := writer.CreateFormFile(key, file)
if err != nil {
r.log.L().Errorln("创建文件表单字段失败:", err)
resp.err = err
// os.Exit(1)
return resp
}
_, err = io.Copy(part, file1)
if err != nil {
r.log.L().Errorf("拷贝 %s 数据失败: %s", file, err.Error())
resp.err = err
// os.Exit(1)
return resp
}
}
}
err := writer.Close()
if err != nil {
r.log.L().Errorln("关闭数据写入器失败:", err)
resp.err = err
// os.Exit(1)
return resp
}
r.body = body
r.Header(CONTENT_TYPE_HEADER, writer.FormDataContentType())
}
// 准备请求
req, err := http.NewRequestWithContext(ctx, r.method.String(), r.url(), r.body)
if err != nil {
resp.err = err
r.log.L().Errorw(resp.err.Error())
return resp
}
req.URL.RawQuery = r.params.Encode()
// 补充Header
for k, vs := range r.headers {
for i := range vs {
req.Header.Set(k, vs[i])
}
}
// 补充认证
r.buildAuth(req)
// 补充cookie
for i := range r.cookies {
req.AddCookie(r.cookies[i])
}
// 补充代理部分
if r.proxyAdd != "" {
u := &url.URL{}
if r.proxyScheme == "" {
u.Scheme = "http"
} else {
u.Scheme = r.proxyScheme
}
if r.proxyPort != "" {
u.Host = strings.Join([]string{r.proxyAdd, r.proxyPort}, ":")
} else {
u.Host = r.proxyAdd
}
r.log.L().Debugf("代理地址为:%s", u.String())
transport := reflect.ValueOf(r.c.client.Transport)
if transport.IsValid() && !transport.IsNil() {
tsp := r.c.client.Transport.(*http.Transport)
tsp.MaxIdleConns = 10
tsp.MaxConnsPerHost = 10
tsp.IdleConnTimeout = time.Duration(10) * time.Second
tsp.Proxy = http.ProxyURL(u)
} else {
transport := &http.Transport{
MaxIdleConns: 10,
MaxConnsPerHost: 10,
IdleConnTimeout: time.Duration(10) * time.Second,
Proxy: http.ProxyURL(u),
}
r.c.client.Transport = transport
}
}
// debug信息
r.debug(req)
// 发起请求
raw, err := r.c.client.Do(req)
if err != nil {
resp.err = err
if strings.HasSuffix(err.Error(), "certificate signed by unknown authority") {
r.log.L().Errorln("证书不安全,需要忽略证书,才能继续;忽略证书,请使用SetIgnoreCert()方法!")
} else if strings.HasSuffix(err.Error(), "tls: server selected unsupported protocol version 301") {
r.log.L().Errorln("tls协议不支持,请降低tls版本;要降低tls版本,请使用SetMinVersionTLS()方法!")
} else {
r.log.L().Errorln(err.Error())
}
resp.err = err
// os.Exit(1)
return resp
}
// 设置返回
resp.withStatusCode(raw.StatusCode)
resp.withHeader(raw.Header)
resp.withBody(raw.Body)
resp.withStatus(raw.Status)
resp.withCookies(raw.Cookies())
resp.withUrl(raw.Request.URL)
return resp
}
func (r *Request) buildAuth(req *http.Request) {
switch r.authType {
case BasicAuth:
req.SetBasicAuth(r.user.Username, r.user.Password)
case BearerToken:
req.Header.Set(AUTHORIZATION_HEADER, "Bearer "+r.token)
}
}
func (r *Request) debug(req *http.Request) {
r.log.L().Debugf("[%s] %s", req.Method, req.URL.String())
r.log.L().Debugf("请求头:")
for k, v := range req.Header {
r.log.L().Debugf("\t%s=%s", k, strings.Join(v, ","))
}
}
// 禁止在遇到3XX返回码时自动跳转
func (r *Request) SetNoAutoRedirect() *Request {
if r.err != nil {
return r
}
r.c.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
r.log.L().Debugln("设置禁用自动跳转")
return r
}
// 设置忽略证书
func (r *Request) SetIgnoreCert() *Request {
transport := reflect.ValueOf(r.c.client.Transport)
if transport.IsValid() && !transport.IsNil() {
field := transport.Elem().FieldByName("TLSClientConfig")
if field.IsValid() && !field.IsNil() {
tlsConfig, ok := field.Interface().(*tls.Config)
if ok {
tlsConfig.InsecureSkipVerify = true
}
} else {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
}
field.Set(reflect.ValueOf(tlsConfig))
}
} else {
tr := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
r.c.client.Transport = tr
}
r.log.L().Debugln("设置忽略证书")
return r
}
// 设置TLS最低版本
func (r *Request) SetMinVersionTLS(tlsVersion uint16) *Request {
transport := reflect.ValueOf(r.c.client.Transport)
if transport.IsValid() && !transport.IsNil() {
field := transport.Elem().FieldByName("TLSClientConfig")
if field.IsValid() && !field.IsNil() {
tlsConfig, ok := field.Interface().(*tls.Config)
if ok {
tlsConfig.MinVersion = tlsVersion
}
} else {
tlsConfig := &tls.Config{
MinVersion: tlsVersion,
}
field.Set(reflect.ValueOf(tlsConfig))
}
} else {
tr := &http.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tlsVersion,
},
}
r.c.client.Transport = tr
}
r.log.L().Debugln("已经修改为低版本为:", tlsVersion)
return r
}
// 设置证书
func (r *Request) SetCert(caPath, clientCrtPath, clientKeyPath string) *Request {
cert, err := tls.LoadX509KeyPair(clientCrtPath, clientKeyPath)
if err != nil {
r.log.L().Errorln("未能加载客户端证书:", err.Error())
os.Exit(1)
}
caCert, err := os.ReadFile("ca.crt")
if err != nil {
r.log.L().Errorln("无法读取CA证书:", err.Error())
os.Exit(1)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
transport := reflect.ValueOf(r.c.client.Transport)
if transport.IsValid() && !transport.IsNil() {
field := transport.Elem().FieldByName("TLSClientConfig")
if field.IsValid() && !field.IsNil() {
tlsConfig, ok := field.Interface().(*tls.Config)
if ok {
tlsConfig.Certificates = []tls.Certificate{cert}
tlsConfig.RootCAs = caCertPool
}
} else {
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
}
field.Set(reflect.ValueOf(tlsConfig))
}
} else {
tr := &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
},
}
r.c.client.Transport = tr
}
r.log.L().Debugln("已经添加证书:", caPath, clientCrtPath, clientKeyPath)
return r
}
// 自动携带cookie
func (r *Request) SetCarryCookies() *Request {
cookieJar, err := cookiejar.New(nil)
if err != nil {
r.log.L().Errorln(err.Error())
}
r.c.client.Jar = cookieJar
return r
}
// 设置当跳转的时候,自动携带请求参数 severity是程度
func (r *Request) SetCarryQueryParameters(severity Severity) *Request {
r.c.severity = severity
r.c.client.CheckRedirect = r.c.retryOnRedirect
return r
}
// 设置代理
func (r *Request) SetProxyUrl(u string) *Request {
u1 := strings.TrimRight(u, "/")
if !strings.HasPrefix(u1, "http://") && !strings.HasPrefix(u1, "https://") {
// 添加默认的http://作为协议前缀
u1 = "http://" + strings.TrimSpace(u1)
}
parsedURL, err := url.Parse(u1)
if err != nil {
r.log.L().Errorw("url转化失败", err)
os.Exit(1)
}
scheme := parsedURL.Scheme
if scheme == "" || (scheme != "http" && scheme != "https") {
// 如果方案不为空且不是http或https,则认为输入的URL无效
r.log.L().Errorw("无效URL:", u1)
os.Exit(1)
}
r.proxyScheme = scheme
r.proxyAdd = parsedURL.Host
return r
}
func (r *Request) SetProxyPort(port int) *Request {
r.proxyPort = strconv.Itoa(port)
return r
}
func (r *Request) SetProxyScheme(scheme string) *Request {
r.proxyScheme = scheme
return r
}
func (r *Request) SetProxyAdd(add string) *Request {
r.proxyAdd = add
return r
}
func CustomDialContext(dnsResolver string) func(ctx context.Context, network, address string) (net.Conn, error) {
resolver := &net.Resolver{
// 指定自定义的 DNS 解析器,优先使用go语言内部的dns服务器
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 5 * time.Second,
}
return d.DialContext(ctx, "udp", dnsResolver)
},
}
return func(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ips []net.IP
// 查找 IPv4 地址
ipv4s, err := resolver.LookupIP(ctx, "ip4", host)
if err == nil {
ips = append(ips, ipv4s...)
}
// 查找 IPv6 地址
ipv6s, err := resolver.LookupIP(ctx, "ip6", host)
if err == nil {
ips = append(ips, ipv6s...)
}
if len(ips) == 0 {
return nil, fmt.Errorf("unable to find IP addresses for %s", host)
}
for _, ip := range ips {
conn, err := net.DialTimeout(network, net.JoinHostPort(ip.String(), port), time.Second*5)
if err == nil {
return conn, nil
}
}
return nil, fmt.Errorf("unable to connect to %s", address)
}
}
func (r *Request) SetDNS(dnsResolver string) *Request {
transport := reflect.ValueOf(r.c.client.Transport)
if transport.IsValid() && !transport.IsNil() {
field := transport.Elem().FieldByName("DialContext")
if field.IsValid() && !field.IsNil() {
tsp := r.c.client.Transport.(*http.Transport)
tsp.DialContext = CustomDialContext(dnsResolver)
} else {
field.Set(reflect.ValueOf(CustomDialContext(dnsResolver)))
}
} else {
dc := &http.Transport{
DialContext: CustomDialContext(dnsResolver),
}
r.c.client.Transport = dc
}
r.log.L().Debugln("设置DNS Server")
return r
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/hexug/go-chain-restful-client.git
git@gitee.com:hexug/go-chain-restful-client.git
hexug
go-chain-restful-client
goChainRestfulClient
v0.5.3

搜索帮助