葡萄书--图注意力网络

news2024/11/18 7:35:41

同质图中的注意力网络

注意力网络

GAT 核心就是对于每个顶点都计算其与邻居节点的注意力系数,通过注意力系数来聚合节点的特征,而非认为每个邻居节点的贡献程度都是相同的

单头注意力机制

为了计算注意力权重 eij​,我们首先将即节点特征 hi​ 和 hj​ 通过 W 映射到一个高维空间,然后将这两个向量拼接在一起,之后通过一个全连接层 afc​ 将拼接后的向量映射到一个标量,再应用 LeakyReLU 激活函数得到注意力系数 eij

,其中afc​ 的参数是被共享的
假设目标节点 i 有 k 个邻居节点 j,因此我们需要用 softmax 对得到的所有注意力系数进行归一化,就可以用其对邻居节点特征进行线性组合,得到最终的目标节点特征

多头注意力机制

使用 K 个 Wk 将原特征进行映射到 K 个不同的高维空间,然后重复上述计算单头注意力的步骤

最后concat或者加权组合起来

注意力网络代码

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        '''
        g : dgl的grapg实例
        in_dim : 节点embedding维度
        out_dim : attention编码维度
        '''
        self.g = g #dgl的一个graph实例
        self.fc = nn.Linear(in_dim, out_dim, bias=False) # 对节点进行通用的映射的fc
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False) # 计算edge attention的fc
        self.reset_parameters() # 参数初始化
        
    def edge_attention(self, edges):
                '''
                edges 的src表示源节点,dst表示目标节点
                '''
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1) # eq.1 里面的拼接操作
        a = self.attn_fc(z2) # eq.1 里面对e_{ij}的计算
        return {'e' : F.leaky_relu(a)} # 这里的return实际上等价于 edges.data['e'] =  F.leaky_relu(a),这样后面才能将这个 e 传递给 源节点 
      
    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
    
    def message_func(self, edges):
        return {'z' : edges.src['z'], 'e' : edges.data['e']} #将前面 edge_attention 算出来的 e 以及 edges的源节点的 node embedding 都传给 nodes.mailbox 
    
    def reduce_func(self, nodes):
          # 通过 message_func 之后即可得到 源节点的node embedding 与 edges 的 e
        alpha = F.softmax(nodes.mailbox['e'], dim=1) # softmax归一化,得到 a
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1) # 根据a进行加权求和
        return {'h' : h}
    
    def forward(self, h):
        z = self.fc(h) # eq. 1 这里的h就是输入的node embedding
        self.g.ndata['z'] = z 
        self.g.apply_edges(self.edge_attention) # eq. 2
        self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4
        return self.g.ndata.pop('h') # 返回经过attention的 node embedding

class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        '''
        g : dgl的grapg实例
        in_dim : 节点embedding维度
        out_dim : attention编码维度
        num_heads : 头的个数
        merge : 最后一层为'mean',其他层为'cat'
        '''
        self.heads = nn.ModuleList()
        for i in range(num_heads):
              # 这里简单粗暴,直接声明 num_heads个GATLayer达到 Multi-Head的效果
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge
        
    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        # 根据merge的类别来处理 Multi-Head的逻辑
        if self.merge == 'cat':
            return torch.cat(head_outs, dim=1)
        else:
            return torch.mean(torch.stack(head_outs))
class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        '''
        g : dgl的grapg实例
        in_dim : 节点embedding维度
        hidden_dim : 隐层的维度
        out_dim : attention编码维度
        num_heads : 头的个数
        '''
        # 这里简简单单的写了一个两层的 MultiHeadGATLayer 
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
        
    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

结果可通过 t-SNE 降维来进行可视化

异质图中的注意力网络

异质图注意力网络

元路径及依据元路径定义的邻居

在异质图中会有非常复杂的节点之间的联系,但是这种联系并不全是有效的,所以通过定义元路径来定义一些有意义的连接方式

节点 i 在通过元路径生成的图中的邻居就是依据元路径定义的邻居

如图元路径可定义为MAM和MDM,便得到了依据元路径定义的邻居d

异质图注意力网络整体架构

  • (a)节点级注意力的目的是学习节点与基于元路径的邻居节点之间的重要性。
  • (b)语义级注意力的目的是学习不同元路径的重要性。
  • (c)通过接 MLP 输出节点 i 的预测 y~​i​。

节点级注意力

对每一条元路径形成的邻居节点子图进行一次同质图卷积网络 GAT 的注意力计算

语义级注意力

首先,初始化一个语义级别的可学习向量 q,然后用 q 和每个元路径的节点级特征的高维映射 WZiΦp​​ 求内积,并且平均了所有的节点 i 后,得到元路径 p 的语义级注意力系数

 

最后对进行softmax归一化后的注意力系数进行加权求和

异质图注意力网络代码

class SemanticAttention(nn.Module):
    def __init__(self, in_size, hidden_size=128):
        super(SemanticAttention, self).__init__()

        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False)
        )

    def forward(self, z):
        w = self.project(z).mean(0)                    # (M, 1)
        beta = torch.softmax(w, dim=0)                 # (M, 1)
        beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)

        return (beta * z).sum(1)                       # (N, D * K)

class HANLayer(nn.Module):
    """
    HAN layer.

    Arguments
    ---------
    num_meta_paths : number of homogeneous graphs generated from the metapaths.
    in_size : input feature dimension
    out_size : output feature dimension
    layer_num_heads : number of attention heads
    dropout : Dropout probability

    Inputs
    ------
    g : list[DGLGraph]
        List of graphs
    h : tensor
        Input features

    Outputs
    -------
    tensor
        The output feature
    """
    def __init__(self, num_meta_paths, in_size, out_size, layer_num_heads, dropout):
        super(HANLayer, self).__init__()

        # One GAT layer for each meta path based adjacency matrix
        self.gat_layers = nn.ModuleList()
        for i in range(num_meta_paths):
            self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
                                           dropout, dropout, activation=F.elu))
        self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
        self.num_meta_paths = num_meta_paths

    def forward(self, gs, h):
        semantic_embeddings = []

        for i, g in enumerate(gs):
            semantic_embeddings.append(self.gat_layers[i](g, h).flatten(1))
        semantic_embeddings = torch.stack(semantic_embeddings, dim=1)                  # (N, M, D * K)

        return self.semantic_attention(semantic_embeddings)                            # (N, D * K)
      
class HAN(nn.Module):
    def __init__(self, num_meta_paths, in_size, hidden_size, out_size, num_heads, dropout):
        super(HAN, self).__init__()

        self.layers = nn.ModuleList()
        self.layers.append(HANLayer(num_meta_paths, in_size, hidden_size, num_heads[0], dropout))
        for l in range(1, len(num_heads)):
            self.layers.append(HANLayer(num_meta_paths, hidden_size * num_heads[l-1],
                                        hidden_size, num_heads[l], dropout))
        self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)

    def forward(self, g, h, is_training=True):
        for gnn in self.layers:
            h = gnn(g, h)
            
        if is_training:
            return self.predict(h)
        else:
            return h

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1610580.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

移动Web学习09-响应式布局bootstrap案例开发

3、综合案例-AlloyTeam移动全端 准备工作 HTML 结构 <title>腾讯全端</title> <link rel"shortcut icon" href"favicon.ico" type"image/x-icon"> <!-- 层叠性&#xff1a;咱们的css 要 层叠 框架的 --> <link rel&…

JAVA 项目<果园之窗>_2

上节主要是理论流程&#xff0c;这次直接用实际例子过一遍整个流程 目标是向数据库添加一个员工 上述是前端页面&#xff0c;点击保存 浏览器向我后端发送http请求 后端这一部分专门接收employee请求 在这里对http post请求进行转换成JAVA数据&#xff0c;并处理数据&#xff…

安装zlmediakit和wvp-pro

通过docker安装zlmediakit&#xff0c;并单独启动wvp-pro.jar - zlmediakit安装 zlmediakit安装比较依赖环境和系统配置&#xff0c;所以这里直接使用docker的方式来安装。 docke pull拉取镜像 docker pull zlmediakit/zlmediakit:master使用下边命令先运行起来 sudo docke…

Linux 下安装MySQL 5.7与 8.0详情

官网资源下载 1、检查 /tmp 临时目录权限&#xff0c;将 /tmp 目录及所有子目录和文件的权限设置为所有用户可读、写和执行 [rootlocalhost ~]# chmod -R 777 /tmp2、安装前检查相关依赖 2.1、查找是否安装了 libaio 包。libaio 是 Linux 下异步 I/O 库的一个实现&#xff0…

华为ensp中rip和ospf路由重分发 原理及配置命令

作者主页&#xff1a;点击&#xff01; ENSP专栏&#xff1a;点击&#xff01; 创作时间&#xff1a;2024年4月20日20点21分 路由重分发&#xff08;Route Redistribution&#xff09;是指路由器将从一种路由协议学习到的路由信息&#xff0c;通过另一种路由协议通告出去的功…

汇舟问卷:国外问卷调查做题完整步骤细节展示

大家好&#xff0c;我是汇舟问卷。想要做国外问卷调查&#xff0c;搭建国外的环境是必不可少的&#xff0c;今天给大家讲解一下搭建国外环境的完整步骤&#xff0c;只要按照步骤一步步操作就一定能把环境搭建出来​。 一、安装辅助软件 需要先下载对应的软件进行安装&#xf…

项目七:学会使用python爬虫解析库(小白大成级)

前期我们学会了怎么使用python爬虫请求库和解析库的简单应用和了解&#xff0c;同时能够对爬虫有一个较为清晰的体系&#xff0c;毕竟简单的爬虫基本上都是请求数据——解析数据——存储数据的大概流程。 那么回忆一下&#xff0c;请求库我们学的是requests模块&#xff0c;解…

【数据结构练习题】二叉树(1)——1.相同的树2.另一颗树的子树3.翻转二叉树4.平衡二叉树5.对称二叉树

♥♥♥♥♥个人主页♥♥♥♥♥ ♥♥♥♥♥数据结构练习题总结专栏♥♥♥♥♥ ♥♥♥♥♥上一章&#xff1a;队——1.用队实现栈2.用栈实现队♥♥♥♥♥ 文章目录 1.相同的树1.1题目描述1.2 思路分析1.3绘图分析1.4代码实现2.另一颗树的子树2.1问题描述2.2思路分析2.3绘图分析2.…

视觉SLAM学习打卡【11】-尾述

到目前为止&#xff0c;视觉SLAM14讲已经到了终章&#xff0c;历时一个半月&#xff0c;时间有限&#xff0c;有些地方挖掘的不够深入&#xff0c;只能在后续的学习中更进一步。接下来&#xff0c;会着手ORB-SLAM2的开源框架&#xff0c;同步学习C。 视觉SLAM学习打卡【11】-尾…

【linux】多路径|Multipath I/O 技术

目录 简略 详细 什么是多路径? Multipath安装与使用 安装 使用 Linux下multipath软件介绍 附录 配置文件说明 其他解 简略 略 详细 什么是多路径? 普通的电脑主机都是一个硬盘挂接到一个总线上&#xff0c;这里是一对一的关系。 而到了分布式环境&#xff0c;主机和存储网络连…

Linux使用Libevent库实现一个网页服务器---C语言程序

Web服务器 这一个库的实现 其他的知识都是这一个专栏里面的文章 实际使用 编译的时候需要有一个libevent库 gcc httpserv.c -o httpserv -levent实际使用的时候需要指定端口以及共享的目录 ./httpserv 80 .这一个函数会吧这一个文件夹下面的所有文件共享出去 实际的效果, 这…

Matlab之空间坐标系绘制平面图形

在空间直角坐标系中&#xff0c;绘制指定平面方程的图形 版本说明&#xff1a; 20240413_V1.01&#xff1a;更正代码错误&#xff0c;并修改输入参数类型&#xff08;测试用例得修改&#xff09; 20240413_V1.00&#xff1a;初始版本 一、平面方程 基本形式为&#xff1a;A…

[阅读笔记29][AgentStudio]A Toolkit for Building General Virtual Agents

这篇论文是24年3月提交的&#xff0c;提出了一个用于agent开发的全流程工具包。 作者提到目前agent开发主要有两个阻碍&#xff0c;一个是缺乏软件基础&#xff0c;另一个是缺乏在真实世界场景中进行评估。针对这两个阻碍&#xff0c;作者涉及了一个开发工具包&#xff0c;包括…

中颖51芯片学习8. ADC模数转换

中颖51芯片学习8. ADC模数转换 一、ADC工作原理简介1. 概念2. ADC实现方式3. 基准电压 二、中颖芯片ADC功能介绍1. 中颖芯片ADC特性2. ADC触发源&#xff08;1&#xff09;**软件触发**&#xff08;2&#xff09;**TIMER4定时器触发**&#xff08;3&#xff09;**外部中断2触发…

2024第八届图像、信号处理和通信国际会议 (ICISPC 2024)即将召开!

2024第八届图像、信号处理和通信国际会议 &#xff08;ICISPC 2024&#xff09;将于2024年7月19-21日在日本福冈举行。启迪思维&#xff0c;引领未来&#xff0c;ICISPC 2024的召开&#xff0c;旨在全球专家学者共襄盛举&#xff0c;聚焦图像信号&#xff0c;在图像中寻找美&am…

通用大模型研究重点之五:llama family

LLAMA Family decoder-only类型 LLaMA&#xff08;Large Language Model AI&#xff09;在4月18日公布旗下最大模型LLAMA3&#xff0c;参数高达4000亿。目前meta已经开源了80亿和700亿版本模型&#xff0c;主要升级是多模态、长文本方面工作。 模型特点&#xff1a;采用标准的…

企业监管工具:为何如此重要?

随着通信技术的发展&#xff0c;员工使用微信等即时通讯工具来进行工作沟通已经成为了常态。为了帮助企业有效地监管员工的工作微信使用情况&#xff0c;微信管理系统应运而生。 下面就一起来看看&#xff0c;它都有哪些功能吧&#xff01; 1、历史消息&#xff1a;洞察员工聊…

VMware设置Centos7静态ip

1、获取网段&#xff0c;子网掩码和网关 到此获取到的信息&#xff1a; 网段&#xff1a;192.168.204.128 ~ 192.168.204.254 子网掩码&#xff1a;255.255.255.0 网关IP&#xff1a;192.168.204.2 2、修改Centos系统的网络配置 使用命令vim /etc/sysconfig/network-scripts/…

一键搞定线性回归亚组森林图!快速生成顶级SCI论文的高清图!

现在亚组分析好像越来越流行&#xff0c;无论是观察性研究还是RCT研究&#xff0c;亚组分析一般配备森林图。 其实亚组分析的原理十分简单&#xff1a;它一般属于文章的附加内容&#xff0c;文章主体通过对全人群进行分析后&#xff0c;希望在亚组人群中进一步探索暴露与结局的…

DS:顺序表的实现

感谢各位友友的支持&#xff01;目前我的博客进行到了DS阶段&#xff0c;在此阶段首先会介绍一些数据结构相关的知识&#xff0c;然后再进行顺序表的学习。学习数据结构是为后面的通讯录项目打基础。 在学习数据结构之前&#xff0c;需要友友们掌握一些储备知识——结构体、指…