图解PyTorch中的torch.gather函数和 scatter 函数

news2024/12/25 23:58:00

前言

torch.gather在目前基于 transformer or query based 的目标检测中,在最后获取目标结果时,经常用到。

这里记录下用法,防止之后又忘了。

介绍

torch.gather

在这里插入图片描述
官方文档对torch.gather()的定义非常简洁

定义:从原tensor中获取指定dim和指定index的数据
看到这个核心定义,我们很容易想到gather()的基本想法其实就类似从完整数据中按索引取值般简单,比如下面从列表中按索引取值

lst = [1, 2, 3, 4, 5]
value = lst[2]  # value = 3
value = lst[2:4]  # value = [3, 4]

上面的取值例子是取单个值或具有逻辑顺序序列的例子,而对于深度学习常用的批量tensor数据来说,我们的需求可能是选取其中多个且乱序的值,此时gather()就是一个很好的tool,它可以帮助我们从批量tensor中取出指定乱序索引下的数据,因此其用途如下

用途:方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的

示例

我们找个3x3的二维矩阵做个实验

import torch

tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

输出结果

tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

2.1 输入行向量index,并替换行索引(dim=0)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

输出结果

tensor([[9, 7, 5]])

过程如图所示
在这里插入图片描述

2.2 输入行向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5, 4, 3]])

过程如图所示
在这里插入图片描述

2.3 输入列向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5],
        [7],
        [9]])

过程如图所示
在这里插入图片描述

scatter

基本是 gather 的反过程,是将数据添加进去,
doc:https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

example:

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])

具体过程见 gather 的就好~一摸一样,一个获取,一个填入。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1562019.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

苹果App上架指南

苹果上架要求是苹果公司对于提交应用程序到苹果商店上架的要求和规定。这些要求主要是为了保证用户体验、应用程序的质量和安全性。以下是苹果上架要求的详细介绍:1. 应用程序的内容和功能必须符合苹果公司的规 苹果上架要求是苹果公司对于提交应用程序到苹果商店上…

iPhone设备中调试应用程序崩溃日志的高效方法探究

​ 目录 如何在iPhone设备中查看崩溃日志 摘要 引言 导致iPhone设备崩溃的主要原因是什么? 使用克魔助手查看iPhone设备中的崩溃日志 奔溃日志分析 总结 摘要 本文介绍了如何在iPhone设备中查看崩溃日志,以便调查崩溃的原因。我们将展示三种不同的…

基于YOLOv8的绝缘子检测系统

💡💡💡本文摘要:基于YOLOv8的绝缘子小目标检测,阐述了整个数据制作和训练可视化过程 1.YOLOv8介绍 Ultralytics YOLOv8是Ultralytics公司开发的YOLO目标检测和图像分割模型的最新版本。YOLOv8是一种尖端的、最先进的&a…

iPhone设备中定位应用程序崩溃问题的日志分析技巧

​ 目录 如何在iPhone设备中查看崩溃日志 摘要 引言 导致iPhone设备崩溃的主要原因是什么? 使用克魔助手查看iPhone设备中的崩溃日志 奔溃日志分析 总结 摘要 本文介绍了如何在iPhone设备中查看崩溃日志,以便调查崩溃的原因。我们将展示三种不同的…

iPhone设备中通过开发者选项查看应用程序崩溃日志的实用技术

​ 目录 如何在iPhone设备中查看崩溃日志 摘要 引言 导致iPhone设备崩溃的主要原因是什么? 使用克魔助手查看iPhone设备中的崩溃日志 奔溃日志分析 总结 摘要 本文介绍了如何在iPhone设备中查看崩溃日志,以便调查崩溃的原因。我们将展示三种不同的…

[Rust开发]用可视化案例讲Rust编程6.动态分发与最终封装

全系列合集 [Rust开发]用可视化案例讲Rust编程1.用Rust画个百度地图 [Rust开发]用可视化案例讲Rust编程2. 编码的核心组成:函数 [Rust开发]用可视化案例讲Rust编程3.函数分解与参数传递 [Rust开发]用可视化案例讲Rust编程4.用泛型和特性实现自适配shapefile的读取 […

目标检测——服饰属性标签识别数据集

一、重要性及意义 首先,随着电商、时尚推荐等业务的发展,服饰属性标签识别已经成为一项关键的计算机视觉任务。这些标签,如颜色、款式、材质等,对于实现图像搜索、时尚推荐等业务需求至关重要。服饰属性标签识别数据集为此类任务…

用ChatGPT出题,完全做不完

最近小朋友正在学习加减法,正好利用ChatGPT来生成加减法练习题,小朋友表示够了,够了,完全做不完。本文将给大家介绍如何利用ChatGPT来生成练习题。 尚未获得ChatGPT的用户,请移步:五分钟开通GPT4.0。 角色…

谷粒商城实战(007 压力测试)

Java项目《谷粒商城》架构师级Java项目实战,对标阿里P6-P7,全网最强 总时长 104:45:00 共408P 此文章包含第141p-第p150的内容 简介 安装jmeter 安装jmeter 使用中文 这样写就是200个线程循环100次 一共是2万个请求 介绍线程组 添加请求 可以是htt…

设计模式之享元模式详解(下)

4)完整解决方案-不带外部状态 1.结构图 IgoChessman充当抽象享元类,BlackIgoChessman和WhiteIgoChessman充当具体享元类,IgoChessmanFactory充当享元工厂类。 2.代码案例 抽象享元类 //围棋棋子类:抽象享元类 abstract class …

ES-7.12-官方文档阅读-ILM-Automate rollover

教程:使用ILM自动化滚动创建index 当你持续将带有时间戳的文档index到Elasticsearch当中时,通常会使用数据流(data streams)以便可以定义滚到到新索引。这是你能够实施一个hot-warm-cold架构来满足你的性能要强,控制随…

【单片机 5.3开关检测】

文章目录 前言一、5.3开关检测1.1没按键按下的1.2有按键按下的 二、改进1.改进 三、独立键盘3.1为什么要取反3.2 实用的按键 总结 前言 提示:这里可以添加本文要记录的大概内容: 课程需要: 提示:以下是本篇文章正文内容&#xf…

C++项目——集群聊天服务器项目(十一)服务器异常退出与添加好友业务

本节来实现C集群聊天服务器项目中的服务器异常退出与添加好友业务,一起来试试吧 一、服务器异常退出 在Linux环境下,我们在服务器端使用CTRLC结束程序执行,即使用CTRLC让服务器异常退出,这样的后果是本应登录服务器的用户在数据库…

2023年第十四届蓝桥杯 - 省赛 - Python研究生组 - A.工作时长

题目 数据文件:https://labfile.oss.aliyuncs.com/courses/21074/records.txt Idea 直接通过 datetime 模块加载时间字符串进行格式化,然后对时间列表进行排序,最后两两计算时间差。 Code Python from datetime import datetimetime_lis…

数论与线性代数——整除分块【数论分块】的【运用】【思考】【讲解】【证明(作者自己证的QWQ)】

文章目录 整除分块的思考与运用整除分块的时间复杂度证明 & 分块数量整除分块的公式 & 公式证明公式证明 代码code↓ 整除分块的思考与运用 整除分块是为了解决一个整数求和问题 题目的问题为: ∑ i 1 n ⌊ n i ⌋ \sum_{i1}^{n} \left \lfloor \frac{n}{…

ESP32 引脚分配

请注意,以下引脚分配参考适用于流行的 30 引脚ESP32 devkit v1开发板。 仅输入引脚 GPIO34~39是GPIs–仅输入的管脚。这些引脚没有内部上拉或下拉电阻。它们不能用作输出,因此只能将这些管脚用作输入:GPIO 34、GPIO 35、GPIO 36、GPIO 39 S…

【opencv】教程代码 —features2D(7)根据单应性矩阵估计相机坐标系下的物体位姿...

pose_from_homography.cpp从图像中找到棋盘角点并进行姿态估计 从图像中找到棋盘角点并显示 计算角点在世界坐标系中的位置 读取相机内参和畸变系数并校正图像中的角点 计算从3D点到2D点的单应性矩阵 通过奇异值分解(SVD)优化对旋转矩阵的估计 基于单应矩阵分解及其优化结果&am…

VSCode - 离线安装扩展python插件教程

1,下载插件 (1)首先使用浏览器打开 VSCode 插件市场link (2)进入插件主页,点击右侧的 Download Extension 链接,将离线安装包下载下来(文件后缀为 .vsix) 2,…

Django之REST framework环境搭建

一、环境搭建 Django REST framework是基于Django实现的一个RESTful风格API框架,能够帮助我们快速开发RESTful风格的API 官网:Home - Django REST framework 中文文档:主页 - Django REST framework中文站点 1.1、安装 Python3.8+ pip install django==4.1.1 pip inst…

RISC-V/ARM mcu OpenOCD 调试架构解析

Risc-v/ARM mcu OpenOCD 调试架构解析 最近有使用到risc-v的单片机,所以了解了下risc-v单片机的编译与调试环境的搭建,面试时问到risc-v的调试可参看以下内容。 risc-v根据官方的推荐,调试器服务是选择OpenOCD,DopenOCD(开放片上…