GCN,GraphSAGE 到底在训练什么呢?

news2025/1/22 16:42:28

根据DGL 来做的,按照DGL 实现来讲述

1. GCN Cora 训练代码:

import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv


class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h


def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
    for e in range(100):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that you should only compute the losses of the nodes in the training set.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if e % 5 == 0:
            print(
                f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})"
            )




if __name__ == "__main__" :
    dataset = dgl.data.CoraGraphDataset()
    # print(f"Number of categories: {dataset.num_classes}")
    g = dataset[0]
    g = g.to('cuda')
    model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes).to('cuda')
    train(g, model)

一些基础python torch.tensor语法概述:

1.  

if __name__ == "__main__" :
    XXXXXXX
    XXXXXXX

当我们直接执行这个脚本时,__name__属性被设置为__main__,因此满足if条件,语句块中的代码被调用。
但如果我们将该脚本作为模块导入到另一个脚本中,则__name__属性会被设置为模块的名称(例如"example"),语句块中的代码不会被执行。

2. 

# Compute prediction
pred = logits.argmax(1)    # 返回沿着第一个维度(即维度索引为1)的最大值的索引。
                           # 即,加入有5个样本,每个样本有3个维度的评分,那么就会给出没个样本3中维度评分最高的哪个维度的索引序号

 

3. numpy 关于 tensor 的一个用法:

在DGL 中使用一串 True 或 False 组成的 一维tensor 来标识 这个节点到底是属于 train test val 哪一类

train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]

而后,由于对于torch中的tensor来说:

就可以:select_label_tensor = labels[train_mask] 了

import torch

# 定义一个Tensor
tensor = torch.tensor([1, 2, 3, 4, 5])

# 定义一个布尔数组,选择索引为1和4的元素
mask = torch.tensor([False, True, False, False, True])

# 通过布尔索引选择元素
selected_tensor = tensor[mask]

print(selected_tensor)  # tensor([2, 5])

顺便,查看一个变量到底是什么类型可以使用 type() 函数:

train_mask = g.ndata["train_mask"]
print(type(train_mask))

# 输出为:
# <class 'torch.Tensor'>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1284380.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UVM验证环境 加入env

&#xff08;1&#xff09; 如何在UVM验证环境中例化reference model、scoreboard 如何在在验证平台中加入reference model、scoreboard&#xff0c;这个问题的解决方案是引入一个容器类&#xff0c;在这个容器类中实例化driver、monitor、reference model和scoreboard等。在…

Python 自动化办公:文件快速整理分类

平时桌面或文件夹内鱼龙混杂&#xff0c;各种类型的文件都有怎么办&#xff1f; 本篇文章中&#xff0c;我们将学习如何使用 Python 编写一个文件整理分类的脚本。 该脚本能够自动获取文件类型&#xff0c;并将文件按照类型整理到不同的子文件夹中。 先看下效果&#xff0c;…

新的 BLUFFS 攻击导致蓝牙连接不再私密

蓝牙是一种连接我们设备的低功耗无线技术&#xff0c;有一个新的漏洞需要解决。 中间的攻击者可以使用新的 BLUFFS 攻击轻松窥探您的通信。 法国研究中心 EURECOM 的研究员 Daniele Antonioli 演示了六种新颖的攻击&#xff0c;这些攻击被定义为 BLUFFS&#xff08;蓝牙转发和…

渗透测试学习day4

文章目录 靶机&#xff1a;SequelTask1Task2Task3Task4Task5Task6Task7Task8 靶机&#xff1a;CrocodileTask1Task2Task3Task4Task5Task6Task7Task8Task9Task10 靶机&#xff1a;ResponderTask1Task2Task3Task4Task5Task6Task7Task8Task9Task10Task11 靶机&#xff1a;ThreeTas…

使用百度开发者平台处理语音朗读

--TIME --百度开发者中心-汇聚、开放、助力、共赢 --注册账号 -- 准备工作 准备工作 更新时间&#xff1a;2023-01-13 成为开发者 三步完成账号的基本注册与认证&#xff1a; STEP1&#xff1a;点击进入控制台&#xff0c;选择需要使用的AI服务项。若为未登录状态&#xf…

CleanMyMac X2024破解注册激活码

CleanMyMac X for Mac中文2024版只需两个简单步骤就可以把系统里那些乱七八糟的无用文件统统清理掉&#xff0c;节省宝贵的磁盘空间。 cleanmymac x个人认为X代表界面上的最大升级&#xff0c;功能方面有更多增加&#xff0c;与最新macOS系统更加兼容&#xff0c;流畅地与系统性…

MacBook Pro 安装Redis【超详细图解】

目录 一、使用brew安装Redis 二、查看安装及配置文件位置 三、启动Redis 3.1 查看redis服务进程 3.2 redis-cli连接redis服务 四、关闭Redis 因项目需要&#xff0c;顺便记录安装过程 一、使用brew安装Redis brew install redis 如图所示即为安装成功&#xff01; 二…

csdn语法说明/csdn新手指导/csdn入门指导/csdn博文助手

文章目录 1、文章目录2、标题3、文本样式3.1、强调、加粗、黄色标记、删除、引用、乘方&#xff0c;化学表达式3.2、标红、按钮效果 4、功能快捷键5、注脚、注释6、链接7、图片8、列表9、表格 本篇博文主要写一下csdn博文中的语法说明。 1、文章目录 [TOC](这里写自定义目录标…

谈谈压测方案的那点事

&#x1f4e2;专注于分享软件测试干货内容&#xff0c;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01;&#x1f4e2;交流讨论&#xff1a;欢迎加入我们一起学习&#xff01;&#x1f4e2;资源分享&#xff1a;耗时200小时精选的「软件测试」资…

论文解读:EfficientViT-提高吞吐量

摘要 要解决的问题 Vision transformers have shown great success due to their high model capabilities. However, their remarkable performance is accompanied by heavy computation costs, which makes them unsuitable for real-time applications. vit计算开销大&a…

人工智能学习7(决策树算法)

编译工具&#xff1a;PyCharm 文章目录 编译工具&#xff1a;PyCharm 决策树算法信息熵信息熵例题计算&#xff1a; 信息增益&#xff08;决策树划分依据之一ID3&#xff09;信息增益例题计算&#xff1a; 信息增益率(决策树划分依据之一C4.5)基尼值和基尼指数(决策树划分依据之…

蓝桥杯物联网竞赛_STM32L071KBU6_全部工程及源码

包含stm32L071kbu6全部实验工程、源码、原理图、官方提供参考代码及原理图 链接&#xff1a;https://pan.baidu.com/s/1xm8mLotLBvOULQlg76ca7g?pwdp0mx 提取码&#xff1a;p0mx

边缘计算网关构建智慧楼宇新生态,打造未来建筑管理

边缘计算网关在无人值守环境中的应用十分广泛&#xff0c;尤其在智慧楼宇管理方面发挥着重要作用。它能够实现多个地点多楼宇之间的数据实时互通&#xff0c;通过边缘计算网关物联网应用构建智慧楼宇生态系统&#xff0c;解决传统楼宇管理网络布线、人员巡检以及后期运维等问题…

SQL Server 数据库,创建数据表(使用T-SQL语句)

2.3表的基本概念 表是包含数据库中所有数据的数据库对象。数据在表中的组织方式与在电子表格中相似&#xff0c;都是 按行和列的格式组织的&#xff0c;每行代表一条唯一的记录&#xff0c;每列代表记录中的一个字段.例如&#xff0c;在包含公 司员工信息的表中&#xff0c;每行…

深度学习模型部署与优化:关键考虑与实践策略

4. 深度学习模型部署与优化&#xff1a;关键考虑与实践策略 4.1 FLOPS TOPS 首先&#xff0c;我们来解释FLOPS和TOPS的含义&#xff1a; FLOPS&#xff1a;是Floating Point Operations Per Second的缩写&#xff0c;意思是每秒浮点运算次数。它是衡量计算机或计算设备在每秒…

深入理解JVM虚拟机第二十七篇:详解JVM当中InvokeDynamic字节码指令,Java是动态类型语言么?

😉😉 学习交流群: ✅✅1:这是孙哥suns给大家的福利! ✨✨2:我们免费分享Netty、Dubbo、k8s、Mybatis、Spring...应用和源码级别的视频资料 🥭🥭3:QQ群:583783824 📚📚 工作微信:BigTreeJava 拉你进微信群,免费领取! 🍎🍎4:本文章内容出自上述:Sp…

前端漏洞--front(系统有user1/user1,admin1/admin1两个用户)

任务一&#xff1a;挖掘反射型XSS漏洞&#xff08;以弹窗test13&#xff09;证明 任务二&#xff1a;复现环境中的CSRF漏洞&#xff0c;设计表单&#xff0c;当管理员点击URL后自动将自己密码重置为&#xff1a;123456 任务三&#xff1a;复现环境中的JSON Hijacking漏洞&#…

【MySQL环境配置在虚拟机中】

MySQL环境配置在虚拟机中 先检查虚拟机中是否有MySQL在线安装1.下载yum Repository2.安装yum Repository3.安装mysql5.7的服务4.开机自启动5.启动mysql6.查看状态7.获取临时密码8.登录mysql9.关闭密码复杂验证10.设置密码11.修改权限12.卸载yum Repository 离线安装1.先找一下机…

【题目】栈和队列专题

文章目录 专题一&#xff1a;栈系列1. 中缀表达式转后缀表达式&#xff08;逆波兰式&#xff09;2. 有效的括号3. 用栈实现队列4. 最小栈 专题一&#xff1a;栈系列 1. 中缀表达式转后缀表达式&#xff08;逆波兰式&#xff09; 算法原理 2. 有效的括号 题目链接 算法原理 代…

uniapp:如何使用uCharts

目录 第一章 前言 第二章 安装插件uCharts 第三章 使用uCharts 第四章 注意 第一章 前言 需求&#xff1a;这是很久之前的一个项目的需求了&#xff0c;当时我刚接触app&#xff0c;有这么一个需求&#xff0c;在uniapp写的app项目中做一些图表统计&#xff0c;最开始以为…