小白学视觉 | 各种各样神奇的自注意力机制(Self-attention) 建议收藏!

news2024/12/25 0:53:35

本文来源公众号“小白学视觉”,仅用于学术分享,侵权删,干货满满。

原文链接:收藏!各种各样神奇的自注意力机制(Self-attention)

编者荐语

文章总结了关于李宏毅老师在 2022 年春季机器学习课程中关于各种注意力机制介绍的主要内容,也是相对于 2021 年课程的补充内容。

转载自丨PaperWeekly

参考视频见:

https://www.bilibili.com/video/BV1Wv411h7kN/?p=51&vd_source=eaef25ec79a284858ac2a990307e06ae

在 2021 年课程的 transformer 视频中,李老师详细介绍了部分 self-attention 内容,但是 self-attention 其实还有各种各样的变化形式:

先简单复习下之前的 self-attention。假设输入序列(query)长度是 N,为了捕捉每个 value 或者 token 之间的关系,需要对应产生 N 个 key 与之对应,并将  query 与 key 之间做 dot-product,就可以产生一个 Attention Matrix(注意力矩阵),维度 N*N。这种方式最大的问题就是当序列长度太长的时候,对应的 Attention Matrix 维度太大,会给计算带来麻烦。

对于 transformer 来说,self-attention 只是大的网络架构中的一个 module。由上述分析我们知道,对于 self-attention 的运算量是跟 N 的平方成正比的。当 N 很小的时候,单纯增加 self-attention 的运算效率可能并不会对整个网络的计算效率有太大的影响。因此,提高 self-attention 的计算效率从而大幅度提高整个网络的效率的前提是 N 特别大的时候,比如做图像识别(影像辨识、image processing)。

如何加快 self-attention 的求解速度呢?根据上述分析可以知道,影响 self-attention 效率最大的一个问题就是 Attention Matrix 的计算。如果我们可以根据一些人类的知识或经验,选择性的计算 Attention Matrix 中的某些数值或者某些数值不需要计算就可以知道数值,理论上可以减小计算量,提高计算效率。

举个例子,比如我们在做文本翻译的时候,有时候在翻译当前的 token 时不需要给出整个 sequence,其实只需要知道这个 token 两边的邻居,就可以翻译的很准,也就是做局部的 attention(local attention)。这样可以大大提升运算效率,但是缺点就是只关注周围局部的值,这样做法其实跟 CNN 就没有太大的区别了。

如果觉得上述这种 local attention 不好,也可以换一种思路,就是在翻译当前 token 的时候,给它空一定间隔(stride)的左右邻居,从而捕获当前与过去和未来的关系。当然stride的数值可以自己确定。

还有一种 global attention 的方式,就是选择 sequence 中的某些 token 作为 special token(比如标点符号),或者在原始的 sequence 中增加 special token。让 special token 与序列产生全局的关系,但是其他不是 special token 的 token 之间没有 attention。以在原始 sequence 前面增加两个 special token 为例:

到底哪种 attention 最好呢?小孩子才做选择...对于一个网络,有的 head 可以做 local attention,有的 head 可以做 global attention... 这样就不需要做选择了。看下面几个例子:

  • Longformer 就是组合了上面的三种 attention

  • Big Bird 就是在 Longformer 基础上随机选择 attention 赋值,进一步提高计算效率

上面集中方法都是人为设定的哪些地方需要算 attention,哪些地方不需要算 attention,但是这样算是最好的方法吗?并不一定。对于 Attention Matrix 来说,如果某些位置值非常小,我们可以直接把这些位置置 0,这样对实际预测的结果也不会有太大的影响。也就是说我们只需要找出 Attention Matrix 中 attention 的值相对较大的值。但是如何找出哪些位置的值非常小/非常大呢?

下面这两个文献中给出一种 Clustering(聚类)的方案,即先对 query 和 key 进行聚类。属于同一类的 query 和 key 来计算 attention,不属于同一类的就不参与计算,这样就可以加快 Attention Matrix 的计算。比如下面这个例子中,分为 4 类:1(红框)、2(紫框)、3(绿框)、4(黄框)。在下面两个文献中介绍了可以快速粗略聚类的方法。

有没有一种将要不要算 attention 的事情用 learn 的方式学习出来呢?有可能的。我们再训练一个网络,输入是 input sequence,输出是相同长度的 weight sequence。将所有 weight sequence 拼接起来,再经过转换,就可以得到一个哪些地方需要算 attention,哪些地方不需要算 attention 的矩阵。有一个细节是:某些不同的 sequence 可能经过 NN 输出同一个 weight sequence,这样可以大大减小计算量。

上述我们所讲的都是 N*N 的 Matrix,但是实际来说,这样的 Matrix 通常来说并不是满秩的,也就是说我们可以对原始 N*N 的矩阵降维,将重复的 column 去掉,得到一个比较小的 Matrix。

具体来说,从 N 个 key 中选出 K 个具有代表的 key,每个 key 对应一个 value,然后跟 query 做点乘。然后做 gradient-decent,更新 value。

为什么选有代表性的 key 不选有代表性的 query 呢?因为 query 跟 output 是对应的,这样会 output 就会缩短从而损失信息。

怎么选出有代表性的 key 呢?这里介绍两种方法,一种是直接对 key 做卷积(conv),一种是对 key 跟一个矩阵做矩阵乘法。

回顾一下注意力机制的计算过程,其中 I 为输入矩阵,O 为输出矩阵。

先忽略 softmax,那么可以化成如下表示形式:

上述过程是可以加速的。如果先 V*K^T 再乘 Q 的话相比于 K^T*Q 再乘 V 结果是相同的,但是计算量会大幅度减少。

附:线性代数关于这部分的说明

还是对上面的例子进行说明。如果 K^T*Q,会执行 N*d*N 次乘法。V*A,会再执行 d'*N*N 次乘法,那么一共需要执行的计算量是(d+d')N^2。

如果 V*K^T,会执行 d'*N*d 次乘法。再乘以 Q,会执行 d'*d*N 次乘法,所以总共需要执行的计算量是 2*d'*d*N。

而(d+d')N^2>>2*d'*d*N,所以通过改变运算顺序就可以大幅度提升运算效率。

现在我们把 softmax 拿回来。原来的 self-attention 是这个样子,以计算b1为例:

如果我们可以将 exp(q*k) 转换成两个映射相乘的形式,那么可以对上式进行进一步简化:

▲ 分母部分化简

▲ 分子化简

将括号里面的东西当做一个向量,M 个向量组成 M 维的矩阵,在乘以 φ(q1),得到分子。

用图形化表示如下:

由上面可以看出蓝色的 vector 和黄色的 vector 其实跟 b1 中的 1 是没有关系的。也就是说,当我们算 b2、b3... 时,蓝色的 vector 和黄色的 vector 不需要再重复计算。

self-attention 还可以用另一种方法来看待。这个计算的方法跟原来的 self-attention 计算出的结果几乎一样,但是运算量会大幅度减少。简单来说,先找到一个转换的方式 φ(),首先将 k 进行转换,然后跟 v 做 dot-product 得到 M 维的 vector。再对 q 做转换,跟 M 对应维度相乘。其中 M 维的 vector 只需要计算一次。

b1 计算如下:

b2 计算如下:

可以这样去理解,将 φ(k) 跟 v 计算的 vector 当做一个 template,然后通过 φ(q) 去寻找哪个 template 是最重要的,并进行矩阵的运算,得到输出 b。

那么 φ 到底如何选择呢?不同的文献有不同的做法:

在计算 self-attention 的时候一定需要 q 和 k 吗?不一定。在 Synthesizer 文献里面,对于 attention matrix 不是通过 q 和 k 得到的,而是作为网络参数学习得到。虽然不同的 input sequence 对应的 attention weight 是一样的,但是 performance 不会变差太多。其实这也引发一个思考,attention 的价值到底是什么?

处理 sequence 一定要用 attention 吗?可不可以尝试把 attention 丢掉?有没有 attention-free 的方法?下面有几个用 mlp 的方法用于代替 attention 来处理 sequence。

最后这页图为今天所有讲述的方法的总结。下图中,纵轴的 LRA score 数值越大,网络表现越好;横轴表示每秒可以处理多少 sequence,越往右速度越快;圈圈越大,代表用到的 memory 越多(计算量越大)。

以上就是关于李老师对于《各种各样神奇的自注意力机制(Self-attention)变形》这节课总结的全部内容。

THE END!

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1591858.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MAC(M1芯片)编译Java项目慢且发热严重问题解决方案

目录 一、背景二、排查三、解决四、效果以及结果展示五、总结 一、背景 使用idea编译项目等操作,经常性发热严重,并且时间慢。直到昨天编译一个项目用时30分钟,电脑温度很高,并且有烧灼的味道,于是有了此篇文章。 二、…

关于Jar包提示找不到主类 找不到或无法加载主类

关于Jar包提示找不到主类 找不到或无法加载主类 当时看到教程打包一个正常的小型增删改查为50MB,也就是几十MB,可我打包得到的Jar包只有几MB,一直提示找不到主类application。经检查 根据方法,只需使用mvn clean 和 mvn packa…

大米自动化生产线设备:现代粮食加工的核心力量

随着科技的不断进步和粮食加工行业的快速发展,大米自动化生产线设备在现代粮食加工中的地位愈发重要。这些设备不仅大大提高了生产效率,还保证了产品的质量和安全,成为了现代粮食加工行业不可或缺的核心力量。 一、自动化生产线设备助力效率提…

厂房起火3D消防灭火安全救援模拟演练

深圳VR公司华锐视点依托前沿的VR虚拟现实制作、三维仿真和图形图像渲染技术,将参训者带入栩栩如生的火灾现场。佩戴VR头盔,参训者将真切体验火势蔓延的紧张与危机,身临其境地感受火灾的恐怖。 并且消防安全VR虚拟现实演练系统精心模拟了住宅、…

基于单链表实现通讯管理系统!(有完整源码!)

​ 个人主页:秋风起,再归来~ 文章专栏:C语言实战项目 个人格言:悟已往之不谏,知来者犹可追 克心守己,律己则安! 1、前言 友友们,这篇文章是基于单链…

005Node.js模块URL的使用

引入 URL 模块 要使用 URL 模块,首先需要在代码中引入它。可以使用以下代码将 URL 模块导入到你的脚本中: const url require(url);实例代码 const urlrequire(url); var apihttp://www.baidu.com?nameshixiaobin&age20; console.log(url.parse(…

Android应用和开发环境

🌈个人主页:小新_- 🎈个人座右铭:“成功者不是从不失败的人,而是从不放弃的人!”🎈 🎁欢迎各位→点赞👍 收藏⭐️ 留言📝 🏆所属专栏&#xff1…

70 个常用的GIS Python 库

由于其多功能性、广泛的库生态系统和用户友好的语法,Python 已成为地理信息系统 (GIS) 和遥感领域的主导语言。这个 70 个地理空间 Python 库的汇编展示了可用于 GIS 和遥感数据处理和分析的丰富工具包。 Python 在 GIS 中的重要性源于它处理复杂地理空间数据的能力…

rancher踩坑日志:prometheus访问kubelet 10250端口提示鉴权失败

该原因是因为kubectl禁止了非授权用户访问10250端口来获取node的数据。 解决思路: 添加prometheus访问kubelet时带上证书进行验证匹配 --> 由于我的prometheus是rancher安装的,不知道要怎么修改所以研究了一会没研究明白就放弃了。设置prometheus访问…

运动听歌哪款耳机靠谱?精选五款热门开放式耳机

随着人们对运动健康的重视,越来越多的运动爱好者开始关注如何在运动中享受音乐。开放式蓝牙耳机凭借其独特的设计,成为了户外运动的理想选择。它不仅让你在运动时能够清晰听到周围环境的声音,保持警觉,还能让你在需要时与他人轻松…

gym界面修改

资料:https://blog.csdn.net/weixin_46178278/article/details/135962782 在gym环境中使用mujoco的时候,有一个很难受的地方,界面上没有实时显示动作空间和状态空间状态的地方。 gym自己原始带的环境是用pygame画的图,所以在定义…

【爬虫+数据清洗+可视化分析】Python文本分析《狂飙》电视剧的哔哩哔哩评论

一、背景介绍 把《狂飙》换成其他影视剧,套用代码即可得分析结论! 2023《狂飙》热播剧引发全民追剧,不仅全员演技在线,且符合主旋律,创下多个收视记录! 基于此热门事件,我用python抓取了B站上千…

【SpringBoot】获取参数

获取参数 传递单个参数传递多个参数传递对象后端参数重命名传递数组传递 json 数据获取 URL 中参数上传文件获取 cookie 和 session获取cookie获取session 传递单个参数 RequestMapping("/user") RestController public class UserController {// 传递单个参数Reque…

力扣 | 160. 相交链表

import ListNodeInfo.ListNode;import java.util.HashSet; import java.util.Set;public class Problem_160_IntersectionOfTwoLinkedList {//双指针方法 public ListNode getIntersectionListNode(ListNode headA, ListNode headB){if(headA null || headB null) return nul…

S32K324 CANFD报文接收超限分析

文章目录 前言问题描述原因分析问题处理总结 前言 随着汽车软件复杂度越来越高,传输的数据越来越多,CAN总线到CANFD总线已经是发展的必然了。CANFD总线中单个报文ID可以传递至多64byte数据,对CAN Driver来说,所需的MCU资源也将变…

数据集学习

1,CIFAR-10数据集 CIFAR-10数据集由10个类的60000个32x32彩色图像组成,每个类有6000个图像。有50000个训练图像和10000个测试图像。 数据集分为五个训练批次和一个测试批次,每个批次有10000个图像。测试批次包含来自每个类别的恰好1000个随机…

Ubuntu桌面系统安装成功后的简单配置

目录 更换软件源设置系统时间同步 更换软件源 本文使用的换源方法仅限于桌面版 前提有网络连接,能够访问互联网 在Ubuntu18.04桌面版中,点击左下角“显示应用程序”,搜索“软件更新器”,点击进入。 暂时不要点击立即安装&#xff…

ES6: promise对象与回调地狱

ES6: promise对象与回调地狱 一、回调地狱二、Promise概述三、Promise的组成四、用函数封装Promise读取文件操作 一、回调地狱 在js中大量使用回调函数进行异步操作,而异步操作什么时候返回结果是不可控的,所以希望一段程序按我们制定的顺序执…

免费ssl证书能一直续签吗?如何获取SSL免费证书?

免费SSL证书是否可以一直续签。我们需要了解SSL证书的基本工作原理。当你访问一个使用HTTPS协议的网站时,该网站实际上在使用一个SSL证书。这个证书相当于一个数字身份证明,它验证了网站的真实性和安全性。而这个证明是由受信任的第三方机构——通常是证…

Unity之C#面试题(二)

内容将会持续更新,有错误的地方欢迎指正,谢谢! Unity之C#面试题(二) TechX 坚持将创新的科技带给世界! 拥有更好的学习体验 —— 不断努力,不断进步,不断探索 TechX —— 心探索、心进取&a…