python 理解BN、LN、IN、GN归一化、分析torch.nn.LayerNorm()和torch.var()工作原理

news2025/1/16 5:06:57

目录

 

前言:

简言之BN、LN、IN、GN等归一化的区别:

批量归一化(Batch Normalization,BN)

优点

缺点

计算过程

层归一化(Layer Normalization,LN)

优点 

计算过程

总结

分析torch.nn.LayerNorm()工作原理

分析torch.var()工作原理

torch.var()函数 

参数

关键字参数

重点


前言:

最近在学习Vit(Vision Transformer)模型,在构建自注意力层(Attention)和前馈网络层(MLP)时,用到了torch.nn.LayerNorm(dim),也就是LN归一化,与常见卷积神经网络(CNN)所使用的BN归一化略有不同。

简言之BN、LN、IN、GN等归一化的区别:

假设输入样本为4张大小为240x240的彩色图片,因此样本Batch数量N为4,RGB彩色通道Channel为3,长H为240,宽W为240,样本数据矩阵为[4,3,240,240]

BN归一化相当于作用在通道维度上,一共3次归一化,分别求通道1、2、3的4张240x240照片的均值和方差,也就是分别计算3次[4,240,240]数据的均值和方差。

LN归一化相当于作用在样本数量上,一共4次归一化,分别求照片1、2、3、4的均值和方差,也就是计算4次[3,240,240]数据的均值和方差。

IN归一化相当于作用在样本数量和通道维度上,一共3x4=12次归一化,分别求照片1、2、3、4的通道1、2、3的均值和方差,也就是计算12次[240,240]数据的均值和方差。

GN归一化相当于作用在样本数量和以组为单位的通道维度上,例如将通道维度分为两组,第一组为通道1、2,第二组为通道3,一共2x4=8次归一化,分别求照片1、2、3、4的通道组1的均值和方差和照片1、2、3、4的通道组2的均值和方差,也就是计算4次[2,240,240]和4次[1,240,240]数据的均值和方差。

批量归一化(Batch Normalization,BN)

优点

1、极大提升了训练速度,收敛过程大大加快;

2、减弱对初始化的强依赖性;

3、保持隐藏层中数值的均值、方差不变,让数值更稳定,为后面网络提供坚实的基础;

4、还能增加分类效果,一种解释是这是一种防止过拟合的正则化表达方式(相当于给隐藏层加入噪声,类似Dropout),所以不用Dropout也能达到相当的效果;

5、另外调参过程也简单多了,对于初始化要求没那么高,而且可以使用大的学习率等;

缺点

1、每次是在一个batch上计算均值、方差,如果batch size太小,则计算的均值、方差不足以代表整个数据分布。

2、batch size太大会超过内存容量;需要跑更多的epoch,导致总训练时间变长;会直接固定梯度下降的方向,导致很难更新。

由于BN与mini-batch的数据分布紧密相关,故而mini-batch的数据分布需要与总体的数据分布近似相等。因此BN适用于batch size较大且各mini-batch分布相近似的场景下(训练前需进行充分的shuffle)。BN计算过程需要保存某一层神经网络batch的均值和方差等统计信息,适合定长网络结构DNN CNN,不适用动态网络结构RNN。

计算过程

1、沿着通道计算每个batch的均值μ

2、沿着通道计算每个batch的方差σ²

3、将每个值进行归一化(分母方差加了一个极小数,防止分母为0)

4、加入缩放和平移变量 γ 和 β(深度学习就是在学习变量 γ 和 β的大小)

在这里插入图片描述

详细内容补充:

笔记详情 (bilibili.com)

层归一化(Layer NormalizationLN)

优点 

LN不受batch size的影响。同时,LN可以很好地用到序列型网络RNN中。 

计算过程

针对BN不适用于深度不固定的网络(sequence长度不一致,如RNN),LN对深度网络的某一层的所有神经元的输入按以下公式进行normalization操作。

在这里插入图片描述

LN中同层神经元的输入拥有相同的均值和方差,不同的输入样本有不同的均值和方差
对于特征图在这里插入图片描述 ,LN 对每个样本的 C、H、W 维度上的数据求均值和标准差,保留 N 维度。其均值和标准差公式为:

在这里插入图片描述

Layer Normalization (LN) 的一个优势是不需要批训练,在单条数据内部就能归一化。LN不依赖于batch size和输入sequence的长度,因此可以用于batch size为1和RNN中。LN用于RNN效果比较明显,但是在CNN上,效果不如BN。

总结

我们将feature map shape 记为[N, C, H, W]。如果把特征图比喻成一摞书,这摞书总共有 N 本,每本有 C 页,每页有 H 行,每行 有W 个字符。

在这里插入图片描述

1、BN是在batch上,对N、H、W做归一化,而保留通道 C 的维度。BN 相当于把这些书按页码一一对应地加起来,再除以每个页码下的字符总数:N×H×W。

2、LN在通道方向上,对C、H、W归一化。LN 相当于把每一本书的所有字加起来,再除以这本书的字符总数:C×H×W。

3、IN在图像像素上,对H、W做归一化。IN 相当于把一页书中所有字加起来,再除以该页的总字数:H×W。

4、GN将channel分组,然后再做归一化。GN 相当于把一本 C 页的书平均分成 G 份,每份成为有 C/G 页的小册子,对每个小册子做Norm。

另外,还需要注意它们的映射参数γ和β的区别:对于 BN,IN,GN, 其γ和β都是维度等于通道数 C 的向量。而对于 LN,其γ和β都是维度等于 normalized_shape 的矩阵。

最后,BN 和 IN 可以设置参数:momentum和track_running_stats来获得在整体数据上更准确的均值和标准差。LN 和 GN 只能计算当前 batch 内数据的真实均值和标准差。

IN和GN请参考 :

(14条消息) 常用的归一化(Normalization) 方法:BN、LN、IN、GN_归一化方法_初识-CV的博客-CSDN博客

深度学习之9——逐层归一化(BN,LN) - 知乎 (zhihu.com)

其他归一化方法可见博主另一篇文章:

(14条消息) 【机器学习】数据归一化全方法总结:Max-Min归一化、Z-score归一化、数据类型归一化、标准差归一化等_daphne odera�的博客-CSDN博客

分析torch.nn.LayerNorm()工作原理

通过以下代码分析torch.nn.LayerNorm()在nlp模型中是如何工作的,计算输入数据是一批单词嵌入序列: 

import torch

batch_size, seq_size, dim = 1, 2, 3
embedding = torch.randn(batch_size, seq_size, dim)
print("x: ", embedding)

layer_norm = torch.nn.LayerNorm(dim)
print("y: ", layer_norm(embedding))

 结果如下:

x:  tensor([[[-0.5975,  2.0992,  0.1889],
         [ 0.9362,  1.2452, -0.7753]]])
y:  tensor([[[-1.0253,  1.3562, -0.3309],
         [ 0.5261,  0.8738, -1.3999]]], grad_fn=<NativeLayerNormBackward0>)

我们编写LN归一化的代码,模拟torch.nn.LayerNorm()工作流程:

def custom_layer_norm(
    x: torch.Tensor, dim: tuple[int] = -1, eps: float = 0.00001
) -> torch.Tensor:
    mean = torch.mean(embedding, dim=dim, keepdim=True)
    var = torch.square(embedding - mean).mean(dim=(-1), keepdim=True)
    return (embedding - mean) / torch.sqrt(var + eps)

print("y_custom: ", custom_layer_norm(embedding))

 结果如下(一模一样):

y_custom:  tensor([[[-1.0253,  1.3562, -0.3309],
         [ 0.5261,  0.8738, -1.3999]]])

未加入上述所说的缩放和平移变量 γ 和 β,直接通过每个样本嵌入值的均值和方差来计算:

mean = torch.mean(embedding[0, :, :])
std = torch.sqrt(torch.var(embedding[0, :, :], unbiased=False)) # 母体方差 分母为N unbiased默认为True 样本方差 无偏估计 分母为n-1
print("mean: ", mean)
print("std: ", std)
print((embedding[0, 0, :] - mean) / std)

结果如下(较为接近):

mean:  tensor(0.5161)
std:  tensor(1.0189)
tensor([-1.0929,  1.5537, -0.3212])

分析torch.var()工作原理

在计算方差时,使用了torch.var()函数,仅由一个参数决定torch.var()计算的是样本方差还是母体方差,所以着重讲解一下。

import numpy as np
print("np.var: ", np.var([[1, 2], [2, 3]]))

# 结果如下

np.var:  0.5

我们在写一个案例:

X_test = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
print("np.var unbiased=Ture: ", torch.var(X_test))
print("np.var unbiased=False: ", torch.var(X_test, unbiased=False))

# 结果如下

np.var unbiased=True:  tensor(0.6667)
np.var unbiased=False:  tensor(0.5000)

为什么结果不一样呢,因为取决于一个参数,即unbiased,无偏的意思。默认值为true,也就是说,默认是计算样本方差,当unbiased=False时,计算的是母体方差,也就是无偏估计。

torch.var()函数 

torch.var(input, dim, unbiased, keepdim=False, *, out=None) → Tensor

参数

  • input(Tensor) -输入张量。

  • dim(int或者python的元组:ints) -要减小的尺寸或尺寸。

关键字参数

  • unbiased(bool) -是否使用贝塞尔校正(δN=1)。

  • keepdim(bool) -输出张量是否保留了dim

  • out(Tensor,可选的) -输出张量。

在这里插入图片描述

重点

 当unbiased=True时(默认),计算的是样本方差,分母是样本数量-1

 当unbiased=False时,计算的是母体方差,也就是无偏估计,分母是样本数量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/415721.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue2-黑马(十一)

目录&#xff1a; &#xff08;1&#xff09;vue2-联调准备 &#xff08;2&#xff09;vue2-登录实战-国际化 &#xff08;3&#xff09;vue2实战-登录-login-index.vue &#xff08;1&#xff09;vue2-联调准备 登录这个请求&#xff0c;并不是发给后台的&#xff0c;现在还…

浙大MBA提面申请材料的三六九等……

每年浙大MBA项目提前批面试申请的每个批次中都会有部分材料因为某些原因而被淘汰&#xff0c;无缘面试资格。考生们由最初的不理解到逐渐隐约的理解&#xff0c;行至今日也可以大体接受材料被刷这个结果&#xff0c;当然其中含有一部分面上资质背景还可以的考生&#xff0c;等到…

Faster-RCNN代码解读2:快速上手使用

Faster-RCNN代码解读2&#xff1a;快速上手使用 前言 ​ 因为最近打算尝试一下Faster-RCNN的复现&#xff0c;不要多想&#xff0c;我还没有厉害到可以一个人复现所有代码。所以&#xff0c;是参考别人的代码&#xff0c;进行自己的解读。 ​ 代码来自于B站的UP主&#xff08;…

中国电子学会2023年03月份青少年软件编程Scratch图形化等级考试试卷四级真题(含答案)

2023-03 Scratch四级真题 分数&#xff1a;100 题数&#xff1a;24 测试时长&#xff1a;90min 一、单选题(共10题&#xff0c;共30分) 1.编写一段程序&#xff0c;从26个英文字母中&#xff0c;随机选出10个加入列表a。空白处应填入的代码是&#xff1f;&#xff08;C&am…

Flink (十二) --------- Flink CEP

目录一、基本概念1. CEP 是什么2. 模式 (Pattern)3. 应用场景二、快速上手1. 需要引入的依赖2. 一个简单实例三、模式 API&#xff08;Pattern API&#xff09;1. 个体模式2. 组合模式3. 模式组4. 匹配后跳过策略四、模式的检测处理1. 将模式应用到流上2. 处理匹配事件3. 处理超…

【高项】项目整体管理、范围管理与进度管理(十大管理)

【高项】项目整体管理与范围管理 文章目录1、项目整体管理1.1 整体管理的过程1.2 制定项目章程&#xff08;启动&#xff09;1.3 制订项目管理计划&#xff08;规划&#xff09;1.4 指导与管理项目执行&#xff08;执行&#xff09;1.5 监控项目工作与实施整体变更控制&#xf…

Systemverilog中operators和expression的记录

1. Equality operators Equality operators有三种&#xff1a; Logical equality&#xff1a;, !&#xff0c;该运算符中如果运算数包含有x/z态&#xff0c;那么结果就是x态。只有在两边的bit都不包含x/z态&#xff0c;最终结果才会为0(False)或1(True)Case equality&#xf…

中云盾DDoS云防护系统

中云盾 DDoS 防护系统作为公司级网络安全产品&#xff0c;为各类业务提供专业可靠的 DDoS/CC 攻击防护。在黑客攻防对抗日益激烈的环境下&#xff0c; DDoS 对抗不仅需要 “降本” 还需要 “增效”。 为什么上云&#xff1f; 云原生作为近年来相当热门的概念&#xff0c;无论…

RHCE-NTP、SSH服务器

1.配置ntp时间服务器&#xff0c;确保客户端主机能和服务主机同步时间​ 服务器端&#xff1a; &#xff08;1&#xff09;首先安装chrony软件&#xff1a; dnf install -y chrony &#xff08;2&#xff09;配置时间同步源&#xff1a; 进入vim /etc/chrony.conf &#xf…

引用和指针

总结 引用&#xff1a; 因为引用是变量的别名&#xff0c;所以引用必须初始化 因为引用不存在自己的地址&#xff0c;所以指针不能指向引用&#xff0c;即不能定义引用的指针 因为引用不是对象&#xff0c;但是引用又要绑定一个对象&#xff0c;所以不能定义引用的引用 in…

一篇文章看懂C++三大特性——多态的定义和使用

目录 前文 一&#xff0c;什么是多态&#xff1f; 1.1 多态的概念 二&#xff0c; 多态的定义及实现 2.1 多态的构成条件 2.2 虚函数 2.3 虚函数的重写 2.3.1 虚函数重写的两个例外 2.4 C override 和 final 2.5 重载,重写(覆盖),隐藏(重定义)的区别 三&#xff0c;抽…

代码随想录刷题-双指针总结篇

文章目录双指针移除元素习题我的解法双指针优化反转字符串习题我的解法剑指 Offer 05. 替换空格习题我的解法正确解法反转字符串里的单词习题我的解法反转链表习题我的解法删除链表的倒数第 N 个节点习题我的解法相交链表习题我的解法环形链表 II习题我的解法三数之和习题我的解…

Unity VFX -- (3)创建环境粒子系统

粒子系统中最常用也最重要的一种使用场景是实现天气效果。只需要做很少修改&#xff0c;场景就能很快从蓝天白云变成雪花飘舞。 和之前看到的粒子系统从一个源头发出粒子的情况不同&#xff0c;天气效果完全围绕着场景。 新增和放置一个新的粒子系统 为了创建下雨或下雪的天气…

【从零开始学Skynet】基础篇(三):服务模块常用API

1、服务模块 Skynet提供了开启服务和发送消息的API&#xff0c;必须要先掌握它们。列出了Skynet中8个最重要的API&#xff0c;PingPong程序会用到它们。 Lua API说明newservice(name, ...) 启动一个名为 name 的新服务&#xff0c;并返回服务的地址。 start(func) …

【学习笔记】unity脚本学习(二)(Time时间体系、Random随机数、Mathf数学运算)

目录Time时间体系timeScalemaximumDeltaTimefixedDeltaTimecaptureDeltaTimedeltaTime整体展示Random随机数Mathf数学运算IMathf.Round()Mathf.Ceil() Mathf.CeilToInt()Mathf.SignMathf.ClampMathf数学运算II-曲线变换Lerp 线性插值LerpAngleSmoothDamp疑问&#xff1a;Smooth…

自己动手写编译器:DFA跳转表的压缩算法

在编译器开发体系中有两套框架&#xff0c;一个叫"lex && yacc", 另一个名气更大叫llvm&#xff0c;这两都是开发编译器的框架&#xff0c;我们只要设置好配置文件&#xff0c;那么他们就会生成相应的编译器代码&#xff0c;通常是c或者c代码&#xff0c;然后…

AI自动寻路AStar算法【图示讲解原理】

文章目录AI自动寻路AStar算法背景AStar算法原理AStar寻路步骤AStar具体寻路过程AStar代码实现运行结果AI自动寻路AStar算法 背景 AI自动寻路的算法可以分为以下几种&#xff1a; 1、A*算法&#xff1a;A*算法是一种启发式搜索算法&#xff0c;它利用启发函数&#xff08;heu…

Jmeter接口测试和性能测试

目前最新版本发展到5.0版本&#xff0c;需要Java7以上版本环境&#xff0c;下载解压目录后&#xff0c;进入\apache-jmeter-5.0\bin\&#xff0c;双击ApacheJMeter.jar文件启动JMemter。 1、创建测试任务 添加线程组&#xff0c;右击测试计划&#xff0c;在快捷菜单单击添加-…

STM32F103RCT6驱动SG90舵机-完成正反转角度控制

一、SG90舵机介绍 SG90是一种微型舵机&#xff0c;也被称为伺服电机。它是一种小型、低成本的直流电机&#xff0c;通常用于模型和机器人控制等应用中。SG90舵机可以通过电子信号来控制其精确的位置和速度。它具有体积小、重量轻、响应快等特点&#xff0c;因此在各种小型机械…

亚马逊测评只能下单上好评?卖家倾向养号测评还有这些骚操作

亚马逊测评这对于绝大部分亚马逊卖家来说都不陌生&#xff0c;如今的亚马逊市场也很多卖家都在用测评科技来打造爆款。不过很多对于亚马逊测评的认知只停留在简单的刷销量&#xff0c;上好评。殊不知亚马逊养号测评还有其它强大的骚操作。 亚马逊自养号测评哪些功能呢&#xf…