机器学习深度学习——常见循环神经网络结构(RNN、LSTM、GRU)

news2024/11/27 10:27:57

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——RNN的从零开始实现与简洁实现
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

常见循环神经网络结构(RNN、LSTM、GRU)

  • 引言
  • RNN
  • LSTM
    • 门控记忆元
      • 输入门、输出门和遗忘门
      • 候选记忆元
      • 记忆元
      • 隐状态
    • LSTM的简洁实现
  • GRU
    • 结构详解
    • GRU的简洁实现
  • 常用应用方式

引言

之前已经实现讲解并实现过了RNN模型,而LSTM可以弥补RNN的一些缺点,GRU是LSTM的简化版本,这里我们就回顾一下RNN模型,接着循序渐进讲解LSTM和GRU。
CNN和全连接网络的数据表示能力已经很强了,但是我们为啥还需要循环神经网络呢?这是因为现实的问题更复杂,很多数据的输入顺序对于结果都是有很大影响的。如文本数据(尤其是字母和文字的组合),先后顺序具有非常重要的意义,如果打乱,就会无法正确表示原始信息。而相比其他网络,循环神经网络因为具有记忆能力,所以更有效。

RNN

RNN循环神经网络使用torch.nn.RNN()来构建,如下图所示:
在这里插入图片描述
针对t时刻的隐状态,可以由下面公式计算:
h t = φ ( W i h x t + b i h + W h h h t − 1 + b h h ) = φ ( W i h x t + W h h h t − 1 + b h ) 其中: h t 是 t 时刻的隐藏状态; h t − 1 是 t − 1 时刻的隐藏状态 W i h 是输入到隐藏层的权重; W h h 是隐藏层到隐藏层的权重; b i h 是输入到隐藏层的偏置; b h h 是隐藏层到隐藏层的偏置; h_t=φ(W_{ih}x_t+b_{ih}+W_{hh}h_{t-1}+b_{hh})\\ =φ(W_{ih}x_t+W_{hh}h_{t-1}+b_{h})\\ 其中:h_t是t时刻的隐藏状态;h_{t-1}是t-1时刻的隐藏状态\\ W_{ih}是输入到隐藏层的权重;W_{hh}是隐藏层到隐藏层的权重;\\ b_{ih}是输入到隐藏层的偏置;b_{hh}是隐藏层到隐藏层的偏置; ht=φ(Wihxt+bih+Whhht1+bhh)=φ(Wihxt+Whhht1+bh)其中:htt时刻的隐藏状态;ht1t1时刻的隐藏状态Wih是输入到隐藏层的权重;Whh是隐藏层到隐藏层的权重;bih是输入到隐藏层的偏置;bhh是隐藏层到隐藏层的偏置;
激活函数可以使用ReLU或tanh。
虽然在对序列数据进行建模时,RNN有一定记忆能力,但单纯的RNN会随着递归次数的增加,出现权重指数级爆炸或消失的问题,从而难以捕捉长时间关联,并导师训练时收敛困难。

LSTM

LSTM称为长短期记忆网络,是一种特殊的RNN,主要用于解决长序列训练过程中的梯度消失和爆炸问题,能在长序列中获得更好的分析效果。

门控记忆元

记忆元的目的是为了记录附加的信息,要控制记忆元,我们需要下面的几个门:
1、输出门:用来从单元中输出条目
2、输入门:决定何时将数据读入单元
3、遗忘门:重置单元的内容
接下来来看看如何工作的:

输入门、输出门和遗忘门

当前时间步的输入和前一个时间步的隐状态作为数据送入长短期记忆网络的门中,如下图:
在这里插入图片描述
上图的σ是代表由sigmoid激活函数的全连接层处理,因此三个门的值都在(0,1)范围内,显然计算方法如下:
I t = σ ( X t W x i + H t − 1 W h i + b i ) O t = σ ( X t W x o + H t − 1 W h o + b o ) F t = σ ( X t W x f + H t − 1 W h f + b f ) I_t=\sigma(X_tW_{xi}+H_{t-1}W_{hi}+b_i)\\ O_t=\sigma(X_tW_{xo}+H_{t-1}W_{ho}+b_o)\\ F_t=\sigma(X_tW_{xf}+H_{t-1}W_{hf}+b_f) It=σ(XtWxi+Ht1Whi+bi)Ot=σ(XtWxo+Ht1Who+bo)Ft=σ(XtWxf+Ht1Whf+bf)

候选记忆元

其计算与上面类似,但是使用tanh来作为激活函数,函数范围为(-1,1),计算方式为:
G t = t a n h ( X t W x g + H t − 1 W h g + b g ) G_t=tanh(X_tW_{xg}+H_{t-1}W_{hg}+b_g) Gt=tanh(XtWxg+Ht1Whg+bg)
如图所示:
在这里插入图片描述

记忆元

在LSTM中,有两个门用于实现一种输入和遗忘的机制:输入门控制采用多少来自候选记忆元的新数据,而遗忘门控制保留多少过去的记忆元的内容。使用按元素乘法,得出:
C t = F t ⨀ C t − 1 + I t ⨀ G t C_t=F_t \bigodot C_{t-1}+I_t \bigodot G_t Ct=FtCt1+ItGt
若遗忘门始终为1且输入门始终为0,则过去的记忆元 将随时间被保存并传递到当前时间步。
引入这种设计是为了缓解梯度消失问题, 并更好地捕获序列中的长距离依赖关系。
如下图所示:
在这里插入图片描述

隐状态

最后是计算隐状态,这里就是输出门的作用了。LSTM中,它是记忆元的tanh的门控版本,确保了隐状态的值在(-1,1)之间:
H t = O t ⨀ t a n h ( C t ) H_t=O_t \bigodot tanh(C_t) Ht=Ottanh(Ct)
只要输出门接近1,就能有效将所有记忆换递给预测部分,对于输出门接近0,我们只保留记忆元内的所有信息,而不需要更新隐状态。
那么整体的LSTM图示如下所示:
在这里插入图片描述

LSTM的简洁实现

使用高级API,我们可以直接实例化LSTM模型。这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节:

from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

运行结果:

perplexity 1.1, 48684.5 tokens/sec on cpu
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

运行图片:
在这里插入图片描述

GRU

结构详解

LSTM对很多需要“长期记忆”的任务来说效果显著。但是门控状态太多,导致需要训练更多的参数,使得训练难度加大。因此提出循环门控单元GRU,GRU通过将遗忘门和输入门组合在一起,减少了门的数量,并做了其他改变,在保证记忆能力同时,提升网络训练效率。其组成如下所示:
在这里插入图片描述
而每个GRU单元针对输入进行下面函数的计算:
R t = σ ( X t W x r + H t − 1 W h r + b r ) Z t = σ ( X t W x z + H t − 1 W h z + b z ) 候选隐状态 H t ′ = t a n h ( X t W x h + ( R t ⨀ H t − 1 ) W h h + b h ) 其中 R t ⨀ H t − 1 可以减少以往遗忘状态的影响: 每当 R t 接近 1 时,我们恢复一个传统 R N N 网络; R t 接近 0 时,候选隐状态是以 X t 作为输入的多层感知机的结果 H t = Z t ⨀ H t − 1 + ( 1 − Z t ) ⨀ H t ′ Z t 接近 1 时,模型倾向于保留旧状态; Z t 接近 0 时,倾向于候选隐状态 R_t=\sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)\\ Z_t=\sigma(X_tW_{xz}+H_{t-1}W_{hz}+b_z)\\ 候选隐状态H_t^{'}=tanh(X_tW_{xh}+(R_t \bigodot H_{t-1})W_{hh}+b_h)\\ 其中R_t \bigodot H_{t-1}可以减少以往遗忘状态的影响:\\ 每当R_t接近1时,我们恢复一个传统RNN网络;\\ R_t接近0时,候选隐状态是以X_t作为输入的多层感知机的结果\\ H_t=Z_t \bigodot H_{t-1}+(1-Z_t) \bigodot H_t^{'}\\ Z_t接近1时,模型倾向于保留旧状态;Z_t接近0时,倾向于候选隐状态 Rt=σ(XtWxr+Ht1Whr+br)Zt=σ(XtWxz+Ht1Whz+bz)候选隐状态Ht=tanh(XtWxh+(RtHt1)Whh+bh)其中RtHt1可以减少以往遗忘状态的影响:每当Rt接近1时,我们恢复一个传统RNN网络;Rt接近0时,候选隐状态是以Xt作为输入的多层感知机的结果Ht=ZtHt1+(1Zt)HtZt接近1时,模型倾向于保留旧状态;Zt接近0时,倾向于候选隐状态
总之,GRU有以下显著特征:
1、重置门有助于捕获序列中的短期依赖关系
2、更新门有助于捕获序列中的长期依赖关系

GRU的简洁实现

from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

运行结果:

perplexity 1.0, 12581.5 tokens/sec on cpu
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

运行图片:
在这里插入图片描述

常用应用方式

循环神经网络中的不同的输入输出对应情况都有不同的应用方式。其中,一对多的网络结构可以用于图像描述(根据输入的一张图像,自动使用文字描述图像内容);多对一的网络结构可用于文本分类;多对多的网络结构可用于语言翻译。
比如,我们可以用RNN来做手写体分类,可以用LSTM来做中文新闻分类,可以用GRU来进行情感分类等等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/858649.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SD-MTSP:杨氏双缝实验优化算法YDSE求解单仓库多旅行商问题MATLAB(可更改数据集,旅行商的数量和起点)

一、杨氏双缝实验优化算法YDSE 杨氏双缝实验优化算法(Young’s double-slit experiment optimizer,YDSE)由Mohamed Abdel-Basset等人于2023年提出。 参考文献: [1]Mohamed Abdel-Basset, Doaa El-Shahat, Mohammed Jameel, Moha…

阿里云轻量应用服务器_2核2G3M_108元/年_性能测评

阿里云轻量应用服务器2核2G3M带宽108元一年,系统盘为50GB高效云盘;轻量服务器2核4G4M带宽,60GB高效云盘297.98元12个月。目前轻量应用服务器只有2核2G和2核4G有活动,阿里云百科分享阿里云轻量应用服务器入口: 目录 阿…

【Mybatis】调试查看执行的 SQL 语句

1. 问题场景: 记录日常开发过程中 Mybatis 调试 SQL 语句,想要查看Mybatis 中执行的 SQL语句,导致定位问题困难 2. 解决方式 双击shift找到mybatis源码中的 MappedStatement的getBoundSql()方法 public BoundSql getBoundSql(Object para…

贝锐蒲公英:快速搭建连锁门店监控体系,赋能企业高效管理

随着国民生活水平的提高和零售场景的变革,消费者对于餐饮类目的消费支出不断增加,线下社区生鲜商超作为下沉市场最主要的消费场景之一,蕴藏着巨大价值机会。 对于线下连锁生鲜超市而言,连锁门店多、员工多,门店管理时会…

大学生课设实训|基于springboot的在线拍卖系统

目录 项目描述 主要技术栈 功能效果 数据库设计 开发顺序 业务功能 大家好!我是龍弟-idea!需要源码资料信息可私聊我【HWL__666666】! 项目描述 本系统是一个网上商品竞拍系统,为拍卖者和竞买者提供一个在线交流平台。本项…

【一口气 Ping 1000 个 IP 地址,会发生什么事情?】

ping命令是我们检查网络中最常用的命令,作为网络人员,基本上每天都会用到,可以很好地帮助我们分析和判定网络故障,对吧? 一般来说,网工们用 ping查看网络情况,主要是检查两个指标: …

css3 实现文字横幅无缝滚动

css3 实现文字横幅无缝滚动 使用 css3 关键帧 keyframes 和 animation 属性实现文字横幅无缝滚动。 <template><div class"skiHallBanner"><div class"skiHallBanner-text"><span>{{ text }}</span></div></div>…

嵌入式开发:高薪与广阔前景

嵌入式开发是高薪且前景广阔的领域。随着物联网和智能化的快速发展&#xff0c;嵌入式开发人才需求不断增加&#xff0c;市场供应相对不足&#xff0c;导致竞争激烈&#xff0c;推动了薪资水平的提升。 嵌入式开发的复杂性和技术要求使得企业为了吸引优秀人才&#xff0c;普遍…

并发——ThreadPoolExecutor 使用示例

文章目录 1 示例代码:RunnableThreadPoolExecutor2 线程池原理分析3 几个常见的对比3.1 Runnable vs Callable3.2 execute() vs submit()3.3 shutdown()VSshutdownNow()3.2 isTerminated() VS isShutdown() 4 加餐:CallableThreadPoolExecutor示例代码 我们上面讲解了 Executor…

数据结构——时间复杂度和空间复杂度

1.算法效率 2.时间复杂度 3.空间复杂度 4. 常见时间复杂度以及复杂度oj练习 1.算法效率 1.1 如何衡量一个算法的好坏 如何衡量一个算法的好坏呢&#xff1f;比如对于以下斐波那契数的计算 long long Fib(int N) { if(N < 3) return 1; return Fib(N-1) Fib(N-2); }我们看到…

如何提高商城系统的稳定性?

电商行业的飞速发展&#xff0c;越来越多的企业开始关注电商建设。其中&#xff0c;商城系统的稳定性是企业最为关心的问题之一。 商城系统的稳定性不仅影响用户体验&#xff0c;还关系到企业的声誉和利益。因此&#xff0c;如何提高商城系统的稳定性是每一个电商企业必须要面对…

高忆管理:股票集合竞价?

股票集合竞价&#xff08;英文缩写为“SPAC”&#xff09;是股票商场开市前最终一个阶段&#xff0c;也被称为“开盘竞价”。在这个阶段&#xff0c;买卖双方能够提交订单&#xff0c;而且体系将会平衡对买卖盘进行撮合&#xff0c;以确认股票开盘价。这个阶段通常会在上午九点…

AnyCase4.0全球贸易集成平台震撼上线,免费试用赢取精美好礼!

全球贸易行业一直以来都面临着各种挑战和复杂的操作流程。然而&#xff0c;随着科技的不断进步和跨境贸易的日益发展&#xff0c;一个集物流服务、外贸服务、供应商管理和企业风控管理于一体的全新跨境贸易集成平台AnyCase4.0应运而生。经过多年的沉淀和精心打磨&#xff0c;An…

树结构转换

思路&#xff1a; 先把数组转化成一个对象&#xff08;map&#xff09;&#xff0c;对象的key值是对象的id 遍历对象&#xff1b;map[map[k].pid].children.push(map[k]),【k代表索引】&#xff0c;pid等于0代表是根节点 // 数结构转换let arr [{id: 1,pid: 0,name: "b…

【606. 根据二叉树创建字符串】

目录 1.题目描述2.算法思想3.代码实现 1.题目描述 这道题的重点其实就是要省去不影响映射的括号。如&#xff1a; 2.算法思想 3.代码实现 class Solution { public:string _tree2str(TreeNode* root,string& ret){if(rootnullptr){return "";}retto_string(ro…

考研不是在职提升的唯一的途径,还有免联考的社科院与杜兰大学金融管理硕士

社会的迅猛发展&#xff0c;职场竞争愈发激烈&#xff0c;许多人选择考研来提升自己的竞争力&#xff0c;因此考研人数也是逐年增加。根据研招网官方统计&#xff0c;2023年研究生报考人数为474万&#xff0c;相较去年22考研全国硕士研究生招生考试报名人数457万人&#xff0c;…

网盘直链下载助手

一、插件介绍 1.介绍 这是一款免费开源获取网盘文件真实下载地址的油猴脚本&#xff0c;基于 PCSAPI&#xff0c;支持 Windows&#xff0c;Mac&#xff0c;Linux 等多平台&#xff0c;支持 IDM&#xff0c;XDown&#xff0c;Aria2 等多线程下载工具&#xff0c;支持 JSON-RPC…

【网络模块】数传DTU(USR-DR150)进行MQTT通讯

文章目录 [TOC] 准备资料软件硬件硬件接线 USR-CAT1 V1.1.4配置 USR-DR15X 是一款有人物联网推出的“口红DTU”&#xff0c;也称为超小体积导轨式DTU&#xff0c;该产品具有体积小巧、集成SIM卡、蓝牙配置、导轨和挂耳安装方便的特征&#xff1b;Cat-1系列产品具备高速率、低延…

Windows使用docker desktop 安装kafka、zookeeper集群

docker-compose安装zookeeper集群 参考文章&#xff1a;http://t.csdn.cn/TtTYI https://blog.csdn.net/u010416101/article/details/122803105?spm1001.2014.3001.5501 准备工作&#xff1a; ​ 在开始新建集群之前&#xff0c;新建好文件夹&#xff0c;用来挂载kafka、z…

vue消息订阅与发布,实现任意组件间通讯

第一步&#xff1a;下载第三方消息订阅与发布库&#xff0c;例如常用的pubsub.js,他可以在任何框架中使用包括vue、react、anglar等等。 命令&#xff1a;npm i pubsub-js 注意是pubsub-js(不是点); 第二步&#xff1a;引入库&#xff1b; import pubsub from pubsub-js 第…