【模型学习之路】TopK池化,全局池化

news2024/11/28 14:56:21

来学学图卷积中的池化操作

目录

DataBatch

Dense Batching

Dynamic Batching

DataBatch

存取操作

TopKPooling

GAP/GMP

一个例子

后话


DataBatch

当进行图级别的任务时,首先的任务是把多个图合成一个batch。

在Transformer中,一个句子的维度是【单词数,词向量长度】。在一个batch内,batch_size个长度相同的句子(长度短了就做padding)的维度是【句子数,单词数,词向量长度】。

这里,在图任务中得到batch有两种策略。

Dense Batching

一个batch有batch_size个图,第i个图的x的特征维度为m_{i}f,那么先:

m = max(m_{1}, m_{2}, ..., m_{batchsize})

把所有的图做padding,然后合到一起,那么最后数据的维度就是【batch_size, m, f】。

这种方式通常用于需要固定大小输入的场景,例如某些图神经网络的实现或者特定的并行计算框架。

Dynamic Batching

这是PyG默认的批处理方式,它不要求所有图具有相同数量的节点。在这种模式下,每个图的节点特征被拼接在一起,形成一个大的特征矩阵【M,f】,其中:

M = \sum_{i=1}^{batchsize}m_{i}

同时,会有一个batch向量,它是一个长度为M的一维Tensor,记录每个节点属于哪个图。

DataBatch

前面提到过,Data对象是PyG数据的基本单元。我们先生成一个一个Data对象的list:

import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader

data_lst = [Data(x=torch.randint(0, 2, (5, 3)), 
                 edge_index=torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]),
                 y=torch.randint(0, 1, (5,)))
                for _ in range(1000)]

重写Dataset,然后将list[Data]转化为Dataset:

class MyDataset(Dataset):
    def __init__(self, data_lst):
        super(MyDataset, self).__init__()
        self.data_lst = data_lst
    
    def __len__(self):
        return len(self.data_lst)
    
    def __getitem__(self, idx):
        return self.data_lst[idx]

dataset = MyDataset(data_lst)
dataset

# output
MyDataset(1000)

进一步做成Dataloader:

dataloader = DataLoader(dataset, batch_size=32, follow_batch=['x'], shuffle=True)
first_batch = list(dataloader)[0]
first_batch

# output
DataBatch(x=[160, 3], x_batch=[160], x_ptr=[33], edge_index=[2, 128], y=[160], batch=[160], ptr=[33])

x,y,edge_index都是由多个图拼接而成。x_batch就是用来记录每个节点属于哪个图。ptr用于记录每个图的位置信息(不用过多关注),大小正好是batch_size + 1,记录每个图的终点和起点。

不指定follow_batch=['x'],就没有了ptr,模型就会认为这是一个由很多图拼起来的一个大图,而不是视为很多图。这里不必深究,指定一下follow_batch就好了。

存取操作

可以继承重写PyG中一些与数据相关的类,做到存取的效果,不过有些难度可以看看这个:【图神经网络工具】PyTorch Geometric Tutorial 之Data Handling - 知乎

也可以看看这个的15~19集:5-数据集创建函数介绍_哔哩哔哩_bilibili

我们实现一个简单的存取方法:

from torch_geometric.data import Batch
batch = Batch.from_data_list(data_lst)
batch

# output
DataBatch(x=[5000, 3], edge_index=[2, 4000], y=[5000], batch=[5000], ptr=[1001])

可以看到,和我们Dataloader取出来的东西一样,都是DataBatch对象。然后我们把它存起来:

torch.save(batch, 'batch.pt')

loaded_batch = torch.load('batch.pt', map_location='cpu', weights_only=False)
data_lst = loaded_batch.to_data_list()

TopKPooling

先端上官方文档:

torch_geometric.nn.pool.TopKPooling — pytorch_geometric documentation

再端上一张网上随便一找就能看到的图:

p是要学习的参数。y的维度是(M, 1),计算出每一个点的“重要性”。除以二范数是为了标准化。

然后选取M个点中k个最重要的

根据这个topk,在X以及A中挑出对应的k个,得到,相应的邻接矩阵也只保留剩下的边之间的关系。

最后,由于y’本身记录了“重要性”的信息,那就把重要性加权到X中:

  

仅发表一下个人意见,出于归一化的想法,感觉用softmax挺合适:

 

好,搞定。

一个小问题,在做这个pool操作时,会不会导致某一个图的所有节点全部消失?

并不会,因为TopK是独立地在每个图中做topk操作。

GAP/GMP

global_mean_pool(GAP)和global_max_pool(GMP)是两种常用的全局池化(global pooling)操作,它们用于将整个图的信息聚合为一个固定大小的向量。

全局平均池化(GAP)操作将图中所有节点的特征向量求平均。简单说来就是,每一个图表示为自己所有节点求平均得到的向量。

全局最大池化(GMP)操作将图中所有节点的特征向量进行逐元素的最大值操作。简单来说就是,对于每一个图,拿出自己所有的节点,拿到每个特征的最大值,组成一个向量。

So,在维度上,都会有这样的特征:【M, f】-> 【batch_size,f】

这俩是两种常用的全局池化操作,它们用于将图中所有节点的特征聚合为一个全局特征向量。这两种操作通常在图神经网络的最后阶段使用,以便将图级别的表示用于图分类或其他下游任务。

一个例子

用PyG写个一个神经网络模型。

import torch
import torch.nn as nn
from torch_geometric.nn import TopKPooling, SAGEConv
from torch_geometric.nn import global_mean_pool as gap
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        torch.manual_seed(114514)
        
        self.conv1 = SAGEConv(128, 128)
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = SAGEConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = SAGEConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)
        
        self.embed = nn.Embedding(100, 128)
        
        self.lin = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1), 
        )
        
        self.bn = nn.BatchNorm1d(128)
        self.bn2 = nn.BatchNorm1d(64)
        
    def forward(self, data):
        """
        x: [M, 1]
        edge_index: [2, e]
        batch: [M]
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = x.squeeze(1)  # [M, 1] -> [M]  # 这里是大坑!在github评论区逛了一圈,还好一个老外和我一样的错误
        x = self.embed(x)  # [M] -> [M, 128]  
        
        x = self.conv1(x, edge_index)  # [M, 128]
        x = F.relu(x)
        x, edge_index, _, batch, *_ = self.pool1(x, edge_index, None, batch)  # [0.8*M, 128]
        
        x1 = gap(x, batch)  # [batch, 128]

        x = self.conv2(x, edge_index)  # [0.8*M, 128]
        x = F.relu(x)
        x, edge_index, _, batch, *_ = self.pool2(x, edge_index, None, batch)  # [0.8*0.8*M, 128]
        
        x2 = gap(x, batch)  # [batch, 128]
        
        x = self.conv3(x, edge_index)  # [0.8*0.8*M, 128]
        x = F.relu(x)
        x, edge_index, _, batch, *_ = self.pool3(x, edge_index, None, batch)  # [0.8*0.8*0.8*M, 128]
        
        x3 = gap(x, batch)  # [batch, 128]
        
        out = x1 + x2 + x3  # [batch, 128]
        out = self.lin(out)  # [batch, 1]
        out = out.squeeze(1)  # [batch]
        out = F.sigmoid(out)
        return out
        
        

这个网络架构的设计意图是利用图卷积层提取局部图结构特征,通过池化层进行降采样以捕捉更全局的信息,然后通过全连接层和激活函数进行特征融合和分类。这种架构在图分类、节点分类等任务中很常见。

后话

代码中的SAGEConv是什么?它是众多卷积方式的一种。

PyG文档上有大量卷积层、池化层的类。确实,路漫漫其修远兮!

这个文章上有很多的卷积层和池化层的讲解,看看能不能在未来的时间里都弄懂它们的原理:转载 | 一文遍览GNN卷积与池化的代表模型 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2249137.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

<项目代码>YOLOv8 停车场空位识别<目标检测>

YOLOv8是一种单阶段(one-stage)检测算法,它将目标检测问题转化为一个回归问题,能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法(如Faster R-CNN),YOLOv8具有更高的…

如何在Python中进行数学建模?

数学建模是数据科学中使用的强大工具,通过数学方程和算法来表示真实世界的系统和现象。Python拥有丰富的库生态系统,为开发和实现数学模型提供了一个很好的平台。本文将指导您完成Python中的数学建模过程,重点关注数据科学中的应用。 数学建…

ThingsBoard规则链节点:GCP Pub/Sub 节点详解

目录 引言 1. GCP Pub/Sub 节点简介 2. 节点配置 2.1 基本配置示例 3. 使用场景 3.1 数据传输 3.2 数据分析 3.3 事件通知 3.4 任务调度 4. 实际项目中的应用 4.1 项目背景 4.2 项目需求 4.3 实现步骤 5. 总结 引言 ThingsBoard 是一个开源的物联网平台&#xff…

【工具变量】城市供应链创新试点数据(2007-2023年)

一、测算方式:参考C刊《经济管理》沈坤荣和乔刚老师(2024)的做法,使用“供应链创新与应用试点”的政策虚拟变量(TreatPost)表征。若样本城市为试点城市,则赋值为 1,否则为 0&#xf…

小程序租赁系统开发的优势与应用解析

内容概要 随着科技的迅猛发展,小程序租赁系统应运而生,成为许多企业优化业务的重要工具。首先,它提升了用户体验。想象一下,用户只需轻轻一点,就能够浏览和租赁心仪的商品,这种便捷的过程使繁琐的操作大大…

Spring MVC练习(前后端分离开发实例)

White graces:个人主页 🙉专栏推荐:Java入门知识🙉 🐹今日诗词:二十五弦弹夜月,不胜清怨却飞来🐹 ⛳️点赞 ☀️收藏⭐️关注💬卑微小博主🙏 ⛳️点赞 ☀️收藏⭐️关注&#x1f4…

使用IDEA构建springboot项目+整合Mybatis

目录 目录 1.Springboot简介 2.SpringBoot的工作流程 3.SpringBoot框架的搭建和配置 4.用Springboot实现一个基本的select操作 5.SpringBoot项目部署非常简单,springBoot内嵌了 Tomcat、Jetty、Undertow 三种容器,其默认嵌入的容器是 Tomcat,…

不玩PS抠图了,改玩Python抠图

网上找了两个苏轼的印章图片: 把这两个印章抠出来的话,对于不少PS高手来说是相当容易,但是要去掉其中的水印,可能要用仿制图章慢慢描绘,图章的边缘也要慢慢勾画或者用通道抠图之类来处理,而且印章的红色也不…

ElasticSearch的下载和基本使用(通过apifox)

1.概述 一个开源的高扩展的分布式全文检索引擎,近乎实时的存储,检索数据 2.安装路径 Elasticsearch 7.8.0 | Elastic 安装后启动elasticsearch-7.8.0\bin里的elasticsearch.bat文件, 启动后就可以访问本地的es库http://localhost:9200/ …

26届JAVA 学习日记——Day16

2024.11.27 周三 尽量在抽出时间做项目,持续学习优化简历,等到基础的八股都熟悉、leetcode热题100刷完、苍穹外卖项目AI项目彻底完成投简历,目标是找到日常实习,然后边做边准备暑期实习。 八股 WebSocket WebSocket是什么&…

Javaweb 前端 HTML css 案例 总结

顶部导航栏 弹性布局 搜索表单区域 表单标签 表单标签,表单项 复选,一次选多个 隐藏域,看不到,但会传参数 text输入框 radio单选 男女,是 前端页面上显示的值 搜索表单区域 button 按钮 表格数据展示区域 fo…

每日一练:【动态规划算法】斐波那契数列模型之使用最小花费爬楼梯(easy)

1. 题目链接:746. 使用最小花费爬楼梯 2. 题目描述 根据一般的思维,我们会认为本题中数组的最后一个位置是楼顶,但是根据第一个例子,如果最后一个位置是楼顶,花费最少应该为10,但是结果是15,因…

HCIP——堆叠技术实验配置

目录 一、堆叠的理论知识 二、堆叠技术实验配置 三、总结 一、堆叠的理论知识 1.1堆叠概述: 是指将两台交换机通过堆叠线缆连接在一起,从逻辑上变成一台交换设备,作为一个整体参与数据的转发。 1.2堆叠的基本概念 堆叠系统中所有的单台…

微软正在测试 Windows 11 对第三方密钥的支持

微软目前正在测试 WebAuthn API 更新,该更新增加了对使用第三方密钥提供商进行 Windows 11 无密码身份验证的支持。 密钥使用生物特征认证,例如指纹和面部识别,提供比传统密码更安全、更方便的替代方案,从而显著降低数据泄露风险…

ubuntu 安装proxychains

在Ubuntu上安装Proxychains,你可以按照以下步骤操作: 1、更新列表 sudo apt-update 2、安装Proxychains sudo apt-get install proxychains 3、安装完成后,你可以通过编辑/etc/proxychains.conf文件来配置代理规则 以下是一个简单的配置示例&…

数组学习后记——递归

数组这块学得有点乱,条理性欠佳。这次正好总结一下。上周的课堂内容没有更新, 因为小白自己也还没来得及吸收呢qwq。也解释一下为什么文中有这么多例题。因为我呢喜欢就着题去分析和学习,直接灌输知识不太能理解,有例子就能及时检验和应用了的。 先看看B3817 基础的双数组…

螺旋矩阵(java)

题目描述 给你一个 m 行 n 列的矩阵 matrix &#xff0c;请按照 顺时针螺旋顺序 &#xff0c;返回矩阵中的所有元素。 代码思路&#xff1a; class Solution {public List<Integer> spiralOrder(int[][] matrix) {List<Integer> list new ArrayList<>(); …

【C#设计模式(16)——解释器模式(Interpreter Pattern)】

前言 解释器模式是用来解释和执行特定的语法或表达式。它将一种表达式的规则和语义进行抽象和封装&#xff0c;然后通过解释器来解析和执行这些规则&#xff0c;将其转化为可执行的操作。 代码 //抽象表达式public interface Expression{int Interpret(Context context); //解释…

OpenHarmony属性信息怎么修改?触觉智能RK3566鸿蒙开发板来演示

本文介绍在开源鸿蒙OpenHarmony系统下&#xff0c;修改产品属性信息的方法&#xff0c;触觉智能Purple Pi OH鸿蒙开发板演示&#xff0c;搭载了瑞芯微RK3566四核处理器&#xff0c;Laval鸿蒙社区推荐开发板&#xff0c;已适配全新OpenHarmony5.0 Release系统&#xff0c;感兴趣…

Python学习35天

# 定义父类 class Computer: CPUNone MemoryNone diskNone def __init__(self,CPU,Memory,disk): self.disk disk self.Memory Memory self.CPU CPU def get_details(self): return f"CPU:{self.CPU}\tdisk:{self.disk}\t…