LSTM反向传播原理——LSTM从零实现系列(2)

news2025/1/12 12:14:13

一、LSTM反向传播介绍

        LSTM的反向传播过程相对复杂,主要因为其对应的控制门较多,而对于每一个控制门我们都需要求导,所以工作量较大。

        首先我们根据LSTM结构图分析一下每个控制门的求导过程。在讲解反向传播之前,先了解一些要用到的参数意义。

          一般来说LSTM在层后会接一个全连接层FNN,全连接层后面再接一个损失函数Loss,所以这里我将全连接层反向传回给LSTM层的总误差称之为E

        从上图LSTM的结构可以观察出,回传的总误差E其实会由两个分支进入LSTM内部,分别是C_tH_t。因此从宏观上看,每个LSTM单元误差传播的起始点为C_tH_t,终点为C_{t-1}H_{t-1},在起始点和终点之间分别夹杂着4个控制门F ,I ,\tilde{C_t}, O,上述的这些元素,其实就是LSTM整个反向传播求导过程要涉及的全部内容。

        下面我们详细介绍每个元素在求导中的处理方法。

二、反向传播过程符号定义和说明

2.1.符号意义说明

        LSTM单元中误差反向传播的过程大体分为两种情况:一种是反向传播的误差来源只包含H_t一条链,比如控制门O;另一种是反向传播的误差来源包含C_tH_t两条传播链,比如控制门F,I,\tilde{C_t}以及C_{t-1}。如下图所示的遗忘门F,其反向传播的误差就是来自于红色和绿色两条传播链,计算时两条链都要计算。

        下面我们分别列举各个元素的反向传播路径:

C_{t-1}的误差来自于C_tH_t两条传播链,传播路径为:

\left\{\begin{matrix} E\rightarrow C_t\rightarrow C_{t-1} \\ E\rightarrow H_t\rightarrow C_t\rightarrow C_{t-1} \end{matrix}\right.  

H_{t-1}的误差来自于四条传播链,分别对应下面F ,I ,\tilde{C_t}, O四个控制门:

        (1)控制门O的误差由H_t传来,传播路径为:

E\rightarrow H_t\rightarrow O\rightarrow H_{t-1}

        (2)控制门F的误差由H_tC_t传来,传播路径为:

\left\{\begin{matrix} E\rightarrow H_t\rightarrow C_t\rightarrow F\rightarrow H_{t-1} \\ E\rightarrow C_t\rightarrow F\rightarrow H_{t-1} \end{matrix}\right.

       (3)控制门I的误差由H_tC_t传来,传播路径为:

\left\{\begin{matrix} E\rightarrow H_t\rightarrow C_t\rightarrow I\rightarrow H_{t-1} \\ E\rightarrow C_t\rightarrow I\rightarrow H_{t-1} \end{matrix}\right.

        (4)控制门\tilde{C_t}的误差由H_tC_t传来,传播路径为:

\left\{\begin{matrix} E\rightarrow H_t\rightarrow C_t\rightarrow \tilde{C_t}\rightarrow H_{t-1} \\ E\rightarrow C_t\rightarrow \tilde{C_t}\rightarrow H_{t-1} \end{matrix}\right.

        列举一下每个元素求偏导过程中的符号定义

  1. LSTM单元反向传播的起点C_tH_t的误差,是由上一层传递来的,这里表示为\frac{\partial E}{\partial C_{t}}\frac{\partial E}{\partial H_{t}},是已知常量。
  2. LSTM单元中的H_t是由C_tO_t计算而来,所以求导过程中存在H_tC_t的偏导数 \frac{\partial H_t}{\partial C_{t}},以及 H_tO_t的偏导数 \frac{\partial H_t}{\partial O_{t}} 。C_t分别由F_t,I_t,\tilde{C_t}计算而来,所以求导过程中存在 C_t分别到F_t,I_t,\tilde{C_t}的偏导数\frac{\partial C_t}{\partial F_{t}}\frac{\partial C_t}{\partial I_{t}}\frac{\partial C_t}{\partial \tilde{C}_{t}}
  3. LSTM单元中从控制门F,I,\tilde{C_t},O_tH_{t-1}存在偏导数\frac{\partial F_t}{\partial H_{t-1}}\frac{\partial I_t}{\partial H_{t-1}}\frac{\partial \tilde{C}_t}{\partial H_{t-1}}\frac{\partial O_t}{\partial H_{t-1}}。从C_tH_tC_{t-1}的偏导数则为\frac{\partial C_t}{\partial C_{t-1}}\frac{\partial H_t}{\partial C_{t-1}}。 
  4. 最后在求偏导的过程中还会用到一些前向传播的数值,如C_tC_{t-1}F_tI_t等,这些参数在计算前向传播过程中都可以加以保留。
  5. 在理清上述这些反向传播中存在的关系后,下面我们就可以完成整个LSTM单元的反向传播计算了。

        LSTM计算顺序如下图所示:     

        介绍完这些符号定义后,下面就可以开始LSTM的反向传播计算了。

2.2.重要细节——特殊的传播链Ct

        C_t是一条比较特殊的传播链,在前向传播时,每个sample都中包含n个Timestep,在第一个Timestep计算时,C_t的初始值是零矩阵,在后续的Timestep计算时C_t会进行不断累计和向后传递。在当前LSTM层计算完成后,向下个LSTM层传递时,只向后传递输出的状态H_t,作为下个LSTM层的输入Xh,而当前层的C_t值不再向下个LSTM层传递。

        所以同理,在反向传播时,下一层的\Delta H_t会反向传播到上一层作为误差输入,但\Delta C_t不会回传,所以每一层LSTM按照时间步倒序计算反向传播过程中,计算第一个Timestep时\Delta C_t的初始值也是零矩阵,并且在后续时间步中进行累计和传递。

        这一规则十分重要,在这里单独强调,后续内容不再重复说明。

三、LSTM反向传播流程解析

        上一节我们说过,从反向传播链终点的角度出发,有两种类型的传播链,即 C_{t-1}链和H_{t-1}链 。而细分H_{t-1}的传播链其实又有两种类型,即包含C_t的和不包含C_t的。所以反向传播链大体上可分为三类,下面来讲解这三种流程。

3.1.第一类偏导:H_{F}H_{I}H_{\tilde{C}}

3.1.1.遗忘门F_t求导过程

        两条完整求导路径表达式如下

 \frac{\partial E}{\partial H_{t-1}}=\frac{\partial E}{\partial C_{t}}\cdot \frac{\partial C_{t}}{\partial F_{t}}\cdot \frac{\partial F_t}{\partial H_{t-1}}+\frac{\partial E}{\partial H_{t}}\cdot\frac{\partial H_{t}}{\partial C_{t}}\cdot \frac{\partial C_{t}}{\partial F_{t}}\cdot \frac{\partial F_t}{\partial H_{t-1}}

\frac{\partial E}{\partial H_{t-1}}= (\frac{\partial E}{\partial C_{t}}+\frac{\partial E}{\partial H_{t}}\cdot\frac{\partial H_{t}}{\partial C_{t}})\cdot \frac{\partial C_{t}}{\partial F_{t}}\cdot \frac{\partial F_t}{\partial H_{t-1}}

        3.1.1.1.绿色路径H_tC_t的偏导

H_t=O_t\odot tan(C_t)

\frac{\partial H_t}{\partial C_t}=O_t\odot (1-tan^{2}C_t)

        3.1.1.2.红色路径C_tF_t的偏导 

C_t=F_t\odot C_{t-1}+I_t\odot \tilde{C_t}

\frac{\partial C_t}{\partial F_t}=C_{t-1}

        3.1.1.3.遗忘门F_tH_{t-1}的偏导 

F_{t}=\sigma (x_{t}W_{xf}+h_{t-1}W_{h_{t-1}f}+b_f)

{\sigma}'(x)=\sigma (x)(1-\sigma (x))

\frac{\partial F_t}{\partial H_{t-1}}=F_t(1-F_t)

        3.1.1.4.两条路径表达式合并

\frac{\partial E}{\partial H_{t-1}}= \left \{ \frac{\partial E}{\partial C_{t}}+\frac{\partial E}{\partial H_{t}}\cdot[O_t\odot (1-tan^{2}C_t)] \right \}\cdot C_{t-1}\cdot F_t(1-F_t)

3.1.2.输入门I_t求导过程

        原理同上,这里直接写

\frac{\partial E}{\partial H_{t-1}}=\frac{\partial E}{\partial C_{t}}\cdot \frac{\partial C_{t}}{\partial I_{t}}\cdot \frac{\partial I_t}{\partial H_{t-1}}+\frac{\partial E}{\partial H_{t}}\cdot\frac{\partial H_{t}}{\partial C_{t}}\cdot \frac{\partial C_{t}}{\partial I_{t}}\cdot \frac{\partial I_t}{\partial H_{t-1}}

\frac{\partial E}{\partial H_{t-1}}= (\frac{\partial E}{\partial C_{t}}+\frac{\partial E}{\partial H_{t}}\cdot\frac{\partial H_{t}}{\partial C_{t}})\cdot \frac{\partial C_{t}}{\partial I_{t}}\cdot \frac{\partial I_t}{\partial H_{t-1}}

        3.1.2.1.H_tC_t的偏导

H_t=O_t\odot tan(C_t)

\frac{\partial H_t}{\partial C_t}=O_t\odot (1-tan^{2}C_t)

        3.1.2.2.C_tI_t的偏导 

C_t=F_t\odot C_{t-1}+I_t\odot \tilde{C_t}

\frac{\partial C_t}{\partial I_t}=\tilde{C}_t

        3.1.2.3.输入门I_tH_{t-1}的偏导 

I_{t}=\sigma (x_{t}W_{xI}+h_{t-1}W_{h_{t-1}I}+b_I)

{\sigma}'(x)=\sigma (x)(1-\sigma (x))

\frac{\partial I_t}{\partial H_{t-1}}=I_t(1-I_t)

        3.1.2.4.两条路径表达式合并

\frac{\partial E}{\partial H_{t-1}}= \left \{ \frac{\partial E}{\partial C_{t}}+\frac{\partial E}{\partial H_{t}}\cdot[O_t\odot (1-tan^{2}C_t)] \right \}\cdot \tilde{C}_{t}\cdot I_t(1-I_t)

3.1.3.候选记忆\tilde{C}_t求导过程

        原理同上,这里直接写

\frac{\partial E}{\partial H_{t-1}}=\frac{\partial E}{\partial C_{t}}\cdot \frac{\partial C_{t}}{\partial \tilde{C}_{t}}\cdot \frac{\partial \tilde{C}_t}{\partial H_{t-1}}+\frac{\partial E}{\partial H_{t}}\cdot\frac{\partial H_{t}}{\partial C_{t}}\cdot \frac{\partial C_{t}}{\partial \tilde{C}_{t}}\cdot \frac{\partial \tilde{C}_t}{\partial H_{t-1}}

\frac{\partial E}{\partial H_{t-1}}= (\frac{\partial E}{\partial C_{t}}+\frac{\partial E}{\partial H_{t}}\cdot\frac{\partial H_{t}}{\partial C_{t}})\cdot \frac{\partial C_{t}}{\partial \tilde{C}_{t}}\cdot \frac{\partial \tilde{C}_t}{\partial H_{t-1}}

        3.1.3.1.H_tC_t的偏导

H_t=O_t\odot tan(C_t)

\frac{\partial H_t}{\partial C_t}=O_t\odot (1-tan^{2}C_t)

        3.1.3.2.C_t\tilde{C}_t的偏导 

C_t=F_t\odot C_{t-1}+I_t\odot \tilde{C_t}

\frac{\partial C_t}{\partial \tilde{C}_t}=I_t

        3.1.3.3.候选记忆\tilde{C}_tH_{t-1}的偏导 

\tilde{C_t}=tanh(x_tW_{x\tilde{C}}+h_{t-1}W_{h_{t-1}\tilde{C}}+b_{\tilde{C}})

{\sigma}'(x)=\sigma (x)(1-\sigma (x))

\frac{\partial \tilde{C}_t}{\partial H_{t-1}}=\tilde{C}_t(1-\tilde{C}_t)

        3.1.3.4.两条路径表达式合并

\frac{\partial E}{\partial H_{t-1}}= \left \{ \frac{\partial E}{\partial C_{t}}+\frac{\partial E}{\partial H_{t}}\cdot[O_t\odot (1-tan^{2}C_t)] \right \}\cdot I_{t}\cdot \tilde{C}_{t}(1-\tilde{C}_{t})

3.2.第二类偏导:H_{O}

        完成表达式如下 

\frac{\partial E}{\partial H_{t-1}}=\frac{\partial E}{\partial H_t}\cdot \frac{\partial H_t}{\partial O_t}\cdot \frac{\partial O_t}{\partial H_{t-1}}

        3.2.1.红色路径H_tO_t偏导

H_t=O_t\odot tan(C_t)

\frac{\partial H_t}{\partial O_t}=tan(C_t)

        3.2.2.O_tH_{t-1}偏导

O_{t}=\sigma (x_{t}W_{xo}+h_{t-1}W_{h_{t-1}o}+b_o)

{\sigma}'(x)=\sigma (x)(1-\sigma (x))

\frac{\partial O_t}{\partial H_{t-1}}=O_t(1-O_t)

        3.2.3.表达式合并

\frac{\partial E}{\partial H_{t-1}}=\frac{\partial E}{\partial H_t}\cdot tan(C_t)\cdot O_t(1-O_t)

3.3.第三类偏导:C_{t-1}

        两条完整求导路径表达式如下

\frac{\partial E}{\partial C_{t-1}}=\frac{\partial E }{\partial C_{t}}\cdot \frac{\partial C_{t}}{\partial C_{t-1}}+\frac{\partial E}{\partial H_{t}}\cdot\frac{\partial H_{t}}{\partial C_{t}}\cdot \frac{\partial C_{t}}{\partial C_{t-1}}

\frac{\partial E}{\partial C_{t-1}}=(\frac{\partial E}{\partial C_{t}}+\frac{\partial E}{\partial H_{t}}\cdot\frac{\partial H_{t}}{\partial C_{t}})\cdot \frac{\partial C_{t}}{\partial C_{t-1}}

        3.3.1.红色路径C_tC_{t-1}的偏导

C_t=F_t\odot C_{t-1}+I_t\odot \tilde{C_t}

\frac{\partial C_t}{\partial C_{t-1}}=F_{t}

        3.3.2.绿色路径H_tC_t的偏导

H_t=O_t\odot tan(C_t)

\frac{\partial H_t}{\partial C_t}=O_t\left [ 1-tan^{2}(C_t) \right ]

        3.3.3.两条路径表达式合并

\frac{\partial E}{\partial C_{t-1}}=\left \{ \frac{\partial E}{\partial C_{t}}+\frac{\partial E}{\partial H_{t}}\cdot O_t\left [ 1-tan^{2}(C_t) \right ] \right \}\cdot F_t

3.4.合并

        四个控制门计算完成后,最终合并起来,表达式如下:

\frac{\partial E}{\partial Sum}=\frac{\partial E}{\partial F}+\frac{\partial E}{\partial I}+\frac{\partial E}{\partial \tilde{C}}+\frac{\partial E}{\partial O}

四、权重的反向传播

        在上述第三节的求导过程结束后,此时误差就传递到了上图中的两个黄色圆圈部分,显然还差一步反向传播就完成了。最后一步就是将误差从黄色的圆圈部分分别传递到C_{t-1}H_{t-1}X_t。对于H_{t-1}X_t来说,要计算两个权重矩阵\Delta W_h\Delta W_x,而对于C_{t-1}来说,不必求矩阵,直接向后传递即可。

        在前向传播时计算流程如下:

\left\{\begin{matrix} H_{t-1}*W_h = Sum_h\\ X_t*W_x = Sum_x \\ Sum_h+Sum_x = Sum \end{matrix}\right.

        由上式可得,反向传播的表达式可以表达如下:

\left\{\begin{matrix} \Delta W_h = (H_{t-1})^T*Sum \\ \Delta H_{t-1}= Sum*(W_h)^T \\ \Delta W_x = (X_t)^T*Sum \\ \Delta X_t = Sum *(W_{x})^T \end{matrix}\right.

        最后还剩一个偏置项b,我们将上一节得出的\frac{\partial E}{\partial Sum}直接降维即可得到\Delta b,具体的程序实现方法会在下一篇文章中更直观的介绍。

五、总结

        上面就是一个完整的LSTM单元反向传播流程,至此本文已经对LSTM的反向传播理论基础进行了比较清楚的讲解。结合上一篇LSTM网络前向传播原理讲解,相信大家对LSTM的基本原理有了一个比较清晰的认识了。

        但这仅仅是单一的LSTM模型前向传播和反向传播原理,不足以构成一个完整的模型。但是有了上述基础,下篇文章中我们就可以1:1实现一个包含输入,隐藏层,输出,损失函数的完整的神经网络模型了。

文章正在写作中……

参考文献:

循环神经网络RNN&LSTM推导及实现 - 知乎

LSTM的推导与实现 - liujshi - 博客园

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/72856.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

web课程设计网页规划与设计:鲜花网站设计——基于HTML+CSS+JavaScript制作网上鲜花网页设计(5页)

🎉精彩专栏推荐 💭文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业: 【📚毕设项目精品实战案例 (10…

redis的热key、大key

目录 1.概述 2.查找方法 2.1.知道具体哪个key有问题 2.2.不知道具体哪个key有问题 3.处理方法 3.1.大key的处理方法 3.2.热key的处理方法 1.概述 大key: 含有较大数据或含有大量成员的Key称之为大Key,常见的大key如: String类型的Ke…

用友降运维成本实践:OceanBase替换MySQL,实现高可用

导语:随着业务模型的不断变化使运维难度越来越大,用友IT内部采用任务调度中心XXL-JOB和配置管理中心Nacos来实现公司IT分布式任务调度和微服务开发。但XXL-JOB和Nacos集群数量的增多又使其支撑系统MySQL难以招架。 为了寻找一款既能提供高可用又能统一管…

git常用命令(linux和windows通用)

本文的命令已可满足日常需求 配置用户信息 git config --global user.name “github用户名” git config --global user.email “github绑定邮箱"查看配置信息 git config --global user.name git config --global user.email~/.bashrc文件介绍 ~指当前用户的根目录&…

负载分担方式的双链路热备份配置

** 负载分担方式的双链路热备份配置 ** 实验要求和拓扑 负载分担方式的优点和主备方式的不同 负载分担可以每个ac都管理ap这样就避免了资源浪费,然后又作到了备份 主备方式则是,ap都交给一个ac,另一个ac只作为备份 实验拓扑 实验要求 配置…

全面上新!阿里 2023 版(Java 岗)面试突击手册,Github 已标星 37K

程序员面试背八股,几乎已经是互联网不可逆的一个形式了。自从面试**八股文火了之后,网上出现了不少 Java 相关的面试题,很多朋友盲目收集背诵,**但网上大部分的面试题,大多存在这几个问题:第一,…

剑指 Offer 53 - II. 0~n-1中缺失的数字

摘要 剑指 Offer 53 - II. 0~n-1中缺失的数字 一、二分法 1.1 二分法分析 排序数组中的搜索问题,首先想到 二分法 解决。 根据题意,数组可以按照以下规则划分为两部分。 左子数组: nums[i]i;右子数组:…

Eolink 11月企业与产品动态速览!

本月,Eolink IDEA 插件 “Eolink ApiKit” 最新版本 1.1.3 发布,可进行方法 API 解析的插件,可自动生成注释,可分析方法出入参等。 此外,Eolink 再获多项荣誉与认证,持续行业领先!一起来看看 11…

我今天吃了SHI,请对下联

1. 跨平台终端 Tabby(前身是 Terminus) 是一个可高度配置的终端模拟器和 SSH 或串口客户端,支持 Windows,macOS 和 Linux。 还有一些功能比较常见和易于使用的: 集成了 SSH,Telnet 客户端和连接管理器,可以在 SSH 会…

JAVA SCRIPT设计模式--行为型--设计模式之Mediator中介者模式(17)

JAVA SCRIPT设计模式是本人根据GOF的设计模式写的博客记录。使用JAVA SCRIPT语言来实现主体功能,所以不可能像C,JAVA等面向对象语言一样严谨,大部分程序都附上了JAVA SCRIPT代码,代码只是实现了设计模式的主体功能,不代…

图的初体验

最近周赛有两个差不多的题目,都是关于图的,之前也没有怎么练过关于图的题目,来记录一下。 T1 力扣T320周赛:T3:到达首都的最少油耗 class Solution {//结果long result ;public long minimumFuelCost(int[][] roads…

【推免攻略】四.2021年北交计算机学院夏令营、预推免保研经验

欢迎订阅本专栏:《北交计算机保研经验》 订阅地址:https://blog.csdn.net/m0_38068876/category_10779337.html 【推免攻略】一.北交计算机学院夏令营、预推免攻略【推免攻略】二.联系导师的前期准备及注意事项【推免攻略】三.2020年北交计算机学院夏令营…

如何能成为测试老大?先搞懂项目中的敏捷开发模式

1 什么是敏捷开发? 1、敏捷开发是以用户的需求进化为核心,采取迭代、循序渐进的方式来 进行软件项目的开发。 2、即将项目切分为多个子项目,每个子项目单独发布,保证软件较早可用。 3、及时收集用户反馈,调整未发布…

线性回归线性关系、非线性关系、常见函数导数、损失函数与优化算法、正规方程与单变量函数梯度下降、多变量函数梯度下降

一、线性回归概述 线性回归(Linear regression):是利用回归方程(函数)对一个或多个自变量(特征值)和因变量(目标值)之间关系进行建模的一种分析方式 特点:只有一个自变量的情况称为单变量回归,多于一个自变量情况的叫做多元回归 特征值与目…

机器学习笔记之受限玻尔兹曼机(五)基于含隐变量能量模型的对数似然梯度

机器学习笔记之受限玻尔兹曼机——基于含隐变量能量模型的对数似然梯度引言回顾:包含配分函数的概率分布受限玻尔兹曼机——场景构建对比散度基于含隐变量能量模型的对数似然梯度引言 上一节介绍了对比散度(Constractive Divergence)思想,本节将介绍基于…

制造型企业如何进行多项目管理?这篇文章说清楚了

受经济全球化与科技迅速发展的影响,我国很多企业早已进入了多项目管理模式。多项目管理是从企业整体出发,动态选择不具有类似性的项目,对企业所拥有的或可获得的生产要素和资源进行优化组合,有效、最优地分配企业资源,…

葡聚糖修饰金纳米颗粒(Dex-AuNps)|聚环氧氯丙烷二甲胺修饰多孔磁性葡聚糖微球

葡聚糖修饰金纳米颗粒(Dex-AuNps)|聚环氧氯丙烷二甲胺修饰多孔磁性葡聚糖微球 产品描述:通过特异性识别作用在表面等离子体共振传感器的金膜表面构建了伴刀豆球蛋白A/葡聚糖修饰金纳米颗粒自组装膜 中文名称:葡聚糖修饰金纳米颗粒 英文名称&#xff1…

CMAKE编译知识

1,Ubuntu安装了cmake之后,直接输入指令查看版本。cmake -version 我这里的版本为3.16.3 2,使用visual studio里面创建一个CMake项目是最快可以看到的。但是一般无法理解。所以我找了网上资料。根据网上所说和自己再试错下。初步了解了cmake…

[附源码]JAVA毕业设计微博网站(系统+LW)

[附源码]JAVA毕业设计微博网站(系统LW) 项目运行 环境项配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术&#xf…

IAA游戏市场规模已达百亿,如何防范游戏安全问题?

近年来,移动休闲游戏市场发展速度迅猛,伽马数据发布的《2022年休闲游戏发展报告》称,2022年第一季度移动游戏下载量TOP200榜单中,休闲类游戏占比已达45%。 2022年第一季度下载量TOP200移动游戏占比情况丨数据来源伽马数据 相比IA…