目录
- 一、简介:
- 二、图片分类网络
-
- 1.记载训练数据(torch自带的cifa10数据集)
- 2.数据增强
- 3.模型构建
- 4.模型训练
- 三、完整源码及文档
一、简介:
基于残差连接的图片分类网络,本网络使用ResNet18作为基础模块,根据cifa10的特点进行改进网络,使用交叉熵损失函数和SGD优化器。本网络在cifa10数据集上不使用预训练参数,经过数据增强,训练30轮达到了85%的分类准确率。
二、图片分类网络
1.记载训练数据(torch自带的cifa10数据集)
2.数据增强
数据增强防止过拟合,将图像数据进行标准化、缩放
3.模型构建
改模型:原始的resnet18首层使用的7x7的卷积核,CIFAR10图片太小不适合,要改成3x3的,步长和padding都要一并改成1。因为图太小,最大池化层也同样没用,删掉。最后一个全连接层输出改成10。
先定义一个残差类(继承NN.module,后面重复使用残差):
分类模型构建: