SOFTS: 时间序列预测的最新模型以及Python使用示例

news2025/2/22 18:15:30

近年来,深度学习一直在时间序列预测中追赶着提升树模型,其中新的架构已经逐渐为最先进的性能设定了新的标准。

这一切都始于2020年的N-BEATS,然后是2022年的NHITS。2023年,PatchTST和TSMixer被提出,最近的iTransformer进一步提高了深度学习预测模型的性能。

这是2024年4月《SOFTS: Efficient Multivariate Time Series Forecasting with Series-Core Fusion》中提出的新模型,采用集中策略来学习不同序列之间的交互,从而在多变量预测任务中获得最先进的性能。

在本文中,我们详细探讨了SOFTS的体系结构,并介绍新的STar聚合调度(STAD)模块,该模块负责学习时间序列之间的交互。然后,我们测试将该模型应用于单变量和多变量预测场景,并与其他模型作为对比。

SOFTS介绍

SOFTS是 Series-cOre Fused Time Series的缩写,背后的动机来自于长期多元预测对决策至关重要的认识:

首先我们一直研究Transformer的模型,它们试图通过使用补丁嵌入和通道独立等技术(如PatchTST)来降低Transformer的复杂性。但是由于通道独立性,消除了每个序列之间的相互作用,因此可能会忽略预测信息。

iTransformer 通过嵌入整个序列部分地解决了这个问题,并通过注意机制处理它们。但是基于transformer的模型在计算上是复杂的,并且需要更多的时间来训练非常大的数据集。

另一方面有一些基于mlp的模型。这些模型通常很快,并产生非常强的结果,但当存在许多序列时,它们的性能往往会下降。

所以出现了SOFTS:研究人员建议使用基于mlp的STAD模块。由于是基于MLP的,所以训练速度很快。并且STAD模块,它允许学习每个序列之间的关系,就像注意力机制一样,但计算效率更高。

SOFTS架构

在上图中可以看到每个序列都是单独嵌入的,就像在iTransformer 中一样。

然后将嵌入发送到STAD模块。每个序列之间的交互都是集中学习的,然后再分配到各个系列并融合在一起。

最后再通过线性层产生预测。

这个体系结构中有很多东西需要分析,我们下面更详细地研究每个组件。

1、归一化与嵌入

首先使用归一化来校准输入序列的分布。使用了可逆实例的归一化(RevIn)。它将数据以单位方差的平均值为中心。然后每个系列分别进行嵌入,就像在iTransformer 模型。

在上图中我们可以看到,嵌入整个序列就像应用补丁嵌入,其中补丁长度等于输入序列的长度。

这样,嵌入就包含了整个序列在所有时间步长的信息。

然后将嵌入式系列发送到STAD模块。

2、STar Aggregate-Dispatch (STAD)

STAD模块是soft模型与其他预测方法的真正区别。使用集中式策略来查找所有时间序列之间的相互作用。

嵌入的序列首先通过MLP和池化层,然后将这个学习到的表示连接起来形成核(上图中的黄色块表示)。

核构建好了以后就进入了“重复”和“连接”的步骤,在这个步骤中,核表示被分派给每个系列。

MLP和池化层未捕获的信息还可以通过残差连接添加到核表示中。然后在融合(fuse)操作的过程中,核表示及其对应系列的残差都通过MLP层发送。最后的线性层采用STAD模块的输出来生成每个序列的最终预测。

与其他捕获通道交互的方法(如注意力机制)相比,STAD模块的主要优点之一是它降低了复杂性。

因为STAD模块具有线性复杂度,而注意力机制具有二次复杂度,这意味着STAD在技术上可以更有效地处理具有多个序列的大型数据集。

下面我们来实际使用SOFTS进行单变量和多变量场景的测试。

使用SOFTS预测

这里,我们使用 Electricity Transformer dataset 数据集。

这个数据集跟踪了中国某省两个地区的变压器油温。每小时和每15分钟采样一个数据集,总共有四个数据集。

我门使用neuralforecast库中的SOFTS实现,这是官方认可的库,并且这样我们可以直接使用和测试不同预测模型的进行对比。

在撰写本文时,SOFTS还没有集成在的neuralforecast版本中,所以我们需要使用源代码进行安装。

 pip install git+https://github.com/Nixtla/neuralforecast.git

然后就是从导入包开始。使用datasetsforecast以所需格式加载数据集,以便使用neuralforecast训练模型,并使用utilsforecast评估模型的性能。这就是我们使用neuralforecast的原因,因为他都是一套的

 import pandas as pd
 import numpy as np
 import matplotlib.pyplot as plt
 
 from datasetsforecast.long_horizon import LongHorizon
 
 from neuralforecast.core import NeuralForecast
 from neuralforecast.losses.pytorch import MAE, MSE
 from neuralforecast.models import SOFTS, PatchTST, TSMixer, iTransformer
 
 from utilsforecast.losses import mae, mse
 from utilsforecast.evaluation import evaluate

编写一个函数来帮助加载数据集,以及它们的标准测试大小、验证大小和频率。

 def load_data(name):
     if name == "ettm1":
         Y_df, *_ = LongHorizon.load(directory='./', group='ETTm1')
         Y_df = Y_df[Y_df['unique_id'] == 'OT'] # univariate dataset
         Y_df['ds'] = pd.to_datetime(Y_df['ds'])
         val_size = 11520
         test_size = 11520
         freq = '15T'
     elif name == "ettm2":
         Y_df, *_ = LongHorizon.load(directory='./', group='ETTm2')
         Y_df['ds'] = pd.to_datetime(Y_df['ds']) 
         val_size = 11520
         test_size = 11520
         freq = '15T'
 
     return Y_df, val_size, test_size, freq

然后就可以对ETTm1数据集进行单变量预测。

1、单变量预测

加载ETTm1数据集,将预测范围设置为96个时间步长。

可以测试更多的预测长度,但我们这里只使用96。

 Y_df, val_size, test_size, freq = load_data('ettm1')
 
 horizon = 96

然后初始化不同的模型,我们将soft与TSMixer, iTransformer和PatchTST进行比较。

所有模型都使用的默认配置将最大训练步数设置为1000,如果三次后验证损失没有改善,则停止训练。

 models = [
     SOFTS(h=horizon, input_size=3*horizon, n_series=1, max_steps=1000, early_stop_patience_steps=3),
     TSMixer(h=horizon, input_size=3*horizon, n_series=1, max_steps=1000, early_stop_patience_steps=3),
     iTransformer(h=horizon, input_size=3*horizon, n_series=1, max_steps=1000, early_stop_patience_steps=3),
     PatchTST(h=horizon, input_size=3*horizon, max_steps=1000, early_stop_patience_steps=3)
 ]

然后初始化NeuralForecast对象训练模型。并使用交叉验证来获得多个预测窗口,更好地评估每个模型的性能。

 nf = NeuralForecast(models=models, freq=freq)
 nf_preds = nf.cross_validation(df=Y_df, val_size=val_size, test_size=test_size, n_windows=None)
 nf_preds = nf_preds.reset_index()

评估计算了每个模型的平均绝对误差(MAE)和均方误差(MSE)。因为之前的数据是缩放的,因此报告的指标也是缩放的。

 ettm1_evaluation = evaluate(df=nf_preds, metrics=[mae, mse], models=['SOFTS', 'TSMixer', 'iTransformer', 'PatchTST'])

从上图可以看出,PatchTST的MAE最低,而softts、TSMixer和PatchTST的MSE是一样的。在这种特殊情况下,PatchTST仍然是总体上最好的模型。

这并不奇怪,因为PatchTST在这个数据集中是出了名的好,特别是对于单变量任务。下面我们开始测试多变量场景。

2、多变量预测

使用相同的load_data函数,我们现在为这个多变量场景使用ETTm2数据集。

 Y_df, val_size, test_size, freq = load_data('ettm2')
 
 horizon = 96

然后简单地初始化每个模型。我们只使用多变量模型来学习序列之间的相互作用,所以不会使用PatchTST,因为它应用通道独立性(意味着每个序列被单独处理)。

然后保留了与单变量场景中相同的超参数。只将n_series更改为7,因为有7个时间序列相互作用。

 models = [SOFTS(h=horizon, input_size=3*horizon, n_series=7, max_steps=1000, early_stop_patience_steps=3, scaler_type='identity', valid_loss=MAE()),
           TSMixer(h=horizon, input_size=3*horizon, n_series=7, max_steps=1000, early_stop_patience_steps=3, scaler_type='identity', valid_loss=MAE()),
           iTransformer(h=horizon, input_size=3*horizon, n_series=7, max_steps=1000, early_stop_patience_steps=3, scaler_type='identity', valid_loss=MAE())]

训练所有的模型并进行预测。

 nf = NeuralForecast(models=models, freq='15min')
 
 nf_preds = nf.cross_validation(df=Y_df, val_size=val_size, test_size=test_size, n_windows=None)
 nf_preds = nf_preds.reset_index()

最后使用MAE和MSE来评估每个模型的性能。

 ettm2_evaluation = evaluate(df=nf_preds, metrics=[mae, mse], models=['SOFTS', 'TSMixer', 'iTransformer'])

上图中可以看到到当在96的水平上预测时,TSMixer large在ETTm2数据集上的表现优于iTransformer和soft。

虽然这与soft论文的结果相矛盾,这是因为我们没有进行超参数优化,并且使用了96个时间步长的固定范围。

这个实验的结果可能不太令人印象深刻,我们只在固定预测范围的单个数据集上进行了测试,所以这不是SOFTS性能的稳健基准,同时也说明了SOFTS在使用时可能需要更多的时间来进行超参数的优化。

总结

SOFTS是一个很有前途的基于mlp的多元预测模型,STAD模块是一种集中式方法,用于学习时间序列之间的相互作用,其计算强度低于注意力机制。这使得模型能够有效地处理具有许多并发时间序列的大型数据集。

虽然在我们的实验中,SOFTS的性能可能看起来有点平淡无奇,但请记住,这并不代表其性能的稳健基准,因为我们只在固定视界的单个数据集上进行了测试。

但是SOFTS的思路还是非常好的,比如使用集中式学习时间序列之间的相互作用,并且使用低强度的计算来保证数据计算的效率,这都是值得我们学习的地方。

并且每个问题都需要其独特的解决方案,所以将SOFTS作为特定场景的一个测试选项是一个明智的选择。

SOFTS: Efficient Multivariate Time Series Forecasting with Series-Core Fusion

https://avoid.overfit.cn/post/6254097fd18d479ba7fd85efcc49abac

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1824540.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【编程技巧】降低程序复杂度:控制逻辑与业务逻辑分离

为什么要降低代码复杂度 好的项目都是迭代出来的,所以代码肯定是会被人维护的 降低代码复杂度就是为了降低下一个维护人的维护成本,更简单地理解跟修改代码 代码组成 代码逻辑 控制逻辑 业务逻辑 控制逻辑 控制业务逻辑的代码 例如:加缓存…

计算机网络(7) 错误检测

一.校验和 使用补码计算校验和是一种常见的错误检测方法,应用于网络协议如IP和TCP。补码是二进制数的一种表示方法,可以有效地处理符号位和进位。下面是如何利用补码计算校验和的详细步骤和算数例子。 ### 计算步骤 1. **将数据分块**:将数…

七个备受欢迎的IntelliJ IDEA实用插件

有了Lombok插件,IntelliJ就能完全理解Lombok注解,使它们能如预期般工作,防止出现错误,并改善IDE的自动完成功能。 作为IntelliJ IDEA的常用用户,会非常喜欢使用它,但我们必须承认,有时这个IDE&…

Linux---系统的初步学习【 项目二 管理Linux文件和目录】

项目二 管理Linux文件和目录 2.1项目知识准备 ​ 文件是存储在计算机上的数据集合。在Windows系统中,我们理解的文件可以是文本文档、图片、程序、音乐、视频等。在Linux中,一切皆文件,也就是除了Windows中所理解的文件,目录、字…

AI模型部署:Triton Inference Server部署ChatGLM3-6B实践

前言 内容摘要 本篇先将搭建基础Triton设置模块,将ChatGLM3-6B部署为服务跑通,再加入动态批处理和模型预热来提升服务的性能和效率,包括以下几个模块 Docker镜像环境准备模型基础配置config.pbtxt自定义Python后端model.py模型服务加载卸载…

人工智能历史与现状

1 人工智能历史与现状 1.1 人工智能的概念和起源 1.1.1 人工智能的概念 人工智能 (Artificial Intelligence ,AI)是一门研究如何使计算机 能够模拟人类智能行为的科学和技术,目标在于开发能够感知、理解、 学习、推理、决策和解决问题的智能机器。人工智能的概念主要包含 以…

Stable Diffusion本地化部署详细攻略

一、硬件要求 内存:至少16GB 硬盘:至少60GB以上的磁盘空间,推荐SSD固态硬盘 显卡:推荐NVIDIA显卡 显存:至少4GB Stabl Diffusion因为是在本地部署,对显卡的要求比较高,如果经济能力可以的话…

如何打造电力全域知识中心:知识库融合知识图谱

前言 随着人工智能技术的进步,智能化成为产业转型升级的关键抓手,国家电网在“十四五”发展规划中提出加快公司数字化转型进程、推进能源互联网企业建设的要求。知识管理能力建设作为强化企如何打造电力全域知识中心:知识库融合知识图谱业运…

荣耀笔记本IP地址查看方法详解:轻松掌握网络配置技巧

在数字化时代的浪潮中,笔记本电脑已经成为我们生活和工作中不可或缺的重要工具。对于荣耀笔记本用户而言,掌握基本的网络配置技巧显得尤为重要。其中,查看IP地址是连接网络、配置设备、排除故障等场景下的关键步骤。本文将详细介绍荣耀笔记本…

Python 全栈系列252 一些小计划

说明 最近整体进展还比较顺利,不过也因为这样,好几个线头怎么继续平衡和推进需要稍微捋一下。 内容 按重要|紧急方法来看,线头1是重要且紧急的,QTV200也算重要且紧急,其他都算是重要不紧急。 线头1: 数据清洗 虽然…

电子行业实施MES管理系统的时机是什么

随着信息技术的飞速发展,MES生产管理系统逐渐成为电子企业实现自动化生产和信息化管理的必备工具。那么,何时是电子企业实施MES管理系统的最佳时机呢? 1.生产过程中出现了问题,需要优化和改进。 2.企业需要提高产品交付和响应速…

5月产品更新 | 10大更新汇总,快来看看你的需求上线了吗?

5月,Smartbi从客户需求出发,并结合企业在数据分析、处理等方面遇到的问题,对数据模型、数据指标等数十项功能进行了优化升级。 Smartbi用户可以在官网下载下载PC端,更新后便可以使用相关功能,也可以在体验中心体验相关…

第二十三节:带你梳理Vue2:Vue插槽的认识和基本使用

前言: 通过上一节的学习,我们知道了如何将数据从父组件中传递到子组件中, 除了除了将数据作为props传入到组件中,Vue还允许传入HTML, Vue 实现了一套内容分发的 API&#xff0c;这套 API 的设计灵感源自 Web Components 规范草案&#xff0c;将 <slot> 元素作为承载分发…

Rust 实战丨并发构建倒排索引

引言 继上篇 Rust 实战丨倒排索引&#xff0c;本篇我们将参考《Rust 程序设计&#xff08;第二版&#xff09;》中并发编程篇章来实现高并发构建倒排索引。 本篇主要分为以下几个部分&#xff1a; 功能展示&#xff1a;展示我们最终实现的 2 个工具的效果&#xff08;构建索…

linux系统宝塔服务器temp文件夹里总是被上传病毒php脚本

目录 简介 上传过程 修复上传漏洞 tmp文件夹总是被上传病毒文件如下图: 简介 服务器时不时的会发送短信说你服务器有病毒, 找到了这个tmp文件, 删除了之后又有了。 确实是有很多人就这么无聊, 每天都攻击你的服务器。 找了很久的原因, 网上也提供了一大堆方法,…

力扣 面试题17.04.消失的数字

数组nums包含从0到n的所有整数&#xff0c;但其中缺了一个。请编写代码找出那个缺失的整数。你有办法在O(n)时间内完成吗&#xff1f; 示例 1&#xff1a; 输入&#xff1a;[3,0,1] 输出&#xff1a;2 示例 2&#xff1a; 输入&#xff1a;[9,6,4,2,3,5,7,0,1] 输出&#x…

【qt】平面CAD(计算机辅助设计 )项目 上

CAD 一.前言二.界面设计三.提升类四.接受槽函数五.实现图形action1.矩形2.椭圆3.圆形4.三角形5.梯形6.直线7.文本 六.总结 一.前言 用我们上节课刚刚学过的GraphicsView架构来绘制一个可以交互的CAD项目! 效果图: 二.界面设计 添加2个工具栏 需要蔬菜的dd我! 添加action: …

Vue 若依框架常见问题

获取当前用户id或其它信息 user.js import { login, logout, getInfo } from /api/login import { getToken, setToken, removeToken } from /utils/authconst user {state: {token: getToken(),id: ,name: ,avatar: ,roles: [],permissions: [], shop: [] // 店铺列表},mu…

Zig标准库:最全数据结构深度解析(1)

最近新闻看到17岁中专女生拿下阿里全球数学竞赛第12名。咱们学习标准库中的数据结构是和学习数学是一脉相承的&#xff0c;结构体很多&#xff0c;也非常枯燥&#xff0c;但是不能全面解读过一遍&#xff0c;你很难写出合理的代码。所以&#xff0c;这一章节我们开始深度解析Zi…

HTML静态网页成品作业(HTML+CSS)—— 校园贷主题网页(2个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有2个页面。 二、作品演示 三、代…