Batch Normalization学习笔记

news2025/1/28 1:09:13

文章目录

  • 一、为何引入 Batch Normalization
  • 二、具体步骤
    • 1、训练阶段
    • 2、预测阶段
  • 三、关键代码实现
  • 四、补充
  • 五、参考文献

一、为何引入 Batch Normalization

  现在主流的卷积神经网络几乎都使用了批量归一化(Batch Normalization,BN)1,它是一种逐层归一化方法,可以对神经网络中任意的中间层进行归一化操作。我们可以从不同角度来理解为什么要引入 Batch Normalization:

① 训练时的误差表面(error surface) 可能会十分崎岖,使得做优化时容易陷入局部最优值或鞍点等。通常我们会使用各种算法如Adam等进行优化,那么能不能直接改误差表面的地貌,“把山铲平”,让它变得比较好训练呢?Batch Normalization 就是其中一个“把山铲平”的想法2。另外一个好处是,误差表面变得没那么崎岖后,我们在训练时便可以增大学习率,使得网络更快收敛。

② 对于典型的多层感知机或卷积神经网络,在训练时中间层中的变量可能具有更广的变化范围。也就是说,随着训练时间的推移,每一层的模型参数分布范围变化莫测(比如一个深层网络,反向传播更新参数时,顶层与最底层数据范围差异会比较大,因为最底层相当于通过链式法则乘了一堆偏导数,导致数据范围非常大或小):

在这里插入图片描述

变量分布中的不规则的偏移可能会阻碍网络的收敛,因此为了使各层拥有适当的数据范围,通过 Batch Normalization“强制性”地调整数据分布使其约束到更小的范围(标准正态分布),这样便可以使得训练更加稳定,且对于初始值的设置没那么敏感。调整之后示意图如下:

在这里插入图片描述

③ 深层的网络很复杂,容易过拟合。而 Batch Normalization可以作为一种隐形的正则化方法,减轻过拟合(因此有时候使用BN后,dropout显得没那么必要使用)。由于Batch Normalization是基于一个 mini batch的,因此在训练时,神经网络对一个样本的预测不仅和该样本自身相关,也和同一批次中的其他样本相关,这种选取批次的随机性,使得神经网络不会“过拟合”到某个特定样本,从而提高网络的泛化能力。

总而言之,Batch Normalization 的优点如下3

  • 不那么依赖初始值(对于初始值不用那么神经质)。
  • 可以使学习快速进行(可以增大学习率)。
  • 抑制过拟合(降低Dropout等的必要性)。


二、具体步骤

batch normalization本质是对不同样本的同一特征做标准化。

1、训练阶段

  在训练时,Batch Normalization会逐步对每个mini-batch进行归一化。具体步骤如下:

设一个mini-batch中有 m m m 个输入数据,记为集合 B = { x 1 , x 2 , ⋯   , x m } B=\{x_1,x_2,\cdots,x_m\} B={x1,x2,,xm},对该集合求均值 μ B \mu_B μB 和方差 σ B 2 \sigma_B^2 σB2
μ B ← 1 m ∑ i = 1 m x i \begin{aligned}\mu_B\leftarrow\frac{1}{m}\sum_{i=1}^mx_i\end{aligned} μBm1i=1mxi

σ B 2 ← 1 m ∑ i = 1 m ( x i − μ B ) 2 \begin{aligned}\sigma_B^2\leftarrow\frac{1}{m}\sum_{i=1}^m(x_i-\mu_B)^2\end{aligned} σB2m1i=1m(xiμB)2
接下来利用求得的均值和方差对输入数据进行归一化:
x ^ i ← x i − μ B σ B 2 + ε \hat{x}_i\leftarrow\frac{x_i-\mu_B}{\sqrt{\sigma_B^2+\varepsilon}} x^iσB2+ε xiμB
其中 ε \varepsilon ε 是一个微小值(如 10 e − 7 10e^{-7} 10e7 等),以防止出现除以0的情况。

于是便可以将输入数据转换为均值为0,方差为1的数据 { x ^ 1 , x ^ 2 , ⋯   , x ^ m } \left\{\hat{x}_1,\hat{x}_2,\cdots,\hat{x}_m\right\} {x^1,x^2,,x^m} 了。

  为了使得归一化不对网络的表示能力造成负面影响,再通过一个附加的缩放和平移变换改变新数据的取值区间(虽然归一化加快了训练速度和稳定性,但它改变了数据的原始分布。对于某些任务来说,直接使用归一化的数据可能会限制模型的表达能力,因此引入可以学习的超参数 γ \gamma γ β \beta β ,使得模型可以灵活地调整归一化后的数据分布,恢复其自由度):
y i ← γ x ^ i + β y_i\leftarrow\gamma\hat{x}_i+\beta yiγx^i+β

最后把上述所有处理插入到激活函数的前面即可(整个过程相当于一个BatchNorm层),示意图如下:

在这里插入图片描述


示意图二(其中 W W W 是全连接层, L ^ \widehat{\mathcal{L}} L 是损失函数)4

在这里插入图片描述


2、预测阶段

  在训练过程中,我们无法得知整个数据集来估计平均值和方差,所以只能根据每个小批次(mini-batch)的平均值和方差不断训练模型。 而在预测模式下,一般使用整个预测数据集的均值和方差(因为这时候已经经过完整的训练了,因此可以得知全局信息)。为了节省存储资源,实际中大多采用**移动平均(moving average)**的方式来计算全局的均值和方差。移动平均的计算过程如下式所示:
μ t o t a l = λ ∗ μ t o t a l + ( 1 − λ ) ∗ μ B σ t o t a l 2 = λ ∗ σ t o t a l 2 + ( 1 − λ ) ∗ σ B 2 \begin{aligned}\mu_{total}&=\lambda*\mu_{total}+(1-\lambda)*\mu_{\mathcal{B}}\\\sigma_{total}^2&=\lambda*\sigma_{total}^2+(1-\lambda)*\sigma_{\mathcal{B}}^2\end{aligned} μtotalσtotal2=λμtotal+(1λ)μB=λσtotal2+(1λ)σB2



三、关键代码实现

以动手学深度学习第二版5的代码为例(Pytorch):

import torch
from torch import nn
from d2l import torch as d2l


def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data

解释几个可能的疑惑点:

为什么分为全连接层和卷积层两种情况?

  全连接层和卷积层比较典型,它们的批量规范化实现略有不同:当作用在全连接层时,实际上是作用在特征维;当作用在卷积层上时,实际上是作用在通道维(将通道维当成是卷积层的特征维)。

为什么作用在通道维?因为每个通道都有自己的拉伸参数偏移参数,并且都是标量。例如下图6所示:

在这里插入图片描述

上图各颜色通道中的像素值通常具有不同的分布和范围,这种不一致性可能会导致训练出错或网络不收敛等问题。因此需要通过Normalize操作,将每个通道的像素值标准化为均值为0、标准差为1的分布,使得所有通道的像素值范围和分布一致。(因此假如扩展到n维张量,你也只需对通道维求均值即可)



为什么全连接层设置 dim=0,而卷积层设置 dim=(0,2,3)

  全连接层是二维的,即(batch_size, feature) ,计算全连接层时,计算的是特征维的均值和方差,而每个行代表一个样本,每列代表一个特征。

dim=0dim=1 的含义

  • dim=0 表示沿着 “行” 的方向进行操作(也就是跨样本的操作),即对每个特征维的所有样本值进行聚合计算,比如求均值、方差等。
  • dim=1 表示沿着 “列” 的方向进行操作(也就是跨特征的操作),即对每个样本的所有特征值聚合计算。

下图重量/甜度/颜色评分为苹果的特征维,我们来计算特征维的均值:

苹果编号重量(克)甜度(°Bx)颜色评分(1 - 10)
苹果 1200127
苹果 2180106
苹果 3220148

dim=0代表行,dim=1代表列,既然我们要求特征维的均值,那么需要让 dim=0 ,也就是沿着行的方向“拍扁”。上图沿着行方向“拍扁”后得到的特征维的均值如下:

计算结果
重量:(200 + 180 + 220) / 3 = 200
甜度:(12 + 10 + 14) / 3 = 12
颜色评分: (7 + 6 + 8) / 3 = 7

那么卷积层 (batch_size, channels, height, width)设dim=(0,2,3)也很好理解了,我们需要得到通道维的均值,那么就得把其它几个维都“拍扁”。



为什么全连接层无需设置keepdim=True 而卷积层需设置keepdim=True

  由于pytorch的广播机制,只会从左边补1,换个说法即只会补齐最外层的维度,因此前者无需设置而后者需设置keepdim=True来保证广播机制的正常启动。

有点抽象,举例子说明:

# 构造一个形状为 (2, 3, 4, 5, 6) 的五维张量
A = torch.randn(2, 3, 4, 5, 6)

# 打印张量 A 的形状
print("张量 A 的形状:", A.shape)

# 构造一个形状为 (3, 4, 5, 6) 的四维张量
B = torch.randn(3, 4, 5, 6)
print("张量 B 的形状:", B.shape)

try:
    # 尝试执行 A + B
    A + B
    print("可以成功输出")
except Exception as e:
    # 如果发生异常,打印失败信息
    print("失败输出:", e)

输出结果为:

张量 A 的形状: torch.Size([2, 3, 4, 5, 6])
张量 B 的形状: torch.Size([3, 4, 5, 6])
可以成功输出

因为广播机制会让B的维度补齐成(1,3,4,5,6),也就是最左边补“1”,于是就可以执行 A+B操作了。

而如下情况,即仅仅稍微改变一下B的形状:

# 构造一个形状为 (2, 3, 4, 5, 6) 的五维张量
A = torch.randn(2, 3, 4, 5, 6)

# 打印张量 A 的形状
print("张量 A 的形状:", A.shape)

# 构造一个形状为 (2, 3, 4, 5) 的四维张量
B = torch.randn(2, 3, 4, 5)
print("张量 B 的形状:", B.shape)

try:
    # 尝试执行 A + B
    A + B
    print("可以成功输出")
except Exception as e:
    # 如果发生异常,打印失败信息
    print("失败输出:", e)

输出结果为:

张量 A 的形状: torch.Size([2, 3, 4, 5, 6])
张量 B 的形状: torch.Size([2, 3, 4, 5])
失败输出: The size of tensor a (6) must match the size of tensor b (5) at non-singleton dimension 4

因为广播机制只会往最左边补“1”,而这里B补“1”后形状变成(1,2,3,4,5),依旧和张量A的形状不一致,所以不能做相加操作。

回到 Batch-Normalization 的代码:

 if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)

我们知道, dim等于哪个维,就是将那个维进行“拍扁”。

对于全连接层(batch_size, feature),设置 dim=0时,相当于将第 0 维“拍扁”,拍扁了相当于那个维直接“消失”了,此时meanvar的形状为(feature)。于是直接可以通过广播机制,在最左边补“1”,变成(1, feature),便可以和变量 X 一起计算了【X的形状(batch_size, feature)】。

而卷积层 (batch_size, channels, height, width)dim=(0,2,3)时,相当于将第 0,2,3 维“拍扁”,此时meanvar的形状为(channels),而 X 的形状是 (batch_size, channels, height, width),你得将meanvar的形状扩展到和 X 一致才可以进行计算,而广播机制只能往最左边补“1”,因此(channels)无法扩展成和X一致的形状,顶多扩展成(1, channels),所以无法和 X 进行计算,程序报错。

因此需要对卷积层使用 keepdim=True这个参数,这样meanvar的形状就可以扩展成 (1, channels, 1, 1),与X一致,才能进行接下来的计算。



if not torch.is_grad_enabled() 为什么可以判断是训练还是预测模式?

  反向传播时会涉及梯度的计算,而只有训练时才会进行反向传播,因此可以通过是否进行梯度的计算来判断训练模式还是预测模式。



四、补充

  原论文中提出Batch-Normalization的优点是减少了内部协变量转移(internal covariate shift,简单来说就是变量值的分布在训练过程中会发生变化,但是这种解释在后续论文被证实比较不严谨,发现它并没有减少内部协变量的转移 [Santurkar et al.,2018]。



五、参考文献


  1. Ioffe S. Batch normalization: Accelerating deep network training by reducing internal covariate shift[J]. arXiv preprint arXiv:1502.03167, 2015. ↩︎

  2. 王琦, 杨毅远, 江季, 深度学习详解, 北京:人民邮电出版社, 2024 ↩︎

  3. (日)斋藤康毅著, 陆宇杰译, 深度学习入门基于Python的理论与实现, 北京:人民邮电出版社, 2018.07 ↩︎

  4. Santurkar S, Tsipras D, Ilyas A, et al. How does batch normalization help optimization?[J]. Advances in neural information processing systems, 2018, 31. ↩︎

  5. 阿斯顿·张(Aston Zhang), 李沐(Mu Li), [美] 扎卡里·C. 立顿(Zachary C. Lipton), 等. 动手学深度学习(PyTorch版)[M]. 第二版. 人民邮电出版社, 2023-2. ↩︎

  6. 【Batch Normalization】 https://www.bilibili.com/video/BV11s4y1c7pg/?share_source=copy_web&vd_source=199a3f4e3a9db6061e1523e94505165a ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2282993.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

高等数学学习笔记 ☞ 微分方程

1. 微分方程的基本概念 1. 微分方程的基本概念: (1)微分方程:含有未知函数及其导数或微分的方程。 举例说明微分方程:;。 (2)微分方程的阶:指微分方程中未知函数的导数…

第19个项目:蛇年特别版贪吃蛇H5小游戏

下载地址:https://download.csdn.net/download/mosquito_lover1/90308956 游戏玩法: 点击"开始游戏"按钮开始 使用键盘方向键控制蛇的移动 吃到红色食物可以得分 撞到墙壁或自己会结束游戏 核心源码: class SnakeGame { constructor() { this.canvas = docum…

Tensor 基本操作4 理解 indexing,加减乘除和 broadcasting 运算 | PyTorch 深度学习实战

前一篇文章,Tensor 基本操作3 理解 shape, stride, storage, view,is_contiguous 和 reshape 操作 | PyTorch 深度学习实战 本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started Tensor 基本使用 索引 indexing示例代码 加减…

算法中的移动窗帘——C++滑动窗口算法详解

1. 滑动窗口简介 滑动窗口是一种在算法中常用的技巧,主要用来处理具有连续性的子数组或子序列问题。通过滑动窗口,可以在一维数组或字符串上维护一个固定或可变长度的窗口,逐步移动窗口,避免重复计算,从而提升效率。常…

技术总结:FPGA基于GTX+RIFFA架构实现多功能SDI视频转PCIE采集卡设计方案

目录 1、前言工程概述免责声明 3、详细设计方案设计框图SDI 输入设备Gv8601a 均衡器GTX 解串与串化SMPTE SD/HD/3G SDI IP核BT1120转RGBFDMA图像缓存RIFFA用户数据控制RIFFA架构详解Xilinx 7 Series Integrated Block for PCI ExpressRIFFA驱动及其安装QT上位机HDMI输出RGB转BT…

二叉树的最大深度(C语言详解版)

一、摘要 嗨喽呀大家,leetcode每日一题又和大家见面啦,今天要讲的是104.二叉树的最大深度,思路互相学习,有什么不足的地方欢迎指正!好啦让我们开始吧!!! 二、题目简介 给定一个二…

知识图谱抽取三元组技术介绍

知识图谱三元组抽取是知识图谱构建的重要步骤之一,其目的是从文本或数据中提取出结构化的信息,以形成实体、属性和关系之间的联系。这些三元组(Subject-Predicate-Object)是知识图谱的基本单元,用于描述实体之间的语义…

矩阵的秩在机器学习中具有广泛的应用

矩阵的秩在机器学习中具有广泛的应用,主要体现在以下几个方面: 一、数据降维与特征提取 主成分分析(PCA): PCA是一种常用的数据降维技术,它通过寻找数据中的主成分(即最大方差方向&#xff09…

【以音频软件FFmpeg为例】通过Python脚本将软件路径添加到Windows系统环境变量中的实现与原理分析

在Windows系统中,你可以通过修改环境变量 PATH 来使得 ffmpeg.exe 可在任意路径下直接使用。要通过Python修改环境变量并立即生效,如图: 你可以使用以下代码: import os import winreg as reg# ffmpeg.exe的路径 ffmpeg_path …

Linux-rt下卡死之hrtimer分析

Linux-rt下卡死之hrtimer分析 日志 超时读过程分析 #define readl_poll_timeout(addr, val, cond, delay_us, timeout_us) \readx_poll_timeout(readl, addr, val, cond, delay_us, timeout_us)34 #define readx_poll_timeout(op, addr, val, cond, sleep_us, timeout_us) \…

PHP防伪溯源一体化管理系统小程序

🔍 防伪溯源一体化管理系统,品质之光,根源之锁 🚀 引领防伪技术革命,重塑品牌信任基石 我们自豪地站在防伪技术的前沿,为您呈现基于ThinkPHP和Uniapp精心锻造的多平台(微信小程序、H5网页&…

WPF2-在xaml为对象的属性赋值

1. AttributeValue方式 1.1. 简单属性赋值1.2. 对象属性赋值 2. 属性标签的方式给属性赋值3. 标签扩展 (Markup Extensions) 3.1. StaticResource3.2. Binding 3.2.1. 普通 Binding3.2.2. ElementName Binding3.2.3. RelativeSource Binding3.2.4. StaticResource Binding (带参…

uart iic spi三种总线的用法

1、uart串口通信 这种连接方式抗干扰能力弱,旁边有干扰源就会对收发的电平数据造成干扰,进而导致数据失真 这种连接方式一般适用于一块板子上面的两个芯片之间进行数据传输 ,属于异步全双工模式。 1.空闲位:当不进行数据收发时&am…

Java 高级工程师面试高频题:JVM+Redis+ 并发 + 算法 + 框架

前言 在过 2 个月即将进入 3 月了,然而面对今年的大环境而言,跳槽成功的难度比往年高了很多,很明显的感受就是:对于今年的 java 开发朋友跳槽面试,无论一面还是二面,都开始考验一个 Java 程序员的技术功底…

【图文详解】lnmp架构搭建Discuz论坛

安装部署LNMP 系统及软件版本信息 软件名称版本nginx1.24.0mysql5.7.41php5.6.27安装nginx 我们对Markdown编辑器进行了一些功能拓展与语法支持,除了标准的Markdown编辑器功能,我们增加了如下几点新功能,帮助你用它写博客: 关闭防火墙 systemctl stop firewalld &&a…

【番外篇】排列组合实现算法2(Java版)

一、说明 在牛客网的很多算法试题中,很多试题底层都是基于排列组合算法实现的,比如动态规划、最优解、最大值等常见问题。排列组合算法有一定的难度,并不能用一般的多重嵌套循环解决,没有提前做针对性的学习和研究,考…

PAT甲级-1020 Tree Traversals

题目 题目大意 给出一棵树的后序遍历和中序遍历,要求输出该树的层序遍历。 思路 非常典型的树的构建与遍历问题。后序遍历和中序遍历可以得出一个树的结构,用递归锁定根节点,然后再遍历左右子树,我之前发过类似题目的博客&…

GIS 中的 SQLAlchemy:空间数据与数据库之间的桥梁

利用 SQLAlchemy 在现代应用程序中无缝集成地理空间数据导言 地理信息系统(GIS)在管理城市规划、环境监测和导航系统等各种应用的空间数据方面发挥着至关重要的作用。虽然 PostGIS 或 SpatiaLite 等专业地理空间数据库在处理空间数据方面非常出色&#…

【AppleID】注册M区AppleID 2025年

注册(一台电脑一天只能注册一个) https://account.apple.com/ 需任意邮箱,任意手机 手机上登录后填地址 示例

frida的常用api

1、Hook普通方法、打印参数和修改返回值 Hook函数 Hook代码 function hookTest1(){var utils Java.use("com.zj.wuaipojie.Demo");utils.a.implementation function(str){// a "test";var retval this.a(str);console.log(str , retval);return retva…