1 Star 0 Fork 0

Survivor_zzc/langchaingo

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
vector_math.go 1.70 KB
一键复制 编辑 原始数据 按行查看 历史
Survivor_zzc 提交于 2023-10-24 15:15 . mark
package embeddings
import (
"errors"
"math"
)
var (
// ErrVectorsNotSameSize is returned if the vectors returned from the
// embeddings api have different sizes.
ErrVectorsNotSameSize = errors.New("vectors gotten not the same size")
// ErrAllTextsLenZero is returned if all texts to be embedded has the combined
// length of zero.
ErrAllTextsLenZero = errors.New("all texts have length 0")
)
func CombineVectors(vectors [][]float64, weights []int) ([]float64, error) {
average, err := getAverage(vectors, weights)
if err != nil {
return nil, err
}
averageNorm := getNorm(average)
for i := 0; i < len(average); i++ {
average[i] /= averageNorm
}
return average, nil
}
// getAverage does the following calculation:
//
// avg = sum(vectors * weights) / sum(weights).
func getAverage(vectors [][]float64, weights []int) ([]float64, error) {
// Check that all vectors are the same size and get that size.
vectorLen := -1
for _, vector := range vectors {
if vectorLen == -1 {
vectorLen = len(vector)
continue
}
if len(vector) != vectorLen {
return nil, ErrVectorsNotSameSize
}
}
if vectorLen == -1 {
return []float64{}, nil
}
// Get the sum of the weights.
weightSum := 0
for _, weight := range weights {
weightSum += weight
}
if weightSum == 0 {
return nil, ErrAllTextsLenZero
}
average := make([]float64, vectorLen)
for i := 0; i < vectorLen; i++ {
for j := 0; j < len(vectors); j++ {
average[i] += vectors[j][i] * float64(weights[j])
}
}
for i := 0; i < len(average); i++ {
average[i] /= float64(weightSum)
}
return average, nil
}
func getNorm(v []float64) float64 {
var sum float64
for i := 0; i < len(v); i++ {
sum += v[i] * v[i]
}
return math.Sqrt(sum)
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/zzcadmin/langchaingo.git
git@gitee.com:zzcadmin/langchaingo.git
zzcadmin
langchaingo
langchaingo
v0.2.3

搜索帮助

344bd9b3 5694891 D2dac590 5694891