【机器学习-神经网络】循环神经网络

news2024/9/22 23:38:48

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ Python机器学习 ⌋ ⌋ 机器学习是一门人工智能的分支学科,通过算法和模型让计算机从数据中学习,进行模型训练和优化,做出预测、分类和决策支持。Python成为机器学习的首选语言,依赖于强大的开源库如Scikit-learn、TensorFlow和PyTorch。本专栏介绍机器学习的相关算法以及基于Python的算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/Python_machine_learning。

文章目录

    • 一、循环神经网络的基本原理
    • 二、门控循环单元
    • 三、动手实现GRU


  在前面两篇文章中,我们分别介绍了神经网络的基础概念和最简单的MLP,以及适用于图像处理的CNN。从中我们可以意识到,不同结构的神经网络具有不同的特点,在不同任务上具有自己的优势。例如MLP复杂度低、训练简单、适用范围广,适合解决普通任务或作为大型网络的小模块;CNN可以捕捉到输入中不同尺度的关联信息,适合从图像中提取特征。而对于具有序列特征的数据,例如一年内随时间变化的温度、一篇文章中的文字等,它们具有明显的前后关联。然而这些关联的数据在序列中出现的位置可能间隔非常远,例如文章在开头和结尾描写了同一个事物,如果用CNN来提取这些关联的话,其卷积核的大小需要和序列的长度相匹配。当数据序列较长时,这种做法会大大增加网络复杂度和训练难度。因此,我们需要引入一种新的网络结构,使其能够充分利用数据的序列性质,从前到后分析数据、提取关联。这就是本文要介绍的循环神经网络(recurrent neural networks,RNN)。

一、循环神经网络的基本原理

  我们先从最简单的模型开始考虑。对于不存在序列关系的数据,我们采用一个两层的MLP来拟合它,如图1(a)所示,输入样本为 x \boldsymbol x x,经过第一个权重为 W i \boldsymbol W_i Wi b i \boldsymbol b_i bi的隐层得到中间向量 h = f h ( W i x + b i ) \boldsymbol h = \boldsymbol f_h(\boldsymbol W_i\boldsymbol x+\boldsymbol b_i) h=fh(Wix+bi),再经过权重为 W o \boldsymbol W_o Wo b o \boldsymbol b_o bo的隐层得到输出 y = f o ( W o h + b o ) \boldsymbol y = f_o(\boldsymbol W_o\boldsymbol h+\boldsymbol b_o) y=fo(Woh+bo),其中 f h f_h fh f o f_o fo为激活函数。这是一个标准的MLP的预测流程。

在这里插入图片描述

图1 从MLP到RNN

  假设数据集中的数据分别是在时刻1和时刻2采集到的,并且我们知道时刻2的结果与时刻1有关。这时,由于两个时刻的数据产生了依赖关系,如果我们用相同的模型权重来进行预测而忽略其关联,预测的准确度就会降低。为了利用上额外的关联信息,我们将MLP的结构拓展一下,如图1(b)所示,第二个MLP的中间向量与一般的MLP不同。在计算时刻2的中间向量 h 2 \boldsymbol h_2 h2时,我们将时刻1的中间向量 h 1 \boldsymbol h_1 h1也纳入进来,得到 h 2 = f h ( W h h 1 + W i x 2 + b i ) \boldsymbol h_2 = f_h(\boldsymbol W_h\boldsymbol h_1+\boldsymbol W_i\boldsymbol x_2+\boldsymbol b_i) h2=fh(Whh1+Wix2+bi),再将 h 2 \boldsymbol h_2 h2传给第二个隐层,计算出输出 y 2 = f o ( W o h 2 + b o ) y_2=f_o(\boldsymbol W_o\boldsymbol h_2+\boldsymbol b_o) y2=fo(Woh2+bo)。这样,我们就在时刻2的预测中用到了时刻1的信息。如果将这种思想进一步扩展,如图1(c)所示,我们可以将MLP沿着序列不断扩展下去,中间的每个MLP都将上一时刻的中间向量 h t − 1 \boldsymbol h_{t-1} ht1与当前的输入 x t \boldsymbol x_t xt组合得到中间向量,再进行后续处理。同时,由于序列中每一位置之间又存在对称性,为了减小网络的复杂度,每一MLP前后的权重与中间组合的权重可以共用,不随序列位置变化。因此,这样重复的网络结构可以用图2中的循环来表示,称为循环神经网络。

在这里插入图片描述

图2 RNN的循环表示

  RNN的输入与输出并不一定要像上面展示的一样,在每一时刻都有一个输入样本和一个预测输出。根据任务的不同,RNN的输入输出对应可以有多种形式。图3展示了一些不同对应形式的RNN结构,从左到右依次是一对多、多对一、同步多对多和异步多对多,它们都有合适的任务场景。例如,如果我们要根据一个关键词生成一句话,以词语作为最小单元,那么RNN的输入只有一个,而生成的句子需要有连贯的含义和语义,因此可以利用RNN在每一时刻输出一个词,从前到后连成完整的句子。这样的任务就更适合采用一对多的结构。再比如,常见的时间序列预测任务需要我们根据一段时间中收集的数据,预测接下来一定时间内数据的情况。这时,我们就可以用异步多对多的结构,先分析样本的规律和特征,再生成紧接着样本所在时间之后的结果。

在这里插入图片描述

图3 适用与不同任务的RNN结构

  当我们训练RNN时,由于每一时刻的中间向量都会组合上一时刻的中间向量,如果把时刻 t t t的中间向量全部展开,就得到
h t = f h ( W h h t − 1 + W i x t + b i ) = f h ( W h f h ( W h h t − 2 + W i x t − 1 + b i ) + W i x t + b i ) = ⋯ = f h ( W h f h ( ⋯ W h f h ( W h ( W i x 1 + b i ) + W i x 2 + b i ) ⋯   ) + W i x t + b i ) \begin{aligned} \boldsymbol h_t &= f_h(\boldsymbol W_h\boldsymbol h_{t-1}+\boldsymbol W_i\boldsymbol x_t+\boldsymbol b_i) \\ &= f_h(\boldsymbol W_hf_h(\boldsymbol W_h\boldsymbol h_{t-2}+\boldsymbol W_i\boldsymbol x_{t-1}+\boldsymbol b_i)+\boldsymbol W_i\boldsymbol x_t+\boldsymbol b_i) \\ &= \cdots \\ &= f_h(\boldsymbol W_hf_h(\cdots\boldsymbol W_hf_h(\boldsymbol W_h(\boldsymbol W_i\boldsymbol x_1+\boldsymbol b_i)+\boldsymbol W_i\boldsymbol x_2+\boldsymbol b_i)\cdots)+\boldsymbol W_i\boldsymbol x_t+\boldsymbol b_i) \end{aligned} ht=fh(Whht1+Wixt+bi)=fh(Whfh(Whht2+Wixt1+bi)+Wixt+bi)==fh(Whfh(Whfh(Wh(Wix1+bi)+Wix2+bi))+Wixt+bi)

  如果在时刻 t t t存在输出,我们可计算时刻 t t t的损失函数,并使用梯度回传方法优化参数。然而,随着反向传播的步数增加,RNN有可能会出现梯度消失或梯度爆炸的现象。为了详细解释这一现象,我们考虑时刻 t t t的损失 L t \mathcal L_t Lt关于参数 W i \boldsymbol W_i Wi的导数。根据求导的链式法则,我们可以计算如下:
∂ L t ∂ W i = ∂ L t ∂ y t ∂ y t ∂ W i = ∂ L t ∂ y t ∂ y t ∂ h t d h t d W i = ∂ L t ∂ y t ∂ y t ∂ h t ( ∂ h t ∂ W i + ∂ h t ∂ h t − 1 d h t − 1 d W i ) = ⋯ = ∂ L t ∂ y t ∂ y t ∂ h t ( ∂ h t ∂ W i + ∂ h t ∂ h t − 1 ∂ h t − 1 ∂ W i + ⋯ + ∂ h t ∂ h t − 1 ∂ h t − 1 ∂ h t − 2 ⋯ ∂ h 2 ∂ h 1 ∂ h 1 ∂ W i ) = ∂ L t ∂ y t ∂ y t ∂ h t ( ∂ h t ∂ W i + ∑ j = 1 t − 1 ( ∏ k = j + 1 t ∂ h k ∂ h k − 1 ) ∂ h j ∂ W i ) \begin{aligned} \frac{\partial\mathcal L_t}{\partial\boldsymbol W_i} &= \frac{\partial\mathcal L_t}{\partial y_t}\frac{\partial y_t}{\partial\boldsymbol W_i} \\[2ex] &= \frac{\partial\mathcal L_t}{\partial y_t}\frac{\partial y_t}{\partial\boldsymbol h_t}\frac{\mathrm d\boldsymbol h_t}{\mathrm d\boldsymbol W_i} \\[2ex] &= \frac{\partial\mathcal L_t}{\partial y_t}\frac{\partial y_t}{\partial\boldsymbol h_t}\left(\frac{\partial\boldsymbol h_t}{\partial\boldsymbol W_i}+\frac{\partial\boldsymbol h_t}{\partial\boldsymbol h_{t-1}}\frac{\mathrm d\boldsymbol h_{t-1}}{\mathrm d\boldsymbol W_i}\right) \\[2ex] &= \cdots \\[1ex] &= \frac{\partial\mathcal L_t}{\partial y_t}\frac{\partial y_t}{\partial\boldsymbol h_t}\left(\frac{\partial\boldsymbol h_t}{\partial\boldsymbol W_i}+\frac{\partial\boldsymbol h_t}{\partial\boldsymbol h_{t-1}}\frac{\partial\boldsymbol h_{t-1}}{\partial\boldsymbol W_i}+\cdots+\frac{\partial\boldsymbol h_t}{\partial\boldsymbol h_{t-1}}\frac{\partial\boldsymbol h_{t-1}}{\partial\boldsymbol h_{t-2}}\cdots\frac{\partial\boldsymbol h_2}{\partial\boldsymbol h_1}\frac{\partial\boldsymbol h_1}{\partial\boldsymbol W_i}\right) \\[2ex] &= \frac{\partial\mathcal L_t}{\partial y_t}\frac{\partial y_t}{\partial\boldsymbol h_t}\left(\frac{\partial\boldsymbol h_t}{\partial\boldsymbol W_i}+\sum_{j=1}^{t-1}\left(\prod_{k=j+1}^t\frac{\partial\boldsymbol h_k}{\partial\boldsymbol h_{k-1}}\right)\frac{\partial\boldsymbol h_j}{\partial\boldsymbol W_i}\right) \end{aligned} WiLt=ytLtWiyt=ytLthtytdWidht=ytLthtyt(Wiht+ht1htdWidht1)==ytLthtyt(Wiht+ht1htWiht1++ht1htht2ht1h1h2Wih1)=ytLthtyt Wiht+j=1t1 k=j+1thk1hk Wihj ∂ h k ∂ h k − 1 = f h ′ W h \begin{aligned}\frac{\partial\boldsymbol h_k}{\partial\boldsymbol h_{k-1}}=f'_h\boldsymbol W_h\end{aligned} hk1hk=fhWh ∂ h k ∂ W i = x k \begin{aligned}\frac{\partial\boldsymbol h_k}{\partial\boldsymbol W_i}=\boldsymbol x_k\end{aligned} Wihk=xk 代入,就得到 ∂ L t ∂ W i = ∂ L t ∂ y t ∂ y t ∂ h t ( x t + ∑ j = 1 t − 1 ( ∏ k = j + 1 t f h ′ W h ) x j ) \frac{\partial\mathcal L_t}{\partial\boldsymbol W_i} = \frac{\partial\mathcal L_t}{\partial y_t}\frac{\partial y_t}{\partial\boldsymbol h_t}\left(\boldsymbol x_t+\sum_{j=1}^{t-1}\left(\prod_{k=j+1}^tf'_h\boldsymbol W_h\right)\boldsymbol x_j\right) WiLt=ytLthtyt xt+j=1t1 k=j+1tfhWh xj

  观察上式可以发现,梯度中会出现一些 f h ′ W h f'_h\boldsymbol W_h fhWh的连乘项。如果 f h ′ W h < 1 f'_h\boldsymbol W_h<1 fhWh<1,当时刻 t t t与时刻 j j j距离较远时,该连乘的值就会趋近于0,因此由时刻 t t t的损失函数计算出的梯度在回传时会逐渐消失;反之,如果 f h ′ W h > 1 f'_h\boldsymbol W_h>1 fhWh>1,该连乘会趋于无穷大,梯度在回传时会出现发散的现象。我们将这两种情况分别称为梯度消失梯度爆炸。无论出现哪种情况,网络的参数都无法正常更新,模型的性能也会大打折扣。当出现梯度消失时,时刻 t t t的梯度只能影响时刻 t t t之前的少数几步,而无法影响到较远的位置。换句话说,距离时刻 t t t较远的信息已经丢失,模型很难捕捉到序列中的长期关联。而当出现梯度爆炸时,网络的梯度会迅速发散,出现数值溢出等错误。

  为了防止上述现象发生,最简单的做法是对梯度进行裁剪,为梯度设置上限和下限,当梯度过大或过小时,直接用上下限来代替梯度的值。但是,这种做法在复杂情况下仍然会导致信息丢失,通常只作为一种辅助手段。我们还可以选用合适的激活函数 f h f_h fh并调整网络参数 W h \boldsymbol W_h Wh初始化的值,使得乘积 f h ′ W h f'_h\boldsymbol W_h fhWh始终稳定在1附近。但是,随着网络参数不断更新, W h \boldsymbol W_h Wh总会变化,要始终控制它们的乘积比较困难。因此,我们可以将网络中关联起相邻两步的 f h f_h fh W h \boldsymbol W_h Wh扩展成一个小的网络,通过设计其结构来达到稳定梯度的目的。

二、门控循环单元

  本节,我们就来介绍一种较为简单的设计——门控循环单元(gated recurrent unit,GRU)。为了解决梯度消失与梯度爆炸的问题,GRU在普通RNN的设计上进行改进,通过门控单元来调整 h t \boldsymbol h_t ht h t − 1 \boldsymbol h_{t-1} ht1的关系。我们不妨将输入 x t \boldsymbol x_t xt理解为外部输入的信息, h t \boldsymbol h_t ht理解为网络记住的信息,它从时刻1的 h 1 \boldsymbol h_1 h1开始向后传递。然而,由于模型本身复杂度的限制,模型并不需要、也无法将所有时刻的信息都保留下来。因此,在由上一时刻的信息 h t − 1 \boldsymbol h_{t-1} ht1计算 h t \boldsymbol h_t ht时,必须有选择地进行遗忘。同时,在时刻 t t t有新的信息 x t \boldsymbol x_t xt输入进网络,我们需要在过去的信息 h t − 1 \boldsymbol h_{t-1} ht1与新信息 x t \boldsymbol x_t xt之间做到平衡。

  图4展示了GRU单元的内部结构,GRU设置的门控单元共有两个,分别称为更新门重置门。每个门控单元输出一个数值或向量,由上一时刻的信息 h t − 1 \boldsymbol h_{t-1} ht1和当前时刻的输入 x t \boldsymbol x_t xt组合计算得到
z t = σ ( W z x t + U z h t − 1 + b z ) r t = σ ( W r x t + U r h t − 1 + b r ) \begin{aligned} \boldsymbol z_t = \sigma(\boldsymbol W_z\boldsymbol x_t+\boldsymbol U_z\boldsymbol h_{t-1}+\boldsymbol b_z) \\ \boldsymbol r_t = \sigma(\boldsymbol W_r\boldsymbol x_t+\boldsymbol U_r\boldsymbol h_{t-1}+\boldsymbol b_r) \end{aligned} zt=σ(Wzxt+Uzht1+bz)rt=σ(Wrxt+Urht1+br) 其中, z t \boldsymbol z_t zt是更新单元, r t \boldsymbol r_t rt是重置单元, W z \boldsymbol W_z Wz W r \boldsymbol W_r Wr U z \boldsymbol U_z Uz U r \boldsymbol U_r Ur b z \boldsymbol b_z bz b r \boldsymbol b_r br都是网络的参数, σ \sigma σ是逻辑斯谛函数,从而门控单元的值都在 ( 0 , 1 ) (0,1) (0,1)区间内。

在这里插入图片描述

图4 GRU结构示意

  虽然这两个单元的计算方式完全相同,但是接下来它们会发挥不同的作用。利用重置单元 r t \boldsymbol r_t rt,我们对过去的信息 h t − 1 \boldsymbol h_{t-1} ht1进行选择性遗忘: h t − 1 ′ = r t ⊙ h t − 1 \boldsymbol h'_{t-1}=\boldsymbol r_t\odot\boldsymbol h_{t-1} ht1=rtht1 其中, ⊙ \odot 称为阿达马积(Hadamard product),表示向量或矩阵的逐元素相乘。例如,形状均为 m × n m\times n m×n 的矩阵 A \boldsymbol A A B \boldsymbol B B的阿达玛积为
A ⊙ B = ( a 11 b 11 a 12 b 12 ⋯ a 1 n b 1 n a 21 b 21 a 22 b 22 ⋯ a 2 n b 2 n ⋮ ⋮   ⋮ a m 1 b m 1 a m 2 b m 2 ⋯ a m n b m n ) \boldsymbol{\boldsymbol A\odot\boldsymbol B} = \begin{pmatrix} a_{11}b_{11} &a_{12}b_{12} &\cdots &a_{1n}b_{1n} \\ a_{21}b_{21} &a_{22}b_{22} &\cdots &a_{2n}b_{2n} \\ \vdots &\vdots\ & &\vdots \\ a_{m1}b_{m1} &a_{m2}b_{m2} &\cdots &a_{mn}b_{mn} \end{pmatrix} AB= a11b11a21b21am1bm1a12b12a22b22 am2bm2a1nb1na2nb2namnbmn r t \boldsymbol r_t rt某一维度的值接近0时,网络就更倾向于遗忘 h t − 1 \boldsymbol h_{t-1} ht1的相应维度;反之,当 r t \boldsymbol r_t rt某一维度的值接近1时,网络更倾向于保留 h t − 1 \boldsymbol h_{t-1} ht1的相应维度。之后,我们再将重置过的 h t − 1 ′ \boldsymbol h'_{t-1} ht1 x t \boldsymbol x_t xt组合,得到 h ^ t \hat{\boldsymbol h}_t h^t h ^ t = tanh ⁡ ( W h x t + U h h t − 1 ′ + b h ) \hat{\boldsymbol h}_t=\tanh(\boldsymbol W_h\boldsymbol x_t+\boldsymbol U_h\boldsymbol h'_{t-1}+\boldsymbol b_h) h^t=tanh(Whxt+Uhht1+bh) 这里得到的 h ^ t \hat{\boldsymbol h}_t h^t混合了当前的 x t \boldsymbol x_t xt与部分过去的信息 h t − 1 ′ \boldsymbol h'_{t-1} ht1,并由 tanh ⁡ \tanh tanh函数映射到了 ( − 1 , 1 ) (-1,1) (1,1)范围内。观察上式与普通RNN的更新方式 h t = f h ( W i x i + W h h t − 1 + b t ) \boldsymbol h_t = f_h(\boldsymbol W_i\boldsymbol x_i+\boldsymbol W_h\boldsymbol h_{t-1}+\boldsymbol b_t) ht=fh(Wixi+Whht1+bt),可以看出,普通的RNN相当于令重置单元 r t \boldsymbol r_t rt的所有维度都为1,从而保留了所有过去的信息;而 r t = 0 \boldsymbol r_t=0 rt=0 会消除所有过去的信息,使得RNN退化为与过去无关的单个MLP。可以通过这样的对比体会重置单元的意义。

  最后,我们要决定 h t \boldsymbol h_t ht是要更倾向于旧的信息 h t − 1 \boldsymbol h_{t-1} ht1,还是旧信息与新输入 x t \boldsymbol x_t xt的混合 h ^ t \hat{\boldsymbol h}_t h^t。利用更新单元 z t \boldsymbol z_t zt,我们令 h t = z t ⊙ h t − 1 + ( 1 − z t ) ⊙ h ^ t \boldsymbol h_t=\boldsymbol z_t\odot\boldsymbol h_{t-1}+(\boldsymbol1-\boldsymbol z_t)\odot\hat{\boldsymbol h}_t ht=ztht1+(1zt)h^t

  在上式中,如果更新单元 z t \boldsymbol z_t zt接近 1 \boldsymbol1 1,我们将保留更多的旧信息 h t − 1 \boldsymbol h_{t-1} ht1,而忽略 x t \boldsymbol x_t xt的影响;反之,如果 z t \boldsymbol z_t zt接近 0 \boldsymbol0 0,我们将让旧信息与新信息混合,保留 h ^ t − 1 \hat{\boldsymbol h}_{t-1} h^t1。需要注意,重置单元和更新单元的作用并不相同,两者不能合为一个单元。简单来说,重置单元控制旧信息保留的比例,而更新单元同时控制旧信息和新输入的比例。虽然理论上我们可以用类似 h t = z t ⊙ f h ( U h h t − 1 + b h ) + ( 1 − z t ) ⊙ f x ( W x x t + b x ) \boldsymbol h_t=\boldsymbol z_t\odot f_h(\boldsymbol U_h\boldsymbol h_{t-1}+\boldsymbol b_h)+(\boldsymbol1-\boldsymbol z_t)\odot f_x(\boldsymbol W_x\boldsymbol x_t+\boldsymbol b_x) ht=ztfh(Uhht1+bh)+(1zt)fx(Wxxt+bx) 这样的式子,仅用一个更新单元来计算 h t \boldsymbol h_t ht,但是其灵活性将大打折扣。

  为什么GRU的设计可以缓解梯度爆炸与梯度消失问题呢?上文我们已经提到,导致梯度问题的最大因素是 ∂ h t ∂ h t − 1 \begin{aligned}\frac{\partial\boldsymbol h_t}{\partial\boldsymbol h_{t-1}}\end{aligned} ht1ht的连乘。在GRU中,我们可以通过调整门控单元 r t \boldsymbol r_t rt z t \boldsymbol z_t zt的值,使该梯度始终保持稳定。以文本分析为例,假如某一事物在一段话的开头和结尾出现,为了让模型保留它们之间的关联,我们只需要将重置单元 r t \boldsymbol r_t rt的值减小、更新单元 z t \boldsymbol z_t zt的值增大,就可以使网络在间隔很多时间步之后,仍然保留最初的记忆信息。最极端的情况下,如果令 z 2 , ⋯   , z t − 1 = 1 \boldsymbol z_2,\cdots,\boldsymbol z_{t-1}=\boldsymbol1 z2,,zt1=1,那么从时刻2到时刻 t − 1 t-1 t1 的所有输入都将被忽略,可以直接得到 h t = h 1 \boldsymbol h_t=\boldsymbol h_1 ht=h1。这样,梯度的连乘为 ∏ k = 2 t ∂ h k ∂ h k − 1 = ∂ h t ∂ h 1 = I \prod_{k=2}^t\frac{\partial\boldsymbol h_k}{\partial\boldsymbol h_{k-1}}=\frac{\partial\boldsymbol h_t}{\partial\boldsymbol h_1}=\boldsymbol I k=2thk1hk=h1ht=I

  虽然门控单元的值也是由网络训练得到的,但是门控单元的引入使得GRU可以自我调节梯度。也就是说,如果 h 1 \boldsymbol h_1 h1非常重要,那么门控单元会让 h 1 \boldsymbol h_1 h1保留下来,其梯度较大;如果 h 1 \boldsymbol h_1 h1重要性不高,随着时间推移被遗忘,那么其梯度即使消失也不会产生什么问题。因此,GRU几乎不会发生普通RNN的梯度爆炸或梯度消失现象。

三、动手实现GRU

  本节我们使用PyTorch库中的工具来实现GRU模型,完成简单的时间序列预测任务。时间序列预测任务是指根据一段连续时间内采集的数据、分析其变化规律、预测接下来数据走向的任务。如果当前数据与历史数据存在依赖关系,或者有随时间有一定的规律性,该任务就很适合用RNN求解。本节中,我们生成了一条经过一定处理的正弦曲线作为数据集,存储在sindata_1000.csv中。该曲线包含1000个数据点。其中前800个点作为训练集,后200个点作为测试集。由于本任务是时序预测任务,我们在划分训练集和测试集时无须将其打乱。我们首先导入必要的库和数据集,并将数据集的图像绘制出来。

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn

# 导入数据集
data = np.loadtxt('sindata_1000.csv', delimiter=',')
num_data = len(data)
split = int(0.8 * num_data)
print(f'数据集大小:{num_data}')
# 数据集可视化
plt.figure()
plt.scatter(np.arange(split), data[:split], color='blue', s=10, label='training set')
plt.scatter(np.arange(split, num_data), data[split:], color='none', edgecolor='orange', s=10, label='test set')
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.legend()
plt.show()
# 分割数据集
train_data = np.array(data[:split])
test_data = np.array(data[split:])

在这里插入图片描述

  在训练RNN模型时,虽然我们可以把每个时间步 t t t单独输入,得到模型的预测值 y ^ t \hat y_t y^t,但是这样无法体现出数据的序列相关性质。因此,我们通常会把一段时间序列 x t , ⋯   , x t + k \boldsymbol x_t,\cdots,\boldsymbol x_{t+k} xt,,xt+k整体作为输入,PyTorch中的GRU模块输出这段序列对应的中间变量 h t , ⋯   , h t + k \boldsymbol h_t,\cdots,\boldsymbol h_{t+k} ht,,ht+k。下面的实现中,我们每次输入 x t , ⋯   , x t + k \boldsymbol x_t,\cdots,\boldsymbol x_{t+k} xt,,xt+k 的时间序列,预测输入向后错一步 x t + 1 , ⋯   , x t + k + 1 \boldsymbol x_{t+1},\cdots,\boldsymbol x_{t+k+1} xt+1,,xt+k+1 的数据。参照图4的结构可以发现,GRU模型只输出中间变量。如果要得到我们最后的输出,还需要将这些中间变量经过自定义的其他网络。这一点和CNN里卷积层负责提取特征、MLP负责根据特征完成特定任务的做法非常相似。因此,我们在GRU之后拼接一个全连接层,通过中间变量序列 h t + 1 , ⋯   , h t + k + 1 \boldsymbol h_{t+1},\cdots,\boldsymbol h_{t+k+1} ht+1,,ht+k+1 来预测未来的数据分布。

# 输入序列长度
seq_len = 20
# 处理训练数据,把切分序列后多余的部分去掉
train_num = len(train_data) // (seq_len + 1) * (seq_len + 1)
train_data = np.array(train_data[:train_num]).reshape(-1, seq_len + 1, 1)
np.random.seed(0)
torch.manual_seed(0)

x_train = train_data[:, :seq_len] # 形状为(num_data, seq_len, input_size)
y_train = train_data[:, 1: seq_len + 1]
print(f'训练序列数:{len(x_train)}')

# 转为PyTorch张量
x_train = torch.from_numpy(x_train).to(torch.float32)
y_train = torch.from_numpy(y_train).to(torch.float32)
x_test = torch.from_numpy(test_data[:-1]).to(torch.float32)
y_test = torch.from_numpy(test_data[1:]).to(torch.float32)

在这里插入图片描述

  考虑到GRU的模型结构较为复杂,我们直接使用在PyTorch库中封装好的GRU模型。我们只需要为该模型提供两个参数,第一个参数input_size表示输入 x \boldsymbol x x的维度,第二个参数hidden_size表示中间向量 h \boldsymbol h h的维度,其余参数我们保持默认值。在前向传播时,GRU接受序列 x \boldsymbol x x和初始的中间变量 h \boldsymbol h h。如果最开始我们不知道中间变量的值,GRU会自动将其初始化为全零。前向传播的输出是outhidden,前者是整个时间序列上中间变量的值,而后者只包含是最后一步。out[-1]hidden在GRU内部的层数不同时会有区别,但本节只使用单层网络,因此不详细展开。感兴趣的可以参考PyTorch的官方文档。我们将out作为最后全连接层的输入,得到预测值,再把预测值和hidden返回。hidden将作为下一次前向传播的初始中间变量。

class GRU(nn.Module):
    # 包含PyTorch的GRU和拼接的MLP
    def __init__(self, input_size, output_size, hidden_size):
        super().__init__()
        # GRU模块
        self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size) 
        # 将中间变量映射到预测输出的MLP
        self.linear = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, hidden):
        # 前向传播
        # x的维度为(batch_size, seq_len, input_size)
        # GRU模块接受的输入为(seq_len, batch_size, input_size)
        # 因此需要对x进行变换
        # transpose函数可以交换x的坐标轴
        # out的维度是(seq_len, batch_size, hidden_size)
        out, hidden = self.gru(torch.transpose(x, 0, 1), hidden) 
        # 取序列最后的中间变量输入给全连接层
        out = self.linear(out.view(-1, hidden_size))
        return out, hidden

  接下来,我们设置超参数并实例化GRU。在训练之前,我们还要强调时序模型在测试时与普通模型的区别。GRU在测试时,我们将输入的时间序列长度降为1,即只输入 x t \boldsymbol x_t xt,让GRU预测 t + 1 t+1 t+1 时刻的值。之后,不像普通的任务那样把所有测试数据都给模型,而是让GRU将自己预测的 x ^ t + 1 \hat{\boldsymbol x}_{t+1} x^t+1作为输入,再预测 t + 2 t+2 t+2 时刻的值,循环往复。这样的测试方式对模型在时序上的建模能力有相当高的要求,否则就会很快因为预测值的误差累积,和真实值偏差很大。

# 超参数
input_size = 1 # 输入维度
output_size = 1 # 输出维度
hidden_size = 16 # 中间变量维度
learning_rate = 5e-4

# 初始化网络
gru = GRU(input_size, output_size, hidden_size)
gru_optim = torch.optim.Adam(gru.parameters(), lr=learning_rate)

# GRU测试函数,x和hidden分别是初始的输入和中间变量
def test_gru(gru, x, hidden, pred_steps):
    pred = []
    inp = x.view(-1, input_size)
    for i in range(pred_steps):
        gru_pred, hidden = gru(inp, hidden)
        pred.append(gru_pred.detach())
        inp = gru_pred
    return torch.concat(pred).reshape(-1)

  作为对比,我们用相同的数据同步训练一个3层的MLP模型。该MLP将同样将 x t , ⋯   , x t + k \boldsymbol x_t,\cdots,\boldsymbol x_{t+k} xt,,xt+k 的数据拼接在一起作为输入,此时 k k k被理解为输入的批量大小,并输出 x t + 1 , ⋯   , x t + k + 1 \boldsymbol x_{t+1},\cdots,\boldsymbol x_{t+k+1} xt+1,,xt+k+1 的预测值,与GRU保持一致。在测试时,MLP同样只接受测试集第一个时间步的数据,以和GRU相同的方式进行自循环预测。

# MLP的超参数
hidden_1 = 32
hidden_2 = 16
mlp = nn.Sequential(
    nn.Linear(input_size, hidden_1),
    nn.ReLU(),
    nn.Linear(hidden_1, hidden_2),
    nn.ReLU(),
    nn.Linear(hidden_2, output_size)
)
mlp_optim = torch.optim.Adam(mlp.parameters(), lr=learning_rate)

# MLP测试函数,相比于GRU少了中间变量
def test_mlp(mlp, x, pred_steps):
    pred = []
    inp = x.view(-1, input_size)
    for i in range(pred_steps):
        mlp_pred = mlp(inp)
        pred.append(mlp_pred.detach())
        inp = mlp_pred
    return torch.concat(pred).reshape(-1)

  我们用完全相同的数据训练GRU和MLP。由于已经有了序列长度,我们不再设置SGD的批量大小,直接将每个训练样本单独输入模型进行优化。

max_epoch = 150
criterion = nn.functional.mse_loss
hidden = None # GRU的中间变量

# 训练损失
gru_losses = []
mlp_losses = []
gru_test_losses = []
mlp_test_losses = []
# 开始训练
with tqdm(range(max_epoch)) as pbar:
    for epoch in pbar:
        st = 0
        gru_loss = 0.0
        mlp_loss = 0.0
        # 随机梯度下降
        for X, y in zip(x_train, y_train):
            # 更新GRU模型
            # 我们不需要通过梯度回传更新中间变量
            # 因此将其从有梯度的部分分离出来
            if hidden is not None:
                hidden.detach_()
            gru_pred, hidden = gru(X[None, ...], hidden)
            gru_train_loss = criterion(gru_pred.view(y.shape), y)
            gru_optim.zero_grad()
            gru_train_loss.backward()
            gru_optim.step()
            gru_loss += gru_train_loss.item()
            # 更新MLP模型
            # 需要对输入的维度进行调整,变成(seq_len, input_size)的形式
            mlp_pred = mlp(X.view(-1, input_size))
            mlp_train_loss = criterion(mlp_pred.view(y.shape), y)
            mlp_optim.zero_grad()
            mlp_train_loss.backward()
            mlp_optim.step()
            mlp_loss += mlp_train_loss.item()
        
        gru_loss /= len(x_train)
        mlp_loss /= len(x_train)
        gru_losses.append(gru_loss)
        mlp_losses.append(mlp_loss)
        
        # 训练和测试时的中间变量序列长度不同,训练时为seq_len,测试时为1
        gru_pred = test_gru(gru, x_test[0], hidden[:, -1], len(y_test))
        mlp_pred = test_mlp(mlp, x_test[0], len(y_test))
        gru_test_loss = criterion(gru_pred, y_test).item()
        mlp_test_loss = criterion(mlp_pred, y_test).item()
        gru_test_losses.append(gru_test_loss)
        mlp_test_losses.append(mlp_test_loss)
        
        pbar.set_postfix({
            'Epoch': epoch,
            'GRU loss': f'{gru_loss:.4f}',
            'MLP loss': f'{mlp_loss:.4f}',
            'GRU test loss': f'{gru_test_loss:.4f}',
            'MLP test loss': f'{mlp_test_loss:.4f}'
        })

在这里插入图片描述

  最后,我们在测试集上对比GRU和MLP模型的效果并绘制出来。图中包含了原始数据的训练集和测试集的曲线,可以看出,GRU的预测基本符合测试集的变化规律,而MLP很快就因为缺乏足够的时序信息与测试集偏离。

# 最终测试结果
gru_preds = test_gru(gru, x_test[0], hidden[:, -1], len(y_test)).numpy()
mlp_preds = test_mlp(mlp, x_test[0], len(y_test)).numpy()

plt.figure(figsize=(13, 5))

# 绘制训练曲线
plt.subplot(121)
x_plot = np.arange(len(gru_losses)) + 1
plt.plot(x_plot, gru_losses, color='blue', label='GRU training loss')
plt.plot(x_plot, mlp_losses, color='red', ls='-.', label='MLP training loss')
plt.plot(x_plot, gru_test_losses, color='blue', ls='--', label='GRU test loss')
plt.plot(x_plot, mlp_test_losses, color='red', ls=':', label='MLP test loss')
plt.xlabel('Training step')
plt.ylabel('Loss')
plt.legend(loc='lower left')

# 绘制真实数据与模型预测值的图像
plt.subplot(122)
plt.scatter(np.arange(split), data[:split], color='blue', s=10, label='training set')
plt.scatter(np.arange(split, num_data), data[split:], color='none', edgecolor='orange', s=10, label='test set')
plt.scatter(np.arange(split, num_data - 1), mlp_preds, color='violet', marker='x', alpha=0.4, s=20, label='MLP preds')
plt.scatter(np.arange(split, num_data - 1), gru_preds, color='green', marker='*', alpha=0.4, s=20, label='GRU preds')
plt.legend(loc='lower left')
plt.savefig('output_20_0.png')
plt.savefig('output_20_0.pdf')
plt.show()

在这里插入图片描述

:以上文中的数据集及相关资源下载地址:
链接:https://pan.quark.cn/s/b485bdc0e8eb
提取码:NAn2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2100089.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Graylog配置用户权限以及常用搜索语法

文章目录 一、Graylog配置用户管理1、用户创建2、角色权限管理 二、搜索语法 基于Docker搭建Graylog的具体步骤&#xff1a; https://blog.csdn.net/weixin_44876263/article/details/141638739?spm1001.2014.3001.5502 一、Graylog配置用户管理 1、用户创建 2、角色权限管理…

Linux--实现简易shell

文章目录 shell定义和功能myshell.cGetCwd()GetUsrName()GetHostName()MakeCommandLineAndPrint()GetUserCommand()SplitCommand()Die()ExecuteCommand()GetHome()Cd()CheckBuildin()CheckRedir()myshell.c完整代码 makefile测试函数和进程之间的相似性 Shell是一个功能强大的工…

LVS之net模式实验

总结&#xff1a; lvs #配置环境&#xff0c;两个网卡 [rootlvs ~]# cd /etc/NetworkManager/system-connections/ [rootlvs system-connections]# ls ens160.nmconnection eth0.nmconnection eth1.nmconnection [rootlvs system-connections]# vim eth0.nmconnection [co…

华为OD机试 - 猜数字 - 穷举搜索(Java 2024 E卷 100分)

华为OD机试 2024E卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;E卷D卷A卷B卷C卷&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;私信哪吒&#xff0c;备注华为OD&#xff0c;加…

【RabbitMQ之一:windows环境下安装RabbitMQ】

目录 一、下载并安装Erlang1、下载Erlang2、安装Erlang3、配置环境变量4、验证erlang是否安装成功 二、下载并安装RabbitMQ1、下载RabbitMQ2、安装RabbitMQ3、配置环境变量4、验证RabbitMQ是否安装成功5、启动RabbitMQ服务&#xff08;安装后服务默认自启动&#xff09; 三、安…

Vue2转Vue3学习历程

选项式API>组合式API vue3和vue2的差别就是选项式api改为组件式api&#xff0c;就是以前vue2要定义data、method、mounted&#xff0c;在vue3就变为了更模块化的&#xff0c;并且我感觉vue3设计思路更多是以调用方法的方式实现&#xff0c;比如我实现一个方法&#xff0c;并…

C语言深入理解指针2

1.数组名的理解 #include <stdio.h> int main() {int arr[10] { 1,2,3,4,5,6,7,8,9,10 };printf("&arr[0] %p\n", &arr[0]);printf("arr %p\n", arr);return 0; }可以发现数组名和数组首元素地址的打印结果一样&#xff0c;因此&#xf…

研究生深度学习入门的十天学习计划------第七天

第7天&#xff1a;自然语言处理&#xff08;NLP&#xff09;中的深度学习 目标&#xff1a; 掌握自然语言处理的基础知识与深度学习模型&#xff0c;理解如何应用RNN、LSTM、Transformer等模型处理文本数据。 7.1 自然语言处理的基础概念 自然语言处理&#xff08;NLP&#…

Vue学习笔记 二

4、Vue基础扩展 4.1 插槽 组件的最大特性就是复用性,而用好插槽能大大提高组件的可复用能力在Vue中插槽是很重要的存在,通过插槽,我们可以把父组件中指定的DOM作用到子组件的任意位置,后面我们坐项目用到的组件库比如element-ui,vant-ui都频繁用到的插槽,Vue的插槽主要有…

【hot100篇-python刷题记录】【在排序数组中查找元素的第一个和最后一个位置】

R7-二分查找篇 目录 双指针 二分优化 ps: 思路&#xff1a; 双指针 直接用双指针回缩啊 class Solution:def searchRange(self, nums: List[int], target: int) -> List[int]:ret[-1,-1]left,right0,len(nums)-1while left<len(nums):if nums[left]target:ret[0]…

可解释性与公平性的关系

可解释模型更有可能公平的三个原因 可解释性和公平性似乎是相辅相成的。可解释性涉及理解模型如何进行预测。公平性涉及理解预测是否偏向某些群体。负责任的人工智能框架和机器学习会议始终将这两个特征一起提及。然而&#xff0c;可解释性并不一定意味着公平。 话虽如此&…

[米联客-XILINX-H3_CZ08_7100] FPGA程序设计基础实验连载-26浅谈XILINX FIFO的基本使用

软件版本&#xff1a;VIVADO2021.1 操作系统&#xff1a;WIN10 64bit 硬件平台&#xff1a;适用 XILINX A7/K7/Z7/ZU/KU 系列 FPGA 实验平台&#xff1a;米联客-MLK-H3-CZ08-7100开发板 板卡获取平台&#xff1a;https://milianke.tmall.com/ 登录“米联客”FPGA社区 http…

9、Django Admin优化查询

如果你的Admin后台中有很多计算字段&#xff0c;那么你需要对每个对象运行多个查询&#xff0c;这会使你的Admin后台变得非常慢。要解决此问题&#xff0c;你可以重写管理模型中的get_queryset方法使用annotate聚合函数来计算相关的字段。 以下示例为Origin模型的中ModelAdmin…

Spring6梳理5——基于XML管理Bean环境搭建

以上笔记来源&#xff1a; 尚硅谷Spring零基础入门到进阶&#xff0c;一套搞定spring6全套视频教程&#xff08;源码级讲解&#xff09;https://www.bilibili.com/video/BV1kR4y1b7Qc 目录 ①搭建模块 ②引入配置文件 ③创建BeanXML文件 ④创建Java类文件&#xff08;User…

在K8s上运行GitHub Actions的自托管运行器

1&#xff1a;添加Actions Runner Controller的Helm仓库 helm repo add actions-runner-controller https://actions-runner-controller.github.io/actions-runner-controller helm repo update2&#xff1a;创建GitHub Personal Access Token (PAT) 登录到你的GitHub账户。访…

SQL语句(数据更新、查询操作)

数据库表操作 创建数据库语法格式 create table 表名(字段名1 类型 约束&#xff0c;字段名2 类型 约束&#xff0c;..... ..... )创建学生表&#xff0c;字段要求如下&#xff1a; 姓名&#xff08;长度为10&#xff09;、年龄、身高&#xff08;保留2位小数&#xff09; cre…

安卓shiply热更新入门

目录 一。我的开发环境 二。集成shiply热更新sdk 三。编写代码 1。创建一个CustomRFixLog类 2。创建一个MyApplication类 3。配置AndroidManifest.xml 4。创建一个新的Activity继承AbsRFixDevActivity 用于测试 四。登录shiply后台配置 1。创建项目 五。制作补丁 1。在app…

Ae关键帧动画基础练习-街道汽车超车

目录 1.让背景向左移动 2.让小红车匀速移动 3.实现小黄车的超车 完成街道汽车超车的一个简单动画&#xff0c;背景向左移动看起来就如同画面向右移动了一般&#xff0c;根据这个原理&#xff0c;可以完成这个动画。 导入素材时&#xff0c;要选择不同的图层&#xff0c;这样…

微软AD替代方案统一管理Windows和信创电脑的登录认证与网络准入认证

自国资委79号文明确了2027年底前信息系统全面国产化的目标后&#xff0c;金融单位、央国企集团及各子公司纷纷加大国产化改造力度。不少子、孙公司表示&#xff0c;集团要求到2024年底或2025年底国外的关键IT基础设施要停止使用&#xff0c;如微软AD、云桌面等。 信创国产化是大…

Mybatis链路分析:JDK动态代理和责任链模式的应用

背景 此前写过关于代理模式的文章&#xff0c;参考&#xff1a;代理模式 动态代理功能&#xff1a;生成一个Proxy代理类&#xff0c;Proxy代理类实现了业务接口&#xff0c;而通过调用Proxy代理类实现的业务接口&#xff0c;实际上会触发代理类的invoke增强处理方法。 责任链功…