长短期记忆网络(Long Short-Term Memory,LSTM)

news2024/10/23 14:49:38

简介:个人学习分享,如有错误,欢迎批评指正。

长短期记忆网络(Long Short-Term Memory,简称LSTM)是一种特殊的循环神经网络(Recurrent Neural Network,简称RNN)架构,专门设计用于处理和预测序列数据中的长依赖关系。LSTM由Sepp Hochreiter和Jürgen Schmidhuber在1997年提出,旨在克服传统RNN在处理长序列时面临的梯度消失和梯度爆炸问题。

背景与动机

传统的RNN在处理序列数据(如时间序列、自然语言等)时,通过其循环结构能够记忆和利用先前的信息。然而,随着序列长度的增加,RNN在训练过程中会遇到梯度消失或梯度爆炸的问题,导致模型难以学习到长期依赖关系。该限制使得RNN在许多需要捕捉长距离依赖的任务中的表现不理想。LSTM通过引入门控机制,有效地解决了这一问题,使得网络能够在更长的序列中保持信息。

一、RNN的基本结构与局限性

1. RNN的基本结构

在这里插入图片描述

RNN通过循环连接来处理序列数据。对于一个序列输入 ( x 1 , x 2 , … , x T ) (x_1, x_2, \ldots, x_T) (x1,x2,,xT),RNN在每个时间步 t t t 更新新隐藏状态 h t h_t ht:

h t = tanh ⁡ ( W x h x t + W h h h t − 1 + b h ) h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h) ht=tanh(Wxhxt+Whhht1+bh)

其中, W x h W_{xh} Wxh W h h W_{hh} Whh 是权重矩阵, b h b_h bh 是偏置项, tanh ⁡ \tanh tanh 是激活函数。

2. RNN的局限性

  • 梯度消失与梯度爆炸: 在反向传播过程中,长序列会导致梯度在时间步上传播的迅速减小或增大,使得模型难以学习长期依赖。(因为RNN在时序上共享参数,梯度在反向传播过程中,不断连乘,数值不是越来越大就是越来越小)

  • 长期依赖难以捕捉: 由于梯度衰减,RNN难以记住序列中较早的信息。(梯度小幅更新的网络层会停止学习,这些通常是较早的层。由于这些层不学习,RNN无法记住它在较长序列中学习到的内容,因此它的记忆是短期的。)

二、LSTM的核心理念

LSTM旨在解决传统循环神经网络(RNN)在处理长序列时面临的梯度消失和梯度爆炸问题。其核心思想是通过引入门控机制(Gates)来控制信息的流动,允许网络选择性地记住或遗忘信息,从而有效地捕捉长期开依赖关系

1. 信息流动与记忆保持

在RNN中,隐藏状态 h t h_t ht 通过时间步传递,理论上可以保留任意长的历史信息。然而,实际训练中,由于梯度在反向传播的逐步消失或爆炸,RNN难以有效学习到长距离的依赖关系。LSTM通过设计专门的结构,确保关键信息可以在长时间内被有效传递和更新。

2. 门控机制的引入

LSTM引入了三个主要的门控单元——遗忘门(Forget Gate)、输入门(Input Gate)和输出门(Output Gate)。这些门通过学习动态地控制信息的保留和更新,从而实现对长期和短期记忆的有效管理。

三、LSTM的详细结构

在这里插入图片描述

1. LSTM单元的组成

一个标准的LSTM单元包括以下几个关键部分:

  1. 记忆单元(Cell State, C t C_t Ct):负责存储长期记忆
  2. 隐藏状态(Hidden State, h t h_t ht):传递短期记忆和输出
  3. 遗忘门(Forget Gate, f t f_t ft):决定遗忘多少过去的信息
  4. 输入门(Input Gate, i t i_t it):决定接受多少新信息
  5. 候选记忆单元(Candidate Cell State, C ~ t \tilde{C}_t C~t):生成新信息用于更新记忆单元。
  6. 输出门(Output Gate, o t o_t ot):决定输出多少记忆单元的信息

2. 信息流动路径

信息在LSTM单元中的流动可以分为以下几个步骤:

  1. 遗忘阶段:决定从记忆单元中遗忘多少信息。
  2. 输入阶段:决定接收多少新信息,并生成候选记忆单元。
  3. 更新记忆单元:结合遗忘门和输入门的输出,更新记忆单元。
  4. 输出阶段:决定输出多少记忆单元的信息作为隐藏状态。

3. 详细数学表示

以下是每个门控单元和记忆更新的详细数学表达:

3.1 遗忘门(Forget Gate)

遗忘门决定记忆单元中哪些信息需要被遗忘。通过一个 sigmoid 激活函数,输出值 f t f_t ft 在 0 到 1 之间,每个元素决定对应记忆单元信息的保留程度

f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma (W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)

  • σ \sigma σ:Sigmoid 激活函数。
  • W f W_f Wf:遗忘门的权重矩阵。
  • h t − 1 h_{t-1} ht1:前一时刻的隐藏状态。
  • x t x_t xt:当前时刻的输入。
  • b f b_f bf:遗忘门的偏置。

解释

  • f t f_t ft 接近 1 时,记忆单元中的信息被保留
  • f t f_t ft 接近 0 时,记忆单元中的信息被遗忘

3.2 输入门(Input Gate)与候选记忆单元(Candidate Cell State)

输入门控制新信息的加入。它由两个部分组成:

  1. 输入门层(Input Gate Layer)通过 sigmoid 函数确定哪些部分需要更新
  2. 候选记忆单元层(Candidate Cell State Layer)通过 tanh 函数生成新的候选记忆信息

i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma (W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)

C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh (W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)

  • i t i_t it:输入门的输出。
  • C ~ t \tilde{C}_t C~t:候选记忆单元。
  • 其他符号含义同上。

解释

  • i t i_t it 决定了记忆单元中哪些部分将被更新
  • C ~ t \tilde{C}_t C~t 提供了新的信息,用于更新记忆单元。

3.3 更新记忆单元

结合遗忘门和输入门的输出,更新记忆单元状态 C t C_t Ct

C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ftCt1+itC~t

  • ⊙ \odot :逐元素相乘。
  • C t − 1 C_{t-1} Ct1:前一时刻的记忆单元状态。

解释

  • f t ⊙ C t − 1 f_t \odot C_{t-1} ftCt1:保留部分记忆单元中的信息。
  • i t ⊙ C ~ t i_t \odot \tilde{C}_t itC~t:添加新信息到记忆单元中。

3.4 输出门(Output Gate)

输出门决定了下一隐藏状态 h t h_t ht 的值,即当前时刻 LSTM 单元的输出。

o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma (W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)

h t = o t ⊙ tanh ⁡ ( C t ) h_t = o_t \odot \tanh(C_t) ht=ottanh(Ct)

  • o t o_t ot:输出门的输出。
  • h t h_t ht:当前时刻的隐藏状态。

解释

  • o t o_t ot 决定了记忆单元中哪些部分将被输出
  • h t h_t ht 是通过将输出门的输出与记忆单元状态的 tanh ⁡ \tanh tanh 变换相乘得到的。

四、LSTM的工作流程详解

1. 前向传播过程

在每个时间步,LSTM单元按照以下步骤进行信息处理:

  1. 输入接收:接收当前输入 x t x_t xt 和前一时刻的隐藏状态 h t − 1 h_{t-1} ht1
  2. 计算遗忘门 f t f_t ft:通过遗忘门决定从记忆单元中遗忘多少信息。
  3. 计算输入门 i t i_t it 和候选记忆单元 C ~ t \tilde{C}_t C~t:决定添加多少新信息到记忆单元中。
  4. 更新记忆单元 C t C_t Ct:结合遗忘门和输入门的输出,更新记忆单元状态。
  5. 计算输出门 o t o_t ot:决定记忆单元中的哪些信息输出。
  6. 生成隐藏状态 h t h_t ht:通过输出门控制的记忆单元状态生成当前时刻的隐藏状态。

2. 反向传播与梯度传播

LSTM通过反向传播算法(Backpropagation Through Time, BPTT)进行训练。在反向传播过程中,梯度需要通过时间步传递。LSTM的设计通过门控机制有效地缓解了梯度消失和梯度爆炸的问题。

2.1 梯度流动

  1. 直接路径:记忆单元 C t C_t Ct 通过加法操作与 f t f_t ft i t i_t it 连接,允许梯度直接在时间步上传播,从而减缓梯度消失。
  2. 门控制机制的调节:遗忘门和输入门通过 sigmoid 激活函数动态调节梯度的流动。当需要保留信息时,门的激活值接近1,允许梯度通过;反之则减小梯度流动。

2.2 反向传播的数学细节

假设损失函数为 L L L,则需要计算每个参数对 L L L 的偏导数。以下是主要梯度计算步骤:

  1. 计算损失对输出 h t h_t ht 的梯度

∂ L ∂ h t \frac{\partial L}{\partial h_t} htL

  1. 计算输出门 o t o_t ot 的梯度

∂ L ∂ o t = ∂ L ∂ h t ⊙ tanh ⁡ ( C t ) \frac{\partial L}{\partial o_t} = \frac{\partial L}{\partial h_t} \odot \tanh(C_t) otL=htLtanh(Ct)

  1. 计算记忆单元 C t C_t Ct 的梯度

∂ L ∂ C t = ∂ L ∂ h t ⊙ o t ⊙ ( 1 − tanh ⁡ 2 ( C t ) ) + ∂ L ∂ C t + 1 ⊙ f t + 1 \frac{\partial L}{\partial C_t} = \frac{\partial L}{\partial h_t} \odot o_t \odot (1 - \tanh^2(C_t)) + \frac{\partial L}{\partial C_{t+1}} \odot f_{t+1} CtL=htLot(1tanh2(Ct))+Ct+1Lft+1

  1. 计算遗忘门 f t f_t ft 和输入门 i t i_t it 的梯度

∂ L ∂ f t = ∂ L ∂ C t ⊙ C t − 1 \frac{\partial L}{\partial f_t} = \frac{\partial L}{\partial C_t} \odot C_{t-1} ftL=CtLCt1

∂ L ∂ i t = ∂ L ∂ C t ⊙ C ~ t \frac{\partial L}{\partial i_t} = \frac{\partial L}{\partial C_t} \odot \tilde{C}_t itL=CtLC~t

  1. 计算候选记忆单元 C ~ t \tilde{C}_t C~t 的梯度

∂ L ∂ C ~ t = ∂ L ∂ C t ⊙ i t \frac{\partial L}{\partial \tilde{C}_t} = \frac{\partial L}{\partial C_t} \odot i_t C~tL=CtLit

  1. 计算各个门控单元的激活函数的梯度

∂ L ∂ z = ∂ L ∂ gate output ⋅ gate output ⋅ ( 1 − gate output ) \frac{\partial L}{\partial z} = \frac{\partial L}{\partial \text{gate output}} \cdot \text{gate output} \cdot (1 - \text{gate output}) zL=gate outputLgate output(1gate output)

其中, z z z 表示门的线性组合输入。

  1. 更新权重和偏置:通过链式法则将梯度传递给权重和偏置,并使用优化算法(如Adam、RMSprop等)更新参数。

3. 参数更新

LSTM的参数包括遗忘门、输入门、候选记忆单元和输出门的权重和偏置。具体参数更新步骤如下:

  1. 计算各参数的梯度

∂ L ∂ W f , ∂ L ∂ b f , ∂ L ∂ W i , ∂ L ∂ b i , … \frac{\partial L}{\partial W_f}, \quad \frac{\partial L}{\partial b_f}, \quad \frac{\partial L}{\partial W_i}, \quad \frac{\partial L}{\partial b_i}, \ldots WfL,bfL,WiL,biL,

  1. 应用优化算法(如Adam)根据梯度更新参数

θ = θ − η ⋅ ∂ L ∂ θ \theta = \theta - \eta \cdot \frac{\partial L}{\partial \theta} θ=θηθL

其中, η \eta η 是学习率, θ \theta θ 代表参数。

五、LSTM的门控制详解

1. 遗忘门(Forget Gate)

在LSTM中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为“遗忘门”的结构完成。该遗忘门会读取上一个输出 h t − 1 h_{t-1} ht1 和当前输入 x t x_t xt,做一个Sigmoid 的非线性映射,然后输出一个向量 f t f_t ft该向量每一个维度的值都在0到1之间,1表示完全保留,0表示完全舍弃,相当于记住了重要的,忘记了无关紧要的),最后与细胞状态 C t − 1 C_{t-1} Ct1 相乘。
在这里插入图片描述

遗忘门的作用是决定从记忆单元中丢弃多少过去的信息。其通过当前输入 x t x_t xt 和前一隐藏状态 h t − 1 h_{t-1} ht1 计算得出。

f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma (W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)

特性

  • f t f_t ft 接近 1 时,记忆单元中的信息被保留。
  • f t f_t ft 接近 0 时,记忆单元中的信息被遗忘。

重要性

  • 允许网络动态决定保留或丢弃信息,有助于捕捉长期依赖关系

2. 输入门(Input Gate)与候选记忆单元

下一步是确定什么样的新信息被存放在细胞状态中。这里包含两个部分:

在这里插入图片描述

输入门控制当前输入的信息如何更新到记忆单元中。它包含两个部分:

  1. 输入门层

i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma (W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)

sigmoid层称“输入门层”决定了哪些部分的候选记忆单元将被更新。

  1. 候选记忆单元层

C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh (W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)

tanh层生成新的候选记忆信息。

特性

  • 输入门 i t i_t it 控制新信息的流入。
  • 候选记忆单元 C ~ t \tilde{C}_t C~t 提供新的信息以更新记忆单元。

细胞状态

现在是更新旧细胞状态的时间了, C t − 1 C_{t-1} Ct1 更新为 C t C_{t} Ct 。我们把旧状态与 f t f_t ft相乘,丢弃掉我们确定需要丢弃的信息,接着加上 i t ∗ C ~ t i_t*\tilde{C}_t itC~t。这就是新的候选值,根据我们决定更新每个状态的程度进行变化。

在这里插入图片描述

3. 输出门(Output Gate)

最终,我们需要确定输出什么值。这个输出将会基于我们的细胞状态,但是也是一个过滤后的版本。
首先,我们运行一个sigmoid层来确定细胞状态的哪个部分将输出出去。
接着,我们把细胞状态通过tanh进行处理(得到一个在-1到1之间的值)并将它和sigmoid门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。

在这里插入图片描述

输出门决定了当前时刻的隐藏状态 h t h_t ht 以及输出。

o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma (W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)

h t = o t ⊙ tanh ⁡ ( C t ) h_t = o_t \odot \tanh(C_t) ht=ottanh(Ct)

特性

  • 输出门 o t o_t ot 控制了记忆单元中哪些部分的信息被输出。
  • 结合记忆单元状态,通过 tanh ⁡ \tanh tanh 函数提供非线性变换。

重要性

  • 决定了隐藏状态 h t h_t ht 中包含的信息,从而影响下一时刻的计算和最终输出

六、信息流动与记忆更新的详细过程

在这里插入图片描述

1. 信息流动示例

假设有一个序列 x = [ x 1 , x 2 , … , x T ] x = [x_1, x_2, \ldots, x_T] x=[x1,x2,,xT],LSTM 在每个时间步 t t t 的计算过程如下:

  1. 时间步 t t t

    • 输入 x t x_t xt 和前一隐藏状态 h t − 1 h_{t-1} ht1
    • 计算遗忘门 f t f_t ft
    • 计算输入门 i t i_t it 和候选记忆单元 C ~ t \tilde{C}_t C~t
    • 更新记忆单元 C t C_t Ct
    • 计算输出门 o t o_t ot
    • 生成隐藏状态 h t h_t ht
  2. 信息流动

    • 记忆单元 C t C_t Ct 是通过保留部分 C t − 1 C_{t-1} Ct1 和添加新信息 C ~ t \tilde{C}_t C~t 来更新的。
    • 隐藏状态 h t h_t ht 是通过输出门控制的 C t C_t Ct tanh ⁡ \tanh tanh 变换生成的,作为当前时刻的输出,并传递到下一个时间步。

2. 记忆更新的动态

  • 保留与遗忘:通过遗忘门 f t f_t ft,网络决定了哪些历史信息需要被保留,哪些需要被遗忘。这使得网络能够保留重要的长期信息,而忽略无关的短期信息。

  • 新信息的引入:通过输入门 i t i_t it 和候选记忆单元 C ~ t \tilde{C}_t C~t,网络决定了引入多少新的信息到记忆单元中,从而更新当前的记忆状态。

  • 输出的生成:通过输出门 o t o_t ot,网络决定了当前记忆单元状态中哪些部分需要被输出,从而影响当前时刻的隐藏状态 h t h_t ht

3. 实例说明

假设我们正在处理一个文本序列,目标是预测下一个单词。LSTM在每个时间步接收当前单词的嵌入向量 x t x_t xt,并基于前一时刻的隐藏状态 h t − 1 h_{t-1} ht1 和记忆单元状态 C t − 1 C_{t-1} Ct1 进行计算:

  1. 遗忘门决定:例如,网络可能决定忘记当前一个时间步的某些主题信息(如一个名词)。
  2. 输入门决定:网络可能决定引入新的信息(如一个动词)。
  3. 记忆单元更新:结合遗忘和输入门的输出,记忆单元状态被更新为保留了重要的主题信息,并引入了新的动词信息。
  4. 输出门决定:网络根据新的记忆单元状态生成当前的隐藏状态,用于预测下一个单词。

这种动态调整的机制使得LSTM能够在处理长文本时,保持对主题的长期记忆,同时灵活地引入新的信息。

七、LSTM的变种与扩展

LSTM有许多变种和扩展,旨在改进其性能或适应特定的应用场景。以下是几种常见的变种:

1. 双向LSTM(Bidirectional LSTM)

在这里插入图片描述

概述

  • 双向LSTM由两个LSTM单元组成,一个处理序列的正向信息,另一个处理序列的反向信息
  • 最终的隐藏状态是两个方向隐藏状态的组合(通常是拼接或求和)。

公式

h t → = LSTM ( x t , h t − 1 → ) \overrightarrow{h_t} = \text{LSTM}(x_t, \overrightarrow{h_{t-1}}) ht =LSTM(xt,ht1 )

h t ← = LSTM ( x t , h t + 1 ← ) \overleftarrow{h_t} = \text{LSTM}(x_t, \overleftarrow{h_{t+1}}) ht =LSTM(xt,ht+1 )

h t = [ h t → ; h t ← ] h_t = [\overrightarrow{h_t}; \overleftarrow{h_t}] ht=[ht ;ht ]

优点

  • 能够同时利用前后文信息,提高对上下文的理解能力。

在这里插入图片描述

应用场景

  • 自然语言处理中如命名实体识别、语义理解等任务。

2. 堆叠LSTM(Stacked LSTM)

概述

  • 堆叠LSTM通过将多个LSTM层堆叠在一起,形成更深的网络结构
  • 每一层的输出作为下一层的输入,增加模型的表达能力。

公式

h t ( l ) = LSTM ( l ) ( h t ( l − 1 ) , h t − 1 ( l ) ) h_t^{(l)} = \text{LSTM}^{(l)}(h_t^{(l-1)}, h_{t-1}^{(l)}) ht(l)=LSTM(l)(ht(l1),ht1(l))

其中, l l l 表示层数, h t ( 0 ) = x t h_t^{(0)} = x_t ht(0)=xt

优势

  • 提升模型的复杂度和拟合能力。
  • 更好地捕捉高级特征和抽象信息。

应用场景

  • 需要深层特征提取的任务,如复杂的自然语言处理任务和时间序列预测。

3. 卷积LSTM(Convolutional LSTM, ConvLSTM)

概述

  • 卷积LSTM结合了卷积神经网络(CNN)和LSTM,适用于处理具有空间结构的时空数据
  • 门控制机制中的全连接操作被卷积操作取代。

公式

f t = σ ( W f ∗ [ h t − 1 , x t ] + b f ) f_t = \sigma (W_f * [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)

i t = σ ( W i ∗ [ h t − 1 , x t ] + b i ) i_t = \sigma (W_i * [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)

C ~ t = tanh ⁡ ( W C ∗ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh (W_C * [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)

C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ftCt1+itC~t

o t = σ ( W o ∗ [ h t − 1 , x t ] + b o ) o_t = \sigma (W_o * [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)

h t = o t ⊙ tanh ⁡ ( C t ) h_t = o_t \odot \tanh(C_t) ht=ottanh(Ct)

其中, ∗ * 表示卷积操作。

优势

  • 能够捕捉时空数据中的空间依赖和时间依赖。

应用场景

  • 视频预测、天气预报、交通流量预测等。

4. Peephole LSTM

概述

  • Peephole LSTM在门控制中引入了记忆单元状态的直接连接,使门控单元能够访问记忆单元的状态

公式

f t = σ ( W f ⋅ [ h t − 1 , x t ] + V f ⋅ C t − 1 + b f ) f_t = \sigma (W_f \cdot [h_{t-1}, x_t] + V_f \cdot C_{t-1} + b_f) ft=σ(Wf[ht1,xt]+VfCt1+bf)

i t = σ ( W i ⋅ [ h t − 1 , x t ] + V i ⋅ C t − 1 + b i ) i_t = \sigma (W_i \cdot [h_{t-1}, x_t] + V_i \cdot C_{t-1} + b_i) it=σ(Wi[ht1,xt]+ViCt1+bi)

o t = σ ( W o ⋅ [ h t − 1 , x t ] + V o ⋅ C t + b o ) o_t = \sigma (W_o \cdot [h_{t-1}, x_t] + V_o \cdot C_t + b_o) ot=σ(Wo[ht1,xt]+VoCt+bo)

优势

  • 提高模型对记忆单元状态的感知能力,增强门控机制的表现力。

应用场景

  • 需要更精细控制记忆单元状态的任务,如精确的时间序列预测。

5. 注意力机制与LSTM结合

概述

  • 将注意力机制(Attention Mechanism)与LSTM结合,使模型能够动态地关注输入序列中与当前输出最相关的部分

优势

  • 提升模型的性能和解释性,尤其是在处理长序列时能够更有效地利用重要信息。

应用场景

  • 机器翻译、文本摘要、图像描述生成等任务。

示例

  • 在机器翻译中,注意力机制使得解码器在生成每个目标词时,能够关注源句子中最相关的词,从而提高翻译质量。

八、LSTM的实现细节与优化

1. 权重矩阵的初始化

权重初始化对LSTM的训练至关重要,常用的方法包括:

  1. Xavier初始化(Glorot Initialization)

    Variance = 2 输入维度 + 输出维度 \text{Variance} = \frac{2}{\text{输入维度} + \text{输出维度}} Variance=输入维度+输出维度2

    适用于Sigmoid和tanh激活函数。

  2. He初始化

    Variance = 2 输入维度 \text{Variance} = \frac{2}{\text{输入维度}} Variance=输入维度2

    适用于ReLU激活函数。

重要性

  • 合适的初始化方法有助于加速收敛,防止梯度消失或爆炸。

2. 激活函数的选择

  1. Sigmoid函数

    σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1

    • 用于门控单元,输出范围在0到1之间。
    • 控制信息的流动。
  2. 双曲正切函数(tanh)

    tanh ⁡ ( x ) = e x − e − x e x + e − x \tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} tanh(x)=ex+exexex

    • 用于生成候选记忆单元和输出隐藏状态。
    • 输出范围在-1到1之间,提供非线性变换。

3. 正则化技术

为了防止LSTM过拟合,常用的正则化技术包括:

  1. Dropout

    • 在训练过程中随机丢弃部分神经元,防止模型过于依赖某些特征。
    • 适用于LSTM的各个门控单元和隐藏层。
  2. L2正则化

    • 在损失函数中加入权重的平方和,限制权重的大小。
    • 有助于防止权重过大,减少过拟合风险。

4. 梯度裁剪(Gradient Clipping)

梯度裁剪用于防止梯度爆炸问题,特别是在处理长序列时。

实现方法

  • 全局梯度裁剪将所有参数的梯度组合成一个向量,如果其范数超过预设阈值,则按比例缩放

    如果   ∥ g ∥ > 阈值   ,则   g = g ∥ g ∥ × 阈值 \text{如果} \, \|g\| > \text{阈值} \, \text{,则} \, g = \frac{g}{\|g\|} \times \text{阈值} 如果g>阈值,则g=gg×阈值

  • 按参数裁剪分别对每个参数的梯度进行裁剪,确保每个参数的梯度在阈值范围内

重要性

  • 防止梯度过大导致训练不稳定或参数更新过度。

5. 批量处理与序列填充

在实际应用中,为了提高训练效率,通常采用批量处理(Batch Processing)技术。然而,序列数据的长度可能不同,需要进行填充(Padding)以统一长度。

步骤

  1. 序列填充(Padding)

    • 将所有序列填充到相同的长度,通常在序列的末尾添加零向量。
    • 确定一个最大序列长度,超出的部分截断,不足的部分填充。
  2. 掩码(Masking)

    • 使用掩码标记填充的位置,使得模型在计算损失和梯度时忽略填充部分。

6. 优化算法的选择

常用的优化算法包括:

  1. Adam优化器

    • 结合了动量和自适应学习率的优势。
    • 适用于大多数深度学习任务。
  2. RMSprop

    • 适用于处理非平稳目标,常用于循环神经网络。
  3. SGD(随机梯度下降)

    • 适用于大规模数据,但通常需要较长的训练时间和学习率调整。

九、案例:多变量时间序列预测及python代码

案例概述
使用长短期记忆网络(LSTM)进行多变量时间序列预测。我们将以股票价格预测为例,利用多个相关特征(如开盘价、收盘价、最高价、最低价、成交量等)来预测未来的收盘价。

1. 数据收集与预处理

1.1 获取股票数据

我们将使用yfinance库从雅虎财经获取苹果公司(AAPL)的股票数据。首先,确保已安装必要的库:

pip install yfinance pandas numpy scikit-learn matplotlib tensorflow

1.2 导入必要的库

import yfinance as yf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping

1.3 下载股票数据

# 下载苹果公司(AAPL)的历史股票数据
df = yf.download('AAPL', start='2010-01-01', end='2023-12-31')

# 查看数据
print(df.head())

1.4 处理缺失值和异常值

# 检查缺失值
print(df.isnull().sum())

# 填充缺失值(如果有)
df.fillna(method='ffill', inplace=True)

1.5 特征工程(可选)

在这个例子中,我们将使用原始的开盘价(Open)、最高价(High)、最低价(Low)、收盘价(Close)和成交量(Volume)作为特征。此外,可以创建一些技术指标,如移动平均线(MA)、相对强弱指数(RSI)等,但为了简化,我们将仅使用基本特征。

1.6 数据归一化

LSTM对数据的尺度敏感,因此需要对数据进行归一化处理。我们将使用MinMaxScaler将数据缩放到0到1之间。

# 选择特征
features = ['Open', 'High', 'Low', 'Close', 'Volume']
data = df[features]

# 初始化Scaler
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data)

# 将归一化后的数据转换为DataFrame
scaled_df = pd.DataFrame(scaled_data, columns=features, index=df.index)

1.7 创建时间序列样本

我们将使用过去60天的数据来预测第61天的收盘价。

def create_sequences(data, seq_length):
    X = []
    y = []
    for i in range(seq_length, len(data)):
        X.append(data[i-seq_length:i])
        y.append(data[i, 3])  # 'Close'的索引为3
    return np.array(X), np.array(y)

SEQ_LENGTH = 60

# 转换为numpy数组
scaled_array = scaled_df.values

# 创建序列
X, y = create_sequences(scaled_array, SEQ_LENGTH)

print(f'Input shape: {X.shape}')
print(f'Target shape: {y.shape}')

1.8 拆分训练集和测试集

通常,时间序列数据按时间顺序拆分,不能随机拆分。

# 定义训练集比例
TRAIN_SIZE = 0.8
train_size = int(len(X) * TRAIN_SIZE)

X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]

print(f'Training samples: {X_train.shape[0]}')
print(f'Testing samples: {X_test.shape[0]}')

2. 构建LSTM模型

2.1 定义模型架构
我们将构建一个堆叠的LSTM模型,包含两个LSTM层和一个全连接层。为了防止过拟合,我们将在LSTM层之间添加Dropout层。

# 获取输入特征数量
n_features = X_train.shape[2]

# 构建模型
model = Sequential()

# 第一层LSTM
model.add(LSTM(units=50, return_sequences=True, input_shape=(SEQ_LENGTH, n_features)))
model.add(Dropout(0.2))

# 第二层LSTM
model.add(LSTM(units=50, return_sequences=False))
model.add(Dropout(0.2))

# 全连接层
model.add(Dense(units=25))
model.add(Dense(units=1))  # 输出一个值,即预测的收盘价

# 查看模型摘要
model.summary()

2.2 编译模型

我们将使用均方误差(MSE)作为损失函数,优化器选择Adam。

model.compile(optimizer='adam', loss='mean_squared_error')

3. 训练模型

3.1 训练模型

为了防止过拟合,我们将使用Early Stopping回调,当验证损失在连续5个周期内不再改善时停止训练。

# 定义Early Stopping
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

# 训练模型
history = model.fit(
    X_train, y_train,
    epochs=100,
    batch_size=64,
    validation_split=0.2,
    callbacks=[early_stop],
    verbose=1
)

3.2 可视化训练过程

# 绘制训练和验证损失
plt.figure(figsize=(12, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss During Training')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

4. 评估与预测

4.1 在测试集上评估模型

# 预测
predictions = model.predict(X_test)

# 反归一化
predictions = scaler.inverse_transform(
    np.concatenate((np.zeros((predictions.shape[0], 4)), predictions), axis=1)
)[:, 4]

# 真实值反归一化
y_test_rescaled = scaler.inverse_transform(
    np.concatenate((np.zeros((y_test.shape[0], 4)), y_test.reshape(-1, 1)), axis=1)
)[:, 4]

# 计算均方根误差(RMSE)
rmse = np.sqrt(np.mean((predictions - y_test_rescaled) ** 2))
print(f'RMSE on Test Set: {rmse:.2f}')

4.2 可视化预测结果

# 创建一个DataFrame来存储真实值和预测值
test_dates = df.index[-len(y_test):]
comparison_df = pd.DataFrame({
    'Date': test_dates,
    'Actual Close': y_test_rescaled,
    'Predicted Close': predictions
})

# 设置日期为索引
comparison_df.set_index('Date', inplace=True)

# 绘制图表
plt.figure(figsize=(14, 7))
plt.plot(comparison_df['Actual Close'], label='Actual Close Price')
plt.plot(comparison_df['Predicted Close'], label='Predicted Close Price')
plt.title('Actual vs Predicted Close Price')
plt.xlabel('Date')
plt.ylabel('Close Price USD')
plt.legend()
plt.show()

4.3 预测未来价格

为了预测未来几天的收盘价,我们需要使用最新的60天数据作为输入。

# 假设我们要预测未来5天的收盘价
future_days = 5
last_sequence = scaled_array[-SEQ_LENGTH:]

for _ in range(future_days):
    # 预测下一天的收盘价
    pred = model.predict(last_sequence.reshape(1, SEQ_LENGTH, n_features))
    # 反归一化
    pred_rescaled = scaler.inverse_transform(
        np.concatenate((np.zeros((1, 4)), pred), axis=1)
    )[:, 4][0]
    print(f'Predicted Close Price: {pred_rescaled:.2f}')

    # 更新序列,移除最早的一天,添加预测值
    # 这里我们仅更新'Close'价格,其他特征保持不变或进行合理假设
    new_entry = last_sequence[-1].copy()
    new_entry[3] = pred  # 更新'Close'价格
    # 这里简单地将新_entry的其他特征与'Close'价格相同,实际应用中应使用更合理的策略
    last_sequence = np.vstack([last_sequence[1:], new_entry])

十、LSTM的优势与局限

1. 优势

  1. 捕捉长期依赖:通过门控机制,LSTM能够有效地捕捉和保持长期依赖信息,解决了传统RNN的梯度消失问题。
  2. 灵活性高:适用于各种类型的序列数据,如文本、时间序列、音频、视频等。
  3. 稳定的训练过程:相较于传统RNN,LSTM更容易训练,梯度消失和爆炸问题得到缓解。
  4. 强大的表达能力:通过堆叠和双向等变种,LSTM能够捕捉复杂的模式和特征。

2. 局限

  1. 计算复杂度高:LSTM单元包含多个门控机制,参数较多,计算开销较大,导致训练和推理时间较长。
  2. 训练时间长:由于结构复杂,尤其在处理长序列时,训练时间相对较长。
  3. 模型解释性有限:尽管LSTM能够有效地捕捉序列中的依赖关系,但其内部工作机制对于人类来说不够直观,解释性差。
  4. 过拟合风险:在数据量不足的情况下,LSTM容易过拟合,需要采取正则化措施。

总结
长短期记忆网络(LSTM)通过引入遗忘门、输入门和输出门,有效地解决了传统RNN在处理长序列时的梯度消失和梯度爆炸问题,使其能够捕捉和保持长期依赖信息。LSTM广泛应用于自然语言处理、时间序列预测、视频分析等领域,并通过各种变种(如双向LSTM、堆叠LSTM、卷积LSTM等)进一步提升了其性能和适用性。尽管LSTM在处理序列数据方面表现出色,但其计算复杂度和训练时间仍然是需要考虑的因素。随着深度学习技术的不断发展,LSTM及其衍生模型将在更多应用场景中发挥重要作用。

参考文献:
一幅图真正理解LSTM、BiLSTM


结~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2221657.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络安全中的日志审计:为何至关重要?

在数字化时代,网络安全已成为企业和组织不可忽视的重要议题。随着网络攻击手段的不断进化,保护信息系统和数据安全变得日益复杂和具有挑战性。在这种背景下,日志审计作为一种关键的信息安全和网络管理工具,发挥着至关重要的作用。…

RHCE——例行性工作 at、crontab

一.单一执行的列行型工作:仅处理执行一次就结束了 1.at命令的工作过程 (1)/etc/at.allow,写在该文件的人可以使用at命令 (2)/etc/at.deny,黑名单 (3)两个文件如果都…

【Spring篇】Spring的Aop详解

🧸安清h:个人主页 🎥个人专栏:【计算机网络】【Mybatis篇】【Spring篇】 🚦作者简介:一个有趣爱睡觉的intp,期待和更多人分享自己所学知识的真诚大学生。 目录 🎯初始Sprig AOP及…

SVM(支持向量机)

SVM(支持向量机) 引言 支持向量机(Support Vector Machine,SVM),可以用来解答二分类问题。支持向量(Support Vector):把划分数据的决策边界叫做超平面,点到超平面的距离叫做间隔。在SVM中,距离超平面最近…

京东笔试题

和谐敏感词 🔗 题目地址 🎉 模拟 import java.util.Scanner;public class Main {public static void main(String[] args) {Scanner scanner new Scanner(System.in);int n scanner.nextInt();String s scanner.next();String[] words new String[…

Mapbox GL 加载GeoServer底图服务器的WMS source

貌似加载有点慢啊!! 1 这是底图 2 这是加载geoserver中的地图效果 3源码 3.1 geoserver中的网络请求 http://192.168.10.10:8080/geoserver/ne/wms?SERVICEWMS&VERSION1.1.1&REQUESTGetMap&formatimage/png&TRANSPARENTtrue&STYL…

Linux--epoll(ET)实现Reactor模式

Linux–多路转接之epoll Reactor反应堆模式 Reactor反应堆模式是一种事件驱动的设计模式,通常用于处理高并发的I/O操作,尤其是在服务器或网络编程中。 基本概念 Reactor模式又称之为响应器模式,基于事件多路复用机制,使得单个…

网络与信息安全工程师最新报考介绍(工信部教育与考试中心)

文章目录 前言 网络与信息安全工程师职业介绍主要的工作内容职业技能要求网络与信息安全工程师职业前景怎么样网络与信息安全工程师工作方向网络与信息安全工程师适学人群 如何入门学习网络安全 【----帮助网安学习,以下所有学习资料文末免费领取!----】…

solidworks(sw)右侧资源栏变成英文,无法点击

sw右侧资源栏变成英文,无法点击,如图 使用xxclean 的扩展功能 SW右侧栏是英文 toolbox配置无效 这个按钮 修复完成之后重新打开软件查看是否变成中文。

[linux]快速入门

学习目标 通过学习能够掌握以下的linux操作 操作系统 按照应用领域的不同, 操作系统可以分为几类 桌面操作系统服务器操作系统移动设备操作系统嵌入式操作系统 不同领域的主流操作系统 桌面操作系统 Windows(用户数量最多)MacOS(操作体验好,办公人士首选)Linux…

Spring AI : Java写人工智能(LLM)的应用框架

Spring AI:为Java开发者提供高效集成大模型能力的框架 当前Java调用大模型时,面临缺乏优质AI应用框架的挑战。Spring作为资深的Java应用框架提供者,通过推出Spring AI来解决这一问题。它借鉴了langchain的核心理念,并结合了Java面…

解密 Redis:如何通过 IO 多路复用征服高并发挑战!

文章目录 一、什么是 IO 多路复用?二、为什么 Redis 要使用 IO 多路复用?三、Redis 如何实现 IO 多路复用?四、IO 多路复用的核心机制:epoll五、IO 多路复用在 Redis 中的工作流程六、IO 多路复用的优点七、IO 多路复用使用中的注…

安装buildkit,并使用buildkit构建containerd镜像

背景 因为K8s抛弃Docker了,所以就只装了个containerd,这样就需要一个单独的镜像构建工具了,就用了buildkit,这也是Docker公司扶持的,他们公司的人出来搞的开源工具,官网在 https://github.com/moby/buildkit 简介 服务端为buildkitd,负责和runc或containerd后端连接干活,目前…

w~自动驾驶合集6

我自己的原文哦~ https://blog.51cto.com/whaosoft/12286744 #自动驾驶的技术发展路线 端到端自动驾驶 Recent Advancements in End-to-End Autonomous Driving using Deep Learning: A SurveyEnd-to-end Autonomous Driving: Challenges and Frontiers 在线高精地图 HDMa…

windows文件拷贝给wsl2的Ubuntu

参考: windows文件如何直接拖拽到wsl中_win 移到文件到wsl-CSDN博客 cp -r /mnt/盘名/目标文件 要复制到wsl中的位置e.g.cp -r /mnt/d/byt5 /home Linux文件复制、移动、删除等操作命令_linux移动命令-CSDN博客 Linux 文件、文件夹的复制、移动、删除 - Be-myse…

构建后端为etcd的CoreDNS的容器集群(二)、下载最新的etcd容器镜像

在尝试获取etcd的容器的最新版本镜像时,使用latest作为tag取到的并非最新版本,本文尝试用实际最新版本的版本号进行pull,从而取到想的最新版etcd容器镜像。 一、用latest作为tag尝试下载最新etcd的镜像 1、下载镜像 [rootlocalhost opt]# …

多品牌摄像机视频平台EasyCVR海康大华宇视视频平台如何接入多样化设备

在实际的工程项目里,我们常常会面临这样的情况:项目管理者可能会决定使用多个品牌的视频监控摄像头,或者有需求将现有的、多种类型的监控系统进行整合。现在,让我们来探讨一下如何实现不同品牌摄像头的连接和使用。 1、GB/T281协议…

2024版最新148款CTF工具整理大全(附下载安装包)含基础环境、Web 安全、加密解密、密码爆破、文件、隐写、逆向、PWN

经常会有大学生粉丝朋友私信小强,想通过打CTF比赛镀金,作为进入一线互联网大厂的门票。 但是在CTF做题很多的时候都会用到工具,所以在全网苦寻CTF比赛工具安装包! 关于我 有不少阅读过我文章的伙伴都知道,我曾就职于…

SSM框架实战小项目:打造高效用户管理系统 day3

前言 在前两篇博客中,后台已经搭建完毕,现在需要设计一下前端页面 webapp下的项目结构图 创建ftl文件夹,导入css和js 因为我们在后台的视图解析器中,设置了页面解析器,跳转路径为/ftl/*.ftl,所以需要ftl文件…

JAVA开源项目 网上订餐系统 计算机毕业设计

本文项目编号 T 018 ,文末自助获取源码 \color{red}{T018,文末自助获取源码} T018,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 新…