深度学习: BatchNormlization论文详细解读

news2024/11/15 22:54:26

《Batch Normlization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》 论文详细解读


💡目录

    • <center>《Batch Normlization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》 论文详细解读
  • 基础知识
  • 面临的挑战
    • Internal covariate shift (内部协变量偏移)
  • 解决方案
    • whiten(白化)
      • PCA白化
    • Batch Normalization
      • Training
      • Testing
      • 在CNN中的运用
  • 总结
  • 代码实现

基础知识

一文读懂PCA

面临的挑战

Internal covariate shift (内部协变量偏移)

在这里插入图片描述

作者把在训练期间参数的改变而导致网络激活分布的改变叫做内部协变量偏移,对此我们有两个版本版本的解释:

  1. 如上图所示,前向计算从数据侧到损失侧,反向传播与其相反,函数更新从上到下,随着网络深度的加深,越往下梯度就越小,在学习率固定的情况下,参数更新幅度也就越来越小。靠近损失侧的神经元提取的大多是高层语义信息,这些神经元的权重往往很容易拟合,而靠近数据侧的神经元提取的是底层的纹理、线条等信息,这部分数据权重拟合较慢,因为更新参数会导致分布改变,顶部会不断的去适应底部的分布,这就会导致训练速度很慢。
  2. 如下图所示,数据x根据参数A输出a(根据链式法则,a也等于函数对B的偏导数),a通过参数B输出b,数据x从左到右前向转播计算损失,之后从后往前计算梯度,我们发现当参数A到A’的时候,参数B也更新到了B’,但是B’的梯度计算是以a为基础的,而此刻a已经变成了a’,也就是说B‘在这个模型中就不是最合适的了,BN的核心思想就是尽量的让a与a’的分布相近,这样可以缓解上面问题所带差距。

在这里插入图片描述

解决方案

whiten(白化)

PCA白化

PCA是在对观测数据进行基变换,新的坐标系使各数据维度线性无关,坐标系的重要程度从大到小衰减。

求解过程:

  1. 数据标准化(以远点为坐标原点)
  2. 求协方差矩阵
  3. 对协方差矩阵特征值分解找到最大方差的方向
  4. 对数据基变换

其中特征向量,就是最大方差方向,每个特征向量对应的特征值就是这个数据维度的方差。

PCA白化实际上就是在数据通过PCA进行基变换后再把数据进行标准化,让数据每个维度的方差全部为1。
公式推导如下:

符号定义:X:原始数据矩阵 M:原始数据协方差矩阵 设 S 1 / 2 S^{1/2} S1/2为白化矩阵

在这里插入图片描述

对M特征值分解:
在这里插入图片描述
U就是我们要找的变换矩阵,转换数据基坐标:
X P C A = U X X_{PCA}=UX XPCA=UX

然后进行白化操作:
lambda为特征值
在这里插入图片描述

其中有的特征值很小,会造成数值溢出,就给它加上了1个常数项,于是把白化矩阵改为:

在这里插入图片描述

我们发现,白化操作可以让观测数据的方差与均值固定,去除每个维度的相关性。这样确实可以加快模型的收敛,但是也面临着一个问题:
如果忽略了E[x]对b的依赖(也就是反向传播计算梯度的时候考虑均值的影响)
在这里插入图片描述
从上面案例中我们发现,更新偏置b前后函数的输出没有改变,也就是损失没有改变,反而b不断增加,这会使模型变得更糟。

我们把归一化操作定义为Norm,如果反向传播不考虑Norm,那么更新的梯度就会与Norm抵消,如果考虑,就会增加很大的计算量。

Batch Normalization

Training

由于白化的计算代价很大,作者提出了简化的版本,从对整个数据集进行归一化改成对每一个Batch的每一层神经元的output归一化来确保均值与方差固定。
在这里插入图片描述
如果把每层的输出固定下来,可能会对网络产生负面的影响,所以我们加入两个可学习的参数:贝塔与伽马使均值与方差变得可以调节。
其中伽马初始化为这一batch对应层输出的方差,贝塔初始化为其均值,从而保证整个network的capacity。(有关capacity的解释:实际上BN可以看作是在原模型上加入的“新操作”,这个新操作很大可能会改变某层原来的输入。当然也可能不改变,不改变的时候就是“还原原来输入”。如此一来,既可以改变同时也可以保持原输入,那么模型的容纳能力(capacity)就提升了。)

在这里插入图片描述

总体流程如下:
在这里插入图片描述
反向传播梯度计算公式如下:
在这里插入图片描述

Testing

  1. 在训练阶段,我们通过每个batch的数据来计算均值与方差,当在测试阶段,由于一些环境条件的限制,batch一般为1,就不能计算均值与方差了,所以在训练阶段采用指数加权平均的方式来计算所有batch的均值与方差的平均值。
  2. 为了使计算更加准确,采用无偏估计。
    在这里插入图片描述

在CNN中的运用

当BN操作应用在卷积层后,作者找到了一个符合卷积神经网络特性的方法,归一化作用在了通道维度上。
我们用代码输出结果展示一下:
用pytorch生成 Batch=2 channel = 3 hw 2 * 2 的特征图:
在这里插入图片描述
计算均值
在这里插入图片描述
计算举例:
(0+1+2+3+12+13+14+15)/8 = 7.5

总结

  1. BN使得每层网络输出分布相对稳定,可以使用更大的学习率加速模型。
  2. BN使得模型对网络中的参数不那么敏感,简化调参过程,使得网络学习更加稳定。
  3. BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题。
  4. BN具有一定的正则化效果。

代码实现

class BatchNorm(nn.Block):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims, **kwargs):
        super().__init__(**kwargs)
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = self.params.get('gamma', shape=shape, init=init.One())
        self.beta = self.params.get('beta', shape=shape, init=init.Zero())
        # 非模型参数的变量初始化为0和1
        self.moving_mean = np.zeros(shape)
        self.moving_var = np.ones(shape)

    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var
        # 复制到X所在显存上
        if self.moving_mean.ctx != X.ctx:
            self.moving_mean = self.moving_mean.copyto(X.ctx)
            self.moving_var = self.moving_var.copyto(X.ctx)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma.data(), self.beta.data(), self.moving_mean,
            self.moving_var, eps=1e-12, momentum=0.9)
        return Y

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/73579.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习11支持向量机SVM(处理线性数据)

文章目录一、什么是支撑向量机&#xff1f;二、Hard Margin SVM思想逻辑推理点到直线的距离&#xff1a;推论&#xff1a;再推&#xff1a;换符号替代&#xff1a;最大化距离&#xff1a;三、Soft Margin SVM和SVM正则化Hard Margin SVM缺点&#xff1a;所以我们必须思考一个机…

YOLO v1

参考 YOLO v1 - 云社区 - 腾讯云 摘要 我们提出了一种新的目标检测方法YOLO。 先前的目标检测工作重新利用分类器来执行检测。 相反&#xff0c;我们将对象检测作为空间分离的边界框和相关类概率的回归问题。 在一次评估中&#xff0c;一个单一的神经网络直接从完整的图像预…

内核态的文件操作函数:filp_open、filp_close、vfs_read、vfs_write、set_fs、get_fs

关于用户态的文件操作函数我们知道有open、read、write这些。但是这些的实现都是依赖于库的实现&#xff0c;但是在内核态是没有库函数可用的。最近做测试&#xff0c;在内核态中&#xff0c;需要学习一下在内核态里面的文件操作函数。分为三对出现。 感谢前辈的优秀文章&…

企业网站怎么建立?【企业网站的建设】

不少的实体企业都会考虑建立一个自己的企业网站&#xff0c;那么在企业网站的建设之前需要做好功课。那么企业网站怎么建立&#xff1f;下面给大家说说大概的流程。 1、申请域名 企业可以申请一个和自己企业名称相关的域名&#xff0c;而且域名尽量不要太长&#xff0c;否则难…

Java学习之多态数组

目录 一、定义 二、举例说明 要求1 父类-Person 子类-Student 子类-Teacher main类 运行结果 要求2 思路分析 main类中的代码 运行结果 一、定义 数组的定义类型为父类类型&#xff0c; 里面保存的实际元素类型为子类类型&#xff08;也可以有父类&#xff09; 二、…

Cat.1无线数据传输终端/Cat.1 DTU/LTE Cat.1 DTU/Cat 1模组功能

LTE Cat.1无线数传终端F2C16将借助成熟的LTE网络以更好的覆盖、更快的速度、更低的延时&#xff0c;完美取代传统2G/3G网络&#xff0c;为中低速率物联网行业提供优质的无线连接服务。 工业级芯片设计&#xff0c;设备稳定联网 ●全工业级芯片设计&#xff0c;宽温宽压&#xf…

「虚拟社交」爆火,资深玩家「当道」

⬆️“政企数智办公行业研究报告及融云新品发布会”明天直播&#xff01; 一切应用都将社交化。关注【融云全球互联网通信云】回复【融云】抽取高颜值大容量高端可乐保温杯哦~ 中国政企数智办公平台行业研究报告 融入社交能力&#xff0c;创造增长奇迹。激活用户在不同场景的社…

6个改善【客户体验】的自动电子邮件营销回复示例

关键词&#xff1a;客户体验、电子邮件营销 电子邮件自动回复器是将跨境电商的客户体验 (CX) 提升到一个新水平的一种方式。为了帮助跨境电商决定应该设置哪种自动电子邮件&#xff0c;我们汇总了对客户体验影响最大的 六个电子邮件自动回复示例。 这里有一些统计数据可以正确看…

国内各行业领域是否能通过与元宇宙和虚拟数字人的结合振兴数藏经济?

在过去几年&#xff0c; NFT和数字藏品已被广泛用于数字经济。 根据中国数字藏品行业协会早在2021年发布的市场发展报告中就指出了当年中国数字藏品市场规模达到2166亿元。 今年&#xff0c;国内元宇宙概念被炒得火热&#xff0c;从故宫博物院联合腾讯、网易等推出「故宫系列」…

关于C++11

文章目录&#x1f60d;C11优势&#x1f60e; 列表初始化&#x1f601;变量类型推导&#x1f44c;为什么需要类型推导&#x1f44d;decltype类型推导&#xff08;了解&#xff09;&#x1f61c;final 与 overridefinal&#x1f91e;override❤️默认成员函数控制&#x1f929;显…

TH10-数据统计与内容审核

TH10-数据统计与内容审核1、用户冻结解冻1.1 用户冻结ManageControllerManageService1.2 用户解冻ManageControllerManageService1.3 查询数据列表UserInfoManageService1.4 探花系统修改UserFreezeService2、数据统计2.1 数据采集2.1.1 部署RabbitMQ2.1.2 消息类型说明2.1.3 实…

使用dd+hexdump命令修改环境变量的值和升级uboot

前言 这篇写的较细&#xff0c;使用dd擦除emmc本来就是比较危险的事情&#xff0c;所以一定要细致。哪里没看明白的&#xff0c;赶紧留言问我&#xff0c;可不能存有侥幸心理。 思路大概就是&#xff1a; 1 先从emmc把数据读出来&#xff0c;放一个镜像文件里&#xff0c;使…

【整理】Python全栈技术学习路线

【整理】Python全栈技术学习路线【阶段一】Python基础Linux【阶段二】多任务编程服务器前端基础【阶段三】数据库mini Web框架【阶段四】Dhango框架美多商城项目【阶段五】DRF框架美多商城后台【阶段六】项目部署Flask框架Hm头条【阶段七】人工智能基础推荐系统基础Hm头条推荐系…

带你了解extern “C“

1.extern “C” 这个语法是c的语法。我们知道在一个.c文件中调用另一个.c中实现的函数是没有任何问题的&#xff0c;一个.cpp文件调用另一个.cpp文件中实现的函数也是没有问题的。但是我们如果想要在一个.cpp文件调用另一个.c文件中实现的函数&#xff0c;或者在一个.c文件中调…

双调序列

目录 双调序列 思路: 代码: 时间复杂度: 总结: 题目链接: 双调序列 题目描述&#xff1a; XJ编程小组的童鞋们经常玩一些智力小游戏&#xff0c;某月某日&#xff0c;小朋友们又发明了一种新的序列&#xff1a;双调序列&#xff0c;所谓的双调呢主要是满足如下条件描述…

TensorFlow之分类模型-2

1 基本概念 2 文本分类与情感分析 获取数据集 加载数据集 训练数据集 性能设置 为了提升训练过程中数据处理的性能&#xff0c;keras技术框架提供数据集缓存的功能&#xff0c;使用缓存可以避免读取磁盘数据集时由于IO消耗太多而出现性能瓶颈的问题&#xff0c;如果数据集…

操作系统的主要功能是什么

操作系统的主要功能是进程管理、存储管理、设备管理、文件管理、作业管理。 计算机系统的资源可分为设备资源和信息资源两大类。 操作系统位于底层硬件与用户之间&#xff0c;是两者沟通的桥梁。 1、进程管理&#xff0c;其工作主要是进程调度&#xff0c;在单用户单任务的情…

opcj2-盘点几个常见的Java开源脚手架

很多人抱怨自己是CURDer&#xff0c;很多时候就是在简单的修修改改。如果不书序SSM&#xff08;Spring、SpringMVC和Mybatis&#xff09;套路的人可能开始的时候会感觉非常吃力。但是熟悉之后发现其实就这么回事。SpringMVC负责响应对外接口&#xff0c;Mybatis负责数据库的访问…

TF4-圈子功能

TF4-圈子功能1、首页推荐1.1、接口分析1.2、功能实现1.2.1 controller1.2.2 service1.2.3 API接口1.2.4 请求dto对象2、MongoDB集群3、圈子功能2.1、功能说明1.2、实现方案分析1.3、技术方案(重点)1.4、表结构设计4、圈子实现3.1、环境搭建3.1.1、mongo主键自增3.1.2、实体类Mo…

基于matlab的SVM支持向量机分类仿真,核函数采用RBF函数

目录 1.算法描述 2.仿真效果预览 3.MATLAB核心程序 4.完整MATLAB 1.算法描述 支持向量机&#xff08;support vector machines, SVM&#xff09;是二分类算法&#xff0c;所谓二分类即把具有多个特性&#xff08;属性&#xff09;的数据分为两类&#xff0c;目前主流机器学…