文章目录
- LSTM 简单介绍
- LSTM的基本结构
- LSTM的工作原理
- 输入门
- 遗忘门
- 输出门
- 细胞状态更新
- 输出计算
- 总结
- 代码实例
LSTM 简单介绍
在自然语言处理、语音识别等领域,长短时记忆网络 (Long Short-Term Memory, LSTM) 已经成为了常用的模型之一。本文将介绍 LSTM 的基本结构和工作原理,帮助读者入门 LSTM。
LSTM的基本结构
LSTM 是一种递归神经网络(Recurrent Neural Network, RNN)。与传统 RNN 不同,LSTM 在网络的循环单元中引入了三个门控,以控制输入、输出和记忆的流程。一个 LSTM 单元的基本结构如下图所示:
图片来自博客:Pytorch入门之RNN_pytorch rnn_Ton10的博客-CSDN博客
其中, x t x_t xt 表示输入, h t h_t ht 表示输出, c t c_t ct 表示细胞状态(cell state)。 σ \sigma σ 表示 sigmoid 函数, ⊙ \odot ⊙ 表示逐元素相乘。
从图中可以看到,LSTM 单元有三个门控:输入门、遗忘门和输出门。它们分别控制输入、遗忘和输出的流程。
LSTM的工作原理
下面我们将详细介绍 LSTM 的工作原理。在讲解之前,我们先定义一些符号:
- x t x_t xt:时刻 t t t 的输入;
- h t h_t ht:时刻 t t t 的输出;(输出t时刻的隐藏状态,该状态在最终时刻是LSTM系统的输出)
- c t c_t ct:时刻 t t t 的细胞状态(cell state);
- i t i_t it:时刻 t t t 的输入门(input gate);
- f t f_t ft:时刻 t t t 的遗忘门(forget gate);
- o t o_t ot:时刻 t t t 的输出门(output gate);
- W W W:权重矩阵;
- b b b:偏置向量;
- σ ( x ) \sigma(x) σ(x):sigmoid 函数;
- tanh ( x ) \tanh(x) tanh(x):tanh 函数。
输入门
输入门控制着当前输入
x
t
x_t
xt 对于细胞状态的影响程度。输入门的计算方式如下:
i
t
=
σ
(
W
i
x
t
+
U
i
h
t
−
1
+
b
i
)
i_t=\sigma(W_ix_t+U_ih_{t-1}+b_i)
it=σ(Wixt+Uiht−1+bi)
其中,
W
i
W_i
Wi 和
U
i
U_i
Ui 分别是输入
x
t
x_t
xt 和上一时刻输出
h
t
−
1
h_{t-1}
ht−1 的权重矩阵,
b
i
b_i
bi 是输入门的偏置向量。
σ
\sigma
σ 表示 sigmoid 函数。
输入门的输出结果将与
c
t
~
\tilde{c_t}
ct~ 相乘,作为新的细胞状态的一部分。其中,
c
t
~
\tilde{c_t}
ct~ 的计算方式如下:
c
t
~
=
t
a
n
h
(
W
c
x
t
+
U
c
h
t
−
1
+
b
c
)
\tilde{c_t}=\mathrm{tanh}(W_cx_t+U_ch_{t-1}+b_c)
ct~=tanh(Wcxt+Ucht−1+bc)
公式中的符号代表的含义如下:
- x t x_t xt:表示当前时刻的输入,即该时刻的输入特征向量。
- h t − 1 h_{t-1} ht−1:表示上一个时刻的输出,即输出特征向量。
- W i W_i Wi 和 U i U_i Ui:分别是输入 x t x_t xt 和上一时刻输出 h t − 1 h_{t-1} ht−1 的权重矩阵,用于计算输入门的输出 i t i_t it。
- b i b_i bi:表示输入门的偏置向量,用于计算输入门的输出 i t i_t it。
- σ \sigma σ:表示 sigmoid 函数,用于将输入门的输入值 W i x t + U i h t − 1 + b i W_ix_t+U_ih_{t-1}+b_i Wixt+Uiht−1+bi 映射到一个介于 0 到 1 之间的范围内,表示当前输入 x t x_t xt 对于细胞状态的影响程度。
- i t i_t it:表示输入门的输出,代表当前输入对于细胞状态的影响程度。
- c t ~ \tilde{c_t} ct~:表示更新后的细胞状态的一部分,用于更新当前时刻的细胞状态。
- W c W_c Wc 和 U c U_c Uc:分别是输入 x t x_t xt 和上一时刻输出 h t − 1 h_{t-1} ht−1 的权重矩阵,用于计算 c t ~ \tilde{c_t} ct~。
- b c b_c bc:表示 c t ~ \tilde{c_t} ct~ 的偏置向量。
- t a n h \mathrm{tanh} tanh:表示双曲正切函数,用于将 c t ~ \tilde{c_t} ct~ 的输入值 W c x t + U c h t − 1 + b c W_cx_t+U_ch_{t-1}+b_c Wcxt+Ucht−1+bc 映射到介于 -1 到 1 之间的范围内,表示当前时刻的输入和上一时刻的输出对于更新后的细胞状态的影响程度。
遗忘门
遗忘门控制着上一时刻的细胞状态
c
t
−
1
c_{t-1}
ct−1 对于当前细胞状态
c
t
c_t
ct 的影响程度。遗忘门的计算方式如下:
f
t
=
σ
(
W
f
x
t
+
U
f
h
t
−
1
+
b
f
)
f_t=\sigma(W_fx_t+U_fh_{t-1}+b_f)
ft=σ(Wfxt+Ufht−1+bf)
其中,
W
f
W_f
Wf 和
U
f
U_f
Uf 分别是输入
x
t
x_t
xt 和上一时刻输出
h
t
−
1
h_{t-1}
ht−1 的权重矩阵,
b
f
b_f
bf 是遗忘门的偏置向量。
σ
\sigma
σ 表示 sigmoid 函数。
遗忘门的输出结果将与上一时刻的细胞状态 c t − 1 c_{t-1} ct−1 相乘,作为新的细胞状态的一部分。
输出门
输出门控制着当前细胞状态
c
t
c_t
ct 对于输出
h
t
h_t
ht 的影响程度。输出门的计算方式如下:
o
t
=
σ
(
W
o
x
t
+
W
o
h
t
−
1
+
b
o
)
o_t=\sigma(W_ox_t+W_oh_{t-1}+b_o)
ot=σ(Woxt+Woht−1+bo)
其中,
W
o
W_o
Wo 和
U
o
U_o
Uo 分别是输入
x
t
x_t
xt 和上一时刻输出
h
t
−
1
h_{t-1}
ht−1 的权重矩阵,
b
o
b_o
bo 是输出门的偏置向量。
σ
\sigma
σ 表示 sigmoid 函数。
输出门的输出结果将与 tanh ( c t ) \tanh(c_t) tanh(ct) 相乘,作为当前时刻的输出 h t h_t ht。其中, tanh \tanh tanh 表示 tanh 函数。
细胞状态更新
在输入门、遗忘门和输出门的计算过程中,LSTM 还需要更新当前的细胞状态
c
t
c_t
ct。细胞状态的更新方式如下:
c
t
=
f
t
⊙
c
t
−
1
+
i
t
⊙
c
t
~
c_t=f_t\odot c_{t-1}+i_t\odot \tilde{c_t}
ct=ft⊙ct−1+it⊙ct~
其中,
⊙
\odot
⊙ 表示逐元素相乘。
f
t
⊙
c
t
−
1
f_t \odot c_{t-1}
ft⊙ct−1 表示上一时刻的细胞状态通过遗忘门传递到当前时刻,
i
t
⊙
c
t
~
i_t \odot \tilde{c_t}
it⊙ct~ 表示当前时刻的输入通过输入门影响到细胞状态。最终得到的
c
t
c_t
ct 即为当前时刻的细胞状态。
输出计算
最后,LSTM 的输出计算方式如下:
h
t
=
o
t
⊙
t
a
n
h
(
c
t
)
h_t=o_t\odot \mathrm{tanh}(c_t)
ht=ot⊙tanh(ct)
其中,
o
t
o_t
ot 表示输出门的输出结果,
tanh
(
c
t
)
\tanh(c_t)
tanh(ct) 表示细胞状态经过 tanh 函数处理后的结果。最终得到的
h
t
h_t
ht 即为当前时刻的输出。
总结
LSTM 是一种递归神经网络,用于处理序列数据。LSTM 在循环单元中引入了三个门控,以控制输入、输出和记忆的流程。LSTM 的基本结构包括输入门、遗忘门、输出门和细胞状态,通过门控机制实现了对于输入、输出和记忆的精细控制。
希望这份 LSTM 的入门教程能够帮助你更好地理解 LSTM 的原理和运作方式。如果你想深入学习 LSTM,可以了解 LSTM 的变种结构,如 peephole LSTM、GRU 等,以及应用场景和实现细节。
同时,如果你想进一步了解深度学习,还可以学习其他类型的神经网络,如卷积神经网络、自编码器、生成对抗网络等。神经网络在计算机视觉、自然语言处理、语音识别等领域都得到了广泛的应用,是人工智能领域中不可或缺的技术之一。
代码实例
使用PyTorch框架写的LSTM网络,用于对FashionMNIST数据处理。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# 设置设备为GPU,只有当有可用GPU时才使用,否则使用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 定义LSTM模型
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# 初始化隐藏状态和细胞状态
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# 前向传播 LSTM
out, _ = self.lstm(x, (h0, c0))
# 取最后一个时间步的输出
out = self.fc(out[:, -1, :])
return out
# 超参数
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 5
# 加载FashionMNIST数据集并进行预处理
train_dataset = torchvision.datasets.FashionMNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.FashionMNIST(root='./data',
train=False,
transform=transforms.ToTensor())
# 数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# 初始化模型并移动模型到GPU(如果可用)
model = LSTM(input_size, hidden_size, num_layers, num_classes).to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# 将图像数据移动到GPU(如果可用)
images = images.reshape(-1, 28, 28).to(device)
labels = labels.to(device)
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化器步骤
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 每隔100个批次打印一次训练信息
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# 测试模型
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
# 将图像数据移动到GPU(如果可用)
images = images.reshape(-1, 28, 28).to(device)
labels = labels.to(device)
# 前向传播
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
# 统计预测正确的样本数和总样本数
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))