NLP Seq2Seq模型

news2024/12/27 13:45:07
🍨 本文为[🔗365天深度学习训练营学习记录博客
 
🍦 参考文章:365天深度学习训练营
 
🍖 原作者:[K同学啊 | 接辅导、项目定制]\n🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)

Seq2Seq模型是一种深度学习模型,用于处理序列到序列的任务,它由两个主要部分组成:编码器(Encoder)和解码器(Decoder)。

  1. 编码器(Encoder): 编码器负责将输入序列(例如源语言句子)编码成一个固定长度的向量,通常称为上下文向量或编码器的隐藏状态。编码器可以是循环神经网络(RNN)、长短期记忆网络(LSTM)或者变种如门控循环单元(GRU)等。编码器的目标是捕捉输入序列中的语义信息,并将其编码成一个固定维度的向量表示。

  2. 解码器(Decoder): 解码器接收编码器生成的上下文向量,并根据它来生成输出序列(例如目标语言句子)。解码器也可以是RNN、LSTM、GRU等。在训练阶段,解码器一次生成一个词或一个标记,并且其隐藏状态从一个时间步传递到下一个时间步。解码器的目标是根据上下文向量生成与输入序列对应的输出序列。

在训练阶段,Seq2Seq模型的目标是最大化目标序列的条件概率给定输入序列。为了实现这一点,通常使用了一种称为教师强制(Teacher Forcing)的技术,即将目标序列中的真实标记作为解码器的输入。但是,在推理阶段(即模型用于生成新的序列时),解码器则根据先前生成的标记生成下一个标记,直到生成一个特殊的终止标记或达到最大长度为止。

下面演示了如何使用PyTorch实现一个简单的Seq2Seq模型,用于将一个序列翻译成另一个序列。这里我们将使用一个虚构的数据集来进行简单的法语到英语翻译。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

# 定义数据集
class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# 定义Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim)
        
    def forward(self, src):
        embedded = self.embedding(src)
        outputs, hidden = self.rnn(embedded)
        return outputs, hidden

# 定义Decoder
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, input, hidden):
        input = input.unsqueeze(0)
        embedded = self.embedding(input)
        output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc_out(output.squeeze(0))
        return prediction, hidden

# 定义Seq2Seq模型
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.fc_out.out_features
        
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        encoder_outputs, hidden = self.encoder(src)
        
        input = trg[0,:]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden)
            outputs[t] = output
            teacher_force = np.random.rand() < teacher_forcing_ratio
            top1 = output.argmax(1) 
            input = trg[t] if teacher_force else top1
        
        return outputs

# 设置参数
INPUT_DIM = 10
OUTPUT_DIM = 10
ENC_EMB_DIM = 32
DEC_EMB_DIM = 32
HID_DIM = 64
N_LAYERS = 1
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

# 实例化模型
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Seq2Seq(enc, dec, device).to(device)

# 打印模型结构
print(model)

# 定义训练函数
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        src, trg = batch
        src = src.to(device)
        trg = trg.to(device)
        
        optimizer.zero_grad()
        
        output = model(src, trg)
        
        output_dim = output.shape[-1]
        
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

# 定义测试函数
def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src, trg = batch
            src = src.to(device)
            trg = trg.to(device)

            output = model(src, trg, 0) # 关闭teacher forcing

            output_dim = output.shape[-1]

            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

# 示例数据
train_data = [(torch.tensor([1, 2, 3]), torch.tensor([3, 2, 1])),
              (torch.tensor([4, 5, 6]), torch.tensor([6, 5, 4])),
              (torch.tensor([7, 8, 9]), torch.tensor([9, 8, 7]))]

# 超参数
BATCH_SIZE = 3
N_EPOCHS = 10
LEARNING_RATE = 0.001
CLIP = 1

# 数据集与迭代器
train_dataset = SimpleDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# 定义损失函数与优化器
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

# 训练模型
for epoch in range(N_EPOCHS):
    train_loss = train(model, train_loader, optimizer, criterion, CLIP)
    print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1482453.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

等概率事件算法

1等概率的生成(0-8)范围内的正整数 // Math.random 数据范围[0,1) 且 是 等概率的产生随机数 // 应用&#xff1a; // 1.生成等概率的整数&#xff08;等概率的生成(0-8)范围内的正整数 int value (int) (Math.random() * 9); System.out.println("value "…

YOLOv9有效提点|加入BAM、CloFormer、Reversible Column Networks、Lskblock等几十种注意力机制(二)

专栏介绍&#xff1a;YOLOv9改进系列 | 包含深度学习最新创新&#xff0c;主力高效涨点&#xff01;&#xff01;&#xff01; 一、本文介绍 本文只有代码及注意力模块简介&#xff0c;YOLOv9中的添加教程&#xff1a;可以看这篇文章。 YOLOv9有效提点|加入SE、CBAM、ECA、SimA…

逆变器专题(13)-逆变器LCL滤波器设计

相应仿真原件请移步资源下载 随着新能源的蓬勃发展&#xff0c;逆变器作为光伏储能的核心电力电子器件之一&#xff0c;其也得到了大力发展&#xff1b;传统的L型以及LC型滤波器因其体积较大、滤波效果差等原因已经被逐渐替代&#xff0c; LCL滤波器因其具有较小的体积&#xf…

JavaWeb HTTP 请求头、请求体、响应头、响应体、响应状态码

J2EE&#xff08;Java 2 Platform Enterprise Edition&#xff09;是指“Java 2企业版”&#xff0c;B/S模式开发Web应用就是J2EE最核心的功能。 Web是全球广域网&#xff0c;也称为万维网(www)&#xff0c;能够通过浏览器访问的网站。 在日常的生活中&#xff0c;经常会使用…

通过GitHub探索Python爬虫技术

1.检索爬取内容案例。 2.找到最近更新的。(最新一般都可以直接运行) 3.选择适合自己的项目&#xff0c;目前测试下面画红圈的是可行的。 4.方便大家查看就把代码粘贴出来了。 #图中画圈一代码 import requests import os import rewhile True:music_id input("请输入歌曲…

编译链接实战(22)C/C++代码覆盖率统计报告生成

文章目录 GCOV 工具简介gcov 使用lcov相关编译选项 GCOV 工具简介 gcov是一个测试代码覆盖率的工具&#xff0c;它是 gcc 自带的查看代码覆盖率的工具。 与GCC结合使用&#xff0c;可以分析您的程序以帮助创建更高效、运行更快的代码&#xff0c;并发现程序中未经测试的部分。…

Linux 之压缩与解压相关命令的基础用法

目录 1、zip 与 unzip 2、gzip 命令 3、tar 命令 1、zip 与 unzip 在桌面新建一个文件和文件夹用于测试 在 test 目录下有一个 1.txt 文件 我们使用 zip 命令对其压缩 用法&#xff1a; zip 自定义压缩包名 被压缩文件路径位置 zip myon.zip 1.txt 因为我们这里就是在 …

数据库管理-第157期 Oracle Vector DB AI-08(20240301)

数据库管理157期 2024-03-01 数据库管理-第157期 Oracle Vector DB & AI-08&#xff08;20240301&#xff09;1 创建示例向量2 查找最近向量3 基于向量簇组的最近向量查询总结 数据库管理-第157期 Oracle Vector DB & AI-08&#xff08;20240301&#xff09; 作者&…

多层感知器(神经网络)与激活函数

单个神经元&#xff08;二分类&#xff09; 多个神经元&#xff08;多分类&#xff09; 多层感知器 多层感知器&#xff0c;他是一种深度学习模型&#xff0c;通过多层神经元的连接和激活来解决非线性问题。 激活函数 激活函数的种类包括relu&#xff0c;sigmoid和tanh等 …

C 嵌入式系统设计模式 15:基本并发概念

本书的原著为&#xff1a;《Design Patterns for Embedded Systems in C ——An Embedded Software Engineering Toolkit 》&#xff0c;讲解的是嵌入式系统设计模式&#xff0c;是一本不可多得的好书。 本系列描述我对书中内容的理解。本文章描述嵌入式并发和资源管理模式之一…

GEE数据集——GLC_FCS30D - 全球 30 米土地覆被变化数据集(1985-2022 年)

GLC_FCS30D - 全球 30 米土地覆被变化数据集&#xff08;1985-2022 年&#xff09; 注 本数据集是正在提交的论文的一部分&#xff0c;因此没有引用和 DOI 信息。请在使用本数据集时注意这一点。 GLC_FCS30D 数据集是全球土地覆被监测领域的一项开创性进展&#xff0c;它以 30…

青少年CTF2024 #Round1 wp web

web EasyMD5 MD5碰撞&#xff0c;使用工具fastcoll生成内容不同但md5值相同的两个pdf文件上传即可获得flag&#xff1b; ./fastcoll_v1.0.0.5.exe -p 1.pdf -o 2.pdf 3.pdf # -p指定任意源文件&#xff0c;-o指定生成两个内容不同但md5值相同的目标文件 工具下载&#x…

【MySQL】复合查询(重点)-- 详解

一、基本查询练习回顾 1、查询工资高于 500 或岗位为 MANAGER 的雇员&#xff0c;同时还要满足他们的姓名首字母为大写的 J 2、按照部门号升序而雇员的工资降序排序 3、使用年薪进行降序排序 4、显示工资最高的员工的名字和工作岗位 5、显示工资高于平均工资的员工信息 6、显…

【AIGC】微笑的秘密花园:红玫瑰与少女的美好相遇

在这个迷人的画面中&#xff0c;我们目睹了一个迷人的时刻&#xff0c;女子则拥有一头柔顺亮丽的秀发&#xff0c;明亮的眼睛如同星河般璀璨&#xff0c;优雅而灵动&#xff0c;她的微笑如春日暖阳&#xff0c;温暖而又迷人。站在红玫瑰花瓣的惊人洪水中。 在一片湛蓝无云的晴…

100M服务器能同时容纳多少人访问

100M服务器的并发容纳人数会受到多种因素的影响&#xff0c;这些因素包括单个用户的平均访问流量大小、每个用户的平均访问页面数、并发用户比例、服务器和网络的流量利用率以及服务器自身的处理能力。 点击以下任一云产品链接&#xff0c;跳转后登录&#xff0c;自动享有所有…

精品SSM的选修课管理系统选课签到打卡

《[含文档PPT源码等]精品基于SSM的选修课管理系统设计与实现[包运行成功]》该项目含有源码、文档、PPT、配套开发软件、软件安装教程、项目发布教程、包运行成功&#xff01; 软件开发环境及开发工具&#xff1a; Java——涉及技术&#xff1a; 前端使用技术&#xff1a;HTM…

leetcode--接雨水(双指针法,动态规划,单调栈)

目录 方法一&#xff1a;双指针法 方法二&#xff1a;动态规划 方法三&#xff1a;单调栈 42. 接雨水 - 力扣&#xff08;LeetCode&#xff09; 黑色的是柱子&#xff0c;蓝色的是雨水&#xff0c;我们先来观察一下雨水的分布情况: 雨水落在凹槽之间&#xff0c;在一个凹槽的…

Laravel Octane 和 Swoole 协程的使用分析二

又仔细研究了下 Octane 源码和 Swoole 的文档&#xff0c;关于前几天 Laravel Octane 和 Swoole 协程的使用分析中的猜想&#xff0c;得到进一步验证&#xff1a; Swoole 的 HTTP Server 启动后会创建一个 master 进程和一个 manager 进程&#xff1b;master 进程又会创建多个…

持安科技孙维伯:零信任在攻防演练下的最佳实践|DISCConf 2023

近日&#xff0c;在2023数字身份安全技术大会上&#xff0c;持安科技联合创始人孙维伯应主办方的特别邀请&#xff0c;发表了主题为“零信任在攻防演练下的最佳实践”的演讲。 孙维伯在2023数字身份安全技术大会上发表演讲 以下为本次演讲实录&#xff1a; 我是持安科技的联合…

Leetcode 第 386 场周赛题解

Leetcode 第 386 场周赛题解 Leetcode 第 386 场周赛题解题目1&#xff1a;3046. 分割数组思路代码复杂度分析 题目2&#xff1a;3047. 求交集区域内的最大正方形面积思路代码复杂度分析 题目3&#xff1a;3048. 标记所有下标的最早秒数 I思路代码复杂度分析 题目4&#xff1a;…