代码拉取完成,页面将自动刷新
package ernie
import (
"context"
"encoding/json"
"fmt"
"gitee.com/kzangv/gsf-ai-agent/llms"
"gitee.com/kzangv/gsf-ai-agent/utils"
"net/http"
)
type Option struct {
clt *http.Client
cache utils.KeyValueCache
msgLimit uint
temperature,
topP,
penaltyScore float32
}
func (p *Option) SetHttpClient(val *http.Client) *Option {
p.clt = val
return p
}
func (p *Option) SetCache(val utils.KeyValueCache) *Option {
p.cache = val
return p
}
func (p *Option) SetMsgLimit(val uint) *Option {
p.msgLimit = val
return p
}
func (p *Option) SetTemperature(val float32) *Option {
p.temperature = val
return p
}
func (p *Option) SetTopP(val float32) *Option {
p.topP = val
return p
}
func (p *Option) SetPenaltyScore(val float32) *Option {
p.penaltyScore = val
return p
}
func NewOption() *Option {
return &Option{}
}
type LLM struct {
model ModelType
modelReqUrl string
apiKey string
apiSecret string
options *Option
}
func (llm *LLM) CheckSupport(v llms.LLMSupportType) bool {
switch v {
case llms.LLMSupportTypeTool:
return true
}
return false
}
func (llm *LLM) _buildMsgList(req *Request, input llms.Input, msgs []llms.Message) error {
if sysMsg := input.GetSysMsg(); sysMsg != nil {
req.System = sysMsg.Content()
}
for _, msg := range msgs {
switch msg.Type() {
case llms.MessageTypeUser:
mVal := msg.(*llms.UserMessage)
msgTxt := mVal.Content()
req.Messages = append(req.Messages, &ReqChatCompMsg{
Role: _MessageRoleUser,
Name: mVal.Role(),
Content: &msgTxt,
})
case llms.MessageTypeAssistant:
mVal := msg.(*llms.AssistantMessage)
msgTxt := mVal.Content()
if mVal.FnList == nil {
req.Messages = append(req.Messages, &ReqChatCompMsg{
Role: _MessageRoleAssistant,
Name: mVal.Role(),
Content: &msgTxt,
})
} else {
for _, fn := range mVal.FnList {
req.Messages = append(req.Messages, &ReqChatCompMsg{
Role: _MessageRoleAssistant,
Name: mVal.Role(),
FunctionCall: &ErnieFunctionCall{
Name: fn.Name,
Arguments: fn.Arguments,
},
})
fnJsonVal, _ := json.Marshal(map[string]string{"output": fn.Result})
fnJsonStr := string(fnJsonVal)
req.Messages = append(req.Messages, &ReqChatCompMsg{
Role: _MessageRoleFunction,
Name: fn.Name,
Content: &fnJsonStr,
})
}
}
}
}
return nil
}
func (llm *LLM) _buildToolList(req *Request, _ llms.Input, toolList []llms.Tool) error {
for _, tool := range toolList {
if tool.Type() == llms.ToolTypeFunction {
tFn := tool.(*llms.FunctionTool)
paramSchemeStr, err := tFn.ParamJsonScheme()
if err != nil {
return fmt.Errorf("func [%s] params encode json scheme error: %s", tFn.Name(), err.Error())
}
var examples []*ErnieFunctionExample
if fnExamples := tFn.Examples(); len(fnExamples) > 0 {
examples = make([]*ErnieFunctionExample, 0, len(fnExamples)*2)
for k := range fnExamples {
examples = append(examples, &ErnieFunctionExample{Role: _MessageRoleUser, Content: &fnExamples[k].Input})
if fnExamples[k].IsCall {
examples = append(examples, &ErnieFunctionExample{Role: _MessageRoleAssistant,
FunctionCall: &ErnieFunctionCall{Name: tFn.Name(), Arguments: fnExamples[k].FunctionParams}})
} else {
examples = append(examples, &ErnieFunctionExample{Role: _MessageRoleAssistant,
FunctionCall: &ErnieFunctionCall{Name: "", Arguments: "{}", Thoughts: "无需调用任何工具"}})
}
}
}
req.Functions = append(req.Functions, &ErnieFunction{
Name: tFn.Name(),
Description: tFn.Desc(),
Parameters: json.RawMessage(paramSchemeStr),
Examples: examples,
})
}
}
return nil
}
func (llm *LLM) _ParserResponse2Output(resp *Response, tb llms.ToolBox) (*llms.Output, error) {
var (
fnList []*llms.FunctionCall
content string
)
if resp.FinishReason == _FinishFunctionCall {
if resp.FunctionCall == nil || resp.FunctionCall.Name == "" {
return nil, fmt.Errorf("response is invalid %v", resp)
}
fn := &llms.FunctionCall{
ID: resp.FunctionCall.Name,
Name: resp.FunctionCall.Name,
Arguments: resp.FunctionCall.Arguments,
}
if tool, ok := tb.Get(resp.FunctionCall.Name); !ok {
return nil, fmt.Errorf("tools.ToolBox no tool: %v", resp.FunctionCall.Name)
} else if tool.Type() != llms.ToolTypeFunction {
return nil, fmt.Errorf("tool type is no function (type: %d)", tool.Type())
} else {
var err error
fn.Result, err = tool.Run(resp.FunctionCall.Arguments)
if err != nil {
return nil, fmt.Errorf("call function faild reason: %s", err.Error())
}
}
fnList = []*llms.FunctionCall{fn}
} else {
content = resp.Result
}
fType := llms.OutputFinishInvalid
if ft, ok := _ModelRespFinishMap[resp.FinishReason]; ok {
fType = ft
}
return llms.NewOutput(llms.TokenUsage{
Prompt: int64(resp.Usage.PromptTokens),
Completion: int64(resp.Usage.CompletionTokens),
Total: int64(resp.Usage.TotalTokens),
}, []*llms.OutputMsg{llms.NewOutputMsg(llms.MessageTypeAssistant, "", content, fType, fnList...)}), nil
}
func (llm *LLM) Request(ctx context.Context, input llms.Input, tb llms.ToolBox) (*llms.Output, error) {
msgs, tools := input.GetAll(), tb.Tools()
msgLen, toolLen := len(msgs), len(tools)
req := &Request{
Messages: make([]*ReqChatCompMsg, 0, msgLen+msgLen/4),
Functions: make([]*ErnieFunction, 0, toolLen),
Temperature: llm.options.temperature,
TopP: llm.options.topP,
PenaltyScore: llm.options.penaltyScore,
}
// 构造参数 & 工具
if err := llm._buildMsgList(req, input, msgs); err != nil {
return nil, err
}
if err := llm._buildToolList(req, input, tools); err != nil {
return nil, err
}
resp, err := llm._FetchCompletion(ctx, req)
if err != nil {
return nil, err
}
return llm._ParserResponse2Output(resp, tb)
}
func NewLLM(apiKey, apiSecret string, model ModelType, opt *Option) (*LLM, error) {
if apiKey == "" || apiSecret == "" {
return nil, fmt.Errorf("api_key or api_secret is empty")
}
if opt == nil {
opt = &Option{}
}
if opt.clt == nil {
opt.clt = &http.Client{}
}
if opt.cache == nil {
opt.cache = utils.NewMapCache()
}
if opt.msgLimit == 0 {
opt.msgLimit = _DefaultEmptyMessagesLimit
}
if opt.temperature == 0 {
opt.temperature = defaultTemperature
}
if opt.topP == 0 {
opt.topP = defaultTopP
}
if opt.penaltyScore == 0 {
opt.penaltyScore = defaultPenaltyScore
}
ret := &LLM{
apiKey: apiKey,
apiSecret: apiSecret,
model: model,
modelReqUrl: _ModelReqUrlMap[model],
options: opt,
}
return ret, nil
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。