Transformer中的Self-Attention和Multi-Head Attention

news2024/12/25 1:10:10

2017 Google 在Computation and Language发表

当时主要针对于自然语言处理(之前的RNN模型记忆长度有限且无法并行化,只有计算完ti时刻后的数据才能计算ti+1时刻的数据,但Transformer都可以做到)

文章提出Self-Attention概念,在此基础上提出Multi-Head Atterntion

下面借鉴霹雳吧啦博主的视频进行学习:


Self-Attention

假设输入的序列长度为2,输入就两个节点x1,x2,然后通过Input Embedding也就是图中的f(x)将输入映射到a1,a2。紧接着分别将a1,a2分别通过三个变换矩阵Wq,Wk,Wv(这三个参数是可训练的,是共享的)得到对应的q^{i},k^{i},v^{i}(直接使用全连接层实现)。

其中:

q代表query,后续会去和每一个k进行匹配

k代表key,后续会被每个q匹配

v代表从a中提取得到的信息

后续q和k匹配的过程可以理解成计算两者的相关性,相关性越大对应v的权重也越大。

假设a_{1}=(1,1),a_{2}=(1,0),W^{q}=\begin{pmatrix} 1,&1 \\ 0,&1 \end{pmatrix}

那么q^{1}=(1,1)\begin{pmatrix} 1, &1 \\ 0, & 1 \end{pmatrix}=(1,2), q^{2}=(1,0)\begin{pmatrix} 1, &1 \\ 0, & 1 \end{pmatrix}=(1,1)

因为Transformer是并行化的,可以直接写成:

\begin{pmatrix} q^{1}\\ q^{2} \end{pmatrix}=\begin{pmatrix} 1, &1 \\ 1, &0 \end{pmatrix}\begin{pmatrix} 1, &1 \\ 0, &1 \end{pmatrix}=\begin{pmatrix} 1, &2 \\ 1, &1 \end{pmatrix}

同理可以得到\begin{pmatrix} k^{1}\\ k^{2} \end{pmatrix}\begin{pmatrix} v^{1}\\ v^{2} \end{pmatrix},那么求得的\begin{pmatrix} q^{1}\\ q^{2} \end{pmatrix}就是原论文中的Q,\begin{pmatrix} k^{1}\\ k^{2} \end{pmatrix}是K,\begin{pmatrix} v^{1}\\ v^{2} \end{pmatrix}是V。接着q^{1}和每个k进行match,点乘操作,接着除以\sqrt{d}得到对应的\alpha,其中d代表向量k^{i}的长度,除以\sqrt{d}的原因是在论文中的解释“进行点乘后数值很大,导致通过softmax后梯度变得很小”,所以通过\sqrt{d}进行缩放。

\alpha _{1,1}=\frac{q^{1}\cdot k^{1}}{\sqrt{d}}=\frac{1*1+2*0}{\sqrt{2}}=0.71\\ \alpha _{1,2}=\frac{q^{1}\cdot k^{2}}{\sqrt{d}}=\frac{1*0+2*1}{\sqrt{2}}=1.41

同理q^{2}去匹配所有的k能得到\alpha _{2,i},统一写成乘法矩阵形式:

\begin{pmatrix} \alpha _{1,1} & \alpha _{1,2} \\ \alpha _{2,1} & \alpha _{2,2} \end{pmatrix}=\frac{\begin{pmatrix} q^{1}\\ q^{2} \end{pmatrix}\begin{pmatrix} k^{1}\\ k^{2} \end{pmatrix}^{T}}{\sqrt{d}}

接着对每一行即(\alpha _{1,1},\alpha _{1,2}),(\alpha _{2,1},\alpha _{2,2})分别进行softmax处理得到(\widehat{\alpha} _{1,1},\widehat{\alpha} _{1,2}),(\widehat{\alpha} _{2,1},\widehat{\alpha} _{2,2}),这里的\widehat{\alpha }相当于计算得到针对每个v的权重。到这里完成了Attention(Q,K,V)公式中的softmax(\frac{QK^{T}}{\sqrt{d_{k}}})部分。

上面已经计算得到\alpha,即针对每个v的权重,接着进行加权得到最终结果

b_{1}=\widehat{\alpha }_{1,1}\times v^{1}+\widehat{\alpha }_{1,2}\times v^{2}=(0.33,0.67)\\ b_{2}=\widehat{\alpha }_{2,1}\times v^{1}+\widehat{\alpha }_{2,2}\times v^{2}=(0.50,0.50)

统一写成矩阵乘法形式:

\begin{pmatrix} b_{1}\\ b_{2} \end{pmatrix}=\begin{pmatrix} \widehat{\alpha }_{1,1} & \widehat{\alpha }_{1,2}\\ \widehat{\alpha }_{2,1}& \widehat{\alpha }_{2,2} \end{pmatrix}\begin{pmatrix} v^{1}\\ v^{2} \end{pmatrix}

Self-Attention的内容就结束了,总结下来就是论文中一个公式:

 Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V


Multi-Head Attention

多头注意力机制能联合来自不同head部分学习到的信息。

首先还是和Self-Attention模块一样将a_{i}分别通过W^{q},W^{k},W^{v}得到对应的q^{i},k^{i},v^{i},然后再根据使用的head的数目h进一步把得到的q^{i},k^{i},v^{i}均分成h份。比如下图中假设的h=2然后q^{1}拆分成q^{1,1},q^{1,2},那么q^{1,1}就属于head1,q^{1,2}属于head2。

论文中写的通过W_{i}^{Q},W_{i}^{K},W_{i}^{V}映射得到每个head的Q_{i},K_{i},V_{i}:

head_{i}=Attention(QW_{i}^{Q},KW_{i}^{K},VW_{i}^{V})

其实简单的均分也可以将W_{i}^{Q},W_{i}^{K},W_{i}^{V}设置成对应值来实现均分,比如下图中的Q通过W_{1}^{Q}就能得到均分后的Q_{1}

通过上述方法就能得到每个headi对应的Q_{i},K_{i},V_{i}参数,接下来针对每个head使用Self-Atttention中相同的方法即可得到对应的结果。

Attention(Q_{i},K_{i},V_{i})=softmax(\frac{Q_{i}K_{i}^{T}}{\sqrt{d_{k}}})V_{i}

接着将每个head得到的结果进行concat拼接,比如下图中b1,1(head1得到的b1)和b1,2(head2得到的b1)拼接在一起,b2,1(head得到的b2)和b2,2(head得到的b2)拼接在一起。

接着将拼接后的结果通过W^{O}(可学习的参数)进行融合,如下图,融合后得到最终的结果b1,b2

到这,总结下来就是论文中的两个公式:

MultiHead(Q,K,V)=Concat(head_{1},...,heah_{h})W^{O}\\ where head_{i}=Attention(QW_{i}^{Q},KW_{i}^{K},VW_{i}^{V})

import torch
from fvcore.nn import FlopCountAnalysis

def main():
    #Self-Attention
    a1 = torch.nn.MultiheadAttention(embed_dim=512, num_heads=1)
    a1.proj = torch.nn.Identity() #removr Wo

    #Multi-Head Attention
    a2 = torch.nn.MultiheadAttention(embed_dim=512, num_heads=8)

    #[batch_szie,num_tokens,total_embed_dim]
    t = torch.rand(32, 1024, 512)

    flops1 = FlopCountAnalysis(a1, t)
    print("Self-Attention FLOPs:", flops1.total())

    flops2 = FlopCountAnalysis(a2, t)
    print("Multi-Head Attention FLOPs:",flops2.total())

if __name__ == '__main__':
    main()
Self-Attention FLOPs: 60129542144
Multi-Head Attention FLOPs: 68719476736

其实两者FLOPs的差异只是在最后的W^{O}上,如果把Multi-Head Attentio的W^{O}也删除(即把a2的proj也设置成Identity),可以看出两者FLOPs是一样的:

Self-Attention FLOPs: 60129542144
Multi-Head Attention FLOPs: 60129542144

Positional Encoding

刚才计算是没有考虑到位置信息的。假设在Self-Attention模块中,输入a1,a2,a3得到b1,b2,b3。对于a1而言,a2和a3离它都是一样近且没有先后顺序。假设将输入的顺序改为a1,a2,a3,对结果b1是没有任何影响的。下面是Pytorch的实验,首先使用nn.MultiheadAttention创建一个Self-Attention模块(num_heads=1),注意这里在正向传播过程中直接传入QKV,接着创建两个顺序不同的QKV变量t1和t2(主要是将q2,k2,v2和q3,k3,v3的顺序换了下),分别将这两个变量输入Self-Attention模块进行正向传播。

import torch
import torch.nn as nn

m = nn.MultiheadAttention(embed_dim=2, num_heads=1)

t1 = [[[1., 2.], #q1,k1,v1
            [2., 3.], #q2,k2,v2
            [3., 4.]]] #q3,k3,v3

t2 = [[[1., 2.], #q1,k1,v1
            [3., 4.], #q3,k3,v3
            [2., 3.]]] #q2,k2,v2

q, k, v  = torch.as_tensor(t1), torch.as_tensor(t1), torch.as_tensor(t1)
print("result:\n", m(q, k, v))

q, k, v = torch.as_tensor(t2), torch.as_tensor(t2), torch.as_tensor(t2)
print("result2:\n", m(q, k , v))

即使调换了qkv顺序,但对b1是没有影响的。

为了引入位置信息,原论文引入了位置编码positional encoding。如下图所示,位置编码是直接加在输入的a={a1,...,an}中的,即pe={pe1,...,pen}和a={a1,...,an}拥有相同维度大小。关于位置编码在原论文有提出两种方案,一种是原论文中使用的固定编码,即论文中给出的sine and cosine funtions方法,按照该方法可计算出位置编码;另一种是可训练的位置编码。ViT论文中使用的是可训练的位置编码。positional encoding


超参对比

关于Transformer中的一些超参数的实验对比可以参考原论文,其中:

N表示重复堆叠的Transformer Block的次数

dmodel表示Multi-Head Self-Attention输入输出的token维度(向量长度)

dff表示在MLP(feed forward)中隐层的节点个数

h表示Multi-Head Self-Attention中的head的个数

dk,dv表示Multi-Head Self-Attention 中每个head的key(K)以及query(Q)的维度

Pdrop表示dropout层的drop_rate

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1832088.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python学习笔记-06

函数进阶 1.无参数无返回值:这类函数往往用于提示信息打印 2.无参数有返回值:这类函数往往用于数据采集过程中 3.有参数有返回值:这类函数一般是计算型的 4.有参数无返回值:这类函数多用于设置某些不需要返回值的参数设置1.局部变…

实验2:RIPv2的配置

由于RIPv1是有类别的路由协议,路由更新不携带子网信息,不支持不连续子网、VLSM、手工汇总和验证等,本书重点讨论RIPv2。 1、实验目的 通过本实验可以掌握: RIPv1和 RIPv2的区别。在路由器上启动RIPv2路由进程。激活参与RIPv2路由协议的接口。auto-sum…

一个提问高下立见?国产AI大模型冲上扣子广场PK

以“国产GPTs”出名的扣子,做出了GPT没有的功能。 6月12日,字节跳动旗下的AI应用开发平台“扣子”(Coze国内版)悄悄上线了新功能“模型广场”。 扣子是AI应用开发平台,无论用户是否有编程基础,都可以在扣子…

OpenTiny CCF开源创新大赛赛事指南,助力你赢取10W赛事奖金

第七届CCF开源创新大赛在国家自然科学基金委信息科学部的指导下,由中国计算机学会(CCF)主办,长沙理工大学、CCF 开源发展委员会联合承办。大赛面向国家“十四五”开源生态发展战略布局,聚焦“卡脖子”软件领域以及人工…

clickhouse学习笔记(四)库、表、分区相关DDL操作

目录 一、数据库操作 1、创建数据库 2、查询及选择数据库 3、删除数据库 二、数据表操作 1、创建表 2、删除表 3、基本操作 ①追加新字段 ②修改字段类型或默认值 ③修改字段注释 ④删除已有字段 ⑤移动数据表(重命名) ⑥清空表 三、默认值…

【leetcode刷题】面试经典150题 , 27. 移除元素

leetcode刷题 面试经典150 27. 移除元素 难度:简单 文章目录 一、题目内容二、自己实现代码2.1 方法一:直接硬找2.1.1 实现思路2.1.2 实现代码2.1.3 结果分析 2.2 方法二:排序整体删除再补充2.1.1 实现思路2.1.2 实现代码2.1.3 结果分析 三、…

day12--150. 逆波兰表达式求值+239. 滑动窗口最大值+ 347. 前 K 个高频元素

一、150. 逆波兰表达式求值 题目链接:https://leetcode.cn/problems/evaluate-reverse-polish-notation/description/ 文章讲解:https://programmercarl.com/0150.%E9%80%86%E6%B3%A2%E5%85%B0%E8%A1%A8%E8%BE%BE%E5%BC%8F%E6%B1%82%E5%80%BC.html 视频…

QT 的文件

QT 和C、linux 一样,也有自带的文件系统. 它的操作和C、c差不多,不过也需要我们来了解一下。 输入输出设备类 QObject 有一个子类,名为 QIODevice 类,如其名字,该类是管理所有输入输出设备的类。 比如文件、网络套…

Java获取本机IP地址的方法(内网、公网)

起因是公司一个springboot项目启动类打印了本机IP地址加端口号,方便访问项目页面,但是发现打印出来的不是“无线局域网”的ip而是“以太网适配器”ip,如下图所示 这样就导致后续本地起项目连接xxl-job注册节点的时候因为不在同个局域网下ping…

Arcgis投影问题

今天下载数据,右键查看属性,发现只有地理坐标系,在arcgis里面进行展示有点丑 怎么变成下面的? 步骤1:加载数据 打开ArcGIS Pro或ArcMap。在目录窗口中,右键点击“文件夹连接”或“文件夹”选项&#xff0c…

苹果的后来者居上策略:靠隐私保护打脸微软

01.苹果与微软相比更注重用户隐私 我一直是Windows的忠实用户,但微软疯狂地将人工智能融入一切,让我开始觉得应该咬咬牙换成Mac。 自小我几乎只用Windows电脑,所以我对MacOS一直不太适应。虽然Windows 11有其缺点,但总的来说&am…

设计四大基本原则的全面解析

每每问起设计四大基本原则,无论是蜚荣全球的业内大咖还是初出茅庐的张三李四,都会不约而同地告诉你一个答案:亲密性、对齐、重复、对比。 自罗宾威廉姆斯于《写给大家看的设计书》中提出后,四大基本原则涵盖了品牌、电商、包装、…

【数据结构初阶】--- 堆

文章目录 一、什么是堆?树二叉树完全二叉树堆的分类堆的实现方法 二、堆的操作堆的定义初始化插入数据(包含向上调整详细讲解)向上调整删除堆顶元素(包含向下调整详细讲解)向下调整返回堆顶元素判断堆是否为空销毁 三、…

时间同步概念及常见的时间同步协议NTP PTP

一、前言 前面几篇文章介绍了Linux中的各种各样的时间、时钟源以及时间维护的方式,其中在timekeeper等数据结构中,我们当时略过了NTP相关的字段,为了补充这一段内容,从本篇开始会介绍时间同步的基本概念、及常见的时间同步协议&am…

2024年春季学期《算法分析与设计》练习15

问题 A: 简单递归求和 题目描述 使用递归编写一个程序求如下表达式前n项的计算结果&#xff1a; (n<100) 1 - 3 5 - 7 9 - 11 ...... 输入n&#xff0c;输出表达式的计算结果。 输入 多组输入&#xff0c;每组输入一个n&#xff0c;n<100。 输出 输出表达式的计…

定时器介绍之8253芯片

目录 定时器简介 8253功能介绍 组成 工作原理 相关引脚 启动方法 计数方式 实现 读取计数值 定时器简介 8253功能介绍 内部结构 相关引脚 计数器组成 工作原理 启动方法 计数方式 初始化&#xff1a;写入控制字——>写入计数初值 实现 计数长度选择&#xff1a…

Python 全栈系列254 异步服务与并发调用

说明 发现对于异步(IO)还是太陌生了&#xff0c;熟悉一下。 内容 今天搞了一整天&#xff0c;感觉有一个long story to tell&#xff0c;但是不知道从何说起&#xff0c;哈哈。 异步(协程)需要保证链路上的所有环节都是异步(协程)的&#xff0c;任何一个环节没这么做都会导致…

CSS文本超限后使用省略号代替

方案一&#xff1a; 只显示一行&#xff0c;超限后使用省略号代替 .detail {overflow: hidden;text-overflow: ellipsis;white-space: nowrap; }方案二&#xff1a; 显示多行&#xff0c;到最后一行还没有显示完&#xff0c;则最后一行多出来的部分使用省略号代替。 .detai…

如何通过Appium连接真机调试

1、打开appium&#xff0c;点击启动appium服务器&#xff08;如图1&#xff09; 2、appium启动成功后&#xff0c;点击放大镜启动检查会话&#xff08;如图2&#xff09; 3、填写真机设备信息和APP的package、activity,点击启动会话&#xff08;如图3&#xff09; 4、打开运行A…

C#——字典diction详情

字典 字典: 包含一个key(键)和这个key所以对应的value&#xff08;值&#xff09;&#xff0c;字典是是无序的&#xff0c;key是唯一的&#xff0c;可以根据key获取值。 定义字典: new Diction<key的类型&#xff0c;value的类型>() 方法 添加 var dic new Dictionar…