1 Star 0 Fork 0

leminewx / leego

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
context.go 8.89 KB
一键复制 编辑 原始数据 按行查看 历史
leminewx 提交于 2023-09-09 14:39 . 优化处理Goway的上下文代码
package leego
import (
"encoding/json"
"fmt"
"html/template"
"io"
"net/http"
"os"
"strings"
"gitee.com/leminewx/leego/global/logging"
)
type Handler func(*Context)
// Json 定义简单的JSON数据格式
type Json map[string]any
// Context 定义请求响应上下文
type Context struct {
// 请求信息
Path string // 请求路由
Method string // 请求方法
Headers http.Header // 请求头
Params map[string]string // 路由参数
// 响应信息
StatusCode int // 响应状态码
// 中间件
cursor int // 中间件执行游标,记录执行到第几个中间件
middlewares []Handler // 中间件
// 引擎相关
engine *Engine
// 原始上下文
request *http.Request // 原始请求
response http.ResponseWriter // 原始响应
}
// init 初始化Context
func (own *Context) init(response http.ResponseWriter, request *http.Request, middlewares []Handler) {
own.Method = request.Method
own.Path = request.URL.Path
own.Headers = request.Header
own.middlewares = middlewares
own.request = request
own.response = response
}
// reset 重置Context
func (own *Context) reset() {
own.Path = ""
own.Method = ""
own.Headers = nil
own.Params = nil
own.StatusCode = 0
own.cursor = -1
own.middlewares = nil
own.request = nil
own.response = nil
}
// RunNext 执行下一个中间件
func (own *Context) RunNext() {
for own.cursor += 1; own.cursor < len(own.middlewares); own.cursor++ {
own.middlewares[own.cursor](own)
}
}
/****************** GET REQUEST ********************/
func (own *Context) GetRequest() *http.Request {
return own.request
}
func (own *Context) GetResponse() http.ResponseWriter {
return own.response
}
// GetUrlParam 获取动态路由参数
func (own *Context) GetUrlParam(key string) string {
param, ok := own.Params[key]
if ok {
return param
}
return ""
}
// GetFormParam 获取表单参数
func (own *Context) GetFormValue(key string) string {
if own.Method == http.MethodGet && own.request.ParseForm() == nil {
return own.request.FormValue(key)
} else if own.Method == http.MethodPost && own.request.ParseMultipartForm(GetMaxRequestBodySize()) == nil {
return own.request.PostFormValue(key)
}
return ""
}
// GetBytesBody 获取bytes格式请求体
func (own *Context) GetBytesBody() ([]byte, error) {
reqBody := own.request.Body
defer reqBody.Close()
return io.ReadAll(reqBody)
}
// GetJsonBody 获取json格式请求体
func (own *Context) GetJsonBody(model any) error {
body, err := own.GetBytesBody()
if err != nil {
return err
}
return json.Unmarshal(body, &model)
}
// GetJsonBody 获取string格式请求体
func (own *Context) GetStringBody() (string, error) {
body, err := own.GetBytesBody()
if err != nil {
return "", err
}
return string(body), nil
}
// GetBasicAuth 获取基本认证的账号密码
func (own *Context) GetBasicAuth() (username, password string) {
var ok bool
username, password, ok = own.request.BasicAuth()
if !ok {
return "", ""
}
return username, password
}
/****************** SET RESPONSE ********************/
func (own *Context) SetHeader(key string, value string) {
own.response.Header().Set(key, value)
}
func (own *Context) SetHeaders(headers map[string]string) {
for k, v := range headers {
own.response.Header().Set(k, v)
}
}
// SetStatusCode 设置响应状态码
func (own *Context) SetStatusCode(code int) {
// 设置状态码之后,就意味着响应至此返回,无需继续执行后续中间件
own.cursor = len(own.middlewares)
own.response.WriteHeader(code)
own.StatusCode = code
}
/****************** RESPONSE ********************/
// ResponseBytes 响应字节数组
func (own *Context) ResponseBytes(code int, data []byte) {
own.SetStatusCode(code)
if _, err := own.response.Write(data); err != nil {
logging.Logger.Errorf("response bytes error: %v", err)
own.ResponseError500(err)
}
}
// ResponseString 响应字符串
func (own *Context) ResponseString(code int, data string) {
own.response.Header().Set("Content-Type", "text/plain")
own.SetStatusCode(code)
if _, err := own.response.Write([]byte(data)); err != nil {
logging.Logger.Errorf("response string error: %v", err)
own.ResponseError500(err)
}
}
// ResponseJson 响应Json格式数据
func (own *Context) ResponseJson(code int, data any) {
resp, err := json.Marshal(data)
if err != nil {
logging.Logger.Errorf("marshal json error before response: %v", err)
own.ResponseError500(err)
return
}
own.response.Header().Set("Content-Type", "text/json")
own.SetStatusCode(code)
if _, err = own.response.Write(resp); err != nil {
logging.Logger.Errorf("response json error: %v", err)
own.ResponseError500(err)
}
}
// ResponseHtml 响应单个HTML文件
func (own *Context) ResponseHtml(code int, filename string, data any) {
own.response.Header().Set("Content-Type", "text/html")
own.SetStatusCode(code)
// 使用引擎设置的 HTML 模板集
if own.engine.templates != nil {
if err := own.engine.templates.ExecuteTemplate(own.response, filename, data); err != nil {
logging.Logger.Errorf("response html file with engine templates error: %v", err)
own.ResponseError500(err)
}
return
}
// 使用引擎设置的处理 HTML 模板的函数集
if own.engine.templateFuncs != nil {
tpl := template.Must(template.New("").Funcs(own.engine.templateFuncs).ParseFiles(filename))
if err := tpl.Execute(own.response, data); err != nil {
logging.Logger.Errorf("response html file with engine template funcs error: %v", err)
own.ResponseError500(err)
}
return
}
// 使用自定义模板
tpl := template.Must(template.New("").ParseFiles(filename))
if err := tpl.Execute(own.response, data); err != nil {
logging.Logger.Errorf("response html file error: %v", err)
own.ResponseError500(err)
}
}
// ResponseHtmlString 响应HTML字符串
func (own *Context) ResponseHtmlString(code int, text string, data any) {
own.response.Header().Set("Content-Type", "text/html")
own.SetStatusCode(code)
// 使用引擎设置的处理 HTML 模板的函数集
if own.engine.templateFuncs != nil {
tpl := template.Must(template.New("").Funcs(own.engine.templateFuncs).Parse(text))
if err := tpl.Execute(own.response, data); err != nil {
logging.Logger.Errorf("response html string with engine template funcs error: %v", err)
own.ResponseError500(err)
}
return
}
// 使用自定义模板
tpl := template.Must(template.New("").ParseFiles(text))
if err := tpl.Execute(own.response, data); err != nil {
logging.Logger.Errorf("response html string error: %v", err)
own.ResponseError500(err)
}
}
// ResponseFile 响应单个文件
func (own *Context) ResponseFile(code int, filepath, filename string) {
// 处理完整的文件路径
path := fmt.Sprintf("%s/%s", strings.Join(strings.Split(filepath, "/"), "/"), filename)
info, err := os.Stat(path)
if err != nil {
logging.Logger.Errorf("not found file before response: %v", err)
own.ResponseError500(err)
return
}
// 小文件的响应方式
if info.Size() < GetLittleFileSize() {
file, err := os.Open(path)
if err != nil {
logging.Logger.Errorf("open file error before response: %v", err)
own.ResponseError500(err)
return
}
defer file.Close()
own.response.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename))
own.SetStatusCode(code)
if _, err = io.Copy(own.response, file); err != nil {
logging.Logger.Errorf("response file error: %v", err)
own.ResponseError500(err)
}
return
}
// 大文件的响应方式
own.response.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename))
own.SetStatusCode(code)
http.ServeFile(own.response, own.request, filepath)
}
/****************** RESPONSE ERROR ********************/
// ResponseFail 响应错误信息
func (own *Context) ResponseError(code int, error string) {
own.StatusCode = code
own.cursor = len(own.middlewares)
http.Error(own.response, error, code)
}
// ResponseError401 响应400
func (own *Context) ResponseError400(error any) {
code := http.StatusBadRequest
if error == nil {
own.ResponseError(code, fmt.Sprintf("%d BAD REQUEST", code))
} else {
own.ResponseError(code, fmt.Sprintf("%d BAD REQUEST: %v", code, error))
}
}
// ResponseError401 响应401
func (own *Context) ResponseError401() {
code := http.StatusUnauthorized
own.SetHeader("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
own.ResponseError(code, fmt.Sprintf("%d UNAUTHORIZED", code))
}
// ResponseError404 响应404
func (own *Context) ResponseError404() {
code := http.StatusNotFound
own.ResponseError(code, fmt.Sprintf("%d NOT FOUND: %s - %s", code, own.Method, own.Path))
}
// ResponseError500 响应500
func (own *Context) ResponseError500(error any) {
code := http.StatusInternalServerError
if error == nil {
own.ResponseError(code, fmt.Sprintf("%d INTERNAL SERVER ERROR", code))
} else {
own.ResponseError(code, fmt.Sprintf("%d INTERNAL SERVER ERROR: %v", code, error))
}
}
Go
1
https://gitee.com/leminewx/leego.git
git@gitee.com:leminewx/leego.git
leminewx
leego
leego
4d7864a3835c

搜索帮助

53164aa7 5694891 3bd8fe86 5694891