秃姐学AI系列之:残差网络 ResNet

news2025/1/22 21:41:02

目录

残差网络——ResNet

残差块思想

ResNet块细节

ResNet架构

总结

代码实现

残差块

两种 ResNet 块的情况 

ResNet 模型

QA


由上图发现,只有当较复杂的函数类包含较小的函数类时,才能确保提高它们的性能。

对于深度神经网络,如果我们能将新添加的层训练成恒等映射(identity function)f(x)=x,新模型和原模型将同样有效。 同时,由于新模型可能得出更优的解来拟合训练数据集,因此添加层似乎更容易降低训练误差。

针对这一问题,何恺明等人提出了残差网络(ResNet) 。它在2015年的 ImageNet 图像识别挑战赛夺魁,并深刻影响了后来的深度神经网络的设计。

残差网络——ResNet

残差网络的核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。

于是,残差块(residual blocks)便诞生了,这个设计对如何建立深层神经网络产生了深远的影响。 凭借它,ResNet赢得了2015年ImageNet大规模视觉识别挑战赛。

残差块思想

残差块加入快速通道来得到f(x) = x + g(x) 的结构

如下图所示,假设我们的原始输入为x,而希望学出的理想映射为 f(x)(作为上方激活函数的输入)。

左图虚线框中的部分需要直接拟合出该映射 f(x),而右图虚线框中的部分则需要拟合出残差映射 f(x)−x。残差映射在现实中往往更容易优化

开头提到的恒等映射作为我们希望学出的理想映射 f(x),我们只需将右图虚线框内上方的加权运算(如仿射)的权重和偏置参数设成 0,那么 f(x) 即为恒等映射。 实际中,当理想映射 f(x) 极接近于恒等映射时,残差映射也易于捕捉恒等映射的细微波动。

右图是 ResNet 的基础架构–残差块(residual block)。 在残差块中,输入可通过跨层数据线路更快地向前传播。

相当于ResNet觉得,你就算虚线框里面所有层都没学到东西,下一层还是可以接收到这层的上一层传递来的东西(残差连接)即一个简单的直接传递的小模型。这个想法从函数的角度来说,可以认为更大、更复杂的模型里面包含一个小模型。

ResNet块细节

ResNet 是从 VGG 过来的,所以采用的是 3x3Conv

以下是 ResNet 块的两个不同的实现

右边存在的意义是:如果虚线的block对通道做了变换,那直接的X加不回去了,所以需要用卷积来对x做一个通道数的变换用于相加 。

ResNet架构

ResNet 最主要的思想就是单拎出来一条路让你可以把输入和输出加起来

抛开这个其他的你可以认为和 VGG 以及GoogLeNet 很像,也是由5个Stage拼成,只是把组合成网络的 Stage 替换成了 ResNet 块

  • 一个高宽减半的 ResNet 块(步幅为2)(那个支线上有Conv的Block,用来把输入的通道数翻一倍)
  • 重复多个高宽不变的 ResNet 块 

总结

  • 残差块使得很深的网络更加容易训练

    • 甚至可以训练以前层的网络

  • 残差网络对随后的深层神经网络设计产生了深远的影响,无论是卷积类网络还是全连接类网络 

  • 学习嵌套函数(nested function)是训练神经网络的理想情况。在深层神经网络中,学习另一层作为恒等映射(identity function)较容易(尽管这是一个极端情况)。

  • 残差映射可以更容易地学习同一函数,例如将权重层中的参数近似为零。

  • 利用残差块(residual blocks)可以训练出一个有效的深层神经网络:输入可以通过层间的残余连接更快地向前传播。

代码实现

残差块

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l


class Residual(nn.Module):  #@save
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                               kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

两种 ResNet 块的情况 

输入和输出形状一致

blk = Residual(3,3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
Y.shape

# 输出
torch.Size([4, 3, 6, 6])

增加输出通道数的同时,减半输出的高和宽

blk = Residual(3, 6, use_1x1conv=True, strides= 2)
blk(X).shape

# 输出
torch.Size([4, 6, 3, 3])

ResNet 模型

ResNet 的前两层跟之前介绍的 GoogLeNet 中的一样:

在输出通道数为 64、步幅为 2 的 7×7 卷积层后,

接步幅为 2 的 3×3 的最大汇聚层。

不同之处在于 ResNet 每个卷积层后增加了批量规范化层。

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

GoogLeNet 在后面接了 4 个由 Inception 块组成的模块。ResNet 则使用 4 个由残差块组成的模块,

每个模块使用若干个同样输出通道数的残差块。

第一个模块的通道数同输入通道数一致。由于之前已经使用了步幅为 2 的最大汇聚层,所以无须减小高和宽。

之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。

def resnet_block(input_channels, num_channels, num_residuals,
                 first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk

b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

net = nn.Sequential(b1, b2, b3, b4, b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(), nn.Linear(512, 10))

老规矩,不同模块的数据形状变化

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)

QA

  • 残差概念体现在哪里?

可以理解成,因为 f(x) 是由 x 和 g(x) 相加得来的,x 又是由上一层网络训练的来的,可以被视为一个小网络的输出。所以整个 ResNet 就是先训练小网络,然后小网络 fit 不到的(小的差距)再由上面的层去补充。这就是残差(残留的差距)的概念。

  •  为什么 BN 需要定义两个,而 ReLU 不需要?

BN 是两个独立的层,每个层都有自己需要学的不同的参数,而 ReLU 没有什么学习性,所以公用一个层就可以

  • 训练 ACC 是不是在不 overfitting 的情况下,永远大于测试 ACC?

不一定哦,后面会看到当你做了大量的数据噪音的时候,测试精度会高于训练精度,因为你测试的时候不会添加噪声。

  • 为什么 ResNet 可以训练 100 层网络? 

 假设 g(x) 是在 f(x) 之外新加的一个层,那对于梯度的计算公式根据链式求导法展开,多出来的第一项就是新套的那层的输入和输出求导。假设加的这个层的拟合能力比较强,这一项会很快的变得特别小。一个很小的值乘我们之前那一层的梯度,梯度就会变得比之前小很多。梯度变小之后可以选择增大学习率,但是很有可能增大学习率也没啥用。因为也不能增的太大,f 这一层比 g 更靠近数据,如果增加太大那 g 这一层会变得不稳定。这就是为什么之前模型变深之后会出现梯度消失的问题。

主要原因就是层数叠加,梯度是一直做乘法。回传的时候就会出现底部的梯度特别小。

而 ResNet 是怎么解决这个问题的呢?

因为 ResNet 的网络设计使得它的梯度计算是相加的,哪怕有哪一块比较小也没关系,哪怕当g(x)不存在的时候去拟合,也有 f(x) 的梯度存在。

大数 + 小数没问题,但是 大数*小数问题很大! 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2080221.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

QGIS制图流程

在之前我们推送了QGIS的软件安装、插件安装、数据导入等基础操作,今天我们介绍一下QGIS的制图功能。QGIS的制图与ArcGIS Pro存在一定的区别,但是思路上相似。我们教程内容主要是参考QGIS官方文档: https://docs.qgis.org/3.34/en/docs/user_…

Android 中ebpf 的集成和调试

1. BPF 简介 BPF,是Berkeley Packet Filter的简称,最初构想提出于 1992 年。是一种网络分流器和数据包过滤器,允许早操作系统级别捕获和过滤计算机网络数据包。它为数据链路层提供了一个原始接口,允许发送和接收原始链路层数据包…

安卓中回调函数的使用

在Android开发中,回调函数是一种常见的编程模式,用于在某个任务完成时异步接收通知或数据。它们通常用于处理用户界面事件、完成网络请求、数据库操作或其他长时间运行的任务。回调(Callback)是一种允许某段代码通知另一段代码执行…

机器人学——机械臂轨迹规划-2

直线轨迹 线段转折点速度不连续 加速度状态讨论 double dot 多段直线轨迹,转折点利用二次方程转为圆弧 关键步骤 第一个线段处理 Vt V0 at , 此处的V0 0 , 利用函数连续性,左右速度相等,联立求解 sgn(x):符号函数 最后一个线段…

dubbo:dubbo服务负载均衡、集群容错、服务降级、服务直连配置详解(五)

文章目录 0. 引言1. dubbo负载均衡1.1 负载均衡算法1.2. dubbo负载均衡使用1.3 自定义负载均衡策略 2. dubbo服务容错2.1 8种服务容错策略2.2 自定义容错策略 3. dubbo服务降级(mock)4. dubbo服务直连5. 总结 0. 引言 之前我们讲解了dubbo的基本使用&am…

内部类java

内部类就是定义在一个类里面的类,里面的类可以理解成(寄生),外部类可以理解成(寄主)。 //外部类 public class people{//内部类public class heart{} } 内部类的使用场景、作用 1.当一个事物的内部&…

STM32学习笔记3---ADC,DMA

目录 ADC模拟数字转换器 规则组的四种转换模式 AD单通道 AD多通道 常用代码函数相关 DMA直接存储器 存取(访问) 两个应用 DMA存储器到存储器的转运 ADCDMA ADC模拟数字转换器 stm32数字电路,只有高低电平,无几V电压的概念…

MySQL:常用函数

MySQL:常用函数 日期时间函数字符串函数数学函数加密函数 在MySQL中,存在许多现成的函数,可以简化部分操作,本博客讲解MySQL中的常用函数。 日期时间函数 current_date current_date函数用于输出当前的日期: curren…

一道关于php文件包含的CTF题

一、源码 这是index.php的页面。 点击login后会发现url里多了action的参数&#xff0c;那么我们就可以通过它来获取源码。 ?actionphp://filter/readconvert.base64-encode/resourcelogin.php 再通过base64的解码可以查看源码。 index.php源码&#xff1a; <?php erro…

【编码解码】CyberChef v10.18.9

下载地址 【编码解码神器】CyberChef v10.18.9 在线地址 CyberChef (gchq.github.io) 简介 CyberChef 是一个简单易用的网页应用&#xff0c;&#xff0c;包含了四百多种在线编解码工具。它在浏览器中执行各种“网络安全”操作。这些操作包括简单的 XOR 和 Base64 编码、复…

基于单片机的无线空气质量检测系统设计

本设计以STC89C52单片机为核心&#xff0c;其中包含了温湿度检测模块、光照检测模块、PM2.5检测模块、报警电路、LCD显示屏显示电路、按键输入模块和无线传输模块来完成工作。首先&#xff0c;系统可以通过按键输入模块设置当前的时间和报警值&#xff1b;使用检测模块检测当前…

spring boot(学习笔记第十九课)

spring boot(学习笔记第十九课) Spring boot的batch框架&#xff0c;以及Swagger3(OpenAPI)整合 学习内容&#xff1a; Spring boot的batch框架Spring boot的Swagger3&#xff08;OpenAPI&#xff09;整合 1. Spring boot batch框架 Spring Batch是什么 Spring Batch 是一个…

个人网站免费上线

声明一下&#xff0c;小科用的是natapp&#xff0c;进行的 1.起步-下载安装 去浏览器搜索" natapp "&#xff0c;在官网下载&#xff0c;或者直接 点击下列网站 NATAPP-内网穿透 基于ngrok的国内高速内网映射工具https://natapp.cn/ 打开后下滑找到下载&#xff…

JMeter Plugins之内网插件问题解决

JMeter Plugins之内网插件问题解决 背景 在我司内部进行JMeter工具进行性能脚本开发时&#xff0c;为了提高测试效率&#xff0c;我们会用到部分JMeter提供的插件&#xff0c;但是在我司内网的情况下&#xff0c;我们如果直接点击JMeter界面右上角的插件按钮 弹出来的JMeter…

洛谷刷题(4)

P1089 [NOIP2004 提高组] 津津的储蓄计划 题目描述 津津的零花钱一直都是自己管理。每个月的月初妈妈给津津 300 元钱&#xff0c;津津会预算这个月的花销&#xff0c;并且总能做到实际花销和预算的相同。 为了让津津学习如何储蓄&#xff0c;妈妈提出&#xff0c;津津可以随…

零基础5分钟上手亚马逊云科技 - AI模型内容安全过滤

在上一篇文章中&#xff0c;小李哥带大家深入调研亚马逊云科技AI模型平台Amazon Bedrock热门开发功能&#xff0c;了解了模型平台的文字/图片生成、模型表现评估和模型内容安全审核的实践操作。这次我们将继续介绍如何利用API的形式&#xff0c;利用Python代码的形式对AI模型内…

OpenSearch的快照还原

本次测试选择把索引快照备份到Amazon S3&#xff0c;所以需要使用S3 repository plugin&#xff0c;这个插件添加了对使用 Amazon S3 作为快照/恢复存储库的支持。 OpenSearch集群自带了这个插件&#xff0c;所以无需额外安装。 由于需要和Amazon Web Services打交道&#xf…

工厂数字化转型中工业一体机起到什么作用?

近年来工厂数字化转型成为企业提升竞争力的关键路径。而在这场转型浪潮中&#xff0c;工业一体机扮演着至关重要的角色&#xff0c;它不仅是推动工厂数字化转型的关键工具&#xff0c;更是赋能企业实现更高效、智能、灵活生产的关键要素。 一、工业一体机&#xff1a;连接物理与…

CAN通信之波特率相关配置

由于 CAN 属于异步通讯&#xff0c;没有时钟信号线&#xff0c;连接在同一个总线网络中的各个节点会像串口异步通讯那样&#xff0c;节点间使用约定好的波特率进行通讯。 首先我们要明确几个概念&#xff1a; 波特率&#xff1a;can 1s传输的位数&#xff0c;其单位为bps。 T…

Vue3学习笔记之插槽

目录 前言 一、基础 (一) 默认插槽 (二) 具名插槽 (三) 作用域插槽 (四) 动态插槽 二、实战案例 前言 插槽&#xff08;Slots&#xff09;&#xff1f; 插槽可以实现父组件自定义内容传递给子组件展示&#xff0c;相当于一块画板&#xff0c;画板就是我们的子组件&…