图神经网络:在自定义数据集上动手实现图神经网络

news2024/9/30 5:26:00

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。

文章目录

  • 自定义数据集动手实现图神经网络
    • 自定义数据集
    • 训验测集拆分,创建Data的数据结构,观察Data的基本信息,可视化图网络
    • 搭建模型,训练前的准备,训练模型得出结果并可视化
  • 结果分析
  • 完整代码
  • 后记

自定义数据集动手实现图神经网络

自定义数据集

导库

from random import randint,sample

数据背景描述:这段代码生成北京化工大学三个学院(国际教育学院(100个学生),数理学院(300个学生),信息学院(500个学生))共计900个学生社交网络。每个学生,学院内部随机认识随机个人;学院外部随机认识随机个人,从而搭建边的关系。具体如何随机只能请看代码,鉴于篇幅原因不再过多赘述。

class dataset:

    def __init__(self):
        self.data_x=[];self.data_y=[]
        for i in range(100):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(0)
        for i in range(100,400):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(1)
        for i in range(400,900):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(2)
        self.data_edge=[[],[]]
        lt1=[i for i in range(100)];lt2=[i for i in range(100,400)];lt3=[i for i in range(400,900)]
        lt4=lt2+lt3;lt5=lt1+lt3;lt6=lt1+lt2;lt7=lt1+lt2+lt3
        for i in range(100):
            j=randint(30,70)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[:i]+lt7[i+1:100],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt4,j))
        for i in range(100,400):
            j=randint(50,100)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[100:i]+lt7[i+1:400],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt5,j))
        for i in range(400,900):
            j=randint(75,125)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[400:i]+lt7[i+1:],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt6,j))
    
    def x(self):
        return self.data_x
    
    def edge(self):
        return self.data_edge
    
    def y(self):
        return self.data_y

训验测集拆分,创建Data的数据结构,观察Data的基本信息,可视化图网络

导库

from torch_geometric.data import Data
import torch

训验测集拆分:对数据集随机拆分。train_val_test_split中,a代表训练集比例,b代表验证集比例,所以测试集比例为1-a-b啦。c代表了至少保证训练集中每个类别至少c个,否则一直循环随机去构建训练集,直到满足每个类别至少c个。最后返回三个长度为900的布尔列表吧。

def caozuo(lt):
    lt_new=[False for i in range(900)]
    for i in lt:
        lt_new[i]=True
    return lt_new

def T_or_F(lt,d):
    a,b,c=0,0,0
    for i in lt:
        if i<100:
            a+=1
        elif i<400:
            b+=1
        elif i<900:
            c+=1
        if a>=d and b>=d and c>=d:
            return True
    return False

def train_val_test_split(a,b,c):
    while True:
        lt=sample([i for i in range(900)],900)
        train_index=lt[:int(900*a)];val_index=lt[int(900*a):int(900*a+900*b)];test_index=lt[int(900*a+900*b):]
        if T_or_F(train_index,c)==False:
            continue
        else:
            return caozuo(train_index),caozuo(val_index),caozuo(test_index)
data=dataset();x=torch.Tensor(data.x());edge=torch.LongTensor(data.edge());y=torch.LongTensor(data.y());lt1,lt2,lt3=train_val_test_split(0.1,0.3,5)

创建Data的数据结构,观察Data的基本信息:简单说明一下 x x x 吧 , x : 900 × 900 x:900 \times 900 x:900×900 单位阵,表示第i个节点在第i维度。这个特征矩阵不算是通常意义上特征矩阵,信息有限。

data=Data(x=x,edge_index=edge,y=y,train_mask=torch.BoolTensor(lt1),val_mask=torch.BoolTensor(lt2),test_mask=torch.BoolTensor(lt3));print(data)
#输出:Data(x=[900, 900], edge_index=[2, 81246], y=[900], train_mask=[900], val_mask=[900], test_mask=[900])

导库

from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx

可视化图网络

def visualize_graph(G,color):
    plt.style.use("seaborn");plt.rcParams['font.family']='SimHei';plt.rcParams['font.sans-serif']=['SimHei']
    plt.figure(figsize=(16,9));plt.title("国教,数理,信息学院学生社交网络",size=20);plt.xticks([]);plt.yticks([])
    nx.draw_networkx(G,pos=nx.spring_layout(G),with_labels=False,node_size=3,node_color=color,width=0.01,edge_color="black",cmap="Set2")
    legend_dict={"red":"国教","blue": "数理","green":"信息"};plt.legend(handles=[plt.Line2D([],[],color=c,label=l,linestyle='None',marker='o') for c, l in zip(legend_dict.keys(),legend_dict.values())],loc='upper right')
    plt.savefig("figure",dpi=1000)
G=to_networkx(data,to_undirected=True);visualize_graph(G,color=["red"]*100+["blue"]*300+["green"]*500)

在这里插入图片描述

搭建模型,训练前的准备,训练模型得出结果并可视化

导库

from torch_geometric.nn import GCNConv
import torch.nn.functional as F

搭建模型

class GCN(torch.nn.Module):
    def __init__(self,hidden_channels):
        super().__init__()
        self.conv1=GCNConv(900,hidden_channels)
        self.conv2=GCNConv(hidden_channels,3)
    def forward(self,x,edge_index):
        x=self.conv1(x,edge_index)
        x=x.relu()
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.conv2(x,edge_index)
        return x

训练前的准备

model=GCN(9);optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4);criterion=torch.nn.CrossEntropyLoss()

训练模型得出结果并可视化

def train():
    model.train()
    optimizer.zero_grad()
    out=model(data.x,data.edge_index)
    loss=criterion(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss

def test():
    model.eval()
    out=model(data.x,data.edge_index)
    pred=out.argmax(dim=1)
    test_correct=pred[data.test_mask]==data.y[data.test_mask]
    test_acc=int(test_correct.sum())/int(data.test_mask.sum())
    return test_acc
for epoch in range(1,101):
    loss=train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
test_acc=test()
print(f'Test Accuracy: {test_acc:.4f}')
#这里就只展示测试集的结果
#输出:1.00

损失函数下降曲线
在这里插入图片描述

T-SNE可视化
模型未训练只经过一次正向传播
在这里插入图片描述

训练好后
在这里插入图片描述

结果分析

模型在测试集表现很好为1,那么效果为什么会这么好呢。数据集就很好,创建数据集的时候忽略了很多的现实因素,比如,有的同学它是社交恐怖分子认识全校80%的人,或者有的同学刚刚转院等等现实情况,所以这里创建的数据集十分理想,加之模型本身很好,最终导致了很好的结果。我们可以在数据集构建几个转院同学看看结果但是我就不做了吧。

完整代码

from torch_geometric.utils import to_networkx
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from random import randint,sample
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import torch.nn.functional as F
import networkx as nx
import numpy as np
import torch

def train():
    model.train()
    optimizer.zero_grad()
    out=model(data.x,data.edge_index)
    loss=criterion(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss

def test():
    model.eval()
    out=model(data.x,data.edge_index)
    pred=out.argmax(dim=1)
    test_correct=pred[data.test_mask]==data.y[data.test_mask]
    test_acc=int(test_correct.sum())/int(data.test_mask.sum())
    return test_acc

def visualize(h,color,s):
    z=TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
    plt.style.use("seaborn");plt.figure(figsize=(16,9));plt.xticks([]);plt.yticks([])
    plt.scatter(z[:,0],z[:,1],s=3,c=color)
    plt.savefig("figure"+str(s),dpi=1000)

class GCN(torch.nn.Module):
    def __init__(self,hidden_channels):
        super().__init__()
        self.conv1=GCNConv(900,hidden_channels)
        self.conv2=GCNConv(hidden_channels,3)
    def forward(self,x,edge_index):
        x=self.conv1(x,edge_index)
        x=x.relu()
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.conv2(x,edge_index)
        return x
    
def visualize_graph(G,color):
    plt.style.use("seaborn");plt.rcParams['font.family']='SimHei';plt.rcParams['font.sans-serif']=['SimHei']
    plt.figure(figsize=(16,9));plt.title("国教,数理,信息学院学生社交网络",size=20);plt.xticks([]);plt.yticks([])
    nx.draw_networkx(G,pos=nx.spring_layout(G),with_labels=False,node_size=3,node_color=color,width=0.01,edge_color="black",cmap="Set2")
    legend_dict={"red":"国教","blue": "数理","green":"信息"};plt.legend(handles=[plt.Line2D([],[],color=c,label=l,linestyle='None',marker='o') for c, l in zip(legend_dict.keys(),legend_dict.values())],loc='upper right')
    plt.savefig("figure1",dpi=1000)

def caozuo(lt):
    lt_new=[False for i in range(900)]
    for i in lt:
        lt_new[i]=True
    return lt_new

def T_or_F(lt,d):
    a,b,c=0,0,0
    for i in lt:
        if i<100:
            a+=1
        elif i<400:
            b+=1
        elif i<900:
            c+=1
        if a>=d and b>=d and c>=d:
            return True
    return False

def train_val_test_split(a,b,c):
    while True:
        lt=sample([i for i in range(900)],900)
        train_index=lt[:int(900*a)];val_index=lt[int(900*a):int(900*a+900*b)];test_index=lt[int(900*a+900*b):]
        if T_or_F(train_index,c)==False:
            continue
        else:
            return caozuo(train_index),caozuo(val_index),caozuo(test_index)

class dataset:

    def __init__(self):
        self.data_x=[];self.data_y=[]
        for i in range(100):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(0)
        for i in range(100,400):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(1)
        for i in range(400,900):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(2)
        self.data_edge=[[],[]]
        lt1=[i for i in range(100)];lt2=[i for i in range(100,400)];lt3=[i for i in range(400,900)]
        lt4=lt2+lt3;lt5=lt1+lt3;lt6=lt1+lt2;lt7=lt1+lt2+lt3
        for i in range(100):
            j=randint(30,70)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[:i]+lt7[i+1:100],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt4,j))
        for i in range(100,400):
            j=randint(50,100)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[100:i]+lt7[i+1:400],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt5,j))
        for i in range(400,900):
            j=randint(75,125)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[400:i]+lt7[i+1:],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt6,j))
    
    def x(self):
        return self.data_x
    
    def edge(self):
        return self.data_edge
    
    def y(self):
        return self.data_y
    
if __name__=="__main__":
    data=dataset();x=torch.Tensor(data.x());edge=torch.LongTensor(data.edge());y=torch.LongTensor(data.y());lt1,lt2,lt3=train_val_test_split(0.1,0.3,5)
    data=Data(x=x,edge_index=edge,y=y,train_mask=torch.BoolTensor(lt1),val_mask=torch.BoolTensor(lt2),test_mask=torch.BoolTensor(lt3));print(data)
    G=to_networkx(data,to_undirected=True);visualize_graph(G,color=["red"]*100+["blue"]*300+["green"]*500)
    model=GCN(9);optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4);criterion=torch.nn.CrossEntropyLoss()
    model.eval();out=model(data.x,data.edge_index);visualize(out,color=["red"]*100+["blue"]*300+["green"]*500,s=2)
    lt=[]
    for epoch in range(1,101):
        loss=train();lt.append(loss.item())
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    test_acc=test()
    print(f'Test Accuracy: {test_acc:.4f}')
    plt.style.use("seaborn");plt.rcParams['font.family']='SimHei';plt.rcParams['font.sans-serif']=['SimHei']
    plt.figure(figsize=(16,9));plt.title("损失函数下降曲线",size=20)
    plt.plot([i for i in range(len(lt))],lt,marker="o",ms=3,color="red",linewidth=1,label="交叉熵")
    plt.xticks([i for i in range(0,len(lt),10)],[i+1 for i in range(0,len(lt),10)]);plt.legend();plt.savefig("figure3",dpi=1000)
    model.eval();out=model(data.x,data.edge_index);visualize(out,color=["red"]*100+["blue"]*300+["green"]*500,s=4)

后记

应该会再写一篇文章吧,我们从底层实现PYG,就不导库了吧,然后具体讲讲GCN是怎么操作。然后,我们结束图神经网络在无向图中节点分类这个话题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/485711.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Golang项目实战】用Go写一个学生信息管理系统,真的太酷啦| 保姆级详解,附源码——建议收藏

博主简介&#xff1a;努力学习的大一在校计算机专业学生&#xff0c;热爱学习和创作。目前在学习和分享&#xff1a;数据结构、Go&#xff0c;Java等相关知识。博主主页&#xff1a; 是瑶瑶子啦所属专栏: Go语言核心编程近期目标&#xff1a;写好专栏的每一篇文章 学习了Go的基…

Java 基础进阶篇(十)—— Java集合详细总结

文章目录 一、集合类体系结构二、Collection系列集合2.1 Collection 集合体系2.2 Collection 集合体系特点2.3 Collection 常用API2.4 Collection 集合的遍历方式2.4.1 方式一&#xff1a;迭代器2.4.2 方式二&#xff1a;foreach&#xff08;增强for循环&#xff09;2.4.3 方式…

Python系列之Windows环境安装配置

目录 一、Python安装 1.1下载 1.2 安装 1.3增加环境变量 二、PyCharm安装 2.1 PyCharm简介 2.2 PyCharm下载安装 一、Python安装 1.1下载 python 官网The official home of the Python Programming Languagehttps://www.python.org/downloads/ 1.2 安装 要勾选选项 Ad…

校园兼职平台系统的设计与实现

技术栈&#xff1a; Spring、SpringMVC、MyBatis、HikariCP、fastjson、slf4j、EL和JSTL 系统功能&#xff1a; 前台&#xff1a; &#xff08;1&#xff09;用户注册&#xff1a;这里的用户分为职位发布者和职位应聘者&#xff0c;他们都需要注册本大学生兼职管理系统才能进…

为什么 OpenAI 团队采用 Python 开发他们的后端服务?

Python&#xff0c;年龄可能比很多读者都要大&#xff0c;但是它在更新快速的编程界却一直表现出色&#xff0c;甚至有人把它比作是编程界的《葵花宝典》&#xff0c;只是Python的速成之法相较《葵花宝典》有过之而无不及。 Python简洁&#xff0c;高效的特点&#xff0c;大大…

196页11万字智慧水务平台建设方案

本资料来源公开网络&#xff0c;仅供个人学习&#xff0c;请勿商用&#xff0c;如有侵权请联系删除。 业务需求分析 3.1 主要业务描述 &#xff08;1&#xff09;调度中心主要业务描述 配套工程调度中心为一级调度机构&#xff0c;同时也是水务集团原水供水的统一调度中心。…

python-pandas库

目录 目录 目录 1.pandas库简介&#xff08;https://www.gairuo.com/p/pandas-overview&#xff09; 2.pandas库read_csv方法&#xff08;https://zhuanlan.zhihu.com/p/340441922?utm_mediumsocial&utm_oi27819925045248&#xff09; 1.pandas库简介&#xff08;http…

第七章 使用ssh服务管理远程主机

第七章 使用ssh服务管理远程主机 一、配置网卡服务 1、配置网卡参数 &#xff08;1&#xff09;、执行nmtui命令运行网络配置工具 [rootcentos ~]# nmtui&#xff08;2&#xff09;、选择编辑连接并按回车 &#xff08;3&#xff09;、选择以太网中网卡名称并编辑 &#xf…

JavaWeb06(三层架构连接数据库)

目录 三层架构 1.什么是三层架构 三层架构 就是将整个业务划分为三层&#xff1a;表示层、业务逻辑层、数据访问层。 2. 层与层之间的关系 3.怎么理解三层架构 4.为什么需要三层架构 区分层次的目的是为了“高内聚&#xff0c;低耦合”的思想&#xff1b; 简单来说&…

从零开始学习Linux运维,成为IT领域翘楚(五)

文章目录 &#x1f525;Linux打包压缩与搜索命令&#x1f525;Linux常用系统工作命令&#x1f525;Linux管道符、重定向与环境变量&#x1f525;管道命令符 &#x1f525;Linux打包压缩与搜索命令 tar 命令 语法&#xff1a; tar [选项] [文件]选项: &#x1f41f; -c 产生.t…

牛客网---CM11 链表分割 代码详解+哨兵位的比较

文章目录 前言CM11 链表分割链接&#xff1a;方法一&#xff1a;尾插(带哨兵位)1.1 思路&#xff1a;1.2 代码&#xff1a;1.3 流程图1.4 注意点 方法二&#xff1a;尾插(不带哨兵位)2.1代码&#xff1a; 对比&#xff1a; 总结 前言 独处未必孤独喜欢就是自由 本章的内容是牛…

Chapter4:频率响应法(上)

第四章:频率响应法 Exercise4.1 已知微分网络和积分网络电路图如下图所示,求网络的频率特性。 解: 【图 ( a ) ({\rm a}) (a)微分网络】 由微分网络电路图可得:

c# 运算符重载

1.概要 1.1可重载运算符 可重载运算符说明 x, -x, !x, ~x, , --, true, falsetrue和 false 运算符必须一起重载。 x y, x - y, x * y, x / y, x % y, x & y, x | y, x ^ y, x << y, x >> y, x >>> y x y, x ! y, x < y, x > y, x < y,…

使用NNI对BERT模型进行粗剪枝、蒸馏与微调

前言 模型剪枝&#xff08;Model Pruning&#xff09;是一种用于减少神经网络模型尺寸和计算复杂度的技术。通过剪枝&#xff0c;可以去除模型中冗余的参数和连接&#xff0c;从而减小模型的存储需求和推理时间&#xff0c;同时保持模型的性能。模型剪枝的一般步骤&#xff1a…

OpenAI文本生成器-怎么解决openai只写一半

openai写文案写一半没了怎么解决 如果您正在使用 OpenAI 写文案的服务&#xff0c;在撰写文案的过程中遇到了意外中断或者其他问题导致文案未保存&#xff0c;以下是一些有用的解决方法&#xff1a; 重新调用 API 去生成文案。您可以调用 OpenAI 的 API 重新获取您所需的文案…

Three.js--》几何体顶点知识讲解

目录 几何体顶点位置数据 点线定义几何体顶点数据 网格模型定义几何体顶点数据 顶点法线数据 实现阵列立方体与相机适配 常见几何体简介 几何体的旋转、缩放、平移方法 几何体顶点位置数据 本篇文章主要讲解几何体的顶点概念&#xff0c;相对偏底层一些&#xff0c;不过…

魔兽世界商业服务端定制商人自定义NPC教程

魔兽世界自定义NPC教程 大家好&#xff0c;我是艾西今天跟大家聊一下自定义NPC&#xff0c;自定义NPC可以添加自己想要售卖的物品以及定价等可以更好的将一个游戏设定以及游戏的拓展性有质的提升 creature表是游戏所有生物人物等表格 Creature_template是所有生物模板&#xf…

kafka快的原因(四)

四、kafka快的原因 4.1 顺序读写page cache 见上一节文件系统 使用6个7200rpm、SATA接口、RAID-5的磁盘阵列在JBOD配置下的顺序写入的性能约为600MB/秒&#xff0c;但随机写入的性能仅约为100k/秒&#xff0c;相差6000倍以上。 4.2 网络模型 4.2.1 reactor模型 4.2.2 epo…

kubernetes项目部署

目录 ​一、容器交付流程 二、k8s平台部署项目流程 三、在K8s平台部署项目 一、容器交付流程 容器交付流程通常分为四个阶&#xff1a;开发阶段、持续集成阶段、应用部署阶段和运维阶段。 开发阶段&#xff1a;开发应用程序&#xff0c;编写Dockerfile; 持续集成阶段&#…

gradle 模块

目录 ​settings.gradle文件的作用 SourceSet类的作用 Plugin 插件 Java 对 Plugin 的扩展 settings.gradle文件的作用 settings用于配置哪些工程是要被gradle集成的&#xff0c;gradle 通过 Settings.java 类来处理 settings.gradle 文件。 gradle的初始化阶段&#xff0c…