时间序列预测 —— DeepAR 模型

news2024/9/22 5:12:50

时间序列预测 —— DeepAR 模型

DeepAR 模型是一种专门用于处理时间序列概率预测的深度学习模型,它可以自动学习数据中的复杂模式,提高预测的准确性。本文将介绍 DeepAR 模型的理论基础、优缺点,并通过 Python 实现单步预测和多步预测的完整代码。

1. DeepAR 模型简介

DeepAR 模型是由亚马逊提出的一种概率生成模型,旨在进行时间序列预测。与传统的基于深度学习的序列模型(如 LSTM 和 GRU)不同,DeepAR 考虑了时间序列的复杂模式。它采用了一种门控循环单元(Gated Recurrent Unit, GRU)的,结合了蒙特卡洛的概率生成方法,使其在处理多维时间序列数据时表现优异。

2. DeepAR 模型理论及公式

DeepAR

2.1 基本原理

DeepAR 模型的基本原理是将历史时间序列数据作为输入,通过逐步预测未来数据。模型利用历史观测值和可变长度的上下文窗口来学习数据的动态模式,并通过概率生成方法提供置信区间。传统时序预测为点预测,DeepAR为区间概率预测。

2.2 模型公式

DeepAR 模型的核心是门控循环单元(GRU)和注意力机制,其中 GRU 的更新公式如下:

GRU 的更新步骤:
z t = σ ( W z ⋅ [ h t − 1 , x t ] + b z ) r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) h ~ t = tanh ( W ⋅ [ r t ⋅ h t − 1 , x t ] + b ) h t = ( 1 − z t ) ⋅ h t − 1 + z t ⋅ h ~ t \begin{aligned} z_t &= \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \\ r_t &= \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \\ \tilde{h}_t &= \text{tanh}(W \cdot [r_t \cdot h_{t-1}, x_t] + b) \\ h_t &= (1 - z_t) \cdot h_{t-1} + z_t \cdot \tilde{h}_t \end{aligned} ztrth~tht=σ(Wz[ht1,xt]+bz)=σ(Wr[ht1,xt]+br)=tanh(W[rtht1,xt]+b)=(1zt)ht1+zth~t

其中:

  • z t z_t zt 是更新门,控制要从候选隐藏状态 (\tilde{h}_t) 中选择多少信息传递给新的隐藏状态。
  • r t r_t rt 是重置门,控制要忽略前一时刻隐藏状态的程度。
    - h ~ t \tilde{h}_t h~t是候选隐藏状态。
  • h t h_t ht 是当前时刻的隐藏状态。

2.3 模型结构

DeepAR 模型结构包括多层 GRU 单元,每一层都有自己的权重和偏差。模型还使用了注意力机制,以便更好地处理长序列。DeepAR以蒙特卡罗样本的形式进行概率预测,该样本可用于计算预测视界内所有子范围的一致分位数估

3. DeepAR 与其他模型的区别

相较于传统的序列模型,DeepAR 在以下方面有显著的不同之处:

  • 考虑季节性和周期性: DeepAR 能够自动学习和捕捉数据中的季节性和周期性模式,而无需手动调整。

  • 多维时间序列处理: DeepAR 能够有效处理多维时间序列,同时考虑多个时间序列之间的相关性。

4. DeepAR 模型优缺点

4.1 优点

  • 自动捕捉复杂模式: DeepAR 能够自动捕捉时间序列中的复杂模式,无需手动特征工程。

  • 概率生成: 模型提供了预测的概率分布,可以用于构建置信区间。

4.2 缺点

  • 计算成本较高: 模型的训练和预测过程相对较慢,尤其

在处理大规模数据时。

5. DeepAR 模型 Python 实现

下面将通过 Python 代码演示 DeepAR 模型的单步预测和多步预测。请确保已安装必要的库:

pip install mxnet gluonts

5.1 单步预测代码

# 导入必要的库
from gluonts.model.deepar import DeepAREstimator
from gluonts.dataset.common import ListDataset
from gluonts.trainer import Trainer

# 创建示例数据(替换为你自己的数据)
train_data = [...]  # 训练数据
test_data = [...]   # 测试数据

# 将数据转换为 ListDataset 格式
train_ds = ListDataset([{"start": "2022-01-01", "target": train_data}], freq="D")
test_ds = ListDataset([{"start": "2022-01-10", "target": test_data}], freq="D")

# 定义 DeepAR 模型参数
estimator = DeepAREstimator(freq="D", prediction_length=7, trainer=Trainer(epochs=10))

# 训练模型
predictor = estimator.train(train_ds)

# 单步预测
single_step_forecast = predictor.predict(test_ds)

# 打印预测结果
print(single_step_forecast.mean)

5.2 多步预测代码

# 导入必要的库
from gluonts.model.deepar import DeepAREstimator
from gluonts.dataset.common import ListDataset
from gluonts.trainer import Trainer

# 创建示例数据(替换为你自己的数据)
train_data = [...]  # 训练数据
test_data = [...]   # 测试数据

# 将数据转换为 ListDataset 格式
train_ds = ListDataset([{"start": "2022-01-01", "target": train_data}], freq="D")
test_ds = ListDataset([{"start": "2022-01-10", "target": test_data}], freq="D")

# 定义 DeepAR 模型参数
estimator = DeepAREstimator(freq="D", prediction_length=14, trainer=Trainer(epochs=10))

# 训练模型
predictor = estimator.train(train_ds)

# 多步预测
multi_step_forecast = predictor.predict(test_ds)

# 打印预测结果
print(multi_step_forecast.mean)

6. 总结

本文介绍了时间序列预测中的 DeepAR 模型,包括其理论基础、优缺点以及与其他模型的区别。通过 Python 代码演示了 DeepAR 模型的单步预测和多步预测,希望读者能够更好地理解和应用这一强大的时间序列预测模型。在实际应用中,可以根据具体问题进行模型参数的调整和优化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1436738.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

git整合分支的两种方法——合并(Merge)、变基(Rebase)

问题描述: 初次向git上传本地代码或者更新代码时,总是会遇到以下两个选项。有时候,只是想更新一下代码,没想到,直接更新了最新的代码,但是自己本地的代码并没有和git上的代码融合,反而被覆盖了…

坚持刷题 | 二叉树的直径

文章目录 题目考察点代码实现实现总结方便用迭代的方式实现吗?迭代实现迭代实现总结 Hello,大家好,我是阿月。坚持话题,老年痴呆追不上我,今天还有时间,那就再来一题吧:二叉树的直径 题目 543.…

子集枚举介绍

集合枚举的意思是从一个集合中找出它的所有子集。集合中每个元素都可以被选或不选,含有n个元素的集合总共有个子集(包括全集和空集) 例如考虑集合和它的4个子集、、、,按照某个顺序,把全集A中的每个元素在每个子集中的…

Visual Studio 2010+C#实现信源和信息熵

1. 设计要求 以图形界面的方式设计一套程序,该程序可以实现以下功能: 从输入框输入单个或多个概率,然后使用者可以通过相关按钮的点击求解相应的对数,自信息以及信息熵程序要能够实现马尔可夫信源转移概率矩阵的输入并且可以计算…

计算机网络-差错控制(纠错编码 海明码 纠错方法)

文章目录 纠错编码-海明码海明距离1.确定校验码位数r2.确定校验码和数据的位置3.求出校验码的值4.检错并纠错纠错方法1纠错方法2 小结 纠错编码-海明码 奇偶校验码:只能发现错误不能找到错误位置和纠正错误 海明距离 如果找到码距为1,那肯定为1了&…

鸿蒙开发-UI-组件导航-Tabs

鸿蒙开发-UI-组件 鸿蒙开发-UI-组件2 鸿蒙开发-UI-组件3 鸿蒙开发-UI-气泡/菜单 鸿蒙开发-UI-页面路由 鸿蒙开发-UI-组件导航-Navigation 文章目录 一、基本概念 二、导航 1.底部导航 2.顶部导航 3.侧边导航 4.导航栏限制滑动 三、导航栏 1.固定导航栏 2.滚动导航栏 3…

莉莉与神奇花朵的冒险

现在,我将根据这些步骤编写一个对话形式的童话故事。 在很久很久以前的一个小村庄里,有一个勤劳善良的小女孩叫莉莉。她住在一间小茅屋里,和她的奶奶一起生活。奶奶年纪大了,行动不便,所以莉莉每天都要照顾她。 一天&a…

【多模态大模型】跨越视觉-语言界限:BLIP的多任务精细处理策略

BLIP 核心思想MED架构和CapFilt方法效果 总结CLIP模型 VS BLIP模型CLIP模型BLIP模型 核心思想 论文:https://proceedings.mlr.press/v162/li22n/li22n.pdf 代码:https://github.com/salesforce/BLIP BLIP(Bootstrapping Language-Image Pre…

win11安装MySql5.7

1、下载 打开下载链接:MySQL :: Download MySQL Installer 2、安装 2.1、安装界面 2.2、选择自定义安装 2.3、根据自己系统的位数进行选择是X64还是X86 2.4、选择安装路径 2.5、继续下一步 2.6、选择服务器专用,端口是3306 2.7、设置密码 2.8、设置服…

python函数入参、类成员引用支持灵活参数可配

一、背景 python编码时,有可能在不同场景下输入修改的参数,不方便直接写死,因此需要灵活配置这些函数入参,类成员 二、函数入参支持灵活可配 场景:如下场景,对于hello函数,不同场景下想要对不…

async/await使用过程中,要注意的问题

问: const getData async () >{ console.log(触发了getData接口) let resultData await getActivityInfo(activityId); console.log(resultData,resultData) let id resultData.id; let shareImg resultData.shareImg let shareSubtitle resultData.shareSubtit…

T-Sql 也能更新修改查询JSON?

今天看见一个澳洲项目里面使用了 JSON_VALUE 这样的函数解析 JSON 我倍感诧异,我印象当中Sql Server并不支持JOSN的相关操作,他最多只把JSON当成一个字符串来存储,更不要说去解析,查询和更新了 我随后查询了下此函数,…

美颜SDK是什么?探秘直播美颜SDK在视频社交平台中的应用

当今,美颜SDK正扮演着改变用户体验的“角色”,本篇文章,小编将为大家讲述美颜SDK的概念、原理以及在视频社交平台中的广泛应用。 一、什么是美颜SDK? 美颜SDK是一套软件工具包,提供了一系列美颜算法和功能&#xff…

基础面试题整理6之Redis

1.Redis的应用场景 Redis支持类型:String、hash、set、zset、list String类型 hash类型 set类型 zset类型 list类型 一般用作缓存,例如 如何同时操作同一功能 2.redis是单线程 Redis服务端(数据操作)是单线程,所以Redis是并发安全的,因…

页面模块向上渐变显示效果实现

ps: 先祝各位朋友新春快乐 ^o^/ 想要首页不那么枯燥无味吗?还在未首页过于单调而苦恼吧,来试试这个吧(大佬请忽略上述语句o) 今天要实现一个页面线上渐变显示的效果,用来丰富首页等页面: 这里先随机建立几…

哈希加密Python实现

一、代码 from cryptography.fernet import Fernet import os import bcrypt# 密钥管理和对称加密相关 def save_key_to_file(key: bytes, key_path: str):with open(key_path, wb) as file:file.write(key)def load_key_from_file(key_path: str) -> bytes:if not os.path…

linux服务器如何提高游戏帧率?

在Linux服务器上,由于硬件配置和系统的限制,提高游戏帧率变得更加困难。但是通过一些优化和调整,我们仍然可以提升Linux服务器上的游戏性能。 首先我们需要了解游戏帧率与服务器性能之间的关系。游戏帧率是指游戏每秒渲染的帧数,…

零基础学编程从哪里入手,在学习中可以线上会议答疑解惑

一、前言 零基础学编程可以先从容易学的语言入手,比如中文编程,然后再学其他编程语言则会比较轻松,初步掌握编程思路。很多IT人士一般学2到3种编程语言。 今天给大家分享的中文编程开发语言工具资料如下: 编程入门视频教程链接…

网络防御安全:2-6天笔记

第二章:防火墙 一、什么是防火墙 防火墙的主要职责在于:控制和防护。 防火墙可以根据安全策略来抓取流量之后做出对应的动作。 二、防火墙的发展 区域: Trust 区域,该区域内网络的受信任程度高,通常用来定义内部…

【cmu15445c++入门】(6)c++的迭代器

一、迭代器 C 迭代器是指向容器内元素的对象。它们可用于循环访问该容器的对象。我们知道迭代器的一个示例是指针。指针可用于循环访问 C 样式数组. 二、代码 自己实现一个迭代器 // C iterators are objects that point to an element inside a container. // They can be…