GNN初探

news2024/11/16 10:56:33

  测试了下网上找的一篇代码,运行成功~

# import sys
# print(sys.path)

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid

# 加载并预处理Cora数据集
dataset_path = './dataset/cora'
dataset = Planetoid(root=dataset_path, name='Cora')
data = dataset[0]  # 取第一个图数据

# 打印数据集信息以确认加载成功
print(f'Dataset: {dataset}')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_node_features}')
print(f'Number of classes: {dataset.num_classes}')
print(data)


class GNN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(num_features, 16)  # 第一个卷积层
        self.conv2 = GCNConv(16, num_classes)  # 第二个卷积层

    def forward(self, data):
        x, edge_index = data.x, data.edge_index  # 获取节点特征和边索引
        x = self.conv1(x, edge_index)  # 第一个卷积
        x = F.relu(x)  # 激活函数
        x = self.conv2(x, edge_index)  # 第二个卷积
        return F.log_softmax(x, dim=1)  # 返回经过softmax的输出

model = GNN(num_features=dataset.num_node_features, num_classes=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

print(model)

# 训练过程
model.train()
for epoch in range(200):
    optimizer.zero_grad()  # 清除之前的梯度
    out = model(data)  # 前向传播  data如何传递给forward函数的?
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # 计算损失
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数
    if epoch % 20 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

model.eval()
_, pred = model(data).max(dim=1)  # 获取预测结果
correct = pred[data.test_mask] == data.y[data.test_mask]  # 比较预测与真实标签
accuracy = int(correct.sum()) / int(data.test_mask.sum())  # 计算准确率
print(f'Test Accuracy: {accuracy:.4f}')

 遇到的问题:

1.python是脚本语言,没有编译过程,因此对齐要非常小心,创建model这一行干脆不要空格,否则class的结束部分会报错。

一般语句不需要缩进,顶行书写且不留空白。

当表示分支、循环、函数、类等含义,在if,while,for,def,class等保留字所在的完整语句后通过英文冒号(:)结尾,并在之后进行缩进,表示前后代码之间的从属关系

2.Plaintoid数据本地读取的问题,从github下载后放到/dataset/cora/raw下面,运行成功~

参考: 1 .GNNpython代码实现_mob649e816347dd的技术博客_51CTO博客

         2. https://zhuanlan.zhihu.com/p/452747749

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2241455.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

初识算法 · 位运算(end)

目录 前言: 题目解析 算法原理 算法编写 前言: 本文作为初识算法 位运算的最后一篇文章,使用一道hard题目来结束这个专题,题目的链接为: 面试题 17.19. 消失的两个数字 - 力扣(LeetCode)…

3. Spring Cloud Eureka 服务注册与发现(超详细说明及使用)

3. Spring Cloud Eureka 服务注册与发现(超详细说明及使用) 文章目录 3. Spring Cloud Eureka 服务注册与发现(超详细说明及使用)前言1. Spring Cloud Eureka 的概述1.1 服务治理概述1.2 服务注册与发现 2. 实践:创建单机 Eureka Server 注册中心2.1 需求说明 图解…

springboot实现简单的数据查询接口(无实体类)

目录 前言:springboot整体架构 1、ZjGxbMapper.xml 2、ZjGxbMapper.java 3、ZjGxbService.java 4、ZjGxbController.java 5、调用接口测试数据是否正确 6、打包放到服务器即可 前言:springboot整体架构 文件架构,主要编写框选的这几类…

awk(常用)

这个有点难 O.o 一、awk # 语法 awk 参数 模式 {动作} 文件# 第一列,包含p的 $1~"p" # 第一列,不包含p的 $1!~"p" # 开始时干嘛,结束时干嘛 awk BEGIN{开始时做的事}END{结束时做的事}{print $0} 文件 1、内置变量&…

EXPLAIN优化慢SQL

项目中发现数据查询很慢,导致前端超时等待的问题。经过日志打印发现,查询sql耗时10秒以上,相关sql如下: select distincttablemodel.*from pjtask_model tablemodelJOIN buss_type_permission a ON (tablemodel.fields_data_id …

Skywalking搭建-来自于图灵课堂

Skywalking主要用于链路追踪,日志收集查看,异常日志查看,服务监控弱一些,服务器监控可以使用prometheus 一、搭建服务端,使用startup.bat启动 配置持久化,如果是用mysql持久化,拷贝mysql链接包…

ZooKeeper单机、集群模式搭建教程

单点配置 ZooKeeper在启动的时候,默认会读取/conf/zoo.cfg配置文件,该文件缺失会报错。因此,我们需要在将容器/conf/挂载出来,在制定的目录下,添加zoo.cfg文件。 zoo.cfg logback.xml 配置文件的信息可以从二进制包…

计算机网络(11)和流量控制补充

这一篇对数据链路层中的和流量控制进行详细学习 流量控制(Flow Control)是计算机网络中确保数据流平稳传输的技术,旨在防止数据发送方发送过多数据,导致接收方的缓冲区溢出,进而造成数据丢失或传输失败。流量控制通常…

二元一次不定方程@整数解问题

文章目录 二元一次不定方程|整数解定理1整数解存在充要条件定理2 通解特解知识回顾利用辗转相除法求例 使用表达式凑出通解 二元一次不定方程|整数解 二元一次不定方程的一般形式为 a x b y c ax by c axbyc(1) 其中 a a a、 b b b、 c c c 是整数,且 a a a…

深入理解Flutter生命周期函数之StatefulWidget(一)

目录 前言 1.为什么需要生命周期函数 2.开发过程中常用的生命周期函数 1.initState() 2.didChangeDependencies() 3.build() 4.didUpdateWidget() 5.setState() 6.deactivate() 7.dispose() 3.Flutter生命周期总结 1.调用顺序 2.函数调用时机以及主要作用 4.生…

llama factory lora 微调 qwen2.5 7B Instruct模型

项目背景 甲方提供一台三卡4080显卡 需要进行qwen2.5 7b Instruct模型进行微调。以下为整体设计。 要使用 LLaMA-Factory 对 Qwen2.5 7B Instruct模型 进行 LoRA(Low-Rank Adapters)微调,流程与之前提到的 Qwen2 7B Instruct 模型类似。LoRA …

机器学习day2-特征工程

四.特征工程 1.概念 一般使用pandas来进行数据清洗和数据处理、使用sklearn来进行特征工程 将任意数据(文本或图像等)转换为数字特征,对特征进行相关的处理 步骤:1.特征提取;2.无量纲化(预处理&#xf…

Llama架构及代码详解

Llama的框架图如图: 源码中含有大量分布式训练相关的代码,读起来比较晦涩难懂,所以我们对llama自顶向下进行了解析及复现,我们对其划分成三层,分别是顶层、中层、和底层,如下: Llama的整体组成…

stm32在linux环境下的开发与调试

环境安装 注:文末提供一键脚本 下载安装stm32cubeclt 下载地址为:https://www.st.com/en/development-tools/stm32cubeclt.html 选择 linux版本下载安装 安装好后默认在家目录st下 > $ ls ~/st/stm32cubeclt_1.16.0 …

第T7周:Tensorflow实现咖啡豆识别

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 具体实现 (一)环境 语言环境:Python 3.10 编 译 器: PyCharm 框 架: (二)具体步骤 1. 使…

亲测有效:Maven3.8.1使用Tomcat8插件启动项目

我本地maven的settings.xml文件中的配置&#xff1a; <mirror><id>aliyunmaven</id><mirrorOf>central</mirrorOf><name>阿里云公共仓库</name><url>https://maven.aliyun.com/repository/public</url> </mirror>…

LLM - 使用 LLaMA-Factory 微调大模型 Qwen2-VL SFT(LoRA) 图像数据集 教程 (2)

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/143725947 免责声明&#xff1a;本文来源于个人知识与公开资料&#xff0c;仅用于学术交流&#xff0c;欢迎讨论&#xff0c;不支持转载。 LLaMA-…

神经网络与Transformer详解

一、模型就是一个数学公式 模型可以描述为:给定一组输入数据,经过一系列数学公式计算后,输出n个概率,分别代表该用户对话属于某分类的概率。 图中 a, b 就是模型的参数,a决定斜率,b决定截距。 二、神经网络的公式结构 举例:MNIST包含了70,000张手写数字的图像,其中…

鲸鱼机器人和乐高机器人的比较

鲸鱼机器人和乐高机器人各有其独特的优势和特点&#xff0c;家长在选择时可以根据孩子的年龄、兴趣、经济能力等因素进行综合考虑&#xff0c;选择最适合孩子的教育机器人产品。 优势 鲸鱼机器人 1&#xff09;价格亲民&#xff1a;鲸鱼机器人的产品价格相对乐高更为亲民&…

Flink Source 详解

Flink Source 详解 原文 flip-27 FLIP-27 介绍了新版本Source 接口定义及架构 相比于SourceFunction&#xff0c;新版本的Source更具灵活性&#xff0c;原因是将“splits数据获取”与真“正数据获取”逻辑进行了分离 重要部件 Source 作为工厂类&#xff0c;会创建以下两…