门控循环单元(GRU)

news2025/1/16 20:02:51
  • 门控循环神经网络可以更好地捕获时间步距离很长的序列上的依赖关系。

  • 重置门有助于捕获序列中的短期依赖关系。

  • 更新门有助于捕获序列中的长期依赖关系。

  • 重置门打开时,门控循环单元包含基本循环神经网络;更新门打开时,门控循环单元可以跳过子序列。

矩阵连续乘积可以导致梯度消失或梯度爆炸的问题,这种梯度异常在实践中的意义:

  • 我们可能会遇到这样的情况:早期观测值对预测所有未来观测值具有非常重要的意义。 考虑一个极端情况,其中第一个观测值包含一个校验和, 目标是在序列的末尾辨别校验和是否正确。 在这种情况下,第一个词元的影响至关重要。 我们希望有某些机制能够在一个记忆元里存储重要的早期信息。 如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度, 因为它会影响所有后续的观测值。

  • 我们可能会遇到这样的情况:一些词元没有相关的观测值。 例如,在对网页内容进行情感分析时, 可能有一些辅助HTML代码与网页传达的情绪无关。 我们希望有一些机制来跳过隐状态表示中的此类词元。

  • 我们可能会遇到这样的情况:序列的各个部分之间存在逻辑中断。 例如,书的章节之间可能会有过渡存在, 或者证券的熊市和牛市之间可能会有过渡存在。 在这种情况下,最好有一种方法来重置我们的内部状态表示。

在学术界已经提出了许多方法来解决这类问题。 其中最早的方法是”长短期记忆”(long-short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997), 我们将在下一次讨论。 门控循环单元(gated recurrent unit,GRU) (Cho et al., 2014) 是一个稍微简化的变体,通常能够提供同等的效果, 并且计算 (Chung et al., 2014)的速度明显更快。 由于门控循环单元更简单,我们从它开始解读。

1.门控隐状态

门控循环单元与普通的循环神经网络之间的关键区别在于: 前者支持隐状态的门控。 这意味着模型有专门的机制来确定应该何时更新隐状态, 以及应该何时重置隐状态。 这些机制是可学习的,并且能够解决了上面列出的问题。 例如,如果第一个词元非常重要, 模型将学会在第一次观测之后不更新隐状态。 同样,模型也可以学会跳过不相关的临时观测。 最后,模型还将学会在需要的时候重置隐状态。 下面我们将详细讨论各类门控。

1.1重置门和更新门

我们首先介绍重置门(reset gate)和更新门(update gate)。 我们把它们设计成(0,1)区间中的向量, 这样我们就可以进行凸组合。 重置门允许我们控制“可能还想记住”的过去状态的数量; 更新门将允许我们控制新状态中有多少个是旧状态的副本

我们从构造这些门控开始。 图9.1.1 描述了门控循环单元中的重置门和更新门的输入, 输入是由当前时间步的输入和前一时间步的隐状态给出。 两个门的输出是由使用sigmoid激活函数的两个全连接层给出。

 

 1.2候选隐状态

 图9.1.2说明了应用重置门之后的计算流程.

 1.3隐状态

 这些设计可以帮助我们处理循环神经网络中的梯度消失问题, 并更好地捕获时间步距离很长的序列的依赖关系。 例如,如果整个子序列的所有时间步的更新门都接近于1, 则无论序列的长度如何,在序列起始时间步的旧隐状态都将很容易保留并传递到序列结束。

图9.1.3说明了更新门起作用后的计算流。

总之,门控循环单元具有以下两个显著特征:

  • 重置门有助于捕获序列中的短期依赖关系;

  • 更新门有助于捕获序列中的长期依赖关系。

 2.从零开始实现

为了更好地理解门控循环单元模型,我们从零开始实现它。 首先,我们读取 8.5节中使用的时间机器数据集:

pip install mxnet==1.7.0.post1
pip install d2l==0.15.0
from mxnet import np, npx
from mxnet.gluon import rnn
from d2l import mxnet as d2l

npx.set_np()

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

 2.1初始化模型参数

下一步是初始化模型参数。 我们从标准差为0.01的高斯分布中提取权重, 并将偏置项设为0,超参数num_hiddens定义隐藏单元的数量, 实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。

def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return np.random.normal(scale=0.01, size=shape, ctx=device)

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                np.zeros(num_hiddens, ctx=device))

    W_xz, W_hz, b_z = three()  # 更新门参数
    W_xr, W_hr, b_r = three()  # 重置门参数
    W_xh, W_hh, b_h = three()  # 候选隐状态参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = np.zeros(num_outputs, ctx=device)
    # 附加梯度
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.attach_grad()
    return params

2.2定义模型

现在我们将定义隐状态的初始化函数init_gru_state。 与 8.5节中定义的init_rnn_state函数一样, 此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

def init_gru_state(batch_size, num_hiddens, device):
    return (np.zeros(shape=(batch_size, num_hiddens), ctx=device), )

现在我们准备定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)
        R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)
        H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = np.dot(H, W_hq) + b_q
        outputs.append(Y)
    return np.concatenate(outputs, axis=0), (H,)

2.3训练与预测

训练和预测的工作方式与 8.5节完全相同。 训练结束后,我们分别打印输出训练集的困惑度, 以及前缀“time traveler”和“traveler”的预测序列上的困惑度。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                            init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

2.4简洁实现

高级API包含了前文介绍的所有配置细节, 所以我们可以直接实例化门控循环单元模型。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

gru_layer = rnn.GRU(num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/117243.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023年广东数据分析师CPDA认证招生简章(理论+实战)

CPDA数据分析师认证是中国大数据领域有一定权威度的中高端人才认证,它不仅是中国较早大数据专业技术人才认证、更是中国大数据时代先行者,具有广泛的社会认知度和权威性。 无论是地方政府引进人才、公务员报考、各大企业选聘人才,还是招投标加…

Github上排名前五的开源网络监控工具,附详细的图文说明和项目下载地址

Github上排名前五的开源网络监控工具,附详细的图文说明和项目下载地址。 维护网站正常运行是系统管理员最基本的任务之一,所以对系统进行监视,并保持网络的最佳运行状态至关重要。 在现代的网络中,有许多不同的方法来监视&#…

LeetCode刷题之406 根据身高重建队列(贪心算法)

题目描述 假设有打乱顺序的一群人站成一个队列,数组 people 表示队列中一些人的属性(不一定按顺序)。每个 people[i] [hi, ki] 表示第 i 个人的身高为 hi ,前面 正好 有 ki 个身高大于或等于 hi 的人。 请你重新构造并返回输入数…

元宇宙数字产品制作厂商广州华锐互动如何?

近两年,元宇宙引发全球的高度关注,很多玩家也都开始参与进来。作为虚拟世界的重要组成部分,元宇宙也从一个概念逐渐演变成了一种全新的商业模式以及行业发展趋势。 元宇宙具有极致沉浸和交互的体验,可以提升生活、工业、社会、科…

云原生之使用Docker部署OneNav个人书签管理器

云原生之使用Docker部署OneNav个人书签管理器一、OneNav介绍1.OneNav简介2.OneNav特点二、检查本地docker环境1.检查docker版本2.检查docker状态三、下载onenav镜像四、部署OneNav应用1.创建数据目录2.创建OneNav容器3.查看OneNav容器状态五、访问OneNav首页六、访问OneNav后台…

难受的这两天,你们怎么样?

12月23号周五,下班回来小云说中午下楼买菜碰到小区认识的一个妈妈,两个人在楼下聊了一会,晚上那个妈妈检测出阳性,周五晚上睡觉前,小云没任何不适,周末看朋友圈和小区已经很多很多人中招,我那时…

大数据系列——什么是hdfs?hdfs用来干什么的?

目录 一、什么是HDFS 二、hdfs用来干什么的 三、hdfs适用场景 四、hdfs不适合的场景 五、hdfs 架构 基本概念 六、HDFS基础命令 七、hdfs业务中应用 一、什么是HDFS HDFS全称是Hadoop Distributed File System是一种分布式文件系统(HDFS使用多台计算机存储文件&#xff…

易基因|深度综述:癌症中RNA修饰机制的遗传和表观遗传失调(m6A+m1A+m5C+ψ)

大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 2022年11月12日,《Trends Genet》杂志发表了题为“Genetic and epigenetic defects of the RNA modification machinery in cancer”的综述文章,讨论了m6A、m5C、…

CSS滤镜 filter 网站灰色设置

-webkit-filter: grayscale(100%); -moz-filter: grayscale(100%); -ms-filter: grayscale(100%); -o-filter: grayscale(100%); filter: grayscale(100%); filter: progid:DXImageTransform.Microsoft.BasicImage(grayscale1);例如: CSS滤镜 filter 详情链接 blu…

【K3s】第6篇 详解yaml文件,pod端口供外部访问

目录 1、yaml文件详解 2、pod 容器对外透露端口号 1、yaml文件详解 apiVersion: v1 #创建一个新的命名空间 kind: Namespace metadata:name: test --- apiVersion: apps/v1 #资源版本,可使用 kubectl api-versions命令查看有哪些,只有指定具体的资源…

嵌入式:ARM常用开发编译软件介绍

文章目录编译器介绍1、ADS1.22、ARM RealView Developer Suite (RVDS)3、IAR EWARM4、KEIL ARM-MDKARM5、WIN ARM-GCC ARM编译器介绍 1、ADS1.2 ADS(ARM Developer Suite),是在1993年由Metrowerks公司开发是ARM处理器下最主要的开发工具。 他…

结合今年的考试难度浅析浙大MBA/MPA/MEM三项目的分数线趋势…

最近一些天,不少考生都在关注今年的分数线问题,按照往年惯例,这个问题的答案要等到3月10日前后才会出来,包括国家线和浙大自划线两个复试的标尺。但从历史数据和经验方面倒也可以提前针对这个问题做一些趋势分析,特别是…

client-go源码学习(一):client-go源码结构、Client客户端对象

本文基于Kubernetes v1.22.4版本进行源码学习,对应的client-go版本为v0.22.4 1、client-go源码结构 client-go的代码库已经集成到Kubernetes源码中了,源码结构示例如下: $ tree vendor/k8s.io/client-go -L 1 vendor/k8s.io/client-go ├─…

《业务安全大讲堂》——2022全年大回顾!

数字化的深入普及,让企业的业务愈加开放互联。企业的核心业务、关键数据、用户信息、基础设施、运营过程等均处于边界模糊且日益开放的环境中,在电商、支付、信贷、账户、交互、交易等各种形态的业务场景中,存在着形式多样的欺诈行为。而业务…

企业数字化转型迫切,团队协同工具何以成为“杀手锏”?

不久前,2022世界互联网大会乌镇峰会开幕,360创始人周鸿祎以“构建SaaS生态,助力数字化共同富裕”为主题发表分论坛演讲,并宣布360集团正式上线SaaS商店,为中小微企业和实体产业提供一站式数字化转型服务,填…

elasticsearch之metric聚合

文章目录1、背景2、准备数据2.1 准备mapping2.2 准备数据3、metric聚合3.1 max 平均值3.1.1 dsl3.1.2 java代码3.2 min最小值3.2.1 dsl3.2.2 java3.3 min最小值3.3.1 dsl3.3.2 java3.4 min最小值3.4.1 dsl3.4.2 java3.5 count(*)3.5.1 dsl3.5.2 java3.6 count(distinct)3.6.1 d…

一款非常萌的桌面工具 --- Bongo Cat Mver 附使用教程

最近看B站的时候发现了一个很好玩的桌面工具,Bongo Cat Mver 通过多方查找资源,终于找到了,并且已经下载使用 O(∩_∩)O 不知道这只小猫是不是特别好看呢?放在你的桌面上 Bongo Cat Mver简介 Bongo Cat Mver 是一款画风非常萌的…

[论文解析] NeRF-Art: Text-Driven Neural Radiance Fields Stylization

文章目录OverviewWhat problem is addressed in the paper?Is it a new problem? If so, why does it matter? If not, why does it still matter?What is the key to the solution?What is the main contribution?What can we learn from ablation studies?P…

vue3 路由 vite方式新建项目【适合新手】

一 配环境、并初始化项目 安装nodejs https://blog.csdn.net/lh155136/article/details/128444850 参考官网https://cn.vuejs.org/guide/quick-start.html#creating-a-vue-application 找个空目录cmd打开黑窗口 初始化项目 npm init vuelatest输入y 输入项目名字&#xff…

大聪明教你学Java | 带你了解 Binlog 实现 MySQL 主从同步的原理及实现方式

前言 🍊作者简介: 不肯过江东丶,一个来自二线城市的程序员,致力于用“猥琐”办法解决繁琐问题,让复杂的问题变得通俗易懂。 🍊支持作者: 点赞👍、关注💖、留言&#x1f4…