图神经网络:在Cora上动手实现图神经网络

news2025/1/11 14:57:21

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。
3)我在百度网盘上传了这篇文章的jupyter notebook。超链。提取码8888。

文章目录

    • 代码实操1:GCN的复杂实现
    • 代码实操2:GCN的简单实现
    • 代码实操3:GAT的简单实现

代码实操1:GCN的复杂实现

导入绘图的库,定义绘图函数。

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def visualize(h,color):
    z=TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])
    plt.scatter(z[:,0],z[:,1],s=70,c=color,cmap="Set2")
    plt.show()

目前,我并不知道TSNE降维理论。所以,暂时把它作为一种降维并且可视化的技术。
导入对应的库,导入对应的数据集,导入对应的库。

from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.datasets import Planetoid
dataset=Planetoid(root='/tmp/Planetoid',name='Cora',transform=NormalizeFeatures())
data=dataset[0]
#确定具体的图

Cora数据集简单说明:特征矩阵 N × M N \times M N×M N N N表示为论文数量, M M M表示为特征维度,对于每维,如果单词在论文中,就是1,反之0。邻接矩阵 N × N N \times N N×N N N N表示为论文数量,论文间存在引用,之间就有一条边。
其他说明:这段代码会在C盘,生成一个叫做DATA的文件,并将数据集放在DATA之中,有强迫症注意一下。

import torch.nn.functional as F
from torch.nn import Linear
import torch

搭建一个多层的感知机,训练模型并且得到结果。

class MLP(torch.nn.Module):

    def __init__(self,hidden_channels):
        super().__init__()
        self.lin1=Linear(dataset.num_features,hidden_channels)
        self.lin2=Linear(hidden_channels,dataset.num_classes)

    def forward(self,x):
        x=self.lin1(x)
        x=x.relu()
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.lin2(x)
        return x

model=MLP(hidden_channels=16)
print(model)
#输出:
#MLP(
#  (lin1): Linear(in_features=1433, out_features=16, bias=True)
#  (lin2): Linear(in_features=16, out_features=7, bias=True)
#)
model=MLP(hidden_channels=16)
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)

def train():
      model.train()
      optimizer.zero_grad()
      out=model(data.x)
      loss=criterion(out[data.train_mask],data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test():
      model.eval()
      out=model(data.x)
      pred=out.argmax(dim=1)
      test_correct=pred[data.test_mask]==data.y[data.test_mask]
      test_acc=int(test_correct.sum())/int(data.test_mask.sum())
      return test_acc

for epoch in range(1,201):
    loss=train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
#这里就不展示输出
test_acc=test()
print(f'Test Accuracy: {test_acc:.4f}')
#输出:Test Accuracy: 0.5750

导入对应的库,搭建图神经网络GCN

from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
    def __init__(self,hidden_channels):
        super().__init__()
        self.conv1=GCNConv(dataset.num_features,hidden_channels)
        self.conv2=GCNConv(hidden_channels,dataset.num_classes)
    def forward(self,x,edge_index):
        x=self.conv1(x,edge_index)
        x=x.relu()
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.conv2(x,edge_index)
        return x
model=GCN(hidden_channels=16)
print(model)
#输出:
#GCN(
#  (conv1): GCNConv(1433, 16)
#  (conv2): GCNConv(16, 7)
#)

可视化图嵌入(这里只有正向传播)

model=GCN(hidden_channels=16)
model.eval()
out=model(data.x,data.edge_index)
visualize(out,color=data.y)

在这里插入图片描述

进行训练得出结果

model=GCN(hidden_channels=16)
optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)
criterion=torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()
      out=model(data.x, data.edge_index)
      loss=criterion(out[data.train_mask],data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test():
      model.eval()
      out=model(data.x,data.edge_index)
      pred=out.argmax(dim=1)
      test_correct=pred[data.test_mask]==data.y[data.test_mask]
      test_acc=int(test_correct.sum())/int(data.test_mask.sum())
      return test_acc


for epoch in range(1,101):
    loss=train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
#这里就不展示输出
test_acc=test()
print(f'Test Accuracy: {test_acc:.4f}')
#输出:Test Accuracy: 0.8010

可视化图嵌入(训练过后)
在这里插入图片描述

代码实操2:GCN的简单实现

这是PYG官方文档的代码,就以难度而言其实就是少了可视化的东西。构建GCN的框架不同,使用损失函数不同。

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=GCNConv(dataset.num_node_features,16)
        self.conv2=GCNConv(16,dataset.num_classes)
    def forward(self,data):
        x,edge_index=data.x,data.edge_index
        x=self.conv1(x,edge_index)
        x=F.relu(x)
        x=F.dropout(x,training=self.training)
        x=self.conv2(x,edge_index)
        return F.log_softmax(x,dim=1)
dataset=Planetoid(root='/DATA/Cora',name='Cora')
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=GCN().to(device)
data=dataset[0].to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out=model(data)
    loss=F.nll_loss(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
model.eval()
pred=model(data).argmax(dim=1)
correct=(pred[data.test_mask]==data.y[data.test_mask]).sum()
acc=int(correct)/int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
#输出:Accuracy: 0.8090

代码实操3:GAT的简单实现

这里操作同上,代码略有不同。

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv
import torch.nn.functional as F
import torch
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=GATConv(dataset.num_node_features,16)
        self.conv2=GATConv(16,dataset.num_classes)
    def forward(self,data):
        x,edge_index=data.x,data.edge_index
        x=F.dropout(x,p=0.6,training=self.training)
        x=self.conv1(x,edge_index)
        x=F.relu(x)
        x=F.dropout(x,p=0.6,training=self.training)
        x=self.conv2(x,edge_index)
        return x
dataset=Planetoid(root='/DATA/Cora',name='Cora')
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu');model=GCN().to(device);data=dataset[0].to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=0.05,weight_decay=5e-4);criterion=torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out=model(data)
    loss=criterion(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
model.eval()
pred=model(data).argmax(dim=1);correct=(pred[data.test_mask]==data.y[data.test_mask]).sum();acc=int(correct)/int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
#输出:Accuracy: 0.7980

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/482478.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++语言练习题位运算

位运算(01)基础 位运算(02)从一个 16 位的单元中取出某几位 题目描述 从一个 16 位的单元中取出某几位(即该几位保留原值,其余位为 0. 使用 value 存放该 16 位的数,n1 为欲取出的起始位,n2 为欲取出的结束位。&#xff…

thinkphp6 JWT报错 ‘“kid“ empty, unable to lookup correct key‘解决办法

文章目录 JWT简介安装问题先前的代码解决办法修改后的完整代码 JWT简介 JWT全称为Json Web Token,是一种用于在网络应用之间传递信息的简洁、安全的方式。JWT标准定义了一种简洁的、自包含的方法用于通信双方之间以JSON对象的形式安全的传递信息。由于它的简洁性、可…

[论文笔记] In Search of an Understandable Consensus Algorithm (Extended Version)

In Search of an Understandable Consensus Algorithm (Extended Version) 寻找可理解的共识算法 (扩展版) [Extended Paper] [Original Paper] ATC’14 (Original) 摘要 Raft 是一个用于管理复制日志的共识算法. Raft 更易于理解, 且为构建实际的系统提供了更好的基础. Raf…

apache hive release notes

hive release notes位置 https://github.com/apache/hive/blob/master/RELEASE_NOTES.txt 如何查看不同版本的release note

计算机是如何工作的

一、冯诺依曼体系: CPU中央处理器(运算器控制器):CPU是计算机最核心的部分,进行算数运算和逻辑判断。CPU最重要的指标是“主频”,如:2.5Ghz,描述了CPU的运算速度,可以近…

【React】redux和React-redux

🎀个人主页:努力学习前端知识的小羊 感谢你们的支持:收藏🎄 点赞🍬 加关注🪐 Redux和React-redux reduxredux的使用Redux的工作流Redux APIstoreactionreducerstore.dispatch()redux的方法使用 React-Redux…

python人工智能【隔空手势控制鼠标】“解放双手“

大家好,我是csdn的博主:lqj_本人 这是我的个人博客主页: lqj_本人的博客_CSDN博客-微信小程序,前端,python领域博主lqj_本人擅长微信小程序,前端,python,等方面的知识https://blog.csdn.net/lbcyllqj?spm1011.2415.3001.5343哔哩哔哩欢迎关注…

【计算机图形学基础教程】MFC上机操作步骤

MFC上机操作步骤 步骤1 在Visual Studio界面,选择文件-新建-项目: 步骤2 在新建项目对话框,选择MFC-MFC应用程序: 步骤3 创建一个带有下列特征的新控制台工程框架,主要内容如下: 基于Win32的单文档…

PMP/高项 05-项目进度管理

项目进度管理 概念 项目进度管理(Schedule Management) 项目进度管理又叫项目工期管理(Duration Management)或项目的时间管理(Time Management) 是一种为管理项目按时完成项目所需的各个过程 进度管理过程 规划进度管理 定义活动 排列活动顺序 估算活…

前端web3入门脚本五:decode input data

一、前言 作为一个前端,在调用合约调试的时候,在区块浏览器里拿到一串 hex 格式的 input data,我们应该怎么decode呢? 二、举例 解码交易需要拥有 对应合约的 abi 以及 input data 下面举例介绍怎么获得这两个信息: 参…

二叉搜索树中的众数

1题目 给你一个含重复值的二叉搜索树(BST)的根节点 root ,找出并返回 BST 中的所有 众数(即,出现频率最高的元素)。 如果树中有不止一个众数,可以按 任意顺序 返回。 假定 BST 满足如下定义&…

存储资源调优技术——智能缓存分区

SmartPratition智能缓存分区 基本概念 本质上就是一种Cache分区技术 通过对系统核心资源的分区(隔离不同业务所需要的缓存资源),保证关键应用的性能 工作原理 用户可以以LUN或文件系统为单位设置SmartPartition分区 每个SmartPartition分区的…

Qt文件系统源码分析—第二篇QSaveFile

范围 深度 首先指定深度分析深度,否者会陷入代码海洋之中。 本文只分析到Win32 API/Windows Com组件/STL库函数层次,再下层代码不做探究 本文主要了解QSaveFile及其具体实现,使用到父类数据的地方只讨论关键点 QT Private类 大部分Qt类有…

基础篇-设计模式

单例模式: 注意:这里的唯一实例不是使用时候才创建,而是构造时候就会创建; 注意:提前创建了对象,并不是调用时候才创建 解决方法: 枚举饿汉单例: 注意: 饿汉式枚举不会通过反序列化破坏单例 懒汉模式&…

SQL笔记(3)——MySQL数据类型

学习MySQL,通常应该是先学习数据类型的,因为不管是开发还是MySQL中,每个数据对象都有其对应的数据类型,MySQL提供了丰富的数据类型,如在创建表的时候就需要指定列的数据类型,在向表中插入数据时&#xff0c…

ElasticSearch(一)下载及安装(windows)

1. 官网 ElasticSearch官网地址ElasticSearch生态组件下载地址Kibana下载地址ik中文分词插件 备注:网址打不开,或者打开速度慢是正常情况。 2. 解压后目录结构 bin :脚本文件,包括启动elasticsearch,安装插件&#…

目录打开显示提示文件或目录损坏且无法读取、文件或目录损坏且无法读取的破解之道

咱们在平日工作时,通常都会将资料放进不同的目录中,方便咱们找到,随着时间的推移就会产生有越来越多目录。最近有位用户了这样一个问题,就是目录无论怎么都无法打开,这样就无法浏览、使用里面的资料了,影响…

springboot sharding-jdbc 主从 读写分离

目录 1 mysql 主从搭建 1.1 docker mysql 主从搭建 1.2 非docker mysql 主从搭建 2 springboot sharding-jdbc 主从 读写分离 2.1 pom 加依赖 2.1 yml 配置文件 3 测试 -> 直接使用 就是读写分离 3.1 实体类User -> 数据字段 对象字典 3.2 Mapper -> 增删改查…

Nomogram | 盘点一下绘制列线图的几个R包!~(二)

1写在前面 不知道各位小伙伴的五一假期过的在怎么样,可怜的我感冒了。😷 今天继续之前没有写完的列线图教程吧,再介绍几个制作列线图的R包。🤠 2用到的包 rm(list ls())library(tidyverse)library(survival)library(rms)library(…

新闻文本关键词提取有哪些算法,这些算法的特点以及应用,以及不足方面的解决办法

目录 一、新闻文本关键词提取算法 1. TF-IDF(Term Frequency-Inverse Document Frequency)算法 2. TextRank算法 3. 词向量算法 4. 深度学习算法 5. 主题模型算法 二、这些算法的不足方面的解决办法 1. TF-IDF算法: 2. TextRank算法&…