PyTorch 基础学习(2)- 张量 Tensors

news2024/9/22 7:31:33

PyTorch张量简介

张量是数学和计算机科学中的一个基本概念,用于表示多维数据,是AI世界中一切事物的表示和抽象。可以将张量视为一个扩展了标量、向量和矩阵的通用数据结构。以下是对张量的详细解释:

张量的定义

  1. 标量(0阶张量):标量是一个单一的数值,例如 3 或 -1.5。

  2. 向量(1阶张量):向量是一维数组,可以表示为一列或一行数值,例如 ([1, 2, 3])。

  3. 矩阵(2阶张量):矩阵是一个二维数组,包含行和列,例如:

    [
\begin{bmatrix}
1 & 2 & 3 \
4 & 5 & 6
\end{bmatrix}
]

  4. 高阶张量:高阶张量是更高维度的数组。例如,三阶张量可以被视为一个矩阵的集合,每个矩阵位于一个"层"上,如下所示:

    [
\begin{bmatrix}
\begin{bmatrix}
1 & 2 \
3 & 4
\end{bmatrix},
\begin{bmatrix}
5 & 6 \
7 & 8
\end{bmatrix}
\end{bmatrix}
]

张量的特点

  • 维度(阶):张量的维度指的是数据的方向数。例如,一个三阶张量可以用一个3D数组表示,其中每个维度称为一个轴。

  • 形状:张量的形状由沿每个轴的元素数量定义。例如,形状为(3, 4, 5)的张量意味着它有3个“层”,每个“层”包含4行5列的数据。

  • 数据类型:张量可以存储不同类型的数据,如整数、浮点数等。

张量在机器学习中的应用

在深度学习和机器学习中,张量是处理和操作数据的基础。以下是一些应用示例:

  1. 图像数据:图像通常表示为3阶张量,其中维度分别表示高度、宽度和颜色通道。

  2. 时间序列数据:时间序列数据可以表示为2阶张量,其中一个维度表示时间步,另一个维度表示特征。

  3. 神经网络权重:在神经网络中,权重和偏置通常存储为张量,以便进行高效的计算。

将图像转换成张量

from PIL import Image
from torchvision import transforms

# 第一步:加载图像
image_path = 'path_to_your_image.jpg'  # 替换为你的图像路径
image = Image.open(image_path)

# 第二步:定义转换
transform = transforms.Compose([
    transforms.ToTensor()  # 将图像转换为张量,并将像素值归一化到[0, 1]之间
])

# 第三步:应用转换
image_tensor = transform(image)

# 查看结果
print(f"张量形状: {image_tensor.shape}")  # 输出张量形状,通常为 (C, H, W)
print(f"数据类型: {image_tensor.dtype}")  # 输出张量的数据类型

输出:

张量形状: torch.Size([3, 1024, 1024])  //一幅三通道(RGB)的 1024x1024 像素图像。
数据类型: torch.float32 //张量的数据类型为32位浮点数

张量计算

PyTorch等深度学习框架利用张量作为基本数据结构,因为它们可以高效地在GPU上进行并行计算。张量支持多种数学运算,如加法、乘法、转置、索引和切片等,这使得它们非常适合处理复杂的数据计算任务。

总结来说,张量是一个灵活且高效的数据结构,广泛用于科学计算、图像处理、自然语言处理等领域。通过使用张量,我们可以在计算机中有效地表示和操作高维数据。

张量操作

  1. 张量检测

    • torch.is_tensor(obj):判断对象obj是否是一个PyTorch张量,返回TrueFalse
  2. 元素个数

    • torch.numel(input):返回张量input中的元素个数。
    import torch
    
    a = torch.randn(1, 2, 3, 4, 5)
    print(torch.numel(a))  # 输出: 120
    
  3. 设置默认张量类型

    • torch.set_default_tensor_type(t):设置默认的张量类型。
  4. 打印选项

    • torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None):设置张量打印的选项,如精度、行宽等。

张量创建

  1. 单位矩阵

    • torch.eye(n, m=None, out=None):返回一个2D单位矩阵。
    print(torch.eye(3))
    
  2. 从Numpy数组创建张量

    • torch.from_numpy(ndarray):将Numpy数组转换为PyTorch张量。
    import numpy as np
    
    a = np.array([1, 2, 3])
    t = torch.from_numpy(a)
    
  3. 线性间隔张量

    • torch.linspace(start, end, steps=100, out=None):创建一个1D张量,包含从startend的等间隔数值。
    print(torch.linspace(3, 10, steps=5))
    
  4. 对数间隔张量

    • torch.logspace(start, end, steps=100, out=None):在对数尺度上创建等间隔数值的1D张量。
    print(torch.logspace(start=0.1, end=1.0, steps=5))
    
  5. 全1张量

    • torch.ones(*sizes, out=None):创建一个形状为sizes的全1张量。
    print(torch.ones(2, 3))
    
  6. 随机张量

    • torch.rand(*sizes, out=None):创建一个形状为sizes的张量,包含均匀分布的随机数。
    • torch.randn(*sizes, out=None):创建一个形状为sizes的张量,包含标准正态分布的随机数。
  7. 排列张量

    • torch.randperm(n, out=None):生成从0到n-1的随机排列。
    print(torch.randperm(4))
    

索引与切片

  1. 张量连接

    • torch.cat(inputs, dimension=0):在给定维度上连接张量序列。
    x = torch.randn(2, 3)
    print(torch.cat((x, x, x), 0))
    
  2. 张量分块

    • torch.chunk(tensor, chunks, dim=0):将张量分成指定数量的块。
  3. 选择性索引

    • torch.index_select(input, dim, index, out=None):在指定维度上选择特定的索引项。
    x = torch.randn(3, 4)
    indices = torch.LongTensor([0, 2])
    print(torch.index_select(x, 0, indices))
    
  4. 非零元素索引

    • torch.nonzero(input, out=None):返回一个张量,包含非零元素的索引。
    print(torch.nonzero(torch.Tensor([1, 1, 0, 0, 1])))
    
  5. 挤压与扩展

    • torch.squeeze(input, dim=None, out=None):移除形状中的1。
    • torch.unsqueeze(input, dim, out=None):在指定位置插入维度1。
    x = torch.Tensor([1, 2, 3, 4])
    print(torch.unsqueeze(x, 0))
    

随机抽样

  1. 设定随机数种子

    • torch.manual_seed(seed):设定随机数生成的种子。
  2. 伯努利分布

    • torch.bernoulli(input, out=None):从伯努利分布中抽取二元随机数。
    a = torch.Tensor(3, 3).uniform_(0, 1)
    print(torch.bernoulli(a))
    
  3. 多项分布抽样

    • torch.multinomial(input, num_samples, replacement=False, out=None):从多项分布中抽取样本。
  4. 正态分布

    • torch.normal(means, std, out=None):从正态分布中抽取样本。
    print(torch.normal(means=torch.arange(1, 6)))
    

序列化

  1. 保存对象

    • torch.save(obj, f):保存对象到文件。
  2. 加载对象

    • torch.load(f, map_location=None):从文件加载对象。

并行化

  1. 获取和设置线程数
    • torch.get_num_threads():获取用于并行化CPU操作的线程数。
    • torch.set_num_threads(int):设置用于并行化CPU操作的线程数。

应用实例:线性回归模型

下面是一个简单的线性回归模型示例,展示了如何使用上述PyTorch操作实现基本的机器学习任务。

import torch
import torch.nn as nn
import torch.optim as optim

# 生成随机数据
x = torch.rand(100, 1) * 10  # 100个样本,每个样本有1个特征
y = 2 * x + 3 + torch.randn(100, 1)  # y = 2*x + 3 + 噪声

# 定义线性模型
model = nn.Linear(1, 1)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
for epoch in range(1000):
    # 前向传播
    outputs = model(x)
    loss = criterion(outputs, y)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/1000], Loss: {loss.item():.4f}')

# 打印模型参数
[w, b] = model.parameters()
print(f'Weight: {w.item()}, Bias: {b.item()}')

输出:

......
Epoch [800/1000], Loss: 1.1024
Epoch [900/1000], Loss: 1.1014
Epoch [1000/1000], Loss: 1.1011
Weight: 1.9926187992095947, Bias: 3.0762107372283936

总结

本教程介绍了PyTorch中的基本张量操作,包括创建、索引、随机采样、序列化和并行化等。通过实例演示了如何使用这些操作构建一个简单的线性回归模型。掌握这些基本操作能够帮助开发者更高效地进行深度学习模型的构建与优化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2043381.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Assembly(七)实验环境搭建

本篇文章将讲解在win11环境下的王爽老师的汇编语言的环境搭建 首先凑齐这些文件: 随后安装好Dosbox,去官网下载就好 打开箭头所指文件 找到文件最后部分 [autoexec] # Lines in this section will be run at startup. # You can put your MOUNT lines here. MOUNT C D:\Debug …

快速搭建Vue_cli以及ElementUI简单项目学生管理系统雏形

为了帮助大家快速搭建Vue_cli脚手架还有ElementUI的简单项目,今天我给大家提供方法. 因为这个搭建这个项目步骤繁多,容易忘记,所以给大家提供这个资料希望可以帮助到你们. 废话不多说开始搭建项目: 搭建Vue_cli项目 首先点开HBuilder左上角的文件点击新建,点击项目,选择vue项…

2024年人工智能固态硬盘采购容量预计超过45 EB

根据TrendForce发布的最新市场报告,人工智能(AI)服务器客户在过去两个季度显著增加了对企业级固态硬盘(SSD)的订单。为了满足AI应用中不断增长的SSD需求,上游供应商正在加速工艺升级,并计划在20…

智慧交通物联网应用,5G路由器赋能高速道路监控数据传输

高速道路为了保障交通的高速、安全运行,沿线部署了控制设施、监视设施、情报设施、传输设施、显示设施及控制中心等。在传统的高速管理中,这些设施的传输设施多采用光纤线缆进行数据传输,但高速道路覆盖范围广、距离远,布线与施工…

韩顺平 集合

集合 一、体系结构图二、Collection2.1 Collection 接口和常用方法2.2 集合遍历2.2.1 迭代器2.2.2 增强for循环 三、List接口及其常用方法3.1 三种遍历方式3.2 ArrayList3.3 LinkedList 四 MAP4.1 hashmap 一、体系结构图 集合主要是两组 单列和双列集合 Collection接口有两个重…

第十五章:高级调度

本章内容包括: 使用节点污点和pod容忍度组织pod调度到特定节点将节点亲和性规则作为节点选择器的一种替代使用节点亲和性进行多个pod的共同调度使用节点非亲和性来分离多个pod Kubernetes允许你去影响pod被调度到哪个节点。起初,只能通过在pod规范⾥指定…

Linux安装Nginx后,无法解析Windows主机Hosts文件

问题展示: 配置好Linux的Nginx配置后,Windows同样配置好host,而通过浏览器只能用IP地址成功访问,而域名则不行 解决方法: 点击Windows图标,搜索记事本,选择以管理员身份运行,编辑…

php-xlswriter实现数据导出excel单元格合并,内容从指定行开始写

最终效果图: 代码: public function export_data() {$list $this->get_list_organ();$content [];$content[] []; // 第2行不设置内容,设置为空foreach ($list as $key > $value) {$content[] [$value[organ_name], $value[clas…

防火墙技术与地址转换

文章目录 前言一、四种区域二、实验拓扑图基础配置防火墙配置测试结果 前言 防火墙是计算机网络中的一种安全设备或软件功能,旨在监控和控制进出网络的网络流量。其核心目的是保护内部网络免受外部攻击或不必要的访问。防火墙通过设定一系列安全规则,允…

【iOS】UITableViewCell的重用问题解决方法

我自己在实验中对cell的重用总结如下: 非自定义Cell和非自定义cell的复用情况一样: 第一次加载创建tableView的时候,是屏幕上最多也显示几行cell就先创建几个cell,此时复用池里什么都没有开始下滑tableView,刚开始滑…

可视化编程-七巧低代码入门02

1.1.什么是可视化编程 非可视化编程是一种直接在集成开发环境中(IDE)编写代码的编程方式,这种编程方式要求开发人员具备深入的编程知识,开发效率相对较低,代码维护难度较大,容易出现错误,也需要…

最新的APS高级计划排程系统推动的MRP供应链计划是什么?

在当下“内卷”的市场环境下,制造业的订单需求从过去大批量标准品生产已经演变成小批量、多订单的非标订单生产,这对制造业的供应链提出了更高的要求。为了应对市场实现产销平衡,中大型的企业都开始重视供应链的建设工作,以应对企…

数字签名和CA数字证书的核心原理和作用

B站讲解视频,讲述HTTPS CA认证的整个行程过程与原理 https://www.bilibili.com/video/BV1mj421d7VE

[Qt][Qt 文件]详细讲解

目录 1.输入输出设备类2.文件读写类3.文件和目录信息类 1.输入输出设备类 在Qt中,⽂件读写的类为QFile,其⽗类为QFileDevice QFileDevice提供了⽂件交互操作的底层功能QFileDevice的⽗类是QIODevice,其⽗类为QObject QIODevice是Qt中所有I/O…

【数学建模备赛】Ep05:斯皮尔曼spearman相关系数

文章目录 一、前言🚀🚀🚀二、斯皮尔曼spearman相关系数:☀️☀️☀️1. 回顾皮尔逊相关系数2. 斯皮尔曼spearman相关系数3. 斯皮尔曼相关系数公式4. 另外一种斯皮尔曼相关系数定义5. matlab的用法5. matlab的用法 三、对斯皮尔曼相…

立仪光谱共焦传感器行业应用|透明胶水高度测量

01|检测需求:透明胶水高度测量 02|检测方式 根据客户要求及观察我们使用立仪科技D40A26XL镜头搭配E系列控制器进行测量 03|光谱共焦测量结果 经过测量可以得出胶水的高度为1076.406μm 04|光谱共焦侧头 D40A26XL侧头…

uniapp接口请求this.$request

代码示例: createPhoto(url) {this.$request({url: /emp/gallery-photo/create,//后端接口method: post,//请求方法header: {//请求头tenant-id: 1,},data: {//请求参数galleryId: this.albumId,empUserId: this.empUserId,"url": url,}}).then((res) &…

JVM -垃圾回收器

本人在这篇文章中讲解了垃圾回收机制,这为前置知识 美团一面面经:Threadlocal(线程局部变量的原理)->内存泄漏问题->垃圾回收机制_threadlocal回收-CSDN博客 首先对前置知识漏洞做一个补充:ja…

时序电路实验-节拍脉冲发生器

二、实验目的 掌握节拍脉冲发生器的设计方法,理解节拍脉冲发生器的工作原理。 三、实验环境 PC计算机 四、实验内容 单步/连续节拍发生电路设计 增加两个2-1多路选择器,可将图3.3所示电路修改为图3.5所示电路。 图3.5单步/连续节拍脉冲发生器工作波…

如何进行长截图的两种方法

前言 本文主要讲2种截图方式,分别是谷歌和QQ。 谷歌分为Web端 和 移动端,选一种即可。 第一种:谷歌浏览器控制台自带的 1.先把控制台语言更改为中文,方便查看 ①.按F12,点击设置面板 ②.修改语言为中文并关闭 ③.点击…