2. PyTorch——Tensor和Numpy

news2025/2/26 5:44:22

2.1Tensor和Numpy

Tensor和Numpy数组之间具有很高的相似性,彼此之间的互操作也非常简单高效。需要注意的是,Numpy和Tensor共享内存。由于Numpy历史悠久,支持丰富的操作,所以当遇到Tensor不支持的操作时,可先转成Numpy数组,处理后再转回tensor,其转换开销很小。

import numpy as np
a = np.ones([2, 3],dtype=np.float32)
a
array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)
b = t.from_numpy(a)
b
tensor([[1., 1., 1.],
        [1., 1., 1.]])
b = t.Tensor(a) # 也可以直接将numpy对象传入Tensor
b
tensor([[1., 1., 1.],
        [1., 1., 1.]])
a[0, 1]=100
b
tensor([[  1., 100.,   1.],
        [  1.,   1.,   1.]])
c = b.numpy() # a, b, c三个对象共享内存
c
array([[  1., 100.,   1.],
       [  1.,   1.,   1.]], dtype=float32)

注意: 当numpy的数据类型和Tensor的类型不一样的时候,数据会被复制,不会共享内存。

a = np.ones([2, 3])
# 注意和上面的a的区别(dtype不是float32)
a.dtype
dtype('float64')
b = t.Tensor(a) # 此处进行拷贝,不共享内存
b.dtype
torch.float32
c = t.from_numpy(a) # 注意c的类型(DoubleTensor)
c
tensor([[1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)
a[0, 1] = 100
b # b与a不共享内存,所以即使a改变了,b也不变
tensor([[1., 1., 1.],
        [1., 1., 1.]])
c # c与a共享内存
tensor([[  1., 100.,   1.],
        [  1.,   1.,   1.]], dtype=torch.float64)

注意: 不论输入的类型是什么,t.tensor都会进行数据拷贝,不会共享内存

tensor = t.tensor(a) 
tensor[0,0]=0
a
array([[  1., 100.,   1.],
       [  1.,   1.,   1.]])

广播法则(broadcast)是科学运算中经常使用的一个技巧,它在快速执行向量化的同时不会占用额外的内存/显存。
Numpy的广播法则定义如下:

  • 让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分通过在前面加1补齐
  • 两个数组要么在某一个维度的长度一致,要么其中一个为1,否则不能计算
  • 当输入数组的某个维度的长度为1时,计算时沿此维度复制扩充成一样的形状

PyTorch当前已经支持了自动广播法则,但是笔者还是建议读者通过以下两个函数的组合手动实现广播法则,这样更直观,更不易出错:

  • unsqueeze或者view,或者tensor[None],:为数据某一维的形状补1,实现法则1
  • expand或者expand_as,重复数组,实现法则3;该操作不会复制数组,所以不会占用额外的空间。

注意,repeat实现与expand相类似的功能,但是repeat会把相同数据复制多份,因此会占用额外的空间。

a = t.ones(3, 2)
b = t.zeros(2, 3,1)
# 自动广播法则
# 第一步:a是2维,b是3维,所以先在较小的a前面补1 ,
#               即:a.unsqueeze(0),a的形状变成(1,3,2),b的形状是(2,3,1),
# 第二步:   a和b在第一维和第三维形状不一样,其中一个为1 ,
#               可以利用广播法则扩展,两个形状都变成了(2,3,2)
a+b
tensor([[[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.]]])
# 手动广播法则
# 或者 a.view(1,3,2).expand(2,3,2)+b.expand(2,3,2)
a[None].expand(2, 3, 2) + b.expand(2,3,2)
tensor([[[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.]]])
# expand不会占用额外空间,只会在需要的时候才扩充,可极大节省内存
e = a.unsqueeze(0).expand(10000000000000, 3,2)

2.2 内部结构

tensor的数据结构如图1-1所示。tensor分为头信息区(Tensor)和存储区(Storage),信息区主要保存着tensor的形状(size)、步长(stride)、数据类型(type)等信息,而真正的数据则保存成连续数组。由于数据动辄成千上万,因此信息区元素占用内存较少,主要内存占用则取决于tensor中元素的数目,也即存储区的大小。

一般来说一个tensor有着与之相对应的storage, storage是在data之上封装的接口,便于使用,而不同tensor的头信息一般不同,但却可能使用相同的数据。下面看两个例子。

在这里插入图片描述

a = t.arange(0, 6)
a.storage()
 0
 1
 2
 3
 4
 5
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]
b = a.view(2, 3)
b.storage()
 0
 1
 2
 3
 4
 5
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]
# 一个对象的id值可以看作它在内存中的地址
# storage的内存地址一样,即是同一个storage
id(b.storage()) == id(a.storage())
True
# a改变,b也随之改变,因为他们共享storage
a[1] = 100
b
tensor([[  0, 100,   2],
        [  3,   4,   5]])
c = a[2:] 
c.storage()
 0
 100
 2
 3
 4
 5
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]
c.data_ptr(), a.data_ptr() # data_ptr返回tensor首元素的内存地址
# 可以看出相差8,这是因为2*4=8--相差两个元素,每个元素占4个字节(float)
(4478820552016, 4478820552000)
c[0] = -100 # c[0]的内存地址对应a[2]的内存地址
a
tensor([   0,  100, -100,    3,    4,    5])
d = t.LongTensor(c.storage())
d[0] = 6666
b
tensor([[6666,  100, -100],
        [   3,    4,    5]])
# 下面4个tensor共享storage
id(a.storage()) == id(b.storage()) == id(c.storage()) == id(d.storage())
True
a.storage_offset(), c.storage_offset(), d.storage_offset()
(0, 2, 0)
e = b[::2, ::2] # 隔2行/列取一个元素
id(e.storage()) == id(a.storage())
True
b.stride(), e.stride()
((3, 1), (6, 2))
e.is_contiguous()
False

可见绝大多数操作并不修改tensor的数据,而只是修改了tensor的头信息。这种做法更节省内存,同时提升了处理速度。在使用中需要注意。
此外有些操作会导致tensor不连续,这时需调用tensor.contiguous方法将它们变成连续的数据,该方法会使数据复制一份,不再与原来的数据共享storage。
另外读者可以思考一下,之前说过的高级索引一般不共享stroage,而普通索引共享storage,这是为什么?(提示:普通索引可以通过只修改tensor的offset,stride和size,而不修改storage来实现)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1293291.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java设计模式学习之【装饰器模式】

文章目录 引言装饰器模式简介定义与用途实现方式 使用场景优势与劣势装饰器模式在Spring中的应用画图示例代码地址 引言 在日常生活中,我们常常对基本事物添加额外的装饰以增强其功能或美观。例如,给手机加一个保护壳来提升其防护能力,或者在…

阿里云磁盘在线扩容

我们从阿里云的控制面板中给硬盘扩容后结果发现我们的磁盘空间并没有改变 注意:本次操作是针对CentOS 7的 #使用df -h并没有发现我们的磁盘空间增加 #使用fdisk -l发现确实还有部分空间 运行df -h命令查看云盘分区大小。 以下示例返回分区&#xf…

Java的三种代理模式实现

代理模式的定义: Provide a surrogate or placeholder for another object to control access to it.(为其他对象提供一种代理以控制对这个对象的访问。) 简单说,就是设置一个中间代理来控制访问原目标对象,达到增强原…

复杂gRPC之go调用go

1. 复杂的gRPC调用 我们使用了一个较为复杂的proto文件,这个文件的功能主要是用来定位的,详细内容可以看代码中的注解 syntax "proto3"; //指定生成的所属的package,方便调用 option go_package "./"; package route…

渗透测试——七、网站漏洞——命令注入和跨站请求伪造(CSRF)

渗透测试 一、命令注入二、跨站请求伪造(CSRF)三、命令注入页面之注人测试四、CSRF页面之请求伪造测试 一、命令注入 命令注入(命令执行) 漏洞是指在网页代码中有时需要调用一些执行系统命令的函数例如 system()、exec()、shell_exec()、eval()、passthru(),代码未…

B站缓存视频M4S合并MP4(js + ffmpeg )

文章目录 B站缓存视频转MP4(js ffmpeg )1、说明 2、ffmpeg2.1 下载地址2.2 配置环境变量2.3 测试2.4 转换MP4命令 3、处理程序 B站缓存视频转MP4(js ffmpeg ) 注意:这样的方式只用于个人之间不同设备的离线观看。请…

链式二叉树的创建及遍历(数据结构实训)

题目: 链式二叉树的创建及遍历 描述: 树的遍历有先序遍历、中序遍历和后序遍历。先序遍历的操作定义是先访问根结点,然后访问左子树,最后访问右子树。中序遍历的操作定义是先访问左子树,然后访问根,最后访问…

阿里云实时数据仓库HologresFlink

1. 实时数仓Hologres特点 专注实时场景:数据实时写入、实时更新,写入即可见,与Flink原生集成,支持高吞吐、低延时、有模型的实时数仓开发,满足业务洞察实时性需求。亚秒级交互式分析:支持海量数据亚秒级交…

单片机学习13——串口通信

单片机的通信功能: 实现单片机和单片机的信息交换,实现单片机和计算机的信息交换。 计算机通信是指计算机与外部设备或计算机与计算机之间的信息交换。 通信有并行通信和串行通信两种方式。 在多微机系统以及现在测控系统中信息的交换多采用串行通信方…

游戏盾的防御原理以及为什么程序类型更适合接入游戏盾。

游戏盾是一种专门用于游戏服务器的安全防护服务,旨在抵御各种网络攻击。它的原理主要包括以下几个方面: 流量清洗和过滤:游戏盾会对进入游戏服务器的流量进行实时监测、分析和过滤。它通过识别恶意流量和攻击流量,过滤掉其中的攻击…

SVM原理理解

目录 概念推导: 共识:距离两个点集距离最大的分类直线的泛化能力更好,更能适应复杂数据。 怎么能让margin最大? 最大化margin即: 拉格朗日乘子法: 为什么公式中出现求和符号? SVM模型: 小结&#…

css弹窗动画效果,示例弹窗从底部弹出

从底部弹出来,有过渡动画效果 用max-height可以自适应内容的高度,当内容会超过最大高度时可以在弹窗里加个scroll-view 弹窗不能用v-if来隐藏,不然transition没效果,transition只能对已有dom元素起效果,所以用透明和v…

55.MQ高级特性

目录 一、RabbitMQ部署指南。 1)单机部署。 1.1.下载镜像 1.2.安装MQ 2)安装DelayExchange插件。 2.1.下载插件 2.2.上传插件 2.3.安装插件 2.4.使用插件。 3)集群部署。 3.1.集群分类 3.2.获取cookie 3.3.准备集群配置 3.4.启…

【开源】基于JAVA语言的农家乐订餐系统

项目编号: S 043 ,文末获取源码。 \color{red}{项目编号:S043,文末获取源码。} 项目编号:S043,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户2.2 管理员 三、系统展示四、核…

Day15——File类与IO流

1.java.io.File类的使用 1.1 File类的理解 File 类及本章下的各种流,都定义在 java.io 包下。一个 File 对象代表硬盘或网络中可能存在的一个文件或者文件目录(俗称文件夹),与平台无关。(体会万事万物皆对象&#xf…

Qt 输入一组数,排序后用柱状图显示

Qt柱状图&#xff0c;需要使用到QChart模块&#xff0c;因此需要在安装Qt时勾选上QChart模块。然后在工程.pro文件中加上 QT charts 参考代码&#xff1a; //MainWindow.h #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow> #include <QPushButton…

Global IIIumination(GI)全局光照原理(一)3D空间全局光照

文章目录 一、Global IIIumination&#xff08;GI&#xff09;全局光照基本概念二、主流的全局光照方法&#xff1a;三、Reflective shadow maps&#xff08;RSM&#xff09;反射阴影贴图 全局光照四、Light Propagation Volumes (LPV)光线传播体积 全局光照1.第一步&#xff0…

Apache Flink(四):Flink 其他实时计算框架对比

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹哥教你大数据个人主页-哔哩哔哩视频 根据前文描述我们知道Flink主要处…

Qt::UniqueConnection和lambda一块用无效

如果槽函数是lambda。 那么用了Qt::UniqueConnection也会出现槽函数被多次调用的问题。 原因&#xff1a; 参考官方文档&#xff1a; QObject Class | Qt Core 5.15.16https://doc.qt.io/qt-5/qobject.html#connect

UE Websocket笔记

参考链接 [UE4 C入门到进阶]12.Websocket网络通信 - 哔哩哔哩 包含怎么用Nodejs 写测试服务器 UE4_使用WebSocket和Json&#xff08;上&#xff09; - 知乎 包含Python写测试服务器 UE4_使用WebSocket和Json&#xff08;下&#xff09; - 知乎 示例代码 xxx.Build.cs"W…