图神经网络实战(15)——SEAL链接预测算法

news2025/1/9 2:06:11

图神经网络实战(15)——SEAL链接预测算法

    • 0. 前言
    • 1. SEAL 框架
      • 1.1 基本原理
      • 1.2 算法流程
    • 2. 实现 SEAL 框架
      • 2.1 数据预处理
      • 2.2 模型构建与训练
    • 小结
    • 系列链接

0. 前言

我们已经学习了基于节点嵌入的链接预测算法,这种方法通过学习相关的节点嵌入来计算链接可能性。接下来,我们介绍另一类方法,通过查看目标节点周围的局部邻域执行链接预测任务,这类技术称为基于子图的算法,由 SEAL 广泛使用,可以说 SEAL 表示用于链接预测的子图、嵌入和属性 (Subgraphs, Embeddings, and Attributes for Link prediction) 的缩写,但并不完全准确。在本节中,我们将介绍 SEAL 框架,并使用 PyTorch Geometric 实现该框架。

1. SEAL 框架

1.1 基本原理

SEALZhangChen2018 年提出,是一个学习图结构特征以进行链接预测的框架。它将目标节点 ( x , y ) (x,y) (x,y) 和它们的 k 跳 (k-hop) 邻居所形成的子图定义为封闭子图 (enclosing subgraph)。每个封闭子图(而非整个图)都被用作预测链接可能性的输入。从另一个角度来看,SEAL 自动学习了一种用于链接预测的局部启发式方法。

1.2 算法流程

SEAL 框架包括三个步骤:

  1. 封闭子图提取 (Enclosing subgraph extraction),包括提取一组真实链接和一组虚假链接(负抽样)来形成训练数据
  2. 节点信息矩阵构建 (Node information matrix construction),包括节点标记、节点嵌入和节点特征三个部分
  3. 图神经网络 (Graph Neural Networks, GNN) 训练 (GNN training),将节点信息矩阵作为输入,并输出链接的可能性

这些步骤可以用下图进行总结:

SEAL 框架

封闭子图提取是一个简单的过程,列出目标节点及其 k 跳邻居,以提取它们的边和特征。k 值越大,SEAL 所能学习到的启发式算法的质量就越高,但同时也会创建更大、计算开销更大的子图。
节点信息构建的第一个部分是节点标记 (node labeling)。这一过程为每个节点分配一个特定的编号,如果没有进行标记,GNN 就无法区分目标节点和上下文节点(目标节点的邻居)。它还融合了距离,用来描述节点的相对位置和结构重要性。
在实践中,目标节点 x x x y y y 必须共享一个唯一的标签,以确定它们是目标节点。对于上下文节点 i i i j j j,如果它们与目标节点的距离相同,则必须共享相同的标签—— d ( i , x ) = d ( j , x ) d(i, x) = d(j, x) d(i,x)=d(j,x) d ( i , y ) = d ( j , y ) d(i, y) = d(j, y) d(i,y)=d(j,y)。我们称这种距离为双半径 (double radius),表示为 ( d ( i , x ) , d ( i , y ) ) (d(i, x), d(i, y)) (d(i,x),d(i,y))
SEAL 中使用双半径节点标记 (Double-Radius Node Labeling, DRNL) 算法,其工作原理如下:

  1. 首先,将标签 1 分配给节点 x x x y y y
  2. 将标签 2 分配给半径为 (1,1) 的节点
  3. 将标签 3 分配给半径为 (1,2)(2,1) 的节点
  4. 将标签 4 分配给半径为 (1,3)(3,1) 的节点,以此类推

DRNL 函数的数学表达式如下:
f ( i ) = 1 + m i n ( d ( i , x ) , d ( i , y ) ) + ( d / 2 ) [ ( d / 2 ) + ( d % 2 ) − 1 ] f(i)=1+min(d(i,x),d(i,y))+(d/2)[(d/2)+(d\%2)-1] f(i)=1+min(d(i,x),d(i,y))+(d/2)[(d/2)+(d%2)1]
其中, d = d ( i , x ) + d ( i , y ) d= d(i, x) + d(i, y) d=d(i,x)+d(i,y) ( d / 2 ) (d/2) (d/2) ( d % 2 ) (d\%2) (d%2) 分别是 d d d 除以 2 2 2 的整数商和余数。最后,对这些节点标签进行独热编码 (one-hot encode)。
节点信息矩阵构建过程中的其它两个部分比较容易获得。节点嵌入是可选的,可以使用其他算法(如 Node2Vec )计算。然后,将它们与节点特征和独热编码标签连接起来,构建最终的节点信息矩阵 (node information matrix)。
最后,训练 GNN,利用封闭子图的信息和邻接矩阵来预测链接。为此,SEAL 使用了深度图卷积神经网络 (Deep Graph Convolutional Neural Network, DGCNN),该架构执行以下三个步骤:

  1. 使用数个图卷积网络 (Graph Convolutional Network, GCN) 层计算节点嵌入,然后将其串联起来,类似于图同构网络 (Graph Isomorphism Network, GIN)
  2. 使用全局排序池化层按照一致的顺序排列这些嵌入,然后再将它们深入到卷积层,而卷积层不具备置换不变性
  3. 使用传统的卷积层和全连接层应用于排序后图表示,并输出链接概率

DGCNN 模型使用二进制交叉熵损失进行训练,输出的概率介于 01 之间

2. 实现 SEAL 框架

SEAL 框架需要进行大量的预处理,以提取并标注封闭子图。接下来,我们使用 PyTorch Geometric 来实现 SEAL 框架。

2.1 数据预处理

(1) 首先,导入所有必要的库:

import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
from scipy.sparse.csgraph import shortest_path

import torch.nn.functional as F
from torch.nn import Conv1d, MaxPool1d, Linear, Dropout, BCEWithLogitsLoss

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, aggr
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix

(2) 加载 Cora 数据集,并应用链接级随机拆分:

transform = RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True)
dataset = Planetoid('.', name='Cora', transform=transform)
train_data, val_data, test_data = dataset[0]

(3) 链接级随机拆分会在数据对象中创建新字段,用于存储每条正样本边(真实的边)和负样本边(虚假的边)的标签和索引:

print(train_data)

# Data(x=[2708, 1433], edge_index=[2, 8976], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], pos_edge_label=[4488], pos_edge_label_index=[2, 4488], neg_edge_label=[4488], neg_edge_label_index=[2, 4488])

(4) 创建函数 seal_processing() 处理拆分后的数据集,并获得带有独热编码节点标签和节点特征的封闭子图,使用列表 data_list 存储这些子图:

def seal_processing(dataset, edge_label_index, y):
    data_list = []

对于数据集中的每一对节点(源和目的节点),提取 k 跳邻居(本节中 k = 2):

    for src, dst in edge_label_index.t().tolist():
        sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph([src, dst], 2, dataset.edge_index, relabel_nodes=True)
        src, dst = mapping.tolist()

使用双半径节点标记 (Double-Radius Node Labeling, DRNL) 函数计算距离。首先,从子图中删除目标节点:

        mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
        mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
        sub_edge_index = sub_edge_index[:, mask1 & mask2]

根据上一个子图计算源节点和目标节点的邻接矩阵:

        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(sub_edge_index, num_nodes=sub_nodes.size(0)).tocsr()

        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]

        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]

计算每个节点与源节点/目标节点之间的距离:

        # Calculate the distance between every node and the source target node
        d_src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)
        d_src = np.insert(d_src, dst, 0, axis=0)
        d_src = torch.from_numpy(d_src)

        # Calculate the distance between every node and the destination target node
        d_dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst-1)
        d_dst = np.insert(d_dst, src, 0, axis=0)
        d_dst = torch.from_numpy(d_dst)

计算子图中每个节点的节点标签 z z z

        dist = d_src + d_dst
        z = 1 + torch.min(d_src, d_dst) + dist // 2 * (dist // 2 + dist % 2 - 1)
        z[src], z[dst], z[torch.isnan(z)] = 1., 1., 0.
        z = z.to(torch.long)

在本节中,并未使用节点嵌入,但仍将特征和独热编码标签串联起来,以构建节点信息矩阵:

        node_labels = F.one_hot(z, num_classes=200).to(torch.float)
        node_emb = dataset.x[sub_nodes]
        node_x = torch.cat([node_emb, node_labels], dim=1)

创建一个 Data 对象并将其附加到列表 data_list 中,作为函数的最终输出:

        data = Data(x=node_x, z=z, edge_index=sub_edge_index, y=y)
        data_list.append(data)

    return data_list

(5) 调用 deal_processing 提取每个数据集的封闭子图。将正样本和负样本分开,以获得正确的预测标签:

train_pos_data_list = seal_processing(train_data, train_data.pos_edge_label_index, 1)
train_neg_data_list = seal_processing(train_data, train_data.neg_edge_label_index, 0)

val_pos_data_list = seal_processing(val_data, val_data.pos_edge_label_index, 1)
val_neg_data_list = seal_processing(val_data, val_data.neg_edge_label_index, 0)

test_pos_data_list = seal_processing(test_data, test_data.pos_edge_label_index, 1)
test_neg_data_list = seal_processing(test_data, test_data.neg_edge_label_index, 0)

(6) 合并正负数据列表,重建训练、验证和测试数据集:

train_dataset = train_pos_data_list + train_neg_data_list
val_dataset = val_pos_data_list + val_neg_data_list
test_dataset = test_pos_data_list + test_neg_data_list

(7) 创建数据加载器,使用批数据训练 GNN

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

2.2 模型构建与训练

(1) 定义 DGCNN 类,其中参数 k 表示每个子图的节点数:

class DGCNN(torch.nn.Module):
    def __init__(self, dim_in, k=30):
        super().__init__()

创建四个 GCN 层,设定隐藏维度为 32

        self.gcn1 = GCNConv(dim_in, 32)
        self.gcn2 = GCNConv(32, 32)
        self.gcn3 = GCNConv(32, 32)
        self.gcn4 = GCNConv(32, 1)

实例化全局排序池化层 (深度图卷积神经网络 (Deep Graph Convolutional Neural Network, DGCNN) 架构的核心):

        self.global_pool = aggr.SortAggregation(k=k)

全局排序池化层提供的节点排序使我们能够使用传统的卷积层:

        self.conv1 = Conv1d(1, 16, 97, 97)
        self.conv2 = Conv1d(16, 32, 5, 1)
        self.maxpool = MaxPool1d(2, 2)

最后,实例化多层感知机 (Multilayer Perceptron, MLP) 用于获取预测:

        self.linear1 = Linear(352, 128)
        self.dropout = Dropout(0.5)
        self.linear2 = Linear(128, 1)

forward() 方法中,计算每个 GCN 的节点嵌入,并将结果串联起来:

    def forward(self, x, edge_index, batch):
        # 1. Graph Convolutional Layers
        h1 = self.gcn1(x, edge_index).tanh()
        h2 = self.gcn2(h1, edge_index).tanh()
        h3 = self.gcn3(h2, edge_index).tanh()
        h4 = self.gcn4(h3, edge_index).tanh()
        h = torch.cat([h1, h2, h3, h4], dim=-1)

对串联结果依次应用全局排序池化、卷积层和全连接层:

        # 2. Global sort pooling
        h = self.global_pool(h, batch)

        # 3. Traditional convolutional and dense layers
        h = h.view(h.size(0), 1, h.size(-1))
        h = self.conv1(h).relu()
        h = self.maxpool(h)
        h = self.conv2(h).relu()
        h = h.view(h.size(0), -1)
        h = self.linear1(h).relu()
        h = self.dropout(h)
        h = self.linear2(h).sigmoid()

        return h

(2) 将模型实例化,并使用 Adam 优化器和二进制交叉熵损失对其进行训练:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DGCNN(train_dataset[0].num_features).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
criterion = BCEWithLogitsLoss()

(3) 创建 train() 函数用于批训练:

def train():
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_dataset)

(4)test() 函数中,计算 ROC AUC 分数和平均精度,以比较 SEAL 和变分图自编码器 (Variational Graph Autoencoder, VGAE) 的性能:

@torch.no_grad()
def test(loader):
    model.eval()
    y_pred, y_true = [], []

    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        y_pred.append(out.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))

    auc = roc_auc_score(torch.cat(y_true), torch.cat(y_pred))
    ap = average_precision_score(torch.cat(y_true), torch.cat(y_pred))

    return auc, ap

(5)DGCNN 进行 31epoch 的训练:

for epoch in range(31):
    loss = train()
    val_auc, val_ap = test(val_loader)
    print(f'Epoch {epoch:>2} | Loss: {loss:.4f} | Val AUC: {val_auc:.4f} | Val AP: {val_ap:.4f}')

模型训练过程监测

(6) 最后,在测试数据集上对其进行测试:

test_auc, test_ap = test(test_loader)
print(f'Test AUC: {test_auc:.4f} | Test AP {test_ap:.4f}')

# Test AUC: 0.7899 | Test AP 0.8174

可以看到,使用 SEAL 框架得到的结果与使用 VGAE 得到的结果(AUC0.8727AP0.8620) 相似。从理论上讲,基于子图的方法(如 SEAL )比基于节点的方法(如 VGAE )更具表达能力,基于子图的方法通过明确考虑目标节点周围的整个邻域来捕捉更多信息。通过 k 参数增加所考虑的邻域数量,可以进一步提高 SEAL 的准确性。

小结

链接预测是指利用图数据中已知的节点和边的信息,来推断图中未知的连接关系或者未来可能出现的连接关系,在机器学习和数据挖掘等领域具有广泛的应用。本节中介绍了用于链接预测的 SEAL 框架,其侧重于子图表示,每个链接周围的邻域作为预测链接概率的输入。并使用边级随机分割和负采样在 Cora 数据集上实现了 SEAL 模型执行链接预测任务。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)——经典链接预测算法
图神经网络实战(14)——基于节点嵌入预测链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1863385.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【第三方JSON库】org.json.simple用法初探—Java编程【Eclipse平台】【不使用项目管理工具】【不添加依赖解析】

本文将重点介绍,在不使用项目管理工具,不添加依赖解析情况下,【第三方库】JSON.simple库在Java编程的应用。 JSON.simple是一种由纯java开发的开源JSON库,包含在JSON.simple.jar中。它提供了一种简单的方式来处理JSON数据和以JSO…

SQL Server 2022从入门到精通

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。…

架构是怎样练成的-楼宇监控系统案例

目录 概要 项目背景 原系统设计方案 改进后的设计方案 小结 概要 绝大多数人掌握的架构都是直接学习,慢慢地才能体会到一个架构的好处。架构是一种抽象,是为了复用目的而对代码做的抽象。通过一个项目的改造,理解架构是如何产生的&…

[C++][设计模式][抽象工厂]详细讲解

目录 1.动机2.模式定义3.要点总结4.代码感受1.代码一2.代码二 -- 工厂方法3.代码三 -- 抽象工厂 1.动机 在软件系统中,经常面临着“一系列相互依赖的对象”的创建工作;同时,由于需求的变化,往往存在更多系列对象的创建工作如何应…

【ARM】MDK工程切换高版本的编译器后出现error A1137E报错

【更多软件使用问题请点击亿道电子官方网站】 1、 文档目标 解决工程从Compiler 5切换到Compiler 6进行编译时出现一些非语法问题上的报错。 2、 问题场景 对于一些使用Compiler 5进行编译的工程,要切换到Compiler 6进行编译的时候,原本无任何报错警告…

Redis-哨兵模式-主机宕机-推选新主机的过程

文章目录 1、为哨兵模式准备配置文件2、启动哨兵3、主机6379宕机3.4、查看sentinel控制台日志3.5、查看6380主从信息 4、复活63794.1、再次查看sentinel控制台日志 1、为哨兵模式准备配置文件 [rootlocalhost redis]# ll 总用量 244 drwxr-xr-x. 2 root root 150 12月 6 2…

免费APP分发平台:小猪APP分发如何解决开发者的痛点

你是否曾为自己开发的APP找不到合适的分发平台而烦恼?你是否因为高昂的分发费用而望而却步?放心吧,你并不是一个人。很多开发者都面临同样的问题。但别担心,小猪APP分发来了,它可以帮你解决这些问题。 小猪app封装www…

微软结束将数据中心置于海底的实验

2016 年,微软 宣布了一项名为"纳蒂克项目"(Project Natick)的实验。基本而言,该项目旨在了解数据中心能否在海洋水下安装和运行。经过多次较小规模的测试运行后,该公司于 2018 年春季在苏格兰海岸外 117 英尺…

《Redis设计与实现》阅读总结-2

第 7 章 压缩列表 1. 概念: 压缩列表是列表键和哈希键的底层实现之一。当一个列表键只包含少量列表项,并且每个列表项是小整数值或长度比较短的字符串,那么Redis就会使用压缩类别来做列表键的底层实现。哈希键里面包含的所有键和值都是最小…

基于ESP8266串口WIFI模块ESP-01S在AP模式(即发射无线信号( WiFi))下实现STC单片机与手机端网路串口助手相互通信功能

基于ESP8266串口WIFI模块ESP-01S在AP模式(即发射无线信号( WiFi))下实现STC单片机与手机端网路串口助手相互通信功能 ESP8266_01S引脚功能图ESP8266_01S原理图ESP8266_01S尺寸图检验工作1、USB-TTL串口工具(推荐使用搭载CP2102芯片的安信可USB-T1串口)与ESP8266_01S WiFi…

Websocket在Java中的实践——最小可行案例

WebSocket是一种先进的网络通信协议,它允许在单个TCP连接上进行全双工通信,即数据可以在同一时间双向流动。WebSocket由IETF标准化为RFC 6455,并且已被W3C定义为JavaScript API的标准,成为现代浏览器的重要特性之一。 WebSocket的…

代码随想录——跳跃游戏(Leecode55)

题目链接 贪心 class Solution {public boolean canJump(int[] nums) {int cover 0;if(nums.length 1){return true;}// 只有一个元素可以达到for(int i 0; i < cover; i){// 在cover内选择跳跃步数cover Math.max(i nums[i],cover);if(cover > nums.length - 1)…

C++进修——C++核心编程

内存分区模型 C程序在执行时&#xff0c;将内存大方向划分为4个区域 代码区&#xff1a;存放函数体的二进制编码&#xff0c;由操作系统进行管理全局区&#xff1a;存放全局变量和静态变量以及常量栈区&#xff1a;由编译器自动分配释放&#xff0c;存放函数的参数值&#xff…

如何在FastAPI服务器中添加黑名单和白名单实现IP访问控制

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 文章内容 📒📝 添加黑名单功能步骤1:安装依赖步骤2:创建FastAPI应用步骤3:添加黑名单📝 添加白名单功能步骤1:创建白名单列表步骤2:添加白名单检查⚓️ 相关链接 ⚓️📖 介绍 📖 在现代网络应用开发中,为了增强…

【推荐】Prometheus+Grafana企业级监控预警实战

新鲜出炉&#xff01;&#xff01;&#xff01;PrometheusGrafanaAlertmanager springboot 企业级监控预警实战课程&#xff0c;从0到1快速搭建企业监控预警平台&#xff0c;实现接口调用量统计&#xff0c;接口请求耗时统计…… 详情请戳 https://edu.csdn.net/course/detai…

深度挖掘数据资产,洞察业务先机:利用先进的数据分析技术,精准把握市场趋势,洞悉客户需求,为业务决策提供有力支持,实现持续增长与创新

在当今日益激烈的商业竞争环境中&#xff0c;企业想要实现持续增长与创新&#xff0c;必须深入挖掘和有效运用自身的数据资产。数据不仅是企业运营过程中的副产品&#xff0c;更是洞察市场趋势、理解客户需求、优化业务决策的重要资源。本文将探讨如何通过利用先进的数据分析技…

黑马苍穹外卖7 用户下单+订单支付(微信小程序支付流程图)

地址簿 数据库表设计 就是基本增删改查&#xff0c;与前面的类似。 用户下单 用户点餐业务流程&#xff1a; 购物车-订单提交-订单支付-下单成功 展示购物车数据&#xff0c;不需要提交到后端 数据库设计&#xff1a;两个表【订单表orders&#xff0c;订单明细表order_d…

智慧车库管理系统

摘 要 随着城市化进程的不断加快&#xff0c;私家车数量的快速增长给城市交通带来了巨大的挑战&#xff0c;停车问题成为城市交通管理中的一大难题。车辆停车时&#xff0c;在停车场寻找停车位耗时过久&#xff0c;不仅仅浪费用户的时间&#xff0c;还可能引起交通拥堵。城市停…

考研数学一有多难?130+背后的残酷真相

考研数学一很难 大家平时在网上上看到很多人说自己考了130&#xff0c;其实这些人只占参加考研数学人数的极少部分&#xff0c;有个数据可以展示出来考研数学到底有多难&#xff1a; 在几百万考研大军中&#xff0c;能考到120分以上的考生只有2%。绝大多数人的分数集中在30到…

构建实用的Flutter文件列表:从简到繁的完美演进

前言&#xff1a;为什么我们需要文件列表&#xff1f; 在现代科技发展迅速的时代&#xff0c;我们的电脑、手机、平板等设备里积累了大量的文件&#xff0c;这些文件可能是我们的照片、文档、音频、视频等等。然而&#xff0c;当文件数量增多时&#xff0c;我们如何快速地找到…