1 Star 0 Fork 0

kzangv / gsf-ai-agent

Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
Clone or Download
llm.go 7.85 KB
Copy Edit Raw Blame History
kzangv authored 2024-03-29 16:40 . fixed
package azure
import (
"context"
"encoding/json"
"fmt"
"gitee.com/kzangv/gsf-ai-agent/llms"
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
)
type ModelType string
const (
ModelTypeGpt35Turbo ModelType = "gpt-35-turbo"
ModelTypeGpt35Turbo16k ModelType = "gpt-35-turbo-16k"
ModelTypeGpt4Pro ModelType = "gpt-4"
ModelTypeGpt4Pro32K ModelType = "gpt-4-32k"
)
const (
defaultTemperature = 0.7
defaultTopP = 1
defaultMaxTokens = 500
)
var (
_ModelRespFinishMap = map[azopenai.CompletionsFinishReason]llms.OutputFinishType{
azopenai.CompletionsFinishReasonStopped: llms.OutputFinishStop,
azopenai.CompletionsFinishReasonTokenLimitReached: llms.OutputFinishLength,
azopenai.CompletionsFinishReasonContentFiltered: llms.OutputFinishContentFilter,
azopenai.CompletionsFinishReasonToolCalls: llms.OutputFinishToolCalls,
azopenai.CompletionsFinishReasonFunctionCall: llms.OutputFinishToolCalls,
}
)
/**
Option
*/
type Option struct {
_LLMParams
apiEndpoint string
}
func (p *Option) SetTemperature(val float32) *Option {
p.temperature = val
return p
}
func (p *Option) SetTopP(val float32) *Option {
p.topP = val
return p
}
func (p *Option) SetMaxTokens(val int32) *Option {
p.maxTokens = val
return p
}
func (p *Option) SetAzureEndPoint(endpoint string) *Option {
p.apiEndpoint = endpoint
return p
}
func NewOption() *Option {
return &Option{}
}
/**
LLM
*/
type _LLMParams struct {
temperature,
topP float32
maxTokens int32
}
type LLM struct {
_LLMParams
model string
modelDeploy string
apiSecret string
client *azopenai.Client
}
func (llm *LLM) CheckSupport(v llms.LLMSupportType) bool {
return v == llms.LLMSupportTypeTool
}
func (llm *LLM) _buildMsgList(req *azopenai.ChatCompletionsOptions, input *llms.Input, msgs []llms.Message) error {
if sysMsg := input.GetSysMsg(); sysMsg != nil {
var name *string = nil
if sysMsg.Name != "" {
name = &sysMsg.Name
}
req.Messages = append(req.Messages, &azopenai.ChatRequestSystemMessage{Name: name, Content: &sysMsg.Content})
}
for _, msg := range msgs {
switch msg.MsgType() {
case llms.MessageTypeUser:
mVal := msg.(*llms.UserMessage)
var name *string = nil
if nVal := mVal.MsgRoleName(); nVal != "" {
name = &nVal
}
req.Messages = append(req.Messages, &azopenai.ChatRequestUserMessage{
Name: name, Content: azopenai.NewChatRequestUserMessageContent(mVal.Content)})
case llms.MessageTypeAssistant:
mVal := msg.(*llms.AssistantMessage)
var name *string = nil
if nVal := mVal.MsgRoleName(); nVal != "" {
name = &nVal
}
if mVal.FnList == nil {
req.Messages = append(req.Messages, &azopenai.ChatRequestAssistantMessage{
Name: name, Content: &mVal.Content})
} else if len(mVal.FnList) > 0 {
tools := make([]azopenai.ChatCompletionsToolCallClassification, 0, len(mVal.FnList))
fnResults := make([]azopenai.ChatRequestMessageClassification, 0, len(mVal.FnList))
for _, fn := range mVal.FnList {
tools = append(tools, &azopenai.ChatCompletionsFunctionToolCall{
ID: &fn.ID,
Type: to.Ptr("function"),
Function: &azopenai.FunctionCall{Name: &fn.Name, Arguments: &fn.Arguments},
})
fnResults = append(fnResults, &azopenai.ChatRequestToolMessage{ToolCallID: &fn.ID, Content: &fn.Result})
}
req.Messages = append(req.Messages, &azopenai.ChatRequestAssistantMessage{
Name: name, ToolCalls: tools})
req.Messages = append(req.Messages, fnResults...)
}
}
}
return nil
}
func (llm *LLM) _buildToolList(req *azopenai.ChatCompletionsOptions, _ *llms.Input, toolList []llms.Tool) error {
for _, tool := range toolList {
if tool.Type() == llms.ToolTypeFunction {
tFn := tool.(*llms.FunctionTool)
paramSchemeStr, err := tFn.ParamJsonScheme()
if err != nil {
return fmt.Errorf("func [%s] params encode json scheme error: %s", tFn.Name(), err.Error())
}
req.Tools = append(req.Tools, &azopenai.ChatCompletionsFunctionToolDefinition{
Type: to.Ptr("function"),
Function: &azopenai.FunctionDefinition{
Name: to.Ptr(tFn.Name()),
Description: to.Ptr(tFn.Desc()),
Parameters: json.RawMessage(paramSchemeStr),
},
})
}
}
if len(req.Tools) > 1 {
req.ToolChoice = azopenai.ChatCompletionsToolChoiceAuto
}
return nil
}
func (llm *LLM) _ParserResponse2Output(resp *azopenai.GetChatCompletionsResponse, tb llms.ToolBox) (*llms.Output, error) {
var (
fnList []*llms.FunctionCall
content string
)
rMsg := &resp.Choices[0]
fType := llms.OutputFinishInvalid
if ft, ok := _ModelRespFinishMap[*rMsg.FinishReason]; ok {
fType = ft
}
if *rMsg.FinishReason == azopenai.CompletionsFinishReasonToolCalls {
if rMsg.Message.ToolCalls == nil || len(rMsg.Message.ToolCalls) == 0 {
return nil, fmt.Errorf("response is invalid %v", resp)
}
fnList = make([]*llms.FunctionCall, 0, len(rMsg.Message.ToolCalls))
for k := range rMsg.Message.ToolCalls {
v := rMsg.Message.ToolCalls[k].GetChatCompletionsToolCall()
if *v.Type == "function" {
fnVal := rMsg.Message.ToolCalls[k].(*azopenai.ChatCompletionsFunctionToolCall)
fn := &llms.FunctionCall{
ID: *v.ID,
Name: *fnVal.Function.Name,
Arguments: *fnVal.Function.Arguments,
}
if tool, ok := tb.Get(*fnVal.Function.Name); !ok {
return nil, fmt.Errorf("tools.ToolBox no tool: %v", *fnVal.Function.Name)
} else if tool.Type() != llms.ToolTypeFunction {
return nil, fmt.Errorf("tool type is no function (type: %d)", tool.Type())
} else {
var err error
fn.Result, err = tool.Run(*fnVal.Function.Arguments)
if err != nil {
return nil, fmt.Errorf("call function faild reason: %s", err.Error())
}
}
fnList = append(fnList, fn)
}
}
} else {
content = *rMsg.Message.Content
}
return llms.NewOutput(llms.TokenUsage{
Prompt: int64(*resp.Usage.PromptTokens),
Completion: int64(*resp.Usage.CompletionTokens),
Total: int64(*resp.Usage.TotalTokens),
}, []*llms.OutputMsg{llms.NewOutputMsg(llms.MessageTypeAssistant, "", content, fType, fnList...)}), nil
}
func (llm *LLM) Request(ctx context.Context, input *llms.Input, tb llms.ToolBox) (*llms.Output, error) {
msgs, tools := input.CurrValues(), tb.Tools()
msgLen, toolLen := len(msgs), len(tools)
var toolList []azopenai.ChatCompletionsToolDefinitionClassification = nil
if toolLen > 0 {
toolList = make([]azopenai.ChatCompletionsToolDefinitionClassification, 0, toolLen)
}
req := azopenai.ChatCompletionsOptions{
DeploymentName: &llm.modelDeploy,
MaxTokens: &llm.maxTokens,
TopP: &llm.topP,
Temperature: &llm.temperature,
Messages: make([]azopenai.ChatRequestMessageClassification, 0, msgLen+msgLen/4),
Tools: toolList,
}
// 构造参数 & 工具
if err := llm._buildMsgList(&req, input, msgs); err != nil {
return nil, err
}
if err := llm._buildToolList(&req, input, tools); err != nil {
return nil, err
}
// 参数请求
resp, err := llm.client.GetChatCompletions(ctx, req, nil)
if err != nil {
return nil, err
}
// 结果解析
return llm._ParserResponse2Output(&resp, tb)
}
func NewLLM(apiSecret string, model ModelType, modelDeploy string, opt *Option) (*LLM, error) {
if apiSecret == "" {
return nil, fmt.Errorf("api_key or api_secret is empty")
}
keyCredential := azcore.NewKeyCredential(apiSecret)
client, err := azopenai.NewClientWithKeyCredential(opt.apiEndpoint, keyCredential, nil)
if err != nil {
return nil, err
}
if opt.temperature == 0 {
opt.temperature = defaultTemperature
}
if opt.topP == 0 {
opt.topP = defaultTopP
}
if opt.maxTokens <= 0 {
opt.maxTokens = defaultMaxTokens
}
ret := &LLM{
apiSecret: apiSecret,
model: string(model),
modelDeploy: modelDeploy,
client: client,
_LLMParams: opt._LLMParams,
}
return ret, nil
}
Go
1
https://gitee.com/kzangv/gsf-ai-agent.git
git@gitee.com:kzangv/gsf-ai-agent.git
kzangv
gsf-ai-agent
gsf-ai-agent
v0.0.7

Search