darts 时序预测入门

news2024/10/5 17:16:25

darts是一个强大而易用的Python时间序列建模工具包。在github上目前拥有超过7k颗stars。

它主要支持以下任务:

  • 时间序列预测 (包含 ARIMA, LightGBM模型, TCN, N-BEATS, TFT, DLinear, TiDE等等)

  • 时序异常检测 (包括 分位数检测 等等)

  • 时间序列滤波 (包括 卡尔曼滤波,高斯过程滤波)

本文演示使用darts构建N-BEATS模型对 牛奶月销量数据进行预测~

公众号算法美食屋后台回复关键词:源码,获取本文notebook源码和数据集~

!pip install darts

一,  准备数据

首先,你需要准备时间序列数据。如果数据有缺失,需要进行数据填充。

这里示范的是一个每月牛奶销量数据集。

 
 
import numpy as np 
import pandas as pd 
from matplotlib import pyplot as plt


import darts
from darts import TimeSeries
from darts.dataprocessing.transformers import MissingValuesFiller
from darts.dataprocessing.transformers import Scaler 




# 1,读取数据集
df = pd.read_csv('month_milk.csv')
df.columns = ['ds','y']
df = df.sort_values(by='ds')
#df['y'] = df['y'].interpolate(method='linear')




# 2,填充数据集
ts_raw = TimeSeries.from_dataframe(df,time_col='ds',value_cols=['y'])
#df = ts_raw.pd_dataframe() #timeseries转成dataframe
fig, ax1 = plt.subplots(1, 1, figsize=(10,5))
ts_raw.plot(color='brown',ax = ax1)
ax1.set_title('before fill')


#from darts.utils.missing_values import fill_missing_values
#ts = fill_missing_values(ts_raw)


filler = MissingValuesFiller()
ts = filler.transform(ts_raw)


fig, ax2 = plt.subplots(1, 1, figsize=(10,5))
ts.plot(color='cyan', ax = ax2)
ax2.set_title('after fill')


# 3,分割数据集
ts_train, ts_test= ts.split_after(pd.Timestamp("1973-01-01")) 
fig, ax3 = plt.subplots(1,1,figsize=(10,5))
ts_train.plot(color='blue',label='train',ax=ax3)
ts_test.plot(color='red',label='test',ax=ax3)
ax3.set_title('after split');




# 4,缩放数据集
scaler = Scaler()
ts_train_scaled = scaler.fit_transform(ts_train)
ts_test_scaled = scaler.transform(ts_test)
ts_scaled = scaler.transform(ts)

2724e31a0807564598b405d1427fbf05.png

0a0854fc32b002bd5207824d4847c037.png

f98d248ece22c91ccd55587a1c1c9941.png

二, 定义模型

接下来,定义一个时间序列模型。Darts支持多种模型。

包括统计模型(ARIMA, FFT, ExponentialSmoothing等)

机器学习模型(Prophet,LightGBM等)

神经网络模型(RNN类型,Transformer类型,CNN类型,MLP类型)。

此处我们使用 NBeats模型。

 
 
import warnings
warnings.filterwarnings("ignore")
import logging
logging.disable(logging.CRITICAL)
import wandb
wandb.login()
 
 
from darts.models import NBEATSModel 
import pytorch_lightning as pl 
from darts.utils.callbacks import TFMProgressBar
from pytorch_lightning.loggers import WandbLogger


wandb_logger = WandbLogger(project='MILK')
progress_bar = TFMProgressBar(enable_train_bar_only=True)
early_stopper = pl.callbacks.EarlyStopping(monitor="val_loss",patience=10,
        min_delta=1e-5,mode="min")
pl_trainer_kwargs = dict(max_epochs=100, 
     accelerator='cpu',
     callbacks = [progress_bar,early_stopper],
     logger = wandb_logger
) 


settings = dict(
    input_chunk_length=10,
    output_chunk_length=3,
    generic_architecture=True,
    num_stacks=10,
    num_blocks=1,
    num_layers=4,
    layer_widths=512,
    save_checkpoints=True,
    
    batch_size=800,
    random_state=42,
    model_name='nbeats',
    force_reset=True
)


settings['pl_trainer_kwargs'] = pl_trainer_kwargs


model = NBEATSModel(**settings)

三,训练模型

model.fit(series = ts_train_scaled, val_series= ts_train_scaled)
model = model.load_from_checkpoint(model_name=model.model_name,best=True)

四,评估模型

 
 
# 历史数据逐段回测,使用真实历史数据作为特征,不做滚动预测
ts_test_forecast = model.historical_forecasts(
    series=ts_scaled,
    start=ts_test_scaled.start_time(),
    forecast_horizon=3,
    stride=3,
    last_points_only=False,
    retrain=False,
    verbose=True,
)
 
 
from darts import concatenate
ts_test_forecast  = scaler.inverse_transform(concatenate(ts_test_forecast))
ts_test_true = ts[ts_test_forecast.time_index]
test_score = r2_score(ts_test_true, ts_test_preds)
 
 
import matplotlib.pyplot  as plt 
plt.figure(figsize=(8, 5))
ts_test_true.plot(label="y",color='blue')
ts_test_forecast.plot(label='yhat-forecast',color='cyan')
plt.title(
    "yhat-forecast R2-score: {}".format(test_score)
)
plt.legend()

9a18ae60157acae54c79a44dda143c4a.png

五,使用模型

#滚动预测,使用预测的数据作为后面预测步骤的特征 
# (注意:当预测步数 n 小于等于模型的output_chunk_length,无需滚动)
ts_preds = model.predict(n = len(ts_test_scaled),series = ts_train_scaled, num_samples=1)
ts_test_preds  = scaler.inverse_transform(ts_preds)
 
 
from darts.metrics import r2_score 
test_score = r2_score(ts_test_true, ts_test_preds)
 
 
import matplotlib.pyplot  as plt 
plt.figure(figsize=(8, 5))
ts_test_true.plot(label="y",color='blue')
ts_test_forecast.plot(label="yhat-forecast",color='cyan')
ts_test_preds.plot(label='yhat',color='red')
plt.title(
    "yhat R2-score: {}".format(test_score)
)
plt.legend()

0dd218379c73417198932b8ab071c98c.png

六,保存模型

 
 
model.save('nbeats')
model_loaded = NBEATSModel.load('nbeats') #重新加载

4872ae24e820f45e89741844e10f1021.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1809995.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Rd-03E】使用CH340给Rd03_E雷达模块烧录固件

Rd03_E 指导手册 安信可新品雷达模组Rd-03搭配STM32制作简易人体感应雷达灯教程 http://t.csdnimg.cn/mqhkE 测距指导手册网址: https://docs.ai-thinker.com/_media/rd-03e%E7%B2%BE%E5%87%86%E6%B5%8B%E8%B7%9D%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C%E4%B8%AD%…

【Android面试八股文】一图展示 Android生命周期:从Activity到Fragment,以及完整的Android Fragment生命周期

图片来源于:https://github.com/xxv/android-lifecycle Android生命周期:从Activity到Fragment 图:android-lifecycle-activity-to-fragments.png 完整的Android Fragment生命周期 图:complete_android_fragment_lifecycle.png…

cve_2022_0543-redis沙盒漏洞复现 vulfocus

1. 原理 该漏洞的存在是因为Debian/Ubuntu中的Lua库是作为动态库提供的。自动填充了一个package变量,该变量又允许访问任意 Lua 功能。 2.复现 我们可以尝试payload: eval local io_l package.loadlib("/usr/lib/x86_64-linux-gnu/liblua5.1.so…

AWT常用组件

AWT中常用组件 前言一、基本组件组件名标签(Label类)Label类的构造方法注意要点 按钮(Button)Button的构造方法注意要点 文本框(TextField)TextField类的构造方法注意要点 文本域(TextArea)TextArea 的构造方法参数scrollbars的静态常量值 复选框&#x…

文心一言 VS 讯飞星火 VS chatgpt (278)-- 算法导论20.3 5题

五、假设我们创建一个包含 u 1 k u^\frac{1}{k} uk1​ 个簇(而不是全域大小为 x ↓ {\sqrt[↓]{x}} ↓x ​ 的 x ↑ {\sqrt[↑]{x}} ↑x ​ 个簇)的 vEB 树,其每个簇的全域大小为 u 1 − 1 k u ^ {1-\frac{1}{k}} u1−k1​ ,其中 k>1 &#xff0c…

【UML用户指南】-13-对高级结构建模-包

目录 1、名称 2、元素 3、可见性 4、引入与引出 用包把建模元素安排成可作为一个组来处理的较大组块。可以控制这些元素的可见性,使一些元素在包外是可见的,而另一些元素要隐藏在包内。也可以用包表示系统体系结构的不同视图。 狗窝并不复杂&#x…

数据库管理-第200期 身边的数据库从业者(20240610)

数据库管理200期 2024-06-10 数据库管理-第200期 身边的数据库从业者(20240610)首席-薛晓刚院长-施嘉伟邦德-王丁丁强哥-徐小强会长-吴洋灿神-熊灿灿所长-严少安探长-张震总结活动预告 数据库管理-第200期 身边的数据库从业者(20240610&#…

匈牙利匹配算法

一 什么是匈牙利匹配算法 匈牙利算法是一种解决二分图最大匹配问题的算法。在二分图中,将左边的点称为X集合,将右边的点称为Y集合,我们需要在X集合和Y集合之间建立一个双向边集合,使得所有的边都不相交。如果我们能够找到一个最大…

前端工程化工具系列(十)—— Browserslist:浏览器兼容性配置工具

Browserslist 是一个能够在不同的前端工具间共享目标浏览器的配置,各工具根据该配置进行代码转译等操作。 具体的这些前端工具为:Autoprefixer、Babel、postcss-preset-env、eslint-plugin-compat、stylelint-no-unsupported-browser-features、postcss-…

文件的基础必备知识(初学者入门)

1. 为什么使用文件 2. 什么是文件 3. 二进制文件和文本文件 4. 文件的打开和关闭 1.为什么使用文件 我们写的程序数据是存储在电脑内存中,如果程序退出,内存回收,数据就丢失,等程序再次运行时,上次的数据已经消失。面…

安利一款非常不错浏览器文本翻译插件(效果很不错,值得一试)

官网地址:https://immersivetranslate.com/ “沉浸式翻译”这个词,由我们发明创造。如今,它已然成为“双语对照翻译”的代名词。自2023年上线以来,这款备受赞誉的 AI 双语对照网页翻译扩展,已帮助超过 100 万用户跨越语…

鸿蒙开发必备:《DevEco Studio 系列一:实用功能解析与常用快捷键大全》

系列文章目录 文章目录 系列文章目录前言一、下载与安装小黑板 二、IDE被忽略的实用功能-帮助(Help)1.Quick Start2. API Reference3.FAQ 三、常用快捷键一、编辑二、查找或替换三、编译与运行四、调试五、其他 前言 DevEco Studio)是基于In…

NLP——电影评论情感分析

python-tensorflow2.0 numpy 1.19.1 tensorflow 2.0.0 导入库 数据加载 数据处理 构建模型 训练 评估 预测 1.基于2层dropout神经网络 2.基于LSTM的网络 #导入需要用到的库 import os import tarfile import urllib. request import tensorflow as tf import numpy a…

W25Q64简介

W25Q64介绍 本节使用的是:W25Q64: 64Mbit / 8MByte。存储器分为易失性存储器和非易失性存储器,易失性存储器一般是SRAM,DRAM。非易失性存储器一般是E2PROM,Flash等。非易失性存储器,掉电不丢失。 字库存储…

单片机嵌入式计算器(带程序EXE)

单片机嵌入式计算器 主要功能:完成PWM占空比计算,T溢出时间(延时); [!NOTE] 两个程序EXE; [!CAUTION] 百度网盘链接:链接:https://pan.baidu.com/s/1VJ0G7W5AEQw8_MiagM7g8A?pwdg8…

java线程生命周期介绍

Java线程的生命周期包含以下几个状态: 1.新建(New):线程对象被创建,但是还没有调用start()方法。 1.运行(Runnable):线程正在运行或者是就绪状态,等待CPU时间片。 1.阻塞(Blocked):线程暂时停止执行&…

Nginx 精解:正则表达式、location 匹配与 rewrite 重写

一、常见的 Nginx 正则表达式 在 Nginx 配置中,正则表达式用于匹配和重写 URL 请求。以下是一些常见的 Nginx 正则表达式示例: 当涉及正则表达式时,理解各个特殊字符的含义是非常重要的。以下是每个特殊字符的例子: ^&#xff1…

修复损坏的Excel文件比你想象的要简单,这里提供几种常见的修复方法

打开重要的Excel文件时遇到问题吗?Microsoft Excel是否要求你验证文件是否已损坏?Excel文件可能由于各种原因而损坏,从而无法打开。但不要失去希望;你可以轻松修复损坏的Excel文件。 更改Excel信任中心设置 Microsoft Excel有一个内置的安全功能,可以在受限模式下打开有…

Failed to start gdm.servide - GNOME Display manager

启动虚拟机时,卡在Failed to start gdm.servide - GNOME Display manager不动。 解决方法: 1.重新启动Ubuntu,在进度条未结束之前,长按shift直到界面跳转到选项菜单。 2.选择第二个,按enter进入 3.选择最新版本(后面…

MySQL时间和日期类型详解(零基础入门篇)

目录 1. DATE 2. DATETIME 3. TIMESTAMP 4. TIME 5. YEAR 6. 日期和时间的使用示例 以下SQL语句的测试可以使用命令行,或是使用SQL工具比如MySQL Workbench或SQLynx等。 在 MySQL 中,时间和日期数据类型用于存储与时间相关的数据,如何…