PyG Temporal搭建STGCN实现多变量输入多变量输出时间序列预测

news2025/1/16 1:00:04

目录

  • I. 前言
  • II. STGCN
  • III. PyG Temporal
  • IV. 模型训练/测试
  • V. 代码

I. 前言

前面已经写过不少时间序列预测的文章:

  1. 深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)
  2. PyTorch搭建LSTM实现时间序列预测(负荷预测)
  3. PyTorch中利用LSTMCell搭建多层LSTM实现时间序列预测
  4. PyTorch搭建LSTM实现多变量时间序列预测(负荷预测)
  5. PyTorch搭建双向LSTM实现时间序列预测(负荷预测)
  6. PyTorch搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
  7. PyTorch搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
  8. PyTorch搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
  9. PyTorch搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
  10. PyTorch搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
  11. PyTorch中实现LSTM多步长时间序列预测的几种方法总结(负荷预测)
  12. PyTorch-LSTM时间序列预测中如何预测真正的未来值
  13. PyTorch搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
  14. PyTorch搭建ANN实现时间序列预测(风速预测)
  15. PyTorch搭建CNN实现时间序列预测(风速预测)
  16. PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
  17. PyTorch搭建Transformer实现多变量多步长时间序列预测(负荷预测)
  18. PyTorch时间序列预测系列文章总结(代码使用方法)
  19. TensorFlow搭建LSTM实现时间序列预测(负荷预测)
  20. TensorFlow搭建LSTM实现多变量时间序列预测(负荷预测)
  21. TensorFlow搭建双向LSTM实现时间序列预测(负荷预测)
  22. TensorFlow搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
  23. TensorFlow搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
  24. TensorFlow搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
  25. TensorFlow搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
  26. TensorFlow搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
  27. TensorFlow搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
  28. TensorFlow搭建ANN实现时间序列预测(风速预测)
  29. TensorFlow搭建CNN实现时间序列预测(风速预测)
  30. TensorFlow搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
  31. PyG搭建图神经网络实现多变量输入多变量输出时间序列预测
  32. PyTorch搭建GNN-LSTM和LSTM-GNN模型实现多变量输入多变量输出时间序列预测

从第31篇文章起,本系列开始更新时空预测模型,其中前两篇文章都不是属于论文中的模型,今天介绍一个使用较为广泛的用于时序预测的时空图卷积网络STGCN。

II. STGCN

STGCN是北大发表在IJCAI 2018上的论文Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting中提出来的,其目的是用于实时的交通预测。

在该论文中使用的数据集为美国加州PeMSD7数据集,里面包含了分布在不同地方的228个传感器观测到的车流量,文章中使用这228个节点构成了一个无向图,然后根据历史的车流量信息预测未来某个时间段的所有传感器所在地的车流量信息。

可以看出,STGCN要解决的问题与前两篇文章要解决的问题基本一致。前两篇问题中,我们给出了13个变量前24小时的数据,目的是预测13个变量未来某几个小时的数据。在这里13个变量类比于228个传感器。

STGCN的原理也较为简单,STGCN由两个时空图卷积块(ST-Conv Block)和一个输出全连接层(Output Layer组成。其中ST-Conv Block又由两个时间门控卷积和中间的一个空间图卷积组成:
在这里插入图片描述
从图右边可知,两个Temporal Gated-Conv使用的是1-D卷积,和CNN处理一维时序信号类似,即进行seq_len维度上的卷积。Spatial Graph-Conv进行的是空域上的卷积,模型为GCN。

关于STGCN详细的原理可以阅读原论文,原理也比较简单。本篇文章不做太多详细的推导过程,主要讲解如何利用STGCN进行多变量输入多变量输出的时间序列预测。

III. PyG Temporal

PyG Temporal是PyG的一个扩展库,其主要用于处理时空信号数据,里面实现了许多使用较为广泛的时空图卷积模型如STGCN、DCRNN、T-GCN、LRGCN等。

PyG Temporal的安装也比较简单:

pip install torch-geometric-temporal

PyG Temporal中STGCN的实现如下:
在这里插入图片描述
参数解释如下:

  1. in_channels:节点输入特征的维度大小,这里为1,即每个节点都只有一个特征,我们需要预测的也是该特征。
  2. hidden_channels:字面意思。
  3. out_channels:字面意思。
  4. kernel_size:时域卷积时的卷积核大小,类比CNN即可。
  5. K:将切比雪夫多项式作为图卷积核时的卷积核大小,具体可以参考我之前写的一篇文章:ICML 2019 | SGC:简单图卷积网络。
  6. normalization:拉普拉斯矩阵的归一化选项,前面也讲过了。
  7. bias:无需多述。

一个STConv所能接受的输入格式为:
在这里插入图片描述
可以看出,一个STConv需要接受三个输入:

  1. X:维度大小为(batch_size, seq_len, num_nodes, in_channels),在本文中即X=(256, 24, 13, 1)
  2. edge_index:图的邻接矩阵。
  3. edge_weight:边权重矩阵(可选)。

为此,我们可以首先搭建一个STGCN:

class STGCN(nn.Module):
    def __init__(self, num_nodes, size, K):
        super(STGCN, self).__init__()
        self.conv1 = STConv(num_nodes=num_nodes, in_channels=1, hidden_channels=16,
                            out_channels=32, kernel_size=size, K=K)
        self.conv2 = STConv(num_nodes=num_nodes, in_channels=32, hidden_channels=16,
                            out_channels=32, kernel_size=size, K=K)

    def forward(self, x, edge_index):
        # x(batch_size, seq_len, num_nodes, in_channels)
        x, edge_index = x.to(device), edge_index.to(device)
        x = F.elu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)

        return x

然后一个用于多变量输入多变量输出的STGCN模型搭建如下:

class STGCN_MLP(nn.Module):
    def __init__(self, args):
        super(STGCN_MLP, self).__init__()
        self.args = args
        self.out_feats = 128
        self.stgcn = STGCN(num_nodes=args.input_size, size=3, K=1)
        self.fcs = nn.ModuleList()
        for k in range(args.input_size):
            self.fcs.append(nn.Sequential(
                nn.Linear(16 * 32, 64),
                nn.ReLU(),
                nn.Linear(64, args.output_size)
            ))

    def forward(self, x, edge_index):
        # x(batch_size, seq_len, input_size)
        # x(512, 24, 13)--->(512, 24, 13, 1)
        x = x.unsqueeze(3)
        x = self.stgcn(x, edge_index)
        preds = []
        for k in range(x.shape[2]):
            preds.append(self.fcs[k](torch.flatten(x[:, :, k, :], start_dim=1)))

        pred = torch.stack(preds, dim=0)

        return pred

照例简单分析一下模型的处理过程:

首先我们有x=(batch_size=256, seq_len=24, input_size=13),为了满足STGCN的输入要求(batch_size, seq_len, num_nodes, in_channels=1),我们需要将x扩展一个维度:

x = x.unsqueeze(3)

然后经过STGCN:

x = self.stgcn(x, edge_index)

得到x=(256, 16, 13, 32)。操作过程与CNN类似,一维卷积作用在seq_len=24维度,最终变成16。随后,为了得到每个变量的输出,我们简单地将13个变量各自的(16, 32)经过13个不同的全连接层。

IV. 模型训练/测试

这点与前面一致,不再赘述。

预测效果相当不错:
在这里插入图片描述
预测效果示意图(只给出前6个变量):
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

V. 代码

后续考虑整理公开。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/156084.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++——类和对象1

目录 1. 类和对象认识 2. 类的引入 3. 类的定义 4. 类的访问限定符及封装 4.1 访问限定符 4.2 封装 5. 类的作用域 6. 类的实例化 7. 类对象模型 7.1 如何计算类对象的大小 7.2 类对象的存储方式猜测 7.3 结构体内存对齐规则 8. this指针 8.1 this指针的…

cv-cuda (cvcuda、nvcv)教程——Python安装

由于当前版本安装后,大家反应import nvcv cvcuda 失败,看官方文档,当前还不是很规范,特此记录当前版本的安装方法。 官方安装文档:Installation — CV-CUDA Alpha documentation 方法一、如果你有权限推荐deb安装方式…

机器学习第15章-规则学习

机器学习第15章-规则学习 以下列出我觉得重要,在编码的思路中可以参考的地方 冲突消融 当一条规则的判断出现不同的结果时,解决冲突的方法 1.投票法 2.排序法 3.无规则法 序贯覆盖 生成规则过程中去除当前规则所能覆盖的数据 生成方式 自顶向下…

双软认证的好处,赶紧来看看吧

1、“双软件”认可对企业有什么好处? 对于认定的软件企业,从盈利年度起,第一年和第二年免征企业所得税,第三年至第五年减半征收企业所得税,即两免三减。对认定软件产品的企业,对实际增值税负担超过3%的部分…

【ONE·C++ || vector (二)】

总言 主要讲述vector的模拟实现。 文章目录总言1、基本框架搭建:成员变量2、对构造函数、析构函数3、增删查改、空间扩容3.1、vector::push_back、vector::pop_back3.2、vector::reserve、vector::capacity、vector::size3.3、operator[ ]3.4、遍历:迭代…

记录robosense RS-LIDAR-16使用过程1

拿到设备,首先对照型号去官网下载相关资料(用户手册/软件/SDK),需要填写资料https://www.robosense.ai/resources-27工业相机通常也有出厂SDK文件,之前有使用知微传感的D130相机,也是先安装SDK、看手册然后使用。大型厂…

【Java集合】Map接口常用方法及实现子类

文章目录01 Map 接口实现类的特点02 Map 接口和常用方法03 Map 接口遍历方法04 HashMap 用例 小结05 HashMap 底层&扩容机制06 Hashtable07 PropertiesMap为双列集合,Set集合的底层也是Map,只不过有一列是常量所占,只使用到了一列。 01 …

国科大《高级人工智能》沈老师部分——行为主义笔记

国科大《高级人工智能》沈老师部分——行为主义笔记 沈华伟老师yyds,每次上他的课都有一种深入浅出的感觉,他能够把很难的东西讲的很简单,听完就是醍醐灌顶,理解起来特别清晰今年考试题目这部分跟往年基本一样,沈老师画…

长城汽车2022年销量106万辆,20万以上车型占比15%

2023年,长城汽车预计将推出超10款新能源车型,发力新能源和智能化。1. 年度销量:超106万辆 根据长城最新发布的产销数据:•2022年,长城汽车全年销售1,067,523辆; •其中,海外市场累计销售173,180…

2022CTF培训(十二)IOT 相关 CVE 漏洞分析

附件下载链接 NETGEAR R7800(CVE-2020-11790) NETGEAR R7800 存在命令注入漏洞,下面以 V1.0.2.62 版本固件为例进行介绍。 固件仿真 漏洞存在于 uhttpd 中,由于该功能比较独立,可以直接用 qemu user mode 仿真。 /…

在 anaconda 中安装 tensorflow models (gpu)

环境:Windows; Intel CPU Nvidia GPU 1. 创建环境 不推荐单次安装过多的库,可能导致安装失败(如超出终端缓存等)注意添加库的顺序 tensorflow-gpu 需要在 cudatoolkit 之前否则下载的 tensorflow-gpu 不支持 gpu 「实测」 TODO…

设备注册挂载流程(包含上电、使能、i2c通讯介绍)

目录 简介 上电时序 电压不同 时序不同 使能与复位 CLK时钟 I2C通讯 主从关系 识别设备 通讯格式 简介 任何相对于主板芯片的外挂设备都需要一定的注册挂载流程 (外挂设备:比如摄像头、nfc芯片、显示屏等等) 设备的挂载则需要满足…

JAVAEE-多线程(4)

目录 定时器 实现自己的Timer 线程池 常见的锁策略: 乐观锁和悲观锁 读写锁和普通互斥锁 重量级锁和轻量锁 自旋锁和挂起等待锁 公平锁和非公平锁 可重入锁和不可重入锁 synchronized CAS CAS和ABA问题 锁粗化 JUC 原子类 Semaphore CountDownLatc…

CAN总线控制器MCP2515 替代芯片 DP2515 DP2515-I/ST

汽车K总线与CAN的区别是什么 1、功能不同   K线一般用于检测系统,属单线模式,与诊断仪器连接并相互传递数据。CAN线主要用于控制单元与控制单元之间传递数据、属双线模式,分高位线和地位线。   2、通讯速度不同   K线通讯速率较低&…

101.对称二叉树 | 递归 + 迭代

对称二叉树 leetcode : https://leetcode.cn/problems/symmetric-tree/ 参考 对称二叉树 递归思路 首先在开始时, 一定要注意, 对称二叉树对比的并不是一个节点的左右子树, 而是两棵树, 这个很关键! 对比时是内侧和内侧对比, 外侧和外侧对比, 递归三步 : 确定递归的参数以…

1.1.2 了解JAVA语言

文章目录1 JAVA语言发展史2 面向对象的概念3 跨平台性4 JDK1 JAVA语言发展史 JAVA是由詹姆斯•高斯林(James Gosling)所创建的,其1977年获得了加拿大卡尔加里大学计算机科学学士学位,1983年 获得了美国卡内基梅隆大学计算机科学博…

4)Mybatis数据源以及事务实现

1. Mybatis数据源分为两种,一种直接连接数据库,一种使用连接池连接数据库,具体代码实现在包目录下 org.apache.ibatis.datasource 数据源接口: javax.sql.DataSource 池化数据源: org.apache.ibatis.datasource.…

OpenGL集锦(1)-安装与概述

目录概述fedora下安装编写OpenGL应用程序测试hello,world概述 OpenGL(英语:Open Graphics Library,译名:开放图形库或者“开放式图形库”)是用于…

Lichee_RV学习系列--CoreMark-Pro移植

Lichee_RV学习系列文章目录 Lichee_RV学习系列—认识Lichee Rv Dock、环境搭建和编译第一个程序 Lichee_RV学习系列—移植dhrystone 文章目录Lichee_RV学习系列文章目录一、CoreMark-Pro简介二、获取源码三、编译coremark-pro1、配置coremark-pro2、编译coremark-pro四、开发板…

各种树的总结

1.B树和B树 数据库的大量数据用什么存储?为什么是B树和B树?使用二叉树不行吗?先来说说他们的演变吧,首先如果用二叉树的话都为排好序的树查询起来是不是效率不高?所以此时我们提出了对树排序,就变成了二叉…