李沐66_使用注意力机制的seq2seq——自学笔记

news2024/11/19 13:26:28

加入注意力

1.编码器对每次词的输出作为key和value

2.解码器RNN对上一个词的输出是query

3.注意力的输出和下一个词的词嵌入合并进入RNN

一个带有Bahdanau注意力的循环神经网络编码器-解码器模型

总结

1.seq2seq通过隐状态在编码器和解码器中传递信息

2.注意力机制可以根据解码器RNN的输出来匹配到合适的编码器RNN的输出来更有效的传递信息。

pip install d2l==0.17.6  ### 很重要,不要下载错了,对于colab
import torch
from torch import nn
from d2l import torch as d2l

注意力解码器

AttentionDecoder类定义了带有注意力机制解码器的基本接口

class AttentionDecoder(d2l.Decoder):
    """带有注意力机制解码器的基本接口"""
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    @property
    def attention_weights(self):
        raise NotImplementedError

Seq2SeqAttentionDecoder类中 实现带有Bahdanau注意力的循环神经网络解码器。

1.编码器在所有时间步的最终层隐状态,将作为注意力的键和值;

2.上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;

3.编码器有效长度(排除在注意力池中填充词元)。

class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        self.attention = d2l.AdditiveAttention(num_hiddens,num_hiddens,num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(
            embed_size + num_hiddens, num_hiddens, num_layers,
            dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # outputs的形状为(batch_size,num_steps,num_hiddens).
        # hidden_state的形状为(num_layers,batch_size,num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)

    def forward(self, X, state):
        # enc_outputs的形状为(batch_size,num_steps,num_hiddens).
        # hidden_state的形状为(num_layers,batch_size,
        # num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        # 输出X的形状为(num_steps,batch_size,embed_size)
        X = self.embedding(X).permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        for x in X:
            # query的形状为(batch_size,1,num_hiddens)
            query = torch.unsqueeze(hidden_state[-1], dim=1)
            # context的形状为(batch_size,1,num_hiddens)
            context = self.attention(
                query, enc_outputs, enc_outputs, enc_valid_lens)
            # 在特征维度上连结
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            # 将x变形为(1,batch_size,embed_size+num_hiddens)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        # 全连接层变换后,outputs的形状为
        # (num_steps,batch_size,vocab_size)
        outputs = self.dense(torch.cat(outputs, dim=0))
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,
                                          enc_valid_lens]

    @property
    def attention_weights(self):
        return self._attention_weights

使用包含7个时间步的4个序列输入的小批量测试Bahdanau注意力解码器。

encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,
                             num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2
                                  )
decoder.eval()
X = d2l.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

实例化一个带有Bahdanau注意力的编码器和解码器, 并对这个模型进行机器翻译训练。

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(
    len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
    len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
loss 0.020, 7390.3 tokens/sec on cuda:0

在这里插入图片描述

模型训练后,我们用它将几个英语句子翻译成法语并计算它们的BLEU分数。

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = d2l.predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device, True)
    print(f'{eng} => {translation}, ',
          f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
go . => va !,  bleu 1.000
i lost . => j'ai perdu .,  bleu 1.000
he's calm . => il est mouillé .,  bleu 0.658
i'm home . => je suis chez moi .,  bleu 1.000
attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((
    1, 1, -1, num_steps))

训练结束后,下面通过可视化注意力权重 会发现,每个查询都会在键值对上分配不同的权重,这说明 在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。

# 加上一个包含序列结束词元
d2l.show_heatmaps(
    attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),
    xlabel='Key positions', ylabel='Query positions')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1624603.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

stable diffusion 的controlNet 安装和使用

stable diffusion 安装controlNet需要先下载扩展 扩展地址 下载了扩展以后,需要下载相应的模型,每个模型大约1.45G,可以按需下载。 模型地址 如果下载速度太慢,可以考虑去liblib下载,但是是全量模型 liblib 模型下载完后&#…

使用windows端MySQL创建数据库

1.命令行登录数据库 命令:mysql -u用户名 -p密码; 切记命令后面要以分号结尾 2. 查看和创建数据库 查看数据库命令:show database; 创建数据库命令:mysql> create database db_classes; 创建一个名为db_classes的…

通配符HTTPS安全证书

众多类型的SSL证书,要说适用或者说省钱肯定是通配符了,因为谁都想一本SSL证书包括了整条域名,而且也不用一条一条单独管理。 通配符HTTPS安全证书,其实就是通配符SSL证书,SSL证书主流CA的参数都一样,通配符…

使用riscv-tests进行指令测试(二)

使用riscv-tests进行指令测试(二) 1 测试用例命名规则2 测试用例dump文件介绍 本文属于《 TinyEMU模拟器基础系列教程》之一,欢迎查看其它文章。 1 测试用例命名规则 用例名称 TVM Name “-” Target Environment Name “-” “指令”…

【论文浅尝】Phi-3-mini:A Highly Capable Language Model Locally on Your Phone

Phi-3-mini phi-3-mini,一个3.8亿个参数的语言模型,训练了3.3万亿个token,其总体性能,通过学术基准和内部测试进行衡量,可以与Mixtral 8x7B和GPT-3.5等模型相媲美(在MMLU上达到69%,在MT-bench上达到8.38)&…

python_django农产品物流信息服务系统6m344

Python 中存在众多的 Web 开发框架:Flask、Django、Tornado、Webpy、Web2py、Bottle、Pyramid、Zope2 等。近几年较为流行的,大概也就是 Flask 和 Django 了 Flask 是一个轻量级的 Web 框架,使用 Python 语言编写,较其他同类型框…

13 如何利用缓存实现万级并发扣减

在上一讲的实现方案里我们讨论了采用纯数据库的扣减实现方案,如果以常规的机器或者 Docker 来进行评估,此方案较难实现单机过万的 TPS。之所以介绍,是想告诉你,架构是面向业务功能、成本、实现难度、时间等因素的取舍,…

广工电工与电子技术实验报告-8路彩灯循环控制电路

实验代码 module LED_water (clk,led); input clk; output [7:0] led; reg [7:0] led; integer p; reg clk_1Hz; reg [7:0] current_state, next_state; always (posedge clk) begin if(p25000000-1)begin …

对2023年图灵奖揭晓看法

2023年图灵奖揭晓,你怎么看? 2023年图灵奖,最近刚刚颁给普林斯顿数学教授 Avi Wigderson!作为理论计算机科学领域的领军人物,他对于理解计算中的随机性和伪随机性的作用,作出了开创性贡献。这些贡献不仅推…

C++修炼之路之多态---多态的原理(虚函数表)

目录 一:多态的原理 1.虚函数表 2.原理分析 3.对于虚表存在哪里的探讨 4.对于是不是所有的虚函数都要存进虚函数表的探讨 二:多继承中的虚函数表 三:常见的问答题 接下来的日子会顺顺利利,万事胜意,生活明朗--…

PPSSPPSDL for Mac v1.17.1 PSP游戏模拟器(附500款游戏) 激活版

PPSSPPSDL for Mac是一款模拟器软件,它允许用户在Mac上运行PSP(PlayStation Portable)游戏。通过这款模拟器,用户可以体验到高清甚至更高的分辨率的游戏画面,同时还能够升级纹理以提升清晰度,并启用后处理着…

安卓手机连接电脑实用技巧:实现文件传输与共享

在手机使用过程中,我们常常需要将手机中的文件传输到电脑,或者将手机与电脑进行共享。为了实现这一需求,掌握一些实用的安卓手机连接电脑技巧就显得尤为重要。本文将为您详细介绍2种简单、高效且安全的方法,让您轻松实现安卓手机与…

【网络安全】安全事件管理处置 — 事件分级分类

专栏文章索引:网络安全 有问题可私聊:QQ:3375119339 目录 一、安全事件分级 二、应急事件分级 三、安全事件分类 四、常见安全事件原因分析 1.web入侵 2.漏洞攻击 3.网络攻击 一、安全事件分级 在对安全事件的应急响应过程中&#xf…

如何最大程度使用AWS?

随着云计算技术的不断发展,AWS已经成为众多企业的首选,为其提供了强大的基础设施和服务。那么如何最大程度地、灵活地利用AWS,成为许多企业专注的焦点。九河云作为AWS的合作伙伴,为读者们提供一些技巧和策略,帮助读者充…

物联网鸿蒙实训解决方案

一、建设背景 在数字化浪潮汹涌的时代,华为鸿蒙系统以其前瞻的技术视野和创新的开发理念,成为了引领行业发展的风向标。 据华为开发者大会2023(HDC. Together)公布的数据,鸿蒙生态系统展现出了强劲的发展动力&#x…

Qt : 禁用控件默认的鼠标滚轮事件

最近在写一个模拟器,在item中添加了很多的控件,这些控件默认是支持鼠标滚动事件的。在数据量特别大的时候,及容易不小心就把数据给修改了而不自知。所有,我们这里需要禁用掉这些控件的鼠标滚轮事件。 实现的思想很简单&#xff0c…

[Swift]组件化开发

一、组件化开发基础 1.组件定义 在软件开发中,一个组件是指一个独立的、可替换的软件单元,它封装了一组相关的功能。组件通过定义的接口与外界交互,并且这些接口隔离了组件内部的实现细节。在Swift语言中,组件可以是一个模块、一…

新品发布!无人机装调检修实训系统

近年,我国密集出台相关产业政策,推动低空经济从探索走向发展,根据新华网数据,2030年低空经济规模有望达2万亿。无人机专业属于跨学科的综合性专业,其中装调检测技术是无人机教培的重要组成部分。 天途推出无人机装调检…

easyExcel - 带图片导出

目录 前言一、情景介绍二、问题分析三、代码实现1. 单图片导出2. 多图片导出3. 多图片导出(优化) 前言 Java-easyExcel入门教程:https://blog.csdn.net/xhmico/article/details/134714025 之前有介绍过如何使用 easyExcel,以及写…

分布式文件系统--MinIO

1 MinIO安装(Docker) ●在root目录下新建docker_minio文件夹 ●在docker_minio文件夹下新建config文件夹,data文件夹 ●在root目录下新建docker_compose文件夹,在docker_compose文件夹中添加docker-compose.yaml services:minio:image: quay.io/minio/miniocontainer_name: mi…