【Graph Net学习】GNN/GCN代码实战

news2025/1/23 2:15:41

一、简介

        GNN(Graph Neural Network)和GCN(Graph Convolutional Network)都是基于图结构的神经网络模型。本文目标就是打代码基础,未用PyG,来扒一扒Graph Net两个基础算法的原理。直接上代码。

二、代码

import time
import random
import os
import numpy as np
import math
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import scipy.sparse as sp

#配置项
class configs():
    def __init__(self):
        # Data
        self.data_path = r'E:\code\Graph\data'
        self.save_model_dir = r'\code\Graph'

        self.model_name = r'GCN' #GNN/GCN
        self.seed = 2023

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_size = 64
        self.epoch = 200
        self.in_features = 1433  #core ~ feature:1433
        self.hidden_features = 16  # 隐层数量
        self.output_features = 8  # core~paper-point~ 8类

        self.learning_rate = 0.01
        self.dropout = 0.5

        self.istrain = True
        self.istest = True

cfg = configs()

def seed_everything(seed=2023):
    random.seed(seed)
    os.environ['PYTHONHASHSEED']=str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

seed_everything(seed = cfg.seed)

#数据
class Graph_Data_Loader():
    def __init__(self):
        self.adj, self.features, self.labels, self.idx_train, self.idx_val, self.idx_test = self.load_data()
        self.adj = self.adj.to(cfg.device)
        self.features = self.features.to(cfg.device)
        self.labels = self.labels.to(cfg.device)
        self.idx_train = self.idx_train.to(cfg.device)
        self.idx_val = self.idx_val.to(cfg.device)
        self.idx_test = self.idx_test.to(cfg.device)

    def load_data(self,path=cfg.data_path, dataset="cora"):
        """Load citation network dataset (cora only for now)"""
        print('Loading {} dataset...'.format(dataset))

        idx_features_labels = np.genfromtxt(os.path.join(path,dataset,dataset+'.content'),
                                            dtype=np.dtype(str))
        features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
        labels = self.encode_onehot(idx_features_labels[:, -1])

        # build graph
        idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
        idx_map = {j: i for i, j in enumerate(idx)}
        edges_unordered = np.genfromtxt(os.path.join(path,dataset,dataset+'.cites'),
                                        dtype=np.int32)
        edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                         dtype=np.int32).reshape(edges_unordered.shape)
        adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                            shape=(labels.shape[0], labels.shape[0]),
                            dtype=np.float32)

        # build symmetric adjacency matrix
        adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

        features = self.normalize(features)
        adj = self.normalize(adj + sp.eye(adj.shape[0]))

        idx_train = range(140)
        idx_val = range(200, 500)
        idx_test = range(500, 1500)

        features = torch.FloatTensor(np.array(features.todense()))
        labels = torch.LongTensor(np.where(labels)[1])
        adj = self.sparse_mx_to_torch_sparse_tensor(adj)

        idx_train = torch.LongTensor(idx_train)
        idx_val = torch.LongTensor(idx_val)
        idx_test = torch.LongTensor(idx_test)
        return adj, features, labels, idx_train, idx_val, idx_test

    def encode_onehot(self,labels):
        classes = set(labels)
        classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                        enumerate(classes)}
        labels_onehot = np.array(list(map(classes_dict.get, labels)),
                                 dtype=np.int32)
        return labels_onehot

    def normalize(self,mx):
        """Row-normalize sparse matrix"""
        rowsum = np.array(mx.sum(1))
        r_inv = np.power(rowsum, -1).flatten()
        r_inv[np.isinf(r_inv)] = 0.
        r_mat_inv = sp.diags(r_inv)
        mx = r_mat_inv.dot(mx)
        return mx

    def sparse_mx_to_torch_sparse_tensor(self,sparse_mx):
        """Convert a scipy sparse matrix to a torch sparse tensor."""
        sparse_mx = sparse_mx.tocoo().astype(np.float32)
        indices = torch.from_numpy(
            np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
        values = torch.from_numpy(sparse_mx.data)
        shape = torch.Size(sparse_mx.shape)
        return torch.sparse.FloatTensor(indices, values, shape)

#精度评价指标
def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

#模型
#01-GNN
class GNNLayer(nn.Module):
    def __init__(self, in_features, output_features):
        super(GNNLayer, self).__init__()
        self.linear = nn.Linear(in_features, output_features)

    def forward(self, adj_matrix, features):
        hidden_features = torch.matmul(adj_matrix, features)  # GNN公式:H' = A * H
        hidden_features = self.linear(hidden_features)  # 使用线性变换
        hidden_features = F.relu(hidden_features)  # 使用ReLU作为激活函数

        return hidden_features
class GNN(nn.Module):
    def __init__(self, in_features, hidden_features, output_features, num_layers=2):
        super(GNN, self).__init__()
        #输入维度in_features、隐藏层维度hidden_features、输出维度output_features、GNN的层数num_layers
        self.layers = nn.ModuleList(
            [GNNLayer(in_features, hidden_features) if i == 0 else GNNLayer(hidden_features, hidden_features) for i in
             range(num_layers)])
        self.output_layer = nn.Linear(hidden_features, output_features)

    def forward(self, adj_matrix, features):
        hidden_features = features
        for layer in self.layers:
            hidden_features = layer(adj_matrix, hidden_features)

        output = self.output_layer(hidden_features)

        return F.log_softmax(output,dim=1)

#02-GCN
class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

class GCN(nn.Module):
    def __init__(self, in_features, hidden_features, output_features, dropout=cfg.dropout):
        super(GCN, self).__init__()
        self.gc1 = GraphConvolution(in_features, hidden_features)
        self.gc2 = GraphConvolution(hidden_features, output_features)
        self.dropout = dropout

    def forward(self, adj_matrix, features):
        x = F.relu(self.gc1(features, adj_matrix))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj_matrix)
        return F.log_softmax(x, dim=1)


class graph_run():
    def train(self):
        t = time.time()
        #Create Train Processing
        all_data = Graph_Data_Loader()

        #创建一个模型
        model = eval(cfg.model_name)(in_features=cfg.in_features,
                                     hidden_features=cfg.hidden_features,
                                     output_features=cfg.output_features).to(cfg.device)
        optimizer = optim.Adam(model.parameters(),
                               lr=cfg.learning_rate, weight_decay=5e-4)

        #Train
        model.train()
        for epoch in range(cfg.epoch):
            optimizer.zero_grad()
            output = model(all_data.adj, all_data.features)
            loss_train = F.nll_loss(output[all_data.idx_train], all_data.labels[all_data.idx_train])
            acc_train = accuracy(output[all_data.idx_train], all_data.labels[all_data.idx_train])
            loss_train.backward()
            optimizer.step()
            loss_val = F.nll_loss(output[all_data.idx_val], all_data.labels[all_data.idx_val])
            acc_val = accuracy(output[all_data.idx_val], all_data.labels[all_data.idx_val])
            print('Epoch: {:04d}'.format(epoch + 1),
                  'loss_train: {:.4f}'.format(loss_train.item()),
                  'acc_train: {:.4f}'.format(acc_train.item()),
                  'loss_val: {:.4f}'.format(loss_val.item()),
                  'acc_val: {:.4f}'.format(acc_val.item()),
                  'time: {:.4f}s'.format(time.time() - t))
        torch.save(model, os.path.join(cfg.save_model_dir, 'latest.pth'))  # 模型保存

    def infer(self):
        #Create Test Processing
        all_data = Graph_Data_Loader()
        model_path = os.path.join(cfg.save_model_dir, 'latest.pth')
        model = torch.load(model_path, map_location=torch.device(cfg.device))
        model.eval()
        output = model(all_data.adj,all_data.features)
        loss_test = F.nll_loss(output[all_data.idx_test], all_data.labels[all_data.idx_test])
        acc_test = accuracy(output[all_data.idx_test], all_data.labels[all_data.idx_test])
        print("Test set results:",
              "loss= {:.4f}".format(loss_test.item()),
              "accuracy= {:.4f}".format(acc_test.item()))

if __name__ == '__main__':
    mygraph = graph_run()
    if cfg.istrain == True:
        mygraph.train()
    if cfg.istest == True:
        mygraph.infer()

三、结果与讨论

        需要从网上下载cora数据集,数据组织形式如下图。

        测了下Params和GFLOPs,还是比较大的,发现若作为一个Net的Block还是需要优化的哈哈~

ModelParamsGFLOPs
GNN23.352K126.258M
ModelCora(/train/val/test)
GNN1.0000/0.7800/0.7620
GCN0.9714/0.7767/0.8290

四、展望

        未来可以考虑用PyG(PyTorch Geometric),毕竟PyG实现GAT等图网络、图的数据组织、加载会更加方便。Graph Net通常用可以用于属性数据的embedding模式,将属性数据可以作为一种补充特征加入Net去训练,看能不能发挥效能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1021195.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue3记录

Vue3快速上手 1.Vue3简介 2020年9月18日,Vue.js发布3.0版本,代号:One Piece(海贼王)耗时2年多、2600次提交、30个RFC、600次PR、99位贡献者github上的tags地址:https://github.com/vuejs/vue-next/releas…

使用Python构建网络爬虫:从网页中提取数据

💂 个人网站:【工具大全】【游戏大全】【神级源码资源网】🤟 前端学习课程:👉【28个案例趣学前端】【400个JS面试题】💅 寻找学习交流、摸鱼划水的小伙伴,请点击【摸鱼学习交流群】 网络爬虫是一种强大的工…

libevent学习——Reactor模式

Reactor模式 Reator的事件处理机制 Reactor翻译为“反应堆”,是一种事件驱动机制。该机制和普通函数调用的不同在于:应用程序不是主动调用某个API完成处理,相反,Reactor逆置了事件处理流程,应用程序需要提供相应的接…

前后端分离毕设项目之基于springboot+vue的笔记记录分享网站设计与实现(内含源码+文档+部署教程)

博主介绍:✌全网粉丝10W,前互联网大厂软件研发、集结硕博英豪成立工作室。专注于计算机相关专业毕业设计项目实战6年之久,选择我们就是选择放心、选择安心毕业✌ 🍅由于篇幅限制,想要获取完整文章或者源码,或者代做&am…

练习敲代码速度

2023年9月18日,周一晚上 今晚不想学习,但又不想玩游戏,于是找了一些练习敲代码的网站来玩玩,顺便练习一下敲代码的速度 目录 参考资料个人推荐第一个 第二个第三个 参考资料 电脑打字慢,有哪些比较好的练打字软件&a…

xxl-job的原理(1)

xxl-job的架构 系统组成 调度中心:进行任务统一调度,可以新增和管理执行器和任务;执行器:任务执行依赖的组件,一个执行器可以关联多个任务,添加的执行器可以自动注册到调度中心上;任务&#x…

全国职业技能大赛云计算--高职组赛题卷②(私有云)

全国职业技能大赛云计算--高职组赛题卷②(私有云) 第一场次题目:OpenStack平台部署与运维任务1 基础运维任务(5分)任务2 OpenStack搭建任务(15分)任务3 OpenStack云平台运维(15分&am…

mysq 主从同步错误之 Error_code 1032 handler error HA_ERR_KEY_NOT_FOUND

错误说明: MySQL主从同步的1032错误,一般是指要更改的数据不存在,SQL_THREAD提取的日志无法应用故报错,造成同步失败 (Update、Delete、Insert一条已经delete的数据)。 1032的错误本身对数据一致性没什么影…

VScode断点调试vue

VScode断点调试vue 1、修改launch.js文件(没有这个文件就新建)。 {// Use IntelliSense to learn about possible attributes.// Hover to view descriptions of existing attributes.// For more information, visit: https://go.microsoft.com/fwlin…

ChatGLM 通俗理解大模型的各大微调方法:从LoRA、QLoRA到P-Tuning V1/V2

前言 PEFT 方法仅微调少量(额外)模型参数,同时冻结预训练 LLM 的大部分参数 第一部分 高效参数微调的发展史 1.1 Google之Adapter Tuning:嵌入在transformer里 原有参数不变 只微调新增的Adapter 谷歌的研究人员首次在论文《Parameter-Efficient Transfer Learning for N…

CSS选择器练习小游戏

请结合CSS选择器练习小游戏进行阅读(网页的动态效果是没有办法通过静态图片展示的) 网址:请点击 有些题有多种答案,本文就不一一列出了 第一题 答案:plate第二题 答案:bento第三题 答案:#fa…

前后端分离管理系统day01---Springboot+MybatisPlus

目录 目录 软件 基础知识 一创建后端项目 注意: 删除多余项 创建测试类 二 加入mybatis-plus依赖支持 1.加入依赖码 2.创建数据库实例/创建用户表/插入默认数据 创建数据库实例 创建表 插入数据 3.配置yml文件 注意:wms01必须是数据库的名字&…

JVM——8.内存分配方式

这篇文章我们来讲一下jvm的内存分配方式 目录 1.概述 1.1jvm运行时数据区 1.2堆空间的分代 1.3对象分配的整体流程 2.具体的内存分配方式 2.1指针碰撞法 2.2空闲列表法 2.3Java虚拟机选择策略 3.小结 1.概述 我们前面在GC那篇文章中写了JVM的内存分配策略&#xff0…

计算机竞赛 深度学习 opencv python 实现中国交通标志识别

文章目录 0 前言1 yolov5实现中国交通标志检测2.算法原理2.1 算法简介2.2网络架构2.3 关键代码 3 数据集处理3.1 VOC格式介绍3.2 将中国交通标志检测数据集CCTSDB数据转换成VOC数据格式3.3 手动标注数据集 4 模型训练5 实现效果5.1 视频效果 6 最后 0 前言 🔥 优质…

JWT~~

概述 回顾登录的流程: 接下来的问题是:这个出入证(令牌)里面到底存啥? 一种比较简单的办法就是直接存储用户信息的JSON串,这会造成下面的几个问题: 非浏览器环境,如何在令牌中记录…

【ABAP】一文了解如何实现ALV下拉列表编辑(附完整示例代码)

💂作者简介: THUNDER王,阿里云社区专家博主,华为云云享专家,腾讯云社区认证作者,CSDN SAP应用技术领域优质创作者。在学习工作中,我通常使用偏后端的开发语言ABAP,SQL进行任务的完成…

聚焦数据库和新兴硬件的技术合力 中科驭数受邀分享基于DPU的数据库异构加速方案

随着新型硬件成本逐渐降低,充分利用新兴硬件资源提升数据库性能是未来数据库发展的重要方向之一,SIGMOD、VLDB、CICE数据库顶会上出现越来越多新兴硬件的论文和专题。在需求侧,随着数据量暴增和实时性的要求越来越高,数据库围绕处…

【TCP】三次握手 与 四次挥手 详解

三次握手 与 四次挥手 1. 三次握手2. 四次挥手三次握手和四次挥手的区别 在正常情况下,TCP 要经过三次握手建立连接,四次挥手断开连接 1. 三次握手 服务端状态转化: [CLOSED -> LISTEN] 服务器端调用 listen 后进入 LISTEN 状态&#xff…

系统架构设计师(第二版)学习笔记----信息安全系统及信息安全技术

【原文链接】系统架构设计师(第二版)学习笔记----信息加解密技术 文章目录 一、信息安全系统的组成框架1.1 信息安全系统组成框架1.2 信息安全系统技术内容1.3 常用的基础安全设备1.4 网络安全技术内容1.5 操作系统安全内容1.6 操作系统安全机制1.7 数据…

双节履带机械臂小车实现蓝牙遥控功能

1.功能描述 本文示例所实现的功能为:采用蓝牙远程遥控双节履带机械臂小车进行运动。 2.结构说明 双节履带机械臂小车,采用履带底盘,可适用于任何复杂地形。 前节履带抬起高度不低于10cm,可用于履带车进行爬楼行进。 底盘上装有一…