【深度学习 | LSTM】解开LSTM的秘密:门控机制如何控制信息流

news2024/11/25 5:40:12

在这里插入图片描述

🤵‍♂️ 个人主页: @AI_magician
📡主页地址: 作者简介:CSDN内容合伙人,全栈领域优质创作者。
👨‍💻景愿:旨在于能和更多的热爱计算机的伙伴一起成长!!🐱‍🏍
🙋‍♂️声明:本人目前大学就读于大二,研究兴趣方向人工智能&硬件(虽然硬件还没开始玩,但一直很感兴趣!希望大佬带带)

在这里插入图片描述

【深度学习 | LSTM】解开LSTM的秘密:门控机制如何控制信息流
作者: 计算机魔术师
版本: 1.0 ( 2023.8.27 )

摘要: 本系列旨在普及那些深度学习路上必经的核心概念,文章内容都是博主用心学习收集所写,欢迎大家三联支持!本系列会一直更新,核心概念系列会一直更新!欢迎大家订阅

该文章收录专栏
[✨— 《深入解析机器学习:从原理到应用的全面指南》 —✨]

LSTM最全详解,

  • LSTM
      • 原理详解(每个神经元)
        • a. 遗忘门:Forget Gate
        • b. 输入门:Input Gate
        • c. Cell State
        • d. 输出门:Output Gate
      • 参数详解
      • 参数计算
      • 实际场景

LSTM

原理详解(每个神经元)

LSTM(Long Short-Term Memory)是一种常用于处理序列数据的循环神经网络模型。LSTM的核心思想是在传递信息的过程中,通过门的控制来选择性地遗忘或更新信息。LSTM中主要包含三种门:输入门(input gate)、输出门(output gate)和遗忘门(forget gate),以及一个记忆单元(memory cell)。

在LSTM层中,有三个门控单元,即输入门、遗忘门和输出门。这些门控单元在每个时间步上控制着LSTM单元如何处理输入和记忆。在每个时间步上,LSTM单元从输入、前一个时间步的输出和前一个时间步的记忆中计算出当前时间步的输出和记忆。

在LSTM的每个时间步中,输入 x t x_t xt和前一时刻的隐状态 h t − 1 h_{t-1} ht1被馈送给门控制器,然后门控制器根据当前的输入 x t x_t xt和前一时刻的隐状态 h t − 1 h_{t-1} ht1计算出三种门的权重,然后将这些权重作用于前一时刻的记忆单元 c t − 1 c_{t-1} ct1。具体来说,门控制器计算出三个向量:**输入门的开启程度 i t i_t it、遗忘门的开启程度 f t f_t ft和输出门的开启程度 o t o_t ot,这三个向量的元素值均在[0,1]**之间。

然后,使用这些门的权重对前一时刻的记忆单元 c t − 1 c_{t-1} ct1进行更新,计算出当前时刻的记忆单元 c t c_t ct,并将它和当前时刻的输入 x t x_t xt作为LSTM的输出 y t y_t yt。最后,将当前时刻的记忆单元 c t c_t ct和隐状态 h t h_t ht一起作为下一时刻的输入,继续进行LSTM的计算。

在这里插入图片描述

如果你对LSTM以及其与反向传播算法之间的详细联系感兴趣,我建议你参考以下资源:

  1. “Understanding LSTM Networks” by Christopher Olah: https://colah.github.io/posts/2015-08-Understanding-LSTMs/ 强烈推荐!!!
  2. TensorFlow官方教程:Sequence models and long-short term memory network (https://www.tensorflow.org/tutorials/text/text_classification_rnn)
  3. PyTorch官方文档:nn.LSTM (https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html)
  4. 详细讲解RNN,LSTM,GRU https://towardsdatascience.com/a-brief-introduction-to-recurrent-neural-networks-638f64a61ff4

以上资源将为你提供更多关于LSTM及其与反向传播算法结合使用的详细解释、示例代码和进一步阅读材料。

LSTM 的核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”,记忆门一个控制信号控制门是否应该保留该信息,在实现上通常是乘1或乘0来选择保留或忘记。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我们通过“门”结构来实现,“门”结构在训练过程中会去学习该保存或遗忘哪些信息。

LSTM的参数包括输入到状态的权重 W x i , W h i , b i W_{xi},W_{hi},b_i Wxi,Whi,bi,输入到遗忘门的权重 W x f , W h f , b f W_{xf},W_{hf},b_f Wxf,Whf,bf,输入到输出门的权重 W x o , W h o , b o W_{xo},W_{ho},b_o Wxo,Who,bo,以及输入到记忆单元的权重 W x c , W h c , b c W_{xc},W_{hc},b_c Wxc,Whc,bc,其中 W W W表示权重矩阵, b b b表示偏置向量。在实际应用中,LSTM模型的参数通常需要通过训练来获得,以最小化预测误差或最大化目标函数。

a. 遗忘门:Forget Gate

遗忘门的功能是决定应丢弃或保留哪些信息。来自前一个隐藏状态的信息和当前输入的信息同时传递到 sigmoid 函数中去,输出值介于 0 和 1 之间,越接近 0 意味着越应该丢弃,越接近 1 意味着越应该保留。

遗忘门的计算公式

b. 输入门:Input Gate

输入门用于更新细胞状态。首先将前一层隐藏状态的信息和当前输入的信息传递到 sigmoid 函数中去。将值调整到 0~1 之间来决定要更新哪些信息。0 表示不重要,1 表示重要。其次还要将前一层隐藏状态的信息和当前输入的信息传递到 tanh 函数中去,创造一个新的侯选值向量。最后将 sigmoid 的输出值与 tanh 的输出值相乘,sigmoid 的输出值将决定 tanh 的输出值中哪些信息是重要且需要保留下来

使用tanh作为LSTM输入层的激活函数,一定程度上可以避免梯度消失和梯度爆炸的问题。

在LSTM中,如果权重值较大或者较小,那么在反向传播时,梯度值会非常大或者非常小,导致梯度爆炸或者消失的情况。而tanh函数的导数范围在[-1, 1]之间,可以抑制梯度的放大和缩小,从而避免了梯度爆炸和消失的问题(RNN遇到的问题)。此外,tanh函数在输入为0附近的时候输出接近于线性,使得网络更容易学习到线性相关的特征。另外,tanh 函数具有对称性,在处理序列数据时能够更好地捕捉序列中的长期依赖关系。

因此,使用tanh作为LSTM输入层的激活函数是比较常见的做法。

c. Cell State

首先前一层的细胞状态与遗忘向量逐点相乘。如果它乘以接近 0 的值,意味着在新的细胞状态中,这些信息是需要丢弃掉的。然后再将该值与输入门的输出值逐点相加,将神经网络发现的新信息更新到细胞状态中去。至此,就得到了更新后的细胞状态。

d. 输出门:Output Gate

输出门用来确定下一个隐藏状态的值,隐藏状态包含了先前输入的信息。首先,我们将前一个隐藏状态和当前输入传递到 sigmoid 函数中,然后将新得到的细胞状态传递给 tanh 函数。最后将 tanh 的输出与 sigmoid 的输出相乘,以确定隐藏状态应携带的信息。再将隐藏状态作为当前细胞的输出,把新的细胞状态和新的隐藏状态传递到下一个时间步长中去。

在LSTM层中,每个时间步上的计算涉及到许多参数,包括输入、遗忘和输出门的权重,以及当前时间步和前一个时间步的输出和记忆之间的权重。这些参数在模型训练过程中通过反向传播进行学习,以最小化模型在训练数据上的损失函数。总之,LSTM通过门的控制,使得信息在传递过程中可以有选择地被遗忘或更新,从而更好地捕捉长序列之间的依赖关系,广泛应用于语音识别、自然语言处理等领域。

LSTM的输出可以是它的最终状态(最后一个时间步的隐藏状态)或者是所有时间步的隐藏状态序列。通常,LSTM的最终状态可以被看作是输入序列的一种编码,可以被送入其他层进行下一步处理。如果需要使用LSTM的中间状态,可以将return_sequences参数设置为True,这样LSTM层将返回所有时间步的隐藏状态序列,而不是仅仅最终状态。

需要注意的是,LSTM层在处理长序列时容易出现梯度消失或爆炸的问题。为了解决这个问题,通常会使用一些技巧,比如截断反向传播、梯度裁剪、残差连接等

参数详解

layers.LSTM 是一个带有内部状态的循环神经网络层,其中包含了多个可训练的参数。具体地,LSTM层的输入是一个形状为(batch_size, timesteps, input_dim)的三维张量,其中batch_size表示输入数据的批次大小,timesteps表示序列数据的时间步数,input_dim表示每个时间步的输入特征数。LSTM层的输出是一个形状为**(batch_size, timesteps, units)的三维张量,其中units表示LSTM层的输出特征数**。以下是各个参数的详细说明:

  • units:LSTM 层中的单元数,即 LSTM 层输出的维度。
  • activation:激活函数,用于计算 LSTM 层的输出和激活门。
  • recurrent_activation:循环激活函数,用于计算 LSTM 层的循环状态。
  • use_bias:是否使用偏置向量。
  • kernel_initializer:用于初始化 LSTM 层的权重矩阵的初始化器。
  • recurrent_initializer:用于初始化 LSTM 层的循环权重矩阵的初始化器。
  • bias_initializer:用于初始化 LSTM 层的偏置向量的初始化器。
  • unit_forget_bias:控制 LSTM 单元的偏置初始化,如果为 True,则将遗忘门的偏置设置为 1,否则设置为 0。
  • kernel_regularizer:LSTM 层权重的正则化方法。
  • recurrent_regularizer:LSTM 层循环权重的正则化方法。
  • bias_regularizer:LSTM 层偏置的正则化方法。
  • activity_regularizer:LSTM 层输出的正则化方法。
  • dropout:LSTM 层输出上的 Dropout 比率。
  • recurrent_dropout:LSTM 层循环状态上的 Dropout 比率。
  • return_sequences: 可以控制LSTM的输出形式。如果设置为True,则输出每个时间步的LSTM的输出,如果设置为False,则只输出最后一个时间步的LSTM的输出。因此,return_sequences的默认值为False,如果需要输出每个时间步的LSTM的输出,则需要将其设置为True。

这些参数的不同设置将直接影响到 LSTM 层的输出和学习能力。需要根据具体的应用场景和数据特点进行选择和调整。

tf.keras.layers.LSTM(
units,
activation=“tanh”,
recurrent_activation=“sigmoid”, #用于重复步骤的激活功能
use_bias=True, #是否图层使用偏置向量
kernel_initializer=“glorot_uniform”, #kernel权重矩阵的 初始化程序,用于输入的线性转换
recurrent_initializer=“orthogonal”, #权重矩阵的 初始化程序,用于递归状态的线性转换
bias_initializer=“zeros”, #偏差向量的初始化程序
unit_forget_bias=True, #则在初始化时将1加到遗忘门的偏置上
kernel_regularizer=None, #正则化函数应用于kernel权重矩阵
recurrent_regularizer=None, #正则化函数应用于 权重矩阵
bias_regularizer=None, #正则化函数应用于偏差向量
activity_regularizer=None, #正则化函数应用于图层的输出(其“激活”)
kernel_constraint=None,#约束函数应用于kernel权重矩阵
recurrent_constraint=None,#约束函数应用于 权重矩阵
bias_constraint=None,#约束函数应用于偏差向量
dropout=0.0,#要进行线性转换的输入单位的分数
recurrent_dropout=0.0,#为递归状态的线性转换而下降的单位小数
return_sequences=False,#是否返回最后一个输出。在输出序列或完整序列中
return_state=False,#除输出外,是否返回最后一个状态
go_backwards=False,#如果为True,则向后处理输入序列并返回反向的序列
stateful=False,#如果为True,则批次中索引i的每个样本的最后状态将用作下一个批次中索引i的样本的初始状态。
time_major=False,
unroll=False,#如果为True,则将展开网络,否则将使用符号循环。展开可以加快RNN的速度,尽管它通常会占用更多的内存。展开仅适用于短序列。
)

参数计算

对于一个LSTM(长短期记忆)模型,参数的计算涉及输入维度、隐藏神经元数量和输出维度。在给定输入维度为(64,32)和LSTM神经元数量为32的情况下,我们可以计算出以下参数:

  1. 输入维度:(64,32)
    这表示每个时间步长(sequence step)的输入特征维度为32,序列长度为64。

  2. 隐藏神经元数量:32
    这是指LSTM层中的隐藏神经元数量。每个时间步长都有32个隐藏神经元。

  3. 输入门参数:

    • 权重矩阵:形状为(32,32 + 32)的矩阵。其中32是上一时间步的隐藏状态大小,另外32是当前时间步的输入维度。
    • 偏置向量:形状为(32,)的向量。
  4. 遗忘门参数:

    • 权重矩阵:形状为(32,32 + 32)的矩阵。
    • 偏置向量:形状为(32,)的向量。
  5. 输出门参数:

    • 权重矩阵:形状为(32,32 + 32)的矩阵。
    • 偏置向量:形状为(32,)的向量。
  6. 单元状态参数:

    • 权重矩阵:形状为(32,32 + 32)的矩阵。
    • 偏置向量:形状为(32,)的向量。
  7. 输出参数:

    • 权重矩阵:形状为(32,32)的矩阵。将隐藏状态映射到最终的输出维度。
    • 偏置向量:形状为(32,)的向量。

因此,总共的参数数量可以通过计算上述所有矩阵和向量中的元素总数来确定。

实际场景

当使用LSTM(长短期记忆)神经网络进行时间序列预测时,可以根据输入和输出的方式将其分为四种类型:单变量单步预测、单变量多步预测、多变量单步预测和多变量多步预测。

  1. 单变量单步预测:

    • 输入:只包含单个时间序列特征的历史数据。
    • 输出:预测下一个时间步的单个时间序列值。
    • 例如,给定过去几天的某股票的收盘价,使用LSTM进行单变量单步预测将预测未来一天的收盘价。
  2. 单变量多步预测:

    • 输入:只包含单个时间序列特征的历史数据。
    • 输出:预测接下来的多个时间步的单个时间序列值。
    • 例如,给定过去几天的某股票的收盘价,使用LSTM进行单变量多步预测将预测未来三天的收盘价。
  3. 多变量单步预测:

    • 输入:包含多个时间序列特征的历史数据。
    • 输出:预测下一个时间步的一个或多个时间序列值。
    • 例如,给定过去几天的某股票的收盘价、交易量和市值等特征,使用LSTM进行多变量单步预测可以预测未来一天的收盘价。
  4. 多变量多步预测:

    • 输入:包含多个时间序列特征的历史数据。
    • 输出:预测接下来的多个时间步的一个或多个时间序列值。
    • 例如,给定过去几天的某股票的收盘价、交易量和市值等特征,使用LSTM进行多变量多步预测将预测未来三天的收盘价。

这些不同类型的时间序列预测任务在输入和输出的维度上略有差异,但都可以通过适当配置LSTM模型来实现。具体的模型架构和训练方法可能会因任务类型和数据特点而有所不同。

在这里插入图片描述

						  🤞到这里,如果还有什么疑问🤞
					🎩欢迎私信博主问题哦,博主会尽自己能力为你解答疑惑的!🎩
					 	 🥳如果对你有帮助,你的赞是对博主最大的支持!!🥳

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1022629.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机视觉与深度学习-全连接神经网络-训练过程-批归一化- [北邮鲁鹏]

文章目录 思想批归一化操作批归一化与梯度消失经过BN处理 算法实现 思想 直接对神经元的输出进行批归一化 批归一化:对输出值进行归一化,将归一化结果平移缩放作为输出。 批归一化操作 小批量梯度下降算法回顾:每次迭代时会读入一批数据&am…

工信部将制定虚拟宇宙标准

中国工业和信息化部(MIIT)周一表示,随着北京寻求成为新技术的全球标准制定者,中国将成立一个工作组来制定虚拟宇宙行业的标准。 周一,该部发布了一份提案草案,旨在组建一个虚拟宇宙工作组,该工作组可以通过互联网访问共…

CHATGPT中国免费网页版有哪些-CHATGPT中文版网页

CHATGPT中国免费网页版,一个强大的人工智能聊天机器人。如果你曾经感到困惑、寻求答案,或者需要一些灵感,那么CHATGPT国内网页版可能会成为你的好朋友。 CHATGPT国内免费网页版:你的多面“好朋友” 随着人工智能技术的不断发展&a…

Java学习day04:数组

声明:该专栏本人重新过一遍java知识点时候的笔记汇总,主要是每天的知识点题解,算是让自己巩固复习,也希望能给初学的朋友们一点帮助,大佬们不喜勿喷(抱拳了老铁!) Java学习day04:数组 一、开发…

C++:new 和 delete

个人主页 : 个人主页 个人专栏 : 《数据结构》 《C语言》《C》 文章目录 前言一、C内存管理1.内置类型2.自定义类型3.delete 与 new不匹配使用问题(VS平台下) 二、operator new 与 operator delete函数三、 new 和delete的实现原理内置类型自定义类型 四…

【前端知识】Three 学习日志(十)—— 常见几何体(长方体、球体、圆柱、矩形平面、圆形平面)

Three 学习日志(十)—— 常见几何体(长方体、球体、圆柱、矩形平面、圆形平面) 一、构建常用几何体 const geometry_list []// BoxGeometry:长方体 const geometry_box new THREE.BoxGeometry(100, 100, 100); geo…

CPU性能优化

在进行CPU性能优化的时候,我们经常先需要分析出来我们的应用程序中的CPU资源在哪些函数中使用的比较多,这样才能高效地优化。一个非常好的分析工具就是《性能之巅》作者 Brendan Gregg 发明的火焰图。 我们今天就来介绍下火焰图的使用方法,以…

GeoServer地图服务器权限控制

目录 1下载相关软件 2部署软件 3配置鉴权环节 4Java工程 5测试鉴权 6测试鉴权结果分析 本文章应该会后面试验一个鉴权功能就会发布一系列测试过程(GeoServer有很多鉴权方式) 1Download - GeoServer 1下载相关软件 进入geoserver官网的下载页面 …

如何为你的Python程序配置HTTP/HTTPS爬虫IP

在编写Python程序时,有时候我们需要使用HTTP或HTTPS爬虫ip来实现网络请求和访问外部资源。本文将向您介绍如何快速入门,为您的Python程序配置HTTP/HTTPS爬虫ip,以便您能够轻松地处理爬虫ip设置并顺利运行您的程序。 一、了解HTTP/HTTPS爬虫ip…

Python Opencv实践 - ORB特征匹配

参考资料: ORB特征笔记_亦枫Leonlew的博客-CSDN博客 python opencv3 基于ORB的特征检测和 BF暴力匹配 knn匹配 flann匹配 - 知乎 Python OpenCV中的drawMatches()关键点匹配绘制方法详解_cv2.drawmatches_乔卿的博客-CSDN博客 import cv2 as cv import numpy as…

mysql 多个字段 like 同一个值怎么实现

1,需求:前端一个输入框 输入的内容要和数据库中多个字段进行匹配 前端输入内容需要和 username,realname,age,bh 这四个字段匹配 方法1(可优化);select * from rzt_user where user…

爱惨了,这个听书神器APP

我喜欢听书的原因,第一个是比较省时间,而且很方便,看小说需要花费时间,看久了,眼睛又很疼。听书的话,刷牙听、走路听、开车听、睡前听等等都可以。 最近狂爱这个爱屁屁:听书神器 1、全网资源&…

nodejs项目实战(带源码)

nodejs项目实战 主要实现功能用户模块文章分类模块文章模块核心代码 数据库完整代码 主要实现功能 本项只适合新手,是一个接口类的项目,主要涉及一些增删改查功能以及三方包的使用,主要包括用node实现写用户登录注册,添加删除文章…

机器学习——聚类算法

0、前言: 机器学习聚类算法主要就是两类:K-means和DBSCAN聚类:一种无监督的学习,事先不知道类别(相当于不用给数据提前进行标注),自动将相似的对象归到同一个簇中 1、K-means: 原理…

idea项目配置三大步

场景: 使用 idea 打开一个新项目的时候,想让项目迅速跑起来, 其实只需要下面简单三步: 1. 首先,配maven 2. 其次,配置 jdk 这里配置 project 就行了,不用管Modules中的配置。 3. 最后&#…

德纳 Dana EDI 项目案例

德纳 Dana是一家总部位于美国的公司,专门从事车辆传动和密封解决方案。它设计、制造和销售各种汽车零部件,如轴、传动系统、密封件等。该公司在汽车行业中具有悠久的历史,为各种不同类型的车辆提供关键的机械和工程解决方案。 项目背景与目标…

pythonSDK安装+Visual Studio Code

安装PythonSDK 点击去下载python的SDK:https://www.python.org/ 去下载 双击 下载好的安装包 等待安装可能会很慢… 如何验证是否成功安装了python的SDK Windows电脑 打开 CMD 窗口 如何打开 CMD 窗口 键盘 按 wind R python安装编辑器 Visual Studio Code…

在华为云服务器上CentOS 7安装单机版Redis

https://redis.io/是官网地址。 点击右上角的Download。 可以进入https://redis.io/download/——Redis官网下载最新版的网址。 然后在https://redis.io/download/页面往下拉,点击下图超链接这里。 进入https://download.redis.io/releases/下载自己需要的安装…

弱监督目标检测:ALWOD: Active Learning for Weakly-Supervised Object Detection

论文作者:Yuting Wang,Velibor Ilic,Jiatong Li,Branislav Kisacanin,Vladimir Pavlovic 作者单位:Rutgers University;The Institute for Artificial Intelligence Research and Development of Serbia;Nvidia Corporation 论文链接:http:…

如何使用微信编辑器的这个功能呢?

微信编辑器是一个非常实用的公众号工具,除了能够进行文字编辑和排版外,还有一个特别实用的功能,就是可以将图片转换成PDF格式。这个功能对于需要将多张图片合并成一份文件的人来说,无疑是一个非常方便的解决方案。 那么&#xff…