ConvLSTM
- 1. 算法简介与应用场景
- 2. 算法原理
- 2.1 LSTM基础
- 2.2 ConvLSTM原理
- 2.2.1 ConvLSTM的结构
- 2.2.2 卷积操作的优点
- 2.3 LSTM与ConvLSTM的对比分析
- 2.4 ConvLSTM的应用
- 3. PyTorch代码
- 参考文献
仅需要网络源码的可以直接跳到末尾即可
1. 算法简介与应用场景
ConvLSTM(卷积长短期记忆网络)是一种结合了卷积神经网络(CNN)和长短期记忆网络(LSTM)优势的深度学习模型。它主要用于处理时空数据,特别适用于需要考虑空间特征和时间依赖关系的任务,如气象预测、视频分析、交通流量预测等。
在气象预测中,ConvLSTM可以根据过去的气象数据(如降水、温度等)预测未来的天气情况。在视频分析中,它可以帮助识别视频中的活动或事件,利用时间序列的连续性和空间信息进行更准确的分析。
2. 算法原理
2.1 LSTM基础
在介绍ConvLSTM之前,先让我们来回归一下什么是长短期记忆网络(LSTM)。LSTM是一种特殊的循环神经网络(RNN),它通过引入门控机制解决了传统RNN在长序列训练中面临的梯度消失和爆炸问题。LSTM单元主要包含三个门:输入门、遗忘门和输出门。这些门控制着信息在单元中的流动,从而有效地记住或遗忘信息。
LSTM的核心公式如下:
-
遗忘门:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf) -
输入门:
i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)
C ~ t = tanh ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC) -
单元状态更新:
C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t \ast C_{t-1} + i_t \ast \tilde{C}_t Ct=ft∗Ct−1+it∗C~t -
输出门:
o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)
h t = o t ∗ tanh ( C t ) h_t = o_t \ast \tanh(C_t) ht=ot∗tanh(Ct)
这里, C t C_t Ct 是当前的单元状态, h t h_t ht 是当前的隐藏状态, x t x_t xt 是当前的输入。
2.2 ConvLSTM原理
ConvLSTM在LSTM的基础上引入了卷积操作。传统的LSTM使用全连接层处理输入数据,而ConvLSTM则采用卷积层来处理空间数据。这样,ConvLSTM能够更好地捕捉输入数据中的空间特征。
2.2.1 ConvLSTM的结构
ConvLSTM的单元结构与LSTM非常相似,但是在每个门的计算中使用了卷积操作。具体来说,ConvLSTM的每个门的公式可以表示为:
i
t
=
σ
(
W
x
i
∗
X
t
+
W
h
i
∗
H
t
−
1
+
W
c
i
∘
C
t
−
1
+
b
i
)
i_t = \sigma (W_{xi} * X_t + W_{hi} * H_{t-1} + W_{ci} \circ C_{t-1} + b_i)
it=σ(Wxi∗Xt+Whi∗Ht−1+Wci∘Ct−1+bi)
f
t
=
σ
(
W
x
f
∗
X
t
+
W
h
f
∗
H
t
−
1
+
W
c
f
∘
C
t
−
1
+
b
f
)
f_t = \sigma (W_{xf} * X_t + W_{hf} * H_{t-1} + W_{cf} \circ C_{t-1} + b_f)
ft=σ(Wxf∗Xt+Whf∗Ht−1+Wcf∘Ct−1+bf)
C
t
=
f
t
∘
C
t
−
1
+
i
t
∘
t
a
n
h
(
W
x
c
∗
X
t
+
W
h
c
∗
H
t
−
1
+
b
c
)
C_t = f_t \circ C_{t-1} + i_t \circ tanh(W_{xc} * X_t + W_{hc} * H_{t-1} + b_c)
Ct=ft∘Ct−1+it∘tanh(Wxc∗Xt+Whc∗Ht−1+bc)
o
t
=
σ
(
W
x
o
∗
X
t
+
W
h
o
∗
H
t
−
1
+
W
c
o
∘
C
t
+
b
o
)
o_t = \sigma (W_{xo} * X_t + W_{ho} * H_{t-1} + W_{co} \circ C_t + b_o)
ot=σ(Wxo∗Xt+Who∗Ht−1+Wco∘Ct+bo)
H
t
=
o
t
∘
t
a
n
h
(
C
t
)
H_t = o_t \circ tanh(C_t)
Ht=ot∘tanh(Ct)
这里的 所有
W
W
W都是是卷积权重,
b
b
b是偏置项,
σ
\sigma
σ 是 sigmoid 函数,
tanh
\tanh
tanh 是双曲正切函数。。
2.2.2 卷积操作的优点
-
空间特征提取:卷积操作能够有效提取输入数据中的空间特征。对于图像数据,卷积操作可以捕捉局部特征,例如边缘、纹理等,这在时间序列数据中同样适用。
-
参数共享:卷积操作通过使用相同的卷积核在不同位置计算特征,从而减少了模型参数的数量,降低了计算复杂度。
-
平移不变性:卷积网络对输入数据的平移具有不变性,即相同的特征在不同位置都会被检测到,这对于时空序列数据来说是非常重要的。
2.3 LSTM与ConvLSTM的对比分析
特性 | LSTM | ConvLSTM |
---|---|---|
输入类型 | 一维序列 | 三维数据(时序的图像数据) |
处理方式 | 全连接层 | 卷积操作 |
空间特征捕捉 | 较弱 | 较强 |
应用场景 | 自然语言处理、时间序列预测 | 图像序列预测、视频分析 |
2.4 ConvLSTM的应用
ConvLSTM在多个领域中表现出色,特别适合处理具有时空特征的数据。以下是一些主要的应用场景:
- 气象预测:利用历史气象数据(如温度、湿度、降水等)来预测未来的天气情况。
- 视频分析:对视频中的动态场景进行建模,识别和预测视频中的活动。
- 交通流量预测:基于历史交通数据预测未来的交通流量,帮助城市交通管理。
- 医学影像分析:分析医学影像序列(如CT、MRI)中的变化,辅助疾病诊断。
3. PyTorch代码
以下是ConvLSTM的完整代码,可以直接拿来用:
import torch.nn as nn
import torch
class ConvLSTMCell(nn.Module):
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
"""
初始化卷积 LSTM 单元。
参数:
----------
input_dim: int
输入张量的通道数。
hidden_dim: int
隐藏状态的通道数。
kernel_size: (int, int)
卷积核的大小。
bias: bool
是否添加偏置项。
"""
super(ConvLSTMCell, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
# 计算填充大小以保持输入和输出尺寸一致
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
# 定义卷积层,输入是输入维度加上隐藏维度,输出是4倍的隐藏维度(对应i, f, o, g)
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
out_channels=4 * self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias)
def forward(self, input_tensor, cur_state):
h_cur, c_cur = cur_state
# 沿着通道轴进行拼接
combined = torch.cat([input_tensor, h_cur], dim=1)
combined_conv = self.conv(combined)
# 将输出分割成四个部分,分别对应输入门、遗忘门、输出门和候选单元状态
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
# 更新单元状态
c_next = f * c_cur + i * g
# 更新隐藏状态
h_next = o * torch.tanh(c_next)
return h_next, c_next
def init_hidden(self, batch_size, image_size):
height, width = image_size
# 初始化隐藏状态和单元状态为零
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
class ConvLSTM(nn.Module):
"""
卷积 LSTM 层。
参数:
----------
input_dim: 输入通道数
hidden_dim: 隐藏通道数
kernel_size: 卷积核大小
num_layers: LSTM 层的数量
batch_first: 批次是否在第一维
bias: 卷积中是否有偏置项
return_all_layers: 是否返回所有层的计算结果
输入:
------
一个形状为 B, T, C, H, W 或者 T, B, C, H, W 的张量
输出:
------
元组包含两个列表(长度为 num_layers 或者长度为 1 如果 return_all_layers 为 False):
0 - layer_output_list 是长度为 T 的每个输出的列表
1 - last_state_list 是最后的状态列表,其中每个元素是一个 (h, c) 对应隐藏状态和记忆状态
示例:
>>> x = torch.rand((32, 10, 64, 128, 128))
>>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
>>> _, last_states = convlstm(x)
>>> h = last_states[0][0] # 0 表示层索引,0 表示 h 索引
"""
def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
batch_first=False, bias=True, return_all_layers=False):
super(ConvLSTM, self).__init__()
# 检查 kernel_size 的一致性
self._check_kernel_size_consistency(kernel_size)
# 确保 kernel_size 和 hidden_dim 的长度与层数一致
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
if not len(kernel_size) == len(hidden_dim) == num_layers:
raise ValueError('不一致的列表长度。')
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.num_layers = num_layers
self.batch_first = batch_first
self.bias = bias
self.return_all_layers = return_all_layers
# 创建 ConvLSTMCell 列表
cell_list = []
for i in range(0, self.num_layers):
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
kernel_size=self.kernel_size[i],
bias=self.bias))
self.cell_list = nn.ModuleList(cell_list)
def forward(self, input_tensor, hidden_state=None):
"""
前向传播函数。
参数:
----------
input_tensor: 输入张量,形状为 (t, b, c, h, w) 或者 (b, t, c, h, w)
hidden_state: 初始隐藏状态,默认为 None
返回:
-------
last_state_list, layer_output
"""
if not self.batch_first:
# 改变输入张量的顺序,如果 batch_first 为 False
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
b, _, _, h, w = input_tensor.size()
# 实现状态化的 ConvLSTM
if hidden_state is not None:
raise NotImplementedError()
else:
# 初始化隐藏状态
hidden_state = self._init_hidden(batch_size=b,
image_size=(h, w))
layer_output_list = []
last_state_list = []
seq_len = input_tensor.size(1)
cur_layer_input = input_tensor
for layer_idx in range(self.num_layers):
h, c = hidden_state[layer_idx]
output_inner = []
for t in range(seq_len):
# 在每个时间步上更新状态
h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
cur_state=[h, c])
output_inner.append(h)
# 将输出堆叠起来
layer_output = torch.stack(output_inner, dim=1)
cur_layer_input = layer_output
layer_output_list.append(layer_output)
last_state_list.append([h, c])
if not self.return_all_layers:
# 如果不需要返回所有层,则只返回最后一层的输出和状态
layer_output_list = layer_output_list[-1:]
last_state_list = last_state_list[-1:]
return layer_output_list, last_state_list
def _init_hidden(self, batch_size, image_size):
init_states = []
for i in range(self.num_layers):
# 初始化每一层的隐藏状态
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
return init_states
@staticmethod
def _check_kernel_size_consistency(kernel_size):
if not (isinstance(kernel_size, tuple) or
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
raise ValueError('`kernel_size` 必须是 tuple 或者 list of tuples')
@staticmethod
def _extend_for_multilayer(param, num_layers):
if not isinstance(param, list):
param = [param] * num_layers
return param
参考文献
[1]Shi, X., Chen, Z., Wang, H., Yeung, D. Y., Wong, W. K., & Woo, W. (2015). Convolutional LSTM Network: A Machine Learning [2]Approach for Precipitation Nowcasting. Advances in Neural Information Processing Systems, 28.
[3]Hochreiter, S., & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735-1780.
Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.