机器翻译之创建Seq2Seq的编码器、解码器

news2024/12/24 9:11:57

1.创建编码器、解码器的基类

1.1创建编码器的基类

from torch import nn


#构建编码器的基类
class Encoder(nn.Module):   #继承父类nn.Module
    def __init__(self, **kwargs):   #**kwargs:不定常的关键字参数
        super().__init__(**kwargs)
        
    def forward(self, X, *args):  #*args:不定常的位置参数
        #若继承了Encoder这个基类,就必须实现forward(),否则就会报下这个错
        raise  NotImplementedError          

1.2创建解码器的基类

#创建解码器的基类
#创建解码器的基类比创建编码器的基类多一个 state的初始化
class Decoder(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    #初始化state
    def init_state(self, enc_outputs, *args):
        raise NotImplementedError
    
    #前向传播,解码器比编码器多传入一个state
    def forward(self, X, state):
        raise NotImplementedError

 1.3合并编码器和解码器的基类

class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, enc_X, dec_X, *args):
        """
        enc_X:编码器需传入的数据
        dec_X:解码器需传入的数据
        """
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

 2.基于上述基类,正式创建Seq2Seq编码器与解码器的类

import collections
import math
import torch
import dltools

2.1创建Seq2Seq的编码器类 

class Seq2SeqEncoder(Encoder):  #继承父类Encoder
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        """
        vocab_size:词汇表大小
        embed_size:嵌入层大小
        num_hiddens:隐藏层的神经元数量
        num_layers:隐藏层的层数
        dropout=0 : 默认所有的神经元参与计算
        """
        #初始化嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        #初始化神经网络层
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)
        
    def forward(self, X, *args):
        #在进行embedding之前,X的shape=(batch_size, num_steps, vocab_size)
        X = self.embedding(X) 
        #X经过embedding处理,X的shape=(batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)  
        #经过permute调换维度之后,X的shape=(num_steps, batch_size, embed_size)
        
        #此时, pytorch 会自动完成隐藏状态的初始化,即0, 不需要手动传入state
        outputs, state = self.rnn(X)
        #outputs的shape=(num_steps, batch_size, num_hiddens) ,最后一维是神经元的数量
        #state的shape=(num_layers, batch_size, num_hiddens)
        return outputs, state
#测试代码
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
encoder.eval()
# batch_size=4, num_steps=7
X = torch.zeros((4, 7), dtype=torch.long)
outputs, state = encoder(X)

print(outputs.shape, state.shape)
torch.Size([7, 4, 16]) torch.Size([2, 4, 16])

2.2 创建Seq2Seq的解码器类

class Seq2SeqDecoder(Decoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        #初始化嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        #初始化神经网络层
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        #初始化输出层
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    #定义函数:获取状态state
    def init_state(self, enc_outputs, *args):
        #编码器输出的结果有两个,第二个为state
        return enc_outputs[1]
    
    #前向传播
    def forward(self, X, state):
        #X的原始shape=(batch_size, num_steps, vocab_size)
        X = self.embedding(X)  #X的shape=(batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)  #调整数据维度, X的shape=(num_steps, batch_size, embed_size)
       
        # 把X和state拼接到一起. 方便计算. 
        # X现在的形状(num_steps, batch_size, embed_size) , 
        # state的形状(batch_size, num_hiddens)
        # 要把state的形状扩充成三维. 变成(num_steps, batch_size, num_hiddens)
        context = state[-1].repeat(X.shape[0], 1, 1)  #扩充X.shape[0]=num_steps次,1:所对应的维度不变
        X_and_context = torch.cat((X, context), 2) #按照索引为2的维度合并
        #此时,X_and_context的shape=(num_steps, batch_size, embed_size+num_hiddens)
        #神经网络层
        outputs, state = self.rnn(X_and_context, state)
        #输出层
        outputs = self.dense(outputs).permute(1, 0, 2) #将数据维度重新调换过来
        #outputs的shape=(batch_size, num_steps, vocab_size)
        #state的shape=(num_layers, batch_size, num_hiddens)
        return outputs, state
#测试
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
outputs, state = decoder(X, state)
outputs.shape, state.shape
(torch.Size([4, 7, 10]), torch.Size([2, 4, 32]))

3.编码器 、解码器理论图

 

4.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2151588.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[产品管理-29]:NPDP新产品开发 - 27 - 战略 - 分层战略与示例

目录 1. 公司战略 2. 经营战略 3. 创新战略 4. 新产品组合战略 5. 新产品开发战略 战略分层是企业规划和管理的重要组成部分,它涉及不同层级的战略制定和实施。以下是根据您的要求,对公司战略、经营战略、创新战略、新产品组合战略、新产品开发战略…

闯关leetcode——66. Plus One

大纲 题目地址内容 解题代码地址 题目 地址 https://leetcode.com/problems/plus-one/description/ 内容 You are given a large integer represented as an integer array digits, where each digits[i] is the ith digit of the integer. The digits are ordered from mo…

2016年国赛高教杯数学建模B题小区开放对道路通行的影响解题全过程文档及程序

2016年国赛高教杯数学建模 B题 小区开放对道路通行的影响 2016年2月21日,国务院发布《关于进一步加强城市规划建设管理工作的若干意见》,其中第十六条关于推广街区制,原则上不再建设封闭住宅小区,已建成的住宅小区和单位大院要逐…

TinkerTool System for Mac实用软件系统维护工具

TinkerTool System 是一款功能全面且强大的 Mac 实用软件,具有以下特点和功能: 软件下载地址 维护功能: 磁盘清理:能够快速扫描并清理系统中的垃圾文件、临时文件以及其他无用文件,释放宝贵的磁盘空间,保…

c++优先级队列自定义排序实现方式

1、使用常规方法实现 使用结构体实现自定义排序函数 2、使用lambda表达式实现 使用lambda表达式实现自定义排序函数 3、具体实现如下&#xff1a; #include <iostream> #include <queue> #include <vector>using namespace std; using Pair pair<in…

FL Studio 24.1.1.4285中文破解完整版免费下载FL 2024注册密钥完整版crack百度云安装包下载

FL Studio 24.1.1.4285中文破解版是一个强大的软件选项&#xff0c;可以使用专业应用程序&#xff08;如最先进的录音机、均衡器、内置工具等&#xff09;制作循环和歌曲。它由数百个合成器和混响等效果以及均衡器组成&#xff0c;除此之外&#xff0c;您还可以在新音乐制作的方…

solidwork找不到曲面

如果找不到曲面 则右键找到选项卡&#xff0c;选择曲面

Kafka3.8.0+Centos7.9的安装参考

Kafka3.8.0Centos7.9的安装参考 环境准备 操作系统版本&#xff1a;centos7.9 用户/密码&#xff1a;root/1qazXSW 主机名 IP地址 安装软件 下载地址 k1 192.168.207.131 jdk1.8zookeeper3.9.2kafka_2.13-3.8.0efak-web-3.0.1 1&#xff09; Java Downloads | Oracle …

双击热备 Electron网页客户端

安装流程&#xff1a; 1.下载node.js安装包进行安装 2.点击Next; 3.勾选&#xff0c;点击Next; 4.选择安装目录 5.选择Online 模式 6.下一步执行安装 。 7.运行cmd,执行命令 path 和 node --version&#xff0c;查看配置路径和版本 8.Goland安装插件node.js 9.配置运行…

Vue 修饰符 | 指令 区别

Vue 修饰符 | 指令 区别 在Vue.js这个前端框架中&#xff0c;修饰符&#xff08;modifiers&#xff09;和指令&#xff08;directives&#xff09;是两个非常重要的概念。这里我们深度讨论下他们区别。 文章目录 Vue 修饰符 | 指令 区别一、什么是修饰符修饰符案例常见修饰符列…

2024年“华为杯”研赛第二十一届中国研究生数学建模竞赛解题思路|完整代码论文集合

我是Tina表姐&#xff0c;毕业于中国人民大学&#xff0c;对数学建模的热爱让我在这一领域深耕多年。我的建模思路已经帮助了百余位学习者和参赛者在数学建模的道路上取得了显著的进步和成就。现在&#xff0c;我将这份宝贵的经验和知识凝练成一份全面的解题思路与代码论文集合…

Elasticsearch:检索增强生成背后的重要思想

作者&#xff1a;来自 Elastic Jessica L. Moszkowicz 星期天晚上 10 点&#xff0c;我九年级的女儿哭着冲进我的房间。她说她对代数一无所知&#xff0c;注定要失败。我进入超级妈妈模式&#xff0c;却发现我一点高中数学知识都不记得了。于是&#xff0c;我做了任何一位超级妈…

ActivityManagerService 分发广播(6)

ActivityManagerService 分发广播 简述 上一节我们看了发送广播流程&#xff0c;主要是将广播信息封装至BroadcastRecord&#xff0c;然后通过BroadcastQueueImpl/BroadcastQueueModernImpl的enqueueBroadcastLocked来发送广播。 这一节我们来看一下AMS是怎么分发广播的流程&…

【论文笔记】Are Large Kernels Better Teacheres than Transformers for ConvNets

Abstract 本文提出蒸馏中小核ConvNet做学生时&#xff0c;与Transformer相比&#xff0c;大核ConvNet因其高效的卷积操作和紧凑的权重共享&#xff0c;使得其做教师效果更好&#xff0c;更适合资源受限的应用。 用蒸馏从Transformers蒸到小核ConvNet的效果并不好&#xff0c;原…

视频去噪技术分享

视频去噪是一种视频处理技术&#xff0c;旨在从视频帧中移除噪声和干扰&#xff0c;提高视频质量。噪声可能由多种因素引起&#xff0c;包括低光照条件、高ISO设置、传感器缺陷等。视频去噪对于提升视频内容的可视性和可用性至关重要&#xff0c;特别是在安全监控、医疗成像和视…

001.从0开始实现线性回归(pytorch)

000动手从0实现线性回归 0. 背景介绍 我们构造一个简单的人工训练数据集&#xff0c;它可以使我们能够直观比较学到的参数和真实的模型参数的区别。 设训练数据集样本数为1000&#xff0c;输入个数&#xff08;特征数&#xff09;为2。给定随机生成的批量样本特征 X∈R10002 …

正点原子阿尔法ARM开发板-IMX6ULL(八)——串口通信(寄存器解释)(补:有源蜂鸣器)

文章目录 一、蜂鸣器&#xff08;待&#xff0c;理解&#xff09;1.1 第一行1.2 第二行1.3 第三行 二、串口原理2.1 通信格式2.2 UART寄存器 一、蜂鸣器&#xff08;待&#xff0c;理解&#xff09; 1.1 第一行 对于第一行&#xff0c;首先先到fsl_iomuxc文件里面寻找IOMUXC_S…

探索C语言与Linux编程:获取当前用户ID与进程ID

探索C语言与Linux编程:获取当前用户ID与进程ID 一、Linux系统概述与用户、进程概念二、C语言与系统调用三、获取当前用户ID四、获取当前进程ID五、综合应用:同时获取用户ID和进程ID六、深入理解与扩展七、结语在操作系统与编程语言的交汇点,Linux作为开源操作系统的典范,为…

01-Mac OS系统如何下载安装Python解释器

目录 Mac安装Python的教程 mac下载并安装python解释器 如何下载和安装最新的python解释器 访问python.org&#xff08;受国内网速的影响&#xff0c;访问速度会比较慢&#xff0c;不过也可以去我博客的资源下载&#xff09; 打开历史发布版本页面 进入下载页 鼠标拖到页面…

安装Kali Linux后8件需要马上安排的事

目录 一、更新升级 二、 编辑器 三、用户与权限 四、 下载TOR 五、下载终端 一、更新升级 sudo apt update -y && sudo apt upgrade -y && sudo apt autoremove 二、 编辑器 VScode或者vim&#xff1b;点击.deb就会下载了 一般都会下载到Downloads文件夹中…