LSTM内部结构及前向传播原理——LSTM从零实现系列(1)

news2024/11/15 23:39:12

一、前言

        作为专注于时间序列分析的玩家,虽然LSTM用了很久但一直没有写过一篇自己的LSTM原理详解,所以这次要写一个LSTM的从0到1的系列,从模型原理讲解到最后不借助三方框架自己手写代码来实现LSTM模型。本文本身没有特别独到之处,因为网上其实已经有很多优秀的关于LSTM讲解的文章,文章也做了很多借鉴;本文一方面是作为后续讲解如何从0实现LSTM的前置阅读内容,另一方面是作为自己多年对LSTM学习和理解的总结,同时也为作者的后续在此方向进一步的理论研究做铺垫。

        LSTM是建立在普通的神经网络之上,准确说是建立在循环神经网络RNN之上,加上其可叠加的特性,一般归为深度学习范畴。对于神经网络和RNN这些前置知识之前已经写过很多文章,可参考下面链接。本文重点讲解LSTM的算法机制,为后续实现手写LSTM模型做铺垫。

https://blog.csdn.net/yangwohenmai1/category_9126892.html

        LSTM主要解决了RNN的梯度消失和梯度爆炸问题,实现了时间序列的长期预测功能,这一机制的实现方案类似残差网络,利用一条单独通路,将过去的记忆尽可能多的向未来传递,再通过LSTM模型内部的各种控制门来对过去的记忆内容进行提取,舍弃,修正,变成当前的输出结果。再将当前时间步的结果和修改过的记忆信息向下一个LSTM单元传递。

        LSTM的缺点是如果我们输入的历史序列很长,则LSTM模型的内存和训练时间会大大增加,这是由于LSTM内部的计算过程复杂和监督学习数据的特殊结构所导致的。所以后来推出了简化版的LSTM,即GRU,我将其视为LSTM的儿子,LSTM则是爸爸。下面就来了解一下LSTM的原理,图片大多是从网络上借鉴而来,文尾附参考文献链接。

二、LSTM结构特点

        一般的RNN可以理解为多个相同的单元循环链接,将本次循环的状态向后传递,作为下次循环输入的一部分。缺点是不论激活函数如何选择,都会出现梯度爆炸或梯度消失的问题,因为最后链式求导过程都是连乘的形式。常见的RNN结构如下:

        LSTM比RNN内部结构复杂的多,总的来说有三大控制门,分别是遗忘门、候选记忆门、输出门,还有一条负责传送长期记忆的传送带。就是这些门对历史信息进行提取、舍弃、更新这些精细化控制,同时也保证了模型不会出现梯度消失和梯度爆炸的情况。LSTM网络结构如下:

        LSTM最核心的机制就是这个历史信息“传送带”,一方面它贯穿了整个“循环网络”,将历史的信息顺畅的向后传递,另一方面每个单元在更新历史信息时使用的是“+”而不是“×”,这使得链式求导时避免了连乘的出现,从而解决了梯度爆炸和梯度消失的问题。对应功能模块如下:

        LSTM整体结构的中文示意图如下,这个结构中有两个重要参数分别是“记忆细胞”C和“隐藏状态”H,一般来说我们称C为长期记忆,H为短期记忆。原则上不建议也不应该用拟人的方式去解释神经网络,但是为了方便理解各个功能模块后续就沿用这种描述方式。下面我们就对每个过程进行拆分讲解。

三、LSTM原理分步解析

3.1.遗忘门的原理:

        遗忘门f_t可表示为下式,其中W_{xf}表示输入x传递到f_t对应的权重矩阵,W_{h_{t-1}f}表示上一时间步状态h_{t-1}传递到f_t对应的权重矩阵,b_f表示偏置项。通过激活函数\sigmaf_t的计算结果限定在(0,1)之间。

f_{t}=\sigma (x_{t}W_{xf}+h_{t-1}W_{h_{t-1}f}+b_f)

 3.2.输入门的原理:

        输入门i_t可表示为下式,其中W_{xi}表示输入x_t传递到i_t对应的权重矩阵,W_{h_{t-1}i}表示上层状态h_{t-1}传递到i_t对应的权重矩阵,b_i表示偏置项。通过激活函数\sigmai_t的计算结果限定在(0,1)之间。

i_{t}=\sigma (x_{t}W_{xi}+h_{t-1}W_{h_{t-1}i}+b_i)

        中间状态\tilde{C_t}可表示为下式, 其中W_{xC}表示输入x_t传递到\tilde{C_t}对应的权重矩阵,W_{h_{t-1}C}表示上层状态h_{t-1}传递到\tilde{C_t}对应的权重矩阵,b_C表示偏置项。通过激活函数tanh\tilde{C_t}的计算结果限定在(-1,1)之间。这里为什么使用tanh而不用\sigma其实没有定论,根据经验激活函数的选择应该是作者通过超参数后得出的较优结论,不具有明确的可解释行。

\tilde{C_t}=tanh(x_tW_{xC}+h_{t-1}W_{h_{t-1}C}+b_{C})

3.3.候选记忆的原理:

        输出状态C_t可表示为下式,其中C_{t-1}是上一时间步传递过来的输出状态,f_ti_t\tilde{C_t}是遗忘门、输入门、中间状态的计算结果。

C_t=f_t\odot C_{t-1}+i_t\odot \tilde{C_t}

        f_t\odot C_{t-1}表示遗忘门f_t和上一时间步的状态C_{t-1}做逐点相乘,f_t\in (0,1),使矩阵中接近0的位置的内容被遗忘,接近1位置的部分被保留,达到选择性遗忘的效果。

        i_t\odot \tilde{C_t}表示输出入门i_t和中间状态\tilde{C_t}进行逐点相乘,得到新的候选记忆,也即需要记忆的新特征。i_t\in (0,1),决定了\tilde{C_t}的信息是完全保留还是全部忘记。

        最后我们用f_t遗忘掉历史记忆C_{t-1}中的信息,将新的记忆i_t\odot \tilde{C_t}添加到C_{t-1}中,构成了当前时间步的新记忆C_{t}

3.4.输出门的原理:

        输出门o_t可表示为下式,其中W_{xo}表示输入x_t传递到o_t对应的权重矩阵,W_{h_{t-1}o}表示上个时间步的状态h_{t-1}传递到o_t对应的权重矩阵,b_o表示偏置项。通过激活函数\sigmao_t的计算结果限定在(0,1)之间。

o_{t}=\sigma (x_{t}W_{xo}+h_{t-1}W_{h_{t-1}o}+b_o)

        输出状态h_t可表示为下式,将输出门o_ttanh(C_t)逐点相乘,得到当前时间步新的输出状态,作为下个时间步输入的一部分。

h_t=o_t\odot tanh(C_t)

3.5.总结:

        此时我们就得到一套完整的大脑机制,遗忘门忘掉没用的信息,输入门和候选记忆产生新的记忆信息,输出门负责更新记忆,再将新的记忆送给下一个神经元。

        上述内容基本就是LSTM核心的工作流程,通过将每个功能模块组合起来,就构成了完整的LSTM模型,下面介绍一下数据如何在LSTM模型中流转。

四、多层LSTM以及前向传播的细节

        下面问题来了,LSTM一般属于深度学习范畴,那深度学习必然会有多层模型进行叠加传参的问题。说到模型叠加一般网上就会给出类似下面这张图,这种图可能除了画图人自己,没多少人能看明白参数是怎么传递的。 那么对于LSTM这种更加复杂的RNN类模型,作者讲解一下参数是如何传递的。

4.1.多层LSTM结构

        下图是一个双层LSTM的示例,我们一起分析一下参数是如何在模型中流转的。首先要明确的是图中双层LSTM分别指第一行的3个LSTM单元和第二行的3个LSTM单元,而每一行中的3个LSTM单元表示的是当前LSTM层包含3个时间步。传播方向由第二行指向第一行。

4.2.第一层LSTM计算

        首先看下输入数据x在第一层LSTM中的流转,蓝色方框表示第一层LSTM。

  • 我们知道LSTM单元包含三个输入参数x,c,h,首先x1作为第一个时间步,输入到第一个LSTM单元中,此时输入的初始c0和h0都是0矩阵,计算完成后,第一个LSTM单元输出新的一组h1,c1,作为本层LSTM的第二个时间步的输入参数。
  • 因此第二个时间步的输入就是h1,c1,x2,而输出是h2,c2
  • 因此第三个时间步的输入就是h2,c2,x3,而输出是h3,c3

        至此第一层LSTM的三个时间步计算完成,我们得到了三个输出结果(h1,h2,h3)。

4.3.第二层LSTM计算

        在下图中,蓝色方框对应的是第二层LSTM,本层没有输入参数x1,x2,x3,所以我们将第一层LSTM输出的(h1,h2,h3),作为第二层LSTM的输入x1,x2,x3。

  • 第一个时间步输入的初始c0和h0都为0矩阵,计算完成后,第一个时间步输出新的一组h1,c1,作为本层LSTM的第二个时间步的输入参数。
  • 因此第二个时间步的输入就是h1,c1,x2,而输出是h2,c2
  • 因此第三个时间步的输入就是h2,c2,x3,而输出是h3,c3

        直到LSTM的三个时间步计算完成,又会得到三个新输出(h1,h2,h3),可以继续向后传播,或者直接传入全连接层,将其映射到对应的输出结果。

         上述流程就是多层LSTM内部实现前向传播的数据流转过程。更具体的流程后续我们会在源码实现中讲解。

五、总结

        本文讲解了LSTM的内部结构,原理机制,以及前向传播流程。下片文章讲解LSTM的反向传播流程。链接如下:

写作中。。。。

参考文献:

从零开始实现循环神经网络(无框架) - 知乎

Stateful LSTM in Keras – Philippe Remy – My Blog.

https://colah.github.io/posts/2015-08-Understanding-LSTMs/

白话--长短期记忆(LSTM)的几个步骤,附代码!_mantchs的博客-CSDN博客

LSTM 为何如此有效?这五个秘密是你要知道的

循环神经网络RNN&LSTM推导及实现 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/64079.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue学习:el 与data的两种写法

el两种写法 法一&#xff1a;建立了联系 <!-- 准备容器 --><div id"root"><h1>hello,{{name}} </h1> <!-- {{插值语法}} --></div><script>new Vue({ el: #root,data: {name:Amy},});</script> 法二&#xff1a…

论文投稿指南——中国(中文EI)期刊推荐(第1期)

&#x1f680; EI是国际知名三大检索系统之一&#xff0c;在学术界的知名度和认可度仅次于SCI&#xff01;&#x1f384;&#x1f388; 【前言】 想发论文怎么办&#xff1f;手把手教你论文如何投稿&#xff01;那么&#xff0c;首先要搞懂投稿目标——论文期刊。其中&#xf…

java计算机毕业设计ssm特大城市地铁站卫生防疫系统5i80c(附源码、数据库)

java计算机毕业设计ssm特大城市地铁站卫生防疫系统5i80c&#xff08;附源码、数据库&#xff09; 项目运行 环境配置&#xff1a; Jdk1.8 Tomcat8.5 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持…

UDS服务基础篇之14

前言 你知道如果系统产生了DTC&#xff0c;应当如何清除呢&#xff1f;14服务具体的执行流程如何&#xff1f;14服务在使用过程中的常见bug又有哪些&#xff1f; 这篇&#xff0c;我们来一起探索并回答这些问题。为了便于大家理解&#xff0c;以下是本文的主题大纲&#xff1…

相控阵天线(十二):天线校准技术仿真介绍之旋转矢量法

目录简介旋转矢量法算法介绍旋转矢量法校准对方向图的影响旋转矢量法算法仿真移相器位数对旋转矢量法的影响多通道旋转矢量法算法仿真分区旋转矢量法算法仿真简介 由于制造公差和天线互耦的影响&#xff0c;天线各通道会呈现出较大的幅相误差&#xff0c;因此需对天线进行校准…

光阑,像差和成像光学仪器

人眼 人眼成像过程 空气-角膜 水状液-晶状体 晶状体-玻璃体 三个界面的折射成像 瞳孔 2-8mm 可变光阑,调节入射光强弱 睫状肌 改变晶状体曲率---调焦 人眼的调节 远点—眼睛完全松弛状态下看清楚的最远点&#xff0c;正常眼的远点在无穷远 近点—睫状肌最大收缩(焦…

【Redis】解决全局唯一 id 问题

永远要记得坚持的意义 一、全局唯一 id 场景 概念&#xff1a; 以订单表的 id 为例 使用自增 id 会产生的问题&#xff1a; id 的规律性太明显&#xff0c;容易让用户猜测到一些信息受表单数据量的限制 —— 分布式存储时&#xff0c;会产生问题 &#xff08;自增长&#x…

讲理论,重实战!阿里内部SpringBoot王者晋级之路全彩小册开源

大家都知道&#xff0c;Spring Boot框架目前不仅是微服务框架的最佳选择之一&#xff0c;还是现在企业招聘人才肯定会考察的点&#xff1b;很多公司甚至已经将SpringBoot作为了必备技能。但&#xff0c;现在面试这么卷的情况下&#xff0c;很多人面试时还只是背背面试题&#x…

基于KDtree的电路故障检测算法的MATLAB仿真

目录 1.算法描述 2.仿真效果预览 3.MATLAB核心程序 4.完整MATLAB 1.算法描述 k-d树是每个节点都为k维点的二叉树。所有非叶子节点可以视作用一个超平面把空间分割成两个半空间。节点左边的子树代表在超平面左边的点&#xff0c;节点右边的子树代表在超平面右边的点。选择超…

企业数据安全如何落实?私有化知识文档管理系统效率部署

编者按&#xff1a;本文分析了数据安全性企业的重要性&#xff0c;特别是高保密企业单位&#xff0c;介绍了天翎知识文档管理群晖NA是如何保护企业数据安全的。 关键词&#xff1a;私有化部署&#xff0c;安全技术&#xff0c;数据备份&#xff0c;病毒防护&#xff0c;全网隔…

【zeriotier】win10安装zeriotier的辛酸泪

目录概述问题1&#xff1a;waiting for zeriotier system service问题2&#xff1a;Zerotier One 出现Node ID “unknown”问题3&#xff1a;一切正常&#xff0c;但是连不上服务器最终解决方法附录概述 背景&#xff1a;实验室的服务器是使用zeriotier组网的&#xff0c;因此…

字符串-模板编译

模板编译 编译就是一种格式转换成另一种格式的过程&#xff0c;这里主要讨论一下模板编译。模板字符串对比普通的字符串有很多的不同&#xff0c;模板字符串可以嵌套&#xff0c;并且模板字符串可以在内部使用${xxx}进行表达式运算以及函数调用&#xff0c;这些其实都是模板编…

DPDK Ring

无锁环ring是DPDK提供的一种较为基础的数据结构&#xff0c;其支持多生产者和多消费者同时访问。 经过我的经验&#xff0c;无锁结构的实现主要依靠两方面&#xff1a; 最终的数据交换一定要是原子级的操作&#xff0c;最常用到的自然就是比较后交换&#xff08;Compare And S…

Java项目:SSM个人博客网站管理系统

作者主页&#xff1a;源码空间站2022 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 项目介绍 本项目包含管理员与游客两种角色&#xff1b; 管理员角色包含以下功能&#xff1a; 发表文章,查看文章,类别管理,添加类别,个人信息管理,评论…

DeepSort目标跟踪算法

DeepSort目标跟踪算法是在Sort算法基础上改进的。 首先介绍一下Sort算法 Sort算法的核心便是卡尔曼滤波与匈牙利匹配算法 卡尔曼滤波是一种通过运动特征来预测目标运动轨迹的算法 其核心为五个公式&#xff0c;包含两个过程&#xff1a; 其分为先验估计&#xff08;预测&…

[附源码]计算机毕业设计人事管理系统Springboot程序

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

UE4中抛体物理模拟UProjectileMovementComponent

UE4中抛体物理模拟UProjectileMovementComponent1.简述2.使用方法3.绘制抛物曲线4.绘制抛物曲线1.简述 背景&#xff1a;实现抛体运动&#xff0c;反弹效果&#xff0c;抛物曲线等功能 通用实现可以使用spline绘制&#xff0c;物体按照下图接口可以根据时间更新位置 USplineC…

CN_MAC介质访问控制子层@CSMA协议

文章目录常用方法静态方法信道划分MAC特点动态方法随机访问MACCSMA协议CSMA/CD多点接入(或多点访问):载波监听Note:&#x1f388;碰撞检测碰撞:碰撞冲突过程传播时延对载波侦听的影响&#x1f388;争用期发现碰撞的最迟情况电磁波的速率是有限最短帧长&#x1f388;小结&#x…

CAD重复圆绘制机械图形

这次CAD必练图形第四个&#xff0c;这个图形主要用到了CAD圆、直线、修剪、旋转等多个命令&#xff0c;看着不简单&#xff0c;等绘制出来后就觉得还是挺简单的。 目标图形 操作步骤 1.使用CAD直线命令绘制一条水平的直线和四条垂直的直线&#xff0c;四条垂直的直线之间的距…

【网络层】DHCP协议(应用层)、ICMP、IPv6详解

注&#xff1a;最后有面试挑战&#xff0c;看看自己掌握了吗 文章目录DHCP------DHCP服务器来动态分配IP--------应用层协议----允许地址重用ICMP字段----差错报文、询问报文差错报文-----终点不可达无法交付--------源点抑制、拥塞丢数据&#xff08;现在废弃&#xff09;----…