图神经网络教程之HAN-异构图模型

news2025/1/19 20:26:17

异构图

包含不同类型节点和链接的异构图

异构图的定义:节点类别数量和边的类别数量加起来大于2就叫异构图。

meta-path元路径的定义:连接两个对象的复合关系,比如,节点类型A和节点类型B,A-B-A和B-A-B都是一种元路径。

meta-path下的邻居节点的定义:如下图所示。

在这里插入图片描述

其中m1-a1-m2,m1-a3-m3都是一种meta-path,所以m1的邻居有m2、m3以及本身m1

在这里插入图片描述

节点级别的attention和语义级别的attention

节点级别:简单来说就是单种meta-path求得节点embeddings,比如对于M-D-M,Terminator2的embeddings通过M-D-M的元路径即可求的另一个M(Termintor)的embeddings。

语义级别:对于Terminator的embeddings不再是根据一种meta-path进行获取,而是根据两种meta-path进行权重的分配相加得到。

节点级别:

举例子:

在这里插入图片描述

如上图所示,对于异构图,一种meta-path为蓝-黄-蓝,对于节点x1-xa-x2,所以x1与x2通过meta-path元路径,同理每一对节点,构成上图中的第二个图的连接方式。

在这里插入图片描述

对于节点x1,与节点x2、x3、x6相连,所以x2、x3、x6都是节点x1的邻居节点,也就是公式2。

对于公式三,分子将i和j节点拼接在一起以后乘以一个可学习的参数然后再通过激活函数,再通过exp。分母就是他的邻居节点的。

对后求的节点级别下的embeddings。

语义级别:

简单来说语义级别就是多种meta-path呗,只需要把每种meta-path下面的求出来进行加权就可以了。

在这里插入图片描述

如上图所示,通过节点级别的求解方法,求出来对于每一种metapath下面的embeddings,然后最后进行加权求和。

知道了上面的HAN的原理,下面讲解一下model代码。

在讲解原理的时候分为语义级别和节点级别,在代码的时候会分为给定已经处理好的邻接矩阵和直接输入异构图。

异构图直接输入(异构图模型。):

需要将meta-path转化为邻接矩阵即元组形式。

实现了Heterogeneous Graph Attention Network(HAN)模型,用于处理异构图数据。HAN是一种深度学习模型,用于在异构图中进行节点分类任务

import torch
import torch.nn as nn
import torch.nn.functional as F

from dgl.nn.pytorch import GATConv

首先,导入了PyTorch库以及用于图神经网络的相关模块。

class SemanticAttention(nn.Module):
    def __init__(self, in_size, hidden_size=128):
        super(SemanticAttention, self).__init__()
        # input:[Node, metapath, in_size]; output:[None, metapath, 1]; 所有节点在每个meta-path上的重要性值
        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False)
        )

这里定义了一个名为SemanticAttention的PyTorch模型类,它用于计算每个节点在不同元路径(metapath)上的重要性。SemanticAttention类有以下成员:

  • __init__方法:初始化模型。它接受输入特征的维度in_size以及可选的隐藏层维度hidden_size。在初始化过程中,它创建了一个神经网络模块self.project,该模块包括两个线性层和一个Tanh激活函数,最后一个线性层没有偏差。
    def forward(self, z):
        w = self.project(z).mean(0)#每个节点在metapath维度的均值; mean(0): 每个meta-path上的均值(/|V|); (MetaPath, 1)
        beta = torch.softmax(w, dim=0)       # 归一化   # (M, 1)
        beta = beta.expand((z.shape[0],) + beta.shape) #  拓展到N个节点上的metapath的值   (N, M, 1)
        return (beta * z).sum(1)#(beta*z)=>所有节点,在metapath上的attention值;(beta*z).sum(1)=>节点最终的值(N,D*K)
  • forward方法:用于计算每个节点在不同元路径上的重要性。首先,将输入特征z通过self.project模块传递,然后计算每个元路径上的重要性均值w。接着,使用softmax函数对这些均值进行归一化,以获得每个元路径上的注意力权重beta。最后,将注意力权重与输入特征相乘,并对所有元路径求和,得到最终的节点表示。

这个SemanticAttention模块的目的是计算每个节点在不同元路径上的权重,以便后续的元路径级别的注意力聚合。

接下来,定义了另一个模型类HANLayer

class HANLayer(nn.Module):
    def __init__(self, num_meta_paths, in_size, out_size, layer_num_heads, dropout):
        super(HANLayer, self).__init__()
        self.gat_layers = nn.ModuleList()
        for i in range(num_meta_paths):  # meta-path Layers; 两个meta-path的维度是一致的
            self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
                                           dropout, dropout, activation=F.elu))
        self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)  # 语义attention; out-size*layers
        self.num_meta_paths = num_meta_paths

HANLayer类代表了HAN模型中的一个层次。每个HANLayer层包括以下成员:

  • __init__方法:初始化层。它接受以下参数:
    • num_meta_paths:元路径的数量。
    • in_size:输入特征的维度。
    • out_size:输出特征的维度。
    • layer_num_heads:每个GAT层中的注意力头的数量。
    • dropout:用于正则化的dropout率。

在初始化过程中,它首先创建了多个GATConv层,每个GATConv层对应一个元路径,这些层将用于图注意力聚合。然后,创建了一个SemanticAttention模块,用于计算每个节点在不同元路径上的语义级别的注意力。

接下来,定义了整个HAN模型类HAN

class HAN(nn.Module):
    def __init__(self, num_meta_paths, in_size, hidden_size, out_size, num_heads, dropout):
        super(HAN, self).__init__()

        self.layers = nn.ModuleList()
        self.layers.append(HANLayer(num_meta_paths, in_size, hidden_size, num_heads[0], dropout)) # meta-path数量 + semantic_attention
        for l in range(1, len(num_heads)): # 多层多头,目前是没有
            self.layers.append(HANLayer(num_meta_paths, hidden_size * num_heads[l-1],
                                        hidden_size, num_heads[l], dropout))
        self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)  # hidden*heads, classes; HAN->classes

HAN类是整个HAN模型的定义。它接受以下参数:

  • num_meta_paths:元路径的数量。
  • in_size:输入特征的维度。
  • hidden_size:隐藏层的维度。
  • out_size:输出特征的维度(通常是类别数量)。
  • num_heads:一个列表,指定每个HANLayer层中的注意力头数量。
  • dropout:用于正则化的dropout率。

在初始化过程中,它首先创建了多个HANLayer层,每个HANLayer层包括一个或多个GATConv层和一个SemanticAttention层。

输入处理好的异构图,即邻接矩阵(普通图模型。):

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GATConv

首先,导入了必要的库和模块。

class SemanticAttention(nn.Module):
    def __init__(self, in_size, hidden_size=128):
        super(SemanticAttention, self).__init__()

        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False)
        )

这里定义了一个名为SemanticAttention的PyTorch模型类,它用于计算每个节点在不同元路径上的语义级别的重要性。和第一个代码段的SemanticAttention类相似,这个类也包括以下成员:

  • __init__方法:初始化模型。它接受输入特征的维度in_size以及可选的隐藏层维度hidden_size。在初始化过程中,它创建了一个神经网络模块self.project,该模块包括两个线性层和一个Tanh激活函数,最后一个线性层没有偏差。
    def forward(self, z):
        w = self.project(z).mean(0)                    # (M, 1)
        beta = torch.softmax(w, dim=0)                 # (M, 1)
        beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)

        return (beta * z).sum(1)                       # (N, D * K)
  • forward方法:用于计算每个节点在不同元路径上的语义级别的重要性。首先,将输入特征z通过self.project模块传递,然后计算每个元路径上的语义级别的均值权重w。接着,使用softmax函数对这些均值进行归一化,得到每个元路径上的注意力权重beta,将这些权重与输入特征相乘,并对所有元路径求和,得到最终的节点表示。

接下来,定义了另一个模型类HANLayer,它代表HAN模型中的一个层次。

class HANLayer(nn.Module):

    def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout):
        super(HANLayer, self).__init__()

        # One GAT layer for each meta path based adjacency matrix
        self.gat_layers = nn.ModuleList()
        for i in range(len(meta_paths)):
            self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
                                           dropout, dropout, activation=F.elu,
                                           allow_zero_in_degree=True))
        self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
        self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths)  # 将meta-path转换成元组形式

        self._cached_graph = None
        self._cached_coalesced_graph = {}

    def forward(self, g, h):
        semantic_embeddings = []

        if self._cached_graph is None or self._cached_graph is not g:  # 第一次,建立一张metapath下的异构图
            self._cached_graph = g
            self._cached_coalesced_graph.clear()
            for meta_path in self.meta_paths:
                self._cached_coalesced_graph[meta_path] = dgl.metapath_reachable_graph(
                        g, meta_path)  # 构建异构图的邻居;
        # self._cached_coalesced_graph 多个metapath下的异构图
        for i, meta_path in enumerate(self.meta_paths):
            new_g = self._cached_coalesced_graph[meta_path]  # meta-path下的节点邻居图
            semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))   # 图attention
        semantic_embeddings = torch.stack(semantic_embeddings, dim=1)                  # (N, M, D * K)

        return self.semantic_attention(semantic_embeddings)                            # (N, D * K)

HANLayer类包括以下主要部分:

  • __init__方法:初始化HAN层,它包括多个GATConv层以及一个语义注意力模块。每个GATConv层对应一个元路径,用于处理节点在该元路径上的信息。语义注意力模块用于计算节点在不同元路径上的语义级别的注意力。

  • forward方法:执行HAN层的前向传播。对于每个元路径,首先获取该元路径的邻居图,然后通过GATConv层计算节点的注意力表示。最后,通过语义注意力模块将不同元路径上的表示进行加权求和,得到最终的节点表示。

最后,定义了整个HAN模型类HAN

class HAN(nn.Module):
    def __init__(self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout):
        super(HAN, self).__init__()

        self.layers = nn.ModuleList()
        self.layers.append(HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout))
        for l in range(1, len(num_heads)):
            self.layers.append(HANLayer(meta_paths, hidden_size * num_heads[l-1],
                                        hidden_size, num_heads[l], dropout))
        self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)

HAN类定义了整个HAN模型,包括多个HANLayer层以及最后的预测层。

  • __init__方法:初始化HAN模型,它包括多个HANLayer层,每个HANLayer层用于处理一个元路径的信息。最后,添加一个线性预测层,将最终的节点表示映

射到输出特征(通常是类别数量)。

  • forward方法:执行HAN模型的前向传播。它依次通过多个HANLayer层来计算最终的输出,每个HANLayer层都包括元路径信息的处理和注意力聚合。

训练代码train

训练代码就是常规的套路。

  1. 引入必要的库和模块:

    • 导入了PyTorch库和sklearn库,用于深度学习和评估模型性能。
    • 导入了自定义的load_dataEarlyStopping函数,以及其他必要的模块。
  2. score函数:

    • 这个函数用于计算模型的性能指标,包括准确率(accuracy)、微平均F1分数(micro_f1),和宏平均F1分数(macro_f1)。
    • 它接受模型的预测结果(logits)和真实标签(labels),然后计算这些性能指标。
    • 准确率表示正确分类的样本比例,微平均F1分数和宏平均F1分数是一种综合的评估指标,用于度量分类模型的性能。
  3. evaluate函数:

    • 这个函数用于评估模型在验证集上的性能。
    • 它接受模型(model)、图数据(g)、特征数据(features)、标签数据(labels)、掩码数据(mask),以及损失函数(loss_func)作为输入。
    • 在评估过程中,模型处于评估模式(model.eval()),不会更新梯度。
    • 通过模型预测验证集上的结果,并计算损失、准确率、微平均F1分数和宏平均F1分数。
    • 最后返回这些评估指标。
  4. main函数:

    • 这是主要的训练和评估逻辑所在的函数。
    • 首先,加载数据(包括图数据、特征数据、标签数据等)并将其移动到指定的计算设备(CPU或GPU)上。
    • 根据参数args中的'hetero'标志,选择不同的模型和数据处理方式。如果'hetero'为True,则使用异构图模型;否则,使用普通图模型。
    • 定义了模型的损失函数、优化器和早停(EarlyStopping)对象。
    • 开始训练循环,每个epoch进行一次训练和验证。在训练过程中,计算损失、准确率和F1分数等指标,并打印出来。如果验证集上的性能不再提升,会触发早停(early stopping)。
    • 最后,在测试集上评估模型的性能,并打印出测试集上的损失、准确率、微平均F1分数和宏平均F1分数。
  5. if __name__ == '__main__': 部分:

    • 这个部分用于设置命令行参数,并调用main函数来运行训练和评估过程。
    • 可以通过命令行传递参数来配置模型的训练和数据处理方式。

rlyStopping)对象。

  • 开始训练循环,每个epoch进行一次训练和验证。在训练过程中,计算损失、准确率和F1分数等指标,并打印出来。如果验证集上的性能不再提升,会触发早停(early stopping)。
  • 最后,在测试集上评估模型的性能,并打印出测试集上的损失、准确率、微平均F1分数和宏平均F1分数。
  1. if __name__ == '__main__': 部分:

    • 这个部分用于设置命令行参数,并调用main函数来运行训练和评估过程。
    • 可以通过命令行传递参数来配置模型的训练和数据处理方式。

总体来说,这段代码实现了一个用于异构图数据或普通图数据的节点分类任务的训练和评估流程。它加载数据、选择模型、进行训练和验证,最后在测试集上评估模型性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/967040.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[C++] STL_list常用接口的模拟实现

文章目录 1、list的介绍与使用1.1 list的介绍1.2 list的使用 2、list迭代器3、list的构造4、list常用接口的实现4.1 list capacity4.2 插入删除、交换、清理4.2.1 insert任意位置插入4.2.2 push_front头插4.2.3 push_back尾插4.2.4 erase任意位置删除4.2.5 pop_front头删4.2.6 …

Keil 编译 Debug

# 头文件无法导入进来 # 导入头文件,只有函数声明,但缺少函数实现 已经导入了air32f10x_gpio.h但是没有导入 .c,就导致 编译出错出现undefined symbol (某个函数),这时候按照下面的操作,导入外设模块就好。

PQUEUE - Printer Queue

题目描述 The only printer in the computer science students union is experiencing an extremely heavy workload. Sometimes there are a hundred jobs in the printer queue and you may have to wait for hours to get a single page of output. Because some jobs are …

pip切换源

pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

CLFS信息泄露漏洞CVE-2023-28266分析

引用 这篇文章的目的是介绍今年4月发布的CLFS信息泄露漏洞CVE-2023-28266分析. 文章目录 引用简介CVE-2023-28266漏洞分析CVE-2023-28266调试过程漏洞复现相关引用参与贡献 简介 文章结合了逆向代码和调试结果分析了CVE-2023-28266漏洞利用过程和漏洞成因. CVE-2023-28266漏洞…

两个线程同步执行:解决乱箭穿心(STL/Windows/Linux)

C自学精简教程 目录(必读) C并发编程入门 目录 多线程同步 线程之间同步是指线程等待其他线程执行完某个动作之后再执行(本文情况)。 线程同步还可以是像十字路口的红绿灯一样,只允许一个方向的车同行,其他方向的车等待。 本…

UART串口Shell软硬件模型分析总结

文章目录 层次一、最底层逻辑配置交互----如何从Uart硬件读写单个字节数据层次二、抽象串口软件模块交互----基于串口对接输入输出流 和 Printf适配层次三、类似Shell封装抽象交互----基于串口交互命令行界面(命令解析、补全、修改、记录)case1 依次输入…

自建音乐服务器Navidrome之一

这里写自定义目录标题 1.1 官方网站 2. Navidrome 简介2.1 简介2.2 特性 3. 准备工作4. 视频教程5. 界面演示5.1 初始化页5.2 专辑页 前言 之前给大家介绍过 Koel 音频流服务,就是为了解决大家的这个问题:下载下来的音乐,只能在本机欣赏&…

上海的正西边有哪些城市

背景 上海一路向西,来一趟拉萨之行,那么上海出现,所经过的那么多城市,哪些是在上海的正西边呢? 画一幅地图 基于这个背景需求,我们需要拿来一幅地图,一看便知。下面的python代码生成了一幅地…

通信原理板块——平稳随机过程

微信公众号上线,搜索公众号小灰灰的FPGA,关注可获取相关源码,定期更新有关FPGA的项目以及开源项目源码,包括但不限于各类检测芯片驱动、低速接口驱动、高速接口驱动、数据信号处理、图像处理以及AXI总线等 1、平稳随机过程的定义 (1)严平稳随…

UE4 显示遮挡物体

SceneDepth是你相机能够看见的物体的深度距离 CustomDepth是你相机包括看不见被遮挡的物体的深度距离 如果CustemDepth比SceneDepth的距离相等,那么就是没有被遮挡的物体,如果被遮挡那么就是CustemDepth比SceneDepth深度距离远,然后再做对应…

PYTHON知识点学习-循环语句

🚀write in front🚀 🔎大家好,我是Aileen★。希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 🆔本文由 Aileen_0v0★ 原创 CSDN首发🐒 如需转载还…

【机器学习】线性回归

Model Representation 1、问题描述2、表示说明3、数据绘图4、模型函数5、预测总结附录 1、问题描述 一套 1000 平方英尺 (sqft) 的房屋售价为300,000美元,一套 2000 平方英尺的房屋售价为500,000美元。这两点将构成我们的数据或训练集。面积单位为 1000 平方英尺&a…

Swift 如何从图片数据(Data)检测原图片类型?

功能需求 如果我们之前把图片对应的数据(Data)保持在内存或数据库中,那么怎么从 Data 对象检测出原来图片的类型呢? 如上图所示:我们将 11 张不同类型的图片转换为 Data 数据,然后从 Data 对象正确检测出了原图片类型。 目前,我们的代码可以检测出 jpeg(jpg), tiff,…

WebRTC 安全之一

WebRTC 的安全需要满足三个基本需求 Authentication 用户访问需要认证Authorization 用户访问需要授权Audit 用户的访问应该可被追踪和审查 其中前两项也可以归结为 CIA Confidentiality 机密性:信息需要保密, 访问权限也需要控制Integrity 完整性&#…

Spring Cloud集成Nacos配置中心/注册中心

Spring Cloud版本 2021.0.5 Spring Cloud Alibaba版本 2021.0.5.0 Spring Boot版本 2.7.10 pom文件 需要放在依赖管理的pom文件 <dependencyManagement><dependencies><!-- spring boot依赖 --><dependency><groupId>org.springframewor…

2023-9-3 试除法判定质数

题目链接&#xff1a;试除法判定质数 #include <iostream>using namespace std;bool is_prime(int n) {if(n < 2) return false;for(int i 2; i < n / i; i){if(n % i 0) return false;}return true; }int main() {int n;cin >> n;while(n--){int x;cin &g…

git大文件推送报错

报错信息 不多掰扯&#xff0c;直接上报错信息和截图 Delta compression using up to 8 threadsRPC failde; HTTP 413 curl 22 The requested URL returned error: 413 Request Entity Too Large从以上的报错信息不难看出推送仓库的时候&#xff0c;请求体过大&#xff0c;为…

C++ do...while 循环

不像 for 和 while 循环&#xff0c;它们是在循环头部测试循环条件。do…while 循环是在循环的尾部检查它的条件。 do…while 循环与 while 循环类似&#xff0c;但是 do…while 循环会确保至少执行一次循环。 语法 C 中 do…while 循环的语法&#xff1a; do {statement(s…

AD16 基础应用技巧(一些 “偏好“ 设置)

1. 修改铺铜后自动更新铺铜 AD16 铺铜 复制 自动变形 偏好设置 将【DXP】中的【参数选择】。 将【PCB Editor】中的【General】&#xff0c;然后勾选上【Repour Polygons After Modification】。 2. PCB直角走线处理与T型滴泪 一些没用的AD技巧——AD PCB直角走线处理与…