PyTorch深度学习实战——基于ResNet模型实现猫狗分类

news2025/4/19 3:07:36

PyTorch深度学习实战——基于ResNet模型实现猫狗分类

    • 0. 前言
    • 1. ResNet 架构
    • 2. 基于预训练 ResNet 模型实现猫狗分类
    • 相关链接

0. 前言

VGG11VGG19,不同之处仅在于网络层数,一般来说,神经网络越深,它的准确率就越高。但并非仅增加网络层数,就可以获得更准确的结果,随着网络层数的增加可能会出现以下问题:

  • 梯度消失和爆炸:在网络层次过深的情况下,反向传播可能会面临梯度消失和爆炸的问题,导致训练网络时无法收敛
  • 过拟合:增加网络深度会带来更多的参数,如果数据样本过少或网络过于复杂,会导致网络过拟合,降低模型的泛化能力

总之,在构建的神经网络过深时,有两个问题:前向传播中,网络的最后几层几乎没有学习到有关原始图像的任何信息;在反向传播中,由于梯度消失(梯度值几乎为零),靠近输入的前几层几乎没有任何梯度更新。
深度残差网络 (ResNet) 的提出就是为了解决上述问题。在 ResNet 中,如果模型没有什么要学习的,那么卷积层可以什么也不做,只是将上一层的输出传递给下一层。但是,如果模型需要学习其他一些特征,则卷积层将前一层的输出作为输入,并学习完成目标任务所需的其它特征。

1. ResNet 架构

ResNet 通过残差结构解决网络过深时出现的问题,让模型能够训练得更深。经典的 ResNet 架构如下所示:

ResNet架构
残差结构的基本思想是:每一个残差块都不是直接映射输入信号到输出信号,而是通过学习残差映射来实现:
F ( x ) = H ( x ) − x F(x)=H(x)−x F(x)=H(x)x

其中, x x x 是输入, H ( x ) H(x) H(x) 是一个表示所需映射的基本块,而 F ( x ) F(x) F(x) 是残差块学习到的映射。换句话说,输入 x x x 通过卷积层,得到特征变换后的输出 F ( x ) F(x) F(x),与输入 x x x 进行逐元素的相加运算,得到最终输出 H ( x ) H(x) H(x)

H ( x ) = x + F ( x ) H(x) = x + F(x) H(x)=x+F(x)

如果某个基本块为恒等映射,则残差块的学习目标就变为学习 F ( x ) = 0 F(x)=0 F(x)=0,也就是让输入信号直接到达残差块的输出层。这样就可以解决梯度消失的问题,可以训练更深的神经网络。
实现过程中 ResNet 中使用 Shortcut Connection (也称跳跃连接, Skip Connection )在残差块中实现跨层连接,从而实现信息的直接传递,跨层连接可以绕过一个或多个卷积层,直接将网络中的浅层信息传递到深层中。
ResNet 的残差块中,Shortcut Connection 经常与卷积层或批归一化 (Batch Normalization) 相结合。通过该连接,残差块的激活张量可以直接和下一层的输出相加,理论上,即使是最后一层可能拥有原始图像的全部信息,并且反向传播过程中梯度将可以在几乎没有修改的情况下自由地流向浅层。典型的残差块如下所示:

残差模块

在传统中顺序堆叠的神经网络中,神经网络通常直接学习 F ( x ) F(x) F(x),其中 x 是来自前一层的输出值,而在残差网络中,利用跳跃连接,将残差信号 F ( x ) F(x) F(x) 加上恒等映射 x x x 得到最终的输出 H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x。接下来,我们通过在 PyTorch 中构建残差块来深入了解残差网络。

2. 基于预训练 ResNet 模型实现猫狗分类

(1)__init__ 方法中定义一个带有卷积操作的类:

from torch import nn

class ResLayer(nn.Module):
    def __init__(self,ni,no,kernel_size,stride=1):
        super(ResLayer, self).__init__()
        padding = kernel_size - 2
        self.conv = nn.Sequential(
            nn.Conv2d(ni, no, kernel_size, stride, 
                    padding=padding),
            nn.ReLU()
        )

在以上代码中,为了确保通过卷积后输出的尺寸保持不变,以便于将输入与卷结果相加,我们通过 padding 控制卷积时输出的尺寸。

(2) 定义 forward 方法:

    def forward(self, x):
        return self.conv(x) + x

在以上代码中,得到的输出是通过卷积操作的输入和原始输入之和。

PyTorch 中预训练的基于残差块的 ResNet18 架构如下:

请添加图片描述
该架构有 18 个可训练网络层,因此被称为 ResNet18 架构。此外,需要注意的是,ResNet18 并不是每个卷积层都会添加跳跃连接,而是在每两层之后使用跳跃连接。
了解了 ResNet 架构之后,构建一个基于预训练 ResNet18 架构的模型来执行狗猫分类任务。构建分类器的流程可以参考在迁移学习中使用预训练 VGG16 模型构建的猫狗分类器。

(3) 加载预训练 ResNet18 模型并检查模型中的模块:

model = models.resnet18(pretrained=True)

ResNet18 模型架构包含以下组件:

  • 卷积层
  • 批归一化
  • ReLU 激活
  • 最大池化层
  • 4ResNet
  • 平均池化 (avgpool) 层
  • 全连接层 (fc) 层

冻结特征提取模块的网络权重,仅替换 avgpoolfc 层并更新其中的参数。

(4) 定义模型架构、损失函数和优化器:

def get_model():
    model = models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    model.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))
    model.fc = nn.Sequential(nn.Flatten(),
    nn.Linear(512, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 1),
    nn.Sigmoid())
    loss_fn = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3)
    return model.to(device), loss_fn, optimizer

在模型中,fc 模块的输入形状为 512,因为 avgpool 的输出形状为 batch size x 512 x 1 x 1。定义了模型后,训练模型,随着 epoch 的增加,模型训练和验证准确率的变化(对应模型分别为 ResNet18ResNet34ResNet50ResNet101ResNet152) 如下:

模型训练和验证准确率
仅对 1000 张图像进行训练时,模型的准确率就可以达到 98% 左右,且准确率随着 ResNet 层数的增加而增加。

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1016605.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux学习第12天:基于API函数的字符设备驱动开发:一字一符总见情

本节学习的内容主要为基于LinuxAPI函数的字符设备驱动的开发,还包括在驱动模块加载的时候如何自动创建设备节点。总结的脑图如下: 一、驱动原理 1.分配和释放设备号 申请设备号函数: int alloc_chrdev_region(dev_t *dev, unsigned basemin…

改进YOLOv5小目标检测:构建多尺度骨干和特征增强模块,提升小目标检测

构建多尺度骨干和特征增强模块,提升小目标检测 背景代码使用配置文件如下🔥🔥🔥 提升小目标检测,创新提升 🔥🔥🔥 测试在小目标数据集进行提点 👉👉👉: 新设计的创新想法,包含详细的代码和说明,具备有效的创新组合 🐤🐤🐤 1. 本文包含两个创新改…

SQL优化--count优化

select count(*) from tb_user ;在之前的测试中,我们发现,如果数据量很大,在执行count操作时,是非常耗时的。 MyISAM 引擎把一个表的总行数存在了磁盘上,因此执行 count(*) 的时候会直接返回这个 数,效率很…

档案管理系统设计与实现

摘 要 近年来,随着企业彼此间的竞争日趋激烈,信息技术在企业的发展中占据着越来越重要的地位。在企业的运输生产中,档案已成为企业运输经营中不可或缺的一部分,为管理者进行管理决策和进行各种经营活动提供了重要的依据&#xf…

前后端分离--Vue的入门基础版

目录 一.前后端分离 二.Vue的简介 三.Vue的入门案例 四.Vue的生命周期 一.前后端分离 前后端分离是一种软件架构模式,将应用程序的前端(用户界面)和后端(数据处理和业务逻辑)独立开发、独立部署。在前后端分离的架…

【数据结构】AVL树的删除(解析有点东西哦)

文章目录 前言一、普通二叉搜索树的删除1. 删除结点的左右结点都不为空2. 删除结点的左结点为空,右节点不为空3. 删除结点的右结点为空,左节点不为空4. 删除结点的左右结点都不为空 二、AVL树的删除1. 删除结点,整棵树的高度不变化1.1 parent…

RISV-V架构的寄存器介绍

1、RISC-V的通用寄存器 (1)在编写汇编代码时,使用寄存器的ABI名字,一般不直接使用寄存器的编号; (2)x0-x31是用来做整形运算的寄存器,f0-f31是用来做浮点数运算的寄存器;…

傅里叶变换应用 (01/2):频域和相位

一、说明 我努力理解傅里叶变换,直到我将这个概念映射到现实世界的直觉上。这是一系列技术性越来越强的解释中的第一篇文章。我希望直觉也能帮助你! 二、傅里叶变换中频域简介 声音是一种机械波,是空气中的振动或其他介质。音符对应于波的频率…

【LeetCode75】第五十七题 电话号码的字母组合

目录 题目: 示例: 分析: 代码: 题目: 示例: 分析: 给我们按下的按键,让我们返回对应按键可能产生的所有可能。 这是一道很经典的递归题,我们首先先拿一个数组把每个…

day45:C++ day5,运算符重载剩余部分、静态成员、继承

#include <iostream> #include <cstring> #define pi 3.14 using namespace std;class Shape { protected:double round;double area; public://无参构造Shape():round(40),area(100){cout<<"Shape::无参构造函数&#xff0c;默认周长为40&#xff0c;面…

C语言入门Day_21 函数的使用

目录 前言&#xff1a; 1.变量作用域 2.代码执行顺序 3.易错点 4.思维导图 前言&#xff1a; 我们是先定义函数&#xff0c;再调用函数。完成了函数的定义以后&#xff0c;我们就可以开始调用函数了&#xff0c;让我们来回顾一下&#xff1a; 调用函数分为两部分&#…

1131. 绝对值表达式的最大值

1131. 绝对值表达式的最大值 原题链接&#xff1a;完成情况&#xff1a;解题思路&#xff1a;求方向一次遍历两度统计 参考代码&#xff1a;求方向一次遍历两度统计 原题链接&#xff1a; 1131. 绝对值表达式的最大值 https://leetcode.cn/problems/maximum-of-absolute-val…

网络安全深入学习第二课——热门框架漏洞(RCE—Thinkphp5.0.23 代码执行)

文章目录 一、什么是框架&#xff1f;二、导致框架漏洞原因二、使用步骤三、ThinkPHP介绍四、Thinkphp框架特征五、Thinkphp5.0.23 远程代码执行1、漏洞影响范围2、漏洞成因 六、POC数据包Windows下的Linux下的 七、漏洞手工复现1、先Burp抓包&#xff0c;把抓到的请求包发送到…

【AI语言大模型】文心一言功能使用介绍

一、前言 文心一言是一个知识增强的大语言模型,基于飞桨深度学习平台和文心知识增强大模型,持续从海量数据和大规模知识中融合学习具备知识增强、检索增强和对话增强的技术特色。 最近收到百度旗下产品【文心一言】的产品,抱着试一试的心态体验了一下,整体感觉:还行! 二…

自动化测试API【软件测试】

自动化测试 selenium 1. 为什么使用selenium&#xff1f; 开源免费支持多浏览器支持多系统支持多语言编程提供了丰富的web自动化测试API 2. API 查找页面元素 find Element() find Elements() 元素定位的方式 xpath selector 通常情况下&#xff0c;不需要手动来编写xpath…

Learn Prompt-经验法则

还记得我们在“基础用法”当中提到的三个经验法则吗&#xff1f; 尝试提示的多种表述以获得最佳结果使用清晰简短的提示&#xff0c;避免不必要的词语减少不精确的描述 现在经过了几页的学习&#xff0c;我认为是时候引入一些新的原则了。 3. 一个话题对应一个chat​ ChatG…

Kafka开篇

前言 从本篇开始对个人Kafka学习做一个总结, 目标有这么几个。 从概念架构角度, 对消息中间件形成概要认知;从使用角度, 掌握其常见用法;从性能角度, 探究其高性能实现机制; 消息中间件的用途 从消息生产和消费的角度, 平衡消费者和消费者的速率差。基于该点可以做到削峰填…

白炽灯对新生儿视力有影响吗?推荐专业的儿童台灯

大家都知道婴儿还在成长发育的重要阶段&#xff0c;身体各方面都是比较脆弱的&#xff0c;对外界事务的感知也很敏感&#xff0c;一点点的刺激都会影响的婴儿。而白炽灯是否适合婴儿使用这个问题&#xff0c;我的建议是尽量不要用白炽灯。 因为白炽灯光线不是很柔和&#xff0c…

周易算卦流程c++实现

代码 #include<iostream> using namespace std; #include<vector> #include<cstdlib> #include<ctime> #include<Windows.h>int huaYiXiangLiang(int all, int& left) {Sleep(3000);srand(time(0));left rand() % all 1;while (true) {if…

许可分析 license分析 第十二章

许可分析是指对软件许可证进行详细的分析和评估&#xff0c;以了解组织内部对软件许可的需求和使用情况。通过许可分析&#xff0c;可以帮助组织更好地管理和优化软件许可证的使用。以下是一些可能的许可分析方法和步骤&#xff1a; 软件许可证的分配和使用权限&#xff1a;制定…