代码拉取完成,页面将自动刷新
# 真实函数的参数缺省值为 w=1.2,b=0.5
import torch
from DL.实验3.nndl.linear import Linear
from DL.实验3.nndl.optimizer_lsm import optimizer_lsm
from matplotlib import pyplot as plt # matplotlib 是 Python 的绘图库
from nndl.create_data import create_toy_data
from DL.实验3.nndl.mean_squared_error import mean_squared_error
def linear_func(x, w=1.2, b=0.5):
y = w * x + b
return y
y_true= torch.tensor([[-0.2],[4.9]],dtype=torch.float32)
y_pred = torch.tensor([[1.3],[2.5]],dtype=torch.float32)
error = mean_squared_error(y_true=y_true, y_pred=y_pred).item()
print("error:",error)
# func = linear_func
# interval = (-10, 10)
# train_num = 100 # 训练样本数目
# test_num = 50 # 测试样本数目
# noise = 2
# X_train, y_train = create_toy_data(func=func, interval=interval, sample_num=train_num, noise=noise, add_outlier=False)
# X_test, y_test = create_toy_data(func=func, interval=interval, sample_num=test_num, noise=noise, add_outlier=False)
#
# X_train_large, y_train_large = create_toy_data(func=func, interval=interval, sample_num=5000, noise=noise,
# add_outlier=False)
#
# # torch.linspace返回一个Tensor,Tensor的值为在区间start和stop上均匀间隔的num个值,输出Tensor的长度为num
# X_underlying = torch.linspace(interval[0], interval[1], train_num)
# y_underlying = linear_func(X_underlying)
#
# # 绘制数据
# plt.scatter(X_train, y_train, marker='*', facecolor="none", edgecolor='#e4007f', s=50, label="train data")
# plt.scatter(X_test, y_test, facecolor="none", edgecolor='#f19ec2', s=50, label="test data")
# plt.plot(X_underlying, y_underlying, c='#000000', label=r"underlying distribution")
# plt.legend(fontsize='x-large') # 给图像加图例
# plt.savefig('ml-vis.pdf') # 保存图像到PDF文件中
# plt.show()
# input_size = 1
# model = Linear(input_size)
# model = optimizer_lsm(model, X_train.reshape([-1, 1]), y_train.reshape([-1, 1]))
# print("w_pred:", model.params['w'].item(), "b_pred: ", model.params['b'].item())
# y_train_pred = model(X_train.reshape([-1, 1])).squeeze()
# train_error = mean_squared_error(y_true=y_train, y_pred=y_train_pred).item()
# print("train error: ", train_error)
# y_test_pred = model(X_test.reshape([-1, 1])).squeeze()
# test_error = mean_squared_error(y_true=y_test, y_pred=y_test_pred).item()
# print("test error: ", test_error)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。