PyTorch实战:基于Seq2seq模型处理机器翻译任务(模型预测)

news2025/1/17 3:05:19

文章目录

  • 引言
  • 数据预处理
    • 加载字典对象`en2id`和`zh2id`
    • 文本分词
  • 加载训练好的Seq2Seq模型
  • 模型预测完整代码
  • 结束语

引言

随着全球化的深入,翻译需求日益增长。传统的人工翻译方式虽然质量高,但效率低,成本高。机器翻译的出现,为解决这一问题提供了可能。英译中机器翻译任务是机器翻译领域的一个重要分支,旨在将英文文本自动翻译成中文。本博客以《PyTorch自然语言处理入门与实战》第九章的Seq2seq模型处理英译中翻译任务作为基础,附上模型预测模块。

模型的训练及验证模块的详细解析见PyTorch实战:基于Seq2seq模型处理机器翻译任务(模型训练及验证)

数据预处理

加载字典对象en2idzh2id

在预测阶段中,需要加载模型训练及验证阶段保存的字典对象en2idzh2id

代码如下:

import pickle

with open("en2id.pkl", 'rb') as f:
    en2id = pickle.load(f)
with open("zh2id.pkl", 'rb') as f:
    zh2id = pickle.load(f)

文本分词

在对输入文本进行预测时,需要先将文本进行分词操作。参考代码如下:

def extract_words(sentence):  
    """  
    从给定的英文句子中提取单词,并去除单词后的标点符号。  
      
    Args:  
        sentence (str): 要提取单词的英文句子。  
          
    Returns:  
        List[str]: 提取并处理后的单词列表。  
    """  
    en_words = []  
    for w in sentence.split(' '):  # 将英文句子按空格分词  
        w = w.replace('.', '').replace(',', '')  # 去除跟单词连着的标点符号  
        w = w.lower()  # 统一单词大小写  
        if w:  
            en_words.append(w)  
    return en_words  
  
# 测试函数  
sentence = 'I am Dave Gallo.'  
print(extract_words(sentence))

运行结果:

加载训练好的Seq2Seq模型

代码如下:

import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, emb_dim)  # 词嵌入
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src = (src len, batch size)
        embedded = self.dropout(self.embedding(src))
        # embedded = (src len, batch size, emb dim)
        outputs, (hidden, cell) = self.rnn(embedded)
        # outputs = (src len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)
        # rnn的输出总是来自顶部的隐藏层
        return hidden, cell


class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        # 各输入的形状
        # input = (batch size)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # LSTM是单向的  ==> n directions == 1
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        input = input.unsqueeze(0)  # (batch size)  --> [1, batch size)

        embedded = self.dropout(self.embedding(input))  # (1, batch size, emb dim)

        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        # LSTM理论上的输出形状
        # output = (seq len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # 解码器中的序列长度 seq len == 1
        # 解码器的LSTM是单向的 n directions == 1 则实际上
        # output = (1, batch size, hid dim)
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        prediction = self.fc_out(output.squeeze(0))

        # prediction = (batch size, output dim)

        return prediction, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, input_word_count, output_word_count, encode_dim, decode_dim, hidden_dim, n_layers,
                 encode_dropout, decode_dropout, device):
        """

        :param input_word_count:    英文词表的长度     34737
        :param output_word_count:   中文词表的长度     4015
        :param encode_dim:          编码器的词嵌入维度
        :param decode_dim:          解码器的词嵌入维度
        :param hidden_dim:          LSTM的隐藏层维度
        :param n_layers:            采用n层LSTM
        :param encode_dropout:      编码器的dropout概率
        :param decode_dropout:      编码器的dropout概率
        :param device:              cuda / cpu
        """
        super().__init__()
        self.encoder = Encoder(input_word_count, encode_dim, hidden_dim, n_layers, encode_dropout)
        self.decoder = Decoder(output_word_count, decode_dim, hidden_dim, n_layers, decode_dropout)
        self.device = device

    def forward(self, src):
        # src = (src len, batch size)

        # 编码器的隐藏层输出将作为解码器的第一个隐藏层输入
        hidden, cell = self.encoder(src)

        # 解码器的第一个输入应该是起始标识符<sos>
        input = src[0, :]  # 取trg的第“0”行所有列  “0”指的是索引
        pred = [0] # 预测的第一个输出应该是起始标识符
        top1 = 0
        while top1 != 1 and len(pred) < 100:
            # 解码器的输入包括:起始标识符的词嵌入input; 编码器输出的 hidden and cell states
            # 解码器的输出包括:输出张量(predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)
            top1 = output.argmax(dim=1)  # (batch size, )
            pred.append(top1.item())
            input = top1

        return pred

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  # GPU可用 用GPU
# Seq2Seq模型实例化
source_word_count = 34737  # 英文词表的长度     34737
target_word_count = 4015  # 中文词表的长度     4015
encode_dim = 256  # 编码器的词嵌入维度
decode_dim = 256  # 解码器的词嵌入维度
hidden_dim = 512  # LSTM的隐藏层维度
n_layers = 2  # 采用n层LSTM
encode_dropout = 0.5  # 编码器的dropout概率
decode_dropout = 0.5  # 编码器的dropout概率
model = Seq2Seq(source_word_count, target_word_count, encode_dim, decode_dim, hidden_dim, n_layers, encode_dropout,
                decode_dropout, device).to(device)

# 加载训练好的模型
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

模型预测完整代码

提示预测代码是我们基于训练及验证代码进行改造的,不一定完全正确,可以参考后自行修改~

import torch
import torch.nn as nn
import pickle


class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, emb_dim)  # 词嵌入
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src = (src len, batch size)
        embedded = self.dropout(self.embedding(src))
        # embedded = (src len, batch size, emb dim)
        outputs, (hidden, cell) = self.rnn(embedded)
        # outputs = (src len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)
        # rnn的输出总是来自顶部的隐藏层
        return hidden, cell


class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        # 各输入的形状
        # input = (batch size)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # LSTM是单向的  ==> n directions == 1
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        input = input.unsqueeze(0)  # (batch size)  --> [1, batch size)

        embedded = self.dropout(self.embedding(input))  # (1, batch size, emb dim)

        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        # LSTM理论上的输出形状
        # output = (seq len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # 解码器中的序列长度 seq len == 1
        # 解码器的LSTM是单向的 n directions == 1 则实际上
        # output = (1, batch size, hid dim)
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        prediction = self.fc_out(output.squeeze(0))

        # prediction = (batch size, output dim)

        return prediction, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, input_word_count, output_word_count, encode_dim, decode_dim, hidden_dim, n_layers,
                 encode_dropout, decode_dropout, device):
        """

        :param input_word_count:    英文词表的长度     34737
        :param output_word_count:   中文词表的长度     4015
        :param encode_dim:          编码器的词嵌入维度
        :param decode_dim:          解码器的词嵌入维度
        :param hidden_dim:          LSTM的隐藏层维度
        :param n_layers:            采用n层LSTM
        :param encode_dropout:      编码器的dropout概率
        :param decode_dropout:      编码器的dropout概率
        :param device:              cuda / cpu
        """
        super().__init__()
        self.encoder = Encoder(input_word_count, encode_dim, hidden_dim, n_layers, encode_dropout)
        self.decoder = Decoder(output_word_count, decode_dim, hidden_dim, n_layers, decode_dropout)
        self.device = device

    def forward(self, src):
        # src = (src len, batch size)

        # 编码器的隐藏层输出将作为解码器的第一个隐藏层输入
        hidden, cell = self.encoder(src)

        # 解码器的第一个输入应该是起始标识符<sos>
        input = src[0, :]  # 取trg的第“0”行所有列  “0”指的是索引
        pred = [0] # 预测的第一个输出应该是起始标识符
        top1 = 0
        while top1 != 1 and len(pred) < 100:
            # 解码器的输入包括:起始标识符的词嵌入input; 编码器输出的 hidden and cell states
            # 解码器的输出包括:输出张量(predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)
            top1 = output.argmax(dim=1)  # (batch size, )
            pred.append(top1.item())
            input = top1

        return pred


if __name__ == '__main__':
    sentence = 'I am Dave Gallo.'
    en_words = []

    for w in sentence.split(' '):  # 英文内容按照空格字符进行分词
        # 按照空格进行分词后,某些单词后面会跟着标点符号 "." 和 “,”
        w = w.replace('.', '').replace(',', '')  # 去掉跟单词连着的标点符号
        w = w.lower()  # 统一单词大小写
        if w:
            en_words.append(w)

    print(en_words)

    with open("en2id.pkl", 'rb') as f:
        en2id = pickle.load(f)
    with open("zh2id.pkl", 'rb') as f:
        zh2id = pickle.load(f)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  # GPU可用 用GPU
    # Seq2Seq模型实例化
    source_word_count = 34737  # 英文词表的长度     34737
    target_word_count = 4015  # 中文词表的长度     4015
    encode_dim = 256  # 编码器的词嵌入维度
    decode_dim = 256  # 解码器的词嵌入维度
    hidden_dim = 512  # LSTM的隐藏层维度
    n_layers = 2  # 采用n层LSTM
    encode_dropout = 0.5  # 编码器的dropout概率
    decode_dropout = 0.5  # 编码器的dropout概率
    model = Seq2Seq(source_word_count, target_word_count, encode_dim, decode_dim, hidden_dim, n_layers, encode_dropout,
                    decode_dropout, device).to(device)

    model.load_state_dict(torch.load("best_model.pth"))
    model.eval()

    src = [0] # 0 --> 起始标识符的编码
    for i in range(len(en_words)):
        src.append(en2id[en_words[i]])
    src = src + [1] # 1 --> 终止标识符的编码

    text_input = torch.LongTensor(src)
    text_input = text_input.unsqueeze(-1).to(device)

    text_output = model(text_input)
    print(text_output)
    id2zh = dict()
    for k, v in zh2id.items():
        id2zh[v] = k

    text_output = [id2zh[index] for index in text_output]
    text_output = " ".join(text_output)
    print(text_output)

结束语

  • 亲爱的读者,感谢您花时间阅读我们的博客。我们非常重视您的反馈和意见,因此在这里鼓励您对我们的博客进行评论。
  • 您的建议和看法对我们来说非常重要,这有助于我们更好地了解您的需求,并提供更高质量的内容和服务。
  • 无论您是喜欢我们的博客还是对其有任何疑问或建议,我们都非常期待您的留言。让我们一起互动,共同进步!谢谢您的支持和参与!
  • 我会坚持不懈地创作,并持续优化博文质量,为您提供更好的阅读体验。
  • 谢谢您的阅读!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1340474.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows搭建FTP服务器教学以及计算机端口介绍

目录 一. FTP服务器介绍 FTP服务器是什么意思&#xff1f; 二.Windows Service 2012 搭建FTP服务器 1.开启防火墙 2.创建组 ​编辑3.创建用户 4.用户绑定组 5.安装ftp服务器 ​编辑6.配置ftp服务器 7.配置ftp文件夹的权限 8.连接测试 三.计算机端口介绍 什么是网络…

最新AI系统ChatGPT网站H5系统源码,支持Midjourney绘画,GPT语音对话+ChatFile文档对话总结+DALL-E3文生图

一、前言 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统&#xff0c;支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美&#xff0c;可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作Ch…

12.27_黑马数据结构与算法笔记Java(补1)

目录 266 活动选择问题 分析 267 活动选择问题 贪心 268 分数背包问题 贪心 269 0-1 背包问题 贪心 270 斐波那契 动态规划 271 斐波那契 动态规划 降维 272 Bellman Ford 动态规划 分析 273 Bellman Ford 动态规划 实现1 274 Bellman Ford 动态规划 实现2 275 Leetco…

文献速递:人工智能医学影像分割---一个用于 COVID-19 CT 图像的粗细分割网络

文献速递&#xff1a;人工智能医学影像分割—一个用于 COVID-19 CT 图像的粗细分割网络 01 文献速递介绍 2019 年新型冠状病毒疾病&#xff08;COVID-19&#xff09;正在全球迅速传播。自 2019 年以来&#xff0c;已有超过一千万人感染&#xff0c;其中数十万人死亡。COVID-…

apisix 插件配置 未生效 未起作用

插件配置完成&#xff0c;却没生效&#xff0c;请检查插件的启用状态是否是启用状态&#xff0c; 以某个route配置的限速插件&#xff08;limit-req&#xff09;为例 1.打开dashboad-->路由-->某个路由-->更多-->查看&#xff0c; 查看配置&#xff0c;实际未启用…

leaflet学习笔记-初始化vue项目(一)

leaflet简介 Leaflet是一款开源的轻量级交互式地图可视化JavaScript库&#xff0c;能够满足大多数开发者的地图可视化需求&#xff0c;其最早的版本大小仅仅38 KB。Leaflet能够在主流的计算机或移动设备上高效运行&#xff0c;其功能可通过插件进行扩展&#xff0c;拥有易于使用…

Day20 222完全二叉树的节点个数 110平衡二叉树 257二叉树的所有路径

222 完全二叉树的结点个数 本题先不把它当成完全二叉树来看&#xff0c;用广度优先和深度优先搜索分别遍历&#xff0c;也能达到目的&#xff0c;只要将之前的代码稍加修改即可。注意后序遍历时的result要加上自身本身的那个结点。 //后序递归遍历 class Solution { public:in…

CGAL的D维包围盒相交计算

包围盒相交测试是一种用于快速判断两个三维对象是否相交的方法&#xff0c;而AABB树则是一种数据结构&#xff0c;常用于加速场景中的射线检测和碰撞检测。 首先&#xff0c;让我们了解一下包围盒相交测试。这种测试的目的是为了快速判断两个三维对象是否相交&#xff0c;而不需…

数据仓库 基本信息

数据仓库基本理论 数据仓库&#xff08;英语&#xff1a;Data Warehouse&#xff0c;简称数仓、DW&#xff09;,是一个用于存储、分析、报告的数据系统。数据仓库的目的是构建面向分析的集成化数据环境&#xff0c;为企业提供决策支持&#xff08;Decision Support&#xff09…

【轻松入门】OpenCV4.8 + QT5.x开发环境搭建

引言 大家好&#xff0c;今天给大家分享一下最新版本OpenCV4.8 QT5 如何一起配置&#xff0c;完成环境搭建的。 下载OpenCV4.8并解压缩 软件版本支持 CMake3.13 或者以上版本 https://cmake.org/ VS2017专业版或者以上版本 QT5.15.2 OpenCV4.8源码包 https://github.com/op…

主浏览器优化之路1——你现在在用的是什么浏览器?Edge?谷歌?火狐?360!?

上一世&#xff0c;我的浏览器之路 引言为什么要用两个浏览器为什么一定要放弃火狐结尾给大家一个猜数字小游戏&#xff08;测运气&#xff09; 引言 小时候&#xff0c;我一开始上网的浏览器是2345王牌浏览器吧&#xff0c; 因为上面集成了很多网站&#xff0c;我记得上面有7…

【MySQL】多表连接查询

&#x1f34e;个人博客&#xff1a;个人主页 &#x1f3c6;个人专栏&#xff1a; 数 据 库 ⛳️ 功不唐捐&#xff0c;玉汝于成 目录 前言 正文 1. 交叉连接&#xff08;CROSS JOIN&#xff09; 2. 内连接&#xff08;INNER JOIN&#xff09; 3. 外连接 结语 我的…

几款软件助您事半功倍

在如今繁忙而竞争激烈的工作环境中&#xff0c;寻找适合自己的工作软件是提高工作效率、优化工作流程的重要一环。为了帮助你更好地管理任务、组织工作和提高生产力&#xff0c;我将向你推荐四款备受推崇的工作软件&#xff0c;并详细介绍它们各自的功能和特点。 1. Zoom&#…

记录使用minikube部署web程序,并灰度发布不同版本

1. 安装软件 1.1安装docker desktop 下载地址 重点&#xff1a;配置镜像加速 1.2 安装k8s&minikube 这里参考阿里社区的配置 minikube1.24.0版本下载地址 重点&#xff1a;安装版本问题【因为后面要用阿里云的服务来获取所需Docker镜像&#xff0c;一直不成功使用的高版…

软件测试/测试开发丨Pytest学习笔记

Pytest 格式要求 文件: 以 test_ 开头或以 _test 结尾类: 以 Test 开头方法/函数: 以 _test 开头测试类中不可以添加构造函数, 若添加构造函数将导致Pytest无法识别类下的测试方法 断言 与Unittest不同, 在Pytest中我们需要使用python自带的 assert 关键字进行断言 assert…

JOSEF约瑟 双位置继电器 DCS-12/110V 线圈电压直流110V 板前安装

系列型号&#xff1a; DCS-11双位置继电器&#xff1b; DCS-12双位置继电器&#xff1b; DCS-13双位置继电器&#xff1b; RXMVB2 RK 251 204双位置继电器&#xff1b; RXMVB2 RK 251 205双位置继电器&#xff1b; RXMVB2 RK 251 106双位置继电器&#xff1b; 一、用途 …

Flink项目实战篇 基于Flink的城市交通监控平台(下)

系列文章目录 Flink项目实战篇 基于Flink的城市交通监控平台&#xff08;上&#xff09; Flink项目实战篇 基于Flink的城市交通监控平台&#xff08;下&#xff09; 文章目录 系列文章目录4. 智能实时报警4.1 实时套牌分析4.2 实时危险驾驶分析4.3 出警分析4.4 违法车辆轨迹跟…

6.Nacos

1.单机部署 1.1 官网 https://nacos.io/zh-cn/index.html https://github.com/alibaba/Nacos 1.2.版本说明 https://github.com/alibaba/spring-cloud-alibaba/wiki/%E7%89%88%E6%9C%AC%E8%AF%B4%E6%98%8E 1.3.下载地址 https://github.com/alibaba/nacos/releases/tag/2.2.…

百度CTO王海峰:飞桨开发者已达1070万

目录 写在前面 飞桨开发者已达1070万 文心一言用户规模破亿&#xff0c;日提问量快速增长 写在前面 “文心一言用户规模突破1亿。”12月28日&#xff0c;百度首席技术官、深度学习技术及应用国家工程研究中心主任王海峰在第十届WAVE SUMMIT深度学习开发者大会上宣布。会上&…

全平台去水印系统源码:画质高清无损害,一键下载 支持目前主流80多个平台无水印下载 带完整的安装部署教程

在数字内容爆炸的时代&#xff0c;图片和视频的传播和使用越来越频繁。然而&#xff0c;许多优质资源都带有水印&#xff0c;不仅影响了美观&#xff0c;也在一定程度上限制了资源的再利用。传统的去水印方法往往操作复杂&#xff0c;效果不尽如人意&#xff0c;甚至可能损害原…