基于Pytorch的猫狗图片分类【深度学习CNN】

news2025/1/12 16:14:06

猫狗分类来源于Kaggle上的一个入门竞赛——Dogs vs Cats。为了加深对CNN的理解,基于Pytorch复现了LeNet,AlexNet,ResNet等经典CNN模型,源代码放在GitHub上,地址传送点击此处。项目大纲如下:
在这里插入图片描述


文章目录

  • 一、问题描述
  • 二、数据集处理
    • 1 损坏图片清洗
    • 2 抽取图片形成数据集
  • 三、图片预处理
    • (1)init 方法
    • (2)getitem方法
    • (3)len方法
    • (4)测试
  • 四、模型
    • 1 LeNet
    • 2 AlexNet模型
  • 五、训练
    • 1 开始训练
    • 2 tensorboard可视化
  • 六、不同模型训练结果分析
    • 1 LeNet模型
      • (1) 数据集数量=1000,无数据增强
      • (2) 数据集数量=4000,无数据增强
      • (3)数据集数量=4000,数据增强
      • (4)数据集=4000,数据增强
      • (5)使用dropout函数抑制过拟合
    • 2 AlexNet模型
    • 3 squeezeNet模型
    • 4 resNet模型
    • 总结
  • 七、预测


一、问题描述

基于训练集数据,训练一个模型,利用训练好的模型预测未知图片中的动物是狗或者猫的概率。

训练集有25,000张图片,测试集12,500 张图片。

数据集下载地址:https://www.kaggle.com/datasets/shaunthesheep/microsoft-catsvsdogs-dataset

截屏2024-02-19 15.56.01

二、数据集处理

1 损坏图片清洗

01_clean.py中,用多种方式来清洗损坏图片:

  1. 判断开头是否有JFIF
  2. 用imghdr库中的imghdr.what函数判断文件类型
  3. 用Image.open(filename).verify()验证图片是否损坏

结果如下:

截屏2022-04-20 下午1.54.15

2 抽取图片形成数据集

由于一万多张图片比较多,并且需要将Cat类和Dog类的图片合在一起并重新命名,方便获得每张图片的labels,所以可以从原图片文件夹复制任意给定数量图片到train的文件夹,并且重命名如下:

截屏2022-04-22 下午3.58.33

程序为:02_data_processing.py.

三、图片预处理

图片预处理部分需要完成:

  1. 对图片的裁剪:将大小不一的图片裁剪成神经网络所需的,我选择的是裁剪为**(224x224)**
  2. 转化为张量
  3. 归一化:三个方向归一化
  4. 图片数据增强
  5. 形成加载器:返回图片数据和对应的标签,利用Pytorch的Dataset包

dataset.py中定义Mydata的类,继承pytorch的Dataset,定义如下三个方法:

(1)init 方法

读取图片路径,并拆分为数据集和验证集(以下代码仅体现结构,具体见源码):

class Mydata(data.Dataset):
    """定义自己的数据集"""
    def __init__(self, root, Transforms=None, train=True):
        """进行数据集的划分"""
        if train:
            self.imgs = imgs[:int(0.8*imgs_num)]  #80%训练集
        else:
            self.imgs = imgs[int(0.8*imgs_num):]  #20%验证集
        """定义图片处理方式"""
        if Transforms is None:
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                            std=[0.229, 0.224, 0.225])
            self.transforms = transforms.Compose(
                    [ transforms.CenterCrop(224), 
                    transforms.Resize([224,224]),
                    transforms.ToTensor(), normalize])

(2)getitem方法

对图片处理,返回数据和标签:

 def __getitem__(self, index):
     return data, label

(3)len方法

返回数据集大小:

    def __len__(self):
        """返回数据集中所有图片的个数"""  
        return len(self.imgs)

(4)测试

实例化数据加载器后,通过调用getitem方法,可以得到经过处理后的 3 × 244 × 244 3\times244\times244 3×244×244的图片数据

if __name__ == "__main__":
    root = "./data/train"
    train = Mydata(root, train=True)  #实例化加载器
    img,label=train.__getitem__(5)    #获取index为5的图片
    print(img.dtype)
    print(img.size(),label)   
    print(len(train))    #数据集大小
#输出
torch.float32
torch.Size([3, 224, 224]) 0
3200

裁剪处理后图片如下所示,大小为224X224

截屏2022-04-22 下午5.28.56

四、模型

模型都放在 models.py中,主要用了一些经典的CNN模型:

  1. LeNet
  2. ResNet
  3. ResNet
  4. SqueezeNet

下面给出重点关注的LeNet模型和AlexNet模型:

1 LeNet

LeNet模型是一个早期用来识别手写数字图像的卷积神经网络,这个名字来源于LeNet论文的第一作者Yann LeCun。LeNet展示了通过梯度下降训练卷积神经网络可以达到手写数字识别在当时最先进的结果,LeNet模型结构图示如下所示:

截屏2022-04-29 下午7.54.44

由上图知,LeNet分为卷积层块全连接层块两个部分,在本项目中我对LeNet模型做了相应的调整

  1. 采用三个卷积层
  2. 三个全连接层
  3. ReLu作为激活函数
  4. 在卷积后正则化
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        #三个卷积层
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=16,
                kernel_size=3,
                stride=2,
            ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=3,
                stride=2,
            ),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=3,
                stride=2,
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        #三个全连接层
        self.fc1 = nn.Linear(3 * 3 * 64, 64)
        self.fc2 = nn.Linear(64, 10)
        self.out = nn.Linear(10, 2)   #分类类别为2,

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.shape[0], -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.out(x)
        return x

调用torchsummary库,可以观察模型的结构、参数:

截屏2022-04-30 上午12.35.15

2 AlexNet模型

2012年,AlexNet横空出世,这个模型的名字来源于论文第一作者的姓名Alex Krizhevsky。AlexNet使用了8层卷积神经网络,由5个卷积层和3个池化Pooling 层 ,其中还有3个全连接层构成。AlexNet 跟 LeNet 结构类似,但使⽤了更多的卷积层和更⼤的参数空间来拟合⼤规模数据集 ImageNet,它是浅层神经⽹络和深度神经⽹络的分界线。

特点:

  1. 在每个卷积后面添加了Relu激活函数,解决了Sigmoid的梯度消失问题,使收敛更快。
  2. 使用随机丢弃技术(dropout)选择性地忽略训练中的单个神经元,避免模型的过拟合(也使用数据增强防止过拟合)
  3. 添加了归一化LRN(Local Response Normalization,局部响应归一化)层,使准确率更高。
  4. 重叠最大池化(overlapping max pooling),即池化范围 z 与步长 s 存在关系 z>s 避免平均池化(average pooling)的平均效应

五、训练

训练在 main.py中,主要是对获取数据、训练、评估、模型的保存等功能的整合,能够实现以下功能:

  1. 指定训练模型、epoches等基本参数
  2. 是否选用预训练模型
  3. 接着从上次的中断的地方继续训练
  4. 保存最好的模型和最后一次训练的模型
  5. 对模型的评估:Loss和Accuracy
  6. 利用TensorBoard可视化

1 开始训练

main.py程序中,设置参数和模型(models.py中可以查看有哪些模型):

截屏2022-04-29 下午11.22.34

在vscode中点击运行或在命令行中输入:

python3 main.py

即可开始训练,开始训练后效果如下:

截屏2022-04-30 上午8.24.14

若程序中断,设置resume参数为True,可以接着上次的模型继续训练,可以非常方便的任意训练多少次

2 tensorboard可视化

在vscode中打开tensorboard,或者在命令行中进入当前项目文件夹下输入

tensorboard --logdir runs

即可打开训练中的可视化界面,可以很方便的观察模型的效果:

截屏2022-04-30 上午8.28.37

如上图所示,可以非常方便的观察任意一个模型训练过程的效果!

六、不同模型训练结果分析

1 LeNet模型

在用LeNet模型训练的过程中,通过调整数据集数量、是否用数据增强等不同的方法,来训练模型,并观察模型的训练效果。

(1) 数据集数量=1000,无数据增强

通过Tensorboard可视化可以观察到:

  1. 验证集准确率(Accuracy)在上升,训练30epoch左右,达到最终**63%**左右的最好效果
  2. 但验证集误差(Loss)也在上升,训练集误差一直下降
  3. 训练集误差接近于0

说明模型在训练集上效果好,验证集上效果不好,泛化能力差,可以推测出模型过拟合了。而这个原因也是比较好推测的,数据集比较少。

截屏2022-04-29 下午8.23.09

(2) 数据集数量=4000,无数据增强

同样过拟合了,但是最后的准确率能达到**68%**左右,说明数据集增加有效果

截屏2022-04-29 下午8.32.01

(3)数据集数量=4000,数据增强

这次数据集数量同上一个一样为4000,但采用了如下的数据增强:

  1. 水平翻转,概率为p=0.5
  2. 上下翻转,概率为p=0.1

我们可以看到这次一开始验证集误差是下降的,说明一开始没有过拟合,但到15个epoch之后验证集误差开始上升了,说明已经开始过拟合了,但最后的准确率在**71%**左右,说明数据增强对扩大数据集有明显的效果。

截屏2022-04-29 下午8.38.00

(4)数据集=4000,数据增强

这次数据集数量为4000,但采用了如下的数据增强:

  1. 水平翻转,概率为p=0.5
  2. 上下翻转,概率为p=0.5
  3. 亮度变化截屏2022-04-29 下午8.48.10

可以看到:

  1. 35个epoch之前,验证集误差呈下降趋势,准确率也一直上升,最高能到75%
  2. 但在35个epoch之后,验证集误差开始上升,准确率也开始下降

说明使用了更强的数据增强之后,模型效果更好了。

截屏2022-04-29 下午8.50.01

(5)使用dropout函数抑制过拟合

本次数据集和数据增强方式同(4),但是在模型的第一个全连接层加入dropout函数。

dropout原理:

训练过程中随机丢弃掉一些参数。在前向传播的时候,让某个神经元的激活值以一定的概率p(伯努利分布)停止工作,这样可以使模型泛化性更强。截屏2022-04-29 下午8.59.39

不使用dropout示意图 使用dropout示意图

这样相当于每次训练的是一个比较"瘦"的模型,更不容易过拟合

加入dropout函数后,训练85个epochs,可以观察到效果十分显著

  1. 验证集的误差总体呈现下降趋势,且最后没有反弹
  2. 训练集误差下降比较慢了!
  3. 准确率一直上升,最后可以达到76%

说明模型最后没有过拟合,并且效果还不错。

截屏2022-04-29 下午9.03.21

2 AlexNet模型

将AlexNet模型参数打印出来:

截屏2022-04-30 上午12.58.58

可以看到AlexNet相比LeNet,参数数目有数量级的上升,而在数据量比较小的情况下,很容易梯度消失,经过反复的调试:

  1. 要在卷积层加入正则化
  2. 优化器选择SGD
  3. 学习率不能过大

才能避免验证集的准确率一直在50%

经过调试,较好的一次结果如下所示,最终准确率能达到78%

截屏2022-04-30 上午1.10.08

3 squeezeNet模型

在后面两个模型中,使用迁移学习的方法。

**迁移学习(Transfer Learning)**是机器学习中的一个名词,是指一种学习对另一种学习> 的影响,或习得的经验对完成其它活动的影响。迁移广泛存在于各种知识、技能与社会规范> 的学习中,将某个领域或任务上学习到的知识或模式应用到不同但相关的领域或问题中。``截屏2022-04-29 下午11.58.32```

使用squeezeNet预训练模型,在迭代16个epoch后,准确率可以达到93%

截屏2022-04-29 下午11.51.43

4 resNet模型

使用resnet50的预训练模型,训练25个epoch后,准确率可以达到98%!

截屏2022-04-30 上午12.12.36

总结

模型测试集预测准确率
LeNet(无数据增强)68%
LeNet(数据增强)75%
LeNet(采用Dropout)76%
Alexnet78%
squeezeNet(迁移学习)93%
resNet98%

七、预测

模型训练好后,可以打开 predict.py对新图片进行预测,给定用来预测的模型和预测的图片文件夹:

 model = LeNet1() # 模型结构
    modelpath = "./runs/LeNet1_1/LeNet1_best.pth" # 训练好的模型路径
    checkpoint = torch.load(modelpath)  
    model.load_state_dict(checkpoint)  # 加载模型参数
  
    root = "test_pics"

运行 predict.py 会将预测的图片储存在 output文件夹中,如下图所示:

pre_04_cat

会给出预测的类别和概率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1471234.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Vue3】学习watch监视:深入了解Vue3响应式系统的核心功能(上)

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…

Linux基础命令—进程管理

基础知识 linux进程管理 什么是进程 开发写代码->代码运行起来->进程 运行起来的程序叫做进程程序与进程区别 1.程序是一个静态的概念,主要是指令集和数据的结合,可以长期存放在操作系统中 2.进程是一个动态的概念,主要是程序的运行状态,进程存在生命周期,生命周期结…

nginx.conf配置文件详解、案例,Nginx常用命令与模块

目录 一、Nginx常用命令 二、Nginx涉及的文件 2.1、Nginx 的默认文件夹 2.2、Nginx的主配置文件nginx.conf nginx.conf 配置的模块 2.2.1、全局块:全局配置,对全局生效 2.2.2、events块:配置影响 Nginx 服务器与用户的网络连接 2.2.3…

docker 容器访问 GPU 资源使用指南

概述 nvidia-docker 和 nvidia-container-runtime 是用于在 NVIDIA GPU 上运行 Docker 容器的两个相关工具。它们的作用是提供 Docker 容器与 GPU 加速硬件的集成支持,使容器中的应用程序能够充分利用 GPU 资源。 nvidia-docker 为了提高 Nvidia GPU 在 docker 中的…

Python爬虫-爬取豆瓣高分电影封面

本文是本人最近学习Python爬虫所做的小练习。如有侵权,请联系删除。 页面获取url 代码 import requests import os import re# 创建文件夹 path os.getcwd() /images if not os.path.exists(path):os.mkdir(path)# 获取全部数据 def get_data():# 地址url "…

输电线路微波覆冰监测装置助力电网应对新一轮寒潮

2月19日起,湖南迎来新一轮寒潮雨雪冰冻天气。为做好安全可靠的供电准备,国网国网湘潭供电公司迅速启动雨雪、覆冰预警应急响应,采取“人巡机巡可视化巡视”的方式,对输电线路实施三维立体巡检。该公司组织员工对1324套通道可视化装…

leetcode hot100 买卖股票的最佳时机二

注意,本题是针对股票可以进行多次交易,但是下次买入的时候必须保证上次买入的已经卖出才可以。 动态规划可以解决整个股票买卖系列问题。 dp数组含义: dp[i][0]表示第i天不持有股票的最大现金 dp[i][1]表示第i天持有股票的最大现金 递归公…

全面InfiniBand解决方案——LLM培训瓶颈问题

ChatGPT对技术的影响引发了对人工智能未来的预测,尤其是多模态技术的关注。OpenAI推出了具有突破性的多模态模型GPT-4,使各个领域取得了显著的发展。 这些AI进步是通过大规模模型训练实现的,这需要大量的计算资源和高速数据传输网络。端到端…

东莞IBM服务器维修之IBM x3630 M4阵列恢复

记录东莞某抖音电商公司送修一台IBM SYSTEM X3630 M4文档服务器RAID6故障导致数据丢失的恢复案例 时间:2024年02月20日, 服务器品牌:IBM System x3630 M4,阵列卡用的是DELL PERC H730P 服务器用途和用户位置:某抖音电…

【Flink精讲】Flink性能调优:内存调优

内存调优 内存模型 JVM 特定内存 JVM 本身使用的内存,包含 JVM 的 metaspace 和 over-head 1) JVM metaspace: JVM 元空间 taskmanager.memory.jvm-metaspace.size,默认 256mb 2) JVM over-head 执行开销&#xff1…

Spring Boot对接RocketMQ示例

部署服务 参考RocketMq入门介绍 示例 引入maven依赖 <dependency><groupId>org.apache.rocketmq</groupId><artifactId>rocketmq-spring-boot-starter</artifactId><version>2.2.2</version></dependency>完整依赖如下&am…

C# Onnx Yolov8-OBB 旋转目标检测

目录 效果 模型信息 项目 代码 下载 C# Onnx Yolov8-OBB 旋转目标检测 效果 模型信息 Model Properties ------------------------- date&#xff1a;2024-02-26T08:38:44.171849 description&#xff1a;Ultralytics YOLOv8s-obb model trained on runs/DOTAv1.0-ms.ya…

关系型数据库事务的四性ACID:原子性(Atomicity)、一致性(Consistency)、隔离性(Isolation)和持久性(Durability)

关系型数据库事务的四性ACID:原子性&#xff08;Atomicity&#xff09;、一致性&#xff08;Consistency&#xff09;、隔离性&#xff08;Isolation&#xff09;和持久性&#xff08;Durability&#xff09; 事务的四性通常指的是数据库事务的ACID属性&#xff0c;包括原子性&…

C语言第三十一弹---自定义类型:结构体(下)

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】 目录 1、结构体内存对齐 1.1、为什么存在内存对齐? 1.2、修改默认对齐数 2、结构体传参 3、结构体实现位段 3.1、什么是位段 3.2、位段的内存分配 3.3、…

qt-C++笔记之事件过滤器

qt-C笔记之事件过滤器 —— 杭州 2024-02-25 code review! 文章目录 qt-C笔记之事件过滤器一.使用事件过滤器和不使用事件过滤器对比1.1.使用事件过滤器1.2.不使用事件过滤器1.3.比较 二.Qt 中事件过滤器存在的意义三.为什么要重写QObject的eventFilter方法&#xff1f;使用QO…

数据结构:链表的冒泡排序

法一&#xff1a;修改指针指向 //法二 void maopao_link(link_p H){if(HNULL){printf("头节点为空\n");return;}if(link_empty(H)){printf("链表为空\n");return;}link_p tailNULL;while(H->next->next!tail){link_p pH;link_p qH->next;while(q…

常见的音频与视频格式

本专栏是汇集了一些HTML常常被遗忘的知识&#xff0c;这里算是温故而知新&#xff0c;往往这些零碎的知识点&#xff0c;在你开发中能起到炸惊效果。我们每个人都没有过目不忘&#xff0c;过久不忘的本事&#xff0c;就让这一点点知识慢慢渗透你的脑海。 本专栏的风格是力求简洁…

【Unity】使用Video Player播放CG视频

1.在UI上新建一个Raw Image 2.添加Video Player 【参数详解】 Source&#xff1a;视频源类型&#xff0c;有Video Clip 和 URL两种Video Clip&#xff1a;视频片段&#xff0c;当Source选择video clip生效URL&#xff1a;视频路径&#xff0c;当Source选择URL生效Play On Awak…

洛谷C++简单题小练习day21—梦境数数小程序

day21--梦境数数--2.25 习题概述 题目背景 Bessie 处于半梦半醒的状态。过了一会儿&#xff0c;她意识到她在数数&#xff0c;不能入睡。 题目描述 Bessie 的大脑反应灵敏&#xff0c;仿佛真实地看到了她数过的一个又一个数。她开始注意每一个数码&#xff08;0…9&#x…

在IDEA中创建vue hello-world项目

工作中最近在接触vue前端项目&#xff0c;记录一下从0搭建一个vue hello world项目的步骤 1、本地电脑安装配置node、npm D:\Project\vue\hello-world>node -v v14.21.3 D:\Project\vue\hello-world>npm -v 6.14.18 D:\Project\vue\hello-world> 2、设置npm国内淘…