2 Star 2 Fork 9

王布衣/gox

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
matrix.go 1.71 KB
一键复制 编辑 原始数据 按行查看 历史
王布衣 提交于 2023-06-03 06:07 . 调整vek目录为num
package functions
import (
"golang.org/x/exp/constraints"
"runtime"
"sync"
)
var numCPU int = runtime.NumCPU()
// matMulParallel runs matrix multiply in parallel by dividing the input rows
func matMulParallel[T constraints.Float](
dst, x, y []T, m, n, p int,
vecMul func(dst, x, y []T, m, n int),
matMul func(dst, x, y []T, m, n, p int),
) {
if m < 4 || m*p*n < 100_000 {
if p == 1 {
vecMul(dst, x, y, m, n)
} else {
matMul(dst, x, y, m, n, p)
}
return
}
rowsPerCPU, rem := m/numCPU, m%numCPU
i := 0
var wg sync.WaitGroup
for c := 0; c < numCPU && i < m; c++ {
numRows := rowsPerCPU
if c < rem {
numRows += 1
}
dstStart := i * p
dstEnd := (i + numRows) * p
xStart := i * n
xEnd := (i + numRows) * n
wg.Add(1)
go func() {
if p == 1 {
vecMul(dst[dstStart:dstEnd], x[xStart:xEnd], y, numRows, n)
} else {
matMul(dst[dstStart:dstEnd], x[xStart:xEnd], y, numRows, n, p)
}
wg.Done()
}()
i += numRows
}
wg.Wait()
}
func MatMul_Go[T constraints.Float](dst, x, y []T, m, n, p int) {
for i := 0; i < m; i++ {
for k := 0; k < n; k++ {
for j := 0; j < p; j++ { // dst not set to zero
dst[i*p+j] += x[i*n+k] * y[k*p+j]
}
}
}
}
func MatMulVec_Go[T constraints.Float](dst, x, y []T, m, n int) {
for i := 0; i < m; i++ {
for k := 0; k < n; k++ { // note: dst is not set to zero
dst[i] += x[i*n+k] * y[k]
}
}
}
func Mat4Mul_Go[T constraints.Float](dst, x, y []T) {
for i := 0; i < 4; i++ {
for j := 0; j < 4; j++ {
dst[i*4+j] = x[i*4]*y[j] + x[i*4+1]*y[1*4+j] +
x[i*4+2]*y[2*4+j] + x[i*4+3]*y[3*4+j]
}
}
}
func MatMul_Parallel_Go[T constraints.Float](dst, x, y []T, m, n, p int) {
matMulParallel(dst, x, y, m, n, p, MatMulVec_Go[T], MatMul_Go[T])
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/quant1x/gox.git
git@gitee.com:quant1x/gox.git
quant1x
gox
gox
v1.15.1

搜索帮助