Bahdanau注意力机制

news2024/11/15 22:52:01

介绍

在Bahadanu注意力机制中,本质上是序列到序列学习的注意力机制实现,在编码器-解码器结构中,解码器的每一步解码过程都依赖着整个上下文变量,通过Bahdanau注意力,使得解码器在每一步解码时,对于整个上下文变量的不同部分产生不同程度的对齐,如在文本翻译时,将“I am studying”的“studying”与“我正在学习”的“学习”进行对齐,即注意力在解码时将绝大多数注意力放在“studying”处。

原理和结构

原理

Bahdanau注意力机制本质上是将上下文变量进行转换即可,其中转换后的上下文变量计算方式如下式所示:

c_{t^{'}}=\Sigma_{t=1}^T \alpha (s_{t^{'}-1},h_t)h_t

在传统注意力机制中,一般使用的公式形如c_{t^{'}}=\Sigma_{t=1}^T \alpha (s_{t^{'}-1},k) V,在Bahdanau中,键与值是同一个变量,都是t时刻的编码器隐状态,s表示该时刻的查询,即上一时刻的解码器隐状态。

架构

下图为Bahdanau注意力机制的编码器-解码器架构示意图:

为便于理解,对上述示意结构进行说明:首先将X依次输入GRU,之后在循环过程中依次产生len个隐状态,最后一个隐状态h_{len}直接作为解码器的初始隐状态。在每个解码步骤 (t),注意力机制计算当前解码器隐藏状态 (s_t) 和编码器所有隐藏状态 (h_i) 的相似度,即应用注意力机制编写新的上下文变量,之后在解码器的循环解码过程中,都计算带注意力的上下文变量,通过此变量和上一解码隐状态计算当前时刻t的输出。

代码实现

引入库

注:本blog使用mxnet进行训练学习。

from mxnet import np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2l

npx.set_np()

 定义注意力解码器

这里实现一个接口,只需重新定义解码器即可。 为了更方便地显示学习的注意力权重, 以下AttentionDecoder类定义了带有注意力机制解码器的基本接口。

class AttentionDecoder(d2l.Decoder):
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)
    def attention_weights(self):
        raise NotImplementedError

接下来,让我们在接下来的Seq2SeqAttentionDecoder类中实现带有Bahdanau注意力的循环神经网络解码器。 首先,初始化解码器的状态,需要下面的输入:

  1. 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;

  2. 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;

  3. 编码器有效长度(排除在注意力池中填充词元)。

在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。 因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。

对接下来的代码实现略作补充说明:在编码器中,对一个batch每个输入(采用one-hot编码,长度为Vocab_size)依次进行嵌入层运算,得到固定embed_size个结果之后进行RNN运算,RNN使用层数为num_layers的深层循环神经网络,进行forward运算得到状态,部分过程进行闭包和解包,对于对维度大小出现疑惑的点,大多是闭包和解包造成的。

class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        self.attention = d2l.AdditiveAttention(num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Dense(vocab_size, flatten=False)
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        outputs, hidden_state = enc_outputs
        return (outputs.swapaxes(0, 1), hidden_state, enc_valid_lens)
    def forward(self, X, state):
        enc_outputs, hidden_state, enc_valid_lens = state
        X = self.embedding(X).swapaxes(0, 1)
        outputs, self._attention_weights = [], []
        for x in X:
            query = np.expand_dims(hidden_state[0][-1], axis=1)
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
            x = np.concatenate((context, np.expand_dims(x, axis=1)), axis=-1)
            out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        outputs = self.dense(np.concatenate(outputs, axis=0))
        return outputs.swapaxes(0, 1), [enc_outputs, hidden_state,enc_valid_lens]
    def attention_weights(self):
        return self._attention_weights

训练

我们在这里指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器, 并对这个模型进行机器翻译训练。

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

结果 

采用BLEU计算困惑度,代码具有较好的表现。

go . => entre ., bleu 0.000
i lost . => j'ai gagné ., bleu 0.000
he's calm . => j'ai gagné ., bleu 0.000
i'm home . => je suis chez moi <unk> !, bleu 0.719

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2089701.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ET6框架(七)Excel配置工具

文章目录 一、Excel表的基本规则&#xff1a;二、特殊特殊标记三、编译路径说明四、动态获取数据五、可导表类型查看: 一、Excel表的基本规则&#xff1a; 在框架中我们的Excel配置表在ET > Excel文件夹中 1.在表结构中需要注意的是起始点必须在第三行第三列&#xff0c;且…

91.游戏的启动与多开-游戏启动

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 内容参考于&#xff1a;易道云信息技术研究院 上一个内容&#xff1a;90.游戏安全项目-项目搭建与解析 以90.游戏安全项目-项目搭建与解析它的代码为基础进行…

[java][代码] java中date格式化输出时间字符串

Date date new Date();//具备默认的风格//DateFormat dateFormDateFormat.getDateTimeInstance(DateFormat.LONG,DateFormat.LONG);DateFormat dateFormnew SimpleDateFormat("yyyy-mm-dd");System.out.println(dateForm.format(date)); 这段Java代码演示了如何使用S…

YOLOv9改进策略【模型轻量化】| PP-LCnet

一、本文介绍 本文记录的是利用PP-LCNet中的DepSepConv模块优化YOLOv9中的RepNCSPELAN4。YOLOv9在使用辅助分支后&#xff0c;模型的参数量和计算量相对较大&#xff0c;本文利用DepSepConv模块改善模型结构&#xff0c;使模型在几乎不增加延迟的情况下提升网络准确度。 文章目…

海外新闻稿发布对区块链项目具有深远的影响

在知名媒体平台上发布新闻稿对区块链项目具有深远的影响&#xff0c;例如全球雅虎&#xff0c;彭博社这些全球知名财经媒体&#xff0c;能够显著提升项目的曝光度和信誉&#xff0c;吸引潜在投资者以及其他利益相关者的关注。以下是几个方面的详细分析&#xff1a; 1. 增强品牌…

区块链通证系统功能分析

区块链通证系统功能分析涉及多个关键方面&#xff0c;以确保系统能够满足不同的业务需求和合规性要求。 同质与非同质通证&#xff1a;区块链通证系统需要支持同质通证&#xff08;如ERC-20&#xff09;和非同质通证&#xff08;如ERC-721&#xff09;&#xff0c;以适应不同类…

如何快速掌握销售数据?一张报表就够了

在销售管理中&#xff0c;数据是企业做出战略决策的重要依据。有效的销售数据分析不仅能帮助企业精准把握市场动向&#xff0c;还能提高销售团队的工作效率&#xff0c;优化客户关系管理。然而&#xff0c;面对海量的销售数据&#xff0c;如何高效地解读这些数据呢&#xff1f;…

Java SpringBoot实现大学生平时成绩量化管理系统:一步步教你构建高效成绩统计,集成MySQL数据库,打造自动化评分流程

&#x1f34a;作者&#xff1a;计算机毕设匠心工作室 &#x1f34a;简介&#xff1a;毕业后就一直专业从事计算机软件程序开发&#xff0c;至今也有8年工作经验。擅长Java、Python、微信小程序、安卓、大数据、PHP、.NET|C#、Golang等。 擅长&#xff1a;按照需求定制化开发项目…

three.js 编辑器,动画,着色器, cesium 热力图,聚合点位,大量点线面, 图层,主题,文字,等众多案例中心

对于大多数的开发者来言&#xff0c;看了很多文档可能遇见不到什么有用的&#xff0c;就算有用从文档上看&#xff0c;把代码复制到自己的本地大多数也是不能用的&#xff0c;非常浪费时间和学习成本&#xff0c; 尤其是three.js &#xff0c; cesium.js 这种难度较高&#xff…

ThinkPHP之入门讲解

文章目录 1 ThinkPHP1.1 框架1.1.1 目录讲解1.1.1.1 5.x1.1.1.2 6.0以上 1.1.2 配置文件1.1.2.1 5.x1.1.2.2 6.0以上 1.1.3 函数文件1.1.3.1 5.x1.1.3.1 6.0以上 1.2 控制器1.2.1 控制器的后缀1.2.2 框架中的命名空间1.2.3 url访问1.2.4 调试模式1.2.4.1 5.x1.2.4.2 6.0以上 1.…

Oracle迁移至openGauss的工具:ora2op的安装配置

目录 前言 1. ora2op的下载 1.1 下载地址 1.2 ora2op 介绍 2. ora2op的安装 2.1 安装perl的依赖包 2.2 安装连接Oracle数据库的模块 2.3 安装ora2op 2.4 安装连接openGauss数据库的模块 前言 本工具是使用perl&#xff0c;在安装时会遇到各种问题&#xff0c;解决方式…

Keil5 Debug模式Watch窗口添加的监控变量被自动清除

Keil5 Debug模式Watch窗口添加的监控变量被自动清除 问题解决记录 问题描述&#xff1a;每次进入Debug模式时&#xff0c;watch窗口里面上一次调试添加的监控变量都会被全部清掉 如图&#xff1a; 退出Debug模式后&#xff0c;重新进入Debug模式&#xff1a; 解决方法&…

用户体验设计案例:提升电商网站的用户体验

在数字化时代&#xff0c;用户体验设计&#xff08;UX Design&#xff09;已成为影响品牌成功的关键因素之一。尤其是在竞争激烈的电商行业&#xff0c;如何通过优质的用户体验来吸引和留住客户&#xff0c;是每个企业都需要面对的挑战。本文将通过一个具体的电商网站设计案例&…

解析 uni-app 小程序分包策略:合理优化包体积分布

引言 微信小程序的流行使得越来越多的开发者投入到小程序的开发中。但是&#xff0c;随着项目规模的增大&#xff0c;小程序的性能也会面临一些挑战。其中&#xff0c;小程序分包策略是提升性能的重要一环。 同时&#xff0c;uni-app 的流行也使众多的移动端开发者选择使用 u…

AcWing895. 最长上升子序列

这个代码不知道怎么说&#xff0c;反正就是对着代码手算一次就懂了&#xff0c;无需多言&#xff0c;就是俩for循环里面的第二层for的循环条件是j<i,j是从下标1往下标i-1遍历的&#xff0c;每次a【j】<a【i】就在答案数组f【i】上面做出更新。基本的输入样例已经可以覆盖…

揭秘数字水印技术:使用PyQt5实现图像中的LSB隐写术

在当今的数字化世界中&#xff0c;保护信息的安全性和隐秘性变得尤为重要。无论是在保护版权的数字水印&#xff0c;还是在隐秘传输信息的过程中&#xff0c;数字隐写术&#xff08;Steganography&#xff09;都是一种不可或缺的技术。今天&#xff0c;我们将带领大家探索一种简…

关于LLC知识14

1、LLC必须工作在感性区 2、为了降低LLC进入容性区后MOS管的电流应力&#xff0c;必须要选择快管&#xff0c;对体二极管的反向恢复参数有要求&#xff1a;trr<200ns 3、对于上下管的死区时间不能太短&#xff0c;否则电容无法充放电完成&#xff0c;就无法实现ZVS导通 如…

DNN学习平台(GoogleNet、SSD、FastRCNN、Yolov3)

DNN学习平台&#xff08;GoogleNet、SSD、FastRCNN、Yolov3&#xff09; 前言相关介绍1&#xff0c;登录界面&#xff1a;2&#xff0c;主界面&#xff1a;3&#xff0c;部分功能演示如下&#xff08;1&#xff09;识别网络图片&#xff08;2&#xff09;GoogleNet分类&#xf…

短视频去哪里找素材?16个有短视频素材APP网站分享给你

短视频创作爱好者们&#xff0c;你们好&#xff01;在这个充斥视觉内容的年代&#xff0c;制作一部引人注目的短视频无疑是一项挑战&#xff0c;但也是一种艺术。要想打造出色的视频内容&#xff0c;首先需要的是高质量的素材。今天&#xff0c;我将向大家推荐几个非常棒的短视…

windows安装macos虚拟机

我用的是macOS Ventura 13.5(22G74) MH.iso 下载链接 https://macoshome.com/macos/20492.html#Down 一、下载unlocker 用于给VMware提供macos选项 下载链接 https://github.com/DrDonk/unlocker/releases/tag/v4.2.7下载好后解压&#xff0c;进入windows目录&#xff0c;双…