文章目录
- 前言
- 简介
- 数据集
- 项目结构
- utils模块
- dataLoader
- models
- plotShow
- train模块
- predict模块
- 下载地址
前言
在前面的两篇文章中我们介绍了现代计算机视觉中常见的结构化和非结构化的CNN模型,本篇我们将使用这些CNN模型在手写数学符号数据集上进行识别。
CNN模型的介绍请参照之前的两篇文章,源码放到最后。
pytorch深度学习基础(十)——常用线性CNN模型的结构与训练
pytorch深度学习基础(十一)——常用结构化CNN模型构建
简介
数据集
所用的数据集是来自kaggle
的Handwritten math symbols dataset
,其中包括超过30w张图片,共有82个类别。解压后的数据放到extracted_images中作为数据
数据集下载地址:
Handwritten math symbols dataset
项目结构
项目的组织形式如下
utils模块
utils中包括数据的加载,模型以及画图展示
dataLoader
idxPrepare
传入数据所在的路径,获取标签与索引的对应关系并以字典的形式保存,并返回由(图片,类别)组成的列表
image2txt
传入由(图片,类别)组成的列表,将数据集划分成训练数据和测试数据。并将路径以及对应的标签存放到txt文件中
MyLoader
使用torchvsion加载图片
MyDataLoader
由于数据量稍微有些大,直接使用torch的数据加载有可能会导致OOM,所以构建了一个数据加载的类,这个类并不会将所有的数据一次性全部加载到内存,而是将存有数据路径和对应的类别先全部加载到内存中,当需要用到数据时,再将数据从磁盘中读到内存当中。
LoadDataset
为了方便加载数据,构建了一个数据加载器,传入数据的路径,数据的批量大小和图片的大小,返回训练数据和测试数据的迭代器以供模型的训练
models
包含各种常见的CNN分类模型,包括LeNet、alexNet、vgg11、NiN、GoogLeNet、resNet18、denseNet模型的构建
select_model
用于选择模型,传入模型的名称以及模型的参数,返回选择的分类模型
plotShow
传入由训练损失、训练精度、测试损失、测试精度组成的字典,然后绘制出图像
train模块
用于训练模型
accuracy
传入预测结果和标签,用于评估预测精度
train
传入选择使用的模型,模型参数,训练数据、测试数据、训练轮次、学习率、训练设备、提前终止训练的阈值以及是否保存检查点,进行模型的训练,最终训练结果将会保存在model_weights
文件夹中
predict模块
根据不同的参数设置可以进行预测,可选的预测模式有单张图片预测,从文件夹中预测以及随机选取测试集中的数据进行预测
下载地址
GitHub地址:Handwritten-math-symbols-recognition