代码拉取完成,页面将自动刷新
// ++++++++++++++++++++++++++++++++++++++++
// 《零基础Go语言算法实战》源码
// ++++++++++++++++++++++++++++++++++++++++
// Author:廖显东(ShirDon)
// Blog:https://www.shirdon.com/
// Gitee:https://gitee.com/shirdonl/goAlgorithms.git
// Buy link :https://item.jd.com/14101229.html
// ++++++++++++++++++++++++++++++++++++++++
package main
import (
"fmt"
"math"
)
// 逻辑回归结构体,包含权重、学习率和迭代次数
type LogisticRegression struct {
weights []float64
lr float64
iterations int
}
// 创建新的逻辑回归对象并返回
func NewLogisticRegression(lr float64, iterations int) *LogisticRegression {
return &LogisticRegression{lr: lr, iterations: iterations}
}
// 计算z的 sigmoid 函数值并返回
func (l *LogisticRegression) sigmoid(z float64) float64 {
return 1.0 / (1.0 + math.Exp(-z))
}
// 预测方法,给定输入向量X,预测输出并返回
func (l *LogisticRegression) predict(X []float64) float64 {
var z float64
for i, xi := range X {
z += xi * l.weights[i]
}
return l.sigmoid(z)
}
// 训练方法,给定输入矩阵X和输出向量y,训练模型的权重
func (l *LogisticRegression) train(X [][]float64, y []float64) {
nSamples := len(X)
nFeatures := len(X[0])
l.weights = make([]float64, nFeatures)
// 进行多次迭代,更新权重
for i := 0; i < l.iterations; i++ {
for j := 0; j < nSamples; j++ {
yPred := l.predict(X[j])
res := y[j] - yPred
for k := 0; k < nFeatures; k++ {
l.weights[k] += l.lr * res * X[j][k]
}
}
}
}
// 预测方法,给定输入矩阵X,预测输出并返回
func (l *LogisticRegression) predictAll(X [][]float64) []float64 {
nSamples := len(X)
yPred := make([]float64, nSamples)
for i, xi := range X {
yPred[i] = l.predict(xi)
}
return yPred
}
func main() {
// 输入矩阵X
X := [][]float64{
{1, 2},
{2, 1},
{3, 4},
{4, 3},
}
//输出向量y
y := []float64{0, 0, 1, 1}
// 创建逻辑回归对象并训练模型
lr := NewLogisticRegression(0.1, 100)
lr.train(X, y)
// 使用训练好的模型进行预测
yPred := lr.predictAll(X)
// 输出预测结果
fmt.Println(yPred)
}
//$ go run logisticRegression.go
//[0.7203202998607646 0.6688645965102862 0.8854875088641785 0.858447853849889]
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。