PyTorch中Tensor的存储结构

news2024/10/26 7:27:59

PyTorch中Tensor的存储结构

Tensor数据的类型

Tensor 中数据主要有下面两种类型:

  • meta data:元数据,也就是描述数据特征的数据,例如 shape、dtype、device、stride等等
  • raw data:数据本身,我们可以通过 tensor.data_ptr() 获取到数据存储的内存位置

参考下面案例

def tensor_struct():
    #  meta_data / raw_data
    nd_array = np.array([[1, 2, 3], [4, 5, 6]])
    # tensor = torch.tensor(nd_array) # deep copy
    tensor = torch.from_numpy(nd_array)

    # raw data
    print(f"pytorch data: \n{tensor}")
    # print("pytorch raw data: \n", tensor.storage())
    print(f"numpy raw data_ptr: {nd_array.ctypes.data}")
    print(f"pytroch raw data_ptr: {tensor.data_ptr()}")  # raw_data

    print(f"numpy data id: {id(nd_array)}", )
    print(f"pytorch data id: {id(tensor)}")

    tensor2 = tensor.reshape(1, 6)
    # 观察可以看到 tensor 及 tensor2 的 id 是不同的, 但是 data_ptr 却相同
    # tensor2 的 row_data 没有变化, meta_data 发生了变化 -> tensor2 是 tensor 的一个 view
    print(f"tensor id: {id(tensor)}")
    print(f"tensor2 id: ", id(tensor2))
    print(f"tensor pointer addr: {tensor.data_ptr()}")
    print(f"tensor2 pointer addr: {tensor2.data_ptr()}")

视图

首先了解一下 Pytorch 中下面的两个概念:

  • stride() :获取张量(Tensor)的步幅信息。步幅(Stride)描述了张量在内存中相邻元素之间的距离(以元素个数为单位),对于多维张量而言,它是一个表示各维度间跳跃关系的元组
  • data_ptr():获取张量(Tensor)底层数据在内存中的起始地址。这个地址是一个整数值,通常表示为一个C语言指针类型(在Python环境中表现为Python整数)

参考下面案例

# 理解 tensor 的步长
def stride_demo():
    tensor = torch.randn(2, 3, 5)
    # stride 就是 tensor 中某一个维度上, 相邻元素之间的步长(以元素个数为单位)
    # 对于 shape 为 2,3,5 的 tensor
    # 在第0维上, 两个元素之间的步长为 3*5 = 15
    # 在第1维上, 两个元素之间的步长为 5*1 = 5
    # 在第2维上, 由于是最后一个维度了, 两个相邻元素间步长就是1了
    tensor_stride = tensor.stride()
    print(f"tensor_stride: {tensor_stride}")
    print(f"tensor.stride(0): {tensor.stride(0)}")
    print(f"tensor.stride(1): {tensor.stride(1)}")
    print(f"tensor.stride(2): {tensor.stride(2)}")

实际上PyTorch获取指定索引位置的数据时,本质上是通过data_ptr()的位置获取多维数组的起始点,然后依据 stride() 计算指定维度走一步需要移动的位置,最终计算出当前索引的数据。

对于一个 shape 为 [2, 3, 5]的 tensor,那么它的 stride 应当为:

  • 第0维:stride[0] 应当为后面两维的乘积,也就是 5*3 = 15
  • 第1维:stride[1] 应当为后面一维的维度,也就是 5
  • 第1维:stride[2] 上面每一个数值都是连续的,也就是1

因此,stride也就是 [15, 5, 1]
img

连续型与破坏连续性

Tensor中的连续性

如果 Tensor 的 stride 满足前面的定义,那么在读取数据时可以认为是连续的,在做类似矩阵乘法时读取数据的效率就会比较高。

但是有一些操作是会破坏这种连续性的

参考下面案例

def contiguous_demo():
    data0 = torch.randint(0, 10, (2, 5))
    data1 = data0.transpose(1, 0)
    data2 = data0.reshape(5, 2)
    print(f"data0: {data0}")
    # data1 和 data2 的 shape 相同, 但是对应位置上的值是不同的
    # data0: [ [3, 5, 5, 9, 2], [8, 7, 4, 9, 7] ]
    # data1: [ [3, 8], [5, 7], [5, 4], [9, 9], [2, 7] ]
    # data2: [ [3, 5], [5, 9], [2, 8], [7, 4], [9, 7] ]
    print(f"data1: {data1}")
    print(f"data2: {data2}")

    # data0、data1、data2 中 的data_ptr() 都是是相同的,说明 row_data 是没有变化的
    # transpose 以及 reshape 操作虽然数据不同,但转换以后 raw_data 是没有变化的
    print(f"data0 data_ptr: {data0.data_ptr()}")
    print(f"data1 data_ptr: {data1.data_ptr()}")
    print(f"data2 data_ptr: {data2.data_ptr()}")

    # transpose 以及 reshape 的区别在于两个操作以后 tensor 的 stride 发生了变化
    # 根据之前的例子对于一个 (5, 2) 的 tensor, stride 取值应当是 (2, 1)
    # 可以看到, reshape 以后是满足这个性质的
    # ------------------------ transpose 导致的不连续现象 -------------------------
    # tensor 在 transpose 操作之后, 读取数据的方式发生了改变, 不能像之前一样 "挨个" 读取数据
    # 从而发生了数据 "不连续" 的现象 !!!
    # 也就是说 transpose 操作本质上仍然是获取的是一个 view,但是会导致数据的不连续
    # ------------------------ transpose 导致的不连续现象 -------------------------
    print(f"data0 stride: {data0.stride()}")  # (5, 1)
    print(f"data1 stride: {data1.stride()}")  # (1, 5)
    print(f"data2 stride: {data2.stride()}")  # (2, 1)

    print(f"data0 is_contiguous: {data0.is_contiguous()}")  # True
    print(f"data1 is_contiguous: {data1.is_contiguous()}")  # False
    print(f"data2 is_contiguous: {data2.is_contiguous()}")  # True

可以看到 transpose 操作会与原始的 tensor 共享同一份 raw_data,但是会使得原来读取最后一个维度数据时发生不连续的现象,因此使得数据变得 “不连续” 了。

常见的破坏连续性的算子

主要有 transpose、permute、T 等等

参考下面案例

def discontinuous_operator():
    data0 = torch.randint(0, 10, (2, 3, 4))
    # transpose 指定交换 第0轴 和 第1轴
    data1 = data0.transpose(0, 1)
    # permute 指的是: 原来第0轴 -> 第2轴, 原来第1轴 -> 第0轴, 原来第2轴 -> 第1轴
    data2 = data0.permute(2, 0, 1)
    data3 = data0.T

    print(f"data0.shape: {data0.shape}")  # [2, 3, 4]
    print(f"data1.shape: {data1.shape}")  # [3, 2, 4]
    print(f"data2.shape: {data2.shape}")  # [4, 2, 3]
    print(f"data3.shape: {data3.shape}")  # [4, 3, 2]

    print(f"data0 stride: {data0.stride()}")  # (12, 4, 1)
    print(f"data1 stride: {data1.stride()}")  # (4, 12, 1)
    print(f"data2 stride: {data2.stride()}")  # (1, 12, 4)
    print(f"data3 stride: {data3.stride()}")  # (1, 4, 12)
contiguous() 方法

既然有些算子会破坏Tensor的连续性,那么有没有什么方法可以避免呢?
我们可以使用 Tensor 中提供的 contiguous()方法使得 Tensor 变为连续的,本质上也就是新开辟了一个数据存储空间,然后把原来的数据挪到新空间下。

参考下面案例

def contiguous_method():
    data0 = torch.randint(0, 10, (2, 5))
    # 这时候 data1 只是 data0 的一个 view
    data1 = data0.transpose(0, 1)
    # 此时创建了一个新的数据空间, data1 已经不是 data0 的一个 view了, 两者的 raw_data 已经不同了
    data1 = data1.contiguous()

    print(f"data1 shape: {data1.shape}")
    print(f"data1 stride: {data1.stride()}")

    # 可以看到此时 data0 与 data1 的 data_ptr 已经不同了
    print(f"data0 data_ptr: {data0.data_ptr()}")
    print(f"data1 data_ptr: {data1.data_ptr()}")

我们可以看到,对于一个不连续的 Tensor 调用 contiguous()方法后,Tensor重新变为连续的了,但是 raw_data 也发生了改变。

reshape vs view

在大部分情况下,reshape 和 view 的作用都是相同的,但是在处理不连续的 Tensor 时,两个算子处理上有所差异:

  • view:直接报错 _view size is not compatible with input tensor's size and stride_
  • reshape:会新开辟一个空间存储,将原有数据copy到新的存储空间当中。

参考下面案例

def view_discontinuous():
    data0 = torch.randint(0, 10, (2, 5))
    data1 = data0.transpose(0, 1)
    # 直接报错: view size is not compatible with input tensor's size and stride
    data2 = data1.view(2, 5)
    print(f"data2: {data2}")


def reshape_discontinuous():
    data0 = torch.randint(0, 10, (2, 5))
    data1 = data0.transpose(0, 1)
    # 此时程序可以跑通
    data2 = data1.reshape(2, 5)

    print(f"data0: {data0}")
    print(f"data1: {data1}")
    print(f"data2: {data2}")

    # 可以看到 data0 和 data1 共享一份 raw_data, 但是 data2 的 raw_data 发生了改变
    # 也就是说: reshape 一个不连续的 tensor, 会新创建一个空间, 将原来的数据 copy 到新的空间
    print(f"data0 data_ptr: {data0.data_ptr()}")
    print(f"data1 data_ptr: {data1.data_ptr()}")
    print(f"data2 data_ptr: {data2.data_ptr()}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2210517.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【漏洞复现】SpringBlade menu/list SQL注入漏洞

》》》产品描述《《《 致远互联智能协同是一个信息窗口与工作界面,进行所有信息的分类组合和聚合推送呈现。通过面向角色化、业务化、多终端的多维信息空间设计,为不同组织提供协同门户,打破组织内信息壁垒,构建统一协同沟通的平台。 》》》漏洞描述《《《 致远互联 FE协作办公…

【尚硅谷】FreeRTOS学笔记(更新中更新时间2024.10.12)

在网上看到的一段很形象的描述,放在这里给大家娱乐一下。 裸机开发:n个人拉屎,先进去一个拉完,下一个再来。看门狗:如果有人拉完屎还占着,茅坑刷视频,把他拖出去中断系统:n个人拉屎&…

Qt基础对话框QDialog

模态显示对话框 调用exec方法可以使得对话框模态显示,但是一个阻塞函数 [virtual slot] int QDialog::exec() 对话框的三个槽函数 accept [virtual slot] void QDialog::accept(); reject [virtual slot] void QDialog::reject() done [virtual slot] void QDia…

搭建mongodb单机部署-认证使用

搭建mongodb单机部署-认证使用 实现思路 先将配置文件配置好,使用不用认证的启动命令启动docker,然后创建账号并制定角色。在使用开启认证的命令重新启动容器就好。 这里我并没有说先停止容器,删掉容器重新创建容器。是因为我的启动命令中…

libaom 源码分析系列:noise_model.c 文件

libaom libaom 是 AOMedia(开放媒体联盟)开发的一个开源视频编解码器库,它是 AV1 视频压缩格式的参考实现,并被广泛用于多种生产系统中。libaom 支持多种功能,包括可扩展视频编码(SVC)、实时通信(RTC)优化等,并定期进行更新以提高压缩效率和编码速度 。 libaom 的一些…

豆包MarsCode体验有京东卡和现金

https://www.marscode.cn/events/s/iBpts1oT/ 先登录注册 然后到VSCODE里,在最左侧导航栏处看到EXTEBSIONS点一下(快捷键CtrlShiftX),然后搜索MarsCode,并安装插件。 安装后登录体验一次问答即可。然后回到活动页即…

电瓶车的无钥匙启动功能为用户带来了极大的便利

电瓶车智能钥匙一键启动系统是一种依赖智能钥匙和一键启动按钮的启动方式。 智能钥匙和一键启动系统的结合使用提高了车辆的安全性,防止了未经授权的启动。 携带智能钥匙进入车辆,按下一键启动按钮,车辆通过感应智能钥匙存在而启动。 一键…

数据结构-C语言顺序栈功能实现

栈 栈&#xff1a;类似于一个容器&#xff0c;如我们生活中的箱子&#xff0c;我们向箱子里放东西&#xff0c;那么最先放的东西是最后才能拿出来的 代码实现 #include <stdio.h> #include <stdlib.h>#define MAX_SIZE 100typedef struct {int* base; // 栈底指针…

【亲测可行】ubuntu根目录空间不够,将其它盘挂载到/opt

文章目录 &#x1f315;缘起&#x1f315;从其它盘压缩出一个未分配的空间&#x1f319;从windows系统中压缩出个未分配的空间&#x1f319;从linux系统中压缩出个未分配的空间 &#x1f315;右键点击未分配的盘新建分区&#x1f315;查看分区&#x1f315;先将新分区挂载到/mn…

VMware vCenter Server 6.7U3v 发布下载 - ESXi 集中管理软件

VMware vCenter Server 6.7U3v 发布下载 - ESXi 集中管理软件 集中式控制 vSphere 环境 请访问原文链接&#xff1a;https://sysin.org/blog/vmware-vcenter-6-7/ 查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xff1a;sysin.org VMware vCenter Server 是…

爬虫post收尾以及cookie加代理

爬虫post收尾以及cookie加代理 目录 1.post请求收尾 2.cookie加代理 post收尾 post请求传参有两种格式&#xff0c;载荷中有请求载荷和表单参数&#xff0c;我们需要做不同的处理。 1.表单数据&#xff1a;data字典传参 content-type: application/x-www-form-urlencoded; …

【STM32单片机_(HAL库)】6-2【串口通信UART、USART】串口通信框图

USB转TTL 串口通信协议 USART框图 UART&#xff1a;通用异步收发器&#xff1b;USART&#xff1a;通用同步异步收发器 STM32F103C8T6支持三个串口通信

yolo参数调节

1-weight 不同版本的神经网络 可以在这下载复制 2 source图片路径或者文件夹路径 3 img size 尺寸&#xff08;尽量与神经网络模型匹配&#xff09; 4 4 -conf-thres 简单理解就是模型识别成功概率超过这一标准才会显示 5 iou多区域重合 &#xff08;重合比例&#xff09;…

HTML入门教程一口气讲完!(下)\^o^/

HTML 表单 HTML 表单和输入 HTML 表单用于收集不同类型的用户输入。 在线实例 创建文本字段 (Text field) 本例演示如何在 HTML 页面创建文本域。用户可以在文本域中写入文本。 创建密码字段 本例演示如何创建 HTML 的密码域。 &#xff08;在本页底端可以找到更多实例。&a…

MySQL基础教程(二):检索数据和排序检索数据

本篇文章主要介绍通过 MySQL 中的 SELECT, DISTINCT, ORDER BY, LIMIT语句完成最基本的数据检索和对检索到的数据进行排序。最基本的数据检索是指我们通过 SELECT 语句查询表中的某些列或者行。对检索到的数据进行排序是指对数据以某种规则显示&#xff0c;例如按照某个字段升序…

QD1-P20 CSS 简单了解

本节学习&#xff1a;简单了解CSS&#xff0c;什么是什么CSS&#xff0c;如何在HTML中使用CSS&#xff1f; ‍ 本节视频 www.bilibili.com/video/BV1n64y1U7oj?p20 CSS是什么&#xff1f; CSS&#xff08;层叠样式表&#xff0c;Cascading Style Sheets&#xff09;是一种样…

【Java面试——基础知识——Day2】

1.面向对象基础 1.1 面向对象和面向过程的区别 面向过程编程&#xff08;POP&#xff09;&#xff1a;面向过程把解决问题的过程拆成一个个方法&#xff0c;通过一个个方法的执行解决问题。面向对象编程&#xff08;OOP&#xff09;&#xff1a;面向对象会先抽象出对象&#…

Jetbrains Fleet1.41 发布:新特性杀疯了

决定我们自身的不是过去的经历 而是我们自己赋予经历的意义 因为过去的经历 是否影响他 如何影响他 完全由他自己决定 有时候 克服恐惧最好的办法 就是把恐惧说出来 前几日 jetbrains fleet1.41 正式发布了,这次的发布可谓是真的诚意满满,包含了多个开发者非常喜欢的小…

Bootstrap 4 多媒体对象

Bootstrap 4 多媒体对象 引言 Bootstrap 4 是目前最受欢迎的前端框架之一,它提供了一套丰富的工具和组件,帮助开发者快速构建响应式和移动设备优先的网页。在本文中,我们将重点探讨 Bootstrap 4 中的多媒体对象(Media Object)组件,这是一种用于构建复杂和灵活布局的强大…

Java:数据结构-LinkedList和链表(2)

一 LinkedList LinkedList的方法的实现 1.头插法 public class MyLinkedList implements IList{static class ListNode{public int val;public ListNode next;public ListNode prev;public ListNode(int val){this.valval;}}public ListNode head;public ListNode last;Overr…