门控循环单元GRU

news2025/1/21 1:02:37

目录

  • 一、GRU提出的背景:
    • 1.RNN存在的问题:
    • 2.GRU的思想:
  • 二、更新门和重置门:
  • 三、GRU网络架构:
    • 1.更新门和重置门如何发挥作用:
      • 1.1候选隐藏状态H~t:
      • 1.2隐藏状态Ht:
    • 2.GRU:
  • 四、训练过程举例******:
  • 五、预测过程举例******:
  • 六、底层源码:
  • 七、Pytorch版代码:

一、GRU提出的背景:

1.RNN存在的问题:

循环神经网络讲解文章

由于RNN的隐藏状态ht用于记录每个句子之前的所有序列信息,而对于长序列问题来说ht会记录太多序列信息导致序列时序特征区分度很差(最前面的序列特征因为进行了太多轮迭代往往不太好从ht中提取),并且RNN默认当前时间步的token单词和该句子的隐藏状态ht中所有序列信息都有同等的相关度,因此一些比较靠前但与当前时间步输入的token相关性高的序列特征在ht中可能就不太被重视,而一些比较靠后但与当前时间步输入的token相关性低的序列特征在ht中被过于关注。

2.GRU的思想:

GRU的提出就是为了解决RNN默认句子中所有token之间的相关性相等问题。
GRU的思想是对于每个时间步的输入token,使用门的控制将隐藏状态ht中与当前token相关性高的序列信息拿来参与计算,而ht中与当前token相关性低的序列信息作为噪音不参与计算。

  • 对于需要关注的序列信息,使用更新门来提高关注度
  • 对于需要遗忘的序列信息,使用遗忘门来降低关注度

二、更新门和重置门:

GRU提出更新门和重置门的思想来改变隐藏状态ht中不同序列信息的关注度。
在这里插入图片描述
更新门和重置门可以分别看做一个全连接层的隐藏层,这样的话上图就等价于两个并排的隐藏层,其中:

  • 每个隐藏层都接收之前时间步的隐藏状态Ht-1和当前时间步的输入token或token集合(batch_size>1)。
  • 更新门和重置门有各自的可学习权重参数和偏置值,公式含义类似传统RNN。
  • Rt 和 Zt 都是根据过去的隐藏状态 Ht-1 和当前输入 Xt 计算得到的 [0,1] 之间的量(激活函数)。

三、GRU网络架构:

1.更新门和重置门如何发挥作用:

重置门对过去t个时间步的序列信息(Ht-1)进行选择,更新门对当前一个时间步的序列信息(Xt)进行选择。具体原理如下:

1.1候选隐藏状态H~t:

候选隐藏状态既保留了之前的隐藏状态Ht-1,又保留了当前一个时间步的序列信息Xt。
在这里插入图片描述
因为Rt是一个[0,1] 之间的量,所以Rt×Ht-1是对之前的隐藏状态Ht-1进行一次选择:Rt 在某个位置的值越趋近于0,则表示Ht-1这个位置的序列信息越倾向于被丢弃,反之保留。

综上,重置门的作用是对过去的序列信息Ht-1进行选择,Ht-1中哪些序列信息对H~T是有用的,应该被保存下来,而哪些序列信息是不重要的,应该被遗忘。

1.2隐藏状态Ht:

在这里插入图片描述
因为Zt是一个[0,1] 之间的量,如果Zt全为0,则当前隐藏状态Ht为当前候选隐藏状态,该候选隐藏状态不仅保留了之前的序列信息,还保留了当前时间步batch的序列信息;如果Zt全为1,则当前隐藏状态Ht为上一个时间步的隐藏状态。

综上,更新门的作用是决定当前一个时间步的序列信息是否保留,如果Zt全为0,则说明当前时间步token的序列信息是有用的(候选隐藏状态包含之前的序列信息和当前一个时间步的序列信息),保留下来加入到隐藏状态Ht中;如果Zt全为1,则说明当前时间步batch的序列信息是没有用的,丢弃当前token的序列信息,直接使用上一个时间步的隐藏状态Ht-1作为当前的隐藏状态Ht。(Ht-1仅包含之前的序列信息,不包含当前一个时间步的序列信息)

2.GRU:

GRU网络架构如下,可以看做是三个隐藏层并排的架构。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

四、训练过程举例******:

以下文预测问题为例,一次epoch训练过程如下。
1.对整个文本进行数据预处理,获得数据字典,这里假设字典中有vocab_size条字典序,这样就转换成了一个vocab_size分类的序列问题。
2.将每个单词token值使用独热编码转换成1×vocab_size的一维向量,作为特征,表示各分类上的概率。
3.每轮epoch输入格式为batch_num×batch_size×num_steps×vocab_size,其中batch_num表示该轮压迫训练多少个batch,batch_size表示每个batch中有多少个句子序列,每个句子有num_steps个单词token,即该batch要训练多少个时间步,即循环time_step次传统神经网络,每个单词为一个一维向量,表示在字典序上的概率。每次训练一个batch,每个时间步t使用该batch中所有batch_size个序列的第t个token集合Xt进行训练(num_steps=t的token),batch尺寸为batch_size×num_steps×vocab_size,Xt尺寸为batch_size×vocab_size
4.隐藏层参数Whh维度为num_hiddens×num_hiddens,表示隐藏层关于序列信息(隐藏状态)的权重矩阵;Whx维度为vocab_size×num_hiddens,表示隐藏层关于输入特征的权重矩阵;参数bh维度为1×num_hiddens
5.三个并行隐藏层各自的参数Whh、Whz、Whr维度计算为num_hiddens×num_hiddens,表示隐藏层关于序列信息(隐藏状态)的权重矩阵;三个并行隐藏层各自的参数Wxh、Wxz、Wxr维度计算为vocab_size×num_hiddens,表示隐藏层关于输入特征的权重矩阵;参数bh、bz、br维度计算为1×num_hiddens。这里由于三个隐藏层输出维度相同,所以隐藏内的神经元数目都是相同的=num_hiddens。
6.对于第一个batch,训练过程如下:
6.1.初始化0时刻序列信息(隐藏层输出,隐藏状态)h0,尺寸为(batch_size,神经元个数num_hiddens)
6.2.t1时间步num_steps=1,取该batch所有序列样本的第一个token组成x0,尺寸batch_size×vocab_size,每个vocab一维向量并行放入神经网络学习,首先x0中每个token和ho同时进入更新门隐藏层和重置门隐藏层,重置门隐藏层输出R1=sigmoid(Whr×h0+Wxr×x0+br)、更新门隐藏层输出Z1=sigmoid(Whz×h0+Wxz×x0+bz),两个隐藏层分别用来筛选过去和当前的序列信息,输出维度均为batch_size×num_hiddens。
6.3.重置门输出R1、隐藏状态h0和x0中每个token进入候选隐藏状态隐藏层,使用重置门对过去的序列信息进行筛选,计算出候选隐藏状态H~1。
6.4.更新门输出Z1、隐藏状态h0和候选隐藏状态H~1联合计算,使用更新门对当前的序列信息进行筛选,计算出当前时间步的隐藏状态h1,隐藏层输出维度batch_size×num_hiddens,h1作为t1时间步的输出层输入、t2时间步的隐藏层输入序列信息(隐藏状态)。
6.5.此时两个操作并行执行:t1时间步的输出层计算、t2时间步的隐藏层计算。
6.5.1首先h1作为t1时间步的输出层输入,输出层有vocab_size个神经元,会执行多分类预测,可学习参数为Woh(num_hiddens×vocab_size)和bo(1×vocab_size),每个token输出维度1×vocab_size,输出层输出维度batch_size×vocab_size,表示各个token在各个分类上的预测。
6.5.2其次,t2时间步num_steps=2,取batch中num_steps=2的token集合为x1,维度为batch_size×vocab_size,并行将每个token一维向量放入神经网络学习,隐藏层输出h2=sigmoid(Whh×h1+Whx×x1+bh),每个token输出维度1×num_hiddens,隐藏层输出维度batch_size×num_hiddens,h2作为t2时间步的输出层输入、t3时间步的隐藏层输入序列信息。
6.6.如此反复每个时间步取一个数据点token集合进行训练,并更新隐藏层输出ht作为下一个时间步的输入,直到完成所有num_steps个时间步的训练任务,整个batch就训练完成了。
6.7.对于每个时间步上的预测batch_size×vocab_size,num_steps个时间步上总的预测为(num_steps×batch_size,vocab_size),这是该batch的训练总输出。
6.8.使用损失函数计算batch中各个句子中每个token的概率损失,并取均值。
6.9.反向传播算法计算各个参数关于损失函数的梯度。
6.10.梯度裁剪修改梯度。
6.11.梯度下降算法修改参数值。
7.该batch训练完成。进行下一个batch训练,初始化隐藏状态h0…。

五、预测过程举例******:

背景定义同训练过程,模型的预测过程如下。
1.输入prefix长度的前缀,来预测接下来num_preds个token。
2.首先还是将prefix转换成字典序并进行独热编码,尺寸为1×prefix×vocab_size,其中prefix=num_steps。
3.加载模型,初始化时序信息h0。
4.batch_size为1,在每个时间步上对句子长度每个token一维向量依次作为模型一个时间步的输入,输入维度1×vocab_size,总共计算prefix个时间步,循环计算prefix个时间步后的时序信息hp,hp尺寸为1×num_hiddens(batch_size=1)。
5.将prefix最后一个token和hp作为模型输入,来预测num_preds个token的第一个token,输出预测结果pred1和时序信息hp1,然后将pred1和hp1作为输入预测pred2和hp2(即使用预测值来预测下一个预测值),直到预测num_preds个预测值。(等价于batch=1,num_steps=num_preds的训练过程)
6.将预测值使用字典转为字符串输出。

六、底层源码:

代码中num_hiddens表示隐藏层神经元个数,由于重置门、更新门的输出维度相同,所以重置门和更新门两个隐藏层的神经元个数也是一样的=num_hiddens。

import torch
from torch import nn
from d2l import torch as d2l

# 数据预处理,获取datalodaer和字典
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

# 初始化可学习参数
def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01

    def three():
        return (normal(
            (num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xz, W_hz, b_z = three()
    W_xr, W_hr, b_r = three()
    W_xh, W_hh, b_h = three()
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

# 初始化隐藏状态
def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),)

# 定义门控循环单元模型
def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

# 训练
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                            init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

七、Pytorch版代码:

num_inputs = vocab_size
# 调用pytorch构建网络结构
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1976446.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

当自回归遇到Diffusion

文章目录 Autoregressive Image Generation without Vector Quantization一. 简介1.1 摘要1.1 引言二.相关工作2.1 Sequence Models for Image Generation2.2 Diffusion for Representation Learning2.3 Diffusion for Policy Learning三.方法3.1 重新思考离散值的tokens3.2 Di…

Kotlin OpenCV 图像图像50 Haar 级联分类器模型

Kotlin OpenCV 图像图像50 Haar 级联分类器模型 1 OpenCV Haar 级联分类器模型2 Kotlin OpenCV Haar 测试代码 1 OpenCV Haar 级联分类器模型 Haar级联分类器是一种用于对象检测(如人脸检测)的机器学习算法。它由Paul Viola和Michael Jones在2001年提出…

conda环境pip 安装Tensorflow-gpu 2.10.2提示nbconvert 的包依赖冲突

问题如下: ERROR: pip’s dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. nbconvert 7.16.4 requires beautifulsoup4, which is not inst…

DETR论文详解

文章目录 前言一、DETR理论二、模型架构1. CNN2. Transformer3. FFN 三、损失函数四、代码实现总结 前言 DETR是Facebook团队在2020年提出的一篇论文,名字叫做《End-to-End Object Detection with Transformers》端到端的基于Transformers的目标检测,DET…

数仓入门:数据分析模型、数仓建模、离线实时数仓、Lambda、Kappa、湖仓一体

往期推荐 大数据HBase图文简介-CSDN博客 数仓分层ODS、DWD、DWM、DWS、DIM、DM、ADS-CSDN博客 数仓常见名词解析和名词之间的关系-CSDN博客 目录 0. 前言 0.1 浅谈维度建模 0.2 数据分析模型 1. 何为数据仓库 1.1 为什么不直接用业务平台的数据而要建设数仓? …

ChatGPT能代替网络作家吗?

最强AI视频生成:小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频百万播放量https://aitools.jurilu.com/ 当然可以!只要你玩写作AI玩得6,甚至可以达到某些大神的水平! 看看大神、小白、AI输出内容的区…

重塑企业知识库:AI搜索的深度应用与变革

在数字化浪潮的推动下,企业知识库已成为企业智慧的核心载体。而AI搜索技术的融入,让海量信息瞬间变得井然有序,触手可及。它不仅革新了传统的搜索方式,更开启了企业知识管理的新纪元,引领着企业向更加智能化、高效化的…

【人工智能】FPGA实现人工智能算法硬件加速学习笔记

一. FPGA的优势 FPGA拥有高度的重配置性和并行处理能力,能够同时处理多个运算单元和多个数据并行操作。FPGA与卷积神经网络(CNN)的结合,有助于提升CNN的部署效率和性能。由于FPGA功耗很低的特性进一步增强了其吸引力。此外,FPGA可以根据具体算法需求量身打造硬件加速器。针对动…

[CR]厚云填补_SEGDNet

Structure-transferring edge-enhanced grid dehazing network Abstract 在过去的二十年里,图像去雾问题在计算机视觉界受到了极大的关注。在雾霾条件下,由于空气中水汽和粉尘颗粒的散射,图像的清晰度严重降低,使得许多计算机视觉…

鸿蒙媒体开发【基于AVCodec能力的视频编解码】音频和视频

基于AVCodec能力的视频编解码 介绍 本实例基于AVCodec能力,提供基于视频编解码的视频播放和录制的功能。 视频播放的主要流程是将视频文件通过解封装->解码->送显/播放。视频录制的主要流程是相机采集->编码->封装成mp4文件。 播放支持的原子能力规…

【从0到1进阶Redis】Jedis 操作 Redis

笔记内容来自B站博主《遇见狂神说》:Redis视频链接 Jedis 是一个用于 Java 的 Redis 客户端库,它提供了一组 API 用于与 Redis 数据库进行交互。Redis 是一个高性能的键值存储数据库,广泛用于缓存、消息队列等场景。Jedis 使得 Java 开发者能…

图欧科技-IMYAI智能助手24年5月~7月更新日志大汇总

上一篇推文盘点了我们图欧科技团队近一年来的更新日志,可以说是跟随着人工智能时代的发展,我们的IMYAI也丝毫不落后于这场时代的浪潮!近三个月以来,我们的更新频率直线上升,现在我们AI网站已经成为一个集GPT、Claude、…

《学会 SpringMVC 系列 · 消息转换器 MessageConverters》

📢 大家好,我是 【战神刘玉栋】,有10多年的研发经验,致力于前后端技术栈的知识沉淀和传播。 💗 🌻 CSDN入驻不久,希望大家多多支持,后续会继续提升文章质量,绝不滥竽充数…

Inno Setup 安装界面、卸载界面+美化

Inno Setup Inno Setup用Delphi写成,其官方网站同时也提供源程序免费下载。它虽不能与Installshield这类恐龙级的安装制作软件相比,但也当之无愧算是后起之秀。Inno Setup是一个免费的安装制作软件,小巧、简便、精美是其最大特点,…

arduino程序—模拟输出(基础知识)

arduino程序—模拟输出(基础知识) 1-25 模拟输出1-analogWrite电路效果演示模拟输出analog output复合运算符示例程序Analogwrite() 1-26 模拟输出2-PWMPWM概念(极其重要) 1-27 模拟输出3-for电路效果演示程…

【Verilog-CBB】开发与验证(2)——单比特信号CDC同步器

引言 多时钟域的设计中,CDC处理的场景还是蛮多的。单比特信号在CDC时,为保证信号采样的安全性,降低亚稳态,必须要对信号做同步处理。CDC从时钟的快慢关系来说分为两种case:快到慢、慢到快。对于脉冲型的控制信号&…

『C++实战项目 负载均衡式在线OJ』一、项目介绍与效果展示(持续更新)

文章目录 一、项目介绍二、开发环境三、第三方库四、相关技术五、项目整体框架代码目录框架 代码仓库连接 点击这里✈ 一、项目介绍 本项目是实现一个仿 leetcode 的 OJ (Online-Judge)系统。更准确的说应该称之为leetcode 的裁剪版。因为本项目只实现了leetcode中…

‘#‘ is not followed by a macro parameter 关于宏定义的错误

今天在项目代码上想定义一个这样的宏,结果编译错误,这个宏定义类似这样的: #define DELETE_FILE_DPP(key) \ #ifdef PLATFORM_DPP \delete_file(&key); \ #endif 因为有平台之分需要用到编译宏,但不想每个调用的地方都写 #i…

HTML 专业词汇与语法规则

目录 1. 专业词汇 2. 语法规则 1. 专业词汇 标签&#xff08;tag&#xff09;&#xff1a;一堆尖叫号&#xff08;<>&#xff09;&#xff0c; 属性&#xff08;attribute&#xff09;&#xff1a;对标签特征设置的方式&#xff1b; 文本&#xff08;text&#xff0…

【外排序】--- 文件归并排序的实现

Welcome to 9ilks Code World (๑•́ ₃ •̀๑) 个人主页: 9ilk (๑•́ ₃ •̀๑) 文章专栏&#xff1a; 数据结构 我们之前学习的八大排序&#xff1a;冒泡&#xff0c;快排&#xff0c;插入&#xff0c;堆排等都是内排序&#xff0c;这些排序算法处理的都是…