本教程通过构建并部署Java版本的LeNet网络的训练,演示MindSpore Lite端侧训练Java接口的使用。 首先指导您在本地成功训练LeNet模型,然后讲解示例代码。
首先克隆源码,然后编译MindSpore Lite端侧训练Java包,Linux
指令如下:
git clone -b r2.2 https://gitee.com/mindspore/mindspore.git
cd mindspore
bash build.sh -I x86_64 -j8
编译环境要求以及环境变量设置,请参考编译MindSpore Lite章节。
本教程使用的示例源码在mindspore/lite/examples/train_lenet_java
目录。
示例中的MNIST
数据集由10类28*28的灰度图片组成,训练数据集包含60000张图片,测试数据集包含10000张图片。
MNIST数据集官网下载地址:http://yann.lecun.com/exdb/mnist/,共4个下载链接,分别是训练数据、训练标签、测试数据和测试标签。
下载并解压到本地,解压后的训练和测试集分别存放于/PATH/MNIST_Data/train
和/PATH/MNIST_Data/test
路径下。
目录结构如下:
MNIST_Data/
├── test
│ ├── t10k-images-idx3-ubyte
│ └── t10k-labels-idx1-ubyte
└── train
├── train-images-idx3-ubyte
└── train-labels-idx1-ubyte
首先进入示例工程所在目录,运行示例程序,命令如下:
cd /codes/mindspore/mindspore/lite/examples/train_lenet_java
./prepare_and_run.sh -D /PATH/MNIST_Data/ -r ../../../../output/mindspore-lite-${version}-linux-x64.tar.gz
../resources/model/lenet_tod.ms是示例工程中预置的LeNet训练模型,您也可以参考训练模型转换,自行转换出LeNet模型。
/PATH/MNIST_Data/是MNIST数据集所在路径。
示例运行结果如下:
==========Loading Model, Create Train Session=============
Model path is ../model/lenet_tod.ms
batch_size: 4
virtual batch multiplier: 16
==========Initing DataSet================
train data cnt: 60000
test data cnt: 10000
==========Training Model===================
step_500: Loss is 0.05553353 [min=0.010149269] max_acc=0.9543269
step_1000: Loss is 0.15295759 [min=0.0018140086] max_acc=0.96594554
step_1500: Loss is 0.018035552 [min=0.0018140086] max_acc=0.9704527
step_2000: Loss is 0.029250022 [min=0.0010245014] max_acc=0.9765625
step_2500: Loss is 0.11875624 [min=7.5288175E-4] max_acc=0.9765625
step_3000: Loss is 0.046675075 [min=7.5288175E-4] max_acc=0.9765625
step_3500: Loss is 0.034442786 [min=4.3545474E-4] max_acc=0.97686297
==========Evaluating The Trained Model============
accuracy = 0.9770633
Trained model successfully saved: ./model/lenet_tod_trained.ms
train_lenet_java
├── lib
├── build.sh
├── model
│ ├── lenet_export.py
│ ├── prepare_model.sh
│ └── train_utils.sh
├── pom.xml
├── prepare_and_run.sh
├── resources
│ └── model
│ └── lenet_tod.ms # LeNet训练模型
├── src
│ └── main
│ └── java
│ └── com
│ └── mindspore
│ └── lite
│ ├── train_lenet
│ │ ├── DataSet.java # MNIST数据集处理
│ │ ├── Main.java # Main函数
│ │ └── NetRunner.java # 整体训练流程
详细的Java接口使用请参考https://www.mindspore.cn/lite/api/zh-CN/r2.2/index.html。
加载并编译MindSpore Lite模型文件,构建会话。
MSContext context = new MSContext();
// use default param init context
context.init();
boolean isSuccess = context.addDeviceInfo(DeviceType.DT_CPU, false, 0);
TrainCfg trainCfg = new TrainCfg();
trainCfg.init();
model = new Model();
Graph graph = new Graph();
graph.load(modelPath);
model.build(graph, context, trainCfg);
model.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);
切换为训练模式,循环迭代,训练模型。
model.setTrainMode(true)
float min_loss = 1000;
float max_acc = 0;
for (int i = 0; i < cycles; i++) {
for (int b = 0; b < virtualBatch; b++) {
fillInputData(ds.getTrainData(), false);
model.runStep();
float loss = getLoss();
if (min_loss > loss) {
min_loss = loss;
}
if ((b == 0) && ((i + 1) % 500 == 0)) {
float acc = calculateAccuracy(10); // only test 10 batch size
if (max_acc < acc) {
max_acc = acc;
}
System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_acc=" + max_acc);
}
}
}
切换为推理模式,执行推理,评估模型精度。
model.setTrainMode(false);
for (long i = 0; i < tests; i++) {
Vector<Integer> labels = fillInputData(test_set, (maxTests == -1));
if (labels.size() != batchSize) {
System.err.println("unexpected labels size: " + labels.size() + " batch_size size: " + batchSize);
System.exit(1);
}
model.predict();
MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses));
if (outputsv == null) {
System.err.println("can not find output tensor with size: " + batchSize * numOfClasses);
System.exit(1);
}
float[] scores = outputsv.getFloatData();
for (int b = 0; b < batchSize; b++) {
int max_idx = 0;
float max_score = scores[(int) (numOfClasses * b)];
for (int c = 0; c < numOfClasses; c++) {
if (scores[(int) (numOfClasses * b + c)] > max_score) {
max_score = scores[(int) (numOfClasses * b + c)];
max_idx = c;
}
}
if (labels.get(b) == max_idx) {
accuracy += 1.0;
}
}
}
推理完成后,如果需要继续训练,需要切换为训练模式。
保存训练模型。
// arg 0: FileName
// arg 1: quantization type QT_DEFAULT -> 0
// arg 2: model type MT_TRAIN -> 0
// arg 3: use default output tensor names
model.export(trainedFilePath, 0, false, null);
模型训练完成后,保存到指定路径,后续可以继续加载运行。
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。