diff --git a/copilot-proxy.service b/copilot-proxy.service new file mode 100644 index 0000000000000000000000000000000000000000..f07c99b036ed779c44d2f6e57ada3b22bce372d6 --- /dev/null +++ b/copilot-proxy.service @@ -0,0 +1,15 @@ +# 用服务器部署时,可以用这个systemd配置守护程序,注意env和工作目录配置 +[Unit] +Description=copilot-proxy server daemon +After=network.target + +[Service] +PIDFile=/tmp/copilot-proxy.pid +ExecStart=/usr/local/copilot/copilot-proxy +ExecReload=/bin/kill -USR1 $MAINPID +Restart=always +TimeoutStartSec=2 +WorkingDirectory=/usr/local/copilot + +[Install] +WantedBy=multi-user.target diff --git a/internal/controller/copilot/chat_completions.go b/internal/controller/copilot/chat_completions.go index 31b5721871e5874170aa0361ae9c8895e8ce9313..e3baaba048dc81acd95f0e034ac8f9713c2001fd 100644 --- a/internal/controller/copilot/chat_completions.go +++ b/internal/controller/copilot/chat_completions.go @@ -16,6 +16,7 @@ import ( "strings" ) +// Deprecated: 使用v2 流式chat响应 func chatCompletions(c *gin.Context) { ctx := c.Request.Context() diff --git a/internal/controller/copilot/chat_completions_v2.go b/internal/controller/copilot/chat_completions_v2.go new file mode 100644 index 0000000000000000000000000000000000000000..5f0ceecaac3a9704b3a9f7c5e105979525542022 --- /dev/null +++ b/internal/controller/copilot/chat_completions_v2.go @@ -0,0 +1,99 @@ +package copilot + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "io" + "log" + "net/http" + "os" + "strconv" + "strings" +) + +func chatCompletionsV2(c *gin.Context) { + ctx := c.Request.Context() + + body, err := io.ReadAll(c.Request.Body) + if nil != err { + c.AbortWithStatus(http.StatusBadRequest) + return + } + defer c.Request.Body.Close() + + body, _ = sjson.SetBytes(body, "model", os.Getenv("CHAT_API_MODEL_NAME")) + + if !gjson.GetBytes(body, "function_call").Exists() { + messages := gjson.GetBytes(body, "messages").Array() + lastIndex := len(messages) - 1 + if !strings.Contains(messages[lastIndex].Get("content").String(), "Respond in the following locale") { + body, _ = sjson.SetBytes(body, "messages."+strconv.Itoa(lastIndex)+".content", messages[lastIndex].Get("content").String()+"Respond in the following locale: "+os.Getenv("CHAT_LOCALE")+".") + } + } + + body, _ = sjson.DeleteBytes(body, "intent") + body, _ = sjson.DeleteBytes(body, "intent_threshold") + body, _ = sjson.DeleteBytes(body, "intent_content") + body, _ = sjson.SetBytes(body, "stream", true) + + ChatMaxTokens, _ := strconv.Atoi(os.Getenv("CHAT_MAX_TOKENS")) + if int(gjson.GetBytes(body, "max_tokens").Int()) > ChatMaxTokens { + body, _ = sjson.SetBytes(body, "max_tokens", ChatMaxTokens) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, os.Getenv("CHAT_API_BASE"), io.NopCloser(bytes.NewBuffer(body))) + if nil != err { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+os.Getenv("CHAT_API_KEY")) + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + resp, err := client.Do(req) + if nil != err { + if errors.Is(err, context.Canceled) { + c.AbortWithStatus(http.StatusRequestTimeout) + return + } + + log.Println("request conversation failed:", err.Error()) + c.AbortWithStatus(http.StatusInternalServerError) + return + } + defer closeIO(resp.Body) + + c.Status(resp.StatusCode) + + contentType := resp.Header.Get("Content-Type") + if "" != contentType { + c.Header("Content-Type", contentType) + } + + pr, pw := io.Pipe() // 创建管道 + + // 启动一个 goroutine 用于从 resp.Body 读取数据并写入到管道 + go func() { + defer pw.Close() + _, err := io.Copy(pw, resp.Body) + if err != nil { + log.Println("Error copying response body:", err) + } + }() + + // 从管道中读取数据并写入到客户端 + _, err = io.Copy(c.Writer, pr) + if err != nil { + log.Println("Error writing to client:", err) + } +} diff --git a/internal/controller/copilot/router_register.go b/internal/controller/copilot/router_register.go index cac79e1483b7e45095449a5b4fc3f7cdfaccd8e3..30c9083d55f92b3539642cc0dd8a9b72d8d9240d 100644 --- a/internal/controller/copilot/router_register.go +++ b/internal/controller/copilot/router_register.go @@ -18,7 +18,7 @@ func GinApi(g *gin.RouterGroup) { g.POST("/v1/engines/copilot-codex/completions", codeCompletions) g.POST("/v1/engines/copilot-codex", codeCompletions) - g.POST("/chat/completions", chatCompletions) + g.POST("/chat/completions", chatCompletionsV2) g.GET("/api/v3/meta", v3meta) g.GET("/api/v3/", cliv3)