经典神经网络(6)ResNet及其在Fashion-MNIST数据集上的应用

news2024/12/26 21:36:37

经典神经网络(6)ResNet及其在Fashion-MNIST数据集上的应用

1 ResNet的简述

  1. ResNet 提出了一种残差学习框架来解决网络退化问题,从而训练更深的网络。这种框架可以结合已有的各种网络结构,充分发挥二者的优势。

  2. ResNet以三种方式挑战了传统的神经网络架构:

    • ResNet 通过引入跳跃连接来绕过残差层,这允许数据直接流向任何后续层。

      这与传统的、顺序的pipeline 形成鲜明对比:传统的架构中,网络依次处理低级feature 到高级feature

    • ResNet 的层数非常深,高达1202层。而ALexNet 这样的架构,网络层数要小两个量级。

    • 通过实验发现,训练好的 ResNet 中去掉单个层并不会影响其预测性能。而训练好的AlexNet 等网络中,移除层会导致预测性能损失。

  3. ImageNet分类数据集中,拥有152层的残差网络,以3.75% top-5 的错误率获得了ILSVRC 2015 分类比赛的冠军。

  4. 很多证据表明:残差学习是通用的,不仅可以应用于视觉问题,也可应用于非视觉问题。

  5. 论文地址: https://arxiv.org/pdf/1512.03385.pdf

  6. 卷积神经网络领域的两次技术爆炸,第一次是AlexNet,第二次就是ResNet了。

1.1 网络退化问题

  • 1、理论上来讲网络深度越深越好。网络越深,提取的图片特征越多越丰富,但随之会带来很多的问题(通过Batch Normalization 在很大程度上解决),比如过拟合或者计算量爆炸、梯度消失、梯度爆炸等,导致网络在一定深度下就达到了局部最优解。

  • 2、ResNet 论文作者发现:随着网络的深度的增加,准确率达到饱和之后迅速下降,而这种下降不是由过拟合引起的。这称作网络退化问题。如果更深的网络训练误差更大,则说明是由于优化算法引起的:越深的网络,求解优化问题越难。如下所示:更深的网络导致更高的训练误差和测试误差。

在这里插入图片描述

  • 3、理论上讲,较深的模型不应该比和它对应的、较浅的模型更差。因为较深的模型是较浅的模型的超空间。较深的模型可以这样得到:先构建较浅的模型,然后添加很多恒等映射的网络层。实际上我们的较深的模型后面添加的不是恒等映射,而是一些非线性层。因此,退化问题表明:通过多个非线性层来近似横等映射可能是困难的

在这里插入图片描述

  • 4、针对这⼀问题,何恺明等⼈提出了残差⽹络(ResNet)。它在2015年的ImageNet图像识别挑战赛夺魁,并深刻影响了后来的深度神经⽹络的设计。残差⽹络的核⼼思想是:每个附加层都应该更容易地包含原始函数作为其元素之⼀

1.2 残差块(residual blocks)

1.2.1 残差块的理解

在这里插入图片描述

1、假设需要学习的是映射 y = H(x),残差块使用堆叠的非线性层拟合残差:y = F(x,W) + x 。

其中:

  • x 和 y 是块的输入和输出向量。
  • F(x,W)是要学习的残差映射。因为 F(x,W) = H(x) - x,因此称F为残差。
  • + :通过快捷连接逐个元素相加来执行。快捷连接 指的是那些跳过一层或者更多层的连接。
    • 快捷连接简单的执行恒等映射,并将其输出添加到堆叠层的输出。
    • 快捷连接既不增加额外的参数,也不增加计算复杂度。
  • 相加之后通过非线性激活函数,这可以视作对整个残差块添加非线性,即 relu(y)

2、残差映射易于捕捉恒等映射的细微波动。比如5正常映射为5.1,加入残差后变成 5+0.1。此时输入变成5.2,对于没有残差结构的结果,影响仅为0.1/5.1 = 2%。而对于残差结构,变成 5+0.2 , 由0.1变成了0.2 影响为100%。

3、残差映射 H ( x ) = F ( x ) + x ,在反向传播的时候就变成了 H ′ ( x ) = F ′ ( x ) + 1,这里的加1也可以保证梯度消失现象

4、作者也证明了退化问题在任何数据集上都普遍存在。在imagenet上拿到冠军之后,迁移学习用到了coco同样拿到了好几个赛道的冠军,说明残差结构是普适的。最后又和VGG比了一下,比VGG深了8倍,计算复杂性却还比VGG小 。

1.2.2 残差函数F的形式的可变性

  • 层数可变:论文中的实验包含有两层堆叠、三层堆叠,实际任务中也可以包含更多层的堆叠。

    如果F只有一层,则残差块退化线性层:y = Wx + x 。此时对网络并没有什么提升。

  • 连接形式可变:不仅可用于全连接层,可也用于卷积层。此时F代表多个卷积层的堆叠,而最终的逐元素加法+ 在两个feature map 上逐通道进行。

    此时 x 也是一个feature map,而不再是一个向量。

1.2.3 残差学习成功的原因

学习残差F(x,W)比学习原始映射H(x)要更容易。

  • 1、当原始映射H就是一个恒等映射时, 就是一个F零映射。此时求解器只需要简单的将堆叠的非线性连接的权重推向零即可。

    实际任务中原始映射 H可能不是一个恒等映射:

    • 如果H 更偏向于恒等映射(而不是更偏向于非恒等映射),则F就是关于恒等映射的抖动,会更容易学习。
    • 如果原始映射H 更偏向于零映射,那么学习 本身要更容易。但是在实际应用中,零映射非常少见,因为它会导致输出全为0。
  • 2、如果原始映射H是一个非恒等映射,则可以考虑对残差模块使用缩放因子。如Inception-Resnet 中:在残差模块与快捷连接叠加之前,对残差进行缩放。注意:ResNet 作者在随后的论文中指出:不应该对恒等映射进行缩放。

  • 3、可以通过观察残差 F的输出来判断:如果F的输出均为0附近的、较小的数,则说明原始映射H更偏向于恒等映射;否则,说明原始映射H更偏向于非横等映射。

1.2.4 残差块代码实现

在这里插入图片描述

from torch import nn
from torch.nn import functional as F
import torch

'''
⼀种是当use_1x1conv=False时,应⽤ReLU⾮线性函数之前,将输⼊添加到输出。
另⼀种是当use_1x1conv=True时,添加通过1 × 1卷积调整通道和分辨率



ResNet沿⽤了VGG完整的3 × 3卷积层设计。
残差块⾥⾸先有2个有相同输出通道数的3 × 3卷积层。
每个卷积层后接⼀个批量规范化层和ReLU激活函数。
然后我们通过跨层数据通路,跳过这2个卷积运算,将输⼊直接加在最后的ReLU激活函数前。
这样的设计要求2个卷积层的输出与输⼊形状⼀样,从⽽使它们可以相加。

如果想改变通道数,就需要引⼊⼀个额外的1 × 1卷积层来将输⼊变换成需要的形状后再做相加运算。
'''
class Residual(nn.Module):

    def __init__(self,input_channels, num_channels,use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None

        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)


if __name__ == '__main__':
    blk = Residual(3, 3)
    X = torch.rand(4, 3, 6, 6)
    Y = blk(X)
    print(Y.shape)  # 输⼊和输出形状⼀致 torch.Size([4, 3, 6, 6])
    blk = Residual(3, 6, use_1x1conv=True, strides=2)
    Y = blk(X)
    print(Y.shape)  # 在增加输出通道数的同时,减半输出的高和宽 torch.Size([4, 6, 3, 3])

1.3 ResNet网络

1.3.1 四种plain 网络

plain 网络:一些简单网络结构的叠加,如下图所示。图中给出了四种plain 网络,它们的区别主要是网络深度不同。其中,输入图片尺寸 224x224 。

ResNet 简单的在plain 网络上添加快捷连接来实现。

FLOPsfloating point operations 的缩写,意思是浮点运算量,用于衡量算法/模型的复杂度。

FLOPSfloating point per second的缩写,意思是每秒浮点运算次数,用于衡量计算速度。

在这里插入图片描述

相对于输入的feature map,残差块的输出feature map 尺寸可能会发生变化:

  • 输出 feature map 的通道数增加,此时需要扩充快捷连接的输出feature map 。否则快捷连接的输出 feature map 无法和残差块的feature map 累加。

    有两种扩充方式:

    • 直接通过 0 来填充需要扩充的维度。
    • 通过1x1 卷积来扩充维度。
  • 输出 feature map 的尺寸减半。此时需要对快捷连接执行步长为 2 的池化/卷积:如果快捷连接已经采用 1x1 卷积,则该卷积步长为2 ;否则采用步长为 2 的最大池化 。

1.3.2 模型预测能力

VGG-1934层 plain 网络Resnet-34
计算复杂度(FLOPs)19.6 billion3.5 billion3.6 billion

ImageNet 验证集上执行10-crop 测试的结果。

  • A 类模型:快捷连接中,所有需要扩充的维度的填充 0 。
  • B 类模型:快捷连接中,所有需要扩充的维度通过1x1 卷积来扩充。
  • C 类模型:所有快捷连接都通过1x1 卷积来执行线性变换。

C 优于BB 优于A。但是 C 引入更多的参数,相对于这种微弱的提升,性价比较低。所以后续的ResNet 均采用 B 类模型。

模型top-1 误差率top-5 误差率
VGG-1628.07%9.33%
GoogleNet-9.15%
PReLU-net24.27%7.38%
plain-3428.54%10.02%
ResNet-34 A25.03%7.76%
ResNet-34 B24.52%7.46%
ResNet-34 C24.19%7.40%
ResNet-5022.85%6.71%
ResNet-10121.75%6.05%
ResNet-15221.43%5.71%

1.3.3 ResNet-18实现

import torch.nn as nn
import torch
from _06_Residual import Residual


class ResNet18(nn.Module):

    def __init__(self):
        super(ResNet18, self).__init__()
        self.model = self.get_net()

    def forward(self, X):
        X = self.model(X)
        return X


    def get_net(self):
        '''
        ResNet的前两层跟GoogLeNet中的⼀样:
           在输出通道数为64、步幅为2的7 × 7卷积层后,接步幅为2的3 × 3的最⼤汇聚层。
           不同之处在于ResNet每个卷积层后增加了批量规范化层。
        '''
        b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                           nn.BatchNorm2d(64), nn.ReLU(),
                           nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        '''
        GoogLeNet在后⾯接了4个由Inception块组成的模块。
        
        ResNet则使⽤4个由残差块组成的模块,每个模块使⽤若⼲个同样输出通道数的残差块。
        第⼀个模块的通道数同输⼊通道数⼀致。由于之前已经使⽤了步幅为2的最⼤汇聚层,所以⽆须减⼩⾼和宽。
        之后的每个模块在第⼀个残差块⾥将上⼀个模块的通道数翻倍,并将⾼和宽减半。
        '''
        b2 = nn.Sequential(*self.resnet_block(64, 64, 2, first_block=True))
        b3 = nn.Sequential(*self.resnet_block(64, 128, 2))
        b4 = nn.Sequential(*self.resnet_block(128, 256, 2))
        b5 = nn.Sequential(*self.resnet_block(256, 512, 2))
        net = nn.Sequential(b1, b2, b3, b4, b5,
                            nn.AdaptiveAvgPool2d((1, 1)),
                            nn.Flatten(), nn.Linear(512, 10))
        return net

    def resnet_block(self, input_channels, num_channels, num_residuals, first_block=False):
        blk = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))
            else:
                blk.append(Residual(num_channels, num_channels))
        return blk

if __name__ == '__main__':
    net = ResNet18()
    X = torch.rand(size=(1, 1, 224, 224), dtype=torch.float32)
    for layer in net.model:
        X = layer(X)
        print(layer.__class__.__name__, 'output shape:', X.shape)
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 128, 28, 28])
Sequential output shape: torch.Size([1, 256, 14, 14])
Sequential output shape: torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape: torch.Size([1, 512, 1, 1])
Flatten output shape: torch.Size([1, 512])
Linear output shape: torch.Size([1, 10])

2 ResNet-18在Fashion-MNIST数据集上的应用示例

2.1 创建ResNet网络模型

如1.2.4及1.3.3代码所示。

2.2 读取Fashion-MNIST数据集

其他所有的函数,与经典神经网络(1)LeNet及其在Fashion-MNIST数据集上的应用完全一致。

batch_size = 256

# 为了使Fashion-MNIST上的训练短⼩精悍,将输⼊的⾼和宽从224降到96,简化计算
train_iter,test_iter = get_mnist_data(batch_size,resize=96)

2.3 在GPU上进行模型训练

from _06_ResNet18 import ResNet18

# 初始化模型
net = ResNet18()

lr, num_epochs = 0.05, 10
train_ch(net, train_iter, test_iter, num_epochs, lr, try_gpu())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/592461.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LeetCode】342. 4的幂

342. 4的幂(简单) 方法一:二进制 思路 首先考虑一个数字是不是 2 的整数次方:如果一个数字 n 是 2 的整数次方,那么它的二进制一定是 0...010...0 这样的形式,将它和 -n 按位与的结果一定是它本身。如果 …

前沿质谱应用沙龙分享会暨苏州百趣落成仪式即将开幕!

质谱作为一项医学检验新技术,凭借高特异性、高灵敏度、多指标检测等优势,成为了体外诊断领域最富生命力的新技术之一。目前质谱技术能够准确的测定多种生物小分子代谢物,且质谱在大分子物质例如蛋白质方面也应用的非常广泛。目前,…

要电脑重装系统装在哪个盘最好

在进行电脑重装系统时,选择一个合适的系统安装盘是非常重要的。本文将为您介绍如何选择最佳的系统安装盘,以确保系统性能和稳定性的最佳表现。 工具/原料: 系统版本:windows系统 品牌型号:华硕VivoBook14 软件版本…

张小龙发明了小程序,是否意味着失败?

今天微信小程序上线,从开发到上线仅仅用了四天时间,这是一个了不起的成就。 小程序诞生以来,一直存在着一种声音:它是张小龙“伟大的发明”,是微信“伟大的创新”。然而,张小龙在小程序发布会上宣布&#…

Spark SQL概述、数据帧与数据集

文章目录 一、准备工作1、准备数据文件2、启动Spark Shell 二、加载数据为Dataset1、读文件得数据集 三、给数据集添加元数据信息1、定义学生样例类2、导入隐式转换3、将数据集转换成学生数据集4、对学生数据集进行操作(1)显示数据集内容(2&a…

认识熟悉 Stable Diffusion(SD)基本参数

界面样式 界面参数 界面参数说明prompt希望生成的图片的描述negative prompt不希望在图片中出现的描述Batch size每次生成的图片个数Width图片宽度Height图片高度 这里需要注意的就是尺寸,尺寸并非越大越好,需要根据自己的配置和需求适当调整&#xff…

node.js+vue学生读书笔记共享分享系统

从上面的描述中可以基本可以实现软件的功能: 1、开发实现读书笔记共享平台的整个系统程序; 2、管理员;首页、个人中心、用户管理、笔记分享管理、个人笔记管理、管理员管理、交流互动、系统管理等。 3、用户:首页、个人中心、笔记分享管理、个人笔记管理、我的收藏管理。 4、前…

Window10配置Maven详细教程

文章目录 一、Maven概述二、Maven下载三、配置Maven环境变量四、查看Maven是否配置成功五、为Maven配置本地仓库以及指定远程仓库5.1 Maven构件搜索顺序5.2 Maven配置本地仓库5.3 Maven指定远程仓库 一、Maven概述 Maven是专门管理和构建Java项目的工具,Maven的主要…

Linux nohup-后台挂起运行程序神器

一. 场景描述 天黑了,我得离开实验室去吃饭了。为了环保,我必须关闭电脑,减少不必要的浪费!正常情况下当我关闭终端或电脑时,上面运行的任务代码即会自动停止,但我依旧希望保持代码的正常运行,此…

Android 更新后跑不起来?快来适配 AGP8 和 Flamingo/JDK 17

随着 Android Studio Flamingo 正式版的发布,AGP 8(Android Gradle Plugin 8)也正式进入大家的视野,这次 AGP 8 相关更新属于「断代式」更新,同时如果想体验 AGP 8,就需要升级到 Android Studio Flamingo 版…

揭秘速卖通卖家成功的绝佳秘籍,助您打造畅销店铺!

在竞争激烈的速卖通市场中,如何让您的店铺脱颖而出并实现畅销?林哥今天就跟大家讲一讲一些成功速卖通卖家的绝佳秘籍,帮助您引导高流量和高转化率,成就一个畅销的店铺。 ​一、精确定位目标受众 成功的速卖通店铺离不开精确的目标…

自动生成作文的软件有哪些?盘点五种自动生成作文软件

写作是一项需要花费大量时间和精力的任务,而自动生成作文的软件可以帮助我们节省大量的时间。这些软件通过分析和归纳大量的素材和语言模型,能够快速生成高质量的文章。相比于传统的写作方式,使用自动生成作文软件可以更快地完成文章&#xf…

一套完整的客户管理系统应该包含哪些模块呢?

一套完整的客户管理系统应该包含哪些模块呢? 想要弄清楚一个完整的客户管理系统应该具备哪些功能,首先得清楚系统使用者、使用场景以及主要功能这三个因素。 以我们公司为例: 主要使用者:运营人员、市场人员、产品人员。主要目…

Android Settings中Preference的理解以及使用

Preference 是Android App 中重要的控件之一,Settings 模块大部分都是通过 Preference 实现 优点: Preference 可以自动显示我们上次设置的数据,Android提供preference这个键值对的方式来处理这种情况,自动保存这些数据&#xff…

链接生成二维码怎么弄?这些制作方法分享给大家

在现代社会中,链接生成二维码已经成为了一个非常实用的工具。通过将链接转换为二维码,我们可以将它们轻松地分享给朋友、家人或同事,而无需手动输入URL或复制粘贴。这使得信息的传播变得更加快捷和高效。例如,你正在计划一个聚会&…

Spring第三方bean管理

文章目录 1.第三方bean管理1.1 Bean1.2 小结 2.第三方bean依赖注入2.1 简单类型:成员变量2.2 引用类型:方法形参2.3 小结 3.总结 1.第三方bean管理 1.1 Bean 首先看一下目录结构,APP里面就初始化了SpringConfig文件 SpringConifg中就一句话…

【vue】8个非常实用的Vue自定义指令:

文章目录 一、批量注册指令,新建 directives/index.js 文件二、在 main.js 引入并调用【1】v-copy【2】v-longpress【3】v-debounce【4】v-emoji【5】v-LazyLoad【6】v-permission【7】vue-waterMarker【8】v-draggable 复制粘贴指令 v-copy 长按指令 v-longpress 输…

JUnit单元测试之旅

目录 1. 什么是单元测试和JUnit2. JUnit入门与基本注解2.1测试类的定义:2.2 生命周期注解:2.3断言注解:2.4 参数化单参数多参数通过方法获取参数 2.5 测试套件 三.用到的依赖包 1. 什么是单元测试和JUnit 单元测试(Unit Testing)是对软件中的最小可测试单元进行检查和验证。它…

PyCharm使用指南 - 如何创建密码短语生成器(上)

PyCharm是一种Python IDE,其带有一整套可以帮助用户在使用Python语言开发时提高其效率的工具。此外,该IDE提供了一些高级功能,以用于Django框架下的专业Web开发。 PyCharm 最新下载 本文将展示如何使用免费的 PyCharm Community Edition 开…

Java 泛型的介绍

文章目录 1.学习目标2.什么是泛型3.引入泛型语法 4.泛型类的使用语法示例 6.泛型的上界语法示例 7.泛型的方法定义语法示例 8.通配符通配符解决什么问题通配符上界通配符下界 9.包装类基本数据类型和对应的包装类装箱和拆箱自动装箱和自动拆箱 1.学习目标 1.以能阅读 java 集合…