GNN实战——KarateClub数据集

news2025/1/20 3:40:21

GNN:graph neural network 图神经网络,是⼀种连接模型,通过⽹络中节点之间的信息传递(message passing)的⽅式来获取图中的依存关系(dependence of graph),GNN通过从节点任意深度的邻居来更新该节点状态,这个状态能够表示状态信息。由于 GNN 在图节点之间强大的建模功能,使得与图分析相关的研究领域取得了突破。图神经网络(GNN)是一类基于深度学习的处理图域信息的方法。由于其较好的性能和可解释性,现已被广泛应用到各个领域。涵盖了推荐系统、组合优化、计算机视觉、物理 / 化学以及药物发现等领域。

一、数据集介绍

数据集中只有一张图。
在这里插入图片描述
该图描述了一个空手道俱乐部会员的社交关系,以34名会员作为节点,如果两位会员在俱乐部之外仍保持社交关系,则在节点间增加一条边。
每个节点具有一个34维的特征向量,一共有78条边。在收集数据的过程中,管理人员 John A 和 教练 Mr. Hi(化名)之间产生了冲突,会员们选择了站队,一半会员跟随 Mr. Hi 成立了新俱乐部,剩下一半会员找了新教练或退出了俱乐部。通过收集到的图数据,Zachary 进行了分类,除1名会员外都分类正确。将原图进行抽象可得到下图:
在这里插入图片描述

二、GNN实战

1. 导入所需的包

%matplotlib inline
import torch
import networkx as nx
import matplotlib.pyplot as plt
# KarateClub是torch_geometric内置的数据集
from torch_geometric.datasets import KarateClub

注:torch_geometric库的安装不能直接pip install,具体的安装方法可以参考之前的blog:https://blog.csdn.net/m0_51339444/article/details/128611141

2. 定义可视化函数

def visualize_graph(G, color):
    plt.figure(figsize=(5,5))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap='Set2')
    plt.show()

3. 导入并查看KarateClub数据集

dataset = KarateClub()
print(f'Dataset: {dataset}:')
print(f'Number of the graphs: {len(dataset)}')
print(f'Number of the features: {dataset.num_features}')
print(f'Number of the classes: {dataset.num_classes}')

在这里插入图片描述

data = dataset[0]
print(data)

在这里插入图片描述

# edge_index是邻接矩阵,表示每两个点之间的关联
edge_index = data.edge_index
# 打印出每个点分别和谁有关系
print(edge_index.t())

这里对上一个运行结果解释一下,这是整个数据集的全部生态环境了,x是特征,就是一个一个的点,第一个34表示一共有34个点,即34个样本,第二个34表示每个样本是34维的向量(即34个特征);edge_index是邻接矩阵,表示每两个点之间的关联,第一个元素一定是2,表示两个点之间的边,156表示一共有156个关系,即156条边;train_mask记录了34个数据中有标签与否,有标签是True,没有标签是False。

4. 使用networkx进行可视化展示

# 将处理好(对应的标准格式)的data传入to_networkx,再传入visualize_graph(最上面自己定义的)绘图
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

在这里插入图片描述

5. 搭建网络

这里会使用到torch_geometric的方法(封装好的函数),有疑问的地方可以去官网查询API,这里拍个链接:https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4)  # 两个参数分别为输入特征和输出特征
        self.conv2 = GCNConv(4,4)
        self.conv3 = GCNConv(4,2)
        self.classifier = Linear(2, dataset.num_classes)   
    # x是特征,没经过一层后数据都是不断变化的,即x 变成h,h不断变成新的h,而edge_index邻接矩阵是一直不变的,谁和谁之间有联系是不变的    
    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index) # 输入特征和邻接矩阵
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()        
        # 分类层
        out = self.classifier(h)        
        # out是输出,h是中间结果(conv3的输出)(一个2维的向量(方便绘图打印))
        return out, h
            
model = GCN()
print(model)

由于数据集比较小,因此搭建小网络即可,网络参数如下:
在这里插入图片描述

6. 进行embedding操作并可视化

def visualize_embedding(h, color, epoch=None, loss=None):
    plt.figure(figsize=(5,5))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap='Set2')
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()
model = GCN()
_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')
visualize_embedding(h, color=data.y)

在这里插入图片描述

7. 训练模型

import time 

model = GCN()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(data):
    optimizer.zero_grad()
    out, h = model(data.x, data.edge_index) 
    # 这里体现了半监督的思想,只拿有标签的计算损失,没有标签的不参与计算
    loss = loss_function(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss, h

for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 0:
        visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
        time.sleep(0.3)

在这里插入图片描述
在这里插入图片描述
可以看到,随着epoch的增大,损失函数逐渐收敛,可视化结果逐渐将三种颜色分成了三个类别(类似聚类的结果)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/161457.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux网络编程 第四天

目录 学习目标 多路IO-poll 多路IO-epoll 进阶epoll 用实验验证LT和ET模式 epoll反应堆 学习目标 1 了解poll函数 2 熟练使用epoll多路IO模型 3 了解epoll ET/LT触发模式并实现 4 理解epoll边缘非阻塞模式并实现 5 了解epoll反应堆模型设计思想 6 能看懂epoll反应堆模型的…

《C++程序设计原理与实践》笔记 第10章 输入/输出流

在本章和下一章中,我们将介绍C标准库中用于处理来自各种源的输入和输出的功能:I/O流。本章关注基本模型:如何读写单个值,以及如何打开和读写整个文件。下一章将介绍具体细节。 10.1 输入和输出 如果没有数据,计算就毫…

【正点原子FPGA连载】第十三章Linux内核移植 摘自【正点原子】DFZU2EG_4EV MPSoC之嵌入式Linux开发指南

1)实验平台:正点原子MPSoC开发板 2)平台购买地址:https://detail.tmall.com/item.htm?id692450874670 3)全套实验源码手册视频下载地址: http://www.openedv.com/thread-340252-1-1.html 第十三章Linux内…

2023版软件测试学习路线图(超详细自学路线)

送福利了!超详细的软件测试学习路线图来啦,2023版是首发哟!软件测试学习路线图分为9个阶段,包含:软件测试环境配置和管理-->软件测试数据管理与数据库测试-->web前端测试技术-->通用软件测试技术-->Python…

回顾2022! 链上NFT精彩项目大盘点

过去一年,WEB3和元宇宙无疑吸引了一大波关注度和热度。不少知名品牌如耐克、GUCCI、百事可乐、星巴克、麦当劳等都纷纷加入这波浪潮,通过推出NFT、数字商品等,来尝试WEB3机制,进而塑造更好的用户消费体验和参与度。NFT兼具身份、功…

springboot,vue二手交易平台

开发工具:IDEA服务器:Tomcat9.0, jdk1.8项目构建:maven数据库:mysql5.7系统用户前台和管理后台两部分,项目采用前后端分离前端技术:vue elementUI服务端技术:springbootmybatis项目功…

0基础快速掌握正则表达式

背景 在日常开发中,我们经常会遇到使用正则表达式的场景,比如一些常见的表单校验,会让你匹配用户输入的手机号或者身份信息是否规范,这就可以用正则表达式去匹配。相信大多数人在碰到这种场景的时候都是直接去网上找,…

在 2023 ETH Denver 与 Cartesi 一起建设

我们非常高兴的加入了 2023年ETHDenver,参加了BUIDLathon 赛道和现场研讨会等活动。作为规模最大、持续时间最长的ETH 活动之一,我们将向热衷于为全球区块链生态系统做出贡献的新开发者社区分享 Cartesi 技术。你想在2023年#BUIDL 做一些有趣有意义的事情…

基于springboot的景区旅游信息管理系统(源代码+数据库)

基于springboot的景区旅游信息管理系统(源代码数据库) 一、系统介绍 本项目分为管理员与普通用户两种角色 用户登录 前台功能:旅游路线、旅游景点、旅游酒店、旅游车票、旅游保险、旅游策略管理员登录 后台功能:用户管理、旅游路线管理、旅游景点管理…

Codeforces Round #843 (Div. 2)(A~C,E)

A1/A2. Gardener and the Capybaras (easy version)三个字符串,按照顺序连在一起,三个字符串满足第二个字符串大于等于第一个和第三个,或者第二个字符串小于等于第一个和第三个,输出满足情况的三个字符串。思路:对于长…

ubuntu18.04系统下挂载新的机械硬盘

ubuntu18.04系统下挂载新的机械硬盘1.显示硬盘以及所属分区情况sudo fdisk -lDisk /dev/sda doesnt contain a valid partition table硬盘分区 对机械硬盘进行操作 sudo fdisk /dev/sda下图表示的是具体流程截图: The partition table has been altered!硬盘格式…

AWS RDS开启审计日志

问题 需要对AWS的RDS开启相关日志。先检查RDS是否开启日志,如下图: 选中一个数据库实例,查看到只开启了数据库的错误日志。但是,我们需要开启其他类型的审计日志。下面开始怎么样开启其他类型日志,来启用高级审计模…

corrosion 靶机(ffuf模糊测试,命令执行)

环境准备 靶机链接:百度网盘 请输入提取码 提取码:c2j6 虚拟机网络链接模式:桥接模式 攻击机系统:kali linux 2022.03 信息收集 1.探测目标靶机开放端口和服务情况 2.用gobuster扫描目录,并访问 gobuster dir -…

手把手编译FFmpeg

支持centos8.6、ubuntu20.04 export 建议开始之前,弄一台干净的机子,或者系统恢复到出厂设置,否则容易出问题 然后设置动态库默认加载目录(注意/usr/local/lib不是系统默认的路径,/lib和/usr/lib才是) …

jsp库存管理管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 JSP 库存管理管理系统 是一套完善的系统源码,对理解JSP java serlvet MVC编程开发语言有帮助,系统具有完整的源代码和数据库,以及相应配套的设计文档,系统主要采用B/S 模式开发。 通过本系统建设&#xff0c…

ArcGIS基础实验操作100例--实验97计算河道方向坡度

本实验专栏参考自汤国安教授《地理信息系统基础实验操作100例》一书 实验平台:ArcGIS 10.6 实验数据:请访问实验1(传送门) 空间分析篇--实验97 计算河道方向坡度 目录 一、实验背景 二、实验数据 三、实验步骤 (1&…

内存管理-模板初阶理解-string类的模拟实现

文章目录1. 内存管理operator new和operator delete面试题:malloc、free和new、delete的区别2. 内存泄漏1. 内存泄漏:2. 内存泄漏危害:3.堆内存泄漏4.系统资源泄漏3. 模板初阶函数模板类模板:模板运行时不检查数据类型&#xff0c…

黑马编程资源最新最全全清单:速来收藏~

今年是黑马坚持免费分享视频教程的第16年,每年到了这个时候,「成绩单」也不会缺席,不仅是对过往的回顾,更是对那些选择跟着黑马持续学习的小伙伴们的一种激励。 黑马视频教程2022年速报 截至年底,黑马程序员 B 站累计…

ArcGIS基础实验操作100例--实验96创建地形剖面图

本实验专栏参考自汤国安教授《地理信息系统基础实验操作100例》一书 实验平台:ArcGIS 10.6 实验数据:请访问实验1(传送门) 空间分析篇--实验96 创建地形剖面图 目录 一、实验背景 二、实验数据 三、实验步骤 (1&am…

头戴式耳机跑步方便吗、公认最好的跑步耳机排行榜

平时,我们总能看到许多运动健身的人群,在锻炼时都佩戴着耳机。但运动耳机的选择,同样是大有学问的。如果佩戴传统的真无线蓝牙耳机,有可能出现佩戴不稳、耳道肿胀等问题,影响运动体验。所以今天我们特意给大家带来几款…