图神经网络:(节点分类)在KarateClub数据集上动手实现图神经网络

news2024/12/23 13:13:50

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。
3)我在百度网盘上传了这篇文章的jupyter notebook。超链。提取码8888。

文章目录

    • 文献阅读:
    • 代码实操:

文献阅读:

参考文献:SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS
中文翻译:用图神经网络进行半监督的分类
我在百度网盘上传这篇文献。超链。提取码8888。

文献首先:介绍了其他前辈的工作。在损失函数中使用拉普拉斯正则化项。公式如下(打这个公式真费劲,还的学Latex): L = L 0 + λ L r e g \mathcal{L}=\mathcal{L}_{0}+\lambda\mathcal{L}_{reg} L=L0+λLreg with L r e g = ∑ i , j A i , j ∣ ∣ f ( X i ) − f ( X j ) ∣ ∣ 2 = f ( X ) T Δ f ( X ) \mathcal{L}_{reg}=\sum_{i,j}{A}_{i,j}||\mathcal{f}({X}_{i})-\mathcal{f}({X}_{j})||^{2}=\mathcal{f}(X)^{T}\Delta\mathcal{f}(X) Lreg=i,jAi,j∣∣f(Xi)f(Xj)2=f(X)TΔf(X)
符号说明: L \mathcal{L} L表示为损失函数。 L 0 \mathcal{{L}_{0}} L0表示为有标签的损失(还有没标签的毕竟是半监督)。 λ \lambda λ表示为权重系数。 A i , j {A_{i,j}} Ai,j表示为图边。 f ( ⋅ ) \mathcal{f}(\cdot) f()表示为像神经网络的可微函数。 X X X表示为特征矩阵。 Δ = D − A \Delta=D-A Δ=DA表示为非规范化的拉普拉斯算子。 D D D表示为度的矩阵, D i , i = ∑ j A i , j D_{i,i}=\sum_{j}A_{i,j} Di,i=jAi,j
文章然后:简单说明使用上述公式需要有个假设:图中连接节点共享相同标签。于是作者这篇文章便就来了,为了解决这个问题,使用神经网络模型 f ( X , A ) f(X,A) f(X,A)编码图结构,避免使用显示基于图正则化。文章有两贡献,1.提出一种简单良好直接作用于图上的神经网络传播规则并且展示它是如何从谱图卷积的一阶逼近得到反馈。2.演示了基于图神经网络是如何分类的。
文章然后:具体开始阐述理论。 H l + 1 = σ ( D ~ − 1 2 A ~ D ~ − 1 2 H l W l ) H^{l+1}=\sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{l}W^{l}) Hl+1=σ(D~21A~D~21HlWl)。(知道核心公式就好,其他细节跳过因为我看不懂)
符号说明: D i , i = ∑ j A i , j D_{i,i}=\sum_{j}A_{i,j} Di,i=jAi,j表示为度的矩阵。 A ~ = A + I N \tilde{A}=A+I_{N} A~=A+IN表示为邻接矩阵加上一个单位矩阵。 W l W^{l} Wl表示为权重系数。 σ \sigma σ表示为激活函数。 H l H^{l} Hl为第 l l l层的特征矩阵。 H 0 H^{0} H0即为 X X X
文章然后:进行代码分类实操,他们这里搭建了两层GCN。所以最后的公式为 Z = f ( X , A ) = s o f t m a x ( A ^ R e l u ( A ^ X W 0 ) W 1 ) Z=f(X,A)=softmax(\widehat{A}Relu(\widehat{A}XW^{0})W^{1}) Z=f(X,A)=softmax(A Relu(A XW0)W1)。这里 A ^ = D ~ − 1 2 A ~ D ~ − 1 2 \widehat{A}=\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}} A =D~21A~D~21。损失函数就使用交叉熵 L = − ∑ l ∈ Y l ∑ f = 1 F Y l f ln ⁡ Z l f L=-\sum_{l \in \mathcal{Y}_{l}}\sum_{f=1}^FY_{lf}\ln{Z_{lf}} L=lYlf=1FYlflnZlf吧。
文章然后:介绍图半监督学习领域以及图上运行神经网络领域两个领域相关工作。
文章然后:进行实验展示结果。
文章然后:进行讨论。1.作者模型可以克服Skip-gram方法难以优化多步流程限制同时时间以及效果表现更好。2.未来工作1)解决内存:作者证明对于无法使用GPU大型图,用CPU是可行的。用小批量随机梯度可以缓解这个问题。但是生成小批量时应该考虑GCN的层数,对于非常大且密集连接的图可能需要进一步地近似。2)不支持有向图,但是有解决方法的(具体是什么我没看懂)3)考虑一个权衡参数 λ \lambda λ可能会有益。具体来说就是修改生成自循环图时用的 λ \lambda λ。即 A ~ = A + λ I \tilde{A}=A+\lambda I A~=A+λI
文章然后:得到结论。
文章最后:引用以及其他工作。1)WL-1算法2)深层的GCN。太深不好。
PS:以上仅是我的理解,我的理解可能不对。然后关于这个GCN以及WL算法,有两篇文章研究了它们,还是挺有趣的。我在百度网盘上传了这连篇文章。超链。提取码8888。

代码实操:

导入对应的库

import matplotlib.pyplot as plt
import networkx as nx

定义可视化的函数

def visualize_graph(G,color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G,pos=nx.spring_layout(G,seed=42),with_labels=False,node_color=color,cmap="Set2")
    plt.show()
#可视化图网络
def visualize_embedding(h,color,epoch=None,loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h=h.detach().cpu().numpy()
    plt.scatter(h[:,0],h[:,1],s=140,c=color,cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}',fontsize=16)
    plt.show()

导入对应的库:数据集1

from torch_geometric.datasets import KarateClub
dataset=KarateClub()

KarateClub数据集简单说明:34个人的社交网络,如果在俱乐部之外两人认识连一条边。然后由于俱乐部的内部冲突,人们选择站队所以分成两派。
打印数据集的信息

print(len(dataset),dataset.num_features,dataset.num_classes)
#输出:1 34 4

简单说明:num_features:33加上1。33指,这个节点与其他的33个节点是否有边,有边为1,无边为0。1是指度。num_classer:按理应该为2,但是官方做了修改,所以为4。

data=dataset[0]
#具体到确定的图上
print(data.num_nodes,data.num_edges,data,data.train_mask.sum().item())
#输出:34 156 Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34]) 4
print(data.has_isolated_nodes(),data.has_self_loops(),data.is_undirected())
#输出:False False True
edge_index=data.edge_index
print(edge_index.t())
#输出:不表

导入对应的库

from torch_geometric.utils import to_networkx

可视化图网络

G=to_networkx(data,to_undirected=True)
visualize_graph(G,color=data.y)

在这里插入图片描述
搭建模型GCN的框架

from torch_geometric.nn import GCNConv
from torch.nn import Linear
import torch
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=GCNConv(dataset.num_features,4)
        self.conv2=GCNConv(4,4)
        self.conv3=GCNConv(4,2)
        self.classifier=Linear(2,dataset.num_classes)
    def forward(self,x,edge_index):
        h=self.conv1(x,edge_index)
        h=h.tanh()
        h=self.conv2(h,edge_index)
        h=h.tanh()
        h=self.conv3(h,edge_index)
        h=h.tanh()
        out=self.classifier(h)
        return out,h
model=GCN()
print(model)
#输出
#GCN(
#  (conv1): GCNConv(34, 4)
#  (conv2): GCNConv(4, 4)
#  (conv3): GCNConv(4, 2)
#  (classifier): Linear(in_features=2, out_features=4, bias=True)
#)

简单说明: X v ( l + 1 ) = W ( l + 1 ) ∑ w ∈ N ( v ) ∪ { v } 1 c w , v ⋅ X w ( l ) X_{v}^{(l+1)}=W^{(l+1)}\sum_{w \in N(v)\cup{\{v\}}}\frac{1}{c_{w,v}}\cdot X_{w}^{(l)} Xv(l+1)=W(l+1)wN(v){v}cw,v1Xw(l)
可视化图嵌入(这里只有正向传播)

model=GCN()
_,h=model(data.x,data.edge_index)
visualize_embedding(h,color=data.y)

在这里插入图片描述
进行训练得出结果

model=GCN()
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
def train(data):
    optimizer.zero_grad()
    out,h=model(data.x,data.edge_index)
    loss=criterion(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss,h
for epoch in range(401):
    loss,h=train(data)
    if epoch==400:
        visualize_embedding(h,color=data.y,epoch=epoch,loss=loss)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/537239.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一个由“API未授权漏洞”引发的百万级敏感数据泄露

2023年4月的某一天,腾讯安全专家Leo正在为某家医院的重保防护做第一轮的安全风险排查。 医院的专用APP是外部网络访问最高的,也就是最大的风险敞口,需要重点排查。 Leo下载APP进行测试后,发现该医院存在一个严重的问题&#xff…

图像复原与重建MATLAB实验

文章目录 一、实验目的二、实验内容1. 噪声图像及其直方图。2. 空间噪声滤波器。3. 逆滤波。 一、实验目的 了解一些常用随机噪声的生成方法。掌握根据指定退化函数对图像进行退化的方法。掌握当模糊图像只存在噪声时的几种滤波复原方法。掌握当模糊图像同时存在线性退化和噪声…

学会搭建小程序生鲜商城,开启生鲜电商新模式

电商平台的出现,为人们带来了极大的便利。然而,传统的电商平台已经不能满足消费者对于购物体验的要求。如今,小程序生鲜商城因其轻量化、高效率等特点,成为了众多卖家的首选。本文将介绍如何学会搭建小程序生鲜商城,并…

二分特训上------刷题部分----Week4(附带LeetCode特训)

二分特训上------理论部分----Week4(附带LeetCode特训)_小杰312的博客-CSDN博客 如果需要理论,请移步上一篇. /***** 注意:我们把 0000001111111模型中:0称呼为左边区间,1称呼为右边区间 (答案第一个1在右区间) 1111…

浅谈Redis

一、Redis的简介 1.开源免费的缓存中间件,性能高,读可达110000次/s,写可达81000次/s。 2.redis的单线程讨论: V4.0之前:是单线程的,所有任务处理都在一个线程内完成. V4.0:引入多线程,异步线程用于处理一些耗…

机器学习基础学习之线性回归

文章目录 首先从**目标函数**开始梯度下降法结合两个公式,让目标函数梯度下降多项式回归,多重回归解决办法:随机梯度下降 首先从目标函数开始 假设下图反映了 投入多少广告费,产生了多少销售量的关系 图中每个点都是一个数据&a…

Spring Security

1、这是securityConfigpackage com.ruoyi.framework.config;import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.http.HttpMethod; import org.springframework.security.authe…

【JAVA】Java中方法的使用,理解方法重载和递归

目录 1.方法的概念及使用 1.1什么是方法 1.2方法的定义 1.3方法调用的执行过程 1.4实参和形参 2.方法重载 2.1为什么需要使用方法重载 2.2什么是方法重载 3.递归 3.1什么是递归 3.2递归执行的过程 3.3递归的使用 1.方法的概念及使用 1.1什么是方法 方法就是一个代…

消息队列:RabbitMQ

文章目录 消息队列(RabbitMQ)概念优势技术亮点可靠性灵活的路由集群联合高可用的队列多协议广泛的客户端可视化管理工具追踪插件系统 原理:AMQP 0-9-1 模型简介 消息队列(RabbitMQ) 概念 一种异步通信中间件 优势 消…

RHCSA 作业三

1. 2. [rootserver yum.repos.d]# mount /dev/sr0 /media mount: /media: /dev/sr0 已挂载于 /media. [rootserver yum.repos.d]# ls redhat.repo [rootserver yum.repos.d]# vim /etc/yum.repos.d/redhat.repo [rootserver yum.repos.d]# yum makecache 正在更新 Subscripti…

如何管理好团队的工时表?

工时表管理对所有团队来说都是一项具有挑战性的任务。它是确保每个团队成员高效工作并获得最大时间的关键工具。团队工时表是任何项目经理武器库中的一个重要工具。它们提供了对团队表现的宝贵见解。 一个成功的工时表管理系统对于希望最大限度提高生产力和利润的团队成员是必…

360+ChatGLM联手研发中国版“微软+OpenAI”

文章目录 前言360与智谱AI强强联合什么是智谱AI360智脑360GLM与360GPT大模型战略布局写在最后 前言 5月16日,三六零集团(下称“360”)与智谱AI宣布达成战略合作,双方共同研发的千亿级大模型“360GLM”已具备新一代认知智能通用模…

LLMs 诸神之战:LangChain ,以【奥德赛】之名

LLMs 一出,谁与争锋? 毫无疑问,大语言模型(LLM)掀起了新一轮的技术浪潮,成为全球各科技公司争相布局的领域。诚然,技术浪潮源起于 ChatGPT,不过要提及 LLMs 的技术发展的高潮&#x…

React的表单数据绑定

当我们在页面中使用表单提交数据时,react是如何拿取表单数据的呢 这里通过两种方式来实现 非受控组件实现 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" conte…

在Ubuntu 22.04 LTS Jammy Linux 系统上安装MySQL

在Ubuntu 22.04 LTS Jammy Linux 系统上安装MySQL 1. Update Apt Package Index2. Install MySQL Server & client on Ubuntu 22.043. To Check the version4. Run the Security script to secure MySQL5. Login Database Server as the root user6. Manage MySQL service7…

C-认识指针

认识指针 内容来自《深入理解C指针》 声明指针 在数据类型后面跟上星号*&#xff0c;如下的声明都是等价的 int* pi; int * pi; int *pi; int*pi;阅读声明 如下&#xff1a; const int *pci;1.pci是一个变量 const int *pci; 2.pci是一个指针变量 const int *pci; 3.pci是一…

FMC篇-SDRAM(IS42S16400J)

IS42S16400J 这个东西太常见啦&#xff0c;长方形的。不会过多解释&#xff0c;详细请阅读它的数据手册。 IS42S16400J是一种高速同步动态随机存储器(SDRAM)&#xff0c;64Mb的存储容量&#xff0c;采用4个bank&#xff0c;每个bank大小为16Mb&#xff0c;总线宽度为16位&…

eDiary-白日梦电子记事本基本使用说明【记事本导出和导入方法、本地数据迁移方法、记录工作日报、日历代办等】

文章目录 说明笔记导出与导入导出导入 本地数据迁移及备份本地备份说明恢复 记录工作日报记录今天发生美事等日历代办 说明 因为公司大佬分享资料&#xff0c;需要用到白日梦这个电子记事本&#xff0c;所以才了解到这个软件&#xff0c;体量小&#xff0c;功能高级&#xff0…

图数据库 NebulaGraph 的内存管理实践之 Memory Tracker

数据库的内存管理是数据库内核设计中的重要模块&#xff0c;内存的可度量、可管控是数据库稳定性的重要保障。同样的&#xff0c;内存管理对图数据库 NebulaGraph 也至关重要。 图数据库的多度关联查询特性&#xff0c;往往使图数据库执行层对内存的需求量巨大。本文主要介绍 …