图神经网络:在Cora数据集上动手实现图神经网络

news2024/9/28 3:30:26

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。
3)我在百度网盘上传了这篇文章的jupyter notebook。超链。提取码8888。

文章目录

    • 代码实操1:GCN的复杂实现
    • 代码实操2:GCN的简单实现
    • 代码实操3:GAT的简单实现

代码实操1:GCN的复杂实现

导入绘图的库,定义绘图函数。

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def visualize(h,color):
    z=TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])
    plt.scatter(z[:,0],z[:,1],s=70,c=color,cmap="Set2")
    plt.show()

目前,我并不知道TSNE降维理论。所以,暂时把它作为一种降维并且可视化的技术。
导入对应的库,导入对应的数据集,导入对应的库。

from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.datasets import Planetoid
dataset=Planetoid(root='/DATA/Planetoid',name='Cora',transform=NormalizeFeatures())
data=dataset[0]
#确定具体的图

Cora数据集简单说明:特征矩阵 N × M N \times M N×M N N N表示为论文数量, M M M表示为特征维度,对于每维,如果单词在论文中,就是1,反之0。邻接矩阵 N × N N \times N N×N N N N表示为论文数量,论文间存在引用,之间就有一条边。
其他说明:这段代码会在C盘,生成一个叫做DATA的文件,并将数据集放在DATA之中,有强迫症注意一下。

import torch.nn.functional as F
from torch.nn import Linear
import torch

搭建一个多层的感知机,训练模型并且得到结果。

class MLP(torch.nn.Module):

    def __init__(self,hidden_channels):
        super().__init__()
        self.lin1=Linear(dataset.num_features,hidden_channels)
        self.lin2=Linear(hidden_channels,dataset.num_classes)

    def forward(self,x):
        x=self.lin1(x)
        x=x.relu()
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.lin2(x)
        return x

model=MLP(hidden_channels=16)
print(model)
#输出:
#MLP(
#  (lin1): Linear(in_features=1433, out_features=16, bias=True)
#  (lin2): Linear(in_features=16, out_features=7, bias=True)
#)
model=MLP(hidden_channels=16)
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)

def train():
      model.train()
      optimizer.zero_grad()
      out=model(data.x)
      loss=criterion(out[data.train_mask],data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test():
      model.eval()
      out=model(data.x)
      pred=out.argmax(dim=1)
      test_correct=pred[data.test_mask]==data.y[data.test_mask]
      test_acc=int(test_correct.sum())/int(data.test_mask.sum())
      return test_acc

for epoch in range(1,201):
    loss=train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
#这里就不展示输出
test_acc=test()
print(f'Test Accuracy: {test_acc:.4f}')
#输出:Test Accuracy: 0.5750

导入对应的库,搭建图神经网络GCN

from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
    def __init__(self,hidden_channels):
        super().__init__()
        self.conv1=GCNConv(dataset.num_features,hidden_channels)
        self.conv2=GCNConv(hidden_channels,dataset.num_classes)
    def forward(self,x,edge_index):
        x=self.conv1(x,edge_index)
        x=x.relu()
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.conv2(x,edge_index)
        return x
model=GCN(hidden_channels=16)
print(model)
#输出:
#GCN(
#  (conv1): GCNConv(1433, 16)
#  (conv2): GCNConv(16, 7)
#)

可视化图嵌入(这里只有正向传播)

model=GCN(hidden_channels=16)
model.eval()
out=model(data.x,data.edge_index)
visualize(out,color=data.y)

在这里插入图片描述

进行训练得出结果

model=GCN(hidden_channels=16)
optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)
criterion=torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()
      out=model(data.x, data.edge_index)
      loss=criterion(out[data.train_mask],data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test():
      model.eval()
      out=model(data.x,data.edge_index)
      pred=out.argmax(dim=1)
      test_correct=pred[data.test_mask]==data.y[data.test_mask]
      test_acc=int(test_correct.sum())/int(data.test_mask.sum())
      return test_acc


for epoch in range(1,101):
    loss=train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
#这里就不展示输出
test_acc=test()
print(f'Test Accuracy: {test_acc:.4f}')
#输出:Test Accuracy: 0.8010

可视化图嵌入(训练过后)
在这里插入图片描述

代码实操2:GCN的简单实现

这是PYG官方文档的代码,就以难度而言其实就是少了可视化的东西。构建GCN的框架不同,使用损失函数不同。

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=GCNConv(dataset.num_node_features,16)
        self.conv2=GCNConv(16,dataset.num_classes)
    def forward(self,data):
        x,edge_index=data.x,data.edge_index
        x=self.conv1(x,edge_index)
        x=F.relu(x)
        x=F.dropout(x,training=self.training)
        x=self.conv2(x,edge_index)
        return F.log_softmax(x,dim=1)
dataset=Planetoid(root='/DATA/Cora',name='Cora')
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=GCN().to(device)
data=dataset[0].to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out=model(data)
    loss=F.nll_loss(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
model.eval()
pred=model(data).argmax(dim=1)
correct=(pred[data.test_mask]==data.y[data.test_mask]).sum()
acc=int(correct)/int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
#输出:Accuracy: 0.8090

代码实操3:GAT的简单实现

这里操作同上,代码略有不同。

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv
import torch.nn.functional as F
import torch
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=GATConv(dataset.num_node_features,16)
        self.conv2=GATConv(16,dataset.num_classes)
    def forward(self,data):
        x,edge_index=data.x,data.edge_index
        x=F.dropout(x,p=0.6,training=self.training)
        x=self.conv1(x,edge_index)
        x=F.relu(x)
        x=F.dropout(x,p=0.6,training=self.training)
        x=self.conv2(x,edge_index)
        return x
dataset=Planetoid(root='/DATA/Cora',name='Cora')
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu');model=GCN().to(device);data=dataset[0].to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=0.05,weight_decay=5e-4);criterion=torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out=model(data)
    loss=criterion(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
model.eval()
pred=model(data).argmax(dim=1);correct=(pred[data.test_mask]==data.y[data.test_mask]).sum();acc=int(correct)/int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
#输出:Accuracy: 0.7980

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/504364.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IntelliJ Platform-Plugins-获取方法相关信息(PsiElement/PsiMethodImpl)

PsiElement接口是文件中光标所在的那个字段,或者光标所在的那个方法的抽象,例如下图中PsiElement就是public String getName(),它的实现类是PsiMethodImpl 下面的代码会演示:如果光标在方法上,就打印方法名字&#x…

「Cpolar」内网穿透实现在外远程连接MongoDB数据库【端口映射】

💂作者简介: THUNDER王,一名热爱财税和SAP ABAP编程以及热爱分享的博主。目前于江西师范大学本科在读,同时任汉硕云(广东)科技有限公司ABAP开发顾问。在学习工作中,我通常使用偏后端的开发语言A…

第二十四章 Unity 纹理贴图

通常情况下,3D网格模型只能展示游戏对象的几何形状,而表面的细节则纹理贴图提供。纹理贴图通过UV坐标“贴附”在模型的表面。当然,这个过程不需要我们在Unity中完成,而是在建模软件中完成的。通常情况下,我们通过3ds m…

鸿蒙Hi3861学习九-Huawei LiteOS-M(互斥锁)

一、简介 互斥锁又被称为互斥型信号量,是一种特殊的二值信号量,用于实现对共享资源的独占式处理。 任意时刻互斥锁的状态只有两种:开锁或闭锁。 当有任务占用公共资源时,互斥锁处于闭锁状态,这个任务获得该互斥锁的使用…

C++系列六:一文打尽C++运算符

C运算符 1. 算术运算符2. 关系运算符3. 逻辑运算符4. 按位运算符5. 取地址运算符6. 取内容运算符7. 成员选择符8. 作用域运算符9. 总结 1. 算术运算符 算术运算符用于执行基本数学运算,例如加减乘除和取模等操作。下表列出了C中支持的算术运算符: 运算…

Oracle 19C 单机环境升级RU(19.3升级至19.12)

📢📢📢📣📣📣 哈喽!大家好,我是【IT邦德】,江湖人称jeames007,10余年DBA及大数据工作经验 一位上进心十足的【大数据领域博主】!😜&am…

鸿蒙Hi3861学习五-Huawei LiteOS-M(任务管理)

一、任务简介 关于任务的相关介绍,之前文章有比较详细的介绍,这里不做过多解释,可以参考如下文章:FreeRTOS学习二(任务)_t_guest的博客-CSDN博客 而LiteOS的主要特性可以总结为如下几点: LiteO…

〖数据挖掘〗weka3.8.6的安装与使用

目录 背景 一、安装 二、使用explorer 1. 介绍 2.打开自带的数据集(Preprocess) 1.打开步骤 2.查看属性和数据编辑 3.classify 4.Cluster 5.Associate 6.Select attributes 7.Visualize 待补充 背景 Weka的全名是怀卡托智能分析环境(Waikato Environme…

低代码平台解读—如何不写代码创建表单和维护表单

工作表新建与修改——敲敲云 新建工作表的流程包含 新建工作表/编辑公祖表为工作表添加字段,例如“员工档案”表中有姓名、性别、年龄等字段为字段设置属性工作表布局工作表预览、保存、关闭 1、新建工作表/修改工作表 新建工作表 修改工作表 2、为工作表添加字段 …

c#笔记-定义类

声明类 类可以使用帮助你管理一组相互依赖的数据,来完成某些职责。 类使用class关键字定义,并且必须在所有顶级语句之下。 类的成员只能有声明语句,不能有执行语句。 class Player1 {int Hp;int MaxHp;int Atk;int Def;int Overflow(){if (…

算法记录 | Day55 动态规划

392.判断子序列 思路: 1.确定dp数组(dp table)以及下标的含义: dp[i][j] 表示以下标i-1为结尾的字符串s,和以下标j-1为结尾的字符串t,相同子序列的长度为dp[i][j]。 2.确定递推公式: if (s[i - 1] t[…

线程同步、生产者消费模型和POSIX信号量

gitee仓库: 1.阻塞队列代码:https://gitee.com/WangZihao64/linux/tree/master/BlockQueue 2.环形队列代码:https://gitee.com/WangZihao64/linux/tree/master/ringqueue 条件变量 概念 概念: 利用线程间共享的全局变量进行同…

单片机c51中断 — 开关状态监测

项目文件 文件 关于项目的内容知识点可以见专栏单片机原理及应用 的第五章,中断 图中 P2.0引脚处接有一个发光二极管 D1,P3.2引脚处接有一个按键。要求分别采用一般方式和中断方式编程实现按键压下一次,D1 的发光状态反转一次的功能。 查询…

从C语言到C++⑦(第二章_类和对象_下篇)初始化列表+explicit+static成员+友元+内部类+匿名对象

目录 1. 构造函数的初始化列表 1.1 初始化列表概念 1.2 初始化列表注意事项 2. 构造函数的explicit关键字 2.1 C语言的隐式类型转换 2.2 explicit 关键字使用 3. static成员 3.1 static的概念 3.2 static成员特性 3.3 static成员使用场景 4. 友元(frien…

【Java 基础】类和对象 方法重载详解

《Java 零基础入门到精通》专栏持续更新中。通过本专栏你将学习到 Java 从入门到进阶再到实战的全套完整内容,所有内容均将集中于此专栏。无论是初学者还是有经验的开发人员,都可从本专栏获益。 订阅专栏后添加我微信或者进交流群,进群可找我领取 前端/Java/大数据/Python/低…

Linux 常用命令(1)

文章目录 Linux 常用命令格式 clear 清屏清屏获取当前目录的路径 pwd目录切换命令 cd进入上一级目录进入当前目录的文件夹 ta中(假设这里有一个文件夹ta)进入主目录进入根目录 显示目录内容 ls显示详细信息,包含文件属性显示全部内容,包含隐藏文件&#…

tiechui_lesson07_中断级和自旋锁

一、中断级IRQL 高级别可以打断低级别的调用,同级别不能打断同级别的调用。 中断级在软件层面分为三级,再高的级别是硬件发送的中断。 - 0 pass_level- 1 apc_level- 2 dpc_level 只有硬件中断能打断 1.获取中断级 DbgPrint("当前执行中断级为 %…

无法防范的网络攻击-DDOS

DDoS攻击(Distributed Denial of Service Attack)是一种网络攻击方式,攻击者通过利用大量的计算机或者网络设备向目标服务器发送大量的请求,使得目标服务器无法正常响应合法用户的请求,从而导致服务不可用或者服务质量…

M302H-YS-Hi3798MV300H/MV310-当贝纯净桌面卡刷固件包

M302H-YS-Hi3798MV300H/MV310-当贝纯净桌面卡刷固件包-内有教程及短接点提示 特点: 1、适用于对应型号的电视盒子刷机; 2、开放原厂固件屏蔽的市场安装和u盘安装apk; 3、修改dns,三网通用; 4、大量精简…

LicheePi4A尝鲜开箱笔记

开发板介绍 LicheePi4A是以 TH1520 主控核心,搭载 4TOPSint8 AI 算力的 NPU,支持双屏 4K 显示输出,支持 4K 摄像头接入,双千兆 POE 网口和多个 USB 接口,音频由 C906 核心处理。 LicheePi4A详细介绍可以在https://wi…