自然语言处理: 第五章Attention注意力机制

news2025/1/18 17:08:12

自然语言处理: 第五章Attention注意力机制

理论基础

Attention(来自2017年google发表的[1706.03762] Attention Is All You Need (arxiv.org) ),顾名思义是注意力机制,字面意思就是你所关注的东西,比如我们看到一个非常非常的故事的时候,但是其实我们一般能用5W2H就能很好的归纳这个故事,所以我们在复述或者归纳一段文字的时候,我们肯定有我们所关注的点,这些关注的点就是我们的注意力,而类似How 或者when 这种不同的形式就成为了Attention里的多头的机制。 下图是引自GPT3.5对注意力的一种直观的解释,简而言之其实就是各种不同(多头)我们关注的点(注意力)构成了注意力机制,这个奠定现代人工智能基石的基础。
在这里插入图片描述



那么注意力机制的优点是什么呢? (下面的对比是相对于上一节的Seq2Seq模型)

  1. 解决了长距离依赖问题,由于Seq2Seq模型一般是以时序模型eg RNN / Lstm / GRU 作为基础, 所以就会必然导致模型更倾向新的输入 – 多头注意力机制允许模型在解码阶段关注输入序列中的不同部分
  2. 信息损失:很难将所有信息压缩到一个固定长度的向量中(encorder 输出是一个定长的向量) – 注意力机制动态地选择输入序列的关键部分
  3. 复杂度和计算成本:顺序处理序列的每个时间步 – 全部网络都是以全连接层或者点积操作,没有时序模型
  4. 对齐问题:源序列和目标序列可能存在不对齐的情况 – 注意力机制能够为模型提供更精细的词汇级别对齐信


注意力可以拆解成下面6个部分,下面会在代码实现部分逐个解释

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-npjJuAYN-1689687821801)(image/06_attention/1689603947092.png)]


  1. 缩放点积注意力

两个向量相乘可以得到其相似度, 一般常用的是直接点积也比较简单,原论文里还提出里还提出了下面两种计算相似度分数的方式, 也可以参考下图。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sAngeG15-1689687821802)(image/06_attention/1689604308442.png)]



其实相似度分数,直观理解就是两个向量点积可以得到相似度的分数,加权求和得到输出。[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-so4FVaol-1689687821802)(image/06_attention/1689604224729.png)]
在这里插入图片描述



然后就是对点积注意力的拆解首先我们要明确目标我们要求解的是X1关于X2的注意力输出,所以首先需要确定的是X1 和 X2 的特征维度以及batch_size肯定要相同,然后是seq长度可以不同。 然后我们计算原始注意力权重即X1(N , seq1 , embedding) · X2(N , seq2 , embedding) -> attention( N , seq1 , seq2) , 可以看到我们得到了X1中每个单词对X2中每个单词的注意力权重矩阵所以维度是(seq1 , seq2)。

当特征维度尺寸比较大时,注意力的值会变得非常大,从而导致后期计算softmax的时候梯度消失,所以这里会对注意力的值做一个缩小,也就是将点积的结果/scaler_factor, 这个scaler_factor一般是embedding_size 的开根号。

然后我们对X2的维度做一个softmax 得到归一化的注意力权重,至于为什么是X2,是因为我们计算的是X1关于X2的注意力,所以在下一步我会会让整个attention 权重与X2做点积也就是加权求和,这里需要把X2对应的权重做归一化所以要对X2的权重做归一化。

也就是X1 与 X2 之间相互每个单词的相似度(score) 因为我们求的是X1 关于X2的注意力,所以最后我们将归一化后的权重与X2做一个加权求和(点积)即Attention_scaled(N , seq1 , seq2) · X2(N , seq2 , embedding) -> X1_attention_X2( N , seq1 , embedding) 这个时候我们可以看到最后的输出与X1的维度相同,但是里面的信息已经是整合了X2的信息的的X1’。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2cDP4HFU-1689687821802)(image/06_attention/1689604378729.png)]



  1. 编解码注意力
    这个仅存在Seq2Seq的架构中,也就是将编码器的最后输出与解码器的隐藏状态相互结合,下图进行了解释可以看到encoder将输入的上下文进行编码后整合成一个context 向量,由于我们最后的输出是decoder 所以这里X1 是解码器的隐层状态,X2 是编码器的隐层状态。

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ti3ZkUEp-1689687821803)(image/06_attention/1689684006567.png)]



4. QKV

下面介绍的就是注意力中一个经常被弄混的概念,QKV, 根据前面的只是其实query 就是X1也就是我们需要查询的目标, Key 和 Value 也就是X2,只是X2不同的表现形式,K 可以等于 V 也可以不等,上面的做法都是相等的。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HiChaI3W-1689687821803)(image/06_attention/1689684377799.png)]

  1. 自注意力

    最后就是注意力机制最核心的内容,也就是自注意力机制,那么为什么多了一个自呢?其实就是X1 = X2 ,换句话说就是自己对自己做了一个升华,文本对自己的内容做了一个类似summary的机制,得到了精华。就如同下图一样,自注意力的QKV 都来自同一个输入X向量,最后得到的X’, 它是自己整合了自己全部信息向量,它读完了自己全部的内容,并不只是单独的一个字或者一段话,而是去其糟粕后的X。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VwoyKvPH-1689687821804)(image/06_attention/1689684493893.png)]

而多头自注意力也就是同样的X切分好几个QKV, 可以捕捉不同的重点(类似一个qkv捕捉when , 一个qkv捕捉how),所以多头是有助于网络的表达,然后这里需要注意的是多头是将embedding - > (n_head , embedding // n_head ) , 不是(n_head , embedding)。

所以多头中得到的注意力权重的shape也会变成(N , n_head , seq , seq ) 这里由于是自注意力 所以seq1 = seq2 = seq 。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5GkFKSvL-1689687821804)(image/06_attention/1689684868618.png)]



代码实现

这里只介绍了核心代码实现, 下面是多头注意力的实现:

import torch.nn as nn # 导入torch.nn库
# 创建一个Attention类,用于计算注意力权重
class Mult_attention(nn.Module):
    def __init__(self, n_head):
        super(Mult_attention, self).__init__()
        self.n_head = n_head

    def forward(self, decoder_context, encoder_context , dec_enc_attn_mask): # decoder_context : x1(q) , encoder_context : x2(k , v)
        # print(decoder_context.shape, encoder_context.shape) # X1(N , seq_1 , embedding) , X2(N , seq_2 , embedding)
        # 进行切分 , (N , seq_len_X , emb_dim) -> (N , num_head , seq_len_X , head_dim) / head_dim * num_head  = emb_dim
        Q = self.split_heads(decoder_context)  # X1
        K = self.split_heads(encoder_context)  # X2
        V = self.split_heads(encoder_context)  # X2
        # print(Q.shape , 0)

        # 将注意力掩码复制到多头 attn_mask: [batch_size, n_heads, len_q, len_k]
        attn_mask = dec_enc_attn_mask.unsqueeze(1).repeat(1, self.n_head, 1, 1)

        # 计算decoder_context和encoder_context的点积,得到多头注意力分数,其实就是在原有的基础上多加一个尺度
        scores = torch.matmul(Q, K.transpose(-2, -1)) # -> (N , num_head , seq_len_1 , seq_len_2)
        scores.masked_fill_(attn_mask , -1e9)
        # print(scores.shape ,1 )

        # 自注意力原始权重进行缩放
        scale_factor = K.size(-1) ** 0.5
        scaled_weights = scores / scale_factor # -> (N , num_head , seq_len_1 , seq_len_2)
        # print(scaled_weights.shape , 2)

        # 归一化分数
        attn_weights = nn.functional.softmax(scaled_weights, dim=-1) # -> (N , num_head , seq_len_1 , seq_len_2)
        # print(attn_weights.shape , 3)

        # 将注意力权重乘以encoder_context,得到加权的上下文向量
        attn_outputs = torch.matmul(attn_weights, V) # -> (N , num_head , seq_len_1 , embedding // num_head)
        # print(attn_outputs.shape , 4) 

        # 将多头合并下(output  & attention)
        attn_outputs  = self.combine_heads(attn_outputs) # 与Q的尺度是一样的
        attn_weights = self.combine_heads(attn_weights)
        # print(attn_weights.shape , attn_outputs.shape , 5) #
        return attn_outputs, attn_weights
  
    # 将所有头的结果拼接起来,就是把n_head 这个维度去掉,
    def combine_heads(self , tensor):
        # print(tensor.size())
        batch_size, num_heads, seq_len, head_dim = tensor.size()
        feature_dim = num_heads * head_dim
        return tensor.transpose(1, 2).contiguous().view(batch_size, seq_len, feature_dim)


    def split_heads(self , tensor):
        batch_size, seq_len, feature_dim = tensor.size()
        head_dim = feature_dim // self.n_head
        # print(tensor.shape, head_dim , self.n_head)
        return tensor.view(batch_size, seq_len, self.n_head, head_dim).transpose(1, 2)


多头注意力的解码器,这里添加了mask机制。

class DecoderWithMutliHeadAttention(nn.Module):
    def __init__(self, hidden_size, output_size , n_head):
        super(DecoderWithMutliHeadAttention, self).__init__()
        self.hidden_size = hidden_size # 设置隐藏层大小
        self.n_head = n_head # 多头
        self.embedding = nn.Embedding(output_size, hidden_size) # 创建词嵌入层
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True) # 创建RNN层
        self.multi_attention = Mult_attention(n_head = n_head)
        self.out = nn.Linear(2 * hidden_size, output_size)  # 修改线性输出层,考虑隐藏状态和上下文向量


    def forward(self, inputs, hidden, encoder_outputs , encoder_input):
        embedded = self.embedding(inputs)  # 将输入转换为嵌入向量
        rnn_output, hidden = self.rnn(embedded, hidden)  # 将嵌入向量输入RNN层并获取输出 
        dec_enc_attn_mask = self.get_attn_pad_mask(inputs, encoder_input) # 解码器-编码器掩码
        context, attn_weights = self.multi_attention(rnn_output, encoder_outputs , dec_enc_attn_mask)  # 计算注意力上下文向量
        output = torch.cat((rnn_output, context), dim=-1)  # 将上下文向量与解码器的输出拼接
        output = self.out(output)  # 使用线性层生成最终输出
        return output, hidden, attn_weights
  
    def get_attn_pad_mask(self , seq_q, seq_k):
        #-------------------------维度信息--------------------------------
        # seq_q 的维度是 [batch_size, len_q]
        # seq_k 的维度是 [batch_size, len_k]
        #-----------------------------------------------------------------
        # print(seq_q.size(), seq_k.size())
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        # 生成布尔类型张量[batch_size,1,len_k(=len_q)]
        pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  #<PAD> Token的编码值为0 
        # 变形为何注意力分数相同形状的张量 [batch_size,len_q,len_k]
        pad_attn_mask = pad_attn_mask.expand(batch_size, len_q, len_k)
        #-------------------------维度信息--------------------------------
        # pad_attn_mask 的维度是 [batch_size,len_q,len_k]
        #-----------------------------------------------------------------
        return pad_attn_mask # [batch_size,len_q,len_k]


结果

整体实验结果如下,可能是因为整体语料库太小了,所以翻译结果不是太好,但是多头注意力机制还是都跑通了:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B9v9o7Ep-1689687821804)(image/06_attention/1689687316028.png)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/768885.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3中echarts的使用

效果&#xff1a; 代码&#xff1a; <div class"outcharbox"><a-row :gutter"10"><a-col :span"8" v-for" (item, index) in linesobjdata" :key"item.MonitorItemId"><monitoringItemsChart :colorI…

49天精通Java,第40天,jd-gui反编译class文件,解决jd-gui中文乱码问题

目录 专栏导读一、添加局部变量二、反编译class文件三、解决乱码问题四、产品经理就业实战1、内容简介2、作者简介 专栏导读 本专栏收录于《49天精通Java从入门到就业》&#xff0c;本专栏专门针对零基础和需要进阶提升的同学所准备的一套完整教学&#xff0c;从0开始&#xf…

【SQL应知应会】表分区(四)• MySQL版

欢迎来到爱书不爱输的程序猿的博客, 本博客致力于知识分享&#xff0c;与更多的人进行学习交流 本文收录于SQL应知应会专栏,本专栏主要用于记录对于数据库的一些学习&#xff0c;有基础也有进阶&#xff0c;有MySQL也有Oracle 分区表 • MySQL版 前言一、分区表1.非分区表2.分区…

系统学习Linux-SSH远程服务(二)

概念 安全外壳协议&#xff0c;提供安全可靠的远程连接 特点 ssh是工作在传输层和应用层的协议 ssh提供了一组管理命令 ssh 远程登陆 scp 远程拷贝 sftp 远程上传下载 ssh-copy-id ssh keygen 生成 提供了多种身份验证机制 身份验证机制 密码验证 需要提供密码 密…

vue element select下拉框回显展示数字

vue element select下拉框回显展示数字 问题截图&#xff1a; 下拉框显示数字可以从数据类型来分析错误&#xff0c;接收的数据类型是字符串&#xff0c;但是value是数字类型 <el-form-item prop"classifyLabelId" :label"$t(item.classifyLabelId)"…

051、事务设计之TiDB事务实现方式

事务在TiDB中的存储 分布式事务 提交的第一阶段&#xff0c;会用三个CF 来存放这些数据信息&#xff0c; 一类列簇对应一类键值对&#xff0c; 第一个CF(default)存放的是数据 的键值对。第二个存放的是锁信息。 第三个对应的是提交信息。 put<3_100,Frank> 3_100: prim…

LeetCode·每日一题·1851. 包含每个查询的最小区间·优先队列(小顶堆)

题目 示例 思路 离线查询&#xff1a; 输入的结果数组queries[]是无序的。如果我们按照输入的queries[]本身的顺序逐个查看&#xff0c;时间复杂度会比较高。 于是&#xff0c;我们将queries[]数组按照数值大小&#xff0c;由小到大逐个查询&#xff0c;这种方法称之为离线查询…

Go语言之接口(interface)

1.1 、多态的含义 在java里&#xff0c;多态是同一个行为具有不同表现形式或形态的能力&#xff0c;即对象多种表现形式的体现&#xff0c;就是指程序中定义的引用变量所指向的具体类型和通过该引用变量发出的方法调用在编程时并不确定&#xff0c;而是在程序运行期间才确定&am…

T5模型: Transfer Text-to-Text Transformer(谷歌)

&#x1f525; T5由谷歌发表于2019&#xff0c;《Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer》&#xff0c;最终版本发布在&#xff1a;JMLR。 一句话总结T5: 大一统模型&#xff0c;seq2seq形式完成各类nlp任务&#xff0c;大数据集…

Docker 的前世今生:从社区到市场,从领域到技术应用的全方位分析

博主 默语带您 Go to New World. ✍ 个人主页—— 默语 的博客&#x1f466;&#x1f3fb; 《java 面试题大全》 &#x1f369;惟余辈才疏学浅&#xff0c;临摹之作或有不妥之处&#xff0c;还请读者海涵指正。☕&#x1f36d; 《MYSQL从入门到精通》数据库是开发者必会基础之…

下载编译Chromium

参考&#xff1a;Mac上本地编译Chrome浏览器踩坑笔记&#xff08;2021.02最新&#xff09; - 掘金 For Mac: 一、下载编译工具链&#xff1a;deptool git clone https://chromium.googlesource.com/chromium/tools/depot_tools.git export PATH"$PATH:/Users/yumlu/co…

jib进行本地打包,并上传本地镜像仓库

使用 Jib 进行本地打包和上传到本地镜像仓库是一种方便的方式&#xff0c;而无需编写 Dockerfile。Jib 是一个开源的 Java 容器镜像构建工具&#xff0c;它可以直接将 Java 项目打包为镜像&#xff0c;并将其推送到容器镜像仓库。 gradle 进行jib的配置 import java.time.Zon…

第53步 深度学习图像识别:Bottleneck Transformer建模(Pytorch)

基于WIN10的64位系统演示 一、写在前面 &#xff08;1&#xff09;Bottleneck Transformer "Bottleneck Transformer"&#xff08;简称 "BotNet"&#xff09;是一种深度学习模型&#xff0c;在2021年由Google的研究人员在论文"Bottleneck Transfor…

MaxCompute与 Mysql 之单字段转多行

在实际数据处理中&#xff0c;可能会遇到行列转换的数据处理&#xff0c;在 MaxCompute 与 AnalyticDB MySQL 数据处理与转换 介绍过如多行转一行&#xff0c;本篇主要介绍将逗号分割的字段转成多行。 一、MaxCompute 实现方式 在MaxCompute中有TRANS_ARRAY函数&#xff0c;可…

显示一行或两行多出的文字用省略号代替

以上就是一行的效果&#xff0c;超出宽度就用...代替 .recommendContainer .scrollItem text{/* 单行文本溢出隐藏 省略号代替 */display: block;white-space: nowrap; /*溢出不换行*/overflow: hidden; /*溢出隐藏*/text-overflow: ellipsis; /*溢出的内容已...代替*/} 多…

watch中监听vuex中state改变监听不到

watch中监听vuex中state改变监听不到 https://blog.csdn.net/aliven1/article/details/100581529?utm_mediumdistribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-0-100581529-blog-122614448.t5_layer_targeting_sa&spm1001.2101.3001.4242…

软通动力与华秋达成生态共创合作,共同推动物联网硬件创新

7月11日&#xff0c;在2023慕尼黑上海电子展现场&#xff0c;软通动力信息技术(集团)股份有限公司(以下简称“软通动力”)与深圳华秋电子有限公司(以下简称“华秋”)签署了生态共创战略合作协议&#xff0c;共同推动物联网硬件生态繁荣发展。当前双方主要基于软通动力的产品及解…

从Vue2到Vue3【二】——Composition API(第二章)

系列文章目录 内容链接从Vue2到Vue3【零】Vue3简介及创建从Vue2到Vue3【一】Composition API&#xff08;第一章&#xff09; 文章目录 系列文章目录前言一、 生命周期二、hook三、toRef以及toRefs总结 前言 Vue3作为Vue.js框架的最新版本&#xff0c;引入了许多令人激动的新…

vue项目部署自动检测更新

前言 当我们重新部署前端项目的时候&#xff0c;如果用户一直停留在页面上并未刷新使用&#xff0c;会存在功能使用差异性的问题&#xff0c;因此&#xff0c;当前端部署项目后&#xff0c;需要提醒用户有去重新加载页面。 在以往解决方案中&#xff0c;不少人会使用websocke…

C#基础--委托

C#基础–委托 C#基础–委托 简单说它就是一个能把方法当参数传递的对象,而且还知道怎么调用这个方法,同时也是粒度更小的“接口”(约束了指向方法的签名) 一、什么是委托,委托的本质是什么? 跟方法有点类似,有参数,返回值,访问修饰符+ delegate public delegate void …