深度学习 GNN图神经网络(三)模型思想及文献分类案例实战

news2025/1/8 4:39:09

如果你有一定神经网络的知识基础,想学习GNN图神经网络,可以按顺序参考系列文章:
深度学习 GNN图神经网络(一)图的基本知识
深度学习 GNN图神经网络(二)PyTorch Geometric(PyG)安装
深度学习 GNN图神经网络(三)模型原理及文献分类案例实战

一、前言

本文介绍GNN图神经网络的思想原理,然后使用Cora数据集对其中的2708篇文献进行分类。用普通的神经网络与GNN图神经网络分别实现,并对比两者之间的效果。

二、总体思想

GNN的作用就是对节点进行特征提取,可以看下这个几分钟的视频《简单粗暴带你快速理解GNN》。
比如说这里有一张图,包含5个节点,每个节点有三个特征值:
在这里插入图片描述
节点A的特征值 x a = [ 1 , 1 , 1 ] x_a=[1,1,1] xa=[1,1,1],节点B的特征值 x b = [ 2 , 2 , 2 ] x_b=[2,2,2] xb=[2,2,2]

我们依次对所有节点的特征值进行更新:
新的信息=自身的信息 + 所有邻居点的信息
所有邻居点信息的表达有几种:

  • 求和Sum
  • 求平均Mean
  • 求最大Max
  • 求最小Min

我们以求和为例:
x ^ a = σ ( w a x a + w b x b + w c x c ) \hat{x}_a=\sigma(w_ax_a+w_bx_b+w_cx_c) x^a=σ(waxa+wbxb+wcxc)
x ^ b = σ ( w b x b + w a x a ) \hat{x}_b=\sigma(w_bx_b+w_ax_a) x^b=σ(wbxb+waxa)
x ^ c = σ ( w c x c + w a x a + w d x d ) \hat{x}_c=\sigma(w_cx_c+w_ax_a+w_dx_d) x^c=σ(wcxc+waxa+wdxd)
x ^ d = σ ( w d x d + w a x a + w c x c ) \hat{x}_d=\sigma(w_dx_d+w_ax_a+w_cx_c) x^d=σ(wdxd+waxa+wcxc)
x ^ e = σ ( w e x e + w d x d ) \hat{x}_e=\sigma(w_ex_e+w_dx_d) x^e=σ(wexe+wdxd)
其中, w w w是各自节点的权重参数, σ \sigma σ是激活函数。

求所有邻居点信息的操作叫做消息传递(或信息聚合)
整个特征值更新过程叫做图卷积(跟CNN卷积神经网络中的卷积是两回事),整个神经网络叫做图卷积网络(GCN)。

在经历第一次更新操作后:
A中有B、C、D的信息;
B中有A的信息;
C中有A、D的信息;
D中有A、C、E的信息;
E中有D的信息;

在经历第二次更新操作后:
A中有B、C、D、E的信息;
⋮ \vdots
E中有A、C、D、E的信息;

如此循环,节点逐渐包含更多其他节点的信息,只是权重不同。

PS:过年了,这段写得有点仓促,如有错误恳请纠正。作者也会在这留下TODO,后续参考更多的资料进行校验纠正。祝兔年快乐~ 😃

三、数据集介绍

Cora数据集由2708篇机器学习论文组成。 这些论文分为七类:

  1. 基于案例
  2. 遗传算法
  3. 神经网络
  4. 概率方法
  5. 强化学习
  6. 规则学习
  7. 理论

每个论文样本包含1433个特征值,由0/1组成,表示论文内容是否包含某关键字。
数据集中的边表示论文引用关系。

四、实战案例

4.1、引入数据集

我们首先引入Cora数据集,看看图数据集的格式:

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# 手动下载https://gitee.com/jiajiewu/planetoid
# 或者https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
dataset=Planetoid(root="./data/Planetoid",name='Cora',transform=NormalizeFeatures())
print(f'num_features={dataset.num_features}')
print(f'num_classes={dataset.num_classes}')
print(dataset.data)
print(dataset.data.edge_index.T)

输出结果:

num_features=1433
num_classes=7
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
tensor([[   0,  633],
        [   0, 1862],
        [   0, 2582],
        ...,
        [2707,  598],
        [2707, 1473],
        [2707, 2706]])

num_features=1433:有1433个特征值
num_classes=7:有7种类型
x=[2708,1433]:数据包含2708篇论文,每篇论文有1433个特征值
edge_index=[2, 10556]:每条边连接两篇论文,存在10556条边,即论文间有10556次引用关系
y=[2708]:有2708个标签(0-6)

4.2 多层感知器分类测试

首先,我们使用多层感知器,即普通的神经网络进行分类测试。
定义网络模型:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt


class MLP(nn.Module):
    
    def __init__(self):
        # 初始化Pytorch父类
        super().__init__()
        
        # 定义神经网络层
        self.model = nn.Sequential(
            nn.Linear(dataset.num_features, 16),
            nn.ReLU(),
            nn.Linear(16, dataset.num_classes),
        )
        
        # 创建损失函数,使用交叉熵误差
        self.loss_function = nn.CrossEntropyLoss()

        # 创建优化器,使用Adam梯度下降
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.01,weight_decay=5e-4)

        # 训练次数计数器
        self.counter = 0
        # 训练过程中损失值记录
        self.progress = []
    
    # 前向传播函数
    def forward(self, inputs):
        return self.model(inputs)
    
    # 训练函数
    def train(self, inputs, targets):
        # 前向传播计算,获得网络输出
        outputs = self.forward(inputs)
        
        # 计算损失值
        loss = self.loss_function(outputs[dataset.data.train_mask], targets)

        # 累加训练次数
        self.counter += 1

        # 每10次训练记录损失值
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())

        # 每10000次输出训练次数   
        if (self.counter % 100 == 0):
            print(f"counter={self.counter}, loss={loss.item()}")
            
        # 梯度清零, 反向传播, 更新权重
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
    
    # 测试函数
    def test(self, inputs, targets):
        # 前向传播计算,获得网络输出
        outputs = self.forward(inputs)
        
        pred=outputs.argmax(dim=1)
        test_correct=pred[dataset.data.test_mask]==targets
        return (test_correct.sum()/dataset.data.test_mask.sum()).item()

    # 绘制损失变化图
    def plot_progress(self):
        plt.plot(range(100),self.progress)
    

迭代训练:

M = MLP()
for i in range(1000):
    M.train(dataset.data.x,dataset.data.y[dataset.data.train_mask])

运行结果:

counter=100, loss=0.0084211565554142
counter=200, loss=0.0063483878038823605
counter=300, loss=0.0051103029400110245
counter=400, loss=0.004452046472579241
counter=500, loss=0.0040738522075116634
counter=600, loss=0.0038454567547887564
counter=700, loss=0.003702200250700116
counter=800, loss=0.0036090961657464504
counter=900, loss=0.0035553970374166965
counter=1000, loss=0.0035170542541891336

输出损失值变化图:

M.plot_progress()

在这里插入图片描述
测试结果:

M.test(dataset.data.x,dataset.data.y[dataset.data.test_mask])

运行结果:

0.5730000138282776

可以看到,准确率大概为57.3%,效果比较差。

4.3 GNN分类测试

现在我们构建GNN图神经网络进行分类测试。

import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
import matplotlib.pyplot as plt

class GNN(nn.Module):
    
    def __init__(self):
        # 初始化Pytorch父类
        super().__init__()
        
        # 定义神经网络层,torch_geometric有自己的Sequential实现
        # 报错信息https://github.com/pyg-team/pytorch_geometric/discussions/3726
        # 见https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.sequential.Sequential
        # self.model = nn.Sequential(
        #     GCNConv(dataset.num_features, 16),
        #     nn.ReLU(),
        #     GCNConv(16, dataset.num_classes),
        # )

        self.conv1=GCNConv(dataset.num_features, 16)
        self.conv2=GCNConv(16, dataset.num_classes)
        
        # 创建损失函数,使用交叉熵误差
        self.loss_function = nn.CrossEntropyLoss()

        # 创建优化器,使用Adam梯度下降
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.01,weight_decay=5e-4)

        # 训练次数计数器
        self.counter = 0
        # 训练过程中损失值记录
        self.progress = []
    
    # 前向传播函数
    def forward(self, x, edge_index):
        # return self.model(x, edge_index)
        x=self.conv1(x,edge_index)
        x=x.relu()
        x=self.conv2(x, edge_index)
        return x
    
    # 训练函数
    def train(self, x, edge_index, targets):

        # 前向传播计算,获得网络输出
        outputs = self.forward(x, edge_index)
        
        # 计算损失值
        loss = self.loss_function(outputs[dataset.data.train_mask], targets)

        # 累加训练次数
        self.counter += 1

        # 每10次训练记录损失值
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())

        # 每10000次输出训练次数   
        if (self.counter % 100 == 0):
            print(f"counter={self.counter}, loss={loss.item()}")
            
        # 梯度清零, 反向传播, 更新权重
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
    
    # 测试函数
    def test(self, x, edge_index, targets):
        # 前向传播计算,获得网络输出
        outputs = self.forward(x, edge_index)
        
        pred=outputs.argmax(dim=1)
        test_correct=pred[dataset.data.test_mask]==targets
        return (test_correct.sum()/dataset.data.test_mask.sum()).item()

    # 绘制损失变化图
    def plot_progress(self):
        plt.plot(range(100),self.progress)
    

迭代训练:

G = GNN()
for i in range(1000):
    G.train(dataset.data.x,dataset.data.edge_index,dataset.data.y[dataset.data.train_mask])

运行结果:

counter=100, loss=0.01617591269314289
counter=200, loss=0.010460852645337582
counter=300, loss=0.008510907180607319
counter=400, loss=0.007648027036339045
counter=500, loss=0.007218983490020037
counter=600, loss=0.006993760820478201
counter=700, loss=0.0068700965493917465
counter=800, loss=0.006797503679990768
counter=900, loss=0.006750799715518951
counter=1000, loss=0.006724677048623562

输出损失值变化图:

G.plot_progress()

在这里插入图片描述
测试结果:

G.test(dataset.data.x,dataset.data.edge_index,dataset.data.y[dataset.data.test_mask])

运行结果:

0.8059999942779541

可以看到,准确率大概为80.6%,效果好了很多。

五、参考资料

简单粗暴带你快速理解GNN
【唐博士带你学AI】图神经网络

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/174841.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nginx入门与应用

NginxNginx概述Nginx介绍Nginx下载和安装windowsLinuxNginx目录结构Nginx命令查看版本检查配置文件正确性启动和停止重新加载配置文件Nginx环境变量(Linux)Nginx配置文件结构Nginx具体应用部署静态资源反向代理负载均衡Nginx概述 Nginx介绍 Nginx是一款…

Linux系统——基础IO

要努力,但不要着急,繁花锦簇,硕果累累,都需要过程! 目录 1.文件基础必备概念 2.文件系统调用接口 1.open && close 2.write 3.read 3.文件描述符fd 3.1什么是文件描述符 3.2文件描述符意义 3.3文件描述符的分配…

【C++】map和set的模拟实现

​🌠 作者:阿亮joy. 🎆专栏:《吃透西嘎嘎》 🎇 座右铭:每个优秀的人都有一段沉默的时光,那段时光是付出了很多努力却得不到结果的日子,我们把它叫做扎根 目录👉红黑树的…

一个线程如何处理多个连接?(非阻塞IO)

从BIO到NIO的转变 五种IO模型BIO的缺陷非阻塞非阻塞IO(NIO)非阻塞读非阻塞写非阻塞IO模型php NIO 实现适用场景什么是C10k问题?C10K问题的由来五种IO模型 在《UNIX 网络编程》一书中介绍了五种IO模型: 分别是 BIO,NIO…

无线电基础电路 > RLC阻尼系数计算仿真

随机搭建电路如下图所示&#xff1a; 阻尼系数的希腊字母符号“ ζ ”读作泽塔。 阻尼系数ζ (R/2) * √C/L 1000/2 * √0.00001 1.58 包括三种情况&#xff1a; ζ>1&#xff1a;过阻尼&#xff0c;频响不利落&#xff0c;需要较长时间才能消失。 ζ<1&#xff…

MinIO基本使用(实现上传、下载功能)

MinIO基本使用&#xff08;实现上传、下载功能&#xff09;1.简介2.下载和安装3.启动服务端4.创建User和Bucket4.1 创建User4.1.1 生成accessKey和secretKey4.2 创建Bucket5.在SpringBoot中使用MinIO5.1 引入依赖5.2 配置文件定义5.3 定义实体类5.4 定义业务类5.5 定义测试类5.…

vivado中block design遇到的error总结

Error1.[BD 41-1356] Address block </processing_system7_0/S_AXI_HP0/HP0_DDR_LOWOCM> is not mapped into </axi_vdma_0/Data_MM2S>. Please use Address Editor to either map or exclude it. 修改方法. a、点击Address Editor. b、在Address Editor页面右击失…

【Ajax】了解Ajax与jQuery中的Ajax

一、了解Ajax什么是AjaxAjax 的全称是 Asynchronous Javascript And XML&#xff08;异步 JavaScript 和 XML&#xff09;。通俗的理解&#xff1a;在网页中利用 XMLHttpRequest 对象和服务器进行数据交互的方式&#xff0c;就是Ajax。2. 为什么要学Ajax之前所学的技术&#xf…

使用MQTT fx测试云服务器的 mosquitto 通讯

文章目录一.MQTT.fx介绍二.MQTT.fx安装教程三.使用MQTT.fx测试云服务器的 mosquitto 通讯一.MQTT.fx介绍 MQTT.fx是一款基于Eclipse Paho&#xff0c;使用Java语言编写的MQTT客户端工具。支持通过Topic订阅和发布消息&#xff0c;用来前期和物理云平台调试非常方便。 二.MQTT…

【数据结构——顺序表的实现】

前言&#xff1a; 在之前我们已经对复杂度进行的相关了解&#xff0c;因此现在我们将直接进入数据结构的顺序表的相关知识的学习。 目录1.线性表2.顺序表2.1概念及结构2.2 接口实现2.2.1.打印顺序表2.2.2初始化顺序表2.2.3.容量的检查2.2.4.销毁顺序表2.2.5.尾插操作2.2.6.尾删…

Ubuntu下的LGT8F328P MiniEVB Arduino开发和烧录环境

基于 LGT8F328P LQFP32 的 Arduino MiniEVB, 这个板型资料较少, 记录一下开发环境和烧录过程以及当中遇到的问题. 关于 LGT8F328P 芯片参数 8位RISC内核32K字节 Flash, 2K字节 SRAM最大支持32MHz工作频率 集成32MHz RC振荡器集成32KHz RC振荡器 SWD片上调试器工作电压: 1.8V…

C语言文件操作(3)

TIPS 1. 文件是不是二进制文件&#xff0c;不是后缀说了算&#xff0c;而是内容说了算 2. 文件的随机读写 文件的随机读写也就是说我指哪打哪 fseek() 人为调整指针指向的位置 1. 根据文件指针FILE*的当前位置和你给出的偏移量来让它这个文件指针呢定位到你想要的位置上…

Flutter 这一年:2022 亮点时刻

回看 2022&#xff0c;展望 Flutter Forward 2022 年&#xff0c;我们非常兴奋的看到 Flutter 社区持续发展壮大&#xff0c;也因此让更多人体验到了令人难以置信的体验。每天有超过 1000 款使用 Flutter 的新移动应用发布到 App Store 和 Google Play&#xff0c;Web 平台和桌…

实战打靶集锦-002-SolidState

**写在前面&#xff1a;**谨以此文纪念不完美的一次打靶经历。 目录1. 锁定主机与端口2. 服务枚举3. 服务探查3.1 Apache探查3.1.1 浏览器手工探查3.1.2 目录枚举3.2 JAMES探查3.2.1 搜索公共EXP3.2.2 EXP利用3.2.2.1 构建payload3.2.2.2 netcat构建反弹shell3.2.3 探查JAMES控…

三十一、Kubernetes中Service详解、实例第一篇

1、概述 在kubernetes中&#xff0c;pod是应用程序的载体&#xff0c;我们可以通过pod的ip来访问应用程序&#xff0c;但是pod的ip地址不是固定的&#xff0c;这也就意味着不方便直接采用pod的ip对服务进行访问。 为了解决这个问题&#xff0c;kubernetes提供了Service资源&…

NX二开ufun函数UF_MODL_ask_curve_points(获取曲线信息)

根据曲线tag&#xff0c;返回曲线相关信息&#xff1a;弦宽容、弧度、最大步长、点数组的点。 实例返回结果截图如下&#xff1a; 实例创建曲线截图如下&#xff1a; 1、函数结构 int UF_MODL_ask_curve_points &#xff08;tag_t curve_id&#xff0c; double ctol&#xf…

【SpringCloud19】SpringCloud Alibaba Sentinel实现熔断与限流

1.概述 官网 中文文档 1.1 是什么 一句话解释&#xff0c;之前我们讲解过的Hystrix 1.2 怎么下 下载网址 1.3 作用 1.4 如何使用 官网学习 服务使用中的各种问题&#xff1a; 服务雪崩服务降级服务熔断服务限流 2.安装Sentinel控制台 2.1 组成部分 核心库&#x…

Golang之实战篇(1)

"千篇一律&#xff0c;高手寂寞。几十不惑&#xff0c;全都白扯"上篇介绍了golang这门新的语言的一些语法。那么我们能用golang简单地写些什么代码出来呢&#xff1f;一、猜数字这个游戏的逻辑很简单。系统随机给你生成一个数&#xff0c;然后读取你猜的数字&#xf…

老杨说运维 | AIOps如何助力实现全面可观测性(上)

前言&#xff1a; 嗨&#xff0c;今天是大年三十&#xff0c;大家是不是已经在家坐享团圆之乐了&#xff1f;还是说在奔向团圆的路上呢&#xff1f;不论如何&#xff0c;小编先祝大家新年如意安康&#xff0c;平安顺遂~ 熟悉我们的朋友肯定都知道&#xff0c;关于《老杨说运维…

30.字符串处理函数

文章目录1.测字符串长度函数2.字符串拷贝函数1.strcpy函数2.strncpy函数3.字符串追加函数1.strcat函数2.strncat函数4.字符串比较函数1.strcmp函数2.strncmp函数5.字符查找函数1.strchr函数2.strrchr函数6.字符串匹配函数7.空间设定函数8.字符串转换数值9.字符串切割函数strtok…