代码拉取完成,页面将自动刷新
package ernie
import (
"context"
"gitee.com/zzcadmin/langchaingo/embeddings"
"gitee.com/zzcadmin/langchaingo/llms/ernie"
)
// Ernie Embedding-V1 doc: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu
type Ernie struct {
client *ernie.LLM
batchSize int // 每个文本长度不超过 384个token
batchCount int // 文本数量不超过16
stripNewLines bool
}
var _ embeddings.Embedder = &Ernie{}
// NewErnie creates a new Ernie with options. Options for client, strip new lines and batch size.
func NewErnie(opts ...Option) (*Ernie, error) {
v := &Ernie{
stripNewLines: defaultStripNewLines,
batchSize: defaultBatchSize,
batchCount: defaultBatchCount,
}
for _, opt := range opts {
opt(v)
}
if v.client == nil {
client, err := ernie.New()
if err != nil {
return nil, err
}
v.client = client
}
return v, nil
}
// split texts with batchCount.
func (e *Ernie) embed(ctx context.Context, texts []string) ([][]float64, error) {
emb := make([][]float64, 0, len(texts))
offsetLen := len(texts) / e.batchCount
for i := 0; i <= offsetLen; i++ {
start := i * e.batchCount
end := i*e.batchCount + e.batchCount
if end > len(texts) {
end = len(texts)
}
curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts[start:end])
if err != nil {
return nil, err
}
emb = append(emb, curTextEmbeddings...)
}
return emb, nil
}
// EmbedDocuments use ernie Embedding-V1.
func (e *Ernie) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
batchedTexts := embeddings.BatchTexts(
embeddings.MaybeRemoveNewLines(texts, e.stripNewLines),
e.batchSize,
)
emb := make([][]float64, 0, len(texts))
for _, texts := range batchedTexts {
curTextEmbeddings, err := e.embed(ctx, texts)
if err != nil {
return nil, err
}
textLengths := make([]int, 0, len(texts))
for _, text := range texts {
textLengths = append(textLengths, len(text))
}
combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths)
if err != nil {
return nil, err
}
emb = append(emb, combined)
}
return emb, nil
}
// EmbedQuery use ernie Embedding-V1.
func (e *Ernie) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
emb, err := e.EmbedDocuments(ctx, []string{text})
if err != nil {
return nil, err
}
return emb[0], nil
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。