【深度学习笔记】10_11 注意力机制

news2025/1/17 21:37:29

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

10.11 注意力机制

在10.9节(编码器—解码器(seq2seq))里,解码器在各个时间步依赖相同的背景变量来获取输入序列信息。当编码器为循环神经网络时,背景变量来自它最终时间步的隐藏状态。

现在,让我们再次思考那一节提到的翻译例子:输入为英语序列“They”“are”“watching”“.”,输出为法语序列“Ils”“regardent”“.”。不难想到,解码器在生成输出序列中的每一个词时可能只需利用输入序列某一部分的信息。例如,在输出序列的时间步1,解码器可以主要依赖“They”“are”的信息来生成“Ils”,在时间步2则主要使用来自“watching”的编码信息生成“regardent”,最后在时间步3则直接映射句号“.”。这看上去就像是在解码器的每一时间步对输入序列中不同时间步的表征或编码信息分配不同的注意力一样。这也是注意力机制的由来 [1]。

仍然以循环神经网络为例,注意力机制通过对编码器所有时间步的隐藏状态做加权平均来得到背景变量。解码器在每一时间步调整这些权重,即注意力权重,从而能够在不同时间步分别关注输入序列中的不同部分并编码进相应时间步的背景变量。本节我们将讨论注意力机制是怎么工作的。

在10.9节(编码器—解码器(seq2seq))里我们区分了输入序列或编码器的索引 t t t与输出序列或解码器的索引 t ′ t' t。该节中,解码器在时间步 t ′ t' t的隐藏状态 s t ′ = g ( y t ′ − 1 , c , s t ′ − 1 ) \boldsymbol{s}_{t'} = g(\boldsymbol{y}_{t'-1}, \boldsymbol{c}, \boldsymbol{s}_{t'-1}) st=g(yt1,c,st1),其中 y t ′ − 1 \boldsymbol{y}_{t'-1} yt1是上一时间步 t ′ − 1 t'-1 t1的输出 y t ′ − 1 y_{t'-1} yt1的表征,且任一时间步 t ′ t' t使用相同的背景变量 c \boldsymbol{c} c。但在注意力机制中,解码器的每一时间步将使用可变的背景变量。记 c t ′ \boldsymbol{c}_{t'} ct是解码器在时间步 t ′ t' t的背景变量,那么解码器在该时间步的隐藏状态可以改写为

s t ′ = g ( y t ′ − 1 , c t ′ , s t ′ − 1 ) . \boldsymbol{s}_{t'} = g(\boldsymbol{y}_{t'-1}, \boldsymbol{c}_{t'}, \boldsymbol{s}_{t'-1}). st=g(yt1,ct,st1).

这里的关键是如何计算背景变量 c t ′ \boldsymbol{c}_{t'} ct和如何利用它来更新隐藏状态 s t ′ \boldsymbol{s}_{t'} st。下面将分别描述这两个关键点。

10.11.1 计算背景变量

我们先描述第一个关键点,即计算背景变量。图10.12描绘了注意力机制如何为解码器在时间步2计算背景变量。首先,函数 a a a根据解码器在时间步1的隐藏状态和编码器在各个时间步的隐藏状态计算softmax运算的输入。softmax运算输出概率分布并对编码器各个时间步的隐藏状态做加权平均,从而得到背景变量。

图10.12 编码器—解码器上的注意力机制

具体来说,令编码器在时间步 t t t的隐藏状态为 h t \boldsymbol{h}_t ht,且总时间步数为 T T T。那么解码器在时间步 t ′ t' t的背景变量为所有编码器隐藏状态的加权平均:

c t ′ = ∑ t = 1 T α t ′ t h t , \boldsymbol{c}_{t'} = \sum_{t=1}^T \alpha_{t' t} \boldsymbol{h}_t, ct=t=1Tαttht,

其中给定 t ′ t' t时,权重 α t ′ t \alpha_{t' t} αtt t = 1 , … , T t=1,\ldots,T t=1,,T的值是一个概率分布。为了得到概率分布,我们可以使用softmax运算:

α t ′ t = exp ⁡ ( e t ′ t ) ∑ k = 1 T exp ⁡ ( e t ′ k ) , t = 1 , … , T . \alpha_{t' t} = \frac{\exp(e_{t' t})}{ \sum_{k=1}^T \exp(e_{t' k}) },\quad t=1,\ldots,T. αtt=k=1Texp(etk)exp(ett),t=1,,T.

现在,我们需要定义如何计算上式中softmax运算的输入 e t ′ t e_{t' t} ett。由于 e t ′ t e_{t' t} ett同时取决于解码器的时间步 t ′ t' t和编码器的时间步 t t t,我们不妨以解码器在时间步 t ′ − 1 t'-1 t1的隐藏状态 s t ′ − 1 \boldsymbol{s}_{t' - 1} st1与编码器在时间步 t t t的隐藏状态 h t \boldsymbol{h}_t ht为输入,并通过函数 a a a计算 e t ′ t e_{t' t} ett

e t ′ t = a ( s t ′ − 1 , h t ) . e_{t' t} = a(\boldsymbol{s}_{t' - 1}, \boldsymbol{h}_t). ett=a(st1,ht).

这里函数 a a a有多种选择,如果两个输入向量长度相同,一个简单的选择是计算它们的内积 a ( s , h ) = s ⊤ h a(\boldsymbol{s}, \boldsymbol{h})=\boldsymbol{s}^\top \boldsymbol{h} a(s,h)=sh。而最早提出注意力机制的论文则将输入连结后通过含单隐藏层的多层感知机变换 [1]:

a ( s , h ) = v ⊤ tanh ⁡ ( W s s + W h h ) , a(\boldsymbol{s}, \boldsymbol{h}) = \boldsymbol{v}^\top \tanh(\boldsymbol{W}_s \boldsymbol{s} + \boldsymbol{W}_h \boldsymbol{h}), a(s,h)=vtanh(Wss+Whh),

其中 v \boldsymbol{v} v W s \boldsymbol{W}_s Ws W h \boldsymbol{W}_h Wh都是可以学习的模型参数。

10.11.1.1 矢量化计算

我们还可以对注意力机制采用更高效的矢量化计算。广义上,注意力机制的输入包括查询项以及一一对应的键项和值项,其中值项是需要加权平均的一组项。在加权平均中,值项的权重来自查询项以及与该值项对应的键项的计算。

在上面的例子中,查询项为解码器的隐藏状态,键项和值项均为编码器的隐藏状态。
让我们考虑一个常见的简单情形,即编码器和解码器的隐藏单元个数均为 h h h,且函数 a ( s , h ) = s ⊤ h a(\boldsymbol{s}, \boldsymbol{h})=\boldsymbol{s}^\top \boldsymbol{h} a(s,h)=sh。假设我们希望根据解码器单个隐藏状态 s t ′ − 1 ∈ R h \boldsymbol{s}_{t' - 1} \in \mathbb{R}^{h} st1Rh和编码器所有隐藏状态 h t ∈ R h , t = 1 , … , T \boldsymbol{h}_t \in \mathbb{R}^{h}, t = 1,\ldots,T htRh,t=1,,T来计算背景向量 c t ′ ∈ R h \boldsymbol{c}_{t'}\in \mathbb{R}^{h} ctRh
我们可以将查询项矩阵 Q ∈ R 1 × h \boldsymbol{Q} \in \mathbb{R}^{1 \times h} QR1×h设为 s t ′ − 1 ⊤ \boldsymbol{s}_{t' - 1}^\top st1,并令键项矩阵 K ∈ R T × h \boldsymbol{K} \in \mathbb{R}^{T \times h} KRT×h和值项矩阵 V ∈ R T × h \boldsymbol{V} \in \mathbb{R}^{T \times h} VRT×h相同且第 t t t行均为 h t ⊤ \boldsymbol{h}_t^\top ht。此时,我们只需要通过矢量化计算

softmax ( Q K ⊤ ) V \text{softmax}(\boldsymbol{Q}\boldsymbol{K}^\top)\boldsymbol{V} softmax(QK)V

即可算出转置后的背景向量 c t ′ ⊤ \boldsymbol{c}_{t'}^\top ct。当查询项矩阵 Q \boldsymbol{Q} Q的行数为 n n n时,上式将得到 n n n行的输出矩阵。输出矩阵与查询项矩阵在相同行上一一对应。

10.11.2 更新隐藏状态

现在我们描述第二个关键点,即更新隐藏状态。以门控循环单元为例,在解码器中我们可以对6.7节(门控循环单元(GRU))中门控循环单元的设计稍作修改,从而变换上一时间步 t ′ − 1 t'-1 t1的输出 y t ′ − 1 \boldsymbol{y}_{t'-1} yt1、隐藏状态 s t ′ − 1 \boldsymbol{s}_{t' - 1} st1和当前时间步 t ′ t' t的含注意力机制的背景变量 c t ′ \boldsymbol{c}_{t'} ct [1]。解码器在时间步 t ′ t' t的隐藏状态为

s t ′ = z t ′ ⊙ s t ′ − 1 + ( 1 − z t ′ ) ⊙ s ~ t ′ , \boldsymbol{s}_{t'} = \boldsymbol{z}_{t'} \odot \boldsymbol{s}_{t'-1} + (1 - \boldsymbol{z}_{t'}) \odot \tilde{\boldsymbol{s}}_{t'}, st=ztst1+(1zt)s~t,

其中的重置门、更新门和候选隐藏状态分别为

r t ′ = σ ( W y r y t ′ − 1 + W s r s t ′ − 1 + W c r c t ′ + b r ) , z t ′ = σ ( W y z y t ′ − 1 + W s z s t ′ − 1 + W c z c t ′ + b z ) , s ~ t ′ = tanh ( W y s y t ′ − 1 + W s s ( s t ′ − 1 ⊙ r t ′ ) + W c s c t ′ + b s ) , \begin{aligned} \boldsymbol{r}_{t'} &= \sigma(\boldsymbol{W}_{yr} \boldsymbol{y}_{t'-1} + \boldsymbol{W}_{sr} \boldsymbol{s}_{t' - 1} + \boldsymbol{W}_{cr} \boldsymbol{c}_{t'} + \boldsymbol{b}_r),\\ \boldsymbol{z}_{t'} &= \sigma(\boldsymbol{W}_{yz} \boldsymbol{y}_{t'-1} + \boldsymbol{W}_{sz} \boldsymbol{s}_{t' - 1} + \boldsymbol{W}_{cz} \boldsymbol{c}_{t'} + \boldsymbol{b}_z),\\ \tilde{\boldsymbol{s}}_{t'} &= \text{tanh}(\boldsymbol{W}_{ys} \boldsymbol{y}_{t'-1} + \boldsymbol{W}_{ss} (\boldsymbol{s}_{t' - 1} \odot \boldsymbol{r}_{t'}) + \boldsymbol{W}_{cs} \boldsymbol{c}_{t'} + \boldsymbol{b}_s), \end{aligned} rtzts~t=σ(Wyryt1+Wsrst1+Wcrct+br),=σ(Wyzyt1+Wszst1+Wczct+bz),=tanh(Wysyt1+Wss(st1rt)+Wcsct+bs),

其中含下标的 W \boldsymbol{W} W b \boldsymbol{b} b分别为门控循环单元的权重参数和偏差参数。

10.11.3 发展

本质上,注意力机制能够为表征中较有价值的部分分配较多的计算资源。这个有趣的想法自提出后得到了快速发展,特别是启发了依靠注意力机制来编码输入序列并解码出输出序列的变换器(Transformer)模型的设计 [2]。变换器抛弃了卷积神经网络和循环神经网络的架构。它在计算效率上比基于循环神经网络的编码器—解码器模型通常更具明显优势。含注意力机制的变换器的编码结构在后来的BERT预训练模型中得以应用并令后者大放异彩:微调后的模型在多达11项自然语言处理任务中取得了当时最先进的结果 [3]。不久后,同样是基于变换器设计的GPT-2模型于新收集的语料数据集预训练后,在7个未参与训练的语言模型数据集上均取得了当时最先进的结果 [4]。除了自然语言处理领域,注意力机制还被广泛用于图像分类、自动图像描述、唇语解读以及语音识别。

小结

  • 可以在解码器的每个时间步使用不同的背景变量,并对输入序列中不同时间步编码的信息分配不同的注意力。
  • 广义上,注意力机制的输入包括查询项以及一一对应的键项和值项。
  • 注意力机制可以采用更为高效的矢量化计算。

参考文献

[1] Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473.

[2] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need. In Advances in Neural Information Processing Systems (pp. 5998-6008).

[3] Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805.

[4] Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., Sutskever I. (2019). Language Models are Unsupervised Multitask Learners. OpenAI.


注:本节与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1518894.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信小程序--开启下拉刷新页面

1、下拉刷新获取数据enablePullDownRefresh 开启下拉刷新: enablePullDownRefreshbooleanfalse是否开启当前页面下拉刷新 案例: 下拉刷新,获取新的列表数据,其实就是进行一次新的网络请求: 第一步:在.json文件中开…

分享6款非常优质炫酷的前端动画特效(附在线演示)

分享6款非常不错的项目动画特效 其中有three.js特效、canvas特效、CSS动画、SVG特效等等 下方效果图可能不是特别的生动 那么你可以点击在线预览进行查看相应的动画特效 同时也是可以下载该资源的 CSS33D海盗船动画 基于纯CSS3的3D海盗船动画,该动画分两部分&…

BUGKU-WEB No one knows regex better than me

题目描述 题目截图如下&#xff1a; 进入场景看看&#xff1a; 解题思路 看到此类题目&#xff0c;直接代码审计 相关工具 base64 在线加密https://www.mklab.cn/utils/regex 解题步骤 代码审计 <?php error_reporting(0); # 从请求中获取了两个参数&#xff1…

Servlet的book图书表格实现(使用原生js实现)

作业内容&#xff1a; 1 建立一个book.html,实现图书入库提交 整体参考效果如下: 数据提交后&#xff0c;以窗口弹出数据结果&#xff0c;如: 2 使用正则表达式验证ISBN为x-x-x格式&#xff0c;图书名不低于2个字符&#xff0c;作者不能为空&#xff0c;单价在【10-100】之间…

深入解析Arm架构:掌握汇编与逆向工程,提升设备安全性——蓝狐卷带你入门

写在前面 与传统的CISC&#xff08;Complex Instruction Set Computer&#xff0c;复杂指令集计算机&#xff09;架构相比&#xff0c;Arm架构的指令集更加简洁明了&#xff0c;指令执行效率更高&#xff0c;能够在更低的功耗下完成同样的计算任务&#xff0c;因此在低功耗、嵌…

用pyecharts的overlap绘制叠加图时,设置的颜色不起作用

问题 用pyecharts绘制叠加图时&#xff0c;如折线图上叠加散点图时&#xff0c;分别设置了自己的颜色&#xff08;三角是绿色&#xff0c;圆形是蓝色&#xff09;&#xff0c;但是渲染颜色和图例颜色不一致&#xff0c;如下图所示&#xff0c;折线颜色和散点颜色相同。 解决…

k8s集群部署elk

一、前言 本次部署elk所有的服务都部署在k8s集群中&#xff0c;服务包含filebeat、logstash、elasticsearch、kibana&#xff0c;其中elasticsearch使用集群的方式部署&#xff0c;所有服务都是用7.17.10版本 二、部署 部署elasticsearch集群 部署elasticsearch集群需要先优化…

VMware 配置虚拟机网络

之前需要完成的任务 &#xff08;1&#xff09;、下载和安装VMware-Workstation-Pro.exe软件&#xff0c;推荐16.0版本 &#xff08;2&#xff09;、下载centOS7镜像&#xff0c;可以在阿里云下载。 &#xff08;3&#xff09;、VM创建一个虚拟机&#xff0c;并且使用本地已下载…

IDEA配置JRebel热部署

插件仓库安装 打开IDEA&#xff0c;选择File—>Settings—>Plugins—>在右侧选择Marketplace&#xff0c; 在搜索框输入jrebel—>选择搜索结果—>点击Install JRebel激活 其中&#xff0c;Team URL可以使用在线GUID地址在线生成GUID 拿到GUID串之后&#xff…

vue2点击左侧的树节点(el-tree)定位到对应右侧树形表格(el-table)的位置,树形表格懒加载

左侧树代码 <el-tree :data"treeData" node-key"id" default-expand-all"" //节点默认全部展开:expand-on-click-node"false" //是否在点击节点的时候展开或者收缩节点:props"defaultProps" node-click"handleNodeC…

(二十五)Flask之MTVMVC架构模式Demo【重点:原生session使用及易错点!】

目录&#xff1a; 每篇前言&#xff1a;MTV&MVC构建一个基于MTV模式的Demo项目&#xff1a;蹦出一个问题&#xff1a; 每篇前言&#xff1a; &#x1f3c6;&#x1f3c6;作者介绍&#xff1a;【孤寒者】—CSDN全栈领域优质创作者、HDZ核心组成员、华为云享专家Python全栈领…

社区居民医疗健康系统 微信小程序

设计原则 本社区健康医疗APP采用 Hbuildex技术&#xff0c;使用Java语言开发&#xff0c;充分保证了系统稳定性、完整性。 社区健康医疗APP的设计与实现的设计思想如下&#xff1a; &#xff08;1&#xff09;操作简单方便、系统界面安全良好、简单明了的页面布局、方便查询相…

java-模拟的例题实战

例题实战 在实际的开发工作中&#xff0c;对字符串的处理是最常见的编程惹怒我。本题目即是要求程序对用户输入的串进行处理。具体规则如下&#xff1a; 1 把每个单词的首字母变成大写 2 把数字与字母之间用下划线字符&#xff08;_&#xff09;分开&#xff0c;使得更清晰 …

下载BenchmarkSQL并使用BenchmarkSQL查看OceanBase 的执行计划

下载BenchmarkSQL并使用BenchmarkSQL查看OceanBase 的执行计划 一、什么是BenchmarkSQL二、下载BenchmarkSQL三、使用BenchmarkSQL查看OceanBase 的执行计划 一、什么是BenchmarkSQL BenchmarkSQL是一个开源的数据库基准测试工具&#xff0c;可以用来评估数据库系统的性能&…

unity3d Animal Controller的Animal组件中Stances,Advanced基础部分理解

Stances 立场 立场要求在动物动画控制器上的姿态动画参数。 你可以有多个运动状态,并根据当前的立场使用它们 过渡的条件是: Stance StanceID Default Stance默认姿势 如果调用函数Stance_Reset&#xff08;&#xff09;&#xff0c;动物将返回到的默认姿势。 Current …

webconfig-boot项目说明

1、前言 最近利用空余时间写了一个项目webconfig-boot 。该项目主要配置了web项目常用的一些配置&#xff0c;如统一参数校验、统一异常捕获、统一日期的处理、常用过滤器、常用注解等。引入依赖接口完成常规的web配置。 这里也是总结了笔者在项目开发中遇到的一些常用的配置…

力扣111---二叉树的最小深度(简单题,Java,递归+非递归)

目录 题目描述&#xff1a; &#xff08;递归&#xff09;代码&#xff1a; &#xff08;非递归、层次遍历&#xff09;代码&#xff1a; 题目描述&#xff1a; 给定一个二叉树&#xff0c;找出其最小深度。 最小深度是从根节点到最近叶子节点的最短路径上的节点数量。 说…

【源码编译】Apache SeaTunnel-Web 适配最新2.3.4版本教程

Apache SeaTunnel新版本已经发布&#xff0c;感兴趣的小伙伴可以看之前版本发布的文章 本文主要给大家介绍为使用2.3.4版本的新特性&#xff0c;需要对Apache SeaTunnel-Web依赖的版本进行升级&#xff0c;而SeaTunnel2.3.4版本部分API跟之前版本不兼容&#xff0c;所以需要对 …

备战蓝桥杯Day27 - 省赛真题-2023

题目描述 大佬代码 import os import sysdef find(n):k 0for num in range(12345678,98765433):str1 ["2","0","2","3"]for x in str(num) :if x in str1:if str1[0] x:str1.pop(0)if len(str1) ! 0:k1print(k)print(85959030) 详…

Qt 图形视图 /基于Qt示例DiagramScene解读图形视图框架

文章目录 概述从帮助文档看示例程序了解程序背景/功能理清程序概要设计 分析图形视图的协同运作机制如何嵌入到普通Widget程序中&#xff1f;形状Item和文本Item的插入和删除&#xff1f;连接线Item与形状Item的如何关联&#xff1f;如何绘制ShapeItem间的箭头线&#xff1f; 下…