1 Star 0 Fork 0

Guo/基于线性回归的房价预测

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
regression.py 2.35 KB
一键复制 编辑 原始数据 按行查看 历史
Guo 提交于 2024-07-07 21:26 . Initial commit
# 真实函数的参数缺省值为 w=1.2,b=0.5
import torch
from DL.实验3.nndl.linear import Linear
from DL.实验3.nndl.optimizer_lsm import optimizer_lsm
from matplotlib import pyplot as plt # matplotlib 是 Python 的绘图库
from nndl.create_data import create_toy_data
from DL.实验3.nndl.mean_squared_error import mean_squared_error
def linear_func(x, w=1.2, b=0.5):
y = w * x + b
return y
y_true= torch.tensor([[-0.2],[4.9]],dtype=torch.float32)
y_pred = torch.tensor([[1.3],[2.5]],dtype=torch.float32)
error = mean_squared_error(y_true=y_true, y_pred=y_pred).item()
print("error:",error)
# func = linear_func
# interval = (-10, 10)
# train_num = 100 # 训练样本数目
# test_num = 50 # 测试样本数目
# noise = 2
# X_train, y_train = create_toy_data(func=func, interval=interval, sample_num=train_num, noise=noise, add_outlier=False)
# X_test, y_test = create_toy_data(func=func, interval=interval, sample_num=test_num, noise=noise, add_outlier=False)
#
# X_train_large, y_train_large = create_toy_data(func=func, interval=interval, sample_num=5000, noise=noise,
# add_outlier=False)
#
# # torch.linspace返回一个Tensor,Tensor的值为在区间start和stop上均匀间隔的num个值,输出Tensor的长度为num
# X_underlying = torch.linspace(interval[0], interval[1], train_num)
# y_underlying = linear_func(X_underlying)
#
# # 绘制数据
# plt.scatter(X_train, y_train, marker='*', facecolor="none", edgecolor='#e4007f', s=50, label="train data")
# plt.scatter(X_test, y_test, facecolor="none", edgecolor='#f19ec2', s=50, label="test data")
# plt.plot(X_underlying, y_underlying, c='#000000', label=r"underlying distribution")
# plt.legend(fontsize='x-large') # 给图像加图例
# plt.savefig('ml-vis.pdf') # 保存图像到PDF文件中
# plt.show()
# input_size = 1
# model = Linear(input_size)
# model = optimizer_lsm(model, X_train.reshape([-1, 1]), y_train.reshape([-1, 1]))
# print("w_pred:", model.params['w'].item(), "b_pred: ", model.params['b'].item())
# y_train_pred = model(X_train.reshape([-1, 1])).squeeze()
# train_error = mean_squared_error(y_true=y_train, y_pred=y_train_pred).item()
# print("train error: ", train_error)
# y_test_pred = model(X_test.reshape([-1, 1])).squeeze()
# test_error = mean_squared_error(y_true=y_test, y_pred=y_test_pred).item()
# print("test error: ", test_error)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/gjx521/linear_predict.git
git@gitee.com:gjx521/linear_predict.git
gjx521
linear_predict
基于线性回归的房价预测
master

搜索帮助