1 Star 0 Fork 0

Survivor_zzc / langchaingo

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
openai_chat.go 1.54 KB
一键复制 编辑 原始数据 按行查看 历史
Survivor_zzc 提交于 2023-10-24 17:21 . package name
package openaichat
import (
"context"
"strings"
"gitee.com/zzcadmin/langchaingo/embeddings"
"gitee.com/zzcadmin/langchaingo/llms/openai"
)
// ChatOpenAI is the embedder using the OpenAI api.
type ChatOpenAI struct {
client *openai.Chat
StripNewLines bool
BatchSize int
}
var _ embeddings.Embedder = ChatOpenAI{}
// NewChatOpenAI creates a new ChatOpenAI with options. Options for client, strip new lines and batch.
func NewChatOpenAI(opts ...ChatOption) (ChatOpenAI, error) {
o, err := applyChatClientOptions(opts...)
if err != nil {
return ChatOpenAI{}, err
}
return o, nil
}
func (e ChatOpenAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
batchedTexts := embeddings.BatchTexts(
embeddings.MaybeRemoveNewLines(texts, e.StripNewLines),
e.BatchSize,
)
emb := make([][]float64, 0, len(texts))
for _, texts := range batchedTexts {
curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts)
if err != nil {
return nil, err
}
textLengths := make([]int, 0, len(texts))
for _, text := range texts {
textLengths = append(textLengths, len(text))
}
combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths)
if err != nil {
return nil, err
}
emb = append(emb, combined)
}
return emb, nil
}
func (e ChatOpenAI) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
if e.StripNewLines {
text = strings.ReplaceAll(text, "\n", " ")
}
emb, err := e.client.CreateEmbedding(ctx, []string{text})
if err != nil {
return nil, err
}
return emb[0], nil
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/zzcadmin/langchaingo.git
git@gitee.com:zzcadmin/langchaingo.git
zzcadmin
langchaingo
langchaingo
v0.4.7

搜索帮助

344bd9b3 5694891 D2dac590 5694891