InstanceNorm LayerNorm

news2025/2/24 23:13:41

InstanceNorm && LayerNorm

@author: SUFEHeisenberg

@date: 2023/01/26

先说结论:

  • 将Transformer类比于RNN:一个token就是一层layer,对一整句不如token有意义
  • 原生Bert代码或huggingface中用的都是InstanceNorm instead of LayerNorm,但都是torch.nn.LayerNorm实现的。

1. 对NLP数据的理解

NLP input data的为[batch_size, sequence_len, dim] 表示为[K, N, D]

关键就是这个形如 [K, N, D] 的 tensor 它其实不是一层,而是N 个形如 [K, D] 的层拼接的结果。用 RNN 来想就很明白了,计算完时间步 t 以后才能计算时间步 t+1,比如 h_{t+1}=tanh(Wh_t+b),h_{t+1} 和 h_t 在计算图上的深度都不同,显然第 t 个词的 D 维向量和第 t+1 个词的 D维向量属于两个不同的层。只不过为了方便使用,会把所有的 h_{1:T} 都拼接起来组成一个 tensor 返回。在 xfmr 里,因为各个时间步可以同时计算,所以这一点不够不明显了。

简单来说,由于RNN在每个时间步都共享同一套参数(其实Transformer也是一样,同一层的不同token共享同一套QKV),BatchNorm是跨时间步进行的(换句话说就是跨token进行的,因为一个batch中所有句子在同一位置的token属于同一个时间步),而LayerNorm是只取决于当前时间步(或者说当前这个token)。

从这样的视角来看,或者说从网络的实际计算流程来看,对于一批文本输入[K, N, D],实际上可以看作是由N个[K, d]的输入拼接而来的,其中每个[K, D]代表的是一个batch中所有句子在某一位置(或说某个时间步)的token嵌入组成的。在RNN中,[K, D]按时序依次输入网络,在Transformer中则是通过并行计算,但本质上都是通过同一套参数来计算。 因此,在[K, N, D]上进行LN,就是在SN个[K, D]这样的批数据上进行LN,只不过这N个LN共享同一套 gain( γ \gamma γ)和bias( β \beta β)。

在此直接照搬知乎大佬的讲解,通过举例已经非常浅显透彻了。

震惊!BERT用LayerNorm的可能不是你认为的那个Layer Norm?

2. 结合公式举个栗子🌰

2.1 生成demo data

import torch
K, N, D = 2, 3, 4
# 生成demo data
embedding = torch.randn(K, N, D)
Out[1]: 
tensor([[[ 2.3833,  0.1780,  1.0667,  0.2227],
         [ 0.2482, -0.3889,  0.7117,  0.9091],
         [ 0.4513,  1.6905,  0.5648, -1.2175]],
        [[ 0.1469, -0.9727,  2.5195, -1.3820],
         [-0.0406,  0.4197,  1.8440,  1.2459],
         [ 0.0238,  0.4803, -1.0974, -0.3951]]])

2.2 验证LayerNorm

LayerNorm是K*N*D固定了K,在每一个batch会生成K个mean均值 μ 1 , ⋯   , μ K \mu_1,\cdots,\mu_K μ1,,μK,每个batch中的得到标准差 σ 1 , ⋯   , σ K \sigma_1,\cdots, \sigma_K σ1,,σK

所以,对于第k个batch ∈ R N × D \in\mathbb{R}^{N\times D} RN×D而言(对于bert而言是 R N + 2 , D \mathbb{R}^{N+2,D} RN+2,D, 在此不细究讨论):
X n d ( k ) ′ = X n d ( k ) − μ k σ k X_{nd}^{(k)\prime} = \frac{X^{(k)}_{nd}-\mu_k}{\sigma_k} Xnd(k)=σkXnd(k)μk
μ \mu μ σ \sigma σ是每个batch中N*D的均值方差。

# layer_normalization
layer_norm = torch.nn.LayerNorm([N,dim], elementwise_affine = False)
print("layer_norm: ", layer_norm(embedding))
layer_norm:  tensor([[[ 2.0472, -0.4403,  0.5621, -0.3898],
                      [-0.3610, -1.0797,  0.1617,  0.3843],
                      [-0.1320,  1.2658, -0.0039, -2.0143]],
                     [[-0.0760, -1.0675,  2.0254, -1.4301],
                      [-0.2420,  0.1656,  1.4271,  0.8974],
                      [-0.1850,  0.2193, -1.1780, -0.5560]]])

验证第一行元素:

mean = embedding.mean(dim=(1,2))
# tensor([0.5683, 0.2327])
std = embedding.std(dim=(1,2), unbiased=False) #一定要记得unbiased=False
# tensor([0.8866, 1.1291])
# or 用Var的数学期望定义
var = torch.square(embedding-mean).mean(dim=(1,2))
#tensor([0.7860, 1.2748])
(embedding[0][0]-mean[0])/std[0]
# Out[189]: tensor([[ 2.0472, -0.4403,  0.5621, -0.3898]])

验证所有元素

eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-2,-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-2,-1), keepdim=True)

print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))

2.3 验证InstanceNorm

LayerNorm是K*N*D固定了K*N,在每一个token会生成K*N个mean均值 μ 11 , ⋯   , μ k n , ⋯   , μ K N \mu_{11},\cdots,\mu_{kn},\cdots,\mu_{KN} μ11,,μkn,,μKN,每个batch中的得到标准差 σ 1 , ⋯   , σ k n , ⋯   , σ K N \sigma_1,\cdots,\sigma_{}kn,\cdots, \sigma_{KN} σ1,,σkn,,σKN

所以,对于第kn个token ∈ R 1 × D \in\mathbb{R}^{1\times D} R1×D而言:
X d ( k n ) ′ = X d ( k n ) − μ k n σ k n X_{d}^{(kn)\prime} = \frac{X^{(kn)}_{d}-\mu_{kn}}{\sigma_{kn}} Xd(kn)=σknXd(kn)μkn
μ \mu μ σ \sigma σ是每个batch中每个seq_len的token中D里面的均值方差。

# instance_normalization,可以看到二者其实都是通过nn.LayerNorm实现的
instance_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)
(embedding)
print("instance_norm: ", instance_norm(embedding))
instance_norm:  tensor([[[ 1.5902, -0.8784,  0.1164, -0.8283],
                         [-0.2438, -1.5193,  0.6840,  1.0791],
                         [ 0.0761,  1.2702,  0.1855, -1.5318]],
                        [[ 0.0454, -0.6927,  1.6098, -0.9626],
                         [-1.2464, -0.6145,  1.3411,  0.5199],
                         [ 0.4668,  1.2533, -1.4650, -0.2550]]])

验证第一行元素:

(embedding[0][0]-embedding[0][0].mean())/embedding[0][0].std(unbiased=False)
# tensor([ 1.5902, -0.8784,  0.1164, -0.8283])

验证所有元素

eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)

print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))

Reference

震惊!BERT用LayerNorm的可能不是你认为的那个Layer Norm?

关于BatchNorm与LayerNorm的一点认识

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/179602.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AAAI2023】Head-Free Lightweight Semantic Segmentation with Linear Transformer

论文:【AAAI2023】Head-Free Lightweight Semantic Segmentation with Linear Transformer 代码:https://github.com/dongbo811/AFFormer 这是来自阿里巴巴的工作,作者构建了一个轻量级的Transformer网络用于语义分割,主要有两点…

发现下属的学历造假,但是他的工作能力又很强,该开除他吗?

在职场上混,学历是敲门砖还是定音锤呢?一位网友问:发现下属的学历造假,但是他的工作能力又很强,该开除他吗?有人觉得一定要开除,这就是钻空子,受影响最大的人不是他,而是那些真才实…

上采样与下采样

数据分析中的上采样和下采样 背景: 在分类问题中,由于各种原因,我们所获取到的数据集很容易出现正负样本的不平衡,或者某些数据特别多,有些数据则特别少,在这样的数据集中,进行训练&#xff0c…

OpenCV直方图Java 演示程序

直方图Java 演示程序以下文件编码为utf-8 为佳。代码文件名:OpenCvMain.javapackage org.opencv;import java.net.URL;import java.util.LinkedList;import java.util.List;import org.opencv.core.Core;import org.opencv.core.CvType;import org.opencv.core.Mat;…

Linux常用命令——setpci命令

在线Linux命令查询工具(http://www.lzltool.com/LinuxCommand) setpci 查询和配置PCI设备的使用工具 补充说明 setpci命令是一个查询和配置PCI设备的使用工具。 语法 setpci(选项)(参数)选项 -v:显示指令执行的细节信息; -f:当没有任何…

Opencv形态学操作——腐蚀、膨胀、梯度、开运算、闭运算、礼帽、黑帽(附案例详细讲解及可执行代码)

Opencv形态学操作 腐蚀膨胀梯度开运算闭运算礼帽黑帽总结腐蚀 在地理或者化学中,我们学习过腐蚀,是指在某种 作用下产生损耗与破坏的过程。你也可以理解为减肥。在Opencv中,腐蚀操作可以使白色轮廓变小,也就是说可以去除一些白色的噪声。 如果你接触过卷积核的话,腐蚀就更…

【JavaSE专栏8】运算符、表达式和语句

作者主页:Designer 小郑 作者简介:Java全栈软件工程师一枚,来自浙江宁波,负责开发管理公司OA项目,专注软件前后端开发(Vue、SpringBoot和微信小程序)、系统定制、远程技术指导。CSDN学院、蓝桥云…

盖子的c++小课堂——第十三讲:二维数组

前言 过了几天了,终于有时间更新了,有个通知,以后我不用颜色区分了,不然换了背景看不见,理解一下,蟹蟹~~ 举例 作者:一下是某次奥运会的奖牌榜,你知道如何储存奖牌榜吗~~ 粉丝&am…

机器学习中软投票和硬投票的不同含义和理解

设置一个场景,比如对于今天音乐会韩红会出现的概率三个人三个观点 A:韩红出现的概率为47% B:韩红出现的概率为57% C:韩红出现的概率为97% 软投票:软投票会认为韩红出现的概率为1/3*(47%57%97%)67% 硬投票:…

“子序列问题”系列总结,一文读懂(Java实现)

目录 前言 一、最长递增子序列 1.1、dp定义 1.2、递推公式 1.3、初始化 1.4、注意 1.5、解题代码 二、最长连续递增序列 2.1、分析 2.2、解题代码 三、最长重复子数组 3.1、dp定义 3.2、递推公式 3.3、初始化 3.4、解题代码 四、最长公共子序列 4.1、分析 4.2…

Opencv项目实战:20 单手识别数字0到5

目录 0、项目介绍 1、效果展示 2、项目搭建 3、项目代码展示 HandTrackingModule.py Figures_counter.py 4、项目资源 5、项目总结 0、项目介绍 今天要做的是单手识别数字0到5,通过在窗口展示,实时的展示相应的图片以及文字。 在网上找了很久的…

硬核来袭!!!一篇文章教你入门Python爬虫网页解析神器——BeautifulSoup详细讲解

文章目录一、BeautifulSoup介绍二、安装三、bs4数据解析的原理四、bs4 常用的方法和属性1、BeautifulSoup构建1.1 通过字符串构建1.2 从文件加载2、BeautifulSoup四种对象2.1 Tag对象2.2 NavigableString对象2.3 BeautifulSoup对象2.4 Comment对象五、contents、children与desc…

springboot自定义拦截器的简单使用和一个小例子

springboot自定义拦截器的使用1. 自定义拦截器2. 拦截器登录验证的小demo2.1 配置pom.xml2.2 创建User的bean组件2.3 创建需要的表单页面以及登录成功的页面2.4 编写controller映射关系2.5 自定义拦截器类,实现intercepetor接口2.6注册添加拦截器,自定义…

【SpringCloud】Nacos集群搭建

集群结构图官方给出的Nacos集群图如下:其中包含3个nacos节点,然后一个负载均衡器代理3个Nacos。这里负载均衡器可以使用nginx。我们接下来要尝试 Nacos集群搭建,效果图如下所示:三个nacos节点的地址:节点ipportnacos1l…

二、Java框架之Spring注解开发

文章目录1. IOC/DI注解开发1.1 Component注解ComponentController Service Repository1.2 纯注解开发模式1.3 注解开发bean管理ScopePostConstruct PreDestroy1.4 注解开发依赖注入Autowired QualifierValuePropertySource1.5 第三方bean管理Beanimport(多个Config类…

Redisson 完成分布式锁

1、简介 Redisson 是架设在 Redis 基础上的一个 Java 驻内存数据网格(In-Memory Data Grid)。充分 的利用了 Redis 键值数据库提供的一系列优势,基于 Java 实用工具包中常用接口,为使用者 提供了一系列具有分布式特性的常用工具类…

JavaWeb | 揭开SQL注入问题的神秘面纱

本专栏主要是记录学习完JavaSE后学习JavaWeb部分的一些知识点总结以及遇到的一些问题等,如果刚开始学习Java的小伙伴可以点击下方连接查看专栏 本专栏地址:🔥JDBC Java入门篇: 🔥Java基础学习篇 Java进阶学习篇&#x…

MyEclipse提示过期,MyEclipse Subscription Expired解决方案

一、错误描述 某一天打开MyEclipse,突然发现出现如下提示框: 1.错误日志 Thank you for choosing MyEclipse Your license expired 1091 days ago. To continue use of MyEclipse please choose "Buy" to purchase a MyEclipse license. I…

离散系统的数字PID控制仿真-3

离散PID控制的封装界面如图1所示,在该界面中可设定PID的三个系数、采样时间及控制输入的上下界。仿真结果如图2所示。图1 离散PID控制的封装界面图2 阶跃响应结果仿真图:离散PID控制的比例、积分和微分三项分别由Simulink模块实现。离散PID控制器仿真图&…

【servlet篇】servlet相关类介绍

目录 servlet对象什么时候被创建? 2.servlet接口中各个方法的作用 3.相关类和接口介绍 GenericServlet ServletConfig ServletContext HttpServlet servlet对象什么时候被创建? 1,通常情况下,tomcat启动时,并没有…