图解transformer中的自注意力机制(备忘)

news2025/1/22 19:13:55

注意力机制

在整个注意力过程中,模型会学习了三个权重:查询、键和值。查询、键和值的思想来源于信息检索系统。所以我们先理解数据库查询的思想。

假设有一个数据库,里面有所有一些作家和他们的书籍信息。现在我想读一些Rabindranath写的书:

在数据库中,作者名字类似于键,图书类似于值。查询的关键词Rabindranath是这个问题的键。所以需要计算查询和数据库的键(数据库中的所有作者)之间的相似度,然后返回最相似作者的值(书籍)。

同样,注意力有三个矩阵,分别是查询矩阵(Q)、键矩阵(K)和值矩阵(V)。它们中的每一个都具有与输入嵌入相同的维数。模型在训练中学习这些度量的值。

我们可以假设我们从每个单词中创建一个向量,这样我们就可以处理信息。对于每个单词,生成一个512维的向量。所有3个矩阵都是512x512(因为单词嵌入的维度是512)。对于每个标记嵌入,我们将其与所有三个矩阵(Q, K, V)相乘,每个标记将有3个长度为512的中间向量。

接下来计算分数,它是查询和键向量之间的点积。分数决定了当我们在某个位置编码单词时,对输入句子的其他部分的关注程度。

然后将点积除以关键向量维数的平方根。这种缩放是为了防止点积变得太大或太小(取决于正值或负值),因为这可能导致训练期间的数值不稳定。选择比例因子是为了确保点积的方差近似等于1。

然后通过softmax操作传递结果。这将分数标准化:它们都是正的,并且加起来等于1。softmax输出决定了我们应该从不同的单词中获取多少信息或特征(值),也就是在计算权重。

这里需要注意的一点是,为什么需要其他单词的信息/特征?因为我们的语言是有上下文含义的,一个相同的单词出现在不同的语境,含义也不一样。

最后一步就是计算softmax与这些值的乘积,并将它们相加。

可视化图解

上面逻辑都是文字内容,看起来有一些枯燥,下面我们可视化它的矢量化实现。这样可以更加深入的理解。

查询键和矩阵的计算方法如下

同样的方法可以计算键向量和值向量。

最后计算得分和注意力输出。

简单代码实现

import torch
import torch.nn as nn
from typing import List

def get_input_embeddings(words: List[str], embeddings_dim: int):
# we are creating random vector of embeddings_dim size for each words
# normally we train a tokenizer to get the embeddings.
# check the blog on tokenizer to learn about this part
embeddings = [torch.randn(embeddings_dim) for word in words]
return embeddings


text = "I should sleep now"
words = text.split(" ")
len(words) # 4


embeddings_dim = 512 # 512 dim because the original paper uses it. we can use other dim also
embeddings = get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape # torch.Size([512])


# initialize the query, key and value metrices
query_matrix = nn.Linear(embeddings_dim, embeddings_dim)
key_matrix = nn.Linear(embeddings_dim, embeddings_dim)
value_matrix = nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape # torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])


# query, key and value vectors computation for each words embeddings
query_vectors = torch.stack([query_matrix(embedding) for embedding in embeddings])
key_vectors = torch.stack([key_matrix(embedding) for embedding in embeddings])
value_vectors = torch.stack([value_matrix(embedding) for embedding in embeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape # torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])


# compute the score
scores = torch.matmul(query_vectors, key_vectors.transpose(-2, -1)) / torch.sqrt(torch.tensor(embeddings_dim, dtype=torch.float32))
scores.shape # torch.Size([4, 4])


# compute the attention weights for each of the words with the other words
softmax = nn.Softmax(dim=-1)
attention_weights = softmax(scores)
attention_weights.shape # torch.Size([4, 4])


# attention output
output = torch.matmul(attention_weights, value_vectors)
output.shape # torch.Size([4, 512])

以上代码只是为了展示注意力机制的实现,并未优化。

多头注意力

上面提到的注意力是单头注意力,在原论文中有8个头。对于多头和单多头注意力计算相同,只是查询(q0-q3),键(k0-k3),值(v0-v3)中间向量会有一些区别。

之后将查询向量分成相等的部分(有多少头就分成多少)。在上图中有8个头,查询,键和值向量的维度为512。所以就变为了8个64维的向量。

把前64个向量放到第一个头,第二组向量放到第二个头,以此类推。在上面的图片中,我只展示了第一个头的计算。

这里需要注意的是:不同的框架有不同的实现方法,pytorch官方的实现是上面这种,但是tf和一些第三方的代码中是将每个头分开计算了,比如8个头会使用8个linear(tf的dense)而不是一个大linear再拆解。还记得Pytorch的transformer里面要求emb_dim能被num_heads整除吗,就是因为这个

使用哪种方式都可以,因为最终的结果都类似影响不大。

当我们在一个head中有了小查询、键和值(64 dim的)之后,计算剩下的逻辑与单个head注意相同。最后得到的64维的向量来自每个头。

我们将每个头的64个输出组合起来,得到最后的512个dim输出向量。

多头注意力可以表示数据中的复杂关系。每个头都能学习不同的模式。多个头还提供了同时处理输入表示的不同子空间(本例:64个向量表示512个原始向量)的能力。

多头注意代码实现

num_heads = 8
# batch dim is 1 since we are processing one text.
batch_size = 1

text = "I should sleep now"
words = text.split(" ")
len(words) # 4


embeddings_dim = 512
embeddings = get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape # torch.Size([512])


# initialize the query, key and value metrices
query_matrix = nn.Linear(embeddings_dim, embeddings_dim)
key_matrix = nn.Linear(embeddings_dim, embeddings_dim)
value_matrix = nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape # torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])


# query, key and value vectors computation for each words embeddings
query_vectors = torch.stack([query_matrix(embedding) for embedding in embeddings])
key_vectors = torch.stack([key_matrix(embedding) for embedding in embeddings])
value_vectors = torch.stack([value_matrix(embedding) for embedding in embeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape # torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])


# (batch_size, num_heads, seq_len, embeddings_dim)
query_vectors_view = query_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
key_vectors_view = key_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
value_vectors_view = value_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
query_vectors_view.shape, key_vectors_view.shape, value_vectors_view.shape
# torch.Size([1, 8, 4, 64]),
# torch.Size([1, 8, 4, 64]),
# torch.Size([1, 8, 4, 64])


# We are splitting the each vectors into 8 heads.
# Assuming we have one text (batch size of 1), So we split
# the embedding vectors also into 8 parts. Each head will
# take these parts. If we do this one head at a time.
head1_query_vector = query_vectors_view[0, 0, ...]
head1_key_vector = key_vectors_view[0, 0, ...]
head1_value_vector = value_vectors_view[0, 0, ...]
head1_query_vector.shape, head1_key_vector.shape, head1_value_vector.shape


# The above vectors are of same size as before only the feature dim is changed from 512 to 64
# compute the score
scores_head1 = torch.matmul(head1_query_vector, head1_key_vector.permute(1, 0)) / torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))
scores_head1.shape # torch.Size([4, 4])


# compute the attention weights for each of the words with the other words
softmax = nn.Softmax(dim=-1)
attention_weights_head1 = softmax(scores_head1)
attention_weights_head1.shape # torch.Size([4, 4])

output_head1 = torch.matmul(attention_weights_head1, head1_value_vector)
output_head1.shape # torch.Size([4, 512])


# we can compute the output for all the heads
outputs = []
for head_idx in range(num_heads):
head_idx_query_vector = query_vectors_view[0, head_idx, ...]
head_idx_key_vector = key_vectors_view[0, head_idx, ...]
head_idx_value_vector = value_vectors_view[0, head_idx, ...]
scores_head_idx = torch.matmul(head_idx_query_vector, head_idx_key_vector.permute(1, 0)) / torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))

softmax = nn.Softmax(dim=-1)
attention_weights_idx = softmax(scores_head_idx)
output = torch.matmul(attention_weights_idx, head_idx_value_vector)
outputs.append(output)

[out.shape for out in outputs]
# [torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64])]

# stack the result from each heads for the corresponding words
word0_outputs = torch.cat([out[0] for out in outputs])
word0_outputs.shape

# lets do it for all the words
attn_outputs = []
for i in range(len(words)):
attn_output = torch.cat([out[i] for out in outputs])
attn_outputs.append(attn_output)
[attn_output.shape for attn_output in attn_outputs] # [torch.Size([512]), torch.Size([512]), torch.Size([512]), torch.Size([512])]


# Now lets do it in vectorize way.
# We can not permute the last two dimension of the key vector.
key_vectors_view.permute(0, 1, 3, 2).shape # torch.Size([1, 8, 64, 4])


# Transpose the key vector on the last dim
score = torch.matmul(query_vectors_view, key_vectors_view.permute(0, 1, 3, 2)) # Q*k
score = torch.softmax(score, dim=-1)


# reshape the results
attention_results = torch.matmul(score, value_vectors_view)
attention_results.shape # [1, 8, 4, 64]

# merge the results
attention_results = attention_results.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, embeddings_dim)
attention_results.shape # torch.Size([1, 4, 512])

总结

注意力机制(attention mechanism)是Transformer模型中的重要组成部分。Transformer是一种基于自注意力机制(self-attention)的神经网络模型,广泛应用于自然语言处理任务,如机器翻译、文本生成和语言模型等。本文介绍的自注意力机制是Transformer模型的基础,在此基础之上衍生发展出了各种不同的更加高效的注意力机制,所以深入了解自注意力机制,将能够更好地理解Transformer模型的设计原理和工作机制,以及如何在具体的各种任务中应用和调整模型。这将有助于你更有效地使用Transformer模型并进行相关研究和开发。

原文出处,入侵吾删

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1305354.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络编程----select 模型总结

为什么要使用select模型? 答:解决基本C/S模型中,accept()、recv()、send()阻塞的问题 select模型与C/S模型的不同点 C/S模型中accept()会阻塞一直傻等socket来链接select模型只解决accept()傻等的问题,不解决recv(),send()执行…

Android View闪烁动画AlphaAnimation,Kotlin

Android View闪烁动画AlphaAnimation,Kotlin private fun flickerAnimation(view: View?) {val animation: Animation AlphaAnimation(1f, 0f) //不透明到透明。animation.duration 500 // 1次过程时长。animation.interpolator LinearInterpolator() // 线性速…

一天搞定jmeter入门到入职全套教程之Jmeter分布式测试

随着并发量的增大,一台机器就不能满足需求了,所以我们采用分布式(Master-Slaver)的方案去执行高并发的测试 注意事项: Master机器一般我们不执测试,所以可以拿一台配置差些的机器,主要用来采集…

YOLOv8改进 | 2023主干篇 | 利用RT-DETR特征提取网络PPHGNetV2改进YOLOv8(超级轻量化精度更高)

一、本文介绍 本文给大家带来利用RT-DETR模型主干HGNet去替换YOLOv8的主干,RT-DETR是今年由百度推出的第一款实时的ViT模型,其在实时检测的领域上号称是打败了YOLO系列,其利用两个主干一个是HGNet一个是ResNet,其中HGNet就是我们…

养牛场北斗综合管理系统解决方案

1.系统架构 随着我国北斗卫星导航定位系统的快速发展和定位精度的持续不断提高,在牛身上穿戴定位终端后可以实现对牛的位置和温度的测量,在蜂窝网络正常的情况下,定位和温度数据通过蜂窝网络通信方式回传到监控云平台,在蜂窝网络缺…

使用docker编排容器

使用Dockerfile构建一个自定义的nginx 首先用docker拉一个nginx镜像 docker pull nginx拉取完成后,编辑一个Dockerfile文件 vim Dockerfile命令如下所示,FROM 后面跟的你的基础镜像,而run则是表示你构建镜像时需要执行的指令,下面的指令意…

mysql的ON DELETE CASCADE 和ON DELETE RESTRICT区别

​​ON DELETE CASCADE​​​ 和 ​​ON DELETE RESTRICT​​ 是 MySQL 中两种不同的外键约束级联操作。它们之间的主要区别在于当主表中的记录被删除时,子表中相关记录的处理方式。 ON DELETE CASCADE:当在主表中删除一条记录时,所有与之相关的子表中的匹配记录也会被自动删…

Android studio:打开应用程序闪退的问题2.0

目录 找到问题分析问题解决办法 找到问题 老生常谈,可能这东西真的很常见吧,在之前那篇文章中 linkhttp://t.csdnimg.cn/UJQNb 已经谈到了关于打开Androidstuidio开发的软件后明明没有报错却无法运行(具体表现为应用程序闪退的问题&#xff…

MySQL之创建时间类型的字段表

mysql之创建时间类型的字段表 CREATE TABLE tab(birthday DATE, -- 生日job_time DATETIME, -- 记录年月日时分秒login_time TIMESTAMP -- 时间戳NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP )解释: NOT NULL DEFAULT :默认不为空…

快速幂(C语言)

前言 快速幂算法一般用于高次幂取模的题目中,比如求3的10000次方对7取模。这时候有些同学会说:这还不简单?我直接调用pow函数然后对结果%7不得了么?可是3的10000次方这么庞大的数字,真的能储存在计算机里么&#xff1f…

Logstash输入Kafka输出Es配置

Logstash介绍 Logstash是一个开源的数据收集引擎,具有实时管道功能。它可以从各种数据源中动态地统一和标准化数据,并将其发送到你选择的目的地。Logstash的早期目标主要是用于收集日志,但现在的功能已经远远超出这个范围。任何事件类型都可…

技术资讯:VSCode大更新,这两个功能终于有了

大家好,我是大澈! 本文约1200字,整篇阅读大约需要2分钟。 感谢关注微信公众号:“程序员大澈”,然后免费加入问答群,从此让解决问题的你不再孤单! 1. 资讯速览 就在前阵子,前端人都…

【Android嵌入式开发及实训课程实验】【项目1】 图形界面——计算器项目

【项目1】 图形界面——计算器项目 需求分析界面设计实施1、创建项目2、 界面实现实现代码1.activity_main.xml2.Java代码 - MainActivity.java 3、运行测试 注意点结束~ 需求分析 开发一个简单的计算器项目,该程序只能进行加减乘除运算。要求界面美观,…

【Java 基础】27 XML 解析

文章目录 1.SAX 解析器1)什么是 SAX2)SAX 工作流程初始化实现事件处理类解析 3)示例代码 2.DOM 解析器1)什么是 DOM2)DOM 工作流程初始化解析 XML 文档操作 DOM 树 3)示例代码 总结 在项目开发中&#xff0…

电脑ffmpeg.dll丢失如何修复?3个详细修复的教程分享

在计算机使用过程中,我们经常会遇到一些错误提示,其中之一就是“ffmpeg.dll丢失”。ffmpeg.dll是FFmpeg多媒体框架中的一个重要组件,它负责处理音频和视频的编解码。当这个文件丢失或损坏时,可能会导致一些应用程序无法正常运行。…

iframe 与主应用页面之间如何互相通信传递数据

背景 当我们的Web页面需要复用现有网站的页面时,我们通常会考虑代码层面的抽离引用,但是对于一些过于复杂的页面,通过 iframe 嵌套现有的网站页面也是一种不错的方式,。目前我就职的项目组就有多个业务利用 iframe 完成业务的复用…

【数据结构】堆的模拟实现

前言:前面我们学习了顺序表、单链表、栈、队列,今天我们就开始新的学习吧,今天我们将进入堆的学习!(最近博主处于低谷期)一起加油吧各位。 💖 博主CSDN主页:卫卫卫的个人主页 💞 👉 专栏分类:数据结构 &…

在AWS EC2中部署和使用Apache Superset的方案

大纲 1 Superset部署1.1 启动AWS EC21.2 下载Superset Docker文件1.3 修改Dockerfile1.4 配置管理员1.5 结果展示1.6 检查数据库驱动1.7 常见错误处理 2 Glue(可选参考)3 IAM与安全组3.1 使用AWS Athena3.2 使用AWS RedShift或AWS RDS3.2.1 查看AWS Reds…

MySQL8.0默认配置详解--持续更新中

binlog日志的默认保留数量和大小 在MySQL 8.0中,您可以使用以下SQL命令来查询binlog日志的默认保留数量和大小: SHOW VARIABLES LIKE binlog_expire_logs_seconds; SHOW VARIABLES LIKE max_binlog_size;binlog_expire_logs_seconds 变量表示binlog日志…

食品进销存系统哪个好?亿发商品信息管理系统,操作简单好用,可定制

元旦将近,年的味道也越来越浓厚。年货置办的人越来越多,食品店也迎来年底的生意旺季。但众所周知,食品行业作为一个商品品类众多、品牌繁多且商品销售价格波动频繁的领域,常常面临商品批次管理的挑战,特别是需要注意避…