​语言模型输出端共享Embedding的重新探索

news2025/1/22 12:41:07

4a954f52b86a3e09f8f5f905d4c5cf46.gif

©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 科学空间

研究方向 | NLP、神经网络

预训练刚兴起时,在语言模型的输出端重用 Embedding 权重是很常见的操作,比如 BERT、第一版的 T5、早期的 GPT,都使用了这个操作,这是因为当模型主干部分不大且词表很大时,Embedding 层的参数量很可观,如果输出端再新增一个独立的同样大小的权重矩阵的话,会导致显存消耗的激增。

不过随着模型参数规模的增大,Embedding 层的占比相对变小了,加之《Rethinking embedding coupling in pre-trained language models》[1] 等研究表明共享 Embedding 可能会有些负面影响,所以现在共享 Embedding 的做法已经越来越少了。

本文旨在分析在共享 Embedding 权重时可能遇到的问题,并探索如何更有效地进行初始化和参数化。尽管共享 Embedding 看起来已经“过时”,但这依然不失为一道有趣的研究题目。

b3b06267bf5e57697b62f28e0223cf1e.png

共享权重

在语言模型的输出端重用 Embedding 权重的做法,英文称之为 “Tied Embeddings” 或者 “Coupled Embeddings”,其思想主要是 Embedding 矩阵跟输出端转换到 logits 的投影矩阵大小是相同的(只差个转置),并且由于这个参数矩阵比较大,所以为了避免不必要的浪费,干脆共用同一个权重,如下图所示:

15a85466f5e92e00d17a8319e954beb0.png

▲ 共享 Embedding 权重的 Transformer 示意图

共享 Embedding 最直接的后果可能是——它会导致预训练的初始损失非常大。这是因为我们通常会使用类似 DeepNorm 的技术来降低训练难度,它们都是将模型的残差分支初始化得接近于零。换言之,模型在初始阶段近似于一个恒等函数,这使得初始模型相当于共享 Embedding 的 2-gram 模型。接下来我们将推导这样的 2-gram 模型损失大的原因,以及分析一些解决方案。

934ad426010e7add4074e16028f1e316.png

准备工作

在正式开始推导之前,我们需要准备一些基础结论。

首先,要明确的是,我们主要对初始阶段的结果进行分析,此时的权重都是从某个“均值为 0、方差为 ”的分布中独立同分布地采样出来的,这允许我们通过期望来估计某些求和结果。比如对于 ,我们有

ba84a2c0ed8d25f4e8319c0659da07d2.png

因此可以取 。那么误差有多大呢?我们可以通过它的方差来感知。为此,我们先求它的二阶矩:

9e26c3194b85158fd57b954b8b6fef9f.png

如果采样分布是正态分布,那么可以直接算出 ,所以

2313daa5f43305e603b43716d7c278da.png

这个方差大小也代表着 的近似程度,也就是说原本的采样方差 越小,那么近似程度越高。特别地,常见的采样方差是 (对应 ,即单位向量),那么代入上式得到 ,意味着维度越高近似程度越高。此外,如果采样分布不是正态分布,可以另外重新计算 ,或者直接将正态分布的结果作为参考结果,反正都只是一个估算罢了。

如果 是另一个独立同分布向量,那么我们可以用同样的方法估计内积,结果是

ae3e001a2bd834c31258f1b619b0ba08.png

以及

a0d350c232c64d93bf09b7e89b111c6b.png

同样地,取 的话,那么方差是 ,维度越高近似程度越高。以上两个结果可以说是《n维空间下两个随机向量的夹角分布》[2]、《让人惊叹的Johnson-Lindenstrauss引理:理论篇》中的结论的统计版本。

e48b8ce51a3d300a53370d2ca2220642.png

损失分析

对语言模型来说,最终要输出一个逐 token 的 元分布,这里 是词表大小。假设我们直接输出均匀分布,也就是每个 token 的概率都是 ,那么不难计算交叉熵损失将会是 。这也就意味着,合理的初始化不应该使得初始损失明显超过 ,因为   代表了最朴素的均匀分布,明显超过 等价于说远远不如均匀分布,就好比是故意犯错,并不合理。

那么,为什么共享 Embedding 会出现这种情况呢?假设初始 Embedding 是 ,前面已经说了,初始阶段残差分支接近于零,所以输入输入 token ,模型输出就是经过 Normalization 之后的 Embedding 。常见的 Normalization 就是 Layer Norm 或者 RMS Norm,由于初始化分布是零均值的,所以 Layer Norm 跟 RMS Norm 大致等价,因此输出是

27800d763c53eb30480d3afbcf4525af.png

接下来重用 Embedding,内积然后 Softmax,所建立的分布实质是

575ec35920b66ff82218347280238cdc.png

对应的损失函数就是

24df994c2e419d33039a77476d4d0137.png

语言模型任务是为了预测下一个 token,而我们知道自然句子中叠词的比例很小,所以基本上可以认为 ,那么根据结果 (4) 就有 。所以,初始损失函数是

061152ad90e96a72296364f7e8ff027f.png

后面的 再次用到了式(1)和式(4)。常见的初始化方差 ,或者是一个常数,或者是 (此时 ),不管是哪一种,当 较大时,都导致 占主导,于是损失将会是 级别,这很容易就超过了均匀分布的 。

4a299245d6bfa3e95860616d155e5d2f.png

一些对策

根据上述推导结果,我们就可以针对性地设计一些对策了。比较直接的方案是调整初始化,根据式(9),我们只需要让 ,那么初始损失就是变成 级别的,也就是说初始化的标准差要改为 。

一般来说,我们会希望参数的初始化方差尽量大一些,这样梯度相对来说没那么容易下溢,而 有时候会显得过小了。为此,我们可以换一种思路:很明显,式(9)之所以会偏大,是因为出现了 ,由于两个 相同,它们内积变成了模长,从而变得很大,如果能让它们不同,那么就不会出现这一个占主导的项了。

为此,最简单的方法自然是干脆不共享 Embedding,此时是 而不是 ,用(4)而不是(1)作为近似,于是式(9)渐近于 。如果还想保留共享 Embedding,我们可以在最后的 Normalization 之后,再接一个正交初始化的投影层,这样 变成了 ,根据 Johnson-Lindenstrauss 引理,经过随机投影的向量近似于独立向量了,所以也近似于不共享的情况,这其实就是 BERT 的解决办法。特别地,这个投影层还可以一般化地加上 bias 和激活函数。

如果一丁点额外参数都不想引入,那么可以考虑在 Normalization 之后“打乱” 的各个维度,

2d6791965c5cc0bb4d0b88fb82090116.png

这里的 是拼接操作,那么 和 也接近正交了,内积自然也约等于0。这相当于(在初始阶段)将原来的 的 Embedding 矩阵劈开为两个 的矩阵然后构建不共享 Embedding 的 2-gram 模型。另外,我们还可以考虑其他打乱操作,比如 ShuffleNet [3] 中的先 reshape,然后 transpose 再 reshape 回来。

在笔者的实验中,直接改初始化标准差为 收敛速度是最慢的,其余方法收敛速度差不多,至于最终效果,所有方法似乎都差不多。

c68605c974b999e3425fbf6384a2ab07.png

文章小结

本文重温了语言模型输出端共享 Embedding 权重的操作,推导了直接重用 Embedding 来投影输出可能会导致损失过大的可能性,并探讨了一些解决办法。

outside_default.png

参考文献

outside_default.png

[1] https://arxiv.org/abs/2010.12821

[2] https://kexue.fm/archives/7076

[3] https://arxiv.org/abs/1707.01083

更多阅读

647cc6838c765afa03c8ddad58684cb4.png

77db9cc70ac308bf4997da12f772256c.png

461a89d74fdca9ae034cf7f01c54a565.png

4a4a2822b9c8e0197a93a34d92e1310a.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

d02df83dcc0f16fad14592fac40df6a9.png

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

·

30f544bfa33fd445a5c9b07abd36603b.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/796643.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

每天一个电商API分享:获取淘宝商品sku接口

SKU通俗来讲就是一个产品最小的出库单位,比如说一款手机产品有红黑白三个颜色,那么一台红色手机就是一个sku。那么多销售属性的产品,再往下分,比如说一件T恤除了有颜色还有尺码,那么一件s码绿色体恤就是单个sku。 sku指…

数据学习教程:Linux基础教程(下)

本文继上一篇《Linux基础教程(上)》的下一篇,欢迎收藏。 4 Linux常用基础命令 Linux刚面世时并没有图形界面, 所有的操作全靠命令完成, 如磁盘操作、文件存取、目录操作、进程管理、文件权限设定等 在职场中,大量的服务器维护工作…

C盘满了怎么清理?最全c盘清理攻略!

“c盘怎么会那么容易满啊?而且每次清理好像也释放不了多少空间。谁懂啊?一天要清理好几次c盘!真的很麻烦。有谁能告诉我应该怎么做吗!” 电脑c盘对我们来说真的是很重要。当我们在电脑上安装软件、存储文件或者浏览网站&#xff0…

提升打印品质:解决Excel表格乱套问题的实用技巧

作为办公人员,我们经常需要打印大量的数据表格。然而,打印表格并不是一件简单的事情,如果不注意,打印效果可能会变得混乱不堪。那么该怎么办呢?在这里,我将为大家分享9个关于Excel表格打印的技巧&#xff0…

django自定义app,创建子应用

1.工程里创建apps包 ; 2.创建子应用,pycharm terminal 运行:python ./nanage.py startapp app名称; 3.子应用移动到apps包里; 4.settings.py里设置INSTALLED_APPS如“apps.users”,该名字跟子应用apps.py文…

KBYCMS框架后台使用帮助介绍

后台入口文件 后台入口文件默认是public目录下的admin.php。访问后台时加上admin.php访问,您可根据需要,重命名后台入口文件。 重命名后需要在config/app.php文件中修改配置,配置如下,如果没有以下配置那么该版本无需理会。 // 入口文件绑定,无需写index app_file …

Fastjson远程命令执行漏洞总结

## 1.FastJson 简介 ##### fastjson.jar包原始下载地址:https://github.com/alibaba/fastjson ##### fastjson用于将Java Bean序列化为JSON字符串,也可以从JSON字符串反序列化到JavaBea... 1.FastJson 简介 fastjson.jar包原始下载地址:Git…

设计模式-模版方法模式

生活中处处存在模版,模版定义了大的框架,具体内容由使用者填充即可,这给很多人的生活、工作带来了很大的遍历。比如: PPT模版:好的PPT模版提供了更全面的叙述框架,更优美的UI画面&图标,提升…

算法训练营第五十一天||309.最佳买卖股票时机含冷冻期 ● 714.买卖股票的最佳时机含手续费 ●总结

309.最佳买卖股票时机含冷冻期 这道题主要就是搞懂dp数组含义以及状态之间的转换&#xff0c;没看答案能自己做出来 class Solution { public:int maxProfit(vector<int>& prices) {vector<vector<int>> dp(prices.size(),vector<int>(5,0));//前…

jMeter使用随记

参数化BodyData 先制作参数文件 再设置一个csv data set config 最后在body data里面写上参数${xxxxx}

【外卖系统】更新员工信息

需求分析 员工管理列表界面&#xff0c;需要对某个员工的账号进行启用和禁用操作。账号禁用的员工不能登录系统&#xff0c;启用后的员工可以正常登录。只有admin可以对其他普通用户进行启用、禁用的操作&#xff0c;普通用户登录系统后启动、禁用按钮都是不显示的编辑员工信息…

家庭有必要买洗地机吗、洗地机排行榜推荐

洗地机相信大家都认识吧&#xff0c;在清洁家电领域这可谓是个“名人”。在清洁工具的名单中&#xff0c;要说一机多用&#xff0c;使用体验好的&#xff0c;洗地机绝对名列前茅。和传统清洁工具相比&#xff0c;洗地机可以很快速的就清洁干净地面&#xff0c;十多分钟就能还你…

数据库—用户权限管理(三十三)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 前言 一、概述 二、用户权限类型 ​三、用户赋权 四、权限删除 五、用户删除 前言 数据库用户权限管理是指对数据库用户的权限进行控制和管理&#xff0c;确保用户只能执…

探究Spring Bean的六种作用域:了解适用场景和使用方式

这里写目录标题 单例&#xff08;Singleton&#xff09;作用域&#xff1a;原型&#xff08;Prototype&#xff09;作用域&#xff1a;请求&#xff08;Request&#xff09;作用域&#xff1a;会话&#xff08;Session&#xff09;作用域&#xff1a;全局&#xff08;applicati…

【一文搞懂】—带霍尔编码器的直流有刷减速电机

文章目录 一、直流有刷电机二、减速比三、霍尔编码器3.1 霍尔编码器3.2 霍尔编码器测速原理 四、测速程序设计4.1 跳变沿检测4.2 计算转速 一、直流有刷电机 宏观上说直流有刷电机由固定部分&#xff08;定子&#xff09;和旋转部分&#xff08;转子&#xff09;组成。在定子上…

Web Worker的概念、用法、使用场景

​ 目录 1. 简介 2. 适用场景 2.1 复杂计算 2.2 后台下载 2.3 数据处理 2.4 实时通信 3. 代码示例 3.1 Worker特性检测 3.2 Worker API 3.3 SharedWorker API 3.4 创建 JavaScript 文件 3.5 创建 Web Worker 4. 总结 1. 简介 Web Worker 使得在一个独立于 Web 应…

2023-07-27 LeetCode每日一题(删除每行中的最大值)

2023-07-27每日一题 一、题目编号 2500. 删除每行中的最大值二、题目链接 点击跳转到题目位置 三、题目描述 给你一个 m x n 大小的矩阵 grid &#xff0c;由若干正整数组成。 执行下述操作&#xff0c;直到 grid 变为空矩阵&#xff1a; 从每一行删除值最大的元素。如果…

VS2022和QT混合编程打包发布程序

1.在开始菜单输入 CMD 找到 Qt5.15.2(MSVC 64-bit) 2.输入windeployqt exe所在路径 3.运行完毕后&#xff0c;双击打开exe文件&#xff0c;可能会报错&#xff0c;缺少相关的dll,找到缺少的dll拷贝到运行文件夹下即可。

数字化管理能给企业带来哪些好处?

企业数字化管理&#xff08;EDM&#xff09;是指使用数字技术和工具来管理企业运营和流程的各个方面。如果有效实施&#xff0c;EDM 可以给企业带来多种好处&#xff0c;提高企业的整体效率、生产力和竞争力。以下是一些主要优点&#xff1a; 1.提高效率&#xff1a;EDM 通过自…

参数自定义配置比例阀放大器

模拟指令输入比例阀放大器通常使用模拟信号来控制其输出&#xff0c;例如10V, 0~5V,0~10V,4~20mA模拟量信号。它可以将输入的模拟信号放大并转换为一个与输入信号成正比的输出信号&#xff0c;从而实现对执行机构的位置或速度控制。 适配各种不带位置反馈比例阀的控制&#xf…