图神经网络:在KarateClub上动手实现图神经网络

news2025/1/22 15:50:00

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。
3)我在百度网盘上传了这篇文章的jupyter notebook。超链。提取码8888。

文章目录

    • 文献阅读:
    • 代码实操:

文献阅读:

参考文献:SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS
中文翻译:用图神经网络进行半监督的分类
我在百度网盘上传这篇文献。超链。提取码8888。

文献首先:介绍了其他前辈的工作。在损失函数中使用拉普拉斯正则化项。公式如下(打这个公式真费劲,还的学Latex): L = L 0 + λ L r e g \mathcal{L}=\mathcal{L}_{0}+\lambda\mathcal{L}_{reg} L=L0+λLreg with L r e g = ∑ i , j A i , j ∣ ∣ f ( X i ) − f ( X j ) ∣ ∣ 2 = f ( X ) T Δ f ( X ) \mathcal{L}_{reg}=\sum_{i,j}{A}_{i,j}||\mathcal{f}({X}_{i})-\mathcal{f}({X}_{j})||^{2}=\mathcal{f}(X)^{T}\Delta\mathcal{f}(X) Lreg=i,jAi,j∣∣f(Xi)f(Xj)2=f(X)TΔf(X)
符号说明: L \mathcal{L} L表示为损失函数。 L 0 \mathcal{{L}_{0}} L0表示为有标签的损失(还有没标签的毕竟是半监督)。 λ \lambda λ表示为权重系数。 A i , j {A_{i,j}} Ai,j表示为图边。 f ( ⋅ ) \mathcal{f}(\cdot) f()表示为像神经网络的可微函数。 X X X表示为特征矩阵。 Δ = D − A \Delta=D-A Δ=DA表示为非规范化的拉普拉斯算子。 D D D表示为度的矩阵, D i , i = ∑ j A i , j D_{i,i}=\sum_{j}A_{i,j} Di,i=jAi,j
文章然后:简单说明使用上述公式需要有个假设:图中连接节点共享相同标签。于是作者这篇文章便就来了,为了解决这个问题,使用神经网络模型 f ( X , A ) f(X,A) f(X,A)编码图结构,避免使用显示基于图正则化。文章有两贡献,1.提出一种简单良好直接作用于图上的神经网络传播规则并且展示它是如何从谱图卷积的一阶逼近得到反馈。2.演示了基于图神经网络是如何分类的。
文章然后:具体开始阐述理论。 H l + 1 = σ ( D ~ − 1 2 A ~ D ~ − 1 2 H l W l ) H^{l+1}=\sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{l}W^{l}) Hl+1=σ(D~21A~D~21HlWl)。(知道核心公式就好,其他细节我们跳过因为我看不懂)
符号说明: D i , i = ∑ j A i , j D_{i,i}=\sum_{j}A_{i,j} Di,i=jAi,j表示为度的矩阵。 A ~ = A + I N \tilde{A}=A+I_{N} A~=A+IN表示为邻接矩阵加上一个单位矩阵。 W l W^{l} Wl表示为权重系数。 σ \sigma σ表示为激活函数。 H l H^{l} Hl为第 l l l层的特征矩阵。 H 0 H^{0} H0即为 X X X
文章然后:进行代码分类实操,他们这里搭建了两层GCN。所以最后的公式为 Z = f ( X , A ) = s o f t m a x ( A ~ R e l u ( A ~ X W 0 ) W l ) Z=f(X,A)=softmax(\tilde{A}Relu(\tilde{A}XW^{0})W^{l}) Z=f(X,A)=softmax(A~Relu(A~XW0)Wl)。损失函数就使用交叉熵 L = − ∑ l ∈ Y l ∑ f = 1 F Y l f ln ⁡ Z l f L=-\sum_{l \in \mathcal{Y}_{l}}\sum_{f=1}^FY_{lf}\ln{Z_{lf}} L=lYlf=1FYlflnZlf吧。
文章然后:介绍图半监督学习领域以及图上运行神经网络领域两个领域相关工作。
文章然后:进行实验展示结果。
文章然后:进行讨论。1.作者模型可以克服Skip-gram方法难以优化多步流程限制同时时间以及效果表现更好。2.未来工作1)解决内存:作者证明对于无法使用GPU大型图,用CPU是可行的。用小批量随机梯度可以缓解这个问题。但是生成小批量时应该考虑GCN的层数,对于非常大且密集连接的图可能需要进一步地近似。2)不支持有向图,但是有解决方法的(具体是什么我没看懂)3)考虑一个权衡参数 λ \lambda λ可能会有益。具体来说就是修好生成自循环图时用的 λ \lambda λ。即 A ~ = A + λ I \tilde{A}=A+\lambda I A~=A+λI
文章然后:得到结论。
文章最后:引用以及其他工作。1)WL-1算法2)深层的GCN。太深不好。
PS:以上仅是我的理解,我的理解可能不对。然后关于这个GCN以及WL算法,有两篇文章研究了它们,还是挺有趣的。我在百度网盘上传了这连篇文章。超链。提取码8888。

代码实操:

导入对应的库

import matplotlib.pyplot as plt
import networkx as nx

定义可视化的函数

def visualize_graph(G,color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G,pos=nx.spring_layout(G,seed=42),with_labels=False,node_color=color,cmap="Set2")
    plt.show()
#可视化图网络
def visualize_embedding(h,color,epoch=None,loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h=h.detach().cpu().numpy()
    plt.scatter(h[:,0],h[:,1],s=140,c=color,cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}',fontsize=16)
    plt.show()

导入对应的库:数据集1

from torch_geometric.datasets import KarateClub
dataset=KarateClub()

KarateClub数据集简单说明:34个人的社交网络,如果在俱乐部之外两人认识连一条边。然后由于俱乐部的内部冲突,人们选择站队所以分成两派。
打印数据集的信息

print(len(dataset),dataset.num_features,dataset.num_classes)
#输出:1 34 4

简单说明:num_features:33加上1。33指,这个节点与其他的33个节点是否有边,有边为1,无边为0。1是指度。num_classer:按理应该为2,但是官方做了修改,所以为4。

data=dataset[0]
#具体到确定的图上
print(data.num_nodes,data.num_edges,data,data.train_mask.sum().item())
#输出:34 156 Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34]) 4
print(data.has_isolated_nodes(),data.has_self_loops(),data.is_undirected())
#输出:False False True
edge_index=data.edge_index
print(edge_index.t())
#输出:不表

导入对应的库

from torch_geometric.utils import to_networkx

可视化图网络

G=to_networkx(data,to_undirected=True)
visualize_graph(G,color=data.y)

在这里插入图片描述
搭建模型GCN的框架

from torch_geometric.nn import GCNConv
from torch.nn import Linear
import torch
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=GCNConv(dataset.num_features,4)
        self.conv2=GCNConv(4,4)
        self.conv3=GCNConv(4,2)
        self.classifier=Linear(2,dataset.num_classes)
    def forward(self,x,edge_index):
        h=self.conv1(x,edge_index)
        h=h.tanh()
        h=self.conv2(h,edge_index)
        h=h.tanh()
        h=self.conv3(h,edge_index)
        h=h.tanh()
        out=self.classifier(h)
        return out,h
model=GCN()
print(model)
#输出
#GCN(
#  (conv1): GCNConv(34, 4)
#  (conv2): GCNConv(4, 4)
#  (conv3): GCNConv(4, 2)
#  (classifier): Linear(in_features=2, out_features=4, bias=True)
#)

简单说明: X v ( l + 1 ) = W ( l + 1 ) ∑ w ∈ N ( v ) ∪ { v } 1 c w , v ⋅ X w ( l ) X_{v}^{(l+1)}=W^{(l+1)}\sum_{w \in N(v)\cup{\{v\}}}\frac{1}{c_{w,v}}\cdot X_{w}^{(l)} Xv(l+1)=W(l+1)wN(v){v}cw,v1Xw(l)
可视化图嵌入(这里只有正向传播)

model=GCN()
_,h=model(data.x,data.edge_index)
visualize_embedding(h,color=data.y)

在这里插入图片描述
进行训练得出结果

model=GCN()
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
def train(data):
    optimizer.zero_grad()
    out,h=model(data.x,data.edge_index)
    loss=criterion(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss,h
for epoch in range(401):
    loss,h=train(data)
    if epoch==400:
        visualize_embedding(h,color=data.y,epoch=epoch,loss=loss)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/482185.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaWeb05(删除增加修改功能实现连接数据库)

目录 一.实现删除功能 1.1 url如何传参? xx.do?参数参数值&参数名参数值 1.2 servlet如何拿对应值? //根据参数名拿到对应的参数值 String str req.getParameter("参数名") 1.3 如何询问? οnclick"return con…

区位码-GB2312

01-09区为特殊符号 10-15区为用户自定义符号区(未编码) 16-55区为一级汉字,按拼音排序 56-87区为二级汉字,按部首/笔画排序 88-94区为用户自定义汉字区(未编码) 特殊符号 区号:01 各类符号 0 1 2 3 4 …

I/O多路转接——epoll服务器代码编写

目录 一、poll​ 二、epoll 1.epoll 2.epoll的函数接口 ①epoll_create ②epoll_ctl ③epoll_wait 3.操作原理 三、epoll服务器编写 1.日志打印 2.TCP服务器 3.Epoll ①雏形 ②InitEpollServer 与 RunServer ③HandlerEvent 四、Epoll的工作模式 1.LT模式与ET…

第二十一章 光源

光源是每个场景必不可少的部分,光源除了能够照亮场景之外,还可以产生阴影效果。 Unity中分为四种光源类型: 1. 方向光:Directional Light 用于模拟太阳光,方向光任何地方都能照射到。 2. 点光源:Point L…

JavaWeb-Servlet【内含思维导图】

目录 Servlet思维导图​编辑 1.什么是Servlet 2.Servelt概述 3.Servlet-Quickstart Your Project 3.1创建一个Web项目,导入Servlet依赖 3.1.1 选择Servlet导入依赖 3.1.2 导入Servlet依赖 3.2 在Web项目,定义类,实现Servlet接口…

Java8新特性-流式操作

在Java8中提供了新特性—流式操作,通过流式操作可以帮助我们对数据更快速的进行一些过滤、排序、去重、最大、最小等等操作并且内置了并行流将流划分成多个线程进行并行执行,提供更高效、快速的执行能力。接下来我们一起看看Java8为我们新增了哪些便捷呢…

Python基础合集 练习19(类与对象3(多态))

多态 class Horse: def init(self, name) -> None: self.name name def fature(self):return 父亲-----马的名字: {0}.format(self.name)def mover(self):print(马儿跑起来很潇洒)class Monkey: def init(self, name) -> None: self.name name def fature(self):ret…

《用于准确连续非侵入性血压监测的心跳内生物标志物》阅读笔记

目录 0 基础知识 1 论文摘要 2 论文十问 3 实验结果 4 论文亮点与不足之处 5 与其他研究的比较 6 实际应用与影响 7 个人思考与启示 参考文献 0 基础知识 非侵入性是指在进行医学检查或治疗时,不需要切开皮肤或穿刺体内组织,而是通过外部手段进…

【VQGAN论文精读】Taming Transformers for High-Resolution Image Synthesis

【VQGAN论文精读】Taming Transformers for High-Resolution Image Synthesis 0、前言Abstract1. Introduction2. Related Work3. Approach3.1. Learning an Effective Codebook of Image Constituents for Use in Transformers学习一个有效的图像成分的Codebook为了在Transfor…

高性能:负载均衡

目录 什么是负载均衡 负载均衡分类 服务端负载均衡 服务端负载均衡——软硬件分类 服务端负载均衡——OSI模型分类 客户端负载均衡 负载均衡常见算法 七层负载均衡做法 DNS解析 反向代理 什么是负载均衡 将用户请求分摊(分流) 到不同的服务器上…

小记Java调用C++开发的动态链接库(DLL)

一、背景 五一快乐吖!死肥宅正趁着五一这段时间,努力提升自己! 最近使用Java拦截Windows系统中一些默认事件时,发现了一些瓶颈。 我用Java操作浏览器、用Java最小化其他应用窗口,但是我发现这个操作,他都…

【Unity-UGUI控件全面解析】| InputField 输入框组件详解

🎬【Unity-UGUI控件全面解析】| InputField 输入框组件详解一、组件介绍二、组件属性面板2.1 Content Type(内容类型)三、代码操作组件四、组件常用方法示例4.1 代码限制输入字符4.2 校验文本输入格式4.3 校验输入文本长度💯总结🎬 博客主页:https://xiaoy.blog.csdn.…

话说【永恒之塔sf】里面最有前途的职业:商人

如果有人问我永恒之塔里面什么职业最有前途!那我告诉你就是商人! 做一个NB商人比拥有一身牛b装备要更有成就感。 在老区由于进入的比较晚,所以最后随了大流被淹死在千万基纳中。为了证明商人在永恒之塔是钱途无量的,我转到了新区—…

快解析动态域名解析,实现外网访问内网数据库

今天跟大家分享一下如何借助快解析动态域名解析,在两种特定网络环境下,实现外网访问内网mysql数据库。 第1种网络环境:路由器分配的是动态公网IP,且有路由器登录管理权限。如何实现外网访问内网mysql数据库? 针对这种…

IDEA2022版教程上()

0、前景摘要 0.1 概览 0.2 套课程适用人群 初学Java语言,熟悉了记事本、EditPlus、NotePad或Sublime Text3等简易开发工具的Java初学者熟练使用其他Java集成开发环境(IDE),需要转向IDEA工具的Java工程师们关注IDEA各方面特性的J…

Hadoop大数据分析技术(伪分布式搭建)

一.安装JDK和配置SSH免密登录 (1)准备软件 (2)解压压缩包 tar -zxvf jdk-8u221-linux-x64.tar.gz (3)在此处我们配置系统环境变量,使用命令: vim /etc/profile (4&#x…

Python入门教程(高级版)

Python用了好几年了,但似乎一直没 “系统入门” 过(o(╯□╰)o)。今年(2023年)趁着五一假期,我做了一次相对完整的 “入门” ——本文是这次学习历程的详细记录。 目录 1 Python基础1.1 Python1.1.1 认识Py…

Oracle VM VirtualBox安装centos7步骤 for win10

目录 1.安装VirtualBox 2.安装vagrant 3.安装centos7 4.查看网络与百度和物理机连通情况 5.设置IP 1.安装VirtualBox 下载的链接:Downloads – Oracle VM VirtualBox 2.安装vagrant 根据自己的操作系统选择对应的版本。 Install | Vagrant | HashiCorp Developer 我的P…

asp.net+sqlserver旅游网站zjy99A2

1.系统登录:系统登录是用户访问系统的路口,设计了系统登录界面,包括用户名、密码和验证码,然后对登录进来的用户判断身份信息,判断是管理员用户还是普通用户。 2.系统用户管理:不管是…

redis使用总结

目录 redis安装与登录redis 持久化RDB(Redis DataBase)AOF(Append Only File)RDB-AOF混合持久纯缓存模式 redis 的 keyredis 的数据类型和常见应用场景StringListHashMapSet集合ZSet有序集合bitmap位图HyperLogLog基数统计GEO 地理空间Stream 流bitfiled redis 事务事务的正常执…