深度网络学习笔记(一)——self-attention机制介绍和计算步骤

news2025/1/20 13:22:41

self-attention机制介绍及其计算步骤

  • 前言
  • 一、介绍和意义
  • 二、 计算细节
    • 2.1 计算Attention Score
    • 2.2 计算value
    • 2.3 计算关联结果b
    • 2.4 统一计算
  • 三、总结

前言

Transformer是一种非常常见且强大的深度学习网络架构,尤其擅长处理输出为可变长度向量序列的任务,如自然语言处理(NLP)和图(Graph)处理。Transformer采用了self-attention机制,克服了传统循环神经网络(RNN)在处理长序列时存在的梯度消失和并行计算困难的问题。本文将详细介绍self-attention机制的定义和计算细节。

如果觉得该笔记对您有用的话,可以点个小小的赞,或者点赞收藏关注一键三连ヾ(◍’౪`◍) ~ 谢谢!!

一、介绍和意义

Self-attention是一种计算输入序列中每个元素对其他所有元素的重要性权重的方法。它主要用于捕捉序列数据(如文本)中的长期依赖关系和上下文信息。尽管类似的功能可以通过全连接层(fully-connected layer)实现,但当输入长度非常大时,全连接层的参数计算量会爆炸性增加,导致计算负担过大。以下图为例,当输入为4个节点时,全连接层的计算量已经非常巨大(箭头数量),可以想象当输入增多后的计算量。在这里插入图片描述
为了解决以上问题,Self-attention便被提了出来。主要计算步骤如下(注意在本文中,所有的计算公式都是基于上图中四个输入(这四个输入将在一个向量中)的例子,即输入为 a 1 , a 2 , a 3 , a 4 a_1,a_2,a_3, a_4 a1a2a3a4)。

二、 计算细节

2.1 计算Attention Score

Self-attention用于捕捉上下文信息,因此引入了注意力分数(Attention Score) α \alpha α。给 α \alpha α加上下标后,就可以代表任意两个节点之间的关联程度了。比如 α 1 , 2 \alpha{_1,_2} α1,2代表了节点a1和a2之间的关联程度。为计算关联程度因子 α \alpha α,我们引入了三个向量Query, Key和Value,分别用首字母q,k,v来表示。这里首先介绍q,k和 α \alpha α的计算公式:
q i = W q ⋅ a i k j = W k ⋅ a j α i , j = q i ⋅ k j q^i = W^q \cdot a^i \\ k^j = W^k \cdot a^j \\ \alpha{_i,_j} = q^i \cdot k^j qi=Wqaikj=Wkajαi,j=qikj
解释:首先对于每个输入 a i a^i ai(如单词的嵌入),通过乘上 W q W^q Wq获得对应的 q i q^i qi,之后对任意输入 a j a^j aj乘上 W k W^k Wk获得 k j k^j kj,将这两个值相乘即可获得输入 a i a^i ai a j a^j aj的关联程度,即注意力分数为 α i , j = q i ⋅ k j \alpha{_i,_j} = q^i \cdot k^j αi,j=qikj。如下图所示。在这里插入图片描述
当然输入 a i a^i ai也可以用自己的q乘k得到自己与自己的关联程度,比如 α 1 , 1 \alpha{_1,_1} α1,1可由 q 1 ⋅ k 1 q^1 \cdot k^1 q1k1计算得到。最后,将得到的结果输入一个softmax层中,使用softmax函数对注意力分数进行归一化,得到每个输入节点对其他输入向量的权重,即获得处理后的attention sore α ′ \alpha^{\prime} α,如下图所示(右上角是softmax计算公式)。
在这里插入图片描述

2.2 计算value

再将向量 W q W_q Wq乘上对应的输入向量a,得到对应的v值,公式为:
v i = W v ⋅ a i v^i = W_v \cdot a^i vi=Wvai这与 Query 和 Key 的计算方式一致。将每个输入向量对应的v值全部计算出来,如下图所示。

在这里插入图片描述

2.3 计算关联结果b

最终我们可以得到Self-attention针对单个输入 a 1 a_1 a1的输出结果 b 1 b^1 b1
b 1 = ∑ i α 1 , i ′ ⋅ v i b^1 = \sum_{i} \alpha_{1,i}^{\prime} \cdot v^i b1=iα1,ivi根据这个公式,我们可以看出,某个上下文向量同a1的关联程度越高,对应的值在b中的占比就越大。
在这里插入图片描述

2.4 统一计算

上述计算都是针对单个输入进行的,在这里我们将将输入看成一个整体,即一整个输入向量,来再次梳理整个计算步骤。

  1. 首先我们可以将所有输入向量统合到一起,形成一个向量为I,这里依旧用图中的4个输入为例,公式为:
    I = [ a 1   a 2   a 3   a 4 ] I = [ a^{1} \ a^{2} \ a^{3} \ a^{4} ] I=[a1 a2 a3 a4]
  2. 其次我们可以计算统一的Q,K和V:
    Q = [ q 1   q 2   q 3   q 4 ] = W q ⋅ [ a 1   a 2   a 3   a 4 ] = W q ⋅ I K = [ k 1   k 2   k 3   k 4 ] = W k ⋅ [ a 1   a 2   a 3   a 4 ] = W k ⋅ I V = [ v 1   v 2   v 3   v 4 ] = W v ⋅ [ a 1   a 2   a 3   a 4 ] = W v ⋅ I Q = [ q^{1} \ q^{2} \ q^{3} \ q^{4}] = W^q \cdot [ a^{1} \ a^{2} \ a^{3} \ a^{4} ] = W^q \cdot I \\ K = [ k^{1} \ k^{2} \ k^{3} \ k^{4} ] = W^k \cdot [ a^{1} \ a^{2} \ a^{3} \ a^{4} ] = W^k \cdot I\\ V = [ v^{1} \ v^{2} \ v^{3} \ v^{4} ] = W^v \cdot [ a^{1} \ a^{2} \ a^{3} \ a^{4} ] = W^v \cdot I Q=[q1 q2 q3 q4]=Wq[a1 a2 a3 a4]=WqIK=[k1 k2 k3 k4]=Wk[a1 a2 a3 a4]=WkIV=[v1 v2 v3 v4]=Wv[a1 a2 a3 a4]=WvI
  3. 此时,我们可以把输入节点a1的关联分数 α \alpha α 写成如下公式:
    [ α 1 , 1   α 1 , 2   α 1 , 3   α 1 , 4   ] = q 1 ⋅ [ k 1   k 2   k 3   k 4 ] [\alpha_{1,1} \ \alpha_{1,2} \ \alpha_{1,3} \ \alpha_{1,4} \ ] = q^1 \cdot [k^{1} \ k^{2} \ k^{3} \ k^{4}] [α1,1 α1,2 α1,3 α1,4 ]=q1[k1 k2 k3 k4]同理,剩下三个输入向量的关联分数公式为:
    [ α 2 , 1   α 2 , 2   α 2 , 3   α 2 , 4   ] = q 2 ⋅ [ k 1   k 2   k 3   k 4 ] [ α 3 , 1   α 3 , 2   α 3 , 3   α 3 , 4   ] = q 3 ⋅ [ k 1   k 2   k 3   k 4 ] [ α 4 , 1   α 4 , 2   α 4 , 3   α 4 , 4   ] = q 4 ⋅ [ k 1   k 2   k 3   k 4 ] [\alpha_{2,1} \ \alpha_{2,2} \ \alpha_{2,3} \ \alpha_{2,4} \ ] = q^2 \cdot [k^{1} \ k^{2} \ k^{3} \ k^{4}]\\ [\alpha_{3,1} \ \alpha_{3,2} \ \alpha_{3,3} \ \alpha_{3,4} \ ] = q^3 \cdot [k^{1} \ k^{2} \ k^{3} \ k^{4}]\\ [\alpha_{4,1} \ \alpha_{4,2} \ \alpha_{4,3} \ \alpha_{4,4} \ ] = q^4 \cdot [k^{1} \ k^{2} \ k^{3} \ k^{4}] [α2,1 α2,2 α2,3 α2,4 ]=q2[k1 k2 k3 k4][α3,1 α3,2 α3,3 α3,4 ]=q3[k1 k2 k3 k4][α4,1 α4,2 α4,3 α4,4 ]=q4[k1 k2 k3 k4]我们可以发现,当我们把所有输入向量的关联分数放进一个向量A中时,我们会得到:
    A = [ α 1 , 1 α 1 , 2 α 1 , 3 α 1 , 4 α 2 , 1 α 2 , 2 α 2 , 3 α 2 , 4 α 3 , 1 α 3 , 2 α 3 , 3 α 3 , 4 α 4 , 1 α 4 , 2 α 4 , 3 α 4 , 4 ] = [ q 1 q 2 q 3 q 4 ] ⋅ [ k 1 k 2 k 3 k 4 ] = Q T ⋅ K A = \begin{bmatrix} \alpha_{1,1} & \alpha_{1,2} & \alpha_{1,3} & \alpha_{1,4} \\ \alpha_{2,1} & \alpha_{2,2} & \alpha_{2,3} & \alpha_{2,4} \\ \alpha_{3,1} & \alpha_{3,2} & \alpha_{3,3} & \alpha_{3,4} \\ \alpha_{4,1} & \alpha_{4,2} & \alpha_{4,3} & \alpha_{4,4} \\ \end{bmatrix} = \begin{bmatrix} q^1 \\ q^2 \\ q^3 \\ q^4 \\ \end{bmatrix} \cdot \begin{bmatrix} k^1 & k^2 & k^3 & k^4 \end{bmatrix} =Q^T \cdot K A= α1,1α2,1α3,1α4,1α1,2α2,2α3,2α4,2α1,3α2,3α3,3α4,3α1,4α2,4α3,4α4,4 = q1q2q3q4 [k1k2k3k4]=QTK
    最后,通过softmax得到最终的注意力矩阵(Attention Matrix) A ′ A^\prime A值:
    A ′ = [ α 1 , 1 ′ α 1 , 2 ′ α 1 , 3 ′ α 1 , 4 ′ α 2 , 1 ′ α 2 , 2 ′ α 2 , 3 ′ α 2 , 4 ′ α 3 , 1 ′ α 3 , 2 ′ α 3 , 3 ′ α 3 , 4 ′ α 4 , 1 ′ α 4 , 2 ′ α 4 , 3 ′ α 4 , 4 ′ ] = s o f t m a x ( A ) A^\prime = \begin{bmatrix} \alpha_{1,1}^\prime & \alpha_{1,2}^\prime & \alpha_{1,3}^\prime & \alpha_{1,4}^\prime \\ \alpha_{2,1}^\prime & \alpha_{2,2}^\prime & \alpha_{2,3}^\prime & \alpha_{2,4}^\prime \\ \alpha_{3,1}^\prime & \alpha_{3,2}^\prime & \alpha_{3,3}^\prime & \alpha_{3,4} ^\prime\\ \alpha_{4,1}^\prime & \alpha_{4,2}^\prime & \alpha_{4,3}^\prime & \alpha_{4,4}^\prime \\ \end{bmatrix} =softmax( A ) A= α1,1α2,1α3,1α4,1α1,2α2,2α3,2α4,2α1,3α2,3α3,3α4,3α1,4α2,4α3,4α4,4 =softmax(A)
  4. 此时我们已经得到了整体的关联分数向量 A ′ A^\prime A和整体的Value向量,就可以通过相乘得到对应每输入个位置的加权和,就是整体的输出结果向量O:
    O = [ b 1   b 2   b 3   b 4 ] = [ v 1   v 2   v 3   v 4 ] ⋅ A ′ = V ⋅ A ′ O = [b^1\ b^2\ b^3\ b^4] = [v^1\ v^2\ v^3\ v^4] \cdot A^\prime =V \cdot A^\prime O=[b1 b2 b3 b4]=[v1 v2 v3 v4]A=VA

三、总结

以上就是 self-attention 层的计算步骤。尽管看上去复杂,但实际上在这些计算中只有 W q W^q Wq W k W^k Wk W v W^v Wv是需要在网络中学习的参数,输入 I 在输入层就会传递给网络,剩下的都是基于这些参数的计算。
self-attention 机制的核心在于能够并行计算,极大地提升了训练和推理效率,特别适合 GPU 加速。这是 self-attention 相比 RNN 的一个重要优势。通过 self-attention 机制,Transformer 可以在处理长序列时有效地捕捉到序列中的长期依赖关系和上下文信息,解决了传统 RNN 的一些主要问题。在下一篇文章中,我们将深入探讨 Transformer 的整体架构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1790060.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

杂项——STM32ZET6要注意的一些问题——高级定时器问题和PB3,PB4引脚问题

ZET6可能会用到定时器,高级定时器要输出PWM要加上这样一行代码,否则无法正常输出PWM波 TIM_CtrlPWMOutputs(TIM8, ENABLE); // 主输出使能,当使用的是通用定时器时,这句不需要 ZET6中PB3,PB4引脚默认功能是JTDO和NJTRST,如果想将…

python基础篇(2):字符串扩展知识点

1 字符串的三种定义方式 字符串在Python中有多种定义形式: (1)单引号定义法 name 博主帅绝上下五千年 print(name) print(type(name)) 效果如下: (2)双引号定义法 name "博主帅绝上下五千年&qu…

代课老师可以评职称吗?

代课老师可以评职称吗?这个问题颇具争议。代课老师由于其工作性质的特殊性,往往处于职称评审的边缘地带 代课老师,承担着临时或短期的教学任务,填补因各种原因造成的教师空缺。他们的工作性质决定了他们与正式教师在职责和角色上存…

【uni-app】开发问题汇总

文章目录 1、APP获取dom2、添加页面,参考其他页面,国际化就是对应页面的导航的国际化"navigationBarTitleText": "%m.i.ForgetPaymentPassword.bartitle%",3、setStatusBarStyle这个导航栏设置方法不要了,导航栏现在都用…

66、API攻防——接口安全阿里云KEYPostmanDVWS

文章目录 一、工具使用——Postman自动化测试二、安全问题——Dvws泄露&鉴权&XXE三、安全问题——阿里KEY信息泄露利用 dvws-node 一、工具使用——Postman自动化测试 二、安全问题——Dvws泄露&鉴权&XXE 路径中出现/api/,一般都是接口。 请求包是…

Cesium项目报错An error occurred while rendering. Rendering has stopped.

一般就是本地打开会报错,改成用本地服务器打开 全局安装一个live-server sudo cnpm i live-server -g然后新增一个package.json文件 npm init -y然后在package.json的scripts中增加一个命令 "server": "live-server ./ --port8181 --hostlocalhos…

香橙派 AIpro 的系统评测

0. 前言 你好,我是悦创。 今天受邀测评 Orange Pi AIpro开发板,我将准备用这个测试简单的代码来看看这块开发版的性能体验。 分别从:Sysbench、Stress-ng、PyPerformance、RPi.GPIO Benchmark、Geekbench 等方面来测试和分析结果。 下面就…

第15章 面向服务架构设计理论实践

服务是一个由服务提供者提供的,用于满足使用者请求的业务单元。服务的提供者和使用者都是软件代理为了各自的利益而产生的角色。 在面向服务的体系结构(Service-Oriented Architecture,SOA)中,服务的概念有了延伸,泛指系统对外提供的功能集。…

纷享销客BI智能分析平台技术架构介绍

纷享销客BI智能分析平台致力于降低用户上手门槛,无缝继承纷享销客PaaS平台的对象关系模型和权限体系,让使用纷享CRM的营销人员、销售人员、服务人员等各类角色人员都能够将分析场景与业务场景相融合,将数据思维融合到自己的日常工作、团队工作…

【笔记】半车垂向振动模型

模型如下 半车振动模型,被动悬架 力学推导 代码 %书第135页 clc clear close all %% 1.车辆参数 ms=690; Isy=1222; mwf=40.5; mwr=45.4; Ksf=17000; Ksr=22000; Csf=1500; Csr= 1500; Kwf= 192000; Kwr= 192000; a= 1.25; b= 1.51; L=a+b; %% 2.状态方程 A=[0 0 0…

今日科普:了解、预防、控制高血压

高血压,常被称为“隐形的健康威胁”,许多患者可能在毫无预警的情况下发病,且患病率逐年攀升,同时患者群体逐渐年轻化,高血压虽然难以根治,但并不可怕,真正可怕的是血压长期居高不下,…

【一百零一】【算法分析与设计】差分,1109. 航班预订统计,P4231 三步必杀,P5026 Lycanthropy

1109. 航班预订统计 这里有 n 个航班,它们分别从 1 到 n 进行编号。 有一份航班预订表 bookings ,表中第 i 条预订记录 bookings[i] [first(i), last(i), seats(i)] 意味着在从 first(i) 到 last(i) (包含 first(i) 和 last(i) )…

【ArcGISPro SDK】构建多面体要素

结果展示 每个面构建顺序 代码 using ArcGIS.Core.CIM; using ArcGIS.Core.Data; using ArcGIS.Core.Geometry; using ArcGIS.Desktop.Catalog; using ArcGIS.Desktop.Core; using ArcGIS.Desktop.Editing; using ArcGIS.Desktop.Extensions; using ArcGIS.Desktop.Framework;…

JVM学习-内存泄漏

内存泄漏的理解和分类 可达性分析算法来判断对象是否是不再使用的对象,本质都是判断一上对象是否还被引用,对于这种情况下,由于代码的实现不同就会出现很多内存泄漏问题(让JVM误以为此对象还在引用,无法回收,造成内存泄…

写一个盲盒模拟器

最近想写一个小程序,随便写一个玩吧,先想了下功能: 1.有很多盲盒,可以选择模拟开启 2.自定义盲盒,我们可以自定义制作盲盒自己玩 3.用户界面,记录盲盒历史,可以给坏越提意见 所用技术栈&…

数据库 mysql 的彻底卸载

MySQL卸载步骤如下: (1)按 winr 快捷键,在弹出的窗口输入 services.msc,打开服务列表。 (2)在服务列表中, 找到 mysql 开头的所有服务, 右键停止,终止对应的…

拉普拉斯算子

问Chat GPT两种不同拉普拉斯算子的区别:

打印机的ip不同且连不上

打印机的ip不同且连不上 1.问题分析2.修改网段3.验证网络 1.问题分析 主要是打印机的网段和电脑不在同一个网段 2.修改网段 3.验证网络

springcloudalibaba项目注册nacos1.4.2,在nacos上修改配置项不生效问题

背景 之前的项目启动正常,后来发现springcloudalibaba的各版本匹配不正确,于是对项目中的springboot、springcloud、springcloudalibaba版本进行匹配升级,nacos1.4.2匹配的springboot、springcloud、springcloudalibaba版本与我的项目中的版本比较接近,于是我便重新安装了…

汉化PyCharm 2021.1.1 x64超详细教程

1.先打开PyCharm 然后去按住ctrlalts就会打开设置 然后去官网下载 官网地址:Chinese (Simplified) Language Pack / 中文语言包 Plugin for JetBrains IDEs | JetBrains Marketplace 点击进去找到对应的版本进行下载然后将里面的jar包放到 然后重启PyCharm就好了…