DGL库之HGTConv的使用

news2024/10/11 0:22:59

DGL库之HGTConv的使用

  • 论文地址和异构图构建教程
  • HGTConv语法格式
  • HGTConv的使用

论文地址和异构图构建教程

论文地址:https://arxiv.org/pdf/2003.01332
异构图构建教程:异构图构建
异构图转同构图:异构图转同构图

HGTConv语法格式

dgl.nn.pytorch.conv.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes, dropout=0.2, use_norm=False)

参数说明:

  • in_size (int): 输入节点特征的大小。
  • head_size (int): 输出头的大小。输出节点特征的大小为 head_size * num_heads。
  • num_heads (int): 头的数量。输出节点特征的大小为 head_size * num_heads。
  • num_ntypes (int): 节点类型的数量。
  • num_etypes (int): 边类型的数量。
  • dropout (可选, float): dropout 比率,用于防止过拟合。
  • use_norm (可选, bool): 如果为 True,则在输出节点特征上应用层归一化。
forward(g, x, ntype, etype, *, presorted=False)

参数说明:

  • g (DGLGraph): 输入的图对象。

  • x (torch.Tensor): 一个 2D 张量,表示节点特征。其形状应为 (num_nodes, in_size),num_nodes 是节点数量,in_size 是输入特征的维度。

  • ntype (torch.Tensor): 一个 1D 整数张量,表示节点类型。其形状应为 (num_nodes,),对应每个节点的类型索引。

  • etype (torch.Tensor): 一个 1D 整数张量,表示边类型。其形状应为 (num_edges,),对应每条边的类型索引。

  • presorted (bool, 可选): 指示输入图的节点和边是否已经按照类型排序。如果输入图是预排序的,则前向传播可能会更快。通过调用 to_homogeneous()创建的图会自动满足此条件。也可以使用 reorder_graph() 方法手动重新排序节点和边。

返回值:

  • 返回的新节点特征: 返回的特征是一个 2D 张量,其形状为 (num_nodes, head_size * num_heads),表示经过HGTConv 处理后的新节点特征,返回的张量类型为 torch.Tensor。

HGTConv的使用

使用的异构图如下:
在这里插入图片描述
在使用HGTConv时,一定要使用dgl.to_homogeneous将异构图转为同构图,否则不能使用,代码如下:

import dgl
import torch
import torch.nn as nn
import dgl.nn.pytorch

# 定义一个简单的异构图
def create_hetero_graph():
    # 定义两个类型的节点:drug(药物)和 disease(疾病)
    data_dict = {
        ('drug', 'd_interacts', 'drug'): (torch.tensor([0, 1]), torch.tensor([1, 2])),  # 药物间的相互作用
        ('drug', 'g_interacts', 'gene'): (torch.tensor([0, 1]), torch.tensor([2, 3])),  # 药物与基因间的相互作用
        ('drug', 'treats', 'disease'): (torch.tensor([1]), torch.tensor([2]))           # 药物与疾病的关系
    }

    # 创建一个异构图
    hetero_graph = dgl.heterograph(data_dict)

    # 设置节点和边的特征
    hetero_graph.nodes['drug'].data['h'] = torch.ones(3, 320)  # 假设药物特征是320维的
    hetero_graph.nodes['disease'].data['h'] = torch.zeros(3, 320)  # 假设疾病特征是320维的
    hetero_graph.nodes['gene'].data['h'] = torch.ones(4, 320)  # 假设基因特征是320维的
    return hetero_graph

# 定义一个HGT模型类
class HGTModel(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, num_layers, num_node_types, num_edge_types, dropout=0.2):
        super(HGTModel, self).__init__()
        # 使用 dgl.nn.pytorch.conv.HGTConv 初始化 HGT 卷积层
        self.layers = nn.ModuleList()  # 创建一个空的层列表
        for _ in range(num_layers):
            layer = dgl.nn.pytorch.conv.HGTConv(
                in_dim,  # 输入维度
                out_dim,  # 输出维度
                num_heads,  # 注意力头的数量
                num_node_types,  # 节点类型数量
                num_edge_types,  # 边类型数量
                dropout=dropout  # dropout比率
            )
            self.layers.append(layer)  # 将层添加到列表中

    def forward(self, g):
        with g.local_scope():  # 创建一个局部作用域,‌确保对图的操作不会影响原始图。‌
            for layer in self.layers:
                # 使用HGTConv层进行卷积操作
                h = layer(g, g.ndata['h'], g.ndata['_TYPE'], g.edata['_TYPE'], presorted=True)
                g.ndata['h'] = h  # 更新节点特征
            return g.ndata['h']  # 返回最后一层的节点特征

# 创建一个异构图
hetero_graph = create_hetero_graph()

print('异构图为:\n',hetero_graph)  # 输出异构图的信息
# 将异构图转换为同构图
homogeneous_graph = dgl.to_homogeneous(hetero_graph, ndata=['h'])
print(f"节点特征矩阵为:\n{homogeneous_graph.ndata['h']}")  # 打印节点特征的类型

# 创建模型并移动到 CPU 设备
hgt_model = HGTModel(in_dim=320, out_dim=80, num_heads=4, num_layers=2,
                     num_node_types=3, num_edge_types=3, dropout=0.3).to(torch.device('cpu'))

# 前向传播
output_features = hgt_model(homogeneous_graph)

print("更新后的特征:\n", output_features)  # 输出特征的形状

结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2203628.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

极客兔兔Gee-Cache Day7

protobuf配置: 从 Protobuf Releases 下载最先版本的发布包安装。解压后将解压路径下的 bin 目录 加入到环境变量即可。 如果能正常显示版本,则表示安装成功。 $ protoc --version libprotoc 3.11.2在Golang中使用protobuf,还需要protoc-g…

【单链表的模拟实现Java】

【单链表的模拟实现Java】 1. 了解单链表的功能2. 模拟实现单链表的功能2.1 单链表的创建2.2 链表的头插2.3 链表的尾插2.3 链表的长度2.4 链表的打印2.5 在指定位置插入2.6 查找2.7 删除第一个出现的节点2.8 删除出现的所有节点2.9 清空链表 3. 正确使用模拟单链表 1. 了解单链…

重头开始嵌入式第四十八天(Linux内核驱动 linux启动流程)

目录 什么是操作系统? 一、管理硬件资源 二、提供用户接口 三、管理软件资源 什么是操作系统内核? 一、主要功能 1. 进程管理: 2. 内存管理: 3. 设备管理: 4. 文件系统管理: 二、特点 什么是驱动…

WebGoat JAVA反序列化漏洞源码分析

目录 InsecureDeserializationTask.java 代码分析 反序列化漏洞知识补充 VulnerableTaskHolder类分析 poc 编写 WebGoat 靶场地址:GitHub - WebGoat/WebGoat: WebGoat is a deliberately insecure application 这里就不介绍怎么搭建了,可以参考其他…

yq 工具

文章目录 yq命令快速 Recipes查找数组中的项目查找并更新数组中的项目深度修剪一棵树对数组中的项目进行多次或复杂的更新按字段对数组进行排序 OperatorsOmitOmit keys from mapOmit indices from array DeleteDelete entry in mapDelete nested entry in mapDelete entry in …

【重学 MySQL】六十三、唯一约束的使用

【重学 MySQL】六十三、唯一约束的使用 创建表时定义唯一约束示例 在已存在的表上添加唯一约束示例 删除唯一约束示例 复合唯一约束案例背景创建表并添加复合唯一约束插入数据测试总结 特点注意事项 在 MySQL 中,唯一约束(UNIQUE Constraint)…

butterfly主题留言板 报错记录 未解决

新建留言板,在博客根目录执行下面的命令 hexo new page messageboard 在博客/source/messageboard的文件夹下找到index.md文件并修改 --- title: 留言板 date: 2018-01-05 00:00:00 type: messageboard ---找到butterfly主题下的_config.yml文件 把留言板的注释…

基于springboot+小程序的智慧物流管理系统(物流1)

👉文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1、项目介绍 基于springboot小程序的智慧物流管理系统实现了管理员、司机及用户。 1、管理员实现了司机管理、用户管理、车辆管理、商品管理、物流信息管理、基础数据管理、论坛管理、公告信息管理等。…

帮助自闭症孩子融入社会,寄宿学校是明智选择

在广州这座充满活力与温情的城市,有一群特殊的孩子,他们被称为“星星的孩子”——自闭症儿童。自闭症,这个让人既陌生又熟悉的名词,背后承载的是无数家庭的辛酸与希望。对于自闭症儿童来说,融入社会、与人交流、理解世…

【Linux第一弹】- 基本指令

🌈 个人主页:白子寰 🔥 分类专栏:重生之我在学Linux,C打怪之路,python从入门到精通,数据结构,C语言,C语言题集👈 希望得到您的订阅和支持~ 💡 坚持…

blender 记一下lattice

这个工具能够辅助你捏形状 这里演示如何操作BOX shift A分别创建俩对象一个BOX 一个就是lattice对象 然后在BOX的修改器内 创建一个叫做lattice的修改器 然后指定object为刚刚创建的lattice对象 这样就算绑定好了 接下来 进入lattice的编辑模式下 你选取一个点进行运动&#…

量化交易与基础投资工具介绍

🌟作者简介:热爱数据分析,学习Python、Stata、SPSS等统计语言的小高同学~🍊个人主页:小高要坚强的博客🍓当前专栏:Python之机器学习《Python之量化交易》Python之机器学习🍎本文内容…

谈谈留学生毕业论文如何分析问卷采访数据

留学生毕业论文在设计好采访问题并且顺利进行了采访之后,我们便需要将得到的采访答案进行必要的分析,从而得出一些结论。我们可以通过这些结论回答研究问题,或者提出进一步的思考等等。那么我们应当如何分析采访数据呢?以下有若干…

python3开头如何设置utf-8

编码格式1&#xff1a; 在源文件第一行或者第二行定义&#xff1a; # coding<encoding name> 例如&#xff1a; # codingutf-8 编码格式2&#xff1a;&#xff08;这种最流行&#xff09; 格式如下&#xff1a; #!/usr/bin/python # -*- coding: <encoding name>…

信息安全工程师(43)入侵检测概述

一、定义与目的 入侵检测&#xff08;Intrusion Detection&#xff09;是指通过对行为、安全日志、审计数据或其他网络上可以获得的信息进行操作&#xff0c;检测到对系统的闯入或闯入的企图。其主要目的是确保网络安全和信息安全&#xff0c;保护个人和机构的敏感数据免受未经…

论文阅读:Split-Aperture 2-in-1 Computational Cameras (二)

Split-Aperture 2-in-1 Computational Cameras (一) Coded Optics for High Dynamic Range Imaging 接下来&#xff0c;文章介绍了二合一相机在几种场景下的应用&#xff0c;首先是高动态范围成像&#xff0c;现有的快照高动态范围&#xff08;HDR&#xff09;成像工作已经证…

FreeRTOS——任务创建(静态、动态创建)、任务删除以及内部实现剖析

任务创建和删除的API函数 任务的创建和删除本质就是调用FreeRTOS的API函数 API函数描述xTaskCreate()动态方式创建任务xTaskCreateStatic()静态方式创建任务vTaskDelete()删除任务 动态创建任务&#xff1a;任务的任务控制块以及任务的栈空间所需的内存&#xff0c;均有FreeR…

考研代码题:10.10 汉诺塔 爬楼梯 取球 猴子吃桃

汉诺塔 C语言 - 汉诺塔详解&#xff08;最简单的方法&#xff0c;进来看看就懂&#xff09;_汉诺塔c语言程序详解-CSDN博客 #include <stdio.h>void move(char begin,char end){printf("%c->%c\n",begin,end); } //begin开始杆&#xff0c;help辅助杆&am…

BUU刷题-Pwn-axb_2019_mips(MIPS跳转bss段执行shellcode)

解题所涉知识点&#xff1a; 泄露或修改内存数据&#xff1a; 堆地址&#xff1a;栈地址&#xff1a;libc地址&#xff1a;BSS段地址&#xff1a; 劫持程序执行流程&#xff1a;MIPS_ROP 获得shell或flag&#xff1a;[[MIPS_Shellcode]] && [[MIPS劫持RA寄存器]] 题…

开源文件管理工具File Browser本地部署并一键发布公网远程传输文件

文章目录 前言1.下载安装File Browser2.启动访问File Browser3.安装cpolar内网穿透3.1 注册账号3.2 下载cpolar客户端3.3 登录cpolar web ui管理界面3.4 创建公网地址 4.固定公网地址访问 前言 File Browser是一个开源的文件管理器和文件共享工具&#xff0c;它可以帮助用户轻…