神经网络基础部件-BN层详解

news2024/11/24 19:57:16

一,数学基础

1.1,概率密度函数

随机变量(random variable)是可以随机地取不同值的变量。随机变量可以是离散的或者连续的。简单起见,本文用大写字母 X X X 表示随机变量,小写字母 x x x 表示随机变量能够取到的值。例如, x 1 x_1 x1 x 2 x_2 x2 都是随机变量 X X X 可能的取值。随机变量必须伴随着一个概率分布来指定每个状态的可能性。

概率分布(probability distribution)用来描述随机变量或一簇随机变量在每一个可能取到的状态的可能性大小。我们描述概率分布的方式取决于随机变量是离散的还是连续的。

当我们研究的对象是连续型随机变量时,我们用概率密度函数(probability density function, PDF)而不是概率质量函数来描述它的概率分布。

更多内容请阅读《花书》第三章-概率与信息论,或者我的文章-深度学习数学基础-概率与信息论。

1.2,正态分布

当我们不知道数据真实分布时使用正态分布的原因之一是,正态分布拥有最大的熵,我们通过这个假设来施加尽可能少的结构。

实数上最常用的分布就是正态分布(normal distribution),也称为高斯分布 (Gaussian distribution)。

如果随机变量 X X X ,服从位置参数为 μ \mu μ、尺度参数为 σ \sigma σ 的概率分布,且其概率密度函数为:

f ( x ) = 1 σ 2 π e − ( x − μ ) 2 2 σ 2 (1) f(x)=\frac{1}{\sigma\sqrt{2 \pi} } e^{- \frac{{(x-\mu)^2}}{2\sigma^2}} \tag{1} f(x)=σ2π 1e2σ2(xμ)2(1)
则这个随机变量就称为正态随机变量,正态随机变量服从的概率分布就称为正态分布,记作:
X ∼ N ( μ , σ 2 ) (2) X \sim N(\mu,\sigma^2) \tag{2} XN(μ,σ2)(2)
如果位置参数 μ = 0 \mu = 0 μ=0,尺度参数 σ = 1 \sigma = 1 σ=1 时,则称为标准正态分布,记作:
X ∼ N ( 0 , 1 ) (3) X \sim N(0, 1) \tag{3} XN(0,1)(3)
此时,概率密度函数公式简化为:
f ( x ) = 1 2 π e − x 2 2 (4) f(x)=\frac{1}{\sqrt{2 \pi}} e^{- \frac{x^2}{2}} \tag{4} f(x)=2π 1e2x2(4)
正态分布的数学期望值或期望值 μ \mu μ 等于位置参数,决定了分布的位置;其方差 σ 2 \sigma^2 σ2 的开平方或标准差 σ \sigma σ 等于尺度参数,决定了分布的幅度。正态分布的概率密度函数曲线呈钟形,常称之为钟形曲线,如下图所示:

正太分布概率密度函数曲线

可视化正态分布,可直接通过 np.random.normal 函数生成指定均值和标准差的正态分布随机数,然后基于 matplotlib + seabornkdeplot函数绘制概率密度曲线。示例代码如下所示:

import seaborn as sns
x1 = np.random.normal(0, 1, 100)
x2 = np.random.normal(0, 1.5, 100) 
x3 = np.random.normal(2, 1.5, 100) 

plt.figure(dpi = 200)

sns.kdeplot(x1, label="μ=0, σ=1")
sns.kdeplot(x2, label="μ=0, σ=1.5")
sns.kdeplot(x3, label="μ=2, σ=2.5")

#显示图例
plt.legend()
#添加标题
plt.title("Normal distribution")
plt.show()

以上代码直接运行后,输出结果如下图:

不同参数的正态分布函数曲线

当然也可以自己实现正态分布的概率密度函数,代码和程序输出结果如下:

import numpy as np
import matplotlib.pyplot as plt
plt.figure(dpi = 200)
plt.style.use('seaborn-darkgrid') # 主题设置

def nd_func(x, sigma, mu):
  	"""自定义实现正态分布的概率密度函数
  	"""
    a = - (x-mu)**2 / (2*sigma*sigma)
    f = np.exp(a) / (sigma * np.sqrt(2*np.pi))
    return f

if __name__ == '__main__':
    x = np.linspace(-5, 5)
    f = nd_fun(x, 1, 0)
    p1, = plt.plot(x, f)

    f = nd_fun(x, 1.5, 0)
    p2, = plt.plot(x, f)

    f = nd_fun(x, 1.5, 2)
    p3, = plt.plot(x, f)

    plt.legend([p1 ,p2, p3], ["μ=0,σ=1", "μ=0,σ=1.5", "μ=2,σ=1.5"])
    plt.show()

自己实现的不同参数的正态分布函数曲线

二,背景

训练深度神经网络的复杂性在于,因为前面的层的参数会发生变化导致每层输入的分布在训练过程中会发生变化。这又导致模型需要需要较低的学习率和非常谨慎的参数初始化策略,从而减慢了训练速度,并且具有饱和非线性的模型训练起来也非常困难。

网络层输入数据分布发生变化的这种现象称为内部协变量转移,BN 就是来解决这个问题。

2.1,如何理解 Internal Covariate Shift

在深度神经网络训练的过程中,由于网络中参数变化而引起网络中间层数据分布发生变化的这一过程被称在论文中称之为内部协变量偏移(Internal Covariate Shift)。

那么,为什么网络中间层数据分布会发生变化呢?

在深度神经网络中,我们可以将每一层视为对输入的信号做了一次变换(暂时不考虑激活,因为激活函数不会改变输入数据的分布):
Z = W ⋅ X + B (5) Z = W \cdot X + B \tag{5} Z=WX+B(5)
其中 W W W B B B 是模型学习的参数,这个公式涵盖了全连接层和卷积层。

随着 SGD 算法更新参数,和网络的每一层的输入数据经过公式5的运算后,其 Z Z Z分布一直在变化,因此网络的每一层都需要不断适应新的分布,这一过程就被叫做 Internal Covariate Shift。

而深度神经网络训练的复杂性在于每层的输入受到前面所有层的参数的影响—因此当网络变得更深时,网络参数的微小变化就会被放大。

2.2,Internal Covariate Shift 带来的问题

  1. 网络层需要不断适应新的分布,导致网络学习速度的降低

  2. 网络层输入数据容易陷入到非线性的饱和状态并减慢网络收敛,这个影响随着网络深度的增加而放大。

    随着网络层的加深,后面网络输入 x x x 越来越大,而如果我们又采用 Sigmoid 型激活函数,那么每层的输入很容易移动到非线性饱和区域,此时梯度会变得很小甚至接近于 0 0 0,导致参数的更新速度就会减慢,进而又会放慢网络的收敛速度。

饱和问题和由此产生的梯度消失通常通过使用修正线性单元激活( R e L U ( x ) = m a x ( x , 0 ) ReLU(x)=max(x,0) ReLU(x)=max(x,0)),更好的参数初始化方法和小的学习率来解决。然而,如果我们能保证非线性输入的分布在网络训练时保持更稳定,那么优化器将不太可能陷入饱和状态,进而训练也将加速。

2.3,减少 Internal Covariate Shift 的一些尝试

  1. 白化(Whitening): 即输入线性变换为具有零均值和单位方差,并去相关。

    白化过程由于改变了网络每一层的分布,因而改变了网络层中本身数据的表达能力。底层网络学习到的参数信息会被白化操作丢失掉,而且白化计算成本也高。

  2. 标准化(normalization)

    Normalization 操作虽然缓解了 ICS 问题,让每一层网络的输入数据分布都变得稳定,但却导致了数据表达能力的缺失。

三,批量归一化(BN)

3.1,BN 的前向计算

论文中给出的 Batch Normalizing Transform 算法计算过程如下图所示。其中输入是一个考虑一个大小为 m m m 的小批量数据 B \cal B B

Batch Normalizing Transform

论文中的公式不太清晰,下面我给出更为清晰的 Batch Normalizing Transform 算法计算过程。

m m m 表示 batch_size 的大小, n n n 表示 features 数量,即样本特征值数量。在训练过程中,针对每一个 batch 数据,BN 过程进行的操作是,将这组数据 normalization,之后对其进行线性变换,具体算法步骤如下:

μ B = 1 m ∑ 1 m x i σ B 2 = 1 m ∑ 1 m ( x i − μ B ) 2 n i = x i − μ B σ B 2 + ϵ z i = γ n i + β = γ σ B 2 + ϵ x i + ( β − γ μ B σ B 2 + ϵ ) (6) \begin{aligned} \mu_B = \frac{1}{m}\sum_1^m x_i \\ \sigma^2_B = \frac{1}{m} \sum_1^m (x_i-\mu_B)^2 \\ n_i = \frac{x_i-\mu_B}{\sqrt{\sigma^2_B + \epsilon}} \\ z_i = \gamma n_i + \beta = \frac{\gamma}{\sqrt{\sigma^2_B + \epsilon}}x_i + (\beta - \frac{\gamma\mu_{B}}{\sqrt{\sigma^2_B + \epsilon}}) \tag{6}\\ \end{aligned} μB=m11mxiσB2=m11m(xiμB)2ni=σB2+ϵ xiμBzi=γni+β=σB2+ϵ γxi+(βσB2+ϵ γμB)(6)

以上公式乘法都为元素乘,即 element wise 的乘法。其中,参数 γ , β \gamma,\beta γ,β 是训练出来的, ϵ \epsilon ϵ 是为零防止 σ B 2 \sigma_B^2 σB2 0 0 0 ,加的一个很小的数值,通常为1e-5。公式各个符号解释如下:

符号数据类型数据形状
X X X输入数据矩阵[m, n]
x i x_i xi输入数据第i个样本[1, n]
N N N经过归一化的数据矩阵[m, n]
n i n_i ni经过归一化的单样本[1, n]
μ B \mu_B μB批数据均值[1, n]
σ B 2 \sigma^2_B σB2批数据方差[1, n]
m m m批样本数量[1]
γ \gamma γ线性变换参数[1, n]
β \beta β线性变换参数[1, n]
Z Z Z线性变换后的矩阵[1, n]
z i z_i zi线性变换后的单样本[1, n]
δ \delta δ反向传入的误差[m, n]

其中:
z i = γ n i + β = γ σ B 2 + ϵ x i + ( β − γ μ B σ B 2 + ϵ ) z_i = \gamma n_i + \beta = \frac{\gamma}{\sqrt{\sigma^2_B + \epsilon}}x_i + (\beta - \frac{\gamma\mu_{B}}{\sqrt{\sigma^2_B + \epsilon}}) zi=γni+β=σB2+ϵ γxi+(βσB2+ϵ γμB)
可以看出 BN 本质上是做线性变换。

3.2,BN 层如何工作

在论文中,训练一个带 BN 层的网络, BN 算法步骤如下图所示:

Training a Batch-Normalized Network

在训练期间,我们一次向网络提供一小批数据。在前向传播过程中,网络的每一层都处理该小批量数据。 BN 网络层按如下方式执行前向传播计算:

Batch Norm 层执行的前向计算

图片来源这里。

注意,图中计算均值与方差的无偏估计方法是吴恩达在 Coursera 上的 Deep Learning 课程上提出的方法:对 train 阶段每个 batch 计算的 mean/variance 采用指数加权平均来得到 test 阶段 mean/variance 的估计。

在训练期间,它只是计算此 EMA,但不对其执行任何操作。在训练结束时,它只是将该值保存为层状态的一部分,以供在推理阶段使用。

如下图可以展示BN 层的前向传播计算过程数据的 shape ,红色框出来的单个样本都指代单个矩阵,即运算都是在单个矩阵运算中计算的。

Batch Norm 向量的形状

图片来源 这里。

BN 的反向传播过程中,会更新 BN 层中的所有 β \beta β γ \gamma γ 参数。

3.3,训练和推理式的 BN 层

批量归一化(batch normalization)的“批量”两个字,表示在模型的迭代训练过程中,BN 首先计算小批量( mini-batch,如 32)的均值和方差。但是,在推理过程中,我们只有一个样本,而不是一个小批量。在这种情况下,我们该如何获得均值和方差呢?

第一种方法是,使用的均值和方差数据是在训练过程中样本值的平均,即:
E [ x ] = E [ μ B ] V a r [ x ] = m m − 1 E [ σ B 2 ] \begin{aligned} E[x] &= E[\mu_B] \\ Var[x] &= \frac{m}{m-1} E[\sigma^2_B] \\ \end{aligned} E[x]Var[x]=E[μB]=m1mE[σB2]
这种做法会把所有训练批次的 μ \mu μ σ \sigma σ 都保存下来,然后在最后训练完成时(或做测试时)做下平均。

第二种方法是使用类似动量的方法,训练时,加权平均每个批次的值,权值 α \alpha α 可以为0.9:
μ m o v i = α ⋅ μ m o v i + ( 1 − α ) ⋅ μ i σ m o v i = α ⋅ σ m o v i + ( 1 − α ) ⋅ σ i \begin{aligned} \mu_{mov_{i}} &= \alpha \cdot \mu_{mov_{i}} + (1-\alpha) \cdot \mu_i \\ \sigma_{mov_{i}} &= \alpha \cdot \sigma_{mov_{i}} + (1-\alpha) \cdot \sigma_i \\ \end{aligned} μmoviσmovi=αμmovi+(1α)μi=ασmovi+(1α)σi
推理或测试时,直接使用模型文件中保存的 μ m o v i \mu_{mov_{i}} μmovi σ m o v i \sigma_{mov_{i}} σmovi 的值即可。

3.4,实验

BNImageNet 分类数据集上实验结果是 SOTA 的,如下表所示:

实验结果表4

3.5,BN 层的优点

  1. BN 使得网络中每层输入数据的分布相对稳定,加速模型训练和收敛速度

  2. 批标准化可以提高学习率。在传统的深度网络中,学习率过高可能会导致梯度爆炸或梯度消失,以及陷入差的局部最小值。批标准化有助于解决这些问题。通过标准化整个网络的激活值,它可以防止层参数的微小变化随着数据在深度网络中的传播而放大。例如,这使 sigmoid 非线性更容易保持在它们的非饱和状态,这对训练深度 sigmoid 网络至关重要,但在传统上很难实现。

  3. BN 允许网络使用饱和非线性激活函数(如 sigmoid,tanh 等)进行训练,其能缓解梯度消失问题

  4. 不需要 dropoutLRN(Local Response Normalization)层来实现正则化。批标准化提供了类似丢弃的正则化收益,因为通过实验可以观察到训练样本的激活受到同一小批量样例随机选择的影响。

  5. 减少对参数初始化方法的依赖

参考资料

  1. 维基百科-正态分布
  2. Batch Norm Explained Visually — How it works, and why neural networks need it
  3. 15.5 批量归一化的原理
  4. Batch Normalization原理与实战

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/336967.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Zabbix 构建监控告警平台(二)--

Apache监控示例(图形监控)模板TemplateZabbix Items 1.Apache监控示例(图形监控) 1.1创建主机组 在“配置”->“主机群组”->“创建主机群组” 填入组名“webserver_test” 创建完成之后可以在“配置”->"主机群组&…

界面控件DevExpress WinForm中文教程 - 如何使用模板库构建类Office UI?

DevExpress WinForm拥有180组件和UI库,能为Windows Forms平台创建具有影响力的业务解决方案。DevExpress WinForm能完美构建流畅、美观且易于使用的应用程序,无论是Office风格的界面,还是分析处理大批量的业务数据,它都能轻松胜任…

opengl glsl shader vscode安装插件glsl_canvas 和 shader languagesupportForVS Code

u_resolution 是画布尺寸,即代表画布宽高 //给内置变量gl_PointSize赋值像素大小,注意值是浮点数 gl_PointSize20.0; // 片元沿着x方向渐变 gl_FragColor vec4(gl_FragCoord.x/500.0*1.0,1.0,0.0,1.0); // 接收插值后的纹理坐标 varying vec2 v…

作为开发人员您会喜欢的 7 个免费公共 API

1. JSON 占位符JSON Placeholder是一项服务,可为您提供用于测试和原型制作的假在线REST API 。这是每个开发人员的首选 API。2.谷歌翻译Google有大量的API,但其中大部分是付费的。值得庆幸的是,Translate API提供100 多种语言的免费翻译&…

Spring面试重点(二)——Spring循环依赖

Spring循环依赖 什么是循环依赖? 从字面上来理解就是A依赖B的同时B也依赖了A,就像上面这样,或者C依赖与自己本身。体现到代码层次就是这个样子 Component public class A { // A中注入了B Autowired private B b; }---Component public cla…

@Valid注解配合属性校验注解完成参数校验并且优化异常处理

Valid注解配合属性校验注解完成参数校验并且优化参数校验异常处理1 Valid注解配合属性校验注解完成参数校验2 优化参数校验异常处理1 Valid注解配合属性校验注解完成参数校验 向数据库商品分类表中新增商品分类字段,并校验传入的参数 不使用注解的传统方法&#xf…

基于python+django社区报修维修平台

本系统主要分为前后和后台页面,前台页面主要功能有:首页,座位信息,交流论坛,公告信息,个人中心,后台管理。后台页面分为:首页,个人中心,学生管理,教师管理,座位信息管理,座位预约管理,班级信息管理,签到信息管理,离开信息管理,座位暂离管理,举报信息管理…

MLX90614红外温度计介绍

MLX90614红外温度计简介MLX90614是一款红外非接触温度计。TO-39金属封装里同时集成了红外感应热电堆探测器芯片和信号处理专用集成芯片。由于集成了低噪声放大器、17位模数转换器和强大的数字信号处理单元,使得高精度和高分辨度的温度计得以实现。温度计具备出厂校准…

如何上传文件

在页面上面&#xff0c;form 表单里面添加属性enctype"multipart/form-data" 比如&#xff1a; <form name"frm" method"post" enctype"multipart/form-data"> 添加文件选择框&#xff1a; <input type"file" na…

安全上下文

目录 文章目录目录本节实战前言1、为 Pod 设置 Security Context2、为容器设置 Security Context3、设置 Linux Capabilities1.Linux Capabilities&#xff08;1&#xff09;什么是 Capabilitie&#xff08;2&#xff09;Capabilities 的赋予和继承&#xff08;3&#xff09;如…

ctfshow 年ctf

文章目录除夕初一初二初三初四初五初六官方wp除夕 include "flag.php";$year $_GET[year];if($year2022 && $year1!2023){echo $flag; }else{highlight_file(__FILE__); }弱比较和强比较的问题 2023那里是强比较&#xff0c;还是很容易的 /?year2022.0科…

CHI协议定义的NOC组件

请求结点RN 可以向NOC发送读/写等请求事务&#xff0c;有以下几种类型的RN&#xff1a; RN-F 一般是处理器核或者核簇结点&#xff0c;包含了局部cache和一致性部件snoopee。与NOC上的一致性部件一起&#xff0c;维护“可缓存”数据的一致性&#xff08;这种可缓存数据…

实验名称:基于C/S的命名管道通信

实验名称&#xff1a;基于C/S的命名管道通信 相关知识 无名管道 无名管道&#xff08;匿名管道&#xff09;用于具有亲缘关系进程间的通信&#xff0c;其特点有 管道是半双工的&#xff0c;数据单向流动&#xff08;双方通信需建立两个通道&#xff09;管道只能用于父子进程…

2023年房地产投资-租金和IRR研究报告

第一章 概况 房地产投资租赁是指置业投资者在购买到物业后&#xff0c;首先对该物业进行适当整饰与装修&#xff0c;之后以出租人的身份&#xff0c;以口头协议或签订合同的形式&#xff0c;将房屋交付承租人占有、使用与收益&#xff0c;由承租人向出租人交付租金的行为。通过…

第一章 企业管理概论

目录 一、企业及其形式 二、企业管理概述 三、企业管理理论与实践的产生与发展 四、网络时代的企业环境 五、网络时代企业管理的变革 一、企业及其形式 1、企业的概念 企业以市场为导向&#xff0c;以价值增值作为经济活动的目的&#xff1b; 企业是从事商品生产和流通的…

BUG解决:微信小程序调用vantweapp遮罩层popup 更改show后没反应,弹框/遮罩层不隐藏,show失效

一、bug复现&#xff1a;引入popup组件&#xff0c;时间选择组件json>"usingComponents": {"van-datetime-picker": "vant/weapp/datetime-picker/index","van-popup": "vant/weapp/popup/index"}页面想实现&#xff0c;…

当我以为z-library已死的时候 它居然又活了?!!

z-library 全世界最大的图书馆What Happened To Z-lib?zlib的复活只是暂时的deepweb会让zlib得到永生&#xff01;真心祝愿zlib的Plans for 2023能够实现What Happened To Z-lib? 这是曾经的zlib&#xff0c;域名是z-lib.org&#xff0c;然而现在死了&#xff08;22年11月时…

Grafana 系列文章(十三):如何用 Loki 收集查看 Kubernetes Events

前情提要 IoT 边缘集群基于 Kubernetes Events 的告警通知实现IoT 边缘集群基于 Kubernetes Events 的告警通知实现&#xff08;二&#xff09;&#xff1a;进一步配置 概述 在分析 K8S 集群问题时&#xff0c;Kubernetes Events 是超级有用的。 Kubernetes Events 可以被当…

Windows 10 Creators版本中的11个大亮点

导读微软在近日公布了有关明年Windows 10更新部分的大量功能&#xff0c;但该公司在其Creators更新版本中悄悄隐藏了远超出11项新的功能。其实&#xff0c;在这个更新包中还将包含许多内容&#xff0c;包括增加一个新的应用程序&#xff0c;以及针对Edge浏览器、地图应用程序和…

蓝牙耳机什么牌子好用又便宜?好用不贵的蓝牙耳机推荐

随着时代的进步&#xff0c;数码产品在人们日常生活中的使用频率越来越高&#xff0c;一部手机&#xff0c;一副耳机似乎已然成为人们出行必备。蓝牙耳机的发展速度很快&#xff0c;在众多的蓝牙耳机牌子中&#xff0c;什么牌子好用又便宜&#xff1f;下面&#xff0c;我来给大…