上一篇 | 下一篇 |
---|---|
RNN(下集) | 待编写 |
LSTM(长短期记忆)
参考知乎文章《人人都能看懂的LSTM介绍及反向传播算法推导(非常详细) - 知乎》,部分图片也进行了引用。
参考视频教程《3.结合例子理解LSTM_哔哩哔哩_bilibili》,举的例子还是比较贴切的。不过反向传播的完整推导,非必要不用深入看。
一、相关知识早知道
-
————————————————该模型提出的原因是为了解决什么问题的?———————————————
答:为了解决 R N N RNN RNN 的
- 梯度爆炸和梯度消失问题 (尽管激活函数 t a n h tanh tanh 已经在一定程度上缓解了梯度消失和梯度爆炸,但是当连乘次数过多时,依然会有此问题)。
- 长期依赖捕捉困难问题 ( R N N RNN RNN 的隐藏状态仅能保留短期历史信息,难以建模长序列中的复杂依赖)。
- 信息筛选问题 ( R N N RNN RNN 对历史信息是全部接收的,但是无法筛选出历史信息中真正有用的内容,间接导致了前两个问题)。
-
————————————为什么 L S T M \pmb{LSTM} LSTM 相较于 R N N \pmb{RNN} RNN 能缓解梯度爆炸和梯度消失?———————————
答:首先梯度存在于反向传播中。 R N N RNN RNN 的前向传播是连乘式结构,这就导致其反向传播也是连乘式结构,在反向传播过程中, R N N RNN RNN 的梯度要么始终大于 1 1 1 ,要么始终小于 1 1 1 ,连乘时便会轻而易举地引起梯度爆炸和梯度消失(是针对较远时间步的梯度来说)。而在 L S T M LSTM LSTM 中,由于其独特的多门结构,可以使得导数值可以在 1 1 1 上下浮动(上一时间步的梯度大于 1 1 1 ,下一时刻的梯度也可能小于 1 1 1 ,连乘时就不至于过大或过小)。并且通过对多个门结构参数的学习,可以实现信息筛选,来决定何时让梯度消失,何时保持梯度(也就是所谓的梯度截断)。
-
———————————————————— L S T M \pmb{LSTM} LSTM 的缺点有哪些?———————————————————
答:尽管 L S T M LSTM LSTM 显著优于 R N N RNN RNN ,但仍存在以下问题:
- 计算复杂度高 :三个门控结构引入更多参数,训练耗时较长。
- 长序列性能衰减 :处理极长序列(如数万步)时仍可能出现记忆衰退。
- 超参数敏感 :初始化策略和门控权重需精细调优,否则易导致训练不稳定。
- 参数多,容易造成过拟合。
- 无法并行计算 :每个时间步需依赖前序结果,影响处理效率(后续的 T r a n s f o r m e r Transformer Transformer 解决了这个问题)。
- 可解释性 一直是个难点,简单来说就是很难完全说清楚为什么内部结构这样设计,内部结构复杂使其决策逻辑难以被直观解释。
二、结构图
对比传统 R N N RNN RNN , L S T M LSTM LSTM 的突出特点是多了个记忆细胞,能选择性地保留之前和当下的语义信息。
1)整体结构图
L S T M LSTM LSTM 网络中的 σ \sigma σ 就是 s i g m o i d sigmoid sigmoid 函数。它在这里也被叫做 “门单元” ,因为 s i g m o i d sigmoid sigmoid 函数的值为 [ 0 , 1 ] [0,1] [0,1] ,类似阀门,开的口大进的就多,开的口小进的就少。
2)单个时刻结构图
其内部包含四个网络层(其中三个门单元,带 σ \sigma σ 的就是门),分别是:遗忘门、更新门、细胞状态更新层、输出门。
后续符号提示: ⊙ ⊙ ⊙ (向量或矩阵的对应元素相乘)。
令 x t \large x_t xt 尺寸为 m × 1 m×1 m×1 , h t − 1 \large h_{t-1} ht−1 尺寸为 n × 1 n×1 n×1 ,则 C t − 1 C_{t-1} Ct−1 尺寸为 n × 1 n×1 n×1 。
①遗忘门:
公式:
f
t
=
σ
(
W
h
f
⋅
h
t
−
1
+
W
x
f
⋅
x
t
+
b
f
)
=
σ
(
W
f
⋅
[
h
t
−
1
,
x
t
]
+
b
f
)
\large f_t=\sigma(W_{hf}·h_{t-1}+W_{xf}·x_t+b_f)=\sigma(W_f·[h_{t-1},x_t]+b_f)
ft=σ(Whf⋅ht−1+Wxf⋅xt+bf)=σ(Wf⋅[ht−1,xt]+bf)
f
t
\large f_t
ft :我叫它 “遗忘矩阵” ,由
x
t
\large x_t
xt 和
h
t
−
1
\large h_{t-1}
ht−1 计算而来,
σ
\large \sigma
σ 函数使
f
t
\large f_t
ft 的元素处于
0
∼
1
0\sim1
0∼1 ,使其对
C
t
−
1
\large C_{t-1}
Ct−1 具有遗忘功能,
1
1
1 表示 “完全接受”,
0
0
0 表示 “完全忽略”。
f
t
\large f_t
ft 的尺寸和
h
t
−
1
\large h_{t-1}
ht−1 、
C
t
−
1
\large C_{t-1}
Ct−1 一样,同为
n
×
1
n×1
n×1 ------------------------则可推出参数
W
h
f
\large W_{hf}
Whf 尺寸为
n
×
n
n×n
n×n ,
W
x
f
\large W_{xf}
Wxf 尺寸为
n
×
m
n×m
n×m ,
b
f
\large b_{f}
bf 尺寸为
n
×
1
n×1
n×1 。
②输入门:
公式:
i
t
=
σ
(
W
h
i
⋅
h
t
−
1
+
W
x
i
⋅
x
t
+
b
i
)
=
σ
(
W
i
⋅
[
h
t
−
1
,
x
t
]
+
b
i
)
C
t
~
=
t
a
n
h
(
W
h
C
⋅
h
t
−
1
+
W
x
C
⋅
x
t
+
b
C
)
=
t
a
n
h
(
W
C
⋅
[
h
t
−
1
,
x
t
]
+
b
C
)
\large i_t=\sigma(W_{hi}·h_{t-1}+W_{xi}·x_t+b_i)=\sigma(W_i·[h_{t-1},x_t]+b_i)\\ \large \tilde{C_t}=tanh(W_{hC}·h_{t-1}+W_{xC}·x_t+b_C)=tanh(W_C·[h_{t-1},x_t]+b_C)
it=σ(Whi⋅ht−1+Wxi⋅xt+bi)=σ(Wi⋅[ht−1,xt]+bi)Ct~=tanh(WhC⋅ht−1+WxC⋅xt+bC)=tanh(WC⋅[ht−1,xt]+bC)
这里的
C
t
~
\large \tilde{C_t}
Ct~ 代表的是此时刻生成的新记忆,只不过是初始版。
这里的 i t \large i_t it :我叫它 “新记忆筛选矩阵” , 元素均处于 0 ∼ 1 0\sim1 0∼1 ,使其对 C t ~ \large \tilde{C_t} Ct~ 具有筛选功能, 1 1 1 表示 “完全通过”, 0 0 0 表示 “完全忽略”。
i t \large i_t it 的尺寸为 n × 1 n×1 n×1 , C t ~ \large \tilde{C_t} Ct~ 的尺寸为 n × 1 n×1 n×1 ------------------------则可推出参数 W h i 、 W h C \large W_{hi}、W_{hC} Whi、WhC 尺寸为 n × n n×n n×n , W x i 、 W x C \large W_{xi}、W_{xC} Wxi、WxC 尺寸为 n × m n×m n×m , b i 、 b C \large b_{i}、b_{C} bi、bC 尺寸为 n × 1 n×1 n×1 。
③细胞状态更新
公式:
C
t
=
f
t
⊙
C
t
−
1
+
i
t
⊙
C
t
~
\large C_t=f_t⊙C_{t-1}+i_t⊙\tilde{C_t}
Ct=ft⊙Ct−1+it⊙Ct~
这里的
C
t
\large C_t
Ct 便是最终要输出的新记忆。
新记忆由经过遗忘的旧记忆,以及经过筛选的原始新记忆,相加而得到。
C t \large C_t Ct 的尺寸为 n × 1 n×1 n×1 。
④输出门:
公式:
o
t
=
σ
(
W
h
o
⋅
h
t
−
1
+
W
x
o
⋅
x
t
+
b
o
)
=
σ
(
W
o
⋅
[
h
t
−
1
,
x
t
]
+
b
o
)
h
t
=
o
t
⊙
t
a
n
h
(
C
t
)
\large o_t=\sigma(W_{ho}·h_{t-1}+W_{xo}·x_t+b_o)=\sigma(W_o·[h_{t-1},x_t]+b_o)\\ \large h_t=o_t⊙tanh(C_t)
ot=σ(Who⋅ht−1+Wxo⋅xt+bo)=σ(Wo⋅[ht−1,xt]+bo)ht=ot⊙tanh(Ct)
这里的
h
t
h_t
ht 便是隐层状态输出,由新记忆
C
t
\large C_t
Ct 通过
t
a
n
h
tanh
tanh 函数调整到
−
1
∼
1
-1\sim1
−1∼1 之间,再与
o
t
o_t
ot 逐元素相乘得到。
这里的 o t o_t ot :我叫它 “隐层状态输出提炼矩阵” ,元素均处于 0 ∼ 1 0\sim1 0∼1 ,使其对 C t \large C_t Ct 具有提炼功能, 1 1 1 表示 “完全通过”, 0 0 0 表示 “完全忽略”。
从 C t \large C_t Ct 中提炼出 h t h_t ht ,是因为细胞状态 C t \large C_t Ct 存储了经过遗忘门和输入门筛选后的所有长期信息(如历史趋势或主题),但并非所有内容都需直接传送给后续网络,通过提炼,仅保留与当前任务相关的部分。并且经过 t a n h tanh tanh 函数压缩之后,可以避免数值爆炸,并增强非线性表达能力。
3)单个时刻公式图
4)补充:
传统 R N N RNN RNN 的同步多对多结构图:
三、损失函数及反向传播
想要深入研究的,就参考知乎那篇文章,以及视频的最后一个分集(可以稍稍看看矩阵求导,对反向传播的计算有帮助)。
只是想用的,知道使用的是链式法则、梯度下降法即可。