论文代码解读STPGNN

news2024/11/17 10:01:59

1.前言

本次代码文章来自于《2024-AAAI-Spatio-Temporal Pivotal Graph Neural Networks for Traffic Flow Forecasting》,基本模型结构如下图所示:

文章讲解视频链接

代码开源链接

接下来就开始代码解读了。

 

2.代码解读 

class nconv(nn.Module):
    def __init__(self):
        super(nconv, self).__init__()

    def forward(self, x, A):
        x = torch.einsum('ncvl,nwv->ncwl', (x, A))
        return x.contiguous()

让我们逐行分析:

  1. def __init__(self): 这是构造函数,初始化 nconv 类的实例。这里没有额外的初始化参数,因为它没有定义任何需要学习的参数。

  2. super(nconv, self).__init__(): 这一行调用了父类 nn.Module 的构造函数,确保了所有必要的初始化步骤得以执行。

  3. def forward(self, x, A): 定义了前向传播方法,这是每个 nn.Module 子类必须实现的方法。这个方法接受两个输入参数:

    • x: 输入张量,形状为 (N, C, V, L),其中 N 是批量大小,C 是通道数,V 是顶点数,L 是序列长度。
    • A: 图的邻接矩阵,形状为 (N, W, V),其中 W 是边的权重数,V 是顶点数。这里的 W 和 V 应该对应于图中的权重和顶点。
  4. x = torch.einsum('ncvl,nwv->ncwl', (x, A)) 这一行是核心计算部分,使用了 torch.einsum 函数来执行一个高效的多维数组乘法和求和操作。einsum 的第一个参数是一个字符串,描述了输入张量的维度标签和输出张量的维度标签。这里的标签解释如下:

    • 'ncvl' 表示输入张量 x 的四个维度:N(批量大小),C(通道数),V(顶点数),L(序列长度)。
    • 'nwv' 表示输入张量 A 的三个维度:N(批量大小),W(边的权重数),V(顶点数)。
    • 'ncwl' 表示输出张量的四个维度:N(批量大小),C(通道数),W(边的权重数),L(序列长度)。

    这个表达式实际上是在进行类似于图卷积的操作,其中输入特征 x 与图的邻接矩阵 A 相乘,以传播信息通过图的边。

  5. return x.contiguous() 最后返回处理后的张量。contiguous() 方法用于确保返回的张量在内存中是连续存储的,这对于后续可能的操作(如索引或视图转换)来说是必要的。

总的来说,nconv 模块接收输入特征和图的邻接矩阵,然后通过 torch.einsum 实现了一种特定的卷积操作,用于处理图结构数据。

class pconv(nn.Module):
    def __init__(self):
        super(pconv, self).__init__()

    def forward(self, x, A):
        x = torch.einsum('bcnt, bmn->bc', (x, A))
        return x.contiguous()

pconv 类定义了一个自定义的PyTorch模块,该模块实现了一种特定类型的卷积操作,其中输入张量与一个可学习的或预定义的邻接矩阵(A)进行乘法运算。这种类型的卷积通常在图神经网络(Graph Neural Networks, GNNs)中使用,其中A可以代表图的邻接矩阵,用于编码节点之间的连接性。下面是对 pconv 类的详细解释:

__init__ 方法

pconv 类继承自 nn.Module,这是所有PyTorch神经网络模块的基类。构造函数 __init__ 中没有定义任何额外的参数或层,这意味着 pconv 不包含任何可学习的参数,即它不会在训练过程中更新其权重。

forward 方法

forward 方法定义了当数据通过这个模块时的操作。它接受两个参数:

  • x: 输入张量,形状为 (batch_size, channels, nodes, time_steps)。其中:
    • batch_size 表示一个批次中的样本数量。
    • channels 表示每个节点在每个时间步上的特征数量。
    • nodes 表示图中的节点数量。
    • time_steps 表示时间序列的长度。
  • A: 邻接矩阵,形状为 (batch_size, nodes, nodes)A 可以是预定义的,也可以是可学习的,它编码了图中节点之间的关系。

内部操作

forward 方法内部,使用了 torch.einsum 函数来执行一个高效的矩阵乘法操作。einsum 是一个通用的函数,用于执行各种类型的张量运算,这里用来实现输入张量 x 与邻接矩阵 A 的乘法。

torch.einsum('bcnt, bmn->bc', (x, A)) 这行代码中,字符串 'bcnt, bmn->bc' 定义了输入张量的子标模式以及期望的输出模式。具体来说:

  • 'bcnt' 指代 x 的四个维度,分别对应于 batch size (b)、channels (c)、nodes (n) 和 time steps (t)。
  • 'bmn' 指代 A 的三个维度,分别对应于 batch size (b)、源节点 (m) 和目标节点 (n)。
  • 'bc' 是输出张量的模式,意味着输出将是一个二维张量,其维度为 batch size 和 channels。

输出

x = torch.einsum('bcnt, bmn->bc', (x, A)) 计算的结果是一个形状为 (batch_size, channels) 的张量,这表明对于每一个样本,我们得到了一个压缩后的特征表示,其中时间步和节点维度被聚合掉了。

最后,return x.contiguous() 确保返回的张量是连续存储的,这对于后续的某些操作可能很重要,例如当张量需要在GPU上进行高效计算时。这是因为非连续的内存布局可能会导致性能下降。

 

class linear(nn.Module):
    def __init__(self, c_in, c_out):
        super(linear, self).__init__()
        self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True)

    def forward(self, x):
        return self.mlp(x)

linear 类是一个自定义的 PyTorch 模块,它实质上实现了一个线性变换。

 

class gcn(nn.Module):
    def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
        super(gcn, self).__init__()
        self.nconv = nconv()
        c_in = (order * support_len + 1) * c_in
        self.mlp = linear(c_in, c_out)
        self.dropout = dropout
        self.order = order

    def forward(self, x, support):
        out = [x]
        for a in support:
            x1 = self.nconv(x, a)
            out.append(x1)
            for k in range(2, self.order + 1):
                x2 = self.nconv(x1, a)
                out.append(x2)
                x1 = x2

        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        return h

gcn 类定义了一个基于图卷积网络(Graph Convolutional Network, GCN)的模块,它在图结构数据上执行多阶卷积操作,以捕获不同层次的节点间关联。下面是对 gcn 类的详细解析:

初始化方法 __init__

在构造函数 __init__ 中,gcn 类继承自 nn.Module 并初始化以下组件:

  • nconv: 实例化 nconv 类,用于执行图卷积操作。
  • mlp: 实例化 linear 类,用于线性变换和聚合来自不同阶卷积的结果。
  • dropout: 设置 dropout 比率,用于正则化和防止过拟合。
  • order: 设置图卷积的阶数,控制卷积操作的深度,即卷积在图上扩展的层数。

c_in 的值被重新定义为 (order * support_len + 1) * c_in,这考虑到了 support_len 个支持矩阵在 order 阶的卷积中产生的特征通道数。+1 是因为原始输入 x 也会被拼接到最终的输出中。

前向传播方法 forward

forward 方法中,gcn 类执行以下操作:

  1. 初始化一个列表 out 来保存每一阶卷积的结果,首先添加原始输入 x
  2. 对于 support 中的每一个邻接矩阵 a,执行以下操作:
    • 使用 nconv 对输入 x 和邻接矩阵 a 进行一次卷积,结果存储在 x1 中,并添加到 out
    • 接下来,对于 order 中的每一阶(从 2 开始),重复使用 nconv 对前一阶的结果 x1 和同一个邻接矩阵 a 进行卷积,结果存储在 x2 中,再添加到 out,并将 x2 设为下一次迭代的输入 x1
  3. 将 out 中的所有结果在通道维度(dim=1)上进行拼接,形成一个包含所有阶卷积结果的张量 h
  4. 将 h 传递给 mlp 层,进行线性变换和通道数的调整,最终输出调整后的特征表示。

总结

gcn 类通过多次调用 nconv 模块来执行多阶图卷积,捕捉图中节点间的多层次关系。通过将不同阶的卷积结果拼接起来,它能够整合从局部到全局的节点信息。最后,mlp 层负责将这些多阶特征映射到期望的输出维度,以便进一步的处理或分类。这种设计使得 gcn 能够有效处理复杂图结构数据,并在诸如社交网络分析、分子结构预测等任务中发挥重要作用。

 

class pgcn(nn.Module):
    def __init__(self, c_in, c_out, dropout, support_len=3, order=2, temp=1):
        super(pgcn, self).__init__()
        self.nconv = nconv()
        self.temp = temp
        c_in = (order * support_len + 1) * c_in
        self.mlp = linear(c_in, c_out)
        self.dropout = dropout
        self.order = order

    def forward(self, x, support):
        out = [x]
        for a in support:
            x1 = self.nconv(x, a)
            out.append(x1)
            for k in range(2, self.order + 1):
                x2 = self.nconv(x1, a)
                out.append(x2)
                x1 = x2

        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        h = h[:,:,:,-h.size(3):-self.temp]
        return h

pgcn 类定义了一个个性化的图卷积网络(Personalized Graph Convolutional Network)模块,它在图卷积的基础上引入了个性化参数,允许模型在处理图数据时考虑到更加细致的节点特性或时间序列特性。下面是对 pgcn 类的详细解析:

初始化方法 __init__

pgcn 类继承自 nn.Module 并初始化以下组件:

  • nconv: 实例化 nconv 类,用于执行图卷积操作。
  • temp: 一个个性化参数,用于在输出中裁剪时间序列数据,这可能用于处理具有周期性或季节性模式的时间序列数据,通过移除某些时间点的数据来增强模型对特定时间模式的学习能力。
  • mlp: 实例化 linear 类,用于线性变换和聚合来自不同阶卷积的结果。
  • dropout: 设置 dropout 比率,用于正则化和防止过拟合。
  • order: 设置图卷积的阶数,控制卷积操作的深度。

gcn 类似,c_in 的值被重新定义为 (order * support_len + 1) * c_in,考虑到了多阶卷积产生的特征通道数。

前向传播方法 forward

forward 方法中,pgcn 类执行的操作与 gcn 类似,但在输出阶段有一个关键的区别:

  1. 初始化一个列表 out 来保存每一阶卷积的结果,首先添加原始输入 x
  2. 对于 support 中的每一个邻接矩阵 a,执行多阶卷积操作,将结果存储在 out 中。
  3. 将 out 中的所有结果在通道维度(dim=1)上进行拼接,形成一个包含所有阶卷积结果的张量 h
  4. 将 h 传递给 mlp 层,进行线性变换和通道数的调整。
  5. 个性化裁剪:在 h 上执行一个个性化裁剪操作,通过 h = h[:,:,:,-h.size(3):-self.temp],这将从 h 的最后一个维度(通常是时间序列的长度)开始,去除从末尾开始的 self.temp 个时间点的数据。这种裁剪可以用于去除不需要的时间点,例如去除最近的短期波动,以便模型更专注于长期趋势或周期性模式。

总结

pgcn 类通过在标准图卷积网络的基础上引入个性化参数 temp,增强了模型处理时间序列图数据的能力。通过裁剪时间序列的末端,模型可以更好地聚焦于数据中的长期模式,这对于处理具有季节性或周期性特性的数据集尤为重要。

 

    def __init__(self, device, num_nodes, dropout=0.3, topk=35,
                 out_dim=12, residual_channels=16, dilation_channels=16, end_channels=512,
                 kernel_size=2, blocks=4, layers=2, days=288, dims=40, order=2, in_dim=9, normalization="batch"):
        super(STPGNN, self).__init__()
        skip_channels = 8
        self.alpha = nn.Parameter(torch.tensor(-5.0))  
        self.topk = topk
        self.dropout = dropout
        self.blocks = blocks
        self.layers = layers

        self.filter_convs = nn.ModuleList()
        self.gate_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()
        self.normal = nn.ModuleList()
        self.gconv = nn.ModuleList()

        self.residual_convs_a = nn.ModuleList()
        self.skip_convs_a = nn.ModuleList()
        self.normal_a = nn.ModuleList()
        self.pgconv = nn.ModuleList()

        self.start_conv_a = nn.Conv2d(in_channels=in_dim,
                                      out_channels=1,
                                      kernel_size=(1, 1))
        self.start_conv = nn.Conv2d(in_channels=in_dim,
                                    out_channels=residual_channels,
                                    kernel_size=(1, 1))

        receptive_field = 1

        self.supports_len = 1
        self.nodevec_p1 = nn.Parameter(torch.randn(days, dims).to(device), requires_grad=True).to(device)
        self.nodevec_p2 = nn.Parameter(torch.randn(num_nodes, dims).to(device), requires_grad=True).to(device)
        self.nodevec_p3 = nn.Parameter(torch.randn(num_nodes, dims).to(device), requires_grad=True).to(device)
        self.nodevec_pk = nn.Parameter(torch.randn(dims, dims, dims).to(device), requires_grad=True).to(device)

这段代码是 STPGNN 类的初始化方法 __init__ 的一部分,它主要负责构建模型的架构和初始化必要的参数。下面是详细的解析:

构建网络组件

  • Convolution Layers and Residual Connections:

    • self.filter_convsself.gate_convs: 这两个列表存储了因果卷积(Causal Convolution)层,它们用于处理时间序列数据,通过滤波器(filter)和门控(gate)机制捕捉时间依赖性。
    • self.residual_convsself.skip_convs: 这些列表分别存储了残差卷积和跳跃连接卷积层,用于在网络中建立残差连接和跳跃连接,有助于梯度传播并避免深度网络中的梯度消失/爆炸问题。
    • self.normal: 这个列表包含了归一化层,如批量归一化(Batch Normalization)或层归一化(Layer Normalization),用于加速训练过程和提升模型性能。
  • Graph Convolution Layers:

    • self.gconvself.pgconv: 这两个列表分别存储了图卷积(Graph Convolution)层和个性化图卷积(Personalized Graph Convolution)层,用于处理图结构数据,捕捉节点间的空间依赖性。
    • self.residual_convs_aself.skip_convs_aself.normal_a: 这些组件与前面提到的组件类似,但是专门用于辅助分支,可能是为了处理特定类型的信息或者用于构建个性化的图卷积。
  • Input Layers:

    • self.start_conv_aself.start_conv: 这两个卷积层用于调整输入数据的维度,self.start_conv_a 可能用于特定的辅助特征提取,而 self.start_conv 则是主输入层,用于调整输入特征至残差通道数。

参数初始化

  • Receptive Field: receptive_field 是一个变量,初始化为1,它表示网络能够感知的时间序列的宽度。随着网络的深入,这个值会增加,表示网络可以捕捉到更远的历史信息。

  • Node Embeddings and Adjacency Matrix Parameters:

    • self.nodevec_p1self.nodevec_p2self.nodevec_p3self.nodevec_pk: 这些参数是节点嵌入向量和用于构建动态邻接矩阵的参数,它们在训练过程中是可学习的。self.nodevec_p1 代表时间相关的节点嵌入,self.nodevec_p2 和 self.nodevec_p3 代表空间相关的节点嵌入,而 self.nodevec_pk 用于构建核心节点之间的关联,这些参数一起用于构建一个适应性更强的图结构,使得模型能够根据输入数据动态调整节点之间的关联强度。

通过上述组件和参数的初始化,STPGNN 构建了一个能够处理时空序列数据的深度学习模型,结合了时间序列分析和图结构数据处理的优势,适用于如交通流量预测、环境监测等需要同时考虑时间和空间依赖性的任务。

 
        for b in range(blocks):
            additional_scope = kernel_size - 1
            new_dilation = 1
            for i in range(layers):
                # dilated convolutions
                self.filter_convs.append(nn.Conv2d(in_channels=residual_channels,
                                                   out_channels=dilation_channels,
                                                   kernel_size=(1, kernel_size), dilation=new_dilation))

                self.gate_convs.append(nn.Conv1d(in_channels=residual_channels,
                                                 out_channels=dilation_channels,
                                                 kernel_size=(1, kernel_size), dilation=new_dilation))

                self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels,
                                                     out_channels=residual_channels,
                                                     kernel_size=(1, 1)))

                self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels,
                                                 out_channels=skip_channels,
                                                 kernel_size=(1, 1)))

                self.residual_convs_a.append(nn.Conv1d(in_channels=dilation_channels,
                                                       out_channels=residual_channels,
                                                       kernel_size=(1, 1)))
                
                self.pgconv.append(
                    pgcn(dilation_channels, residual_channels, dropout, support_len=self.supports_len, order=order, temp=new_dilation))
                
                self.gconv.append(
                    gcn(dilation_channels, residual_channels, dropout, support_len=self.supports_len, order=order))
                
                if normalization == "batch":
                    self.normal.append(nn.BatchNorm2d(residual_channels))
                    self.normal_a.append(nn.BatchNorm2d(residual_channels))
                elif normalization == "layer":
                    self.normal.append(nn.LayerNorm([residual_channels, num_nodes, 13 - receptive_field - new_dilation + 1]))
                    self.normal_a.append(nn.LayerNorm([residual_channels, num_nodes, 13 - receptive_field - new_dilation + 1]))
                new_dilation *= 2
                receptive_field += additional_scope
                additional_scope *= 2

这段代码是 STPGNN 类初始化方法的一部分,它主要负责构建多层因果卷积块,这些块是构成整个网络的基础单元。以下是详细解析:

构建因果卷积块

  • Looping through blocks and layers:
    • 外层循环 for b in range(blocks) 控制着构建的残差块数量,每个块由多个层组成。
    • 内层循环 for i in range(layers) 控制着每个残差块内的层数量。

卷积层的配置

  • Dilated Convolutions:
    • self.filter_convs 和 self.gate_convs 分别存储了滤波器和门控机制的扩张卷积层,用于捕捉时间序列数据中的长期依赖关系。扩张卷积(Dilated Convolution)通过增加卷积核之间的空洞来扩大感受野,而无需增加网络深度或输入尺寸。
    • self.residual_convs 存储了用于残差连接的1x1卷积层,它们用于将输入与扩张卷积的输出相加,形成残差块的核心部分。
    • self.skip_convs 存储了用于跳跃连接的1x1卷积层,它们将中间层的输出传递到网络的最后阶段,帮助网络学习长期依赖。

图卷积层的配置

  • Graph Convolution Layers:
    • self.pgconv 和 self.gconv 分别存储了个性化图卷积(Personalized Graph Convolution)和图卷积(Graph Convolution)层,用于处理图结构数据,捕捉节点间的空间依赖性。这些层在每个因果卷积层之后被调用,将时间序列特征与图结构特征相结合。

归一化层的配置

  • Normalization Layers:
    • 根据 normalization 参数的值,选择批量归一化(nn.BatchNorm2d)或层归一化(nn.LayerNorm)。归一化层有助于加速训练过程,减少内部协变量偏移,提高模型的泛化能力。

扩张因子和感受野的更新

  • Updating Dilation Factor and Receptive Field:
    • new_dilation *= 2 更新了扩张因子,每次内层循环都会翻倍,这样扩张卷积的感受野会随着层数的增加而指数级增长。
    • receptive_field += additional_scope 和 additional_scope *= 2 更新了网络的感受野,反映了随着扩张卷积的深入,网络能够捕捉到的时间序列的宽度也在增加。

通过这种方式,STPGNN 构建了一个能够同时处理时间序列数据和图结构数据的深度学习模型,能够捕捉到数据中的长期依赖和空间依赖,非常适合应用于如交通流量预测等需要同时考虑时间和空间因素的任务。

 

    def dgconstruct(self, time_embedding, source_embedding, target_embedding, core_embedding):
        adp = torch.einsum('ai, ijk->ajk', time_embedding, core_embedding)
        adp = torch.einsum('bj, ajk->abk', source_embedding, adp)
        adp = torch.einsum('ck, abk->abc', target_embedding, adp)
        adp = F.softmax(F.relu(adp), dim=2)
        return adp
    
    def pivotalconstruct(self, x, adj, k):
        x = x.squeeze(1)
        x = x.sum(dim=0)
        y = x.sum(dim=1).unsqueeze(0)
        adjp = torch.einsum('ij, jk->ik', x[:,:-1], x.transpose(0, 1)[1:,:]) / y
        adjp = adjp * adj
        score = adjp.sum(dim=0) + adjp.sum(dim=1)
        N = x.size(0)
        _, topk_indices = torch.topk(score,k)
        mask = torch.zeros(N, dtype=torch.bool,device=x.device)
        mask[topk_indices] = True
        masked_matrix = adjp * mask.unsqueeze(1) * mask.unsqueeze(0)
        adjp = F.softmax(F.relu(masked_matrix), dim=1)
        return (adjp.unsqueeze(0))

这段代码定义了两个函数,dgconstructpivotalconstruct,它们分别用于构建动态图结构和识别关键节点。

dgconstruct函数接受四个参数:time_embedding(时间嵌入),source_embedding(源节点嵌入),target_embedding(目标节点嵌入),和core_embedding(核心嵌入)。此函数的目标是通过四者间的交互作用来构建动态的邻接矩阵adp,这个矩阵描述了在特定时间下,源节点与目标节点之间的影响强度。具体步骤如下:

  1. 首先,使用torch.einsum函数,将时间嵌入与核心嵌入相乘,生成一个中间矩阵adp
  2. 接下来,将源节点嵌入与上一步得到的adp相乘,进一步细化节点间的影响关系。
  3. 最后,目标节点嵌入与当前的adp相乘,完成动态邻接矩阵的构建。
  4. 应用ReLU激活函数和Softmax归一化函数,使矩阵元素非负且按列归一化,确保每个源节点到所有目标节点的边权总和为1。

pivotalconstruct函数则用于识别交通网络中的关键节点。它接受三个参数:x(输入特征矩阵),adj(静态邻接矩阵),和k(关键节点数量)。以下是详细步骤:

  1. 将输入特征矩阵x的维度调整,使其变为二维,然后沿列方向求和,得到节点的时间序列特征。
  2. 对节点的时间序列特征进行行求和,得到节点的总流量,然后将其转置并扩展维度,便于后续计算。
  3. 利用torch.einsum计算节点间的时间序列特征相互作用矩阵adjp,并通过除以节点总流量进行标准化。
  4. adjp与静态邻接矩阵adj相乘,过滤掉不存在物理连接的节点间关系。
  5. 计算每个节点的“重要性”分数,这是通过将adjp矩阵的行和列求和得到的。
  6. 使用torch.topk函数找到具有最高分数的前k个节点,这些节点即为关键节点。
  7. 创建一个布尔掩码mask,用于标记哪些节点是关键节点。
  8. 应用掩码到adjp矩阵,仅保留关键节点间的关系。
  9. 最后,对关键节点的邻接矩阵应用ReLU和Softmax,确保矩阵非负且按列归一化,得到最终的关键节点邻接矩阵adjp,并增加一个维度以适应后续操作。

通过以上两个函数,dgconstruct构建了基于动态特征的邻接矩阵,而pivotalconstruct则识别出了网络中对交通流动有重要影响的关键节点及其相互关系。这两个矩阵将用于后续的图神经网络层,以捕捉交通网络中的空间和时间依赖性。

    def forward(self, inputs, ind):
        """
        input: (B, F, N, T)
        """
        in_len = inputs.size(3)
        num_nodes = inputs.size(2)
        if in_len < self.receptive_field:
            xo = nn.functional.pad(inputs, (self.receptive_field - in_len, 0, 0, 0))
        else:
            xo = inputs
        x = self.start_conv(xo[:, [0]])
        x_a = self.start_conv_a(xo[:, [0]])
        skip = 0
        adj = self.dgconstruct(self.nodevec_p1[ind], self.nodevec_p2, self.nodevec_p3, self.nodevec_pk)
        pivweight = nn.Parameter(torch.randn(num_nodes, num_nodes).to(x.device), requires_grad=True).to(x.device)
        adj_p = self.pivotalconstruct(x_a, pivweight, self.topk)
        supports = [adj]
        supports_a = [adj_p]
    
        for i in range(self.blocks * self.layers):
            residual = x
            filter = self.filter_convs[i](residual)
            filter = torch.tanh(filter)
            gate = self.gate_convs[i](residual)
            gate = torch.sigmoid(gate)
            x = filter * gate
            x_a = self.pgconv[i](residual, supports_a)
            x = self.gconv[i](x, supports)
            alpha_sigmoid = torch.sigmoid(self.alpha)  
            x = alpha_sigmoid * x_a +  (1 - alpha_sigmoid) * x
            x = x + residual[:, :, :, -x.size(3):]
            s = x
            s = self.skip_convs[i](s)
            if isinstance(skip, int):  # B F N T
                skip = s.transpose(2, 3).reshape([s.shape[0], -1, s.shape[2], 1]).contiguous()
            else:
                skip = torch.cat([s.transpose(2, 3).reshape([s.shape[0], -1, s.shape[2], 1]), skip], dim=1).contiguous()
            x = self.normal[i](x)

        x = F.relu(skip)
        x = F.relu(self.end_conv_1(x))
        x = self.end_conv_2(x)
        return x

这段代码实现了一个深度学习模型的前向传播过程,该模型被设计用于处理时序数据,如交通流量预测。模型的输入是一个四维张量,形状为(B, F, N, T),其中B代表批量大小,F代表特征数,N代表节点数,T代表时间步长。模型的架构包含了卷积、门控机制、残差连接、跳过连接以及图神经网络组件。

  1. 输入预处理

    • 首先检查输入的时间长度T是否小于模型的受感野(receptive_field),如果小于,则使用nn.functional.pad对输入进行填充,确保输入的时间序列长度满足要求。
  2. 起始卷积

    • 使用start_convstart_conv_a进行起始卷积,分别对输入的首个特征通道进行处理,得到x和x_a。
  3. 动态图构建

    • dgconstruct函数用于构建动态邻接矩阵,根据节点特征构建图结构。这将用于图卷积操作。
    • pivotalconstruct函数用于构建关键节点图,它使用x_a和关键节点权重矩阵pivweight来构造关键节点的邻接矩阵adj_p。
  4. 多层残差模块

    • 模型包含多个残差块,每个残差块由多层组成。每层首先应用残差连接,之后进行卷积操作,包括滤波器和门控机制。
    • 滤波器和门控卷积的结果分别经过tanh和sigmoid激活函数,之后相乘,产生门控信号控制信息流。
    • 使用pgconv进行关键节点图上的卷积,并使用gconv进行常规图卷积。
    • 引入一个可学习的参数alpha_sigmoid,通过sigmoid函数得到一个0到1之间的值,用于加权融合关键节点图卷积和常规图卷积的结果。
    • 结果再与残差项相加,之后进行跳过连接,将结果存储在skip变量中,用于后续的跳跃连接操作。
  5. 跳跃连接与输出

    • 跳跃连接将每一层的输出收集起来,进行整合,形成skip变量。
    • 经过跳跃连接后,结果经过end_conv_1end_conv_2卷积层处理,最终得到模型的输出。

整个模型通过这种结构能够同时捕捉空间和时间依赖性,特别是在处理像交通流量预测这样的问题时,它能有效利用图结构和时序特性,从而做出更准确的预测。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1799962.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

离散数学答疑 3

&#xff5e;A&#xff1a;A的补集 有时候空集是元素&#xff0c;有时候就是纯粹的空集 A-B的定义&#xff1a; 笛卡尔积&#xff1a; 求等价关系&#xff1a;先求划分再一一列举 不同划分&#xff1a;分几块。一块&#xff1a;两块&#xff1a;三块&#xff1a;分别计算 Ix是…

2024-5-7 石群电路-26

2024-6-7&#xff0c;星期五&#xff0c;15:00&#xff0c;天气&#xff1a;阴转小雨&#xff0c;心情&#xff1a;晴。今天虽然是阴雨天&#xff0c;但是心情不能差哦&#xff0c;离答辩越来越近了&#xff0c;今天学完习好好准备准备ppt&#xff0c;加油学习喽~ 今日观看了石…

Vue 学习笔记 总结

Vue.js 教程 | 菜鸟教程 (runoob.com) 放一下课上的内容 Vue练习 1、练习要求和实验2的用户注册一样&#xff0c;当用户输入后&#xff0c;能在下方显示用户输入的各项内容&#xff08;不需要实现【重置】按钮&#xff09; 2、实验报告中的实验小结部分来谈谈用JS、jQuery和…

接口幂等性设计(5 大方案罗列)

结合案例、列举场景的接口幂等性设计方案。 方案 1. 状态机 业务场景&#xff0c;数据审核成功后进行短信通知&#xff0c;或者是订单状态变成已支付后&#xff0c;短信通知用户订单生成的详细信息&#xff0c;等等和状态有关的操作。 假设 status&#xff1a;0&#xff08;待…

vue改造四级树状可输入table

vue改造四级树状可输入table <template><div class"dimension_wary"><div class"itemHeader"><div class"target"></div><div class"sort">X2</div><div class"weight">…

xiaolingcoding 图解网络笔记——基础篇

文章目录 参考一、网络模型有哪几层DMANAPI 机制二、键入网址到网页显示&#xff0c;期间发生了什么&#xff1f;1. HTTP2. DNS3. 协议栈4. TCP5. IP6. MAC7. 网卡8. 交换机9. 路由器10. 服务器 与 客户端的互相扒皮&#xff08;添加、删除头部信息&#xff09;参考图HTTP 请求…

部署kubesphere报错

安装kubesphere报错命名空间terminted [rootk8smaster ~]# kubectl apply -f kubesphere-installer.yaml Warning: apiextensions.k8s.io/v1beta1 CustomResourceDefinition is deprecated in v1.16, unavailable in v1.22; use apiextensions.k8s.io/v1 CustomResourceDefini…

【数据结构初阶】--- 顺序表

顺序表&#xff0c;好像学C语言时从来没听过&#xff0c;实际上就是给数组穿了层衣服&#xff0c;本质是一模一样的。 这里的顺序表实际是定义了一个结构体&#xff0c;设计各种函数来实现它的功能&#xff0c;比如说数组中的增删改查插入&#xff0c;这些基本操作其实平时就会…

【YOLOV8】4.图片分类-训练自己的数据集

Yolo8出来一段时间了,包含了目标检测、实例分割、人体姿态预测、旋转目标检测、图像分类等功能,所以想花点时间总结记录一下这几个功能的使用方法和自定义数据集需要注意的一些问题,本篇是第四篇,图像分类功能,自定义数据集的训练。 YOLO(You Only Look Once)是一种流行的…

拥抱生态农业,享受绿色生活

随着人们对健康生活的追求日益增强&#xff0c;生态农业逐渐成为人们关注的焦点。我们深知生态农业对于保护生态环境、提高农产品品质的重要性&#xff0c;因此&#xff0c;我们积极推广生态农业理念&#xff0c;让更多的人了解并参与到生态农业的实践中来。 生态农业的蓝总说&…

ALSA 用例配置

ALSA 用例配置。参考 ALSA 用例配置 来了解更详细信息。 ALSA 用例配置 用例配置文件使用 配置文件 语法来定义静态配置树。该树在运行时根据配置树中的条件和动态变量进行评估&#xff08;修改&#xff09;。使用 用例接口 API 解析结果并将其导出到应用程序。 配置目录和主…

苹果手机618大降价重登销量榜首 红米K70pro为何成京东618国产手机之光

今天的618已经好几天了&#xff0c;很多买有机的已经下单&#xff0c;不出意外苹果15系列手机仍然是最卖座的手机&#xff0c;大家虽然口号喊得很响身体却是诚实的。但令人感到意外的是&#xff0c;今年618国产手机的第一把交椅确实红米K70系列&#xff0c;说好的支持华为呢&am…

给孩子的端午节礼物:最新初中数学思维导图大合集+衡水高考学霸笔记,可下载打印!

大家好哇&#xff01;端午节到了&#xff0c;阿星给家里有孩子的伙伴们一份礼物哦&#xff01;今天给大家带来一个超级实用的学习神器——思维导图法&#xff0c;最新版的初中数学思维导图大合集&#xff01; 这可不是我吹哦&#xff0c;连哈佛、剑桥大学都在用的高级学习方法…

常见硬件工程师面试题(一)

大家好&#xff0c;我是山羊君Goat。 对于硬件工程师&#xff0c;学习的东西主要和电路硬件相关&#xff0c;所以在硬件工程师的面试中&#xff0c;对于经验是十分看重的&#xff0c;像PCB设计&#xff0c;电路设计原理&#xff0c;模拟电路&#xff0c;数字电路等等相关的知识…

webman中创建udp服务

webman是workerman的web开发框架 可以很容易的开启udp服务 tcp建议使用gatewayworker webman GatewayWorker插件 创建udp服务: config/process.php中加入: return [// File update detection and automatic reloadmonitor > [ ...........], udp > [handler > p…

转速传感器介绍

一、概述 RPM&#xff08;Revolutions Per Minute&#xff09;转速传感器是一种用于测量旋转机械设备转速的传感器。它可以检测旋转部件上的特定位置标记&#xff08;如齿轮、凸起或磁铁&#xff09;&#xff0c;并根据这些标记的通过频率来计算转速。发电额定频率是50hz和60z…

链表题目练习----重排链表

这道题会联系到前面写的一篇文章----快慢指针相关经典问题。 重排链表 指针法 这道题乍一看&#xff0c;好像有点难处理&#xff0c;但如果仔细观察就会发现&#xff0c;这道题是查找中间节点反转链表链表的合并问题&#xff0c;具体细节有些不同&#xff0c;这个在反装中间链…

Apache Doris 基础 -- 数据表设计(表索引)

1、索引概述 索引用于帮助快速过滤或搜索数据。目前&#xff0c;Doris支持两种类型的索引:内置智能索引和用户创建的二级索引。 内置智能索引 排序键和前缀索引:Apache Doris基于排序键以有序的方式存储数据。它为每1024行数据创建一个前缀索引。索引中的键是当前1024行组的…

国产主流软硬件厂商生态分析

国产领域主流厂商汇总 信创&#xff0c;即信息技术应用创新&#xff0c;由“信息技术应用创新工作委员会”于2016年3月4日发起&#xff0c;是专注于软硬件关键技术研发、应用与服务的非营利性组织。作为科技自强的关键力量&#xff0c;信创在我国信息化建设中占据核心地位&…

小白必学!场外期权的交易模式

场外期权的交易模式 随着金融市场的深化与创新&#xff0c;场外期权交易作为一种灵活多样的金融衍生品交易方式&#xff0c;正逐渐成为投资者关注的焦点。场外期权&#xff0c;顾名思义&#xff0c;是在非交易所市场进行的期权交易&#xff0c;与交易所期权有着显著的区别。那…