【深度学习】四大图像分类网络之ResNet

news2024/12/25 9:05:17

ResNet网络是在2015年由微软实验室中的何凯明等几位提出,在CVPR 2016发表影响深远的网络模型,由何凯明团队提出来,在ImageNet的分类比赛上将网络深度直接提高到了152层,前一年夺冠的VGG只有19层。斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名,可以说ResNet的出现对深度神经网络来说具有重大的历史意义。

论文原文:Deep Residual Learning for Image Recognition

一、网络结构

ResNet在CNN图像方面有着非常突出的表现,利用shortcut 短路连接,解决了深度网络中模型退化的问题;每两层/三层之间增加了短路机制,通过residual learning残差学习使深层的网络发挥出作用。  

其中,提出了两种残差块——两层卷积(两组3x3卷积核)和三层卷积(两组1x1卷积核和一组3x3卷积核),不同残差块组合后形成不同参数量、计算量和性能的网络架构,当网络更深时,其进行的是三层间的残差学习,三层卷积核分别是1x1,3x3和1x1,一个值得注意的是隐含层的feature map数量是比较小的,并且是输出feature map数量的1/4。

ResNet给出了五种层数的结构——18-layer、34-layer、50-layer、101-layer和152-layer。 

二、创新点

1.  超深的网络结构(超过1000层)

        在ResNet提出之前,所有的神经网络都是通过卷积层和池化层的叠加组成的。人们认为卷积层和池化层的层数越多,获取到的图片特征信息越全,学习效果也就越好。但是在实际的试验中发现,随着卷积层和池化层的叠加,不但没有出现学习效果越来越好的情况,反而出现梯度消失(若每一层的误差梯度小于1,反向传播时网络越深,梯度越趋近于0)或梯度爆炸(若每一层的误差梯度大于1,反向传播时网络越深,梯度越来越大)、退化的问题(随着层数的增加,预测效果反而越来越差)。

        因此,在 “plain” 网络中(卷积层、池化层、全连接层的堆叠,不存在残差连接),随着层数的增加,训练错误的最小值并没有提升,梯度通过反向传播逐渐回传到之前的位置,梯度越来越小以至于到最后就没什么梯度了,由于存储方式是浮点数,小于某个数值就视为0。但模型越深,在理论上效果会越好。为了解决梯度消失或梯度爆炸问题,ResNet提出通过数据的预处理以及在网络中使用BN(Batch Normalization)层来解决。为了解决深层网络中的退化问题,可以人为地让神经网络某些层跳过下一层神经元的连接,隔层相连,弱化每层之间的强联系。这种神经网络被称为残差网络。最终使得网络结构超过1000层,并取得了很好的性能。

2. 提出residual(残差结构)模块

        “残差” 指的是输入和输出之间的差异。在传统的深层网络中,模型直接学习输入到输出的映射,这对于层数较深的网络来说,会导致训练困难,因为信号在网络中传播时容易丧失。残差连接通过让每一层学习“输入与目标输出之间的差异”,模型不再学习一个完整的映射 x→y ,而是学习一个小的“改动”或者“残差” F(x),有效避免了这个问题。在深层网络中,原始映射 x→y 可能非常复杂,而通过学习“差异”来调整网络参数通常会更高效。短路线相当于短路操作,在进行反向传播的时候,可以看作是将模型拆成了两个模型进行分别训练,也能更好的进行梯度的传递。

        shortcut 之后并不是通过简单的加法,由于CNN网络中做每一层的时候会发生维度的变化,所以我们需要用到1x1卷积层来调整维度,这也就是为什么有的残差连接是虚线而有些连接是实线。虚线残差结构将图像的高、宽和深度都改变了,实线残差结构的输入、输出特征矩阵维度是一样的,故可以直接进行相加。

3. Batch Normalization(BN)

        BN的目的是使一批(batch)数据所对应feature map的每一个channel的维度满足均值为0、方差为1的分布规律。通过这种方法能够加速网络的收敛并提升准确率。理论上是指整个训练样本集所对应feature map的数据要满足分布规律,也就是说,在计算出整个训练集的feature map之后,再进行标准化处理,但这对于一个大型数据集来说明显是不可能的。所以,我们就一个批次一个批次地进行处理,就是对于一个batch数据的feature map,分别对R、G、B三通道进行处理。在这个过程中,计算得到的 μ 和 \sigma^2 是一个向量而不是一个值,向量的每一个元素代表着一个维度的值。

        训练网络的过程是通过一个批次一个批次的数据进行训练的。但在预测过程通常都是输入一张图片进行预测,即batch_size=1,此时如果再通过上述方法计算均值和方差就没有意义了。所以,在训练过程中要不断地计算每个batch的均值和方差,并使用移动平均的方法记录统计的均值和方差,训练完后可以近似地认为所统计的均值和方差就等于整个训练集的均值和方差。在验证以及预测的过程中,就使用统计得到的均值和方差进行标准化处理。

        BN层让损失函数更平滑,添加BN层后,损失函数的landscape变得更平滑,相比高低不平上下起伏的loss surface,平滑loss surface的梯度预测性更好,可以选取较大的步长。同时,没有BN层的情况下,网络没办法直接控制每层输入的分布,其分布前面层的权重共同决定,或者说分布的均值和方差“隐藏”在前面层的每个权重中,网络若想调整其分布,需要通过复杂的反向传播过程调整前面的每个权重实现,BN层的存在相当于将分布的均值和方差从权重中剥离了出来,只需调整γ和β两个参数就可以直接调整分布,让分布和权重的配合变得更加容易。

三、代码

import torch.nn as nn
import math
from utee import misc
from collections import OrderedDict
 
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']
 
 
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
 
 
def conv3x3(in_planes, out_planes, stride=1):
    # "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
 
 
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        m = OrderedDict()
        m['conv1'] = conv3x3(inplanes, planes, stride)
        m['bn1'] = nn.BatchNorm2d(planes)
        m['relu1'] = nn.ReLU(inplace=True)
        m['conv2'] = conv3x3(planes, planes)
        m['bn2'] = nn.BatchNorm2d(planes)
        self.group1 = nn.Sequential(m)
 
        self.relu= nn.Sequential(nn.ReLU(inplace=True))
        self.downsample = downsample
 
    def forward(self, x):
        if self.downsample is not None:
            residual = self.downsample(x)
        else:
            residual = x
 
        out = self.group1(x) + residual
 
        out = self.relu(out)
 
        return out
 
 
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        m  = OrderedDict()
        m['conv1'] = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        m['bn1'] = nn.BatchNorm2d(planes)
        m['relu1'] = nn.ReLU(inplace=True)
        m['conv2'] = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        m['bn2'] = nn.BatchNorm2d(planes)
        m['relu2'] = nn.ReLU(inplace=True)
        m['conv3'] = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        m['bn3'] = nn.BatchNorm2d(planes * 4)
        self.group1 = nn.Sequential(m)
 
        self.relu= nn.Sequential(nn.ReLU(inplace=True))
        self.downsample = downsample
 
    def forward(self, x):
        if self.downsample is not None:
            residual = self.downsample(x)
        else:
            residual = x
 
        out = self.group1(x) + residual
        out = self.relu(out)
 
        return out
 
 
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
 
        m = OrderedDict()
        m['conv1'] = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        m['bn1'] = nn.BatchNorm2d(64)
        m['relu1'] = nn.ReLU(inplace=True)
        m['maxpool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.group1= nn.Sequential(m)
 
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
 
        self.avgpool = nn.Sequential(nn.AvgPool2d(7))
 
        self.group2 = nn.Sequential(
            OrderedDict([
                ('fc', nn.Linear(512 * block.expansion, num_classes))
            ])
        )
 
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
 
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
 
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))
 
        return nn.Sequential(*layers)
 
    def forward(self, x):
        x = self.group1(x)
 
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
 
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.group2(x)
 
        return x
 
 
def resnet18(pretrained=False, model_root=None, **kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        misc.load_state_dict(model, model_urls['resnet18'], model_root)
    return model
 
 
def resnet34(pretrained=False, model_root=None, **kwargs):
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        misc.load_state_dict(model, model_urls['resnet34'], model_root)
    return model
 
 
def resnet50(pretrained=False, model_root=None, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        misc.load_state_dict(model, model_urls['resnet50'], model_root)
    return model
 
 
def resnet101(pretrained=False, model_root=None, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        misc.load_state_dict(model, model_urls['resnet101'], model_root)
    return model
 
 
def resnet152(pretrained=False, model_root=None, **kwargs):
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        misc.load_state_dict(model, model_urls['resnet152'], model_root)
    return model

参考资料:

ResNet——CNN经典网络模型详解(pytorch实现)_resnet-cnn-CSDN博客icon-default.png?t=O83Ahttps://blog.csdn.net/weixin_44023658/article/details/105843701Batch Normalization(BN)超详细解析_batchnorm在预测阶段需要计算吗-CSDN博客icon-default.png?t=O83Ahttps://blog.csdn.net/weixin_44023658/article/details/105844861你必须要知道CNN模型:ResNet - 知乎icon-default.png?t=O83Ahttps://zhuanlan.zhihu.com/p/31852747

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2252327.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于树莓派的安保巡逻机器人--项目介绍

目录 一、项目简介 二、项目背景 三、作品研发技术方案 作品主要内容: 方案的科学性 设计的合理性 四、作品创新性及特点 五、作品自我评价 本篇为项目“基于树莓派的安保巡逻机器人”介绍博客 演示视频链接: 基于树莓派的安保巡逻机器人_音游…

nn.RNN解析

以下是RNN的计算公式,t时刻的隐藏状态H(t)等于前一时刻隐藏状态H(t-1)乘以参数矩阵,再加t时刻的输入x(t)乘以参数矩阵,最后再通过激活函数,等到t时刻隐藏状态。 下图是输出input和初始化的隐藏状态,当参数batch_first True时候&…

Unity网络框架对比 Mirror|FishNet|NGO

在Unity中制作非单机项目常用的免费网络框架,这里选取了三款比较火的网络框架,Mirror、FishNet和Netcode for GameObject(NGO)。 比较了最常用的免费网络解决方案。可能还有值得探索的付费选项。您需要对此进行自己的研究。数据表格更新日志截止到&#…

【C++】深度剖析 scanf 函数:原理、应用与优化

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C 文章目录 💯前言💯scanf 函数的基本原理💯基本用法示例常见占位符示例 💯使用 scanf 时的注意事项💯引入 cstdio 头文件💯scanf 与 cin 的对比…

YOLOv1 (You Only Look Once)

YOLO (You Only Look Once) 是一种经典的目标检测算法,旨在通过一个统一的卷积神经网络(CNN)进行目标检测,最大化检测速度并保持较高的精度。YOLO 在目标检测领域产生了巨大的影响,并且经过了多个版本的迭代。下面是 Y…

【Verilog】实验二 数据选择器的设计与vivado集成开发环境

目录 一、实验目的 二、实验环境 三、实验任务 四、实验原理 五、实验步骤 top.v mux2_1.v 一、实验目的 1. 掌握数据选择器的工作原理和逻辑功能。 2. 熟悉vivado集成开发环境。 3. 熟悉vivado中进行开发设计的流程。 二、实验环境 1. 装有vivado的计算机。 2. Sw…

【CSS in Depth 2 精译_063】10.2 深入理解 CSS 容器查询中的容器

当前内容所在位置(可进入专栏查看其他译好的章节内容) 【第十章 CSS 容器查询】 ✔️ 10.1 容器查询的一个简单示例 10.1.1 容器尺寸查询的用法 10.2 深入理解容器 ✔️ 10.2.1 容器的类型 ✔️10.2.2 容器的名称 ✔️10.2.3 容器与模块化 CSS ✔️ 10.3…

今天我们来聊聊Maven中两个高级的概念—— 插件和目标

插件&#xff08;plugin&#xff09; Maven的核心是一个插件执行框架;所有的工作都是由插件完成的。 Maven中Plugin分为两种类型&#xff1a; build类型Plugin只能在build阶段执行&#xff0c;在POM中需要在 <build/> 标签下进行配置。 reporting类型&#xff1a;在si…

【触想智能】自动售票机选择工控一体机配套的原因分析

自动售票机是现代公共交通系统中常见的设备之一&#xff0c;它能够方便、快速地为乘客提供票务服务。为了实现高效、可靠的运营&#xff0c;许多自动售票机都采用工控一体机作为核心控制硬件。 触想工控一体机TPC-W200系列 下面&#xff0c;触想智能小编为大家分析为什么自动售…

[计算机网络] HTTP/HTTPS

一. HTTP/HTTPS简介 1.1 HTTP HTTP&#xff08;超文本传输协议&#xff0c;Hypertext Transfer Protocol&#xff09;是一种用于从网络传输超文本到本地浏览器的传输协议。它定义了客户端与服务器之间请求和响应的格式。HTTP 工作在 TCP/IP 模型之上&#xff0c;通常使用端口 …

element-ui的下拉框报错:Cannot read properties of null (reading ‘disabled‘)

在使用element下拉框时&#xff0c;下拉框option必须点击输入框才关闭&#xff0c;点击其他地方报错&#xff1a;Cannot read properties of null (reading disabled) 造成报错原因&#xff1a;项目中使用了el-dropdown组件&#xff0c;但是在el-dropdown里面没有定义el-dropdo…

新一代零样本无训练目标检测

&#x1f3e1;作者主页&#xff1a;点击&#xff01; &#x1f916;编程探索专栏&#xff1a;点击&#xff01; ⏰️创作时间&#xff1a;2024年12月2日21点02分 神秘男子影, 秘而不宣藏。 泣意深不见, 男子自持重, 子夜独自沉。 论文链接 点击开启你的论文编程之旅h…

30.100ASK_T113-PRO 用QT编写视频播放器(一)

1.再buildroot中添加视频解码库 X264, 执行 make menuconfig Target packages -->Libraries --> Multimedia --> X264 CLI 还需要添加 FFmpeg 2. 保存,重新编译 make all 3.将镜像下载开发板

Python办公自动化,批量生成Excel案例数据集

在数据分析的世界里&#xff0c;数据是核心&#xff0c;而如何高效地生成和处理数据则成为每位数据分析师必备的技能之一。今天&#xff0c;我们要探讨一个有趣的话题——“造数”。 但这里的“造数”并非意味着编造数据&#xff0c;而是指在确保数据安全的前提下&#xff0c;…

在线绘制Nature Communication同款双色、四色火山图,突出感兴趣的基因

导读&#xff1a;火山图通常使用三种颜色分别表示显著上调&#xff0c;显著下调和不显著。通过为特定的数据点添加另一种颜色&#xff0c;可以创建双色或四色火山图&#xff0c;从而更直观地突出感兴趣的数据点。 《Nature Communication》文章“Molecular and functional land…

【大语言模型】ACL2024论文-23 检索增强的多语言知识编辑

【大语言模型】ACL2024论文-23 检索增强的多语言知识编辑 目录 文章目录 【大语言模型】ACL2024论文-23 检索增强的多语言知识编辑目录摘要研究背景问题与挑战如何解决核心创新点算法模型实验效果&#xff08;包含重要数据与结论&#xff09;相关工作后续优化方向 后记 检索增强…

ETSI EN 300328 标准的一些笔记

ETSI - European Telecommunications Standards Institute 欧洲电信标准化协会 ETSI EN 300328 是欧洲协调标准&#xff0c;此标准适用于工作在2.4G频段范围内运行的宽频传输系统和设备的无线电频谱。 例如 WIFI、Zigbee、蓝牙、 (国内的星闪)。不涵盖UWB。 符合了EN 300328标…

VSCode:代码格式化插件

settings.json文件中添加如下配置并保存 {"workbench.sideBar.location": "left","cssrem.rootFontSize": 80,"git.ignoreWindowsGit27Warning": true,"eslint.codeAction.showDocumentation": {"enable": true…

Redis实现限量优惠券的秒杀

核心&#xff1a;避免超卖问题&#xff0c;保证一人一单 业务逻辑 代码步骤分析 全部代码 Service public class VoucherOrderServiceImpl extends ServiceImpl<VoucherOrderMapper, VoucherOrder> implements IVoucherOrderService {Resourceprivate ISeckillVoucher…

Github提交Pull Request教程 Git基础扫盲(零基础易懂)

1 PR是什么&#xff1f; PR&#xff0c;全称Pull Request&#xff08;拉取请求&#xff09;&#xff0c;是一种非常重要的协作机制&#xff0c;它是 Git 和 GitHub 等代码托管平台中常见的功能&#xff0c;被广泛用于参与社区贡献&#xff0c;从而促进项目的发展。 PR的整个过…