ResNet 原理剖析以及代码复现

news2024/11/24 11:37:46

原理

ResNet 解决了什么问题?

一言以蔽之:解决了深度的神经网络难以训练的问题。
具体的说,理论上神经网络的深度越深,其训练效果应该越好,但实际上并非如此,层数越深会导致越差的结果并且容易产生梯度爆炸或梯度消失等问题。

ResNet 怎么解决的?

提出了一个残差学习网络的框架,该框架解决了上述问题。

残差网络的架构

在这里插入图片描述

整个架构如上图所示。

首先我们要学习的东西是 H(x),假设现在已经有了一个浅的网络,然后我们要在上面新加一些层,让网络变得更深,如果按传统的做法那么新加的层就继续跟之前一样进行学习就行了。但是现在在新加的层中我们不直接去学 H(x),而是应该去学 H(x) - x。x 就是之前比较浅的网络已经学到的那个东西,也就是在新加的层中不去重新学个东西而只是学学到的东西和真实的东西二者之间的残差 H(x) - x,然后该层最后的输出结果 F(x) 再加上原始数据 x 就是最终结果也就是 F(x) + x ,此时优化的目标就不再是原始的 H(x),而是 H(x) - x 这个东西。

这就是 ResNet 的核心思想。

我感觉有一篇文章讲的很好,可以参考一下:ResNet网络详细讲解

下面是论文原文的描述:

在本文中,我们通过引入一个深度残差学习框架来解决退化问题。我们不希望每几个堆叠层直接拟合一个期望的底层映射,而是明确地让这些层拟合一个残差映射。在形式上,我们将期望的底层映射表示为H ( x ),并让堆叠的非线性层拟合F ( x )的另一个映射:= H ( x ) - x。原始映射被重铸成F ( x ) + x。我们假设优化残差映射比优化原始的、未引用的映射更容易。在极端情况下,如果一个恒等映射是最优的,那么将残差推到零比用一堆非线性层拟合一个恒等映射更容易。

F ( x ) + x的表达式可以通过具有"捷径连接"的前馈神经网络来实现(图2 )。快捷方式连接[ 2、33、48]是那些跳过一个或多个层的连接。在我们的例子中,快捷连接只是执行身份映射,它们的输出被添加到堆叠层的输出中(图2 )。身份捷径连接既不增加额外的参数,也不增加计算复杂度。整个网络仍然可以通过反向传播的SGD进行端到端的训练,并且可以很容易地使用公共库(例如, Caffe )实现,无需修改求解器。

我们在ImageNet [ 35 ]上进行了全面的实验来展示退化问题并评估我们的方法。研究表明:

1 )我们的深度残差网络易于优化,但对应的"普通"网络(简单地堆叠层)在深度增加时表现出更高的训练误差;
2 )我们的深度残差网络可以很容易地从大幅增加的深度中获得精度增益,产生的结果明显优于以前的网络。

在ImageNet分类数据集上[ 35 ],我们通过极深的残差网络获得了优异的结果。我们的152层残差网络是ImageNet上有史以来最深层的网络,但仍比VGG网络具有更低的复杂度[ 40 ]。我们的集成在ImageNet测试集上有3.57 %的top - 5误差,并在ILSVRC 2015分类竞赛中获得第一名。在其他识别任务上也具有出色的泛化性能,并引领我们在ILSVRC & COCO 2015竞赛中进一步获得第1名:ImageNet检测、ImageNet定位、COCO检测和COCO分割。这有力的证据表明,残差学习原理具有一般性,我们预期它在其他视觉和非视觉问题中也适用。

代码复现

这里给出我自己的模型代码:

import torch
from torch import nn

# 基本残差块
class BasicBlock(nn.Module):
    expansion = 1
    """
    参数解释:
    in_ch:输入通道数
    block_ch:输出通道数
    stride:步长,通过该参数我们就可以实现网络结构中特征图Size减半、通道数增加一倍的效果
    downSample:其本身也是一个网络,用来实现残差网络中的跳跃连接(也就是论文中虚线和实线)
                同时跳跃连接也是用来区别基本残差块和瓶颈残差块的,二者区别如下:
                    基本残差块:输入输出通道数相同
                    瓶颈残差块:输入输出通道数不同,需要进行升维操作才能对位相加
                另外二者的结构不同,可以通过论文看到
    """
    def __init__(self, in_ch, block_ch, stride=1, downSample=None):
        super().__init__()
        self.downSample = downSample
        # 从网络结构图中可以看到,先进行第一层卷积
        self.conv1 = nn.Conv2d(in_ch, block_ch, kernel_size=3, stride=stride, padding=1, bias=False)
        # 在网络模型中添加一个二维批归一化(Batch Normalization)层。
        # 批归一化是一种用于加速神经网络训练并提高其性能的技术,类似于将上面所输出的数据进行了统一整理
        self.bn1 = nn.BatchNorm2d(block_ch)
        # 激活函数
        self.relu1 = nn.ReLU()

        # 第二层卷积
        self.conv2 = nn.Conv2d(block_ch, block_ch * self.expansion, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(block_ch * self.expansion)
        self.relu2 = nn.ReLU()


    def forward(self, x):
        identity = x
        # 如果downSample参数不为空,说明其需要升维(也就是论文中虚线的样子)
        if self.downSample is not None:
            # 升维,让输入输出的通道数对齐
            identity = self.downSample(x)
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        # 这里就是论文中的输出与原始输入进对位相加的步骤
        out += identity
        # 对位相加结束后再进行 relu 函数的激活,然后输出结果
        return self.relu2(out)


# 瓶颈残差块
class Bottleneck(nn.Module):
    # 从论文的网络结构图中不难发现,瓶颈残差块在第三层卷积时通道数会放大四倍
    # 因此定义一个 expansion 变量
    expansion = 4
    """
        参数解释:
        in_ch:输入通道数
        block_ch:输出通道数
        stride:步长,通过该参数我们就可以实现网络结构中特征图Size减半、通道数增加一倍的效果
        downSample:其本身也是一个网络,用来实现残差网络中的跳跃连接(也就是论文中虚线和实线)
                    同时跳跃连接也是用来区别基本残差块和瓶颈残差块的,二者区别如下:
                        基本残差块:输入输出通道数相同
                        瓶颈残差块:输入输出通道数不同,需要进行升维操作才能对位相加
                    另外二者的结构不同,可以通过论文看到
    """
    def __init__(self, in_ch, block_ch, stride=1, downSample=None):
        super().__init__()
        self.downSample = downSample
        # 从网络结构图中可以看到,先进行第一层卷积
        self.conv1 = nn.Conv2d(in_ch, block_ch, kernel_size=1, stride=stride, bias=False)
        # 在网络模型中添加一个二维批归一化(Batch Normalization)层。
        # 批归一化是一种用于加速神经网络训练并提高其性能的技术,类似于将上面所输出的数据进行了统一整理
        self.bn1 = nn.BatchNorm2d(block_ch)
        # 激活函数
        self.relu1 = nn.ReLU()

        # 第二层卷积
        self.conv2 = nn.Conv2d(block_ch, block_ch, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(block_ch)
        self.relu2 = nn.ReLU()

        # 第三层卷积
        self.conv3 = nn.Conv2d(block_ch, block_ch * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(block_ch * self.expansion)
        self.relu3 = nn.ReLU()

    def forward(self, x):
        identity = x
        # 如果downSample参数不为空,说明其需要升维(也就是论文中虚线的样子)
        if self.downSample is not None:
            # 升维,让输入输出的通道数对齐
            identity = self.downSample(x)
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.relu2(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        # 这里就是论文中的输出与原始输入进对位相加的步骤
        out += identity
        # 对位相加结束后再进行 relu 函数的激活,然后输出结果
        return self.relu3(out)


# 残差网络
class ResNet(nn.Module):
    """
    in_ch: 默认为3,因为残差网络就是用来图片分类的,所以输入通道数默认为 3
    num_classes:分类的数量,默认设置为100,即 100 种分类
    block:用来区别是 基本残差块 还是 瓶颈残差块
    block_num:每个残差块所需要堆叠的次数(也是论文中提供的有)
    """
    def __init__(self, in_ch=3, num_classes=100, block=Bottleneck, block_num=[3, 4, 6, 3]):
        super().__init__()
        # 因为在各层之间通道数会发生变化,因此要进行跟踪
        self.in_ch = in_ch
        # 对于残差网络来说,不管是什么类型其一开始都要进行 7x7 的卷积和 3x3 的池化
        # 因此我们直接照搬即可(论文中已经有了)
        self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.in_ch = 64

        # 将残差块堆叠起来,形成一个一个的残差层,进而构建成ResNet
        self.layer1 = self._make_layer(block, 64, block_num[0], stride=1)
        self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, block_num[3], stride=2)

        # 最后是全连接层,做预测的
        self.fc_layer = nn.Sequential(
            nn.Linear(512*block.expansion*7*7, num_classes),
            nn.Softmax(dim=-1)
        )
    def _make_layer(self, block, block_ch, block_num, stride=2):
        layers = []
        downSample = nn.Conv2d(self.in_ch, block_ch * block.expansion, kernel_size=1, stride=stride)
        layers += [block(self.in_ch, block_ch, stride=stride, downSample=downSample)]
        self.in_ch = block_ch * block.expansion

        for _ in range(1, block_num):
            layers += [block(self.in_ch, block_ch)]
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.maxpool1(self.bn1(self.conv1(x))) #(1, 3, 224, 224) -> (1, 64, 56, 56)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = out.reshape(out.shape[0], -1)
        out = self.fc_layer(out)
        return out

if __name__ == '__main__':
    # 接下来进行测试
    # 这行代码创建了一个形状为 (1, 3, 224, 224) 的四维张量 x,
    # 其中包含了一个大小为 1 的批次中的一个 224x224 像素的 RGB 图像。
    x = torch.randn(1, 3, 224, 224)
    resnet = ResNet(in_ch=3, num_classes=100, block=Bottleneck, block_num=[2, 2, 2, 2])
    y = resnet(x)
    print(y.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1714498.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Scapy:用Python编写自己的网络抓包工具

随着Python越来越流行,在安全领域的用途也越来越多。比如可以用requests 模块撰写进行Web请求工具;用sockets编写TCP网络通讯程序;解析和生成字节流可以使用struct模块。而要解析和处理网络包在网络安全领域更加普遍,时常我们会使…

Vue——事件修饰符

文章目录 前言阻止默认事件 prevent阻止事件冒泡 stop 前言 在官方文档中对于事件修饰符有一个很好的说明,本篇文章主要记录验证测试的案例。 官方文档 事件修饰符 阻止默认事件 prevent 在js原生的语言中,可以根据标签本身的事件对象进行阻止默认事件…

数组-给出最大容量,求能获得的最大值

一、问题描述 二、解题思路 这个题目其实是求给出数组中,子数组和不大于M中,和最大值的子数组。 求子数组使用双指针就可以解决问题,相对比较简单。(如果是子序列,则等价于0-1背包问题,看题目扩展中的问题…

C++笔试强训day36

目录 1.提取不重复的整数 2.【模板】哈夫曼编码 3.abb 1.提取不重复的整数 链接https://www.nowcoder.com/practice/253986e66d114d378ae8de2e6c4577c1?tpId37&tqId21232&ru/exam/oj 按照题意模拟就行&#xff0c;记得从右往左遍历 #include <iostream> usi…

Vue——计算属性 computed 与方法 methods 区别探究

文章目录 前言计算属性的由来方法实现 计算属性 同样的效果计算属性缓存 vs 方法 前言 在官方文档中&#xff0c;给出了计算属性的说明与用途&#xff0c;也讲述了计算属性与方法的区别点。本篇博客只做自己的探究记录&#xff0c;以官方文档为准。 vue 计算属性 官方文档 …

彻底理解浏览器的进程与线程

彻底理解浏览器的进程与线程 什么是进程和线程&#xff0c;两者的区别及联系浏览器的进程和线程总结浏览器核心进程有哪些浏览器进程与线程相关问题 什么是进程和线程&#xff0c;两者的区别及联系 进程和线程是操作系统中用于管理程序执行的两个基本概念进程的定义及理解 定义…

PHPSTOM配置Laradock,xdebug,phpunit

原理图&#xff1a; 片面理解&#xff1a; phpstorm启用一个9000端口&#xff0c;这个端口用来接收到信息后&#xff0c;启用xdebug功能。服务器端(docker), 当客户端访问laravel项目域名后, 并读取xdebug.ini的配置, 把调试的请求数据, 向配置里面的端口发送消息, 配置里面的端…

QGIS开发笔记(三):Windows安装版二次开发环境搭建(下):将QGis融入QtDemo,添加QGis并加载tif遥感图的Demo

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/139136356 长沙红胖子Qt&#xff08;长沙创微智科&#xff09;博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV…

如何在phpMy管理对Joomla后台的登录密码进行重置

本周有一个客户&#xff0c;购买Hostease的虚拟主机&#xff0c;询问我们的在线客服&#xff0c;如何在phpMy管理对Joomla后台的登录密码进行重置&#xff1f;我们为用户提供相关教程&#xff0c;用户很快解决了遇到的问题。在此&#xff0c;我们分享这个操作教程&#xff0c;希…

​用 ONLYOFFICE 宏帮你自动执行任务:介绍与教程

使用 ONLYOFFICE 宏&#xff0c;可以来自动实现一些操作节省更多时间和精力。在本文中&#xff0c;我们集合了一些关于宏的教程&#xff0c;带您了解宏的工作原理&#xff0c;以及一些实例展示。 什么是 ONLYOFFICE 宏 如果您是一名资深 Microsoft Excel 用户&#xff0c;那么…

大数据报告有什么作用?查询方式一般有几种?

随着互联网金融的飞速发展&#xff0c;网络借贷已经成为了一种常见的融资方式。然而&#xff0c;如何在众多的平台中做出正确的选择&#xff0c;避免风险并实现最大利益&#xff0c;这就需要一份具有参考价值的大数据报告。本文将详细阐述大数据报告的作用及查询方式的几种方式…

ROS2贪吃龟练习工程

本文是ROS2基础知识的综合小应用&#xff0c;练习如何创建工作包&#xff0c;创建Node&#xff0c;定义Topic和Service&#xff0c;以及通过LaunchFile启动多个节点。基础知识可以参考&#xff1a;ROS2基础编程&#xff0c;ROS2 Topics和Services&#xff0c;ROS2 LaunchFile和…

模拟集成电路(5)----单级放大器(共栅级)

模拟集成电路(5)----单级放大器&#xff08;共栅级&#xff09; 有一些场合需要一些小的输入电阻&#xff08;电流放大器&#xff09; 大信号分析 − W h e n V i n ≥ V B − V T H ∙ M 1 i s o f f , V o u t V D D − F o r L o w e r V i n I d 1 2 μ n C o x W L ( V…

matlab安装及破解

一、如何下载 软件下载链接&#xff0c;密码&#xff1a;98ai 本来我想自己生成一个永久百度网盘链接的&#xff0c;但是&#xff1a; 等不住了&#xff0c;所以大家就用上面的链接吧。 二、下载花絮 百度网盘下载速度比上载速度还慢&#xff0c;我给充了个会员&#xff0c…

java调用远程接口下载文件

在postman中这样下载文件 有时下载文件太大postman会闪退&#xff0c;可以通过代码下载&#xff0c;使用hutool的http包

中华活页文选高中版投稿发表

《中华活页文选&#xff08;高中版&#xff09;》创刊于1960年&#xff0c;是中宣部所属中国出版传媒股份有限公司主管、中华书局主办的国家级基础教育期刊&#xff0c;曾获得“中国期刊方阵双效期刊”、国家新闻出版广电总局推荐的“百种优秀报刊”等荣誉称号。本刊以高中学科…

WAMP无法启动mysql

一种原因是原来安装过mysql,mysql默认是自启动的&#xff0c;而WAMP内置mysql会发生冲突&#xff0c;所以 解决方法&#xff1a; winR 输入 services.msc 将mysql关闭&#xff0c;并设为手动模式

扒出秦L三个槽点,我不考虑买它了

文 | Auto芯球 作者 | 雷慢 比亚迪的有一个王炸“秦L”&#xff0c;再一次吸引了我注意力&#xff0c; 我上一辆车刚卖不久&#xff0c;最近打算买第二辆车&#xff0c; 二手车和新车都有在看&#xff0c; 我又是一个坚定的实用主义者&#xff0c; 特别是现在的经济环境不…

深入解析 JSONPath:从入门到精通

码到三十五 &#xff1a; 个人主页 在数据处理和交换领域&#xff0c;JSON已经成为了一种广泛使用的数据格式&#xff0c; 如何有效地查询和操作这些数据也变得越来越重要。在这种情况下&#xff0c;JSONPath 应运而生&#xff0c;成为了一种在JSON数据中定位和提取信息的强大工…

老师如何对付挑事儿的家长?

身为老师&#xff0c;你有没有遇到过这样的家长&#xff1a;孩子在学校里闹点小矛盾&#xff0c;或者作业分数有点争议&#xff0c;他们就气势汹汹地来找你&#xff0c;说你偏心&#xff0c;甚至在其他家长面前说三道四&#xff1f;面对这种爱“挑事”的家长&#xff0c;老师们…