1 Star 0 Fork 0

kzangv / gsf-ai-agent

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
llm_old.go 5.91 KB
一键复制 编辑 原始数据 按行查看 历史
kzangv 提交于 2024-03-29 16:40 . fixed
package azure
import (
"context"
"encoding/json"
"fmt"
"gitee.com/kzangv/gsf-ai-agent/llms"
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
)
type LLMOld struct {
_LLMParams
model string
modelDeploy string
apiSecret string
client *azopenai.Client
}
func (llm *LLMOld) CheckSupport(v llms.LLMSupportType) bool {
return v == llms.LLMSupportTypeTool
}
func (llm *LLMOld) _buildMsgList(req *azopenai.ChatCompletionsOptions, input *llms.Input, msgs []llms.Message) error {
if sysMsg := input.GetSysMsg(); sysMsg != nil {
var name *string = nil
if sysMsg.Name != "" {
name = &sysMsg.Name
}
req.Messages = append(req.Messages, &azopenai.ChatRequestSystemMessage{Name: name, Content: &sysMsg.Content})
}
for _, msg := range msgs {
switch msg.MsgType() {
case llms.MessageTypeUser:
mVal := msg.(*llms.UserMessage)
var name *string = nil
if nVal := mVal.MsgRoleName(); nVal != "" {
name = &nVal
}
req.Messages = append(req.Messages, &azopenai.ChatRequestUserMessage{
Name: name, Content: azopenai.NewChatRequestUserMessageContent(mVal.Content)})
case llms.MessageTypeAssistant:
mVal := msg.(*llms.AssistantMessage)
var name *string = nil
if nVal := mVal.MsgRoleName(); nVal != "" {
name = &nVal
}
if mVal.FnList == nil {
req.Messages = append(req.Messages, &azopenai.ChatRequestAssistantMessage{
Name: name, Content: &mVal.Content})
} else if len(mVal.FnList) > 0 {
funcCall := &azopenai.FunctionCall{}
fnResults := make([]azopenai.ChatRequestMessageClassification, 0, len(mVal.FnList))
for _, fn := range mVal.FnList {
funcCall.Name, funcCall.Arguments = &fn.Name, &fn.Arguments
fnResults = append(fnResults, &azopenai.ChatRequestFunctionMessage{Name: &fn.Name, Content: &fn.Result})
break
}
req.Messages = append(req.Messages, &azopenai.ChatRequestAssistantMessage{
Name: name, FunctionCall: funcCall})
req.Messages = append(req.Messages, fnResults...)
}
}
}
return nil
}
func (llm *LLMOld) _buildToolList(req *azopenai.ChatCompletionsOptions, _ *llms.Input, toolList []llms.Tool) error {
for _, tool := range toolList {
if tool.Type() == llms.ToolTypeFunction {
tFn := tool.(*llms.FunctionTool)
paramSchemeStr, err := tFn.ParamJsonScheme()
if err != nil {
return fmt.Errorf("func [%s] params encode json scheme error: %s", tFn.Name(), err.Error())
}
req.Functions = append(req.Functions, azopenai.FunctionDefinition{
Name: to.Ptr(tFn.Name()),
Description: to.Ptr(tFn.Desc()),
Parameters: json.RawMessage(paramSchemeStr),
})
}
}
if len(req.Tools) > 1 {
req.FunctionCall = &azopenai.ChatCompletionsOptionsFunctionCall{Value: to.Ptr("auto")}
}
return nil
}
func (llm *LLMOld) _ParserResponse2Output(resp *azopenai.GetChatCompletionsResponse, tb llms.ToolBox) (*llms.Output, error) {
var (
fnList []*llms.FunctionCall
content string
)
rMsg := &resp.Choices[0]
fType := llms.OutputFinishInvalid
if ft, ok := _ModelRespFinishMap[*rMsg.FinishReason]; ok {
fType = ft
}
if *rMsg.FinishReason == azopenai.CompletionsFinishReasonFunctionCall {
if rMsg.Message.FunctionCall == nil {
return nil, fmt.Errorf("response is invalid %v", resp)
}
fnList = make([]*llms.FunctionCall, 0, 1)
fnVal := rMsg.Message.FunctionCall
fn := &llms.FunctionCall{
Name: *fnVal.Name,
Arguments: *fnVal.Arguments,
}
if tool, ok := tb.Get(*fnVal.Name); !ok {
return nil, fmt.Errorf("tools.ToolBox no tool: %v", *fnVal.Name)
} else if tool.Type() != llms.ToolTypeFunction {
return nil, fmt.Errorf("tool type is no function (type: %d)", tool.Type())
} else {
var err error
fn.Result, err = tool.Run(*fnVal.Arguments)
if err != nil {
return nil, fmt.Errorf("call function faild reason: %s", err.Error())
}
}
fnList = append(fnList, fn)
} else {
content = *rMsg.Message.Content
}
return llms.NewOutput(llms.TokenUsage{
Prompt: int64(*resp.Usage.PromptTokens),
Completion: int64(*resp.Usage.CompletionTokens),
Total: int64(*resp.Usage.TotalTokens),
}, []*llms.OutputMsg{llms.NewOutputMsg(llms.MessageTypeAssistant, "", content, fType, fnList...)}), nil
}
func (llm *LLMOld) Request(ctx context.Context, input *llms.Input, tb llms.ToolBox) (*llms.Output, error) {
msgs, tools := input.CurrValues(), tb.Tools()
msgLen, toolLen := len(msgs), len(tools)
var funcList []azopenai.FunctionDefinition = nil
if toolLen > 0 {
funcList = make([]azopenai.FunctionDefinition, 0, toolLen)
}
req := azopenai.ChatCompletionsOptions{
DeploymentName: &llm.modelDeploy,
MaxTokens: &llm.maxTokens,
TopP: &llm.topP,
Temperature: &llm.temperature,
Messages: make([]azopenai.ChatRequestMessageClassification, 0, msgLen+msgLen/4),
Functions: funcList,
}
// 构造参数 & 工具
if err := llm._buildMsgList(&req, input, msgs); err != nil {
return nil, err
}
if err := llm._buildToolList(&req, input, tools); err != nil {
return nil, err
}
// 参数请求
resp, err := llm.client.GetChatCompletions(ctx, req, nil)
if err != nil {
return nil, err
}
// 结果解析
return llm._ParserResponse2Output(&resp, tb)
}
func NewLLMOld(apiSecret string, model ModelType, modelDeploy string, opt *Option) (*LLMOld, error) {
if apiSecret == "" {
return nil, fmt.Errorf("api_key or api_secret is empty")
}
keyCredential := azcore.NewKeyCredential(apiSecret)
client, err := azopenai.NewClientWithKeyCredential(opt.apiEndpoint, keyCredential, nil)
if err != nil {
return nil, err
}
if opt.temperature == 0 {
opt.temperature = defaultTemperature
}
if opt.topP == 0 {
opt.topP = defaultTopP
}
if opt.maxTokens <= 0 {
opt.maxTokens = defaultMaxTokens
}
ret := &LLMOld{
apiSecret: apiSecret,
model: string(model),
modelDeploy: modelDeploy,
client: client,
_LLMParams: opt._LLMParams,
}
return ret, nil
}
Go
1
https://gitee.com/kzangv/gsf-ai-agent.git
git@gitee.com:kzangv/gsf-ai-agent.git
kzangv
gsf-ai-agent
gsf-ai-agent
v0.0.6

搜索帮助