【深度学习入门篇 ⑩】Seq2Seq模型:语言翻译

news2024/9/20 16:33:49

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


今天我们进入 Seq2Seq 的领域,了解这种更为复杂且功能强大的模型,它不仅能理解词汇(Word2Vec),还能把这些词汇串联成完整的句子。 

Seq2Seq

Seq2Seq(Sequence-to-Sequence),就是从一个序列到另一个序列的转换。它不仅仅能理解单词之间的关系,而且还能把整个句子的意思打包,并解压成另一种形式的表达。

seq2seq是一种神经网络架构,是由encoder(编码器)decoder(解码器)两个RNN的组成的。其中encoder负责对输入句子的理解,转化为context vector,decoder负责对理解后的句子的向量进行处理,解码,获得输出。  

Seq2seq模型中的encoder接受一个长度为M的序列,得到1个 context vector,之后decoder把这一个context vector转化为长度为N的序列作为输出,从而构成一个M to N的模型,能够处理很多不定长输入输出的问题,比如:文本翻译,问答,文章摘要,关键字写诗等等

  • 编码器的任务是读取并理解输入序列,然后把它转换为一个固定长度的上下文向量,也叫作状态向量。
  • 解码器的任务是接收编码器生成的上下文向量,并基于这个向量生成目标序列。 

可以加入注意力机制(Attention Mechanism):使解码器能够在生成每个输出元素时“关注”输入序列中的不同部分,从而提高模型处理长序列和捕捉复杂依赖关系的能力。 

Seq2Seq模型实现

 任务:

完成一个模型,实现往模型输入一串数字,输出这串数字+0

  •  输入12345678,输出123456780

实现流程

  • 文本转化为序列

  • 使用序列,准备数据集,准备Dataloader

  • 完成编码器

  • 完成解码器

  • 完成seq2seq模型

  • 完成模型训练的逻辑,进行训练

  • 完成模型评估的逻辑,进行模型评估

训练时可以使用GPU训练:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("训练设备为:", device)

文本转化为序列

由于输入的是数字,为了把这写数字和词典中的真实数字进行对应,可以把这些数字理解为字符串

class NumSequence:
    UNK_TAG = "UNK" 
    PAD_TAG = "PAD" 
    EOS_TAG = "EOS" #句子开始
    SOS_TAG = "SOS" #句子结束

    UNK = 0
    PAD = 1
    EOS = 2
    SOS = 3

    def __init__(self):
        self.dict = {
            self.UNK_TAG : self.UNK,
            self.PAD_TAG : self.PAD,
            self.EOS_TAG : self.EOS,
            self.SOS_TAG : self.SOS
        }
        # 字符串和数字对应的字典
        for i in range(10):
            self.dict[str(i)] = len(self.dict)
        self.index2word = dict(zip(self.dict.values(),self.dict.keys()))

    def __len__(self):
        return len(self.dict)

    def transform(self,sequence,max_len=None,add_eos=False):

        
        sequence_list = list(str(sequence))
        seq_len = len(sequence_list)+1 if add_eos else len(sequence_list)

        if add_eos and max_len is not None:
            assert max_len>= seq_len, "max_len 应该大于seq+eos的长度"
        _sequence_index = [self.dict.get(i,self.UNK) for i in sequence_list]
        if add_eos:
            _sequence_index += [self.EOS]
        if max_len is not None:
            sequence_index = [self.PAD]*max_len
            sequence_index[:seq_len] =  _sequence_index
            return sequence_index
        else:
            return _sequence_index

    def inverse_transform(self,sequence_index):
        result = []
        for i in sequence_index:
            if i==self.EOS:
                break
            result.append(self.index2word.get(int(i),self.UNK_TAG))
        return result

num_sequence = NumSequence()

if __name__ == '__main__':
    num_sequence = NumSequence()
    print(num_sequence.dict)
    print(num_sequence.index2word)
    print(num_sequence.transform("232356",add_eos=True))
准备Dataset
from torch.utils.data import Dataset,DataLoader
import numpy as np
from word_sequence import num_sequence
import torch
import config

class RandomDataset(Dataset):
    def __init__(self):
        super(RandomDataset,self).__init__()
        self.total_data_size = 500000
        np.random.seed(10)
        self.total_data = np.random.randint(1,100000000,size=[self.total_data_size])

    def __getitem__(self, idx):

        input = str(self.total_data[idx])
        return input, input+ "0",len(input),len(input)+1

    def __len__(self):
        return self.total_data_size
准备DataLoader

在准备DataLoader的过程中,可以通过定义的collate_fn来实现对dataset中batch数据的处理

def collate_fn(batch):

    batch = sorted(batch,key=lambda x:x[3],reverse=True)
    input,target,input_length,target_length = zip(*batch)


    input = torch.LongTensor([num_sequence.transform(i,max_len=config.max_len) for i in input])
    target = torch.LongTensor([num_sequence.transform(i,max_len=config.max_len,add_eos=True) for i in target])
    input_length = torch.LongTensor(input_length)
    target_length = torch.LongTensor(target_length)

    return input,target,input_length,target_length

data_loader = DataLoader(dataset=RandomDataset(),batch_size=config.batch_size,collate_fn=collate_fn,drop_last=True)

编码器

目的就是为了对文本进行编码,把编码后的结果交给后续的程序使用,使用Embedding+GRU

import torch.nn as nn
from word_sequence import num_sequence
import config


class NumEncoder(nn.Module):
    def __init__(self):
        super(NumEncoder,self).__init__()
        self.vocab_size = len(num_sequence)
        self.dropout = config.dropout
        self.embedding = nn.Embedding(num_embeddings=self.vocab_size,embedding_dim=config.embedding_dim,padding_idx=num_sequence.PAD)
        self.gru = nn.GRU(input_size=config.embedding_dim,
                          hidden_size=config.hidden_size,
                          num_layers=1,
                          batch_first=True)

    def forward(self, input,input_length):
        
        embeded = self.embedding(input) 

        embeded = nn.utils.rnn.pack_padded_sequence(embeded,lengths=input_length,batch_first=True)


        out,hidden = self.gru(embeded)
        

        out,outputs_length = nn.utils.rnn.pad_packed_sequence(out,batch_first=True,padding_value=num_sequence.PAD)
        return out,hidden

解码器

主要负责实现对编码之后结果的处理,得到预测值

import torch
import torch.nn as nn
import config
import random
import torch.nn.functional as F
from word_sequence import num_sequence

class NumDecoder(nn.Module):
    def __init__(self):
        super(NumDecoder,self).__init__()
        self.max_seq_len = config.max_len
        self.vocab_size = len(num_sequence)
        self.embedding_dim = config.embedding_dim
        self.dropout = config.dropout

        self.embedding = nn.Embedding(num_embeddings=self.vocab_size,embedding_dim=self.embedding_dim,padding_idx=num_sequence.PAD)
        self.gru = nn.GRU(input_size=self.embedding_dim,
                          hidden_size=config.hidden_size,
                          num_layers=1,
                          batch_first=True,
                          dropout=self.dropout)
        self.log_softmax = nn.LogSoftmax()

        self.fc = nn.Linear(config.hidden_size,self.vocab_size)

    def forward(self, encoder_hidden,target,target_length):
        

        decoder_input = torch.LongTensor([[num_sequence.SOS]]*config.batch_size)


        decoder_outputs = torch.zeros(config.batch_size,config.max_len,self.vocab_size) 
		
        decoder_hidden = encoder_hidden 
        for t in range(config.max_len):
            decoder_output_t , decoder_hidden = self.forward_step(decoder_input,decoder_hidden)
            

            decoder_outputs[:,t,:] = decoder_output_t
			

            use_teacher_forcing = random.random() > 0.5
            if use_teacher_forcing:

                decoder_input =target[:,t].unsqueeze(1) 
            else:

                value, index = torch.topk(decoder_output_t, 1) 
                decoder_input = index
        return decoder_outputs,decoder_hidden

    def forward_step(self,decoder_input,decoder_hidden):
        
        embeded = self.embedding(decoder_input)  

        out,decoder_hidden = self.gru(embeded,decoder_hidden) 

       	out = out.squeeze(0) 

        out = F.log_softmax(self.fc(out),dim=-1)
        out = out.squeeze(1)
        return out,decoder_hidden

seq2seq模型

完成模型的搭建

import torch
import torch.nn as nn

class Seq2Seq(nn.Module):
    def __init__(self,encoder,decoder):
        super(Seq2Seq,self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input,target,input_length,target_length):
  
        encoder_outputs,encoder_hidden = self.encoder(input,input_length)

        decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,target_length)
        return decoder_outputs,decoder_hidden

完成训练:

import torch
import config
from torch import optim
import torch.nn as nn
from encoder import NumEncoder
from decoder import NumDecoder
from seq2seq import Seq2Seq
from dataset import data_loader as train_dataloader
from word_sequence import num_sequence



encoder = NumEncoder()
decoder = NumDecoder()
model = Seq2Seq(encoder,decoder)
print(model)


optimizer =  optim.Adam(model.parameters())
criterion= nn.NLLLoss(ignore_index=num_sequence.PAD,reduction="mean")

def get_loss(decoder_outputs,target):

    target = target.view(-1)
    decoder_outputs = decoder_outputs.view(config.batch_size*config.max_len,-1)
    return criterion(decoder_outputs,target)


def train(epoch):
    for idx,(input,target,input_length,target_len) in enumerate(train_dataloader):
        optimizer.zero_grad()
        ##[seq_len,batch_size,vocab_size] [batch_size,seq_len]
        decoder_outputs,decoder_hidden = model(input,target,input_length,target_len)
        loss = get_loss(decoder_outputs,target)
        loss.backward()
        optimizer.step()

        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, idx * len(input), len(train_dataloader.dataset),
                   100. * idx / len(train_dataloader), loss.item()))

        torch.save(model.state_dict(), "model/seq2seq_model.pkl")
        torch.save(optimizer.state_dict(), 'model/seq2seq_optimizer.pkl')

if __name__ == '__main__':
    for i in range(5):
        train(i)

Seq2Seq优点:能处理输入和输出长度不固定的序列转换任务,灵活性高

Seq2Seq缺点:使用固定上下文长度、训练和推理通常需要逐步处理输入和输出序列,以及参数量较少,面对复杂场景可能受限。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1937646.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端Vue组件技术实践:构建自定义动态宫格菜单按钮组件

随着前端技术的不断发展,复杂度和开发难度也随之增加。传统的整体式开发方式已经难以满足现代前端应用的需求,特别是在业务场景复杂、产品迭代频繁的情况下。组件化开发作为一种有效的解决方案,通过拆分和组合独立的组件,实现了单…

C语言 | Leetcode C语言题解之第240题搜索二维矩阵II

题目&#xff1a; 题解&#xff1a; bool searchMatrix(int** matrix, int matrixSize, int* matrixColSize, int target){int i 0;int j matrixColSize[0] - 1;while(j > 0 && i < matrixSize){if(target < matrix[i][j])j--;else if(target > matrix[…

监测电商热品推荐的技术心得

在当今数字化时代&#xff0c;电商行业竞争激烈&#xff0c;准确监测热门商品推荐对于电商企业的运营和决策至关重要。通过不断的实践和探索&#xff0c;我积累了以下一些关于监测电商热品推荐的技术心得。 一、数据采集与整合 多平台数据抓取 要全面了解电商市场的热门商品&am…

ORBSLAM3 ORB_SLAM3 Ubuntu18.04 ROS Melodic 虚拟镜像 下载

build.sh 和 build_ros.sh编译结果截图&#xff1a; slam测试视频&#xff1a; orbslam3 ubuntu18.04 test 下载地址&#xff08;付费使用&#xff0c;不能接受请勿下载&#xff09;&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/13YeJS4RGa3fBrG8BKfPbBw?pwds6vg 提…

使用phpMyAdmin操作MYSQL(四)

一. 学会phpMyAdmin&#xff1f; phpMyAdminhttp://water.ve-techsz.cn/phpmyadmin/ 虽然我我们可以用命令行操作数据库&#xff0c;但这样难免没有那么直观&#xff0c;方便。所以接下来我们使用phpMyAdmin来操作MySQL&#xff0c;phpMyAdmin是众多MySQL图形化管理工具中使用…

​人人开源renren-security:基于SpringBoot、Vue3、ElementPlus等框架开发的权限管理系统

摘要&#xff1a; 随着信息技术的快速发展&#xff0c;企业的信息系统安全需求日益凸显。renren-security是一套基于SpringBoot、MyBatis-Plus、Shiro、Vue3、ElementPlus等框架开发的权限管理系统&#xff0c;它旨在为企业提供高效、安全、易用的权限管理解决方案。本文详细阐…

用Wireshark观察IPsec协议的通信过程

目录 一、配置本地安全策略 二、启动Wireshark&#xff0c;设置过滤器&#xff0c;开始捕获 1. 主模式 2. Quick mode 三、心得体会 1. 碰到的问题和解决办法 2. 心得 一、配置本地安全策略 配置好IPsec如下&#xff1a; 由于在windows server2008安装wireshark失败&…

Qt实现一个简单的视频播放器

目录 1 工程配置 1.1 创建新工程 1.2 ui界面配置 1.3 .pro配置 2 代码 2.1 main.c代码 2.2 widget.c 2.3 widget.h 本文主要记述了如何使用Qt编写一个简单的视频播放器&#xff0c;整个示例采用Qt自带组件就可以完成。可以实现视频的播放和暂停等功能。 1 工程配置 1.…

2024.7.19最新详细的VMware17.0.0安装

VM官网VMware - Delivering a Digital Foundation For Businesses。现在官网无法下载&#xff0c;点击会跳转到https://access.broadcom.com/default/ui/v1/signin/ 要注册一个账号&#xff1a; 注册登录以后&#xff0c;点击Please select your identity provider. - Support …

深度学习落地实战:大模型生成图片

前言 大家好&#xff0c;我是机长 本专栏将持续收集整理市场上深度学习的相关项目&#xff0c;旨在为准备从事深度学习工作或相关科研活动的伙伴&#xff0c;储备、提升更多的实际开发经验&#xff0c;每个项目实例都可作为实际开发项目写入简历&#xff0c;且都附带完整的代…

基于RFID的课堂签到系统设计

1.简介 基于RFID的课堂签到系统设计是一种利用无线射频识别&#xff08;RFID&#xff09;技术实现课堂自动签到的系统。这种系统通过RFID标签&#xff08;通常是学生携带的卡片或手环等&#xff09;与安装在教室内的RFID读写器之间的无线电信号进行数据交换&#xff0c;从而实现…

深度学习入门——与学习相关的技巧

前言 本章将介绍神经网络的学习中的一些重要观点&#xff0c;主题涉及寻找最优权重参数的最优化方法、权重参数的初始值、超参数的设定方法等 此外&#xff0c;为了应对过拟合&#xff0c;本章还将介绍权值衰减、Dropout等正则化方法&#xff0c;并进行实现。 最后将对近年来…

【深度学习】PyTorch框架(2):激活函数

1.引言 在文中&#xff0c;我们将深入探讨流行的激活函数&#xff0c;并分析它们在神经网络优化特性中的作用。激活函数在深度学习模型中扮演着至关重要的角色&#xff0c;因为它们为网络引入了非线性特性。尽管文献中描述了众多的激活函数&#xff0c;但它们并非一视同仁&…

如何优化 PostgreSQL 中的连接查询性能?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01;&#x1f4da;领书&#xff1a;PostgreSQL 入门到精通.pdf 文章目录 如何优化 PostgreSQL 中的连接查询性能&#xff1f;一、理解连接查询的基本原理二、优化连接查询的关键…

JavaScript 模板字符串:让字符串拼接变得更优雅

在 JavaScript 开发中&#xff0c;字符串拼接是一个常见的需求。从简单的用户界面文本生成到复杂的动态数据格式化&#xff0c;字符串操作无处不在。传统的字符串拼接方法虽然功能强大&#xff0c;但往往显得冗长且难以阅读。为了解决这一问题&#xff0c;ES6&#xff08;ECMAS…

职升网:监理工程师题型都是选择题吗?

监理工程师考试科目包含的题型主要有单项选择题、多项选择题以及案例分析题三种。其中《建设工程监理基本理论和相关法规》、《建设工程合同管理》、《建设工程目标控制》三科只有选择题题型&#xff0c;而《建设工程监理案例分析》只有案例分析题。 监理工程师各科目考试题型 …

系统架构设计师教程(清华第二版) 第3章 信息系统基础知识-3.2 业务处理系统-解读

教材中,一会儿“业务处理系统”,一会儿“事务处理系统”,语法毛病一堆。真是清华的水平!!! 系统架构设计师教程 第3章 信息系统基础知识-3.2 业务处理系统 3.2.1 业务处理系统的概念3.2.2 业务处理系统的功能3.2.2.1 数据输入3.2.2.2 数据处理3.2.2.2.1 批处理 (Batch …

C++——继承和多态

1.继承 1.1 继承的概念 在过往的文章中介绍过Java的继承&#xff0c;我们这里比较学习C的继承。 继承是出现是基于对代码复用的需求&#xff0c;在我们写代码时&#xff0c;会发现两个类之间存在大量的代码重复的情况&#xff0c;这个时候继承就排上了用场。继承可以在保持原有…

在 PostgreSQL 中如何实现数据的加密存储?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01;&#x1f4da;领书&#xff1a;PostgreSQL 入门到精通.pdf 文章目录 在 PostgreSQL 中如何实现数据的加密存储&#xff1f;一、为什么要进行数据加密存储&#xff1f;二、P…

【Django】网上蛋糕商城后台-订单管理

概念 前面通过多篇文章以完全实现了用户在网上蛋糕商城平台上的所有功能和操作&#xff0c;从本文开始&#xff0c;实现网站的后台管理功能的介绍和操作。 导入静态资源 在static文件夹下&#xff0c;创建admin文件夹&#xff0c;在该文件夹下导入静态资源 在templates文件夹…