理解BatchNormalization层的作用

news2024/11/19 1:40:12

深度学习


文章目录

  • 深度学习
  • 前言
  • 一、“Internal Covariate Shift”问题
  • 二、BatchNorm的本质思想
  • 三、训练阶段如何做BatchNorm
  • 四、BatchNorm的推理(Inference)过程
  • 五、BatchNorm的好处
  • 六、机器学习中mini-batch和batch有什么区别


前言

Batch Normalization作为最近一年来DL的重要成果,已经广泛被证明其有效性和重要性。虽然有些细节处理还解释不清其理论原因,但是实践证明好用才是真的好,别忘了DL从Hinton对深层网络做Pre-Train开始就是一个经验领先于理论分析的偏经验的一门学问。本文是对论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》的导读。
机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。那BatchNorm的作用是什么呢?BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。

接下来一步一步的理解什么是BN。

为什么深度神经网络随着网络深度加深,训练起来越困难,收敛越来越慢?这是个在DL领域很接近本质的好问题。很多论文都是解决这个问题的,比如ReLU激活函数,再比如Residual Network,BN本质上也是解释并从某个不同的角度来解决这个问题的。

一、“Internal Covariate Shift”问题

从论文名字可以看出,BN是用来解决“Internal Covariate Shift”问题的,那么首先得理解什么是“Internal Covariate Shift”?

论文首先说明Mini-Batch SGD相对于One Example SGD的两个优势:梯度更新方向更准确;并行计算速度快;(为什么要说这些?因为BatchNorm是基于Mini-Batch SGD的,所以先夸下Mini-Batch SGD,当然也是大实话);然后吐槽下SGD训练的缺点:超参数调起来很麻烦。(作者隐含意思是用BN就能解决很多SGD的缺点)

接着引入covariate shift的概念:如果ML系统实例集合<X,Y>中的输入值X的分布老是变,这不符合IID假设,网络模型很难稳定的学规律,这不得引入迁移学习才能搞定吗,我们的ML系统还得去学习怎么迎合这种分布变化啊。对于深度学习这种包含很多隐层的网络结构,在训练过程中,因为各层参数不停在变化,所以每个隐层都会面临covariate shift的问题,也就是在训练过程中,隐层的输入分布老是变来变去,这就是所谓的“Internal Covariate Shift”,Internal指的是深层网络的隐层,是发生在网络内部的事情,而不是covariate shift问题只发生在输入层。

然后提出了BatchNorm的基本思想:能不能让每个隐层节点的激活输入分布固定下来呢?这样就避免了“Internal Covariate Shift”问题了。

BN不是凭空拍脑袋拍出来的好点子,它是有启发来源的:之前的研究表明如果在图像处理中对输入图像进行白化(Whiten)操作的话——所谓白化,就是对输入数据分布变换到0均值,单位方差的正态分布——那么神经网络会较快收敛,那么BN作者就开始推论了:图像是深度神经网络的输入层,做白化能加快收敛,那么其实对于深度网络来说,其中某个隐层的神经元是下一层的输入,意思是其实深度神经网络的每一个隐层都是输入层,不过是相对下一层来说而已,那么能不能对每个隐层都做白化呢?这就是启发BN产生的原初想法,而BN也确实就是这么做的,可以理解为对深层神经网络每个隐层神经元的激活值做简化版本的白化操作。

二、BatchNorm的本质思想

BN的基本思想其实相当直观:因为深层神经网络在做非线性变换前的激活输入值(就是那个x=WU+B,U是输入)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值WU+B是大的负值或正值),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因,而BN就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。

THAT’S IT。其实一句话就是:对于每个隐层神经元,把逐渐向非线性函数映射后向取值区间极限饱和区靠拢的输入分布强制拉回到均值为0方差为1的比较标准的正态分布,使得非线性变换函数的输入值落入对输入比较敏感的区域,以此避免梯度消失问题。因为梯度一直都能保持比较大的状态,所以很明显对神经网络的参数调整效率比较高,就是变动大,就是说向损失函数最优值迈动的步子大,也就是说收敛地快。BN说到底就是这么个机制,方法很简单,道理很深刻。

上面说得还是显得抽象,下面更形象地表达下这种调整到底代表什么含义。

在这里插入图片描述
图1 几个正态分布

假设某个隐层神经元原先的激活输入x取值符合正态分布,正态分布均值是-2,方差是0.5,对应上图中最左端的浅蓝色曲线,通过BN后转换为均值为0,方差是1的正态分布(对应上图中的深蓝色图形),意味着什么,意味着输入x的取值正态分布整体右移2(均值的变化),图形曲线更平缓了(方差增大的变化)。这个图的意思是,BN其实就是把每个隐层神经元的激活输入分布从偏离均值为0方差为1的正态分布通过平移均值压缩或者扩大曲线尖锐程度,调整为均值为0方差为1的正态分布。

那么把激活输入x调整到这个正态分布有什么用?首先我们看下均值为0,方差为1的标准正态分布代表什么含义:

在这里插入图片描述
图2 均值为0方差为1的标准正态分布图

这意味着在一个标准差范围内,也就是说64%的概率x其值落在[-1,1]的范围内,在两个标准差范围内,也就是说95%的概率x其值落在了[-2,2]的范围内。那么这又意味着什么?我们知道,激活值x=WU+B,U是真正的输入,x是某个神经元的激活值,假设非线性函数是sigmoid,那么看下sigmoid(x)其图形:
在这里插入图片描述
图3. Sigmoid(x)

及sigmoid(x)的导数为:G’=f(x)*(1-f(x)),因为f(x)=sigmoid(x)在0到1之间,所以G’在0到0.25之间,其对应的图如下:

在这里插入图片描述
图4 Sigmoid(x)导数图

假设没有经过BN调整前x的原先正态分布均值是-6,方差是1,那么意味着95%的值落在了[-8,-4]之间,那么对应的Sigmoid(x)函数的值明显接近于0,这是典型的梯度饱和区,在这个区域里梯度变化很慢,为什么是梯度饱和区?请看下sigmoid(x)如果取值接近0或者接近于1的时候对应导数函数取值,接近于0,意味着梯度变化很小甚至消失。而假设经过BN后,均值是0,方差是1,那么意味着95%的x值落在了[-2,2]区间内,很明显这一段是sigmoid(x)函数接近于线性变换的区域,意味着x的小变化会导致非线性函数值较大的变化,也即是梯度变化较大,对应导数函数图中明显大于0的区域,就是梯度非饱和区。

从上面几个图应该看出来BN在干什么了吧?其实就是把隐层神经元激活输入x=WU+B从变化不拘一格的正态分布通过BN操作拉回到了均值为0,方差为1的正态分布,即原始正态分布中心左移或者右移到以0为均值,拉伸或者缩减形态形成以1为方差的图形。什么意思?就是说经过BN后,目前大部分Activation的值落入非线性函数的线性区内,其对应的导数远离导数饱和区,这样来加速训练收敛过程。

但是很明显,看到这里,稍微了解神经网络的读者一般会提出一个疑问:如果都通过BN,那么不就跟把非线性函数替换成线性函数效果相同了?这意味着什么?我们知道,如果是多层的线性函数变换其实这个深层是没有意义的,因为多层线性网络跟一层线性网络是等价的。

文章中举了个例子,在sigmoid激活函数的中间部分,函数近似于一个线性函数(如下图所示),使用BN后会使归一化后的数据仅使用这一段线性的部分。
在这里插入图片描述
可以看到,在[0.2, 0.8]范围内,sigmoid函数基本呈线性递增,甚至在[0.1, 0.9]范围内,sigmoid函数都是类似于线性函数的,如果只用这一段,那网络不就成了线性网络了么,这显然不是大家愿意见到的。

这意味着网络的表达能力下降了,这也意味着深度的意义就没有了。所以BN为了保证非线性的获得,对变换后的满足均值为0方差为1的x又进行了scale加上shift操作(y=scale*x+shift),每个神经元增加了两个参数scale和shift参数,这两个参数是通过训练学习到的,意思是通过scale和shift把这个值从标准正态分布左移或者右移一点并长胖一点或者变瘦一点,每个实例挪动的程度不一样,这样等价于非线性函数的值从正中心周围的线性区往非线性区动了动。核心思想应该是想找到一个线性和非线性的较好平衡点,既能享受非线性的较强表达能力的好处,又避免太靠非线性区两头使得网络收敛速度太慢。当然,这是我的理解,论文作者并未明确这样说。但是很明显这里的scale和shift操作是会有争议的,因为按照论文作者论文里写的理想状态,就会又通过scale和shift操作把变换后的x调整回未变换的状态,那不是饶了一圈又绕回去原始的“Internal Covariate Shift”问题里去了吗,感觉论文作者并未能够清楚地解释scale和shift操作的理论原因。

三、训练阶段如何做BatchNorm

上面是对BN的抽象分析和解释,具体在Mini-Batch SGD下做BN怎么做?其实论文里面这块写得很清楚也容易理解。为了保证这篇文章完整性,这里简单说明下。

假设对于一个深层神经网络来说,其中两层结构如下:
  在这里插入图片描述
图5 DNN其中两层

要对每个隐层神经元的激活值做BN,可以想象成每个隐层又加上了一层BN操作层,它位于X=WU+B激活值获得之后,非线性函数变换之前,其图示如下:
  在这里插入图片描述
图6. BN操作

对于Mini-Batch SGD来说,一次训练过程里面包含m个训练实例,其具体BN操作就是对于隐层内每个神经元的激活值来说,进行如下变换:
  在这里插入图片描述
  要注意,这里t层某个神经元的x(k)不是指原始输入,就是说不是t-1层每个神经元的输出,而是t层这个神经元的线性激活x=WU+B,这里的U才是t-1层神经元的输出。变换的意思是:某个神经元对应的原始的激活x通过减去mini-Batch内m个实例获得的m个激活x求得的均值E(x)并除以求得的方差Var(x)来进行转换。

上文说过经过这个变换后某个神经元的激活x形成了均值为0,方差为1的正态分布,目的是把值往后续要进行的非线性变换的线性区拉动,增大导数值,增强反向传播信息流动性,加快训练收敛速度。但是这样会导致网络表达能力下降,为了防止这一点,每个神经元增加两个调节参数(scale和shift),这两个参数是通过训练来学习到的,用来对变换后的激活反变换,使得网络表达能力增强,即对变换后的激活进行如下的scale和shift操作,这其实是变换的反操作:
在这里插入图片描述
BN其具体操作流程,如论文中描述的一样:

在这里插入图片描述
过程非常清楚,就是上述公式的流程化描述,这里不解释了,直接应该能看懂。

四、BatchNorm的推理(Inference)过程

BN在训练的时候可以根据Mini-Batch里的若干训练实例进行激活数值调整,但是在推理(inference)的过程中,很明显输入就只有一个实例,看不到Mini-Batch其它实例,那么这时候怎么对输入做BN呢?因为很明显一个实例是没法算实例集合求出的均值和方差的。这可如何是好?

既然没有从Mini-Batch数据里可以得到的统计量,那就想其它办法来获得这个统计量,就是均值和方差。可以用从所有训练实例中获得的统计量来代替Mini-Batch里面m个训练实例获得的均值和方差统计量,因为本来就打算用全局的统计量,只是因为计算量等太大所以才会用Mini-Batch这种简化方式的,那么在推理的时候直接用全局统计量即可。

决定了获得统计量的数据范围,那么接下来的问题是如何获得均值和方差的问题。很简单,因为每次做Mini-Batch训练时,都会有那个Mini-Batch里m个训练实例获得的均值和方差,现在要全局统计量,只要把每个Mini-Batch的均值和方差统计量记住,然后对这些均值和方差求其对应的数学期望即可得出全局统计量,即:
在这里插入图片描述
(在测试时,所使用的均值和方差是整个训练集的均值和方差。整个训练集的均值和方差的值通常是在训练的同时用移动平均法来计算的。)

关于滑动平均值怎么求,请看链接:深度学习中的Batch Normalization_whitesilence的博客-CSDN博客

有了均值和方差,每个隐层神经元也已经有对应训练好的Scaling参数和Shift参数,就可以在推导的时候对每个神经元的激活数据计算NB进行变换了,在推理过程中进行BN采取如下方式:
在这里插入图片描述
这个公式其实和训练时
在这里插入图片描述
是等价的,通过简单的合并计算推导就可以得出这个结论。那么为啥要写成这个变换形式呢?我猜作者这么写的意思是:在实际运行的时候,按照这种变体形式可以减少计算量,为啥呢?因为对于每个隐层节点来说:
在这里插入图片描述
都是固定值,这样这两个值可以事先算好存起来,在推理的时候直接用就行了,这样比原始的公式每一步骤都现算少了除法的运算过程,乍一看也没少多少计算量,但是如果隐层节点个数多的话节省的计算量就比较多了。

这里其实是和训练保持一致,因为都要先减去训练集的均值再除以方差。这里你把训练时候的x^(k)带进去算下就能理解测试时候的了。

五、BatchNorm的好处

BatchNorm为什么NB呢,关键还是效果好。①不仅仅极大提升了训练速度,收敛过程大大加快;②还能增加分类效果,一种解释是这是类似于Dropout的一种防止过拟合的正则化表达方式,所以不用Dropout也能达到相当的效果;③另外调参过程也简单多了,对于初始化要求没那么高,而且可以使用大的学习率等。总而言之,经过这么简单的变换,带来的好处多得很,这也是为何现在BN这么快流行起来的原因。

六、机器学习中mini-batch和batch有什么区别

在机器学习和深度学习中,“mini-batch” 和 “batch” 是两个常用的术语,它们之间存在一些区别。

Mini-batch(小批量):Mini-batch 是指从训练数据集中选择的较小的数据子集。在训练模型时,通常将整个训练数据集划分为多个 mini-batch。每个 mini-batch 包含一定数量的训练样本,通常是2的幂次方,例如 32、64 或 128。模型使用每个 mini-batch 的样本来进行前向传播、计算损失和反向传播,然后根据这些样本的梯度更新模型的参数。使用 mini-batch 的主要目的是减少计算开销和内存占用,并提高训练的效率。

Batch(批量):Batch 是指将整个训练数据集作为一个大批量进行训练。在每次迭代时,模型使用整个训练数据集的样本进行前向传播、计算损失和反向传播,然后根据这些样本的梯度更新模型的参数。相比于 mini-batch,使用 batch 的训练过程可能会占用更多的内存和计算资源,因为需要同时处理整个数据集。

因此,mini-batch 和 batch 的区别在于处理的数据规模不同。mini-batch 是一个相对较小的数据子集,用于训练过程中的迭代更新,而 batch 是整个训练数据集的一次性处理。选择使用 mini-batch 还是 batch 取决于数据集的大小、计算资源的限制以及训练的效率要求。通常情况下,mini-batch 是更常用和常见的训练方式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1276122.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

漏洞复现--Tenda路由器DownloadCfg信息泄露

免责声明&#xff1a; 文章中涉及的漏洞均已修复&#xff0c;敏感信息均已做打码处理&#xff0c;文章仅做经验分享用途&#xff0c;切勿当真&#xff0c;未授权的攻击属于非法行为&#xff01;文章中敏感信息均已做多层打马处理。传播、利用本文章所提供的信息而造成的任何直…

IntelliJ IDEA创建springboot项目时不能选择java8的问题解决方案

最近博主也有创建springboot项目&#xff0c;发现了IntelliJ IDEA在通过Spring Initilizer初始化项目的时候已经没有java8版本的选项了。 基于这个问题&#xff0c;有了这篇文章的分享&#xff0c;希望能够帮助大家克服这个困难。 如图&#xff0c;现在创建springboot项目的时…

BLIoTLink工业协议转换软件功能和使用教程

1.功能简介 BLIoTLink 是一款各种 PLC 协议、Modbus RTU 、Modbus TCP、DL/T645 等多 种协议转换为 Modbus TCP、OPC UA、MQTT、BACnet IP、华为云 IoT、亚 马逊云 IoT、阿里云 IoT、ThingsBoard、钡铼云 IoT 等协议的软件。 BLIoTLink 下行支持&#xff1a;各种 PLC 协议、Mod…

ES6知识

作用域 局部作用域 局部作用域分为函数作用域和块作用域 函数作用域 在函数内部声明的变量只能在函数内部被访问&#xff0c;外部无法直接访问。函数的参数也是函数内部的局部变量。不同函数内部声明的变量无法互相访问。函数执行完毕后&#xff0c;函数内部的变量实际被清空…

【代码】考虑差异性充电模式的电动汽车充放电优化调度matlab-yalmip-cplex/gurobi

程序名称&#xff1a;考虑差异性充电模式的电动汽车充放电优化调度 实现平台&#xff1a;matlab-yalmip-cplex/gurobi 代码简介&#xff1a;提出了一种微电网中电动汽车的协调充电调度方法&#xff0c;以将负荷需求从高峰期转移到低谷期。在所提出的方法中&#xff0c;基于充…

(一)Tiki-taka算法(TTA)求解无人机三维路径规划研究(MATLAB)

一、无人机模型简介&#xff1a; 单个无人机三维路径规划问题及其建模_IT猿手的博客-CSDN博客 参考文献&#xff1a; [1]胡观凯,钟建华,李永正,黎万洪.基于IPSO-GA算法的无人机三维路径规划[J].现代电子技术,2023,46(07):115-120 二、Tiki-taka算法&#xff08;TTA&#xf…

AutoDL 使用记录

AutoDL 使用记录 1.租用新实例 创建实例需要依次选择&#xff1a;计费方式 → \to → 地区 → \to → GPU型号与数量 → \to → 主机 注意事项&#xff1a; 主机 ID&#xff1a;一个吉利的机号有助于炼丹成功价格&#xff1a;哪个便宜选哪个最高 CUDA 版本&#xff1a;影响…

操作系统-输入输出管理

I/O设备的基本概念和分类 I/O就是输入/输出 I/O设备就是可以将数据输入到计算机&#xff0c;或者可以接收计算机输出数据的外部设备&#xff0c;属于计算机中的硬件部件。 I/O设备按使用特性分类 人机交互类外部设备存储设备网络通信设备 I/O设备按传输速率分类 低速设备中…

小米智能摄像头mp4多碎片手工恢复案例

小米智能摄像头mp4多碎片手工恢复案例 智能摄像头目前在市场上极为常见&#xff0c;仅需要一张存储卡即可实现视频、音频的采集&#xff0c;同时可以通过手机APP进行远程控制&#xff0c;相比传统安防品牌成本更低、更容易部署。在智能摄像头品牌中小米算是绝对的大厂&#xf…

HTTP协议、Java前后端交互、Servlet

文章目录 抓包工具 FiddlerHTTP 请求和响应结构URL 唯一资源定位符HTTP 协议中的方法请求报头&#xff08;header&#xff09;HTTP响应构造 HTTP 请求基于 form 标签基于 ajax使用 Postman HTTPS和 HTTP 的区别对称密钥和非对称密钥数字证书 TomcatServlet创建 Maven 项目引入依…

SSM框架(四):SSM整合 案例 + 异常处理器 +拦截器

文章目录 一、整合流程图1.1 Spring整合Mybatis1.2 Spring整合SpringMVC 二、表现层数据封装2.1 问题引出2.2 统一返回结果数据格式 代码设计 三、异常处理器3.1 概述3.2 异常处理方案 四、前端五、拦截器5.1 概念5.2 入门案例5.3 拦截器参数5.4 拦截器链 一、整合流程图 1.1 S…

2.qml 3D-View3D类学习

本章我们来学习View3D类。 View3D是用来渲染3D场景并显示在2D平面的类&#xff0c;并且该类可以放在QML2D下继承于Item子类的任何场景中&#xff0c;比如将View3D放在Rectangle中: Rectangle {width: 200 height: 200color: "red"View3D { anchors.fill: parent…

STM32CubeIDE(CUBE-MX hal库)----蓝牙模块HC-05(详细配置)

系列文章目录 STM32CubeIDE(CUBE-MX hal库)----初尝点亮小灯 STM32CubeIDE(CUBE-MX hal库)----按键控制 STM32CubeIDE(CUBE-MX hal库)----串口通信 STM32CubeIDE(CUBE-MX hal库)----定时器 文章目录 系列文章目录前言一、蓝牙配置二、CUBE-MX可视化配置三、蓝牙APP调试助手四、…

mysql在linux环境下安装(rpm)以及初始化后的登录配置

注&#xff1a;该安装步骤转载于CSDN,下方配置为原创 按照图片安装并初始化完成MySQL等操作后进行&#xff1b; 安装对于rpm包集合 1-查看安装情况&#xff08;有4个路径&#xff09; whereis mysql 2-查看服务状态 systemctl status mysql 3-初始化数据库 mysqld --initial…

6.5 Windows驱动开发:内核枚举PspCidTable句柄表

在 Windows 操作系统内核中&#xff0c;PspCidTable 通常是与进程&#xff08;Process&#xff09;管理相关的数据结构之一。它与进程的标识和管理有关&#xff0c;每个进程都有一个唯一的标识符&#xff0c;称为进程 ID&#xff08;PID&#xff09;。与之相关的是客户端 ID&am…

【蓝桥杯选拔赛真题71】Scratch绘制彩虹 少儿编程scratch图形化编程 蓝桥杯创意编程选拔赛真题解析

目录 scratch绘制彩虹 一、题目要求 编程实现 二、案例分析 1、角色分析

Python+Requests对图片验证码的处理

Requests对图片验证码的处理 在web端的登录接口经常会有图片验证码的输入&#xff0c;而且每次登录时图片验证码都是随机的&#xff1b;当通过request做接口登录的时候要对图片验证码进行识别出图片中的字段&#xff0c;然后再登录接口中使用&#xff1b; 通过request对图片验…

ChatGPT成为“帮凶”:生成虚假数据集支持未知科学假设

ChatGPT 自发布以来&#xff0c;就成为了大家的好帮手&#xff0c;学生党和打工人更是每天都离不开。 然而这次好帮手 ChatGPT 却帮过头了&#xff0c;莫名奇妙的成为了“帮凶”&#xff0c;一位研究人员利用 ChatGPT 创建了虚假的数据集&#xff0c;用来支持未知的科学假设。…

Windows环境 dockertopdesk 部署gitlab

1.在dockertopdesk里搜索 gitlab镜像 (pull)拉取镜像 2.运行镜像到容器 mkdir gitlab gitlab/etc gitlab/log gitlab/opt docker run -id -p 3000:80 -p 9922:22 -v /root/gitlab/etc:/etc/gitlab -v /root/gitlab/log:/var/log/gitlab -v /root/gitlab/opt:/var/opt/gitla…

Linux系统之centos7编译安装Python 3.8

前言 CentOS (Community Enterprise Operating System) 是一种基于 Red Hat Enterprise Linux (RHEL) 进行源代码再编译并免费提供给用户的 Linux 操作系统。 CentOS 7 采用了最新的技术和软件包&#xff0c;并提供了强大的功能和稳定性。它适用于各种服务器和工作站应用场景&a…