图卷积神经网络分类的pytorch实现

news2024/10/5 11:54:21

  图神经网络(GNN)目前的主流实现方式就是节点之间的信息汇聚,也就是类似于卷积网络的邻域加权和,比如图卷积网络(GCN)、图注意力网络(GAT)等。下面根据GCN的实现原理使用Pytorch张量,和调用torch_geometric包,分别对Cora数据集进行节点分类实验。

  Cora是关于科学文献之间引用关系的图结构数据集。数据集包含一个图,图中包括2708篇文献(节点)和10556个引用关系(边)。其中每个节点都有一个1433维的特征向量,即文献内容的嵌入向量。文献被分为七个类别:计算机科学、物理学等。

GCN计算流程

  对于某个GCN层,假设输入图的节点特征为$X\in R^{|V|\times F_{in}}$,边索引表示为序号数组$Ei\in R^{2\times |E|}$,GCN层输出$Y\in R^{|V|\times F_{out}}$。计算流程如下:

  0、根据$Ei$获得邻接矩阵$A_0\in R^{|V|\times |V|}$。

  1、为了将节点自身信息汇聚进去,每个节点添加指向自己的边,即 $A=A_0+I$,其中$I$为单位矩阵。

  2、计算度(出或入)矩阵 $D$,其中 $D_{ii}=\sum_j A_{ij}$ 表示第 $i$ 个节点的度数。$D$为对角阵。

  3、计算对称归一化矩阵 $\hat{D}$,其中 $\hat{D}_{ii}=1/\sqrt{D_{ii}}$。

  4、构建对称归一化邻接矩阵 $\tilde{A}$,其中 $\tilde{A}= \hat{D} A \hat{D}$。

  5、计算节点特征向量的线性变换,即 $Y = \tilde{A} X W$,其中 $X$ 表示输入的节点特征向量,$W\in R^{F_{in}\times F_{out}}$ 为GCN层中待训练的权重矩阵。

  即:

$Y=D^{-0.5}(A_0+I)D^{-0.5}XW$

  在torch_geometric包中,normalize参数控制是否使用度矩阵$D$归一化;cached控制是否缓存$D$,如果每次输入都是相同结构的图,则可以设置为True,即所谓转导学习(transductive learning)。另外,可以看到GCN的实现只考虑了节点的特征,没有考虑边的特征,仅仅通过聚合引入边的连接信息。

GCN实验

调包实现

  Cora的图数据存放在torch_geometric的Data类中。Data主要包含节点特征$X\in R^{|V|\times F_v}$、边索引$Ei\in R^{2\times |E|}$、边特征$Ea\in R^{|E|\times F_e}$等变量。首先导出Cora数据:

from torch_geometric.datasets import Planetoid

cora = Planetoid(root='./data', name='Cora')[0]
print(cora)

  构建GCN,训练并测试。

import torch
from torch import nn
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch.optim import Adam


class GCN(nn.Module):
  def __init__(self, in_channels, hidden_channels, class_n):
    super(GCN, self).__init__()
    self.conv1 = GCNConv(in_channels, hidden_channels)
    self.conv2 = GCNConv(hidden_channels, class_n)

  def forward(self, x, edge_index):
    x = torch.relu(self.conv1(x, edge_index))
    x = torch.dropout(x, p=0.5, train=self.training)
    x = self.conv2(x, edge_index)
    return torch.log_softmax(x, dim=1)

model = GCN(cora.num_features, 16, cora.y.unique().shape[0]).to('cuda')
opt = Adam(model.parameters(), 0.01, weight_decay=5e-4)

def train(its):
  model.train()
  for i in range(its):
    y = model(cora.x, cora.edge_index)
    loss = F.nll_loss(y[cora.train_mask], cora.y[cora.train_mask])
    loss.backward()
    opt.step()
    opt.zero_grad()

def test():
  model.eval()
  y = model(cora.x, cora.edge_index)
  right_n = torch.argmax(y[cora.test_mask], 1) == cora.y[cora.test_mask]
  acc = right_n.sum()/cora.test_mask.sum()
  print("Acc: ", acc)

for i in range(15):
  train(1)
  test()

  仅15次迭代就收敛,测试精度如下:

张量实现

  主要区别就是自定义一个My_GCNConv来代替GCNConv,My_GCNConv定义如下:

from torch import nn
from torch_geometric.utils import to_dense_adj

class My_GCNConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(My_GCNConv, self).__init__()
    self.weight = torch.nn.Parameter(nn.init.xavier_normal(torch.zeros(in_channels, out_channels)))
    self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
  
  def forward(self, x, edge_index):
    adj = to_dense_adj(edge_index)[0]
    adj += torch.eye(x.shape[0]).to(adj)
    dgr = torch.diag(adj.sum(1)**-0.5)
    y = torch.matmul(dgr, adj)
    y = torch.matmul(y, dgr)
    y = torch.matmul(y, x)
    y = torch.matmul(y, self.weight) + self.bias
    return y

  其它代码仅将GCNConv修改为My_GCNConv。

对比实验

MLP实现

  下面不使用节点之间的引用关系,仅使用节点特征向量在MLP中进行实验,来验证GCN的有效性。

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam

class MLP(nn.Module):
  def __init__(self, in_channels, hidden_channels, class_n):
    super(MLP, self).__init__()
    self.l1 = nn.Linear(in_channels, hidden_channels)
    self.l2 = nn.Linear(hidden_channels, hidden_channels)
    self.l3 = nn.Linear(hidden_channels, class_n)

  def forward(self, x):
    x = torch.relu(self.l1(x))
    x = torch.relu(self.l2(x))
    x = torch.dropout(x, p=0.5, train=self.training)
    x = self.l3(x)
    return torch.log_softmax(x, dim=1)

model = MLP(cora.num_features, 512, cora.y.unique().shape[0]).to('cuda')
opt = Adam(model.parameters(), 0.01, weight_decay=5e-4)

def train(its):
  model.train()
  for i in range(its):
    y = model(cora.x[cora.train_mask])
    loss = F.nll_loss(y, cora.y[cora.train_mask])
    loss.backward()
    opt.step()
    opt.zero_grad()

def test():
  model.eval()
  y = model(cora.x[cora.test_mask])
  right_n = torch.argmax(y, 1) == cora.y[cora.test_mask]
  acc = right_n.sum()/cora.test_mask.sum()
  print("Acc: ", acc)

for i in range(15):
  train(30)
  test()

  可以看出MLP包含了3层,并且隐层参数比GCN多得多。结果如下:

  精度收敛在57%左右,效果比GCN的79%差。说明节点之间的链接关系对节点类别的划分有促进作用,以及GCN的有效性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/360528.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java函数式接口

3 函数式接口 3.1 函数式接口概述 函数式接口:有且仅有一个抽象方法的接口 Java中的函数式编程体现就是Lambda表达式,所以函数式接口就是可以适用于Lambda使用的接口只有确保接口中有且仅有一个抽象方法, Java中的Lambda才能顺利地进行推导…

不容错过!飞桨深度学习与大模型产业应用专场24日等你来!

人工智能教父Hinton曾评价,“深度学习将无所不能”,从聊天机器人、自动驾驶到语音助手,深度学习早已在不知不觉中渗透进我们的生活,而AI大模型又是一项深度学习技术的新突破。深度学习、大模型作为人工智能发展的重要方向&#xf…

前端开发项目规范写法介绍

1. 基本原则 结构、样式、行为分离 尽量确保文档和模板只包含 HTML 结构,样式都放到样式表里,行为都放到脚本里。 缩进 统一两个空格缩进(总之缩进统一即可),不要使用 Tab 或者 Tab、空格混搭。 文件编码 使用不带 BOM 的 UTF-8 编码。 在 HTML中指定编码 <meta c…

C# 利用FluentFTP实现FTP上传下载功能

FTP作为日常工作学习中&#xff0c;非常重要的一个文件传输存储空间&#xff0c;想必大家都非常的熟悉了&#xff0c;那么如何快速的实现文件的上传下载功能呢&#xff0c;本文以一个简单的小例子&#xff0c;简述如何通过FluentFTP实现文件的上传和下载功能。仅供学习分享使用…

c++提高篇——queque容器

一、queque容器基本概念 Queue是一种先进先出(FIFO)的教据结构&#xff0c;它有两个出口 队列容器允许从一端新增元素&#xff0c;从另一端移除元素。队列中只有队头和队尾才可以被外界使用&#xff0c;因此队列不允许有遍历行为队列中进数据。 queque容器可以形象化为生活中…

第一个Java程序(初识Java)

个人主页&#xff1a;平行线也会相交 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 平行线也会相交 原创 收录于专栏【JavaSE_primary】 文章目录1.Java概述1.1什么是Java1.2Java之父2.0第一个Java程序编译运行.class3.0程序如何跑起来的&#xff1f;3.1J…

39、基于51单片机声控光控灯人体感应路灯照明灯系统设计

摘 要 随着社会的不断进步&#xff0c;人们对低碳生活逐步认识和接受&#xff0c;并从很多方面开始关注&#xff0c;尤其是在环保上做出了很多努力。利用声音和光线的强弱来控制开关的断开或者闭合的电子产品来能够有效的降低能耗&#xff0c;节约能源。它不仅适用于住宅区的…

CMake构建静态库与动态库以及使用

CMake构建静态库与动态库一、任务二、准备工作三、编译共享库四、ADD_LIBRARY指令五、编译静态库5.1、SET_TARGET_PROPERTIES指令5.2、GET_TARGET_PROPERTY指令六、动态库版本号七、安装共享库和头文件八、使用外部共享库和头文件8.1、准备工作8.2、引入头文件搜索路径8.3、为 …

leetcode 11~20 学习经历

LeetCode 习题 11 - 2011. 盛最多水的容器12. 整数转罗马数字13. 罗马数字转整数14. 最长公共前缀15. 三数之和16. 最接近的三数之和17. 电话号码的字母组合18. 四数之和19. 删除链表的倒数第 N 个结点20. 有效的括号小结11. 盛最多水的容器 给定一个长度为 n 的整数数组 heigh…

【Servlet篇】Request请求转发详细解读

文章目录1. 前言2. 实战案例3. 特点1. 前言 请求转发是一种在服务器内部的资源跳转方式&#xff0c;如图&#xff1a; 上图的大致过程为&#xff0c;浏览器发送请求给服务器&#xff0c;服务器中 a 资源接收到请求&#xff0c;资源 a 处理完请求后将请求发送给资源 b&#xff…

cdr最新2023版本发布更新及CorelDraw功能介绍

CDR作为一款专业的平面设计软件&#xff0c;拥有着庞大的用户群体&#xff0c;而每年春天CorelDRAW新版本的发布也牵动着每一位小伙伴的心。CorelDraw2023近期刚刚发布本人就开始试用&#xff0c;感觉非常良好&#xff0c;特别给大家提出升级的N个理由!CorelDRAW2023最新版内置…

【基础算法】差分的应用(一维差分和二维差分)

&#x1f339;作者:云小逸 &#x1f4dd;个人主页:云小逸的主页 &#x1f4dd;Github:云小逸的Github &#x1f91f;motto:要敢于一个人默默的面对自己&#xff0c;强大自己才是核心。不要等到什么都没有了&#xff0c;才下定决心去做。种一颗树&#xff0c;最好的时间是十年前…

FPGA 20个例程篇:20.USB2.0/RS232/LAN控制并行DAC输出任意频率正弦波、梯形波、三角波、方波(一)

在最后一个例程中笔者精挑细选了一个较为综合性的项目实战&#xff0c;其中覆盖了很多知识点&#xff0c;也是从一个转产产品中所提炼出来的&#xff0c;所以非常贴近实战项目。 整个工程实现了用户通过对上位机PC端人机界面的操作&#xff0c;即可达到控制豌豆开发并行DAC输出…

Java---高级流

目录 一、转换流 &#xff08;1&#xff09;指定的字符集读写数据 二、序列化流和反序列化流 三、解压缩流和压缩流 &#xff08;1&#xff09;解压缩流 &#xff08;2&#xff09;压缩流 一&#xff1a;压缩文件 二&#xff1a;压缩文件夹 注&#xff1a;本文并未介绍J…

CVE-2023-24055 KeePass信息明文传输漏洞复现

前言 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;文章作者不为此承担任何责任。 如果文章中的漏洞出现敏感内容产生了部分影响&#xff0c;请及时联系作者&#xff0c;望谅解。 一、漏洞描述 漏洞简述 Kee…

STM32 SystemInit()函数学习总结

拿到程序后如何看系统时钟&#xff1f;User文件夹——system_stm32f4xx程序&#xff0c;先找systemcoreclock(系统时钟&#xff09;但是这里这么多个系统时钟应该如何选择?点击魔法棒&#xff0c;然后点击C/C可以看到define的是F40_41XXX.USE这一款 &#xff0c;对应着就找出了…

R语言、MaxEnt模型融合技术的物种分布模拟、参数优化方法、结果分析制图与论文写作

基于R语言、MaxEnt模型融合技术的物种分布模拟、参数优化方法、结果分析制图与论文写作技术应用第一章、理论篇以问题导入的方式&#xff0c;深入掌握原理基础什么是MaxEnt模型&#xff1f;MaxEnt模型的原理是什么&#xff1f;有哪些用途&#xff1f;MaxEnt运行需要哪些输入文件…

对云原生集群网络流量可观测性的一点思考

问题背景 在云原生技术的广泛普及和实施过程中&#xff0c;笔者接触到的很多用户需求里都涉及到对云原生集群的可观测性要求。 实现集群的可观测性&#xff0c;是进行集群安全防护的前提条件 。而在可观测性的需求中&#xff0c;集群中容器和容器之间网络流量的可观测性需求是…

别错过!4C首发直播,上届全国总冠军带你入门赛题

和志同道合的伙伴并肩作战&#xff0c;用指尖敲出奇思妙想&#xff0c;飞桨黑客马拉松PaddlePadddle Hackathon第四期全新升级&#xff0c;开放报名啦&#xff01; 玩技术&#xff0c;秀操作&#xff01;这是一场高手云集的开发者盛会。四大赛道&#xff1a;核心框架开源贡献&…

Python每日一练(20230221)

目录 1. 不同路径 II 2. 字符串转换整数 (atoi) 3. 字符串相乘 1. 不同路径 II 一个机器人位于一个 m x n 网格的左上角 &#xff08;起始点在下图中标记为“Start” &#xff09;。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角&#xff08;在下图中…