torch中张量与数据类型的介绍

news2024/11/27 6:16:27

PyTorch张量的定义介绍

PyTorch最基本的操作对象是张量,它表示一个多维数组,类似NumPy的数组,但是前者可以在GPU上加速计算

初始化张量
t=torch.tensor([1,2])  # 创建一个张量
print(t)       
t.dtype       #打印t的数据类型为torch.int64

如果直接从Python数据创建张量,无须指定类型,PyTorch会自动推荐其类型,可通过张量的dtype属性查看其数据的类型,与numpy中的语法很相似

#创建float类型
t=torch.FloatTensor([1,2])
print(t)
print(t.dtype)

# 创建int类型的数据
t=torch.LongTensor([1,2])
print(t)
print(t.dtype)

也可以使用torch.from_numpy()方法从NumPy数组ndarray创建张量。

np_array=np.array([[1,2],[3,4]])
t_np=torch.from_numpy(np_array.reshape(1,4)).type(torch.FloatTensor)
print(t_np)
print(t_np.dtype)

首先我们能够看得到np_array依然是numpy的对象数据,依旧可以使用reshape等相关的方法
我们可以使用torch.from_numpy()方法在numpy数组上创建张量

t=torch.tensor([1,2],dtype=torch.int64)
print(t)
print(t.dtype)

t=torch.tensor([1,2],dtype=torch.float32)
print(t)
print(t.dtype)

这里用的最多的两种类型为int64,float32,这两种类型也尝尝被表示为torch.long和torch.float,其实这两种了类型对应了两种创建张量的方法,分别为
torch.FloatTensor()和torch.LongTensor()

#三种方法创建int64类型的数据
#1.
t1=torch.LongTensor([1,2])
print(t1)
print(t1.dtype)

#2.
t2=torch.tensor([1,2],dtype=torch.long)
print(t2)
print(t2.dtype)

#3.
t3=torch.tensor([1,2],dtype=torch.int64)
print(t3)
print(t3.dtype)

#三种方法创建float32类型的数据
#1.
t1=torch.FloatTensor([1,2])
print(t1)
print(t1.dtype)

#2.
t2=torch.tensor([1,2],dtype=torch.float)
print(t2)
print(t2.dtype)

#3.
t3=torch.tensor([1,2],dtype=torch.float32)
print(t3)
print(t3.dtype)

那他们确实是相等的吗,我们用代码来实现一下

print(torch.float==torch.float32)#判断是否相等,结果放回为true
print(torch.long==torch.int64)

PyTorch中张量的类型可以使用type()方法进行转换

t=torch.tensor([1,2],dtype=torch.float)
print(t.dtype)
#使用type()方法进行数据类型转化
t=t.type(torch.int64)
print(t.dtype)

torch框架中提供了两个快捷的实力转换方法

t=torch.tensor([1,2],dtype=torch.float)
print(t.dtype)
t=t.long()#使用long()将float32数据转换类int64的数据类型
print(t.dtype)
t=t.float()
print(t.dtype)

创建随机值张量

t=torch.rand(2,3)#创建一个2列3行的0-1均匀分布的随机数
print(t)

t=torch.randn(2,3)#创建一个2列3行的标准正态分布随机数
print(t)

t=torch.zeros(3,3)#创建一个3行3列的全是0的张量
print(t)

t=torch.ones(3,2)#创建一个3行2列的全是1的张量
print(t)

t=torch.ones(3,3,3)#创建三维的数组,就只需要传入3个数据进入就行
print(t)

x=torch.zeros_like(t)  #类似的方法还有torch.ones_like()
print(x)

x=torch.rand_like(t)
print(x)

张量的属性

t=torch.ones(2,3,dtype=torch.float32)
print(t.shape)
print(t.size())
print(t.size(0))
print(t.dtype)
print(t.device)

将张量移动到显存

张量可以在cpu上运行,也可以在GPU上运行,在GPU上的运算速度通常高于CPU,默认是在CPU上创建张量,如下可以使用tensor.to()方法将张量移动到GPU上

# 如果GPU可用,将张量移动到显存
if torch.cuda.is_available():
    t=t.to('cuda')
    print('true')
print(t.device)

#一般使用如下代码获取当前可用设备
device="cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
t=t.to(device)
print(t.device)

张量的运算

t1=torch.randn(2,3)
t2=torch.ones(2,3)

print(t1)
print(t2)
print(t1+3)   # t1中每一个元素都加3
print(t1+t2)  #t1与t2中每一个相同位置的元素相加 #不能将不同维数的张量相加,张量当中也有一个广播机制存在

张量加法运算
t3=torch.ones(1,3)
print(t3)
print(t1+t3)

t4=torch.randn(2,3)
t5=torch.ones(2,3)
print(t4)
print(t5)
t4.add_(t5)
print(t4)

如果一个运算方法后面加上下划线,代表就地改变原值,即上方中的t1.add_(t2)会直接将运算结果保存为t1,这样做可以节省内存,但是缺点就是会直接改变t1原值,在使用此方法的时候一定要谨慎使用。

矩阵乘法
print(t5.T)
print(t4.matmul(t5.T)) # matmul中没有下划线的方法
print(t4 @ (t5.T)) #都是对t1和t2的转置进行矩阵乘法

tensor.item
#输出一个python浮点数
print(t3)
print(t3.item())
#首先将张量中的所有的元素求和,得到只有一个元素的张量,然后使用tensor.item()方法将其转换为标量
#这种转换在我们希望打印模型正确率和损失值的时候很常见

与numpy数据类型的转换

a=np.random.randn(2,3)  #创建一个形状为(2,3)的ndarray对象
print("ndarray为:\n",a)                
t=torch.from_numpy(a)   #使用torch.from_numpy创建一个ndarray创建一个张量
print("tensor为:\n",t)       
print("numpy所对应的ndarray为:\n",t.numpy())         #使用torch.numpy()方法获得张量对应的ndarray

张量的变形

tensor.size()和tensor.shape属性可以返回张量的形状,方需要改变张量的形状时,可以通过tensor.view()方法,这个方法相当于NumPy中的reshape方法,用于改变张量的形状,但是在转换的过程中,一定要确保元素数量一致。

t=torch.randn(4,6,dtype=torch.float32)
print(t)
print(t.shape)

t1=t.view(3,8)
print(t1)
print(t1.shape)

t1=t.view(-1,1)
print(t1)

如果想展成横着的一维,就默认让view(1,-1)中设置1

t1=t.view(1,-1)
print(t1)

# 也可以使用view增加维度,当然元素个数是不变的
t1=t.view(1,4,6)
print(t1)
print(t1.shape)

对于维度长度为1的张量,可以使用torch.squeeze()方法去掉长度为1的维度,相应的也有一个增加长度为1维度的方法,即torch.unsqueeze()

print(t1.shape)
print(t1)
t2=torch.squeeze(t1)
print(t2.shape)
print(t2)
t3=torch.unsqueeze(t2,0)
print(t3.shape)
print(t3)

张量的自动微分

在PyTorch中,张量有一个requires_grad属性可以在创建张量时指定此属性为True,如果requires_grad属性被设置为True,PyTorch将开始跟踪对此张量的所有计算,完成计算,可以对计算记过调用backward()方法,PyTorch将自动计算所有梯度。该张量的梯度将累加到张量的grad属性中,张量的grad_fn属性则指向运算生成此张量的方法。
张量的requires_grad属性用来明确是否跟踪张量的梯度,grad属性表示计算得到的梯度,grad_fn属性表示运算得到生成此张量的方法。

t=torch.ones(2,2,requires_grad=True)  # 这里将requires_grad设为True
print(t)
print(t.requires_grad)  # 输出是否跟踪计算张量梯度,输出True
print(t.grad)  #输出tensor.grad输出张量的梯度,输出为None,表示目前t没有梯度

接下来进行张量运算,得到y

y=t+5
print(y)
print(y.grad_fn)
print(y.grad)

# 进行其他运算
z=y*2
out=z.mean()
print(out)

上面的代码中,首先创建了张量,并指定requires_grad属性为True,目前其grad和grad_fn属性均为空。然后经过加法、乘法和取均运算,我们得到了out这个最终结果,注意,现在out只有单个元素,它是一个标量。下面在out上执行自动微分运算,并且输出t的梯度(d(out)/d(x))

out.backward()
print(t.grad)

print(t.requires_grad)
print((t+2).requires_grad)

# 将代码块包装在with torch.no_grad():上下文中
with torch.no_grad():
    print((t+2).requires_grad) # PyTorch没有继续跟踪此张量的运算

也可以使用tensor.detach()方法获得具有相同内容但是不需要跟踪运算的新张量,可以认为是获得张量的值

print(out.requires_grad)
s=out.detach()  # 获得out的值,也可以使用out.data()方法
print(s.requires_grad)

可以使用requires.grad_()方法就地改变张量的这个属性,当我们希望模型的参数不在随着训练变化时,我们可以使用此方法

print(t.requires_grad)
t.requires_grad_(False)
print(t.requires_grad)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1318734.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue 指定class区域增加水印显示(人员姓名+时间)

效果 代码,存放位置 /utils/waterMark.js //waterMark.js文件let waterMark {}let setWaterMark (str,str1) > {let id 1.23452384164.123412416;if (document.getElementById(id) ! null) {//ui-table是table上的一个样式,一般水印显示在这个tab…

visual stdio code运行js没有输出

visual code运行js没有输出 先Debug file 然后右键直接run code就会输出了 插件的安装 visual stdio code插件安装 c qt wordle游戏实现

RK3568平台(网络篇)添加网络交换芯片RTL8306M

一.硬件原理图 分析: 该交换芯片支持I2C、SPI、mdio通信,但是看ast1520的uboot代码采用的是mdio去通信phy芯片的,所以暂时也先采用mdio的方式,需要配置相应的引脚才可以配置成mdio通信模式,具体的配置硬件工程师解决。…

代码随想录算法训练营Day4 | 24.两两交换链表中的节点、19.删除链表的倒数第 N 个节点、面试题. 链表相交、142.环形链表II

LeetCode 24 两两交换链表中的节点 本题要注意的条件: 遍历终止条件改变引用指向的时候,需要保存一些节点记录 为了更好的操作链表,我定义了一个虚拟的头节点 dummyHead 指向链表。如下图所示 既然要交换链表中的节点,那么肯定…

Ribbon使用

Ribbon :处理客户端负载均衡和容错的解决方案 配置Ribbon的负载均衡 Rule接口: 定义客户端负载均衡的规则 RandomRule :随机选择RoundRobinRuleZoneAvoidanceRule 配置ribbon的负载均衡策略 在配置文件中配置 user-center:ribbon:NFLoadBalancerRul…

网络安全项目实战(六)--报文检测

11. NTP应用协议报文解析 目标 了解NTP协议了解NTP包基本捕获方式了解NTP协议探测(解析)方法(简单方法) 11.1. 使用ntpdate同步网络时间 安装 $ sudo apt-get install ntpdate对时服务 查看时间 $ date #date可以查看当前系…

Jupyter Notebook: 交互式数据科学和编程工具

Jupyter Notebook: 功能强大的交互式编程和数据科学工具 简介 Jupyter Notebook是一个开源的Web应用程序,广泛用于数据分析、科学计算、可视化以及机器学习等领域。它允许创建和共享包含实时代码、方程式、可视化和解释性文本的文档。总而言之,我认为它…

PLC、RS485、变频器通讯接线图详解

plc与变频器两者是一种包含与被包含的关系,PLC与变频器都可以完成一些特定的指令,用来控制电机马达,PLC是一种程序输入执行硬件,变频器则是其中之一。 但是PLC的涵盖范围又比变频器大,还可以用来控制更多的东西&#x…

excel可视化看板【动态关联公司、部门、人员、及时间】

昨天网友花钱定制了一个可视化报表,花了一整天时间,做了这份酷炫的可视化报表,右边按钮控件可以动态关联可视化图表 做这种这重要是数据的统计,只要能统计到,剩下的只是如何展示,慢慢的调整,美…

【复杂网络分析与可视化】——通过CSV文件导入Gephi进行社交网络可视化

目录 一、Gephi介绍 二、导入CSV文件构建网络 三、图片输出 一、Gephi介绍 Gephi具有强大的网络分析功能,可以进行各种网络度量,如度中心性、接近中心性、介数中心性等。它还支持社区检测算法,可以帮助用户发现网络中的群组和社区结构。此…

计算机网络 网络层上 | IP数据报,IP地址,ICMP,ARP等

文章目录 1 网络层的两个层面2 网络协议IP2.1 虚拟互联网络2.2 IP地址2.2.1 固定分类编址方式2.2.2 无分类编制CIDR2.2.3 MAC地址和IP地址区别 2.3 地址解析协议ARP2.3.1 解析过程 2.4 IP数据报格式 3 IP层转发分组流程4 国际控制报文协议ICMP4.1 ICMP格式结构4.2 分类4.2.1 差…

快速排序(一)

目录 快速排序(hoare版本) 初级实现 问题改进 中级实现 时空复杂度 高级实现 三数取中 快速排序(hoare版本) 历史背景:快速排序是Hoare于1962年提出的一种基于二叉树思想的交换排序方法 基本思想&#xff1a…

强制性产品认证车辆一致性证书二维码解析

目录 说明 界面 下载 强制性产品认证车辆一致性证书二维码解析 说明 二维码扫描出的信息为: qW0qS6aFjU50pMOqis0WupBnM21DnMxy0dGFN/2Mc9gENXhKh0qEBxFgfXSLoR qW0qS6aFjU50pMOqis0WupBnM21DnMxy0dGFN/2Mc9gENXhKh0qEBxFgfXSLoR 解析后的信息为&#xff1a…

20来岁,大专毕业,学软件测试可行吗?

转行软件测试找不到工作! 转行软件测试找不到工作! 转行软件测试找不到工作! 重要的事情说三遍!千万别听培训班咨询老师给你画饼 ;我就是某某软件测试培训班出来的,大专,其他专业毕业&#x…

初级数据结构(六)——堆

文中代码源文件已上传&#xff1a;数据结构源码 <-上一篇 初级数据结构&#xff08;五&#xff09;——树和二叉树的概念 | NULL 下一篇-> 1、堆的特性 1.1、定义 堆结构属于完全二叉树的范畴&#xff0c;除了满足完全二叉树的限制之外&#xff0c;还满…

小程序商城活动页面怎么生成二维码

背景 小程序商城某些页面需要做成活动推广页&#xff0c;或需要某一个页面做成二维码进行推广。比如某些非公开的商品做成一个活动&#xff0c;发送指定部分用户&#xff0c;这个活动页面可以做成二维码。 前提 小程序已经上线 步骤 登录微信小程序官网&#xff0c;选择工具…

数据库交付运维高级工程师-腾讯云TDSQL

数据库交付运维高级工程师-腾讯云TDSQL上机指导&#xff0c;付费指导&#xff0c;暂定99

Element 介绍

Element 介绍 Vue 快速入门 Vue 常见组件 表格 分页组件 其他自己去看吧 链接: 其他组件

【idea】解决sprintboot项目创建遇到的问题

目录 一、报错Plugin ‘org.springframework.boot:spring-boot-maven-plugin:‘ not found 二、报错java: 错误: 无效的源发行版&#xff1a;17 三、java: 无法访问org.springframework.web.bind.annotation.CrossOrigin 四、整合mybatis的时候&#xff0c;报java.lang.Ill…

文件函数的简单介绍

1. 向文件中写入一个字符 fputc int_Ch指的是输入文件中的字符 &#xff08;int&#xff09;的原因是以ascll码值的型式输入 #include <stdio.h> #include <errno.h> #include <string.h> int main() { FILE* pf fopen("test.txt","…