【复杂网络建模】——使用PyTorch和DGL库实现图神经网络进行链路预测

news2024/11/19 17:47:13

🤵‍♂️ 个人主页:@Lingxw_w的个人主页

✍🏻作者简介:计算机科学与技术研究生在读
🐋 希望大家多多支持,我们一起进步!😄
如果文章对你有帮助的话,
欢迎评论 💬点赞👍🏻 收藏 📂加关注+ 

目录

1、常见的链路预测方法

2、图神经网络上的链路预测

3、使用PyTorch和DGL库实现图神经网络进行链路预测


链路预测是指在一个给定的网络中,根据已有的网络结构信息,尝试预测两个节点之间是否存在连接或者可能会建立连接的概率。这在社交网络分析、生物信息学、推荐系统等领域中都有广泛的应用。

在复杂网络中,链路预测可以帮助我们理解网络的演化过程、发现隐藏的关系和未知的连接,以及预测未来的网络演化趋势。

1、常见的链路预测方法

  1. 基于相似性的方法:这类方法假设具有相似性的节点之间更有可能存在连接。常见的相似性度量方法包括共同邻居数、Jaccard系数、Adamic/Adar指数等。

  2. 基于路径的方法:这类方法考虑节点之间的路径信息,比如最短路径、随机游走路径等。通过分析节点之间的路径特征,可以预测节点间的连接概率。

  3. 基于机器学习的方法:这类方法使用机器学习算法来建模和预测网络中的链路。常见的机器学习算法包括决策树、随机森林、支持向量机(SVM)、神经网络等。

  4. 基于深度学习的方法:这是近年来兴起的一种方法,使用深度学习模型(如图神经网络)来学习节点的表征,并通过这些表征来进行链路预测。

链路预测并非一种绝对准确的预测方法,因为网络的演化和连接行为具有一定的随机性。 

2、图神经网络上的链路预测

图神经网络(Graph Neural Networks,简称GNN)可以用于链路预测任务。GNN是一类专门用于处理图结构数据的深度学习模型,能够学习节点和边的特征表示,并在此基础上进行预测任务。

步骤:

  1. 图表示构建:首先,将原始的网络数据表示为图结构,其中节点表示网络中的实体(如用户、物品),边表示节点之间的连接关系(如关注、交互)。

  2. 节点表征学习:GNN通过多轮的消息传递和聚合操作,从节点和边的特征中学习节点的表征。这样,每个节点都会得到一个向量表示,用于捕捉其在网络中的特征和上下文信息。

  3. 边预测模型构建:在节点表征学习的基础上,可以构建一个边预测模型来预测节点之间的连接概率。一种常见的方法是使用一个全连接层或多层感知机(MLP)来将节点表征映射到一个预测分数或概率。可以使用二元分类任务来预测节点间是否存在连接,或者使用回归任务来预测连接的强度或权重。

  4. 模型训练和评估:使用已知的网络结构数据进行模型的训练,并通过验证集或交叉验证进行模型的选择和调优。评估时,可以使用一些常见的指标,如准确率、精确度、召回率、F1分数等来评估链路预测的性能。

3、使用PyTorch和DGL库实现图神经网络进行链路预测

导入必要的库,包括PyTorch和DGL。

import torch
import torch.nn as nn
import dgl

定义图神经网络模型 GNNLinkPredict,模型包含两个图卷积层,输入特征维度为2,输出特征维度为1。

# 定义图神经网络模型
class GNNLinkPredict(nn.Module):
    def __init__(self, in_feats, hidden_size, out_feats):
        super(GNNLinkPredict, self).__init__()
        self.conv1 = dgl.nn.GraphConv(in_feats, hidden_size)
        self.conv2 = dgl.nn.GraphConv(hidden_size, out_feats)
    
    def forward(self, g, features):
        x = torch.relu(self.conv1(g, features))
        x = torch.relu(self.conv2(g, x))
        return x

创建示例图数据 g,其中包括5个节点和7条边。定义节点特征 features,每个节点有两个特征值。定义标签 labels,表示边的连接情况。

# 构建示例图数据
# 创建一个有向图
g = dgl.DGLGraph()
g.add_nodes(5)
g.add_edges([0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 4, 3, 4])

# 定义节点特征
features = torch.tensor([
    [0.2, 0.4],
    [0.3, 0.5],
    [0.4, 0.6],
    [0.5, 0.7],
    [0.6, 0.8]
])

# 定义标签(边是否存在连接)
labels = torch.tensor([1, 1, 1, 0, 0, 1, 0], dtype=torch.float32)

划分训练集和测试集,使用布尔类型的掩码 train_masktest_mask 表示。

# 划分训练集和测试集
train_mask = torch.tensor([True, True, True, False, False])
test_mask = torch.tensor([False, False, False, True, True])

创建图神经网络模型实例 model

定义优化器损失函数,这里使用Adam优化器和二分类的交叉熵损失函数。

# 创建图神经网络模型
model = GNNLinkPredict(in_feats=2, hidden_size=16, out_feats=1)

# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()

进行模型训练。循环迭代多个epoch,在每个epoch中执行以下步骤

  • 将模型设置为训练模式 model.train()
  • 前向传播计算预测结果 logits
  • 计算预测结果与标签之间的损失。
  • 清空优化器的梯度。
  • 反向传播计算梯度。
  • 更新模型参数。
# 训练模型
for epoch in range(50):
    model.train()
    logits = model(g, features)
    pred = logits.squeeze()
    loss = criterion(pred[train_mask], labels[train_mask])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 打印训练损失
    print(f"Epoch: {epoch + 1}, Loss: {loss.item()}")

在测试集上评估模型。将模型设置为评估模式 model.eval(),然后使用训练好的模型对测试集进行预测。通过将预测结果应用sigmoid函数将其映射到0-1之间,并使用四舍五入将其转换为0或1的预测标签。计算预测准确率并输出。

# 在测试集上评估模型
model.eval()
with torch.no_grad():
    logits = model(g, features)
    pred = logits.squeeze()
    pred = torch.sigmoid(pred)  # 使用sigmoid函数将预测值映射到0-1之间
    pred_labels = torch.round(pred)  # 四舍五入为0或1的预测标签
    accuracy = (pred_labels[test_mask] == labels[test_mask]).float().mean()
    print(f"Accuracy: {accuracy.item()}")

汇总的代码:

# https://www.dgl.ai/pages/start.html

import torch
import torch.nn as nn
import dgl


# 定义图神经网络模型
class GNNLinkPredict(nn.Module):
    def __init__(self, in_feats, hidden_size, out_feats):
        super(GNNLinkPredict, self).__init__()
        self.conv1 = dgl.nn.GraphConv(in_feats, hidden_size)
        self.conv2 = dgl.nn.GraphConv(hidden_size, out_feats)

    def forward(self, g, features):
        x = torch.relu(self.conv1(g, features))
        x = torch.relu(self.conv2(g, x))
        return x


# 构建示例图数据
# 创建一个有向图
g = dgl.DGLGraph()
g.add_nodes(5)
g.add_edges([0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 4, 3, 4])

# 添加自环
g = dgl.add_self_loop(g)

# 定义节点特征
features = torch.tensor([
    [0.2, 0.4],
    [0.3, 0.5],
    [0.4, 0.6],
    [0.5, 0.7],
    [0.6, 0.8]
])

# 定义标签(边是否存在连接)
labels = torch.tensor([1, 1, 1, 0, 0, 1, 0], dtype=torch.float32)

# 划分训练集和测试集
train_mask = torch.tensor([True, True, True, False, False, False, False])
test_mask = torch.tensor([False, False, False, True, True, True, True])

# 创建图神经网络模型
model = GNNLinkPredict(in_feats=2, hidden_size=16, out_feats=1)

# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()

# 训练模型
for epoch in range(50):
    model.train()
    logits = model(g, features)
    pred = logits.squeeze()
    loss = criterion(pred[train_mask], labels[train_mask])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 打印训练损失
    print(f"Epoch: {epoch + 1}, Loss: {loss.item()}")

# 在测试集上评估模型
model.eval()
with torch.no_grad():
    logits = model(g, features)
    pred = logits.squeeze()
    pred = torch.sigmoid(pred)  # 使用sigmoid函数将预测值映射到0-1之间
    pred_labels = torch.round(pred)  # 四舍五入为0或1的预测标签
    accuracy = (pred_labels[test_mask] == labels[test_mask]).float().mean()
    print(f"Accuracy: {accuracy.item()}")

 留下个问题有空再解决。

关于复杂网络建模,我前面写了很多,大家可以学习参考。

【复杂网络建模】——常用绘图软件和库_图论画图软件

【复杂网络建模】——Pytmnet进行多层网络分析与可视化

【复杂网络建模】——Python通过平均度和随机概率构建ER网络

【复杂网络建模】——通过图神经网络来建模分析复杂网络

【复杂网络建模】——Python可视化重要节点识别(PageRank算法)

【复杂网络建模】——基于Pytorch构建图注意力网络模型

【复杂网络建模】——基于微博数据的影响力最大化算法(PageRank)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/653911.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

当老板问:软件质量怎么样,能上线发布吗?阁下该如何应对

说在前面 每当你和团队完成了一款软件产品的开发,是否很容易被问到这样一个问题:质量怎么样?或者是能上线发布吗?如果你是团队的负责人,你会如何回答这样的问题呢?对软件质量的评判标准,不见得…

【Airtest】UI自动化测试的数据分离实践

目录 前言 1. 示例介绍 2. 读取Excel单元格里的数据 1)安装 xlrd 第三方库 2)读取表格数据存储到列表中 3)封装成读取控件信息的函数 3. 处理控件信息并实现控件操作 小结 前言 在UI自动化测试中,测试数据的管理和组织是…

Spring-Retry(重试机制)

Spring-Retry(重试机制) 在实际工作中,重处理是一个非常常见的场景,比如: 发送消息失败。 调用远程服务失败。 争抢锁失败。 这些错误可能是因为网络波动造成的,等待过后重处理就能成功。通常来说,会用try…

Redis入门 - 5种基本数据类型

原文首更地址,阅读效果更佳! Redis入门 - 5种基本数据类型 | CoderMast编程桅杆https://www.codermast.com/database/redis/five-base-datatype.html 说明 在我们平常的业务中基本只会使用到Redis的基本数据类型(String、List、Hash、Set、…

重新学树结构

树 图一 图二 相关术语 前驱:某结点上一层结点,图中H结点的前驱结点是F后继:某结点紧跟的后面的结点,图中F结点的后继是G、H、I三个结点根结点:非空树没有前驱结点的结点,图中的R结点结点的度&#x…

019+limou+C语言预处理

0.前言 您好,这里是limou3434的一篇博客,感兴趣您可以看看我的其他博文系列。本次我主要给您带来了C语言有关预处理的知识。 1.宏的深度理解与使用 1.1.数值宏常量 #define PI 3.1415926注意define和#之间是可以留有空格的 1.2.字符宏常量 #includ…

设置论文中的图、表的题注

参考b站:毕业论文图表如何自动编号/word图表自动编号/图表编号自动更新 其中,更新图表序号 视频使用ctrl 设置论文中的图、表的题注 step1:设置章节1.1: 章节设置字体样式,选择标题11.2:章节添加序号1.3 修改序号 和字之间的缩进&…

Linux->线程基本概念

目录 前言: 1. 线程的基本概念 2 线程的优点 3 线程的缺点 4 数据块大小为4KB大小的真正原因 前言: 本篇文章讲解了线程与进程之间的区别和联系,线程的优缺点,还有内存的数据管理与磁盘之间的关系,虚拟内存到内存…

阿里云服务器提供哪些操作系统和软件支持?是否与常用软件兼容?

阿里云服务器提供哪些操作系统和软件支持?是否与常用软件兼容?    阿里云服务器支持的操作系统   为了满足不同用户需求,阿里云服务器(ECS)提供了丰富的操作系统选择。以下是阿里云服务器支持的主要操作系统&#…

Linux 配置MySQL环境(三)

Linux配置MySQL环境 一、下载1. 官网下载MySQL2. 百度网盘快速下载MySQL 二、安装1、通过 Xftp 将 MySQL 安装包拷贝到 Linux2、解压缩3、安装 common、libs、client、server4、初步连接 三、卸载四、常用设置1. 修改 root 用户密码 五、使用新密码登录六、开启远程访问七、开放…

PHP设计模式21-工厂模式的讲解及应用

文章目录 前言基础知识简单工厂模式工厂方法模式抽象工厂模式 详解工厂模式普通的实现更加优雅的实现 总结 前言 本文已收录于PHP全栈系列专栏:PHP快速入门与实战 学会好设计模式,能够对我们的技术水平得到非常大的提升。同时也会让我们的代码写的非常…

OpenCV 笔记_5

文章目录 笔记_5特征点匹配DMatch 存放匹配结果DescriptorMatcher::match 特征点描述子(一对一)匹配DescriptorMatcher::knnMatch 特征点描述子(一对多)匹配DescriptorMatcher::radiusMatch 特征点描述子(一对多&#…

Frontiers in Microbiology:DAP-seq技术在猪苓C2H2转录因子PuCRZ1调控菌丝生长及渗透胁迫耐受性机制研究中的应用

猪苓(Polyporus umbellatus)是一种可食用的蘑菇,也是我国常用的菌类药材之一,至今已有2000多年的药用历史,在《神农本草经》、《本草纲目》、《本草求真》等典籍中均有记载。猪苓具有利尿、抗菌作用,近年来…

SpringBatch从入门到实战(二):HelloWorld

一:HelloWorld 1.1 配置Job、Step、Tasklet Configuration public class HelloWorldJobConfig {Autowiredprivate JobBuilderFactory jobBuilderFactory;Autowiredprivate StepBuilderFactory stepBuilderFactory;Beanpublic Job helloWorldJob() {return jobBuild…

代码随想录算法训练营第五十九天|503.下一个更大元素II 42. 接雨水

目录 LeeCode 503.下一个更大元素II LeeCode 42. 接雨水 暴力解法 优化双指针法 单调栈法 LeeCode 503.下一个更大元素II 503. 下一个更大元素 II - 力扣(LeetCode) 【思路】 相较于前两道题目,这道题目将数组改为循环数组&#x…

python获取度娘热搜数据并保存成Excel

python获取百度热搜数据 一、获取目标、准备工作二、开始编码三、总结 一、获取目标、准备工作 1、获取目标: 本次获取教程目标:某度热搜 2、准备工作 环境python3.xrequestspandas requests跟pandas为本次教程所需的库,requests用于模拟h…

在读博士怎么申请公派访学?

作为在读博士生,申请公派访学是一项重要而有益的经历。下面知识人网将为您介绍一些关于如何申请公派访学的步骤和注意事项。 首先,您需要找到一个合适的公派访学机会。可以通过与导师、教授或其他相关人士进行交流来获取相关信息。还可以参考学术会议、研…

【Linux】linux | 服务响应慢、问题排查 | 带宽问题导致

一、说明 1、项目使用云服务器,服务器配置:5M带宽、4核、32G,1T,CentOS7 2、CPU、内存、磁盘IO都没有达到瓶颈,猜测是带宽问题 3、应用比较多,应用中间件,十几个差不多 4、同时在线人数30 5、已…

继承~~~

1:继承概述,使用继承的好处 1:什么是继承? Java中提供一个关键字extends,用这个关键字,我们可以让一个类和另一类建立起父子关系。 public class Student extends People{} Student称为子类&#xff08…

乘势而起:机载航电·显控显示系统仿真

改革开放以来,我国国民经济与科学技术取得了长足的发展,信息化、工业成熟度与自动化程度不断深化,极大地增强了国家的综合实力、在世界范围内显示了大国地位。在当前科技产业的发展和变革的历史性交汇期,“工业4.0”、“中国制造2…