代码拉取完成,页面将自动刷新
为了解决MNIST分类问题,我们创建一个简单的多层感知器 使用torch.nn.Linear层创建, 本章使用ReLU函数作为激活函数,并使用两个Linear层(全连接层)作为隐藏层
import torch
from torch import nn
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,6" # 只使用空闲的 GPU
## 3.1.1 模型的建立
# 记得吗,上一讲,我们的train_dl的形状为torch.Size([64,1,28,28])
class Model(nn.Module): # 创建模型,继承自nn.Module
def __init__(self):
super().__init__()
# 第一层输入展平后的特征长度28乘28,创建120个神经元
self.liner_1=nn.Linear(28*28,120)
# 第二层输入的是前一层的输出,创建84个神经元
self.liner_2=nn.Linear(120,84)
# 输出层接受第二层的输入84,输出分类个数10
self.liner_3=nn.Linear(84,10)
def forward(self,input):
x=input.view(-1,28*28) # 将输入展平为二维(1,28,28)->(28*28)
x=torch.relu(self.liner_1(x))
x=torch.relu(self.liner_2(x))
x=self.liner_3(x)
return x
## 3.1.2 分类问题的损失函数
loss_fn=nn.CrossEntropyLoss() # 初始化交叉熵损失函数
# 可以认为它综合了softmax计算和交叉熵损失计算
'''
注意两个参数
1. weight
2. ignore_index
另外,它要求实际类别为数值编码,而不是独热编码
'''
'\n注意两个参数\n1. weight\n2. ignore_index\n另外,它要求实际类别为数值编码,而不是独热编码\n'
## 3.1.3 初始化模型
device="cuda" if torch.cuda.is_available() else "cpu"
model=Model().to(device)
# print(device)
## 3.1.4 优化器 均在torch.optim模块下
# 随机梯度下降优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.005)
## 3.1.5 编写训练循环
### 3.1.5.1 训练函数
def train(dataloader,model,loss_fn,optimizer):
size=len(dataloader.dataset) #获取当前数据集样本总数量
num_batches=len(dataloader) #获取当前data loader总批次数
# train_loss用于累计所有批次的损失之和, correct用于累计预测正确的样本总数
train_loss,correct=0,0
for X,y in dataloader:
X,y=X.to(device),y.to(device)
# 进行预测,并计算第一个批次的损失
pred=model(X)
loss=loss_fn(pred,y)
# 利用反向传播算法,根据损失优化模型参数
optimizer.zero_grad() #先将梯度清零
loss.backward() # 损失反向传播,计算模型参数梯度
optimizer.step() #根据梯度优化参数
with torch.no_grad():
# correct用于累计预测正确的样本总数
correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
# train_loss用于累计所有批次的损失之和
train_loss+=loss.item()
# train_loss 是所有批次的损失之和,所以计算全部样本的平均损失时需要除以总的批次数
train_loss/=num_batches
# correct 是预测正确的样本总数,若计算整个apoch总体正确率,需要除以样本总数量
correct/=size
return train_loss,correct
len(dataset)
数据集总样本数(如100) 数据统计、划分len(dataloader)
总批次数(如4) 训练循环控制len(dataloader.dataset)
等同于 len(dataset) 需要访问原始数据时### 3.1.5.2 测试函数
def test(dataloader,model):
size=len(dataloader.dataset)
num_batches=len(dataloader)
test_loss,correct=0,0
with torch.no_grad():
for X,y in dataloader:
X,y=X.to(device),y.to(device)
pred=model(X)
test_loss+=loss_fn(pred,y).item()
correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
test_loss/=num_batches
correct/=size
return test_loss,correct
# 我们先把数据集加载进来
import torchvision
from torchvision.transforms import ToTensor
train_ds=torchvision.datasets.MNIST("data/",train=True,transform=ToTensor(),download=True)
test_ds=torchvision.datasets.MNIST("data/",train=False,transform=ToTensor(),download=True)
train_dl=torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
test_dl=torch.utils.data.DataLoader(test_ds,batch_size=64)
### 3.1.5.3 正式编写训练循环
# 我们对全部的数据集训练50个epoch(一个epoch表示对全部数据训练一遍)
epochs=50
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
for epoch in range(epochs):
# 调用train()函数训练
epoch_loss,epoch_acc=train(train_dl,model,loss_fn,optimizer)
# 调用test()函数测试
epoch_test_loss,epoch_test_acc=test(test_dl,model)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
# 定义一个打印模板
template=("epoch:{:2d},train_loss:{:5f},train_acc:{:.1f}%,""test_loss:{:.5f},test_acc:{:.1f}%")
print(template.format(epoch,epoch_loss,epoch_acc*100,epoch_test_loss,epoch_test_acc*100))
print("Done")
epoch: 0,train_loss:2.157364,train_acc:46.7%,test_loss:1.83506,test_acc:63.7% epoch: 1,train_loss:1.222660,train_acc:74.3%,test_loss:0.74291,test_acc:81.8% epoch: 2,train_loss:0.612381,train_acc:84.0%,test_loss:0.49773,test_acc:86.3% epoch: 3,train_loss:0.470115,train_acc:87.0%,test_loss:0.41321,test_acc:88.3% epoch: 4,train_loss:0.409083,train_acc:88.5%,test_loss:0.37054,test_acc:89.3% epoch: 5,train_loss:0.372433,train_acc:89.5%,test_loss:0.34050,test_acc:90.4% epoch: 6,train_loss:0.347929,train_acc:90.1%,test_loss:0.32208,test_acc:90.8% epoch: 7,train_loss:0.329210,train_acc:90.6%,test_loss:0.30710,test_acc:91.1% epoch: 8,train_loss:0.314265,train_acc:90.9%,test_loss:0.29423,test_acc:91.5% epoch: 9,train_loss:0.301913,train_acc:91.3%,test_loss:0.28404,test_acc:92.0% epoch:10,train_loss:0.290773,train_acc:91.7%,test_loss:0.27225,test_acc:92.2% epoch:11,train_loss:0.280149,train_acc:91.9%,test_loss:0.26680,test_acc:92.4% epoch:12,train_loss:0.270489,train_acc:92.2%,test_loss:0.25432,test_acc:92.6% epoch:13,train_loss:0.261573,train_acc:92.5%,test_loss:0.24773,test_acc:92.9% epoch:14,train_loss:0.252997,train_acc:92.7%,test_loss:0.23850,test_acc:93.2% epoch:15,train_loss:0.244789,train_acc:93.0%,test_loss:0.23177,test_acc:93.3% epoch:16,train_loss:0.237236,train_acc:93.2%,test_loss:0.22550,test_acc:93.5% epoch:17,train_loss:0.229649,train_acc:93.4%,test_loss:0.21949,test_acc:93.8% epoch:18,train_loss:0.222713,train_acc:93.6%,test_loss:0.21277,test_acc:93.8% epoch:19,train_loss:0.215915,train_acc:93.9%,test_loss:0.20566,test_acc:93.9% epoch:20,train_loss:0.209662,train_acc:94.0%,test_loss:0.19952,test_acc:94.1% epoch:21,train_loss:0.203692,train_acc:94.2%,test_loss:0.19441,test_acc:94.3% epoch:22,train_loss:0.197726,train_acc:94.4%,test_loss:0.19170,test_acc:94.5% epoch:23,train_loss:0.192294,train_acc:94.5%,test_loss:0.18431,test_acc:94.6% epoch:24,train_loss:0.187299,train_acc:94.7%,test_loss:0.18081,test_acc:94.7% epoch:25,train_loss:0.182354,train_acc:94.8%,test_loss:0.17475,test_acc:94.8% epoch:26,train_loss:0.177634,train_acc:94.9%,test_loss:0.17086,test_acc:95.0% epoch:27,train_loss:0.172992,train_acc:95.0%,test_loss:0.16722,test_acc:95.1% epoch:28,train_loss:0.168787,train_acc:95.1%,test_loss:0.16434,test_acc:95.0% epoch:29,train_loss:0.164727,train_acc:95.3%,test_loss:0.16026,test_acc:95.2% epoch:30,train_loss:0.160919,train_acc:95.4%,test_loss:0.15702,test_acc:95.4% epoch:31,train_loss:0.156920,train_acc:95.5%,test_loss:0.15420,test_acc:95.4% epoch:32,train_loss:0.153295,train_acc:95.7%,test_loss:0.15116,test_acc:95.6% epoch:33,train_loss:0.150030,train_acc:95.7%,test_loss:0.14879,test_acc:95.6% epoch:34,train_loss:0.146605,train_acc:95.8%,test_loss:0.14557,test_acc:95.7% epoch:35,train_loss:0.143500,train_acc:95.9%,test_loss:0.14370,test_acc:95.8% epoch:36,train_loss:0.140310,train_acc:96.0%,test_loss:0.14032,test_acc:95.7% epoch:37,train_loss:0.137446,train_acc:96.1%,test_loss:0.13812,test_acc:95.9% epoch:38,train_loss:0.134580,train_acc:96.2%,test_loss:0.13612,test_acc:96.0% epoch:39,train_loss:0.131750,train_acc:96.3%,test_loss:0.13432,test_acc:95.9% epoch:40,train_loss:0.129338,train_acc:96.4%,test_loss:0.13162,test_acc:96.0% epoch:41,train_loss:0.126620,train_acc:96.4%,test_loss:0.13051,test_acc:96.1% epoch:42,train_loss:0.124075,train_acc:96.5%,test_loss:0.12820,test_acc:96.3% epoch:43,train_loss:0.121680,train_acc:96.6%,test_loss:0.12575,test_acc:96.2% epoch:44,train_loss:0.119471,train_acc:96.6%,test_loss:0.12521,test_acc:96.3% epoch:45,train_loss:0.117078,train_acc:96.7%,test_loss:0.12264,test_acc:96.3% epoch:46,train_loss:0.114974,train_acc:96.7%,test_loss:0.12109,test_acc:96.4% epoch:47,train_loss:0.113026,train_acc:96.8%,test_loss:0.11942,test_acc:96.4% epoch:48,train_loss:0.110716,train_acc:96.9%,test_loss:0.12003,test_acc:96.4% epoch:49,train_loss:0.108877,train_acc:97.0%,test_loss:0.11783,test_acc:96.5% Done
## 绘制图像
import matplotlib.pyplot as plt
plt.plot(range(1,epochs+1),train_loss,label="train_Loss")
plt.plot(range(1,epochs+1),test_loss,label="test_Loss",ls="--")
plt.xlabel("epoch")
plt.legend()
plt.show()
## 绘制图像
plt.plot(range(1,epochs+1),train_acc,label="train_acc")
plt.plot(range(1,epochs+1),test_acc,label="test_acc")
plt.xlabel("epoch")
plt.legend()
plt.show()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。