GAT学习

news2024/11/24 20:39:54

文章目录

  • GAT
  • 注意力机制的定义
  • 图注意力层
  • 多头注意力机制
  • GATConv层中forward函数步骤解析:
    • 1. 计算wh。wh:带权特征向量
    • 2. 计算注意力分数e
    • 3. 激活注意力分数e
    • 4. 由边的索引获取邻接矩阵
    • 5. 获得注意力分数矩阵。 attention[i][j]表示i j之间的注意力分数
      • torch.where详解:
    • 6. 归一化注意力分数
    • 7. 加权融合特征向量
    • 8.添加偏置
  • 完整代码
  • 后记

GAT

由于
信息处理能力的局限,人类会选择性地关注完整信息中的某一部分,同时忽略其他信息。这种机制大大提高了人类对信息的处理效率。
注意力机制的核心在于对给定信息进行权重分配,权重高的信息意味着需要系统进行重点加工。
图注意力网络(Graph Attention Networks):自动学习图中节点对节点之间的影响度

注意力机制的定义

在这里插入图片描述
上式中:
Source是需要系统处理的信息源
Query代表某种条件或者先验信息
Attention Value是给定Query信息的条件下,通过注意力机制从Source中提取得到的信息。
similarity(Query,Keyi)表示Query向量和Key向量的相关度,最直接的方法是可以取两向量的内积<Query,Keyi>。内积越大,相似度越高

图注意力层

在这里插入图片描述
上图中,hi:hi∈Rd(l)任意节点vi在第l层所对应的特征向量。
经过一个以注意力机制为核心的聚合操作之后,输出的是每个节点的新的特征向量hi’, hi’∈Rd(l+1)。我们将这个聚合操作称为图注意力层。

假设中心节点为vi, 我们设邻居节点Vj到vi的权重系数eij 为:
在这里插入图片描述
W∈Rd(l+1)xd(l) 是该层节点特征变换的权重系数。
α(·) 是计算两个节点相关度的函数。原则上可以计算图中任意一个节点到节点vi的权重系数,为简化计算将其限制在一节邻居内(在GAT中,将自己也视为自己的邻居)。这里的α可以用向量的内积,只要保证最后输出一个实数就可以。
这里采用如下方程:
在这里插入图片描述
α是一个权重参数,α∈R2d(l+1).这个R表示是实数,2d表示长度,l+1是层数。
W∈Rd(l+1)xd(l) 是该层节点特征变换的权重系数。
hi hj表示节点的特征向量。
在这里插入图片描述αij表示i-j之间的attention系数。 表示i-j之间的关联程度,重要性之类的。GAT使用自注意力机制来计算节点的邻居节点对节点 i 的贡献,并以加权的方式将邻居节点的特征融合到节点 i 的特征中。
h   i   ~ \widetilde{h~i~} h i  表示i节点的特征。
W是一个系数
[Whi||Whj] 表示将两个特征拼接在一起
a ⃗ \vec{a} a T 表示一个可学习的系数。
h ⃗ \vec{h} h j表示j节点(为 i 的邻居)的特征
h ⃗ \vec{h} h i表示节点i的特征
h ⃗ \vec{h} h i’ 表示i节点聚合了所有邻居之后的特征。
eij: 邻居节点vj到vi的权重系数
whi 是节点i的特征表示hi经过权重矩阵weight_w的线性变换后得到的结果, 可以理解为“节点i的权重特征”或“节点i的特征映射”
在这里插入图片描述

多头注意力机制

在这里插入图片描述
h ⃗ \vec{h} h i’ 表示i节点聚合了所有邻居之后的特征。
第二行的表示选取了多个参数,(αij、W)得到节点的多个特征向量。||表示将这些特征向量拼接到一起。
第三行是将多个特征向量求和取平均。


GATConv层中forward函数步骤解析:

1. 计算wh。wh:带权特征向量

这里的wh是所有节点的带权特征向量,whi和whj都包含在其中。
x是所有节点的初始特征向量,与weight_w这样一个权重相乘后得到带权特征向量。

wh = torch.mm(x, self.weight_w)     # 公式中的[Whi||whj], 包含所有结点的特征表示,每一行对应一个节点的特征 wh:[2708,16], x:[2708,1433], weight_w:[1433, 16]

2. 计算注意力分数e

e是一个考虑了所有点,但是没有考虑邻居关系的注意力分数矩阵。eij表示邻居节点vj到vi的权重系数,也叫注意力分数。就是vj对于vi来说的的注意力系数是多少。这里考虑了任意两个节点的注意力系数,但是GAT中只需要考虑一阶邻居的注意力系数(自己也算自己的邻居)

e = torch.mm(wh, self.weight_a[: self.out_channels]) + torch.matmul(wh, self.weight_a[self.out_channels:]).T # 公式中的eij, 表示注意力分数

3. 激活注意力分数e

e = self.leakyrelu(e)

4. 由边的索引获取邻接矩阵

if self.adj == None:
    self.adj = to_dense_adj(edge_index).squeeze()   # 将稀疏邻接矩阵转换为密集邻接矩阵

     # 添加自环,考虑自身加权
    if self.add_self_loops:
         self.adj += torch.eye(x.shape[0]).to(device)

5. 获得注意力分数矩阵。 attention[i][j]表示i j之间的注意力分数

这里的注意力分数矩阵attention是从注意力分数e演变过来的。前面说e考虑了任意两点之间的权重系数,但是我们只要一阶邻居的,所以这里是做了这么个操作。

attention = torch.where(self.adj > 0, e, -1e9 * torch.ones_like(e))

torch.where详解:

torch.where(condition, a, b)
如果condition满足,返回a,如果不满足,返回b

6. 归一化注意力分数

因为要保证所有邻居的权重系数和为1,所以要进行归一化。

attention = F.softmax(attention, dim=1)  # attention:[2708, 2708]

7. 加权融合特征向量

前面的一系列操作就是为了得到注意力系数矩阵attention,然后要将原来的特征项向量hi通过注意力系数进行加权:

output = torch.mm(attention, wh)        # output: [2707,2708]*[2708,16]=[2708,16]

8.添加偏置

if self.bias != None:
    return output + self.bias.squeeze().unsqueeze(0) # self.bias是[16, 1],要变成[16]或者[1, 16]才能自动broadcast相加。可以不用unsqueeze()
else:
    return output

完整代码

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.sparse import coo_matrix
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_dense_adj

# 1.加载Cora数据集
dataset = Planetoid(root='./data/Cora', name='Cora')    # 从PYG中加载数据集,保存到本地根目录的/data/Cora下


# 2.定义GATConv层
class GATConv(nn.Module):
    def __init__(self, in_channels, out_channels, heads=1, add_self_loops=True, bias=True):   # GATConv1: in_channels:1433, out_channel:16
        super(GATConv, self).__init__()     # 子类的初始化,但是在调用子类的初始化时会调用父类的初始化,所以相当于调用nn.Moudle的初始化
        self.in_channels = in_channels  # 输入图节点的特征数
        self.out_channels = out_channels  # 输出图节点的特征数
        self.adj = None
        self.add_self_loops = add_self_loops

        # 定义参数 θ
        self.weight_w = nn.Parameter(torch.FloatTensor(in_channels, out_channels))  #公式中的W  [1433, 16] nn.Parameter()将张量封装为可训练参数。
        self.weight_a = nn.Parameter(torch.FloatTensor(out_channels * 2, 1))        #公式中的a^T  weight_a:[32,1] 由于要和[Whi||whj]拼接在一起,所以size要*2
        # weight_a: 将节点的特征映射成注意力分数

        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_channels, 1))
        else:
            self.register_parameter('bias', None)                                   # 注册上一个参数

        self.leakyrelu = nn.LeakyReLU()
        self.init_parameters()

    # 初始化可学习参数
    def init_parameters(self):
        nn.init.xavier_uniform_(self.weight_w)          # 使用xavier初始化方式初始化参数
        nn.init.xavier_uniform_(self.weight_a)

        if self.bias != None:
            nn.init.zeros_(self.bias)

    def forward(self, x, edge_index):
        # 1.计算wh,进行节点空间映射 wh:带权特征向量
        wh = torch.mm(x, self.weight_w)     # 公式中的[Whi||whj], 包含所有结点的特征表示,每一行对应一个节点的特征 wh:[2708,16], x:[2708,1433], weight_w:[1433, 16]

        # 2.计算注意力分数e    e:[2708, 2708],用到了广播机制 由[2708, 1] + [1, 2708]搞起来的.
        # 第一项得到一个点对其他点的注意力分数,第二项一转置得到所有点对其他点的注意力分数,然后通过广播机制相加,得到所有点对所有点的注意力分数。
        # 但是这里只是初始化的,并未考虑节点的邻居关系。
        e = torch.mm(wh, self.weight_a[: self.out_channels]) + torch.matmul(wh, self.weight_a[self.out_channels:]).T # 公式中的eij, 表示注意力分数

        # 3.激活
        e = self.leakyrelu(e)

        # 4.由边的索引获取邻接矩阵
        if self.adj == None:
            self.adj = to_dense_adj(edge_index).squeeze()   # 将稀疏邻接矩阵转换为密集邻接矩阵

            # 添加自环,考虑自身加权
            if self.add_self_loops:
                self.adj += torch.eye(x.shape[0]).to(device)

        # 5.获得注意力分数矩阵。 attention[i][j]表示i j之间的注意力分数
        attention = torch.where(self.adj > 0, e, -1e9 * torch.ones_like(e))

        # 6.归一化注意力分数
        attention = F.softmax(attention, dim=1)  # attention:[2708, 2708]

        # 7.加权融合特征向量
        output = torch.mm(attention, wh)        # output: [2707,2708]*[2708,16]=[2708,16]

        # 8.添加偏置
        if self.bias != None:
            return output + self.bias.squeeze().unsqueeze(0) # self.bias是[16, 1],要变成[16]或者[1, 16]才能自动broadcast相加
        else:
            return output


# 3.定义GAT网络
class GAT(nn.Module):
    def __init__(self, num_node_features, num_classes): # num_node_features:1433  num_classes:7
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels=num_node_features,
                             out_channels=32,
                             heads=2)   #heads表示多头
        self.conv2 = GATConv(in_channels=32,
                             out_channels=16,
                             heads=2)
        self.conv3 = GATConv(in_channels=16,
                             out_channels=num_classes,
                             heads=1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)       # 将节点特征x和边的索引edge_index作为输入通道和输出通道
        x = F.relu(x)
        x = F.dropout(x, training=self.training)    # training用于区分是否是训练模式
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        return F.log_softmax(x, dim=1)      # 计算节点的类别概率分布


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 设备
epochs = 200  # 学习轮数 训练轮数
lr = 0.0003  # 学习率
num_node_features = dataset.num_node_features  # 每个节点的特征数
num_classes = dataset.num_classes  # 每个节点的类别数
data = dataset[0].to(device)  # Cora的一张图

# 4.定义模型
model = GAT(num_node_features, num_classes).to(device)      # 将模型放到指定设备上运算
optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # 优化器
loss_function = nn.NLLLoss()  # 损失函数

# 训练模式
model.train()

for epoch in range(epochs):

    pred = model(data)

    loss = loss_function(pred[data.train_mask], data.y[data.train_mask])  # 损失

    correct_count_train = torch.eq(pred[data.train_mask].argmax(axis=1), data.y[data.train_mask]).sum().item()  # epoch正确分类数目
    # correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()  # epoch正确分类数目
    acc_train = correct_count_train / data.train_mask.sum().item()  # epoch训练精度
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print("【EPOCH: 】%s" % str(epoch + 1))
        print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))

print('【Finished Training!】')

# 模型验证
model.eval()
pred = model(data)

# 训练集(使用了掩码)
# 再在测试集上看看效果
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()
acc_train = correct_count_train / data.train_mask.sum().item()
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()

# 测试集
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()

print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))
print('Test  Accuracy: {:.4f}'.format(acc_test), 'Test  Loss: {:.4f}'.format(loss_test))

后记

今天,花了一天的时间学这个。对我来说,我觉得进步很大,终于不是一头雾水了,终于拨开云雾见青天了。
生活,重要是过的开心,最好的方法就是享受当下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1048741.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

苹果发布iOS 17.1首个beta版本,新增了这几个新功能!

苹果今日向iPhone/iPad用户推送了iOS/iPadOS17.1开发者预览版Beta版更新&#xff0c;iOS/iPadOS17.1Beta内部版本号为21B5045h。 iOS/iPadOS17.1Beta更新内容如下&#xff1a; 一&#xff1a;Apple Music“已喜爱”分类 用户可以在 iOS 17.1 的 Apple Music 中收藏歌曲、专辑…

优化邮箱体验!推荐替代方案:提升企业效率的选择

近年来&#xff0c;随着互联网技术的快速发展&#xff0c;电子邮件成为了企业沟通和协作的重要工具。而作为国内知名的企业邮箱服务提供商&#xff0c;网易企业邮箱凭借其稳定性、安全性和易用性&#xff0c;受到了广大企业的青睐。然而&#xff0c;随着市场竞争的加剧&#xf…

26532-2011 地理标志产品 慈溪杨梅

声明 本文是学习GB-T 26532-2011 地理标志产品 慈溪杨梅. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了慈溪杨梅的术语和定义、地理标志产品保护范围、要求、试验方法、检验规则及标志、 标签、包装、运输和贮存。 本标准适用…

基于java的银行风险预警系统的研究与实现

文章目录 导文摘要:前言:绪论:相关技术与方法介绍:系统分析:系统设计:系统实现:系统测试:总结与展望:导文 基于java的银行风险预警系统的研究与实现 本文基于Java开发了一款银行风险预警系统,旨在帮助银行有效管理风险并提前预警潜在风险。下面将对文中的各个部分进行…

微信小程序 课程签到系统

目录 前端页面展示主页面我的课程个人中心评论功能签到功能课程绑定超级管理员页面 前端文件结构文件结构app.json前端架构和开发工具前端项目地址 后端后端架构后端项目地址 注意事项 前端页面展示 主页面 登录页面&#xff1a; 账号是&#xff1a;用户名或者手机号 密码是&a…

【Elasticsearch】聚合查询(四)

Elasticsearch&#xff08;简称为ES&#xff09;是一个基于Lucene的开源搜索和分析引擎&#xff0c;提供了丰富的聚合查询功能。聚合查询指的是在搜索结果上执行分组、汇总和统计等操作&#xff0c;以便从大量数据中提取有用的信息和洞察。 这篇文章主要介绍检索相关的操作&…

freertos的任务调度器的启动函数分析(根据源码使用)

volatile uint8_t * const pucFirstUserPriorityRegister ( uint8_t * ) ( portNVIC_IP_REGISTERS_OFFSET_16 portFIRST_USER_INTERRUPT_NUMBER ); 通过宏pucFirstUserPriorityRegister0xE000E400&#xff08;根据宏名字&#xff0c;这是NVIC寄存器地址&#xff09; 查手册…

26523-2022 精制硫酸钴 随笔练习

声明 本文是学习GB-T 26523-2022 精制硫酸钴. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本文件规定了精制硫酸钴的要求、试验方法、检验规则、标志、标签、包装、运输和贮存。 本文件适用于精制硫酸钴。 注&#xff1a;该产品主要用于…

【面试题】有了Docker为啥还需要k8s?

个人主页&#xff1a;金鳞踏雨 个人简介&#xff1a;大家好&#xff0c;我是金鳞&#xff0c;一个初出茅庐的Java小白 目前状况&#xff1a;22届普通本科毕业生&#xff0c;几经波折了&#xff0c;现在任职于一家国内大型知名日化公司&#xff0c;从事Java开发工作 我的博客&am…

【python】基础语法

文章目录 元组列表字典集合推导式函数错误和异常处理文件和操作系统 元组 元组是一个固定长度&#xff0c;不可改变的Python序列对象。创建元组的最简单方式&#xff0c;是用逗号分隔一列值。 创建 2. 元组不可修改的解释 对于元组对象不可变的说明&#xff0c;通俗一点就是…

自学成为一名黑客

前言&#xff1a;想自学网络安全&#xff08;黑客技术&#xff09;首先你得了解什么是网络安全&#xff01;什么是黑客 网络安全可以基于攻击和防御视角来分类&#xff0c;我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、“安全运营”、“安…

【操作系统】调度算法的评价指标和三种调度算法

&#x1f40c;个人主页&#xff1a; &#x1f40c; 叶落闲庭 &#x1f4a8;我的专栏&#xff1a;&#x1f4a8; c语言 数据结构 javaEE 操作系统 Redis 石可破也&#xff0c;而不可夺坚&#xff1b;丹可磨也&#xff0c;而不可夺赤。 操作系统 一、调度算法的评价指标1.1 CPU利…

使用cpolar端口映射的方法轻松实现在Linux环境下SVN服务器的搭建与公网访问

文章目录 前言1. Ubuntu安装SVN服务2. 修改配置文件2.1 修改svnserve.conf文件2.2 修改passwd文件2.3 修改authz文件 3. 启动svn服务4. 内网穿透4.1 安装cpolar内网穿透4.2 创建隧道映射本地端口 5. 测试公网访问6. 配置固定公网TCP端口地址6.1 保留一个固定的公网TCP端口地址6…

深入剖析ThreadLocal使用场景、实现原理、设计思想

前言 ThreadLocal可以用来存储线程的本地数据&#xff0c;做到线程数据的隔离 ThreadLocal的使用不当可能会导致内存泄漏&#xff0c;排查内存泄漏的问题&#xff0c;不仅需要熟悉JVM、利用好各种分析工具还耗费人工 如果能明白其原理并正确使用&#xff0c;就不会导致各种意…

【Verilog 教程】6.2Verilog任务

关键词&#xff1a;任务 任务与函数的区别 和函数一样&#xff0c;任务&#xff08;task&#xff09;可以用来描述共同的代码段&#xff0c;并在模块内任意位置被调用&#xff0c;让代码更加的直观易读。函数一般用于组合逻辑的各种转换和计算&#xff0c;而任务更像一个过程&a…

win10搭建Selenium环境+java+IDEA(2)

接着上一个搭建环境开始叙述&#xff1a;win10系统x64安装java环境以及搭建自动化测试环境_荟K的博客-CSDN博客 上一步结尾的浏览器驱动&#xff0c;本人后面改到了谷歌浏览器.exe文件夹下&#xff1a; 这里需要注意&#xff0c;这个新路径要加载到系统环境变量中。 上一步下…

2023-9-28 JZ26 树的子结构

题目链接&#xff1a;树的子结构 import java.util.*; /** public class TreeNode {int val 0;TreeNode left null;TreeNode right null;public TreeNode(int val) {this.val val;}} */ public class Solution {public boolean HasSubtree(TreeNode root1,TreeNode root2) …

吉利微型纯电,5 万元的快乐

熊猫骑士作为一款主打下层市场的迷你车型&#xff0c;吉利熊猫骑士剑指宝骏悦也&#xff0c;五菱宏光 MINI 等热门选手。 9 月 15 日&#xff0c;吉利熊猫骑士正式上市&#xff0c;售价为 5.39 万&#xff0c;限时优享价 4 .99 万元。价格和配置上对这个级别定位的战略车型有一…

力扣刷题-哈希表-判断两个字符串_其他中元素是否一致

242 有效的字母异位词 给定两个字符串 s 和 t &#xff0c;编写一个函数来判断 t 是否是 s 的字母异位词。 示例 1: 输入: s “anagram”, t “nagaram” 输出: true 示例 2: 输入: s “rat”, t “car” 输出: false 说明: 你可以假设字符串只包含小写字母。 解释&#x…

云安全之身份认证与授权机制介绍

认证与授权技术概述 认证&#xff0c;用于证实某事是否真实或有效的过程。认证一般由标识(ldentification)和鉴别(Authentication)两部分组成。 认证技术分类 身份认证&#xff1a;口令认证、生物特征识别 报文认证&#xff1a;报文源的认证、报文宿的认证、报文内容的认证…