形象化理解pytorch中的tensor.scatter操作

news2024/11/16 3:17:32

定义

        scatter_(dim, index, src, *, reduce=None) -> Tensor

pytorch官网说这个函数的作用是从src中把index指定的位置把数据写入到self里面,然后给了一个公式:           

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0

self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1

self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

这个公式我也是一脸懵,但是我们可以把他降维到二维表格上,即:

            self[ index[i][j] ][j] = src[i][j]  # if dim == 0

把src从 i 行 移动到了 index[i][j] 行

            self[i][ index[i][j] ] = src[i][j]  # if dim == 1

把src 从 j 列移动到了 index[i][j] 列

对此,个人认为比较直观的理解:
        dim=0,就是把本行这个data放到本列的哪行(上下移动)
        dim=1,就是把本列这个data放到本行的哪列(左右移动)

所以,index数组其实是一个位置变化的映射表

例子1

给定src是一个顺序数组,我们可以更清楚看到这一变化过程。

>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

当我们指定 dim=0,就是把每一行data放到上下移动位置,比如我们给一个例子

>>> index = torch.tensor([ [0, 0, 0], [1, 1, 1], [2, 2, 0] ]) 
>>> src.scatter(dim = 0, index=index, src = src)
tensor([[1, 2, 9],
        [4, 5, 6],
        [7, 8, 9]])

可以看到,scatter之后只有 src[0][2] 发生了变化,为什么呢?

 前面提到了index数组其实是一个位置变化的映射表,  dim=0 时候是把src从 i 行 移动到了 index[i][j] 行(上下移动), 这里的index表 0行所有的元素都移动到了0行对应位置, 1行所有的的元素都移动到了1行对应位置, 只有2行最后一个元素移动到了0行,造成的结果就是src只有最后一个元素移动到了0行的对应位置(从src[2][2]移动到了src[0][2])

 例子2

下面我们再试试dim = 1 时候 把src 从 j 列移动到了 index[i][j] 列

给定src是一个顺序数组,我们可以更清楚看到这一变化过程。

>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

index给定如下

>>> index = torch.tensor([ [0, 1, 2], [0, 1, 2], [0, 1, 0] ])
>>> index
tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 0]])

>>> src.scatter(dim = 1, index=index, src = src) 
tensor([[1, 2, 3],
        [4, 5, 6],
        [9, 8, 9]])

可以看到,这里src也只有一个位置发生了变化,为什么呢?

 前面提到了index数组其实是一个位置变化的映射表,  dim=1 时候是把src从 i 列 移动到了 index[i][j] 列 (左右移动), 这里的index表 0行 012 列对应的元素都移动到了0行 012列 对应位置(相当于没动), 1行 012 列对应的元素都移动到了1行 012列 对应位置(相当于没动), 只有2行最后一个元素移动到了0列,造成的结果就是src只有最后一个元素移动到了2行0列的位置(从src[2][2]移动到了src[2][0] )

意义

那么这种映射这么复杂,它的意义在哪里呢? 

答:一般scatter用于生成onehot向量

这里还是举个例子

我们还是拿之前的src数组

>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

我们要如何理解它呢?我们可以认为它是三只股票在昨天、今天、明天的股票价格,昨天三只股票的价格分别为1,4,7,今天三只股票的价格分别为2,5,8, 明天三只股票的价格分别为3,6,9。

现在我们要训练一个预测后天股票价格的神经网络,我们给模型的输入应该是昨天三只股票的价格、今天三只股票的价格、明天三只股票的价格,即1,4,7,2,5,8,3,6,9。同时,我们要把每个数字转化为一个onehot的向量,这样的结果是我们期望的。

所以,我们要做的事情是把src转换为一个 3*3 的矩阵,矩阵中每个元素是一个能表示0-9的10维one-hot向量。

拿一段常用的onehot生成代码说事。


def one_hot(x, n_class, dtype=torch.float32):
    # X shape: (batch, 1), output shape: (batch, n_class)
    x = x.long()
    res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)
    res.scatter_(1, x.view(-1, 1), 1)
    return res


# X shape: batch_size, prices_list
def to_onehot(X, n_class):
    # 返回结果 shape: prices_list, batch_size, onehot_size 三维
    return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]

先不谈代码含义,输出结果如下 

>>> to_onehot(src, 10) 
[tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]),
 tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]]), 
 tensor([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])]

这个结果基本符合了我们的期望,那么这个是如何做到的呢? 


# X shape: batch_size, prices_list
def to_onehot(X, n_class):
    # 返回结果 shape: prices_list, batch_size, onehot_size 三维
    return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]

首先,src按照昨天,今天,明天的维度,被切分为了三个列向量 [1,4,7]、[2,5,8]、 [3,6,9] 。这三个列向量对应了我们的输出,one_hot给定一个列向量,可以转换为一个one-hot列向量组。

def one_hot(x, n_class, dtype=torch.float32):
    # X shape: (batch, 1), output shape: (batch, n_class)
    x = x.long()
    res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)
    res.scatter_(1, x.view(-1, 1), 1)
    return res

为了简单,我们举一个例子


>>> one_hot(torch.tensor([1,2,3]), 4) 
tensor([[0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

>>> torch.tensor([1,2,3]).view(-1,1)  
tensor([[1],
        [2],
        [3]])

 可以看到,res是一个全0矩阵,scatter操作在dim=1时,是一个左右移动的位置映射表,这里的res是一个 3 * 4 的矩阵,src是一个数字,可以认为是跟res同样大小的全1矩阵,但是index是一个 3*1 的矩阵,也就是这个位置映射表可以认为是一个3行1列的映射表,即 全1矩阵的0 行 0 列映射到res的 0 行 1列,全1矩阵的1行0列映射到res的1行2列,全1矩阵的2行0列映射到res的2行3列,其他保持不变(其他都是0),dim=1这种操作就是制造了one-hot向量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2120305.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ROS CDK魔法书:建立你的游戏王国(TypeScript篇)

引言 在虚拟游戏的世界里,数字化的乐趣如同流动的音符,谱写着无数玩家的共同回忆。而在这片充满创意与冒险的乐园中,您的使命就是将独特的游戏体验与丰富的技术知识相结合,打造出令人难以忘怀的作品。当面对如何实现这一宏伟蓝图…

MOS管G极串联电阻的作用是什么

MOS管栅极串联电阻是如何抑制谐振? 为什么会震荡? 首先了解一下LC串联谐振电路,虽然,LC串联在电路中运用的并不多,但是在无意中总会形成串联谐振,从而产生很多各种各样的现象。如果不了解其本质,会让我们很难理解。比如,使用同样的LC电路滤波,用到两个电路上,有的电…

CCF推荐B类会议和期刊总结:(计算机体系结构/并行与分布计算/存储系统领域)

目录 前言 B类会议 1. SoCC 2. SPAA 3. PODC 4. FPGA 5. CGO 6. DATE 7. HOT CHIPS 8. CLUSTER 9. ICCD 10. ICCAD 11. ICDCS 12. CODESISSS 13. HiPEAC 14. SIGMETRICS 15. PACT 16. ICPP 17. ICS 18. VEE 19. IPDPS 20. Performance 21. HPDC 22. ITC…

在亚马逊云科技上利用Graviton4代芯片构建高性能Java应用(下篇)

简介 在AI迅猛发展的时代,芯片算力对于模型性能起到了至关重要的作用。一款能够同时兼具高性能和低成本的芯片,能够帮助开发者快速构建性能稳定的生成式AI应用,同时降低开发成本。今天小李哥将介绍亚马逊推出的4代高性能计算处理器Gravition…

计算机毕业设计Python电影评论情感分析 电影可视化 豆瓣电影爬虫 电影推荐系统 电影数据分析 电影大数据 大数据毕业设计 机器学习 深度学习 知识图谱

相关技术介绍 豆瓣电影数据采集与可视化分析系统是用当前应用很广泛的Python语言和Flask框架,并结合CSS与HTML搭建Web网页,使用MySQL数据库对数据进行存储,依次来开发实现系统的功能。本系统运行需要的软件有Pycharm、普通浏览器、Navicat f…

AD原理图update为pcb

首先,要在自己的项目下面创建好原理图和PCB,记得保存!!! 点击设计>update 更新成功!

数据结构-图-存储-邻接矩阵-邻接表

数据结构-图-存储 邻接矩阵 存储如下图1,图2 图1 对应邻接矩阵 图2 #include<bits/stdc.h> #define MAXN 1005 using namespace std; int n; int v[MAXN][MAXN]; int main(){cin>>n;for(int i1;i<n;i){for(int j1;j<n;j){cin>>v[i][j];}}for(int…

RM比赛常见的电机(直流无刷电机)

声明&#xff1a;个人笔记&#xff0c;仅供参考 一、M2006电机 M2006 P36 电机采用三相永磁直流无刷结构&#xff0c;具有输出转速高、体积小、功率密度高等特点。该电机内置位置传感器&#xff0c;可提供精确的位置反馈&#xff0c;以 FOC 矢量控制方式使电机产生连续的扭矩。…

Cannot Locate Document 原理图导入pcb出现报错

将原理图update到pcb时报错Cannot Locate Document&#xff1a; 记得保存pcb到你的项目就可以了

爬虫之淘宝接口获取||Python返回淘宝商品详情数据SKU接口

在学习爬虫的过程中&#xff0c;大多数的人都是些豆瓣&#xff0c;招聘网站什么的。这里给出一些工作上能够用得到的内容&#xff0c; 仅供大家参考。 本次需要看的是淘宝的接口&#xff0c; 这个接口与微博寻找接口的方式大致相同。请看详细的寻找方法。首先我们先在百度页面点…

python实现c4d的tp粒子在多个物体上发射思维粒子

基本状态思维粒子只能传入一个物体&#xff0c;在一个物体身上发射粒子。 场景如下&#xff0c;右边的multiEmitter的python标签里的python脚本执行后会在其下面生成数个从pt物体的拷贝&#xff0c;同时拷贝其上的XPresso标签及标签里的内容 下面是pt物体的XPresso标签标签的内…

[ RK3566-Android11 ] 关于 RK628F 驱动移植以及调试说明

问题描述 我这个项目的SDK比较老&#xff0c;移植RK628F最新驱动的调试过程&#xff0c;踩了很多坑&#xff0c;希望大家别踩坑。 解决方案&#xff1a; 首先在FTP上下载最新的RK628的驱动 rk628-for-all-v27-240730 版本。 下载完后 不要直接替换&#xff0c;不要直接替换&a…

Vue获取后端重定向拼接的参数

前言 比如我们要重定向这样一个连接&#xff1a; http://192.168.2.189:8081?nameadmin springboot重定向&#xff1a; Vue获取&#xff1a; getParam(param) {var reg new RegExp("(^|&)" param "([^&]*)(&|$)");var r location.searc…

计算机的错误计算(八十八)

摘要 探讨双曲反正切函数 atanh(x)的计算精度问题。 IEEE 754-2019 中含有 atanh(x)函数。其定义为 例1. 计算 atanh(0.9999999999997) . 不妨用 LibreOffice中的电子表格计算&#xff0c;则有&#xff1a; 若在线运行JavaScript代码&#xff1a; let result Math.atanh(0.…

单电源转正负双电源电路

单电源转正负双电源电路&#xff1a; 1.通过两个DCDC芯片进行降压&#xff1a; 不同负载下电源纹波不同&#xff0c;所以看电源纹波首先先说明负载是什么&#xff1a; 采用TPS5430将单电源转换成双电源的方式供电&#xff1a; 2.通过电荷泵的方式转换电压 成本可以压低&#…

民生水暖工程背后的科技力量引领工程智能化转型

物联网技术的广泛应用&#xff0c;使得物理设备能够实时传输运行状态数据至云端&#xff0c;实现了设备的全面感知与互联互通。每一台机器、每一个传感器都成为数据的源泉&#xff0c;为远程监控提供了坚实的基础。而大数据分析技术的应用&#xff0c;则让这些海量数据得以被高…

R语言统计分析——用回归做ANOVA

参考资料&#xff1a;R语言实战【第2版】 ANOVA&#xff08;方差分析&#xff09;和回归都是广义线性模型的特例&#xff0c;方差分析也都可以使用lm()函数来分析。 # 加载multcomp包 library(multcomp) # 查看cholesterol数据集的处理水平 levels(cholesterol$trt) # 用aov()…

久久派搭建风电系统网站(基于mariadb数据库)

久久派搭建风电系统网站 1、安装mariadb2、设置root账号密码3、设置MariaDB开机自启4、允许远程登录5、还原数据库6、扩容swap7、拷贝数据8、运行系统方法1&#xff1a;通过sh脚本运行方法2&#xff1a;直接运行jar包 文中所需网盘资料及讲解视频在文章末尾哦1。 本文中参考资料…

万能无线航模模拟器加密狗说明书

快速开始 Step1 插入加密狗到你的电脑&#xff0c;手机或MAC的USB口。 Step2 使用加密狗上的按钮&#xff0c;选择一个合适的协议。具体看第一节。 Step3 和遥控器对码&#xff0c;成功后指示灯常亮。具体看FAQ第二节。 Step4 在你的电脑&#xff0c;手机或MAC 安装对…

GD32F103单片机-GPIO

GD32F103单片机-GPIO 一、GPIO介绍二、GD32F103库函数介绍三、GPIO输入输出3.1 GPIO输出-LED闪烁3.2 GPIO输入-独立按键 STM32GPIO部分见STM32F1单片机-GPIO 一、GPIO介绍 GD32的GPIO同STM32一样&#xff0c;GPIO可以配置成8种输入输出模式&#xff0c;由软件配置成推挽输出、…