【深度学习】LSTM模型,GRU模型计算公式及其优缺点介绍

news2024/9/21 11:11:18

一.LSTM介绍

LSTM(Long Short-Term Memory)也称长短时记忆结构, 它是传统RNN的变体, 与经典RNN相比能够有效捕捉长序列之间的语义关联, 缓解梯度消失或爆炸现象. 同时LSTM的结构更复杂, 它的核心结构可以分为四个部分去解析:

  • 遗忘门
  • 输入门
  • 细胞状态
  • 输出门

 1LSTM的内部结构图 

1.1 LSTM结构分析

  • 结构解释图:

 

  • 遗忘门部分结构图与计算公式:

 

  • 遗忘门结构分析:

    • 与传统RNN的内部结构计算非常相似, 首先将当前时间步输入x(t)与上一个时间步隐含状态h(t-1)拼接, 得到[x(t), h(t-1)], 然后通过一个全连接层做变换, 最后通过sigmoid函数进行激活得到f(t), 我们可以将f(t)看作是门值, 好比一扇门开合的大小程度, 门值都将作用在通过该扇门的张量, 遗忘门门值将作用的上一层的细胞状态上, 代表遗忘过去的多少信息, 又因为遗忘门门值是由x(t), h(t-1)计算得来的, 因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态h(t-1)来决定遗忘多少上一层的细胞状态所携带的过往信息.
  • 遗忘门内部结构过程演示:

  • 激活函数sigmiod的作用:
    • 用于帮助调节流经网络的值, sigmoid函数将值压缩在0和1之间.
  • 输入门部分结构图与计算公式:

 

  • 输入门结构分析:

    • 我们看到输入门的计算公式有两个, 第一个就是产生输入门门值的公式, 它和遗忘门公式几乎相同, 区别只是在于它们之后要作用的目标上. 这个公式意味着输入信息有多少需要进行过滤. 输入门的第二个公式是与传统RNN的内部结构计算相同. 对于LSTM来讲, 它得到的是当前的细胞状态, 而不是像经典RNN一样得到的是隐含状态.
  • 输入门内部结构过程演示:

 

  • 细胞状态更新图与计算公式:

 

  • 细胞状态更新分析:

    • 细胞更新的结构与计算公式非常容易理解, 这里没有全连接层, 只是将刚刚得到的遗忘门门值与上一个时间步得到的C(t-1)相乘, 再加上输入门门值与当前时间步得到的未更新C(t)相乘的结果. 最终得到更新后的C(t)作为下一个时间步输入的一部分. 整个细胞状态更新过程就是对遗忘门和输入门的应用.
  • 细胞状态更新过程演示:

 

  • 输出门部分结构图与计算公式:

 

  • 输出门结构分析:

    • 输出门部分的公式也是两个, 第一个即是计算输出门的门值, 它和遗忘门,输入门计算方式相同. 第二个即是使用这个门值产生隐含状态h(t), 他将作用在更新后的细胞状态C(t)上, 并做tanh激活, 最终得到h(t)作为下一时间步输入的一部分. 整个输出门的过程, 就是为了产生隐含状态h(t).
  • 输出门内部结构过程演示:

 

1.2 Bi-LSTM介绍

Bi-LSTM即双向LSTM, 它没有改变LSTM本身任何的内部结构, 只是将LSTM应用两次且方向不同, 再将两次得到的LSTM结果进行拼接作为最终输出.


  • Bi-LSTM结构分析:
    • 我们看到图中对"我爱中国"这句话或者叫这个输入序列, 进行了从左到右和从右到左两次LSTM处理, 将得到的结果张量进行了拼接作为最终输出. 这种结构能够捕捉语言语法中一些特定的前置或后置特征, 增强语义关联,但是模型参数和计算复杂度也随之增加了一倍, 一般需要对语料和计算资源进行评估后决定是否使用该结构.

1.3 使用Pytorch构建LSTM模型

  • 位置: 在torch.nn工具包之中, 通过torch.nn.LSTM可调用.

  • nn.LSTM类初始化主要参数解释:

    • input_size: 输入张量x中特征维度的大小.
    • hidden_size: 隐层张量h中特征维度的大小.
    • num_layers: 隐含层的数量.
    • bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用.
  • nn.LSTM类实例化对象主要参数解释:

    • input: 输入张量x.
    • h0: 初始化的隐层张量h.
    • c0: 初始化的细胞状态张量c.
  • nn.LSTM使用示例:

    # 定义LSTM的参数含义: (input_size, hidden_size, num_layers)
    # 定义输入张量的参数含义: (sequence_length, batch_size, input_size)
    # 定义隐藏层初始张量和细胞初始状态张量的参数含义:
    # (num_layers * num_directions, batch_size, hidden_size)
    
    >>> import torch.nn as nn
    >>> import torch
    >>> rnn = nn.LSTM(5, 6, 2)
    >>> input = torch.randn(1, 3, 5)
    >>> h0 = torch.randn(2, 3, 6)
    >>> c0 = torch.randn(2, 3, 6)
    >>> output, (hn, cn) = rnn(input, (h0, c0))
    >>> output
    tensor([[[ 0.0447, -0.0335,  0.1454,  0.0438,  0.0865,  0.0416],
             [ 0.0105,  0.1923,  0.5507, -0.1742,  0.1569, -0.0548],
             [-0.1186,  0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
           grad_fn=<StackBackward>)
    >>> hn
    tensor([[[ 0.4647, -0.2364,  0.0645, -0.3996, -0.0500, -0.0152],
             [ 0.3852,  0.0704,  0.2103, -0.2524,  0.0243,  0.0477],
             [ 0.2571,  0.0608,  0.2322,  0.1815, -0.0513, -0.0291]],
    
            [[ 0.0447, -0.0335,  0.1454,  0.0438,  0.0865,  0.0416],
             [ 0.0105,  0.1923,  0.5507, -0.1742,  0.1569, -0.0548],
             [-0.1186,  0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
           grad_fn=<StackBackward>)
    >>> cn
    tensor([[[ 0.8083, -0.5500,  0.1009, -0.5806, -0.0668, -0.1161],
             [ 0.7438,  0.0957,  0.5509, -0.7725,  0.0824,  0.0626],
             [ 0.3131,  0.0920,  0.8359,  0.9187, -0.4826, -0.0717]],
    
            [[ 0.1240, -0.0526,  0.3035,  0.1099,  0.5915,  0.0828],
             [ 0.0203,  0.8367,  0.9832, -0.4454,  0.3917, -0.1983],
             [-0.2976,  0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]],
           grad_fn=<StackBackward>)

    1.4 LSTM优缺点

  • LSTM优势:

    LSTM的门结构能够有效减缓长序列问题中可能出现的梯度消失或爆炸, 虽然并不能杜绝这种现象, 但在更长的序列问题上表现优于传统RNN.

  • LSTM缺点:

    由于内部结构相对较复杂, 因此训练效率在同等算力下较传统RNN低很多.

 二. GRU介绍

GRU(Gated Recurrent Unit)也称门控循环单元结构, 它也是传统RNN的变体, 同LSTM一样能够有效捕捉长序列之间的语义关联, 缓解梯度消失或爆炸现象. 同时它的结构和计算要比LSTM更简单, 它的核心结构可以分为两个部分去解析:

  • 更新门
  • 重置门

 1GRU的内部结构图

 1.1 GRU结构分析

  • 结构解释图:

 

  • GRU的更新门和重置门结构图:
  • 内部结构分析:

    • 和之前分析过的LSTM中的门控一样, 首先计算更新门和重置门的门值, 分别是z(t)和r(t), 计算方法就是使用X(t)与h(t-1)拼接进行线性变换, 再经过sigmoid激活. 之后重置门门值作用在了h(t-1)上, 代表控制上一时间步传来的信息有多少可以被利用. 接着就是使用这个重置后的h(t-1)进行基本的RNN计算, 即与x(t)拼接进行线性变化, 经过tanh激活, 得到新的h(t). 最后更新门的门值会作用在新的h(t),而1-门值会作用在h(t-1)上, 随后将两者的结果相加, 得到最终的隐含状态输出h(t), 这个过程意味着更新门有能力保留之前的结果, 当门值趋于1时, 输出就是新的h(t), 而当门值趋于0时, 输出就是上一时间步的h(t-1).
1.2 Bi-GRU介绍 

Bi-GRU与Bi-LSTM的逻辑相同, 都是不改变其内部结构, 而是将模型应用两次且方向不同, 再将两次得到的LSTM结果进行拼接作为最终输出. 具体参见上小节中的Bi-LSTM.

1.3 使用Pytorch构建GRU模型
  • 位置: 在torch.nn工具包之中, 通过torch.nn.GRU可调用.

  • nn.GRU类初始化主要参数解释:

    • input_size: 输入张量x中特征维度的大小.
    • hidden_size: 隐层张量h中特征维度的大小.
    • num_layers: 隐含层的数量.
      • bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用.
  • nn.GRU类实例化对象主要参数解释:

    • input: 输入张量x.
      • h0: 初始化的隐层张量h.
  • nn.GRU使用示例:

    >>> import torch
    >>> import torch.nn as nn
    >>> rnn = nn.GRU(5, 6, 2)
    >>> input = torch.randn(1, 3, 5)
    >>> h0 = torch.randn(2, 3, 6)
    >>> output, hn = rnn(input, h0)
    >>> output
    tensor([[[-0.2097, -2.2225,  0.6204, -0.1745, -0.1749, -0.0460],
             [-0.3820,  0.0465, -0.4798,  0.6837, -0.7894,  0.5173],
             [-0.0184, -0.2758,  1.2482,  0.5514, -0.9165, -0.6667]]],
           grad_fn=<StackBackward>)
    >>> hn
    tensor([[[ 0.6578, -0.4226, -0.2129, -0.3785,  0.5070,  0.4338],
             [-0.5072,  0.5948,  0.8083,  0.4618,  0.1629, -0.1591],
             [ 0.2430, -0.4981,  0.3846, -0.4252,  0.7191,  0.5420]],
    
            [[-0.2097, -2.2225,  0.6204, -0.1745, -0.1749, -0.0460],
             [-0.3820,  0.0465, -0.4798,  0.6837, -0.7894,  0.5173],
             [-0.0184, -0.2758,  1.2482,  0.5514, -0.9165, -0.6667]]],
           grad_fn=<StackBackward>)
    1.4 GRU优缺点
  • GRU的优势:

    • GRU和LSTM作用相同, 在捕捉长序列语义关联时, 能有效抑制梯度消失或爆炸, 效果都优于传统RNN且计算复杂度相比LSTM要小.
  • GRU的缺点:

    • GRU仍然不能完全解决梯度消失问题, 同时其作用RNN的变体, 有着RNN结构本身的一大弊端, 即不可并行计算, 这在数据量和模型体量逐步增大的未来, 是RNN发展的关键瓶颈.

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2117959.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于SpringBoot的智能制造云平台系统的设计与实现计算机毕设

一、选题背景与意义&#xff08;300字左右&#xff09; 根据工业4.0智能制造生态链中云工厂在实际生产当中的工作流程进行充分调研和整理出来的&#xff0c;描述最终用户在本系统中对于生产订单的处理、排产、以及生产的完整在线处理流程和业务需求的文档。 针对制造业而言&a…

WebGL系列教程三(使用缓冲区绘制三角形)

目录 1 前言2 缓冲区介绍3 声明顶点的位置和颜色4 回忆Shader的初始化5 开始缓冲区的逻辑5.1 声明顶点坐标5.2 创建并绑定缓冲区5.3 获取顶点着色器中的变量5.4 使变量从缓冲区取值5.5 绘制5.6 完整代码 7 总结 1 前言 上一篇中我们介绍了WebGL的环境搭建及Shader的初始化&…

搭建Docker私有仓库管理本地的Docker镜像,通过harbor实现Web UI访问和管理私有仓库

要在本地搭建一个Docker私有仓库&#xff0c;你可以按照以下步骤进行设置&#xff1a; 安装Docker 确保你已经安装了Docker。如果还没有安装&#xff0c;可以按照官方指南进行安装&#xff1a; 对于Ubuntu系统&#xff0c;你可以运行以下命令来安装Docker&#xff1a; sudo ap…

区块链-P2P(八)

前言 P2P网络&#xff08;Peer-to-Peer Network&#xff09;是一种点对点的网络结构&#xff0c;它没有中心化的服务器或者管理者&#xff0c;所有节点都是平等的。在P2P网络中&#xff0c;每个节点都可以既是客户端也是服务端&#xff0c;这种网络结构的优点是去中心化、可扩展…

【JAVA入门】Day36 - 异常

【JAVA入门】Day36 - 异常 文章目录 【JAVA入门】Day36 - 异常一、异常结构体系综述1.1 错误&#xff08;Error&#xff09;1.2 异常&#xff08;Exception&#xff09;1.3 运行时异常&#xff08;RuntimeException&#xff09;1.4 其他异常 二、编译时异常和运行时异常三、异常…

WebDriver与Chrome DevTools Protocol:如何在浏览器自动化中提升效率

介绍 随着互联网数据的爆炸式增长&#xff0c;爬虫技术成为了获取信息的重要工具。在实际应用中&#xff0c;如何提升浏览器自动化的效率是开发者常常面临的挑战。Chrome DevTools Protocol&#xff08;CDP&#xff09;与Selenium WebDriver相结合&#xff0c;为浏览器自动化提…

vue中ES6的属性every使用@2@

every用于判断数组中的每一项是否均符合条件&#xff0c;并返回一个布尔值&#xff0c;都符合返回true&#xff0c;有一个不符合就返回false&#xff0c;并不再继续执行 //everyvar arr2 [1, 2, 3, 4, 5] let newArr2 arr2.every((num) > {return num < 3}) consol…

安卓13禁止声音调节对话框 删除音量调节对话框弹出 屏蔽音量对话框 android13

总纲 android13 rom 开发总纲说明 文章目录 1.前言2.问题分析3.代码分析3.1 方法13.2 方法24.代码修改4.1 代码修改方法14.2 代码修改方法25.编译6.彩蛋1.前言 客户需要,调整声音,不显示声音调节对话框了。我们在系统里面隐藏这个对话框。 2.问题分析 android在调整声音的…

Chainlit集成Mem0使用一个拥有个性化AI记忆的网页聊天应用

前言 Mem0 简介&#xff0c;可以看我上一篇文章《解决LLM的永久记忆的解决方案-Mem0实现个性化AI永久记忆功能》。本篇文章是对Mem0 实战使用的一个示例。通过Chainlit 快速实现ui界面和open ai的接入&#xff0c;通过使用Mem0 实现对聊天者的对话记录的记忆。 设计实现基本原…

网络空间信息安全实验

实验1 基础实验&#xff08;加密与隐藏&#xff09; 一、实验目的 提高对加密与解密原理的认识&#xff1b;提高对信息隐藏原理的认识&#xff1b;学会使用加密与隐藏软件。 二、实验环境 Pentiuum III、600 MHz以上CPU , 128M 以上内存&#xff0c;10G 以上硬盘&#xff0…

Hoverfly api/v2/simulation 任意文件读取漏洞复现(CVE-2024-45388)

0x01 产品简介 Hoverfly是一个为开发人员和测试人员提供的轻量级服务虚拟化/API模拟/API模拟工具。 0x02 漏洞概述 Hoverfly api/v2/simulation 接口存在任意文件读取漏洞,未经身份验证攻击者可通过该漏洞读取系统重要文件(如数据库配置文件、系统配置文件)、数据库配置文…

CSS学习13--学成网例子

CSS例子 学成网 需要使用的图片&#xff1a; 代码&#xff1a; <html><head><style>/*CSS初始化*/* { /*清除内外边框*/padding: 0;margin: 0;}ul {list-style: none; /*清除列表样式*/}.clearfix:before,.clearfix:after { /*清除浮动*/content: &qu…

今日早报 每日精选15条新闻简报 每天一分钟 知晓天下事 9月9日,星期一

每天一分钟&#xff0c;知晓天下事&#xff01; 2024年9月9日 星期一 农历八月初七 1、 三部门&#xff1a;拟允许在北京、天津、上海、广州、深圳、南京等地设立外商独资医院。 2、 巴黎残奥会结束&#xff1a;中国体育代表团获得94金76银50铜&#xff0c;连续六届残奥会位列…

C语言第二周课

目录 引言: 一、数据类型大小及分类 (1)计算机中常用存储单位 (2)整体介绍一下C语言的数据类型分类。 (3)下面是我们本节课要学的基本内容----常用的数据类型 二、 数据类型的数值范围 三、打印输出类型 数据类型打印示例: 引言: 我们常常在写C语言程序时&#xff0c;总…

用AI操作电脑:使用PyQt和Python结合AI生成代码的实用指南

​ 在github上有这样一个名字叫做open-interpreter的项目&#xff0c;收获了52k个Star。该项目通过自然语言来控制电脑&#xff0c;极大简化了使用电脑的难度&#xff0c;提高了工作效率。受该项目启发&#xff0c;我们可以做一个中文版桌面AI助手。 分步思考&#xff1a; 自…

算法工程师转行大模型:时代的选择or个人的选择

算法工程师的黄金时代&#xff1a;大模型转行之路 随着人工智能技术的飞速发展&#xff0c;尤其是深度学习领域的大规模预训练模型&#xff08;大模型&#xff09;的兴起&#xff0c;算法工程师们正面临前所未有的机遇与挑战。本文旨在探讨算法工程师如何顺利过渡到大模型领域…

基于Matlab的图像去雾系统(四种方法)关于图像去雾的基本算法代码的集合,方法包括局部直方图均衡法、全部直方图均衡法、暗通道先验法、Retinex增强。

基于Matlab的图像去雾系统&#xff08;四种方法&#xff09; 关于图像去雾的基本算法代码的集合&#xff0c;方法包括局部直方图均衡法、全部直方图均衡法、暗通道先验法、Retinex增强。 所有代码整合到App designer编写的GUI界面中&#xff0c;包括导入图片&#xff0c;保存处…

吴恩达发布企业AI转型手册,AI Transformation Playbook,对公司高管、正在创业的AIer是不错的指南。

AI 时代&#xff0c;人人都在创业。 今天看到一篇吴恩达发布的 AI 转型指南&#xff0c;原文《 AI Transformation Playbook——How to lead your company into the AI era》。 对公司高管、正在创业的 AIer 是个不错的指南。 这本人工智能转型手册借鉴了领导谷歌大脑团队和百…

浅谈OLTP 与 OLAP 数据建模的差异

OLTP 与 OLAP&#xff1a;常见工作流 联机分析处理 (OLAP) 和联机事务处理 (OLTP) 是两种主要的数据处理系统。两者之间存在多种差异。 OLTP 系统旨在处理来自多个用户的多个事务&#xff0c;它们通常用于许多应用程序的后端。例如&#xff0c;在线商务网站将使用 OLTP 系统来…

chapter13-常用类——(Arrays)——day16

目录 481-Arrays排序源码解读 483-Arrays模拟排序 484-Arrays其他方法 485-Arrays课堂练习 481-Arrays排序源码解读 接口编程-compare 483-Arrays模拟排序 484-Arrays其他方法 二分搜索查找要求有序&#xff0c;如果数组中不存在这个元素&#xff0c;返回-&#xff08;low1&…