机器学习--循环神经网络(RNN)4

news2025/1/13 10:30:54

一、RNN的学习方式

如果要做学习,需要定义一个损失函数(loss function)来评估模型的好坏,选一个参数要让损失最小。
在这里插入图片描述

以槽填充为例,如上图所示,给定一些句子,给定一些标签,告诉机器说第一个单词它是属于 other 槽,“上海”是目的地槽,“on“属于 other 槽,“June”和“1st”属于时间槽。
“抵达”丢到循环神经网络的时候,循环神经网络会得到一个输出 y1。接下来这个 y1会看它的参考向量(reference vector)算它的交叉熵。我们会期望如果丢进去的是“抵达”,其参考向量应该对应到 other 槽的维度(即other对应的维度为1,其他为 0),这个参考向量的长度就是槽的数量

把“上海”丢进去之后,因为“上海”属于目的地槽,希望把 x2 丢进去之后,y2 它要跟参考向量距离越近越好。那 y2 的参考向量是对应到目的地槽是 1,其它为 0。注意,在丢 x2 之前,一定要丢 x1(在丢“上海”之前先把“抵达”丢进去),不然就不知道存到记忆元里面的值是多少。所以在训练的时候,不能够把这些单词序列打散来看,单词序列仍然要当做一个整体来看。把“on”丢进去,参考向量对应的 other 的维度是 1,其它是 0。
RNN 的损失函数输出和参考向量的交叉熵的和就是要最小化的对象。

有了这个损失函数以后,对于训练也是用梯度下降来做。也就是现在定义出了损失函数L,要更新这个神经网络里面的某个参数 w,就是计算对 w 的偏微分,偏微分计算出来以后,就用梯度下降的方法去更新里面的参数。循环神经网络里面,为了要计算方便,提出了反向传播的进阶版,即随时间反向传播(BackPropagation Through Time,BPTT)。BPTT 跟反向传播其实是很类似的,只是循环神经网络它是在时间序列上运作,所以 BPTT 它要考虑时间上的信息,如下图所示。
在这里插入图片描述

RNN 的训练是比较困难的,如下图所示。
在这里插入图片描述
在做训练的时候,期待学习曲线是像蓝色这条线,这边的纵轴是总损失(total loss),横轴是回合的数量,我们会希望:随着回合的数量越来越多,随着参数不断的更新,损失会慢慢的下降,最后趋向收敛。
但在训练循环神经网络的时候,有时候会看到绿色这条线。如果第一次训练循环神经网络,绿色学习曲线非常剧烈的抖动。
在这里插入图片描述

如上图所示,RNN 的误差表面是总损失的变化是非常陡峭的或崎岖的。误差表面有一些地方非常平坦,一些地方非常陡峭。纵轴是总损失,x 和 y 轴代表是两个参数。
这样会造成什么样的问题呢?假设我们从橙色的点当做初始点,用梯度下降开始调整参数,更新参数,可能会跳过一个悬崖,这时候损失会突然爆长,损失会非常上下剧烈的震荡。有时候我们可能会遇到更惨的状况,就是以正好我们一脚踩到这个悬崖上,因为在悬崖上的梯度很大,之前的梯度会很小,所以可能把学习率调的比较大。很大的梯度乘上很大的学习率结果参数就更新的跨度就会很大,整个参数就飞出去了。
裁剪(clipping)可以解决该问题,当梯度大于某一个阈值的时候,不要让它超过那个阈值,比如当梯度大于 15 时,让梯度等于 15 。因为梯度不会太大,所以就算是踩着这个悬崖上,也不飞出来,会飞到一个比较近的地方,这样还可以继续做 RNN 的训练

之前说 ReLU 激活函数的时候,梯度消失(vanishing gradient)来源于 Sigmoid 函数。但 RNN 会有很平滑的误差表面不是来自于梯度消失。把 Sigmoid 函数换成 ReLU,其实在 RNN 性能通常是比较差的,所以激活函数并不是关键点。
在这里插入图片描述

有更直观的方法来知道一个梯度的大小,可以把某一个参数做小小的变化,看它对网络输出的变化有多大,就可以测出这个参数的梯度大小,如上图所示。
举一个很简单的例子,只有一个神经元,这个神经元是线性的。输入没有偏置,输入的权重是 1,输出的权重也是 1,转移的权重是 w。也就是说从记忆元接到神经元的输入的权重是 w。如下图所示
在这里插入图片描述

假设给神经网络的输入是 [1, 0, 0, 0]T,比如神经网络在最后一个时间(1000 个输出值是 w999)。假设 w 是要学习的参数,我们想要知道它的梯度,所以是改变 w 的值时候,对神经元的输出有多大的影响。假设 w = 1,y^1000 = 1,假设 w = 1.01,y^1000 ≈ 20000,w 有一点小小的变化,会对它的输出影响是非常大的。所以 w 有很大的梯度,那把学习率设小一点就好了。但把 w 设为 0.99,那 y1000 ≈ 0。如果把 w 设为 0.01,y1000 ≈ 0。也就是说在 1 的这个地方有很大的梯度,但是在 0.99 这个地方就突然变得非常小,这个时候需要一个很大的学习率。设置学习率很麻烦,误差表面很崎岖,梯度是时大时小的,在非常小的区域内,梯度有很多的变化。
从这个例子可以看出,RNN 训练的问题其实来自它把同样的东西在转移的时候,在时间按时间的时候,反复使用。所以 w 只要一有变化,它完全由可能没有造成任何影响,一旦造成影响,影响很大,梯度会很大或很小。

所以 RNN 不好训练的原因不是来自激活函数而是来自于它有时间序列同样的权重在不同的时间点被反复的使用。

二、解决RNN梯度消失的方法

广泛被使用的技巧是 LSTM,LSTM 可以让误差表面不要那么崎岖。它会把那些平坦的地方拿掉,解决梯度消失的问题,但不会解决梯度爆炸
(gradient exploding)的问题。有些地方是非常的崎岖的,有些地方是变化非常剧烈的,但是不会有特别平坦的地方。如果做 LSTM 时,大部分地方变化的很剧烈,可以把学习率设置的小一点,保证在学习率很小的情况下进行训练。

LSTM 可以处理梯度消失的问题。RNN 跟 LSTM 在面对记忆元的时候,它处理的操作其实是不一样的。在 RNN 里面,在每一个时间点,神经元的输出都要记忆元里面去,记忆元里面的值都是会被覆盖掉。但是在 LSTM 里面不一样,它是把原来记忆元里面的值乘上一个值再把输入的值加起来放到单元里面。所以它的记忆和输入是相加的。
LSTM 和 RNN 不同的是,如果权重可以影响到记忆元里面的值,一旦发生影响会永远都存在。而 RNN 在每个时间点的值都会被格式化掉,所以只要这个影响被格式化掉它就消失了。但是在 LSTM 里面,一旦对记忆元造成影响,影响一直会被留着,除非遗忘门要把记忆元的值洗掉。不然记忆元一旦有改变,只会把新的东西加进来,不会把原来的值洗掉,所以它不会有梯度消失的问题。

遗忘门可能会把记忆元的值洗掉。 LSTM 的第一个版本其实就是为了解决梯度消失的问题,所以它是没有遗忘门,遗忘门是后来才加上去的。在训练 LSTM的时候,要给遗忘门特别大的偏置,确保遗忘门在多数的情况下都是开启的,只有少数的情况是关闭的。
有另外一个版本用门操控记忆元,叫做 GRU,LSTM 有三个门,而 GRU 有两个门,所以 GRU 需要的参数是比较少的。因为它需要的参数量比较少,所以它在训练的时候是比较鲁棒的。

如果训练 LSTM 的时候,过拟合的情况很严重,可以试下 GRU。GRU 的精神就是:旧的不去,新的不来。它会把输入门跟遗忘门联动起来,也就是说当输入门打开的时候,遗忘门会自动的关闭 (格式化存在记忆元里面的值),当遗忘门没有要格式化里面的值,输入门就会被关起来。也就是要把记忆元里面的值清掉,才能把新的值放进来。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1502443.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Springboot的高校宣讲会管理系统。Javaee项目,springboot项目。

演示视频: 基于Springboot的高校宣讲会管理系统。Javaee项目,springboot项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构,通过Spring Spri…

ECharts 简要介绍及简单实例代码

ECharts 是一个使用 JavaScript 实现的开源可视化库,涵盖各行业图表,满足各种需求。 ECharts 提供了丰富的图表类型和交互能力,使用户能够通过简单的配置生成各种各样的图表,包括但不限于折线图、柱状图、散点图、饼图、雷达图、…

【xv6操作系统】xv6 启动过程分析

一、调试用到的汇编代码 为了方便, Makefile 会创建.asm 文件,可以通过它来定位究竟是哪个指令导致了 bug。 可以看到, kernel 从 80000000 地址处开始执行,第二列为相应指令(如 auipc) 的 16 进制表示&a…

C++ 打印输出十六进制数 指定占位符前面填充0

C 打印十六进制数据&#xff0c;指定数据长度&#xff0c;前面不够时&#xff0c;补充0. 代码如下&#xff1a; #include <iostream> #include <iomanip> #include <cmath>using namespace std;int main() {unsigned int id 0xc01;unsigned int testCaseId…

解决虚拟机静态网址设置后还是变动的的问题

源头就是我的虚拟机静态网址设置好了以后但是网址还是会变动 这是我虚拟机的配置 这是出现的问题 然后我去把多余的ens33的文件都删了 然后还不行 后来按照这个图片进行了下 然后接解决了

string 底层模拟实现常用接口

目录 前言 什么是string? 为什么要学习使用string&#xff1f;string的优势&#xff1f; 因此&#xff0c;string类的成员变量也如图顺序表一样&#xff0c;如下图所示&#xff1a; 构造函数 拷贝构造 析构函数 size() 、capacity&#xff08;&#xff09; operato…

C语言数据结构之二叉堆

愿你千山暮雪 海棠依旧 不为岁月惊扰平添忧愁 &#x1f3a5;前期回顾-二叉树 &#x1f525;数据结构专栏 期待小伙伴们的支持与关注&#xff01;&#xff01;&#xff01; 目录 前期回顾 二叉堆的概念及结构 二叉堆的创建 顺序表的结构声明 顺序表的创建与销毁 二叉堆的插入 …

selenium常用操作汇总

本文总结使用selenium进行web/UI自动化时&#xff0c;会用到的一些常用操作。 定位元素 driver.find_element_by_xpath()#1、绝对路径 2、元素属性 3、层级和属性结合 4、使用逻辑运算符 driver.find_element_by_id()#根据id定位&#xff0c;HTML规定id属性在HTML文档中必须是唯…

Mysql -- 约束

注意:约束是作用于表中字段上的,可以在创建表/修改表的时候添加约束. -- ------------------------------------------------------------------- 约束演示 ---------------------------------------------- create table user(id int primary key auto_increment comment 主键…

CorelDRAW Standard2024适合业余爱好者和家庭企业的图形设计软件

CorelDRAW Standard 2024是一款功能强大的矢量图形设计软件&#xff0c;专为图形爱好者、家庭用户、微型企业和学生们设计。该软件在Windows平台上运行&#xff0c;并提供了智能对象、布局、插图和模板等功能&#xff0c;帮助用户快速创建高质量的设计作品。 CorelDRAW Standa…

seq2seq翻译实战-Pytorch复现

&#x1f368; 本文为[&#x1f517;365天深度学习训练营学习记录博客 &#x1f366; 参考文章&#xff1a;365天深度学习训练营 &#x1f356; 原作者&#xff1a;[K同学啊 | 接辅导、项目定制]\n&#x1f680; 文章来源&#xff1a;[K同学的学习圈子](https://www.yuque.com/…

ssm+vue的农业信息管理系统(有报告)。Javaee项目,ssm vue前后端分离项目。

演示视频&#xff1a; ssmvue的农业信息管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;ssm vue前后端分离项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&…

备考银行科技岗刷题笔记(持续更新版)

银行考试计算机部分复习 IEEE 802.11的帧格式 1.1 IEEE 802.11是什么&#xff1f; 802.11是国际电工电子工程学会&#xff08;IEEE&#xff09;为无线局域网络制定的标准。目前在802.11的基础上开发出了802.11a、802.11b、802.11g、802.11n、802.11ac。并且为了保证802.11更…

npm install没有创建node_modules文件夹

问题记录 live-server 使用时 报错&#xff1a;live-server : 无法将“live-server”项识别为 cmdlet、函数、脚本文件或可运行程序的名称。 npm install 安装 但是 这时npm install没有创建node_modules文件夹&#xff0c;只生成package-lock.json文件 方法一&#xff1a; 手…

NineData与OceanBase完成产品兼容认证,共筑企业级数据库新生态

近日&#xff0c;云原生智能数据管理平台 NineData 和北京奥星贝斯科技有限公司的 OceanBase 数据库完成产品兼容互认证。经过严格的联合测试&#xff0c;双方软件完全相互兼容、功能完善、整体运行稳定且性能表现优异。 此次 NineData 与 OceanBase 完成产品兼容认证&#xf…

软考70-上午题-【面向对象技术2-UML】-UML中的图1

一、图的定义 图是一组元素的图形表示&#xff0c;大多数情况下把图画成顶点、弧的联通图。 顶点&#xff1a;代表事物&#xff1b; 弧&#xff1a;代表关系。 可以从不同的角度画图&#xff0c;UML提供了13种图&#xff1a;&#xff08;只看9种&#xff09; 类图&#xff…

学习c语言:顺序表

一、顺序表的概念和结构 1.1 线性表 线性表&#xff08; linearlist &#xff09;是n个具有相同特性的数据元素的有限序列。线性表是⼀种在实际中⼴泛使⽤的数据结构&#xff0c;常⻅的线性表&#xff1a;顺序表、链表、栈、队列、字符串... 线性表在逻辑上是线性结构&#x…

【网站项目】096实验室开放管理系统

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

15-单片机烧录FreeTOS操作系统后,程序的执行流程

任务创建 1、在系统上电后&#xff0c;第一个执行的是启动文件由汇编语言编写的复位函数 通过复位函数来初始化系统的时钟&#xff0c;然后再执行__main,初始化系统的堆和栈&#xff0c;然后跳转到main函数 2、在main函数中可以直接进行任务创建操作 因为在FreeRTOS中会自动…

c++ primer plus 第十五章笔记 友元,异常和其他

友元类&#xff1a; 两个类不存在继承和包含的关系&#xff0c;但是我想通过一个类的成员函数来修改另一个类的私有成员和保护成员的时候&#xff0c;可以使用友元类。 class A {private:int num;//私有成员//...public: //...friend class B;//声明一个友元类 }class…