Pytorch索引、切片、连接

news2024/9/24 17:12:32

文章目录

    • 1.torch.cat()
    • 2.torch.column_stack()
    • 3.torch.gather()
    • 4.torch.hstack()
    • 5.torch.vstack()
    • 6.torch.index_select()
    • 7.torch.masked_select()
    • 8.torch.reshape
    • 9.torch.stack()
    • 10.torch.where()
    • 11.torch.tile()
    • 12.torch.take()



1.torch.cat()

  torch.cat() 是 PyTorch 库中的一个函数,用于沿指定维度连接张量。它接受一系列张量作为输入,并沿指定的维度进行连接。

torch.cat(tensors, dim=0, out=None)
"""
tensors:要连接的张量序列(例如,列表、元组)。
dim(可选):要沿其进行连接的维度。它指定了轴或维度编号。默认情况下,它设置为0,表示沿第一个维度进行连接。
out(可选):存储结果的输出张量。如果指定了 out,结果将存储在此张量中。如果未提供 out,则会创建一个新的张量来存储结果。
"""
import torch

# 创建两个张量
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])

# 沿着维度0连接两个张量
result = torch.cat((tensor1, tensor2), dim=0)

print(result)

2.torch.column_stack()

 torch.column_stack() 是 PyTorch 中的一个函数,用于按列堆叠张量来创建一个新的张量。它将输入张量沿着列的方向进行堆叠,并返回一个新的张量。

torch.column_stack(tensors)
"""
tensors:要堆叠的张量序列。它可以是一个包含多个张量的元组、列表或任意可迭代对象。
"""
import torch

tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

result = torch.column_stack((tensor1, tensor2))

print(result)

3.torch.gather()

torch.gather() 是 PyTorch 中的一个函数,用于根据给定的索引从输入张量中收集元素。它允许你按照指定的索引从输入张量中选择元素,并将它们组合成一个新的张量。

torch.gather(input, dim, index, out=None, sparse_grad=False)
"""
input:输入张量,从中收集元素。
dim:指定索引的维度。
index:包含要收集元素的索引的张量。
out(可选):输出张量,用于存储结果。
sparse_grad(可选):指定是否启用稀疏梯度。默认为 False
"""

在这里插入图片描述

import torch

# 输入张量
input = torch.tensor([[1, 2], [3, 4]])

# 索引张量
index = torch.tensor([[0, 0], [1, 0]])

# 根据索引从输入张量中收集元素
result = torch.gather(input, 1, index)

print(result)
import torch

# 输入张量
input = torch.tensor([[1, 2], [3, 4]])

# 索引张量
index = torch.tensor([[0, 0], [1, 0]])

# 根据索引从输入张量中收集元素
result = torch.gather(input, 0, index)

print(result)

4.torch.hstack()

  torch.hstack() 是 PyTorch 中的一个函数,用于沿着水平方向(列维度)堆叠张量来创建一个新的张量。它将输入张量沿着水平方向进行堆叠,并返回一个新的张量。

torch.hstack(tensors) -> Tensor
"""
tensors:要堆叠的张量序列。可以是一个包含多个张量的元组、列表或任意可迭代对象。
"""
import torch

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])

result = torch.hstack((tensor1, tensor2))

print(result)
# tensor([[1, 2, 5, 6],
#        [3, 4, 7, 8]])

5.torch.vstack()

torch.vstack()是PyTorch中用于沿垂直方向(行维度)堆叠张量的函数。它将输入张量沿垂直方向进行堆叠,并返回一个新的张量。

torch.vstack(tensors) -> Tensor
import torch

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])

result = torch.vstack((tensor1, tensor2))

print(result)
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])

6.torch.index_select()

torch.index_select() 是 PyTorch 中的一个函数,用于按索引从输入张量中选择元素并返回一个新的张量。

torch.index_select(input, dim, index, out=None) -> Tensor
"""
input:输入张量,从中选择元素。
dim:指定索引的维度。即要在 input 张量的哪个维度上进行索引。
index:指定要选择的索引的张量。它的形状可以与 input 张量的形状不同,但必须满足广播规则。
out(可选):输出张量,用于存储结果。如果提供了 out,则结果将存储在此张量中。
"""
import torch

# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 索引张量
index = torch.tensor([0, 2])

# 根据索引从输入张量中选择元素
result = torch.index_select(input, 0, index)

print(result)
tensor([[1, 2, 3],
        [7, 8, 9]])

7.torch.masked_select()

torch.masked_select() 是 PyTorch 中的一个函数,用于根据给定的掩码从输入张量中选择元素并返回一个新的张量。

torch.masked_select(input, mask, out=None) -> Tensor
"""
input:输入张量,从中选择元素。
mask:掩码张量,用于指定要选择的元素。mask 张量的形状必须与 input 张量的形状相同,或者满足广播规则。
out(可选):输出张量,用于存储结果。如果提供了 out,则结果将存储在此张量中。
"""
import torch

# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 掩码张量
mask = torch.tensor([[True, False, True], [False, True, False], [True, False, True]])

# 根据掩码从输入张量中选择元素
result = torch.masked_select(input, mask)

print(result)
tensor([1, 3, 5, 7, 9])

8.torch.reshape

torch.reshape() 是 PyTorch 中的一个函数,用于改变张量的形状而不改变元素的数量。它返回一个具有新形状的新张量,其中的元素与原始张量相同。

torch.reshape(input, shape) -> Tensor
"""
input:输入张量,要改变形状的张量。
shape:指定的新形状。可以是一个整数元组或传递一个张量,其中包含新的形状。
torch.reshape() 函数将输入张量重新排列为指定的新形状。新的形状应该满足以下条件:

1. 新形状的元素数量与原始张量的元素数量相同。
2. 新形状中各维度的乘积与原始张量的元素数量相同。
"""
import torch

# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 改变形状为 (3, 2)
result1 = torch.reshape(input, (3, 2))

# 改变形状为 (1, 6)
result2 = torch.reshape(input, (1, 6))

# 改变形状为 (6,)
result3 = torch.reshape(input, (6,))

print(result1)
print(result2)
print(result3)

9.torch.stack()

torch.stack() 是 PyTorch 中的一个函数,用于沿着新的维度对给定的张量序列进行堆叠操作。

torch.stack(tensors, dim=0, *, out=None) -> Tensor
"""
tensors:张量的序列,要进行堆叠操作的张量。
dim(可选):指定新的维度的位置。默认值为 0。
out(可选):输出张量。如果提供了输出张量,则将结果存储在该张量中。
"""
import torch

# 张量序列
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
tensor3 = torch.tensor([7, 8, 9])

# 在维度 0 上进行堆叠操作
result = torch.stack([tensor1, tensor2, tensor3], dim=0)

print(result)
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

10.torch.where()

torch.where() 是 PyTorch 中的一个函数,用于根据给定的条件从两个张量中选择元素。

torch.where(condition, x, y) -> Tensor
"""
condition:条件张量,一个布尔张量,用于指定元素选择的条件。
x:张量,与 condition 形状相同的张量,当对应位置的 condition 元素为 True 时,选择 x 中的对应元素。
y:张量,与 condition 形状相同的张量,当对应位置的 condition 元素为 False 时,选择 y 中的对应元素。
"""
import torch

# 条件张量
condition = torch.tensor([[True, False], [False, True]])

# 选择的张量 x
x = torch.tensor([[1, 2], [3, 4]])

# 选择的张量 y
y = torch.tensor([[5, 6], [7, 8]])

# 根据条件选择元素
result = torch.where(condition, x, y)

print(result)
#tensor([[1, 6],
#       [7, 4]])
import torch

# 输入张量
input = torch.tensor([1.5, 0.8, -1.2, 2.7, -3.5])

# 阈值
threshold = 0

# 根据阈值选择元素
result = torch.where(input > threshold, torch.tensor(1), torch.tensor(0))

print(result)#tensor([1, 1, 0, 1, 0])

11.torch.tile()

torch.tile() 是 PyTorch 中的一个函数,用于在指定维度上重复张量的元素。

torch.tile(input, reps) -> Tensor
"""
input:输入张量,要重复的张量。
reps:重复的次数,可以是一个整数或一个元组。
"""
import torch

# 输入张量
input = torch.tensor([1, 2, 3])

# 在维度 0 上重复 2 次
result = torch.tile(input, 2)

print(result)#tensor([1, 2, 3, 1, 2, 3])
import torch

# 输入张量
input = torch.tensor([[1, 2], [3, 4]])

# 在维度 0 和维度 1 上重复
result = torch.tile(input, (2, 3))

print(result)
tensor([[1, 2, 1, 2, 1, 2],
        [3, 4, 3, 4, 3, 4],
        [1, 2, 1, 2, 1, 2],
        [3, 4, 3, 4, 3, 4]])

12.torch.take()

torch.take() 是 PyTorch 中的一个函数,用于在给定索引处提取张量的元素。

torch.take(input, indices) -> Tensor
"""
input:输入张量,要从中提取元素的张量。
indices:索引张量,包含要提取的元素的索引。它可以是一个一维整数张量或一个具有相同形状的张量。
"""
import torch

# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 索引张量
indices = torch.tensor([1, 4, 7])

# 提取元素
result = torch.take(input, indices)

print(result)# tensor([2, 5, 8])
import torch

# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 索引张量
indices = torch.tensor([[0, 2], [1, 2]])

# 提取部分元素
result = torch.take(input, indices)

print(result)
tensor([[1, 3],
        [2, 3]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1706265.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

联想凌拓 NetApp AFF C250 全闪存存储助力丰田合成打造数据新“引擎”

联想凌拓 NetApp AFF C250全闪存存储助力丰田合成打造数据新“引擎” 丰田合成(张家港)科技有限公司(以下简称“丰田合成”)于2003年12月成立,坐落在中国江苏省张家港市保税区中华路113号,是日本丰田合成株…

亚马逊自养号与机刷有何区别?

在亚马逊这一全球电商巨头中,买家评价的重要性如同指南针般引领着消费者的购买决策。在购买前,消费者们往往会驻足查看产品的评论,仔细比较不同产品的买家口碑,以确保自己的选择既明智又满意。因此,测评成为了各大电商…

空间转录组数据的意义

10X空间转录组Visium学习笔记(三)跑通Visium全流程记录 | 码农家园 (codenong.com) 这两个的区别是:一个是像素的位置信息,一个是阵列的位置信息

百度百科个人词条怎么这么难通过?

百度百科作为国内最具影响力的知识平台,个人词条的通过率却让很多人感到困惑。为什么我的个人词条总是难以通过?伯乐网络传媒给大家揭秘百度百科个人词条审核的难点,并提供相应的对策。 一、百度百科词条难以通过的原因分析 1. 内容不符合审…

小动物单通道麻醉机、多通道麻醉机

ZL-04A-5多通道小动物麻醉机采用英国进口的挥发罐体,国内组装而成,产品输出气体稳定。多通道小动物麻醉机无需氧气瓶,自带空气输出机,小动物麻醉机对氧气浓度有要求可以选配氧气输出机。 详情介绍: 产品特点&#xf…

巅峰对决:OpenAI与Google如何用大模型开创未来

2024年,人工智能领域正引领着一场波澜壮阔的全球技术革命。 5月14日,OpenAI揭开了其新一代多模态人工智能大模型GPT4系列的神秘面纱,其中GPT-4o不仅拥有流畅迷人的嗓音,还展现出幽默、机智和深刻的洞察力……紧接着,在…

【MySQL数据库】存储过程实战——图书借阅系统

图书借阅归还 借阅不用count判断,归还不用具体字段值判断 每次借阅或者归还只能操作1本 数据准备 -- 创建数据库 create database db_test3 CHARACTER SET utf8 COLLATE utf8_general_ci; -- 使用数据库 use db_test3; -- 创建图书信息表: create tabl…

C++容器之双端队列(std::deque)

目录 1 概述2 使用实例3 接口使用3.1 construct3.2 assigns3.3 iterators3.4 capacity3.5 rezize3.6 shrink_to_fit3.7 access3.8 assign3.9 push_back3.10 push_front3.11 pop_back3.12 pop_front3.13 insert3.14 erase3.15 swap3.16 clear3.17 emplace3.18 emplace_front3.19…

TCS工作原理

1、TCS的基本原理 TCS 的原理建立在驱动轮最优滑转率基础之上。理论研究证明,轮胎与路面之间的纵向附着特性决定汽车的加速和制动能力,轮胎滑动率与路面附着之间存在一定的关系,驱动轮的滑动率 λ \lambda λ 可以表示如下: λ…

SurfaceFinger layer创建过程

SurfaceFinger layer创建过程 引言 本篇博客重点分析app创建Surface时候,SurfaceFlinger是如何构建对应的Layer的主要工作有那些! 这里参考的Android源码是Android 13 aosp! app端创建Surface 其核心流程可以分为如下接部分: app使用w,h,fo…

使用nvm管理node多版本(安装、卸载nvm,配置环境变量,更换npm淘宝镜像)淘宝的镜像域名更换

最近 使用nvm 管理 node 的时候发现nvm install node版本号 总是失败。 nvm install 20.12.2Error retrieving "http://npm.taobao.org/mirrors/node/latest/SHASUMS256.txt": HTTP Status 404查看原因,因为淘宝的镜像域名更换,由于 npm.taob…

搜维尔科技:【系统集成案例】三面CAVE系统案例

用户名称:成都东软学院 主要产品:工业激光投影机、光学跟踪系统、主动立体眼镜、主动式立体眼镜发生器 在4米x9米的空间内,通过三通道立体成像,对立体模型进行数字化验证,辅助unity课程设计。 立体投影大屏方案采用的…

行波进位加法器和超前进位加法器比较

文章目录 1.行波进位加法器2.超前进位加法器 1.行波进位加法器 行波进位加法器就是将全加器串联起来,将低位的进位输出作为高位的进位输入。 由全加器公式可知: S A ⊕ B ⊕ C i n C o u t A B B C i n A C i n SA\oplus B\oplus C_{in}\\ C_{out}…

ssm教职工防疫打卡小程序-计算机毕业设计源码73760

摘 要 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,教职工防疫打卡小程序被用户普遍使用,为方便用…

QGIS:根据已知色带数组生成sld文件

1、将已知的色带数组转换为Excel文件(Test.xlsx) 格式如下: 2、Python生成QGIS colormap文件(.txt) import numpy as np import os import sys import pandas as pd#QGSI输入文件 #读取C中导出的Rainbow数组 fp_rain…

电商API接口||电商数据连接器:一键连接,效率加倍!

电商数据API接口: 一键连接,效率加倍! 打造智能数据生态,让数据流动更加高效 在数字化时代,数据已成为企业发展的核心驱动力。电商API数据采集接口,助力电商企业实现数据的高效管理和应用。 电商数据API…

【乐吾乐3D可视化组态编辑器】模型类型与属性

编辑器地址:3D可视化组态 - 乐吾乐Le5le 本章主要为您介绍模型的属性功能。 一个模型至少会包含一个节点(Node),从节点类型上可以分为转换节点(TransformNode)、网格(Mesh)、实例网…

【OpenGL手册-13】光源和颜色模型

文章目录 一、说明二、颜色综述三、灯光场景四、光源位置 一、说明 光源和颜色模型也是OpenGL的重要模型之一,我们将光源也看成是一个物体,这个物体特点是,不仅可以自己移动位置,而且要和其它物体颜色进行反射运算,从而…

应急响应-网页篡改-技术操作只指南

初步判断 网页篡改事件区别于其他安全事件地明显特点是:打开网页后会看到明显异常。 业务系统某部分出现异常字词 网页被篡改后,在业务系统某部分网页可能出现异常字词,例如,出现赌博、色情、某些违法APP推广内容等。2019年4月…

【一小时学会Charles抓包详细教程】初识Charles (1)

🚀 个人主页 极客小俊 ✍🏻 作者简介:程序猿、设计师、技术分享 🐋 希望大家多多支持, 我们一起学习和进步! 🏅 欢迎评论 ❤️点赞💬评论 📂收藏 📂加关注 Charles介绍 …