代码拉取完成,页面将自动刷新
package openai
import (
"context"
"encoding/json"
"fmt"
"gitee.com/kzangv/gsf-ai-agent/llms"
openai "github.com/sashabaranov/go-openai"
"net/http"
)
type ModelType string
const (
ModelTypeGpt3Dot5Turbo ModelType = "gpt-3.5-turbo"
ModelTypeGpt35Turbo16k ModelType = "gpt-35-turbo-16k"
ModelTypeGpt4Pro ModelType = "gpt-4"
ModelTypeGpt4Pro32K ModelType = "gpt-4-32k"
ModelTypeGpt4Omni ModelType = "gpt-4o"
)
const (
defaultTemperature = 0.7
defaultTopP = 1
defaultMaxTokens = 500
)
var (
_ModelRespFinishMap = map[openai.FinishReason]llms.OutputFinishType{
openai.FinishReasonStop: llms.OutputFinishStop,
openai.FinishReasonLength: llms.OutputFinishLength,
openai.FinishReasonContentFilter: llms.OutputFinishContentFilter,
openai.FinishReasonToolCalls: llms.OutputFinishToolCalls,
}
)
/**
Option
*/
type Option struct {
_LLMParams
clt *http.Client
}
func (p *Option) SetHttpClient(val *http.Client) *Option {
p.clt = val
return p
}
func (p *Option) SetTemperature(val float32) *Option {
p.temperature = val
return p
}
func (p *Option) SetTopP(val float32) *Option {
p.topP = val
return p
}
func (p *Option) SetMaxTokens(val int) *Option {
p.maxTokens = val
return p
}
func NewOption() *Option {
return &Option{}
}
/**
LLM
*/
type _LLMParams struct {
temperature,
topP float32
maxTokens int
}
type LLM struct {
_LLMParams
model string
apiSecret string
client *openai.Client
}
func (llm *LLM) CheckSupport(v llms.LLMSupportType) bool {
switch v {
case llms.LLMSupportTypeTool:
return true
}
return false
}
func (llm *LLM) _buildMsgList(req *openai.ChatCompletionRequest, input llms.Input, msgs []llms.Message) error {
if sysMsg := input.GetSysMsg(); sysMsg != nil {
req.Messages = append(req.Messages, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleSystem,
Name: sysMsg.Role(), Content: sysMsg.Content()})
}
for _, msg := range msgs {
switch msg.Type() {
case llms.MessageTypeUser:
mVal := msg.(*llms.UserMessage)
req.Messages = append(req.Messages, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleUser,
Name: mVal.Role(), Content: mVal.Content()})
case llms.MessageTypeAssistant:
mVal := msg.(*llms.AssistantMessage)
if mVal.FnList == nil {
req.Messages = append(req.Messages, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleAssistant,
Name: mVal.Role(), Content: mVal.Content()})
} else if len(mVal.FnList) > 0 {
tools := make([]openai.ToolCall, 0, len(mVal.FnList))
fnResults := make([]openai.ChatCompletionMessage, 0, len(mVal.FnList))
for _, fn := range mVal.FnList {
tools = append(tools, openai.ToolCall{
ID: fn.ID,
Type: openai.ToolTypeFunction,
Function: openai.FunctionCall{Name: fn.Name, Arguments: fn.Arguments},
})
fnResults = append(fnResults, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleTool,
ToolCallID: fn.ID, Content: fn.Result})
}
req.Messages = append(req.Messages, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleAssistant,
Name: mVal.Role(), ToolCalls: tools})
req.Messages = append(req.Messages, fnResults...)
}
}
}
return nil
}
func (llm *LLM) _buildToolList(req *openai.ChatCompletionRequest, _ llms.Input, toolList []llms.Tool) error {
for _, tool := range toolList {
if tool.Type() == llms.ToolTypeFunction {
tFn := tool.(*llms.FunctionTool)
paramSchemeStr, err := tFn.ParamJsonScheme()
if err != nil {
return fmt.Errorf("func [%s] params encode json scheme error: %s", tFn.Name(), err.Error())
}
req.Tools = append(req.Tools, openai.Tool{
Type: openai.ToolTypeFunction,
Function: openai.FunctionDefinition{
Name: tFn.Name(),
Description: tFn.Desc(),
Parameters: json.RawMessage(paramSchemeStr),
},
})
}
}
if len(req.Tools) > 1 {
req.ToolChoice = "auto"
}
return nil
}
func (llm *LLM) _ParserResponse2Output(resp *openai.ChatCompletionResponse, tb llms.ToolBox) (*llms.Output, error) {
var (
fnList []*llms.FunctionCall
content string
)
rMsg := &resp.Choices[0]
fType := llms.OutputFinishInvalid
if ft, ok := _ModelRespFinishMap[rMsg.FinishReason]; ok {
fType = ft
}
if rMsg.FinishReason == openai.FinishReasonToolCalls {
if rMsg.Message.ToolCalls == nil || len(rMsg.Message.ToolCalls) == 0 {
return nil, fmt.Errorf("response is invalid %v", resp)
}
fnList = make([]*llms.FunctionCall, 0, len(rMsg.Message.ToolCalls))
for k := range rMsg.Message.ToolCalls {
v := &rMsg.Message.ToolCalls[k]
if v.Type == openai.ToolTypeFunction {
fn := &llms.FunctionCall{
ID: v.ID,
Name: v.Function.Name,
Arguments: v.Function.Arguments,
}
if tool, ok := tb.Get(v.Function.Name); !ok {
return nil, fmt.Errorf("tools.ToolBox no tool: %v", v.Function.Name)
} else if tool.Type() != llms.ToolTypeFunction {
return nil, fmt.Errorf("tool type is no function (type: %d)", tool.Type())
} else {
var err error
fn.Result, err = tool.Run(v.Function.Arguments)
if err != nil {
return nil, fmt.Errorf("call function faild reason: %s", err.Error())
}
}
fnList = append(fnList, fn)
}
}
} else {
content = rMsg.Message.Content
}
return llms.NewOutput(llms.TokenUsage{
Prompt: int64(resp.Usage.PromptTokens),
Completion: int64(resp.Usage.CompletionTokens),
Total: int64(resp.Usage.TotalTokens),
}, []*llms.OutputMsg{llms.NewOutputMsg(llms.MessageTypeAssistant, rMsg.Message.Name, content, fType, fnList...)}), nil
}
func (llm *LLM) Request(ctx context.Context, input llms.Input, tb llms.ToolBox) (*llms.Output, error) {
msgs, tools := input.GetAll(), tb.Tools()
msgLen, toolLen := len(msgs), len(tools)
req := openai.ChatCompletionRequest{
Model: llm.model,
Temperature: llm.temperature,
TopP: llm.topP,
MaxTokens: llm.maxTokens,
Messages: make([]openai.ChatCompletionMessage, 0, msgLen+msgLen/4),
Tools: make([]openai.Tool, 0, toolLen),
}
// 构造参数 & 工具
if err := llm._buildMsgList(&req, input, msgs); err != nil {
return nil, err
}
if err := llm._buildToolList(&req, input, tools); err != nil {
return nil, err
}
// 参数请求
resp, err := llm.client.CreateChatCompletion(ctx, req)
if err != nil {
return nil, err
}
// 结果解析
return llm._ParserResponse2Output(&resp, tb)
}
func NewLLM(apiSecret string, model ModelType, opt *Option) (*LLM, error) {
if apiSecret == "" {
return nil, fmt.Errorf("api_key or api_secret is empty")
}
config := openai.DefaultConfig(apiSecret)
if opt.clt != nil {
config.HTTPClient = opt.clt
}
if opt.temperature == 0 {
opt.temperature = defaultTemperature
}
if opt.topP == 0 {
opt.topP = defaultTopP
}
if opt.maxTokens <= 0 {
opt.maxTokens = defaultMaxTokens
}
ret := &LLM{
apiSecret: apiSecret,
model: string(model),
client: openai.NewClientWithConfig(config),
_LLMParams: opt._LLMParams,
}
return ret, nil
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。