图神经网络:(节点分类)在Cora数据集上动手实现图神经网络

news2025/1/18 19:14:35

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。
3)我在百度网盘上传了这篇文章的jupyter notebook。超链。提取码8888。

文章目录

    • 代码实操1:GCN的复杂实现
    • 代码实操2:GCN的简单实现
    • 代码实操3:GAT的简单实现

代码实操1:GCN的复杂实现

导入绘图的库,定义绘图函数。

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def visualize(h,color):
    z=TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])
    plt.scatter(z[:,0],z[:,1],s=70,c=color,cmap="Set2")
    plt.show()

目前,我并不知道TSNE降维理论。所以,暂时把它作为一种降维并且可视化的技术。
导入对应的库,导入对应的数据集,导入对应的库。

from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.datasets import Planetoid
dataset=Planetoid(root='/DATA/Planetoid',name='Cora',transform=NormalizeFeatures())
data=dataset[0]
#确定具体的图

Cora数据集简单说明:特征矩阵 N × M N \times M N×M N N N表示为论文数量, M M M表示为特征维度,对于每维,如果单词在论文中,就是1,反之0。邻接矩阵 N × N N \times N N×N N N N表示为论文数量,论文间存在引用,之间就有一条边。
其他说明:这段代码会在C盘,生成一个叫做DATA的文件,并将数据集放在DATA之中,有强迫症注意一下。

import torch.nn.functional as F
from torch.nn import Linear
import torch

搭建一个多层的感知机,训练模型并且得到结果。

class MLP(torch.nn.Module):

    def __init__(self,hidden_channels):
        super().__init__()
        self.lin1=Linear(dataset.num_features,hidden_channels)
        self.lin2=Linear(hidden_channels,dataset.num_classes)

    def forward(self,x):
        x=self.lin1(x)
        x=x.relu()
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.lin2(x)
        return x

model=MLP(hidden_channels=16)
print(model)
#输出:
#MLP(
#  (lin1): Linear(in_features=1433, out_features=16, bias=True)
#  (lin2): Linear(in_features=16, out_features=7, bias=True)
#)
model=MLP(hidden_channels=16)
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)

def train():
      model.train()
      optimizer.zero_grad()
      out=model(data.x)
      loss=criterion(out[data.train_mask],data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test():
      model.eval()
      out=model(data.x)
      pred=out.argmax(dim=1)
      test_correct=pred[data.test_mask]==data.y[data.test_mask]
      test_acc=int(test_correct.sum())/int(data.test_mask.sum())
      return test_acc

for epoch in range(1,201):
    loss=train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
#这里就不展示输出
test_acc=test()
print(f'Test Accuracy: {test_acc:.4f}')
#输出:Test Accuracy: 0.5750

导入对应的库,搭建图神经网络GCN

from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
    def __init__(self,hidden_channels):
        super().__init__()
        self.conv1=GCNConv(dataset.num_features,hidden_channels)
        self.conv2=GCNConv(hidden_channels,dataset.num_classes)
    def forward(self,x,edge_index):
        x=self.conv1(x,edge_index)
        x=x.relu()
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.conv2(x,edge_index)
        return x
model=GCN(hidden_channels=16)
print(model)
#输出:
#GCN(
#  (conv1): GCNConv(1433, 16)
#  (conv2): GCNConv(16, 7)
#)

可视化图嵌入(这里只有正向传播)

model=GCN(hidden_channels=16)
model.eval()
out=model(data.x,data.edge_index)
visualize(out,color=data.y)

在这里插入图片描述

进行训练得出结果

model=GCN(hidden_channels=16)
optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)
criterion=torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()
      out=model(data.x, data.edge_index)
      loss=criterion(out[data.train_mask],data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test():
      model.eval()
      out=model(data.x,data.edge_index)
      pred=out.argmax(dim=1)
      test_correct=pred[data.test_mask]==data.y[data.test_mask]
      test_acc=int(test_correct.sum())/int(data.test_mask.sum())
      return test_acc


for epoch in range(1,101):
    loss=train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
#这里就不展示输出
test_acc=test()
print(f'Test Accuracy: {test_acc:.4f}')
#输出:Test Accuracy: 0.8010

可视化图嵌入(训练过后)
在这里插入图片描述

代码实操2:GCN的简单实现

这是PYG官方文档的代码,就以难度而言其实就是少了可视化的东西。构建GCN的框架不同,使用损失函数不同。

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=GCNConv(dataset.num_node_features,16)
        self.conv2=GCNConv(16,dataset.num_classes)
    def forward(self,data):
        x,edge_index=data.x,data.edge_index
        x=self.conv1(x,edge_index)
        x=F.relu(x)
        x=F.dropout(x,training=self.training)
        x=self.conv2(x,edge_index)
        return F.log_softmax(x,dim=1)
dataset=Planetoid(root='/DATA/Cora',name='Cora')
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=GCN().to(device)
data=dataset[0].to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out=model(data)
    loss=F.nll_loss(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
model.eval()
pred=model(data).argmax(dim=1)
correct=(pred[data.test_mask]==data.y[data.test_mask]).sum()
acc=int(correct)/int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
#输出:Accuracy: 0.8090

代码实操3:GAT的简单实现

这里操作同上,代码略有不同。

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv
import torch.nn.functional as F
import torch
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=GATConv(dataset.num_node_features,16)
        self.conv2=GATConv(16,dataset.num_classes)
    def forward(self,data):
        x,edge_index=data.x,data.edge_index
        x=F.dropout(x,p=0.6,training=self.training)
        x=self.conv1(x,edge_index)
        x=F.relu(x)
        x=F.dropout(x,p=0.6,training=self.training)
        x=self.conv2(x,edge_index)
        return x
dataset=Planetoid(root='/DATA/Cora',name='Cora')
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu');model=GCN().to(device);data=dataset[0].to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=0.05,weight_decay=5e-4);criterion=torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out=model(data)
    loss=criterion(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
model.eval()
pred=model(data).argmax(dim=1);correct=(pred[data.test_mask]==data.y[data.test_mask]).sum();acc=int(correct)/int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
#输出:Accuracy: 0.7980

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/527947.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从零开始Vue3+Element Plus的后台管理系统(二)——Layout页面布局的实现

项目搭建好之后,开始写基本的布局。后台管理系统的布局3大元素:头部、侧栏、主要内容,各种布局结构相差不大,我选择了下图所示的布局,其中头部、侧栏、页签在页面中是固定的,只有主要内容容器会跟随页面滚动…

如何从计算机或 SD 卡中恢复已删除的音乐文件?

与我们中的许多人一样,您可能已经从喜爱的专辑中下载并保存了多个音乐文件以供离线收听,但如果您不小心或意外删除了这些音乐文件怎么办?不用担心,我们在这里列出了几种从计算机或 SD 卡中恢复已删除或丢失的音乐文件的方法。 您…

001+limou+Git的安装与入门

0.前言 您好,这里是limou3434的一篇个人博文,感兴趣的话您也可以看看我的其他文章。本系列主要深入讲解有关Git的基础知识和基础使用,在文章中会结合部分Git网站上推荐的电子书《Pro Git》来对Git进行解读,意在补充书中对您“不友…

Java 面试 | RabbitMQ(2023版)

文章目录 rabbitmq1、为什么要使用rabbitmq2、rabbitmq如何确保消息发送?消息接收?3、RabbitMQ的构造4、Exchange交换器的类型5、RabbitMQ的持久化6、RabbitMQ消息发送和接收过程7、如何保证消息队列的高可用8、如何处理消息丢失的情况9、如何保证消息没有重复消费10、如何保…

Shell系统编程三剑客之----sed编辑器

目录 一:sed编辑器 1.sed编辑器概述 2.sed的工作流程 3.sed的命令格式 4.常用选项 5.常用操作 二:sed操作事例 1.查询 (1)打印内容 ​(2)打印行数 ​(3)打印特殊字符、ASCII码 &…

python爬虫简述

Python爬虫是一种自动化获取互联网数据的技术,它可以通过编写程序自动访问网站并抓取所需的数据。在本文中,我们将介绍Python爬虫的基础知识、常用库和实际应用。 一、Python爬虫的基础知识 爬虫的定义 爬虫是一种自动化获取互联网数据的技术&#xf…

屏幕录像怎么录?分享3个简单实用的方法!

案例:怎么录制电脑屏幕? 【对于我这种不太熟悉电脑的人来说,想要录制电脑屏幕十分困难。听说录制电脑屏幕,需要用到录屏工具。有没有小伙伴有好的录屏软件介绍,顺便附带一下教程!求!】 屏幕录…

【冶金轧钢、电厂 JL-8B/E集成电路电流继电器 CMOS运算 JOSEF约瑟】

JL-8B/E集成电路电流继电器名称:集成电路电流继电器型号:JL-8B/E触点容量250V5A功率消耗<5W返回系数过电流:0.90.97;欠电流:1.051.15整定范围0.03~60A 系列型号: JL-8A/E集成电路电流继电器; JL-8B/E集成电路电流继电器; JL-8A/E11-004集成电…

[离散数学]命题逻辑与推理

目录 主析取范式 主合取范式推理理论(假设前提条件为真推出的结论)真值表法直接证明法** 常用推理公式 ** 间接证明 CP规则--附加前提证明法,证明比较方便 单条件形式,提取前件间接法 归谬法 结论是单命题,取反前提引入 常用 latex 定义 主析…

Druid未授权漏洞进一步的利用

一、漏洞描述 Druid是阿里巴巴数据库出品的为监控而生的数据库连接池。并且Druid提供的监控功能包括监控SQL的执行时间、监控Web URI的请求、Session监控等。Druid本身是不存在什么漏洞的,但当开发者配置不当时就可能造成未授权访问。本文除了介绍Druid未授权漏洞之…

js 使用正则获取 html中 所有span标签

let html <p>艾迪莎测试但大家还是</p><h1>你好啊</h1><p>啊是多久啊合适的<span style"text-decoration: underline;">静安寺</span>大家哈圣诞节<span style"text-decoration: underline;">哈桑</s…

企业需要专业电子邮件地址的4大原因

专业的企业电子邮件地址具有贵公司的自定义域名&#xff0c;而不是通用的Zoho Mail 、gmail或yahoo帐户&#xff0c;例如&#xff1a;john stargardening.com 大多数初学者使用不带域名的通用免费企业电子邮件帐户&#xff0c;这不是很专业。例如&#xff1a;zhangsan2022zoho.…

从零开始Vue3+Element Plus后台管理系统(六)——状态管理Pinia和持久化

Pinia 官网&#xff1a;https://pinia.vuejs.org/zh/ Pinia 是 Vue 的专属状态管理库&#xff0c;相比Vuex更好用&#xff0c;优点不多了说官网有&#xff0c;用起来最重要&#xff01; 在应用的根部注入创建的 pinia // main.ts import { createApp } from vue import { c…

CLMP证书:让你在职场中脱颖而出的秘密武器!

CLMP证书是一种精益管理专业证书&#xff0c;是针对精益管理领域的专业人士和学生的培训项目&#xff0c;旨在提高他们在精益管理方面的技能和知识。那么&#xff0c;CLMP证书的含金量高吗&#xff1f;接下来我们来探讨一下。 CLMP证书的优势体现 首先&#xff0c;CLMP证书的…

Android Jetpack Compose之使用脚手架快速搭建APP布局结构

概述 现在市场上大多数的手机APP的通用布局结构都是顶部有个顶部导航栏&#xff0c;底部有个底部导航栏&#xff0c;例如抖音的布局结构&#xff1a; 点击导航栏里面的各个项又可以跳转到相应的页面&#xff0c;现在这种结构特别流行&#xff0c;如果我们使用传统的View来实现…

Web自动化测试-如何进行Selenium页面数据及元素交互?教你一步不漏。

目录 前言&#xff1a; 一、Selenium简介 二、安装Selenium 1.Windows用户安装Selenium 2.安装Chrome浏览器驱动 三、使用Selenium进行页面数据及元素交互 1.启动浏览器 2.访问网页 3.查找元素 4.输入文本 5.点击按钮 6.提交表单 四、完整代码示例 五、总结 Web自…

(MIT6.045)自动机、可计算性和复杂性-DFA和NFA

毕业论文写完了。找点事干干。 佛系更新。 这是一门讲述 什么是计算&#xff1f;什么能被计算&#xff1f;怎么高效计算&#xff1f; 的哲学、数学和工程问题的课程。 主要包括&#xff1a; 有限状态机&#xff08;Finite Avtomata&#xff09;&#xff1a;简单的模型。 可…

【OpenCV-Python】——机器学习kNN算法SVM算法k均值聚类算法深度学习图像识别对象检测

目录 前言&#xff1a; 1、机器学习 1.1 kNN算法 1.2 SVM算法&#xff08;支持向量机&#xff09; 1.3 k均值聚类算法 2、深度学习 2.1 基于深度学习的图像识别 2.2 基于深度学习的对象检测 总结&#xff1a; 前言&#xff1a; 机器学习&#xff08;ML&#xff09;是人…

Linux-权限

1. 认识Linux下用户的分类 root普通用户1.1用户切换 普通用户 转 root su //当前路径切换rootsu - //重新登陆到/root退出 crtl d / exit root 转 普通用户 不需要输入密码 su 用户名退出 ctrl d 1. 2 指令暂时提权 sudo command目前我们用adduser新建的用户&#xff0…

100个软件开发领域必须掌握的关键词,掌握一个都难啊

需要完整xmind文件&#xff0c;私信获取 100个软件开发领域必须掌握的关键词 基础编程语言 JavaPythonC#JavaScriptPHPRubyCObjective-CSwiftKotlin Web 开发 HTMLCSSJavaScriptReactAngularVue.jsjQueryBootstrapNode.jsExpress.js 移动应用开发 AndroidiOSFlutterRea…