# tensorflow_classfication **Repository Path**: lakuite/tensorflow_classfication ## Basic Information - **Project Name**: tensorflow_classfication - **Description**: tensorflow1.14实现分类任务 - **Primary Language**: Python - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 2 - **Forks**: 0 - **Created**: 2020-06-03 - **Last Updated**: 2025-05-26 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # tensorflow_classfication ** 详细说明参考博客:https://blog.csdn.net/qq_35756383/article/details/106534906 ** #### 1. 利用全连接网络识别手写体数字 (1) 全连接网络模型 ![全连接网络模型](https://images.gitee.com/uploads/images/2020/0604/012628_8eb816b3_1068231.png "全连接网络模型.png") (2) 实验过程描述与实验结果 - 实验要求:用全连接网络实现对mnist数据集的分类识别任务,对训练集训练1000step,每个step的batch_size=100,每训练50step打印一次测试集的准确率。 - 实验过程:对输入图标和输出标签设置占位符xs和ys,定义权重Weights和偏置biases,求得xs*Weights+biases的值Wx_plus_b,最后对Wx_plus_b进行softmax操作,获得一个预测的概率分布。选择概率最大的一项作为预测标签,将预测标签与真实标签对比,求得整个测试集的准确率。 - 实验结果: ![fulayer_res1](https://images.gitee.com/uploads/images/2020/0604/012652_2a505353_1068231.png "屏幕截图.png") ![fullayer_res2](https://images.gitee.com/uploads/images/2020/0604/012714_0d956519_1068231.png "屏幕截图.png") #### 2. 利用CNN网络识别手写体数字 (1) CNN网络模型 ![CNN网络模型](https://images.gitee.com/uploads/images/2020/0604/013045_790b8429_1068231.png "屏幕截图.png") (2) 实验过程描述与实验结果 - 实验要求:用简单cnn网络(2层卷积池化+2层全连接)实现对mnist数据集的分类识别任务,batch_size=100,keep_prob=0.5,每100step输出一次训练集acc和loss,训练完成后输出测试集的acc。改变lr,step,优化器,输出其训练完成后的训练集loss和测试集acc的对比结果。 - 实验过程:设置2层卷积池化层,使得输入图片大小由28x28x1变为7x7x64,再经过2层全连接层,获得大小为10的一个概率分布。选择概率最大的一项作为预测标签,将预测标签与真实标签对比,求得整个数据集的准确率。使用交叉熵损失函数,求得训练过程中的loss值。 - 实验结果: 训练过程(lr=0.0001, step=2000, Optimizer=Adam): ![cnn_mnist1](https://images.gitee.com/uploads/images/2020/0604/013116_668edc41_1068231.png "屏幕截图.png") ![cnn_mnist2](https://images.gitee.com/uploads/images/2020/0604/013136_3021bc48_1068231.png "屏幕截图.png") 对比表格: ![cnn_mnist_res_table](https://images.gitee.com/uploads/images/2020/0604/013214_e42d87aa_1068231.png "屏幕截图.png") ##### 3. 利用CNN网络识别人脸 (1) 制作训练集标签 使用orl人脸数据集制作标签,该数据集共400张图片,共40类,每类图片10张,图片大小56x46灰度图。选择每类图片的前9张作为训练集(共360张),最后一张作为测试集(共40张)。并构建one-hot编码的数据集标签。 (2) 实验过程描述与实验结果 - 实验要求:使用制作好的数据集,用3中的cnn网络结构来实现对orl人脸数据集的分类识别。训练完成后输出测试集的准确率,以及测试集识别错误的图片的标签(参数设置:lr=0.001,epoch=30,batch_size=40)。 - 实验过程:先预处理原orl数据集,实现对训练集和测试集的划分。将处理后的训练集和测试集图片数据输入到数组train_data和test_data中,并制作相应的label。最后使用该数据进行训练和测试。 - 实验结果: ![cnn_orl1](https://images.gitee.com/uploads/images/2020/0604/013304_c26c58a5_1068231.png "屏幕截图.png") ![cnn_orl2](https://images.gitee.com/uploads/images/2020/0604/013310_d6e65ce9_1068231.png "屏幕截图.png") ![cnn_orl3](https://images.gitee.com/uploads/images/2020/0604/013315_f1c7491c_1068231.png "屏幕截图.png") pic后的数表示图片的真实类别,label后的数表示识别错误的图片误识别成的类别。 #### 4. 训练集样本增强 (1) 数据增强方法 a. 基于基本图像处理技术的数据增强算法和基于深度学习的数据增强算法; b. 离线增强(对数据集全部处理完再训练)和在线增强(每次获取一个batch的数据来进行增强); c. 方法: 添加噪声:椒盐噪声、高斯噪声、均值滤波等; 图形处理:调整大小、裁剪填充、图片翻转、调整亮度对比度饱和度、图像标准化等。 (2) 实验过程描述与对比实验结果 - 实验要求:处理GT数据集(训练集70张,测试10张,rgb图,每张图大小不同),用3的网络实现对其的人脸识别分类任务,并对比使用和不使用数据增强的结果。 - 实验过程:预处理GT数据集,统一图片大小为83x57的灰度图。使用4的方法对该数据集进行训练,获取结果。对训练集70张图片均进行2中数据增强操作:上采样(放大任意倍数再裁剪回原本大小)和180°旋转,获得210张训练集图片,使用同样的参数训练,获取结果。 - 实验结果(测试集): 数据增强前: ![dataAug_pre](https://images.gitee.com/uploads/images/2020/0604/013428_a59eff10_1068231.png "屏幕截图.png") 数据增强后: ![dataAug_after](https://images.gitee.com/uploads/images/2020/0604/013436_47b180bc_1068231.png "屏幕截图.png") #### 参考文档 - [MNIST机器学习入门](http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html) - [深入MNIST](http://www.tensorfly.cn/tfdoc/tutorials/mnist_pros.html) - [CNN Explaniner](https://poloclub.github.io/cnn-explainer/) - [python文件处理常用代码](https://blog.csdn.net/qq_35756383/article/details/105084180) - [tf.one_hot()函数简介](https://blog.csdn.net/nini_coded/article/details/79250600) - [Tensorflow将自己的数据分割成batch训练](https://blog.csdn.net/sinat_35821976/article/details/82668555?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase) - [数据增强的方法总结及代码实现](https://blog.csdn.net/comway_Li/article/details/82928974) - [Python中读取,显示,保存图片的方法](https://www.cnblogs.com/Terrypython/p/9925885.html)