使用PyG(PyTorch Geometric)实现基于图卷积神经网络(GCN)的节点分类任务

news2024/12/29 0:45:54

文章目录

  • 基本介绍
    • PyTorch Geometric
    • 图卷积神经网络GCN
  • 节点分类任务实现
    • Cora数据集
    • 搭建GCN模型
    • 训练与测试
    • 迭代并输出
    • 完整代码

基本介绍

PyTorch Geometric

PyG(PyTorch Geometric)是一个基于PyTorch的库,可以轻松编写和训练图神经网络(GNN),用于与结构化数据相关的广泛应用。

它包括从各种已发表的论文中对图和其他不规则结构进行深度学习的各种方法,也称为几何深度学习。此外,它还包括易于使用的迷你批处理加载程序,用于在许多小型和单巨型图上操作,多GPU支持,大量通用基准数据集(基于创建自己的简单接口),GraphGym实验管理器,以及有用的转换,既用于在任意图上学习,也用于在3D网格或点云上学习。

安装PyG可以参考我的博客:python安装pyg(pytorch_geometric)的两种方式:https://wang11.blog.csdn.net/article/details/128987042

图卷积神经网络GCN

GCN由Thomas N. Kipf和Max Welling在ICLR2017提出。

Semi-Supervised Classification with Graph Convolutional Networks: https://arxiv.org/abs/1609.02907
在这里插入图片描述

对于一个输入图,他有N个节点,每个节点的特征组成一个特征矩阵X,节点与节点之间的关系组成一个邻接矩阵A,X和A即为模型的输入。

GCN是一个神经网络层,它具有以下逐层传播规则:
在这里插入图片描述

其中,

  • ˜A = A + I,A为输入图的领接矩阵,I为单位矩阵。
  • ˜D为˜A的度矩阵,˜Dii = ∑j ˜Aij
  • H是每一层的特征,对于输入层H = X
  • σ是非线性激活函数
  • W为特定层的可训练权重矩阵

节点分类任务实现

Cora数据集

Cora数据集包含2708篇科学出版物,5429条边,总共7种类别。数据集中的每个出版物都由一个 0/1 值的词向量描述,表示字典中相应词的缺失/存在。 该词典由 1433 个独特的词组成。意思就是说每一个出版物都由1433个特征构成,每个特征仅由0/1表示。它是在Semi-Supervised Learning with Graph Embeddings项目中生成的,可以用于可视化和分析节点之间的连接关系。

Cora数据集的特点包括:

  1. 每个出版物都由一个0/1值的词向量描述,表示字典中相应词的缺失/存在。
  2. 该词典由1433个独特的词组成。
  3. 数据集包含以下文件:
    ind.cora.x:训练集节点特征向量,保存对象为:scipy.sparse.csr.csr_matrix,实际展开后大小为:(140,1433)
    ind.cora.tx:测试集节点特征向量,保存对为:scipy.sparse.csr.csr_matrix,实际展开后大小为:(1000,1433)
    ind.cora.allx:包含有标签和无标签的训练节点特征向量,保存对象为: scipy.sparse.csr.csr_matrix,实际展开后大小为:(1708,1433)
    ind.cora.y:one-hot表示的训练节点的标签,保存对象为:numpy.ndarray
    ind.cora.ty:one-hot表示的测试节点的标签,保存对象为:numpy.ndarray
    ind.cora.ally:one-hot表示的ind.cora.allx对应的标签,保存对象为:numpy.ndarray。
    在这里插入图片描述

使用PyG加载Cora数据集:

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())
data = dataset[0]
print(data)

在这里插入图片描述

print(data.x)	# 节点特征矩阵[2708,1433]
print(data.y)	# 节点类别
print(data.edge_index)	# 边
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of nodes: {data.num_nodes}')  # 节点数量
print(f'Number of edges: {data.num_edges}')  # 边数量
print(f'Number of node features: {data.num_node_features}')  # 节点特征维度
print(f'Number of node features: {data.num_features}')  # 节点特征维度
print(f'Number of edge features: {data.num_edge_features}')  # 边特征维度
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')  # 平均节点度

在这里插入图片描述

搭建GCN模型

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x
# 输入通道数:dataset.num_features=1433,即节点特征维度
# 输出通道数:dataset.num_classes=7,即节点类别数
model = GCN(dataset.num_features, 16, dataset.num_classes)        

在这里插入图片描述

定义损失函数

criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.

定义优化器

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

优化器选择Adam,学习率设置为0.01。

训练与测试

训练

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss

测试

def test():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    test_correct = pred == data.y	# 计算分类正确的节点数
    test_acc = int(test_correct.sum()) / int(data.num_nodes)	# 计算正确率
    return test_acc

迭代并输出

e, l, acc = [], [], []
for epoch in range(1, 201):
    loss = train()
    a = test()
    e. append(epoch)
    l.append(loss)
    acc.append(a)
    print(f'Epoch: {epoch:03d}, Acc: {a:04f}, Loss: {loss:.4f}')
matplotlib.rc("font", family='FangSong')
plt.plot(e, l, color='red', linewidth=2, linestyle="solid", label='loss')
plt.plot(e, acc, color='green', linewidth=2, linestyle="solid", label='acc')
plt.legend()
plt.xlabel("epoch")
plt.show()

其中,定义了两个列表lacc分别用于存储每轮迭代的损失值和准确率,便于后续使用plt可视化输出。
迭代训练过程可视化
在这里插入图片描述
经过200次迭代训练分类准确率达到0.8左右,CELoss由1.9将至0.05左右并趋于收敛。

完整代码

import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x


dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())
data = dataset[0]
print(data)

print(data.x)	# 节点特征矩阵[2708,1433]
print(data.y)	# 节点类别
print(data.edge_index)	# 边
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of nodes: {data.num_nodes}')  # 节点数量
print(f'Number of edges: {data.num_edges}')  # 边数量
print(f'Number of node features: {data.num_node_features}')  # 节点特征维度
print(f'Number of node features: {data.num_features}')  # 节点特征维度
print(f'Number of edge features: {data.num_edge_features}')  # 边特征维度
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')  # 平均节点度


model = GCN(dataset.num_features, 16, dataset.num_classes)

print(model)
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

# 训练
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss

# 测试
def test():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    test_correct = pred == data.y
    test_acc = int(test_correct.sum()) / int(data.num_nodes)
    return test_acc


e, l, acc = [], [], []
for epoch in range(1, 201):
    loss = train()
    a = test()
    e. append(epoch)
    l.append(loss)
    acc.append(a)
    print(f'Epoch: {epoch:03d}, Acc: {a:04f}, Loss: {loss:.4f}')
matplotlib.rc("font", family='FangSong')
plt.plot(e, l, color='red', linewidth=2, linestyle="solid", label='loss')
plt.plot(e, acc, color='green', linewidth=2, linestyle="solid", label='acc')
plt.legend()
plt.xlabel("epoch")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/421771.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ChatGPT,开启人机交互新篇章

ChatGPT在世界掀起了生成式AI的热潮,2个月实现月活用户过亿,是人类有史以来突破1亿人用户最快的消费端互联网产品,打破了Tiktok9个月破亿用户的纪录。不少专家将其视为第四次工业革命,资本市场也贡献大量涨停。当第一波的热情消退…

Android 7.1 Toast修复之终极篇,进程不奔溃(包含apk和兼容外来dex插件)

修复android 7.1 Toast的篇章: 常规app通过ams lancet 字节编码处理:Android Lancet Aop 字节编码修复7.1系统Toast问题(WindowManager$BadTokenException)多渠道游戏app兼容性处理:Android 7.1 Toast修复之多渠道包动态使用Booster或者Lancet plugin …

在外web浏览器远程访问jupyter notebook服务器【内网穿透】

文章目录前言视频教程1. Python环境安装2. Jupyter 安装3. 启动Jupyter Notebook4. 远程访问4.1 安装配置cpolar内网穿透4.2 创建隧道映射本地端口转载自远控源码文章:公网远程访问jupyter notebook【cpolar内网穿透】 前言 Jupyter Notebook,它是一个交…

未来城市的微小单元:滴滴即将量产无人车

汽车诞生之后就一直作为除了家庭与公司之外的「第三空间」存在,技术的脚步从未停止过开发汽车的更多可能。尤其无人驾驶技术的出现,进一步解放了驾驶者,也让人们对于这一能够自主移动的第三空间充满了想象。作为未来城市的微小组成单元&#…

( “树” 之 DFS) 226. 翻转二叉树 ——【Leetcode每日一题】

226. 翻转二叉树 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返回其根节点。 示例 1: 输入:root [4,2,7,1,3,6,9] 输出:[4,7,2,9,6,3,1] 示例 2: 输入:root [2,1,3] 输出:[…

ThreadLocal源码分析及内存泄漏

ThreadLocal原理分析及内存泄漏ThreadLocal的使用ThreadLocal原理set方法解析replaceStaleEntry方法解析expungeStaleEntry方法解析cleanSomeSlots方法解析case 1: 向前有脏数据,向后找到可覆盖的Entrycase 2: 向前有脏数据,向后未找到可覆盖的Entrycase…

吴恩达机器学习--线性回归

文章目录前言一、单变量线性回归1.导入必要的库2.读取数据3.绘制散点图4.划分数据5.定义模型函数6.定义损失函数7.求权重向量w7.1 梯度下降函数7.2 最小二乘法8.训练模型9.绘制预测曲线10.试试正则化11.绘制预测曲线12.试试sklearn库二、多变量线性回归1.导入库2.读取数据3.划分…

掌握高效绘制地图的利器——LeafletJs

文章目录前言一、leafletJs是什么?二、快速入门1、安装2、快速入门三、进阶学习1、Map 控件2、Marker 标记3、Popup 弹出窗口4、图层四、项目实战封装文件4.1 基础点位图4.2 行驶轨迹图前言 GIS 作为获取、存储、分析和管理地理空间数据的重要工具,用 G…

数据结构与算法一览(树、图、排序算法、搜索算法等)- Review

算法基础简介 - OI Wiki (oi-wiki.org) 文章目录1. 数据结构介绍1.1 什么是数据结构1.2 数据结构分类2. 链表、栈、队列:略3. 哈希表:略4. 树4.1 二叉树4.2 B 树与 B 树4.3 哈夫曼(霍夫曼)树:Huffman Tree4.4 线段树&a…

编辑文件/文件夹权限 - Win系统

前言 我们经常会遇到由于权限不够无法删除文件/文件夹的情况,解决方案一般是编辑文件/文件夹的权限,使当前账户拥有文件的完全控制权限,然后再进行删除,下文介绍操作步骤。 修改权限 查看用户权限 右键文件/文件夹,…

(函数指针) 指向函数的指针

函数指针- 指向函数的指针函数指针的声明和使用通过函数指针调用函数函数指针做参数函数指针数组函数指针的声明和使用 函数指针的声明格式: 返回值类型 (*函数指针名)(参数列表); 其中: *函数指针名 表示函数指针的名称返回值类型 则表示该指针所指向…

【Kubernetes】StatefulSet对象详解

文章目录简介1. StatefulSet对象的概述、作用及优点1.1 对比Deployment对象和StatefulSet对象1.2 以下是比较Deployment对象和StatefulSet对象的优缺点:2. StatefulSet对象的基础知识2.1 StatefulSet对象的定义2.1.1 下表为StatefulSet对象的定义及其属性&#xff1…

上岸川大网安院

一些感慨 一年多没写过啥玩意了,因为考研去了嘿嘿。拟录取名单已出,经历一年多的考研之路也可以顺利打上句号了。 我的初试成绩是380,政治65,英语81,数学119,专业课115。 回顾这一路,考研似乎也…

分类预测 | MATLAB实现CNN-BiLSTM-Attention多输入分类预测

分类预测 | MATLAB实现CNN-BiLSTM-Attention多输入分类预测 目录分类预测 | MATLAB实现CNN-BiLSTM-Attention多输入分类预测分类效果基本介绍模型描述程序设计参考资料分类效果 基本介绍 MATLAB实现CNN-BiLSTM-Attention多输入分类预测,CNN-BiLSTM结合注意力机制多输…

Vue3使用Vant组件库避坑总结

文章目录前言一、问题二、解决方法三、问题出现原因总结经验教训前言 本片文章主要写了,Vue3开发时运用Vant UI库的一些避坑点。让有问题的小伙伴可以快速了解是为什么。也是给自己做一个记录。 一、问题 vue3版本使用vant失败,具体是在使用组件时失效…

IPBX系统快速部署和Freeswitch 1.10.7自动安装

IPBX系统部署文档 IPPBX系统 1.10.7版本Freeswitch ,手机互联互通,SIP协议,分机互相拨打免费通话清晰,支持wifi或4G网络互相拨打电话,可以对接OLT设备,系统可以部署到本地物理机,也可以部署到阿…

工程质量之研发过程管理需要关注的点

一、背景 作为程序猿,工程质量是我们逃不开的一个话题,工程质量高带来的好处多多,我在写这篇文章的时候问了一下CHATGPT,就当娱乐一下,以下是ChatGPT的回答: 1、提高产品或服务的可靠性和稳定性。高质量的系…

光时域反射仪那个品牌的好用

光时域反射仪 哪个品牌好用 光时域反射仪要怎么选到合适自己的,这些问题 可能一直在困扰这一线的工作人员,下面小编就为大家一一解答下 首先光时域域反射仪是一款检测光纤线路的损耗 长度 以及 事件点的一款设备,在诊断 光纤线路 故障点的情…

从零开始学架构——CAP理论

CAP定理 CAP 定理(CAP theorem)又被称作布鲁尔定理(Brewer’s theorem),是加州大学伯克利分校的计算机科学家埃里克布鲁尔(Eric Brewer)在 2000 年的 ACM PODC 上提出的一个猜想。2002 年&…

Web前端 HTML、CSS

HTML与CSSHTML、CSS思维导图一、HTML1.1、HTML基础文本标签1.2、图片、音频、视频标签1.3、超链接、表格标签1.4、布局1.5、表单标签1.6、表单项标签综合使用1.7、HTML小结二、CSS(简介)2.1、引入方式2.2、选择器2.3、CSS属性Web前端开发总览 Html&…