【深度学习-注意力机制attention 在seq2seq中应用】

news2025/1/20 18:30:21

注意力机制

  • 为什么需要注意力机制
  • attention机制的架构总体设计
    • 一、attention本身实现
    • 评分函数
  • attention在网络模型的应用-Bahdanau 注意力
    • 加性注意力代码实现

为什么需要注意力机制

在这里插入图片描述

这是一个普通的seq2seq结构,用以实现机器对话,Encoder需要把一个输入的一个句子转化为一个最终的输出,上下文context vector,然后在Decoder中使用,但这里有些问题:

  1. 如果句子很长,这个向量很难包含sequence中最早输入的哪些词的信息,那么decoder的处理必然也缺失了这一部分。
  2. 对话的过程中,大部分情况下decoder第一个的输出应该关心的权重更应该是encoder的前半部分的输入,比如这里Yes,其实应该是对are you这样一个疑问的输出,但是这就要求decoder的预测的时候有区别的针对sequence的输入做输出,现在这个结构没办法实现这个功能。

你可能会想到LSTM或者GRU也是有memory记忆功能的,解决方案:
LSTM中的memory没有办法很大,假设它的memory的大小时K的话,就需要有一个K*K的矩阵,如果太大的memory,不仅计算量大,参数太多还会容易过拟合,因此不可行

attention机制就是用来解决这个问题,attention里面memory增加的话,参数并不会增加,一句话总结就是attention就是来解决长输入在decoder时,能够找到应该关注的输入部分的问题,它最初时从机器翻译发展的,后续也扩展到了其他领域

attention机制的架构总体设计

总体架构
这就是总体的架构设计,输入a1…an,输出b1…bn 对应,注意这里的b考虑了所有的输入,这个输出带有对于每个输入的attention score,score越大,证明这个输入越重要,a在这里可以是输入,也可以是输入解码器后hidden layer的输出,那么中间蓝色框部分就是attention主体实现,它用来生成的b1到bn
举个例子:输入are you free tomorrow? 输出的时候Yes更关注的是are you,那这个的attention score就需要高一些

普通的seq2seq结构
在这里插入图片描述
带有注意力的seq2seq
在这里插入图片描述

在普通的seq2seq相比,解码器使用的上下文变量C不再仅仅是编码器的输出,而是 注意力的输出

与普通的seq2seq模型对比下,带有注意力模型的修改就分为了两部分
1.attention本身的实现
2.attention应用到模型部分

以下详述这了两部分

一、attention本身实现

先不介绍内部的一些数学处理,attention的输出实际上是对某种输入的选择倾向
输入就是要被选择的数据和对应的查询线索
输出对要选择数据的权重
举个例子
输入:the dog is running across the grass
翻译:这个小狗正在穿越草地
解码翻译这 个 小 狗 这些词的时候,注意力应该放在the dog上,这时候我们给与the dog这些词更多的权重,这时候对于输入可能的权重就是0.5 0.5 0 0 0 0

在这里插入图片描述

在数学模型方面,
键key
查询Query
值 Value

要实现的是根据键和查询生成的线索,去计算对于值Value的倾向选择,数学表达是这样的:
在这里插入图片描述
这里的a(q, ki) 一般是经过一个评分函数映射成标量和然后一个softmax操作

这里可以形象的理解一下,比如下面三组数据:

id体重->Q身高->K年龄-> V
15016050
26516523
36017521

当输入体重K 63, 身高V 170,问现在的年龄大概是多少呢?
看到表中的信息,人脑会自然猜测年龄在23和21之间,也就是在id 2和3上权重比较高,0.6* 23 +0.4* 21,这个也接近于注意力的实质,其实是根据Q和V 做评分,用以对V加权取值,这些权重值,就是注意力。
a(q, k1) v1+ a(q, k2)v2

评分函数

评分函数实际有很多种,tanh, 经过一个线性变换,或者sin cos 、加 等等,目前业内没有最好的实践

attention在网络模型的应用-Bahdanau 注意力

很多的论文都涉及注意力的使用,这块的依据是比较早和出名的Bahdanau注意力讲解。
上文seq2se模型中讲过解码器的输入是编码器的输出(上下文变量)以及解码器输入,而在有注意力的网络模型中,这个上下文变成了注意力的输出,解码器示意:
在这里插入图片描述
其中的at,i 就是注意力权重的输出
在这里插入图片描述
时间步t-1 解码器的隐状态是St-1,也是所谓的查询
ht编码器隐状态,是键也是值

加性注意力代码实现

class AdditiveAttention(nn.Module):
    """加性注意力实现
    """
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # 在维度扩展后,
        # queries的形状:(batch_size,查询的个数,1,num_hidden)
        # key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)
        # 使用广播方式进行求和
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # self.w_v仅有一个输出,因此从形状中移除最后那个维度。
        # scores的形状:(batch_size,查询的个数,“键-值”对的个数)
        scores = self.w_v(features).squeeze(-1)
        # 这部分主要是为了遮蔽填充项,理解注意力上的时候可以先忽略它
        self.attention_weights = masked_softmax(scores, valid_lens)
        # values的形状:(batch_size,“键-值”对的个数,值的维度)
        return torch.bmm(self.dropout(self.attention_weights), values)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1007474.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

损失函数loss和优化器optimizer

损失函数与优化器的关联_criterion(outputs, labels)_写代码_不错哦的博客-CSDN博客https://blog.csdn.net/shenjianhua005/article/details/123971915?ops_request_misc&request_id6583569ecbdc4daf89dbf2d43eac9242&biz_id&utm_mediumdistribute.pc_search_resu…

2023常用的原型设计软件推荐

美观易操作的产品原型可以帮助团队构建积极的用户体验,帮助团队理解产品交互逻辑。 因此,可互动、易修改的产品原型设计对产品的点击率和回访率具有重要意义。 选择专业的产品原型设计工具,可以为团队和企业带来高效的产品设计体验。本文选…

算法通关村第十四关——解析堆在数组中找第K大的元素的应用

力扣215题, 给定整数数组nums和整数k,请返回数组中第k个最大的元素。 请注意,你需要找的是数组排序后的第k个最大的元素,而不是第k个不同的元素。 分析:按照“找最大用小堆,找最小用大堆,找中间…

亲手实现:全方位解析SpringCloud Alibaba,这份全彩笔记送给你

SpringCloud Aliababa简介 大家好,这次我们来分享一个实用的开源项目—SpringCloud Alibaba。 SpringCloud是国内外微服务开发的首选框架,而SpringCloud Alibaba则是阿里巴巴为微服务架构而开发的组件,它支持SpringCloud原生组件&#xff0…

数据分析三剑客之Numpy

数据分析三剑客:Numpy,Pandas,Matplotlib 1.简介 NumPy(Numerical Python) 是 Python 语言的一个扩展程序库,支持大量的维度数组与矩阵运算,此外也针对数组运算提供大量的数学函数库。 numpy是基于c语言开发&#x…

第二章 进程与线程 一、进程的概念、组成、特征

目录 一、进程的组成 ​编辑 二、概念 程序: 进程( Process) : PID: 进程控制块(PCB): 三、特征 1、动态性(最基本的特性) 2、并发性 3、独立性 4、异步性 5、结构性 一、进程的组成 二、概念 程序: 是静…

针对电子企业的数字工厂管理系统解决方案

随着科技的飞速发展和市场竞争的日益激烈,电子企业需要一种高效、智能的数字工厂管理系统解决方案,以提升生产效率、优化资源利用、降低运营成本,并确保高品质产品的输出。本文将详细探讨针对电子企业的数字工厂管理系统解决方案。 一、数字工…

页面静态化、Freemarker入门

页面静态化介绍 页面的访问量比较大时,就会对数据库造成了很大的访问压力,并且数据库中的数据变化频率并不高。 那需要通过什么方法为数据库减压并提高系统运行性能呢?答案就是页面静态化。页面静态化其实就是将原来的动态网页(例如通过ajax…

jmeter线程组 bzm - Arrivals Thread Group 阶梯式压测

简介 BZM - Arrivals Thread Group是jmeter的一个插件,它可以模拟并发到达的用户流量、按时间加压,可以有效地帮助测试人员评估系统在高压力和高并发情况下的性能表现。 插件下载地址(jmeter版本不低于 5.2.0 ):https:…

开课吧(三)机器人系统(ros详解)

目录 常用快捷键: 常用命令: Catkin编译系统: 简析.XML文件(说明书) name指package名字 version指版本 description指描述 maintainer指拥有者 license指授权 buildtool_depend 依赖catkin编译 build_depend指依…

【C++】day6学习成果:继承、多态、栈和循环队列

1.将之前定义的栈类和队列类都实现成模板类 栈&#xff1a; #include <iostream>#define MAX 8using namespace std;template<typename T> class Stack { private:T *data; //栈的数组&#xff0c;指向堆区空间&#xff0c;用于存储栈的容器int top; …

基于元素小组的归并排序算法

问题说明 什么是针对元素小组的归并排序算法&#xff0c;举个例子&#xff1a;假如有一个数组[1,2,3,4,5,6,7,8,9]&#xff0c;{1,2,3}为一个小组&#xff0c;{4,5,6}为一个小组&#xff0c;{7,8,9}为一个小组&#xff0c;现需要根据每个小组的第一个元素来进行排序&#xff0…

upload-labs文件上传靶场实操

文章目录 1.Pass-012.Pass-023.Pass-034.Pass-045.Pass-056.Pass-067.Pass-078.Pass-089.Pass-0910.Pass-1011.Pass-1112.Pass-1213.Pass-1314.Pass-1415.Pass-1516.Pass-16 1.Pass-01 改后缀名绕过 只能上传图片&#xff0c;先上传一个jpg格式的图片&#xff0c;然后抓包改格…

如何在 Excel 中进行加,减,乘,除

在本教程中&#xff0c;我们将执行基本的算术运算&#xff0c;即加法&#xff0c;减法&#xff0c;除法和乘法。 下表显示了我们将使用的数据以及预期的结果。 | **S / N** | **算术运算符** | **第一个号码** | **第二个号码** | **结果** | | 1 | 加法&#xff08;&#xff…

企业形象片宣传片策划要从哪里展开

企业形象片宣传片是一种有效的营销工具&#xff0c;能够向潜在客户传达企业的核心价值观、品牌形象和产品服务。对于企业来说&#xff0c;一个成功的宣传片可以增加品牌知名度&#xff0c;提高销售额&#xff0c;并建立与客户的良好关系。然而&#xff0c;要想策划一部成功的企…

org.apache.hadoop.hbase.PleaseHoldException: Master is initializing

背景 CDH集群切换数据盘&#xff0c;导致服务无法启动&#xff0c;卸载重装了 hbase、hdfs、yarn、oozie、spark等服务&#xff0c;未卸载重装的zookeeper、kafka。 重装hbase后无法创建表。 报错 hbase(main):001:0> create test,cf1 ERROR: org.apache.hadoop.hbase.Pl…

【计算机网络】传输层协议——TCP(上)

文章目录 TCPTCP协议段格式报头和有效载荷如何分离&#xff1f;4位首部长度 TCP可靠性确认应答机制的提出序号和确认序号为什么序号和确认序号在不同的字段&#xff1f; 16位窗口大小 6个标志位标志位本质具体标志位PSHRSTURG 超时重传机制 文章目录 TCPTCP协议段格式报头和有效…

SAP SD之定义装运点OVL2

什么是装运点&#xff1f; 装运点是一个独立的组织实体&#xff0c;其中进行货物的发行和交付处理。 可以为每个订单商品确定一个装运点。 确定装运点取决于以下三个因素&#xff1a; 客户主记录中的运输条款和条件&#xff08;运输屏幕&#xff09;。 例如&#xff0c;公司与…

为什么选择C/C++内存检测工具AddressSanitizer?如何使用AddressSanitizer?

目录 1、C程序中的内存问题 2、AddressSanitizer是什么&#xff1f; 3、AddressSanitizer内存检测原理简述 3.1、内存映射 3.2、插桩 4、为什么选择AddressSanitizer&#xff1f; 4.1、Valgrind介绍 4.2、AddressSanitizer在速度和内存方面为什么明显优于Valgrind 4.3…

安卓最强LSPosed框架v1.9.1正式版下载-API变更-支持安卓14新系统+刷入教程

LSPosed框架自1.86以后比较稳定&#xff0c;LSPosed官方更新的也变慢了很多&#xff0c;上周开始LSP框架又开始了大版本更新&#xff0c;直接迭代到V19.1版本。单从更新日志上来看&#xff0c;这两次的更新幅度比较大&#xff0c;也修复了很多我们常见的问题。从我们正常刷入体…