【Transformer】Post-Norm和Pre-Norm

news2024/9/20 14:28:38

文章目录

  • Transformer中使用的Post-Norm
  • 大模型常用的Pre-Norm
  • Pre-Norm一定比Post-Norm好吗?
  • 二者区别总结
  • 参考资料

Pre-Norm和Post-Norm的区别,是面试官非常喜欢问的问题。下面我们按照时间线,尽可能直白地讲清楚二者的区别。

在这里插入图片描述
直观来讲,Pre-Norm 和 Post-Norm 的区别就是 Layer Norm 和 Residual Connections 组合方式的不同。

  • Post-Norm:传统的Layer Norm放在残差之后,做完Add再进行归一化。早期的很多模型都用的是 Post-Norm,比如著名的 Bert
  • Pre-Norm:目前大模型大多数的做法是 先对输入做Layer Norm,然后再进行函数计算(例如attention和FFN)以及Add相加。Pre Norm 的训练更快,且更加稳定,所以之后的大模型架构大多都是 Pre Norm 了,比如 GPT,MPT,Falcon,和Llama。

Transformer中使用的Post-Norm

在原始Transformer中,对数据进行Normalization的方法是Post-Layer Norm(想要对Normalization进行深入理解的小伙伴可以看这篇博客:【Transformer】Normalization),如下图所示:

在这里插入图片描述

Post-Norm用公式表示为:

y = N o r m ( x l + a t t n ( x l ) ) x l + 1 = N o r m ( y + F F N ( y ) ) \begin{aligned} y&=Norm(x_l + attn(x_l)) \\ x_{l+1}&=Norm(y + FFN(y)) \end{aligned} yxl+1=Norm(xl+attn(xl))=Norm(y+FFN(y))

Post-Norm 之所以这么设计,是把 Normalization 放在一个模块的最后,这样下一个模块接收到的总是归一化后的结果。这比较符合 Normalization 的初衷,就是为了降低梯度的方差。但是层层堆叠起来,从上图可以看出,深度学习的基建 ResNet 的结构其实被破坏了,也就是Residual Connections在梯度反向传播时消失了。这就导致训练 Transformer 并不是那么容易的事情,需要加上各种补偿措施,例如 learning rate warm up, 初始化等。

那么为什么Post-Norm会导致Residual Connections在梯度反向传播时消失呢?

这里我们假设输入 x l x_l xl A t t e n t i o n ( x l ) Attention(x_l) Attention(xl) F N N ( x l ) FNN(x_l) FNN(xl) 的均值为0,方差都为1,且相互独立。事实上可能没有那么理想,因为权重矩阵的分布在学习过程中并不一定能保持理想的分布,这里为了说明问题对建模进行了简化。

我们知道,对于两个均值为0,方差为1且相互独立的分布 x 1 , x 2 ∈ N ( 0 , 1 ) x_1, x_2 \in N(0, 1) x1,x2N(0,1),那么 ( x 1 + x 2 ) (x_1+x_2) (x1+x2)就是均值为0,方差为2的分布,那么 Layer Norm对 ( x 1 + x 2 ) (x_1+x_2) (x1+x2)的计算就可以表示为 y = x 1 + x 2 2 y=\frac{x_1+x_2}{\sqrt{2}} y=2 x1+x2

所以,Post-Norm的计算公式可以简化为:
y = N o r m ( x l + a t t n ( x l ) ) = x l + a t t n ( x l ) 2 x l + 1 = N o r m ( y + F F N ( y ) ) = y + F F N ( y ) 2 = x l + a t t n ( x l ) 2 + F F N ( y ) 2 = x l + a t t n ( x l ) + 2 F F N ( y ) 2 = x l + f ( x l ) 2 \begin{aligned} y&=Norm(x_l+attn(x_l)) \\ &=\frac{x_l+attn(x_l)}{\sqrt{2}} \\ x_{l+1}&=Norm(y + FFN(y)) \\ &=\frac{y+FFN(y)}{\sqrt{2}} \\ &=\frac{\frac{x_l+attn(x_l)}{\sqrt{2}}+FFN(y)}{\sqrt{2}} \\ &=\frac{x_l+attn(x_l)+\sqrt{2}FFN(y)}{2} \\ &=\frac{x_l+f(x_l)}{2} \end{aligned} yxl+1=Norm(xl+attn(xl))=2 xl+attn(xl)=Norm(y+FFN(y))=2 y+FFN(y)=2 2 xl+attn(xl)+FFN(y)=2xl+attn(xl)+2 FFN(y)=2xl+f(xl)

这里我们用 f ( x l ) f(x_l) f(xl)来表示对输入 x l x_l xl的复杂计算(包括attention和FFN)。这里可以看出,输入 x l x_l xl每经过一层,输出就变成 x l + f ( x l ) 2 \frac{x_l+f(x_l)}{2} 2xl+f(xl)。那么最终的输出对于最开始的输入来说:

o u t p u t = x l − 1 + f l ( x l − 1 ) 2 = x l − 1 2 + f l ( x l − 1 ) 2 = x l − 2 2 2 + f l − 1 ( x l − 2 ) 2 2 + f l ( x l − 1 ) 2 = x l − 3 2 3 + f l − 2 ( x l − 3 ) 2 3 + f l − 1 ( x l − 2 ) 2 2 + f l ( x l − 1 ) 2 = x 1 2 l − 1 + f 2 ( x 1 ) 2 l − 1 + . . . + f l − 1 ( x l − 2 ) 2 2 + f l ( x l − 1 ) 2 = x 1 2 l − 1 + g ( x ) \begin{aligned} output&=\frac{x_{l-1}+f_{l}(x_{l-1})}{2}\\ &=\frac{x_{l-1}}{2}+\frac{f_{l}(x_{l-1})}{2} \\ &=\frac{x_{l-2}}{2^2}+\frac{f_{l-1}(x_{l-2})}{2^2}+\frac{f_l(x_{l-1})}{2} \\ &=\frac{x_{l-3}}{2^3}+\frac{f_{l-2}(x_{l-3})}{2^3}+\frac{f_{l-1}(x_{l-2})}{2^2}+\frac{f_l(x_{l-1})}{2} \\ &=\frac{x_{1}}{2^{l-1}}+\frac{f_{2}(x_{1})}{2^{l-1}}+...+\frac{f_{l-1}(x_{l-2})}{2^2}+\frac{f_l(x_{l-1})}{2} \\ &=\frac{x_{1}}{2^{l-1}}+g(\bold{x}) \end{aligned} output=2xl1+fl(xl1)=2xl1+2fl(xl1)=22xl2+22fl1(xl2)+2fl(xl1)=23xl3+23fl2(xl3)+22fl1(xl2)+2fl(xl1)=2l1x1+2l1f2(x1)+...+22fl1(xl2)+2fl(xl1)=2l1x1+g(x)

这里我们用 g ( x ) g(\bold{x}) g(x)表示全部网络层对输入 x 1 x_1 x1的作用结果。然后我们对最终的输出求导可得:
∂ ( x 2 l − 1 + g ( x ) ) ∂ x = 1 2 l − 1 + ∂ g ( x ) ∂ x \frac{\partial\left(\frac{x}{2^{l-1}}+g(x)\right)}{\partial x}=\frac{1}{2^{l-1}}+\frac{\partial g(x)}{\partial x} x(2l1x+g(x))=2l11+xg(x)

看到了吗!!Resnet的Residual Connections没了,因为一般ResNet求导结果应该是: ∂ ( f ( x ) + x ) ∂ x = 1 + ∂ f ( x ) ∂ x \frac{\partial(f(x)+x)}{\partial x}=1+\frac{\partial f(x)}{\partial x} x(f(x)+x)=1+xf(x),这里的1起到了防止梯度消失的作用,因此可以稳定训练更深的网路模型。但是经过我们推导发现,Post-Norm求导结果中的第一项,会随着网络层数的增加而指数递减。层数较低还好,如果像是现在的大模型一样堆叠32甚至64层,那几乎和0没什么区别了,也就丧失了 ResNet 的意义。

没了 ResNet 的架构,就导致在训练 Transforemr 的时候,需要小心翼翼。一般都要加一个 learning rate warm up 的过程,先让模型在小学习率上适应一段时间,然后再正常训练。warm up 的过程虽然在 Transformers 的论文里就提了一嘴,但是真正训练的时候会发现真的很重要。

大模型常用的Pre-Norm

发表在ACL 2019上的Learning Deep Transformer Models for Machine Translation 这篇文章首次提出了Layer Normalization位置对训练一个深层的Transformer模型至关重要,并且也开启了后续大家对Layer Normalization的探索。

同样的方式,让我们来看看Pre-Norm是什么计算的?下图左侧就是我们刚推导过的,Transformer原始论文中提到的Post-Norm,而右侧,则是现在大模型常用的Pre-Norm。
在这里插入图片描述

Pre-Norm用公式表示为:

y = x l + a t t n ( N o r m ( x l ) ) x l + 1 = y + F F N ( N o r m ( y ) ) \begin{aligned} y&=x_l + attn(Norm(x_l)) \\ x_{l+1}&=y + FFN(Norm(y)) \end{aligned} yxl+1=xl+attn(Norm(xl))=y+FFN(Norm(y))

很明显,Pre-Norm很好的保留了ResNet的核心Residual Connections,反向传播计算梯度时很好的缓解了梯度消失的问题。

因此,到这里可以总结出Pre-Norm相比于Post-Norm是有优势的,也就是:

  • Post Norm,对模型,尤其是较深的模型训练不稳定,梯度容易爆炸,学习率敏感,初始化权重敏感,收敛困难。因此需要做大量调参工作,以及learning rate warm up的必要工作,费时费力。
  • Pre Norm 则在训练稳定和收敛性方面有明显的优势,所以大模型时代基本都无脑使用 Pre Norm 了。

Pre-Norm一定比Post-Norm好吗?

但是 Pre Norm 也并不是都是好的,2020年的Understanding the Difficulty of Training Transformers这篇论文指出,Pre Norm 有潜在的 Representation Collapse (表示塌陷)问题,具体来说就是靠近输出位置的层会变得非常相似,从而对模型的贡献会变小。

因此,2023年微软提出的ResiDual: Transformer with Dual Residual Connections,就试图融合 Pre Norm 和 Post Norm 的优点。

在这里插入图片描述

在这里插入图片描述
这也就暗示着 Post Norm 虽然不好训练,但是潜力可能比 Pre Norm 更好。同时这篇论文中提到的在 Layer Norm 的时候,调整 x x x f ( x ) f(x) f(x) 的比重,其思路被 DeepNorm 借鉴。只不过这里是可学习的权重,而 DeepNorm 则是超参数。

二者区别总结

  • Post Norm

    • 对模型,尤其是较深的模型训练不稳定,梯度容易爆炸,学习率敏感,初始化权重敏感,收敛困难。因此需要做大量调参工作,以及learning rate warm up的必要工作,费时费力
    • 潜在好处是,在效果上的优势,但是这个事情还需要大量专业的实验来验证,毕竟现在大模型训练太费钱了,Post Norm 在效果上带来的提升很可能不如多扔点数据让 Pre Norm 更快的训练出来
  • Pre Norm

    • 在训练稳定和收敛性方面有明显的优势,所以大模型时代基本都无脑使用 Pre Norm 了
    • 但是其可能有潜在的Representation Collapse(表示塌陷) 问题,也就是上限可能不如 Post Norm

参考资料

  • [1] https://note.mowen.cn/note/detail?noteUuid=xFoeBN-Ez4OcjRDEs1b51
  • [2] https://zhuanlan.zhihu.com/p/474988236

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2121048.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MySQL】MySQL表的增删改查(进阶篇)——之约束条件

前言: 🌟🌟本期讲解关于MySQL表增删查改进阶篇,希望能帮到屏幕前的你。 🌈上期博客在这里:http://t.csdnimg.cn/cF0Mf 🌈感兴趣的小伙伴看一看小编主页:GGBondlctrl-CSDN博客 目录 …

什么是监督学习(Supervised Learning)

一、监督学习概述 监督学习(Supervised Learning)是一种极具威力的机器学习方法,能够训练算法以识别数据中的模式,并据此进行精准的预测或分类。借助已有的标记数据,监督学习模型学会了从输入到输出的映射关系&#x…

导弹制导方式简介-其实跟卫星定位系统关系不大

导弹制导方式其实跟卫星定位系统关系不大,所以所谓关闭卫星定位系统导弹就不能打是谣言! 导弹制导是指利用不同的方式,选择飞行路线,将具有动力飞行的弹头移动一段距离之后,击中预先设定的目标。导弹制导系统利用其中…

【网络】十大网络协议

文章目录 1. HTTP(HyperText Transfer Protocol,超文本传输协议)2. HTTPS(Secure Hypertext Transfer Protocol,安全超文本传输协议)3. HTTP/34. TCP(Transmission Control Protocol&#xff0c…

树莓派5-番外篇-GPU相关-学习记录2

树莓派5-番外篇-GPU相关 要查看你的树莓派5是否支持GPU计算,以及如何启用和使用它,你需要了解树莓派5的硬件配置和当前的驱动支持情况。以下是查看树莓派5的GPU支持情况的步骤。 树莓派5 GPU 支持概述 树莓派5 使用的是 Broadcom BCM2712 处理器&…

大数据新视界 --大数据大厂之Hive与大数据融合:构建强大数据仓库实战指南

💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

Ps初步使用教程

1.画面快捷键 Ctrl鼠标滚轮:画面左右移动 Shift鼠标滚轮:画面上下快速移动(不加Shift正常速度移动) Alt鼠标滚轮:画面大小缩放 2.工具快捷键 ShiftG:渐变、油漆桶、3D材质施放 切换 CtrlO&#xff1a…

【Unity3D】如何用MMD4Mecanim插件将pmx格式模型转换为fbx格式模型

文章目录 概要一、下载MMD4Mecanim插件并导入U3D1.1 下载链接1.2 导入过程 二、将.pmx模型转换为.fbx模型三、其他参数设置3.1 VMD参数3.2 Animations 概要 在Unity的环境下,想要将.pmx格式的3D模型转换为.fbx是有可以用的插件的,并不需要某些教程中那么…

Vue3+TypeScript二次封装axios

安装如下 npm install axios 第一步:创建config配置文件,用于存放请求后端的ip地址,用于后期打包后便于修改ip地址。 注:typescript要求参数要有类型。(ES6 定义对象 属性 类型 修改的是属性的值) inte…

超级干货|AI产品经理6大知识体系,【附零基础小白入门指南】

想要转行AI产品经理的宝子,这6大知识体系是你入门的基础 💥基础知识:AI产品的根基 💥平台和硬件支持:AI产品的技术基础设施 💥AI核心技术:推动产品创新的引擎 💥行业实践应用&#…

Python计算机视觉 第8章-图像内容分类

Python计算机视觉 第8章-图像内容分类 8.1 K邻近分类法(KNN) 在分类方法中,最简单且用得最多的一种方法之一就是 KNN(K-Nearest Neighbor ,K邻近分类法),这种算法把要分类的对象(例如一个特征…

知网合作商AEPH出版,学生/教师均可投稿,优先录用教育社科领域,往期最快2周见刊

AEPH出版社旗下有5本学术期刊,专门出版自然科学、社会科学研究与教育领域论文的高影响力期刊,拥有正规ISSN号,出版类型涉及应用和理论方面的原创和未曾公开发表的研究论文,分配独立DOI号。AEPH作为中国知网(CNKI&#…

COCOS:(飞机大战01)背景图无线循环向下滚动

飞机大战知识点总结 背景图宽高:480*852 将背景图移动到Canvas中 设置图2的Y轴为852,这样图1和图2就衔接上了 创建控制背景的ts文件 import { _decorator, Component, Node } from cc; const { ccclass, property } _decorator;ccclass(Bg) export cla…

HTTPS证书申请

🌐 JoySSL CA机构 机构介绍:JoySSL是网盾安全基于全球可信顶级根创新推出的新一代https数字证书,也是目前为数不多的中国自主品牌SSL证书。 服务特点:JoySSL携手全球权威CA机构,全球多节点服务器验证签发,安…

MySQL系列—8.存储结构

目录 1.系统表空间 ibdata 2.通用表空间 .ibd 3.独立表空间 4.Undo 表空间 5.临时表空间 6.Redo Log File 1.系统表空间 ibdata 系统表空间由参数innodb_data_file_path定义路径、初始化大小、自动扩展策略 如: innodb_data_file_path/dayta/mysql/ibdata1:…

【机器学习】C++与OpenCV实战:创建你的第一个图片显示程序

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 前言 在计算机视觉领域,OpenCV是一个非常强大的开源库,它提供了大量的图像处理和计算机视觉功能。C与Op…

这次我可真没手写代码

我是一个平平无奇的程序员,兢兢业业地做着公司的需求,直到那天,一位十年见过两面的亲戚突然找上门,他说: 小Z啊,听说你是学计算机的对吧。 听完这话,我当场汗流浃背,这不是让我上门修…

Web组件:Servlet Listener Filter

1 前言 1.1 内容概要 掌握ServletContextListener的使用,并且理解其执行时机掌握Filter的使用,并且理解其执行时机能够使用Filter解决一些实际的问题 1.2 前置知识准备 Servlet的执行 ServletContext的功能和使用 2 Web组件 JavaEE的三大Web组件 …

Aigtek功率放大器在超声检测陶瓷复合材料内部缺陷中的应用

2023年5月30日,神舟十六号载人飞船于9时28分左右在酒泉基地正式发射。本次神舟十六号最引人注目的一点就是它的元件国产率相较之前是大大提高了。选择提高自主研发能力,一方面是防范他国在技术、贸易上的“卡脖子”隐患,一方面也是我国制造实…

【智慧物流】新中地智慧城市实训:优秀学生项目作品1

实训结束后,同学们在最后的答辩中纷纷展现了自己的优秀成果,并以小组的形势进行汇报。今天截取部分学生优秀作品给大家进行展示,帮助大家快速了解智慧系列项目的效果。 智慧城市开发项目主题:智慧物流 (为保护学生隐…