【深度学习】时空图卷积网络(STGCN),预测交通流量

news2024/11/16 7:35:13

论文地址:https://arxiv.org/abs/1709.04875

Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting

文章目录

  • 一、摘要
  • 二、数据集介绍
    • 美国洛杉矶交通数据集 METR-LA 介绍
    • 美国加利福尼亚交通数据集 PEMS-BAY 介绍
    • 美国加利福尼亚交通数据集 PEMSD7-M 介绍
    • 数据集含义
  • 三、任务目标
  • 四、训练过程
    • 基础推理过程
    • 邻接矩阵
    • 注意力机制
  • 五、 总结

一、摘要

准时准确的交通预测对城市交通控制和引导至关重要。由于交通流的高非线性和复杂性,传统方法无法满足中长期预测任务的要求,并且通常忽视空间和时间依赖关系。本文提出了一种新颖的深度学习框架,即时空图卷积网络(STGCN),用于解决交通领域的时间序列预测问题。我们不是应用常规的卷积和循环单元,而是在图上制定问题,并利用完整的卷积结构构建模型,这使得训练速度更快,参数更少。实验证明,我们的模型STGCN通过建模多尺度交通网络有效地捕获了全面的时空相关性,并在各种真实世界的交通数据集上始终优于最先进的基线模型。

二、数据集介绍

美国洛杉矶交通数据集 METR-LA 介绍

T-GCN文章选取了该数据集2012年3月1日至3月7日期间的207个传感器及其交通速度。每5分钟汇总一次交通速度。**相似性,数据总结出一个邻接矩阵和一个特征矩阵。邻接矩阵是由交通网络中传感器之间的距离计算出来的。**由于Los-loop数据集包含一些缺失的数据,使用线性插值法来填补缺失值.

SOTA算法:

https://paperswithcode.com/sota/traffic-prediction-on-metr-la

目前第一名是Spatio-Temporal Graph Mixformer for Traffic Forecasting,而本文的STGCN只能排18。

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

美国加利福尼亚交通数据集 PEMS-BAY 介绍

PEMS-BAY 数据集由加利福尼亚大学伯克利分校的交通实验室发布。该数据集包含了旧金山湾区高速公路网络上 325 个传感器的实时交通流量数据。这个数据帮助对基于多源数据的交通预测算法进行评估和比较,并被广泛用于交通预测、拥堵控制、出行决策等领域的研究中。

SOTA:

https://paperswithcode.com/sota/traffic-prediction-on-pems-bay

在这里插入图片描述

美国加利福尼亚交通数据集 PEMSD7-M 介绍

SOTA:
https://paperswithcode.com/sota/traffic-prediction-on-pemsd7-m

PEMSD7-M 数据集由美国加州大学洛杉矶分校的智能交通系统实验室发布。该数据集收集了洛杉矶城市高速公路上的交通流量数据,包含了 228 个传感器(探头)每 5 分钟采样一次的方式收集数据。这个数据集旨在帮助进行基于不同时间尺度的交通预测研究,以提高道路网络可持续性和安全性。

在这里插入图片描述

数据集含义

以pems-bay为例,两个csv文件就是2个矩阵。

vel.csv的形状是(16384, 325),代表了16384次观测数据,每次数据都是325个传感器的数据(可能是车流量或者车速的一种综合表达的数据):

在这里插入图片描述
adj.mat.csv的形状是(325, 325),表示了传感器之间的关系,比如距离关系,是邻接矩阵,邻接矩阵是对角阵:
在这里插入图片描述

三、任务目标

还是以pems-bay数据集为例子,我们有16384次时序的观测数据,任务的目标就是预测这325个传感器的后续数据走向,看看哪个模型预测得准。

交通流量是一个很实际的问题,325个传感器分布在地区的各个要点中,邻接矩阵表达了各个传感器之间的关联性,而16384次时序观测数据也表达了传感器随时间变化采集到的数据。如果有一个模型可以很好预测传感器的后续数据,说明模型能很好预测出交通流量走势,这可以应用到交通疏通中或者预测拥堵的应用里去。

四、训练过程

基础推理过程

输入模型的x大小是,32是batchsize,12是时序步数,325是传感器维度。
在这里插入图片描述
输出的是下一次时序里的数据:
在这里插入图片描述

而训练数据条数一共有13108条:
在这里插入图片描述
13090/32
409.0625

所以每次训练是410个iter迭代步数。

x进去后,由self.st_blocks(一系列小波图运算)运算得到了x_stbs ,

在这里插入图片描述

而后 self.output 得到了预测结果。

    def forward(self, x):
        x_stbs = self.st_blocks(x)
        # [2, 1, 12, 207]->[2, 64, 4, 207]
        if self.Ko > 1:
            x_out = self.output(x_stbs)

邻接矩阵

在上述推理过程中,在哪里使用到邻接矩阵了?

在最开始用wavelet_basis来确定了小波基矩阵。

在这里插入图片描述

这个函数wavelet_basis的主要作用是根据输入的交通流量传感器的邻接矩阵adj(表示传感器之间的连接关系),以及参数s(小波尺度参数)和threshold(阈值),来计算并返回一组小波基矩阵,用于后续在图卷积过程中对信号进行多尺度分析。这个过程涉及图信号处理的核心步骤,包括拉普拉斯矩阵的计算、特征分解、小波变换的构建、阈值化处理和L1范数归一化。下面是各部分的详细解析:

  1. 计算拉普拉斯矩阵:

    • 函数首先调用calculate_laplacian_matrix来计算图的拉普拉斯矩阵L。拉普拉斯矩阵是图论中的重要概念,广泛用于图信号处理和图的谱分析,能体现图的拓扑结构。
  2. 特征分解:

    • 接着,使用fourier函数对拉普拉斯矩阵L进行特征分解,得到特征值lamb和特征向量矩阵U。这一步是图傅里叶变换的基础,用于将信号映射到频域进行分析。
  3. 构建图小波:

    • 利用weight_wavelet函数基于小波尺度s、特征值lamb和特征向量U计算小波权重矩阵Weight,以及通过weight_wavelet_inverse计算逆小波权重矩阵inverse_Weight。这两个矩阵是构建小波基的关键,用于对信号进行多尺度分解和重构。
  4. 阈值化处理:

    • 根据给定的threshold,将Weightinverse_Weight中小于该阈值的元素设为0,这是一种常见的去噪操作,可以减少噪声影响,提升模型的稳健性。
  5. L1范数归一化:

    • Weightinverse_Weight进行L1范数归一化,确保每个行的绝对值之和为1,这有助于保持能量守恒,并使后续的操作(如图卷积)更加稳定。
  6. 转换格式:

    • 最后,将矩阵转换为稀疏矩阵格式coo_matrix,以节省存储空间和加速后续的矩阵运算。

整个函数的输出t_k = (Weight, inverse_Weight)是一对小波基矩阵,它们在基于小波变换的图卷积神经网络(如之前提到的Graph_WaveletsConv)中扮演核心角色,用于捕捉不同尺度的图信号特征,增强模型对复杂时空数据的理解能力。

注意力机制

小波变换真是复杂的作用机制,大体来说,是将有用信息用小波变换方式给入到了模型前向传导。

为了引入自注意力机制,就得找一些层引入,小波卷积层是一个不错的层。

定义一个AttentionModule,它接受节点特征作为输入,通过两层全连接层和ReLU激活函数计算每个节点的注意力分数,然后通过softmax函数标准化这些分数,确保所有节点的注意力权重之和为1。在Graph_WaveletsConvWithAttention中,首先使用原有的小波图卷积操作处理输入,然后将得到的特征传递给注意力模块,实现特征的自适应加权。

在这里插入图片描述
在这里插入图片描述
整体结构在多个尺度会调用上注意力块Graph_WaveletsConvWithAttention:
在这里插入图片描述
效果是没有提升,没提升是正常的,如果一个简单的改进就能得到提升,那大家都好发论文了:

在这里插入图片描述

又试了一下没加注意力之前的模型,如论文所写,的确小波加进去,MAE为1.81:

在这里插入图片描述

五、 总结

综合来看,以普通方式改进STGCN是很难有提升的,而目前的SOTA算法非常厉害,比如
STD-MAE算法(Spatial-Temporal-Decoupled Masked Pre-training for Spatiotemporal Forecasting)在PEMS-BAY的指标MAE为 1.77, 或许要尝试一些别的改进方式才行,我这里就不再进行额外尝试了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1658618.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ASP.NET校园新闻发布系统的设计与实现

摘 要 校园新闻发布系统是在学校区域内为学校教育提供资源共享、信息交流和协同工作的计算机网络信息系统。随着网络技术的发展和Internet应用的普及,互联网已成为人们获取信息的重要来源。由于现在各大学校的教师和学生对信息的需求越来越高,校园信息…

SQL优化详解

目录 插入数据 insert的优化(少量数据) 批量插入 手动事务提交 主键顺序插入 插入大量数据 主键优化 数据组织方式: 页分裂: 主键顺序插入的方式: 主键乱序插入: 页合并: 主键设计…

软件技术主要学什么课程

软件技术专业主要学习的课程和内容有编程语言、数据结构与算法、数据库技术等,以下是上大学网( www.sdaxue.com)整理的软件技术主要学什么课程,供大家参考! 编程语言:掌握一种或多种编程语言,如C#、Java、Python、C等&…

EdgeOne 免费证书快速实现网站 HTTPS 访问

在当今互联网环境下,HTTPS访问已经成为现代网站的必备功能。HTTPS 访问不仅能够更有效地保障用户在访问到网站时的数据安全传输,防止信息泄露、消息劫持等问题,在搜索引擎中,未实现 HTTPS 还会被浏览器提示为不安全网站&#xff0…

Windows平台PyCharm之PySide6开发环境搭建与配置

PySide6 是一个用于创建跨平台 GUI 应用程序的库,它是 Qt for Python 的官方库。Qt 是一个跨平台的 C 应用程序框架,用于开发具有图形用户界面(GUI)的应用程序。PySide6 允许开发者使用 Python 语言访问 Qt 的功能,从而…

ICode国际青少年编程竞赛- Python-2级训练场-综合练习2

ICode国际青少年编程竞赛- Python-2级训练场-综合练习2 1、 Flyer[0].step() Flyer[1].step() Dev.step(4)2、 for i in range(2):Flyer[i].step()Dev.step(2)Dev.turnLeft() Flyer[0].step(2) Dev.step(2)3、 for i in range(2):Flyer[i * 2 1].step()Dev.step(-i - 2)Dev.tu…

OpenAI 发布 AI 生成图片检测器;Meta 推出 AI 广告创意工具;Google 正式发布 Pixel 8a,主打 AI

OpenAI 发布 AI 生成图片检测器 OpenAI 昨日官宣推出专用的 AI 监测工具,用于监测图片是否由其旗下 AI 图片生成工具 DALL-E 生成,准确率高达 98.8%。 不过该公司表示,这个检测工具并非旨在检测 Midjourney 和 Stability 等其他流行生成器生…

QAnything 在mac M2 上纯python环境安装使用体验(避坑指南)

这是一篇mac m2本地纯python环境安装 qanything的文章。安装并不顺利,官方提供的模型无法在本地跑。 这篇文章记录了,使用xinference来部署本地模型,并利用openAi的通用接口的方式,可以正常使用。 记录了遇到的所有的问题&#xf…

使用Docker安装Nginx

一、Nginx介绍 Nginx 是一款高性能的开源 Web 服务器和反向代理服务器,具有高效能、高稳定性、低资源消耗等优点。可以处理大量并发请求,支持多种协议,还能实现负载均衡、缓存等功能,在互联网应用中被广泛使用。在Nginx中&#xf…

ICode国际青少年编程竞赛- Python-2级训练场-迷宫

ICode国际青少年编程竞赛- Python-2级训练场-迷宫 1、 Dev.step(3) Dev.turnLeft() for i in range(2):Dev.step(4)Dev.turnRight() for i in range(2):Dev.step(2)Dev.turnLeft() Dev.step(3) Dev.step(-9)2、 Dev.step(3) Dev.turnRight() Dev.step(2) Dev.turnLeft() for i …

AI视频教程下载:给企业管理层和商业精英的ChatGpt课程

课程内容大纲: 1-引言 2-面向初学者的生成性人工智能 3-与ChatGPT一起学习提示101 详细介绍了如何使用ChatGPT的六种沟通模式,并提供了各种实际应用场景和示例: **Q&A模式(问题与答案模式)**: - 这…

区块链 | NFT 水印:Review on Watermarking Techniques(三)

🍍原文:Review on Watermarking Techniques Aiming Authentication of Digital Image Artistic Works Minted as NFTs into Blockchains 一个 NFT 的水印认证协议 可以引入第三方实体来实现对交易的认证,即通过使用 R S A \mathsf{RSA} RSA…

96、技巧-只出现一次的数字

思路 首先不考虑额外空间的话使用一个set去重即可。第二种就是异或运算。 异或操作的性质 身份元素:任何数与0进行异或运算,结果仍然是原数,即 x ^ 0 x。自反性:任何数与自身进行异或运算,结果是0,即 x…

Spring Cloud Consul 4.1.1

该项目通过自动配置和绑定到 Spring 环境和其他 Spring 编程模型习惯用法,为 Spring Boot 应用程序提供 Consul 集成。通过一些简单的注释,您可以快速启用和配置应用程序内的常见模式,并使用基于 Consul 的组件构建大型分布式系统。提供的模式…

Python版Spark core详解

文章目录 第一章 SparkCore1.1. Spark环境部署1.1.1. Spark介绍1.1.1.1. 什么是Spark1.1.1.2. Spark与MapReduce的对比框架对比运行流程对比 1.1.1.3. Spark的组件1.1.1.4. Spark的特点 1.1.2. Spark的安装部署1.1.2.1. Spark安装包下载1.1.2.2. Spark部署模式介绍1.1.2.3. Loc…

解决Vue devtools插件数据变化不会自动刷新

我们使用devtools插件在监测vuex中表单或自定义组件的数据,发现页面数据发生变化后,但是devtools中还是老数据,必须手动点击devtools刷新才能拿到最新的数据。很烦! 解决方案: 打开chrome的设置,向下翻&…

docker自建GitLab仓库

摘要 GitLab 是一个功能强大的开源代码托管平台,它不仅提供了代码存储和版本控制的核心功能,还集成了项目管理、CI/CD 流水线、代码审查等企业级特性。本文将指导你如何在自己的服务器上搭建 GitLab 社区版,创建一个完全属于自己的开源仓库&…

##10 卷积神经网络(CNN):深度学习的视觉之眼

文章目录 前言1. CNN的诞生与发展2. CNN的核心概念3. 在PyTorch中构建CNN4. CNN的训练过程5. 应用:使用CNN进行图像分类5. 应用:使用CNN进行时序数据预测代码实例7. 总结与展望前言 在深度学习的领域中,卷积神经网络(CNN)已经成为视觉识别任务的核心技术。自从AlexNet在2…

引入RabbitMQ

前置条件 docker 安装 mq docker run \-e RABBITMQ_DEFAULT_USERdudu \-e RABBITMQ_DEFAULT_PASS123456 \-v mq-plugins:/plugins \--name mq \--hostname mq \-p 15672:15672 \-p 5672:5672 \--network hmall \-d \rabbitmq:3.8-management可能会出现:docker: Er…

FPGA+炬力ARM实现VR视频播放器方案

FPGA炬力ARM方案,单个视频源信号,同时驱动两个LCD屏显示,实现3D 沉浸式播放 客户应用:VR视频播放器 主要功能: 1.支持多种格式视频文件播放 2.支持2D/3D 效果实时切换播放 3.支持TF卡/U盘文件播放 4.支持定制化配置…