【从零开始学习深度学习】35. 门控循环神经网络之门控循环单元(gated recurrent unit,GRU)介绍、Pytorch实现GRU并进行训练预测

news2024/12/27 12:13:35

在循环神经网络中,当时间步数较大或者时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。上一篇文章中介绍的裁剪梯度可以应对梯度爆炸,但无法解决梯度衰减的问题。因此,循环神经网络在实际中较难捕捉时间序列中时间步距离较大的依赖关系。

为了更好地捕捉时间序列中时间步距离较大的依赖关系,从而提出了门控循环神经网络(gated recurrent neural network)。它可以通过学习的门来控制信息的流动。其中,门控循环单元(gated recurrent unit,GRU)是一种常用的门控循环神经网络 。

目录

    • 1. 门控循环单元设计
      • 1.1 重置门和更新门
      • 1.2 候选隐藏状态
      • 1.3 隐藏状态
    • 2 读取数据集
    • 3 从零实现门控循环单元并进行歌词训练与预测
      • 3.1 初始化模型参数
      • 3.2 定义模型
      • 3.3 训练模型并创作歌词
    • 4 基于Pytorch的nn.GRU模块实现GRU并进行歌词训练与预测
    • 总结

1. 门控循环单元设计

门控循环单元的设计在原始RNN的基础上引入了重置门(reset gate)和更新门(update gate)的概念,从而修改了循环神经网络中隐藏状态的计算方式。

1.1 重置门和更新门

如下图所示,门控循环单元中的重置门和更新门的输入均为当前时间步输入 X t \boldsymbol{X}_t Xt与上一时间步隐藏状态 H t − 1 \boldsymbol{H}_{t-1} Ht1,输出由激活函数为sigmoid函数的全连接层计算得到。

在这里插入图片描述

假设隐藏单元个数为 h h h,给定时间步 t t t的小批量输入 X t ∈ R n × d \boldsymbol{X}_t \in \mathbb{R}^{n \times d} XtRn×d(样本数为 n n n,输入个数为 d d d)和上一时间步隐藏状态 H t − 1 ∈ R n × h \boldsymbol{H}_{t-1} \in \mathbb{R}^{n \times h} Ht1Rn×h。重置门 R t ∈ R n × h \boldsymbol{R}_t \in \mathbb{R}^{n \times h} RtRn×h和更新门 Z t ∈ R n × h \boldsymbol{Z}_t \in \mathbb{R}^{n \times h} ZtRn×h的计算如下:

R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , \begin{aligned} \boldsymbol{R}_t = \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xr} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hr} + \boldsymbol{b}_r),\\ \boldsymbol{Z}_t = \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xz} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hz} + \boldsymbol{b}_z), \end{aligned} Rt=σ(XtWxr+Ht1Whr+br),Zt=σ(XtWxz+Ht1Whz+bz),

其中 W x r , W x z ∈ R d × h \boldsymbol{W}_{xr}, \boldsymbol{W}_{xz} \in \mathbb{R}^{d \times h} Wxr,WxzRd×h W h r , W h z ∈ R h × h \boldsymbol{W}_{hr}, \boldsymbol{W}_{hz} \in \mathbb{R}^{h \times h} Whr,WhzRh×h是权重参数, b r , b z ∈ R 1 × h \boldsymbol{b}_r, \boldsymbol{b}_z \in \mathbb{R}^{1 \times h} br,bzR1×h是偏差参数。sigmoid函数可以将元素的值变换到0和1之间。因此,重置门 R t \boldsymbol{R}_t Rt和更新门 Z t \boldsymbol{Z}_t Zt中每个元素的值域都是 [ 0 , 1 ] [0, 1] [0,1]

1.2 候选隐藏状态

接下来,门控循环单元将计算候选隐藏状态来辅助稍后的隐藏状态计算。如下图所示,我们将当前时间步重置门的输出与上一时间步隐藏状态做按元素乘法(符号为 ⊙ \odot )。如果重置门中元素值接近0,那么意味着重置对应隐藏状态元素为0,即丢弃上一时间步的隐藏状态。如果元素值接近1,那么表示保留上一时间步的隐藏状态。然后,将按元素乘法的结果与当前时间步的输入连结,再通过含激活函数tanh的全连接层计算出候选隐藏状态,其所有元素的值域为 [ − 1 , 1 ] [-1, 1] [1,1]
在这里插入图片描述

具体来说,时间步 t t t的候选隐藏状态 H ~ t ∈ R n × h \tilde{\boldsymbol{H}}_t \in \mathbb{R}^{n \times h} H~tRn×h的计算为

H ~ t = tanh ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , \tilde{\boldsymbol{H}}_t = \text{tanh}(\boldsymbol{X}_t \boldsymbol{W}_{xh} + \left(\boldsymbol{R}_t \odot \boldsymbol{H}_{t-1}\right) \boldsymbol{W}_{hh} + \boldsymbol{b}_h), H~t=tanh(XtWxh+(RtHt1)Whh+bh),

其中 W x h ∈ R d × h \boldsymbol{W}_{xh} \in \mathbb{R}^{d \times h} WxhRd×h W h h ∈ R h × h \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h是权重参数, b h ∈ R 1 × h \boldsymbol{b}_h \in \mathbb{R}^{1 \times h} bhR1×h是偏差参数。从上面这个公式可以看出,重置门控制了上一时间步的隐藏状态如何流入当前时间步的候选隐藏状态。而上一时间步的隐藏状态可能包含了时间序列截至上一时间步的全部历史信息。因此,重置门可以用来丢弃与预测无关的历史信息。

1.3 隐藏状态

最后,时间步 t t t的隐藏状态 H t ∈ R n × h \boldsymbol{H}_t \in \mathbb{R}^{n \times h} HtRn×h的计算使用当前时间步的更新门 Z t \boldsymbol{Z}_t Zt来对上一时间步的隐藏状态 H t − 1 \boldsymbol{H}_{t-1} Ht1和当前时间步的候选隐藏状态 H ~ t \tilde{\boldsymbol{H}}_t H~t做组合:

H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \boldsymbol{H}_t = \boldsymbol{Z}_t \odot \boldsymbol{H}_{t-1} + (1 - \boldsymbol{Z}_t) \odot \tilde{\boldsymbol{H}}_t. Ht=ZtHt1+(1Zt)H~t.

在这里插入图片描述

更新门可以控制隐藏状态应该如何被包含当前时间步信息的候选隐藏状态所更新,如上图所示。假设更新门在时间步 t ′ t' t t t t t ′ < t t' < t t<t)之间一直近似1。那么,在时间步 t ′ t' t t t t之间的输入信息几乎没有流入时间步 t t t的隐藏状态 H t \boldsymbol{H}_t Ht。实际上,这可以看作是较早时刻的隐藏状态 H t ′ − 1 \boldsymbol{H}_{t'-1} Ht1一直通过时间保存并传递至当前时间步 t t t。这个设计可以应对循环神经网络中的梯度衰减问题,并更好地捕捉时间序列中时间步距离较大的依赖关系。

总结:

  • 重置门有助于捕捉时间序列里短期的依赖关系;
  • 更新门有助于捕捉时间序列里长期的依赖关系。

2 读取数据集

为了实现并展示门控循环单元,下面依然使用上一篇文章中的周杰伦歌词专辑数据集来训练模型作词。

数据集获取参见上一篇文章《【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例–使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】》。

import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

3 从零实现门控循环单元并进行歌词训练与预测

3.1 初始化模型参数

对模型参数进行初始化,超参数num_hiddens定义了隐藏单元的个数。

num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
print('will use', device)

def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
        return torch.nn.Parameter(ts, requires_grad=True)
    def _three():
        return (_one((num_inputs, num_hiddens)),
                _one((num_hiddens, num_hiddens)),
                torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))
    
    W_xz, W_hz, b_z = _three()  # 更新门参数
    W_xr, W_hr, b_r = _three()  # 重置门参数
    W_xh, W_hh, b_h = _three()  # 候选隐藏状态参数
    
    # 输出层参数
    W_hq = _one((num_hiddens, num_outputs))
    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)
    return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q])

3.2 定义模型

定义隐藏状态初始化函数init_gru_state,它返回由一个形状为(批量大小, 隐藏单元个数)的值为0的Tensor组成的元组。

def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

下面根据门控循环单元的计算表达式定义模型。

def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z)
        R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r)
        H_tilda = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(R * H, W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = torch.matmul(H, W_hq) + b_q
        outputs.append(Y)
    return outputs, (H,)

3.3 训练模型并创作歌词

我们在训练模型时只使用相邻采样。设置好超参数后,我们将训练模型并根据前缀“分开”和“不分开”分别创作长度为50个字符的一段歌词。

num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

我们每过40个迭代周期便根据当前训练的模型创作一段歌词。

d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens,
                          vocab_size, device, corpus_indices, idx_to_char,
                          char_to_idx, False, num_epochs, num_steps, lr,
                          clipping_theta, batch_size, pred_period, pred_len,
                          prefixes)

输出:

epoch 40, perplexity 152.550790, time 2.29 sec
 - 分开 我不不 你不我 你不我 你不我 你不我 你不我 你不我 你不我 你不我 你不我 你不我 你不我 你
 - 不分开 一哼我 我不不 你不了我 你不我 你不我 你不我 你不我 你不我 你不我 你不我 你不我 你不我 
epoch 80, perplexity 32.991306, time 2.28 sec
 - 分开 我想要这样的微笑在人人卷戏 爱不再再我 你的美美 你在完人  你在在人的溪边默默默默默著著我 娘子
 - 不分开 我不能再想 我不要再想 我不能再想 我不能再想 我不能再想 我不能再想 我不能再想 我不能再想 我
epoch 120, perplexity 6.238240, time 2.39 sec
 - 分开 我想就这样牵着你的手不放开 爱可不可以简简单单没有伤害 你 靠着我的肩膀 你 在我胸口睡著 一定个
 - 不分开 不知再觉 你是一个人演慢 一直风 三步三步步步四望 连成线背著背默默许下心愿 看远方的星如下听的见
epoch 160, perplexity 1.926641, time 2.64 sec
 - 分开 我不要再宣牵我对你 感感 让给我抬起你有 从杰去真医 你在过人 何都没有 说我该轻的证  从情着头
 - 不分开 不知再觉 你是心蒙 迷迷了中留的寻找 停堡里一只点芜 长满杂草的泥剩 不会骑扫二的胖女还 用拉丁文

4 基于Pytorch的nn.GRU模块实现GRU并进行歌词训练与预测

在PyTorch中我们直接调用nn模块中的GRU类即可。

lr = 1e-2 # 注意调整学习率
gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens)
model = d2l.RNNModel(gru_layer, vocab_size).to(device)
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
                                corpus_indices, idx_to_char, char_to_idx,
                                num_epochs, num_steps, lr, clipping_theta,
                                batch_size, pred_period, pred_len, prefixes)

输出:

epoch 40, perplexity 1.017262, time 0.87 sec
 - 分开始乡相信命运 感谢地心引力 让我碰到你 漂亮的让我面红的可爱女人 温柔的让我心疼的可爱女人 透明的让
 - 不分开始打呼 管家是一只会说法语举止优雅的猪 吸血前会念约翰福音做为弥补 拥有一双蓝色眼睛的凯萨琳公主 专
epoch 80, perplexity 1.015187, time 1.22 sec
 - 分开始乡相信命运 感谢地心引力 让我碰到你 漂亮的让我面红的可爱女人 温柔的让我心疼的可爱女人 透明的让
 - 不分开 它一定实现 娘子 娘子却依旧每日 折一枝杨柳 你在那里 在小村外的溪边河口默默等著我 娘子依旧每日
epoch 120, perplexity 1.013440, time 0.85 sec
 - 分开始乡相信命运 感谢地心引力 让我碰到你 漂亮的让我面红的可爱女人 温柔的让我心疼的可爱女人 透明的让
 - 不分开 陷入了危险边缘Baby  我的世界已狂风暴雨 Wu  爱情来的太快就像龙卷风 离不开暴风圈来不及逃
epoch 160, perplexity 1.910635, time 0.82 sec
 - 分开的话你甘会听 有教堂有城堡 每天忙碌地的寻找 到底什么我有多烦恼  没有你烦我有多烦恼  没有多烦恼
 - 不分开 别发抖 快给我抬起头 有话去对医药箱说 别怪我 别发抖 快给我抬起头 有话去对医药箱说 别怪我 别

总结

  • 门控循环神经网络可以更好地捕捉时间序列中时间步距离较大的依赖关系。
  • 门控循环单元引入了门的概念,从而修改了循环神经网络中隐藏状态的计算方式。它包括重置门、更新门、候选隐藏状态和隐藏状态。
  • 重置门有助于捕捉时间序列里短期的依赖关系。
  • 更新门有助于捕捉时间序列里长期的依赖关系。

如果文章内容对你有帮助,感谢点赞+关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/136100.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Elastic-Job分布式任务调度(1):概述

1 什么是任务调度 我们可以先思考一下下面业务场景的解决方案&#xff1a; 某电商系统需要在每天上午10点&#xff0c;下午3点&#xff0c;晚上8点发放一批优惠券。某银行系统需要在信用卡到期还款日的前三天进行短信提醒。某财务系统需要在每天凌晨0:10结算前一天的财务数据…

【自学Python】Linux安装Python

Linux安装Python Python下载 Python下载地址 https://www.python.org/ftp/python/3.7.4/Python-3.7.4.tar.xzPython下载 我们在 Linux 终端中&#xff0c;直接使用 wget 命令&#xff0c;下载 Linux 版 Python 的安装包&#xff0c;我们在终端输入以下命令&#xff1a; wg…

PAT乙级|1094 谷歌的招聘

题源https://pintia.cn/problem-sets/994805260223102976/exam/problems/1071785997033074688 提交1&#xff1a;一个用例没过 提交2&#xff1a;AC 错因&#xff1a;输出需为字符串&#xff0c;例如在 200236 中找 4 位素数&#xff0c;解是0023 关键&#xff1a;第33行代码…

linphone android sdk 源码下载编译

前言 前面的有写过Android 使用Linphone SDK开发SIP客户端相关的文章, 在后续的开发过程中, 为了更深入了解linphone, 便尝试下载SDK源码自行编译. 关于linphone这里不作过多介绍, 可以参考前面的文章. Linphone-SDK 是一个将 Liblinphone 及其依赖项捆绑为 git 子模块的项目&a…

HTC FOCUS3在PC端串流FOHEART H1数据手套(手柄)

本教程介绍使用FOHEART H1数据手套与HTC手柄驱动VR中的虚拟手运动&#xff0c;实现手部的追踪及定位。 本教程内容与之前使用腕带定位&#xff08;HTC FOCUS3在PC端串流FOHEART H1数据手套&#xff08;腕带&#xff09;&#xff09;不同&#xff0c;这次我们使用头显中自带的…

【Kuangbin简单DP】挤奶时间

4561. 挤奶时间 - AcWing题库 题意&#xff1a; 思路&#xff1a; 一开始的思路是把这么多的区间当作物品&#xff0c;然后选与不选&#xff0c;这样去搞线性DP 显然是不行的&#xff0c;因为这样答案就不知道怎么统计了 而且&#xff0c;我们是设阶段&#xff01;&#xf…

HSK汉语考试变革,您需要了解以下几点

2023年HSK考试可能有哪些变化汉语考试难度增加了还是减低了&#xff1f; 对现在的课程和教材有影响&#xff1f; 汉语老师怎么样应对&#xff1f;HSK考试变化猜想1.HSK3级考试和HSKK初级结合在一起 2.HSK4级考试和HSKK中级结合在一起 3.HSK5,6级考试和HSKK高级结合在一起HSKK考…

INTERSPEECH 2022|面向零样本声音克隆的内容相关细粒度说话人表征方法

本文由清华大学与腾讯 AI Lab、香港中文大学合作。 零样本说话人自适应&#xff08;zero-shot speaker adaptation&#xff09;&#xff0c;或称为零样本声音克隆&#xff0c;旨在根据任意一条参考语音&#xff08;reference speech&#xff09;合成训练过程中从未见过的说话人…

Leetcode:239. 滑动窗口最大值(C++)

目录 问题描述&#xff1a; 实现代码和解析&#xff1a; 暴力法&#xff08;会超时&#xff09;&#xff1a; 原理思路&#xff1a; 单调队列法&#xff1a; 原理思路&#xff1a; 单调队列&#xff1a; 模拟过程&#xff1a; 问题描述&#xff1a; 给你一个整数数组…

Python基础知识(二)

目录 顺序语句 条件语句 条件语句书写格式一及对比&#xff1a;if条件语句 条件语句书写格式二及对比&#xff1a;if...else...语句 条件语句书写格式三及对比&#xff1a;if...elif...else语句 空语句pass 条件语句的总结&#xff1a; 循环语句 while循环 与c/java/…

对于Muduo主从Reactor模式的理解

从12月20号开始看Muduo网络库&#xff0c;到28号的时候弄懂了EventLoop, Poller, Channel是怎么一回事&#xff0c;一番琢磨之后觉得还是应该发到博客上跟大家分享&#xff0c;特此记录。 对照linyacool那个webserver的实现&#xff0c;再看了一遍muduo的EventLoop, Poller ,C…

IDEA快速启动多个微服务模块 -idea如何开启Run DashBoard

文章目录 缘起 Run DashBoard面板如何开启开启 Run DashBoard 注意&#xff1a; 缘起 在idea里面如果需要启动多个项目的话&#xff0c;尤其是是比如微服务项目&#xff0c;动辄要启动五六个七八个应用&#xff0c;如果通过右上角那边启动会很不方便&#xff0c;你需要选择…

基于GIS简单处理世界土壤数据库(HWSD)的中国土壤数据集

来源&#xff1a;GIS前沿 一、 数据介绍 土壤属性表主要字段包括&#xff08;图1&#xff09;&#xff1a;详细描述请参考Harmonized World Soil Database (version 1.1).pdf文件&#xff0c;其中以T开头的土壤属性表示土壤上层的属性&#xff08;0-30cm&#xff09;&#xff…

【曲线全局逼近】

曲线全局逼近 本文是基于 这篇文章 翻译而来的&#xff0c;仅学习。 在插值中&#xff0c;插值曲线以给定的顺序通过所有给定的数据点。正如在全局插值页面中所讨论的&#xff0c;插值曲线可能会在所有数据点上摆动&#xff0c;而不是紧紧跟随数据多边形。为了克服这个问题&…

包装类的使用

文章目录一、单元测试方法的使用步骤二、包装类的使用基本数据类型、包装类、String类型之间的相互转化基本数据类型——>包装类注意包装类——>基本数据类型自动装箱与自动拆箱&#xff08;jdk5.0后&#xff09;基本数据类型、包装类——>String类型String类型——&g…

史上最全 Appium 自动化测试从基础到框架实战精华学习笔记(一)

1080402 31.8 KB 对测试人来说&#xff0c;Appium 是非常重要的一个开源跨平台自动化测试工具&#xff0c;它允许测试人员在不同的平台&#xff08;iOS、Android 等&#xff09;使用同一套 API 来写自动化测试脚本&#xff0c;这样可大幅提升代码复用率和工作效率。 本文汇总了…

郭盛华:警惕家庭智能扬声器中潜在的窃听风险

一名安全研究人员因识别Google Home智能扬声器中的安全问题而获得了107500美元的漏洞赏金&#xff0c;这些问题可能被用来安装后门并将其变成窃听设备。 国际知名网络黑客安全专家、东方联盟创始人郭盛华在一篇技术文章中透露&#xff1a;这些漏洞“允许无线附近的攻击者在设备…

服务的雪崩以及解决方案

文章目录一、什么是服务的雪崩二、服务雪崩形成的原因三、雪崩解决方案3.1 设置超时时间3.2 线程隔离&#xff08;舱壁模式&#xff09;3.3 熔断器&#xff08;断路器&#xff09;3.4 限流四、总结一、什么是服务的雪崩 服务的雪崩效应是一种因服务提供者不可用导致服务调用者…

从源码角度带你清楚分析Spring 的Lazy-init 延迟加载机制原理

lazy-init 延迟加载应用 ApplicationContext 容器的默认值行为是在启动服务器时将所有Singleton Bean 提前进行实例,提前实例化意味着作为初始化过程的一部分,ApplicationContext 实例会创建并配置所有的singleton Bean. 例如: <bean id"testBean" class"c…

张力控制PID增益(Kp)自适应算法详解(含SCL和梯形图完整源代码)

有关收放卷张力控制的详细内容,请参看下面的文章链接,这里不再赘述。 变频器简单张力控制(线缆收放卷应用)_RXXW_Dor的博客-CSDN博客张力控制的开闭环算法,可以查看专栏的其它文章,链接地址如下:PLC张力控制(开环闭环算法分析)_RXXW_Dor的博客-CSDN博客。https://blo…