RNN从理论到实战【理论篇】

news2025/1/12 6:02:48

来源:投稿 作者:175
编辑:学姐

要深入理解深度学习,从零开始创建的经验非常重要,从自己可以理解的角度出发,尽量不使用外部完备的框架前提下,实现我们想要的模型。本系列文章的宗旨就是通过这样的过程,让大家切实掌握深度学习底层实现,而不是仅做一个调包侠。

本文介绍RNN,一种用于处理序列数据的神经网络。

循环神经网络

循环神经网络(Recurrent Neural Network,RNN)是包含循环连接的网络,即有些单元是直接或间接地依赖于它之前的。

本文我们学习一种叫做Elman网络的循环网络,或称为简单循环网络(本文中的RNN都代表该网络)。隐藏层包含一个循环连接作为其输入。即,基于当前输入和前一时刻隐藏状态计算当前隐藏状态。

上图展示了RNN的结构,与普通前馈网络一样,表示当前输入的向量乘以权重矩阵,然后经过非线性激活函数来计算隐藏单元的值。然后用于计算相应的输出。

该网络在处理序列时,一次(一个时间步)顺序地处理序列中的一个元素,与我们之前看到的基于窗口的方法不同。我们使用下表来表示时间,这样,表示时刻(时间步)的输入向量。与前馈网络的关键区别在于上图虚线显示的循环连接。此连接使用上一个时刻隐藏层的值来增强对于当前时刻隐藏层计算的输入。

前一时刻的隐藏层提供了一种记忆(或上下文)的功能,可以提供之前的信息为未来做决定提供帮助。重要的是,这种方法理论上不需要对前文的长度进行限制,不过实际上过远的信息很难有效的保留。

前向传播

RNN中的前向传播(推理)过程和前馈网络差不多。但在使用RNN处理一个序列输入时,需要将RNN按输入时刻展开,然后将序列中的每个输入依次对应到网络不同时刻的输入上,并将当前时刻网络隐藏层的输出也作为下一时刻的输入。

循环网络处理序列输入的示意图,图片来自https://medium.com/deeplearningbrasilia/deep-learning-recurrent-neural-networks-f9482a24d010

为了计算时刻t的输入x_t对应的输出y_t(图中是o_t),我们需要先计算隐藏状态h_t。为了计算它,让输入x_t乘以权重矩阵W以及前一时刻的隐藏状态h_(t-1)乘以权重矩阵U。然后把它们的结果加起来,并经过一个激活函数g,通常为tanh函数,计算当前的隐藏状态h_t。此时,我们可以通过h_t来生成输出向量 y_t

这里要注意维度。我们用d_{in}d_hd_{out}分别代表输入、隐藏和输出层的大小。那么这三个权重矩阵的维度是:

如果是多分类问题, y_t由softmax函数计算而成:

可以看到,时刻t的计算需要前一个时刻t-1的隐藏层激活值(隐藏状态)。显然,这是一种递归形式的定义,从序列开始到序列结束。每个时刻的输入经过层层递归,对最终的输出产生一定影响,每个时刻的隐藏状态h_t承载了1~t时刻的全部输入信息,因此循环神经网络中的隐藏单元也被称为记忆单元。

上图简单神经网络的前向推理。

注意,矩阵U,W,V在每个时刻都是共享的,每个时刻都会计算一个h_iy_i

这里初始时隐藏状态h^0-1

学习

我们有三个权重要更新:输入层到隐藏层的权重W;前一时刻隐藏层到当前时刻隐藏层的权重U;隐藏层到输出层的权重V。

但更新时与前馈网络不同,主要有两点。

  1. 为了计算时刻t的损失,我们需要时刻t-1的隐藏状态;
  2. 时刻t的隐藏状态同时影响了时刻t的输出和时刻t+1的隐藏状态。

所以,也影响了时刻t+1的输出和损失。因此,要评估h_t累积的损失,我们需要知道它对当前输出以及后续输出的影响。

RNN的沿着时间反向传播,图片来自https://mmuratarat.github.io/2019-02-07/bptt-of-rnn

此时,需要修改反向传播算法,形成两阶段的算法来训练RNN中的权重。第一阶段,在第一次传播中,我们执行正向推理,如上图右边黑色箭头所代表的方向(从左到右),计算y_t,在每个时刻累积损失,同时保存隐藏状态的值,以便在第二阶段使用。

在第二阶段,我们反向处理序列,从最后的输出往前计算梯度,即从右到左,如上图红色箭头所示。比如计算了x_{t-1}处的梯度后,得到的损失还需要在前一步处使用。这种方法被称为沿着时间反向传播(Backpropagation Through Time,BPTT)。

我们说这里介绍的是Elman网络,那还有其他什么网络吗?

另一种称为Jordan网络。可以用以下公式来说明它们的区别:

Elman网络:

Jordan网络:

其中x_t为输入向量;h_t为隐藏状态;y_t为输出;W,U,b是参数;f和g为激活函数。

RNN作为语言模型

RNN的这种特性,非常适用于语言模型。可以一次处理序列中的一个单词,基于当前的单词和上一个隐藏状态来预测下一个单词。可以看到,RNN没有N-Gram中N的限制,因为隐藏状态原则上可以表示前面所有单词的信息。

输入序列包含一系列大小为的独热向量,而输出y是代表词典中所有单词概率分布的向量。在每个时刻中,模型通常使用嵌入矩阵E来查看嵌入向量(而不是直接使用独热向量),然后与前一时刻的隐藏状态拼接来计算当前的隐藏状态。然后用于生成输出,它会喂给softmax层生成整个词典上的概率分布。即,在时刻t:

计算的向量可以看成是由h_t提供的对整个词典的所有单词得分。将该得分传入sofmtax归一化后得到概率分布。某个单词i ii作为下一个单词的概率由表示,即y_t的第i个元素:

整个序列的概率就是序列中每个元素的概率之积,我们会使用代表时刻i的真实单词w_i。那么,整个句子率w_{1:n}就可以计算为:

为了训练一个RNN作为语言模型,我们使用文本语料作为训练材料,让模型在每个时刻预测下一个单词。然后训练模型最小化预测真正下一个单词的误差,使用交叉熵作为损失函数:

在语言建模任务下,正确的分布y_t单词,通常被表示为独热向量,对应正确单词位置为1,元素都为0这样,为语言建模的交叉熵损失由模型为正确单词赋予的概率决定。所以在时刻t的损失就是模型赋予下个单词的负对数概率:

 因此,在输入的每个单词t位置处,模型将正确的标记w_{1:t}序列作为输入,并使用它们来计算可能的下一个单词的概率分布,从而计算下一个标记w_{t+1}的模型损失。然后我们移动到下一个单词,此时我们忽略模型对下一个单词的预测,而是使用正确的标记w_{1:t+1}的序列来估计标记w_{t+2}的概率,这种方法被称为tearch forcing。

通过梯度下降来调整网络中的权值,以最小化训练序列上的平均交叉熵损失。上图说明了该训练过程。

可以发现,输入嵌入矩阵E和最后一层权重矩阵V(计算结果经过softmax)很相似。E的列向量代表在训练过程中学习到的词汇表中每个单词的词嵌入,目的是让具有相似含义和特征的单词具有相似的嵌入。并且,由于这些嵌入的长度对应于隐藏层d_h的大小,因此嵌入矩阵的形状为

最后一层矩阵V提供了一种方法,通过计算Vh,对词典中每个单词的可能性进行评分。这得到了一个维度。也就是说,V的行提供了第二组学习的词嵌入。这就引出了一个明显的问题——有必要同时拥有两者吗?

权重绑定(weight tying) 是一种避免这种权重冗余的方法,只需在输入和softmax层上使用同一组嵌入的方法。也就是说,我们在计算的开始和结束时都不用V,而是使用E。

这种改进,除了提升了模型困惑度之外,还显著减少了模型所需的参数量。

我们已经学习了RNN的基础知识,在实际应用上通常不是仅使用我们学到的这种RNN。而是会使用堆叠RNN和双向RNN。下面分别来了解它们。

堆叠NN

我们到此为止所学的例子中,RNN的输入都是由单词嵌入向量组成,而输出是预测单词有用的向量。但是,我们也可以使用一个RNN的整个输出作为另一个RNN的输入,通过这种方向将多个RNN网络堆叠起来。

如上图所示,我们堆叠了三个RNN。

堆叠的RNN通常优于单层RNN。可能的一个原因是,网络在不同层抽象了不同的表示。堆叠RNN的初始层产生的表示可以作为深层有用的抽象——这很难在单词RNN中产生。但是,随着堆叠层数的增加,训练成本也迅速上升。

双向RNN

另一种应用较多的是双向RNN,我们上面学到的是从左到右依次处理序列中的每个元素。但在很多情况下,如果能访问整个序列再做决定,得到的效果会更好。此时就需要双向RNN。

一种实现方式时通过两个独立的RNN网络,一个按照之前的顺序从左往右读;另一个按照逆序从右往左读。在每个时刻t tt,拼接它们生成的表示。

References

Speech and Language Processing

关注下方《学姐带你玩AI》🚀🚀🚀

神经网络系列知识持续更新中

码字不易,欢迎大家点赞评论收藏!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/163141.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JavaSE】数据类型与变量

数据类型与变量数据类型与变量1. 字面常量2. 数据类型3. 变量3.1 变量概念3.2 语法格式3.3.1 整型变量3.3.2 长整型变量3.3.3 短整型变量3.3.4 字节型变量3.3 浮点型变量3.4.1 双精度浮点型3.4.2 单精度浮点型3.4 字符型类型3.5 布尔型变量3.6 类型转换3.7.1 自动类型转换&…

TensorFlow 实战案例: ResNeXt 交通标志图像多分类,附Tensorflow完整代码

各位同学好,今天和大家分享一下如何使用 Tensorflow 构建 ResNeXt 神经网络模型,通过 案例实战 ResNeXt 的训练以及预测过程。每个小节的末尾有网络、训练、预测的完整代码。 ResNeXt 是 ResNet 的改进版,在 bottleneck卷积块 结构上进行了较…

阿里高级技术专家方法论:如何写复杂业务代码?

阿里妹导读:张建飞是阿里巴巴高级技术专家,一直在致力于应用架构和代码复杂度的治理。最近,他在看零售通商品域的代码。面对零售通如此复杂的业务场景,如何在架构和代码层面进行应对,是一个新课题。结合实际的业务场景…

ECM工业能耗管理云平台

在我国的能源消耗中,工业企业是能源消耗的主要群体,能源消耗量占全国能源消耗总量的70%左右,传统方式进行各类工厂能耗的计量,造成能耗数据不完整、不准确、不全面,因而无法进行能耗分析与诊断,造成普遍在各…

DFS初入门

目录 一、前言 二、搜索与暴力法 1、概念 2、搜索的基本思路 3、BFS:一群老鼠走迷宫 4、DFS:一只老鼠走迷宫 三、DFS 1、DFS访问示例 2、DFS的常见操作 3、DFS基础:递归和记忆化搜索 4、DFS的代码框架(大量编码后回头体…

一个真正的鳗,他清楚自己每天都要刷《剑指offer》(第九天)

跟着博主一起刷题 这里使用的是题库: https://leetcode.cn/problem-list/xb9nqhhg/?page1 目录剑指 Offer 57 - II. 和为s的连续正数序列剑指 Offer 59 - I. 滑动窗口的最大值剑指 Offer 60. n个骰子的点数剑指 Offer 57 - II. 和为s的连续正数序列 剑指 Offer 57 …

文旅元宇宙热潮来袭,天下秀用“科技之钥”解锁三大价值

让未来照进现实,让现实走进虚拟,元宇宙正成为通往下个时代的船票。2018年上映的电影《头号玩家》,让大部分人首次感触到元宇宙里的沉浸式体验——男主角带上VR头盔后,瞬间就能进入另一个极其逼真的虚拟世界。随着VR、AR、区块链、…

系统回顾MyBatis体验这一优秀的持久层框架

文章目录1.MyBatis2.Mapper代理3.MyBatis配置升级4.配置文件CRUD5.多条件查询6.多条件动态查询7.单条件动态条件查询8.添加数据并主键返回9.更新数据10.删除数据11.参数传递12.注解开发1.MyBatis MyBatis基本上取消了所有的JDBC硬编码,对于单独使用这样的ORM框架&a…

1585_AURIX_TC275_SMU的部分内核寄存器

全部学习汇总: GreyZhang/g_TC275: happy hacking for TC275! (github.com) 继续看SMU的资料,这次看一部分SMU的内核相关寄存器。这一次整理的内容比较少,而且优点断篇,因此按照序号来分没有保持10页的对齐。 调试相关的寄存器不…

详解外网访问内网DDNS作用 及ddns解析软件使用方法

导语:随着互联网的成熟,家庭宽带的提速,大家对外网访问家庭内网电脑,监控,服务器,存储NAS等设备的需求倍增。目前外网访问内网可以用DDNS动态域名解析的方式,以下本文就来介绍一下原理和实现工具…

ELK日志(3)

EFK日志收集 Elasticsearch: 数据库,存储数据 javalogstash: 日志收集,过滤数据 javakibana: 分析,过滤,展示 javafilebeat: 收集日志,传输到ES或logstash go redis:缓冲数据,等待logstash取数据…

高并发多级缓存架构解决方案 OpenResty、canal搭建及使用流程

高并发多级缓存架构解决方案1、缓存的常规使用方式2、请求流程拆分1、搭建tomcat集群2、搭建OpenRestyOpenResty的目录结构nginx的配置文件lua脚本的执行流程http请求反向代理到tomcat服务器3、OpenResty、Redis的单点故障问题4、防止缓存穿透java中通过redisson实现布隆过滤器…

Mac 下配置 go语言环境

Mac 下配置 go语言环境两种方法安装Go通过Homebrew安装(不太推荐)通过官网安装 (推荐)方法一安装Homebrew通过Homebrew安装Go方法二 通过官网进行安装配置go环境配置go环境国内镜像Vscode环境配置Helloworld.go两种方法安装Go 通…

LabVIEW中的VI脚本

LabVIEW中的VI脚本用户可使用VI脚本选板上的VI、函数和相关的属性、方法,通过程序创建、编辑和运行VI。通过VI脚本,可减少重复的VI编辑所需的时间,例如:创建若干类似VI对齐和分布控件显示或隐藏控件标签连接程序框图对象注: 必须先…

【13】Docker_DockerFile | 关键字

目录 1、DockerFile的定义 2、DockerFile内容基本知识 3、Docker执行DockerFile的大致流程 4、DockerFile的关键字 5、举例: 1、DockerFile的定义 Dockerfile是用来构建Docker镜像的文本文件,是由一条条构建镜像所需的指令和参数构成的脚本。 2、Do…

[前端笔记——HTML介绍] 2.开始学习HTML

[前端笔记——HTML介绍] 2.开始学习HTML1什么是HTML?2剖析一个HTML元素3块级元素和内联元素4空元素5属性6为一个元素添加属性7布尔属性8省略包围属性值的引号9单引号或双引号?10剖析HTML文档11实体引用:在 HTML 中包含特殊字符1什么是HTML? …

LeetCode 17. 电话号码的字母组合

🌈🌈😄😄 欢迎来到茶色岛独家岛屿,本期将为大家揭晓LeetCode 17. 电话号码的字母组合,做好准备了么,那么开始吧。 🌲🌲🐴🐴 一、题目名称 17.…

文件操作中的IO流——字节流与字符流

一,IO流1.什么是IO流IO流是存取和读取数据的解决方案2.IO流的作用IO流用于读写数据,这些数据包括本地文件和网络上的一些数据;比如读写本地文件的时候需要用到文件读写的IO流,读写网络上的数据时需要通过Socket套接字来调用数据流…

机器学习:公式推导与代码实现-监督学习单模型

线性回归 线性回归(linear regression)是线性模型的一种典型方法。 回归分析不再局限于线性回归这一具体模型和算法,更包含了广泛的由自变量到因变量的机器学习建模思想。 原理推导 线性回归学习的关键问题在于确定参数w和b,使得拟合输出y与真实输出yi尽可能接近 为了求…

PowerDesigner16.5配置安装与使用

PowerDesigner16.5百度云下载链接 链接:https://pan.baidu.com/s/1b9XUqxVZ8gTqk_9grptcAQ?pwd3pl7 提取码:3pl7 一:软件安装 1.下载安装包(包含安装文件、汉化包、注册文件) 2.下载后文件内容如下 3.进入安装文件中…