NGPT:在超球面上进行表示学习的归一化 Transformer

news2024/11/5 14:40:13

在超球面上进行表示学习的归一化 Transformer

  • 1. 研究背景
  • 2. nGPT 的核心贡献
    • 超球面上的网络参数优化
    • 作为超球面上的变度量优化器
    • 更快的收敛速度
  • 3. 从 GPT 到 nGPT 的演变
    • 标记嵌入和输出逻辑
  • 层和块
    • 自注意力块
    • MLP 块
    • 有效学习率在 ADAM 中的应用
    • 总结
  • 4. 实验结果
    • 训练加速
    • 网络参数检查
    • 消融研究
  • 5. 相关工作
  • 6. 讨论与结论

在人工智能领域,神经网络架构的创新不断推动着技术的进步。最近,一篇名为 “NGPT: NORMALIZED TRANSFORMER WITH REPRESENTATION LEARNING ON THE HYPERSPHERE” 的研究论文引起了广泛关注。作者是 Ilya Loshchilov、Cheng - Ping Hsieh、Simeng Sun 和 Boris Ginsburg,他们来自 NVIDIA。这篇论文提出了一种新颖的神经网络架构 —— 归一化 Transformer(nGPT),它在超球面上进行表示学习,展现出了令人瞩目的性能优势。

1. 研究背景

Transformer 架构是现代语言模型的基础,为了提高其训练稳定性、推理成本、上下文长度和鲁棒性等,研究人员提出了大量的修改方案。其中,应用各种归一化技术被认为是有益的,例如添加 LayerNorm 和 RMSNorm 等归一化层,以及通过权重衰减控制权重的范数。同时,也有研究表明在超球面上进行表示学习与更稳定的训练、更大的嵌入空间可分性以及在下游任务中更好的性能相关。在此基础上,本文作者提出了归一化 Transformer,旨在统一该领域的各种发现和观察结果。

2. nGPT 的核心贡献

超球面上的网络参数优化

作者提出将构成网络矩阵嵌入维度的所有向量归一化,使其位于单位范数超球面上。这样,矩阵 - 向量乘法就可以看作是表示在 [-1,1] 范围内的余弦相似度的点积,从而使权重衰减变得不必要。

作为超球面上的变度量优化器

归一化 Transformer 本身在超球面上执行多步优化(每层两步),其中注意力和 MLP 更新的每一步都由特征学习率(可学习的变度量矩阵的对角元素)控制。对于输入序列中的每个标记,归一化 Transformer 的优化路径从超球面上对应其输入嵌入向量的点开始,并移动到超球面上最能预测下一个标记嵌入向量的点。

更快的收敛速度

实验表明,归一化 Transformer 将达到相同精度所需的训练步骤数减少了 4 到 20 倍(取决于序列长度)。

3. 从 GPT 到 nGPT 的演变

标记嵌入和输出逻辑

在原始的仅解码器 Transformer 中,标记嵌入向量的范数不受约束,这可能导致不准确的相似度估计。在 nGPT 中,作者提出在训练算法的每一步之后,对存储在和中的嵌入向量进行归一化。同时,由于所有 nGPT 嵌入都是归一化的,原始公式中的逻辑值代表在 [-1,1] 范围内的点积,这限制了 softmax 生成的概率分布的置信度(温度)。因此,作者引入了一个可训练的缩放参数来调整。

层和块

  • 基线 Transformer:对隐藏状态应用层变换,包括交替的自注意力(ATTN)和多层感知器(MLP)块,并使用 RMSNorm 进行归一化。
  • 归一化 Transformer:对于超球面上的任意两点和,可以使用 SLERP 或其近似的 LERP 来计算沿着测地线的插值。作者将其改写为 nGPT 中的更新方程,其中涉及到注意力和 MLP 块的更新方程,通过可学习的参数和以及归一化函数 Norm 来控制更新过程。与基线 Transformer 不同,nGPT 在最后一层之后不需要额外的归一化。

自注意力块

  • 基线 Transformer:注意力机制是 Transformer 的关键组件,它允许每个标记关注序列中的其他标记。在基线 Transformer 中,首先使用 RMSNorm 对输入隐藏状态进行归一化,然后将其投影为查询、键和值,并应用旋转位置嵌入(RoPE)。通过计算查询和键向量的点积,缩放后应用 softmax 函数得到注意力权重,最后计算值向量的加权和。
  • 归一化 Transformer:作者提出对、、和沿着其嵌入维度进行归一化,使得与计算的点积可以解释为单位范数向量之间的余弦相似度。此外,还对和进行额外的归一化,以确保每个查询和键的点积在控制范围内。同时,调整了 softmax 缩放因子。

MLP 块

  • 基线 Transformer:MLP 块的输入隐藏状态首先使用 RMSNorm 进行归一化,然后通过两个单独的线性投影产生两个中间向量和,使用 SwiGLU 激活函数进行组合,最后通过一个最终的线性变换得到输出。
  • 归一化 Transformer:作者提出对矩阵和沿着嵌入维度进行归一化,使得和向量分别代表与存储在和中的向量之间的余弦相似度。为了控制它们的影响,引入了缩放因子和。

有效学习率在 ADAM 中的应用

在 nGPT 中,对于任何可训练的缩放参数向量,如,使用两个标量和来控制其有效学习率。通过调整,可以在保持全局学习率不变的情况下,控制的有效学习率。

总结

将基线 Transformer 转换为归一化 Transformer 的步骤包括:移除所有归一化层;在每次训练步骤后,对所有矩阵沿着其嵌入维度进行归一化;替换更新方程;改变注意力中的 softmax 缩放因子并对和进行重新缩放和归一化;对 MLP 块的中间状态进行重新缩放;对逻辑值进行重新缩放;移除权重衰减和学习率预热。

4. 实验结果

训练加速

作者在 OpenWebText 数据集上训练了基线 Transformer(GPT)和归一化 Transformer(nGPT),并在一组标准下游任务上对它们进行了评估。实验结果表明,在不同的上下文长度和网络大小下,nGPT 的训练速度比 GPT 快 4 到 20 倍。例如,在 4k 上下文长度下,具有 10 亿参数的 nGPT 在 20k 次迭代后达到了与 GPT 在 200k 次迭代后相同的验证损失,展示了 10 倍的迭代速度提升和使用的标记数量提升。

网络参数检查

  • 嵌入的范数分布:nGPT 保持嵌入的固定范数,而 GPT 的嵌入范数有显著变化。GPT 的输入嵌入具有较高的条件数,尤其是在 1B 模型中。嵌入之间的成对点积分布表明,即使在 nGPT 中,嵌入也不是均匀分布在超球面上,而是形成簇,这可能反映了语言数据中的自然模式。
  • 注意力和 MLP 矩阵的条件数:GPT 模型在其注意力矩阵中表现出明显更高的条件数,与 nGPT 相比,这些矩阵可能退化为低秩矩阵,潜在地降低了这些块的学习能力。
  • 特征学习率和缩放因子:注意力和 MLP 块对隐藏状态的贡献由特征学习率和控制。网络学习在和所建议的方向上只采取适度的步骤。缩放因子、和在各层之间相对稳定,它们似乎补偿了在归一化矩阵和嵌入时丢失的幅度信息。

消融研究

作者进行了大量的消融实验,结果表明,对于、、使用固定(不可学习)值,以及对于使用单个全局可学习值,只会导致准确性的轻微下降。此外,nGPT 可以在不需要对 RoPE 进行任何修改的情况下处理更长的上下文。

5. 相关工作

本文的研究与之前关于超球面表示学习的工作相关。例如,在变分自动编码器的潜在空间和用于面部验证的嵌入中,球形表示与更稳定的训练相关。同时,之前的研究也发现下游任务性能与嵌入在超球面上的对齐(紧密性)和均匀性之间存在强烈的经验相关性。作者还讨论了 nGPT 的更新与 GPT 中应用 RMSNorm 的近似关系,以及 nGPT 中 QK 归一化与之前工作的相似性。

6. 讨论与结论

这项工作建立在该领域的众多关键发现和观察结果之上,主要贡献包括对所有矩阵的嵌入维度进行归一化,以及将特征学习率从网络的其他部分解耦,使其成为可训练的参数。

通过这些创新,nGPT 作为一种变度量优化器,能够利用数据驱动的梯度信息在超球面上搜索输出解决方案。实验结果表明,nGPT 在训练速度上有显著提升,同时也为进一步探索新的算法和架构提供了基础。未来的工作可以探索将 nGPT 扩展到更大的网络规模、真实世界数据集以及更广泛的任务范围。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2231253.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AD】1-7 AD24软件扩展插件的设置与安装

1.如图所示打开扩展 2.点击齿轮后,确保离线安装位置关联了软件安装包的路径位置后,进行勾选选择后,点击应用即可安装。 注意:如果位置关联错误,则显示如图

Window on ARM解锁所有的TTS语音包供python调用

Window on ARM解锁所有的TTS语音包供python调用 可用的语音包查看查看TTS可用的语音包解锁语音包设置升级系统打开注册表导出注册表修改注册表导入新的注册表可用的语音包查看 微软的Windows 10操作系统为设备上安装的每种语言提供了一套语音。但只有部分已安装的语音能在整个…

pandas数据处理高级系列003---什么是交叉表(Cross Tabulation)以及pandas如何生成

做ab测试的时候遇到了一个新的知识点,交叉表以及如何用pandas生成交叉表 交叉表(Cross Tabulation),也称为列联表(Contingency Table),是一种用于统计分析的表格,用于显示两个或多个…

MySQL数据库之存储过程的创建与应用

存储过程 procedure 一.存储过程 作用:将经常使用的功能写成存储过程,方便后续重复使用。 二.创建存储过程 三.调用存储过程 call在计算机中是调用的意思 案例1:查看MySQL用户数 如上图所示,这是查看MySQL数据库中的user个数…

手搓简易shell

1.打印命令行 &#xff0c;接受命令行输入 命令行就是&#xff0c;“[用户名主机名 当前目录]$"获取这些信息都存储在Linux内核中的环境变量中&#xff0c;用getenv()函数取出 #include <stdio.h>2 #include <stdlib.h>3 #include <string.h>4 #include…

多个JDK版本之间的切换

首先电脑上可以同时安装多个版本的 JDK&#xff08;Java Development Kit),因为不同的应用程序可能需要不同 Java 版本的支持,安装多个 JDK 版本并不会导致冲突&#xff0c;只要设置好即可,在不同的情况下切换不同的jdk版本保证程序正常工作 很多程序jdk8 已经不支持,所以下载…

鸿蒙生态下开发挑战-鸿蒙低代码开发工具展望及优势

鸿蒙生态下开发挑战 在鸿蒙生态下开发时&#xff0c;开发者可能会遇到多方面的挑战&#xff0c;这些挑战主要涉及开发工具、技术难度、生态竞争以及市场定位等方面。以下是对这些挑战的详细分析&#xff1a; 一、开发工具不完善 尽管鸿蒙系统的开发工具DevEco Studio在逐步完…

celery在django项目中实现并发任务和定时任务

创建一个django项目 django-admin startproject celeryDemo进入项目目录 cd celeryDemo在你的 Django 项目中&#xff0c;创建一个 celery_.py 文件&#xff0c;通常放在项目的根目录&#xff08;与 settings.py 同级&#xff09;&#xff1a; from __future__ import absol…

ST算法解RMQ问题

题目 代码 #include <bits/stdc.h> using namespace std; const int N 2e510, M 20; int st[N][M]; int n, m; int main() {ios::sync_with_stdio(0);cin.tie(0);cin >> n;for(int i 1; i < n; i)cin >> st[i][0];for(int i 1; (1 << i) < …

国内版Sketchfab平台 - CG美术之家(3D编辑发布篇)

CG美术之家为您提供了一个简便的模型上传流程&#xff0c;让发布您的3D模型变得轻而易举。只需准备好通用的3D模型格式&#xff0c;如OBJ、FBX或STL&#xff0c;您就可以轻松上传并分享您的创作。我们的平台支持在线3D渲染&#xff0c;您只需花费不到一分钟的时间&#xff0c;就…

Rocky Linux 9安装后无法远程ssh密码登录解决

在Rocky Linux 9版本中&#xff0c;为了增加安全性&#xff0c;默认情况下禁用SSH root密码登录。这是系统默认设定的规则&#xff0c;我们同样也可以更改它。   允许Rocky Linux 9 root用户通过ssh登录方法&#xff1a; 1.编辑SSH配置文件 2.找到以下内容 PermitRootLogin …

C语言教程——操作符详解(1)

目录 前言 1.操作符的分类&#xff1a; 2.算数操作符 2.1除法 2.2取模 3.移位操作符 3.1二进制相关知识 3.2左移操作符 3.2.1正数 3.2.2负数 3.2.3结论 3.3右移操作符 4.位操作符 4.1 按位与 4.2按位或 4.3按位异或 ​编辑 5.赋值操作符 6.复合赋值符 7.单目操…

mfc140u.dll丢失怎么办? mfc140u.dll文件缺失的修复技巧

mfc140u.dll 是 Microsoft Foundation Classes (MFC) 库的一部分&#xff0c;它是 Visual Studio 2015 的组件之一&#xff0c;主要服务于使用 C 编写的 Windows 应用程序。这个动态链接库文件包含了 MFC 14.0 Unicode 版本的实现代码&#xff0c;为应用程序提供运行时支持。当…

Golang | Leetcode Golang题解之第520题检测大写字母

题目&#xff1a; 题解&#xff1a; func detectCapitalUse(word string) bool {// 若第 1 个字母为小写&#xff0c;则需额外判断第 2 个字母是否为小写if len(word) > 2 && unicode.IsLower(rune(word[0])) && unicode.IsUpper(rune(word[1])) {return f…

专题九——哈希表

目录 0简介 1两数之和 2判定是否互为字符重排 3存在重复元素 4存在重复元素 II 5字母异位词分组 0简介 1两数之和 oj链接&#xff1a;两数之和 解法1 class Solution { public:vector<int> twoSum(vector<int>& nums, int target) {int nnums.size()…

unet中的attn_processor的修改(用于设计新的注意力模块)

参考资料 文章目录 unet中的一些变量的数据情况attn_processorunet.configunet_sd 自己定义自己的attn Processor &#xff0c;对原始的attn Processor进行修改 IP-adapter中设置attn的方法 参考的代码&#xff1a; 腾讯ailabipadapter 的官方训练代码 unet中的一些变量的数据…

客户端时间 与 服务器时间

对客户端时间和服务器有概念&#xff0c;但从来没有这么直观地观察过。直到有一天打开了长久未使用的mac&#xff0c;第一次对时间有了直观的概念&#xff1a; 打开之后就有了上面这样的提示“您的时钟慢了”… 我看了下电脑的时间&#xff0c;然后打开F12获取了下时间&#x…

VLAN高级特性:VLAN聚合

一、VLAN聚合的概述 在一般的三层交换机中&#xff0c;通常是采用一个VLAN对应一个VLANIF接口实现广播域之间的互通&#xff0c;这导致了在一些情况下造成了IP地址的浪费。 因为一个VLAN对应的子网中&#xff0c;子网号&#xff0c;子网广播地址、子网网关地址不能用作VLAN内…

Rust 力扣 - 2653. 滑动子数组的美丽值

文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 我们遍历长度为k的的窗口 因为数据范围比较小&#xff0c;所以我们可以通过计数排序找到窗口中第k小的数 如果小于0&#xff0c;则该窗口的美丽值为第k小的数如果大于等于0&#xff0c;则该窗口的美丽值为0 题…