基于LSTM encoder-decoder模型实现英文转中文的翻译机器

news2024/12/25 12:53:16

前言

神经网络机器翻译(NMT, neuro machine tranlation)是AIGC发展道路上的一个重要应用。正是对这个应用的研究,发展出了注意力机制,在此基础上产生了AIGC领域的霸主transformer。我们今天先把注意力机制这些东西放一边,介绍一个对机器翻译起到重要里程碑作用的模型:LSTM encoder-decoder模型(sutskever et al. 2014)。根据这篇文章的描述,这个模型不需要特别的优化,就可以取得超过其他NMT模型的效果,所以我们也来动手实现一下,看看是不是真的有这么厉害。

模型

原文作者采用了4层LSTM模型,每层有1000个单元(每个单元有输入门,输出门,遗忘门和细胞状态更新共计4组状态),采用1000维单词向量,纯RNN部分,就有64M参数。同时,在encoder的输出,和decoder的输出后放一个长度为80000的softmax层(因为论文的输出字典长80000),用于softmax的参数量为320M。整个模型共计320M + 64M = 384M。该模型用了8GPU的服务器训练了10天。
模型大概长这样:
在这里插入图片描述
按照现在的算力价格,用8张4090的主机训练每小时要花20多块钱,训练一轮下来需要花费小5000,笔者当然没有这么土豪,所以我们会使用一个参数量小得多的模型,主要为了记录整个搭建过程使用到的工具链和技术。另外,由于笔者使用了一个预训练的词向量库,包含了中英文单词共计128万多条,其中中文90多万,英文30多万,要像论文中一样用一个超大的softmax来预测每个词的概率并不现实,因此先使用一个linear层再加上relu来简化,加快训练过程,只求能看到收敛。

笔者的模型看起来像这样:
在这里插入图片描述
该模型的主要参数如下:
词向量维度:300
LSTM隐藏层个数:600
LSTM层数:4
linear层输入:600
linear层输出:300
模型参数个数如下为:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Seq2Seq                                  [1, 11, 300]              --
├─Encoder: 1-1                           [1, 300]                  --
│    └─LSTM: 2-1                         [1, 10, 600]              10,819,200
│    └─Linear: 2-2                       [1, 300]                  180,300
│    └─ReLU: 2-3                         [1, 300]                  --
├─Decoder: 1-2                           [1, 11, 300]              --
│    └─LSTM: 2-4                         [1, 11, 600]              10,819,200
│    └─Linear: 2-5                       [1, 11, 300]              180,300
│    └─ReLU: 2-6                         [1, 11, 300]              --
==========================================================================================
Total params: 21,999,000
Trainable params: 21,999,000
Non-trainable params: 0
Total mult-adds (M): 227.56
==========================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.13
Params size (MB): 88.00
Estimated Total Size (MB): 88.15
==========================================================================================

如果大家希望了解LSTM层的10,819,200个参数如何计算出来,可以参考pytorch源码 pytorch/torch/csrc/api/src/nn/modules/rnn.cpp中方法void RNNImplBase::reset()的实现。笔者如果日后有空也可能会写一写。

3 单词向量及语料

3.1 语料

先说语料,NMT需要大量的平行语料,语料可以从这里获取。另外有个语料天涯网站大全分享给大家。

3.2 词向量

首先需要对句子进行分词,中英文都需要做分词。中文分词工具本例采用jieba,可直接安装。

$ pip install jieba
...
$ python
Python 3.11.6 (tags/v3.11.6:8b6ee5b, Oct  2 2023, 14:57:12) [MSC v.1935 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> for token in jieba.cut("我爱踢足球!", cut_all=True):
...     print(token)
... 
我
爱
踢足球
足球
!

英文分词采用nltk,安装之后,需要下载一个分词模型。

$ pip install nltk
...
$ python
Python 3.11.6 (tags/v3.11.6:8b6ee5b, Oct  2 2023, 14:57:12) [MSC v.1935 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import nltk
>>> nltk.download("punkt")
...
>>> from nltk import word_tokenize
>>> word_tokenize('i love you')
['i', 'love', 'you']

国内有墙,一般下载不了,所以可以到这里找到punkt文件并下载,解压到~/nltk_data/tokenizers/下边。

3.3 加载语料代码

import xml.etree.ElementTree as ET

class TmxHandler():

    def __init__(self):
        self.tag=None
        self.lang=None
        self.corpus={}

    def handleStartTu(self, tag):
        self.tag=tag
        self.lang=None
        self.corpus={}

    def handleStartTuv(self, tag, attributes):
        if self.tag == 'tu':
            if attributes['{http://www.w3.org/XML/1998/namespace}lang']:
                self.lang=attributes['{http://www.w3.org/XML/1998/namespace}lang']
            else:
                raise Exception('tuv element must has a xml:lang attribute')
            self.tag = tag
        else:
            raise Exception('tuv element must go under tu, not ' + tag)
    
    def handleStartSeg(self, tag, elem):
        if self.tag == 'tuv':
            self.tag = tag
            if self.lang:
                if elem.text:
                    self.corpus[self.lang]=elem.text
            else:
                raise Exception('lang must not be none')
        else:
            raise Exception('seg element must go under tuv, not ' + tag)

    def startElement(self, tag, attributes, elem):
        if tag== 'tu':
            self.handleStartTu(tag)
        elif tag == 'tuv':
            self.handleStartTuv(tag, attributes)
        elif tag == 'seg':
            self.handleStartSeg(tag, elem)

    def endElem(self, tag):
        if self.tag and self.tag != tag:
            raise Exception(self.tag + ' could not end with ' + tag)

        if tag == 'tu':
            self.tag=None
            self.lang=None
            self.corpus={}
        elif tag == 'tuv':
            self.tag='tu'
            self.lang=None
        elif tag == 'seg':
            self.tag='tuv'

    def parse(self, filename):
        for event, elem in ET.iterparse(filename, events=('start','end')):
            if event == 'start':
                self.startElement(elem.tag, elem.attrib, elem)
            elif event == 'end':
                if elem.tag=='tu':
                    yield self.corpus
                self.endElem(elem.tag)

3.4 句子转词向量代码

from gensim.models import KeyedVectors
import torch
import jieba
from nltk import word_tokenize
import numpy as np

class WordEmbeddingLoader():
    def __init__(self):
        pass

    def load(self, fname):
        self.model = KeyedVectors.load_word2vec_format(fname)

    def get_embeddings(self, word:str):
        if self.model:
            try:
                return self.model.get_vector(word)
            except(KeyError):
                return None
        else:
            return None
    
    def get_scentence_embeddings(self, scent:str, lang:str):
        embeddings = []
        ws = []
        if(lang == 'zh'):
            ws = jieba.cut(scent, cut_all=True)
        elif lang == 'en':
            ws = word_tokenize(scent)
        else:
            raise Exception('Unsupported language ' + lang)

        for w in ws:
            embedding = self.get_embeddings(w.lower())
            if embedding is None:
                embedding = np.zeros(self.model.vector_size)

            embedding = torch.from_numpy(embedding).float()
            embeddings.append(embedding.unsqueeze(0))
        return torch.cat(embeddings, dim=0)

4 模型代码实现

4.1 encoder

import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, device, embeddings=300, hidden_size=600, num_layers=4):
        super().__init__()
        self.device = device
        self.hidden_layer_size = hidden_size
        self.n_layers = num_layers
        self.embedding_size = embeddings
        self.lstm = nn.LSTM(embeddings, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, embeddings)
        self.relu = nn.ReLU()

    def forward(self, x):
        # x: [batch size, seq length, embeddings]
        # lstm_out: [batch size, x length, hidden size]
        lstm_out, (hidden, cell) = self.lstm(x)
        
		# linear input is the lstm output of the last word
        lineared = self.linear(lstm_out[:,-1,:].squeeze(1))
        out = self.relu(lineared)

        # hidden: [n_layer, batch size, hidden size]
        # cell: [n_layer, batch size, hidden size]
        return out, hidden, cell

4.2 decoder

import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, device, embedding_size=300, hidden_size=900, num_layers=4):
        super().__init__()
        self.device = device
        self.hidden_layer_size = hidden_size
        self.n_layers = num_layers
        self.embedding_size = embedding_size
        self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, embedding_size)
        self.relu = nn.ReLU()

    def forward(self, x, hidden_in, cell_in):

        # x: [batch_size, x length, embeddings]
        # hidden: [n_layers, batch size, hidden size]
        # cell: [n_layers, batch size, hidden size]
        # lstm_out: [seq length, batch size, hidden size]
        lstm_out, (hidden,cell) = self.lstm(x, (hidden_in, cell_in))

        # prediction: [seq length, batch size, embeddings]
        prediction=self.relu(self.linear(lstm_out))
        return prediction, hidden, cell

4.3 encoder-decoder

接下来把encoder和decoder串联起来。

import torch
import encoder as enc
import decoder as dec
import torch.nn as nn
import time

class Seq2Seq(nn.Module):
    def __init__(self, device, embeddings, hiddens, n_layers):
        super().__init__()
        self.device = device
        self.encoder = enc.Encoder(device, embeddings, hiddens, n_layers)
        self.decoder= dec.Decoder(device, embeddings, hiddens, n_layers)
        self.embeddings = self.encoder.embedding_size
        assert self.encoder.n_layers == self.decoder.n_layers, "Number of layers of encoder and decoder must be equal!"
        assert self.decoder.hidden_layer_size==self.decoder.hidden_layer_size, "Hidden layer size of encoder and decoder must be equal!"

    # x: [batches, x length, embeddings]
    # x is the source scentences
    # y: [batches, y length, embeddings]
    # y is the target scentences
    def forward(self, x, y):

        # encoder_out: [batches, n_layers, embeddings]
        # hidden, cell: [n layers, batch size, embeddings]
        encoder_out, hidden, cell = self.encoder(x)

        # use encoder output as the first word of the decode sequence
        decoder_input = torch.cat((encoder_out.unsqueeze(0), y), dim=1)

        # predicted: [batches, y length, embeddings]
        predicted, hidden, cell = self.decoder(decoder_input, hidden, cell)

        return predicted

5 模型训练

5.1 训练代码


def do_train(model:Seq2Seq, train_set, optimizer, loss_function):

    step = 0

    model.train()
    
    # seq: [seq length, embeddings]
    # labels: [label length, embeddings]
    for seq, labels in train_set:
        step = step + 1

        # ignore the last word of the label scentence
        # because it is to be predicted
        label_input = labels[:-1].unsqueeze(0)

        # seq_input: [1, seq length, embeddings]
        seq_input = seq.unsqueeze(0)

        # y_pred: [1, seq length + 1, embeddings]
        y_pred = model(seq_input, label_input)

        # single_loss = loss_function(y_pred.squeeze(0), labels.to(self.device))
        single_loss = loss_function(y_pred.squeeze(0), labels)
        
        optimizer.zero_grad()
        single_loss.backward()
        optimizer.step()

        print_steps = 100
        if print_steps != 0 and step%print_steps==1:
            print(f'[step: {step} - {time.asctime(time.localtime(time.time()))}] - loss:{single_loss.item():10.8f}')

def train(device, model, embedding_loader, corpus_fname, batch_size:int, batches: int):
    reader = corpus_reader.TmxHandler()
    loss = torch.nn.MSELoss()
    # summary(model, input_size=[(1, 10, 300),(1,10,300)])

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    generator = reader.parse(corpus_fname)
    for _b in range(batches):
        batch = []
        try:
            for _c in range(batch_size):
                try:
                    corpus = next(generator)
                    if 'en' in corpus and 'zh' in corpus:
                        en = embedding_loader.get_scentence_embeddings(corpus['en'], 'en').to(device)
                        zh = embedding_loader.get_scentence_embeddings(corpus['zh'], 'zh').to(device)
                        batch.append((en,zh))
                except (StopIteration):
                    break
        finally:
            print(time.localtime())
            print("batch: " + str(_b))
            do_train(model, batch, optimizer, loss)
            torch.save(model, "./models/seq2seq_" + str(time.time()))
        
if __name__=="__main__":
    # device = torch.device('cuda')
    device = torch.device('cpu')
    embeddings = 300
    hiddens = 600
    n_layers = 4

    embedding_loader = word2vec.WordEmbeddingLoader()
    print("loading embedding")
    # a full vocabulary takes too long to load, a baby vocabulary is used for demo purpose
    embedding_loader.load("../sgns.merge.word.toy")
    print("load embedding finished")

	# if there is an existing model, load the existing model from file
    # model_fname = "./models/_seq2seq_1698000846.3281412"
    model_fname = None
    model = None
    if not model_fname is None:
        print('loading model from ' + model_fname)
        model = torch.load(model_fname, map_location=device)
        print('model loaded')
    else:
        model = Seq2Seq(device, embeddings, hiddens, n_layers).to(device)

    train(device, model, embedding_loader, "../News-Commentary_v16.tmx", 1000, 100)

5.2 使用CPU进行训练

让我们先来体验一下CPU的龟速训练。下图是每100句话的训练输出。每次打印的间隔大约为2-3分钟。

[step: 1 - Thu Oct 26 05:14:13 2023] - loss:0.00952744
[step: 101 - Thu Oct 26 05:17:11 2023] - loss:0.00855174
[step: 201 - Thu Oct 26 05:20:07 2023] - loss:0.00831730
[step: 301 - Thu Oct 26 05:23:09 2023] - loss:0.00032693
[step: 401 - Thu Oct 26 05:25:55 2023] - loss:0.00907284
[step: 501 - Thu Oct 26 05:28:55 2023] - loss:0.00937218
[step: 601 - Thu Oct 26 05:32:00 2023] - loss:0.00823146

5.3 使用GPU进行训练

如果把main函数的第一行中的"cpu"改成“cuda”,则可以使用显卡进行训练。笔者使用的是一张GTX1660显卡,打印间隔缩短为15秒。

[step: 1 - Thu Oct 26 06:38:45 2023] - loss:0.00955237
[step: 101 - Thu Oct 26 06:38:50 2023] - loss:0.00844441
[step: 201 - Thu Oct 26 06:38:56 2023] - loss:0.00820994
[step: 301 - Thu Oct 26 06:39:01 2023] - loss:0.00030389
[step: 401 - Thu Oct 26 06:39:06 2023] - loss:0.00896622
[step: 501 - Thu Oct 26 06:39:11 2023] - loss:0.00929985
[step: 601 - Thu Oct 26 06:39:17 2023] - loss:0.00813591

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1138463.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[论文阅读]Point Density-Aware Voxels for LiDAR 3D Object Detection(PDV)

PDV Point Density-Aware Voxels for LiDAR 3D Object Detection 论文网址:PDV 论文代码:PDV 简读论文 摘要 LiDAR 已成为自动驾驶中主要的 3D 目标检测传感器之一。然而,激光雷达的发散点模式随着距离的增加而导致采样点云不均匀&#x…

云原生架构设计理论与实践

云原生架构设计理论与实践 云原生架构概述 云原生的背景 云原生定义和特征 云原生架构的设计原则 架构模式 服务化架构模式 Mesh化架构模式 Serverless模式 存储计算分离模式 分布式事务模式 可观测架构 事件驱动架构 云原生架构相关技术 容器技术 云原生微服务技术 无服务…

orm连接mysql

7.2 ORM ORM可以帮助我们做两件事 创建、修改、删除数据库中的表(不用写SQL语句)。无法创建数据库操作表中的数据(操作表中的数据)。 1.自己创建数据库 启动自己的mysql服务自带的工具创建数据库 create database gx_day5 DE…

CMake多文件构建初步

前面学习了cmake,不熟悉,只是记录了操作过程;下面再继续; 略有一点进步,增加一个代码文件,之前是1个代码文件; 如下图,prj是空文件夹, CMakeLists.txt如下;…

MySQL多表关联on和where速度对比实测谁更快

MySQL多表关联on和where速度对比实测谁更快 背景 今天发现有人在讨论:两张MySQL的数据表按照某一个字段进行关联的时候查询,我们使用on和where哪种查询方式更快。百闻不如一见,我们来亲自测试下。 先说结论 Where、对等查询的join速度基本…

vue使用smooth-signature实现移动端电子签字,包括横竖屏

vue使用smooth-signature实现移动端电子签字&#xff0c;包括横竖屏 1.使用smooth-signature npm install --save smooth-signature二.页面引入插件 import SmoothSignature from "smooth-signature";三.实现效果 四.完整代码 <template><div class&quo…

第13期 | GPTSecurity周报

GPTSecurity是一个涵盖了前沿学术研究和实践经验分享的社区&#xff0c;集成了生成预训练 Transformer&#xff08;GPT&#xff09;、人工智能生成内容&#xff08;AIGC&#xff09;以及大型语言模型&#xff08;LLM&#xff09;等安全领域应用的知识。在这里&#xff0c;您可以…

asp.net学生考试报名管理系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net学生考试报名管理系统是一套完善的web设计管理系统系统&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为vs2010&#xff0c;数据库为sqlserver2008&#xff0c;使 用c#语言开发 应用技术&#xff1a;asp…

YouTrack 中如何设置邮件通知

在 YouTrack 中&#xff0c;默认是不会邮件通知的。 你可以为你的账号设置邮件通知。 设置的方法为单击用户属性&#xff0c;然后在弹出的小窗口中选择属性选项。 设置邮件通知 在通知 Tab 页面中&#xff0c;选择发送邮件的方式&#xff0c;默认这个选项是不选择的。 用户…

React之如何捕获错误

一、是什么 错误在我们日常编写代码是非常常见的 举个例子&#xff0c;在react项目中去编写组件内JavaScript代码错误会导致 React 的内部状态被破坏&#xff0c;导致整个应用崩溃&#xff0c;这是不应该出现的现象 作为一个框架&#xff0c;react也有自身对于错误的处理的解…

spring boot利用redis作为缓存

一、缓存介绍 在 Spring Boot 中&#xff0c;可以使用 Spring Cache abstraction 来实现缓存功能。Spring Cache abstraction 是 Spring 框架提供的一个抽象层&#xff0c;它对底层缓存实现&#xff08;如 Redis、Ehcache、Caffeine 等&#xff09;进行了封装&#xff0c;使得在…

vue实现多时间文字条件查询

// 搜索功能 // 获取搜索框内容 const dateOneref() const dateTworef() // 重新创建数组 const itemList ref([]); const query()>{console.log(formattedDateTime.value);console.log(list.value);itemList.value list.value.filter(item > {//获取数据框的内容是否与…

垃圾回收系统小程序

在当今社会&#xff0c;废品回收不仅有利于环境保护&#xff0c;也有利于资源的再利用。随着互联网技术的发展&#xff0c;个人废品回收也可以通过小程序来实现。本文将介绍如何使用乔拓云网制作个人废品回收小程序。 1. 找一个合适的第三方制作平台/工具&#xff0c;比如乔拓云…

C++之Linux syscall实例总结(二百四十六)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…

进程和多线程

目录 进程 1. 如何管理进程 2. 进程调度 3. 内存管理 4. 进程间通信 多线程 线程和进程的关系&#xff1a; 线程安全问题 进程 一个正在运行的程序,就是一个 进程,进程是一个重要的 "软件资源",是由操作系统内核负责管理的。每个进程都对应一些资源,在上图中…

【面试经典150 | 栈】简化路径

文章目录 Tag题目来源题目解读解题思路方法一&#xff1a;字符串数组模拟栈 其他语言python3 写在最后 Tag 【栈】【字符串】 题目来源 71. 简化路径 题目解读 将 Unix 风格的绝对路径转化成更加简洁的规范路径。字符串中会出现 字母、数字、/、_、. 和 .. 这几种字符&#…

c语言之源码反码和补码

c语言源码反码和补码 c语言之源码反码和补码 c语言源码反码和补码一、源码反码补码的介绍二、源码反码补码例子三、源码反码补码练习 一、源码反码补码的介绍 原码、反码、补码是计算机中对数字的二进制表示方法。 原码&#xff1a;将最高位作为符号位&#xff08;0表示正&…

sipp3.6多方案压测脚本

概述 SIP压测工具sipp&#xff0c;免费&#xff0c;开源&#xff0c;功能足够强大&#xff0c;配置灵活&#xff0c;优点多。 有时候我们需要模拟现网的生产环境来压测&#xff0c;就需要同时有多个sipp脚本运行&#xff0c;并且需要不断的调整呼叫并发。 通过python脚本的子…

一文讲透 “中间层” 思想

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;大厂高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《EffectiveJava》独家解析》专栏作者。 热门文章推荐&…

【打靶】vulhub打靶复现系列3---Chronos

【打靶】vulhub打靶复现系列3---Chronos 一、主机探测 结合之前的方法&#xff08;arp探测、ping检测&#xff09;&#xff0c;因为我们的靶机和攻击机都在第二层&#xff0c;所以打靶时候我们更依赖arp协议 tips&#xff1a;我在运行期间发现&#xff0c;netdiscover窗口没关…