1 Star 0 Fork 0

kzangv / gsf-ai-agent

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
router_chain.go 4.37 KB
一键复制 编辑 原始数据 按行查看 历史
kzangv 提交于 2024-05-15 09:35 . fixed
package chain
import (
"context"
"fmt"
"gitee.com/kzangv/gsf-ai-agent/chain/router"
"gitee.com/kzangv/gsf-ai-agent/llms"
"gitee.com/kzangv/gsf-ai-agent/utils"
)
/**
RouterOption
*/
type RouterOption struct {
SimpleOption
router router.Parser
}
func (op *RouterOption) GetRoleName() string {
return op.role
}
func (op *RouterOption) SetRoleName(v string) Option {
op.role = v
return op
}
func (op *RouterOption) GetInput() []string {
return op.inputs
}
func (op *RouterOption) SetInput(v ...string) Option {
if len(v) > 1 {
v = v[:1]
}
op.inputs = v
return op
}
func (op *RouterOption) GetParamManager() utils.ParamsChainManager {
return op.paramsMgt
}
func (op *RouterOption) SetParamManager(pm utils.ParamsChainManager) Option {
op.paramsMgt = pm
return op
}
func (op *RouterOption) GetOutput() string {
return op.output
}
func (op *RouterOption) SetOutput(v string) Option {
op.output = v
return op
}
func (op *RouterOption) SetRouterParser(router router.Parser) *RouterOption {
op.router = router
return op
}
func NewRouterOption() *RouterOption {
return &RouterOption{}
}
/**
RouterChain
*/
type RouterChain struct {
SimpleLLMChain
chMap map[string]*_ChainNode
}
func (ch *RouterChain) Desc() string {
return "router chain"
}
func (ch *RouterChain) SetOption(opt Option) error {
if _, ok := opt.(*RouterOption); !ok {
return fmt.Errorf("RouterChain Option type is invalid")
}
return ch.SimpleLLMChain.SetOption(opt)
}
func (ch *RouterChain) Add(name, desc string, node Chain, h FormatHandle) {
if ch.chMap == nil {
ch.chMap = make(map[string]*_ChainNode)
}
ch.chMap[name] = NewChainNode(desc, node, h)
}
func (ch *RouterChain) Delete(name string) {
if ch.chMap != nil {
delete(ch.chMap, name)
}
}
func (ch *RouterChain) Init() error {
if ch.opt == nil {
ch.opt = NewRouterOption()
}
if _, ok := ch.opt.(*RouterOption); !ok {
return fmt.Errorf("RouterChain Option type is invalid")
}
if opt := ch.opt.(*RouterOption); opt.router == nil {
opt.SetRouterParser(router.DefaultParser)
}
return ch.SimpleLLMChain.Init()
}
func (ch *RouterChain) Run(ctx context.Context, input *llms.Input) (ret *llms.Output, err error) {
var rMsg, inputVal string
opt := ch.GetOption().(*RouterOption)
if input == nil {
input = &llms.Input{}
}
inputVal, _ = opt.GetParamManager().Get(opt.GetInput()[0])
if inputVal == "" && len(ch.chMap) == 0 {
return nil, fmt.Errorf("Router Input is Empty")
}
// 构造路由模板信息
routerPaths := make([]router.Node, 0, len(ch.chMap))
for k, v := range ch.chMap {
routerPaths = append(routerPaths, router.Node{Name: k, Desc: v.Desc})
}
rMsg, err = opt.router.Format(inputVal, routerPaths)
if err != nil {
return nil, err
}
// 发起路由请求
ch.log.Debug(func() string {
return fmt.Sprintf("** Router Request ******: \n%s\n", inputVal)
})
if err = input.Push(llms.ConvertString2Message(llms.MessageTypeUser, "", rMsg)); err != nil {
return nil, err
}
var (
rOut *llms.Output
rRet router.Output
)
if rOut, err = ch.llm.Request(ctx, input, nil); err == nil {
if rRet, err = opt.router.Parser(rOut.LastMsg().Content()); err == nil {
if chNode, ok := ch.chMap[rRet.GetName()]; ok {
opt.GetParamManager().Add(opt.GetOutput(), rRet.GetInput())
if err = chNode.Chain.SetTraceLog(ch.log); err != nil {
return
}
chNode.Chain.GetOption().GetParamManager().SetPrev(opt.GetParamManager())
if ret, err = chNode.Chain.Run(ctx, input.Next()); err == nil {
if chNodeOp := chNode.Chain.GetOption(); chNodeOp.GetOutput() != "" {
if chNode.Handle != nil {
if msg, err := chNode.Handle(ret); err != nil {
return nil, err
} else {
ch.GetOption().GetParamManager().Add(chNode.Chain.GetOption().GetOutput(), msg)
}
} else if lMsg := ret.LastMsg(); lMsg != nil && lMsg.Content() != "" {
ch.GetOption().GetParamManager().Add(chNode.Chain.GetOption().GetOutput(), lMsg.Content())
}
}
tmp := &llms.Output{}
tmp.MergeMessage(llms.NewOutput(llms.TokenUsage{}, []*llms.OutputMsg{ret.LastMsg()}))
tmp.SetChild(rOut, ret)
ret = tmp
}
} else {
err = fmt.Errorf("no Router Path be choose")
}
}
}
return
}
func NewRouterChain(llm llms.LLM, log utils.TraceLog, opt Option) (ret *RouterChain, err error) {
ret = &RouterChain{}
if err = InitLLMChain(ret, llm, log, opt); err != nil {
ret = nil
}
return ret, err
}
Go
1
https://gitee.com/kzangv/gsf-ai-agent.git
git@gitee.com:kzangv/gsf-ai-agent.git
kzangv
gsf-ai-agent
gsf-ai-agent
v0.0.7

搜索帮助