GoogleLeNet V2 V3 —— Batch Normalization

news2024/9/28 21:23:50

文章目录

  • Batch Normalization
    • internal covariate shift
    • 激活层的作用
    • BN执行的位置
    • 数据白化
    • 网络中的BN层
    • 训练过程
  • BN的实验效果
    • MNIST
    • 与GoogleLeNet V1比较

GoogleLeNet出来之后,Google在这个基础上又演进了几个版本,一般来说是说有4个版本,之前的那个是V1,然后有一个V2,V3和V4。
其实我个人感觉V2和V3应该是在一起的,都是综合了两篇论文中的一些改进点来的:

  • Accelerating deep network training by reducing internal covariate shift
  • Rethinking the Inception Architecture for Computer Vision

其中,第一篇是提出了一个重要的概念:Batch Normalization,是针对内部协变量偏移问题的,简单的说就是加速训练过程。把BN作为激活层之前的另外一个网络层,可以加速网络训练的收敛速度。
第二篇就提出了一些新的卷积方法等,然后总和第一篇论文一起就提出了一个inception v2的网络结构,没有明确提到v3,但是其中的一些变形作为了v3版本。
我们就来看一下这两篇论文说了点啥,这个v2和v3又改进了点啥。

Batch Normalization

internal covariate shift

讲BN之前,肯定要说说BN到底是解决一个什么问题,在论文中提到的就是internal covariate shift问题,翻译过来是内部协变量偏移。不明觉厉,这个看不太懂是什么东西。
原文中的描述为:
Training Deep Neural Networks is complicated by the fact that the distribution of each layer’s inputs changes during training, as the parameters of the previous layers change.
This slows down the training by requiring lower learning rates and careful parameter initialization, and makes it notoriously hard to train models with saturating nonlinearities. We refer to this phenomenon as internal covariate shift, and address the problem by normalizing layer inputs。
大致意思是在训练的反向传播过程中,每个输入数据的分布情况回发生变化,在计算损失的之后,这一层的输出也会发生变化,从而导致下一层的输入数据分布发生变化。这种情况就叫做内部协变量偏移。
简单点说就是网络的隐藏层数据分布变化很大,容易出现梯度消失和梯度爆炸,导致训练过程很难收敛。一个梯度一下大到天上,一下就等于0,确实很难收敛。
那么BN的基本逻辑就是针对每个训练的batch数据,在每个激活层(Sigmond或者ReLU之类)前增加一个BN层,也就是做一次数据标准化,把上一层的输出线性变化到一个固定的分布内(fixed distribution)

激活层的作用

这里增加一点,就是之前一直没太弄明白激活层的作用。看完这篇论文之后大概了解了。整个网络,中间基本都是卷积和全连接层,不管是卷积还是全连接层,都是针对前一层数据的一种线性变换。也就是前一层数据的一种多项式变化,如果中间没有激活层的话,那么实际上无论增加多少层,都可以简化成一层,因为线性变化是可以叠加的。
举个例子,如果第一层的处理是:
F ( x ) = 2 x + 3 F(x)=2x+3 F(x)=2x+3

第二层的处理是:
H ( x ) = 4 x − 4 H(x)=4x-4 H(x)=4x4
这里的x就是上一层的 F ( x ) F(x) F(x),所以就是
H ( x ) = 4 ( 2 x + 3 ) − 4 = 8 x − 8 H(x)=4(2x+3)-4=8x-8 H(x)=4(2x+3)4=8x8
那么就可以简化成一层。复杂的线性变化也是一样的。但是如果增加了激活层的话,就不一样了,激活层是非线形函数,不满足
f ( x + y ) = f ( x ) + f ( y ) f(x+y)=f(x) + f(y) f(x+y)=f(x)+f(y),所以就不存在上述的变换。
这样就可以增强模型的表达能力(reprensetation power),就是对数据分布的拟合能力。
所以基本上,在网络结构里,每个卷积层后面都会跟一个非线性层(池化或者激活)。

BN执行的位置

论文中的描述是:To Batch-Normalize a network, we specify a subset of activations and insert the BN transform for each of them。
增加在所有的激活层之前。

数据白化

论文中提到:By fixing the distribution of the layer inputs x as the training progresses, we expect to improve the training speed. It has been long known that the network training converges faster if its inputs are whitened, linearly transformed to have zero
means and unit variances, and decorrelated。
这里提到就是利用了LeCun 1998年的论文中提到的,白化(whiten)的输入数据可以加速训练。而这个白化数据就是指数据分布符合均值为0,方差为1。
白化过程为:一个d维的矢量样本( x = ( x ( 1 ) , x ( 2 ) . . . . x ( d ) ) x=(x^{(1)},x^{(2)}....x^{(d)}) x=(x(1),x(2)....x(d)))的白化过程:
x ^ ( k ) = x ( k ) − E [ x ( k ) ] V a r [ x ( k ) ] \hat{x}^{(k)}=\frac{x^{(k)}-E[x^{(k)}]}{\sqrt{Var[x^{(k)}]}} x^(k)=Var[x(k)] x(k)E[x(k)]
把每一维计算完成之后就形成了服从0-1分布的 x ^ \hat{x} x^向量。

网络中的BN层

在白化之后,实际上还需要做一个线性变换:
y ( k ) = γ ( k ) x ^ ( k ) + β ( k ) y^{(k)}=\gamma^{(k)}\hat{x}^{(k)}+\beta^{(k)} y(k)=γ(k)x^(k)+β(k)
至于为什么要增加这么一个动作,论文中是说:
Note that simply normalizing each input of a layer may change what the layer can represent. For instance, normalizing the inputs of a sigmoid would constrain them to the linear regime of the nonlinearity.
我理解是直接标准化会降低网络的表达能力,可能是直接强行拉到一个0-1的分布,会造成一些损失吧。所以可以做一些拉伸和偏移(正态分布的那个图做一些拉伸和偏移),然后在学习的过程中去动态的调整这两个参数 γ \gamma γ β \beta β。也就是学习到底是拉伸多少,偏移多少能更好的拟合数据。

上面的数据白话相当于是把一个样本作了标准化,然后需要把一个训练batch的数据一起做标准化。
论文中是说:since we use mini-batches in stochastic gradient training, of the mean and variance each mini-batch produces estimates of each activation。
也就是说为每个批次也要做一个normalization。

计算方法为:

从上图可以看出来,前面三步就是对数据作了一次白化(类似),只是方向上需要理解一下。
比如输入为一张图像,图像为 p ∗ q p * q pq宽,总共有m张图像,那么这里的向量 x x x的长度就是 m m m,也就是沿着图像数的方向。总共有p✖️q个这样的向量。也就是要做p✖️q次normalization计算。

这样就完成了BN层的计算。计算之后相当于这个批次中的图像中的每个像素都是服从同一分布的,但是 γ \gamma γ β \beta β不同。

训练过程

以SGD,随机梯度下降的反向传播算法来说:
从输出层开始,计算完loss和随机梯度后,就会向后传播,那么这个BN层也是需要传播的。

通过下图,就可以计算出 γ \gamma γ β \beta β每次的更新量 Δ γ \Delta \gamma Δγ Δ β \Delta \beta Δβ

整个训练过程为:

针对每一个BN层,通过上述的计算过程进行训练。

BN的实验效果

MNIST

在手写上与最古老的LeNet比较,达到同样的精确度,训练次数大大减少。

与GoogleLeNet V1比较

针对V1做了一些改动

  • 增大学习率
  • 去掉DropOut
  • 去掉LRN
  • 重新打乱训练集
  • 减少图像的扩展

做了下面几个模型的比对

  • 基于上面改动,增加了BN层的基本模型:BN-Baseline
  • Baseline的基础上,学习率提升5倍到 0.0075:BN-x5
  • 学习率提升30倍到0.045:BN-x30
  • 激活层使用Sigmond,5倍学习率的BN-x5-Sigmond

结论是在BN-x5的情况下,达到v1版本的精确率,训练次数最少。
而BN-30可以达到更高的精度,但是训练次数要多一点。
Sigmond根本达不到这个精度,BN更适用于ReLU激活层。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/804631.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

12-1_Qt 5.9 C++开发指南_自定义插件和库-自定义Widget组件(提升法(promotion)创建自定义定制化组件)

当UI设计器提供的界面组件不满足实际设计需求时,可以从 QWidget 继承自定义界面组件。 有两种方法使用自定义界面组件: 一种是提升法(promotion),例如在8.3 节将一个QGraphicsView组件提升为自定义的 QWGraphicsView 类,提升法用…

python将UTC +8 时间 转换为 UTC 时间

因为在工作的时候,有时候经常使用 UTC 时间,因为北京时间是 UTC 8,有时候要自己换算一下,或者 时间戳转换的时候有问题,所以就写了这个。 import time from datetime import datetime import pytz# 输入时间字符串 # …

LLM当前状态和潜在影响;谷歌Brain2Music读取大脑活动生成音乐

🦉 AI新闻 🚀 谷歌Brain2Music利用AI读取大脑活动生成音乐 摘要:谷歌发布了名为Brain2Music的论文,通过人工智能和脑部成像技术生成个性化音乐。他们招募了5名志愿者,记录他们在听不同音乐类型时的大脑活动数据。通过…

刷完阿里 P8 面试官推荐的 Java 高并发核心编程文档后终拿蚂蚁 offer

前言 学完阿里 P8 面试官推荐的 Java 高并发核心编程文档后终于拿到了蚂蚁 p6 的 offer,这份文档包含的内容有点多。 主要包含的内容:Java NIO、Reactor 模式、高性能通信框架 Netty、分布式锁、分布式 ID、分布式缓存、高并发架构、多线程、线程池、内…

C语言IO篇(一) 输出百分号

1.百分号输出问题是什么? C语言中无法直接打印单个的%。 2.怎么解决百分号输出问题? 在C语言中,如何输出百分号呢? 1.在printf中用2个连续 %% 输出百分号。 2.将内容写入到字符串后打印 3.为什么出现百分号输出问题? …

install wxwidgets and wxPython on Linux

安装wxwidgets https://wiki.wxwidgets.org/Compiling_and_getting_startedhttps://wiki.wxwidgets.org/Compiling_and_getting_started 安装wxPython pip install wxPython 安装wxformbuilderhttps://github.com/wxFormBuilder/wxFormBuilder/releaseshttps://github.com/wx…

通达信赫尔均线 (HMA) 指标公式设置及原理详解

我们知道传统的均线存在短周期均线(如5日均线)灵敏但不够平滑,大周期均线(如120日均线)平滑但反应滞后、无法及时反映当前行情变化的缺点。(如下图)赫尔均线 (HMA) 正是为了解决这样的问题&…

AtcoderABC229场

A - First GridA - First Grid 题目大意 要求判断是否可以从每个黑色方块到达其他所有黑色方块,只能经过黑色方块,并且黑色方块之间必须相连(共享一条边)。 思路分析 据题意,不能的只有以下两种情况 .# #. #. .#…

交互式AI技术与模型部署:bert-base-chinese模型交互式问答界面设置

使用Gradio实现Question Answering交互式问答界面,首先你需要有一个已经训练好的Question Answering模型,这里你提到要使用bert-base-chinese模型。 Gradio支持PyTorch和TensorFlow模型,所以你需要将bert-base-chinese模型转换成PyTorch或Te…

双击start.bat文件闪退,运行报错“unable to access jarfile”

问题:电脑运行“start.bat”文件,无反应,闪退,管理员身份运行报错“unable to access jarfile” 解决思路: 1、由于该项目运行需要jdk环境,检查jdk版本需要是1.8.0_251版本 通过在 cmd 命令行输入java -v…

unittest 数据驱动DDT应用

前言 一般进行接口测试时,每个接口的传参都不止一种情况,一般会考虑正向、逆向等多种组合。所以在测试一个接口时通常会编写多条case,而这些case除了传参不同外,其实并没什么区别。 这个时候就可以利用ddt来管理测试数据&#xf…

智慧城市环境污染数据采集远程监控方案4G工业路由器应用

随着科技水平的发展和人民生活水平的提高,城市环境污染问题日渐严峻,尤其是在发展迅速的国家,环境污染问题便更为突出。许多发达国家将重污染工厂搬到发展中国家,这导致发展中国家的环境污染日益严重。严重的环境污染也带来了一系…

青海师范大学迎来首个虚拟IP“卓玛”,这是要怎么赋能数字教育?

虚拟数字人作为虚拟世界和现实世界之间的链接 是人们最先了解元宇宙概念的一个表达形式 相较于传统的高校吉祥物IP 在数字教育时代 爆火的虚拟数字人 会给高校带来怎样的新体验、新机遇 虚拟IP对高校的必要性 新潮化、高互动化赋能数字教育 1. 以可视化虚拟IP&#xff0c…

vue计时器

//将秒转化为时分秒 const resultTime ref();const formateSeconds function (endTime) {let secondTime parseInt(endTime); //将传入的秒的值转化为Numberlet min 0; // 初始化分let h 0; // 初始化小时// let result "";if (secondTime > 60) {//如果秒数…

Qtday4作业

思维导图 2.手动完成服务器的实现 头文件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include<QTcpServer> // 服务器类 #include<QTcpSocket> // 客户端类 #include<QDebug> // 信息调试类 #include<QMessageBox> …

Chapter 8: Files | Python for Everybody 讲义笔记_En

文章目录 Python for Everybody课程简介FilesPersistenceOpening filesText files and linesReading filesSearching through a fileLetting the user choose the file nameUsing try, except, and openWriting filesDebuggingGlossary Python for Everybody Exploring Data Us…

利用vscode--sftp,将本地项目/文件上传到远程服务器中详细教程

1、首先在 vscode 中下载 sftp&#xff1a; 2、然后在 vscode 中打开本地将要上传的项目或文件&#xff1a; 3、安装完后&#xff0c;使用快捷键 ctrlshiftP 打开指令窗口&#xff0c;输入 sftp:config &#xff0c;回车&#xff0c;在当前目录中会自动生成 .vscode 文件夹及 s…

如何拥有一个自己的小程序商城?

在今天的移动互联网时代&#xff0c;拥有一个自己的小程序商城已经成为了很多企业和个人的追求。它不仅可以帮助企业提升品牌形象和销售额&#xff0c;还能够提供更好的用户体验和更高的用户粘性。那么&#xff0c;如何拥有自己的小程序商城呢&#xff1f; 第一步&#xff1a;选…

三. 多传感器标定方案(空间同步)--1

前面的内容&#xff1a; 一. 器件选型心得&#xff08;系统设计&#xff09;--1_goldqiu的博客-CSDN博客 一. 器件选型心得&#xff08;系统设计&#xff09;--2_goldqiu的博客-CSDN博客二. 多传感器时间同步方案&#xff08;时序闭环&#xff09;--1 三. 多传感器标定方案&…

Bugs记录

一、/usr/bin/ld: cannot find -l**** 参考&#xff1a;https://www.cnblogs.com/sakuraie/p/13341508.html 在ubuntu上安装软件时&#xff0c;经常出现这样的问题&#xff1a; /usr/bin/ld: cannot find -l**** 例如&#xff1a; /usr/bin/ld: cannot find -lcaffe 安装 需…