进阶必看,3种灵活操作PyTorch张量的高级方法

news2025/1/12 10:56:21

大家好,在PyTorch中进行高级张量操作时,开发者经常面临这样的问题,如何根据一个索引张量从另一个张量中选取元素。

例如有一个包含数千个特征的大规模数据集,需要根据特定的索引模式快速提取信息。本文将介绍三种索引选择方法来解决这类问题。

torch.index_select

torch.index_select函数通过在指定的维度上进行元素选择,同时在其他维度上保持元素不变。也就是说,在目标维度上根据索引张量来挑选元素,而其他维度的元素则原封不动。为了更直观地理解这一概念,来看一个2D张量的示例,这里将沿着维度1进行元素的选择:

num_picks = 2

values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(num_picks,))
# [len_dim_0, num_picks]
picked = torch.index_select(values, 1, indices)

由此得到的张量形状为[len_dim_0, num_picks]:对于维度0上的每个元素,都从维度1中选取了相同的元素。将其形象化:

现在迈入三维张量的世界,这样更贴近机器学习与数据科学的实际需求。

设想一个三维张量,其维度为[batch_size, num_elements, num_features]:num_elements表示每个批次中的项目数,每个项目具有num_features个特征。这种张量结构所有元素都是以批量方式处理的。

import torch

batch_size = 16
num_elements = 64
num_features = 1024
num_picks = 2

values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, num_elements, size=(num_picks,))
# [batch_size, num_picks, num_features]
picked = torch.index_select(values, 1, indices)

若更倾向于通过代码来理解index_select的功能,以下是使用简单的for循环来模拟该功能实现的示例:

picked_manual = torch.zeros_like(picked)
for i in range(batch_size):
    for j in range(num_picks):
        for k in range(num_features):
            picked_manual[i, j, k] = values[i, indices[j], k]

assert torch.all(torch.eq(picked, picked_manual))

torch.gather

torch.gather函数在功能上与torch.index_select相似,但提供了更为灵活的元素选择方式。

torch.gather中,选择的元素不仅取决于索引张量,还受到其他维度的影响。以机器学习项目为例,可以针对每个批次和每个特征,根据条件从元素维度中选取不同的元素,实现这一点是通过使用另一个张量来指定索引。

在实际应用中,这种用法非常普遍,比如在决策树中根据特定条件选择节点。

每个节点由一组特征定义,可以创建一个索引矩阵,将选定的元素放置在批次维度上,并在特征维度上复制这些值。这样,对于每个批次索引,都可以基于特定条件选择不同的元素,尽管在我们的示例中,这些条件仅与批次索引相关,但也可以根据特征索引来确定。

为了更清楚地理解这一点,再次从二维(2D)示例开始,逐步展示如何使用torch.gather来实现这种灵活的索引选择。

num_picks = 2

values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(len_dim_0, num_picks))
# [len_dim_0, num_picks]
picked = torch.gather(values, 1, indices)

直观来看,torch.gather的元素选择呈现出与torch.index_select不同的模式。不同于后者沿直线进行选择,torch.gather根据维度0上的每个索引,在维度1中挑选出不同的元素:

接下来进入三维世界,并展示如何用Python代码来实现类似的选择机制:

import torch

batch_size = 16
num_elements = 64
num_features = 1024
num_picks = 5
values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, num_elements, size=(batch_size, num_picks, num_features))
picked = torch.gather(values, 1, indices)

picked_manual = torch.zeros_like(picked)
for i in range(batch_size):
    for j in range(num_picks):
        for k in range(num_features):
            picked_manual[i, j, k] = values[i, indices[i, j, k], k]

assert torch.all(torch.eq(picked, picked_manual))

torch.take

在三个函数中,torch.take的工作原理最为简单明了。它首先将输入张量视为一维数组,然后根据指定的索引从中选取元素。

例如,对于一个4行5列的张量,如果使用torch.take并选取索引6和19,实际上获取的是这个张量在一维化之后位于第6个位置和第19个位置的元素,分别对应于原始二维结构中的第2行第2列和最后一行最后一列的元素。

2D示例:

num_picks = 2

values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_0 * len_dim_1, size=(num_picks,))
# [num_picks]
picked = torch.take(values, indices)

现在得到了两个元素:

接下来探讨三维张量的索引选择及其实现。索引张量不受固定形状的限制,可以是任意形状。根据这个索引张量进行的元素选择,其结果也将遵循这种形状,确保输出与索引张量的维度结构一致。

import torch

batch_size = 16
num_elements = 64
num_features = 1024
num_picks = (2, 5, 3)

values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, batch_size * num_elements * num_features, size=num_picks)
# [2, 5, 3]
picked = torch.take(values, indices)

picked_manual = torch.zeros(num_picks)
for i in range(num_picks[0]):
    for j in range(num_picks[1]):
        for k in range(num_picks[2]):
            picked_manual[i, j, k] = values.flatten()[indices[i, j, k]]

assert torch.all(torch.eq(picked, picked_manual))

本文介绍了Pytorch中的三种常见选择方法:torch.index_selecttorch.gathertorch.take。可以使用这些方法,根据不同的条件从张量中选取或索引特定的元素。

对于每种方法,都先通过简单的二维(2D)示例引入,并直观地展示了选择结果。接着,进入更为复杂且实际的三维(3D)应用场景,演示了如何在形状为[batch_size, num_elements, num_features]的张量中进行元素选择——这种情况在机器学习项目中十分常见。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1839486.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java基础概念-数据类型-笔记-标识符-键盘录入

数据类型 分为两种:基本数据类型,引用数据类型 基本数据类型: 注意如果定义long类型变量,需要加L做后缀 long n9999999999L float f10.1F FL大小写都可以 练习 实例: 输出个人信息: public class text…

Danikor智能拧紧轴控制器过压维修知识

【丹尼克尔拧紧轴控制器故障代码维修】 【丹尼克尔Danikor控制器维修具体细节】 丹尼克尔拧紧轴控制器作为一种高精度的电动拧紧工具,广泛应用于各种工业生产线。然而,在使用过程中,由于各种原因,可能会出现Danikor扭矩扳手控制…

Graph RAG 的力量:智能搜索的未来

随着世界越来越依赖数据,对准确、高效的搜索技术的需求从未如此高涨。传统搜索引擎虽然功能强大,但往往难以满足用户复杂而细微的需求,尤其是在处理长尾查询或专业领域时。Graph RAG(检索增强生成)正是在这种情况下应运…

MBR60200PT-ASEMI逆变箱专用MBR60200PT

编辑:ll MBR60200PT-ASEMI逆变箱专用MBR60200PT 型号:MBR60200PT 品牌:ASEMI 封装:TO-247 最大平均正向电流(IF):60A 最大循环峰值反向电压(VRRM):200V…

Vue - 第3天

文章目录 一、Vue生命周期二、Vue生命周期钩子三、工程化开发和脚手架1. 开发Vue的两种方式2. 脚手架Vue CLI基本介绍:好处:使用步骤: 四、项目目录介绍和运行流程1. 项目目录介绍2. 运行流程 五、组件化开发六、根组件 App.vue1. 根组件介绍…

汉化版PSAI全面测评,探索国产AI绘画软件的创新力量

引言 随着AI技术的飞速发展,图像处理和绘画领域迎来了新的变革。作为一名AIGC测评博主,今天我们测评的是一款国产AI绘画软件——StartAI,一句话总结:它不仅在技术上毫不逊色于国际大牌,更在用户体验和本地化服务上做到…

GLib库对核心应用的支持

代码&#xff1a; /** main.c** Created on: 2024-6-19* Author: root*/#include <glib.h> // 包含GLib函数库 static GMutex *mutex NULL; static gboolean t1_end FALSE; // 用于结束线程1的标志 static gboolean t2_end FALSE; // 用于结束线程…

Anti-human IL-10 mAb (12G8), biotin:Mabtech热销品

Anti-human IL-10 mAb (12G8), biotin该单克隆抗体能够在ELISpot、FluoroSpot和ELISA等免疫分析方法中特异性检测人白介素10&#xff08;IL-10&#xff09;。可以将该单克隆抗体12G8作为检测抗体与单克隆抗体9D7&#xff08;ca#3430-3&#xff09;作为捕获抗体配对用于ELISpot、…

js语法---理解反射Reflect对象和代理Proxy对象

Reflect 基本要点 反射&#xff1a;reflect是一个内置的全局对象&#xff0c;它的作用就是提供了一些对象实例的拦截方法&#xff0c;它的用法和Math对象相似&#xff0c;都只有静态方法和属性&#xff0c;同时reflect也没有构造器&#xff0c;无法通过new运算符构建实例对象&…

vue自建h5应用,接入企业微信JDK(WECOM-JSSDK),实现跳转添加好友功能

一、项目场景&#xff1a; 1、使用vue开发了一套h5页面的项目 2、这个h5链接是在企业微信里某个地方打开的 3、打开页面的时候有一个好友列表&#xff0c;点击好友列表某一条复制手机号跳转到企业微信添加好友页面 二、实现的效果图 博客只允许上传gif图&#xff0c;所以我只…

SQL注入-下篇

HTTP注入 一、Referer注入 概述 当你访问一个网站的时候&#xff0c;你的浏览器需要告诉服务器你是从哪个地方访问服务器的。如直接在浏览器器的URL栏输入网址访问网站是没有referer的&#xff0c;需要在一个打开的网站中&#xff0c;点击链接跳转到另一个页面。 Less-19 判…

预算有限?如何挑选经济适用的ERP系统?

中小企业在运营过程中&#xff0c;经常面临着一个共同的挑战——如何在有限的预算内挑选到一款既符合业务需求又经济适用的ERP系统。然而&#xff0c;市场上ERP系统种类繁多&#xff0c;价格差异大&#xff0c;功能复杂&#xff0c;使得许多企业在选择时感到迷茫和困惑。 如果…

【BEV】BEVFormer总结

本文分享BEV感知方案中&#xff0c;具有代表性的方法&#xff1a;BEVFormer。 它基于Deformable Attention&#xff0c;实现了一种融合多视角相机空间特征和时序特征的端到端框架&#xff0c;适用于多种自动驾驶感知任务。 主要由3个关键模块组成&#xff1a; BEV Queries Q&am…

带你了解甘肃独特的调料-苦豆粉

在众多的调味料中&#xff0c;苦豆粉是一种相对小众但却极具特色的存在。今天&#xff0c;就让我们一起深入探究甘肃特产苦豆粉的奇妙世界。苦豆粉&#xff0c;又被称为胡巴豆、大巴豆等&#xff0c;它主要源自于一种豆科植物。很多人初次接触苦豆粉时&#xff0c;可能会被它独…

13.2 Go 接口的动态性

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

RK3588 Android12音频驱动分析全网最全

最近没有搞音频相关的了&#xff0c;在搞BMS, 把之前的经验总结一下。 一、先看一下Android 12音频总架构 从这张图可以看到音频数据流一共经过了3个用户空间层的进程&#xff0c;然后才流到kernel驱动层。Android版本越高&#xff0c;通用性越高&#xff0c;耦合性越低&#…

【Portswigger 学院】CORS

教程和靶场来源于 Burpsuite 的官网 Portswigger&#xff1a;Cross-origin resource sharing (CORS) - PortSwigger 跨域资源共享&#xff08;Cross-origin resource sharing&#xff0c;CORS&#xff09;是一种浏览器机制&#xff0c;允许浏览器访问不同源的资源。同源策略的作…

【Python】已解决Python错误:ImportError: cannot import name get_column_letter的报错解决办法

【Python】已解决Python错误&#xff1a;ImportError: cannot import name get_column_letter的报错解决办法 &#x1f60e; 作者介绍&#xff1a;我是程序员洲洲&#xff0c;一个热爱写作的非著名程序员。CSDN全栈优质领域创作者、华为云博客社区云享专家、阿里云博客社区专家…