图神经网络教程之GCN(pyG)

news2024/11/15 23:38:34

图神经网络-pyG版本的GCN

Data(数据)

data.xdata.edge_indexdata.edge_attrdata.ydata.pos

  • 举个例子
    在这里插入图片描述
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
#代表0-1 1-0 和 1-2 2-1 ,因为是无向图,所以有双向边
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
# 代表每个节点
data = Data(x=x, edge_index=edge_index)
>>> Data(edge_index=[2, 4], x=[3, 1])
# 数据构成

其中edge_index也可以这么构建

edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
  • 一些实用函数
print(data.keys())
>>> ['x', 'edge_index']
print(data['x'])
>>> tensor([[-1.0],
            [0.0],
            [1.0]])
for key, item in data:
    print(f'{key} found in data')
>>> x found in data
>>> edge_index found in data
'edge_attr' in data
>>> False
data.num_nodes
>>> 3
data.num_edges
>>> 4
data.num_node_features
>>> 1
data.has_isolated_nodes()
>>> False
data.has_self_loops()
>>> False
data.is_directed()
>>> False
# Transfer data object to GPU.
device = torch.device('cuda')
data = data.to(device)
  • 包含一些数据集
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
  • 数据转换

转换是torchvision中转换图像和执行增强的常见方式,pyG带有自己的转换。

#对ShapeNet数据集的转换。
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])
dataset[0]
>>> Data(pos=[2518, 3], y=[2518])

通过转换从点云生成最近邻图,将点云数据集转换为图数据集

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6))
dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])
  • 图表上的表示学习
  1. 导入所需的库和模块:

    • torch:PyTorch的主要库。
    • torch.nn.functional as F:PyTorch的神经网络函数模块,用于定义神经网络的层和操作。
    • torch_geometric.nn:PyTorch Geometric库中的神经网络模块,包括图卷积网络(GCN)的实现。
    • torch_geometric.datasets:PyTorch Geometric中的数据集模块,用于加载图数据集。
  2. 加载Cora数据集:

    dataset = Planetoid(root='/tmp/Cora', name='Cora')
    

    这行代码加载了Cora数据集,这是一个用于节点分类的图数据集。数据集将被下载到/tmp/Cora目录中。

  3. 定义了一个名为GCN的神经网络类:

    class GCN(torch.nn.Module):
    

    这个类继承自PyTorch的torch.nn.Module基类,表示它是一个神经网络模型。

  4. GCN类的构造函数中,定义了两个图卷积层(GCNConv):

    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)
    
    • GCNConv层是图卷积层,用于从图数据中提取特征。
    • self.conv1是第一个GCNConv层,它将输入特征的维度设置为dataset.num_node_features(Cora数据集中节点的特征维度)并输出16维特征。
    • self.conv2是第二个GCNConv层,将16维特征映射到数据集的类别数。
  5. 检查并设置GPU或CPU设备:

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    

    这段代码会检查你的系统是否有可用的GPU,并将device设置为GPU或CPU,以便在相应的设备上运行模型。

  6. 创建并将模型和数据移动到所选设备上:

    model = GCN().to(device)
    data = dataset[0].to(device)
    

    这将实例化之前定义的GCN模型,并将模型的参数和计算移动到GPU或CPU上。

  7. 定义优化器(这里使用Adam优化器):

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    

    这行代码创建一个Adam优化器,并将模型的参数传递给它,用于模型参数的更新。lr是学习率,weight_decay是L2正则化项的权重。

  8. 将模型设置为训练模式:

    model.train()
    

    这行代码将模型切换到训练模式,这对于启用训练特定的层(例如,dropout)非常重要。

  9. 开始训练循环,训练模型200个epoch:

    for epoch in range(200):
    

    这是一个训练循环,将模型训练200次。

  10. 在每个epoch中,首先将优化器的梯度清零:

    optimizer.zero_grad()
    

    这行代码用于清除之前的梯度信息,以准备计算新的梯度。

  11. 通过模型前向传播计算预测结果:

    out = model(data)
    

    这会将数据传递给你的GCN模型,然后返回模型的预测结果。

  12. 计算损失函数,这里使用负对数似然损失(Negative Log-Likelihood Loss):

    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    

    这行代码计算了在训练节点子集上的负对数似然损失。data.train_mask指定了用于训练的节点子集,data.y是节点的真实标签。

  13. 反向传播和参数更新:

    loss.backward()
    optimizer.step()
    

    这两行代码用于计算梯度并执行梯度下降,更新模型的参数,以最小化损失函数。

  14. 将模型设置为评估模式:

    model.eval()
    

    这行代码将模型切换到评估模式,以便在测试数据上进行预测。

  15. 在测试集上进行预测:

    pred = model(data).argmax(dim=1)
    

    这行代码用于在测试数据上进行预测,并找到每个节点最可能的类别。

  16. 计算模型的准确性:

    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())
    print(f'Accuracy: {acc:.4f}')
    

    这段代码计算了模型在测试集上的准确性,并打印出来。

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid

# 加载 Cora 数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')

# 定义 GCN 模型
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

# 检查并设置 GPU 或 CPU 设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 创建并将模型和数据移动到所选设备上
model = GCN().to(device)
data = dataset[0].to(device)

# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# 将模型设置为训练模式
model.train()

# 训练模型
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

# 将模型设置为评估模式
model.eval()

# 在测试集上进行预测
pred = model(data).argmax(dim=1)

# 计算模型的准确性
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
   optimizer.step()

# 将模型设置为评估模式
model.eval()

# 在测试集上进行预测
pred = model(data).argmax(dim=1)

# 计算模型的准确性
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/963868.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

博客系统 Java Web 开发(Servlet)

目录 一、准备工作 二、设计数据库 三、编写数据库代码 1、建表sql 2、封装数据库的连接操作 3、创建实体类 4、封装数据库的一些增删改查 (1)BlogDao 新增博客: 根据博客 id 来查询指定博客(用于博客详情页&#xff0…

【配置环境】Visual Studio 配置 OpenCV

目录 一,环境 二,下载和配置 OpenCV 三,创建一个 Visual Studio 项目 四,配置 Visual Studio 项目 五,编写并编译 OpenCV 程序 一,环境 Windows 11 家庭中文版Microsoft Visual Studio Community 2022…

【每日运维】RockyLinux8.6升级OpenSSH9.4p1

为什么需要升级openssh呢,因为很多项目进行漏扫结果都会涉及到这个服务器核心组件,一想到以前升级openssh带来的各种依赖性问题就头疼,不管是什么发行版,升级这个东西真的很烦,这次发现可能还会有好一点的通用一点的升…

Docker最简单的来部署前端vue打包好的h5代码

Docker最简单的来部署前端vue打包好的h5代码 前言 是不是想在服务器上部署好几个前端页面,并且也不想让各个页面之间进行隔离,还有就是想要一键部署,实时更新到服务区上,那这篇文章可能帮到您 这里也得选择一个软件叫Idea&#x…

Web3的新商业综合体——SMT震撼来袭!

SMT元宇宙应用生态平台,致力于打造一个Web3.0的新商业综合体。作为一个基础公链系统,SMT各项性能能够完全满足现在当下的各种应用,以及它们的部署。 用区块链技术和新的商业模式体现P2E并实现一个共建共享的理念,重塑大众生活的衣…

Python Qt学习(九)MainWindow

源代码: # -*- coding: utf-8 -*-# Form implementation generated from reading ui file qt_mainwindow.ui # # Created by: PyQt5 UI code generator 5.15.9 # # WARNING: Any manual changes made to this file will be lost when pyuic5 is # run again. Do n…

python面试题合集(一)

python技术面试题 1、Python中的幂运算 在python中幂运算是由两个 **星号运算的,实例如下: >>> a 2 ** 2 >>> a 4我们可以看到2的平方输出结果为4。 那么 ^指的是什么呢?我们用代码进行演示: >>>…

音频——I2S 右对齐模式(四)

I2S 基本概念飞利浦(I2S)标准模式左(MSB)对齐标准模式右(LSB)对齐标准模式DSP 模式TDM 模式 文章目录 I2S right时序图逻辑分析仪抓包 I2S right I2S 右对齐标准 也叫日本格式,sony 格式。相比于标准左对齐格式,标准右对齐的不足在于接收设备必须事先知…

奇舞周刊第 504 期:谷歌浏览器 Chrome 117 Beta 又上新功能,爱了爱了!

记得点击文章末尾的“ 阅读原文 ”查看哟~ 下面先一起看下本期周刊 摘要 吧~ 奇舞推荐 ■ ■ ■ 谷歌浏览器 Chrome 117 Beta 又上新功能,爱了爱了! Chrome 117 Beta 版本新增了 CSS 网格子网格 (subgrid)、入场和出场动画支持,以及 CSS、数组…

TiDB 一栈式综合交易查询解决方案获“金鼎奖”优秀金融科技解决方案奖

日前,2023“金鼎奖”评选结果揭晓, 平凯星辰(北京)科技有限公司研发的 TiDB 一栈式综合交易查询解决方案获“金鼎奖”优秀金融科技解决方案奖 , 该方案已成功运用于 多家国有大行、城商行和头部保险企业 。 此次获奖再…

企业名片如何制作二维码?一招教你在线制作二维码名片

想要制作企业二维码名片时要怎么操作呢?现在的企业为了节省资源都开始使用无纸化办公了。当一个企业想要使用电子版名片的时候应该怎么制作呢?可以将企业联系方式、邮箱、地址等做成二维码图片,扫码就能在线查看企业信息。这时候,…

【leetcode 力扣刷题】数学题之计算次幂//次方:快速幂

利用乘法求解次幂问题—快速幂 50. Pow(x, n)372. 超级次方 50. Pow(x, n) 题目链接:50. Pow(x, n) 题目内容: 题目就是要求我们去实现计算x的n次方的功能函数,类似c的power()函数。但是我们不能使用power()函数直接得到答案,那…

AI 也有冷静期?

阅读本文大概需要 1.31分钟。 上一次听说「冷静期」这个词,还是大家都在讨论「离婚冷静期」对时候,最近时常看到「冷静期」,缘于大家所讨论的 AI 热度下降事情,俗称「AI 冷静期」。 1、 上半年,每隔一段时间&#xff0…

无涯教程-JavaScript - NORMSDIST函数

NORMSDIST函数替代Excel 2010中的NORM.S.DIST函数。 描述 该函数返回标准正态累积分布函数。分布的平均值为0(零),标准偏差为1。使用此功能代替标准法线区域的表格。 语法 NORMSDIST (z)争论 Argument描述Required/OptionalZThe value for which you want the distributio…

【OpenCV入门】第八部分——滤波器

文章结构 图像平滑处理均值滤波器中值滤波器高斯滤波器双边滤波器拉普拉斯高通滤波器 图像平滑处理 图像平滑处理是指在尽量保留原图像信息的情况下,去除掉图像内部的噪声(分布不均匀的、高亮度的像素点)。而用于图像平滑处理的工具就是滤波…

Qt +VTK+Cmake 编译和环境配置(第一篇 采坑)

VTK下载地址:https://vtk.org/download/ cmake下载地址:https://cmake.org/download/ 版本对应方面,如果你的项目对版本没有要求,就不用在意。我就是自己随机搭建的,VTK选择最新版本吧,如果后面其他的库不…

HttPClient简介及示例:学习如何与Web服务器进行通信

文章目录 前言一、引入依赖二、使用步骤1.创建被调用者2.创建调用者三、结果被调用者服务:调用者服务: 总结 前言 欢迎来到本篇博客,这是一个关于HttPClient的入门案例的指南。🎉 在今天的网络世界中,与服务器进行数据…

数据挖掘导论学习笔记1(第1 、2章)

参考:https://blog.csdn.net/u013232035/article/details/48281659?spm1001.2014.3001.5506 和《数据挖掘导论》学习笔记(第1-2章)_时机性样本_schdut的博客-CSDN博客 第1章 绪论 数据挖掘是一种技术,它将传统的数据分析方法…

【LeetCode】剑指 Offer <二刷>(4)

目录 题目:剑指 Offer 09. 用两个栈实现队列 - 力扣(LeetCode) 题目的接口: 解题思路: 代码: 过啦!!! 题目:剑指 Offer 10- I. 斐波那契数列 - 力扣&am…

FFmpeg5.0源码阅读——FFmpeg大体框架(以GIF转码为示例)

摘要:前一段时间熟悉了下FFmpeg主流程源码实现,对FFmpeg的整体框架有了个大概的认识,因此在此做一个笔记,希望以比较容易理解的文字描述FFmpeg本身的结构,加深对FFmpeg的框架进行梳理加深理解,如果文章中有…