(2024,attention,可并行计算的 RNN,并行前缀扫描)将注意力当作 RNN

news2024/10/6 14:27:55

Attention as an RNN

公众号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)

目录

0. 摘要

3. 方法 

3.1 注意力作为一种(多对一的)RNN

3.2 注意力作为(多对多)RNN

3.3 Aaren: 将注意力当作 RNN

4. 实验 

附录

F. 计算时间


0. 摘要

随着 Transformer 的出现,序列建模领域取得了重大突破,为利用 GPU 并行性提供了高性能的架构。然而,Transformer 在推理时计算开销较大,限制了其在低资源环境(例如移动和嵌入式设备)中的应用。为了解决这个问题,我们采取了以下步骤:

  • 首先,我们展示了注意力机制可以被视为一种特殊的递归神经网络(RNN),能够高效地计算其多对一的 RNN 输出。
  • 接着,我们证明了流行的基于注意力的模型(如 Transformer)可以视为 RNN 的变体。然而,与传统的 RNN(如 LSTM)不同,这些模型无法高效地用新 token 更新,这在序列建模中是一个重要特性。
  • 为了解决这个问题,我们引入了一种基于并行前缀(prefix)扫描算法的新的高效计算注意力多对多 RNN 输出的方法。
  • 基于新的注意力公式,我们推出了 Attention as a recurrent neural network(Aaren),这是一种基于注意力的模块,不仅可以像 Transformer 一样并行训练,还能像传统 RNN 一样高效地用新 token 进行更新,只需要固定内存来进行推理。

在实验中,我们展示了 Aaren 在四个流行的序列问题设置中的 38 个数据集上,表现与 Transformer 相当,但在时间和内存效率上更优。这些问题设置包括强化学习、事件预测、时间序列分类和时间序列预测任务。

3. 方法 

3.1 注意力作为一种(多对一的)RNN

对一个查询向量 q 的注意力可以被看作一个函数,该函数将 N 个上下文 token x_(1:N) 通过它们的键和值 {(ki, vi)}^N_(i=1) 映射到一个单一的输出 o_N = Attention(q, k_(1:N), v_(1:N))。给定 s_i = dot(q, ki),输出 o_N 表示为: 

将注意力视为一种 RNN,我们可以将其迭代计算为滚动求和

其中 k=1,…,N。然而,在实践中,这是一种不稳定的实现,会由于有限精度表示和可能非常小或非常大的指数(即,exp(s))而遇到数值问题。为缓解这一问题,我们通过累积最大值项

重写递归关系,改为计算

值得注意的是,最终结果是相同的,即 o_N = ^a_N / ^c_N= a_N /c_N。因此,a_k、c_k 和 m_k 递归地计算如下: 

通过将 a_k​、c_k 和 m_k 的递归计算从 a_(k-1)​、c_(k-1) 和 m_(k-1) 中封装起来,我们引入了一个 RNN 单元,用于迭代计算注意力的输出(见图 2)。注意力的 RNN 单元以 (a_(k-1), c_(k-1), m_(k-1)​, q) 作为输入,并计算 (ak​,ck​,mk​,q)。请注意,查询向量 q 在 RNN 单元中被传递。注意力的 RNN 的初始隐藏状态为 (a_0, c_0, m_0, q) = (0, 0, 0, q)。

计算注意力的方法。通过将注意力视为一种 RNN,我们可以看到有不同的计算注意力的方法:

  • 在 O(1) 内存中逐 token 递归地(即,顺序地)
  • 以传统方式(即,并行地)在需要线性 O(N) 内存的情况下。由于注意力可以被看作是一种 RNN,计算注意力的传统方法也可以被看作是计算注意力的一对多 RNN 输出的有效方法,即,RNN 的输出将多个上下文 token 作为输入,但在 RNN 结束时只输出一个 token(见图 1a)
  • 一种按块处理 token 的 RNN,需要 O(b) 内存,其中 b 是块的大小。然而,这种方法超出了本工作的范围。因此,按块处理的 RNN 的描述包含在附录 A 中。

将现有基于注意力的模型视为 RNN。通过将注意力视为一种 RNN,现有的基于注意力的模型也可以被视为 RNN 的变体。例如,Transformer 的自注意力是 RNN(图 1b),其上下文 token 作为其初始隐藏状态。

Perceiver 的交叉注意力是 RNN(图 1c),其上下文相关潜变量作为其初始隐藏状态。通过利用它们注意力机制的 RNN 形式,这些现有模型可以有效地计算它们的输出内存。将注意力视为 RNN 存在的挑战。然而,将现有的基于注意力的模型,如 Transformer,视为 RNN 时,这些模型缺乏传统 RNN(如 LSTM 和 GRU)中常见的重要属性。值得注意的是,LSTM 和 GRU 能够以仅为  O(1) 的固定内存和计算高效地更新自己,这对于序列建模是一个重要特性,其中数据以流的形式接收。相比之下,将 Transformer 的 RNN 视图(见图 1b)处理新 token 的方式是添加一个新的RNN,新 token 作为其初始状态。新的 RNN 处理所有先前的 token,需要 O(N) token 数量的线性计算。在 Perceiver 中,由于其体系结构,潜变量(图 1c 中的 L_i)是与输入相关的,这意味着当接收到新 token 时,它们的值会改变。由于它们的 RNN 的初始隐藏状态(即,潜变量)会改变,因此 Perceiver 需要从头重新计算它们的 RNN,需要 O(NL)  token 数量(N)和潜变量数量(L)的线性计算。

3.2 注意力作为(多对多)RNN

针对这些限制,我们提议开发一种基于注意力的模型,能够利用 RNN 形式的能力进行高效更新。为此,我们首先引入了一种高效的并行化方法,用于将注意力计算为多对多的 RNN,即,一种并行方法来计算

为此,我们利用并行前缀扫描算法(Blelloch,1990)(见算法 1),这是一种并行计算方法,用于通过关联运算符 ⊕ 从 N 个连续数据点计算 N 个前缀计算。该算法可以从 {xk}^N_(k=1) 高效地计算

回想一下,Attention(q, x(1:k)) = ok = ak / ck。为了高效计算 Attention(q, x(1:k)),我们可以通过并行扫描算法计算 {ak}^N_(k=1),{ck}^N_(k=1) 和 {mk}^N_(k=1),然后将 ak 和 ck 结合起来计算 Attention(q, x(1:k))。

为此,我们提出以下关联运算符 ⊕,它作用于形式为(m_A, u_A, w_A)的三元组,其中 A 是索引集合,

mA = max_(i∈A) si,uA = ∑_(i∈A) exp(si − mA),wA = ∑_(i∈A) exp(si − mA)vi

并行扫描算法将 {(m{i}, u{i}, w{i})}^N_(i=1) = {(si, 1, vi)}^N_(i=1) 作为输入。该算法递归地应用运算符 ⊕,其工作方式如下:

(mA, uA, wA)⊕(mB, uB, wB) = (mA∪B, uA∪B, wA∪B)

mA∪B = max(mA, mB),uA∪B = uA·exp(mA − mA∪B) + uB exp(mB − mA∪B)

wA∪B = wA· exp(mA−mA∪B) + wB·exp(mB−mA∪B)

在递归地应用运算符完成后,该算法输出

{(m{1,...,k}, u{1,...,k}, w{1,...,k})}^N_(k=1) = {(mk, ∑_(i=1) exp(si − mk), ∑_(i=1) exp(si − mk)vi)}^N_(k=1)

也称为 {(mk, ck, ak)}^N_(k=1)。

通过组合输出元组的最后两个值,我们得到 Attention(q, x(1:k)) = ok = ak / ck,从而实现了一种高效的并行化方法,用于计算注意力作为多对多的 RNN(见图 3)。

3.3 Aaren: 将注意力当作 RNN

利用注意力的并行化多对多形式,我们提出了Aaren。Aaren 的接口与 Transformer 相同,将 N 个输入映射到 N 个输出,其中第 i 个输出是从第 1 到第 i 个输入的聚合。因此,Aaren也是自然可堆叠的,且能够为每个序列 token 计算单独的损失项

然而,与使用因果自注意力的 Transformer 不同,Aaren 使用了前述的将注意力计算为多对多 RNN 的方法,使其更加高效。Aaren 的功能如下:

与 Transformer 中的查询是注意力的一个输入 token 不同,Aaren 的查询 token q 是通过反向传播在训练过程中学习的。在图 4 中,我们包含了一个具有输入上下文 token x_(1​:3) 和输出 y_(1:3) 的堆叠 Aaren 模型的示例。值得注意的是,由于 Aaren 利用了注意力的 RNN 形式,Aaren 的堆叠也是 RNN 的堆叠。因此,Aarens 也能够以高效的方式与新 token 进行更新,即,仅需要固定计算量的 y_k 的迭代计算,因为它仅依赖于 h_(k-1) 和 x_k。与基于 Transformer 的模型不同,它们在使用 KV 缓存时需要线性内存,并且需要存储所有先前的 token,包括中间 Transformer 层中的 token,而基于 Aaren 的模型仅需要固定内存,并且不需要存储所有先前的 token,使得 Aaren 比 Transformer 更加高效。

4. 实验 

附录

F. 计算时间

我们的实验是在 Nvidia GTX 1080 Ti(12 GB)和 Nvidia Tesla P100(16 GB)GPU 的混合环境中运行的。分析是在 Nvidia GTX 1080 Ti(12 GB)上执行的。

强化学习实验大致需要相同的时间:每个实验约 2 到 4 小时。

事件预测实验的时间因数据集而异:

  • MIMIC 约需 0.5 小时
  • Wiki 约需 0.75 小时
  • Reddit 约需 3.5 小时
  • Mooc 约需 8 小时
  • StackOverflow 约需 3.5 小时
  • Sin 约需 1.5 小时
  • Uber 约需 3 小时
  • Taxi 约需1.5 小时

时间序列预测实验以 T ∈ {96, 192, 336, 720} 的单个脚本运行。实验时间因数据集而异:

  • Weather 约需 6 小时
  • Exchange 约需 0.5 小时
  • Traffic约需 1 小时
  • ECL 约需 4 小时
  • ETTh1 约需 0.75 小时
  • ETTm1 约需 11 小时
  • ETTh2 约需 0.75 小时
  • ETTm2 约需 11 小时

时间序列分类实验作为单个脚本运行。在所有数据集上运行实验总共约需 1 小时。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1699665.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

9.4 Go语言入门(运算符)

Go语言入门(运算符) 目录三、运算符1. 算术运算符2. 关系运算符3. 逻辑运算符4. 位运算符5. 赋值运算符6. 其他运算符7. 运算符优先级 目录 Go 语言(Golang)是一种静态类型、编译型语言,由 Google 开发,专注…

异步那些事01

首先我们肯定先说创建线程 1.继承Thread类 o定义一个类MyThread继承Thread类 o在MyThread类中重写run()方法 o创建MyThread类的对象 o启动线程 package Java.thread;public class first extends Thread{public void run(){for(int i0;i<50;i){System.out.println("我…

go ast语义分析实现指标计算器

什么是AST 首先我们要知道AST是什么&#xff08;Abstract Syntax Tree&#xff0c;AST&#xff09;&#xff0c;简称为语法树&#xff0c;是go语言源代码语法结构的一种抽象表示。它以树状的形式表现编程语言的语法结构&#xff0c;树上的每个节点都表示源代码中的一种结构。 …

英语四级翻译练习笔记①——大学英语四级考试2023年12月真题(第一套)——用ChatGPT修改训练四级翻译

目录 引言&#xff08;必看&#xff09; 翻译原文 我的翻译 得分&#xff08;1-3分&#xff09; 原文&#xff1a; 你的翻译&#xff1a; 修改后的翻译&#xff1a; 详细错误讲解&#xff1a; 引言&#xff08;必看&#xff09; 这是一篇英语四级翻译的练习的专栏&…

Java刷题总结(面试)

1、String类 String不可变 java 中String是 immutable的&#xff0c;也就是不可变&#xff0c;一旦初始化&#xff0c;其引用指向的内容是不可变的。 也就是说&#xff0c;String str “aa”&#xff1b;str“bb”&#xff1b;第二句不是改变“aa”所存储地址的内容&#xf…

计算机毕业设计 | SSM汽车租赁系统(附源码)

1&#xff0c; 概述 1.1 课题背景 随着社会的快速发展&#xff0c;计算机的影响是全面且深入的。用户生活水平的不断提高&#xff0c;日常生活中用户对汽车租赁系统方面的要求也在不断提高&#xff0c;需要汽车租赁系统查询的人数更是不断增加&#xff0c;使得汽车租赁系统的…

项目管理:敏捷实践框架

一、初识敏捷 什么是敏捷(Agile)?敏捷是思维方式。 传统开发模型 央企,国企50%-60%需求分析。整体是由文档控制的过程管理。 传统软件开发面临的问题: 交付周期长:3-6个月甚至更长沟通效果差:文档化沟通不及时按时发布低:技术债增多无法发版团队士气弱:死亡行军不关注…

数据库SQL语言实战(十)(最后一篇)

目录 前言 练习题 实验八 实验九 题目一 题目二 总结 前言 本篇练习题的重点有两个&#xff1a; 一、测试提交commit和回滚rollback的作用,了解锁等待、授权等知识。 二、学会复制表结构、学会插入数据&#xff0c;特别是学会如何避免重复插入&#xff0c;也就是如何避…

【云原生】K8s管理工具--Kubectl详解(一)

一、陈述式管理 1.1、陈述式资源管理方法 kubernetes 集群管理集群资源的唯一入口是通过相应的方法调用 apiserver 的接口kubectl 是官方的 CLI 命令行工具&#xff0c;用于与 apiserver 进行通信&#xff0c;将用户在命令行输入的命令&#xff0c;组织并转化为apiserver 能识…

实时通信的方式——WebRTC

文章目录 基于WebRTC实现音视频通话P2P通信原理如何发现对方&#xff1f; 不同的音视频编解码能力如何沟通&#xff1f;&#xff08;媒体协商SDP&#xff09;如何联系上对方&#xff1f;&#xff08;网络协商&#xff09; 常用的API音视频采集getUserMedia核心对象RTCPeerConne…

蓝桥杯物联网竞赛_STM32L071KBU6_关于size of函数产生的BUG

首先现象是我在用LORA发送信息的时候&#xff0c;左边显示长度是8而右边接收到的数据长度却是4 我以为是OLED显示屏坏了&#xff0c;又或者是我想搞创新用了const char* 类型强制转换数据的原因&#xff0c;结果发现都不是 void Function_SendMsg( unsigned char* data){unsi…

C语言系列文章 | 函数 (共 10209 字)

目前主要分为三个专栏&#xff0c;后续还会添加&#xff1a; 专栏如下&#xff1a; C语言刷题解析 C语言系列文章 我的成长经历 感谢阅读&#xff01; 初来乍到&#xff0c;如有错误请指出&#xff0c;感谢&#xff01; 目录 函数的概念库函数自…

匠心独运的掺Si量子势垒策略,显著提升了AlGaN基深紫外LED出光率

WHU团队凭借匠心独运的三明治式掺Si量子势垒策略&#xff0c;显著提升了AlGaN基深紫外光LED的效率&#xff0c;这一创新成果为中国武汉大学的研究团队所取得。他们巧妙地设计出一种三明治状Si掺杂&#xff08;未掺杂&#xff09;方案&#xff0c;应用于Al0.6Ga0.4N量子势垒中&a…

Android硬件渲染流程

Android硬件渲染流程 一.渲染流程1.VSync信号的监听2.VSync信号触发绘制 二.渲染原理1.画布的获取1.1 画布的创建1.2 渲染指令列表的创建 2.绘制与渲染指令2.1 矩形的绘制2.2 硬件渲染指令2.3 节点的绘制 3.绘制的提交3.1 绘制结果的保存3.2 绘制结果的获取 4.层级的构建4.1 绘…

FFmpeg的流程

文章目录 前序代码结构FFmpeg.cffmpeg_opt.c 小结 前序 之前看过FFmpeg的各种命令&#xff0c;然后不是很理解。相信很多人都不是很理解&#xff0c;毕竟&#xff0c;单纯的去记住那些命令行本身就需要很大的内存&#xff0c;我们的大脑内存又有限&#xff0c;所以&#xff0c…

spring cloud alibaba sentinel 配置过程 流控 降级热点 授权

目录 1.基础理论 2.配置 3.加入依赖和配置文件 4.流控 1.基础理论 Sentinel是阿里开源的项目&#xff0c;提供了流量控制、熔断降级、系统负载保护等多个维度来保障服务之间的稳定性。 丰富的应用场景 &#xff1a;Sentinel 承接了阿里巴巴近 10 年的双十一大促流量的核心…

锁相环的一些学习笔记--(1)

下图两组1.2.3可以对应起来&#xff1b; 一些分析&#xff1a; 1.根据这个可知最后vco_voltage停在0.5v 参考资料&#xff1a; 1. Matlab https://www.bilibili.com/video/BV1bR4y1Z7Xg/?spm_id_from333.1296.top_right_bar_window_history.content.click&vd_source555…

windows 搭建 go开发环境

go语言&#xff08;或 Golang&#xff09;是Google开发的开源编程语言&#xff0c;诞生于2006年1月2日下午15点4分5秒&#xff0c;于2009年11月开源&#xff0c;2012年发布go稳定版。Go语言在多核并发上拥有原生的设计优势&#xff0c;Go语言从底层原生支持并发&#xff0c;无须…

Google Play 提示 “您的设备与此版本不兼容“ 解决方案

一、 问题概述Google Play提示“您的设备与此版本不兼容”&#xff0c;无法安装应用。 遇到问题的设备为Xiaomi Mi A3&#xff0c;查了下这台手机的基本信息&#xff0c;Android One系统&#xff0c;版本分为9.0、10.0、11.0。 二、 问题分析Google Play的过滤器 通常有以下5种…

2024-5-10-从0到1手写配置中心Config之Spring Value热更新

定义SpringValueProcessor处理类 实现BeanPostProcessor后置处理器接口&#xff0c;扫描所有的Spring value&#xff0c;保存起来。实现ApplicationListener接口&#xff0c;在配置变更时&#xff0c;更新所有的spring value 实现BeanPostProcessor后置处理器接口 实现postPr…