LSTM和GRU的介绍以及Pytorch源码解析

news2024/9/22 18:25:22

介绍一下LSTM模型的结构以及源码,用作自己复习的材料。 

LSTM模型所对应的源码在:\PyTorch\Lib\site-packages\torch\nn\modules\RNN.py文件中。

上次上一篇文章介绍了RNN序列模型,但是RNN模型存在比较严重的梯度爆炸和梯度消失问题。

本文介绍的LSTM模型解决的RNN的大部分缺陷。

首先展示LSTM的模型框架:

下面是LSTM模型的数学推导公式:

\begin{array}{ll} \\ i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ h_t = o_t \odot \tanh(c_t) \\ \end{array}

h_t表示t时刻的隐藏状态,c_t表示t时刻的记忆细胞状态,x_t表示t时刻的输入,h_{t-1}表示在时间t-1的隐藏状态或在时间0的初始隐藏状态。

i_t,f_t,g_t,o_t 分别是输入门、遗忘门、单元门和输出门。

这张图片比较好的介绍了各个门之间的交互关系以及输入输出,大家可以放大看一下。

接下来展示GRU的框架模型:

下面是GRU的数学推导公式:

r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}

h_t表示t时刻的隐藏状态,x_t表示t时刻的输入,h_{t-1}表示在时间t-1的隐藏状态或在时间0的初始隐藏状态。r_t,n_t,z_t分别表示重置门更新门和新建门

上面的图片可以更直观的看到GRU中是如何迭代的。

接下来我们看一下源码中LSTM和GRU类的初始化(只介绍几个重要的参数):

torch.nn.LSTM(self, input_size, hidden_size, num_layers=1,
              bias=True, batch_first=False, dropout=0.0, 
              bidirectional=False, proj_size=0, device=None,
              dtype=None)
torch.nn.GRU(self, input_size, hidden_size, num_layers=1,
             bias=True, batch_first=False, dropout=0.0, 
             bidirectional=False, device=None, dtype=None)
  • input_size:输入数据中的特征数(可以理解为嵌入维度 embedding_dim)。
  • hidden_size:处于隐藏状态 h 的特征数(可以理解为输出的特征维度)。
  • num_layers:代表着RNN的层数,默认是1(层),当该参数大于零时,又称为多层RNN。
  • bidirectional:即是否启用双向LSTM(GRU),默认关闭。

LSTM与GRU都是特殊的RNN,因此输入输出可以参考的上一篇介绍RNN的文章,在这里直接进行代码举例。

lstm1 = nn.LSTM(input_size=20,hidden_size=40,num_layers=4,bidirectional=True)
lstm2 = nn.LSTM(input_size=20,hidden_size=40,num_layers=4,bidirectional=False)

gru1 = nn.GRU(input_size=20,hidden_size=25,num_layers=4,bidirectional=True)
gru2 = nn.GRU(input_size=20,hidden_size=25,num_layers=4,bidirectional=False)

tensor1 = torch.randn(5,10,20)  # (batch_size * seq_len * emb_dim)
tensor2 = torch.randn(5,10,20)  # (batch_size * seq_len * emb_dim)

out_lstm1,(hn, cn) = lstm1(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))
out_lstm2,(hn, cn) = lstm2(tensor2)  # (batch_size * seq_len * (hidden_size * bidirectional))

out_gru1,h_n = gru1(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))
out_gru2,h_n = gru2(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))

print(out_lstm1.shape)  # torch.Size([5, 10, 80])
print(out_lstm2.shape)  # torch.Size([5, 10, 40])

print(out_gru1.shape)  # torch.Size([5, 10, 50])
print(out_gru2.shape)  # torch.Size([5, 10, 25])

维度已经在注释中给大家标注上了!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1311602.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HPM6750系列--第七篇 Visual Studio Code使用openocd调试查看外设信息

一、目的 在《HPM6750系列--第四篇 搭建Visual Studio Code开发调试环境》我们已经手把手指导大家如何在visual studio code中进行开发,包括编译调试等步骤以及相关配置文件。 但是在实际调试时发现找不到芯片寄存器实时显示的窗口,本篇主要讲解如何实现…

【Linux】信号--信号初识/信号的产生方式/信号的保存

文章目录 一、信号初步理解1.生活角度的信号2.技术应用角度的信号 二、信号的产生方式1.通过终端按键产生信号2.调用系统函数向进程发信号3.硬件异常产生信号4.由软件条件产生信号5.进程退出时的核心转储问题 三、信号的保存1.信号其他相关常见概念2.信号在内核中的表示3.sigse…

《PCL多线程加速处理》-配准-icp

《PCL多线程加速处理》-配准-icp 一、效果展示二、具体实现三、代码一、效果展示 数据越大,速度提升效果越快 1、48万点 2、十万点 3、三万点 4、9000点 配准数据 二、具体实现

使用uniapp,引入uView遇到的问题(easycom无效解决方法)

在HBuilderX通过npm安装uView时,按照官网文档配置easycom无效 官网为: 4. 配置easycom组件模式 此配置需要在项目根目录的pages.json中进行。 温馨提示 uni-app为了调试性能的原因,修改easycom规则不会实时生效,配置完后&…

【前端】vscode 相关插件

一 插件: 01、ESLint 用来识别并检查ECMAScript/JavaScript 代码的工具 02、Prettier 用来格式化代码,如.js、.vue、css等都可以进行格式化 03、Vetur 用来识别并高亮vue语法 04、EditorConfig 用来设置vscode的编程行为 二、安装依赖 01、…

SpringSecurity 手机号登录

一、工作流程 1.向手机发送验证码,第三方短信发送平台,如阿里云短信。 2.手机获取验证码后,在表单中输入验证码。 3.使用自定义过滤器​SmsCodeValidateFilter​。 4.短信校验通过后,使用自定义手机认证过滤器​SmsCodeAuthentic…

利用vue-okr-tree实现飞书OKR对齐视图

vue-okr-tree-demo 因开发需求需要做一个类似飞书OKR对齐视图的功能,参考了两位大神的代码: 开源组件vue-okr-tree作者博客地址:http://t.csdnimg.cn/5gNfd 对组件二次封装的作者博客地址:http://t.csdnimg.cn/Tjaf0 开源组件v…

oracle aq java jms使用(数据类型为XMLTYPE)

记录一次冷门技术oracle aq的使用 版本 oracle 11g 创建用户 -- 创建用户 create user testaq identified by 123456; grant connect, resource to testaq;-- 创建aq所需要的权限 grant execute on dbms_aq to testaq; grant execute on dbms_aqadm to testaq; begindbms_a…

android studio 快捷输入模板提示

在Android开发中,我们经常会遇到一些重复性的代码,例如创建一个新的Activity、定义一个Getter方法等。为了提高开发效率,Android Studio提供了Live Templates功能,可以通过简化输入来快速生成这些重复性代码。 按下图提示设置&am…

GO并发编程综合应用

一.GO并发编程综合应用 1.生产者消费者模式 1.1需求分析 ​ 生产者每秒生产一个商品,并通过物流公司取货 ​ 物流公司将商品运输到商铺 ​ 消费者阻塞等待商铺到货,需要消费10次商品 1.2实现原理 1.3代码实现: package mainimport (&q…

Google DeepMind发布Imagen 2文字到图像生成模型;微软在 HuggingFace 上发布了 Phi-2 的模型

🦉 AI新闻 🚀 Google DeepMind发布Imagen 2文字到图像生成模型 摘要:谷歌的Imagen 2是一种先进的文本到图像技术,可以生成与用户提示紧密对齐的高质量、逼真的图像。它通过使用训练数据的自然分布来生成更逼真的图像&#xff0c…

服务器上配置jupyter,提示Invalid credentials如何解决

我是按照网上教程在服务器上安装的jupyter以及进行的密码配置,我利用 passwd()这个口令生成的转译密码是"argon...."。按照教程配置jupyter notebook配置文件里面的内容,登陆网页提示"Invalid credentials"。我谷歌得到的解答是&…

07用户行为日志数据采集

用户行为数据由Flume从Kafka直接同步到HDFS,由于离线数仓采用Hive的分区表按天统计,所以目标路径要包含一层日期。具体数据流向如下图所示。 按照规划,该Flume需将Kafka中topic_log的数据发往HDFS。并且对每天产生的用户行为日志进行区分&am…

cfa一级考生复习经验分享系列(三)

从总成绩可以看出,位于90%水平之上,且置信区间全体均高于90%线。 从各科目成绩可以看出,所有科目均位于90%线上或高于90%线,其中,另类与衍生、公司金额、经济学、权益投资、固定收益、财报分析表现较好,目测…

多架构容器镜像构建实战

最近在一个国产化项目中遇到了这样一个场景,在同一个 Kubernetes 集群中的节点是混合架构的,也就是说,其中某些节点的 CPU 架构是 x86 的,而另一些节点是 ARM 的。为了让我们的镜像在这样的环境下运行,一种最简单的做法…

双端队列和优先级队列

文章目录 前言dequedeque底层设计迭代器设计 priority仿函数数组中的第k个最大元素优先级队列模拟实现pushpop调整仿函数存储自定义类型 前言 今天要介绍比较特殊的结构,双端队列。 还有一个适配器,优先级队列。 deque 栈的默认容器用了一个deque的东西…

案例课7——百度智能客服

1.公司介绍 百度智能客服是百度智能云推出的将AI技术赋能企业客服业务的一揽子解决方案。该方案基于百度世界先进的语音技术、自然语言理解技术、知识图谱等构建完备的一体化产品方案,结合各行业头部客户丰富的运营经验,持续深耕机场服务、电力调度等场…

【普中】基于51单片机简易计算器显示设计( proteus仿真+程序+设计报告+实物演示+讲解视频)

目录标题 📟1. 主要功能:📟2. 讲解视频:📟3. 设计说明书(报告)📟4. 仿真📟5. 实物烧录和现象📟6. 程序代码📟7. 设计资料内容清单 【普中开发板】基于51单片机简易计算器…

日志框架Log4j、JUL、JCL、Slf4j、Logback、Log4j2

为什么程序需要记录日志 我们不可能实时的24小时对系统进行人工监控,那么如果程序出现异常错误时要如何排查呢?并且系统在运行时做了哪些事情我们又从何得知呢?这个时候日志这个概念就出现了,日志的出现对系统监控和异常分析起着…

Jenkins 添加节点报错

报错日志 Error: A JNI error has occurred, please check your installation and try again Exception in thread "main" java.lang.UnsupportedClassVersionError: hudson/remoting/Launcher has been compiled by a more recent version of the Java Runtime (cl…