自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与2. Teacher forcing(教师强制)

news2024/11/18 5:45:19

自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与2. Teacher forcing(教师强制)


作者:安静到无声 个人主页

作者简介:人工智能和硬件设计博士生、CSDN与阿里云开发者博客专家,多项比赛获奖者,发表SCI论文多篇。

Thanks♪(・ω・)ノ 如果觉得文章不错或能帮助到你学习,可以点赞👍收藏📁评论📒+关注哦! o( ̄▽ ̄)d

欢迎大家来到安静到无声的 《基于pytorch的自然语言处理入门与实践》,如果对所写内容感兴趣请看《基于pytorch的自然语言处理入门与实践》系列讲解 - 总目录,同时这也可以作为大家学习的参考。欢迎订阅,请多多支持!

目录标题

  • 自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与2. Teacher forcing(教师强制)
  • 1. Scheduled Sampling(计划采样)
    • 1.1 概念解释
    • 1.2 代码实现
  • 2. Teacher forcing(教师强制)
    • 2.1 概念解释
    • 2.1 代码实现
  • 参考

1. Scheduled Sampling(计划采样)

1.1 概念解释

Scheduled Sampling是一种用于训练序列生成模型的策略,旨在缓解曝光偏差(Exposure Bias)问题。曝光偏差是指模型在训练时接触到的数据分布与测试时的数据分布不一致,导致性能下降。

在Scheduled Sampling中,模型在每个时间步骤都有一定的概率选择使用真实目标序列中的单词作为输入,而不是使用前一个时间步骤生成的单词。这样可以使模型更好地适应真实数据分布,减少曝光偏差问题。

具体来说,Scheduled Sampling使用以下公式计算每个时间步骤生成当前单词的概率:

P ( y t ∣ y 1 , . . . , y t − 1 ) = ( 1 − ϵ ) ∗ P model ( y t ∣ y 1 , . . . , y t − 1 ) + ϵ ∗ P data ( y t ∣ y 1 , . . . , y t − 1 ) P(y_t|y_1, ..., y_{t-1}) = (1 - \epsilon) * P_{\text{model}}(y_t|y_1, ..., y_{t-1}) + \epsilon * P_{\text{data}}(y_t|y_1, ..., y_{t-1}) P(yty1,...,yt1)=(1ϵ)Pmodel(yty1,...,yt1)+ϵPdata(yty1,...,yt1)其中, P ( y t ∣ y 1 , . . . , y t − 1 ) P(y_t|y_1, ..., y_{t-1}) P(yty1,...,yt1)表示在给定前面的生成序列条件下生成当前单词 y t y_t yt的概率, P model ( y t ∣ y 1 , . . . , y t − 1 ) P_{\text{model}}(y_t|y_1, ..., y_{t-1}) Pmodel(yty1,...,yt1)表示模型生成该单词的概率, P data ( y t ∣ y 1 , . . . , y t − 1 ) P_{\text{data}}(y_t|y_1, ..., y_{t-1}) Pdata(yty1,...,yt1)表示真实目标序列中该单词的概率。参数 ϵ \epsilon ϵ用于控制采样策略,可以随着训练的进行而逐渐增加。

1.2 代码实现

下面是一个使用Python实现Scheduled Sampling的示例代码:

在这里插入图片描述

其中, P ( y t ∣ y < t , x ) P(y_t | y_{<t}, x) P(yty<t,x)表示在给定前文和输入的条件下,生成当前时间步的输出的概率。 P model P_{\text{model}} Pmodel表示由模型生成的概率分布, P prev P_{\text{prev}} Pprev表示根据上一个时间步的真实输出计算得到的概率分布。sample是从均匀分布中采样得到的一个随机数,threshold是一个控制Scheduled Sampling引入程度的超参数。

2. Teacher forcing(教师强制)

2.1 概念解释

Teacher forcing(教师强制)是一种在序列生成模型中使用的训练技术。具体来说,当使用RNN(循环神经网络)或类似架构的模型进行序列生成时,每个时间步都会根据前一个时间步的输入和隐藏状态生成输出。在训练期间,如果使用teacher forcing,那么每个时间步的输入将是真实的目标序列(而不是模型自身生成的序列)。这意味着模型在每个时间步都能够观察到正确的答案,从而更容易地学习到正确的模式和规律。

2.1 代码实现

import torch
import torch.nn as nn

# 定义序列到序列模型
class Seq2SeqModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Seq2SeqModel, self).__init__()
        self.hidden_dim = hidden_dim
        
        # 定义编码器
        self.encoder = nn.RNN(input_dim, hidden_dim)
        
        # 定义解码器
        self.decoder = nn.RNN(output_dim, hidden_dim)
        
        # 定义全连接层,将解码器的输出映射为目标序列
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, input_seq, target_seq):
        # 编码器计算输入序列的隐藏状态
        _, hidden_state = self.encoder(input_seq)
        
        # 解码器初始化隐藏状态
        decoder_hidden_state = hidden_state
        
        # 用真实目标序列作为输入来指导解码器的生成过程
        decoder_outputs, _ = self.decoder(target_seq, decoder_hidden_state)
        
        # 对解码器的输出应用全连接层进行映射
        output_seq = self.fc(decoder_outputs)
        
        return output_seq

# 创建模型实例
input_dim = 10
hidden_dim = 20
output_dim = 10
model = Seq2SeqModel(input_dim, hidden_dim, output_dim)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    # 步骤1:将模型设为训练模式
    model.train()
    
    # 步骤2:清零梯度
    optimizer.zero_grad()
    
    # 步骤3:前向传播
    input_seq = torch.randn(5, 3, input_dim)  # 输入序列
    target_seq = torch.randn(5, 3, output_dim)  # 目标序列
    output_seq = model(input_seq, target_seq)
    
    # 步骤4:计算损失
    loss = criterion(output_seq, target_seq)
    
    # 步骤5:反向传播和优化
    loss.backward()
    optimizer.step()
    
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

上述代码中,我们定义了一个简单的序列到序列模型Seq2SeqModel,其中包括一个RNN编码器、一个RNN解码器和一个全连接层。在forward方法中,我们首先使用编码器计算输入序列的隐藏状态,然后将隐藏状态作为解码器的初始隐藏状态。接下来,我们使用真实目标序列来指导解码器的生成过程,并将解码器的输出映射为目标序列。在训练阶段,我们使用真实目标序列作为输入来指导模型的生成过程。最后,我们定义了损失函数和优化器,并进行训练。

需要注意的是,在实际应用中,模型的推理阶段并不会使用真实目标序列来指导生成过程。在推理阶段,可以将前一个时间步的模型输出作为下一个时间步的输入,从而进行序列的自我生成。

--------推荐专栏--------
🔥 手把手实现Image captioning
💯CNN模型压缩
💖模式识别与人工智能(程序与算法)
🔥FPGA—Verilog与Hls学习与实践
💯基于Pytorch的自然语言处理入门与实践

参考

Scheduled Sampling的搜索结果_百度图片搜索 (baidu.com)
Teacher forcing RNN的搜索结果_百度图片搜索 (baidu.com)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/756699.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C/C++动态内存开辟(详解)

目录 一&#xff0c;mallloc 函数参数&#xff1a; 函数原理&#xff1a; 二&#xff0c;calloc 函数参数&#xff1a; 函数原理&#xff1a; 三&#xff0c;realloc 函数参数&#xff1a; 函数原理: 五&#xff0c;小结 2&#xff09;对开辟空间的越界访问 3&#x…

cnn分类图像cifar10

使用CNN模型来分类图像&#xff0c;数据集采用的cifar10&#xff0c;cifar10共有6万张&#xff0c;这些图像共分为10类。 命名的格式大概是这样的&#xff1a;0_19761.jpg&#xff0c;它的第一个数字表示的就是图像所属的类&#xff0c;分成清楚的就知道了&#xff0c;第0类就是…

Flutter:EasyLoading(loading加载、消息提示)

前言 官方虽然提供了内置的加载指示器和提示信息&#xff0c;但是功能比较简陋&#xff0c;这里推荐&#xff1a;flutter_easyloading CircularProgressIndicator CircularProgressIndicator()加粗样式 ScaffoldMessenger.of(context).showSnackBar(const SnackBar(// 提示…

MySQL(三)SQL优化、Buffer pool、Change buffer

MySQL系列文章 MySQL&#xff08;一&#xff09;基本架构、SQL语句操作、试图 MySQL&#xff08;二&#xff09;索引原理以及优化 MySQL&#xff08;三&#xff09;SQL优化、Buffer pool、Change buffer MySQL&#xff08;四&#xff09;事务原理及分析 MySQL&#xff08;五&a…

泛积木-低代码 搭建 增删改查

文章首发于 增删改查 。 这里我们以增删改查作为示例&#xff0c;演示下从页面创建到各个功能齐全。创建页面的时候&#xff0c;建议接口先写好&#xff0c;当然也可以一边联调一边写接口&#xff0c;当前对增删改查提供以下测试接口&#xff1a; 测试接口 /contactsList 列…

【数据结构】非线性结构之树结构(含堆)

前言 前面的三篇文章已经将线性结构讲述完毕了&#xff0c;下面的文章将会为大家将讲点新东西&#xff1a;非线性结构中的树结构。萌新对这里的知识点相对陌生&#xff0c;建议反复观看&#xff01;&#xff01; 关于线性结构的三篇文章放在下面&#xff1a; 线性表之顺序表 线…

数组与指针

博客内容&#xff1a;数组与指针 文章目录 一、 数组&#xff1f;指针&#xff1f;1.区别与联系大小赋值存储位置 二、指针数组、数组指针&#xff1f;二维数组和二级指针&数组名与数组的区别总结 一、 数组&#xff1f;指针&#xff1f; 数组 相同类型数据的集合 指针 指…

谷歌Bard更新:支持中文提问和语音朗读

ChatGPT不断更新功能&#xff0c;从GPT-3到3.5&#xff0c;再到GPT-4&#xff0c;甚至最新的plus版已经支持图像处理和图表生成&#xff0c;而谷歌Bard却自从推出后就一直很安静&#xff0c;没有什么大动作。眼见被ChatGPT、Claude甚至是文心一言抢去了风头&#xff0c;自然心有…

springcache的使用(小白也看得懂)

简介 SpringCache整合Redis可以使用Spring提供的Cacheable注解来实现对Redis的缓存操作。使用这种方式可以轻松地在应用程序中启用缓存&#xff0c;并且不需要手动编写访问Redis的代码。在配置文件中需要配置Redis的连接信息以及缓存管理器。使用这种方式可以做到轻松配置&…

C++报错:二进制“心<“没有找到接受“std:string“类型的右操作数的运算符(或没有可接受的转换)

1、问题&#xff1a;在进行二维数组的相关计算时报错&#xff1a; 二进制"心<"没有找到接受"std:string"类型的右操作数的运算符(或没有可接受的转换) 2、原因&#xff1a;没有加入头文件——String; 3、解决办法&#xff1a;加上头文件——String; 4、…

GNN学习笔记:A Gentle Introduction to Graph Neural Networks

原文地址&#xff1a; https://distill.pub/2021/gnn-intro/ 不同形式来源的图 Images as graphs 论文中提到将图像建模为一张拓扑图的方法是将图像的每一个像素看作图的一个结点&#xff0c;并将单个像素结点与其相邻的所有像素之间建立一条边。 每一个非边缘的像素结点具…

Linux下做性能分析4:怎么开始

战地分析 性能分析常常是一种战地分析&#xff0c;所以&#xff0c;在我们可以端起咖啡慢慢想怎么进行分析之前&#xff0c;我们要先说说我们在战地上的套路。 战地分析是说在实用环境中发现问题&#xff0c;我们真正需要进行性能分析的场合&#xff0c;通常都没有机会让你反…

LeetCode: 18. 四数之和 | 双指针专题

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

Java中的几种关键字this、super、static和final介绍

Java中的几种关键字this、super、static和final介绍 在Java编程语言中&#xff0c;关键字是具有特殊含义的预定义标识符。关键字是Java编程语言中具有特殊用途的保留单词&#xff0c;用于表示语法结构和程序行为。关键字在语法上具有特定的用途&#xff0c;不能用作变量名、方…

HTTP1.1、HTTPS、HTTP2.0 、HTTP3.0

HTTP1.1 优点&#xff1a; 整体方面&#xff1a;简单、灵活和易于扩展、应用广泛和跨平台 性能方面&#xff1a;长连接、管道网络传输解决请求队头阻塞&#xff08;没有使用&#xff09; 缺点&#xff1a; 安全方面&#xff1a;无状态、明文窃听、伪装、篡改 性能方面&am…

进程间通信之匿名管道

进程间通信—管道 一、进程间通信介绍二、管道1.匿名管道1.1父进程和一个子进程之间的通信1.2父进程和多个子进程之间的通信 一、进程间通信介绍 1.进程间为什么要进行通信&#xff1f; 进程间通信的是为了协调不同的进程&#xff0c;使之能在一个操作系统里同时运行&#xff…

代码随想录day4 | 24. 两两交换链表中的节点 19.删除链表的倒数第N个节点 02.07.链表相交 142.环形链表II

文章目录 一、两两交换链表中的节点二、删除链表的倒数第N个节点三、链表相交四、环形链表 24. 两两交换链表中的节点 19.删除链表的倒数第N个节点 面试题 02.07. 链表相交 142.环形链表II 一、两两交换链表中的节点 两两交换链表中的节点 注意是两两交换&#xff0c;采用虚拟…

Global symbol “%data“ requires explicit package name

Global symbol “%data” requires explicit package name 如图编写demo的时候出现了如图的问题&#xff0c;在网上查找到的原因是&#xff1a; 一&#xff0c;使用use strict; &#xff0c;修改其他代码&#xff0c;如下&#xff1a; 1&#xff0c;首先&#xff0c;检查你是不…

静态库和动态库的区别与优缺点

文章目录 静态库与动态库的区别动态库与静态库的优缺点 静态库与动态库的区别 静态库直接打包链接到可执行程序 动态库将不会链接到可执行文件 &#xff0c;执行文件运行时需要动态加载 动态库 &#xff0c;所以需要提前知道动态库路径&#xff0c;需要将路径保存到环境变量或…