图解transformer中的自注意力机制

news2024/10/6 18:33:45

本文将将介绍注意力的概念从何而来,它是如何工作的以及它的简单的实现。

注意力机制

在整个注意力过程中,模型会学习了三个权重:查询、键和值。查询、键和值的思想来源于信息检索系统。所以我们先理解数据库查询的思想。

假设有一个数据库,里面有所有一些作家和他们的书籍信息。现在我想读一些Rabindranath写的书:

在数据库中,作者名字类似于键,图书类似于值。查询的关键词Rabindranath是这个问题的键。所以需要计算查询和数据库的键(数据库中的所有作者)之间的相似度,然后返回最相似作者的值(书籍)。

同样,注意力有三个矩阵,分别是查询矩阵(Q)、键矩阵(K)和值矩阵(V)。它们中的每一个都具有与输入嵌入相同的维数。模型在训练中学习这些度量的值。

我们可以假设我们从每个单词中创建一个向量,这样我们就可以处理信息。对于每个单词,生成一个512维的向量。所有3个矩阵都是512x512(因为单词嵌入的维度是512)。对于每个标记嵌入,我们将其与所有三个矩阵(Q, K, V)相乘,每个标记将有3个长度为512的中间向量。

接下来计算分数,它是查询和键向量之间的点积。分数决定了当我们在某个位置编码单词时,对输入句子的其他部分的关注程度。

然后将点积除以关键向量维数的平方根。这种缩放是为了防止点积变得太大或太小(取决于正值或负值),因为这可能导致训练期间的数值不稳定。选择比例因子是为了确保点积的方差近似等于1。

然后通过softmax操作传递结果。这将分数标准化:它们都是正的,并且加起来等于1。softmax输出决定了我们应该从不同的单词中获取多少信息或特征(值),也就是在计算权重。

这里需要注意的一点是,为什么需要其他单词的信息/特征?因为我们的语言是有上下文含义的,一个相同的单词出现在不同的语境,含义也不一样。

最后一步就是计算softmax与这些值的乘积,并将它们相加。

可视化图解

上面逻辑都是文字内容,看起来有一些枯燥,下面我们可视化它的矢量化实现。这样可以更加深入的理解。

查询键和矩阵的计算方法如下

同样的方法可以计算键向量和值向量。

最后计算得分和注意力输出。

简单代码实现

 importtorch
 importtorch.nnasnn
 fromtypingimportList
 
 defget_input_embeddings(words: List[str], embeddings_dim: int):
     # we are creating random vector of embeddings_dim size for each words
     # normally we train a tokenizer to get the embeddings.
     # check the blog on tokenizer to learn about this part
     embeddings= [torch.randn(embeddings_dim) forwordinwords]
     returnembeddings
 
 
 text="I should sleep now"
 words=text.split(" ")
 len(words) # 4
 
 
 embeddings_dim=512# 512 dim because the original paper uses it. we can use other dim also
 embeddings=get_input_embeddings(words, embeddings_dim=embeddings_dim)
 embeddings[0].shape# torch.Size([512])
 
 
 # initialize the query, key and value metrices 
 query_matrix=nn.Linear(embeddings_dim, embeddings_dim)
 key_matrix=nn.Linear(embeddings_dim, embeddings_dim)
 value_matrix=nn.Linear(embeddings_dim, embeddings_dim)
 query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape# torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])
 
 
 # query, key and value vectors computation for each words embeddings
 query_vectors=torch.stack([query_matrix(embedding) forembeddinginembeddings])
 key_vectors=torch.stack([key_matrix(embedding) forembeddinginembeddings])
 value_vectors=torch.stack([value_matrix(embedding) forembeddinginembeddings])
 query_vectors.shape, key_vectors.shape, value_vectors.shape# torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])
 
 
 # compute the score
 scores=torch.matmul(query_vectors, key_vectors.transpose(-2, -1)) /torch.sqrt(torch.tensor(embeddings_dim, dtype=torch.float32))
 scores.shape# torch.Size([4, 4])
 
 
 # compute the attention weights for each of the words with the other words
 softmax=nn.Softmax(dim=-1)
 attention_weights=softmax(scores)
 attention_weights.shape# torch.Size([4, 4])
 
 
 # attention output
 output=torch.matmul(attention_weights, value_vectors)
 output.shape# torch.Size([4, 512])

以上代码只是为了展示注意力机制的实现,并未优化。

多头注意力

上面提到的注意力是单头注意力,在原论文中有8个头。对于多头和单多头注意力计算相同,只是查询(q0-q3),键(k0-k3),值(v0-v3)中间向量会有一些区别。

之后将查询向量分成相等的部分(有多少头就分成多少)。在上图中有8个头,查询,键和值向量的维度为512。所以就变为了8个64维的向量。

把前64个向量放到第一个头,第二组向量放到第二个头,以此类推。在上面的图片中,我只展示了第一个头的计算。

这里需要注意的是:不同的框架有不同的实现方法,pytorch官方的实现是上面这种,但是tf和一些第三方的代码中是将每个头分开计算了,比如8个头会使用8个linear(tf的dense)而不是一个大linear再拆解。还记得Pytorch的transformer里面要求emb_dim能被num_heads整除吗,就是因为这个

使用哪种方式都可以,因为最终的结果都类似影响不大。

当我们在一个head中有了小查询、键和值(64 dim的)之后,计算剩下的逻辑与单个head注意相同。最后得到的64维的向量来自每个头。

我们将每个头的64个输出组合起来,得到最后的512个dim输出向量。

多头注意力可以表示数据中的复杂关系。每个头都能学习不同的模式。多个头还提供了同时处理输入表示的不同子空间(本例:64个向量表示512个原始向量)的能力。

多头注意代码实现

 num_heads=8
 # batch dim is 1 since we are processing one text.
 batch_size=1
 
 text="I should sleep now"
 words=text.split(" ")
 len(words) # 4
 
 
 embeddings_dim=512
 embeddings=get_input_embeddings(words, embeddings_dim=embeddings_dim)
 embeddings[0].shape# torch.Size([512])
 
 
 # initialize the query, key and value metrices 
 query_matrix=nn.Linear(embeddings_dim, embeddings_dim)
 key_matrix=nn.Linear(embeddings_dim, embeddings_dim)
 value_matrix=nn.Linear(embeddings_dim, embeddings_dim)
 query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape# torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])
 
 
 # query, key and value vectors computation for each words embeddings
 query_vectors=torch.stack([query_matrix(embedding) forembeddinginembeddings])
 key_vectors=torch.stack([key_matrix(embedding) forembeddinginembeddings])
 value_vectors=torch.stack([value_matrix(embedding) forembeddinginembeddings])
 query_vectors.shape, key_vectors.shape, value_vectors.shape# torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])
 
 
 # (batch_size, num_heads, seq_len, embeddings_dim)
 query_vectors_view=query_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2) 
 key_vectors_view=key_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2) 
 value_vectors_view=value_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2) 
 query_vectors_view.shape, key_vectors_view.shape, value_vectors_view.shape
 # torch.Size([1, 8, 4, 64]),
 #  torch.Size([1, 8, 4, 64]),
 #  torch.Size([1, 8, 4, 64])
 
 
 # We are splitting the each vectors into 8 heads. 
 # Assuming we have one text (batch size of 1), So we split 
 # the embedding vectors also into 8 parts. Each head will 
 # take these parts. If we do this one head at a time.
 head1_query_vector=query_vectors_view[0, 0, ...]
 head1_key_vector=key_vectors_view[0, 0, ...]
 head1_value_vector=value_vectors_view[0, 0, ...]
 head1_query_vector.shape, head1_key_vector.shape, head1_value_vector.shape
 
 
 # The above vectors are of same size as before only the feature dim is changed from 512 to 64
 # compute the score
 scores_head1=torch.matmul(head1_query_vector, head1_key_vector.permute(1, 0)) /torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))
 scores_head1.shape# torch.Size([4, 4])
 
 
 # compute the attention weights for each of the words with the other words
 softmax=nn.Softmax(dim=-1)
 attention_weights_head1=softmax(scores_head1)
 attention_weights_head1.shape# torch.Size([4, 4])
 
 output_head1=torch.matmul(attention_weights_head1, head1_value_vector)
 output_head1.shape# torch.Size([4, 512])
 
 
 # we can compute the output for all the heads
 outputs= []
 forhead_idxinrange(num_heads):
     head_idx_query_vector=query_vectors_view[0, head_idx, ...]
     head_idx_key_vector=key_vectors_view[0, head_idx, ...]
     head_idx_value_vector=value_vectors_view[0, head_idx, ...]
     scores_head_idx=torch.matmul(head_idx_query_vector, head_idx_key_vector.permute(1, 0)) /torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))
 
     softmax=nn.Softmax(dim=-1)
     attention_weights_idx=softmax(scores_head_idx)
     output=torch.matmul(attention_weights_idx, head_idx_value_vector)
     outputs.append(output)
 
 [out.shapeforoutinoutputs]
 # [torch.Size([4, 64]),
 #  torch.Size([4, 64]),
 #  torch.Size([4, 64]),
 #  torch.Size([4, 64]),
 #  torch.Size([4, 64]),
 #  torch.Size([4, 64]),
 #  torch.Size([4, 64]),
 #  torch.Size([4, 64])]
 
 # stack the result from each heads for the corresponding words
 word0_outputs=torch.cat([out[0] foroutinoutputs])
 word0_outputs.shape
 
 # lets do it for all the words
 attn_outputs= []
 foriinrange(len(words)):
     attn_output=torch.cat([out[i] foroutinoutputs])
     attn_outputs.append(attn_output)
 [attn_output.shapeforattn_outputinattn_outputs] # [torch.Size([512]), torch.Size([512]), torch.Size([512]), torch.Size([512])]
 
 
 # Now lets do it in vectorize way. 
 # We can not permute the last two dimension of the key vector.
 key_vectors_view.permute(0, 1, 3, 2).shape# torch.Size([1, 8, 64, 4])
 
 
 # Transpose the key vector on the last dim
 score=torch.matmul(query_vectors_view, key_vectors_view.permute(0, 1, 3, 2)) # Q*k
 score=torch.softmax(score, dim=-1)
 
 
 # reshape the results 
 attention_results=torch.matmul(score, value_vectors_view)
 attention_results.shape# [1, 8, 4, 64]
 
 # merge the results
 attention_results=attention_results.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, embeddings_dim)
 attention_results.shape# torch.Size([1, 4, 512])

总结

注意力机制(attention mechanism)是Transformer模型中的重要组成部分。Transformer是一种基于自注意力机制(self-attention)的神经网络模型,广泛应用于自然语言处理任务,如机器翻译、文本生成和语言模型等。本文介绍的自注意力机制是Transformer模型的基础,在此基础之上衍生发展出了各种不同的更加高效的注意力机制,所以深入了解自注意力机制,将能够更好地理解Transformer模型的设计原理和工作机制,以及如何在具体的各种任务中应用和调整模型。这将有助于你更有效地使用Transformer模型并进行相关研究和开发。

最后有兴趣的可以看看这个,它里面包含了pytorch的transformer的完整实现:

https://avoid.overfit.cn/post/c3f0da0fd4bd4151a8f79741ebc09937

作者:Souvik Mandal

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/665971.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Django-带参数的路由编写(二)【用正则表达式匹配复杂路由】

在上一篇博文中,学习了“不用正则表达式匹配的简单带参数路由”,详情见链接: https://blog.csdn.net/wenhao_ir/article/details/131225388 本篇博文学习用“用正则表达式匹配复杂路由”。 简单的参数路由用库django.urls中的函数path()就可…

内涝监测系统如何助力城市防洪抗涝

近年来,各地内涝问题愈发严重,强降雨天气导致城市内涝已经屡见不鲜了,城市内涝不仅影响城市交通、居民生活,还可能对建筑物和基础设施造成损害,给城市运行带来重大风险。内涝治理除了要解决城市“里子”问题&#xff0…

【2023,学点儿新Java-11】基础案例练习:输出个人基础信息、输出心形 | Java中 制表符\t 和 换行符\n 的简单练习

前情回顾: 【2023,学点儿新Java-10】Java17 API文档简介&获取 |详解Java核心机制:JVM |详解Java内存泄漏与溢出 |Java优缺点总结 |附:GPT3.5-turbo问答测试【2023,学点儿新Java-09】Java初学者常会犯的错误总结与…

数据库SQL Server实验报告 之 SQL语言进行数据更新(6/8)

SQL语言进行数据更新 生命的本质是一场历练 实验目的及要求: 掌握如何使用sql语句进行插入、删除和更新操作。使用sql语句进行插入操作。使用sql语句进行删除操作。使用sql语句进行更新操作。使用各种查询条件完成指定的查询操作 实验内容及步骤&#xff1a…

计算机基础--->网络(2)【TCP、UDP、IP、ARP】

文章目录 TCP与UDP的区别TCP三次握手和四次挥手为什么要三次握手?第二次握手传回了ACK,为什么还要传回SYN?为什么要四次挥手?为什么不能将服务器发送的ACK和FIN合并起来,变成三次挥手?TCP如何保证传输的可靠…

推荐召回-Swing

概述 swing 是阿里原创的 i2i 召回算法,在阿里内部的多个业务场景被验证是一种非常有效的召回方法。据笔者了解,swing 在工业界已得到比较广泛的使用,抖音,小红书,B 站等推荐系统均使用了swing i2i。 1.传统 icf 算法…

MySql常见问题(长期更新)

基于mysql 8.0.3版本 一、忘记root密码1.1 、linux 系统下忘记密码1.2、Windows 系统下忘记密码1.3 Unix 和类 Unix 系统 二、账号问题2.1 远程访问账号设置 一、忘记root密码 1.1 、linux 系统下忘记密码 啥?你问我为什么会忘记密码?别问,…

Flutter状态管理新的实践 | 京东云技术团队

1 背景介绍 1.1 声明式ui 声明式UI其实并不是近几年的新技术,但是近几年声明式UI框架非常的火热。单说移动端,跨平台方案有:RN、Flutter。iOS原生有:SwiftUI。android原生有:compose。可以看到声明式UI是以后的前端发…

大数据从0到1的完美落地之sqoop优化

Sqoop的Job与优化 Job操作 job的好处: 1、一次创建,后面不需要创建,可重复执行job即可 2、它可以帮我们记录增量导入数据的最后记录值 3、job的元数据默认存储目录:$HOME/.sqoop/ 4、job的元数据也可以存储于mysql中。 复制代码…

C# 特性总结

目录 特性是什么? 如何使用特性? (1).Net 框架预定义特性 (2)自定义特性 为什么要使用特性? 特性的应用 特性实现枚举展示描述信息 特性是什么? 特性(Attribute&…

拉新、转化、留存,一个做不好,就可能会噶?

用户周期 对于我们各个平台来说(CSDN也是),我们用户都会有一个生命周期:引入期–成长期–成熟期–休眠期–流失期。 而一般获客就在引入期,在这个时候我们会通过推广的手段进行拉新;升值期则发生在成长期…

智能制造工厂的SCADA解决方案应用

智能制造工厂是当今工业领域的一个重要趋势,它将传统的生产模式与现代信息技术相结合,实现了生产过程的智能化和自动化。 SCADA是一种监控与数据采集系统,广泛应用于工业自动化领域,它通过传感器、控制器和网络等设备&#xff0c…

基于Python机器学习算法小分子药性预测(岭回归+随机森林回归+极端森林回归+加权平均融合模型)

目录 前言总体设计系统整体结构图系统流程图 运行环境Python 环境配置工具包 模块实现1. 数据预处理2. 创建模型并编译3. 模型训练 系统测试工程源代码下载其它资料下载 前言 《麻省理工科技评论》于2020年发布了“十大突破性技术”预测,其中包括“AI药物分子发现”…

一文说透!华熙生物如何步步为营炼就品牌势能?

据华熙生物2022年财报,华熙生物2022年营收同比增长28.53%,净利润同比增长24%,成为全球最大的。同时,近年来也在C端也大展身手。华熙生物此前与故宫博物院合作,推出6 款故宫国宝色口红和2款“故宫美人面膜”。凭借精美的…

中小型企业需要官网和帮助中心吗?为何这样说?

随着互联网技术的不断发展,越来越多的中小型企业开始重视拥有自己的官网和帮助中心。但是,对于许多刚刚起步的中小型企业来说,官网和帮助中心的建设可能需要一定的成本和时间投入。那么,中小型企业是否需要官网和帮助中心呢&#…

python(11):python读取excel、csv文件

1.python读取excel文件 要读取Excel表格的指定行和列范围,可以使用Python中的第三方库pandas。pandas库提供了强大的数据分析和处理工具,包括读取和处理Excel文件的功能。以下是一个示例代码,演示了如何使用pandas库读取Excel表格中的指定行…

[Go]-Go语言第一课

1-1 Go语言特点 特点: 1. 静态类型,编译开源语言2. 脚本化的语法,支持多种编程范式(函数式,面向对象)3. 原生,给力的并发支持并发编程1-2 Go语言优势与劣势 Go语言的优势: 1.脚本化…

软考A计划-系统集成项目管理工程师-信息化知识(三)

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例点击跳转>软考全系列 👉关于作者 专注于Android/Unity和各种游戏开发技巧&#xff…

加密市场与上一轮周期有何异同?五大因素探讨加密市场未来之路

数字资产市场在一季度表现不俗,但二季度的表现却出现了相反的情况。数据显示,BTC 在一季度累计上涨了 71.69%,而二季度截至目前下跌了 7.31%。这样的变化主要是由金融监管机构针对整个数字资产行业采取的监管行动造成的。虽然 BTC 今年以来仍…

uniapp中uni-popup的用法——实例讲解

uni-pop弹出层组件,在应用中弹出一个消息提示窗口、提示框等,可以设置弹出层的位置,是中间、底部、还是顶部。 如下图效果所示:白色区域则为弹出的pop层。 一、 创建一个自定义组件: 1.项目中安装下载uni-pop插件。 2.把pop内容…