代码拉取完成,页面将自动刷新
import torch # 引入PyTorch,用于构建和训练神经网络模型
import torch.nn as nn # 引入torch.nn,用于定义神经网络层和损失函数
import torch.optim as optim # 引入torch.optim,用于定义优化器
from torchvision import datasets, transforms # 引入torchvision,用于加载和转换数据集
from torch.utils.data import DataLoader # 引入DataLoader,用于创建数据加载器
# 定义数据转换,将图像转换为张量并进行归一化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载MNIST训练集和测试集,并应用数据转换
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 将训练集划分为训练集和验证集
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
# 创建数据加载器,用于批次化和随机化数据
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# 定义多层感知机模型
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.flatten = nn.Flatten() # 定义数据展平层,将二维图像数据转换为一维向量
self.fc1 = nn.Linear(784, 20) # 定义第一个全连接层,输入维度为784,输出维度为20
self.relu = nn.ReLU() # 定义ReLU激活函数
self.fc2 = nn.Linear(20, 10) # 定义第二个全连接层,输入维度为20,输出维度为10
def forward(self, x):
x = self.flatten(x) # 将输入数据展平为一维向量
x = self.fc1(x) # 通过第一个全连接层
x = self.relu(x) # 应用ReLU激活函数
x = self.fc2(x) # 通过第二个全连接层并输出结果
return x
# 创建多层感知机模型实例
model = MLP()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 定义交叉熵损失函数,用于计算模型输出与真实标签之间的差异
optimizer = optim.Adam(model.parameters(), lr=0.001) # 定义Adam优化器,用于更新模型参数
# 训练模型
num_epochs = 20 # 设置训练的总轮数为20
for epoch in range(num_epochs): # 开始训练循环,每次循环称为一个epoch
# 训练阶段
model.train() # 将模型设置为训练模式,启用BatchNormalization和Dropout
train_loss = 0.0 # 初始化训练损失为0
train_acc = 0.0 # 初始化训练准确率为0
for images, labels in train_loader: # 遍历训练数据加载器,每次处理一个批次的数据
outputs = model(images) # 将当前批次的图像输入到模型中,得到预测输出
loss = criterion(outputs, labels) # 使用损失函数计算预测输出和真实标签之间的损失
optimizer.zero_grad() # 清零模型参数的梯度,为下一次梯度计算做准备
loss.backward() # 反向传播计算损失函数关于模型参数的梯度
optimizer.step() # 使用优化器更新模型参数,根据计算得到的梯度优化模型
train_loss += loss.item() * images.size(0) # 累加当前批次的训练损失,乘以批次大小得到样本总损失
_, predicted = torch.max(outputs.data, 1) # 获取预测概率最大的类别索引
train_acc += (predicted == labels).sum().item() # 统计预测正确的样本数
train_loss /= len(train_dataset) # 计算平均训练损失,除以训练集总样本数
train_acc /= len(train_dataset) # 计算训练准确率,除以训练集总样本数
# 验证阶段
model.eval() # 将模型设置为评估模式,禁用BatchNormalization和Dropout
val_loss = 0.0 # 初始化验证损失为0
val_acc = 0.0 # 初始化验证准确率为0
with torch.no_grad(): # 关闭梯度计算,减少内存消耗和加速计算
for images, labels in val_loader: # 遍历验证数据加载器,每次处理一个批次的数据
outputs = model(images) # 将当前批次的图像输入到模型中,得到预测输出
loss = criterion(outputs, labels) # 使用损失函数计算预测输出和真实标签之间的损失
val_loss += loss.item() * images.size(0) # 累加当前批次的验证损失,乘以批次大小得到样本总损失
_, predicted = torch.max(outputs.data, 1) # 获取预测概率最大的类别索引
val_acc += (predicted == labels).sum().item() # 统计预测正确的样本数
val_loss /= len(val_dataset) # 计算平均验证损失,除以验证集总样本数
val_acc /= len(val_dataset) # 计算验证准确率,除以验证集总样本数
# 打印当前epoch的训练损失、训练准确率、验证损失和验证准确率
print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
# 在测试集上评估模型
model.eval() # 将模型设置为评估模式
with torch.no_grad(): # 关闭梯度计算
correct = 0
total = 0
for images, labels in test_loader: # 遍历测试数据加载器
outputs = model(images) # 将测试图像输入到模型中,得到预测输出
_, predicted = torch.max(outputs.data, 1) # 获取预测概率最大的类别索引
total += labels.size(0) # 累加测试样本总数
correct += (predicted == labels).sum().item() # 统计预测正确的样本数
accuracy = correct / total # 计算测试准确率
print(f"Test Accuracy: {accuracy:.4f}") # 打印测试准确率
# 从测试集中选择前两个样本作为推理数据
x_infer, _ = next(iter(test_loader))
x_infer = x_infer[:2]
# 使用训练好的模型对推理数据进行预测
model.eval() # 将模型设置为评估模式
with torch.no_grad(): # 关闭梯度计算
y_infer = model(x_infer) # 将推理数据输入到模型中,得到预测输出
_, predicted = torch.max(y_infer.data, 1) # 获取预测概率最大的类别索引
# 保存模型参数
torch.save(model.state_dict(), 'best_mlp_model.pt')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。