图神经网络:(大型图的有关处理)在Pumbed数据集上动手实现图神经网络

news2024/12/26 11:22:46

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。
3)我在百度网盘上传了这篇文章的jupyter notebook和有关文献。超链。提取码8848。

文章目录

    • Pumed数据集
    • 文献阅读
    • 继续实验

Pumed数据集

导库

from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.datasets import Planetoid

下载数据处理数据导入数据

dataset=Planetoid(root='/DATA/Planetoid',name='PubMed',transform=NormalizeFeatures())

其他说明1:这段代码会在C盘生成一个DATA的文件并将数据集放在DATA中,有强迫症注意一下。
其他说明2:如果下载发生错误直接去官网上下载。下载好了复制C:\DATA\Planetoid\PubMed\raw中。官网链接。不会有人没梯子吧。
数据描述

data=dataset[0]
print(data.num_nodes,end=" ");print(data.num_edges)
print(data.train_mask.sum().item(),end=" ");print(data.val_mask.sum().item(),end=" ");print(data.test_mask.sum().item())
print(data.has_isolated_nodes(),end=" ");print(data.has_self_loops(),end=" ");print(data.is_undirected(),end=" ")
#输出如下
#19717 88648
#60 500 1000
#False False True

其他说明:Pumbed数据集开源的生物医学文献数据库。不用细究。

文献阅读

参考文献: Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks。原文链接。

功能概述: 介绍一种方法在时间与空间完爆其他方法。其他方法是指:1)Full-batch gradient:在空间上:O(NLF)(N代表序列的长度,F代表特征的数量,L代表网络层数量)。在时间上,梯度下降收敛较慢。2)Mini-batch SGD:在时间空间上引入大量计算开销造成原因邻域扩展。计算某个节点损失需要第L-1层的嵌入然后需要第L-2层的嵌入递归下去。3)VR-GCN:克服上述领域扩展有关问题但是需要第L-1层的嵌入所以空间要求太高NG。然后Cluster-GCN就来啦。1)Cluster-GCN空间小特别是在大型图上2)Cluster-GCN在浅层GCN中速度等同于VR-GCN;深层GCN中Cluster-GCN快得多,Cluster-GCN是线性,VR-GCN是指数。3)尽管有些工作表明深层图神经网络的效果不佳,但是实验表明Cluster-GCN深层的效果不错。下图:各种方法的时空复杂度。
在这里插入图片描述

理论分析1::将节点划分为n组如 V 1 , V 2 , … V n ] \mathcal{V}_{1},\mathcal{V}_{2},\dots \mathcal{V}_{n}] V1,V2,Vn]。同组节点保存邻接不同组间直接断开。所以图被划分为 G ‾ = [ G 1 , G 2 , … , G n ] = [ { V 1 , E 1 } , { V 2 , E 2 } , … , { V n , E n } ] \overline{G}=[G_{1},G_{2},\dots,G_{n}]=[\{\mathcal{V}_{1},\mathcal{E}_{1}\},\{\mathcal{V}_{2},\mathcal{E}_{2}\},\dots,\{\mathcal{V}_{n},\mathcal{E}_{n}\}] G=[G1,G2,,Gn]=[{V1,E1},{V2,E2},,{Vn,En}]。这个相当于对图作近似。 Δ \Delta Δ保留了删除信息。特征向量以及标签按照节点划分划分。多层图神经网络便变为: Z L = A ‾ ′ σ ( A ‾ ′ σ ( … σ ( A ‾ ′ X W 0 ) W 1 ) …   ) W L − 1 Z^{L}=\overline{A}^{\prime}\sigma(\overline{A}^{\prime}\sigma(\dots\sigma(\overline{A}^{\prime}XW^{0})W^{1})\dots)W^{L-1} ZL=Aσ(Aσ(σ(AXW0)W1))WL1 A ‾ ′ \overline{A}^{\prime} A是分块对角阵 A ‾ \overline{A} A的标准化。损失函数变为: L A ‾ ′ = ∑ t ∣ V t ∣ N L A ‾ t t ′ \mathcal{L}_{\overline{A}^{\prime}}=\sum_{t}\frac{|\mathcal{V}_{t}|}{N}\mathcal{L}_{\overline{A}^{\prime}_{tt}} LA=tNVtLAtt and L A ‾ t t ′ = 1 ∣ V t ∣ ∑ i ∈ V t l o s s ( y i , z i L ) \mathcal{L}_{\overline{A}^{\prime}_{tt}}=\frac{1}{|\mathcal{V}_{t}|}\sum_{i \in \mathcal{V}_{t}}loss(y_{i},z_{i}^{L}) LAtt=Vt1iVtloss(yi,ziL) 。这个便就就是核心思想。大概就是按照下图右边那样进行分割。
在这里插入图片描述
理论分析2: 划分引入一种误差,这种误差是与 Δ \Delta Δ成正比,所以我们应该减小这种误差。于是引入Metis以及Graclus方法 。重点分析了Metics划分,比起随机划分效果更好如下:这些指标都是Accuracy_Score吧。
在这里插入图片描述
理论分析3:
还有问题。1)毕竟还是删除了一些边,模型效果可能还是受到影响。2)由于集群分配算法导致相似节点被分为了一堆,所以可能会与原始数据不同(PS:原文这样写的,我不能够理解),使用随机梯度算法可能会带来偏差吧。所以为了解决问题或者减小问题影响,提出了一个 stochastic multiple clustering方法。简单来说是这样的:随机梯度算法更新权重需要进行Batch的划分;之前邻接矩阵已经被处理成分块对角矩阵。选择m个对角矩阵进入Batch,前面删除的边重新加上。如下这样。(好吧这个图我也没看懂但是大致想法是清晰的)
在这里插入图片描述
结果如下:
在这里插入图片描述
算法的伪代码:
在这里插入图片描述

PS1:后面好像是实验的内容部分,没时间看就这样吧。PS:这是我的理解所以不一定对。

继续实验

导库

from torch_geometric.loader import ClusterData,ClusterLoader

聚类划分,构建批量

cluster_data=ClusterData(data,num_parts=128)
train_loader=ClusterLoader(cluster_data,batch_size=32,shuffle=True)

打印信息

total_num_nodes=0
for step,sub_data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of nodes in the current batch: {sub_data.num_nodes}')
    print(sub_data)
    print()
    total_num_nodes+=sub_data.num_nodes
print(f'Iterated over {total_num_nodes} of {data.num_nodes} nodes!')
#输出如下
#Step 1:
#=======
#Number of nodes in the current batch: 4924
#Data(x=[4924, 500], y=[4924], train_mask=[4924], val_mask=[4924], test_mask=[4924], edge_index=[2, 15404])
#
#Step 2:
#=======
#Number of nodes in the current batch: 4939
#Data(x=[4939, 500], y=[4939], train_mask=[4939], val_mask=[4939], test_mask=[4939], edge_index=[2, 17834])
#
#Step 3:
#=======
#Number of nodes in the current batch: 4928
#Data(x=[4928, 500], y=[4928], train_mask=[4928], val_mask=[4928], test_mask=[4928], edge_index=[2, 17524])
#
#Step 4:
#=======
#Number of nodes in the current batch: 4926
#Data(x=[4926, 500], y=[4926], train_mask=[4926], val_mask=[4926], test_mask=[4926], edge_index=[2, 16042])
#
#Iterated over 19717 of 19717 nodes!

导库

from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch

随便搭建

class GCN(torch.nn.Module):
    def __init__(self,hidden_channels):
        super(GCN,self).__init__()
        self.conv1=GCNConv(dataset.num_node_features,hidden_channels)
        self.conv2=GCNConv(hidden_channels,dataset.num_classes)
    def forward(self,x,edge_index):
        x=self.conv1(x,edge_index)
        x=x.relu()
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.conv2(x,edge_index)
        return x

打印信息

model=GCN(hidden_channels=16)
print(model)
#输出如下
#GCN(
#  (conv1): GCNConv(500, 16)
#  (conv2): GCNConv(16, 3)
#)

开始训练

model=GCN(hidden_channels=16);optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4);criterion=torch.nn.CrossEntropyLoss()

def train():
      model.train()
      for sub_data in train_loader:
          out=model(sub_data.x,sub_data.edge_index)
          loss=criterion(out[sub_data.train_mask],sub_data.y[sub_data.train_mask])
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()

def test():
      model.eval()
      out=model(data.x,data.edge_index)
      pred=out.argmax(dim=1)
      accs=[]
      for mask in [data.train_mask,data.val_mask,data.test_mask]:
          correct=pred[mask]==data.y[mask]
          accs.append(int(correct.sum())/int(mask.sum()))
      return accs

for epoch in range(1,51):
    loss=train()
    train_acc,val_acc,test_acc=test()
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
#最后一次输出如下
#Epoch: 050, Train: 0.9833, Val Acc: 0.8060, Test Acc: 0.7880

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/536989.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【day2】单片机

目录 【1】GPIO 1.定义 2.应用 I - Input - 输入采集 O - Output - 输出控制 ​编辑 3.GPIO结构框图 4.功能描述 输入功能 输出功能 5.相关寄存器 【2】点亮一盏LED灯 1.实验步骤 2.编程实现 3.编译下载 4.复位上电 练习:实现LED灯闪烁 练习…

Linux - 第15节 - 网络基础(应用层)

1.再谈 "协议" 1.1.协议的概念 协议,网络协议的简称,网络协议是通信计算机双方必须共同遵从的一组约定,比如怎么建立连接、怎么互相识别等。 为了使数据在网络上能够从源到达目的,网络通信的参与方必须遵循相同的规则&…

收集数据集以训练自定义模型的 5 种方法

来源:投稿 作者:王同学 编辑:学姐 在过去的十年中,深度学习技术在计算机视觉领域中的应用逐年增加。其中当属「行人检测」和「车辆检测」最为火爆,其原因之一就是「预训练模型」的「可复用性」。 由于深度学习技术在这…

Pandas+Pyecharts | 新冠疫情数据动态时序可视化

文章目录 🏳️‍🌈 1. 导入模块🏳️‍🌈 2. Pandas数据处理2.1 读取数据2.2 按月统计数据 🏳️‍🌈 3. Pyecharts数据可视化3.1 疫情动态时序地图3.2 疫情动态时序折线图3.3 疫情动态时序柱状图3.4 疫情动态…

Maven中scope(作用范围)详解

目录 一、依赖传递二、依赖范围三、依赖范围对传递依赖的影响四、依赖调节五、可选依赖六、排除依赖七、依赖归类八、依赖管理 一、依赖传递 Maven 依赖传递是 Maven 的核心机制之一,它能够一定程度上简化 Maven 的依赖配置。 如下图所示,项目 A 依赖于…

黄牛为什么能抢走“五月天”的门票?

目录 “史上最难抢票”的五月天演唱会 黄牛为什么能抢到票 黄牛抢票带来哪些坏影响 售票平台为什么挡不住黄牛? 管理上如何有效防控黄牛 技术上如何有效防黄牛 相关技术产品推荐 随着文娱活动的复苏,大量黄牛“卷土袭来”。顶象防御云业务安全情报…

【音视频处理】MP4、FLV、HLS适用范围,在线视频播放哪个更好

大家好,欢迎来到停止重构的频道。 我们之前讨论过直播协议,本期我们讨论在线点播的视频格式。 也就是网络视频文件、短视频常用的格式 如MP4、FLV、HLS等。 我们将详细讨论在线点播场景下,这些视频格式的优劣以及原因。 我们按这样的顺序…

分享几个冷门但实用的网站!

今天给大家推荐几个冷门但实用的网站,免费又好用对于打工人来讲十分友好。 Img Cleaner https://imgcleaner.com/ 一个免费的AI智能图片去水印网站,不用注册登录就可以使用,而且操作也比较简单,上传图片之后调整画笔大小&#xf…

小黑子—Java从入门到入土过程:第十章 - 多线程

Java零基础入门10.0 Java系列第十章- 多线程1. 初识多线程2. 并发和并行3. 多线程的实现方式3.1 一:继承Thread类方式实现3.2 二:实现Runnable接口的方式实现3.3 三:利用Callable接口和Future接口方式实现 4. 多线程中常见的成员方法4.1 线程…

❤ Manifest version 2 is deprecated, and support will be removed in 2023. See..

❤谷歌插件开发遇到的问题 开发谷歌插件提示: ❤js 开发谷歌插件提示 提示 Manifest version 2 is deprecated, and support will be removed in 2023. See… 当导入到chrome,提示如下错误: Manifest version 2 is deprecated, and suppo…

物联网:智慧城市还要做的事情

摘要:本文简介关于智慧城市,还有哪些需要做的事情。 1.传统城市需要向智慧城市转型 传统的城市中心已被证明不足以满足社会当前和未来的各种需求,这增加了应用智慧城市理念的需求。智慧城市可以对健康、交通、休闲和工业等多个领域产生重大…

VLC可见光通信:2、高速LED驱动电路

背景 在VLC可见光通信中,需要高速的控制LED的通断,因此需要高速LED驱动电路。 文中出现的低压是指24V电压以下,中压是指24V~60V电压,高压是指60V ~ 160V。 低速是指500KHZ以下,高速是指2MHZ。 小功率是指20W以下,大功率指20W~100W。 低压小功率LED低速&高速:20W、…

你是时候拥抱chatgpt了

随着chatgpt热度不断上升,chatgpt已经广泛应用到各个行业了,很多人都感觉自己地位受到威胁,有人预测chatgpt会取代80%程序员的工作,我也用了chatgpt有几个月了,不得不说是真的牛逼。我甚至用它写了一个python的聊天脚本…

MapReduce计算广州2022年每月最高温度

目录 数据集 1.查询地区编号 2.数据集的下载 编写MapReduce程序输入格式 输出格式 Mapper类 确定参数 代码 Reducer类 思路 代码 Runner类 运行结果 数据集 1.查询地区编号 NCDC是美国国家气象数据中心的缩写,是一个负责收集、存储和分发全球气象和气…

C#中将32位二进制转换为float【Real】十进制类型

已知一个32位二进制字符串,转换为float【Real】十进制。 参考本人一篇博客 float数转二进制 C#关于32位浮点数Float(Real)一步步按位Bit进行解析_real32位浮点数_斯内科的博客-CSDN博客 现在是32位二进制转化为十进制浮点数,C#有…

电动汽车入网技术(V2G)调度优化(Matlab代码实现)

目录 💥1 概述 📚2 运行结果 🎉3 参考文献 👨‍💻4 Matlab代码 💥1 概述 近年来我国电动汽车行业飞速发展,其中电动汽车入网技术(vehicle-to-grid,V2G)在…

RapidScada Linux安装教程

官方安装步骤:在 Linux上安装 - Rapid SCADA,安装过程中遇到一些坑,记录详细步骤。 官方推荐的Ubuntu,未测试Centos 1. 安装ASP.Net运行环境(Runtime) 下载地址:下载 .NET 6.0 (Linux、macOS 和 Windows)&a…

selenium还能这么玩:连接已经存在的浏览器

测试和爬虫对selenium并不会陌生,现有的教程已经非常多。但是因为 selenium 封装的方法比较底层,所以灵活性非常高,我们可以基于这种灵活性来实现非常丰富的定制功能。 这篇文章介绍一个操作,可以让 selenium 连接我们手动打开的…

AI绘画-Midjourney基础2-超强二次元风格模型 niji 5

niji 模型是 mj 的一种模型,可以生成二次元风格的图片。 在控制台输入 /settings 指令,进入设置页面。 选择第二行的 Niji version 5 模型,就可以创作二次元风格的图片了! 一、expressive 风格 expressive 风格是 niji 5 模型的默认风格。 Step into the world :: of a …

14个最佳创业企业WordPress主题

要创建免费网站?从易服客建站平台免费开始 500M免费空间,可升级为20GB电子商务网站 创建免费网站 您网站的设计使您能够展示产品的独特卖点。通过正确的主题,您将能够解释为什么客户应该选择您的品牌而不是其他品牌。 在本文中&#xff0…