4.将图神经网络应用于大规模图数据(Cluster-GCN)

news2025/2/12 7:54:03

       到目前为止,我们已经为节点分类任务单独以全批方式训练了图神经网络。特别是,这意味着每个节点的隐藏表示都是并行计算的,并且可以在下一层中重复使用。

       然而,一旦我们想在更大的图上操作,由于内存消耗爆炸,这种方案就不再可行。例如,一个具有大约1000万个节点和128个隐藏特征维度的图已经为每层消耗了大约5GB的GPU内存

       因此,最近有一些努力让图神经网络扩展到更大的图。其中一种方法被称为Cluster-GCN (Chiang et al. (2019),它基于将图预先划分为子图,可以在子图上以小批量的方式进行操作。

       为了展示,让我们从 Planetoid 节点分类基准套件(Yang et al. (2016))中加载PubMed 图:

import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='PubMed', transform=NormalizeFeatures())

print()
print(f'Dataset: {dataset}:')
print('==================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('===============================================================================================================')

# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.3f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

在这里插入图片描述
       可以看出,该图大约有19717个节点。虽然这个数量的节点应该可以轻松地放入GPU内存,但它仍然是一个很好的例子,可以展示如何在PyTorch Geometric中扩展GNN

       Cluster-GCN的工作原理是首先基于图划分算法将图划分为子图。因此,GNN被限制为仅在其特定子图内进行卷积,从而省略了邻域爆炸的问题。
在这里插入图片描述
       然而,在对图进行分区后,会删除一些链接,这可能会由于有偏差的估计而限制模型的性能。为了解决这个问题,Cluster-GCN还将类别之间的连接合并到一个小批量中,这导致了以下随机划分方案
在这里插入图片描述

       这里,颜色表示每批维护的邻接信息(对于每个epoch可能不同)。PyTorch Geometric提供了Cluster-GCN算法的两阶段实现

  1. ClusterData 将一个 Data 对象转换为包含num_parts分区的子图的数据集。
  2. 给定一个用户定义的batch_size, ClusterLoader 实现随机划分方案以创建小批量。

然后,制作小批量的程序如下:

from torch_geometric.loader import ClusterData, ClusterLoader

torch.manual_seed(12345)
cluster_data = ClusterData(data, num_parts=128)  # 1. Create subgraphs.
train_loader = ClusterLoader(cluster_data, batch_size=32, shuffle=True)  # 2. Stochastic partioning scheme.

print()
total_num_nodes = 0
for step, sub_data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of nodes in the current batch: {sub_data.num_nodes}')
    print(sub_data)
    print()
    total_num_nodes += sub_data.num_nodes

print(f'Iterated over {total_num_nodes} of {data.num_nodes} nodes!')

在这里插入图片描述

       在这里,我们将初始图划分为128个分区,并使用32个子图batch_size来形成mini-batches(每个epoch4批)。正如我们所看到的,在一个epoch之后,每个节点都被精确地看到了一次。

       Cluster-GCN的伟大之处在于它不会使GNN模型的实现复杂化。我们构造如下一个简单模型:
在这里插入图片描述
       这种图神经网络的训练与用于图分类任务的图神经网络训练非常相似。我们现在不再以全批处理的方式对图进行操作,而是对每个小批进行迭代,并相互独立地优化每个批:

model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for sub_data in train_loader:  # Iterate over each mini-batch.
        out = model(sub_data.x, sub_data.edge_index)  # Perform a single forward pass.
        loss = criterion(out[sub_data.train_mask], sub_data.y[sub_data.train_mask])  # Compute the loss solely based on the training nodes.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)  # Use the class with highest probability.
      
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = pred[mask] == data.y[mask]  # Check against ground-truth labels.
        accs.append(int(correct.sum()) / int(mask.sum()))  # Derive ratio of correct predictions.
    return accs

for epoch in range(1, 51):
    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

在这里插入图片描述
       在本文中,我们向您介绍了一种将 scale GNNs to large graphs的方法,否则这些图将不适合GPU内存。

本文内容参考:PyG官网

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/639946.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Python 生成器与迭代器】零基础也能轻松掌握的学习路线与参考资料

一、Python生成器与迭代器概述 Python是一种高级编程语言,其中非常重要的概念就是生成器和迭代器。Python生成器和迭代器联合使用,能够实现高效的迭代操作,避免增加额外的内存消耗,同时提高代码的可读性。Python中常见的生成器和…

单机多节点 elasticsearch 集群安全认证

es 版本:7.6.2 部署环境:CentOS Linux release 7.6.1810 (Core) 一:生成 ca 证书 cd 到 es 的安装目录,并执行下面的命令来生成 ca 证书: ./bin/elasticsearch-certutil ca Elasticsearch碰到第一个直接回车&#xf…

面试专题:Mysql

1.说说自己对于 MySQL 常见的两种存储引擎:MyISAM与InnoDB的理解 关于二者的对比与总结: 1.count运算上的区别:因为MyISAM缓存有表meta-data(行数等),因此在做COUNT(*)时对于一个结构很好的查询是不需要消耗多少资源的…

[CKA]考试之K8s 版本升级

由于最新的CKA考试改版,不允许存储书签,本博客致力怎么一步步从官网把答案找到,如何修改把题做对,下面开始我们的 CKA之旅 题目为: Task 现有的Kubernetes 集群正在运行版本1.22.0。仅将master节点上的所有 Kuberne…

【Python】集合 set ② ( 集合常用操作 | 集合中添加元素 | 集合中移除元素 | 集合中随机取出元素 )

文章目录 一、集合中添加元素二、集合中移除元素三、集合中随机取出元素 在 Python 中 , 集合 set 是无序的 , 因此 集合 数据容器 不支持 使用 下标索引 访问 集合元素 ; 一、集合中添加元素 调用 集合#add(新元素) 函数 , 可以将新元素添加到 集合 数据容器中 ; 集合添加元素…

Vue- ref属性

ref属性 被用来给元素或者子组件注册引用信息(id的替代者) 通过案例来演示_ref属性 1 编写案例 如图:有一个按钮,点击按钮可以输出dom元素 备注:虽然vue不用我们亲自操作dom,但是有的特殊的情况下就要…

【2023华中杯】B题 小学教学应用题 相似性度量及难度评估 29页论文及MATLAB代码

1 题目 B 题 小学数学应用题相似性度量及难度评估 某 MOOC 在线教育平台希望能够进行个性化教学,实现用户自主学习。在用户学习时,系统从题库中随机抽取若干道与例题同步的随堂测试题,记录、分析学生的学习和答题信息,并且课后会自…

【Pytest实战】解决ModuleNotFoundError: No module named ‘pytest’问题

😄作者简介: 小曾同学.com,一个致力于测试开发的博主⛽️,主要职责:测试开发、CI/CD 如果文章知识点有错误的地方,还请大家指正,让我们一起学习,一起进步。😊 座右铭:不想…

JAVA程序的性能优化实践总结

1、 衡量程序性能的指标 可以从常用的性能评估指标入手: 并发:同一时间有多少请求访问TPS:transaction per second(每秒的事物数)QPS:query per second(每秒请求数)耗时:端到端耗时,服务端耗时&#xff…

并行计算——MPI编程

目录 基础知识 进程与线程,并行与并发 奇偶排序 MPI实现 odd-even sort 思路 环境部署 编程实现(C) “若干”的问题 参考链接 一个偶然的机会,我接触到了国立清华大学的MPI编程作业,也就接触到了并行计算。这…

基于Python3接口自动化测试初探

自动化测试是什么? 自动化测试简单来说就是借助工具的方式来辅助手动测试的行为就可以看做是自动化测试。 自动化测试工具有哪些? 现在常用的自动化测试工具包括: QTP:主要用于回归测试和测试同一软件的新版本 Robot Framewor…

大数据ETL工具Kettle

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言最近公司在搞大数据数字化,有MES,CIM,WorkFlow等等N多的系统,不同的数据源DB,需要将这些不同的数据源DB里的数据进行整治统一…

【算法】模拟,高精度

高精度加法 P1601 AB Problem(高精) - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 思路就是模拟,值得注意的就是要用字符串类型输入。存进自己的int数组时要倒着存,因为如果是正着存的话,进位会有点trouble。 时间…

Spread.NET v16.0.20222.0 ASP.NET cRACK

关于 Spread.NET 提供类似 Excel 的电子表格体验。 Spread.NET 可帮助您创建电子表格、网格、仪表板和表单。它包括一个强大的计算引擎,具有450 函数以及导入和导出Excel电子表格的能力。利用广泛的 .NET 电子表格 API 和强大的计算引擎来创建分析、预算、仪表板、…

【C++ 基础篇:24】:【重要模板】C++ 输入输出运算符重载【以 Date 日期类为例】

系列文章说明 本系列 C 相关文章 仅为笔者学习笔记记录,用自己的理解记录学习!C 学习系列将分为三个阶段:基础篇、STL 篇、高阶数据结构与算法篇,相关重点内容如下: 基础篇:类与对象(涉及C的三大…

Mysql Access denied for user ‘root‘@ ‘*.*.*.*‘ (using password: YES)异常处理

目录 一、异常错误二、原因三、解决方法 一、异常错误 PS C:\Users\10568> mysql -u root -p Enter password: **** ERROR 1045 (28000): Access denied for user rootlocalhost (using password: YES)Access denied表示拒绝访问,using password:NO/…

计算机视觉 | 语义分割与Segmentation

前 言 「MMSegmentation」 是一个基于 PyTorch 的语义分割开源工具箱。它是 OpenMMLab 项目的一部分。 MMSegmentation v1.x 在 0.x 版本的基础上有了显著的提升,提供了更加灵活和功能丰富的体验。 主要特性 统一的基准平台 我们将各种各样的语义分割算法集成到了…

Linux权限维持

SSH后门&VIM后门 ssh后门: 创建一个软链接: ln -sf /usr/sbin/sshd /tmp/su 拓展:软链接相当于一个快捷键,硬链接相当于一个指针指向文件地址,也类似于复制 开启后门: /tmp/su -oport12345 开启后…

chatgpt赋能python:Python另存为:如何保存你的程序代码

Python另存为:如何保存你的程序代码 简介 Python是一种高级编程语言,最初由Guido van Rossum于1991年创建。自创建以来,Python已被广泛应用于Web开发、数据分析、人工智能等领域。作为一名有10年Python编程经验的工程师,我发现在…

万物的算法日记|第一天

笔者自述: 一直有一个声音也一直能听到身边的大佬经常说,要把算法学习搞好,一定要重视平时的算法学习,虽然每天也在学算法,但是感觉自己一直在假装努力表面功夫骗了自己,没有规划好自己的算法学习和总结&am…