torch_geometric使用手册-Creating Message Passing Networks(专题二)

news2024/11/24 9:31:57

创建消息传递网络 (Message Passing Networks)

在图神经网络中,将卷积操作推广到不规则域通常表现为一种邻域聚合 (neighborhood aggregation)消息传递 (message passing) 机制。
这一机制通过聚合节点的邻居信息,更新每个节点的特征。

以下公式描述了消息传递机制的基本形式:

公式解释

x i ( k ) = γ ( k ) ( x i ( k − 1 ) , ⨁ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \bigoplus_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) xi(k)=γ(k) xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i)

  • x i ( k − 1 ) \mathbf{x}_i^{(k-1)} xi(k1): 表示第 k − 1 k-1 k1 层时节点 i i i 的特征。
  • e j , i \mathbf{e}_{j,i} ej,i: 表示从节点 j j j 到节点 i i i 的边特征(可选)。
  • N ( i ) \mathcal{N}(i) N(i): 节点 i i i 的邻居节点集合。
  • ϕ ( k ) \phi^{(k)} ϕ(k): 消息函数 (message function),生成从邻居节点 j j j 到节点 i i i 的消息。
  • ⨁ \bigoplus : 聚合函数 (aggregation function),例如加和 (sum)、均值 (mean) 或最大值 (max)。
  • γ ( k ) \gamma^{(k)} γ(k): 更新函数 (update function),结合节点本身的特征与聚合后的消息。

PyTorch Geometric (PyG) 提供了一个名为 MessagePassing 的基类,专门用于实现基于消息传递机制的图神经网络(GNN)。这个类封装了消息传递中的许多细节,开发者只需要定义核心函数,例如消息构造(message)、特征更新(update),以及选择合适的聚合方式(aggr),即可实现复杂的 GNN 算法。


核心概念与方法解析

1. 构造 MessagePassing 基类

MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
功能
  • 定义消息传递的聚合方式:

    • aggr: 表示如何将来自邻居节点的消息聚合到目标节点。
      • add(加和):计算邻居节点消息的加权和。
      • mean(平均):取邻居节点消息的加权平均值。
      • max(最大值):选择邻居节点消息的最大值。
  • 定义消息的传递方向:

    • flow:
      • "source_to_target":从源节点传递消息到目标节点。
      • "target_to_source":从目标节点向源节点传递消息。
  • node_dim:

    • 指定在哪一维度上传递节点特征。通常是倒数第二维(默认为 -2),适配节点特征张量。

2. 消息传递的入口:propagate 方法

MessagePassing.propagate(edge_index, size=None, **kwargs)
功能
  • 触发消息传递过程,从边索引和输入特征开始,依次执行:
    1. 消息构造message):生成从邻居节点传来的消息。
    2. 消息聚合aggregate,自动完成):将邻居节点的消息聚合到目标节点。
    3. 特征更新update):更新目标节点的最终特征。

注意: 这是入口函数,类似forward的操作,会调用messageaggregateupdate函数.

参数
  • edge_index:

    • 图的边索引,形状为 [2, num_edges]
    • 第一行表示源节点,第二行表示目标节点。
  • size:

    • 图中节点的数量或维度。
    • 对于普通图,默认为 [num_nodes, num_nodes];对于二分图(bipartite graph),可以传递 (N, M),分别表示源节点和目标节点数量。
  • kwargs:

    • 其他参数,如节点特征 x,边特征 edge_attr 等。

3. 消息生成:message 方法

MessagePassing.message(...)
功能
  • 根据每条边的两端节点特征(源节点和目标节点)以及边特征,生成要传递的消息。
参数
  • 默认情况下:
    • x_j: 源节点的特征。
    • x_i: 目标节点的特征。
    • edge_attr: 边的特征(如果存在)。
自动变量映射

propagate 内部,会根据 edge_index 自动将输入特征分为:

  • x_j:从源节点出发的特征。
  • x_i:传递到目标节点的特征。

4. 特征更新:update 方法

MessagePassing.update(aggr_out, ...)
功能
  • 根据聚合后的结果 aggr_out,计算目标节点的最终特征。
参数
  • aggr_out: 聚合后的邻居节点消息。
  • 可以使用其他参数,例如目标节点本身的初始特征。

5. 应用流程总结

  1. 消息生成

    • 根据边和节点特征,生成从邻居节点传递的消息(通过 message 方法)。
  2. 消息聚合

    • 使用选定的聚合方式(aggr 参数,如加和或平均),将消息聚合到目标节点。
  3. 特征更新

    • 在目标节点上应用更新规则,生成最终的节点特征(通过 update 方法)。

示例:实现经典的 GCN 和 EdgeConv

实现 GCN 层(Graph Convolutional Layer)

GCN 层的数学定义如下:

x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ⁡ ( i ) ⋅ deg ⁡ ( j ) ⋅ ( W ⊤ ⋅ x j ( k − 1 ) ) + b \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b} xi(k)=jN(i){i}deg(i) deg(j) 1(Wxj(k1))+b

  • 邻居节点特征通过一个权重矩阵 W \mathbf{W} W 进行变换。
  • 然后,按照节点度进行归一化。
  • 最后,对邻居节点特征进行聚合并添加偏置项 b \mathbf{b} b

这个公式可以拆解为以下几个步骤:

  1. 为邻接矩阵添加自环(self-loops)
  2. 对节点特征矩阵进行线性变换
  3. 计算归一化系数
  4. 对特征进行归一化处理
  5. 聚合邻居节点特征(使用"加和"操作,"add" 聚合)。
  6. 对聚合结果加上最终的偏置项

在实现过程中:

  • 步骤 1-3 通常在消息传递(message passing)前完成。
  • 步骤 4-5 使用 MessagePassing 基类轻松实现。

以下是完整的 GCN 层实现代码:

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "加和"聚合 (Step 5)
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # Step 1: 添加自环到邻接矩阵
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: 对节点特征进行线性变换
        x = self.lin(x)

        # Step 3: 计算归一化系数
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: 开始消息传递
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: 添加最终的偏置项
        out = out + self.bias

        return out

    def message(self, x_j, norm):
        # 对节点特征进行归一化 (Step 4)
        return norm.view(-1, 1) * x_j

实现 EdgeConv(边卷积)

边卷积用于处理图结构或点云,其数学定义为:

x i ( k ) = max ⁡ j ∈ N ( i ) h Θ ( x i ( k − 1 ) , x j ( k − 1 ) − x i ( k − 1 ) ) \mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}} \left( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right) xi(k)=jN(i)maxhΘ(xi(k1),xj(k1)xi(k1))

其中, h Θ h_{\mathbf{\Theta}} hΘ 是一个多层感知机(MLP)。
与 GCN 类似,EdgeConv 层也基于 MessagePassing 实现,但使用的是 "max" 聚合方式。

以下是 EdgeConv 的实现代码:

import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing

class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max')  # "最大值" 聚合
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # 计算相对特征并输入到 MLP
        tmp = torch.cat([x_i, x_j - x_i], dim=1)
        return self.mlp(tmp)

EdgeConv 实际上是一个动态卷积,每一层都在特征空间中根据最近邻重新计算图。
PyG 提供了一个 GPU 加速的 k-NN 图生成方法 knn_graph

from torch_geometric.nn import knn_graph

class DynamicEdgeConv(EdgeConv):
    def __init__(self, in_channels, out_channels, k=6):
        super().__init__(in_channels, out_channels)
        self.k = k

    def forward(self, x, batch=None):
        edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
        return super().forward(x, edge_index)

DynamicEdgeConv 动态生成 k-NN 图,然后调用 EdgeConvforward 方法。


练习题翻译

关于 GCNConv:

  1. rowcol 包含什么信息?
  2. degree 方法的作用是什么?
  3. 为什么用 degree(col, ...) 而不是 degree(row, ...)
  4. deg_inv_sqrt[col]deg_inv_sqrt[row] 的作用是什么?
  5. message 方法中,x_j 包含什么信息?如果 self.lin 是恒等函数,x_j 的内容具体是什么?
  6. 添加一个 update 方法,使其将变换后的中心节点特征添加到聚合输出中。

关于 EdgeConv:

  1. x_ix_j - x_i 是什么?
  2. torch.cat([x_i, x_j - x_i], dim=1) 的作用是什么?为什么是 dim=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2246583.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【大数据学习 | Spark-Core】spark-shell开发

spark的代码分为两种 本地代码在driver端直接解析执行没有后续 集群代码,会在driver端进行解析,然后让多个机器进行集群形式的执行计算 spark-shell --master spark://nn1:7077 --executor-cores 2 --executor-memory 2G sc.textFile("/home/ha…

增量预训练(Pretrain)样本拼接篇

增量预训练(Pretrain)样本拼接篇 一、Pretrain阶段,为什么需要拼接拼接? 为了提高pretrain效率、拓展LLM最大长度,随机将若干条短文本进行拼接是pretrain阶段常见手段。 二、有哪些拼接方式? 拼接方式一…

【AI最前线】DP双像素sensor相关的AI算法全集:深度估计、图像去模糊去雨去雾恢复、图像重建、自动对焦

Dual Pixel 简介 双像素是成像系统的感光元器件中单帧同时生成的图像:通过双像素可以实现:深度估计、图像去模糊去雨去雾恢复、图像重建 成像原理来源如上,也有遮罩等方式的pd生成,如图双像素视图可以看到光圈的不同一半&#x…

从零开始-VitePress 构建个人博客上传GitHub自动构建访问

从零开始-VitePress 构建个人博客上传GitHub自动构建访问 序言 VitePress 官网:VitePress 中文版 1. 什么是 VitePress VitePress 是一个静态站点生成器 (SSG),专为构建快速、以内容为中心的站点而设计。简而言之,VitePress 获取用 Markdown…

使用uniapp编写APP的文件上传

使用uniapp插件文件选择、文件上传组件(图片,视频,文件等) - DCloud 插件市场 实用效果: 缺陷是只能一个一个单独上传

【51单片机】红外遥控

学习使用的开发板:STC89C52RC/LE52RC 编程软件:Keil5 烧录软件:stc-isp 开发板实图: 文章目录 红外遥控硬件电路 NEC协议编码编程实例LCD1602显示Data红外遥控控制扇叶转速 红外遥控 红外遥控是利用红外光进行通信的设备&#…

【解决】Unity TMPro字体中文显示错误/不全问题

问题描述:字体变成方块 原因:字体资源所承载的长度有限 1.找一个中文字体放入Assets中 2.选中字体创建为TMPro 字体资源 3.选中创建好的字体资源(蓝色的大F) 在右边的属性中找到Atlas Width h和 Atlas Heigth,修改的大一点&…

深度学习:GPT-1的MindSpore实践

GPT-1简介 GPT-1(Generative Pre-trained Transformer)是2018年由Open AI提出的一个结合预训练和微调的用于解决文本理解和文本生成任务的模型。它的基础是Transformer架构,具有如下创新点: NLP领域的迁移学习:通过最…

CKA认证 | Day2 K8s内部监控与日志

第三章 Kubernetes监控与日志 1、查看集群资源状态 在 Kubernetes 集群中,查看集群资源状态和组件状态是非常重要的操作。以下是一些常用的命令和解释,帮助你更好地管理和监控 Kubernetes 集群。 1.1 查看master组件状态 Kubernetes 的 Master 组件包…

概念解读|K8s/容器云/裸金属/云原生...这些都有什么区别?

随着容器技术的日渐成熟,不少企业用户都对应用系统开展了容器化改造。而在容器基础架构层面,很多运维人员都更熟悉虚拟化环境,对“容器圈”的各种概念容易混淆:容器就是 Kubernetes 吗?容器云又是什么?容器…

JDBC编程---Java

目录 一、数据库编程的前置 二、Java的数据库编程----JDBC 1.概念 2.JDBC编程的优点 三.导入MySQL驱动包 四、JDBC编程的实战 1.创造数据源,并设置数据库所在的位置,三条固定写法 2.建立和数据库服务器之间的连接,连接好了后&#xff…

移动充储机器人“小奥”的多场景应用(上)

在当前现代化城市交通体系中,移动充储机器人“小奥”发挥着至关重要的作用。该机器人不仅是一个简单的设备,而是一个集成了高科技的移动充电站,为新能源汽车提供了一种前所未有的便捷充电解决方案。该机器人配备了先进的电池管理系统&#xf…

element dialog会隐藏body scroll 导致tab抖动 解决方案如下

element dialog会隐藏body scroll 导致tab抖动 解决方案如下 在dialog标签添加 :lockScroll"false"搞定

Android 功耗分析(底层篇)

最近在网上发现关于功耗分析系列的文章很少,介绍详细的更少,于是便想记录总结一下功耗分析的相关知识,有不对的地方希望大家多指出,互相学习。本系列分为底层篇和上层篇。 大概从基础知识,测试手法,以及案例…

Bugku CTF_Web——my-first-sqli

Bugku CTF_Web——my-first-sqli 进入靶场 随便输一个看看 点login没有任何回显 方法一: 上bp抓包 放到repeter测试 试试万能密码(靶机过期了重新开了个靶机) admin or 11--shellmates{SQLi_goeS_BrrRrRR}方法二: 拿包直接梭…

BUUCTF—Reverse—easyre(1)

非常简单的逆向 拿到exe文件先查下信息,是一个64位程序,没有加壳(壳是对代码的加密,起混淆保护的作用,一般用来阻止逆向)。 然后拖进IDA(64位)进行反汇编 打开以后就可以看到flag flag{this_Is_a_EaSyRe}

全面击破工程级复杂缓存难题

目录 一、走进业务中的缓存 (一)本地缓存 (二)分布式缓存 二、缓存更新模式分析 (一)Cache Aside Pattern(旁路缓存模式) 读操作流程 写操作流程 流程问题思考 问题1&#…

React基础知识一

写的东西太多了,照成csdn文档编辑器都开始卡顿了,所以分篇写。 1.安装React 需要安装下面三个包。 react:react核心包 react-dom:渲染需要用到的核心包 babel:将jsx语法转换成React代码的工具。(没使用jsx可以不装)1.1 在html中…

Vue3中使用:deep修改element-plus的样式无效怎么办?

前言:当我们用 vue3 :deep() 处理 elementui 中 el-dialog_body和el-dislog__header 的时候样式一直无法生效,遇到这种情况怎么办? 解决办法: 1.直接在 dialog 上面增加class 我试过,也不起作用,最后用这种…

鸿蒙进阶-状态管理

大家好啊,这里是鸿蒙开天组,今天我们来学习状态管理。 开始组件化开发之后,如何管理组件的状态会变得尤为重要,咱们接下来系统的学习一下这部分的内容 状态管理机制 在声明式UI编程框架中,UI是程序状态的运行结果&a…