1 Star 0 Fork 0

kzangv / gsf-ai-agent

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
llm.go 10.68 KB
一键复制 编辑 原始数据 按行查看 历史
kzangv 提交于 2024-01-31 20:03 . fixed
package skylark
import (
"context"
"encoding/json"
"errors"
"fmt"
"gitee.com/kzangv/gsf-ai-agent/llms"
"github.com/volcengine/volc-sdk-golang/service/maas"
"github.com/volcengine/volc-sdk-golang/service/maas/models/api"
structpb "google.golang.org/protobuf/types/known/structpb"
"net/http"
)
type ModelType string
const (
ModelTypeSkylarkLitePublic ModelType = "skylark-lite-public"
ModelTypeSkylarkPlusPublic ModelType = "skylark-plus-public"
ModelTypeSkylarkProPublic ModelType = "skylark-pro-public"
ModelTypeSkylark2Pro4k ModelType = "Skylark2-pro-4k"
ModelTypeSkylark2Lite8k ModelType = "skylark2-lite-8k"
ModelTypeSkylark2Pro32k ModelType = "skylark2-pro-32k"
ModelTypeMoonshot8KV1 ModelType = "moonshot-v1-8k"
ModelTypeMoonshot32KV1 ModelType = "moonshot-v1-32k"
ModelTypeMoonshot128KV1 ModelType = "moonshot-v1-128k"
)
const (
defaultTopP = 0.9
defaultTopK = 0
defaultMinNewTokens = 1
defaultMaxNewTokens = 1000
defaultMaxPromptTokens = 4000
defaultTemperature = 0.7
defaultHost = "maas-api.ml-platform-cn-beijing.volces.com"
defaultRegion = "cn-beijing"
)
const (
_FinishStop = "stop" // 输出结果命中入参 stop 中指定的字段后被截断
_FinishLength = "length" // 达到了最大的 token 数,根据 EB 返回结果 is_truncated 来截断
_FinishFunctionCall = "function_call" // 调用了 function call功能
)
var (
_ModelRespFinishMap = map[string]llms.OutputFinishType{
_FinishStop: llms.OutputFinishStop,
_FinishLength: llms.OutputFinishLength,
_FinishFunctionCall: llms.OutputFinishToolCalls,
}
_ModelDefaultVersionMap = map[ModelType]string{
ModelTypeSkylark2Pro4k: "1.0",
ModelTypeSkylark2Pro32k: "1.0",
ModelTypeMoonshot8KV1: "1.0",
ModelTypeMoonshot32KV1: "1.0",
ModelTypeMoonshot128KV1: "1.0",
}
)
type Option struct {
_Params
host, region string
clt *http.Client
}
func (op *Option) SetMaxNewTokens(val int64) *Option {
op.maxNewTokens = val
return op
}
func (op *Option) SetMinNewTokens(val int64) *Option {
op.minNewTokens = val
return op
}
func (op *Option) SetMaxPromptTokens(val int64) *Option {
op.maxPromptTokens = val
return op
}
func (op *Option) SetTopK(val int64) *Option {
op.topK = val
return op
}
func (op *Option) SetTemperature(val float32) *Option {
op.temperature = val
return op
}
func (op *Option) SetTopP(val float32) *Option {
op.topP = val
return op
}
func (op *Option) SetHost(val string) *Option {
op.host = val
return op
}
func (op *Option) SetRegion(val string) *Option {
op.region = val
return op
}
func (op *Option) SetClient(val *http.Client) *Option {
op.clt = val
return op
}
func NewOption() *Option {
return &Option{}
}
/*
LLM
*/
type _Params struct {
maxNewTokens, // 输出文本的最大tokens限制
minNewTokens, // 输出文本的最小tokens限制
maxPromptTokens, // 最大输入 token 数,如果给出的 prompt 的 token 长度超过此限制,取最后 max_prompt_tokens 个 token 输入模型。
topK int64 // 选择预测值最大的k个token进行采样,取值范围0-1000,0表示不生效
temperature, // 用于控制生成文本的随机性和创造性,Temperature值越大随机性越大,取值范围0~1
topP float32 // 用于控制输出tokens的多样性,TopP值越大输出的tokens类型越丰富,取值范围0~1
}
type LLM struct {
_Params
apiKey string
apiSecret string
model string
modelVer string
clt *maas.MaaS
}
func (llm *LLM) CheckSupport(_ llms.LLMSupportType) bool {
return false
}
func (llm *LLM) _buildMsgList(req *api.ChatReq, input *llms.Input, msgs []llms.Message) error {
if sysMsg := input.GetSysMsg(); sysMsg != nil {
req.Messages = append(req.Messages, &api.Message{Role: maas.ChatRoleOfSystem,
Name: sysMsg.Name, Content: sysMsg.Content,
})
}
for _, msg := range msgs {
switch msg.MsgType() {
case llms.MessageTypeUser:
mVal := msg.(*llms.UserMessage)
req.Messages = append(req.Messages, &api.Message{Role: maas.ChatRoleOfUser,
Name: mVal.MsgRoleName(), Content: mVal.Content})
case llms.MessageTypeAssistant:
mVal := msg.(*llms.AssistantMessage)
if mVal.FnList == nil {
req.Messages = append(req.Messages, &api.Message{Role: maas.ChatRoleOfAssistant,
Name: mVal.MsgRoleName(), Content: mVal.Content})
} else if len(mVal.FnList) > 0 {
for _, fn := range mVal.FnList {
req.Messages = append(req.Messages, &api.Message{Role: maas.ChatRoleOfAssistant,
Name: mVal.MsgRoleName(),
FunctionCall: &api.FunctionCall{
Name: fn.Name,
Arguments: fn.Arguments,
},
})
//fnJsonVal, _ := json.Marshal(map[string]string{"output": fn.Result})
//fnJsonStr := string(fnJsonVal)
req.Messages = append(req.Messages, &api.Message{
Role: maas.ChatRoleOfFunction,
Name: fn.Name,
Content: fn.Result,
})
}
}
}
}
return nil
}
func (llm *LLM) _buildToolList(req *api.ChatReq, _ *llms.Input, toolList []llms.Tool) error {
for _, tool := range toolList {
if tool.Type() == llms.ToolTypeFunction {
tFn := tool.(*llms.FunctionTool)
paramSchemeStr, err := tFn.ParamJsonScheme()
if err != nil {
return fmt.Errorf("func [%s] params encode json scheme error: %s", tFn.Name(), err.Error())
}
//var examples []string
//if fnExamples := tFn.Examples(); len(fnExamples) > 0 {
// examples = make([]string, 0, len(fnExamples)*2)
// for k := range fnExamples {
// if fnExamples[k].IsCall {
// examples = append(examples, format(fnExamples[k].Input, fnExamples[k].FunctionParams, ""))
// } else {
// examples = append(examples, format(fnExamples[k].Input, nil, "无需调用任何工具"))
// }
// }
//}
scene := &structpb.Struct{}
//这里有一个非 null判断,其中 app.Data为数据库中查询
if paramSchemeStr != "null" {
if err = json.Unmarshal([]byte(paramSchemeStr), scene); err != nil {
return err
}
}
req.Functions = append(req.Functions, &api.Function{
Name: tFn.Name(),
Description: tFn.Desc(),
Parameters: scene,
//Examples: examples,
})
}
}
return nil
}
func (llm *LLM) _ParserResponse2Output(resp *api.ChatResp, tb llms.ToolBox) (*llms.Output, error) {
if resp.Error != nil {
return nil, fmt.Errorf("api responce error(code: %d), %s", resp.Error.GetCodeN(), resp.Error.GetMessage())
}
if resp.Choice == nil {
return nil, fmt.Errorf("api responce error Choice is nil (%+v)", *resp)
}
var (
fnList []*llms.FunctionCall
content string
)
rMsg := resp.Choice
fType := llms.OutputFinishInvalid
if ft, ok := _ModelRespFinishMap[rMsg.FinishReason]; ok {
fType = ft
}
if rMsg.FinishReason == _FinishFunctionCall {
if rMsg.Message.FunctionCall == nil || rMsg.Message.FunctionCall.Name == "" {
return nil, fmt.Errorf("response is invalid %v", resp)
}
fn := &llms.FunctionCall{
ID: rMsg.Message.FunctionCall.Name,
Name: rMsg.Message.FunctionCall.Name,
Arguments: rMsg.Message.FunctionCall.Arguments,
}
if tool, ok := tb.Get(rMsg.Message.FunctionCall.Name); !ok {
return nil, fmt.Errorf("tools.ToolBox no tool: %v", rMsg.Message.FunctionCall.Name)
} else if tool.Type() != llms.ToolTypeFunction {
return nil, fmt.Errorf("tool type is no function (type: %d)", tool.Type())
} else {
var err error
fn.Result, err = tool.Run(rMsg.Message.FunctionCall.Arguments)
if err != nil {
return nil, fmt.Errorf("call function faild reason: %s", err.Error())
}
}
fnList = []*llms.FunctionCall{fn}
} else {
content = rMsg.Message.Content
}
return llms.NewOutput(llms.TokenUsage{
Prompt: resp.Usage.PromptTokens,
Completion: resp.Usage.CompletionTokens,
Total: resp.Usage.TotalTokens,
}, []*llms.OutputMsg{llms.NewOutputMsg(llms.MessageTypeAssistant, rMsg.Message.Name, content, fType, fnList...)}), nil
}
func (llm *LLM) Request(ctx context.Context, input *llms.Input, tb llms.ToolBox) (*llms.Output, error) {
msgs, tools := input.CurrValues(), tb.Tools()
msgLen, toolLen := len(msgs), len(tools)
req := &api.ChatReq{
Model: &api.Model{
Name: llm.model,
Version: llm.modelVer, // use default version if not specified.
},
Messages: make([]*api.Message, 0, msgLen+msgLen/4),
Functions: make([]*api.Function, 0, toolLen),
Parameters: &api.Parameters{
MaxNewTokens: llm.maxNewTokens, // 输出文本的最大tokens限制
MinNewTokens: llm.minNewTokens, // 输出文本的最小tokens限制
Temperature: llm.temperature, // 用于控制生成文本的随机性和创造性,Temperature值越大随机性越大,取值范围0~1
TopP: llm.topP, // 用于控制输出tokens的多样性,TopP值越大输出的tokens类型越丰富,取值范围0~1
TopK: llm.topK, // 选择预测值最大的k个token进行采样,取值范围0-1000,0表示不生效
MaxPromptTokens: llm.maxPromptTokens, // 最大输入 token 数,如果给出的 prompt 的 token 长度超过此限制,取最后 max_prompt_tokens 个 token 输入模型。
},
}
// 构造参数 & 工具
if err := llm._buildMsgList(req, input, msgs); err != nil {
return nil, err
}
//if err := llm._buildToolList(req, input, tools); err != nil {
// return nil, err
//}
// 参数请求
resp, status, err := llm.clt.Chat(req)
if err != nil {
errVal := &api.Error{}
if errors.As(err, &errVal) { // the returned error always type of *api.Error
err = fmt.Errorf("meet maas error=%v, status=%d\n", errVal, status)
}
return nil, err
}
// 结果解析
return llm._ParserResponse2Output(resp, tb)
}
func NewLLM(apiKey, apiSecret string, model ModelType, modelVer string, opt *Option) (*LLM, error) {
if apiKey == "" || apiSecret == "" {
return nil, fmt.Errorf("api_key or api_secret is empty")
}
if opt == nil {
opt = &Option{}
}
if opt.temperature == 0 {
opt.temperature = defaultTemperature
}
if opt.topP == 0 {
opt.topP = defaultTopP
}
if opt.topK == 0 {
opt.topK = defaultTopK
}
if opt.minNewTokens == 0 {
opt.minNewTokens = defaultMinNewTokens
}
if opt.maxNewTokens == 0 {
opt.maxNewTokens = defaultMaxNewTokens
}
if opt.maxPromptTokens == 0 {
opt.maxPromptTokens = defaultMaxPromptTokens
}
if opt.host == "" {
opt.host = defaultHost
}
if opt.region == "" {
opt.region = defaultRegion
}
clt := maas.NewInstance(opt.host, opt.region)
clt.SetAccessKey(apiKey)
clt.SetSecretKey(apiSecret)
if opt.clt != nil {
clt.Client.Client = opt.clt
}
if modelVer == "" {
if v, ok := _ModelDefaultVersionMap[model]; ok {
modelVer = v
}
}
ret := &LLM{
_Params: opt._Params,
apiKey: apiKey,
apiSecret: apiSecret,
model: string(model),
modelVer: modelVer,
clt: clt,
}
return ret, nil
}
Go
1
https://gitee.com/kzangv/gsf-ai-agent.git
git@gitee.com:kzangv/gsf-ai-agent.git
kzangv
gsf-ai-agent
gsf-ai-agent
v0.0.7

搜索帮助