时间序列预测实战(十一)用SCINet实现滚动预测功能(附代码+数据集+原理介绍)

news2024/12/31 5:11:08

论文地址->SCINet官方论文地址

官方代码地址-> 官方代码下载地址

个人整理的代码地址->免费分享给大家创作不易请大家给文章点点赞

一、本文介绍

这篇文章给大家带来的是关于SCINet实现时间序列滚动预测功能的讲解,SCINet是样本卷积交换网络的缩写(Sample Convolutional Interchange Network),SCINet号称是比现有的卷积模型和基于Transformer的模型准确率都有提升(我实验了几次效果确实不错)。本篇文章讲解的代码是我个人根据官方的代码总结出来的模型结构并且进行改进增加了滚动预测的功能。本篇实战案例中包括->详细的参数讲解、改进方向、数据集介绍、模型框架原理、项目结构、如何训练个人数据集的教程、以及结果分析和结果展示。本篇文章的讲解流程为->

适用对象->适合对精度有比较高要求的学习者

预测类型->单元变量预测、多元变量预测

二、模型框架原理

1.SCINet基本原理

SCINet是一个层次化的降采样-卷积-交互TSF框架,有效地对具有复杂时间动态的时间序列进行建模。通过在多个时间分辨率上迭代提取和交换信息,可以学习到具有增强可预测性的有效表示。此外,SCINet的基础构件,SCI-Block,通过将输入数据/特征降采样为两个子序列,然后使用不同的卷积滤波器提取每个子序列的特征。为了补偿降采样过程中的信息损失,每个SCI-Block内部都加入了两种卷积特征之间的交互学习。

个人总结:SCINet就是在不同的维度上面对数据进行处理进行特征提取工作,从而获得不同层次的特征。这点有点类似于目标检测的YOLO系列模型,一张图片进行不断的缩放和扩大获取不同层次的特征,然后对这些特征进行操作,既节省算力又提高精度。(SCINet引入了一个新的概念时间分辨率大家可以注意一下)

2.SCINet基本组件

 下图为SCINet网络结构图

SCINet采用编码器-解码器架构。编码器是一个分层卷积网络,通过丰富的卷积滤波器捕捉多分辨率下的动态时间依赖性。其基本构件SCI-Block将输入数据或特征降采样为两个子序列,然后用一组卷积滤波器处理每个子序列,从每部分中提取独特但有价值的时间特征。为了补偿降采样中的信息损失,它允许两个子序列之间的交互学习。SCINet通过将多个SCI-Blocks排列成二叉树结构来构建。这种设计的一个显著优势是每个SCI-Block都对整个时间序列有局部和全局视角,从而有助于提取有用的时间特征。经过所有降采样-卷积-交互操作后,将提取的特征重新排列成新的序列表示,并将其加入原始时间序列中,用全连接网络作为解码器进行预测。

个人总结:SCINet就是将多个SCI-Block用二叉树的结构堆叠起来,然后提取不从层次的特征,然后从新排列起来,然后经过一个全连接层进行预测。

改进方案:这里其实有改进的空间,经过我的训练过程我发现这个模型训练时间还是比较长的,就是因为他堆叠多个层的SCI-Block这里我觉得可以结和一些新的结构进行改造的,类似于不进行二叉树的操作,但是将不同层次的特征融合起来,有兴趣的小伙伴可以研究一下,没准能发个论文毕竟文章的简写就是SCI~~。

三、数据集介绍 

这个模型用了两个数据集进行测试,一个是某个公司的话务员接线量一个是油温效果都不错,下面讲解用油温的数据集来进行讲解和结果分析。

数据集的部分截图如下->其具有八列数据‘OT’其中间的关系为化学关系比较固定,为油温度。

四、参数讲解 

模型的全部参数如下->

    parser.add_argument('--train', type=bool, default=True, help='Whether to conduct training')
    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')
    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')
    parser.add_argument('--model', type=str, default='Transformer',
                        help='model name, options: [Transformer, Linear, NLinear, DLinear, SCINet, ConvFC, MTSMixer, MTSMatrix, FNet]')

    # data loader
    parser.add_argument('--root_path', type=str, default='./', help='root path of the data file')
    parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
    parser.add_argument('--features', type=str, default='MS',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--checkpoints', type=str, default='./models/', help='location of model models')

    # forecasting task
    parser.add_argument('--seq_len', type=int, default=32, help='input sequence length')
    parser.add_argument('--label_len', type=int, default=8, help='start token length')
    parser.add_argument('--pred_len', type=int, default=4, help='prediction sequence length')

    # model
    parser.add_argument('--rev', action='store_true', default=False, help='whether to apply RevIN')
    parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
    parser.add_argument('--c_out', type=int, default=1, help='output size')
    parser.add_argument('--d_model', type=int, default=512, help='dimension of model')

    parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')

    # optimization
    parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')
    parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate')
    parser.add_argument('--loss', type=str, default='mse', help='loss function')
    parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')

    # GPU
    parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
    parser.add_argument('--device', type=int, default=0, help='gpu')

模型的详细参数讲解如下-> 

参数名称参数类型参数讲解
0trainbool是否进行训练,如果你单纯只想进行预测设置为False即可,
1rollingforecastbool是否进行滚动预测,如果是则设置为True,如果不进行滚动预测则进行正常的预测
2rolling-data-pathstr如果进行滚动预测则需要添加新的和训练文件相同格式的数据
3modelstr定义的模型名称
4root_pathstr这个才是你文件的路径,不要到具体的文件,到目录级别即可。
5data_pathstr这个填写你文件的具体名称。
6featuresstr这个是特征有三个选项M,MS,S。分别是多元预测多元,多元预测单元,单元预测单元。
7targetstr这个是你数据集中你想要预测那一列数据,假设我预测的是油温OT列就输入OT即可。
8freqstr时间的间隔,你数据集每一条数据之间的时间间隔。
9checkpointsstr训练出来的模型保存路径
10seq_lenint用过去的多少条数据来预测未来的数据
11label_lenint可以理解为更高的权重占比的部分要小于seq_len
12pred_lenint预测未来多少个时间点的数据
13enc_inint你数据有多少列,要减去时间那一列,这里我是输入8列数据但是有一列是时间所以就填写7
14dec_inint同上
15c_outint这里有一些不同如果你的features填写的是M那么和上面就一样,如果填写的MS那么这里要输入1因为你的输出只有一列数据。
16d_modelint用于设置模型的维度,默认值为512。可以根据需要调整该参数的数值来改变模型的维度
15n_headsint用于设置模型中的注意力头数。默认值为8,表示模型会使用8个注意力头,我建议和的输入数据的总体保持一致,列如我输入的是8列数据不用刨去时间的那一列就输入8即可。
17e_layersint用于设置编码器的层数
18d_layersint用于设置解码器的层数
19s_layersstr用于设置堆叠编码器的层数
20dropoutfloat这个应该都理解不说了,丢弃的概率,防止过拟合的。
21embedstr时间特征的编码方式,默认为"timeF"
22activationstr激活函数
23num_workersint线程windows大家最好设置成0否则会报线程错误,linux系统随便设置。
24train_epochsint训练的次数
25batch_sizeint一次往模型力输入多少条数据
26learning_ratefloat学习率。
27lossstr     损失函数,默认为"mse"
28lradjstr     学习率的调整方式,默认为"type1"
29use_gpubool是否使用GPU训练,根据自身来选择
30gpuintGPU的编号

五、项目结构 

项目的构造目录如下->

其中data用于方法训练数据,layers用于存放模型,models用于存放训练的保存结果,results用于存放模型的预测结果为CSV的输出格式文件,util用于存放一些工具。 

六、训练和预测

1.训练模型

经过参数的讲解我们已经定义好了所有的参数,可以开始训练了,我的完整main.py文件调好参数的内容如下->

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='SCINet Multivariate Time Series Forecasting')
    # basic config
    parser.add_argument('--train', type=bool, default=True, help='Whether to conduct training')
    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')
    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')
    parser.add_argument('--model', type=str, default='SCINet',help='Model name')

    # data loader
    parser.add_argument('--root_path', type=str, default='./data/', help='root path of the data file')
    parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
    parser.add_argument('--features', type=str, default='MS',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--checkpoints', type=str, default='./models/', help='location of model models')

    # forecasting task
    parser.add_argument('--seq_len', type=int, default=32, help='input sequence length')
    parser.add_argument('--label_len', type=int, default=8, help='start token length')
    parser.add_argument('--pred_len', type=int, default=4, help='prediction sequence length')

    # model
    parser.add_argument('--rev', action='store_true', default=False, help='whether to apply RevIN')
    parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
    parser.add_argument('--c_out', type=int, default=1, help='output size')
    parser.add_argument('--d_model', type=int, default=512, help='dimension of model')

    parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')

    # optimization
    parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')
    parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate')
    parser.add_argument('--loss', type=str, default='mse', help='loss function')
    parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')

    # GPU
    parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
    parser.add_argument('--device', type=int, default=0, help='gpu')
    args = parser.parse_args()
    Exp = SCINetinitialization
    # setting record of experiments
    setting = 'predict-{}-data-{}'.format(args.model, args.data_path[:-4])

    SCI = SCINetinitialization(args)  # 实例化模型
    if args.train:
        print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(args.model))
        SCI.train(setting)
    print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(args.model))
    SCI.predict(setting, True)

我们进行执行控制台开始输出训练结果->

 训练完成后,模型保存到该目录下->

2.开始预测 

训练完成之后,开始进行预测,控制台会进行结果的输出,最后会生成结果文件。

结果文件输出在如下目录->

3.1结果展示 

下面的图片是预测值和真实值的对比图,大家可以看出预测结果还是非常不错的。 

下面的图片为MAE的损失图->

3.2结果分析 

可以看出虽然预测结果还可以接受,但是其中存在明显的数据滞后性,这个问题其实是时间序列预测的通病,目前想要解决两种方法:

  • 一种方法是通过损失精度然后进行数据的预处理操作
  • 另一种就是结合其它能够处理数据滞后性的模型进行模型融合的操作。

后期我也会进行模型融合尝试如果大家需要可以在评论区留言想要看和其它什么模型结合。

七、训练你个人数据集 

这个模型我在写的过程中为了节省大家训练自己数据集,我基本上把大部分的参数都写好了,需要大家注意的就是如果要进行滚动预测下面的参数要设置为True。

    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')

如果上面的参数设置为True那么下面就要提供一个进行滚动预测的数据集该数据集的格式要和你训练模型的数据集格式完全一致(重要!!!),如果没有可以考虑在自己数据的尾部剪切一部分,不要粘贴否则数据模型已经训练过了的话预测就没有效果了。 

    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')

其它的没什么可以讲的了大部分的修改操作在参数讲解的部分我都详细讲过了,这里的滚动预测可能是大家想看的所以摘出来详细讲讲。 

总结  

到此本文已经全部讲解完成了,希望能够帮助到大家,如果你用我的代码可能会存在一些Bug但是肯定不影响运行,如果大家有发现任何的bug可以和我私信沟通,或者评论区留言我可以,大家可以进行探讨。在这里也给大家推荐一些我其它的博客的时间序列实战案例讲解,其中有数据分析的讲解就是我前面提到的如何设置参数的分析博客,最后希望大家订阅我的专栏,本专栏均分文章均分98,并且免费阅读。

时间序列预测:深度学习、机器学习、融合模型、创新模型实战案例(附代码+数据集+原理介绍)

时间序列预测模型实战案例(十)(个人创新模型)通过堆叠CNN、GRU、LSTM实现多元预测和单元预测

时间序列预测中的数据分析->周期性、相关性、滞后性、趋势性、离群值等特性的分析方法

时间序列预测模型实战案例(八)(Informer)个人数据集、详细参数、代码实战讲解

时间序列预测模型实战案例(七)(TPA-LSTM)结合TPA注意力机制的LSTM实现多元预测

时间序列预测模型实战案例(六)深入理解机器学习ARIMA包括差分和相关性分析

时间序列预测模型实战案例(五)基于双向LSTM横向搭配单向LSTM进行回归问题解决

时间序列预测模型实战案例(四)(Xgboost)(Python)(机器学习)图解机制原理实现时间序列预测和分类(附一键运行代码资源下载和代码讲解)

时间序列预测模型实战案例(三)(LSTM)(Python)(深度学习)时间序列预测(包括运行代码以及代码讲解)

【全网首发】(MTS-Mixers)(Python)(Pytorch)最新由华为发布的时间序列预测模型实战案例(一)(包括代码讲解)实现企业级预测精度包括官方代码BUG修复Transform模型

时间序列预测模型实战案例(二)(Holt-Winter)(Python)结合K-折交叉验证进行时间序列预测实现企业级预测精度(包括运行代码以及代码讲解)

最后希望大家工作顺利学业有成!

​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1194770.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C# .NET Core API 注入Swagger

C# .NET Core API 注入Swagger 环境 Windows 10Visual Studio 2019(2017就有可以集中发布到publish目录的功能了吧)C#.NET Core 可跨平台发布代码,超级奈斯NuGet 套件管理dll将方法封装(据说可以提高效率,就像是我们用的dll那种感觉)Swagger 让接口可视化编写时间2020-12-09 …

【Python爬虫】网页抓取实例之淘宝商品信息抓取

之前我们已经说过网页抓取的相关内容 上次我们是以亚马逊某网页的产品为例 抓取价格、品牌、型号、样式等 该网页上价格、品牌、型号、样式等 都只有一个 如果网页上的目标内容 根据不同规格有多个 又该怎么提取呢&#xff1f; ▼如下图所示 当机身颜色、套餐、存储容量…

【MATLAB源码-第73期】基于matlab的OFDM-IM索引调制系统不同子载波数目误码率对比,对比OFDM系统。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 OFDM-IM索引调制技术是一种新型的无线通信技术&#xff0c;它将正交频分复用&#xff08;OFDM&#xff09;和索引调制&#xff08;IM&#xff09;相结合&#xff0c;以提高频谱效率和系统容量。OFDM-IM索引调制技术的基本思想…

ARM IMX6ULL 基础学习记录 / ARM 寄存器介绍

编辑整理 by Staok。 本文大部分内容摘自“100ask imx6ull”开发板的配套资料&#xff08;如《IMX6ULL裸机开发完全手册》等等&#xff09;&#xff0c;侵删。进行了精髓提取&#xff0c;方便日后查阅。过于基础的内容不会在此提及。如有错误恭谢指出&#xff01; 注&#xf…

Django ModelSerializer 实现自定义验证详解

随着 Web 开发的日益复杂化&#xff0c;对数据验证的需求也日益增加。Django REST framework 提供了一套强大的、灵活的验证系统&#xff0c;帮助开发者轻松处理各种复杂情况。本文将重点探讨 Django ModelSerializer 中如何实现自定义验证。 1. 简介 Django ModelSerializer…

openinstall携手途虎养车,赋能汽车服务数字化

近日&#xff0c;openinstall与中国领先的一站式汽车服务平台途虎养车再次续约&#xff0c;双方将开启第三年合作。过去两年&#xff0c;途虎在建设线上线下一体化数字平台的过程中&#xff0c;深度结合openinstall传参归因与渠道统计技术&#xff0c;打造出了一套高效的渠道来…

第12章 PyTorch图像分割代码框架-3:推理与部署

推理模块 模型训练完成后&#xff0c;需要单独再写一个推理模块来供用户测试或者使用&#xff0c;该模块可以命名为test.py或者inference.py&#xff0c;导入训练好的模型文件和待测试的图像&#xff0c;输出该图像的分割结果。inference.py主体部分如代码11-7所示。 代码11-7 …

【MATLAB源码-第74期】基于matlab的OFDM-IM索引调制系统不同频偏误码率对比,对比OFDM系统。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 OFDM-IM索引调制技术是一种新型的无线通信技术&#xff0c;它将正交频分复用&#xff08;OFDM&#xff09;和索引调制&#xff08;IM&#xff09;相结合&#xff0c;以提高频谱效率和系统容量。OFDM-IM索引调制技术的基本思想…

Spring -Spring之依赖注入源码解析(下)--实践(流程图)

IOC依赖注入流程图 注入的顺序及优先级&#xff1a;type-->Qualifier-->Primary-->PriOriry-->name

如何使用HadSky搭配内网穿透工具打造个人站点并公网访问

&#x1f308;个人主页&#xff1a;聆风吟 &#x1f525;系列专栏&#xff1a;Cpolar杂谈、数据结构、算法模板 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 前言一. 网站搭建1.1 网页下载和安装1.2 网页测试1.3 cpolar的安装和注册 二. 本地网页发…

[工业自动化-10]:西门子S7-15xxx编程 - PLC主站 - 信号量:数字量

目录 前言&#xff1a; 一、工业现场常见信号的分类 二、IO数字量模块 2.1 概述 2.2 PLC的数字量是24V还是5V电压&#xff1f; 2.2 数字量模块的安装与接线 2.3 数字量模的注意事项 前言&#xff1a; 一、工业现场常见信号的分类 在工业自动化领域&#xff0c;常常需要使…

3DMAX汽车绑定动画模拟插件MadCar疯狂汽车使用教程

3DMAX汽车绑定动画模拟插件MadCar疯狂的汽车&#xff0c;用于通过模拟控制来快速装配轮式车辆及其动画。这个新版本允许装配任何数量的车轮的车辆&#xff0c;以及包括摩托车在内的任何相互布置。还支持任意数量的拖车。 每个车轮和悬架都有简化的行为设置以及微调&#xff0c…

xss 通过秘籍

终极测试代码 <sCr<ScRiPt>IPT>OonN"\/(hrHRefEF)</sCr</ScRiPt>IPT> 第一关&#xff08;没有任何过滤&#xff09; 使用终极测试代码&#xff0c;查看源码 发现没有任何过滤&#xff0c;直接使用javascrupt中的alert弹框 <script>aler…

企业级操作之STM32项目版本管理方法

在MCU开发过程中&#xff0c;有时候需要软件的迭代&#xff0c;比如从V1.9升级到V1.10&#xff0c;或者从V23.09.23升级到V23.09.24&#xff0c;我们常常通过手动改动字符串或者数组来实现这个功能&#xff0c;从现在开始&#xff0c;我们会使用Keil的内置宏__DATE__和__TIME__…

局域网内部服务器访问外部网络

​ 一、环境说明 如下图所示&#xff0c;局域网1中的服务器是可以访问外网的&#xff0c;局域网2中的服务器发出的数据包经过中间路由可以到达局域网1中的服务器。现在有一种需求需要使局域网2中的服务器也要能访问外网&#xff0c;这里考虑采用如下方法来实现。 ​​ 二、软…

基于element-plus定义表单配置化

文章目录 前言一、配置化的前提二、配置的相关组件1、新建form.vue组件2、新建input.vue组件3、新建select.vue组件4、新建v-html.vue组件5、新建upload.vue组件6、新建switch.vue组件7、新建radio.vue组件8、新建checkbox.vue组件9、新建date.vue组件10、新建time-picker.vue组…

Pytorch实战教程(一)-神经网络与模型训练

0. 前言 人工神经网络 (Artificial Neural Network, ANN) 是一种监督学习算法,其灵感来自人类大脑的运作方式。类似于人脑中神经元连接和激活的方式,神经网络接受输入,通过某些函数在网络中进行传递,导致某些后续神经元被激活,从而产生输出。函数越复杂,网络对于输入的数…

传统企业数字化转型都要面临哪些挑战?_数据治理平台_光点科技

数字化转型已经成为传统企业发展的必经之路&#xff0c;但在这个过程中&#xff0c;企业往往会遭遇多方面的挑战。 1.文化和组织惯性 最大的挑战之一是企业文化和组织惯性的阻力。传统企业往往有着深厚的历史和根深蒂固的工作方式&#xff0c;员工和管理层可能对新的数字化工作…

FFMPEG库实现mp4/flv文件(H264+AAC)的封装与分离

ffmepeg 4.4&#xff08;亲测可用&#xff09; 一、使用FFMPEG库封装264视频和acc音频数据到 mp4/flv 文件中 封装流程 1.使用avformat_open_input分别打开视频和音频文件&#xff0c;初始化其AVFormatContext&#xff0c;使用avformat_find_stream_info获取编码器基本信息 2.使…

一文入门Springboot+actuator+Prometheus+Grafana

环境介绍 技术栈 springbootmybatis-plusmysqloracleactuatorPrometheusGrafana 软件 版本 mysql 8 IDEA IntelliJ IDEA 2022.2.1 JDK 1.8 Spring Boot 2.7.13 mybatis-plus 3.5.3.2 本地主机应用 192.168.1.9:8007 PrometheusGrafana安装在同一台主机 http://…