1 Star 0 Fork 0

kzangv/gsf-ai-agent

加入 Gitee
与超过 1400万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
llm.go 6.96 KB
一键复制 编辑 原始数据 按行查看 历史
kzangv 提交于 2024-07-25 14:38 +08:00 . fixed
package doubao
import (
"context"
"encoding/json"
"fmt"
"gitee.com/kzangv/gsf-ai-agent/llms"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/volcengine"
"net/http"
)
var (
_ModelRespFinishMap = map[model.FinishReason]llms.OutputFinishType{
model.FinishReasonStop: llms.OutputFinishStop,
model.FinishReasonLength: llms.OutputFinishLength,
model.FinishReasonContentFilter: llms.OutputFinishContentFilter,
model.FinishReasonFunctionCall: llms.OutputFinishToolCalls,
model.FinishReasonToolCalls: llms.OutputFinishToolCalls,
}
)
type Option struct {
_Params
host, region string
clt *http.Client
}
func (op *Option) SetTemperature(val float32) *Option {
op.temperature = val
return op
}
func (op *Option) SetTopP(val float32) *Option {
op.topP = val
return op
}
func (op *Option) SetClient(val *http.Client) *Option {
op.clt = val
return op
}
func NewOption() *Option {
return &Option{}
}
/*
LLM
*/
type _Params struct {
temperature, // 用于控制生成文本的随机性和创造性,Temperature值越大随机性越大,取值范围0~1
topP float32 // 用于控制输出tokens的多样性,TopP值越大输出的tokens类型越丰富,取值范围0~1
}
type LLM struct {
_Params
apiKey string
model string
clt *arkruntime.Client
}
func (llm *LLM) CheckSupport(v llms.LLMSupportType) bool {
switch v {
case llms.LLMSupportTypeTool:
return true
}
return false
}
func (llm *LLM) _buildMsgList(req *model.ChatCompletionRequest, input llms.Input, msgs []llms.Message) error {
if sysMsg := input.GetSysMsg(); sysMsg != nil {
req.Messages = append(req.Messages, &model.ChatCompletionMessage{
Role: model.ChatMessageRoleSystem,
Content: &model.ChatCompletionMessageContent{StringValue: volcengine.String(sysMsg.Content())},
})
}
for _, msg := range msgs {
switch msg.Type() {
case llms.MessageTypeUser:
mVal := msg.(*llms.UserMessage)
req.Messages = append(req.Messages, &model.ChatCompletionMessage{
Role: model.ChatMessageRoleUser,
Content: &model.ChatCompletionMessageContent{StringValue: volcengine.String(mVal.Content())},
})
case llms.MessageTypeAssistant:
mVal := msg.(*llms.AssistantMessage)
if mVal.FnList == nil {
req.Messages = append(req.Messages, &model.ChatCompletionMessage{
Role: model.ChatMessageRoleAssistant,
Content: &model.ChatCompletionMessageContent{StringValue: volcengine.String(mVal.Content())},
})
} else if len(mVal.FnList) > 0 {
tools := make([]*model.ToolCall, 0, len(mVal.FnList))
fnResults := make([]*model.ChatCompletionMessage, 0, len(mVal.FnList))
for _, fn := range mVal.FnList {
tools = append(tools, &model.ToolCall{
ID: fn.ID,
Type: model.ToolTypeFunction,
Function: model.FunctionCall{Name: fn.Name, Arguments: fn.Arguments},
})
fnResults = append(fnResults, &model.ChatCompletionMessage{
Role: model.ChatMessageRoleTool,
ToolCallID: fn.ID,
Content: &model.ChatCompletionMessageContent{
StringValue: &fn.Result,
},
})
}
req.Messages = append(req.Messages, &model.ChatCompletionMessage{
Role: model.ChatMessageRoleAssistant,
Content: &model.ChatCompletionMessageContent{},
ToolCalls: tools,
})
req.Messages = append(req.Messages, fnResults...)
}
}
}
return nil
}
func (llm *LLM) _buildToolList(req *model.ChatCompletionRequest, _ llms.Input, toolList []llms.Tool) error {
for _, tool := range toolList {
if tool.Type() == llms.ToolTypeFunction {
tFn := tool.(*llms.FunctionTool)
paramSchemeStr, err := tFn.ParamJsonScheme()
if err != nil {
return fmt.Errorf("func [%s] params encode json scheme error: %s", tFn.Name(), err.Error())
}
req.Tools = append(req.Tools, &model.Tool{
Type: model.ToolTypeFunction,
Function: &model.FunctionDefinition{
Name: tFn.Name(),
Description: tFn.Desc(),
Parameters: json.RawMessage(paramSchemeStr),
},
})
}
}
return nil
}
func (llm *LLM) _ParserResponse2Output(resp *model.ChatCompletionResponse, tb llms.ToolBox) (*llms.Output, error) {
var (
fnList []*llms.FunctionCall
content string
)
rMsg := resp.Choices[0]
fType := llms.OutputFinishInvalid
if ft, ok := _ModelRespFinishMap[rMsg.FinishReason]; ok {
fType = ft
}
if rMsg.FinishReason == model.FinishReasonToolCalls {
if rMsg.Message.ToolCalls == nil || len(rMsg.Message.ToolCalls) == 0 {
return nil, fmt.Errorf("response is invalid %v", resp)
}
fnList = make([]*llms.FunctionCall, 0, len(rMsg.Message.ToolCalls))
for k := range rMsg.Message.ToolCalls {
v := rMsg.Message.ToolCalls[k]
if v.Type == model.ToolTypeFunction {
fnVal := rMsg.Message.ToolCalls[k]
fn := &llms.FunctionCall{
ID: v.ID,
Name: fnVal.Function.Name,
Arguments: fnVal.Function.Arguments,
}
if tool, ok := tb.Get(fnVal.Function.Name); !ok {
return nil, fmt.Errorf("tools.ToolBox no tool: %v", fnVal.Function.Name)
} else if tool.Type() != llms.ToolTypeFunction {
return nil, fmt.Errorf("tool type is no function (type: %d)", tool.Type())
} else {
var err error
fn.Result, err = tool.Run(fnVal.Function.Arguments)
if err != nil {
return nil, fmt.Errorf("call function faild reason: %s", err.Error())
}
}
fnList = append(fnList, fn)
}
}
} else if rMsg.Message.Content != nil && rMsg.Message.Content.StringValue != nil {
content = *rMsg.Message.Content.StringValue
}
return llms.NewOutput(llms.TokenUsage{
Prompt: int64(resp.Usage.PromptTokens),
Completion: int64(resp.Usage.CompletionTokens),
Total: int64(resp.Usage.TotalTokens),
}, []*llms.OutputMsg{llms.NewOutputMsg(llms.MessageTypeAssistant, "", content, fType, fnList...)}), nil
}
func (llm *LLM) Request(ctx context.Context, input llms.Input, tb llms.ToolBox) (*llms.Output, error) {
msgs, tools := input.GetAll(), tb.Tools()
msgLen, toolLen := len(msgs), len(tools)
req := model.ChatCompletionRequest{
Model: llm.model,
Messages: make([]*model.ChatCompletionMessage, 0, msgLen+msgLen/4),
Tools: make([]*model.Tool, 0, toolLen),
Temperature: llm.temperature,
TopP: llm.topP,
}
// 构造参数 & 工具
if err := llm._buildMsgList(&req, input, msgs); err != nil {
return nil, err
}
if err := llm._buildToolList(&req, input, tools); err != nil {
return nil, err
}
// 参数请求
resp, err := llm.clt.CreateChatCompletion(ctx, req)
if err != nil {
return nil, err
}
// 结果解析
return llm._ParserResponse2Output(&resp, tb)
}
func NewLLM(apiKey, model string, opt *Option) (*LLM, error) {
if apiKey == "" || model == "" {
return nil, fmt.Errorf("api_key or model is empty")
}
if opt == nil {
opt = &Option{}
}
clt := arkruntime.NewClientWithApiKey(apiKey)
ret := &LLM{
_Params: opt._Params,
apiKey: apiKey,
model: model,
clt: clt,
}
return ret, nil
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/kzangv/gsf-ai-agent.git
git@gitee.com:kzangv/gsf-ai-agent.git
kzangv
gsf-ai-agent
gsf-ai-agent
v0.0.9

搜索帮助