时间序列预测——GRU模型

news2025/1/18 7:33:05

时间序列预测——GRU模型

在深度学习领域,循环神经网络(RNN)是处理时间序列数据的一种常见选择。上期已介绍了LSTM的单步和多步预测。本文将深入介绍一种LSTM变体——门控循环单元(GRU)模型,包括其理论基础、公式、优缺点,并通过Python实现单步预测的示例。同时,将与长短时记忆网络(LSTM)进行比较,以更好地理解GRU的特性。

1. 引言

循环神经网络(RNN)是一类专门用于处理序列数据的神经网络。然而,传统的RNN存在梯度消失和梯度爆炸等问题,这导致了对长序列的有效建模变得困难。为了解决这些问题,门控循环单元(GRU)被提出。

2. GRU模型的理论

2.1 简介

GRU cell

门控循环单元(GRU)是由Cho等人于2014年提出的,旨在解决长短时记忆网络(LSTM)的一些问题。与LSTM相似,GRU也具有长期依赖性建模的能力,但其结构更加简单。GRU通过更新门和重置门来控制信息的流动,减少了参数数量,使得训练更加高效。

2.2 GRU的结构

GRU由两个门控制:更新门(Update Gate)和重置门(Reset Gate)。与LSTM不同,GRU没有细胞状态,而是直接使用隐藏状态。

GRU的隐藏状态更新公式为:

h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t \begin{equation} h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \end{equation} ht=(1zt)ht1+zth~t

其中:

  • h t h_t ht是当前时间步的隐藏状态。
  • z t z_t zt 是更新门的输出。
  • ⊙ \odot 是逐元素相乘操作。
  • h ~ t \tilde{h}_t h~t 是当前时间步的候选隐藏状态。

2.3 更新门和重置门

更新门(Update Gate)和重置门(Reset Gate)的计算分别为:

z t = σ ( W z ⋅ [ h t − 1 , x t ] ) \begin{equation} z_t = \sigma(W_z \cdot [h_{t-1}, x_t]) \end{equation} zt=σ(Wz[ht1,xt])

r t = σ ( W r ⋅ [ h t − 1 , x t ] ) \begin{equation} r_t = \sigma(W_r \cdot [h_{t-1}, x_t]) \end{equation} rt=σ(Wr[ht1,xt])
其中:

  • W z W_z Wz W r W_r Wr 是权重矩阵。
  • σ \sigma σ 是sigmoid激活函数。
  • [ h t − 1 , x t ] [h_{t-1}, x_t] [ht1,xt] 是当前时间步的隐藏状态和输入拼接而成的向量。

2.4 候选隐藏状态

候选隐藏状态(Candidate Hidden State)的计算为:

h ~ t = tanh ⁡ ( W ⋅ [ r t ⊙ h t − 1 , x t ] ) \begin{equation} \tilde{h}_t = \tanh(W \cdot [r_t \odot h_{t-1}, x_t]) \end{equation} h~t=tanh(W[rtht1,xt])

其中:

  • W W W 是权重矩阵。

3. GRU模型与LSTM的区别

GRU与LSTM有相似之处,都采用了门控制机制,但它们在结构上存在一些区别。

  • 参数数量:GRU的参数数量相对较少,因为它没有细胞状态,直接使用隐藏状态。
  • 计算效率:由于参数较少,GRU在训练和预测时通常更加高效。
  • 表达能力:LSTM的细胞状态允许更好地保留和传递信息,适用于更复杂的序列建模任务。但在某些场景下,GRU由于其简单性能够表达一些简单序列的依赖关系。

4. Python实现GRU的单步预测

接下来,将使用Python和深度学习库Keras实现GRU的单步预测。将使用一个简单的时间序列数据集,以便清晰展示模型的训练和预测过程。

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU, Dense

# 创建示例时间序列数据
np.random.seed(42)
data = np.arange(0, 100, 0.1)
noise = np.random.normal(0, 1, len(data))
data += noise

# 准备训练数据
seq_length = 10
x, y = [], []

for i in range(len(data) - seq_length):
    x.append(data[i:i + seq_length])
    y.append(data[i + seq_length])

x = np.array(x)
y = np.array(y)

x = x.reshape((x.shape[0], x.shape[1], 1))

# 构建GRU模型
model = Sequential()
model.add(GRU(50, activation='relu', input_shape=(seq_length, 1)))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mse')

# 训练GRU模型
model.fit(x, y, epochs=50, verbose=0)

# 使用训练好的模型进行单步预测
input_data = data[-seq_length:].reshape((1, seq_length, 1))
predicted_value = model.predict(input_data, verbose=0)

# 可视化结果
plt.figure(figsize=(12, 6))
plt.plot(data, label='Original Data')
plt.scatter(len(data) - 1, predicted_value, color='red', marker='o', label='GRU Prediction (Single Step)')
plt.title('GRU Model - Single Step Prediction')
plt.legend()
plt.show()

多步预测其实就是修改输入输出的维度,这里不再赘述,可参考LSTM的单步和多步预测。

6. 总结

本文深入介绍了GRU模型的理论基础和相关公式,分析了其优缺点,并通过Python实现了单步预测的示例。GRU作为一种高效而强大的深度学习模型,在时间序列预测中展现了出色的性能。在实际应用中,可以根据具体任务的要求进行调整和优化,以达到更好的预测效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1432301.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ArcGIS Pro 按照字段进行融合或拆分

ArcGIS Pro 按字段融合 在ArcGIS Pro中,通过使用“融合”工具可以轻松地合并具有相同字段的图层。 步骤一:打开ArcGIS Pro 启动ArcGIS Pro应用程序,确保您已经登录并打开您的项目。 步骤二:添加图层 将包含相同字段的图层添加到…

【C++】C++入门 — 类和对象初步介绍

类和对象 1 类的作用域2 类的实例化3 类对象模型4 this指针介绍:特性: Thanks♪(・ω・)ノ谢谢阅读!下一篇文章见!!! 1 类的作用域 类定义了一个新的作用域,类的…

项目安全问题及解决方法-----xss处理

XSS 问题的根源在于,原本是让用户传入或输入正常数据的地方,被黑客替换为了 JavaScript 脚本,页面没有经过转义直接显示了这个数据,然后脚本就被 执行了。更严重的是,脚本没有经过转义就保存到了数据库中,随…

Redis之基础篇

Redis简介 Redis是一种基于键值对(Key-Value)的NoSQL数据库,它支持string(字符串)、hash(哈希)、list(列表)、set(集合)、zset(有序集…

matplotlib-中文乱码问题解决方案

前言 本文主要解决matplotlib在画图时,出现的中文乱码问题,具体问题示意如下: 下面将针对这个问题直接给出具体的解决步骤。 具体步骤 1、首先去网上下载并安装SimHei字体,其它字体也行,如下 并将它安装在此目录下…

面试150 位1的个数 位运算

Problem: 191. 位1的个数 文章目录 思路复杂度Code 思路 👨‍🏫 参考 复杂度 Code public class Solution {// you need to treat n as an unsigned valuepublic int hammingWeight(int n){int res 0;while (n ! 0){res 1;n & n - 1;// 把最后…

海康IPC摄像机接入国标平台,发现一直不在线(离线)的处理方式

目 录 一、问题 二、问题分析 (一)常见设备离线问题的原因 (二)原因分析 三、问题查处 (一)设备端排查故障(设备端自查) 1、检查GB28181参数配置是否有误 2、…

【算法与数据结构】739、LeetCode每日温度

文章目录 一、题目二、解法三、完整代码 所有的LeetCode题解索引,可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、解法 思路分析:   程序如下: 复杂度分析: 时间复杂度: O ( ) O() O()。空间复…

【2月比赛合集】28场可报名的数据挖掘大奖赛,任君挑选!

CompHub[1] 实时聚合多平台的数据类(Kaggle、天池…)和OJ类(Leetcode、牛客…)比赛。本账号会推送最新的比赛消息,欢迎关注! 以下信息仅供参考,以比赛官网为准 目录 Kaggle(2场比赛)阿里天池(…

Elasticsearch:集群故障排除和优化综合指南

Elasticsearch 是一个强大的搜索和分析引擎,是许多数据驱动应用程序和服务的核心。 它实时处理、分析和存储大量数据的能力使其成为当今快节奏的数字世界中不可或缺的工具。 然而,与任何复杂的系统一样,Elasticsearch 可能会遇到影响其性能和…

【Vue项目中使用videojs播放本地mp4的项目】

目录 以下是一个使用video.js播放本地mp4文件的Vue项目代码示例:1. 首先,在终端中使用以下命令安装video.js和video.js插件:2. 在Vue组件中,引入video.js和videojs-youtube插件:3. 配置video-js.css文件,可…

python给word插入脚注

1.需求 最近因为工作需要,需要给大量文本的脚注插入内容,我就写了个小程序。 2.实现 下面程序是我已经给所有脚注插入了两次文本“幸福”,给脚注2到4再插入文本“幸福” from win32com import clientdef add_text_to_specific_footnotes(…

1-2 动手学深度学习v2-基础优化方法-笔记

最常见的算法——梯度下降 当一个模型没有显示解的时候,该怎么办呢? 首先挑选一个参数的随机初始值,可以随便在什么地方都没关系,然后记为 w 0 \pmb{w_{0}} w0​在接下来的时刻里面,我们不断的去更新 w 0 \pmb{w_{0}…

Unity制作随风摇摆的植物

今天记录一下如何实现随风摇摆的植物,之前项目里面的植物摇摆实现是使用骨骼动画实现的,这种方式太消耗性能,植物这种东西没必要,直接使用顶点动画即可。 准备 植物不需要使用标准的PBR流程,基础的颜色贴图加上法向贴…

使用_NT_SYMBOL_PATH在启动VS前设置PDB路径

一、背景 由于公司相关项目的开发管理方式,导致公司会存在多个分支的版本正在开发/测试中。 在这样的背景下,我的日常工作中有时会出现存在某个分支的项目软件的某个功能出现了问题需要我去排查解决,而我当前并不在该分支上开发。于是只能安装…

嵌入式linux移植篇之根文件系统(rootfs)

根文件系统首先是内核启动时所 mount(挂载)的第一个文件系统,系统引导启动程序会在根文件系统挂载之后从中把一些基本的初始化脚本和服务等加载到内存中去运行。单独的 Linux 内核是没法正常工作的,必须要搭配根文件系统。 根文件系统的目录结构 根文…

【SpringBoot】RBAC权限控制

📝个页人主:五敷有你 🔥系列专栏:SpringBoot⛺️稳重求进,晒太阳 权限系统与RBAC模型 权限 为了解决用户和资源的操作关系, 让指定的用户,只能操作指定的资源。 权限功能 菜单权限&a…

自建服务器监控工具uptime kuma

web服务器使用 雨云 提供的2核2g 这里使用1panel的uptime kuma 首先,如果你使用雨云,那么可以直接省去安装1panel的烦恼 直接选择预装后,等待部署完成即可看到面板信息,进入面板,点击应用商店 在应用商店里找到upti…

视云闪播截图

视云闪播截图 1. 截图设置2. 热键设置3. 视频截取3.1. 保存 -> 完成 References 深度学习图像数据获取工具。 视云闪播 https://www.netposa.com/Service/Download.html 1. 截图设置 视云闪播 -> 系统设置 -> 截图设置 2. 热键设置 视云闪播 -> 系统设置 ->…

C# CAD界面-自定义工具栏(三)

运行环境 vs2022 c# cad2016 调试成功 一、引用 二、开发代码进行详细的说明 初始化与获取AutoCAD核心对象: Database db HostApplicationServices.WorkingDatabase;:这行代码获取当前工作中的AutoCAD数据库对象。在AutoCAD中,所有图形数…