【深度学习笔记】6_7 门控循环单元(GRU)

news2024/9/23 1:27:39

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

6.7 门控循环单元(GRU)

上一节介绍了循环神经网络中的梯度计算方法。我们发现,当时间步数较大或者时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。虽然裁剪梯度可以应对梯度爆炸,但无法解决梯度衰减的问题。通常由于这个原因,循环神经网络在实际中较难捕捉时间序列中时间步距离较大的依赖关系。

门控循环神经网络(gated recurrent neural network)的提出,正是为了更好地捕捉时间序列中时间步距离较大的依赖关系。它通过可以学习的门来控制信息的流动。其中,门控循环单元(gated recurrent unit,GRU)是一种常用的门控循环神经网络 [1, 2]。另一种常用的门控循环神经网络则将在下一节中介绍。

6.7.1 门控循环单元

下面将介绍门控循环单元的设计。它引入了重置门(reset gate)和更新门(update gate)的概念,从而修改了循环神经网络中隐藏状态的计算方式。

6.7.1.1 重置门和更新门

如图6.4所示,门控循环单元中的重置门和更新门的输入均为当前时间步输入 X t \boldsymbol{X}_t Xt与上一时间步隐藏状态 H t − 1 \boldsymbol{H}_{t-1} Ht1,输出由激活函数为sigmoid函数的全连接层计算得到。

在这里插入图片描述

图6.4 门控循环单元中重置门和更新门的计算

具体来说,假设隐藏单元个数为 h h h,给定时间步 t t t的小批量输入 X t ∈ R n × d \boldsymbol{X}_t \in \mathbb{R}^{n \times d} XtRn×d(样本数为 n n n,输入个数为 d d d)和上一时间步隐藏状态 H t − 1 ∈ R n × h \boldsymbol{H}_{t-1} \in \mathbb{R}^{n \times h} Ht1Rn×h。重置门 R t ∈ R n × h \boldsymbol{R}_t \in \mathbb{R}^{n \times h} RtRn×h和更新门 Z t ∈ R n × h \boldsymbol{Z}_t \in \mathbb{R}^{n \times h} ZtRn×h的计算如下:

R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , \begin{aligned} \boldsymbol{R}_t = \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xr} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hr} + \boldsymbol{b}_r),\\ \boldsymbol{Z}_t = \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xz} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hz} + \boldsymbol{b}_z), \end{aligned} Rt=σ(XtWxr+Ht1Whr+br),Zt=σ(XtWxz+Ht1Whz+bz),

其中 W x r , W x z ∈ R d × h \boldsymbol{W}_{xr}, \boldsymbol{W}_{xz} \in \mathbb{R}^{d \times h} Wxr,WxzRd×h W h r , W h z ∈ R h × h \boldsymbol{W}_{hr}, \boldsymbol{W}_{hz} \in \mathbb{R}^{h \times h} Whr,WhzRh×h是权重参数, b r , b z ∈ R 1 × h \boldsymbol{b}_r, \boldsymbol{b}_z \in \mathbb{R}^{1 \times h} br,bzR1×h是偏差参数。3.8节(多层感知机)节中介绍过,sigmoid函数可以将元素的值变换到0和1之间。因此,重置门 R t \boldsymbol{R}_t Rt和更新门 Z t \boldsymbol{Z}_t Zt中每个元素的值域都是 [ 0 , 1 ] [0, 1] [0,1]

6.7.1.2 候选隐藏状态

接下来,门控循环单元将计算候选隐藏状态来辅助稍后的隐藏状态计算。如图6.5所示,我们将当前时间步重置门的输出与上一时间步隐藏状态做按元素乘法(符号为 ⊙ \odot )。如果重置门中元素值接近0,那么意味着重置对应隐藏状态元素为0,即丢弃上一时间步的隐藏状态。如果元素值接近1,那么表示保留上一时间步的隐藏状态。然后,将按元素乘法的结果与当前时间步的输入连结,再通过含激活函数tanh的全连接层计算出候选隐藏状态,其所有元素的值域为 [ − 1 , 1 ] [-1, 1] [1,1]

在这里插入图片描述

图6.5 门控循环单元中候选隐藏状态的计算

具体来说,时间步 t t t的候选隐藏状态 H ~ t ∈ R n × h \tilde{\boldsymbol{H}}_t \in \mathbb{R}^{n \times h} H~tRn×h的计算为

H ~ t = tanh ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , \tilde{\boldsymbol{H}}_t = \text{tanh}(\boldsymbol{X}_t \boldsymbol{W}_{xh} + \left(\boldsymbol{R}_t \odot \boldsymbol{H}_{t-1}\right) \boldsymbol{W}_{hh} + \boldsymbol{b}_h), H~t=tanh(XtWxh+(RtHt1)Whh+bh),

其中 W x h ∈ R d × h \boldsymbol{W}_{xh} \in \mathbb{R}^{d \times h} WxhRd×h W h h ∈ R h × h \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h是权重参数, b h ∈ R 1 × h \boldsymbol{b}_h \in \mathbb{R}^{1 \times h} bhR1×h是偏差参数。从上面这个公式可以看出,重置门控制了上一时间步的隐藏状态如何流入当前时间步的候选隐藏状态。而上一时间步的隐藏状态可能包含了时间序列截至上一时间步的全部历史信息。因此,重置门可以用来丢弃与预测无关的历史信息。

6.7.1.3 隐藏状态

最后,时间步 t t t的隐藏状态 H t ∈ R n × h \boldsymbol{H}_t \in \mathbb{R}^{n \times h} HtRn×h的计算使用当前时间步的更新门 Z t \boldsymbol{Z}_t Zt来对上一时间步的隐藏状态 H t − 1 \boldsymbol{H}_{t-1} Ht1和当前时间步的候选隐藏状态 H ~ t \tilde{\boldsymbol{H}}_t H~t做组合:

H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \boldsymbol{H}_t = \boldsymbol{Z}_t \odot \boldsymbol{H}_{t-1} + (1 - \boldsymbol{Z}_t) \odot \tilde{\boldsymbol{H}}_t. Ht=ZtHt1+(1Zt)H~t.

在这里插入图片描述

图6.6 门控循环单元中隐藏状态的计算

值得注意的是,更新门可以控制隐藏状态应该如何被包含当前时间步信息的候选隐藏状态所更新,如图6.6所示。假设更新门在时间步 t ′ t' t t t t t ′ < t t' < t t<t)之间一直近似1。那么,在时间步 t ′ t' t t t t之间的输入信息几乎没有流入时间步 t t t的隐藏状态 H t \boldsymbol{H}_t Ht。实际上,这可以看作是较早时刻的隐藏状态 H t ′ − 1 \boldsymbol{H}_{t'-1} Ht1一直通过时间保存并传递至当前时间步 t t t。这个设计可以应对循环神经网络中的梯度衰减问题,并更好地捕捉时间序列中时间步距离较大的依赖关系。

我们对门控循环单元的设计稍作总结:

  • 重置门有助于捕捉时间序列里短期的依赖关系;
  • 更新门有助于捕捉时间序列里长期的依赖关系。

6.7.2 读取数据集

为了实现并展示门控循环单元,下面依然使用周杰伦歌词数据集来训练模型作词。这里除门控循环单元以外的实现已在6.2节(循环神经网络)中介绍过。以下为读取数据集部分。

import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

6.7.3 从零开始实现

我们先介绍如何从零开始实现门控循环单元。

6.7.3.1 初始化模型参数

下面的代码对模型参数进行初始化。超参数num_hiddens定义了隐藏单元的个数。

num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
print('will use', device)

def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
        return torch.nn.Parameter(ts, requires_grad=True)
    def _three():
        return (_one((num_inputs, num_hiddens)),
                _one((num_hiddens, num_hiddens)),
                torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))
    
    W_xz, W_hz, b_z = _three()  # 更新门参数
    W_xr, W_hr, b_r = _three()  # 重置门参数
    W_xh, W_hh, b_h = _three()  # 候选隐藏状态参数
    
    # 输出层参数
    W_hq = _one((num_hiddens, num_outputs))
    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)
    return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q])

6.7.3.2 定义模型

下面的代码定义隐藏状态初始化函数init_gru_state。同6.4节(循环神经网络的从零开始实现)中定义的init_rnn_state函数一样,它返回由一个形状为(批量大小, 隐藏单元个数)的值为0的Tensor组成的元组。

def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

下面根据门控循环单元的计算表达式定义模型。

def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z)
        R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r)
        H_tilda = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(R * H, W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = torch.matmul(H, W_hq) + b_q
        outputs.append(Y)
    return outputs, (H,)

6.7.3.3 训练模型并创作歌词

我们在训练模型时只使用相邻采样。设置好超参数后,我们将训练模型并根据前缀“分开”和“不分开”分别创作长度为50个字符的一段歌词。

num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

我们每过40个迭代周期便根据当前训练的模型创作一段歌词。

d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens,
                          vocab_size, device, corpus_indices, idx_to_char,
                          char_to_idx, False, num_epochs, num_steps, lr,
                          clipping_theta, batch_size, pred_period, pred_len,
                          prefixes)

输出:

epoch 40, perplexity 149.477598, time 1.08 sec
 - 分开 我不不你 我想你你的爱我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你
 - 不分开 我想你你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我
epoch 80, perplexity 31.689210, time 1.10 sec
 - 分开 我想要你 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不
 - 不分开 我想要你 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不
epoch 120, perplexity 4.866115, time 1.08 sec
 - 分开 我想要这样牵着你的手不放开 爱过 让我来的肩膀 一起好酒 你来了这节秋 后知后觉 我该好好生活 我
 - 不分开 你已经不了我不要 我不要再想你 我不要再想你 我不要再想你 不知不觉 我跟了这节奏 后知后觉 又过
epoch 160, perplexity 1.442282, time 1.51 sec
 - 分开 我一定好生忧 唱着歌 一直走 我想就这样牵着你的手不放开 爱可不可以简简单单没有伤害 你 靠着我的
 - 不分开 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好生活 我该好好生活

6.7.4 简洁实现

在PyTorch中我们直接调用nn模块中的GRU类即可。

lr = 1e-2 # 注意调整学习率
gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens)
model = d2l.RNNModel(gru_layer, vocab_size).to(device)
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
                                corpus_indices, idx_to_char, char_to_idx,
                                num_epochs, num_steps, lr, clipping_theta,
                                batch_size, pred_period, pred_len, prefixes)

输出:

epoch 40, perplexity 1.022157, time 1.02 sec
 - 分开手牵手 一步两步三步四步望著天 看星星 一颗两颗三颗四颗 连成线背著背默默许下心愿 看远方的星是否听
 - 不分开暴风圈来不及逃 我不能再想 我不能再想 我不 我不 我不能 爱情走的太快就像龙卷风 不能承受我已无处
epoch 80, perplexity 1.014535, time 1.04 sec
 - 分开始想像 爸和妈当年的模样 说著一口吴侬软语的姑娘缓缓走过外滩 消失的 旧时光 一九四三 在回忆 的路
 - 不分开始爱像  不知不觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好
epoch 120, perplexity 1.147843, time 1.04 sec
 - 分开都靠我 你拿着球不投 又不会掩护我 选你这种队友 瞎透了我 说你说 分数怎么停留 所有回忆对着我进攻
 - 不分开球我有多烦恼多 牧草有没有危险 一场梦 我面对我 甩开球我满腔的怒火 我想揍你已经很久 别想躲 说你
epoch 160, perplexity 1.018370, time 1.05 sec
 - 分开爱上你 那场悲剧 是你完美演出的一场戏 宁愿心碎哭泣 再狠狠忘记 你爱过我的证据 让晶莹的泪滴 闪烁
 - 不分开始 担心今天的你过得好不好 整个画面是你 想你想的睡不著 嘴嘟嘟那可爱的模样 还有在你身上香香的味道

小结

  • 门控循环神经网络可以更好地捕捉时间序列中时间步距离较大的依赖关系。
  • 门控循环单元引入了门的概念,从而修改了循环神经网络中隐藏状态的计算方式。它包括重置门、更新门、候选隐藏状态和隐藏状态。
  • 重置门有助于捕捉时间序列里短期的依赖关系。
  • 更新门有助于捕捉时间序列里长期的依赖关系。

参考文献

[1] Cho, K., Van Merriënboer, B., Bahdanau, D., & Bengio, Y. (2014). On the properties of neural machine translation: Encoder-decoder approaches. arXiv preprint arXiv:1409.1259.

[2] Chung, J., Gulcehre, C., Cho, K., & Bengio, Y. (2014). Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555.


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1510790.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[蓝桥杯 2020 省 B2] 平面切分

题目链接 [蓝桥杯 2020 省 B2] 平面切分 题目描述 平面上有 N N N 条直线, 其中第 i i i 条直线是 y A i ⋅ x B i y A_i \cdot x B_i yAi​⋅xBi​ 请计算这些直线将平面分成了几个部分。 输入格式 第一行包含一个整数 N N N。 以下 N N N 行, 每行包含两个整数…

总蛋白检测(Total Protein Assay)试剂盒--测定生物样品中胶原蛋白的含量

QuickZyme总胶原蛋白和羟脯氨酸检测试剂盒常用于测定生物样品&#xff08;如组织、组织提取物、细胞提取物和培养基&#xff09;中胶原蛋白的含量。为了正确解释所获得的数据&#xff0c;这些数据应该与一些参考量进行比较&#xff0c;如组织的湿重或干重、蛋白质含量与DNA等。…

2024蓝桥杯每日一题(二分)

一、第一题&#xff1a;教室 解题思路&#xff1a;二分差分 对天数进行二分&#xff0c;在ck函数中用差分方法优化多次区间累加。 【Python程序代码】 n,m map(int,input().split()) a [0] list(map(int,input().split())) d,s,t [0]*(m5),[0]*(m5),[0]*(m5) for…

VR全景在智慧园区中的应用

VR全景如今以及广泛的应用于生产制造业、零售、展厅、房产等领域&#xff0c;如今720云VR全景更是在智慧园区的建设中&#xff0c;以其独特的优势&#xff0c;发挥着越来越重要的作用。VR全景作为打造智慧园区的重要角色和呈现方式已经受到了越来越多智慧园区企业的选择和应用。…

adb常用指令集合

目录 1、查看应用Activity的任务栈2、局域网内无线连接设备3、启动adb服务4、结束adb服务5、查看连接的设备6、安装apk应用7、卸载指定应用8、从电脑拷贝文件到移动设备9、从移动设备拷贝文件到电脑10、重启设备11、查看版本12、调出shell&#xff0c;进入手机设置 1、查看应用…

最新:Selenium操作已经打开的Chrome(免登录)

最近重新尝试了一下&#xff0c;之前写的博客内容。重新捋了一下思路。 目的就是&#xff0c;selenium在需要登录的网站面前&#xff0c;可能就显得有些乏力&#xff0c;因此是不是有一种东西&#xff0c;可以操作它打开我们之前打开过的网站&#xff0c;这样就不用登录了。 …

python基础及网络爬虫

网络爬虫(Web crawler)&#xff0c;有时候也叫网络蜘蛛(Web spider)&#xff0c;是指这样一类程序——它们可以自动连接到互联网站点&#xff0c;并读取网页中的内容或者存放在网络上的各种信息&#xff0c;并按照某种策略对目标信息进行采集&#xff08;如对某个网站的全部页面…

基于JavaWeb开发的springboot网咖管理系统[附源码]

基于JavaWeb开发的springboot网咖管理系统[附源码] &#x1f345; 作者主页 央顺技术团队 &#x1f345; 欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; &#x1f345; 文末获取源码联系方式 &#x1f4dd; &#x1f345; 查看下方微信号获取联系方式 承接各种定制系统 &a…

修改AVD默认存放位置

一、背景 Android Studio安装完成后&#xff0c;通常会配置SDK和AVD&#xff0c;在配置SDK时&#xff0c;可以修改SDK位置&#xff0c;所以&#xff0c;安装完成后&#xff0c;SDK的位置已经进行了修改&#xff0c;但是AVD在创建时&#xff0c;没有修改路径&#xff0c;所以默…

视频监控/云存储EasyCVR视频融合平台设备增删改操作不生效是什么原因?

国标GB28181协议EasyCVR安防平台可以提供实时远程视频监控、视频录像、录像回放与存储、告警、语音对讲、云台控制、平台级联、磁盘阵列存储、视频集中存储、云存储等丰富的视频能力&#xff0c;平台支持7*24小时实时高清视频监控&#xff0c;能同时播放多路监控视频流&#xf…

【办公类-21-09】三级育婴师 视频转文字docx(等线小五单倍行距),批量改成“宋体小四、1.5倍行距、蓝色字体”

作品展示&#xff1a; 背景需求&#xff1a; 一、视频处理 1、育婴师培训的现场视频 2、下载视频&#xff0c;将视频换成考题名称 二、音频 视频用格式工厂转成MP3音频 3、转文字doc 把音频放入“网易云见外工作台”转换为“文字" 等待5分钟&#xff0c;音频文字会被写…

计算机组成原理-练手题集合【期末复习|考研复习】

前言 总结整理不易&#xff0c;希望大家点赞收藏。 给大家整理了一下计算机组成原理中的各章练手题&#xff0c;以供大家期末复习和考研复习的时候使用。 参考资料是王道的计算机组成原理和西电的计算机组成原理。 计算机组成原理系列文章传送门&#xff1a; 第一/二章 概述和数…

零基础小白也行,只用一行命令在自己的电脑跑大模型

什么是Ollama Ollama是一款免费开源的工具&#xff0c;拥有开箱即用的大模型&#xff0c;省去安装环境和下载模型的步骤&#xff0c;让零基础的人也能用起大模型。 项目地址 下载方法 通过下载链接可以找到对应的操作系统的下载版本&#xff0c;而且访问该网站不受限制&…

mysql 的一对一主从复制

一、配置主机 1、在主机mysql 配置文件my.cnf&#xff08;位置一般在/etc/my.cnf&#xff09; #在[mysqld]下面配置 #设置主机server-id(唯一) server-id1 #开启binlog文件 bin-log/var/lib/mysql/mysqlbin2、添加授权账号 #格式旧版 #GRANT REPLICATION SLAVE ON *.* TO sl…

个人博客网站前端页面的实现

博客网站前端页面的实现 博客登录页 相关代码 login.html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><…

【数学】【网格】【状态压缩】782 变为棋盘

作者推荐 视频算法专题 本文涉及知识点 数学 网格 状态压缩 LeetCode:782 变为棋盘 一个 n x n 的二维网络 board 仅由 0 和 1 组成 。每次移动&#xff0c;你能任意交换两列或是两行的位置。 返回 将这个矩阵变为 “棋盘” 所需的最小移动次数 。如果不存在可行的变换&am…

PWARL CTF and others

title: 一些复杂点的题目 date: 2024-03-09 16:05:24 tags: CTF 2024年3月9日 今日习题完成&#xff1a; 1.BUU [网鼎杯 2020 半决赛]AliceWebsite 2.[RoarCTF 2019]Online Proxy 3.[Polar CTF]到底给不给flag呢 4.网鼎杯 2020 总决赛]Game Exp [RoarCTF 2019]Online Proxy …

微信小程序云开发教程——墨刀原型工具入门(常用组件)

引言 作为一个小白&#xff0c;小北要怎么在短时间内快速学会微信小程序原型设计&#xff1f; “时间紧&#xff0c;任务重”&#xff0c;这意味着学习时必须把握微信小程序原型设计中的重点、难点&#xff0c;而非面面俱到。 要在短时间内理解、掌握一个工具的使用&#xf…

初学SpringBoot——请求响应

0 引言 我们在使用SpringBoot开发Java后端项目时候&#xff0c;需要响应前端发送过来的请求&#xff0c;那后端如何响应前端的请求呢&#xff1f;以及前端发送那么多的请求&#xff0c;后端如何根据不同的请求执行不同的代码呢&#xff1f; 1 Postman 当前&#xff0c;主流的…

电影院订票选座小程序|基于微信小程序的电影院购票系统设计与实现(源码+数据库+文档)

电影院订票选座小程序目录 目录 基于微信小程序的电影院购票系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、管理员功能实现 1、 影院信息管理 2 、电影信息管理 2、 用户功能实现 1、 影院信息 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考…