时间序列预测实战(十二)DLinear模型实现滚动长期预测并可视化预测结果

news2025/1/12 12:23:33

官方论文地址->官方论文地址

官方代码地址->官方代码地址

个人修改代码->个人修改的代码已经上传CSDN免费下载

一、本文介绍

本文给大家带来是DLinear模型,DLinear是一种用于时间序列预测(TSF)的简单架构,DLinear的核心思想是将时间序列分解为趋势和剩余序列,并分别使用两个单层线性网络对这两个序列进行建模以进行预测(值得一提的是DLinear的出现是为了挑战Transformer在实现序列预测中有效性)本文的讲解内容包括:模型原理、数据集介绍、参数讲解、模型训练和预测、结果可视化、训练个人数据集,讲解顺序如下->

预测类型->单元预测、多元预测

适用对象->如果你的配置不是很好这个模型应该很适合你因为参数量很小训练速度很快

二、模型原理

DLinear模型出现是为了调整Transformer的有效性从而存在,Transformer的设计都十分的复杂和需要大量的参数,所以作者提出了一种简单的结构DLinear(参数量我实验过程中确实非常小)

DLinear的核心思想是将时间序列分解为趋势和剩余序列,并分别使用两个单层线性网络对这两个序列进行建模以进行预测。

具体地,DLinear如何工作的关键点如下

  1. 时间序列分解:DLinear将输入的时间序列分解为两部分——趋势部分和剩余部分。这种分解有助于分别处理时间序列中的长期趋势和短期波动。

  2. 单层线性网络:对于趋势和剩余序列,DLinear分别使用两个单层的线性网络进行建模。这种简单的架构使得DLinear在处理时间序列时既高效又有效。

  3. 预测任务:在进行预测时,DLinear结合这两个网络的输出来生成最终的时间序列预测。

总结->可以看出DLinear的核心结构真的十分简单就包括一个分解和两个线性网络进行建模最后经过一个简单的相加就输出了结果。

模型的网络结构图如下所示->

图片分析->可以看到和我们上面讲的一样,数据从输入进来经过两个分支,一个为趋势性一个为剩余序列,然后分别经过一个线性层处理(这里的提到的线性层就是普通的全连接层),然后将结果进行简单的拼接就完成了结果的输出(这就这样的简单模型结果比过程十分复杂的Transformer模型效果要好->我自己实验效果确实要好,我拿2020年的bestpaper和普通的Transformer都进行了对比效果确实要有提升)。

下面的图片是一个简单的线性层(普通的全连接层)提取数据的过程图->

这里把模型的代码结构放出来方便大家根据讲解和代码进行对比。

class moving_avg(nn.Module):
    """
    Moving average block to highlight the trend of time series
    """
    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        
        return x


class series_decomp(nn.Module):
    """
    Series decomposition block
    """
    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean

class Model(nn.Module):
    """
    Decomposition-Linear
    """
    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len

        # Decompsition Kernel Size
        kernel_size = 25
        self.decompsition = series_decomp(kernel_size)
        self.individual = configs.individual
        self.channels = configs.enc_in

        if self.individual:
            self.Linear_Seasonal = nn.ModuleList()
            self.Linear_Trend = nn.ModuleList()
            
            for i in range(self.channels):
                self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len))
                self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len))

                # Use this two lines if you want to visualize the weights
                # self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
                # self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
        else:
            self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len)
            self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len)
            
            # Use this two lines if you want to visualize the weights
            # self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
            # self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))

    def forward(self, x):
        # x: [Batch, Input length, Channel]
        seasonal_init, trend_init = self.decompsition(x)
        seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)
        if self.individual:
            seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device)
            trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device)
            for i in range(self.channels):
                seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:])
                trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:])
        else:
            seasonal_output = self.Linear_Seasonal(seasonal_init)
            trend_output = self.Linear_Trend(trend_init)

        x = seasonal_output + trend_output
        return x.permute(0,2,1) # to [Batch, Output length, Channel]

我看论文的内容大比分都是对比实验,因为DLinear的产生就是为了质疑Transformer所以他和各种Transformer的模型进行对比试验,因为本篇文章就是DLinear的实战案例,对比的部分我就不讲了,大家有兴趣可以看看论文内容在最上面我已经提供了链接。 

三、数据集介绍

所用到的数据集为某公司的业务水平评估和其它参数具体的内容我就介绍了估计大家都是想用自己的数据进行训练模型,这里展示部分图片给大家提供参考。

四、参数讲解

模型的参数如下(大部分都是一些公共参数并不涉及模型)->

parser = argparse.ArgumentParser(description='DLinearNet Multivariate Time Series Forecasting')
    # basic config
    parser.add_argument('--train', type=bool, default=True, help='Whether to conduct training')
    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')
    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')
    parser.add_argument('--show_results', type=bool, default=True, help='Whether show forecast and real results graph')
    parser.add_argument('--model', type=str, default='SCINet',help='Model name')

    # data loader
    parser.add_argument('--root_path', type=str, default='./data/', help='root path of the data file')
    parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
    parser.add_argument('--features', type=str, default='MS',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--checkpoints', type=str, default='./models/', help='location of model models')

    # forecasting task
    parser.add_argument('--seq_len', type=int, default=126, help='input sequence length')
    parser.add_argument('--label_len', type=int, default=64, help='start token length')
    parser.add_argument('--pred_len', type=int, default=4, help='prediction sequence length')

    # model
    parser.add_argument('--individual', action='store_true', default=False,
                        help='DLinear: a linear layer for each variate(channel) individually')
    parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
    parser.add_argument('--c_out', type=int, default=1, help='output size')
    parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')

    # optimization
    parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')
    parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate')
    parser.add_argument('--loss', type=str, default='mse', help='loss function')
    parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')

    # GPU
    parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
    parser.add_argument('--device', type=int, default=0, help='gpu')

模型的详细参数讲解如下(如果你想训练你自己的数据集可以仔细看看)->

参数名称参数类型参数讲解
0trainbool是否进行训练,如果你单纯只想进行预测设置为False即可,
1rollingforecastbool是否进行滚动预测,如果是则设置为True,如果不进行滚动预测则进行正常的预测
2rolling-data-pathstr如果进行滚动预测则需要添加新的和训练文件相同格式的数据
3show_resultsbool是否保存预测值和真实值的对比
4modelstr定义的模型名称
5root_pathstr这个才是你文件的路径,不要到具体的文件,到目录级别即可。
6data_pathstr这个填写你文件的具体名称。
7featuresstr这个是特征有三个选项M,MS,S。分别是多元预测多元,多元预测单元,单元预测单元。
8targetstr这个是你数据集中你想要预测那一列数据,假设我预测的是油温OT列就输入OT即可。
9freqstr时间的间隔,你数据集每一条数据之间的时间间隔。
10checkpointsstr训练出来的模型保存路径
11seq_lenint用过去的多少条数据来预测未来的数据
12label_lenint可以理解为更高的权重占比的部分要小于seq_len
13pred_lenint预测未来多少个时间点的数据
14enc_inint你数据有多少列,要减去时间那一列,这里我是输入8列数据但是有一列是时间所以就填写7
15dec_inint同上
16individualbool这个就是我们上面提到的两个线性层,如果为True我们则对每一个通道用单独的线性层处理,False则为所有的通道用一个线性层
17c_outint这里有一些不同如果你的features填写的是M那么和上面就一样,如果填写的MS那么这里要输入1因为你的输出只有一列数据。
18dropoutfloat这个应该都理解不说了,丢弃的概率,防止过拟合的。
19embedstr时间特征的编码方式,默认为"timeF"
20activationstr激活函数
21num_workersint线程windows大家最好设置成0否则会报线程错误,linux系统随便设置。
22train_epochsint训练的次数
23batch_sizeint一次往模型力输入多少条数据
24learning_ratefloat学习率。
25lossstr     损失函数,默认为"mse"
26lradjstr     学习率的调整方式,默认为"type1"
27use_gpubool是否使用GPU训练,根据自身来选择
28gpuintGPU的编号

五、模型训练和预测

1.项目目录结构

项目的目录构造如下->

其中data为训练用的数据放的地方,layers为模型结构存放的地方,models为训练保存的训练模型,results为可视化结果保存的图片和滚动预测的结果,util为一些工具。 

2.模型训练

当我们经过上面的参数讲解之后,我们可以开始训练模型了,控制台输出如下->

3.滚动预测 

这里进行滚动预测的控制台输出->

4.结果展示 

运行结果后,结果保存到同级目录下(下图为预测值和真实值的对比)-> 

5.结果分析 

可以看到预测值和真实值之间的差距还可以,但是这个模型的参数量少得可怜,不得不得质疑Transformer模型的有效性~

六、训练你个人数据集

这个模型我在写的过程中为了节省大家训练自己数据集,我基本上把大部分的参数都写好了,需要大家注意的就是如果要进行滚动预测下面的参数要设置为True。

    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')

如果上面的参数设置为True那么下面就要提供一个进行滚动预测的数据集该数据集的格式要和你训练模型的数据集格式完全一致(重要!!!),如果没有可以考虑在自己数据的尾部剪切一部分,不要粘贴否则数据模型已经训练过了的话预测就没有效果了。 

    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')

其它的没什么可以讲的了大部分的修改操作在参数讲解的部分我都详细讲过了,这里的滚动预测可能是大家想看的所以摘出来详细讲讲。 

总结

到此本文已经全部讲解完成了,希望能够帮助到大家,在这里也给大家推荐一些我其它的博客的时间序列实战案例讲解,其中有数据分析的讲解就是我前面提到的如何设置参数的分析博客,最后希望大家订阅我的专栏,本专栏均分文章均分98,并且免费阅读。

概念理解 

15种时间序列预测方法总结(包含多种方法代码实现)

数据分析

时间序列预测中的数据分析->周期性、相关性、滞后性、趋势性、离群值等特性的分析方法

机器学习——难度等级(⭐⭐)

时间序列预测实战(四)(Xgboost)(Python)(机器学习)图解机制原理实现时间序列预测和分类(附一键运行代码资源下载和代码讲解)

深度学习——难度等级(⭐⭐⭐⭐)

时间序列预测实战(五)基于Bi-LSTM横向搭配LSTM进行回归问题解决

时间序列预测实战(七)(TPA-LSTM)结合TPA注意力机制的LSTM实现多元预测

时间序列预测实战(三)(LSTM)(Python)(深度学习)时间序列预测(包括运行代码以及代码讲解)

时间序列预测实战(十一)用SCINet实现滚动预测功能(附代码+数据集+原理介绍)

Transformer——难度等级(⭐⭐⭐⭐)

时间序列预测模型实战案例(八)(Informer)个人数据集、详细参数、代码实战讲解

时间序列预测模型实战案例(一)深度学习华为MTS-Mixers模型

个人创新模型——难度等级(⭐⭐⭐⭐⭐)

时间序列预测实战(十)(CNN-GRU-LSTM)通过堆叠CNN、GRU、LSTM实现多元预测和单元预测

传统的时间序列预测模型(⭐⭐)

时间序列预测实战(二)(Holt-Winter)(Python)结合K-折交叉验证进行时间序列预测实现企业级预测精度(包括运行代码以及代码讲解)

时间序列预测实战(六)深入理解ARIMA包括差分和相关性分析

融合模型——难度等级(⭐⭐⭐)

时间序列预测实战(九)PyTorch实现融合移动平均和LSTM-ARIMA进行长期预测

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1198441.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Leetcode刷题详解—— 目标和

1. 题目链接:494. 目标和 2. 题目描述: 给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 或 - ,然后串联起所有整数,可以构造一个 表达式 : 例如,nums [2, 1] ,可…

【计算机网络笔记】IP分片

系列文章目录 什么是计算机网络? 什么是网络协议? 计算机网络的结构 数据交换之电路交换 数据交换之报文交换和分组交换 分组交换 vs 电路交换 计算机网络性能(1)——速率、带宽、延迟 计算机网络性能(2)…

有没有实时检测微信聊天图片的软件,只要微信收到了有二维码的图片就把它提取出来?

10-2 如果你有需要自动并且快速地把微信收到的二维码图片保存到指定文件夹的需求,那本文章非常适合你,本文章教你如何实现自动保存微信收到的二维码图片到你指定的文件夹中,助你快速扫码,比别人领先一步。 首先需要准备好的材料…

19 异步通知

一、异步通知 1. 异步通知简介 阻塞和非阻塞两种方式都是需要应用程序去主动查询设备的使用情况。 异步通知类似于驱动可以主动报告自己可以访问,应用程序获取信号后会从驱动设备中读取或写入数据。 异步通知最核心的就是信号: #define SIGHUP 1 /* 终…

openssl研发之base64编解码实例

一、base64编码介绍 Base64编码是一种将二进制数据转换成ASCII字符的编码方式。它主要用于在文本协议中传输二进制数据,例如电子邮件的附件、XML文档、JSON数据等。 Base64编码的特点如下: 字符集: Base64编码使用64个字符来表示二进制数据…

UPLAOD-LABS2

less7 任务 拿到一个shell服务器 提示 禁止上传所有可以解析的后缀 发现所有可以解析的后缀都被禁了 查看一下源代码 $is_upload false; $msg null; if (isset($_POST[submit])) {if (file_exists($UPLOAD_ADDR)) {$deny_ext array(".php",".php5&quo…

password game

目录 password game (1-2) (3) (4) (5) (6) (7) (8) (9) (10&am…

CCF ChinaSoft 2023 论坛巡礼|机器人大模型与具身智能挑战赛

2023年CCF中国软件大会(CCF ChinaSoft 2023)由CCF主办,CCF系统软件专委会、形式化方法专委会、软件工程专委会以及复旦大学联合承办,将于2023年12月1-3日在上海国际会议中心举行。 本次大会主题是“智能化软件创新推动数字经济与社…

第一百六十九回 如何修改NavigationBar的形状

文章目录 1. 概念介绍2. 使用方法3. 代码与效果3.1 示例代码3.2 运行效果 4. 内容总结 我们在上一章回中介绍了"如何修改按钮的形状"相关的内容,本章回中将介绍NavigationBar组件.闲话休提,让我们一起Talk Flutter吧。 1. 概念介绍 我们在本章…

【java进阶】集合的三种遍历(迭代器、增强for、Lambda)

目录 一、先谈集合: 二、单列集合的三种遍历方式 迭代器遍历 增强for遍历 Lambda表达式遍历 一、先谈集合: 🔥那我们平常用for循环依赖下标遍历不行嘛,这就与集合的分类有关了。 集合的体系结构: collection是单…

我的前端笔记JS

js介绍 js是编程语音,之前学的html和css是标记语言 百度搜索mdn官网就可以 语法 输出、对话框、控制台日志、输入对话框 字面量 简单理解就是看到的内容是属于什么类型,例如1232,这个是属于数字字面量

【C++】this指针讲解超详细!!!

💐 🌸 🌷 🍀 🌹 🌻 🌺 🍁 🍃 🍂 🌿 🍄🍝 🍛 🍤 📃个人主页 :阿然成长日记 …

如何用 GPTs 构建自己的知识分身?(进阶篇)

(注:本文为小报童精选文章,已订阅小报童或加入知识星球「玉树芝兰」用户请勿重复付费) 有了这些改进,你可以快速判断 GPT 助手给出的答案是真实还是「幻觉」了。 问题 在《如何用自然语言 5 分钟构建个人知识库应用&am…

【Git】Git使用Gui图形化界面,Git中SSH协议,Idea集成Git

一,Git使用Gui图形化界面 1.1 Gui的简介 Gui (Graphical User Interface)指的是图形用户界面,也就是指使用图形化方式来协同人和计算机进行交互的一类程序。它与传统的命令行界面相比,更加直观、易用,用户…

MATLAB的编程与应用,匿名函数、嵌套函数、蒙特卡洛法的掌握与使用

目录 1.匿名函数 1.1.匿名函数的定义与分类 1.2.匿名函数在积分和优化中应用 2.嵌套函数 2.1.嵌套函数的定义与分类 2.2.嵌套函数彼此调用关系 2.3.嵌套函数在积分和微分中应用 3.微分和积分 4.蒙特卡洛法 4.1.圆周率的模拟 4.2.计算N重积分(均匀分布&am…

Python与ArcGIS系列(一)ArcGIS中使用Python

目录 0 简述1 arcgis中的python窗口2 开始编写代码 0 简述 按照惯例,作为本系列专栏的第一篇,先简单地介绍下本系列文章的内容:通过python语言创建arcgis环境脚本、将脚本以工具箱形式存放在arcgis中、通过脚本自动执行地理处理、数据修复、…

Benchmarking Large Language Models in Retrieval-Augmented Generation-学习翻译

提检索增强生成中大型语言模型的基准测试文献学习 作者将在https://github.com/chen700564/RGB上发布本文的代码和RGB。 y ˇ \check{y} yˇ​ 文章目录 摘要IntroductionRelated workRetrieval-Augmented Generation BenchmarkRAG所需能力数据构建评估指标 ExperimentsSetting…

kubeadm部署k8s及高可用

目录 CNI 网络组件 1、flannel的功能 2、flannel的三种模式 3、flannel的UDP模式工作原理 4、flannel的VXLAN模式工作原理 5、Calico主要组成部分 6、calico的IPIP模式工作原理 7、calico的BGP模式工作原理 8、flannel 和 calico 的区别 Kubeadm部署k8s及高可用 1、…

Vue3 源码解读系列(四)——组件更新

组件更新 组件更新流程: 从头部开始同步 从尾部开始同步 挂载剩余的新节点 删除多余的旧节点 处理未知的子序列 当两个节点类型相同时,执行更新操作当新子节点中没有旧子节点中的某些节点时,执行删除操作当新子节点中多了旧子节点中没有…

小样本目标检测(Few-Shot Object Detection)综述

背景 前言:我的未来研究方向就是这个,所以会更新一系列的文章,就关于FSOD,如果有相同研究方向的同学欢迎沟通交流,我目前研一,希望能在研一发文,目前也有一些想法,但是具体能不能实现还要在做的过程中慢慢评估和实现.写文的主要目的还是记录,避免重复劳动,我想用尽量简洁的语言…