GraphSAGE 到底在训练什么? 图上的Mini-Batch 是怎么训练的 ?

news2024/10/5 12:56:15

1. 一个端到端的 同构图(Cora数据集)节点分类代码:

import argparse

import dgl
import dgl.nn as dglnn

import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset


class SAGE(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # two-layer GraphSAGE-mean
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "gcn"))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, "gcn"))
        self.dropout = nn.Dropout(0.5)

    def forward(self, graph, x):
        h = self.dropout(x)
        for l, layer in enumerate(self.layers):
            h = layer(graph, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h


def evaluate(g, features, labels, mask, model):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)


def train(g, features, labels, masks, model):
    # define train/val samples, loss function and optimizer
    train_mask, val_mask = masks
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)

    # training loop
    for epoch in range(200):
        model.train()
        logits = model(g, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = evaluate(g, features, labels, val_mask, model)
        print(
            "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                epoch, loss.item(), acc
            )
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GraphSAGE")
    parser.add_argument(
        "--dataset",
        type=str,
        default="cora",
        help="Dataset name ('cora', 'citeseer', 'pubmed')",
    )
    parser.add_argument(
        "--dt",
        type=str,
        default="float",
        help="data type(float, bfloat16)",
    )
    args = parser.parse_args()
    print(f"Training with DGL built-in GraphSage module")

    # load and preprocess dataset
    transform = (
        AddSelfLoop()
    )  # by default, it will first remove self-loops to prevent duplication
    if args.dataset == "cora":
        data = CoraGraphDataset(transform=transform)
    elif args.dataset == "citeseer":
        data = CiteseerGraphDataset(transform=transform)
    elif args.dataset == "pubmed":
        data = PubmedGraphDataset(transform=transform)
    else:
        raise ValueError("Unknown dataset: {}".format(args.dataset))
    g = data[0]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    g = g.int().to(device)
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    masks = g.ndata["train_mask"], g.ndata["val_mask"]

    # create GraphSAGE model
    in_size = features.shape[1]
    out_size = data.num_classes
    model = SAGE(in_size, 16, out_size).to(device)

    # convert model and graph to bfloat16 if needed
    if args.dt == "bfloat16":
        g = dgl.to_bfloat16(g)
        features = features.to(dtype=torch.bfloat16)
        model = model.to(dtype=torch.bfloat16)

    # model training
    print("Training...")
    train(g, features, labels, masks, model)

    # test the model
    print("Testing...")
    acc = evaluate(g, features, labels, g.ndata["test_mask"], model)
    print("Test accuracy {:.4f}".format(acc))

2. GraphSAGE的实现 : SAGEConv 类:

我们先来介绍一下DGL对GraphSAGE这个模型的实现:SAGEConv() 在三方库的下述位置:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1305585.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

YOLOv8训练好的pt文件如何用来预测

1. 使用原版预测 代码如下: from ultralytics import YOLO# Load a model model YOLO(yolov8n.pt) # load an official model# Predict with the model results model(https://ultralytics.com/images/bus.jpg) # predict on an image命令如下: y…

linux 17day 堡垒机 堡垒机下载 堡垒机安装 堡垒机使用 堡垒机管理服务器 堡垒机管理数据库

目录 堡垒机官网堡垒机下载堡垒机安装堡垒机使用Linux系统上使用web 使用配置站点url配置邮箱创建要管理的服务器创建 特权用户用于管理创建普通用户添加命令过滤命令过滤创建好 之后就需要 给用户名 和管理添加web用户登录 添加数据库mysql 服务区创建用户创建系统用户创建mys…

CRM对小微公司的实际作用:从客户管理到业务拓展

公司作为一个组织,管理方面是重中之重。传统式的人力会是一个较为费时费力的大工程。随着科技的发展,CRM系统完全可以胜任企业管理的工作。那么,CRM有什么特点?对小微公司有哪些作用? 1、提高管理效率 传统的客户管理…

DataFunSummit:2023年数据治理在线峰会-核心PPT资料下载

一、峰会简介 数据治理(Data Governance)是组织中涉及数据使用的一整套管理行为。由企业数据治理部门发起并推行,关于如何制定和实施针对整个企业内部数据的商业应用和技术管理的一系列政策和流程。 数据治理是一个通过一系列信息相关的过程…

MySQL笔记-第09章_子查询

视频链接:【MySQL数据库入门到大牛,mysql安装到优化,百科全书级,全网天花板】 文章目录 第09章_子查询1. 需求分析与问题解决1.1 实际问题1.2 子查询的基本使用1.3 子查询的分类 2. 单行子查询2.1 单行比较操作符2.2 代码示例2.3 …

生化危机5无法启动丢失xlive.dll怎么修复?快速修复教程分享

xlive.dll丢失的5个解决方法与xlive.dll文件丢失原因以及xlive.dll丢失对电脑有什么影响介绍 一、xlive.dll文件丢失原因: 1. 病毒或恶意软件感染:某些病毒或恶意软件会删除或损坏系统文件中的xlive.dll文件,导致其丢失。 2. 误操作&#…

作业12.12

1.闹钟 主函数 #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this);t new QTimer(this);idstartTimer(1000);speecher new QTextToSpeech(this); }Widget::~Wid…

QML WebEngineView 全屏和退出

作者: 一去、二三里 个人微信号: iwaleon 微信公众号: 高效程序员 在使用浏览器时,我们经常会用到全屏模式,最常见的场景有:观看视频、阅读文章、在线演示等。全屏模式的优点在于,它可以让用户充分地利用有限的屏幕空间,更好地专注于内容本身,从而提供丰富的沉浸式视觉…

视频剪辑入门:视频批量嵌套合并,成为视频编辑达人

随着数字媒体的快速发展,视频剪辑已经成为一项非常流行的技能。如果对视频剪辑感兴趣,想学习如何将多个视频批量嵌套合并,下面是云炫AI智剪批量嵌套合并视频的一些简单步骤,高效剪辑,成为视频编辑达人不再难。 准备要视…

金融银行软件测试超大型攻略,最受欢迎的金融银行大揭秘附面试题

零、为什么做金融类软件测试 举个栗子,银行里的软件测试工程师。横向跟互联网公司里的测试来说,薪资相对稳定,加班少甚至基本没有,业务稳定。实在是测试类岗位中的香饽饽! 一、什么是金融行业 金融业是指经营金融商…

简单自定义vuex的设计思路

vuex集中式存储管理应用所有组件的状态,并以响应的规则保证状态以可预测的方式 发生变化。 步骤: 1.Store类,保存选项,_mutations,_actions,getters 2.响应式状态:new Vue方式设置响应式。 …

Java网络编程,使用UDP实现TCP(三), 基本实现四次挥手

简介 四次挥手示意图 在四次挥手过程中,第一次挥手中的Seq为本次挥手的ISN, ACK为 上一次挥手的 Seq1,即最后一次数据传输的Seq1。挥手信息由客户端首先发起。 实现步骤: 下面是TCP四次挥手的步骤: 第一次挥手&…

环境变量提权

环境变量提权 借鉴文章LINUX提权之环境变量提权篇 - 知乎 (zhihu.com) 利用条件 存在一个文件,利用su权限执行,普通用户可以执行此文件,但只限制在一个目录下可以执行 利用方式 将此文件的目录添加到环境变量中 export PATH/tmp:$PATHe…

分层自动化测试的实战思考!

自动化测试的分层模型 自动化测试的分层模型,我们应该已经很熟悉了,按照分层测试理念,自动化测试的投入产出应该是一个金字塔模型。越是向下,投入/产出比就越高,但开展的难易程度/成本和技术要求就越高,但…

Linux安装Halo(个人网站)

docker安装 curl -fsSL https://get.docker.com | bash -s docker --mirror Aliyun && systemctl start docker && systemctl enable docker && sudo mkdir -p /etc/docker && sudo tee /etc/docker/daemon.json <<-EOF && sudo…

开题PPT答辩复盘

目录 总体思路加粗和红体字使用研究现状之后主要研究内容讨论 总体思路 分为五个部分&#xff0c;规定在10分钟以内讲完。这次开题答辩&#xff0c;主要是要讲清楚研究背景和意义&#xff0c;国内外研究现状。因此前两部分需要花大概6分钟重点解释&#xff0c;主要研究内容用2…

提升--22---ReentrantReadWriteLock读写锁

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 ReadWriteLock----读写锁1.读写锁介绍线程进入读锁的前提条件&#xff1a;线程进入写锁的前提条件&#xff1a;而读写锁有以下三个重要的特性&#xff1a; Reentran…

jQuery遍历与删除添加节点

个人名片&#xff1a; &#x1f60a;作者简介&#xff1a;一名大二在校生 &#x1f921; 个人主页&#xff1a;坠入暮云间x &#x1f43c;座右铭&#xff1a;懒惰受到的惩罚不仅仅是自己的失败&#xff0c;还有别人的成功。 &#x1f385;**学习目标: 坚持每一次的学习打卡 文章…

如何通过 SSH 访问 VirtualBox 的虚机

VirtualBox 是一款免费虚机软件。在用户使用它安装了 linux 以后&#xff0c;它默认只提供了控制台的管理画面。 直接使用控制台管理 Linux 没有使用诸如 putty 或者 vscode 这样的 ssh 远程管理工具方便。那么可不可以直接使用 ssh 访问 VirtualBox 上的 Linux 呢&#xff1f…

关于碰撞试验

主要参数&#xff1a; 冲击与碰撞试验的主要参数及调整方法 - 百度文库 碰撞试验的技术指标包括&#xff1a;峰值加速度、脉冲持续时间、速度变化量&#xff08;半正弦波&#xff09;、每方向碰撞次数。 加速度&#xff1a;冲击的强度&#xff0c;单位为g&#xff1b;一般为3…