NeuralForecast 超参数优化

news2025/1/8 5:50:43

NeuralForecast 超参数优化

flyfish

不使用超参数优化的方式

import numpy as np
import pandas as pd
from IPython.display import display, Markdown

import matplotlib.pyplot as plt
from neuralforecast import NeuralForecast
from neuralforecast.models import NBEATS, NHITS
from neuralforecast.utils import AirPassengersDF

# Split data and declare panel dataset
Y_df = AirPassengersDF
Y_train_df = Y_df[Y_df.ds<='1959-12-31'] # 132 train
Y_test_df = Y_df[Y_df.ds>'1959-12-31'] # 12 test

# Fit and predict with NBEATS and NHITS models
horizon = len(Y_test_df)
models = [NBEATS(input_size=2 * horizon, h=horizon, max_steps=50),
          NHITS(input_size=2 * horizon, h=horizon, max_steps=50)]
nf = NeuralForecast(models=models, freq='M')
nf.fit(df=Y_train_df)
Y_hat_df = nf.predict().reset_index()

# Plot predictions
fig, ax = plt.subplots(1, 1, figsize = (20, 7))
Y_hat_df = Y_test_df.merge(Y_hat_df, how='left', on=['unique_id', 'ds'])
plot_df = pd.concat([Y_train_df, Y_hat_df]).set_index('ds')

plot_df[['y', 'NBEATS', 'NHITS']].plot(ax=ax, linewidth=2)

ax.set_title('AirPassengers Forecast', fontsize=22)
ax.set_ylabel('Monthly Passengers', fontsize=20)
ax.set_xlabel('Timestamp [t]', fontsize=20)
ax.legend(prop={'size': 15})
ax.grid()
plt.show()

在这里插入图片描述

# # Hyperparameter Optimization


# The main steps of hyperparameter tuning are:
#
#  1. Define training and validation sets.
#  2. Define search space.
#  3. Sample configurations with a search algorithm, train models, and evaluate them on the validation set.
#  4. Select and store the best model.

##超参数优化。


#超参数调优的主要步骤如下:
#。
#1.定义培训和验证集。
#2.定义搜索空间。
#3.使用搜索算法对配置进行采样,训练模型,并在验证集上对其进行评估。
#4.选择并存储最好的型号。


# !pip install neuralforecast hyperopt

from neuralforecast.utils import AirPassengersDF

Y_df = AirPassengersDF
Y_df.head()

from ray import tune

nhits_config = {
       "max_steps": 100,                                                         # Number of SGD steps
       "input_size": 24,                                                         # Size of input window
       "learning_rate": tune.loguniform(1e-5, 1e-1),                             # Initial Learning rate
       "n_pool_kernel_size": tune.choice([[2, 2, 2], [16, 8, 1]]),               # MaxPool's Kernelsize
       "n_freq_downsample": tune.choice([[168, 24, 1], [24, 12, 1], [1, 1, 1]]), # Interpolation expressivity ratios
       "val_check_steps": 50,                                                    # Compute validation every 50 steps
       "random_seed": tune.randint(1, 10),                                       # Random seed
    }

from ray.tune.search.hyperopt import HyperOptSearch
from neuralforecast.losses.pytorch import MAE
from neuralforecast.auto import AutoNHITS


model = AutoNHITS(h=12,
                  loss=MAE(),
                  config=nhits_config,
                  search_alg=HyperOptSearch(),
                  backend='ray',
                  num_samples=10)


from neuralforecast import NeuralForecast


nf = NeuralForecast(models=[model], freq='M')
nf.fit(df=Y_df, val_size=24)


results = nf.models[0].results.get_dataframe()
results.head()


Y_hat_df = nf.predict()
Y_hat_df = Y_hat_df.reset_index()
Y_hat_df.head()


import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING) # Use this to disable training prints from optuna


def config_nhits(trial):
    return {
        "max_steps": 100,                                                                                               # Number of SGD steps
        "input_size": 24,                                                                                               # Size of input window
        "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 1e-1),                                         # Initial Learning rate
        "n_pool_kernel_size": trial.suggest_categorical("n_pool_kernel_size", [[2, 2, 2], [16, 8, 1]]),                 # MaxPool's Kernelsize
        "n_freq_downsample": trial.suggest_categorical("n_freq_downsample", [[168, 24, 1], [24, 12, 1], [1, 1, 1]]),    # Interpolation expressivity ratios
        "val_check_steps": 50,                                                                                          # Compute validation every 50 steps
        "random_seed": trial.suggest_int("random_seed", 1, 10),                                                         # Random seed
    }

model = AutoNHITS(h=12,
                  loss=MAE(),
                  config=config_nhits,
                  search_alg=optuna.samplers.TPESampler(),
                  backend='optuna',
                  num_samples=10)



nf = NeuralForecast(models=[model], freq='M')
nf.fit(df=Y_df, val_size=24)


results = nf.models[0].results.trials_dataframe()
results.drop(columns='user_attrs_ALL_PARAMS')


Y_hat_df_optuna = nf.predict()
print("Y_hat_df_optuna:\n",Y_hat_df_optuna)
Y_hat_df_optuna = Y_hat_df_optuna.reset_index()
Y_hat_df_optuna.head()
print("Y_hat_df_optuna.head():\n",Y_hat_df_optuna.head())

import pandas as pd
import matplotlib.pyplot as plt


fig, ax = plt.subplots(1, 1, figsize = (20, 7))
plot_df = pd.concat([Y_df, Y_hat_df]).reset_index()

plt.plot(plot_df['ds'], plot_df['y'], label='y')
plt.plot(plot_df['ds'], plot_df['AutoNHITS'], label='Ray')
plt.plot(Y_hat_df_optuna['ds'], Y_hat_df_optuna['AutoNHITS'], label='Optuna')

ax.set_title('AirPassengers Forecast', fontsize=22)
ax.set_ylabel('Monthly Passengers', fontsize=20)
ax.set_xlabel('Timestamp [t]', fontsize=20)
ax.legend(prop={'size': 15})
ax.grid()
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1537637.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C#事件实例详解

一、什么是事件&#xff1f; 在C#中,事件(event)是一种特殊的类成员,它允许类或对象通知其他类或对象发生了某些事情。 从语法上看,事件的声明类似于字段,但它们在功能和行为上有一些重要的区别。 从技术角度来说,事件实际上是一个封装了事件订阅和取消订阅功能的委托字段。…

通过JWT完成token登录验证

前言 什么是JWT&#xff1f; 全称是JSON Web token&#xff0c;是用于对应用程序上的用户进行身份验证的标记&#xff0c;使用 JWTS 的应用程序不再需要保存有关其用户的 cookie 或其他session数据 使用JWT的优势 提高了程序的可伸缩性&#xff0c;也极大的提高了应用程序的安全…

鸿蒙Harmony应用开发—ArkTS(@Link装饰器:父子双向同步)

子组件中被Link装饰的变量与其父组件中对应的数据源建立双向数据绑定。 说明&#xff1a; 从API version 9开始&#xff0c;该装饰器支持在ArkTS卡片中使用。 概述 Link装饰的变量与其父组件中的数据源共享相同的值。 限制条件 Link装饰器不能在Entry装饰的自定义组件中使用…

前端canvas项目实战——简历制作网站(六):加粗、斜体、下划线、删除线(上)

目录 前言一、效果展示二、实现步骤1. 视图部分&#xff1a;实现用于切换字体属性的按钮2. 逻辑部分&#xff1a;点击按钮之后要做什么&#xff1f;3. 根据Textbox的属性实时更新按钮的状态 三、Show u the code后记 前言 上一篇博文中&#xff0c;我们实现了对文字的字体、字…

ChatGLM3 Linux 部署

1.首先需要下载本仓库&#xff1a; git clone https://github.com/THUDM/ChatGLM3 2.查看显卡对应的torch 版本 官方文档说明&#xff1a; Start Locally | PyTorch 例如&#xff1a; a. 先查看显卡的CUDA版本 nvcc --version 查看对应版本 Previous PyTorch Versions …

Error:No such property: GradleVersion for class: JetGradlePlugin

Gradle版本对照表 Android Gradle 插件版本在项目的根目录&#xff08;不是App目录&#xff09;下的build.gradle文件中&#xff0c;如图 插件所需的Gradle 版本在gradle目录下的gradle-wrapper.properties文件中&#xff0c;如图

安全认证|CISSP认证是什么证书?考了有什么用?能做什么工作?

很多人总是听说CISSP是顶级的信息安全证书&#xff0c;在国内或者国外都有盛誉&#xff0c;那么CISSP到底是个什么样的证书&#xff0c;本期就给大家介绍下&#xff01; 什么是CISSP CISSP&#xff08;Certification for Information System Security Professional&#xff0…

三份天注定,七分靠XX?

文 | 螳螂观察 作者 | 陈小江 1988年&#xff0c;中国宝岛台湾&#xff0c;蒋经国过世后&#xff0c;社会运动风起云涌。在所谓“解严”的时代氛围里&#xff0c;人们对前途虽然迷茫&#xff0c;但却充满打拼的热情。 那时节&#xff0c;40岁的台湾歌手叶启田&#xff0c;开…

【消息队列开发】 实现消费者订阅消息

文章目录 &#x1f343;前言&#x1f333;关于订阅消息方法参数解析&#x1f38b;如何实现将消息推送给消费者&#x1f38d;消费者类&#x1f340;消费消息的流程&#x1f384;如何实现消息确认呢&#xff1f;⭕总结 &#x1f343;前言 本次开发任务 实现消费者订阅消息 &am…

公司内部局域网怎么适用飞书?

随着数字化办公的普及&#xff0c;企业对于内部沟通和文件传输的需求日益增长。飞书作为一款集成了即时通讯、云文档、日程管理、视频会议等多种功能的智能协作平台&#xff0c;已经成为许多企业提高工作效率的首选工具。本文将详细介绍如何在公司内部局域网中应用飞书&#xf…

电脑Wi-Fi无法连接如何排查

Wi-Fi是一个神奇的东西&#xff0c;总是能在某一天莫名其妙的连不上让我们疯狂糟心&#xff01;&#xff01;&#xff01; 呉師傅准备了几个解决方法来帮助大家解决连不上Wi-Fi的问题&#xff1b; 1、疑难解答功能 系统自带的【疑难解答】功能不妨试一试&#xff0c;也能一定…

【AAAI 2024】M2Doc:文档版面分析的可插拔多模态融合方法

一、文章介绍 文档版面分析任务是文档智能的一个关键任务。然而&#xff0c;现有的很多文档版面分析研究方法都基于通用目标检测方法&#xff0c;忽视了文档的文本特征而仅仅只关注于视觉特征。近年来&#xff0c;基于预训练的文档智能模型在很多文档下游任务中都取得了成功&a…

左旋字符串功能的实现

实现一个函数&#xff0c;可以左旋字符串中的k个字符。 例如&#xff1a; #1ABCD左旋一个字符得到BCDA #2ABCD左旋两个字符得到CDAB 由此图可知&#xff0c;其字符串长度为4&#xff0c;每次经历四次左旋后又回到了初始 位置&#xff0c;所以是以字符串长度len为一个循环&…

微服务cloud--抱团取暖吗 netflix很多停更了

抱团只会卷&#xff0c;卷卷也挺好的 DDD 高内聚 低耦合 服务间不要有业务交叉 通过接口调用 分解技术实现的复杂性&#xff0c;围绕业务概念构建领域模型&#xff1b;边界划分 业务中台&#xff1a; 数据中台&#xff1a; 技术中台&#xff1a; 核心组件 eureka&#x…

(done) ROC曲线 和 AUC值 分别是什么?

来源&#xff1a;https://www.bilibili.com/video/BV1wz4y197LU/?spm_id_from333.337.search-card.all.click&vd_source7a1a0bc74158c6993c7355c5490fc600 在二分类问题下&#xff0c;我们的模型通常会输出一个 概率值&#xff0c;通过判断 概率值 和 阈值threshold 的大小…

docker 安装部署 jenkins

今天 小☀ 给大家普及一下什么是 jenkins&#xff01;&#xff01; Jenkins是一个开源软件项目&#xff0c;基于Java开发的持续集成工具。它提供了一个开放易用的软件平台&#xff0c;使软件项目可以进行持续集成。Jenkins起源于Hudson&#xff0c;主要用于持续、自动地构建、…

动态内存数组(malloc、calloc、realloc、free)

一、为什么要创建动态内存数组 动态内存&#xff0c;顾名思义就是说在内存中非固定的申请数组 在学习该项方法前我们申请内存的方法无非就两种&#xff1a;直接创建变量/通过创建数组的方式来申请空间。 那么直接创建变量/通过创建数组的方式来申请空间的缺点就是一旦创建成…

基于python+vue拍卖行系统的设计与实现flask-django-nodejs-php

拍卖行系统的目的是让使用者可以更方便的将人、设备和场景更立体的连接在一起。能让用户以更科幻的方式使用产品&#xff0c;体验高科技时代带给人们的方便&#xff0c;同时也能让用户体会到与以往常规产品不同的体验风格。 与安卓&#xff0c;iOS相比较起来&#xff0c;拍卖行…

2024学习鸿蒙开发,未来发展如何?

一、前言 想要了解一个领域的未来发展如何&#xff0c;可以从如下几点进行&#xff0c;避免盲从&#xff1a; 国家政策落地情况就业市场如何学习 通过上述三点&#xff0c;就能分析出一个行业的趋势。大家可以看到&#xff0c;我上面的总体逻辑就是根据国家政策来分析未来方…