66 使用注意力机制的seq2seq_by《李沐:动手学深度学习v2》pytorch版

news2024/10/1 5:36:43

系列文章目录


文章目录

  • 系列文章目录
  • 动机
  • 加入注意力
  • 总结
  • 代码
    • 定义注意力解码器
    • 训练
    • 小结
    • 练习


我们来真的看一下实际应用中,key,value,query是什么东西,但是取决于应用场景不同,这三个东西会产生变化。先将放在seq2seq这个例子。

动机

机器翻译中,每个生成的词可能相关于源句子中不同的词。
比如RNN的最后一个时刻的最后一个输出,把所有东西压到一起,你看到的东西可能看不清楚。我想要翻译对应的词的时候,关注原句子中对应的部分,这就是将注意力机制放在seq2seq的动机,而之前的seq2seq模型中不能对此直接建模。
在这里插入图片描述

加入注意力

在这里插入图片描述
大体上就是,我在解码器RNN的输入一部分来自embedding,之前还有一部分是来自RNN最后一个时刻的最后一层作为上下文和embedding一起传进去。现在说最后一个时刻作为Context传进去不好,我应该根据我的现在预测值的不一样去选择不是最后一个时刻,可能是前面某个时刻对应的那些隐藏状态(经过注意力机制)作为输入。
编码器对每次词的输出作为key和value(它们成对的)。
解码器RNN对上一个词的输出是query。
注意力的输出和下一个词的词嵌入合并进入解码器。

总结

  1. Seq2seq中通过隐状态在编码器和解码器中传递信息
  2. 注意力机制可以根据解码器RNN的输出来匹配到合适的编码器RNN的输出来更有效的传递信息

代码

import torch
from torch import nn
from d2l import torch as d2l

定义注意力解码器

下面看看如何定义Bahdanau注意力,实现循环神经网络编码器-解码器。
其实,我们只需重新定义解码器即可。
为了更方便地显示学习的注意力权重,
以下AttentionDecoder类定义了[带有注意力机制解码器的基本接口]。

#@save
class AttentionDecoder(d2l.Decoder):
    """带有注意力机制解码器的基本接口"""
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    @property
    def attention_weights(self): #画图所需代码
        raise NotImplementedError

接下来,让我们在接下来的Seq2SeqAttentionDecoder类中[实现带有Bahdanau注意力的循环神经网络解码器]。
首先,初始化解码器的状态,需要下面的输入:

  1. 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;
  2. 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;
  3. 编码器有效长度(排除在注意力池中填充词元)。

在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。
因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。

class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout) #相比之前只是新增了这行代码
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU( embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args): #多了一个enc_valid_lens,之前不需要,现在需要知道英语的句子哪些是pad的。
        # outputs的形状为(batch_size,num_steps,num_hiddens).
        # hidden_state的形状为(num_layers,batch_size,num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)

    def forward(self, X, state):
        # enc_outputs的形状为(batch_size,num_steps,num_hiddens).
        # hidden_state的形状为(num_layers,batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state 
        X = self.embedding(X).permute(1, 0, 2) # 输出X的形状为(num_steps,batch_size,embed_size)
        outputs, self._attention_weights = [], []
        for x in X:
            # print("x.shape = "+str(x.shape))  #x.shape = torch.Size([4, 8])
            query = torch.unsqueeze(hidden_state[-1], dim=1)  # query的形状改为(batch_size,1,num_hiddens) 虽然query只有一个,但是要把query数量这个维度加进去,才能应用上一博客的函数
            # context的形状为(batch_size,1,num_hiddens)
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens) # key和value是成对的,且都是encoder的output,enc_valid_lens是为了防止encoder中那些无效的pad的部分也被当作正常部分进行运算,encoder的长度是由num_steps确定好的定长的。query每次都会变。
            # print("context.shape = "+ str(context.shape))  #context.shape = torch.Size([4, 1, 16])
            # 在特征维度上连结
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            # 将x变形为(1,batch_size,embed_size+num_hiddens)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights) # 将输出和注意力权重保存到列表中。
        # 全连接层变换后,outputs的形状为
        # (num_steps,batch_size,vocab_size)
        outputs = self.dense(torch.cat(outputs, dim=0))
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]

    @property
    def attention_weights(self):
        return self._attention_weights

接下来,使用包含7个时间步的4个序列输入的小批量[测试Bahdanau注意力解码器]。

encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,
                             num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,
                                  num_layers=2)
decoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

训练

与 :numref:sec_seq2seq_training类似,
我们在这里指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器,
并对这个模型进行机器翻译训练。
由于新增的注意力机制,训练要比没有注意力机制的
:numref:sec_seq2seq_training慢得多。

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 30, d2l.try_gpu()

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(
    len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
    len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
loss 0.099, 10527.1 tokens/sec on cuda:0
<Figure size 350x250 with 1 Axes>

在这里插入图片描述
模型训练后,我们用它[将几个英语句子翻译成法语]并计算它们的BLEU分数。

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, device, True)
    print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
go . => <unk> .,  bleu 0.000
i lost . => je suis parti .,  bleu 0.000
he's calm . => il est <unk> .,  bleu 0.658
i'm home . => je suis <unk> .,  bleu 0.512
attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((
    1, 1, -1, num_steps))

训练结束后,下面通过[可视化注意力权重]
会发现,每个查询都会在键值对上分配不同的权重,这说明
在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。

# 加上一个包含序列结束词元
d2l.show_heatmaps(
    attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),
    xlabel='Key positions', ylabel='Query positions')

在这里插入图片描述

小结

  • 在预测词元时,如果不是所有输入词元都是相关的,那么具有Bahdanau注意力的循环神经网络编码器-解码器会有选择地统计输入序列的不同部分。这是通过将上下文变量视为加性注意力池化的输出来实现的。
  • 在循环神经网络编码器-解码器中,Bahdanau注意力将上一时间步的解码器隐状态视为查询,在所有时间步的编码器隐状态同时视为键和值。

练习

  1. 在实验中用LSTM替换GRU。
  2. 修改实验以将加性注意力打分函数替换为缩放点积注意力,它如何影响训练效率?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2181609.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用 SSH 连接 Docker 服务器:IntelliJ IDEA 高效配置与操作指南

使用 SSH 连接 Docker 服务器&#xff1a;IntelliJ IDEA 高效配置与操作指南 本文详细介绍了如何在 2375 端口未开放的情况下&#xff0c;通过 SSH 连接 Docker 服务器并在 Idea 中进行开发。通过修改用户权限、生成密钥对以及配置 SSH 访问&#xff0c;用户可以安全地远程操作…

Ubuntu 系统崩了,如何把数据拷下来

问题描述&#xff1a; Linux系统中安装输入法后&#xff0c;重启后&#xff0c;导致系统无法进入&#xff0c;进入 recovery mode下的resume 也启动不了&#xff0c;所以决定将需要的东西复制到U盘 解决方案&#xff1a; 1.重启ubuntu&#xff0c;随即点按Esc进入grub菜单&am…

Linux shell脚本set -e的作用详解

文章目录 功能详细解释示例不使用 set -e 的情况总结 set -e 是一个用于控制脚本行为的命令&#xff0c;它的作用是&#xff1a; 功能 当脚本运行时&#xff0c;set -e 会确保一旦某个命令返回非零的退出状态&#xff08;即执行失败&#xff09;&#xff0c;整个脚本会立即停止…

Docker面试-24年

1、Docker 是什么&#xff1f; Docker一个开源的应用容器引擎&#xff0c;是实现容器技术的一种工具&#xff0c;让开发者可以打包他们的应用以及环境到一个镜像中&#xff0c;可以快速的发布到任何流行的操作系统上。 2、Docker的三大核心是什么? 镜像&#xff1a;Docker的…

在 Kali Linux 中安装 Impacket

步骤 1&#xff1a;更新系统 打开终端并确保你的系统是最新的&#xff1a; sudo apt update && sudo apt upgrade -y 步骤 2&#xff1a;安装依赖 在安装 Impacket 之前&#xff0c;你需要确保安装了 Python 和一些必要的依赖。通常&#xff0c;Kali 已经预装了 Pytho…

工作日志:el-table在无数据情况下,出现横向滚动条。

1、遇到一个警告。 原因&#xff1a;中的组件不能呈现动画的非元素根节点。 也就是说&#xff0c;Transition包裹的必须是一个单根的组件。 2、el-table在无数据情况下&#xff0c;出现横向滚动条&#xff0c;大概跟边框的设置有关系。 开始排查。 给.el-scrollbar加了一个…

Linux 线程同步

前言 上一期我们介绍了线程互斥&#xff0c;并通过加锁解决了多线程并发访问下的数据不一致问题&#xff01;本期我们来介绍一下同步问题&#xff01; 目录 前言 一、线程同步 • 线程同步的引入 • 同步的概念 理解同步和饥饿问题 • 条件变量 理解条件变量 • 同步…

TypeScript 算法手册 【数组基础知识】

文章目录 1. 数组简介1.1 数组定义1.2 数组特点 2. 数组的基本操作2.1 访问元素2.2 添加元素2.3 删除元素2.4 修改元素2.5 查找元素 3. 数组的常见方法3.1 数组的创建3.2 数组的遍历3.3 数组的映射3.4 数组的过滤3.5 数组的归约3.6 数组的查找3.7 数组的排序3.8 数组的反转3.9 …

AI写作赋能数据采集,开启无限可能性

由人工智能 AI 掀起的新一轮科技革命浪潮&#xff0c;正在不断推动社会进步、各行各业升级发展&#xff0c;深刻影响人们的生活方式&#xff0c;引领我们进入一个充满无限可能的新时代。 那么在数据采集方面&#xff0c;人工智能 AI 可以做什么呢&#xff1f; 下面是搜集网络…

开源在线表结构设计工具

Free, simple, and intuitive database design tool and SQL generator. drawDB在线体验 Discord X drawDB DrawDB is a robust and user-friendly database entity relationship (DBER) editor right in your browser. Build diagrams with a few clicks, export sql scri…

若依--文件上传前端

前端 ry的前端文件上传单独写了一个FileUpload.Vue文件。在main.js中进行了全局的注册&#xff0c;可以在页面中直接使用文件上传的组件。全局导入 在main.js中 import 组件名称 from /components/FileUpLoadapp.compoent(组件名称) //全局挂载组件在项目中使用 组件命令 中…

定时器定时中断定时器外部中断

TIM的函数 // 恢复缺省设置 void TIM_DeInit(TIM_TypeDef* TIMx); // 时基单元初始化&#xff0c;第一个参数TIMx选择某个定时器&#xff0c;第二个参数是结构体&#xff0c;包含了配置时基单元的一些参数。 void TIM_TimeBaseInit(TIM_TypeDef* TIMx, TIM_TimeBaseInitTypeDe…

28 Vue3之搭建公司级项目规范

可以看到保存的时候ref这行被提到了最前面的一行 要求内置库放在组件的前面称为auto fix&#xff0c;数组new arry改成了字面量&#xff0c;这就是我们配置的规范 js规范使用的是airbnb规范模块使用的是antfu 组合prettier&eslint airbnb规范&#xff1a; https://github…

《More Effective C++》的学习

引用与指针 没有所谓的null reference reference一定需要代表某个对象&#xff0c;所以C要求reference必须有初值。 QString &s; 使用reference可能比使用pointer更高效。 因为reference一定是有效的&#xff0c;而指针可能为空&#xff08;需要多加一个判断&#xff0…

Springboot3 + MyBatis-Plus + MySql + Vue + ProTable + TS 实现后台管理商品分类(最新教程附源码)

Springboot3 MyBatis-Plus MySql Uniapp 商品加入购物车功能实现&#xff08;针对上一篇sku&#xff09; 1、效果展示2、数据库设计3、后端源码3.1 application.yml 方便 AliOssUtil.java 读取3.2 model 层3.2.1 BaseEntity3.2.1 GoodsType3.2.3 GoodsTypeSonVo3.3 Controll…

论文翻译 | LLaMA-Adapter :具有零初始化注意的语言模型的有效微调

摘要 我们提出了一种轻量级的自适应方法&#xff0c;可以有效地将LLaMA微调为指令遵循模型。lama - adapter采用52K自指导演示&#xff0c;在冻结的LLaMA 7B模型上只引入1.2M可学习参数&#xff0c;在8个A100 gpu上进行微调花费不到一个小时。具体来说&#xff0c;我们采用了一…

Vue3+Antv X6流程图基本使用

安装 antv/X6 npm i antv/x6 <template><div class"homes"><div class"Shang">上</div><div class"Zhong"><div id"container"></div></div><div class"Xia">下<…

wordpress Contact form 7发件人邮箱设置

此教程仅适用于演示站有留言的主题&#xff0c;演示站没有留言的主题&#xff0c;就别往下看了&#xff0c;免费浪费时间。 使用了Contact form 7插件的简站WordPress主题&#xff0c;在有人留言时&#xff0c;就会发邮件到网站的系统邮箱(一般与管理员邮箱为同一个)里。上面显…

Java | Leetcode Java题解之第448题找到所有数组中消失的数字

题目&#xff1a; 题解&#xff1a; class Solution {public List<Integer> findDisappearedNumbers(int[] nums) {int n nums.length;for (int num : nums) {int x (num - 1) % n;nums[x] n;}List<Integer> ret new ArrayList<Integer>();for (int i …

传奇外网架设教程带图文解说—Gee引擎

架设前准备工作&#xff1a; ①通过百度网盘下载版本、补丁、客户端和DBC2000。版本解压到D盘&#xff0c;客户端解压到D盘或是E盘&#xff0c;补丁先不解压 ②安装和配置DBC2000&#xff0c;有些版本不一定用的是DBC2000数据库&#xff0c;看引擎默认的数据库是哪个 DBC数据…