diff --git a/internal/controller/copilot/chat_completions.go b/internal/controller/copilot/chat_completions.go index 61fab6a4f8d8e525cc1e26d118a1898d55b04fdc..bfe0d70f578e0ead12f0af95edf23ba839c191ea 100644 --- a/internal/controller/copilot/chat_completions.go +++ b/internal/controller/copilot/chat_completions.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "fmt" "io" "log" "net/http" @@ -28,6 +29,21 @@ func chatCompletions(c *gin.Context) { envModelName := os.Getenv("CHAT_API_MODEL_NAME") c.Header("Content-Type", "text/event-stream") body, _ = sjson.SetBytes(body, "model", envModelName) + + if !gjson.GetBytes(body, "function_call").Exists() { + messages := gjson.GetBytes(body, "messages").Array() + for i, msg := range messages { + toolCalls := msg.Get("tool_calls").Array() + if len(toolCalls) == 0 { + body, _ = sjson.DeleteBytes(body, fmt.Sprintf("messages.%d.tool_calls", i)) + } + } + lastIndex := len(messages) - 1 + if !strings.Contains(messages[lastIndex].Get("content").String(), "Respond in the following locale") { + body, _ = sjson.SetBytes(body, "messages."+strconv.Itoa(lastIndex)+".content", messages[lastIndex].Get("content").String()+"Respond in the following locale: "+os.Getenv("CHAT_LOCALE")+".") + } + } + body, _ = sjson.DeleteBytes(body, "intent") body, _ = sjson.DeleteBytes(body, "intent_threshold") body, _ = sjson.DeleteBytes(body, "intent_content") diff --git a/internal/controller/copilot/code_completions.go b/internal/controller/copilot/code_completions.go index e896c7dc0a951595a2d3d619386570a9db945d78..6f438a8a8bcbbab68e7db6083a69fd06a9f9ae09 100644 --- a/internal/controller/copilot/code_completions.go +++ b/internal/controller/copilot/code_completions.go @@ -17,8 +17,9 @@ import ( "net/http" "os" "strconv" - "strings" "time" + "math/rand" + "strings" ) // codeCompletions 代码补全 @@ -49,7 +50,25 @@ func codeCompletions(c *gin.Context) { } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+os.Getenv("CODEX_API_KEY")) + + apiKeys := strings.Split(os.Getenv("CODEX_API_KEY"), ",") + + // 检查 apiKeys 是否有效 + if len(apiKeys) == 0 || (len(apiKeys) == 1 && apiKeys[0] == "") { + abortCodex(c, http.StatusInternalServerError) + return + } + + + randGen := rand.New(rand.NewSource(time.Now().UnixNano())) + selectedKey := strings.TrimSpace(apiKeys[randGen.Intn(len(apiKeys))]) + + if selectedKey == "" { + abortCodex(c, http.StatusInternalServerError) + return + } + + req.Header.Set("Authorization", "Bearer "+selectedKey) client := &http.Client{ Timeout: 30 * time.Second, diff --git a/internal/controller/copilot/get_copilot_internal_v2_token.go b/internal/controller/copilot/get_copilot_internal_v2_token.go index 475581ae2118425c8b4350d7031f344be62d4517..155e898f9dc81c620a2d38c94dcd537cd5857561 100644 --- a/internal/controller/copilot/get_copilot_internal_v2_token.go +++ b/internal/controller/copilot/get_copilot_internal_v2_token.go @@ -11,6 +11,8 @@ import ( "ripper/internal/cache" "strconv" "time" + "math/rand" + "strings" ) // getDisguiseCopilotInternalV2Token 返回伪装的token @@ -69,7 +71,13 @@ func getDisguiseCopilotInternalV2Token(ctx *gin.Context) { // getCopilotInternalV2Token 获取github copilot官方token func getCopilotInternalV2Token(c *gin.Context) { - ghu := os.Getenv("COPILOT_GHU_TOKEN") + ghuTokens := strings.Split(os.Getenv("COPILOT_GHU_TOKEN"), ",") + if len(ghuTokens) == 0 { + return + } + + rand.Seed(time.Now().UnixNano()) + ghu := ghuTokens[rand.Intn(len(ghuTokens))] if ghu == "" { log.Println("ghu token is empty") c.JSON(http.StatusUnprocessableEntity, gin.H{ @@ -113,7 +121,7 @@ func getCopilotInternalV2Token(c *gin.Context) { return } if resp.StatusCode != 200 { - errorMsg := "获取 Token 失败, 当前 ghu_token 账户可能并未订阅 github copilot 服务!" + errorMsg := "获取 Token 失败, 当前 ghu_token 账户可能并未订阅 github copilot 服务!" + ghu c.JSON(resp.StatusCode, gin.H{"error": errorMsg}) log.Println(errorMsg) return diff --git a/internal/controller/copilot/github_completions.go b/internal/controller/copilot/github_completions.go index a585d9463ae91eafaec6198753455b58aaf31e90..3c03ae5e6a6225fa81b94037ad22aaee2da7c685 100644 --- a/internal/controller/copilot/github_completions.go +++ b/internal/controller/copilot/github_completions.go @@ -16,6 +16,8 @@ import ( "ripper/internal/cache" "strconv" "time" + "math/rand" + "strings" ) // codexCompletions 全代理GitHub的代码补全接口 @@ -143,7 +145,13 @@ func chatsCompletions(c *gin.Context) { // getAuthToken 获取GitHub Copilot的临时Token func getAuthToken() (string, error) { - ghu := os.Getenv("COPILOT_GHU_TOKEN") + ghuTokens := strings.Split(os.Getenv("COPILOT_GHU_TOKEN"), ",") + if len(ghuTokens) == 0 { + return "", fmt.Errorf("COPILOT_GHU_TOKEN environment variable is empty or malformed") + } + + rand.Seed(time.Now().UnixNano()) + ghu := ghuTokens[rand.Intn(len(ghuTokens))] cacheKey := "github:copilot_internal_v2_token:" + ghu token, err := cache.Get(cacheKey) if err != nil { @@ -181,7 +189,7 @@ func getAuthToken() (string, error) { defer res.Body.Close() if res.StatusCode != http.StatusOK { - return "", fmt.Errorf("获取 Token 失败") + return "", fmt.Errorf("获取 Token 失败" + ghu) } body, err := ioutil.ReadAll(res.Body)