(2020|ICML PMLR,线性 Transformer,核函数,RNN)Transformer 是 RNN

news2024/11/26 6:59:27

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

公众号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)

目录

0. 摘要

3. 线性 Transformers

3.1. Transformer

3.2. 线性注意力机制

3.2.1. 特征映射与计算成本

3.3. 因果掩码

3.3.1. 梯度计算

3.3.2. 训练和推理

3.4. transformer 是 RNN

4. 实验


0. 摘要

Transformer 在多项任务中表现出色,但由于其对输入长度的二次复杂度,对于非常长的序列来说,速度极慢。为了解决这一限制,我们将自注意力表示为核特征映射(kernel feature maps)的线性点积,并利用矩阵乘积的结合性将复杂度从 O(N^2) 降低到 O(N),其中 N 是序列长度。我们证明了这种表达方式允许一种迭代实现,大大加速了自回归 Transformer,并揭示了它们与递归神经网络的关系。我们的线性 Transformer 在性能上与普通 Transformer 相似,并且在非常长序列的自回归预测中速度快达 4000 倍。 

3. 线性 Transformers

在本节中,我们提出了线性 Transformer。我们展示了将传统的 softmax 注意力机制改为基于特征映射的点积注意力,可以改善时间和内存复杂度,并且可以实现类似于 RNN 的线性时间序列生成模型。

3.1. Transformer

3.2. 线性注意力机制

公式 2 中的注意力定义是通用的,可以用于定义多种其他注意力实现,例如多项式注意力或 RBF 核注意力(Tsai等人,2019)。注意,为了使公式 3 定义的注意力函数有效,我们需要对 sim(·) 施加的唯一约束是非负性。这包括所有核函数 k(x, y): R^(2 × F) → R_+。

给定具有特征表示 ϕ(x) 的核函数,我们可以将公式 2 重写为:

然后利用矩阵乘法的结合性进一步简化为:

当分子以向量形式书写时,上述公式更容易理解,如下所示:

注意,特征映射 ϕ(·) 是逐行应用于矩阵 Q 和 K 的。

从公式 2 可以看出,softmax 注意力的计算成本随 O(N^2) 缩放,其中 N 表示序列长度。内存需求也是如此,因为必须存储完整的注意力矩阵以计算查询、键和值的梯度。相比之下,我们在公式 5 中提出的线性 transformer 具有 O(N) 的时间和内存复杂度,因为我们可以计算

一次,并在每个查询中重复使用它们。

3.2.1. 特征映射与计算成本

对于 softmax 注意力,就乘法和加法的总成本而言,随着 O(N^2·max(D, M)) 缩放,其中 D 是查询和键的维度,M 是值的维度。相反,对于线性注意力,我们首先计算维度为 C 的特征映射。随后,计算新值需要 O(NCM) 次加法和乘法。

上述分析未考虑核函数和特征函数的选择。需要注意的是,对应于指数核的特征函数是无限维的,这使得精确 softmax 注意力的线性化不可行。另一方面,例如多项式核具有精确的有限维特征映射,并且已证明与指数或 RBF 核(Tsai等人,2019)同样有效。线性化多项式 transformer 的计算成本为 O(N·D^2·M)。当 N > D^2 时,这使得计算复杂度更具优势。实际上,由于我们希望能够处理成千上万元素的序列,这一情况是成立的。

对于我们的实验,处理较小的序列,我们采用了一个结果为正相似函数的特征映射,如下定义:

其中 elu(·) 表示指数线性单元(Clevert等人,2015)的激活函数。我们更喜欢 elu(·) 而不是relu(·),以避免在 x 为负时将梯度设置为 0。这种特征映射导致的注意力函数需要 O(NDM) 次乘法和加法。在我们的实验部分,我们展示了公式 7 的特征映射在性能上与完整 transformer 相当,同时显著减少了计算和内存需求。

3.3. 因果掩码

transformer  架构可以通过掩蔽(masking)注意力计算来高效地训练自回归模型,使得第 i 个位置只能被第 j 个位置影响当且仅当 j ≤ i,即一个位置不能被后续位置影响。形式上,这种因果掩码将公式 3 修改如下:

按照3.2节的推理,我们如下所述对掩码注意力进行线性化:

通过引入 Si 和 Zi 如下所示:

我们可以将公式 9 简化为:

注意,Si 和 Zi 可以从 S_(i-1) 和 Z_(i-1) 在固定时间内计算得出,因此使得具有因果掩码的线性 transformer 的计算复杂度相对于序列长度为线性。

3.3.1. 梯度计算

在任何深度学习框架中,公式 12 的朴素实现需要存储所有中间值 Si,以计算梯度。这会增加max(D, M) 倍的内存消耗,从而阻碍因果线性注意力在更长序列或更深模型中的应用。为了解决这个问题,我们将公式 9 中的分子(numerator)的梯度导出为累积和。这使我们能够在线性时间和固定内存中计算因果线性注意力的前向和后向传播。详细推导见附录材料。

给定分子 ¯V_i 和标量损失函数相对于分子的梯度

推导可得:

累计和项在公式 9 和 13-15 中以线性时间计算,并且相对于序列长度需要常量内存。这导致的算法在给定维度为 C 的特征映射下,其计算复杂度为 O(NCM),内存复杂度为 O(N·max (C, M))。算法 1 是分子部分前向和后向传播的伪代码实现。

3.3.2. 训练和推理

在训练自回归 transformer 模型时,可以使用完整的真实序列。这使得公式 1 中的函数 φ(·) 和注意力计算都可以进行分层并行化。因此,transformer 比 RNN 更高效地进行训练。然而,在推理过程中,时间步 i 的输出是时间步 i + 1 的输入。这使得自回归模型无法并行化。此外,transformer 每个时间步的成本不是常量,而是随着当前序列长度的平方增长,因为必须为所有先前的时间步计算注意力。

我们提出的线性 transformer 模型结合了这两者的优点。在训练时,计算可以并行化并充分利用 GPU 或其他加速器。在推理时,我们模型的每次预测在时间和内存上的成本是常量的。这意味着我们可以简单地将

矩阵存储为内部状态,并在每个时间步像递归神经网络一样更新它。这使得推理速度比其他 transformer 模型快数千倍。

3.4. transformer 是 RNN

在文献中,transformer 模型被认为是一种与递归神经网络(RNN)根本不同的方法。然而,从 3.3 节中的因果掩码公式和前一节的讨论可以看出,任何具有因果掩码的 transformer 层都可以被表示为一种模型,该模型在给定输入后修改内部状态,然后预测输出,即 RNN。注意,与通用变压器(Universal Transformers)(Dehghani等人,2018)不同,我们考虑的是时间上的递归,而不是深度上的递归。

在以下公式中,我们将公式 1 的 Transformer 层形式化为 RNN。所得的 RNN 有两个隐藏状态,即注意力记忆 s 和归一化记忆 z。我们用下标表示递归中的时间步。

在上述公式中,x_i 表示特定 Transformer 层的第 i 个输入,y_i 表示第 i 个输出。需要注意的是,我们的公式对特征函数没有任何约束,因此可以用于表示任何 Transformer 模型,理论上甚至包括使用 softmax 注意力的模型。这一公式是更好理解 Transformer 与流行的 RNN(Hochreiter & Schmidhuber, 1997)及其存储和检索信息过程之间关系的第一步。 

4. 实验

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1708988.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣62 不同路径 Java版本

文章目录 题目描述代码 题目描述 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角(在下图中标记为 “Finish” )。 问总共有多少…

NLP技术发展和相关书籍分享

自然语言处理(Natural Language Processing,NLP)是计算机科学领域和人工智能领域的重要研究方向之一,旨在探索实现人与计算机之间用自然语言进行有效交流的理论与方法。它融合了语言学、计算机科学、机器学习、数学、认知心理学等…

场景文本检测识别学习 day10(MMdetection)

配置文件(config) 由于在大型项目中,一种模型需要分:tiny、small、big等很多种,而它们的区别主要在网络结构,数据的加载,训练策略等,且差别很多都很小,所以如果每个模型都手动从头写一份&#…

ssm150旅游网站的设计与实现+jsp

旅游网站设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本旅游网站就是在这样的大环境下诞生,其可以帮助管理者在短时间内处理完毕庞…

鸿蒙OS开发:【一次开发,多端部署】(音乐专辑主页)

一多音乐专辑主页 介绍 本示例展示了音乐专辑主页。 头部返回栏: 因元素单一、位置固定在顶部,因此适合采用自适应拉伸,充分利用顶部区域。专辑封面: 使用栅格组件控制占比,在小尺寸屏幕下封面图与歌单描述在同一行。歌曲列表: 使用栅格组…

汽车电子零部件(14):TMS热管理系统

前言: TMS(thermal management system)热管理系统,这是新能源汽车诞生后随之而产生的一种新汽车零部件,一旦热管理失控会触发自燃,这种现象也是对EV来说是件头疼的事。汽车的热管理系统(TMS)是一个关键部件,有助于调节汽车电池组、车厢和其他车辆系统的温度。TMS的主要…

假象和谎言

原创 | 刘教链 隔夜BTC(比特币)徘徊在69k一线。5.25教链内参报告,《BTC ETF持仓即将超越中本聪》。ETH ETF的尘嚣逐渐散去,复归于平静。戏刚唱了个开头,结尾还留着悬念。4000刀之于ETH看来是个关键阻力位,最…

JavaEE-Spring Controller(服务器控制以及Controller的实现和配置)

Spring Controller 服务器控制 响应架构 Spring Boot 内集成了 Tomcat 服务器,也可以外接 Tomcat 服务器。通过控制层接收浏览器的 URL 请求进行操作并返回数据。 底层和浏览器的信息交互仍旧由 servlet 完成,服务器整体架构如下: Server&…

[9] CUDA性能测量与错误处理

CUDA性能测量与错误处理 讨论如何通过CUDA事件来测量它的性能如何通过CUDA代码进行调试 1.测量CUDA程序的性能 1.1 CUDA事件 CPU端的计时器可能无法给出正确的内核执行时间CUDA事件等于是在你的CUDA应用运行的特定时刻被记录的时间戳,通过使用CUDA事件API&#…

第十四届蓝桥杯c++研究生组

A 关键思路是求每个十进制数的数字以及怎么在一个数组中让判断所有的数字次数相等。 求每个十进制的数字 while(n!0){int x n%10;//x获取了n的每一个位数字n/10;}扩展:求二进制的每位数字 (注意:进制转换、1的个数、位运算) x…

rk3568_semaphore

文章目录 前言1 什么是信号量1.1 信号量API函数2、信号量实验2.1 实验目的2.2函数源码2.3 运行结果图前言 本文记录rk3568开发板的信号量实验 1 什么是信号量 信号量是同步的一种方式,常常用于控制对共享资源的访问。 举个例子:停车场的停车位有100个,这100个停车位就是共…

js的学习

什么是JavaScript? JavaScript(简称:JS)是一门跨平台、面向对象的脚本语言。是用来控制网页行为的,”它能使网页可交互。 JavaScript 和Java 是完全不同的语言,不论是概念还是设计。但是基础语法类似。 JavaScript在1995 年由 Brendan Eich 发明&#x…

【OpenCV】图像通道合并与分离,ROI

介绍可以实现图像通道合并与分离的API,这只是一种方式,后续还会介绍其他的合并与分离方法,以及ROI区域截取的方法。相关API: split() merge() Mat对象() 代码: #include "iostream" #include "ope…

VUE3 学习笔记(6):data数据的监听、表单绑定、操作DOM

data数据的监听&#xff08;侦听&#xff09; 对于data的值的监听&#xff0c;可以用watch中与data中的参数命名一致的值做为函数进行获取监听变动前后的值再做逻辑判断&#xff0c;如下图所示。 示例代码 <template><div><p :class"classDemo">{…

【SQL学习进阶】从入门到高级应用(二)

文章目录 简单查询查一个字段查多个字段查所有字段查询时字段可参与数学运算查询时字段可起别名as关键字省略as关键字别名中有空格别名中有中文 &#x1f308;你好呀&#xff01;我是 山顶风景独好 &#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xf…

【测评】香橙派 AIpro上手初体验

AI毋庸置疑是近年来&#xff0c;热度最高的技术之一&#xff0c;作为一名工程师拥抱新技术的同时不可或缺的需要一块强悍的开发板&#xff0c;香橙派 AIpro除了拥有好看的皮囊之外&#xff0c;还拥有一个有趣且充满魅力的灵魂。作为一位长期活跃在嵌入式开发领域的工程师&#…

Autodl服务器中Faster-rcnn(jwyang)复现(一)

前言 在做实验时需要用到faster-rcnn做对比,本节首先完成代码复现,用的数据集是VOC2007~ 项目地址:https://github.com/jwyang/faster-rcnn.pytorch/tree/pytorch-1.0 复现环境:autodl服务器+python3.6+cuda11.3+Ubuntu20.04+Pytorch1.10.0 目录 一、环境配置二、编译cud…

杀死那个进程

一、场景 eclipse在启动tomcat时&#xff0c;出现端口被占用的情况。我寻思着“任务管理器”没出现相应程序在跑啊。 1.1问题&#xff1a;端口和进程的关系 端口和进程之间存在着一种关系&#xff0c;端口是一个逻辑概念&#xff0c;它用于标识网络通信中的一个终点&#xff0…

二分答案思想下的二进制问题

序列合并 题目描述 给定一个长度为 n n n 的非负整数序列 { a n } \{a_n\} {an​}&#xff0c;你可以进行 k k k 次操作&#xff0c;每次操作你选择两个相邻的数&#xff0c;把它们合并成它们的按位或。 形式化地&#xff0c;一次操作中&#xff0c;你选择一个下标 i i …

【算法】【二叉树,DFS,哈希集合,分类讨论】力扣1110. 删点成林

1110. 删点成林 文章目录 【算法】力扣【二叉树&#xff0c;DFS&#xff0c;哈希集合&#xff0c;分类讨论】1110. 删点成林题目描述示例 1&#xff1a;示例 2&#xff1a; 输入输出示例解释思路解析核心思想算法步骤复杂度分析 代码实现总结 【算法】力扣【二叉树&#xff0c…