图注意网络GAT理解及Pytorch代码实现【PyGAT代码详细注释】

news2024/10/2 12:28:44

文章目录

    • GAT
    • 代码实现【PyGAT】
      • GraphAttentionLayer【一个图注意力层实现】
      • 用上面实现的单层网络测试
      • 加入Multi-head机制的GAT
      • 对数据集Cora的处理
      • csr_matrix()处理稀疏矩阵
      • encode_onehot()对label编号
      • build graph
      • 邻接矩阵构造
    • GAT的推广

GAT

题:Graph Attention Networks
摘要:
提出了图形注意网络(GAT) ,这是一种基于图结构数据的新型神经网络结构,利用掩蔽的自我注意层来解决基于图卷积或其近似的先前方法的缺点。通过叠加层,节点能够参与其邻域的特征,我们能够(隐式地)为邻域中的不同节点指定不同的权重,而不需要任何代价高昂的矩阵操作(如反演) ,或者依赖于预先知道图的结构。通过这种方法,我们同时解决了基于谱的图形神经网络的几个关键问题,并使我们的模型容易地适用于归纳和转导问题。我们的 GAT 模型已经实现或匹配了四个已建立的转导和归纳图基准的最新结果: Cora,Citeseer 和 Pubmed 引用网络数据集,以及protein-protein interaction dataset(其中测试图在训练期间保持不可见)。

Paper with code 网址,可找到对应论文和github源码,原论文使用TensorFlow实现,本篇主要对Pytorch版本的 PyGAT附详细注释帮助理解和测试。

GitHUb: keras版本实现
Pytorch版本实现 PyGAT

在这里插入图片描述
在这里插入图片描述

截图及下文代码注释参考自视频:GAT详解及代码实现

视频中的eij的实现与源码不同,视频中是先拼接两个W,再与a乘
源码在_prepare_attentional_mechanism_input()函数中先分别与a乘,再拼接

代码实现【PyGAT】

在PyGAT :

  • layers.py中定义Simple GAT layer实现(GraphAttentionLayer)和Sparse version GAT layer实现(SpGraphAttentionLayer)。
  • models.py 实现两个版本加入Multi-head机制
  • trains.py 使用model定义的GAT构建模型进行训练,使用cora数据集

GraphAttentionLayer【一个图注意力层实现】

class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features#结点向量的特征维度
        self.out_features = out_features#经过GAT之后的特征维度
        self.alpha = alpha#dropout参数
        self.concat = concat#LeakyReLU参数

        # 定义可训练参数,即论文中的W和a
        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)# xavier初始化
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)# xavier初始化

        # 定义leakyReLU激活函数
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, h, adj):
        '''
        adj图邻接矩阵,维度[N,N]非零即一
        h.shape: (N, in_features), self.W.shape:(in_features,out_features)
        Wh.shape: (N, out_features)
        '''
        Wh = torch.mm(h, self.W) # 对应eij的计算公式
        e = self._prepare_attentional_mechanism_input(Wh)#对应LeakyReLU(eij)计算公式

        zero_vec = -9e15*torch.ones_like(e)#将没有链接的边设置为负无穷
        attention = torch.where(adj > 0, e, zero_vec)#[N,N]
        # 表示如果邻接矩阵元素大于0时,则两个节点有连接,该位置的注意力系数保留
        # 否则需要mask设置为非常小的值,因为softmax的时候这个最小值会不考虑
        attention = F.softmax(attention, dim=1)# softmax形状保持不变[N,N],得到归一化的注意力全忠!
        attention = F.dropout(attention, self.dropout, training=self.training)# dropout,防止过拟合
        h_prime = torch.matmul(attention, Wh)#[N,N].[N,out_features]=>[N,out_features]

        # 得到由周围节点通过注意力权重进行更新后的表示
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def _prepare_attentional_mechanism_input(self, Wh):
        # Wh.shape (N, out_feature)
        # self.a.shape (2 * out_feature, 1)
        # Wh1&2.shape (N, 1)
        # e.shape (N, N)
        # 先分别与a相乘再进行拼接
        Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])
        Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])
        # broadcast add
        e = Wh1 + Wh2.T
        return self.leakyrelu(e)

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

用上面实现的单层网络测试

x = torch.randn(6,10)
adj=torch.tensor([[0,1,1,0,0,0],
                  [1,0,1,0,0,0],
                  [1,1,0,1,0,0],
                  [0,0,1,0,1,1],
                  [0,0,0,1,0,0,],
                  [0,0,0,1,1,0]])
my_gat = GraphAttentionLayer(10,5,0.2,0.2)
print(my_gat(x,adj))
输出:
tensor([[-0.2965,  2.8110, -0.6680, -0.9643, -0.9882],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.4981, -0.7515,  1.1159,  0.3546,  1.3592],
        [ 0.4679,  1.7208,  0.3084, -0.5331, -0.1291],
        [-0.4375, -0.8778,  1.1767, -0.5869,  1.5154],
        [-0.2164, -0.5897,  0.4988, -0.3125,  0.6423]], grad_fn=<EluBackward>)

加入Multi-head机制的GAT

用不同head捕捉不同特征,使模型有更好的拟合能力。

class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = dropout

        # 加入Multi-head机制
        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return F.log_softmax(x, dim=1)

对数据集Cora的处理

在这里插入图片描述
数据集中两个文件,cites:比如上图11行:编号25和编号1114331的文章
content文件:如下图,每篇文章的id、features及类别
在这里插入图片描述

csr_matrix()处理稀疏矩阵

utils.py中对数据进行的处理

#数据是稀疏的,csr_matrix操作从行开始将1的位置取出来,对数据进行压缩

features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
labels = encode_onehot(idx_features_labels[:, -1])

在这里插入图片描述

encode_onehot()对label编号

有7个类别,通过classes_dict是7*7的对角阵把每个类别映射成不同向量,对所有label进行编号,再将编号转换为one_hot向量
在这里插入图片描述

def encode_onehot(labels):
    # The classes must be sorted before encoding to enable static class encoding.
    # In other words, make sure the first class always maps to index 0.
    classes = sorted(list(set(labels)))
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
    return labels_onehot

build graph

见注释:

# build graph
    idx = np.array(idx_features_labels[:, 0], dtype=np.int32)#获取所有文章id
    idx_map = {j: i for i, j in enumerate(idx)}#按文章数目,对id重新映射
    # 读取数据集中文章和文章直接的引用关系
    edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset), dtype=np.int32)
    # 根据idx_map,将文章引用关系也重新映射
    edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), dtype=np.int32).reshape(edges_unordered.shape)
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), shape=(labels.shape[0], labels.shape[0]), dtype=np.float32)

    # build symmetric adjacency matrix 生成邻接矩阵
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

    features = normalize_features(features)
    adj = normalize_adj(adj + sp.eye(adj.shape[0]))

邻接矩阵构造

csr_matrix()只记录了(0,1)1,忽略了(1,0)1。所以需要coo_matrix()操作!才能还原出无向图的邻接矩阵!
在这里插入图片描述

本文的一些代码注释及截图还可见视频

一个拓展:

GAT的推广

GAT的推广
GAT仅仅是应用在了单层图结构网络上,我们是否可以将它推广到多层网络结构呢?

这里我们假设一个有N层网络的结构,每层网络都定义了相同的节点,但是节点之间的关系有所差异。举一个简单的例子,假设有一个用户关系型网络,每层网络的节点都被定义成了网络中的用户,网络的第一层视图的关系可以定义为,两个用户之间是否具有好友关系;网络的第二层视图可以定义为,你评论过我的动态;网络的第三层视图可以定义为你转发过我的动态;第四层关系可以定义为,你at过我等等。

通过这样的定义我们就完成了一个多层网络的构建,他们共享相同的节点,但又分别具有不同的邻边,如果我们分别处理每一层视图视图,然后将他们得出的节点表示单纯相加的话,就可能会失去不同视图之间的协作关系,降低分类(预测)的精度。

基于以上观点,我们提出了一种新的方法:首先在每一层单视图中应用GAT进行学习,并计算出每层视图的节点表示。之后再不同视图之间引入attention机制来让网络自行学习不同视图的权重。之后根据学习的权重,将各个视图加权相加得到全局节点表示并进行后续的诸如节点表示,链接预测等任务。

同时,因为不同视图共享同样的节点,即使每一层视图都表示了不同的节点关系,最终得到的每一层的节点嵌入表示应具有一定的相关性。基于以上理论,我们在每层GAT的网络参数间引入正则化项来约束参数,使其向互相相近的方向学习。大致的网络流程图如下:

这部分来源于 链接:https://www.jianshu.com/p/d5d366ba1a57 来源:简书

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/380335.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Netty之ChannelFuture详解

目录 目标 Netty版本 Netty官方API 客户端如何与服务器建立连接&连接成功后的操作方式 实现 如何处理客户端与服务器连接关闭后的操作 正确关闭连接的方式 方法一 方法二 目标 了解Netty如何处理客户端与服务器之间的连接与关闭问题。 Netty版本 <dependency&…

Kafka系列之:Kafka生产者和消费者

Kafka系列之:Kafka生产者和消费者 一、Kafka生产者发送流程二、提高生产者吞吐量三、Kafka消费方式四、Kafka消费者总体工作流程五、按照时间消费Kafka Topic一、Kafka生产者发送流程 batch.size:只有数据积累到batch.size之后,sender才会发送数据,默认16K。linger.ms:如果…

预热:Eyeshot 2023 Beta 正式版不远 Eyeshot 2023 Fem

预热&#xff1a;Eyeshot 2023 Beta 离正式版不远 Eyeshot 2023 Fem 破解版 devDept Software 自豪地宣布推出新的Eyeshot 2023 Beta版本。 现在已经完成了几次迁移&#xff0c;我们有了一个最终的工作区架构&#xff0c;它不再需要设计/设计用户界面分离的对象。正如我们在迁移…

SMPL可视化大杀器,你并不需要下载SMPL就能可视化你的3D Pose

SMPL 是一种3D人体建模方法&#xff0c;现在几乎所有的元宇宙人体建模都是基于此类方法&#xff0c;包括但不限于元宇宙&#xff0c;自动驾驶等领域。它能估计出比较准确的人体3D姿态&#xff0c;得益于海量数据训练的人体3D先验。不仅仅是人体&#xff0c;包括手部&#xff0c…

【Windows应急响应】HW蓝队必备——开机启动项、临时文件、进程排查、计划任务排查、注册表排查、恶意进程查杀、隐藏账户、webshell查杀等

Windows应急响应应急响应的重要性开机启动项temp文件分析浏览器信息分析文件时间属性分析最近打开文件分析进程分析计划任务隐藏账户的发现添加与删除恶意进程发现及关闭补丁信息webshell查杀应急响应的重要性 近年来信息安全事件频发&#xff0c;信息安全的技能、人才需求大增…

linux + jenkins + svn + maven + node 搭建及部署springboot多模块前后端服务

linux搭建jenkins 基础准备 linux配置jdk、maven&#xff0c;配置系统配置文件 vi /etc/profile配置jdk、maven export JAVA_HOME/usr/java/jdk1.8.0_261-amd64 export CLASSPATH.:$JAVA_HOME/jre/lib/rt.jar:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jarexport MAVEN_H…

【深入浅出 Yarn 架构与实现】4-6 RM 行为探究 - 申请与分配 Container

本小节介绍应用程序的 ApplicationMaster 在 NodeManager 成功启动并向 ResourceManager 注册后&#xff0c;向 ResourceManager 请求资源&#xff08;Container&#xff09;到获取到资源的整个过程&#xff0c;以及 ResourceManager 内部涉及的主要工作流程。 一、整体流程 …

吴恩达机器学习笔记——线性回归

1.模型描述有训练集数据房子面积和卖出的价钱&#xff0c;我们用这组数据来模拟特定面积的房子能够卖出的价钱。这是一个很明显的监督学习&#xff08;supervised learning&#xff09;的例子&#xff0c;因为我们的训练集里包含了正确的结果&#xff08;即房子的卖价&#xff…

非递归迭代实现二叉树前序,中序,后序遍历

文章目录1. 前序遍历2. 中序遍历3. 后序遍历1. 前序遍历 题目链接 解题思路&#xff1a; 非递归遍历一棵树有两点&#xff1a; 1.左路结点 2.左路结点的右子树 什么意思呢&#xff1f; 我们知道前序遍历是按照根&#xff0c;左子树&#xff0c;右子树来的。所以它是先根&…

js中的原型链

js中原型和原型链&#x1f61a; 1、为什么需要原型链&#xff1f;&#x1f923;&#x1f61a; 凡事都是有一定的需求和原因发展起来的&#xff0c;在ECMA中为什么要提出原型链这个概念呢&#xff1f; 我们知道&#xff0c;创建对象有两种方式。一种是通过字面量来创建&#…

科研 | 论文写作 | 最常用的LaTeX语法

最常用的LaTeX语法1. 行内公式2. 行间公式3. 下标4. 上标5. 公式编号6. 数学公式7. 根号和分式8. 上下标记9. 向量10. 积分、极限、求和、乘积11. 三圆点12. 重音符号13. 矩阵14. 小写希腊字母和大写希腊字母15. 公式组合16. 拆分单个公式1. 行内公式 格式&#xff1a;将公式编…

流计算框架storm概览

Attention: supervison 和 nimbus的状态都实时保存在zookeeper集群中和本地. Enchance, this means you can kill -9 Nimbus or the Supervisors and theyll start back up as nothing happened. Topologies 1. storm jar all-my-code.jar org.apache.storm.MyTopology a…

父类子类静态代码块、构造代码块、构造方法执行顺序

github:https://github.com/nocoders/java-everything.git 名词解释 静态代码块&#xff1a;java中使用static关键字修饰的代码块&#xff0c;每个代码块只会执行一次&#xff0c;JVM加载类时会执行静态代码块中的代码&#xff0c;静态代码块先于主方法执行。构造代码块&#…

[Java面经] 三年工作经验, 极兔一二面

极兔一二面面经: 1. mysql的acid怎么实现的 这一点先回答ACID分别是A(原子性),C(一致性),I(隔离性),D(持久性), 其中持久性是数据库落磁盘的操作,无需额外实现. 隔离性是通过事务的隔离级别来实现, MySQL默认的隔离级别是RR(可重复读), 虽然上面还有一层Serializable(串行化…

如何在canvas中模拟css的背景图片样式

笔者开源了一个Web思维导图mind-map&#xff0c;最近在优化背景图片效果的时候遇到了一个问题&#xff0c;页面上展示时背景图片是通过css使用background-image渲染的&#xff0c;而导出的时候实际上是绘制到canvas上导出的&#xff0c;那么就会有个问题&#xff0c;css的背景图…

【日常总结】docker容器相互调用,占用服务器带宽解决方案

目录 一、场景&#xff1a; 1. 环境 2. 项目背景&#xff1a; 3. 全球时区解决方案 4. 方案二步骤 二、问题 三、产生原因 四、解决方案 五、解决步骤 六、整改效果 一、场景&#xff1a; docker容器相互调用&#xff0c;占用慢服务器带宽&#xff0c;导致netty连接的…

go 切片(slice)原理及用法注意事项

切片(slice)定义 go语言中的slice是一种数据结构,其定义为一个结构体,如下所示; type SliceHeader struct {Data uintptr // 指向底层数组的指针Len int // 切片的长度Cap int // 切片的容量 }切片与数组 切片的底层数据存储结构是 数组切片较为灵活,能动态扩容,而数组是定…

vue2使用v-viewer实现图片预览ImagePreview

追溯&#xff1a; View UI Plus 是 View Design 设计体系中基于 Vue.js 3 的一套 UI 组件库&#xff0c;里面有个组件ImagePreview可以实现“图片预览”。 使用ImagePreview组件&#xff0c;报错&#xff1a; [Vue warn]: Unknown custom element: <ImagePreview> - d…

odoo15 标题栏自定义

odoo15 标题栏自定义 如何显示为自定义呢 效果如下: 代码分析: export class WebClient extends Component {setup() {this.menuService = useService("menu");this.actionService = useService("action");this.title = useService("title&…

在Docker 上完成对Springboot+Mysql+Redis的前后端分离项目的部署(全流程,全截图)

本文章全部阅读大约2小时&#xff0c;包含一个完整的springboot vue mysqlredis前后端分离项目的部署在docker上的全流程&#xff0c;比较复杂&#xff0c;请做好心理准备&#xff0c;遇到问题可留言或则私信 目录 1 安装Docker&#xff0c;以及简单使用参照 2 Docker部署m…