基于Python的自然语言处理系列(13):TorchText + GRU + 上下文向量 + Teacher Forcing

news2025/1/22 16:11:23

        在上一篇文章中,我们使用了LSTM来构建一个序列到序列模型(seq2seq)。虽然LSTM表现良好,但我们想看看能否通过使用门控循环单元(GRU)并改进信息压缩的方式来提升模型性能。GRU和LSTM在很多场景下表现相似,但GRU更轻量,因此我们这次将尝试使用GRU,并改进解码器中上下文向量的使用方式。

1. 序列到序列模型中的信息压缩问题

        在前面的模型中,编码器会将整个输入序列压缩为一个上下文向量,并将其传递给解码器。解码器需要利用该向量生成整个输出序列。这种信息压缩在某些情况下会导致性能问题,因为模型必须将所有源序列的信息压缩到一个向量中,这对长序列的处理可能不够充分。

        为了解决这个问题,我们在解码过程中不仅仅依赖隐状态,还结合上下文向量的重用,让它在每个解码步骤都可用,从而减轻信息压缩的负担。

2. GRU:轻量级的循环神经网络

        GRU是LSTM的变体,但相对简单且更高效。研究表明,GRU和LSTM的表现非常接近        【Research】,但GRU结构更简单,因此在很多任务中更具优势。

        GRU的隐状态公式如下:

        与LSTM不同,GRU没有单独的神经元状态(cell state),因此隐状态直接携带信息。

3. 数据加载与预处理

        为了保持一致性,我们继续使用TorchText加载Multi30k数据集,并使用spacy进行标记化处理。加载数据的部分与前面的教程保持一致,因此可以直接参考。

from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer

SRC_LANGUAGE = 'en'
TRG_LANGUAGE = 'de'

train = Multi30k(split=('train'), language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))
token_transform = {}
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')
token_transform[TRG_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')

        在词汇表构建和数据处理方面,我们继续使用与前面相同的流程,包括数值化和构建数据加载器。

4. 模型设计

Encoder

        我们将LSTM替换为GRU来构建编码器。与之前的LSTM不同,GRU没有细胞状态,因此只需要处理隐状态。此外,我们不使用dropout,因为我们这里只有单层GRU,dropout只在多层结构中有效。

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, dropout):
        super().__init__()

        self.hid_dim = hid_dim
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        return hidden

        在解码器中,我们对上下文向量进行重用。每一步解码不仅依赖前一时刻的隐状态,还结合上下文向量。这种设计减轻了信息压缩问题。我们将当前的词嵌入、上下文向量和隐状态拼接,送入GRU和全连接层进行预测。

        下图展示了我们改进后的解码器结构,其中上下文向量 zzz 在每个解码步骤中都被重复使用:

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, dropout):
        super().__init__()

        self.hid_dim = hid_dim
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.gru = nn.GRU(emb_dim + hid_dim, hid_dim)
        self.fc  = nn.Linear(emb_dim + hid_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_, hidden, context):
        input_ = input_.unsqueeze(0)
        embedded = self.dropout(self.embedding(input_))
        emb_con = torch.cat((embedded, context), dim=2)
        output, hidden = self.gru(emb_con, hidden)
        output = torch.cat((embedded.squeeze(0), hidden.squeeze(0), context.squeeze(0)), dim=1)
        prediction = self.fc(output)
        return prediction, hidden

Seq2Seq 模型

        我们将编码器和解码器组合在一起,构建一个完整的seq2seq模型。解码器每一步都接收当前token的嵌入、前一时刻的隐状态和上下文向量。

class Seq2SeqGRU(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim

        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        context = self.encoder(src)
        hidden = context
        input_ = trg[0,:]

        for t in range(1, trg_len):
            output, hidden = self.decoder(input_, hidden, context)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input_ = trg[t] if teacher_force else top1
        return outputs

5. 模型训练与评估

        模型训练与之前非常相似,我们继续使用Adam优化器和交叉熵损失函数。

import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# 训练函数和评估函数与前文相同
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    
    for i, (src, trg) in enumerate(iterator):
        src = src.to(device)
        trg = trg.to(device)
        
        optimizer.zero_grad()
        
        # 模型前向传播
        output = model(src, trg)
        
        # trg = [trg_len, batch_size]
        # output = [trg_len, batch_size, output_dim]
        output_dim = output.shape[-1]
        
        # 将output和目标序列平展
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        
        # 计算损失
        loss = criterion(output, trg)
        
        # 反向传播和梯度裁剪
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        # 优化器更新参数
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for i, (src, trg) in enumerate(iterator):
            src = src.to(device)
            trg = trg.to(device)

            # 关闭Teacher Forcing
            output = model(src, trg, 0)

            # trg = [trg_len, batch_size]
            # output = [trg_len, batch_size, output_dim]
            output_dim = output.shape[-1]

            # 将output和目标序列平展
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)

            # 计算损失
            loss = criterion(output, trg)
            
            epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)

6. 测试与预测

        模型训练完成后,我们可以对新的数据进行预测。使用与之前相同的方法进行推理测试。

with torch.no_grad():
    output = model(src_text, trg_text, 0)  # 关闭Teacher Forcing
    output_max = output.argmax(1)
    for token in output_max:
        print(mapping[token.item()])

结语

        在这篇文章中,我们展示了如何通过使用GRU并重用上下文向量来改进序列到序列(seq2seq)模型。这种方法有效地减轻了信息压缩问题,使模型在处理长序列生成任务时表现更好。通过引入GRU,模型的计算效率得到了提升,且在生成过程中,每一步都能直接访问编码器生成的上下文向量,进一步提高了生成的准确性。

        然而,虽然这种方法在某些场景下表现出色,但它仍然无法完全解决长序列中信息损失的问题。为了解决这个问题,注意力机制(Attention Mechanism)应运而生。Attention可以让模型在每一步解码时,不仅仅依赖一个固定的上下文向量,而是动态地选择和关注输入序列的不同部分,从而更好地处理长序列依赖。

        在下一篇文章中,我们将深入探讨如何结合双向GRU(biGRU)与注意力机制(Attention Mechanism),并继续使用Teacher Forcing来训练模型,进一步提升序列生成的效果。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2148600.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows Server2016多用户登录破解

使用场景 很多时候&#xff0c;公司开发和测试运维会同时登录同一台windows服务器进行查询、更新、维护等操作&#xff0c;本文就来介绍一下Windows2016配置多人远程桌面登录实现&#xff0c;感兴趣的可以了解一下。 操作流程 &#xff08;1&#xff09;首先桌面需要安装远程…

etcd之etcd简介和安装(一)

1、etcd简介 1.1 etcd简介 etcd 是开源的、高可用的分布式key-value存储系统&#xff0c;可用于配置共享和服务的注册和发现&#xff0c;它专注于&#xff1a; 简单&#xff1a;定义清晰、面向用户的API&#xff08;gRPC&#xff09; 安全&#xff1a;可选的客户端TLS证书自…

uni-app功能 1. 实现点击置顶,滚动吸顶2.swiper一个轮播显示一个半内容且实现无缝滚动3.穿透修改uni-ui的样式

uni-app项目中遇到的功能 文章目录 uni-app项目中遇到的功能一、实现点击置顶&#xff0c;滚动吸顶、1.1、scroll-view设置不生效的原因和解决办法1.2 功能代码 二、swiper一个轮播显示一个半内容且实现无缝滚动三、穿透修改uni-ui的样式 一、实现点击置顶&#xff0c;滚动吸顶…

PMP--二模--解题--1-10

文章目录 4.整合管理--商业文件--商业论证&#xff08;是否值得所需投资、高管们决策的依据&#xff09;反映了&#xff1a;1、 [单选] 收到新项目的客户请求之后&#xff0c;项目经理首先应该做什么&#xff1f; 14.敏捷--角色--产品负责人PO–职责–1.创建待办列表并排序;2.确…

EmguCV学习笔记 VB.Net 12.3 OCR

版权声明&#xff1a;本文为博主原创文章&#xff0c;转载请在显著位置标明本文出处以及作者网名&#xff0c;未经作者允许不得用于商业目的。 EmguCV是一个基于OpenCV的开源免费的跨平台计算机视觉库,它向C#和VB.NET开发者提供了OpenCV库的大部分功能。 教程VB.net版本请访问…

MATLAB给一段数据加宽频噪声的方法(随机噪声+带通滤波器)

文章目录 引言方法概述完整代码:结果分析结论参考文献引言 在信号处理领域,添加噪声是模拟实际环境中信号传输时常见的操作。宽频噪声可以用于测试系统的鲁棒性和信号处理算法的有效性。本文将介绍如何使用 M A T L A B MATLAB MATLAB给一段数据添加宽频噪声,具体方法是结合…

漏洞挖掘 | Selenium Grid 中的 SSRF

Selenium 网格框架上的基本服务器端请求伪造 最近&#xff0c;我正在阅读漏洞文章看到Peter Jaric写的一篇 Selenium Grid 文章&#xff1b;他解释了 Selenium Grid 框架上缺乏身份验证和安全措施强化的问题。 在网上进行了更多搜索&#xff0c;我发现 Selenium Grid 开箱即用…

古诗词四首鉴赏

1、出自蓟北门行 唐李白 虏阵横北荒&#xff0c;胡星曜精芒。 羽书速惊电&#xff0c;烽火昼连光。 虎竹救边急&#xff0c;戎车森已行。 明主不安席&#xff0c;按剑心飞扬。 推毂出猛将&#xff0c;连旗登战场。 兵威冲绝漠&#xff0c;杀气凌穹苍。…

打开C嘎嘎的大门:你好,C嘎嘎!(1)

前言&#xff1a; 小编在学习完一些数据结构以后&#xff0c;终于&#xff0c;我还是来到了这一步&#xff0c;开始学习我小学就听说过的C&#xff0c;至于为什么标题写的C嘎嘎&#xff0c;因为小编觉着这样好念而且有意思&#xff0c;今天是小编学习C嘎嘎的第一天&#xff0c;…

零信任安全架构--最小权限原则

最小权限原则&#xff08;Principle of Least Privilege, PoLP&#xff09;是零信任安全架构中的核心理念之一&#xff0c;旨在确保用户、设备、应用等系统实体只拥有完成其任务所必需的最低权限&#xff0c;避免不必要的权限扩展&#xff0c;从而降低安全风险。 1. 概念 最小…

LabVIEW闪退

LabVIEW闪退或无法启动可能由多个原因引起&#xff0c;特别是在使用了一段时间后突然发生的问题。重启电脑后 LabVIEW 和所有 NI 软件都无法打开&#xff0c;甚至在卸载和重装时也没有反应。这种情况通常与系统环境、软件冲突或 NI 软件组件的损坏有关。 1. 检查系统和软件冲突…

Arthas dashboard(当前系统的实时数据面板)

文章目录 二、命令列表2.1 jvm相关命令2.1.1 dashboard&#xff08;当前系统的实时数据面板&#xff09; 二、命令列表 2.1 jvm相关命令 2.1.1 dashboard&#xff08;当前系统的实时数据面板&#xff09; 使用场景&#xff1a; 在 Arthas 中&#xff0c;dashboard 命令用于提…

echarts实现地图下钻并解决海南群岛显示缩略图

一、准备工作 1、echarts版本&#xff1a; ^5.5.1 2、去掉海南数据的json文件 二、获取删除过后的json文件 1、DataV.GeoAtlas地理小工具系列 (aliyun.com) 在网站输入这个复制的&#xff0c;新建一个json文件粘贴进去。 接下来需要删除两个地方&#xff0c;不要删错&…

前端vue-关于标签切换的实现

首先是循环&#xff0c;使用v-for“&#xff08;item,index) in list” :key“item.id” 然后当点击哪个的时候再切换&#xff0c;使用v-bind:class" "或者是:class" ",如果都是用active的话&#xff0c;那么每一个标签都是被选中的状态&#xff0c;…

[C高手编程] C语言宏、内置宏与预处理:深入理解与应用

&#x1f496;&#x1f496;⚡️⚡️专栏&#xff1a;C高手编程-面试宝典/技术手册/高手进阶⚡️⚡️&#x1f496;&#x1f496; 「C高手编程」专栏融合了作者十多年的C语言开发经验&#xff0c;汇集了从基础到进阶的关键知识点&#xff0c;是不可多得的知识宝典。如果你是即将…

TypeScript异常处理

1.异常的概念 程序运行中意外发生的情况就成为异常 例子&#xff1a; //除法运算function chu(num1:number,num2:number){if(num20){//throw 抛出异常throw new Error(除数不能为零)}let num:numbernum1/num2console.log(num) }//程序出现异常后会停止运行// 捕获异常try{ /…

《黑神话悟空》开发框架与战斗系统解析

本文主要围绕《黑神话悟空》的开发框架与战斗系统解析展开 主要内容 《黑神话悟空》采用的技术栈 《黑神话悟空》战斗系统的实现方式 四种攻击模式 连招系统的创建 如何实现高扩展性的战斗系统 包括角色属性系统、技能配置文件和逻辑节点的抽象等关键技术点 版权声明 本…

【他山之石】Humanize AI 简介

Humanize AI 简介 Humanize AI 官方首页截图 文章目录 Humanize AI 简介1 Humanize AI 是什么2 Humanize AI 能做什么3 Humanize AI 怎么用4 Humanize AI 怎么收费5 结论 1 Humanize AI 是什么 数字时代的当下&#xff0c;AI 人工智能已成为内容创作不可或缺的一部分。从生成文…

poi-tl的详细教程(动态表格、单元格合并)

前提了解poi-tl 链接: springboot整合poi-tl 创建word模板 实现效果 代码实现 ServerTableData import com.deepoove.poi.data.RowRenderData;import java.util.List;public class ServerTableData {/*** 携带表格中真实数据*/private List<RowRenderData> serverDat…

【Python常用模块】_PyMySQL模块详解

课 程 推 荐我 的 个 人 主 页:👉👉 失心疯的个人主页 👈👈入 门 教 程 推 荐 :👉👉 Python零基础入门教程合集 👈👈虚 拟 环 境 搭 建 :👉👉 Python项目虚拟环境(超详细讲解) 👈👈PyQt5 系 列 教 程:👉👉 Python GUI(PyQt5)教程合集 👈👈…