图神经网络 pytorch GCN torch_geometric KarateClub 数据集

news2024/12/25 13:21:20

图神经网络

安装Pyg

首先安装torch_geometric需要安装pytorch然后查看一下自己电脑Pytorch的版本

import torch
print(torch.__version__)
#1.12.0+cu113

然后进入官网文档网站

链接: https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html
安装自己的版本选择安装命令,我python用的稳定的3.8版本。如果安装失败可以考虑降低python的版本

在这里插入图片描述
因为我之前安装过所以显示如下
在这里插入图片描述

图信号数据集初入门

本次入门选用Karateclub数据集
在这里插入图片描述
这个数据集讲诉的是一个空手道俱乐部之间人和人的关系,每个节点代表一个人说俱乐部的两个教练吵架了,要每一个节点所代表的人进行站队通过图信号预测。
首先读取数据集

from torch_geometric.datasets import KarateClub

dataset = KarateClub()
print(f'Number of graphs:{len(dataset)}')
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')
#Number of graphs:1
#Number of features:34
#Number of classes:4

可以看到只有一张图每个节点有34个特征,每个特征代表的应该是每一个会员的信息,分成四类我们可以暂时理解成跟了教练A的,跟了教练B的,换了一个新教练的,和退出俱乐部的这四类。

然后我们将这个图打出来进行观察可以看到节点是分成了四类

import matplotlib.pyplot as plt
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
import networkx as nx

dataset = KarateClub()
print(f'Dataset{dataset}')
print(f'Number of graphs:{len(dataset)}')
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')


def visualize_graph(G,color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G,pos = nx.spring_layout(G,seed = 42),with_labels=False,node_color=color,cmap="Set2")
    plt.savefig("net.jpg")
    plt.show()


data = dataset[0]
print(data)
G = to_networkx(data,to_undirected=True)
visualize_graph(G,color=data.y)

在这里插入图片描述
然后我们观察一个图的数据可以观察到一共有34个节点每个节点有34个数据一共有156条边

from torch_geometric.datasets import KarateClub

dataset = KarateClub()
data = dataset[0]
print(data)
#Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

训练代码

接下来使用pyg进行训练

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub
import matplotlib.pyplot as plt

dataset = KarateClub()
data = dataset[0]


def visualize_embedding(h,color,epoch=None,loss=None):
    global i
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:,0],h[:,1],s=140,c=color,cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f"Epoch:{epoch},Loss: {loss.item():.4f}",fontsize = 16,)
    plt.show()

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features,4)
        self.conv2 = GCNConv(4,4)
        self.conv3 = GCNConv(4,2)
        self.classifier = Linear(2,dataset.num_classes)

    def forward(self,x,edge_index):
        h = self.conv1(x,edge_index)
        h = h.tanh()
        h = self.conv2(h,edge_index)
        h = h.tanh()
        h = self.conv3(h,edge_index)
        h = h.tanh()
        out = self.classifier(h)
        # return F.softmax(out,dim=1),h
        return out,h

model = GCN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

loss_list = []

def train(data):
    optimizer.zero_grad()
    out, h = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    loss_list.append(loss.item())
    optimizer.step()
    return loss, h

for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 1:
        visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)

plt.plot(loss_list)
plt.show()

损失曲线如下
在这里插入图片描述

训练集可视化动图如下
![在这里插入图片描述](https://img-blog.csdnimg.cn/e1cff7bf8ba0498cb8abd6f61a2f83a6.gif在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/389334.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【小破站下载工具】Python tkinter 实现网站下载工具,所有数据一键获取

目录前言开发环境本次项目案例步骤先展示下完成品的效果界面导入模块先创建个窗口功能按键主要功能代码编写功能一功能二功能三前言 最近很多同学想问我,怎么把几个代码的功能集合到一起? 很简单,写一个界面就行了,想要哪个代码…

黑马程序员7

算数运算符重载 运算符重载概念:对已有的运算符重新进行定义,赋予其另一种功能,以适应不同的数据类型 加号运算符 通过自己写函数,实现两个对象相加属性后返回新的对象 两种方式重载 成员函数方式重载 全局函数重载 上来 perso…

Go 内存分布

Go内存分布方式在C中,每个值在内存中只占据一个内存块(一段连续内存);但是,一些Go类型的值可能占据多个内存块。以后,我们称一个Go值分布在不同内存块上的部分为此值的各个值部(value part&…

网络安全平台测试赛 easyphp(phar脏数据处理)

昨天的比赛,14.00-17.00.时间有点紧张,比赛期间没拿下来这道 😭非常痛苦,很顺畅的思路 一步步想下来,卡在最后一步末尾脏数据处理了,最后时间到了 没打通,还需多练 这里本地复现一下&#xff1…

linux 进程及调度基础知识

引用Linux进程管理专题Linux进程管理与调度-之-目录导航Linux下0号进程的前世(init_task进程)今生(idle进程)----Linux进程的管理与调度(五)蜗窝科技-进程管理郭健: Linux进程调度技术的前世今生之“前世”郭健: Linux进程调度技术…

1.7 古典概型问题类型一——随机取数问题

(1)我的答案:一、信息首先7个数字全不相同二、分析七个数字全不相同意味着每次取出来的数都不一样,然后每次取出后选择少一种,为简单排列不含10和1这意味从8个数里面选且可以重复为重复排列3.10恰好出现两次隐含着两个问题,第一&a…

备考考研2数学

进度说明,开始形成自己的复习进度说明! 14:21 2023年2月28日星期二 武忠祥数学 截止目前,看完了01.高数基础01 14:21 2023年2月28日星期二 开始看02. 现在15:04 2023年2月28日星期二, 因为这2天的百度网盘不能进行解析了&…

初识HTML技术

文章目录一、为什么学习前端?二、第一个HTML文件VSCode三. HTML元素四. HTML页面一、为什么学习前端? 我们作为一个后端程序员,为什么还要学习前端,因为我们的终极目的是实现web开发,搭建网站,网站 前端 后端 比如我们随便…

最近几篇较好论文实现代码(附源代码下载)

《Towards Layer-wise Image Vectorization》(CVPR 2022) GitHub: github.com/ma-xu/LIVEInstallationWe suggest users to use the conda for creating new python environment.Requirement: 5.0<GCC<6.0; nvcc >10.0.git clone gitgithub.com:ma-xu/LIVE.gitcd LIVE…

一步一步学会给Fritzing添加元器件-丰富你的器件库

文章目录1、获取元器件文件2、单个添加元器件3、批量加入&#xff08;1&#xff09;、通过别人发布的bin文件加载&#xff08;2&#xff09;、终极大招&#xff08;拖&#xff09;4、制作自己器件文章出处&#xff1a; https://blog.csdn.net/haigear/article/details/12931545…

【C++】类和对象——六大默认成员函数

&#x1f3d6;️作者&#xff1a;malloc不出对象 ⛺专栏&#xff1a;C的学习之路 &#x1f466;个人简介&#xff1a;一名双非本科院校大二在读的科班编程菜鸟&#xff0c;努力编程只为赶上各位大佬的步伐&#x1f648;&#x1f648; 目录前言一、类的6个默认成员函数二、构造…

错误异常捕获

1、React中错误异常捕获 在 React 中&#xff0c;可以通过 Error Boundaries&#xff08;错误边界&#xff09;来捕获错误异常。Error Boundaries 是一种 React 组件&#xff0c;它可以在其子组件树的渲染期间捕获 JavaScript 异常&#xff0c;并且可以渲染出备用 UI。React 提…

802.11 service服务类型

802.11 serviceservice定义service分类按照模块分为两类按照功能分为六类数据传输相关服务分布式服务DS&#xff08;Distribution Service&#xff09;整合服务IS&#xff08;Integration Service&#xff09;关联&#xff08;association&#xff09;重关联&#xff08;reasso…

RAD 11.3 delphi和C++改进后新增、废弃及优化的功能

RAD 11.3 delphi和C改进后新增和废弃的功能 目录 RAD 11.3 delphi和C改进后新增和废弃的功能 一、版本RAD 11.3 delphi和C改进后新增功能 1、官方视频位置&#xff1a; 2、官方文档的链接位置&#xff1a; 二、版本RAD 11.3 delphi和C改进后废弃的功能 2.1、编译器不再使…

Eureka注册中心和Nacos注册中心详解以及Nacos与Eureka有什么区别?

目录&#xff1a;前言Eureka注册中心Nacos注册中心Nacos与Eureka有什么区别&#xff1f;前言提供接口给其它微服务调用的微服务叫做服务提供者&#xff0c;而调用其它微服务提供的接口的微服务则是服务消费者。如果服务A调用了服务B&#xff0c;而服务B又调用了服务C&#xff0…

【iOS】设置背景渐变色

drawRect函数 主要负责iOS的绘图操作&#xff0c;程序会自动调用此方法进行绘图。我在这个函数中绘制渐变背景色。 方法定义&#xff1a; -(void)drawRect:(CGRect)rect; 重写此方法&#xff0c;执行重绘任务-(void)setNeedsDisplay; 标记为需要重绘&#xff0c;异步调用dra…

Mysql开发

Mysql开发 可以使用MySQL直接存储文件吗&#xff1f; 可以使用 BLOB (binary large object)&#xff0c;用来存储二进制大对象的字段类型。 TinyBlob 255 值的长度加上用于记录长度的1个字节(8位) Blob 65K值的长度加上用于记录长度的2个字节(16位) MediumBlob 16M值的长度加…

vue-v-for列表渲染中key的作用

1.虚拟DOM中key的作用: key是点拟DON对象的标识&#xff0c;当状态中的数据发生变化时&#xff0c;Vue会根据【新数据】生成【新的虚拟DOM】,随后Vue进行【新虚拟DOM】与【旧虚拟DOM】的差异比较&#xff0c;比较规则如下 2.对比规则: 旧虚拟DOM中找到了与新虚拟DOM相同的ke…

【NLP相关】ChatGPT的前世今生:GPT模型的原理、研究进展和案例

❤️觉得内容不错的话&#xff0c;欢迎点赞收藏加关注&#x1f60a;&#x1f60a;&#x1f60a;&#xff0c;后续会继续输入更多优质内容❤️&#x1f449;有问题欢迎大家加关注私戳或者评论&#xff08;包括但不限于NLP算法相关&#xff0c;linux学习相关&#xff0c;读研读博…

【二分查找】分巧克力、机器人跳跃、数的范围

Halo&#xff0c;这里是Ppeua。平时主要更新C语言&#xff0c;C&#xff0c;数据结构算法......感兴趣就关注我吧&#xff01;你定不会失望。 &#x1f308;个人主页&#xff1a;主页链接 &#x1f308;算法专栏&#xff1a;专栏链接 我会一直往里填充内容哒&#xff01; &…