基于N-HiTS神经层次插值模型的时间序列预测——cross validation交叉验证与ray tune超参数优化

news2025/1/5 17:39:01

论文链接:https://arxiv.org/pdf/2201.12886v3


N-HiTS: Neural Hierarchical Interpolation for TimeSeries Forecasting \begin{aligned} &\text{\large \color{#CDA59E}N-HiTS: Neural Hierarchical Interpolation for TimeSeries Forecasting}\\ \end{aligned} N-HiTS: Neural Hierarchical Interpolation for TimeSeries Forecasting
NHITS builds upon NBEATS and specializes its partial outputs in the different frequencies of the time series through hierarchical interpolation and multi-rate input processing. On the long-horizon forecasting task NHITS improved accuracy by 25% on AAAI’s best paper award the Informer, while being 50x faster.

References
-Boris N. Oreshkin, Dmitri Carpov, Nicolas Chapados, Yoshua Bengio (2019). “N-BEATS: Neural basis expansion analysis for interpretable time series forecasting”.
-Cristian Challu, Kin G. Olivares, Boris N. Oreshkin, Federico Garza, Max Mergenthaler-Canseco, Artur Dubrawski (2023). “NHITS: Neural Hierarchical Interpolation for Time Series Forecasting”. Accepted at the Thirty-Seventh AAAI Conference on Artificial Intelligence.
-Zhou, H.; Zhang, S.; Peng, J.; Zhang, S.; Li, J.; Xiong, H.; and Zhang, W. (2020). “Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting”. Association for the Advancement of Artificial Intelligence Conference 2021 (AAAI 2021).

在这里插入图片描述


前言

系列专栏:【深度学习:算法项目实战】✨︎
涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记忆、自然语言处理、深度强化学习、大型语言模型和迁移学习。

NHITS是一种解决时间序列长期预测中波动性和计算复杂性的模型。它采用了分层插值和多率数据采样技术,通过构建分层结构来降低计算成本并提高预测精度‌。相较于最新的Transformer架构,NHITS在平均精度上提升了16%,同时计算时间减少了50倍‌。这种模型能够更有效地处理时间序列数据,为时间序列分析提供了新的方法。

具体来说,NHITS通过结合新的分层插值和多率数据采样技术,解决了长期预测中的两个常见挑战:预测的波动性和计算复杂性。这些技术使NHITS能够依次组装其预测,强调具有不同频率和尺度的分量,同时分解输入信号并合成预测‌。这种独特的处理方式使得NHITS在长期预测任务中表现出色。

文章目录

  • 1. 数据集加载
  • 2. 数据预处理
  • 3. 数据可视化
  • 4. 定义超参数
  • 5. 构建模型
  • 6. 交叉验证
  • 7. 预测结果
  • 8. 模型评估

import pandas as pd
import matplotlib.pyplot as plt

from ray import tune
from neuralforecast.auto import AutoNHITS
from neuralforecast.core import NeuralForecast
from neuralforecast.losses.numpy import mae, mse, mape, rmse

from datasetsforecast.long_horizon import LongHorizon

1. 数据集加载

datasetsforecast 是一个用于处理时间序列预测相关数据集的库。它的主要目的是方便用户获取、加载和预处理适合于时间序列预测任务的数据集。在时间序列分析和预测领域,拥有高质量、合适的数据集是非常关键的一步,这个库能够帮助我们更高效地开展工作。

# Change this to your own data to try the model
Y_df, X_df, _ = LongHorizon.load(directory='./', group='ETTm2')

2. 数据预处理

Y_df['ds'] = pd.to_datetime(Y_df['ds'])
# For this excercise we are going to take 20% of the DataSet
n_time = len(Y_df.ds.unique())
val_size = int(.2 * n_time)
test_size = int(.2 * n_time)

Y_df.groupby('unique_id').head(2)

3. 数据可视化

# We are going to plot the temperature of the transformer
# and marking the validation and train splits
u_id = 'HUFL'
x_plot = pd.to_datetime(Y_df[Y_df.unique_id==u_id].ds)
y_plot = Y_df[Y_df.unique_id==u_id].y.values

x_val = x_plot[n_time - val_size - test_size]
x_test = x_plot[n_time - test_size]

fig = plt.figure(figsize=(10, 5))
fig.tight_layout()

plt.plot(x_plot, y_plot)
plt.xlabel('Date', fontsize=17)
plt.ylabel('HUFL [15 min temperature]', fontsize=17)

plt.axvline(x_val, color='black', linestyle='-.')
plt.axvline(x_test, color='black', linestyle='-.')
plt.text(x_val, 5, '  Validation', fontsize=12)
plt.text(x_test, 5, '  Test', fontsize=12)

plt.grid()

HUFL

4. 定义超参数

Ray Tune 是一个用于超参数优化的库,它是基于 Ray 框架的一部分。Ray 是一个开源的分布式计算框架,旨在简化并行和分布式Python编程。Ray Tune 专门设计用来帮助开发者高效地搜索机器学习模型的超参数空间,以找到性能最佳的模型配置

horizon = 96 # 24hrs = 4 * 15 min.

# Use your own config or AutoNHITS.default_config
nhits_config = {
       "learning_rate": tune.choice([1e-3]),                                     # Initial Learning rate
       "max_steps": tune.choice([1000]),                                         # Number of SGD steps
       "input_size": tune.choice([5 * horizon]),                                 # input_size = multiplier * horizon
       "batch_size": tune.choice([7]),                                           # Number of series in windows
       "windows_batch_size": tune.choice([256]),                                 # Number of windows in batch
       "n_pool_kernel_size": tune.choice([[2, 2, 2], [16, 8, 1]]),               # MaxPool's Kernel size
       "n_freq_downsample": tune.choice([[168, 24, 1], [24, 12, 1], [1, 1, 1]]), # Interpolation expressivity ratios
       "activation": tune.choice(['ReLU']),                                      # Type of non-linear activation
       "n_blocks":  tune.choice([[1, 1, 1]]),                                    # Blocks per each 3 stacks
       "mlp_units":  tune.choice([[[512, 512], [512, 512], [512, 512]]]),        # 2 512-Layers per block for each stack
       "interpolation_mode": tune.choice(['linear']),                            # Type of Multi-step interpolation
       "val_check_steps": tune.choice([100]),                                    # Compute validation every 100 epochs
       "random_seed": tune.randint(3, 5),
    }

5. 构建模型

nf = NeuralForecast(
    models = [
        AutoNHITS(h=horizon,
                  config=nhits_config,
                  num_samples=5
                  )
    ],
    freq='15min')

6. 交叉验证

交叉验证方法 cross_validation 将返回模型在测试集上的预测结果。

Y_hat_df = nf.cross_validation(df=Y_df, val_size=val_size,
                               test_size=test_size, n_windows=None)
nf.models[0].results.get_best_result().config
{'learning_rate': 0.001,
 'max_steps': 1000,
 'input_size': 480,
 'batch_size': 7,
 'windows_batch_size': 256,
 'n_pool_kernel_size': [2, 2, 2],
 'n_freq_downsample': [1, 1, 1],
 'activation': 'ReLU',
 'n_blocks': [1, 1, 1],
 'mlp_units': [[512, 512], [512, 512], [512, 512]],
 'interpolation_mode': 'linear',
 'val_check_steps': 100,
 'random_seed': 3,
 'h': 96,
 'loss': MAE(),
 'valid_loss': MAE()}

7. 预测结果

y_true = Y_hat_df.y.values
y_hat = Y_hat_df['AutoNHITS'].values

n_series = len(Y_df.unique_id.unique())

y_true = y_true.reshape(n_series, -1, horizon)
y_hat = y_hat.reshape(n_series, -1, horizon)

print('Parsed results')
print('2. y_true.shape (n_series, n_windows, n_time_out):\t', y_true.shape)
print('2. y_hat.shape  (n_series, n_windows, n_time_out):\t', y_hat.shape)
Parsed results
2. y_true.shape (n_series, n_windows, n_time_out):	 (7, 11425, 96)
2. y_hat.shape  (n_series, n_windows, n_time_out):	 (7, 11425, 96)
fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(10, 11))
fig.tight_layout()

series = ['HUFL','HULL','LUFL','LULL','MUFL','MULL','OT']
series_idx = 3

for idx, w_idx in enumerate([200, 300, 400]):
  axs[idx].plot(y_true[series_idx, w_idx,:],label='True')
  axs[idx].plot(y_hat[series_idx, w_idx,:],label='Forecast')
  axs[idx].grid()
  axs[idx].set_ylabel(series[series_idx]+f' window {w_idx}',
                      fontsize=17)
  if idx==2:
    axs[idx].set_xlabel('Forecast Horizon', fontsize=17)
plt.legend()
plt.show()
#plt.savefig('./results/HUFL_window.png', dpi=300)
plt.close()

在这里插入图片描述

8. 模型评估

以下代码使用了一些常见的评估指标:平均绝对误差(MAE)、平均绝对百分比误差(MAPE)、均方误差(MSE)、均方根误差(RMSE)来衡量模型预测的性能。这里我们将调用 neuralforecast.losses.numpy 模块中的 mae, mse, mape, rmse 函数来对模型的预测效果进行评估。

mae = mae(Y_hat_df['y'], Y_hat_df['AutoNHITS'])
print(f"MAE: {mae:.4f}")

mape = mape(Y_hat_df['y'], Y_hat_df['AutoNHITS'])
print(f"MAPE: {mape * 100:.4f}%")

mse = mse(Y_hat_df['y'], Y_hat_df['AutoNHITS'])
print(f"MSE: {mse:.4f}")

rmse = rmse(Y_hat_df['y'], Y_hat_df['AutoNHITS'])
print(f"RMSE: {rmse:.4f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2270412.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Lumos学习王佩丰Excel第二十三讲:Excel图表与PPT

一、双坐标柱形图的补充知识 1、主次坐标设置 2、主次坐标柱形避让(通过增加两个系列,挤压使得两个柱形挨在一起) 增加两个系列 将一个系列设置成主坐标轴,另一个设成次坐标轴 调整系列位置 二、饼图美化 1、饼图美化常见设置 …

YK人工智能(三)——万字长文学会torch深度学习

2.1 张量 本节主要内容: 张量的简介PyTorch如何创建张量PyTorch中张量的操作PyTorch中张量的广播机制 2.1.1 简介 几何代数中定义的张量是基于向量和矩阵的推广,比如我们可以将标量视为零阶张量,矢量可以视为一阶张量,矩阵就是…

Java基于SpringBoot的甘肃非物质文化网站的设计与实现,附源码

博主介绍:✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&…

计算机网络:网络层知识点及习题(一)

网课资源: 湖科大教书匠 1、概述 网络层实现主机到主机的传输,主要有分组转发和路由选择两大功能 路由选择处理机得出路由表,路由表再生成转发表,从而实现分组从不同的端口转发 网络层向上层提供的两种服务:面向连接…

ACL的注意事项

ACL只对数据进行抓取和匹配,ACl本身不对数据做拒绝和允许的操作,只有在接口方向上应用后才对数据进行拒绝或允许的操作。 ACl只在packetfilter包过滤时默认动作是允许,这个时候至少需要有一条deny规则,否则全都是允许的规则&…

【Cesium】九、Cesium点击地图获取点击位置的坐标,并在地图上添加图标

文章目录 一、前言二、实现方法三、App.vue 一、前言 查找发现好几种方法可以获取到点击位置的坐标。这里我实现需求就不深究学习了。将几位大佬的方法学习过来稍微整合了一下。 本文参考文章: cesium 4种拾取坐标的方法 【Cesium基础学习】拾取坐标 cesium拾取当…

OpenStack的核心组件、主要特点和使用场景

OpenStack 是一个开源的云计算平台,主要用于构建和管理公共及私有云环境。它由多个模块组成,提供虚拟化资源管理、存储管理、网络配置等功能,旨在为数据中心提供自动化的、灵活的云基础设施服务。OpenStack最初由NASA和Rackspace共同开发&…

51c自动驾驶~合集44

我自己的原文哦~ https://blog.51cto.com/whaosoft/12969097 #Towards Generalist Robot Policies 清华大学&字节 | 迈向通用机器人策略:如何选择VLA? 论文标题:Towards Generalist Robot Policies: What Matters in Building Vision…

17爬虫:关于DrissionPage相关内容的学习01

概述 前面我们已经大致了解了selenium的用法,DerssionPage同selenium一样,也是一个基于Python的网页自动化工具。 DrissionPage既可以实现网页的自动化操作,也能够实现收发数据包,也可以把两者的功能合二为一。 DressionPage的…

SSM-Spring-AOP

目录 1 AOP实现步骤(以前打印当前系统的时间为例) 2 AOP工作流程 3 AOP核心概念 4 AOP配置管理 4-1 AOP切入点表达式 4-1-1 语法格式 4-1-2 通配符 4-2 AOP通知类型 五种通知类型 AOP通知获取数据 获取参数 获取返回值 获取异常 总结 5 …

【Linux】:线程安全 + 死锁问题

📃个人主页:island1314 🔥个人专栏:Linux—登神长阶 ⛺️ 欢迎关注:👍点赞 👂🏽留言 😍收藏 💞 💞 💞 1. 线程安全和重入问题&…

数字电路期末复习

*前言:*写的东西不太全面,更多的是一个复习大纲,让你发现自己有哪些不懂的问题(不懂的地方就去翻书或者问AI),如果能够解决提出的所有问题,那么过期末考一定不是问题。 这里写目录标题 数制和码…

python数据分析:使用pandas库读取和编辑Excel表

使用 Pandas,我们可以轻松地读取和写入Excel 文件,之前文章我们介绍了其他多种方法。 使用前确保已经安装pandas和 openpyxl库(默认使用该库处理Excel文件)。没有安装的可以使用pip命令安装: pip install pandas ope…

“AI人工智能软件开发公司:创新技术,引领未来

大家好!今天我们来聊聊一个充满未来感的话题——AI人工智能软件开发公司。这个公司,用大白话说,就是专门研究和开发人工智能软件的地方,它们用最新的技术帮我们解决问题,让生活和工作变得更智能、更便捷。听起来是不是…

uniapp中使用ruoyiPlus中的加密使用(crypto-js)

package.json中添加 "crypto-js": "^4.2.0", "jsencrypt": "^3.3.2",但是vue2中使用 import CryptoJS from cryptojs; 这一步就会报错 参照 参照这里:vue2使用CryptoJS实现信息加解密 根目录下的js文档中新增一个AESwork.…

【SQL Server】教材数据库(1)

1 利用sql建立教材数据库,并定义以下基本表: 学生(学号,年龄,性别,系名) 教材(编号,书名,出版社编号,价格) 订购(学号…

全国计算机设计大赛大数据主题赛(和鲸赛道)经验分享

全国计算机设计大赛大数据主题赛(和鲸赛道)经验分享 这是“和鲸杯”辽宁省普通高等学校本科大学生计算机设计竞赛启动会汇报—大数据主题赛的文档总结。想要参加2025年此比赛的可以借鉴。 一、关于我 人工智能专业 计赛相关奖项: 2022年计…

AI对接之JSON Output

AI的JSON Output 实际对接指南 前言 本系列AI的API对接均以 DeepSeek 为例,其他大模型的对接方式类似。 在现代软件开发中,JSON(JavaScript Object Notation)作为一种轻量级的数据交换格式,因其简洁和易于人阅读的特…

Vue3实现PDF在线预览功能

​🌈个人主页:前端青山 🔥系列专栏:Vue篇 🔖人终将被年少不可得之物困其一生 依旧青山,本期给大家带来Vue篇专栏内容:Vue3现PDF在线预览功能 前言 在开发中,PDF预览和交互功能是一个常见的需求。无论是管理…

SpringBootWeb案例-1

文章目录 SpringBootWeb案例1. 准备工作1.1 需求&环境搭建1.1.1 需求说明1.1.2 环境搭建 1.2 开发规范 2. 部门管理2.1 查询部门2.1.1 原型和需求2.1.2 接口文档2.1.3 思路分析2.1.4 功能开发2.1.5 功能测试 2.2 前后端联调2.3 删除部门2.3.1 需求2.3.2 接口文档2.3.3 思路…