时间序列模型SCINet模型(自定义项目)

news2024/10/6 2:30:01

前言

  • 读完代码解析篇,我们针对开源项目中的模型预测方法做一下介绍。Github源码下载地址
  • 下载数据集ETThPEMSTrafficSplar-EnergyElectricityExchange-Rate,这几类公共数据集的任意一类就行。这里以ETTh数据集为例,先在项目文件夹下新建datasets文件夹,然后将数据集移至其中
  • 打开项目文件夹下run_ETTh.py文件,只需要检查一下数据路径、名称和csv文件就行
# 数据名称
parser.add_argument('--data', type=str, required=False, default='ETTh1', choices=['ETTh1', 'ETTh2', 'ETTm1'], help='name of dataset')
# 数据路径
parser.add_argument('--root_path', type=str, default='./datasets/', help='root path of the data file')
# 数据文件
parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='location of the data file')
  • 然后跑一下,看看跑不跑的通,注意一定要是在GPU环境下,否则报错,后面的自定义项目是建立在原代码能跑通的情况下。

自定义项目

参数设定修改

  • 首先将需要预测的数据集放入datasets文件夹中,时间列列名必须为date。
  • 然后我们复制run_ETTh.py文件,并粘贴在项目文件夹下,重命名为run_power.py这个名字随便取,别和已有文件重复就行。
  • 打开run_power.py文件,修改开头库导入部分,主要是最后一句,要导入Exp_power
import argparse
import os
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from experiments.exp_power import Exp_power
  • 修改数据载入部分,包括数据名称、路径、文件、目标预测列、采样间隔(我用的数据集是每1分钟收集一次数据,所以参数设为t
# 数据名称
parser.add_argument('--data', type=str, required=False, default='power data', choices=['ETTh1', 'ETTh2', 'ETTm1'], help='name of dataset')
# 数据路径
parser.add_argument('--root_path', type=str, default='./datasets/', help='root path of the data file')
# 文件名
parser.add_argument('--data_path', type=str, default='power data.csv', help='location of the data file')
# 多变量预测
parser.add_argument('--features', type=str, default='M', choices=['S', 'M'], help='features S is univariate, M is multivariate')
# 目标列
parser.add_argument('--target', type=str, default='总有功功率(kw)', help='target feature')
# 采样间隔(分钟)
parser.add_argument('--freq', type=str, default='t', help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
# 模型保存路径
parser.add_argument('--checkpoints', type=str, default='exp/ETT_checkpoints/', help='location of model checkpoints')
# 是否翻转时间序列
parser.add_argument('--inverse', type=bool, default =False, help='denorm the output data')
# 选择时间编码方式
parser.add_argument('--embed', type=str, default='timeF', help='time features encoding, options:[timeF, fixed, learned]')
  • 修改项目预测需求以及回视窗口等参数
parser.add_argument('--seq_len', type=int, default = 480, help='input sequence length of SCINet encoder, look back window')
parser.add_argument('--label_len', type=int, default = 288, help='start token length of Informer decoder')
parser.add_argument('--pred_len', type=int, default = 960, help='prediction sequence length, horizon')
  • 再修改特征数量设置,data_parser变量,在run_power.py文件中只需要更改这些。
data_parser = {'power data': {'data': 'power data.csv', 'T': '总有功功率(kw)', 'M': [5, 5, 5], 'S': [1, 1, 1], 'MS': [5, 5, 1]},}
  • 注意,如果需要模型输出中间结果,即预测值、真实值,测试值等,请将--save参数置为True
parser.add_argument('--save', type=bool, default = True, help='save the output results')
  • 同样的,打开experiments文件夹,复制exp_ETTh.py文件,并粘贴在同目录中,重命名为exp_power.py,并将其中Exp_ETTh类修改为Exp_power
class Exp_power(Exp_Basic):

数据处理

  • 打开experiments文件夹下exp_power.py文件,修改Exp_power类下_build_model函数,in_dim函数修改为数据特征数,我这里是5,所以in_dim = 5
def _build_model(self):

        if self.args.features == 'S':
            in_dim = 1
        elif self.args.features == 'M':
            # 自定义项目需要修改
            in_dim = 5
        else:
            print('Error!')
  • 再跳转到_get_data函数,修改data_dict
data_dict = {'power data': Dataset_Custom}
  • exp_power.py文件中只需要更改这些。到此为止,项目修改工作结束,这时跑一下run_power.py函数看看能否跑的通。

在kaggle上使用

  • 因为该源码只支持在GPU上运行,若使用的设备没有GPU,我们可以将项目文件搬到kaggle上进行,首先还是要根据上述说明修改好项目文件,然后打包成zip文件上传至kaggle数据集中。
  • 新建notebook文件,并将其设置为P100GPU模式下
    在这里插入图片描述

导入包

  • 加入环境变量
import sys
if not '/kaggle/input/scinet-model-data' in sys.path:
    sys.path += ['/kaggle/input/scinet-model-data']
  • 导入必要包
import argparse
import os
import torch
import numpy as np
import optuna
from torch.utils.tensorboard import SummaryWriter
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from experiments.exp_power import Exp_power

参数传导

args = argparse.ArgumentParser(description='SCINet on ETT dataset')

args.model = 'SCINet'

args.data = 'power data'
args.root_path = '/kaggle/input/scinet-model-data/datasets/'
args.data_path = 'power data.csv'
args.features ='M'
args.target = '总有功功率(kw)'
args.freq = 't'
args.checkpoints = 'exp/power_checkpoints/'
args.inverse = False
args.embed ='timeF'


### -------  device settings --------------
args.use_gpu = True
args.gpu = 0
args.use_multi_gpu = False
args.devices = '0'

### -------  input/output length settings --------------                                                                            
args.seq_len = 480
args.label_len = 288
args.pred_len = 960
args.concat_len = 0
args.single_step = 0
args.single_step_output_One = 0
args.lastWeight = 1.0

### -------  training settings --------------  
args.cols = False
args.num_workers = 0
args.itr = 0
args.train_epochs = 100
args.batch_size = 128
args.patience = 5
args.lr = 1e-4
args.loss = 'SmoothL1Loss'
args.optim = 'AdamW'
args.lradj = 1
args.use_amp = False
# 是否保存结果文件
args.save = True
args.model_name = 'SCINet'
args.resume = False
args.evaluate = False

### -------  model settings --------------  
args.hidden_size = 1.995
args.INN = 1
args.kernel = 7
args.dilation = 1
args.window_size = 480
args.dropout = 0.5
args.positionalEcoding = False
args.groups = 1
args.levels = 3
args.stacks = 2
args.num_decoder_layer = 1
args.RIN = False
args.decompose = False

检查GPU

args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False

if args.use_gpu and args.use_multi_gpu:
    args.devices = args.devices.replace(' ', '')
    device_ids = args.devices.split(',')
    args.device_ids = [int(id_) for id_ in device_ids]
    args.gpu = args.device_ids[0]

定义数据加载

data_parser = {'power data': {'data': 'power data.csv', 'T': '总有功功率(kw)', 'M': [5, 5, 5], 'S': [1, 1, 1], 'MS': [5, 5, 1]},}

if args.data in data_parser.keys():
    data_info = data_parser[args.data]
    args.data_path = data_info['data']
    args.target = data_info['T']
    args.enc_in, args.dec_in, args.c_out = data_info[args.features]

args.detail_freq = args.freq
args.freq = args.freq[-1:]

复现设置

torch.manual_seed(2023)  # reproducible
torch.cuda.manual_seed_all(2023)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True  # Can change it to False --> default: False
torch.backends.cudnn.enabled = True

训练模型

Exp = Exp_power

mae_ = []
maes_ = []
mse_ = []
mses_ = []

setting = '{}_levels {}_kernel {}_hidden {}'.format(args.model,args.levels,args.kernel,args.hidden_size)
exp = Exp(args)  # set experiments
exp.train(setting)
mae, maes, mse, mses = exp.test(setting)
print('{:s}:{:.4f},mae:{:.4f}'.format(setting, mse, mae))

模型参数调节

  • 根据论文,较为重要的几个参数分别为--kernel--levels--stacks--hidden_size其次是--lr--dropout,我的建议是先调节联合调节--kernel--levels--stacks参数,然后再调节--lr--dropout
  • 参数调节范围建议:
    • --kernel整型[1,7]
    • --levels整型[1,5]
    • --stacks整型[1,2]
    • --hidden_size浮点型[0.1,2]
  • 调参时可以选择rayoptuna等智能调参框架,也可以选择写脚本进行网格搜索等等。但该模型时空复杂度并不低,这里还是建议选用智能优化调参框架,在时间尽可能短的情况下锁定局部最优解。

后记

  • 这里放一张模型训练完成后绘制的真实值与预测值间对比图
    请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/164842.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

whistle抓包工具应用

原文地址:(67条消息) whistle抓包工具学习_BBC蟹耳总的博客-CSDN博客_w2 抓包 一、安装whistle 首先安装好whistle抓包工具,有以下两个步骤 在终端中全局安装whistle:npm install -g whistle可以通过whistle help查看相关信息,…

《零基础学机器学习》读书笔记一

《零基础学机器学习》读书笔记一 一、机器学习快速上手路径 1.1 机器学习的家族谱 人工智能,可以被简单地定义为努力将通常由人类完成的智力任务自动化。 AI效应的2个阶段: (1)AI将新技术、新体验带进人类的生活,完…

linux环境安装mysql5.7版本

目录 一、下载准备阶段 二、安装运行阶段 linux环境安装mysql是我们工程师必备的技能之一,今天我们实战分享一下安装流程: 一、下载准备阶段 1、查看linux系统是否已经安装mysql rpm -qa|grep -i mysql 显示没有 如果安装过,可以删除&…

DFS排列组合与连通性

目录 一、前言 二、DFS与排列组合 1、DFS:自写排列算法1 (1)基础模板 (2)基于(1)输出前n个数任意m个都全排列 2、DFS:自写排列算法2(这个写法更常见) (1&#xff…

Java安装详细教程

这里写自定义目录标题Java安装详细教程1.下载Java2,找到jdk8进行下载3.安装jdk4.配置环境变量5.查看是否已经成功安装Java安装详细教程 换了新电脑了,需要安装Java,如果对你也有帮助就点个赞吧~~ 文章目录Java安装详细教程1.下载Java2&#…

一阶低通滤波器学习

导读:电压型磁链观测器由于物理概念清晰、简单易用而备受关注。然而电压型磁链观测器包含一纯积分项,被积项的初始相位与直流偏置都会影响积分结果。所以对传统电压型磁链观测器的改进措施有很多,本期文章主要介绍采用一阶低通滤波器来替换掉…

Java程序设计实验2 | Java语言基础

*本文是博主对Java各种实验的再整理与详解,除了代码部分和解析部分,一些题目还增加了拓展部分(⭐)。拓展部分不是实验报告中原有的内容,而是博主本人自己的补充,以方便大家额外学习、参考。 目录 一、实验…

微信小程序使用npm包、全局数据共享和分包

文章目录导航路线使用 npm 包小程序对 npm 的支持与限制Vant Weapp1. 什么是 Vant Weapp2. 安装 Vant 组件库3. 使用 Vant 组件4. 定制全局主题样式5. 定制全局主题样式API Promise化1. 基于回调函数的异步 API 的缺点2. 什么是 API Promise 化3. 实现 API Promise 化4. 调用 P…

鸣人的影分身(动态规划 | DP | 整数划分模型)[《信息学奥赛一本通》]

题目如下: 在火影忍者的世界里,令敌人捉摸不透是非常关键的。 我们的主角漩涡鸣人所拥有的一个招数——多重影分身之术——就是一个很好的例子。 影分身是由鸣人身体的查克拉能量制造的,使用的查克拉越多,制造出的影分身越强。…

6.R语言【频数、频率统计函数】一维、二维、三维

b站课程视频链接: https://www.bilibili.com/video/BV19x411X7C6?p1 腾讯课堂(最新,但是要花钱,我花99😢😢元买了,感觉讲的没问题,就是知识点结构有点乱,有点废话)&…

PostgreSQL数据库FDW——Parquet S3 MultifileMergeExecutionStateBaseS3

MultifileMergeExecutionStateBaseS3和SingleFileExecutionStateS3、MultifileExecutionStateS3类不同,reader成员被替换为ParquetReader *类型的readers vector。新增slots_initialized布尔变量指示slots成员是否已经初始化。slots成员是Heap类,Heap用于…

重装系统Windows10纯净版操作步骤(微pe)

目录 前言 操作步骤 第一步:格式化硬盘 第二步:硬盘重新分区 固态硬盘分区 机械硬盘分区 完成效果展示 第三步:把ISO镜像文件写入固态硬盘 第四步:关机拔u盘 第五步:开机重装系统成功 前言 1.要重装系统&am…

Webpack提取页面公共资源

1. 利用html-webpack-externals-plugin 分离基础库 在做React开发时,经常需要引入react和react-dom基础库,这样在打包的时候速度就会比较慢,这种情况下我们可以将这些基础库忽略掉,将它们通过CDN的方式直接引入,而不打…

apache和IIS区别?内网本地服务器项目怎么让外网访问?

Apache和IIS是比较常用的搭建服务器的中间件,它们之间还是有一些区别差异的,下面就详细说说 Apache和IIS有哪些区别,以及如何利用快解析实现内网主机应用让外网访问。 1.安全性 首先说说apache和IIS最基本的区别。Apache运行的操作系统通常为…

Python数学建模问题总结(3)数据可视化Cookbook指南·下

概括总结:五、样式:优化图表、数据可视1.形状:形状的精确程度;2.颜色:区分类别、表示数量、突出特定数据、表示含义;3.线:点划线或不同的不透明度;4.文字排版:应用于图表…

IOC/DI配置管理第三方bean及注解开发。

目录 一、IOC/DI 配置管理第三方bean 1、配置第三方bean 2、加载properties 文件 3、核心容器 二、注解开发 1、注解开发定义bean 2、纯注解开发模式 3、注解开发bean作用范围与生命周期管理 4、注解开发依赖注入 三、IOC/DI注解开发管理第三方bean 1、注解开发管…

深度学习中有哪些从数学模型或相关理论出发, 且真正行之有效的文章?

自深度学习兴起后,深层网路对图像进行特征学习,将低层次的基础特征聚合成更高级的语义特征,取得突出的识别效果,在图像识别、分割及目标检测三大领域得到了众多应用。深度学习算法基本上是由多个网络层搭建,每个网络层…

SpringBoot自动装配

前言 Spring翻译为中文是“春天”,的确,在某段时间内,它给Java开发人员带来过春天,但是随着我们项目规模的扩大,Spring需要配置的地方就越来越多,夸张点说,“配置两小时,Coding五分…

Open3D Usage

Open3D UsageWhat is open3Dopen3D 核心功能包括:python quick start交互指令显示点云**read_point_cloud** ParametersReturnPointCloud的属性:加载ply点云:显示单帧点云:批量单帧显示点云可视化**draw_geometries** Parameters含…

Uniswap v3 详解(三):交易过程

交易过程 v3 的 UniswapV3Pool 提供了比较底层的交易接口,而在 SwapRouter 合约中封装了面向用户的交易接口: exactInput:指定交易对路径,付出的 x token 数和预期得到的最小 y token 数(x, y 可以互换)e…