d2l 使用attention的seq2seq

news2024/12/23 22:10:54

这一章节与前面写好的function关联太大,建议看书P291.

这章节主要讲述了添加attention的seq2seq,且只在decoder里面添加,所以全文都在讲这个decoter

目录

1.训练

2.预测


1.训练

#@save
class AttentionDecoder(d2l.Decoder):
    """带有注意⼒机制解码器的基本接⼝"""
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)
        
    @property
    def attention_weights(self):
        raise self._attention_weights

  Encoder对每个词的rnn输出作为key-values,输出为(bs,k-v,h)
  Decoder对上一个词的rnn输出作为query(bs,1,h)
  要预测下一词时,当前预测的作为query,编码器对应的各个原文输出为key-values,进行attention,来找到对应预测下一个词的查询
  原seq2seq是将上一个rnn里最后一个state,现在允许所有词的输出都获取,根据上下文对应找到(bs,q,h///values)

  只有decoder使用attention
  前面的init只增加了一个attention,其他的emb、rnn、dense与seq2seq一样
  init_state相比之前增加了enc_valid_lens,告知encoder中的有效长度

class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                dropout=0, **kwargs):
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        self.attention = d2l.AdditiveAttention(
            num_hiddens, num_hiddens, num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(
            embed_size + num_hiddens, num_hiddens, num_layers,
            dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # outputs的形状为(batch_size,num_steps,num_hiddens).
        # hidden_state的形状为(num_layers,batch_size,num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    def forward(self, X, state):
        # enc_outputs的形状为(batch_size,num_steps,num_hiddens).
        # hidden_state的形状为(num_layers,batch_size,
        # num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        # 输出X的形状为(num_steps,batch_size,embed_size)
        X = self.embedding(X).permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        for x in X:
            # query的形状为(batch_size,1,num_hiddens)
            query = torch.unsqueeze(hidden_state[-1], dim=1)
            # context的形状为(batch_size,1,num_hiddens)
            context = self.attention(
                query, enc_outputs, enc_outputs, enc_valid_lens)
            # 在特征维度上连结
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            # 将x变形为(1,batch_size,embed_size+num_hiddens)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        # 全连接层变换后,outputs的形状为
        # (num_steps,batch_size,vocab_size)
        outputs = self.dense(torch.cat(outputs, dim=0))
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,
                                            enc_valid_lens]
    
        @property
        def attention_weights(self):
            return self._attention_weights

  对于for:每一步的context会变,所以需要一步一步更新context
  query里面的hidden_state[-1]里面的-1表示获取最后一层的rnn_state(bs,h),原本为(n_layers,bs,h)
  再context中,keys=values=enc_outputs,在此虽然最后维度长度相同,但是也可用加性attention。
  T表示k-v,h为values,P289;out尺寸为(bs,T,h)表示原句子所有输出,经过attention得到的context表示拿到了(bs,q,values),在所有的outputs找到了q查询对应的值,
  每个词的context不同,所以q=1,query为上一次decoder的输出的最后一层,再用attention再encoder的全部输出outputs(bs,T,h)找到相关的查询输出注意力权重和查询值(bs,q,values)--权重如何计算?通过grad自己找loss学的。
  即context为(bs,1,h),再合并x送入decoder的rnn产生下一个词的预测。(vx
  加了attention是利用decoder上一层输出到encoder中所有的out找到相关查询值,再与当前x拼一起产生预测,(第一个x的上一层输出query是encoder最后的state再state[-1]取最后一层)

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(
    len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
    len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)


2.预测

预测时,dec_x=1

可见,预测的时候dec_x就是只有一个

(bs,T)--(1,1)

再送入decoder

 经过emb后,且经过permute后,得到(T,bs,emb)

然后就是获取query,(bs,q,q_s),本次q_s=k_s=h=values;传入加性注意力拿到(bs,q,values),与当前x做拼接生成(bs,T,emb+h),再通过permute得到rnn的通识传入尺寸(T,bs,emb+h).通过rnn后得到out为(T,bs,h),再经过线性层并经过permute,得到最终输出(T,bs,len(v))

 decoder中outputs传入dense线性层操作

 原先outputs是一个list,里面的每个元素为(1,bs,h);经过cat后,将T拼接起来形成(T,bs,h),一并传入dense获得整个T的最终输出(T,bs,len(v))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/449507.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HTTP与HTTPS相关介绍(详细)

HTTP与HTTPS相关介绍 HTTP与HTTPS简述HTTPS和HTTP的区别主要如下HTTPS的工作原理前言工作步骤总结 HTTPS的缺点SSL与TLSSSL:TLS:TLS和SSL的关系 对称加密与非对称加密对称加密非对称加密 HTTP与HTTPS简述 超文本传输协议(Hyper Text [Transf…

如何无侵入地引入第三方组件?

做java开发的小伙伴都知道,java的生态比较繁荣,有各种各样的第三方组件来满足我们日常的开发需求。很多常用的中间件(redis,kafka等)都提供java的开发接口,而且这些接口通常会被封装成比较好用的组件来满足我们使用这些中间件的场…

SpringBoot集成MyBatis-yml自动化配置原理详解

SpringBoot集成MyBatis-yml自动化配置原理详解 简介:spring boot整合mybatis开发web系统目前来说是市面上主流的框架,每个Java程序和springboot mybatis相处的时间可谓是比和自己女朋友相处的时间都多,但是springboot mybatis并没有得到你的真…

RabbitMQ--详情概述

一、消息队列(Rabbit Message Queue) 1、概念 消息队列是一种应用之间的通信方式,消息发送之后可以立即返回,由消息系统来确保消息的可靠传递。消息发布者只发布消息到MQ,消息使用者值从MQ中拿消息,两者不知道对方的存在。 简…

Sentinel——限流规则

目录 快速入门 簇点链路 案例 流控模式 流控模式——关联 流控模式——链路 案例 流控效果 流控效果——warm-up 流控效果——排队等待 热点参数限流 快速入门 簇点链路 簇点链路:就是项目内的调用链路,链路中被监控的每个接口就是一个资源。…

【故障检测】基于 KPCA 的故障检测【T2 和 Q 统计指数的可视化】(Matlab代码实现)

💥 💥 💞 💞 欢迎来到本博客 ❤️ ❤️ 💥 💥 🏆 博主优势: 🌞 🌞 🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 …

[前端基础]异步操作(还没写完)

1.写在前面 这篇是因为最近再写异步操作,需要点总结 因为还在学习前端的过程中嘛,所以有些东西可能会慢慢补充上来,也可能会有很多个人理解不是很到位的地方,还望各位评论区佬能帮忙指出.阿里嘎多捏 2.异步操作的概念和举例 异步操作和同步操作在408的三门课程中,都有所提及…

基于php的校园垃圾分类网站的设计与实现

摘要 近年来,随着民众环保意识的增强和资源有效利用问题的重视,全国各地市不断推进垃圾分类工作。教育部,也已于去年发布通知在全国各学校推进垃圾分类工作,以鼓励垃圾分类的有效实施。但现阶段我国校园的垃圾分类践行情况依旧问…

STATS 782 - Control Flow and Functions

文章目录 一、Control Flow1. If-Then-Else2. Loops 二、Functions1. Defining Functions2. 使用函数计算数学公式 总结 一、Control Flow 1. If-Then-Else > if (x > 0) y sqrt(x) else y -sqrt(-x)或 > y if (x > 0) sqrt(x) else -sqrt(-x)2. Loops ① fo…

数组应该怎么用?

文章目录 前言一、数组是什么?二、数组的创建1.数组的创建:2.数组的初始化 三.数组的遍历1.逐个打印2.使用for循环四.二维数组1.语法:2.遍历 五.数组的一些常用方法1.数组转换字符串2.数组拷贝3.二分查找4.冒泡排序5.数组逆序 总结 前言 为什…

动力节点Vue笔记——Vue与Ajax

四、Vue与AJAX 4.1 回顾发送AJAX异步请求的方式 发送AJAX异步请求的常见方式包括: 原生方式,使用浏览器内置的JS对象XMLHttpRequest const xhr new XMLHttpRequest()xhr.onreadystatechange function(){}xhr.open()xhr.send() 原生方式&#xff0…

_awt_container容器_演示

Component作为基类,提供了如下常用的方法来设置组件的大小、位置、可见性等。 方法签名方法功能setLocation(int x,int y)设置组件的位置setSize(int width,int heigth)设置组件的大小setBounds(int x,int y,int width,int heigth)设置组件的位置,大小。…

基于蚂蚁优化算法的BP神经网络在负荷预测中的应用研究(Matlab完整代码实现)

💥💥💥💞💞💞欢迎来到本博客❤️❤️❤️💥💥💥 目录 1 ACO-BP算法 2 ACO-BP算法基本思路 3 具体步骤 4 运行结果 ​ 5 参考文献 6 Matlab代码实现 1 ACO-BP算法 传统的…

数组模拟实现单链表快速操作

前言:我们都知道链表的一般模式是由结构体加指针来实现的,但是在实际的比赛中,结构体指针来实现链表的操作并不常用,原因是因为我们在增加节点时需要开辟新的内存,而比赛时给出的样例大多都是十几万个数据,…

安装配置SVN版本控制管理工具

SVN工具能帮我们做什么? 核心功能:文档版本管理系统 适合对象:个人与团队都可以使用,企业中项目资源的重要管理工具 举例:一个文件夹里面的文档管理 1.下载安装SVN服务器 VisualSVN-Server 2.下载安装SVN客户端 T…

【论文阅读】CVPR2023 IGEV-Stereo

用于立体匹配的迭代几何编码代价体 【cvhub导读】【paper】【code_openi】 代码是启智社区的镜像仓库,不需要魔法,点击这里注册 🚀贡献 1️⃣现有主流方法 基于代价滤波的方法和基于迭代优化的方法: 基于代价滤波的方法可以在c…

大小字母转换

1.代码实例: public class UpString { public static void main(String[] args) { if(args!null && args.length 1){ String str new String(args[0]); System.out.println(“原字符:” str “\n”); String newA str.toUpperCase(); System.out.prin…

C语言分支和循环语句

目录 1.什么是语句😊 2.分支语句(选择结构)😊 2.1 if语句🐾 2.2 switch语句🐾 3.循环语句 😊 3.1 while循环🐾 3.2 for循环🐾 3.3 do...while()循环&#x1f43e…

太太太太太卷了,累了

我们聊到互联网行业的时候,一个不可避免的话题就是“内卷” 在程序员职场上,什么样的人最让人反感呢? 是技术不好的人吗?并不是。技术不好的同事,我们可以帮他。 是技术太强的人吗?也不是。技术很强的同事,可遇不可求&#xff…

C++内联函数的使用

1.内联函数概念 以inline修饰的函数叫做内联函数,编译时C编译器会在调用内联函数的地方展开,没有函数调用建立栈帧的开销,内联函数提升程序运行的效率。 如果在上述函数前增加inline关键字将其改成内联函数,在编译期间编译器会用…