1 Star 0 Fork 0

kzangv / gsf-ai-agent

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
prompt_tools_chain.go 5.91 KB
一键复制 编辑 原始数据 按行查看 历史
kzangv 提交于 2024-03-19 11:14 . fixed
package chain
import (
"context"
"fmt"
"gitee.com/kzangv/gsf-ai-agent/llms"
"gitee.com/kzangv/gsf-ai-agent/prompt"
"gitee.com/kzangv/gsf-ai-agent/utils"
"strings"
)
const (
ToolLLMRequestLimit = 10
)
type PromptToolChainBase struct {
PromptChain
toolBox llms.ToolBox
}
func (ch *PromptToolChainBase) Desc() string {
return "llm chain"
}
func (ch *PromptToolChainBase) SetLLM(llm llms.LLM) error {
if !llm.CheckSupport(llms.LLMSupportTypeTool) {
return fmt.Errorf("llm no support tools")
}
ch.llm = llm
return nil
}
func (ch *PromptToolChainBase) GetToolBox() llms.ToolBox {
if ch.toolBox == nil {
ch.toolBox = llms.ToolBox{}
}
return ch.toolBox
}
func (ch *PromptToolChainBase) SetToolBox(box llms.ToolBox) error {
ch.toolBox = box
return nil
}
func (ch *PromptToolChainBase) _Request(ctx context.Context, input *llms.Input) (*llms.Output, error) {
return ch.llm.Request(ctx, input, ch.GetToolBox())
}
type PromptToolChain struct {
PromptToolChainBase
}
func (ch *PromptToolChain) Run(ctx context.Context, input *llms.Input) (ret *llms.Output, err error) {
if input == nil {
input = &llms.Input{}
}
if err = input.Push(llms.ConvertPrompt2Message(
llms.MessageTypeUser,
ch.GetOption().GetRoleName(),
ch.prompt,
ch.GetOption().GetParamManager(),
)); err != nil {
return nil, err
}
ch.log.Debug(func() string {
msg, _ := ch.prompt.Format(ch.GetOption().GetParamManager())
return fmt.Sprintf("** LLM Tool Request inputs ******: \n%s\n", msg)
})
ret = &llms.Output{}
if tmp, err := ch._Request(ctx, input); err != nil {
return nil, err
} else if ret.MergeMessage(tmp); tmp.LastMsg().FinishReason() == llms.OutputFinishToolCalls {
if err = input.Push(llms.ConvertOutput2Message(tmp)); err != nil {
return nil, err
}
ch.log.Debug(func() string {
fList := tmp.LastMsg().FunctionCallList()
fMsg := make([]string, 0, len(fList))
for _, v := range fList {
fMsg = append(fMsg, fmt.Sprintf("FuncName: %s(ID: %s)\nFuncArguments: %s\nFuncResult: %s\n", v.Name, v.ID, v.Arguments, v.Result))
}
return fmt.Sprintf("** LLM Tool Function Call ******: \n%s\n", strings.Join(fMsg, "\n"))
})
}
return ret, nil
}
func NewPromptToolChain(tmp prompt.Prompt, llm llms.LLM, box llms.ToolBox, log utils.TraceLog, opt Option) (ret *PromptToolChain, err error) {
for {
ret = &PromptToolChain{}
if err = ret.SetPrompt(tmp); err != nil {
break
}
if err = InitLLMChain(ret, llm, log, opt); err != nil {
break
}
if err = ret.SetToolBox(box); err != nil {
break
}
break
}
if err != nil {
ret = nil
}
return ret, err
}
func NewTemplatePromptToolChain(tmp string, llm llms.LLM, box llms.ToolBox, log utils.TraceLog, opt Option) (ret *PromptToolChain, err error) {
for {
tmpPrompt := &prompt.TemplatePrompt{}
if err = tmpPrompt.Init(tmp); err != nil {
break
}
ret, err = NewPromptToolChain(tmpPrompt, llm, box, log, opt)
break
}
if err != nil {
ret = nil
}
return ret, err
}
func NewStringPromptToolChain(tmp string, llm llms.LLM, box llms.ToolBox, log utils.TraceLog, opt Option) (ret *PromptToolChain, err error) {
for {
tmpPrompt := &prompt.StringPrompt{}
if err = tmpPrompt.Init(tmp); err != nil {
break
}
ret, err = NewPromptToolChain(tmpPrompt, llm, box, log, opt)
break
}
if err != nil {
ret = nil
}
return ret, err
}
type PromptToolChainEx struct {
PromptToolChainBase
}
func (ch *PromptToolChainEx) Run(ctx context.Context, input *llms.Input) (ret *llms.Output, err error) {
if input == nil {
input = &llms.Input{}
}
if err = input.Push(llms.ConvertPrompt2Message(
llms.MessageTypeUser,
ch.GetOption().GetRoleName(),
ch.prompt,
ch.GetOption().GetParamManager(),
)); err != nil {
return nil, err
}
ch.log.Debug(func() string {
msg, _ := ch.prompt.Format(ch.GetOption().GetParamManager())
return fmt.Sprintf("** LLM Tool Request inputs ******: \n%s\n", msg)
})
ret = &llms.Output{}
for i := 0; i < LLMRequestLimit; i++ {
if tmp, err := ch._Request(ctx, input); err != nil {
return nil, err
} else {
if ret.MergeMessage(tmp); tmp.LastMsg().FinishReason() == llms.OutputFinishToolCalls {
if err = input.Push(llms.ConvertOutput2Message(tmp)); err != nil {
return nil, err
}
ch.log.Debug(func() string {
fList := tmp.LastMsg().FunctionCallList()
fMsg := make([]string, 0, len(fList))
for _, v := range fList {
fMsg = append(fMsg, fmt.Sprintf("FuncName: %s(ID: %s)\nFuncArguments: %s\nFuncResult: %s\n", v.Name, v.ID, v.Arguments, v.Result))
}
return fmt.Sprintf("** LLM Tool Function Call ******: \n%s\n", strings.Join(fMsg, "\n"))
})
continue
}
}
break
}
ch.log.Debug(func() string {
return fmt.Sprintf("** LLM Tool Request output ******: \n%s\n", ret.LastMsg().Content())
})
return ret, nil
}
func NewPromptToolChainEx(tmp prompt.Prompt, llm llms.LLM, box llms.ToolBox, log utils.TraceLog, opt Option) (ret *PromptToolChainEx, err error) {
for {
ret = &PromptToolChainEx{}
if err = ret.SetPrompt(tmp); err != nil {
break
}
if err = InitLLMChain(ret, llm, log, opt); err != nil {
break
}
if err = ret.SetToolBox(box); err != nil {
break
}
break
}
if err != nil {
ret = nil
}
return ret, err
}
func NewTemplatePromptToolChainEx(tmp string, llm llms.LLM, box llms.ToolBox, log utils.TraceLog, opt Option) (ret *PromptToolChainEx, err error) {
for {
tmpPrompt := &prompt.TemplatePrompt{}
if err = tmpPrompt.Init(tmp); err != nil {
break
}
ret, err = NewPromptToolChainEx(tmpPrompt, llm, box, log, opt)
break
}
if err != nil {
ret = nil
}
return ret, err
}
func NewStringPromptToolChainEx(tmp string, llm llms.LLM, box llms.ToolBox, log utils.TraceLog, opt Option) (ret *PromptToolChainEx, err error) {
for {
tmpPrompt := &prompt.StringPrompt{}
if err = tmpPrompt.Init(tmp); err != nil {
break
}
ret, err = NewPromptToolChainEx(tmpPrompt, llm, box, log, opt)
break
}
if err != nil {
ret = nil
}
return ret, err
}
Go
1
https://gitee.com/kzangv/gsf-ai-agent.git
git@gitee.com:kzangv/gsf-ai-agent.git
kzangv
gsf-ai-agent
gsf-ai-agent
v0.0.7

搜索帮助