【论文解读】2017 STGCN: Spatio-Temporal Graph Convolutional Networks

news2024/10/4 16:26:16

一、简介

使用历史速度数据预测未来时间的速度。同时用于序列学习的RNN(GRU、LSTM等)网络需要迭代训练,它引入了逐步累积的误差,并且RNN模型较难训练。为了解决以上问题,我们提出了新颖的深度学习框架STGCN,用于交通预测。

二、STGCN模型架构

2.1 整体架构图示

在这里插入图片描述

2.2 ST-Conv blocks

符号含义
M历史时间序列长度
n节点数
C i C_i Ci输入的channel 数
C o C_o Co输出的channel 数

2.2.1 TemporalConv: Gated CNNs 用于提取时间特征

Note: nn.Conv2d的输入 channel在第一维度

[ P Q ] = C o n v ( x ) ; o u t = P ⊙ σ ( Q ) [P Q] = Conv(x); \\ out = P \odot \sigma (Q) [PQ]=Conv(x);out=Pσ(Q)

  • x ∈ R C i × M × n x \in \mathbb{R}^{C_i \times M \times n } xRCi×M×n
  • [ P Q ] ∈ R 2 C o ∗ ( M − K t + 1 ) × n [\text{P Q}] \in \mathbb{R}^{2C_o * (M - K_t + 1) \times n } [P Q]R2Co(MKt+1)×n

示例代码:

class TCN(nn.Module):
    def __init__(self, c_in: int, c_out: int, dia: int=1):
        """TemporalConvLayer
        input_dim:  (batch_size, 1, his_time_seires_len, node_num)
        sample:     [b, 1, 144, 207]
        Args:
            c_in (int): channel in
            c_out (int): channel out
            dia (int, optional): 空洞卷积大小. Defaults to 1.
        """
        super(TCN, self).__init__()
        self.c_out = c_out * 2
        self.c_in = c_in
        self.conv = nn.Conv2d(
            c_in, self.c_out, (2, 1), 1, padding=(0, 0), dilation=dia
        )

    def forward(self, x):
        # [batch, channel, his_n, node_num] 
        #  仅在时间维度上进行卷积 
        c = self.c_out//2
        out = self.conv(x)
        if len(x.shape) == 3: # channel, his_n, node_num
            P = out[:c, :, :]
            Q = out[c:, :, :]
        else:
            P = out[:, :c, :, :]
            Q = out[:, c:, :, :]
        return P * torch.sigmoid(Q)

2.2.2 SpatialConv: Graph CNNs 提取空间信息

迭代定义的切比雪夫多项式

o u t = Θ ∗ G x = ∑ k = 0 K − 1 θ k T k ( L ~ ) x = ∑ k = 0 K − 1 W K , l z k , l out= \Theta_{* \mathcal{G}} x = \sum_{k=0}^{K-1}\theta_k T_k(\tilde{L})x=\sum_{k=0}^{K-1}W^{K, l}z^{k, l} out=ΘGx=k=0K1θkTk(L~)x=k=0K1WK,lzk,l

  • Z 0 , l = H l Z^{0, l} = H^{l} Z0,l=Hl
  • Z 1 , l = L ~ ⋅ H l Z^{1, l} = \tilde{L} \cdot H^{l} Z1,l=L~Hl
  • Z k , l = 2 ⋅ L ~ ⋅ Z k − 1 , l − Z k − 2 , l Z^{k, l} = 2 \cdot \tilde{L} \cdot Z^{k-1, l} - Z^{k-2, l} Zk,l=2L~Zk1,lZk2,l
  • L ~ = 2 ( I − D ~ − 1 / 2 A ~ D ~ − 1 / 2 ) / λ m a x − I \tilde{L} = 2\left(I - \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\right)/\lambda_{max} - I L~=2(ID~1/2A~D~1/2)/λmaxI

论文: Recursive formulation for fast filtering

示例代码:

class STCN_Cheb(nn.Module):
    def __init__(self, c, A, K=2):
        """spation cov layer
        Args:
            c (int): hidden dimension
            A (adj matrix): adj matrix
        """
        super(STCN_Cheb, self).__init__()
        self.K = K
        self.lambda_max = 2
        self.tilde_L = self.get_tilde_L(A)
        self.weight = nn.Parameter(torch.empty((K * c, c)))
        self.bias = nn.Parameter(torch.empty(c))
        stdv = 1.0 / np.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def get_tilde_L(self, A):
        I = torch.diag(torch.Tensor([1] * A.size(0))).float().to(A.device)
        tilde_A = A + I 
        tilde_D = torch.diag(torch.pow(tilde_A.sum(axis=1), -0.5))
        return 2 / self.lambda_max * (I - tilde_D @ tilde_A @ tilde_D) - I

    def forward(self, x):
        # [batch, channel, his_n, node_num] -> [batch, node_num, his_n, channel] -> [batch, his_n, node_num, channel] 
        x = x.transpose(1, 3)
        x = x.transpose(1, 2)
        output = self.m_unnlpp(x)
        output = output @ self.weight + self.bias
        output = output.transpose(1, 2)
        output = output.transpose(1, 3)
        return torch.relu(output) 

    def m_unnlpp(self, feat):
        K = self.K
        X_0 = feat
        Xt = [X_0]
        # X_1(f)
        if K > 1:
            X_1 = self.tilde_L @ X_0
            # Append X_1 to Xt
            Xt.append(X_1)
        # Xi(x), i = 2...k
        for _ in range(2, K):
            X_i =  2 * self.tilde_L @ X_1 - X_0
            # Add X_1 to Xt
            Xt.append(X_i)
            X_1, X_0 = X_i, X_1
        # 合并数据
        Xt = torch.cat(Xt, dim=-1)
        return Xt

2.2.3 ST-Block

组合TCNSTCN_Cheb
v l + 1 = Γ 1 ∗ T l ReLU ( Θ ∗ G l ( Γ 0 ∗ T l v l ) ) v^{l+1} = \Gamma ^{l} _{1*\mathcal{T}} \text{ReLU}( \Theta ^l_{*\mathcal{G}} (\Gamma ^{l} _{0*\mathcal{T}} v^l) ) vl+1=Γ1TlReLU(ΘGl(Γ0Tlvl))

  • Γ 0 ∗ T l v l \Gamma ^{l} _{0*\mathcal{T}} v^l Γ0Tlvl: 第一个TCN
  • Θ ∗ G l \Theta ^l_{*\mathcal{G}} ΘGl : STCN_Cheb
  • Γ 1 ∗ T l v l \Gamma ^{l} _{1*\mathcal{T}} v^l Γ1Tlvl: 第二个TCN
class STBlock(nn.Module):
    def __init__(
        self,
        A,
        K=2,
        TST_channel: List=[64, 16, 64]
        T_dia: List=[2, 4]
    ):
        # St-Conv Block1[  TCN(64, 16*2)->SCN(16, 16)->TCN(16, 64*2) ] 
        super(STBlock, self).__init__()
        self.T1 = TCN(TST_channel[0], TST_channel[1], dia=T_dia[0])
        # STCN_Cheb out have relu
        self.S = STCN_Cheb(TST_channel[1], Lk=A, K=K)
        self.T2 = TCN(TST_channel[1], TST_channel[2], dia=T_dia[1])

    def forward(self, x):
        return self.T2(self.S(self.T1(x)))

三、简单复现

复现可以看笔者的github: train.ipynb
用的数据是metr-la.h5

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/769438.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【云原生|Docker系列第4篇】Docker的容器的入门实践

欢迎来到云原生系列的第4篇博客!在前面的两篇博客中,我们已经学习了Docker镜像的基本概念和入门实践。本篇博客将带您深入了解Docker容器,探索如何使用Docker容器来构建、运行和管理应用程序。无论您是新手还是有一定经验的开发者&#xff0c…

PHP循环

PHP while 循环 PHP SwitchPHP For 循环 PHP while 循环在指定条件为 true 时执行代码块。 PHP 循环 在您编写代码时,经常需要反复运行同一代码块。我们可以使用循环来执行这样的任务,而不是在脚本中添加若干几乎相等的代码行。 在 PHP 中&#xff…

linux安装mysql以及使用navicat连接mysql

目录 一、下载mysql 二、安装mysql 三、使用Navicat连接MySQL 四、常见问题 1、启动服务时报错 Failed to start mysql.service: Unit not found. 的解决方法。 2、登录过程出现:access denied for user’root’‘localhost’(using password:Yes) 的解决方…

Moonbeam生态说|探索Web3链游生态Seascape

「Moonbeam生态说」是Moonbeam中文爱好者社区组织的社区AMA活动。该活动为媒体和已部署Moonriver或Moonbeam的项目方提供了在主流Moonbeam非官方中文社区内介绍自己的项目信息,包括:项目介绍、团队介绍、技术优势和行业发展等,帮助社区内的Mo…

Spring Boot进阶(57):Spring中什么时候不要用@Autowired注入 | 超级详细,建议收藏

1. 前言🔥 注解Autowired,相信对于我们Java开发者而言并不陌生吧,在SpringBoot或SpringCloud框架中使用那是非常的广泛。但是当我们使用IDEA编辑器开发代码的时候,经常会发现Autowired 注解下面提示小黄线警告,我们把小…

Unity 任意数据在Scene窗口Debug

任意数据在Scene窗口Debug 🍔效果🥪食用方法 🍔效果 如下所示可以很方便的把需要Debug的数据绘制到Scene中(普通的Editor脚本只能够对MonoBehaviour进行Debug) 🥪食用方法 💡. 新建脚本继承Z…

127.0.0.1/linux常用dos命令

ls 命令 LS 命令用于查看目录的内容。默认情况下,此命令将显示当前工作目录的内容。如果要查看其他目录的内容,请键入 ls,然后键入目录的路径。例如,输入 LS / 家 / 用户名 / 文档查看的内容的文件。 ls命令是文件列表命令&#…

Cisco学习笔记(CCNA)——IP Subnetting

IP Subnetting 目录 IP地址 数制的计算 IP地址的分类 特殊IPv4地址 IP地址的组成 子网掩码 IPv4网络中的地址类型 划分子网 ​编辑CIDR IP地址 网络层概念 主机唯一的标识,保证主机间正常通信 一种网络编码,用来确定网络中一个节点 IP地址…

运维:Centos7安装解压版mysql5.7

目录 1、卸载Centos7默认自带的mariadb数据库,避免冲突 2、下载解压版mysql并安装 3、配置mysql 4、mysql客户端访问 MySQL 是一种开源的关系型数据库管理系统(RDBMS),它具有许多优点和一些缺点。以下是 MySQL 的主要优缺点&am…

【C语言+sqlite3 API接口】实现水果超市

实验内容: 假如我家开了个水果超市,有以下水果,想实现自动化管理,扫描二维码就能知道当前的水果状态,进货几天了, 好久需要再次进货,那些水果畅销,那些水果不畅销,那些水…

Python 算法基础篇:深度优先搜索( DFS )和广度优先搜索( BFS )

Python 算法基础篇:深度优先搜索( DFS )和广度优先搜索( BFS ) 引言 1. 深度优先搜索( DFS )算法概述2. 深度优先搜索( DFS )算法实现实例1:图的 DFS 遍历实例…

理性对道德的作用是很小的

从某种程度上说,理性对道德的作用是很重要的。理性能够帮助我们思考和评估道德问题,并提供合理的解决方案。它使我们能够运用逻辑和推理能力来分析情况,权衡利益和后果,并做出更明智的决策。 理性有助于我们超越个人感受或冲动&am…

如何在医疗器械行业运用IPD?

医疗器械是指单独或者组合使用于人体的仪器、设备、器具、材料或其他物品,包括所需要的软件。按安全性可分为低风险器械、中风险器械和高风险器械。其中低风险器械大都属于低值耗材,其中包括绷带、纱布、海绵、消毒液等;中度风险器械类包括体…

结构型模式 - 适配器模式

概述 如果去欧洲国家去旅游的话,他们的插座如下图最左边,是欧洲标准。而我们使用的插头如下图最右边的。因此我们的笔记本电脑,手机在当地不能直接充电。所以就需要一个插座转换器,转换器第1面插入当地的插座,第2面供…

linux之Ubuntu系列(五)用户管理 终端命令 su 切换用户

# 切换用户 zenxx:su - sup # 录入sup 密码 supxx:$ 切换root用户

增强匿名性:了解 HTTP 代理的作用

当你在浏览网页时,你的个人信息和上网痕迹都是暴露在公共网络之中的,任何人都可以轻松地获取到这些信息。而这种信息泄露不仅会威胁你的个人隐私,还会对你的网络安全带来潜在的风险。为了解决这个问题,HTTP代理应运而生。 而 IPR…

v-model指令获取常见表单项的内容 input,textarea,radio,checkbox,select

v-model指令获取常见表单项的内容 1. v-model 作用和语法2. v-model 获取常见表单项 1. v-model 作用和语法 作用: 给 表单元素 使用, 双向数据绑定 → 可以快速 获取 或 设置 表单元素内容 ① 数据变化 → 视图自动更新 ② 视图变化 → 数据自动更新语法: v-model ‘变量’ …

Day13 01-Linux介绍与安装教程

文章目录 第一章 Linux简介【了解】1.1 Linux的介绍1.2 Linux的两大阵营1.3 CentOS社区版介绍 第二章 Linux的安装【重要】2.1 VMWare&Parallels Desktop的安装2.1.1 VMWare的简介2.1.2 VMWare安装的注意事项2.1.3 Parallels Desktop的简介 2.2 VMWare安装Linux2.2.1 准备事…

UG\NX二次开发 获取工作部件的所有表达式,以及值

文章作者:里海 来源网站:https://blog.csdn.net/WangPaiFeiXingYuan 简介: 获取工作部件的所有表达式,以及值。 效果: 代码: #include "me.hpp" #include <iostream> #include <sstream> #include <string> //double转string保留所有小…

Vue找到package.json中没有用到依赖并删除

引言 一切都是由于强迫症&#xff0c;我想把一个Vue项目中没有用到的依赖删除掉。 解决方法 depcheck Depcheck is a tool for analyzing the dependencies in a project to see: how each dependency is used, which dependencies are useless, and which dependencies are…