PyG-GCN-Cora(在Cora数据集上应用GCN做节点分类)

news2024/11/27 18:32:59

文章目录

  • model.py
  • main.py
  • 参数设置
  • 注意事项
  • 运行图

model.py

import torch.nn as nn
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
class gcn_cls(nn.Module):
    def __init__(self,in_dim,hid_dim,out_dim,dropout_size=0.5):
        super(gcn_cls,self).__init__()
        self.conv1 = GCNConv(in_dim,hid_dim)
        self.conv2 = GCNConv(hid_dim,hid_dim)
        self.fc = nn.Linear(hid_dim,out_dim)
        self.relu  = nn.ReLU()
        self.dropout_size = dropout_size
    def forward(self,x,edge_index):
        x = self.conv1(x,edge_index)
        x = F.dropout(x,p=self.dropout_size,training=self.training)
        x = self.relu(x)
        x = self.conv2(x,edge_index)
        x = self.relu(x)
        x = self.fc(x)
        return x

main.py

import torch
import torch.nn as nn
from torch_geometric.datasets import Planetoid
from model import gcn_cls
import torch.optim as optim
dataset = Planetoid(root='./data/Cora', name='Cora')
print(dataset[0])
cora_data = dataset[0]

epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7

net = gcn_cls(cora_data.x.shape[1],hidden_dim,output_dim)
optimizer = optim.AdamW(net.parameters(),lr=lr,weight_decay=weight_decay)
#optimizer = optim.SGD(net.parameters(),lr = lr,momentum=momentum)
criterion = nn.CrossEntropyLoss()
print("****************Begin Training****************")
net.train()
for epoch in range(epochs):
    out = net(cora_data.x,cora_data.edge_index)
    optimizer.zero_grad()
    loss_train = criterion(out[cora_data.train_mask],cora_data.y[cora_data.train_mask])
    loss_val   = criterion(out[cora_data.val_mask],cora_data.y[cora_data.val_mask])
    loss_train.backward()
    print('epoch',epoch+1,'loss-train {:.2f}'.format(loss_train),'loss-val {:.2f}'.format(loss_val))
    optimizer.step()

net.eval()
out = net(cora_data.x,cora_data.edge_index)
loss_test = criterion(out[cora_data.test_mask],cora_data.y[cora_data.test_mask])
_,pred = torch.max(out,dim=1)
pred_label = pred[cora_data.test_mask]
true_label = cora_data.y[cora_data.test_mask]
acc = sum(pred_label==true_label)/len(pred_label)
print("****************Begin Testing****************")
print('loss-test {:.2f}'.format(loss_test),'acc {:.2f}'.format(acc))

参数设置

epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7

output_dim是输出维度,也就是有多少可能的类别。

注意事项

1.发现loss不下降:
建议改一改lr(学习率),我做的时候开始用的SGD,学习率设的0.01发现loss不下降,改成0.1后好了很多。如果用AdamW,0.001(1e-3)基本就够用了

运行图

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1023250.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【PyTorch 攻略】(6-7/7)

一、说明 本篇介绍模型模型的参数,模型推理和使用,保存加载。 二、训练参数和模型 在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网…

CockroachDB集群部署

CockroachDB集群部署 1、CockroachDB简介 CockroachDB(有时简称为CRDB)是一个免费的、开源的分布式 SQL 数据库,它建立在一个事务性和强一致性的键 值存储之上。它由 PebbleDB(一个受 RocksDB/leveldb 启发的 K/B 存储库)支持,并使用 Raft 分布式共识…

利用Java EE相关技术实现一个简单的Web聊天室系统

利用Java EE相关技术实现一个简单的Web聊天室系统 (1)编写一个登录页面,登录信息中有用户名和密码,分别用两个按钮来提交和重置登录信息。 (2)通过请求指派来处理用户提交的登录信息,如果用户名…

基于YOLOv8模型的烟火目标检测系统(PyTorch+Pyside6+YOLOv8模型)

摘要:基于YOLOv8模型的烟火目标检测系统可用于日常生活中检测与定位烟火目标,利用深度学习算法可实现图片、视频、摄像头等方式的目标检测,另外本系统还支持图片、视频等格式的结果可视化与结果导出。本系统采用YOLOv8目标检测算法训练数据集…

linux中的开发工具

在刚开始使用linux的时候,我们需要在系统上写一些简单的代码,来熟悉环境以及各种指令 并且熟悉属于linux的一套开发的环境,而这对于c来说需要三个软件就可以进行简单的编码 和使用,让我们来认识一下下列工具,以及工具的…

【Java 基础篇】Java字符打印流详解:文本数据的输出利器

在Java编程中,我们经常需要将数据输出到文件或其他输出源中。Java提供了多种输出流来帮助我们完成这项任务,其中字符打印流是一个非常有用的工具。本文将详细介绍Java字符打印流的用法,以及如何在实际编程中充分利用它。 什么是字符打印流&a…

电脑丢失d3dcompiler47.dll怎么办,这个四个修复方法都可以解决

d3dcompiler_47.dll 是一个与 DirectX 相关的动态链接库文件,它包含了 DirectX 编译器的一些函数和类,对于许多应用程序和游戏来说都是必需的。如果您的系统中缺失了这个文件,可能会导致程序无法正常运行。下面我们将介绍四个修复 d3dcompile…

(图论) 1020. 飞地的数量 ——【Leetcode每日一题】

❓ 1020. 飞地的数量 难度:中等 给你一个大小为 m x n 的二进制矩阵 grid ,其中 0 表示一个 海洋单元格、1 表示一个 陆地单元格。 一次 移动 是指从一个陆地单元格走到另一个相邻(上、下、左、右)的陆地单元格或跨过 grid 的边…

vant 组件库的基本使用

文章目录 vant组件库1、什么是组件库2、vant组件 全部导入 和 按需导入的区别3、全部导入的使用步骤:4、按需导入的使用步骤:5、封装vant文件包 vant组件库 该项目将使用到vant-ui组件库,这里的目标就是认识他,铺垫知识 1、什么…

PyG-GAT-Cora(在Cora数据集上应用GAT做节点分类)

文章目录 model.pymain.py参数设置运行图 model.py import torch.nn as nn from torch_geometric.nn import GATConv import torch.nn.functional as F class gat_cls(nn.Module):def __init__(self,in_dim,hid_dim,out_dim,dropout_size0.5):super(gat_cls,self).__init__()s…

安达发APS|国货品牌崛起,制造业迎来智能排产新机遇

随着国货品牌的不断崛起,制造业的生产也面临着巨大的挑战。为应对这一挑战,越来越多的企业开始引入APS智能排产技术,以优化生产线布局、提升设备利用率、缩短生产周期、减少生产成本,从而增强市场竞争力。本文将为您详细解读APS智…

数据结构-----栈(栈的初始化、建立、入栈、出栈、遍历、清空等操作)

目录 前言 栈 1.定义 2.栈的特点 3.栈的储存方式 3.1数组栈 3.2链栈 4.栈的基本操作(C语言) 4.1初始化 4.2判断是否满栈 4.3判断空栈 4.4 入栈 4.5 出栈 4.6获取栈顶元素 4.7遍历栈 4.8清空栈 完整代码示例 前言 大家好呀!今天我…

登录业务实现

登录业务实现: 登录成功/失败实现 -> pinia管理用户数据及数据持久化 -> 不同登录状态的模板适配 -> 请求拦截器携带token -> 退出登录实现 -> token失效(401响应拦截) 1. 登录成功/失败实现 当表单校验通过时&a…

华为云云耀云服务器L实例评测|云耀云服务器L实例部署odoo开源ERP平台

华为云云耀云服务器L实例评测|云耀云服务器L实例部署odoo开源ERP平台 一、云耀云服务器L实例介绍1.1 云耀云服务器L实例简介1.2 云耀云服务器L实例使用场景1.3 云耀云服务器L实例特点 二、odoo介绍2.1 odoo简介2.2 odoo特点 三、本次实践介绍3.1 本次实践简介3.2 本…

解决Java应用程序中的SQLException:服务器时区值未识别问题;MySQL连接问题:服务器时区值 ‘Öйú±ê׼ʱ¼ä‘ 未被识别的解决方法

目录 ​编辑 问题背景 解决方案 问题背景 今天遇见一个这个问题,解决后发出来分享一下: java.sql.SQLException: The server time zone value is unrecognized or represents more than one time zone. You must configure either the server or J…

STP介绍

目录 STP概述 二层环路带来的问题 1.广播风暴 2.MAC地址漂移问题 3.多帧复制---这个好理解,同一个数据帧被重复收到多次,被称为多帧复制。 802.1D生成树 STP的BPDU BPDU主要分为两大类 配置BPDU RPC COST 配置BPDU的工作过程 TCN BPDU TCN…

【python爬虫】——历史天气信息爬取

文章目录 1、任务描述1.1、需求分析1.2 页面分析 2、获取网页源码、解析、保存数据3、结果展示 1、任务描述 1.1、需求分析 在2345天气信息网2345天气网依据地点和时间对相关城市的历史天气信息进行爬取。 1.2 页面分析 网页使用get方式发送请求,所需参数包括a…

c语言练习63:用malloc开辟二维数组的三种办法

用malloc开辟二维数组的三种办法 使用malloc函数模拟开辟一个3*5的整型二维数组&#xff0c;开辟好后&#xff0c;使用二维数组的下标访问形式&#xff0c;访问空间。 第一种办法&#xff1a;用指针数组&#xff1a; #include<stdio.h> int main() {int** p (int**)m…

2023-09-19 LeetCode每日一题(打家劫舍 IV)

2023-09-19每日一题 一、题目编号 2560. 打家劫舍 IV二、题目链接 点击跳转到题目位置 三、题目描述 沿街有一排连续的房屋。每间房屋内都藏有一定的现金。现在有一位小偷计划从这些房屋中窃取现金。 由于相邻的房屋装有相互连通的防盗系统&#xff0c;所以小偷 不会窃取…

【C++代码】二叉树的最大深度,二叉树的最小深度,完全二叉树的节点个数--代码随想录

题目&#xff1a;二叉树的最大深度 给定一个二叉树 root &#xff0c;返回其最大深度。二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 题解 如果我们知道了左子树和右子树的最大深度 l 和 r&#xff0c;那么该二叉树的最大深度即为 m a x ( l , r ) …