在面对复杂网络时,往往需要进行几十甚至几百次的epoch训练。而在训练之前,往往很难掌握在训练到第几个epoch时,模型的精度能达到满足要求的程度。所以经常会采用一边训练的同时,在相隔固定epoch的位置对模型进行精度验证,并保存相应的模型,等训练完毕后,通过查看对应模型精度的变化就能迅速地挑选出相对最优的模型,本文将采用这种方法,以LeNet网络为样本,进行示例。
整体流程如下:
训练数据集下载地址:{"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz ", "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz "}。
测试数据集:{"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"}
数据集放在----*Jupyter工作目录+\MNIST_Data*,如下图结构:
MNIST
├── test
│ ├── t10k-images-idx3-ubyte
│ └── t10k-labels-idx1-ubyte
└── train
├── train-images-idx3-ubyte
└── train-labels-idx1-ubyte
下载下来后的数据集,需要通过mindspore.dataset
处理成适用于MindSpore框架的数据,再使用一系列框架中提供的工具进行数据增强操作来适应LeNet网络的数据处理需求。
import os
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore.common import dtype as mstype
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# define dataset
mnist_ds = ds.MnistDataset(data_path)
# define map operations
resize_op = CV.Resize((32, 32), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)
rescale_op = CV.Rescale(1/255.0, 0.0)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=[resize_op,rescale_op,rescale_nml_op,hwc2chw_op],
input_columns="image", num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
LeNet网络属于7层神经网络,其中涉及卷积层,全连接层,函数激活等算法,在MindSpore中都已经建成相关算子只需导入使用,如下先将卷积函数,全连接函数,权重等进行初始化,然后在LeNet5中定义神经网络并使用construct
构建网络。
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""Conv layer weight initial."""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""Fc layer weight initial."""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""Weight initial."""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
"""Lenet network structure."""
# define the operator required
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
# use the preceding operators to construct networks
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
实现思想:每隔n个epoch验证一次模型精度,由于在自定义函数中实现,如需了解自定义回调函数的详细用法,请参考API说明。
核心实现:回调函数的epoch_end
内设置验证点,如下:
cur_epoch % eval_per_epoch == 0
:即每eval_per_epoch
个epoch结束时,验证一次模型精度。
cur_epoch
:当前训练过程的epoch数值。eval_per_epoch
:用户自定义数值,即验证频次。其他参数解释:
model
:即是MindSpore中的Model
函数。eval_dataset
:验证数据集。epoch_per_eval
:记录验证模型的精度和相应的epoch数,其数据形式为{"epoch":[],"acc":[]}
。from mindspore.train.callback import Callback
class EvalCallBack(Callback):
def __init__(self, model, eval_dataset, eval_per_epoch, epoch_per_eval):
self.model = model
self.eval_dataset = eval_dataset
self.eval_per_epoch = eval_per_epoch
self.epoch_per_eval = epoch_per_eval
def epoch_end(self, run_context):
cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num
if cur_epoch % self.eval_per_epoch == 0:
acc = self.model.eval(self.eval_dataset, dataset_sink_mode=True)
self.epoch_per_eval["epoch"].append(cur_epoch)
self.epoch_per_eval["acc"].append(acc["Accuracy"])
print(acc)
在保存模型的参数CheckpointConfig
中,需计算好单个epoch中的step数,再根据需要进行验证模型精度的频次对应,
本次示例为1875个step/epoch,按照每两个epoch验证一次的思想,这里设置save_checkpoint_steps=eval_per_epoch*1875
,
其中变量eval_per_epoch
等于2。
参数解释:
train_data_path
:训练数据集地址。eval_data_path
:验证数据集地址。train_data
:训练数据集。eval_data
:验证数据集。net_loss
:定义损失函数。net-opt
:定义优化器函数。config_ck
:定义保存模型信息。
save_checkpoint_steps
:每多少个step保存一次模型。keep_checkpoint_max
:设置保存模型数量的上限。ckpoint_cb
:定义模型保存的名称及路径信息。model
:定义模型。model.train
:模型训练函数。epoch_per_eval
:定义收集epoch
数和对应模型精度信息的字典。from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
from mindspore import context
from mindspore.nn.metrics import Accuracy
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
train_data_path = "./MNIST_Data/train"
eval_data_path = "./MNIST_Data/test"
ckpt_save_dir = "./lenet_ckpt"
epoch_size = 10
eval_per_epoch = 2
repeat_size = 1
network = LeNet5()
train_data = create_dataset(train_data_path, repeat_size=repeat_size)
eval_data = create_dataset(eval_data_path, repeat_size=repeat_size)
# define the loss function
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define the optimizer
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
config_ck = CheckpointConfig(save_checkpoint_steps=eval_per_epoch*1875, keep_checkpoint_max=15)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=ckpt_save_dir, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
epoch_per_eval = {"epoch": [], "acc": []}
eval_cb = EvalCallBack(model, eval_data, eval_per_epoch, epoch_per_eval)
model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375), eval_cb],
dataset_sink_mode=True)
epoch: 1 step: 375, loss is 2.3058078 epoch: 1 step: 750, loss is 2.3073978 epoch: 1 step: 1125, loss is 2.3103657 epoch: 1 step: 1500, loss is 0.65018296 epoch: 1 step: 1875, loss is 0.07800862 epoch: 2 step: 375, loss is 0.010344766 epoch: 2 step: 750, loss is 0.052723818 epoch: 2 step: 1125, loss is 0.08183526 epoch: 2 step: 1500, loss is 0.007430988 epoch: 2 step: 1875, loss is 0.0076965275 {'Accuracy': 0.9753605769230769} epoch: 3 step: 375, loss is 0.11964749 epoch: 3 step: 750, loss is 0.04522314 epoch: 3 step: 1125, loss is 0.018271001 epoch: 3 step: 1500, loss is 0.006928641 epoch: 3 step: 1875, loss is 0.15374172 epoch: 4 step: 375, loss is 0.12120275 epoch: 4 step: 750, loss is 0.122824214 epoch: 4 step: 1125, loss is 0.0023852547 epoch: 4 step: 1500, loss is 0.018273383 epoch: 4 step: 1875, loss is 0.08102103 {'Accuracy': 0.9821714743589743} epoch: 5 step: 375, loss is 0.12944886 epoch: 5 step: 750, loss is 0.0010141768 epoch: 5 step: 1125, loss is 0.0054096584 epoch: 5 step: 1500, loss is 0.0022614016 epoch: 5 step: 1875, loss is 0.07229582 epoch: 6 step: 375, loss is 0.0025749032 epoch: 6 step: 750, loss is 0.06261393 epoch: 6 step: 1125, loss is 0.021273317 epoch: 6 step: 1500, loss is 0.011360342 epoch: 6 step: 1875, loss is 0.12855275 {'Accuracy': 0.9853766025641025} epoch: 7 step: 375, loss is 0.09330422 epoch: 7 step: 750, loss is 0.002063415 epoch: 7 step: 1125, loss is 0.0047940286 epoch: 7 step: 1500, loss is 0.0052507296 epoch: 7 step: 1875, loss is 0.018066114 epoch: 8 step: 375, loss is 0.08678668 epoch: 8 step: 750, loss is 0.02440551 epoch: 8 step: 1125, loss is 0.0017507032 epoch: 8 step: 1500, loss is 0.02957578 epoch: 8 step: 1875, loss is 0.0023948685 {'Accuracy': 0.9863782051282052} epoch: 9 step: 375, loss is 0.012376097 epoch: 9 step: 750, loss is 0.029711302 epoch: 9 step: 1125, loss is 0.017438065 epoch: 9 step: 1500, loss is 0.015443239 epoch: 9 step: 1875, loss is 0.0031764025 epoch: 10 step: 375, loss is 0.0005294987 epoch: 10 step: 750, loss is 0.0015696918 epoch: 10 step: 1125, loss is 0.019949459 epoch: 10 step: 1500, loss is 0.004248183 epoch: 10 step: 1875, loss is 0.07389321 {'Accuracy': 0.9824719551282052}
在同一目录的文件夹中可以看到lenet_ckpt
文件夹中,保存了5个模型,和一个计算图相关数据,其结构如下:
lenet_ckpt
├── checkpoint_lenet-10_1875.ckpt
├── checkpoint_lenet-2_1875.ckpt
├── checkpoint_lenet-4_1875.ckpt
├── checkpoint_lenet-6_1875.ckpt
├── checkpoint_lenet-8_1875.ckpt
└── checkpoint_lenet-graph.meta
定义绘图函数eval_show
,将epoch_per_eval
载入到eval_show
中,绘制出不同epoch
下模型的验证精度折线图。
import matplotlib.pyplot as plt
def eval_show(epoch_per_eval):
plt.xlabel("epoch number")
plt.ylabel("Model accuracy")
plt.title("Model accuracy variation chart")
plt.plot(epoch_per_eval["epoch"], epoch_per_eval["acc"], "red")
plt.show()
eval_show(epoch_per_eval)
从上图可以一目了然地挑选出需要的最优模型。
本例使用MNIST数据集通过卷积神经网络LeNet5进行训练,着重介绍了利用回调函数在进行模型训练的同时进行模型的验证,保存对应epoch
的模型,并从中挑选出最优模型的方法。
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。