【图神经网络】手把手带你快速上手OpenHGNN

news2024/11/24 9:31:34

手把手带你快速上手OpenHGNN

  • 1. 评估新的数据集
    • 1.1 如何构建一个新的数据集
  • 2. 使用一个新的模型
    • 2.1 如何构建一个新模型
  • 3. 应用到一个新场景
    • 3.1 如何构建一个新任务
    • 3.2 如何构建一个新的trainerflow
  • 内容来源

1. 评估新的数据集

如果需要,可以指定自己的数据集。本节中,我们使用HGBn-ACM作为节点分类数据集的示例。

1.1 如何构建一个新的数据集

第一步:预处理数据集
这里给出了一个处理HGBn-ACM的演示,这是一个节点分类数据集

首先,下载HGBn-ACM数据集:HGB数据集。下载完成后,需要将其处理为一个dgl.heterograph

以下代码片段是在DGL中创建异构图的示例。

import dgl
import torch as th

graph_data = {
    ('drug','interacts', 'drug'): (th.tensor([0,1]), th.tensor([1,2])),
    ('drug','interacts', 'gene'): (th.tensor([0,1]), th.tensor([2,3])),
    ('drug','treats','disease'): (th.tensor([1]), th.tensor([2]))
}
graph_data

graph_data
canonical_etypes
推荐将feature name设置为h

g.nodes['drug'].data['h'] = th.ones(3, 1)

DGL提供了dgl.save_graphs()dgl.load_graphs()分别表示保存和加载二进制形式的异质图。因此,这里使用dgl.save_graphs保存graphs到磁盘中:

dgl.save_graphs('demo_graph.bin',g)

第二步:增加额外的信息
经过第一步,得到一个demo_graph.bin的二进制文件,然后我们将其移动到openhgnn/dataset/目录下,下一步的具体信息在NodeClassificationDataset.py

例如,我们将category,num_classes和multi_label(if necessary) 设置为paper3True,分别表示要预测类的节点类型、类的数量以及任务是否为多标签分类。有关详细信息,请参阅基本节点分类数据集。
加载dgl
增加额外的信息:

if name_dataset == 'demo_graph':
    data_path = './openhgnn/dataset/demo_graph.bin'
    g, _ = load_graphs(data_path)
    g = g[0].long()
    self.category = 'author'  # 增加额外的信息
    self.num_classes = 4
    self.multi_label = False

第三步:可选
使用demo_graph作为数据集,评估一个存在的模型:

python main.py -m GTN -d demo_graph -t node_classification -g 0 --use_best_config

如果有另一个数据集名称,那需要修改代码build_dataset

2. 使用一个新的模型

这一部分,我们创建一个模型,名为RGAT,它不在我们的模型package <api-model>。

2.1 如何构建一个新模型

第一步:注册器模型
我们创建一个继承基本模型(Base Model)的类RGAT,并使用@register_model(str)注册该模型。

from openhgnn.models import BaseModel, register_model
@register_model('RGAT')
class RGAT(BaseModel):
    ...

第二步:实现函数
必须实现类方法build_model_from_args,其他函数像__init__,forward

...
class RGAT(BaseModel):
    @classmethod
    def build_model_from_args(cls, args, hg):
        return cls(in_dim=args.hidden_dim,
                   out_dim=args.hidden_dim,
                   h_dim=args.out_dim,
                   etypes=hg.etypes,
                   num_heads=args.num_heads,
                   dropout=args.dropout)

    def __init__(self, in_dim, out_dim, h_dim, etypes, num_heads, dropout):
        super(RGAT, self).__init__()
        self.rel_names = list(set(etypes))
        self.layers = nn.ModuleList()
        self.layers.append(RGATLayer(
            in_dim, h_dim, num_heads, self.rel_names, activation=F.relu, dropout=dropout))
        self.layers.append(RGATLayer(
            h_dim, out_dim, num_heads, self.rel_names, activation=None))
        return

    def forward(self, hg, h_dict=None):
        if hasattr(hg, 'ntypes'):
            # full graph training,
            for layer in self.layers:
                h_dict = layer(hg, h_dict)
        else:
            # minibatch training, block
            for layer, block in zip(self.layers, hg):
                h_dict = layer(block, h_dict)
        return h_dict

这里我们没有给出RGATLayer的实现细节。有关更多阅读,请查看:RGATLayer。
在OpenHGNN中,我们在模型之外对数据集的特征进行预处理。具体来说,使用每个节点类型都有偏差的线性层来将所有节点特征映射到共享特征空间。因此,模型中forward的参数h_dict不是原始特征,您的模型不需要进行特征预处理。
第三步:添加到支持的模型字典
我们应该在 model/init.py中向 SUPPORTED _ MODELS 添加一个新条目。

3. 应用到一个新场景

在本节中,我们将应用于一个推荐场景,该场景涉及构建一个新任务和训练流。

3.1 如何构建一个新任务

第一步:注册任务
创建一个类Recommendation,继承内置的BaseTask并用@register_task(str)注册它。

from openhgnn.tasks import BaseTask, register_task
@register_task('recommendation')
class Recommendation(BaseTask):
    ...

第二步:实现方法
我们应该实现与评估指标和损失函数相关的方法。

class Recommendation(BaseTask):
    """Recommendation tasks."""
    def __init__(self, args):
        super(Recommendation, self).__init__()
        self.n_dataset = args.dataset
        self.dataset = build_dataset(args.dataset, 'recommendation')
        self.train_hg, self.train_neg_hg, self.val_hg, self.test_hg = self.dataset.get_split()
        self.evaluator = Evaluator(args.seed)

    def get_loss_fn(self):
        return F.binary_cross_entropy_with_logits

    def evaluate(self, y_true, y_score, name):
        if name == 'ndcg':
            return self.evaluator.ndcg(y_true, y_score)

最后
在task/init.py中,增加一个新的实体到SUPPORTED_TASKS.

3.2 如何构建一个新的trainerflow

第一步:注册trainerflow
创建一个类,继承BaseFlow,并用@register_trainer(str)去注册trainerflow。

from openhgnn.trainerflow import BaseFlow, register_flow
@register_flow('Recommendation')
class Recommendation(BaseFlow):
    ...

第二步:实现方法
我们将函数train()声明为一个抽象方法。因此,train()必须被重写,否则trainerflow就无法实例化。下面给出了一个训练循环的示例。

...
class Recommendation(BaseFlow):
    def __init__(self, args=None):
        super(Recommendation, self).__init__(args)
        self.target_link = self.task.dataset.target_link
        self.model = build_model(self.model).build_model_from_args(self.args, self.hg)
        self.evaluator = self.task.get_evaluator(self.metric)

    def train(self,):
        for epoch in epoch_iter:
            self._full_train_step()
            self._full_test_step()

    def _full_train_step(self):
        self.model.train()
        logits = self.model(self.hg)[self.category]
        loss = self.loss_fn(logits[self.train_idx], self.labels[self.train_idx])
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def _full_test_step(self, modes=None, logits=None):
        self.model.eval()
        with torch.no_grad():
            loss = self.loss_fn(logits[mask], self.labels[mask]).item()
            metric = self.task.evaluate(pred, name=self.metric, mask=mask)
            return metric, loss

最终
在trainerflow/init.py中增加一个新的实体到SUPPORT_FLOWS

内容来源

  1. Developer_Guide

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/579512.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【ROS】服务通信、话题通信的应用

Halo&#xff0c;这里是Ppeua。平时主要更新C语言&#xff0c;C&#xff0c;数据结构算法…感兴趣就关注我吧&#xff01;你定不会失望。 服务通信、话题通信的应用 0. 话题发布1.话题订阅2.服务调用3.话题通信与服务通信的比较 本章将来学习如何利用话题通信&#xff0c;服务…

【软件分析/静态分析】学习笔记02——中间表示Intermediate Representation

&#x1f517; 课程链接&#xff1a;李樾老师和谭天老师的&#xff1a;南京大学《软件分析》课程02&#xff08;Intermediate Representation&#xff09;_哔哩哔哩_bilibili 目录 第二章 Intermediate Representation 2.1 编译器与静态分析器的关系(Compilers & Static …

SpringCloudAlibaba(简介及核心组件使用)

微服务架构常见的问题 一旦采用微服务系统架构&#xff0c;就势必会遇到这样几个问题&#xff1a; 这么多小服务&#xff0c;如何管理他们&#xff1f;服务发现/服务注册---》注册中心 这么多小服务&#xff0c;他们之间如何通讯&#xff1f;Feign -> 基于 http 的微服务调…

使用【Python+Appium】实现自动化测试

一、环境准备 1.脚本语言&#xff1a;Python3.x IDE&#xff1a;安装Pycharm 2.安装Java JDK 、Android SDK 3.adb环境&#xff0c;path添加E:\Software\Android_SDK\platform-tools 4.安装Appium for windows&#xff0c;官网地址 Redirecting 点击下载按钮会到GitHub的…

使用golang 基于 OpenAI Embedding + qdrant 实现k8s本地知识库

使用golang 基于 OpenAI Embedding qdrant 实现k8s本地知识库 文章博客地址:套路猿-使用golang 基于 OpenAI Embedding qdrant 实现k8s本地知识库 流程 将数据集 通过 openai embedding 得到向量组装payload,存入 qdrant用户进行问题搜索,通过 openai embedding 得到向量,从…

“Jmeter WebSocket协议压测”,助你轻松应对高并发场景!

目录 引言 背景说明 步骤1&#xff1a;安装插件JMeter WebSocket Samplers 步骤2&#xff1a;采集器使用 步骤3&#xff1a;脚本执行 结语 引言 在当今高并发的网络环境下&#xff0c;WebSocket协议已经成为了最受欢迎的实时通信技术之一。然而&#xff0c;对于开发人员来…

CorelDRAW2023序列号及下载安装条件

始于1989年并不断推陈出新,致力为设计工作者提供更高效的设计工具&#xff01;CorelDRAW滋养并见证了一代设计师的成长&#xff01;在最短的时间内交付作品&#xff0c;CorelDRAW的智能高效会让你一见钟情&#xff01;CorelDRAW 全称“CorelDRAW Graphics Suite“&#xff0c;也…

Linux:命令tar、zip、unzip对文件或文件夹进行压缩与解压

Linux&#xff1a;命令tar、zip、unzip对文件或文件夹进行压缩与解压 .tar压缩操作&#xff1a; 创建要进行压缩的文件&#xff1a; 对文件进行压缩&#xff1a; 将三个文件压缩成text.tar文件&#xff0c;压缩到当前路径下(默认也是在当前路径) 对比体积&#xff1a; 发现&…

关于f-stack转发框架的几点分析思考

使用DPDK收包&#xff0c;想要用到TCP协议栈&#xff0c;可选的方案有linux原生的tun/tap口以及DPDK自带的KNI驱动&#xff0c;这两种都是通过将DPDK收到的报文注入到linux内核来使用TCP协议栈的功能&#xff0c;然后&#xff0c;用户态协议栈可以考虑开源的f-stack&#xff0c…

在页面使用富文本编译器

富文本编译器的选择 Editor.mdTinyMCESimpleMDECKEditor 还有一些&#xff0c;这里讲的是我用的TinyMCE 1、下载 下载地址&#xff1a;下载tiny | TinyMCE中文文档中文手册 下载开发版本&#xff0c;我下载的最新版 tinymce_6.4.2_dev.zip 将压缩包解压后可以看到下面目录&…

(哈希表 ) 202. 快乐数——【Leetcode每日一题】

❓202. 快乐数 难度&#xff1a;简单 编写一个算法来判断一个数 n 是不是快乐数。 「快乐数」 定义为&#xff1a; 对于一个正整数&#xff0c;每一次将该数替换为它每个位置上的数字的平方和。然后重复这个过程直到这个数变为 1&#xff0c;也可能是 无限循环 但始终变不到…

Groovy系列一 Groovy基础语法

目录 为什么要学习Groovy Groovy 介绍 Groovy 特点 Groovy 实战 动态类型 简单明了的list,map类型 在groovy世界任何东西都是对象 属性操作变得更容易 GString 闭包 委派&#xff1a;delegate Switch变得更简洁 元编程 强制类型检查 Elvis Operator 安全访问 为…

【五】设计模式~~~创建型模式~~~单例模式(Java)

【学习难度&#xff1a;★☆☆☆☆&#xff0c;使用频率&#xff1a;★★★★☆】 5.1. 模式动机 对于系统中的某些类来说&#xff0c;只有一个实例很重要&#xff0c;例如&#xff0c;一个系统中可以存在多个打印任务&#xff0c;但是只能有一个正在工作的任务&#xff1b;一…

一波三折,终于找到 src 漏洞挖掘的方法了【建议收藏】

0x01 信息收集 1、Google Hack 实用语法 迅速查找信息泄露、管理后台暴露等漏洞语法&#xff0c;例如&#xff1a; filetype:txt 登录 filetype:xls 登录 filetype:doc 登录 intitle:后台管理 intitle:login intitle:后台管理 inurl:admin intitle:index of /查找指定网站&…

C++:征服C指针:指针(二)

指针二 1. 指向数组的指针2. 多维数组三级目录 上一篇文章我们介绍了&#xff1a;什么是指针&#xff0c;指针常见的问题&#xff0c;本篇我们主要介绍 &#xff1a;指针与数组。 1. 指向数组的指针 int *p[n] : 指针数组&#xff0c; 它包括 n 个成员&#xff0c;每个成员都是…

探索Maven创建项目全过程(超详细~~~)

文章目录 1.Maven介绍2.Servlet介绍2.1 Servlet定义2.2 Servlet的主要任务 3.创建Servlet程序步骤3.1 创建项目3.2 引入依赖3.3 创建目录3.4编写代码3.5 打包程序3.6 部署程序3.7 验证结果 4.更方便的部署方式4.1.下载Tomcat插件4.2 配置Tomcat插件4.3运行项目 1.Maven介绍 Ma…

认识Tomcat

hi,大家好,今天为大家带来Tomcat的相关知识 &#x1f36d;1.Tomcat是什么 &#x1f36d;2.Tomcat的下载安装 &#x1f36d;3.Tomcat的目录结构 &#x1f36d;4.启动Tomcat &#x1f36d;5.部署博客系统到Tomcat &#x1f349;1.Tomcat是什么 我们之前也已经学了http,http…

【JAVAWEB】HTML的常见标签

目录 1.HTML结构 1.1认识HTML标签 1.2HTML文件基本结构 1.3标签层次结构 1.4快速生成代码框架 2.HTML常见标签 注释标签 标题标签&#xff1a;h1-h6 段落标签:p 换行标签&#xff1a;br 格式化标签 图片标签 超链接标签&#xff1a;a 表格标签 列表标签 表单标…

Windows 同时安装 MySQL5 和 MySQL8 版本

&#x1f44f;作者简介&#xff1a;大家好&#xff0c;我是Rockey&#xff0c;不知名企业的不知名开发着 &#x1f525;如果感觉博主的文章还不错的话&#xff0c;请&#x1f44d;三连支持&#x1f44d;一下博主哦 &#x1f4dd;联系方式&#xff1a;he18339193956&#xff0c;…

MySQL 00 : MySQL_数据库shell登录时遇到的问题

问题1描述&#xff1a;输入链接数据块的命令提示 sh:mysgl:command not found 解决:第一步配置环境变量来解决 1、vim /etc/profile 2、末尾写入export PATH$PATH:/usr/local/mysql/bin 3、保存 4、执行 source /etc/profile 第二部 问题描述 Mac通过MAMP安装MySQL时&#…