python Seq2Seq模型源码实战,超详细Encoder-Decoder模型解析实战;早期机器翻译模型源码demo

news2024/12/23 19:45:41

1.Seq2Seq(Encoder-Decoder)模型简介

Seq2Seq(Encoder-Decoder)模型是一种常用于序列到序列(sequence-to-sequence)任务的深度学习模型。它由两个主要的组件组成:编码器(Encoder)和解码器(Decoder)。

编码器负责将输入序列编码成一个固定长度的向量,该向量包含了输入序列的语义信息。常用的编码器模型有循环神经网络(RNN)和Transformer。RNN编码器通过逐个时间步处理输入序列,并将最后一个时间步的隐藏状态作为向量输出。而Transformer编码器则将输入序列同时输入到多个注意力头中,从而捕捉输入的全局依赖关系。

解码器将编码器输出的向量作为输入,逐步生成目标序列。在每一个时间步,解码器通过学习到的语境信息和当前的输入,预测下一个输出的标记。常用的解码器模型也有RNN和Transformer。RNN解码器通常使用循环单元,每个时间步依次生成输出。而Transformer解码器可以同时计算目标序列中每个位置的输出,从而并行生成序列。

Seq2Seq模型适用于多种任务,如机器翻译、文本摘要、对话生成等。在机器翻译任务中,输入序列是源语言句子,目标序列是目标语言句子,模型的目标是学习源语言到目标语言的翻译规则。在训练阶段,给定源语言句子,模型通过编码器将其编码为一个向量表示,然后通过解码器生成目标语言句子。在测试阶段,给定一个源语言句子,模型使用编码器生成其向量表示,然后利用解码器生成目标语言句子。

Seq2Seq模型的一些改进包括使用注意力机制(Attention)来解决长序列问题,以及引入更复杂的网络结构和训练技巧来提升模型性能。此外,Seq2Seq模型也可以通过引入强化学习技术进行训练优化,以更好地适应特定任务需求。

2.Seq2Seq(Encoder-Decoder)模型代码实战

机器翻译任务

2.1构建语料库 

# 构建语料库,每行包含中文、英文(解码器输入)和翻译成英文后的目标输出 3 个句子
sentences = [
    ['成哥 喜欢 小冰', '<sos> ChengGe likes XiaoBing', 'ChengGe likes XiaoBing <eos>'],
    ['我 爱 学习 人工智能', '<sos> I love studying AI', 'I love studying AI <eos>'],
    ['深度学习 改变 世界', '<sos> DL changed the world', 'DL changed the world <eos>'],
    ['自然 语言 处理 很 强大', '<sos> NLP is so powerful', 'NLP is so powerful <eos>'],
    ['神经网络 非常 复杂', '<sos> Neural-Nets are complex', 'Neural-Nets are complex <eos>']]
word_list_cn, word_list_en = [], []  # 初始化中英文词汇表
# 遍历每一个句子并将单词添加到词汇表中
for s in sentences:
    word_list_cn.extend(s[0].split())
    word_list_en.extend(s[1].split())
    word_list_en.extend(s[2].split())
# 去重,得到没有重复单词的词汇表
word_list_cn = list(set(word_list_cn))
word_list_en = list(set(word_list_en))
# 构建单词到索引的映射
word2idx_cn = {w: i for i, w in enumerate(word_list_cn)}
word2idx_en = {w: i for i, w in enumerate(word_list_en)}
# 构建索引到单词的映射
idx2word_cn = {i: w for i, w in enumerate(word_list_cn)}
idx2word_en = {i: w for i, w in enumerate(word_list_en)}
# 计算词汇表的大小
voc_size_cn = len(word_list_cn)
voc_size_en = len(word_list_en)
print(" 句子数量:", len(sentences)) # 打印句子数
print(" 中文词汇表大小:", voc_size_cn) # 打印中文词汇表大小
print(" 英文词汇表大小:", voc_size_en) # 打印英文词汇表大小
print(" 中文词汇到索引的字典:", word2idx_cn) # 打印中文词汇到索引的字典
print(" 英文词汇到索引的字典:", word2idx_en) # 打印英文词汇到索引的字典

2.2数据预处理

import numpy as np # 导入 numpy
import torch # 导入 torch
import random # 导入 random 库
# 定义一个函数,随机选择一个句子和词汇表生成输入、输出和目标数据
def make_data(sentences):
    # 随机选择一个句子进行训练
    random_sentence = random.choice(sentences)
    # 将输入句子中的单词转换为对应的索引
    encoder_input = np.array([[word2idx_cn[n] for n in random_sentence[0].split()]])
    # 将输出句子中的单词转换为对应的索引
    decoder_input = np.array([[word2idx_en[n] for n in random_sentence[1].split()]])
    # 将目标句子中的单词转换为对应的索引
    target = np.array([[word2idx_en[n] for n in random_sentence[2].split()]])
    # 将输入、输出和目标批次转换为 LongTensor
    encoder_input = torch.LongTensor(encoder_input)
    decoder_input = torch.LongTensor(decoder_input)
    target = torch.LongTensor(target)
    return encoder_input, decoder_input, target 
# 使用 make_data 函数生成输入、输出和目标张量
encoder_input, decoder_input, target = make_data(sentences)
for s in sentences: # 获取原始句子
    if all([word2idx_cn[w] in encoder_input[0] for w in s[0].split()]):
        original_sentence = s
        break
print(" 原始句子:", original_sentence) # 打印原始句子
print(" 编码器输入张量的形状:", encoder_input.shape)  # 打印输入张量形状
print(" 解码器输入张量的形状:", decoder_input.shape) # 打印输出张量形状
print(" 目标张量的形状:", target.shape) # 打印目标张量形状
print(" 编码器输入张量:", encoder_input) # 打印输入张量
print(" 解码器输入张量:", decoder_input) # 打印输出张量
print(" 目标张量:", target) # 打印目标张量

2.3定义编码层和解码层

import torch.nn as nn # 导入 torch.nn 库
# 定义编码器类,继承自 nn.Module
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()       
        self.hidden_size = hidden_size # 设置隐藏层大小       
        self.embedding = nn.Embedding(input_size, hidden_size) # 创建词嵌入层       
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True) # 创建 RNN 层    
    def forward(self, inputs, hidden): # 前向传播函数
        embedded = self.embedding(inputs) # 将输入转换为嵌入向量       
        output, hidden = self.rnn(embedded, hidden) # 将嵌入向量输入 RNN 层并获取输出
        return output, hidden
# 定义解码器类,继承自 nn.Module
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(Decoder, self).__init__()       
        self.hidden_size = hidden_size # 设置隐藏层大小       
        self.embedding = nn.Embedding(output_size, hidden_size) # 创建词嵌入层
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)  # 创建 RNN 层       
        self.out = nn.Linear(hidden_size, output_size) # 创建线性输出层    
    def forward(self, inputs, hidden):  # 前向传播函数     
        embedded = self.embedding(inputs) # 将输入转换为嵌入向量       
        output, hidden = self.rnn(embedded, hidden) # 将嵌入向量输入 RNN 层并获取输出       
        output = self.out(output) # 使用线性层生成最终输出
        return output, hidden
n_hidden = 128 # 设置隐藏层数量
# 创建编码器和解码器
encoder = Encoder(voc_size_cn, n_hidden)
decoder = Decoder(n_hidden, voc_size_en)
print(' 编码器结构:', encoder)  # 打印编码器的结构
print(' 解码器结构:', decoder)  # 打印解码器的结构

2.4创建 Seq2Seq 架构 

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        # 初始化编码器和解码器
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, enc_input, hidden, dec_input):    # 定义前向传播函数
        # 使输入序列通过编码器并获取输出和隐藏状态
        encoder_output, encoder_hidden = self.encoder(enc_input, hidden)
        # 将编码器的隐藏状态传递给解码器作为初始隐藏状态
        decoder_hidden = encoder_hidden
        # 使解码器输入(目标序列)通过解码器并获取输出
        decoder_output, _ = self.decoder(dec_input, decoder_hidden)
        return decoder_output
# 创建 Seq2Seq 架构
model = Seq2Seq(encoder, decoder)
print('S2S 模型结构:', model)  # 打印模型的结构

2.5模型训练

# 定义训练函数
def train_seq2seq(model, criterion, optimizer, epochs):
    for epoch in range(epochs):
       encoder_input, decoder_input, target = make_data(sentences) # 训练数据的创建
       hidden = torch.zeros(1, encoder_input.size(0), n_hidden) # 初始化隐藏状态      
       optimizer.zero_grad()# 梯度清零        
       output = model(encoder_input, hidden, decoder_input) # 获取模型输出        
       loss = criterion(output.view(-1, voc_size_en), target.view(-1)) # 计算损失        
       if (epoch + 1) % 40 == 0: # 打印损失
          print(f"Epoch: {epoch + 1:04d} cost = {loss:.6f}")         
       loss.backward()# 反向传播        
       optimizer.step()# 更新参数
# 训练模型
epochs = 400 # 训练轮次
criterion = nn.CrossEntropyLoss() # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 优化器
train_seq2seq(model, criterion, optimizer, epochs) # 调用函数训练模型

 

2.6进行测试

# 定义测试函数
def test_seq2seq(model, source_sentence):
    # 将输入的句子转换为索引
    encoder_input = np.array([[word2idx_cn[n] for n in source_sentence.split()]])
    # 构建输出的句子的索引,以 '<sos>' 开始,后面跟 '<eos>',长度与输入句子相同
    decoder_input = np.array([word2idx_en['<sos>']] + [word2idx_en['<eos>']]*(len(encoder_input[0])-1))
    # 转换为 LongTensor 类型
    encoder_input = torch.LongTensor(encoder_input)
    decoder_input = torch.LongTensor(decoder_input).unsqueeze(0) # 增加一维    
    hidden = torch.zeros(1, encoder_input.size(0), n_hidden) # 初始化隐藏状态    
    predict = model(encoder_input, hidden, decoder_input) # 获取模型输出    
    predict = predict.data.max(2, keepdim=True)[1] # 获取概率最大的索引
    # 打印输入的句子和预测的句子
    print(source_sentence, '->', [idx2word_en[n.item()] for n in predict.squeeze()])
# 测试模型
test_seq2seq(model, '成哥 喜欢 小冰')  
test_seq2seq(model, '自然 语言 处理 很 强大')

3.总结

        综上所述,Seq2Seq模型是一种用于序列到序列任务的深度学习模型,由编码器和解码器组成。编码器将输入序列编码为一个向量,解码器逐步生成目标序列。该模型在机器翻译、文本摘要、对话生成等任务中得到广泛应用,并有很多改进和扩展。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1405009.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

H6201L 150V降压芯片 48V 60V 72V 90V 100V 120V降33V 12V 5V 控制器

电动车控制器中的降压芯片也是一种电源管理芯片&#xff0c;用于调整和稳定电动车的电源输出电压。其工作原理与一般降压恒压芯片类似&#xff0c;但在电动车控制器中有一些特定的应用和考虑因素。 降压模式&#xff08;Buck Mode&#xff09;:输入电压通常由电动车的电池提供&…

Qt采集本地摄像头推流成rtsp/rtmp(可网页播放/支持嵌入式linux)

一、功能特点 支持各种本地视频文件和网络视频文件。支持各种网络视频流&#xff0c;网络摄像头&#xff0c;协议包括rtsp、rtmp、http。支持将本地摄像头设备推流&#xff0c;可指定分辨率和帧率等。支持将本地桌面推流&#xff0c;可指定屏幕区域和帧率等。自动启动流媒体服…

基于PHP反序列化练习

PHP创建一个以自己姓名命名的类&#xff0c;要求存在两个属性&#xff0c;name&#xff0c;age&#xff0c;进行序列化&#xff0c;输出序列化以后的数据。 <!-- PHP创建一个以自己姓名命名的类&#xff0c;要求存在两个属性&#xff0c;name&#xff0c;age --> <?…

18.鸿蒙HarmonyOS App(JAVA)日期选择器-时间选择器

18.鸿蒙HarmonyOS App(JAVA)日期选择器-时间选择器 点击button按钮触发事件显示月份与获取的时间 Button button3 (Button) findComponentById(ResourceTable.Id_button3);button3.setClickedListener(new Component.ClickedListener() {Overridepublic void onClick(Compon…

让二叉树无处可逃

志不立&#xff0c;天下无可成之事。 ——王阳明 二叉树 1、树&#xff1f;什么是树1、1、基本概念1、2、树的相关概念1、3、树的表示方式1、4、树的实际运用 2、二叉树&#xff1f;只有两个分支吗&#xff1f;2、1、基本概念2、2、二叉树的相关定义2、3、二叉树的相关性质2、4…

Java_网络编程

一、网络编程概述 网络编程的意思就是编写的应用程序可以与网络上其他设备中的应用程序进行数据交互。 网络编程有什么用呢&#xff1f;这个就不言而喻了&#xff0c;比如我们经常用的微信收发消息就需要用到网络通信的技术、在比如我们打开浏览器可以浏览各种网络、视频等也…

Pytest中conftest.py的用法

Pytest中conftest.py的用法 ​ 在官方文档中&#xff0c;描述conftest.py是一个本地插件的文件&#xff0c;简单的说就是在这个文件中编写的方法&#xff0c;可以在其他地方直接进行调用。 注意事项 只能在根目录编写conftest.py 插件加载顺序在搜集用例之前 基础用法 这里…

如何通过系统命令排查账号安全?

如何通过系统命令排查账号安全 query user 查看当前登录账号 logoff id 注销用户id net user 查看用户 net user username 查看用户登录情况 lusrmgr.msc 查看隐藏账号 winR打开regedit注册表 找到计算机\HEKY_LOCAL_MACHINE\SAM\SAM\右键给与用户读写权限 刷新打开 HKEY…

selenium代理ip可用性测试

测试代理ip是否工作正常&#xff0c;将正常的代理ip提取出来 from selenium import webdriver from fake_useragent import UserAgent def check_proxy(proxy):print("开始测试&#xff1a;"proxy)chrome_options webdriver.ChromeOptions()chrome_options.add_arg…

一、认识 JVM 规范(JVM 概述、字节码指令集、Class文件解析、ASM)

1. JVM 概述 JVM&#xff1a;Java Virtual Machine&#xff0c;也就是 Java 虚拟机 所谓虚拟机是指&#xff1a;通过软件模拟的具有完整硬件系统功能的、运行在一个完全隔离环境中的计算机系统。 即&#xff1a;虚拟机是一个计算机系统。这种计算机系统运行在完全隔离的环境中…

大数据学习之Flink,Flink的安装部署

Flink部署 一、了解它的关键组件 客户端&#xff08;Client&#xff09; 作业管理器&#xff08;JobManager&#xff09; 任务管理器&#xff08;TaskManager&#xff09; 我们的代码&#xff0c;实际上是由客户端获取并做转换&#xff0c;之后提交给 JobManger 的。所以 …

centos 7.6 进入单用户模式

1、重启服务器&#xff0c;在选择内核界面使用上下箭头移动 2、选择内核并按“e” 将“RO”改成 rw ,删除 rhgb quiet 添加 init/bin/bash Ctrl X 进入单用户模式 为防止乱码&#xff0c;修改语言为英语 修改完密码建议输入&#xff1a;touch /.autorelabel 更新系统信…

喜讯 | 华院计算摘得“2023大数据产业年度创新技术突破”奖

2024年1月17日&#xff0c; 由数据猿和上海大数据联盟主办&#xff0c;上海市经济和信息化委员会、上海市科学技术委员会指导的“第六届金猿季&魔方论坛——大数据产业发展论坛”在上海市四行仓库举行。论坛以“小趋势大未来”为主题&#xff0c;围绕大数据产业的各个领域展…

Kubernetes operator(一)client-go篇【更新中】

云原生学习路线导航页&#xff08;持续更新中&#xff09; 本文是 Kubernetes operator学习 系列第一篇&#xff0c;主要对client-go进行学习&#xff0c;从源码阅读角度&#xff0c;学习client-go各个组件的实现原理、如何协同工作等参考视频&#xff1a;Bilibili 2022年最新k…

ThinkPHP5.0.0~5.0.23路由控制不严谨导致的RCE

本次我们继续以漏洞挖掘者的视角&#xff0c;来分析thinkphp的RCE 敏感函数发现 在调用入口函数&#xff1a;/ThinkPHP_full_v5.0.22/public/index.php 时 发现了框架底层调用了\thinkphp\library\think\App.php的app类中的incokeMethod方法 注意传递的参数&#xff0c;Refle…

Java基于沙箱环境实现支付宝支付

一、支付宝沙箱环境介绍 沙箱环境是支付宝开放平台为开发者提供的安全低门槛的测试环境&#xff0c;开发者在沙箱环境中调用接口无需具备所需的商业资质&#xff0c;无需绑定和开通产品&#xff0c;同时不会对生产环境中的数据造成任何影响。合理使用沙箱环境&#xff0c;可以…

C语言中的strtok()函数进行字符串分割

引言 在处理文本或字符串数据时&#xff0c;我们常常需要将一长串连续的字符按照特定的分隔符分解成一个个独立的子串。C语言中提供了一个非常实用的库函数——strtok()&#xff0c;用于实现这一功能。本文将通过一段示例代码详细解析并演示如何使用strtok()函数进行字符串分割…

MYSQL之索引语法与使用

索引分类 分类 含义 特点 关键字 主键索引 针对表中主键创建的索引 默认自动创建&#xff0c;只能有一个 PRIMARY 唯一索引 …

春运倒计时,AR 引领铁路运输安全新风向

根据中国交通新闻网发布最新消息&#xff0c;今年春运全国跨区域人员流动量预计达 90 亿人次。 随着春运期间旅客数量不断创下新高&#xff0c;铁路运输面临着空前的挑战与压力。 图源&#xff1a;pixabay 聚焦铁路运输效率与旅客安全保障问题&#xff0c;本期行业趋势将探讨 …

leetcode:反转链表--反转链子表

题目&#xff1a;反转链表 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例&#xff1a; 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1] 提示&#xff1a; 链表中节点的数目范围是 [0, 5000] -5000 < Node.…