Pytorch搭建循环神经网络RNN(简单实战)
去年写了篇《循环神经网络》,里面主要介绍了循环神经网络的结构与Tensorflow实现。而本篇博客主要介绍基于Pytorch搭建RNN。
通过Sin预测Cos
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
首先,我们定义一些超参数
TIME_STEP = 10 # rnn 时序步长数
INPUT_SIZE = 1 # rnn 的输入维度
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
H_SIZE = 64 # of rnn 隐藏单元个数
EPOCHS = 100 # 总共训练次数
h_state = None # 隐藏层状态
使用Numpy生成Sin和Cos函数
steps = np.linspace(0, np.pi*2, 256, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
可视化数据
plt.figure(1)
plt.suptitle('Sin and Cos', fontsize='18')
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()
定义网络结构
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=H_SIZE,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(H_SIZE, 1)
def forward(self, x, h_state):
r_out, h_state = self.rnn(x, h_state)
outs = [] # 保存所有的预测值
for time_step in range(r_out.size(1)): # 计算每一步长的预测值
outs.append(self.out(r_out[:, time_step, :]))
return torch.stack(outs, dim=1), h_state
rnn = RNN().to(DEVICE)
optimizer = torch.optim.Adam(rnn.parameters()) # Adam优化,几乎不用调参
criterion = nn.MSELoss() # 因为最终的结果是一个数值,所以损失函数用均方误差
rnn.train()
plt.figure(2)
for step in range(EPOCHS):
start, end = step * np.pi, (step+1)*np.pi # 一个时间周期
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis]) # shape (batch, time_step, input_size)
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
x = x.to(DEVICE)
prediction, h_state = rnn(x, h_state) # rnn output
# 这一步非常重要
h_state = h_state.data # 重置隐藏层的状态, 切断和前一次迭代的链接
loss = criterion(prediction.cpu(), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (step+1) % 20 == 0: # 每训练20个批次可视化一下效果,并打印一下loss
print("EPOCHS: {},Loss:{:4f}".format(step, loss))
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.cpu().data.numpy().flatten(), 'b-')
plt.draw()
plt.pause(0.01)
运行结果如下:
EPOCHS: 19,Loss:0.052745
EPOCHS: 39,Loss:0.016266
EPOCHS: 59,Loss:0.005471
EPOCHS: 79,Loss:0.001329
EPOCHS: 99,Loss:0.002216