Deep Networks with Stochastic Depth - 动态随机网络

news2024/12/23 13:31:49

文章目录

  • 基本结构
    • ResNet的公式改造
    • 效果计算
    • 前向传播过程
  • 实验结果
    • CIFAR数据集结果
    • SVHN数据集结果
    • 训练时间的比对
    • 极深网络的对比测试
    • ImageNet的测试结果
      • 测试过程中的结果
    • 网络结构的Hyper-parameter比对测试

前面两篇是讲经典网络ResNet的:

ResNet1

ResNet2

这个残差网络结构非常的经典,所以有不少后续的研究都是基于这个网络的改进,这一篇就来说一说其中的一个经典改进:随机动态网络——Stochastic Depth。
论文:Deep Networks with Stochastic Depth
论文提供了基于lua的代码:https://github.com/yueatsprograms/Stochastic_Depth

基本结构

论文的基本出发点还是想解决如下几个深度网络中的关键问题:

  • 梯度消失问题
  • 训练时间过长的问题
  • Diminishing feature reuse,特征重用。其实我理解和梯度消失很类似,就是在前向和反向传播的过程中,信息流(information flow)或者是梯度(反向)在多层网络中传播的时候,经过多个卷积层相乘之后逐渐消失的现象。

解决这个问题的基本逻辑也很简单,就是随机的让ResNet中的一些层失效,也就是在ResNet中,让残差网络的部分失效,直接通过Skip Connection传播(训练过程)。所以,失效的部分以Residual block为单位,也就是一个block失效或者生效。
论文原文:
We start with very deep networks but during training, for each mini-batch, randomly drop a subset of layers and bypass them with the identity function。

还有一段是:

We resolve this conflict by creating deep Residual Network architectures (with hundreds or even thousands of layers) with sufficient modeling
capacity; however, during training we shorten the network significantly by randomly removing a substantial fraction of layers independently for each sample or mini-batch。

和Dropout层的逻辑很像,只是Dropout的逻辑是让反向传播中的一些权重的梯度随机的变成0。而Stochastic Depth方法是让整个残差结构失效,与dropout不同的点还在于:

  • Stochastic Depth可以调整整个网络的深度和宽度。
  • dropout没法和BN层一起使用,或者说合并在一起使用没什么用。
  • Stochastic Depth可以模拟多个层次的网络,相当于可以综合不同深度的ResNet结果。
    原文中的描述为:One explanation for our performance improvements is that training with stochastic depth can be viewed as training an ensemble of ResNets implicitly. Each of the L layers is either active or inactive, resulting in 2 L 2^L 2L possible network combinations. For each training mini-batch one of the 2 L 2^L 2L networks (with shared weights) is sampled and updated.

ResNet的公式改造

论文是基于ResNet的结构来进行改造的。

  • 基本Residual block结构为:Conv-BN-ReLU-Conv-BN。
  • CIFAR的数据集是使用这个结构,ImageNet使用的是称作Bottlenect的residual block。

论文中的重点就是如何确定哪些block失效,哪些不失效。其实逻辑非常简单:

  • 论文中是训练过程中引入了一个Bernoulli random随机概率函数,计作 b l ∈ { 0 , 1 } b_l \in \lbrace 0, 1\rbrace bl{0,1}

  • 每个block的残差部分乘以这个概率,也就是根据这个概率来确定是否生效,论文中称作“survival probability”,计作 p l = P r ( b l = 1 ) p_l = Pr(b_l=1) pl=Pr(bl=1)

  • 基于上述的逻辑,那么ResNet中的那个经典公式就可以改造为:
    H l = R e L U ( b l f l ( H l − 1 ) + i d ( H l − 1 ) ) H_l=ReLU(b_lf_l(H_{l-1})+id(H_{l-1})) Hl=ReLU(blfl(Hl1)+id(Hl1))

  • 这个 p l p_l pl在论文中提出了两种分布函数:

    • p l p_l pl为固定值,每个block的概率是一样的。
    • p l p_l pl为一个线性分布,服从函数:
      p l = 1 − l L ( 1 − p L ) p_l=1 - \frac{l}{L}(1-p_L) pl=1Ll(1pL)
      论文中给出了一个图:

大概的意思就是越到后面的层就越容易失效。

根据上面的逻辑, p L p_L pL就称为了一个超参数,并且论文中通过实验证明, p L = 0.5 p_L=0.5 pL=0.5时效果最好。

效果计算

关于论文要解决的其中一个问题,就是减少训练时间或者说计算量,论文也给出了一个估算方式:
在训练过程中的有效block数可以通过一个随机变量 L ^ \hat L L^来表示(因为其中增加了随机变量 p l p_l pl),这个随机变量的期望就可以表示为:
E ( L ^ ) = ∑ l = 1 L p l E(\hat L)=\sum_{l=1}^Lp_l E(L^)=l=1Lpl

根据上面的 p l p_l pl的线性分布,另外加上 p L = 0.5 p_L=0.5 pL=0.5的话,这个期望计算出来就是:
E ( L ^ ) = ( 3 L − 1 ) / 4 ≈ 3 L / 4 E(\hat L)=(3L-1)/4 \approx 3L/4 E(L^)=(3L1)/43L/4
后面的约等于是在L较大,也就是网络层数较深的情况下成立。

所以,在ResNet-110的网络结构下,总共有L=54(L为block数),那么在训练过程中有效的block数就是 54 ∗ 3 / 4 ≈ 40 54 * 3 /4 \approx 40 543/440。也就是相对于原网络,减少了14个block(约1/4),在训练时的运算量或者说时间。

前向传播过程

因为在反向传播时,一些卷积层中的权重有时是失效的,学到的参数不那么完全,所以在前向传播,也就是推理过程中,也要增加这个随机变量 p l p_l pl
H l T e s t = R e L U ( p l f l ( H l − 1 T e s t , W l ) + H l − 1 T e s t ) H_l^{Test} = ReLU(p_lf_l(H_{l-1}^{Test}, W_l) + H_{l-1}^{Test}) HlTest=ReLU(plfl(Hl1Test,Wl)+Hl1Test)

原文对此的描述是:From the model ensemble perspective, the update rule (5) can be interpreted
as combining all possible networks into a single test architecture, in which each
layer is weighted by its survival probability

实验结果

  • 论文中是使用的ResNet为基本网络。
  • 对于网络中的第一个Res block,也就是 p 0 = 1 p_0=1 p0=1,保证第一个block总是有效。
  • 使用的基本都是线性的概率分布,令 p L = 0.5 p_L=0.5 pL=0.5
  • 在数据集CIFAR-10,CIFAR-100使用的是110层的ResNet(ResNet的第一篇论文中提到过,不过我没有写到我的前一篇文章中,有兴趣的朋友可以去看一下原文,总共3组Block,共18个res blocks,filter number分别为16,32,64)。

总共是在CIFAR-10,CIFAR-100,SVHN,ImageNet四个数据集上对一些网络做了对比,结果如下:

test error result

CIFAR数据集结果

在test error和trainning loss上,把普通的resnet和增加了Stochastic depth的网络做了比较

  • test error有更好的效果
  • trainning loss没有下降的那么快,也就是说对保持梯度有正向的效果,对避免梯度消失问题也有正向的效果。
  • 在两个数据集上有类似的效果

SVHN数据集结果

SVHN数据集是Street View House Number,来自于Google街景。
对于SVHN使用的是ResNet-152的网络结构。

  • 和CIFAR类似的结论,只是没那么明显
  • 与普通的ResNet网络相比,Stochastic depth还不会出现过拟合的现象。普通的ResNet网络在第30个epoch往后就出现了过拟合的现象。

训练时间的比对


这个不用多说,训练时间减少了1/4,和前面计算的差不多。

极深网络的对比测试

在ResNet中,作者提出了一个1202的网络,但是表现不佳,出现了过拟合的现象,还没有101层的网络表现好。在这篇论文中,作者把Stochastic depth结构增加到了这个极深网络中,解决了这个过拟合的问题。

  • 300个epoch的训练
  • 前10个epoch的学习率为0.01,后面恢复成0.1

  • 普通的resnet,深度从101到1202的话,test error从6.41%增加到了6.67%。
  • Stochastic depth网络,深度从101到1202的话,test error从5.25%下降到了4.91%。
  • 也就是说,在极深的网络中使用Stochastic depth是会有帮助的,可以使用非常深的网络来拟合数据了。
  • 网络越深,节省的计算时间也就越多。

ImageNet的测试结果

使用的是ResNet-152的网络深度,而且使用的是bottleneck的残差结构。
论文中的实验在ImageNet上没有前面几个数据集一样好的结果。

测试过程中的结果

  • 在大概epoch=90的时候,Stochastic depth网络的错误率还是要高出一点点。
  • 但是论文指出说,我既然节省了约1/4的时间,那么我多训练30个epoch,到达120个epoch,错误率就可以下降到21.98%。
  • 然后再多训练30个epoch,也就是训练150个epoch,错误率就可以下降到21.78%。
  • 从上面的结论来看,作者的意思是对于ImageNet这种大数据集,必须继续增加深度来获得更好的结果:Although there seems to be no immediate benefit from applying stochastic depth on this particular architecture, it is possible that stochastic depth will lead to improvements on ImageNet with larger models。

网络结构的Hyper-parameter比对测试

论文中还对网络结构中的Hyper-parameter,也就是 p L p_L pL进行了一些比对测试。这个测试是基于ResNet-110的。
比对了使用固定 p L p_L pL和线性 p L p_L pL的结果:

  • p L p_L pL设置的合适的话,比基本的ResNet结果要好。
  • linear decay rule比uniform要好。
  • 最佳区域在0.4-0.8之间
  • 0.5是最好的,所以上面的测试网络都是用的0.5这个值。

论文中还提供了一个热点图:

这个图提供的是网络深度和超参数 p L p_L pL之间的关系。证明足够深的网络使用stochastic是有效的,即使深度不够,也会有一定的帮助。
原文为:A deep enough model is necessary for stochastic depth to significantly outperform the baseline (an observation we also make with the ImageNet data
set), although shorter networks can still benefit from less aggressive skipping。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/847223.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

中介者模式(Mediator)

中介者模式是一种行为设计模式,可以减少对象之间混乱无序的依赖关系。该模式会限制对象之间的直接交互,迫使它们通过一个封装了对象间交互行为的中介者对象来进行合作,从而使对象间耦合松散,并可独立地改变它们之间的交互。中介者…

【设计模式——学习笔记】23种设计模式——中介者模式Observer(原理讲解+应用场景介绍+案例介绍+Java代码实现)

文章目录 案例引入案例一普通实现中介者模式 案例二 介绍基础介绍登场角色尚硅谷 《图解设计模式》 案例实现案例一:智能家庭类图实现 案例二:登录页面逻辑实现说明类图实现 总结文章说明 案例引入 案例一 普通实现 在租房过程中,客户可能…

回看雷军演讲,我对项目经理人的发展又有了2点想法……

大家好,我是老原。 最近在回看一些大佬演讲,看到雷军的演讲时,说实话我蛮激动的,最近常有粉丝朋友找我闲聊,也工作十几年,不知道怎么才能寻求突破? 而小米本身,不就是一直在突破吗…

C高级-day4

#!/bin/bash function fun1(){arr[0]id -u $1arr[1]id -g $1echo ${arr[*]} }arr(fun1 ubuntu) echo ${arr[*]}冒泡排序 void Maopao(int arr[],int len){for(int i1;i<len;i){int count0;for(int j0;j<len-i;j){if(arr[j]>arr[j1]){int tarr[j];arr[j]arr[j1];arr[j…

嵌入式Linux下LVGL的移植与配置

一.sdk源码下载路径 1.官方源码下载路径如下: ​​​​​​ https://github.com/lvgl/lvgl git下载方式 git clone https://github.com/lvgl/lvgl.git 2.个人移植好的源码8.2版本下载路径: 链接:https://pan.baidu.com/s/1jyqIennsQpv-RB4RyKvZyg?pwd=c68e 提取码:c68e…

《HeadFirst设计模式(第二版)》第七章代码——外观模式

代码文件目录&#xff1a; Subsystem: Amplifier package Chapter7_AdapterAndFacadePattern.FacadePattern.Subsystem;/*** Author 竹心* Date 2023/8/8**///扬声器 public class Amplifier {int volume 0;//音量public void on(){System.out.println("The amplifier …

redis 的副本和分片

什么是分片 分片也叫条带&#xff0c;指Redis集群的一个管理组&#xff0c;对应一个redis-server进程。一个Redis集群由若干条带组成&#xff0c;每个条带负责若干个slot&#xff08;槽&#xff09;&#xff0c;数据分布式存储在slot中。Redis集群通过条带化分区&#xff0c;实…

数据库优化:探索 SQL 中的索引

推荐&#xff1a;使用 NSDT场景编辑器 助你快速搭建可编辑的3D应用场景 在一本书中搜索特定主题时&#xff0c;我们将首先访问索引页面&#xff08;该页面位于该书的开头&#xff09;&#xff0c;并找到包含我们感兴趣的主题的页码。现在&#xff0c;想象一下在没有索引页的书中…

python_面向对象基础_数据分析

主要目的 对于文本格式和JSON格式数据进行分析&#xff0c;将其中数据提炼出来绘制折线图。 主要实现步骤 1.设计一个完成对数据的封装 2.设计一个抽象类,定义数据读取相关功能,使用其子类实现具体功能 3.读取文件,生成数据对象 4.进行数据计算 5.绘制图表 定义数据封装类 &…

Annotorious.js 入门教程:图片注释工具

theme: smartblue 本文简介 【今天我必须发一个封面&#xff01;放文末&#xff01;】 最近有工友问我前端怎么给图片做标注。使用 Fabric.js 或者 Konva.js 等库确实可以实现&#xff0c;但多少觉得有点大炮打蚊的感觉&#xff0c;好奇有没有专门做图片标注的工具呢&#xff1…

剑指offer(C++)-JZ16:数值的整数次方(算法-位运算)

作者&#xff1a;翟天保Steven 版权声明&#xff1a;著作权归作者所有&#xff0c;商业转载请联系作者获得授权&#xff0c;非商业转载请注明出处 题目描述&#xff1a; 实现函数 double Power(double base, int exponent)&#xff0c;求base的exponent次方。 注意&#xff1…

stm32项目(5)——基于stm32的工地噪声扬尘检测系统

目录 1.功能设计 2.硬件方案 1.单片机选择 2.声音传感器 3.PM2.5传感器 4.显示器 3.程序设计 4.课题意义 1.功能设计 本次系统实现的功能如下所示&#xff1a; 采用声音传感器检测环境噪音&#xff0c;采用PM2.5传感器检测环境灰尘浓度。若噪声超过阈值或者PM2.5超过阈…

ROS Navigation Stack安装

Navigation导航包是做导航几乎都要用的&#xff0c;大家可以先去ROS Wiki上学习下 我们先Git下对应版本的软件包&#xff0c;我是Kinetic的&#xff0c;所以是Kinetic-devel 下载后发现目录下并没有CMakeLists.txt&#xff0c;所以直接在ROS工作目录下catkin_make并不会产生可…

【FPGA】UART串口通信——奇偶校验实现

文章目录 一、奇偶校验位二、设计思路三、仿真测试 一、奇偶校验位 奇偶校验位是基于uart的数据上进行一个判断 奇校验&#xff1a;数据1个数为奇时&#xff0c;校验为0&#xff0c;反之为1 偶校验&#xff1a;数据0个数为偶时&#xff0c;校验为0&#xff0c;反之为1 Uart…

MySQL 事务原理:事务概述、隔离级别、MVCC

文章目录 一、事务1.1 事务概述1.2 事务控制语句1.3 ACID特性 二、隔离级别2.1 隔离级别的分类2.1.1 读未提交&#xff08;RU&#xff09;2.1.2 读已提交&#xff08;RC&#xff09;2.1.3 可重复读&#xff08;RR&#xff09;2.1.4 串行化 2.2 命令2.3 并发读异常2.3.1 脏读2.3…

Babylon.js着色器简明简称【Shader】

推荐&#xff1a;用 NSDT设计器 快速搭建可编程3D场景 为了生成 BabylonJS 场景&#xff0c;需要用 Javascript 编写代码&#xff0c;BabylonJS 引擎会处理该代码并将结果显示在屏幕上。 场景可以通过改变网格、灯光或摄像机位置来改变。 为了及时显示可能的变化&#xff0c;屏…

借助gopsutil库,获取机器相关信息

使用github.com/shirou/gopsutil/disk这个库&#xff0c;如何获取机器下不同磁盘分区的内容 使用 github.com/shirou/gopsutil/disk 库获取机器下不同磁盘分区的内容&#xff0c;可按如下&#xff1a; import "github.com/shirou/gopsutil/disk"//调用 disk.Partitio…

【瑞吉外卖】Git部分学习

Git简介 Git是一个分布式版本控制工具&#xff0c;通常用来对软件开发过程中的源代码文件进行管理。通过Git仓库来存储和管理这些文件&#xff0c;Git仓库分为两种&#xff1a; 本地仓库&#xff1a;开发人员自己电脑上的Git仓库 远程仓库&#xff1a;远程服务器上的Git仓库…

git原理与使用

目录 引入基本操作分支管理远程操作标签管理 引入 假设你的老板要你设计一个文档&#xff0c;当你设计好了&#xff0c;拿给他看时&#xff0c;他并不是很满意&#xff0c;就要你拿回去修改&#xff0c;你修改完后&#xff0c;再给他看时&#xff0c;他还是不满意&#xff0c;…

ERP、APS、MES 三者之间的关系

ERP&#xff08;Enterprise Resource Planning&#xff09; APS&#xff08;Advanced Planning and Scheduling&#xff09; MES&#xff08;Manufacturing Execution System&#xff09; 这是三种不同类型的软件系统&#xff0c;它们主要用于企业内部管理和自动化运营流程。…