使用Pytorch Geometric 进行链接预测代码示例

news2025/3/12 11:44:42

PyTorch Geometric (PyG)是构建图神经网络模型和实验各种图卷积的主要工具。在本文中我们将通过链接预测来对其进行介绍。

链接预测答了一个问题:哪两个节点应该相互链接?我们将通过执行“转换分割”,为建模准备数据。为批处理准备专用的图数据加载器。在Torch Geometric中构建一个模型,使用PyTorch Lightning进行训练,并检查模型的性能。

库准备

Torch 这个就不用多介绍了

Torch Geometric 图形神经网络的主要库,也是本文介绍的重点

PyTorch Lightning 用于训练、调优和验证模型。它简化了训练的操作

Sklearn Metrics和Torchmetrics 用于检查模型的性能。

PyTorch Geometric有一些特定的依赖关系,如果你安装有问题,请参阅其官方文档。

数据准备

我们将使用Cora ML引文数据集。数据集可以通过Torch Geometric访问。

 data = tg.datasets.CitationFull(root="data", name="Cora_ML")

默认情况下,Torch Geometric数据集可以返回多个图形。我们看看单个图是什么样子的

 data[0]
 > Data(x=[2995, 2879], edge_index=[2, 16316], y=[2995])

这里的 X是节点的特征。edge_index是2 x (n条边)矩阵(第一维= 2,被解释为:第0行-源节点/“发送方”,第1行-目标节点/“接收方”)。

链接拆分

我们将从拆分数据集中的链接开始。使用20%的图链接作为验证集,10%作为测试集。这里不会向训练数据集中添加负样本,因为这样的负链接将由批处理数据加载器实时创建。

一般来说,负采样会创建“假”样本(在我们的例子中是节点之间的链接),因此模型学习如何区分真实和虚假的链接。负抽样基于抽样的理论和数学,具有一些很好的统计性质。

首先:让我们创建一个链接拆分对象。

 link_splitter = tg.transforms.RandomLinkSplit(
     num_val=0.2, 
     num_test=0.1, 
     add_negative_train_samples=False,
     disjoint_train_ratio=0.8)

disjoint_train_ratio调节在“监督”阶段将使用多少条边作为训练信息。剩余的边将用于消息传递(网络中的信息传输阶段)。

图神经网络中至少有两种分割边的方法:归纳分割和传导分割。转换方法假设GNN需要从图结构中学习结构模式。在归纳设置中,可以使用节点/边缘标签进行学习。本文最后有两篇论文详细讨论了这些概念,并进行了额外的形式化:([1],[3])。

 train_g, val_g, test_g = link_splitter(data[0])
 
 > Data(x=[2995, 2879], edge_index=[2, 2285], y=[2995], edge_label=[9137], edge_label_index=[2, 9137])

在这个操作之后,我们有了一些新的属性:

edge_label :描述边缘是否为真/假。这是我们想要预测的。

edge_label_index 是一个2 x NUM EDGES矩阵,用于存储节点链接。

让我们看看样本的分布

 th.unique(train_g.edge_label, return_counts=True)
 > (tensor([1.]), tensor([9137]))
 
 th.unique(val_g.edge_label, return_counts=True)
 > (tensor([0., 1.]), tensor([3263, 3263]))
 
 th.unique(val_g.edge_label, return_counts=True)
 > (tensor([0., 1.]), tensor([3263, 3263]))

对于训练数据没有负边(我们将训练时创建它们),对于val/测试集——已经以50:50的比例有了一些“假”链接。

模型

现在我们可以在使用GNN进行模型的构建了一个

 class GNN(nn.Module):
     
     def __init__(
         self, 
         dim_in: int, 
         conv_sizes: Tuple[int, ...], 
         act_f: nn.Module = th.relu, 
         dropout: float = 0.1,
         *args, 
         **kwargs):
         super().__init__()
         self.dim_in = dim_in
         self.dim_out = conv_sizes[-1]
         self.dropout = dropout
         self.act_f = act_f
         last_in = dim_in
         layers = []
         
         # Here we build subsequent graph convolutions.
         for conv_sz in conv_sizes:
             # Single graph convolution layer
             conv = tgnn.SAGEConv(in_channels=last_in, out_channels=conv_sz, *args, **kwargs)
             last_in = conv_sz
             layers.append(conv)
         self.layers = nn.ModuleList(layers)
     
     def forward(self, x: th.Tensor, edge_index: th.Tensor) -> th.Tensor:
         h = x
         # For every graph convolution in the network...
         for conv in self.layers:
             # ... perform node embedding via message passing
             h = conv(h, edge_index)
             h = self.act_f(h)
             if self.dropout:
                 h = nn.functional.dropout(h, p=self.dropout, training=self.training)
         return h

这个模型中值得注意的部分是一组图卷积——在我们的例子中是SAGEConv。SAGE卷积的正式定义为:

v是当前节点,节点v的N(v)个邻居。要了解更多关于这种卷积类型的信息,请查看GraphSAGE[1]的原始论文

让我们检查一下模型是否可以使用准备好的数据进行预测。这里PyG模型的输入是节点特征X的矩阵和定义edge_index的链接。

 gnn = GNN(train_g.x.size()[1], conv_sizes=[512, 256, 128])
 with th.no_grad():
     out = gnn(train_g.x, train_g.edge_index)
     
 out
 
 
 > tensor([[0.0000, 0.0000, 0.0051,  ..., 0.0997, 0.0000, 0.0000],
         [0.0107, 0.0000, 0.0576,  ..., 0.0651, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0102,  ..., 0.0973, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0549,  ..., 0.0671, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0166,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0034,  ..., 0.1111, 0.0000, 0.0000]])

我们模型的输出是一个维度为:N个节点x嵌入大小的节点嵌入矩阵。

PyTorch Lightning

PyTorch Lightning主要用作训练,但是这里我们在GNN的输出后面增加了一个Linear层做为预测是否链接的输出头。

 class LinkPredModel(pl.LightningModule):
     
     def __init__(
         self,
         dim_in: int,
         conv_sizes: Tuple[int, ...], 
         act_f: nn.Module = th.relu, 
         dropout: float = 0.1,
         lr: float = 0.01,
         *args, **kwargs):
         super().__init__()
         
         # Our inner GNN model
         self.gnn = GNN(dim_in, conv_sizes=conv_sizes, act_f=act_f, dropout=dropout)
         
         # Final prediction model on links.
         self.lin_pred = nn.Linear(self.gnn.dim_out, 1)
         self.lr = lr
     
     def forward(self, x: th.Tensor, edge_index: th.Tensor) -> th.Tensor:
         # Step 1: make node embeddings using GNN.
         h = self.gnn(x, edge_index)
         
         # Take source nodes embeddings- senders
         h_src = h[edge_index[0, :]]
         # Take target node embeddings - receivers
         h_dst = h[edge_index[1, :]]
         
         # Calculate the product between them
         src_dst_mult = h_src * h_dst
         # Apply non-linearity
         out = self.lin_pred(src_dst_mult)
         return out
     
     def _step(self, batch: th.Tensor, phase: str='train') -> th.Tensor:
         yhat_edge = self(batch.x, batch.edge_label_index).squeeze()
         y = batch.edge_label
         loss = nn.functional.binary_cross_entropy_with_logits(input=yhat_edge, target=y)
         f1 = tm.functional.f1_score(preds=yhat_edge, target=y, task='binary')
         prec = tm.functional.precision(preds=yhat_edge, target=y, task='binary')
         recall = tm.functional.recall(preds=yhat_edge, target=y, task='binary')
         
         # Watch for logging here - we need to provide batch_size, as (at the time of this implementation)
         # PL cannot understand the batch size.
         self.log(f"{phase}_f1", f1, batch_size=batch.edge_label_index.shape[1])
         self.log(f"{phase}_loss", loss, batch_size=batch.edge_label_index.shape[1])
         self.log(f"{phase}_precision", prec, batch_size=batch.edge_label_index.shape[1])
         self.log(f"{phase}_recall", recall, batch_size=batch.edge_label_index.shape[1])
 
         return loss
     
     def training_step(self, batch, batch_idx):
         return self._step(batch)
     
     def validation_step(self, batch, batch_idx):
         return self._step(batch, "val")
     
     def test_step(self, batch, batch_idx):
         return self._step(batch, "test")
     
     def predict_step(self, batch):
         x, edge_index = batch
         return self(x, edge_index)
     
     def configure_optimizers(self):
         return th.optim.Adam(self.parameters(), lr=self.lr)

PyTorch Lightning的作用是帮我们简化了训练的步骤,我们只需要配置一些函数即可,我们可以使用以下命令测试模型是否可用

 model = LinkPredModel(val_g.x.size()[1], conv_sizes=[512, 256, 128])
 with th.no_grad():
     out = model.predict_step((val_g.x, val_g.edge_label_index))

训练

对于训练的步骤,需要特殊处理的是数据加载器。

图数据需要特殊处理——尤其是链接预测。PyG有一些专门的数据加载器类,它们负责正确地生成批处理。我们将使用:tg.loader.LinkNeighborLoader,它接受以下输入:

要批量加载的数据(图)。num_neighbors 每个节点在一次“跳”期间加载的最大邻居数量。指定邻居数目的列表1 - 2 - 3 -…-K。对于非常大的图形特别有用。

edge_label_index 哪个属性已经指示了真/假链接。

neg_sampling_ratio -负样本与真实样本的比例。

 train_loader = tg.loader.LinkNeighborLoader(
     train_g,
     num_neighbors=[-1, 10, 5],
     batch_size=128,
     edge_label_index=train_g.edge_label_index,
     
     # "on the fly" negative sampling creation for batch
     neg_sampling_ratio=0.5
 )
 
 val_loader = tg.loader.LinkNeighborLoader(
     val_g,
     num_neighbors=[-1, 10, 5],
     batch_size=128,
     edge_label_index=val_g.edge_label_index,
     edge_label=val_g.edge_label,
 
     # negative samples for val set are done already as ground-truth
     neg_sampling_ratio=0.0
 )
 
 test_loader = tg.loader.LinkNeighborLoader(
     test_g,
     num_neighbors=[-1, 10, 5],
     batch_size=128,
     edge_label_index=test_g.edge_label_index,
     edge_label=test_g.edge_label,
     
     # negative samples for test set are done already as ground-truth
     neg_sampling_ratio=0.0
 )

下面就是训练模型

 model = LinkPredModel(val_g.x.size()[1], conv_sizes=[512, 256, 128])
 trainer = pl.Trainer(max_epochs=20, log_every_n_steps=5)
 
 # Validate before training - we will see results of untrained model.
 trainer.validate(model, val_loader)
 
 # Train the model
 trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

试验数据核对,查看分类报告和ROC曲线。

 with th.no_grad():
     yhat_test_proba = th.sigmoid(model(test_g.x, test_g.edge_label_index)).squeeze()
     yhat_test_cls = yhat_test_proba >= 0.5
     
 print(classification_report(y_true=test_g.edge_label, y_pred=yhat_test_cls))

结果看起来还不错:

               precision    recall  f1-score   support
 
          0.0       0.68      0.70      0.69      1631
          1.0       0.69      0.66      0.68      1631
 
     accuracy                           0.68      3262
    macro avg       0.68      0.68      0.68      3262
 weighted avg       0.68      0.68      0.68      3262

ROC曲线也不错

我们训练的模型并不特别复杂,也没有经过精心调整,但它完成了工作。当然这只是一个为了演示使用的小型数据集。

总结

图神经网络尽管看起来很复杂,但是PyTorch Geometric为我们提供了一个很好的解决方案。我们可以直接使用其中内置的模型实现,这方便了我们使用和简化了入门的门槛。

本文代码:

https://avoid.overfit.cn/post/e14c4369776243d68c22c4a2a0346db2

参考:

  1. Hamilton, W., Ying, Z., & Leskovec, J. (2017). Inductive representation learning on large graphs. Advances in neural information processing systems, 30.
  2. McCormick, C. (2017). Word2Vec Tutorial Part 2 — Negative Sampling.
  3. Rossi, A., Tiezzi, M., Dimitri, G. M., Bianchini, M., Maggini, M., & Scarselli, F. (2018). Inductive–transductive learning with graph neural networks. In Artificial Neural Networks in Pattern Recognition: 8th IAPR TC3 Workshop, ANNPR 2018, Siena, Italy, September 19–21, 2018, Proceedings 8 (pp. 201–212). Springer International Publishing.

作者:Filip Wójcik

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1113233.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

书客护眼台灯好用吗?书客、柏曼、飞利浦多维度测评

护眼台灯作为一种辅助照明设备,旨在提供舒适的光线环境,以减轻眼睛疲劳和保护视力健康。它通常采用柔和的光线、调节亮度和色温的功能,以及一些附加的设计特点,如可调节灯颈、遮光罩等。虽然护眼台灯并不能完全解决眼部问题&#…

YOLOv8改进实战 | 更换主干网络Backbone(一)之轻量化模型Ghostnet

前言 轻量化网络设计是一种针对移动设备等资源受限环境的深度学习模型设计方法。下面是一些常见的轻量化网络设计方法: 网络剪枝:移除神经网络中冗余的连接和参数,以达到模型压缩和加速的目的。分组卷积:将卷积操作分解为若干个较小的卷积操作,并将它们分别作用于输入的不…

应用在冷链运输中的数字温度传感芯片

冷链运输(Cold-chain transportation)是指在运输全过程中,无论是装卸搬运、变更运输方式、更换包装设备等环节,都使所运输货物始终保持一定温度的运输。冷链运输要求在中、长途运输及短途配送等运输环节的低温状态。它主要涉及铁路…

论文笔记:Multi-Concept Customization of Text-to-Image Diffusion

0 概述 论文:Multi-Concept Customization of Text-to-Image Diffusion 源代码和数据:https://www.cs.cmu.edu/~custom-diffusion/ 当生成模型生成从大规模数据库中学习的概念的高质量图像时,用户通常希望合成他们自己的概念的实例(例如&…

Python技能树练习——python字符串转列表

一、题目与解 把下列字符串转为列表格式输出 top_ide_trend """ Rank Change IDE Share Trend 1 Visual Studio 29.24 % 3.5 % 2 Eclipse 13.91 % -2.9 % 3 Visual Studio Code 12.07 % 3.3 % 4 Android Studio 9.13 % -2.5 % 5 pyCharm 8.43 % 0.7 % 6 …

【设计模式】设计模式概述

😀大家好,我是白晨,一个不是很能熬夜😫,但是也想日更的人✈。如果喜欢这篇文章,点个赞👍,关注一下👀白晨吧!你的支持就是我最大的动力!&#x1f4…

中文编程工具开发软件实际案例:酒店饭店餐饮点餐管理系统软件编程实例

中文编程工具开发软件实际案例:酒店饭店餐饮点餐管理系统软件编程实例图片如下 软件的安装方法: 软件绿色免安装,压缩包文件解压后,将文件夹复制到电脑的D或E盘里,将软件目录下的红色程序图标按右键发送到桌面快捷方式…

泛微全新低代码平台e-builder在沪发布,超千名与会者共商数字化转型

10月18日下午,泛微低代码平台体验大会在上海顺利举办,大会以“智能、协同、全程数字化”为主题,吸引了上千位政府及企事单位的信息化负责人参与。 活动现场,参会者身临其境地体验了泛微低代码平台,了解了泛微低代码平…

【树莓派c++图像处理起航1】

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、Qt OPENCV 安装测试?1. 安装qt2.安装opencv 的基础库3. 安装的路就决定了不会一帆风顺3.1.QT 安装出错3.2 运行Qt错误 4. opencv实际路径&#…

汽车辅助系统

目录 一,项目描述 二,项目 功能 三,代码实现 (1)倒车雷达 (2)AD(对 雨滴与光敏电阻传感器进行AD采集) (3)雨刷 (4)灯光 最后总结&#xf…

干货分享:网页录屏的免费方法!

“网页怎么录屏呀,在浏览器看到一篇文章,觉得挺有价值的,想保存下来,但是不能下载,也不可以复制粘贴,朋友说可以录下来保存,想问问大家,有什么好用免费的网页录屏方法推荐吗&#xf…

Python入门指南

概述: Python是一种简单易学、功能强大的编程语言,广泛应用于数据分析、Web开发、人工智能等领域。本文将为初学者提供一个Python入门指南,从安装到基本语法,帮助您开始编写Python程序。 第一部分:安装Python 1、进入…

单链表经典OJ题 :分割链表

题目: 给你一个链表的头节点 head 和一个特定值 x,请你对链表进行分隔,使得所有小于x 的节点都出现在 大于或等于 x 的节点之前。 你不需要保留 每个分区中各节点的初始相对位置。 图例: 本题的意思: 给定一个数值&am…

C# Onnx Yolov8 Detect 红绿灯检测

效果 lable GreenCircular GreenLeft GreenRight GreenStraight RedCircular RedLeft RedRight RedStraight 项目 代码 using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using OpenCvSharp; using System; using System.Collections.Generic; usi…

从零开始探索C语言(十二)----预处理器、输入输出及文件读写

文章目录 1. 预处理器1.1 预处理器实例1.2 预定义宏1.3 预处理器运算符1.4 参数化的宏 2. 输入和输出2.1 getchar() & putchar() 函数2.2 gets() & puts() 函数 3. 文件读写3.1 打开文件3.2 关闭文件3.3 写入文件3.4 读取文件3.5 二进制 I/O 函数 4. typedef 和 #defin…

YOLOv8改进实战 | 更换主干网络Backbone之轻量化模型Efficientvit

前言 轻量化网络设计是一种针对移动设备等资源受限环境的深度学习模型设计方法。下面是一些常见的轻量化网络设计方法: 网络剪枝:移除神经网络中冗余的连接和参数,以达到模型压缩和加速的目的。分组卷积:将卷积操作分解为若干个较小的卷积操作,并将它们分别作用于输入的不…

【java】【MyBatisPlus】【一】快速入门程序

目录 1、创建空项目mybatisProject 2、创建springboot模块 3、删除多余文件 4、修改pom,引入mybatisplus 5、设置application.yml 6、准备实体Emp 7、创建EmpMapper接口 8、测试MybatisQuickstartApplicationTests 前言:学习MyBatisPlus的基本使…

想要隐藏Word文件内容,如何做?四个方法!

Word文件中有些内容想要隐藏,该如何隐藏?今天分享几个方法给大家 方法一: 最简单的方法,将字体颜色与背景颜色设置为一致的,这样就达到了隐藏的效果,选中文字再修改颜色就可以恢复字体 方法二&#xff1a…

MaaS,云厂商在打一场“翻身仗”

今年以来,大模型的热度,让云计算产业为之沸腾。要举出一个最有力的证明,应该是:MaaS(Model as Service)这种全新模式的出现,一座座“模型工厂”,已经建起来了。 所谓MaaS&#xff0c…