1 Star 0 Fork 0

zhujian_nwpu/My_All_Learning

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Model_Training.ipynb 64.42 KB
一键复制 编辑 原始数据 按行查看 历史

3 多层感知机模型与模型训练

为了解决MNIST分类问题,我们创建一个简单的多层感知器 使用torch.nn.Linear层创建, 本章使用ReLU函数作为激活函数,并使用两个Linear层(全连接层)作为隐藏层

import torch
from torch import nn
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,6"  # 只使用空闲的 GPU
## 3.1.1 模型的建立
# 记得吗,上一讲,我们的train_dl的形状为torch.Size([64,1,28,28])
class Model(nn.Module): # 创建模型,继承自nn.Module
    def __init__(self):
        super().__init__()
        # 第一层输入展平后的特征长度28乘28,创建120个神经元
        self.liner_1=nn.Linear(28*28,120)
        # 第二层输入的是前一层的输出,创建84个神经元
        self.liner_2=nn.Linear(120,84)
        # 输出层接受第二层的输入84,输出分类个数10
        self.liner_3=nn.Linear(84,10)

    def forward(self,input):
        x=input.view(-1,28*28) # 将输入展平为二维(1,28,28)->(28*28)
        x=torch.relu(self.liner_1(x))
        x=torch.relu(self.liner_2(x))
        x=self.liner_3(x)
        return x
## 3.1.2 分类问题的损失函数
loss_fn=nn.CrossEntropyLoss() # 初始化交叉熵损失函数
# 可以认为它综合了softmax计算和交叉熵损失计算
'''
注意两个参数
1. weight
2. ignore_index
另外,它要求实际类别为数值编码,而不是独热编码
'''
'\n注意两个参数\n1. weight\n2. ignore_index\n另外,它要求实际类别为数值编码,而不是独热编码\n'
## 3.1.3 初始化模型
device="cuda" if torch.cuda.is_available() else "cpu"
model=Model().to(device)
# print(device)
## 3.1.4 优化器 均在torch.optim模块下
# 随机梯度下降优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.005)
  • 它有两个重要的参数,第一个params表示需要优化的模型参数,第二个为学习率类型为float
## 3.1.5 编写训练循环
### 3.1.5.1 训练函数
def train(dataloader,model,loss_fn,optimizer):
    size=len(dataloader.dataset) #获取当前数据集样本总数量
    num_batches=len(dataloader) #获取当前data loader总批次数
    
    # train_loss用于累计所有批次的损失之和, correct用于累计预测正确的样本总数
    train_loss,correct=0,0
    for X,y in dataloader:
        X,y=X.to(device),y.to(device)
        
        # 进行预测,并计算第一个批次的损失
        pred=model(X)
        loss=loss_fn(pred,y)
        # 利用反向传播算法,根据损失优化模型参数
        optimizer.zero_grad() #先将梯度清零
        loss.backward() # 损失反向传播,计算模型参数梯度
        optimizer.step() #根据梯度优化参数
        
        with torch.no_grad():
            # correct用于累计预测正确的样本总数
            correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
            # train_loss用于累计所有批次的损失之和
            train_loss+=loss.item()
            
        # train_loss 是所有批次的损失之和,所以计算全部样本的平均损失时需要除以总的批次数
    train_loss/=num_batches
        # correct 是预测正确的样本总数,若计算整个apoch总体正确率,需要除以样本总数量
    correct/=size
    return train_loss,correct
        
  • 方法 返回内容 适用场景
  • len(dataset) 数据集总样本数(如100) 数据统计、划分
  • len(dataloader) 总批次数(如4) 训练循环控制
  • len(dataloader.dataset) 等同于 len(dataset) 需要访问原始数据时
### 3.1.5.2 测试函数
def test(dataloader,model):
    size=len(dataloader.dataset)
    num_batches=len(dataloader)
    test_loss,correct=0,0
    with torch.no_grad():
        for X,y in dataloader:
            X,y=X.to(device),y.to(device)
            pred=model(X)
            test_loss+=loss_fn(pred,y).item()
            correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
    test_loss/=num_batches
    correct/=size
    return test_loss,correct
# 我们先把数据集加载进来
import torchvision
from torchvision.transforms import ToTensor
train_ds=torchvision.datasets.MNIST("data/",train=True,transform=ToTensor(),download=True)
test_ds=torchvision.datasets.MNIST("data/",train=False,transform=ToTensor(),download=True)

train_dl=torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
test_dl=torch.utils.data.DataLoader(test_ds,batch_size=64)
    ### 3.1.5.3 正式编写训练循环
    # 我们对全部的数据集训练50个epoch(一个epoch表示对全部数据训练一遍)
    epochs=50 
    train_loss=[]
    train_acc=[]
    test_loss=[]
    test_acc=[]
    
    for epoch in range(epochs):
        # 调用train()函数训练
        epoch_loss,epoch_acc=train(train_dl,model,loss_fn,optimizer)
        # 调用test()函数测试
        epoch_test_loss,epoch_test_acc=test(test_dl,model)
    
        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc)
        test_loss.append(epoch_test_loss)
        test_acc.append(epoch_test_acc)
    
        # 定义一个打印模板
        template=("epoch:{:2d},train_loss:{:5f},train_acc:{:.1f}%,""test_loss:{:.5f},test_acc:{:.1f}%")
    
        print(template.format(epoch,epoch_loss,epoch_acc*100,epoch_test_loss,epoch_test_acc*100))
    
    print("Done")
epoch: 0,train_loss:2.157364,train_acc:46.7%,test_loss:1.83506,test_acc:63.7%
epoch: 1,train_loss:1.222660,train_acc:74.3%,test_loss:0.74291,test_acc:81.8%
epoch: 2,train_loss:0.612381,train_acc:84.0%,test_loss:0.49773,test_acc:86.3%
epoch: 3,train_loss:0.470115,train_acc:87.0%,test_loss:0.41321,test_acc:88.3%
epoch: 4,train_loss:0.409083,train_acc:88.5%,test_loss:0.37054,test_acc:89.3%
epoch: 5,train_loss:0.372433,train_acc:89.5%,test_loss:0.34050,test_acc:90.4%
epoch: 6,train_loss:0.347929,train_acc:90.1%,test_loss:0.32208,test_acc:90.8%
epoch: 7,train_loss:0.329210,train_acc:90.6%,test_loss:0.30710,test_acc:91.1%
epoch: 8,train_loss:0.314265,train_acc:90.9%,test_loss:0.29423,test_acc:91.5%
epoch: 9,train_loss:0.301913,train_acc:91.3%,test_loss:0.28404,test_acc:92.0%
epoch:10,train_loss:0.290773,train_acc:91.7%,test_loss:0.27225,test_acc:92.2%
epoch:11,train_loss:0.280149,train_acc:91.9%,test_loss:0.26680,test_acc:92.4%
epoch:12,train_loss:0.270489,train_acc:92.2%,test_loss:0.25432,test_acc:92.6%
epoch:13,train_loss:0.261573,train_acc:92.5%,test_loss:0.24773,test_acc:92.9%
epoch:14,train_loss:0.252997,train_acc:92.7%,test_loss:0.23850,test_acc:93.2%
epoch:15,train_loss:0.244789,train_acc:93.0%,test_loss:0.23177,test_acc:93.3%
epoch:16,train_loss:0.237236,train_acc:93.2%,test_loss:0.22550,test_acc:93.5%
epoch:17,train_loss:0.229649,train_acc:93.4%,test_loss:0.21949,test_acc:93.8%
epoch:18,train_loss:0.222713,train_acc:93.6%,test_loss:0.21277,test_acc:93.8%
epoch:19,train_loss:0.215915,train_acc:93.9%,test_loss:0.20566,test_acc:93.9%
epoch:20,train_loss:0.209662,train_acc:94.0%,test_loss:0.19952,test_acc:94.1%
epoch:21,train_loss:0.203692,train_acc:94.2%,test_loss:0.19441,test_acc:94.3%
epoch:22,train_loss:0.197726,train_acc:94.4%,test_loss:0.19170,test_acc:94.5%
epoch:23,train_loss:0.192294,train_acc:94.5%,test_loss:0.18431,test_acc:94.6%
epoch:24,train_loss:0.187299,train_acc:94.7%,test_loss:0.18081,test_acc:94.7%
epoch:25,train_loss:0.182354,train_acc:94.8%,test_loss:0.17475,test_acc:94.8%
epoch:26,train_loss:0.177634,train_acc:94.9%,test_loss:0.17086,test_acc:95.0%
epoch:27,train_loss:0.172992,train_acc:95.0%,test_loss:0.16722,test_acc:95.1%
epoch:28,train_loss:0.168787,train_acc:95.1%,test_loss:0.16434,test_acc:95.0%
epoch:29,train_loss:0.164727,train_acc:95.3%,test_loss:0.16026,test_acc:95.2%
epoch:30,train_loss:0.160919,train_acc:95.4%,test_loss:0.15702,test_acc:95.4%
epoch:31,train_loss:0.156920,train_acc:95.5%,test_loss:0.15420,test_acc:95.4%
epoch:32,train_loss:0.153295,train_acc:95.7%,test_loss:0.15116,test_acc:95.6%
epoch:33,train_loss:0.150030,train_acc:95.7%,test_loss:0.14879,test_acc:95.6%
epoch:34,train_loss:0.146605,train_acc:95.8%,test_loss:0.14557,test_acc:95.7%
epoch:35,train_loss:0.143500,train_acc:95.9%,test_loss:0.14370,test_acc:95.8%
epoch:36,train_loss:0.140310,train_acc:96.0%,test_loss:0.14032,test_acc:95.7%
epoch:37,train_loss:0.137446,train_acc:96.1%,test_loss:0.13812,test_acc:95.9%
epoch:38,train_loss:0.134580,train_acc:96.2%,test_loss:0.13612,test_acc:96.0%
epoch:39,train_loss:0.131750,train_acc:96.3%,test_loss:0.13432,test_acc:95.9%
epoch:40,train_loss:0.129338,train_acc:96.4%,test_loss:0.13162,test_acc:96.0%
epoch:41,train_loss:0.126620,train_acc:96.4%,test_loss:0.13051,test_acc:96.1%
epoch:42,train_loss:0.124075,train_acc:96.5%,test_loss:0.12820,test_acc:96.3%
epoch:43,train_loss:0.121680,train_acc:96.6%,test_loss:0.12575,test_acc:96.2%
epoch:44,train_loss:0.119471,train_acc:96.6%,test_loss:0.12521,test_acc:96.3%
epoch:45,train_loss:0.117078,train_acc:96.7%,test_loss:0.12264,test_acc:96.3%
epoch:46,train_loss:0.114974,train_acc:96.7%,test_loss:0.12109,test_acc:96.4%
epoch:47,train_loss:0.113026,train_acc:96.8%,test_loss:0.11942,test_acc:96.4%
epoch:48,train_loss:0.110716,train_acc:96.9%,test_loss:0.12003,test_acc:96.4%
epoch:49,train_loss:0.108877,train_acc:97.0%,test_loss:0.11783,test_acc:96.5%
Done
## 绘制图像
import matplotlib.pyplot as plt
plt.plot(range(1,epochs+1),train_loss,label="train_Loss")
plt.plot(range(1,epochs+1),test_loss,label="test_Loss",ls="--")
plt.xlabel("epoch")
plt.legend()
plt.show()
## 绘制图像
plt.plot(range(1,epochs+1),train_acc,label="train_acc")
plt.plot(range(1,epochs+1),test_acc,label="test_acc")
plt.xlabel("epoch")
plt.legend()
plt.show()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhujiannwpu/my_-all_-learning.git
git@gitee.com:zhujiannwpu/my_-all_-learning.git
zhujiannwpu
my_-all_-learning
My_All_Learning
pytorch

搜索帮助