Pytorch常用的函数(九)torch.gather()用法

news2025/1/9 16:50:59

Pytorch常用的函数(九)torch.gather()用法

torch.gather() 就是在指定维度上收集value。

torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

一句话概括 gather 操作就是:根据 index ,在 inputdim 维度上收集 value

1、举例直观理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

# 2、我们有index_tensor如下
>>> index_tensor = torch.tensor(
       [[[0, 0, 0, 0],
         [2, 2, 2, 2]],
         
        [[0, 0, 0, 0],
         [2, 2, 2, 2]]]
)	

# 3、我们通过torch.gather()函数获取out_tensor
>>> out_tensor = torch.gather(input_tensor, dim=1, index=index_tensor)
tensor([[[ 0,  1,  2,  3],
         [ 8,  9, 10, 11]],
         
        [[12, 13, 14, 15],
         [20, 21, 22, 23]]])

我们以out_tensor中[0,1,0]=8为例,解释下如何利用dim和index,从input_tensor中获得8。

在这里插入图片描述

根据上图,我们很直观的了解根据 index ,在 inputdim 维度上收集 value的过程。

  • 假设 inputindex 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可;
  • 但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报错 ;
  • 所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。在上面示例中,由于dim=1,那么我们就替换索引列表第1个值,即[i,dim,k],因此由原来的[0,1,0]替换为[0,2,0]后,再去input_tensor中取值。
  • pytorch官方文档的写法如下,同一个意思。
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

2、反推法再理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

# 2、假设我们要得到out_tensor如下
>>> out_tensor
tensor([[[ 0,  1,  2,  3],
         [ 8,  9, 10, 11]],
         
        [[12, 13, 14, 15],
         [20, 21, 22, 23]]])# 3、如何知道dim 和 index_tensor呢? 
# 首先,我们要记住:out_tensor的shape = index_tensor的shape

# 从 output_tensor 的第一个位置开始:
# 此时[i, j, k]一样,看不出来 dim 应该是多少
output_tensor[0, 0, :] = input_tensor[0, 0, :] = 0
# 同理可知,此时index都为0
output_tensor[0, 0, 1] = input_tensor[0, 0, 1] = 1
output_tensor[0, 0, 2] = input_tensor[0, 0, 2] = 2
output_tensor[0, 0, 3] = input_tensor[0, 0, 3] = 3

# 我们从下一行的第一个位置开始:
# 这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2
output_tensor[0, 1, 0] = input_tensor[0, 2, 0] = 8
# 同理可知,此时index都为2
output_tensor[0, 1, 1] = input_tensor[0, 2, 1] = 9
output_tensor[0, 1, 2] = input_tensor[0, 2, 2] = 10
output_tensor[0, 1, 3] = input_tensor[0, 2, 3] = 11

# 根据上面推导我们易知dim=1,index_tensor为:
>>> index_tensor = torch.tensor(
       [[[0, 0, 0, 0],
         [2, 2, 2, 2]],
         
        [[0, 0, 0, 0],
         [2, 2, 2, 2]]]
)	

3、实际案例

在大神何凯明MAE模型(Masked Autoencoders Are Scalable Vision Learners)源码中,多次使用了torch.gather() 函数。

  • 论文链接:https://arxiv.org/pdf/2111.06377
  • 官方源码:https://github.com/facebookresearch/mae

在MAE中根据预设的掩码比例(paper 中提倡的是 75%),使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder,另一部分mask 掉。采样25%作为unmasked tokens过程中,使用了torch.gather() 函数。

# models_mae.py

import torch

def random_masking(x, mask_ratio=0.75):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    N, L, D = x.shape  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))  # 计算unmasked的片数
    # 利用0-1均匀分布进行采样,避免潜在的【中心归纳偏好】
    noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

    # sort noise for each sample【核心代码】
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    # 利用torch.gather()从源tensor中获取25%的unmasked tokens
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=x.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore

if __name__ == '__main__':
    x = torch.arange(64).reshape(1, 16, 4)
    random_masking(x)
# x模拟一张图片经过patch_embedding后的序列
# x相当于input_tensor
# 16是patch数量,实际上一般为(img_size/patch_size)^2 = (224 / 16)^2 = 14*14=196
# 4是一个patch中像素个数,这里只是模拟,实际上一般为(in_chans * patch_size * patch_size = 3*16*16 = 768)
>>> x = torch.arange(64).reshape(1, 16, 4) 
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15],
         [16, 17, 18, 19], # 4
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31],
         [32, 33, 34, 35],
         [36, 37, 38, 39],
         [40, 41, 42, 43], # 10
         [44, 45, 46, 47],
         [48, 49, 50, 51], # 12
         [52, 53, 54, 55], # 13
         [56, 57, 58, 59],
         [60, 61, 62, 63]]])
# dim=1, index相当于index_tensor
>>> index
tensor([[[10, 10, 10, 10],
         [12, 12, 12, 12],
         [ 4,  4,  4,  4],
         [13, 13, 13, 13]]])


# x_masked(从源tensor即x中,随机获取25%(4个patch)的unmasked tokens)     
>>> x_masked相当于out_tensor
tensor([[[40, 41, 42, 43],
         [48, 49, 50, 51],
         [16, 17, 18, 19],
         [52, 53, 54, 55]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1654973.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux中gitlab-runner部署使用备忘

环境: 操作系统::CentOS8 gitlab版本:13.11.4 查看gitlab-runner版本 可以从https://packages.gitlab.com/app/runner/gitlab-runner/search找到与安装的gitlab版本相近的gitlab-runner版本以及安装命令等信息,我找到与13.11.4相…

【离散数学】集合上二元关系性质判定的实现(c语言实现)

实验要求 关系矩阵的初始化和打印 我们将关系矩阵存入一个二维数组中,因为集合元素个数不会超过5个所以就用一个5行5列二维数组来表示。 在我们得到了集合元素个数之后我们就可以对数组进行0,1随机赋值 //初始关系矩阵 void init_matrix(int array[][5], int n) {…

【数据结构|C语言版】栈和队列

前言1. 栈1.1 栈的概念和性质1.2 顺序栈1.3 链栈1.4 共享栈 2. 队列2.1 队列的概念和性质2.2 循环队列2.3 链式队列2.4 双端队列 3. 例题精选3.1 有效的括号3.2 用队列实现栈2.4 用栈实现队列3.4 设计循环队列3.5 参考答案 结语 #include<GUIQU.h> int main { 上期回顾: …

centos8.5 安装 redis 7.2.4 详细步骤

1 下载Index of /releases/ (redis.io) 通过xftp等方式上传到服务器&#xff0c;安装依赖包 yum install gcc gcc-c make tcl -y [rootlocalhost software]# ll total 3308 -rw-r--r--. 1 root root 3386861 May 3 21:56 redis-7.2.4.tar.gz [rootlocalhost software]# ll…

SlowFast报错:ValueError: too many values to unpack (expected 4)

SlowFast报错&#xff1a;ValueError: too many values to unpack (expected 4) 报错细节 File "/home/user/yuanjinmin/SlowFast/tools/visualization.py", line 81, in run_visualizationfor inputs, labels, _, meta in tqdm.tqdm(vis_loader): ValueError: too …

软件测试之测试用例详细解读

&#x1f345; 视频学习&#xff1a;文末有免费的配套视频可观看 &#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 一、通用测试用例八要素   1、用例编号&#xff1b;    2、测试项目&#xff1b;   3、测…

CSS---复合选择器和元素显示模式(三)

一、CSS的复合选择器 1.1 什么是复合选择器 在CSS中&#xff0c;可以根据选择器的类型把选择器分为基础选择器和复合选择器&#xff0c;复合选择器是建立在基础选择器之上&#xff0c;对基本选择器进行组合形成的。 复合选择器是由两个或多个基础选择器连写组成&#xff0c;它…

算法学习Day2——单调栈习题

第一题&#xff0c;合并球 题解&#xff1a;一开始写了一次暴力双循环&#xff0c;直接O(n^2)严重超时&#xff0c;后面于是又想到了O(n)时间复杂度的链表&#xff0c;但是还是卡在 最后一个数据会TLE&#xff0c;我也是高兴的拍起来安塞腰鼓和华氏护肤水&#xff0c;后面学长给…

信息系统项目管理师0094:项目管理过程组(6项目管理概论—6.4价值驱动的项目管理知识体系—6.4.3项目管理过程组)

点击查看专栏目录 文章目录 6.4.3项目管理过程组1.适应型项目中的过程组2.适应型项目中过程组之间的关系6.4.3项目管理过程组 项目管理过程组是为了达成项目的特定目标,对项目管理过程进行的逻辑上的分组。项目管理过程组不同于项目阶段:①项目管理过程组是为了管理项目,针对…

typescript 模块化

模块的概念&#xff1a; 把一些公共的功能单独抽离成一个文件作为一个模块。 模块里面的变量、函数、类等默认是私有的&#xff0c;如果我们要在外部访问模块里面的数据&#xff08;变量、函数、类&#xff09;&#xff0c;需要通过export暴露模块里面的数据&#xff08;&#…

C++面向对象程序设计 - 多态性

多态性是面向对象程序设计的一个重要特征&#xff0c;多态的意思是一个事物的多种形态。在C程序设计中&#xff0c;多态性是指具有不同功能的函数可以用同一个函数名&#xff0c;这样就可以用一个函数名调用不同内容的函数。面向对象方法中的多态性&#xff0c;比如说向不同的对…

Colibri for Mac v2.2.0 原生无损音频播放器 激活版

Colibri支持所有流行的无损和有损音频格式的完美清晰的比特完美播放&#xff0c;仅使用微小的计算能力&#xff0c;并提供干净和直观的用户体验。 Colibri在播放音乐时使用极少的计算能力。该应用程序使用最先进的Swift 3编程语言构建&#xff0c;BASS音频引擎作为机器代码捆绑…

计算机毕业设计 | springboot+vue凌云在线阅读平台 线上读书系统(附源码)

1&#xff0c;绪论 随着社会和网络技术的发展&#xff0c;网络小说成为人们茶钱饭后的休闲方式&#xff0c;但是现在很多网络小说的网站都是收费的&#xff0c;高额的收费制度是很多人接受不了的&#xff0c;另外就是很多小说网站都会有大量的弹窗和广告&#xff0c;这极大的影…

【文献解析】3D高斯抛雪球是个什么玩意

论文地址&#xff1a;https://arxiv.org/abs/2308.04079 项目&#xff1a;3D Gaussian Splatting for Real-Time Radiance Field Rendering 代码&#xff1a;git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive 一、文章概述 1.1问题导向 辐射…

如何通过 4 种方式在 Mac 上恢复未保存的 Excel 文件

您曾在 MacBook 上花费数小时处理 Excel 工作簿&#xff0c;但现在它消失了。或者&#xff0c;当您退出 Excel 文件时&#xff0c;您无意中选择了“不保存”。这是不是说你所有的努力都白费了&#xff1f;本文系统地介绍了如何在 Mac 上恢复丢失的 Excel 文件。通过我们的 4 种…

effective python学习笔记_函数

函数返回值尽量不要超过三个 局限性&#xff1a;当返回参数过多时&#xff0c;有时会搞混哪个是哪个&#xff0c;可能返回的两个值反了 解决方法&#xff1a;如果参数过多&#xff0c;可以组装*变量返回&#xff0c;或者自定义轻量类型或namedtuple返回 有意外情况时尽量抛异…

Linux学习笔记2---Makefile

单个文件编译用gcc编译确实是挺方便的&#xff0c;但是多个文件需要编译一个个的编译就属实是麻烦了&#xff0c;而针对多文件编译也有快捷的办法&#xff0c;即Makefile脚本。要运行Makefile需要先安装make程序。 apt install make 1.什么是Makefile 一个工程中的源文件不计…

如何提高日语听力?日语学习日语培训柯桥小语种学校

每次一说起练日语听力&#xff0c;总离不开一个词&#xff0c;那就是“磨耳朵”。 可是&#xff0c;“磨耳朵”真的有用吗&#xff1f; 在讨论这个问题之前&#xff0c;我们需要先知道&#xff1a;什么是“磨耳朵”&#xff1f; 所谓的“磨耳朵”&#xff0c;其实就是让我们的耳…

【Ubuntu20.04安装java-8-openjdk】

1 下载 官网下载链接&#xff1a; https://www.oracle.com/java/technologies/downloads/#java8 下载 最后一行 jdk-8u411-linux-x64.tar.gz&#xff0c;并解压&#xff1a; tar -zxvf jdk-8u411-linux-x64.tar.gz2 环境配置 1、打开~/.bashrc文件 sudo gedit ~/.bashrc2、…

2024蓝桥杯CTF writeUP--packet

根据流量分析&#xff0c;我们可以知道129是攻击机&#xff0c;128被留了php后门&#xff0c;129通过get请求来获得数据 129请求ls Respons在这 里面有flag文件 这里请求打开flag文件&#xff0c;并以base64编码流传输回来 获得flag的base64的数据 然后解码 到手