【youcans的深度学习 07】PyTorch入门教程:张量的基本操作 2

news2025/1/9 16:21:06

欢迎关注『youcans的深度学习』系列,持续更新中…
【youcans的深度学习 01】安装环境之 miniconda
【youcans的深度学习 02】PyTorch CPU版本安装与环境配置
【youcans的深度学习 03】PyTorch CPU版本安装与环境配置
【youcans的深度学习 04】PyTorch入门教程:基础知识
【youcans的深度学习 05】PyTorch入门教程:快速入门
【youcans的深度学习 06】PyTorch入门教程:张量的基本操作 1
【youcans的深度学习 07】PyTorch入门教程:张量的基本操作 2

【youcans的深度学习 07】PyTorch入门教程:张量的基本操作 2

    • 5. 改变Tensor的形状
      • 5.1 张量的阶、轴与形状
      • 5.2 张量的形状处理
        • 5.2.1 张量的维度交换(transpose)
        • 5.2.2 张量形状的重构(reshape)
        • 5.2.3 张量形状的重构(view)
        • 5.2.4 维度删除(squeeze)和插入(unsqueeze)
        • 5.2.5 重新排列张量的维度(permute)
    • 6. 张量的索引、切片和连接
      • 6.1 张量的索引与切片
        • 一维张量的索引
        • 二维张量的索引
        • 三维张量的索引
        • 索引函数 index_select
      • 6.2 张量的拆分
        • 拆分函数 split
        • 拆分函数 chunk
      • 6.3 张量的连接
        • 张量连接函数 cat
        • 张量连接函数 stack
    • 7. PyTorch张量与 Numpy/列表/元组的转换
      • 7.1 Numpy/列表/元组转换为张量
      • 7.2 张量转换为Numpy/列表/元组
      • 7.3 提取张量的数值
    • 8. 序列化张量


5. 改变Tensor的形状

本节深入讨论张量的三个基本属性:阶(Rank)、轴(Axis)和形状(Shape),在此基础上讨论如何改变张量的形状。

5.1 张量的阶、轴与形状

  • 张量的阶(Rank)是指张量中的维数。

张量的阶告诉我们访问(引用)张量数据结构中的特定数据元素需要多少个索引。

A tensor’s rank tells us how many indexes are needed to refer to a specific element within the tensor.

例如,一个二阶张量,对应于一个二维数组、二维矩阵。

在这里插入图片描述

  • 张量的轴(axis)是张量的某个特定的维度。

An axis of a tensor is a specific dimension of a tensor.

张量中的元素被认为是存在并且在某个轴上运动,并受到该轴的长度的限制。换句话说,每个轴的长度,表示沿着这个轴有多少个索引。

对于 n维张量,沿着第一个轴 Axis0 的每个元素是一个 n-1 维数组,沿着第二个轴 Axis1 的每个元素是一个 n-2 维数组,…,沿着最后一个轴 的元素是数字。

二维张量的轴

例如,一个二阶张量有两个维度,这个张量有两个轴。

t = torch.tensor([[1,2,3,4],
                  [5,6,7,8], 
                  [9,10,11,12]])
print("t.ndim:", t.ndim)  # 张量的维数,t.ndim: 2
print("t.shape:", t.shape)  # 张量的形状,t.shape: torch.Size([3, 4])

二维张量 t 的第一个轴 Axis0 的长度为 3,因此沿着第一个轴可以索引 3个位置:

t[0], t[1], t[2]

二维张量 t 的沿着第一个轴 Axis0 的每个元素是一个一维张量(数组):

Axis0: 
t[0]: tensor([1, 2, 3, 4])
t[1]: tensor([5, 6, 7, 8])
t[2]: tensor([ 9, 10, 11, 12])

t 的第二个轴 Axis1 的长度为4,因此沿着第二个轴可以索引 4个位置:

t[0][0], t[1][0], t[2][0]
t[0][1], t[1][1], t[2][1] 
t[0][2], t[1][2], t[2][2] 
t[0][3], t[1][3], t[2][3] 

二维张量 t 的沿着第二个轴 Axis1 的每个元素是一个 0维张量(数字):

Axis1: 
t[0][0]: tensor(1)
t[0][3]: tensor(4)
t[2][3]: tensor(12)

三维张量的轴

三维张量的第一个轴 Axis0 表示的是“深度”或者“高度”(z方向),第二个轴 Axis1 表示的是“长度”(x方向),第三个轴 Axis2 表示的是“宽度”(y方向)。

因此,三维张量的结构是多个二维张量从上到下叠放而成

B = torch.tensor([[[1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 10, 11, 12]],  # B[0]
                  [[4, 3, 2, 1],
                   [8, 7, 6, 5],
                   [12, 11, 10, 9]]])  # B[1]

三维张量 B 的第一个轴 Axis0 的长度为 2,因此沿着第一个轴可以索引 2个位置:

B[0], B[1]

三维张量 B 的沿着第一个轴 Axis0 的每个元素是一个二维张量:

Axis0: 
B[0]:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
B[1]:
tensor([[ 4,  3,  2,  1],
        [ 8,  7,  6,  5],
        [12, 11, 10,  9]])

类似地,四维张量的结构是多个三维张量从上到下叠放而成,N维张量的结构是多个(N-1)维张量从上到下叠放而成。所谓从上到下叠放是相对的,取决于第一个轴 Axis0 的观察角度。但是, PyTorch 张量在输出时是按照从上到小的竖直方向排列,所以选择从上到下 的方向作为第一个轴 Axis0 的观察角度比较直观。

  • 张量的形状(shape)描述了每个轴的长度(元素数),也可以说张量的形状由每个轴的长度决定。

The shape of a tensor gives us the length of each axis of the tensor.

张量的形状不仅给出了张量的每个轴的长度,而且可以获得有关轴、秩以及索引的相关信息。在 PyTorch 中,张量的大小和形状是一样的。使用 shapesize() 可以查看张量的形状。

注意张量形状的维度是按第一轴 Axis0 向第(N-1)轴 Axis(N-1) 排列的,dim=0 表示第一轴,dim=-1 表示最后一轴。因此,任何一个二维以上的 PyTorch 张量,dim=-1 都表示 y 轴(最后一个维度),dim=-2 都表示 x 轴(倒数第二个维度)。例如,torch.Size([2, 3, 4] 表示张量的第一轴长度为 2、第二轴长度为 3、第三轴长度为 4。

print("B.ndim:", B.ndim)  # 张量的维数,B.ndim: 3
print("B.shape:", B.shape)  # 张量的形状,Bshape: torch.Size([2, 3, 4])

例程:

# (16) 张量的阶(Rank)、轴(Axis)和形状(Shape)
# 二维张量
t = torch.tensor([[1,2,3,4],
                  [5,6,7,8],
                  [9,10,11,12]])
print("t.ndim:", t.ndim)  # 张量的维数,t.ndim: 2
print("t.shape:", t.shape)  # 张量的形状,t.shape: torch.Size([3, 4])
print("Axis0:", t[0], t[1], t[2])
print("Axis1:", t[0][0], t[0][3], t[2][3])
# 三维张量
B = torch.tensor([[[1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 10, 11, 12]],  # B[0]
                  [[4, 3, 2, 1],
                   [8, 7, 6, 5],
                   [12, 11, 10, 9]]])  # B[1]
print("B.ndim:", B.ndim)  # 张量的维数,B.ndim: 3
print("B.shape:", B.shape)  # 张量的形状,Bshape: torch.Size([2, 3, 4])
print("Axis0:")
print(B[0])
print(B[1])

5.2 张量的形状处理

PyTorch 中的张量形状处理方法,是指改变张量的维度而不改变元素的数值。常用方法有:torch.reshape,torch.view,torch.repeat,torch.expand,torch.permute,torch.transpose。

方法说明
torch.reshape(input, shape)改变 input 的维度为 shape (如果适合),也可以使用 torch.Tensor.reshape().
torch.Tensor.view(shape)张量的维度变为 shape,返回的张量与原始张量共享数据。
torch.stack(tensors, dim=0)tensors 沿新维度 (dim)连接一系列序列,所有序列的tensors 大小必须相同。
torch.squeeze(input)移除input 维度为1 的行或者列。
torch.unsqueeze(input, dim)添加 input 维度为1 的行或者列。
torch.permute(input, dims)input 的维度根据 dims 重新排列,若赋予新的张量与原始张量共享数据

5.2.1 张量的维度交换(transpose)

transpose 方法用于对张量进行维度交换。

对于多维张量,使用函数 torch.transpose 交换指定的维度。

torch.transpose(input, dim0, dim1) → Tensor

其中,input 表示张量,dim0、dim1 表示交换的张量维度(第几轴)。

# (17) 张量的转置 transpose (变换给定张量的维度)
X3d = torch.tensor([[[1, 2, 3, 4],  # 三维张量
                     [5, 6, 7, 8],
                     [9, 10, 11, 12]],  # B[0]
                    [[4, 3, 2, 1],
                     [8, 7, 6, 5],
                     [12, 11, 10, 9]]])  # B[1]
Xt1 = torch.transpose(X3d, 0, 1)  # 交换 Axis0 与 Axis1
Xt2 = torch.transpose(X3d, 0, 2)  # 交换 Axis0 与 Axis2
Xt3 = torch.transpose(X3d, 1, 2)  # 交换 Axis1 与 Axis2
print("X3d.shape:", X3d.shape)  # shape of X3d: torch.Size([2, 3, 4])
print("Xt1.shape:", Xt1.shape)  # shape of Xt1: torch.Size([3, 2, 4])
print("Xt2.shape:", Xt2.shape)  # shape of Xt2: torch.Size([4, 3, 2])
print("Xt3.shape:", Xt3.shape)  # shape of Xt3: torch.Size([2, 4, 3])

对于二维张量,只有两个维度 dim0 和 dim1,既可以使用函数 torch.transpose 交换 dim0 与 dim1,也可以使用 troch.t(input)input.T 实现矩阵转置。

X2d = torch.tensor([[1, 2, 3, 4],  # 三维张量
                    [5, 6, 7, 8],
                    [9, 10, 11, 12]])
print("X2d.shape:", X2d.shape)  # 张量的形状,X2d.shape: torch.Size([3, 4])
Xt = torch.t(X2d)  # 等价于:Xt = X2d.T
print("Xt.ndim:", Xt.ndim)  # 张量的维数,Xt.ndim: 2
print("Xt.shape:", Xt.shape)  # 张量的形状,Xt.shape: torch.Size([4, 3])

5.2.2 张量形状的重构(reshape)

函数 torch.reshape 返回将输入张量转变为指定的形状大小,元素总数不变。可以理解为将张量压扁为一维后再按指定形状重构。

torch.reshape(input, shape) → Tensor

其中,input 表示张量,shape 表示重构张量的形状。如果 shape 中的一个参数指定为 -1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。

返回的张量将尽可能是输入的视图,否则才会返回输入的拷贝。注意:这里很容易由于混淆“视图”与“拷贝”而发生错误,推荐编程时先强制使用“复制”,调试通过后再考虑优化内存。

连续输入和步幅兼容的输入可以在不复制的情况下重新整形,但不应依赖于复制与查看行为。

# (18) 张量的形状重构
X3d = torch.tensor([[[1, 2, 3, 4],  # 三维张量
                     [5, 6, 7, 8],
                     [9, 10, 11, 12]],  # B[0]
                    [[4, 3, 2, 1],
                     [8, 7, 6, 5],
                     [12, 11, 10, 9]]])  # B[1]

X1 = torch.reshape(X3d, (3, 4, 2))
print(X1)
print("X1.shape:", X1.shape)  # shape of X1: torch.Size([3, 4, 2])

X2 = torch.reshape(X3d, (3, 8))
print(X2)
print("X2.shape:", X2.shape)  # shape of X2: ([3, 8])

X3 = torch.reshape(X3d, (-1,))
print(X3)
print("X3.shape:", X3.shape)  # shape of X3: torch.Size([24])

5.2.3 张量形状的重构(view)

函数 torch.Tensor.view 返回一个新张量,将输入张量转变为指定的形状大小,元素总数不变。可以理解为将张量压扁为一维后再按指定形状重构,相当于 reshape 操作。

torch.Tensor.view(*shape) → Tensor

其中,Tensor 表示张量,shape 表示重构张量的形状。

返回的张量与输入张量必须具有相同数量的元素,但可能具有不同的形状。返回张量的形状,必须与输入张量的形状(size)和步长(stride)兼容。如果 shape 中的一个参数指定为 -1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。

注意,返回的张量与输入张量共享内存,因此更改其中一个张量的内容也会改变另一个张量的内容。如果需要不共享内存,返回一个真正新的副本,可以先使用 clone() 创造一个副本然后再使用 view() 。使用 clone() 还有一个好处是会被记录在计算图中,即梯度回传到副本时也会传到源 Tensor。
s t r i d e [ i ] = s t r i d e [ i + 1 ] × s i z e [ i + 1 ] stride[i]=stride[i+1]×size[i+1] stride[i]=stride[i+1]×size[i+1]


5.2.4 维度删除(squeeze)和插入(unsqueeze)

函数 torch.squeeze() 用于删除长度为 1 的维度,函数 torch.unsqueeze() 用于在指定位置插入大小为 1 的维度。

torch.squeeze(input, dim=None, *, out=None) → Tensor

torch.unsqueeze(input, dim) → Tensor

其中,tensor 表示输入的张量,dim 表示指定的位置。

函数 torch.squeeze() 默认将 input 中所有长度为 1 的维度删掉;也可以通过 dim 指定位置,删掉指定位置的维数为 1的维度。例如,input 的形状为 (A, 1, B),squeeze(input) 返回结果的形状是 (A, B),squeeze(input, 1) 返回结果的形状也是 (A, B),而 squeeze(input, 0) 返回结果的形状仍是 (A, 1, B),

注意:

(1)返回的张量与输入张量共享内存,因此更改其中一个张量的内容也会改变另一个张量的内容。

(2)对于批量张量,如果批量轴(batch axis)的长度为 1(batch=1),维度删除操作也会删除批量维度(batch dimension),可能会导致错误。

# (19) 张量的维度删除和维度插入
x = torch.zeros(2, 1, 3, 4, 1)
y1 = torch.squeeze(x)
y2 = torch.squeeze(x, 0)
y3 = torch.squeeze(x, 1)
print(x.shape)  # torch.Size([2, 1, 3, 4, 1])
print(y1.shape)  # torch.Size([2, 3, 4])
print(y2.shape)  # torch.Size([2, 1, 3, 4, 1])
print(y3.shape)  # torch.Size([2, 3, 4, 1])

z1 = torch.unsqueeze(y1, 0)
z2 = torch.unsqueeze(y1, 1)
print(z1.shape)  # torch.Size([1, 2, 3, 4])
print(z2.shape)  # torch.Size([2, 1, 3, 4])

5.2.5 重新排列张量的维度(permute)

函数 torch.permute() 用于重新排列张量的维度。

torch.permute(input, dims) → Tensor

其中,tensor 表示输入的张量,dims 表示所需的维度及长度。

dims 是元组,表示张量的维数,元组元素依次是第0维长度、第1维长度、等等…。

# (20) 重新排列张量的维度
x = torch.randint(0, 10, size=(2, 3, 5))  # 均匀分布随机整数,[0, 10) 区间
print(x)
print(x.shape)  # torch.Size([2, 3, 5])
y = torch.permute(x, (2, 0, 1))
print(y)
print(y.shape)  # torch.Size([5, 2, 3])

6. 张量的索引、切片和连接

6.1 张量的索引与切片

索引是指访问张量中的唯一元素。张量是有序序列,可以根据每个元素在系统内的顺序位置,访问特定的元素,这就是索引 。

张量的索引方式与 Python 中的列表索引方式类似,可以使用整数、切片、布尔值和其他张量来进行索引。例如,对于一个二维张量,可以使用两个整数来访问其中的元素,如 tensor[0][1]

切片是针对某个维度方向下标访问张量, 通过切片可以一次访问多个顺序的元素。每个维度方向上都可以进行各自独立的切片访问,最终得到分布在不同维度方向上的多个张量元素。例如,对于一个二维张量,可以使用切片来访问张量中的元素,如 tensor[0:2, 0:2]

切片的一般形式是 [start: end: step],通过三个参数和冒号来定义切片的方式。其中:

  • start 是闭合区间即包含 start 索引的元素,end 是开区间即不包含 end 索引的元素。
  • start 如果缺省则默认为 0,end 如果缺省则默认为 len(tensor)。
  • 在 PyTorch 中 step 必须是正整数,不允许负数步长。

一维张量的索引

一维张量索引是从左到右,从 0 开始的。

# (21) 张量的索引与切片
# 一维张量的索引
t1 = torch.arange(10)*2
print(t1)  # tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18])

# 访问索引位置为 0 的元素
print(t1[1])  # tensor(2),零维张量
print(t1[1].item())  # 2,取出零维张量的值

# 一维张量的切片
print(t1[0:2])  # tensor([0, 2])
print(t1[:2])  # tensor([0, 2])
print(t1[::2])  # tensor([ 0,  4,  8, 12, 16])

特别注意:

  1. 张量元素的索引是零维张量,而不是一个数值!要得到张量元素索引的数值,需要使用 item()方法。
  2. 张量索引的结果与原始的张量共享内存,修改原始张量或张量索引中一个的值,则另一个也会被修改。
  3. 注意 t1[:2] 是对第 0 列到 第 1 列进行切片,而 t1[::2] 是对所有列以 2 为步长进行切片。

二维张量的索引

二维张量索引的逻辑与一维张量索引相同。形状为 (nx, ny) 的二维张量,可以理解为由 nx 个一维张量构成,每个一维张量由 ny 个元素构成 。

对于一个二维张量,可以使用两个整数来访问其中的元素,如 tensor[0][1] 表示索引第 0 行、第 1 列的元素。

# 二维张量的索引
t2 = torch.arange(16).reshape(4, 4)
print(t2[0, 1])  # tensor(1)
print(t2[0][1])  # tensor(1)
print(t2[0, 1].item())  # 1

# 二维张量的切片
print(t2[::2, ::2])  # tensor([[0, 2], [8, 10]])
print(t2[::2][::2])  # tensor([[0, 1, 2, 3]])
print(t2[::2])  # tensor([[0, 1, 2, 3], [8, 9, 10, 11]])
print((t2[::2])[::2])  # tensor([[ 0, 1, 2, 3]])

特别注意:

  1. 二维张量元素的索引 t2[0][1],也可以写成 t2[0,1],结果相同,都是零维张量。

  2. 但是,t2[::2][::2]t2[::2, ::2] 切片的结果不同。二维切片 t2[::2, ::2] 使用逗号隔开时,可以理解为对二维张量的全局索引,取隔行隔列的元素。二维切片 t2[::2][::2]在两个中括号中时,可以理解为对二维张量的两次隔行切片 (t2[::2])[::2],先取隔行切片构成一个新的张量 (t2[::2]),又对新的张量 (t2[::2]) 再进行隔行索引,因此切片的结果是张量 t2 的第 0 行。


三维张量的索引

形状为 (nx, ny, nz) 三维张量,可以理解为由 nx 个二维张量构成,每个二维张量由 ny 个一维张量构成,每个一维张量由 nz 个元素构成 。

# 三维张量的索引和切片
t3 = torch.arange(27).reshape(3, 3, 3)
print(t3[0, 1, 2])  # tensor(5)
print(t3[0, 1, 2].item())  # 5
print(t3[0, ::2, ::2])  # tensor([[0, 2], [6, 8]])
print(t3[0, :, ::2])  # tensor([[0, 2], [3, 5], [6, 8]])

索引函数 index_select

PyTorch 还提供了函数 torch.index_select 来对张量进行索引。

torch.index_select(input, dim, index, *, out=None) → Tensor

参数说明:

input,张量,要索引的张量
dim,整数,要索引的维度
index,整形张量,包括索引序号的一维张量

函数 index_select() 表示在张量的哪个维度进行索引,索引的位值是多少,返回依 index 索引数据拼接的张量。

# 张量的索引函数
x = torch.arange(12).reshape(3, 4)
indices = torch.tensor([0, 2])
t1 = torch.index_select(x, 0, indices)
t2 = torch.index_select(x, 1, indices)
print(x)  # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [ 8, 9, 10, 11]])
print(t1)  # tensor([[0, 1, 2, 3], [ 8, 9, 10, 11]])
print(t2)  # tensor([[0, 2], [4, 6], [8, 10]])

6.2 张量的拆分

拆分函数 split

PyTorch 提供了函数 torch.split 将张量拆分为块,每个块都是输入张量的一个视图。

torch.split(tensor, split_size_or_sections, dim=0) → List[Tensor]

参数说明:

  • tensor,张量,要拆分的张量
  • split_size_or_sections,整数类型,表示单个块的大小;或整型列表,表示每个块大小的列表。
  • dim,整数,拆分张量的维度
  • List[Tensor],返回值,张量的列表

函数说明:

  1. split_size_or_sections 是整数类型时,则将张量按整数拆分为大小相等的块,不能整除时最后一块将更小。
  2. split_size_or_sections 是整型列表时,则按列表长度将张量拆分为若干块,每块大小是对应的列表元素值。
# (22) 张量的拆分
x = torch.arange(10).reshape(5, 2)
s1 = torch.split(x, 2)
s2 = torch.split(x, [2, 3])
print(x)
# tensor([[0, 1],
#         [2, 3],
#         [4, 5],
#         [6, 7],
#         [8, 9]])
print(s1)
# (tensor([[0, 1], [2, 3]]),
#  tensor([[4, 5], [6, 7]]),
#  tensor([[8, 9]]))
print(s2)
# (tensor([[0, 1], [2, 3]]),
#  tensor([[4, 5], [6, 7], [8, 9]]))

拆分函数 chunk

PyTorch 提供了函数 torch.chunk 将张量拆分为指定数量的块,每个块都是输入张量的一个视图。

torch.chunk(input, chunks, dim=0) → List[Tensor]

参数说明:

  • input,张量,要拆分的张量
  • chunk,整数,表示返回的块数。
  • dim,整数,拆分张量的维度
  • List[Tensor],返回值,张量的列表

函数说明:

  1. 如果沿着维度 dim 的张量大小可以被块数 chunk 整除,那么所有块的大小相同。
  2. 如果沿着维度 dim 的张量大小不能被块数 chunk 整除,那么最后一块将更小,其它所有块都的大小相同。
  3. 如果无法按以上规则拆分,则函数返回的块数可能少于指定的块数 chunk。
# 张量拆分函数 chunk
t1 = torch.arange(10)
print(t1.chunk(3))
# (tensor([0, 1, 2, 3]), tensor([4, 5, 6, 7]), tensor([8, 9]))
print(t1.chunk(4))
# (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9]))
print(t1.chunk(5))
# (tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))

6.3 张量的连接

函数 torch.cat 和函数 torch.stack 都可以实现张量的连接,区别在于是否产生新的维度。函数 torch.cat 不产生新的维度,相当于张量堆叠;而函数 torch.stack 可以在新的维度上进行拼接,产生高维张量。

张量连接函数 cat

函数 torch.cat 沿着指定维度 dim 对输入的张量序列进行连接。

torch.cat(tensors, dim=0, *, out=None) → Tensor

参数说明:

  • tensors,张量序列(列表或元组),要连接的张量
  • dim,整数,张量连接的维度,可选项,默认为 0

函数说明:

  • 张量序列中所有非空的张量,进行连接的维度 dim 的大小可以不同,其它维度的形状必须相同。
  • 函数 torch.cat 是函数 torch.chunk 的反向操作。
# (23) 张量的连接函数
# 张量连接函数 cat
t1 = torch.arange(6).reshape(2, 3)
t2 = torch.arange(10, 13).reshape(1, 3)

c1 = torch.cat((t1, t1), dim=0)
print(c1)
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [0, 1, 2],
#         [3, 4, 5]])
c2 = torch.cat((t1, t1, t1), dim=1)
print(c2)
# tensor([[0, 1, 2, 0, 1, 2],
#         [3, 4, 5, 3, 4, 5]])
c3 = torch.cat((t1, t2), dim=0)
print(c3)
# tensor([[ 0,  1,  2],
#         [ 3,  4,  5],
#         [10, 11, 12]])
# c4 = torch.cat((t1, t2), dim=1)  # 报错, dim=1 维度不一致

张量连接函数 stack

函数 torch.stack 沿着一个新维度对输入的张量序列进行连接。

torch.stack(tensors, dim=0, *, out=None) → Tensor

参数说明:

  • tensors,张量序列(列表或元组),要连接的张量
  • dim,整数,插入的维度,0~len(out) 之间的整数

函数说明:

  • 张量序列中所有张量的形状必须相同。
  • 将张量序列在增加的新维度 dim 进行堆叠,例如把多个二维张量拼接为一个三维张量,多个三维张量拼接为一个四维张量。
# 张量连接函数 stack
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
x2 = torch.zeros_like(x1)
x3 = torch.ones_like(x1)
print(x1.shape)  # torch.Size([2, 3])

xStack1 = torch.stack((x1, x2), dim=0)
print(xStack1.shape)  # torch.Size([2, 2, 3])
print(xStack1)

xStack2 = torch.stack((x1, x2), dim=1)
print(xStack2.shape)  # torch.Size([2, 2, 3])
print(xStack2)

print(torch.stack((x1, x2, x3), dim=2).shape)  # torch.Size([2, 3, 3])
# print(torch.stack((x1, x2), dim=3).shape)  # IndexError: Dimension out of range

结果如下:

x1.shape: torch.Size([2, 3])
xStack1.shape: torch.Size([2, 2, 3])
tensor([[[1, 2, 3],
         [4, 5, 6]],
        [[0, 0, 0],
         [0, 0, 0]]])
xStack2.shape: torch.Size([2, 2, 3])
tensor([[[1, 2, 3],
         [0, 0, 0]],
        [[4, 5, 6],
         [0, 0, 0]]])
xStack3.shape: torch.Size([2, 3, 3])

7. PyTorch张量与 Numpy/列表/元组的转换

7.1 Numpy/列表/元组转换为张量

可以将 Numpy 数组、列表 list 或元组 tuple 转换为张量。

# (24) 将 Numpy 数组、列表 list 或元组 tuple转换为张量
# Numpy 数组转换为张量
xnp = np.array([[1, 2, 3], [4, 5, 6]])
x1 = torch.tensor(xnp)
print(x1)
x2 = torch.from_numpy(xnp)
print(x2)
# list 列表转换为张量
xlist = [[1, 2, 3], [4, 5, 6]]
x2 = torch.tensor(xlist)
print(x2)
# tuple 元组转换为张量
xtuple = ((1, 2, 3), (4, 5, 6))
x3 = torch.tensor(xtuple)
print(x3)

输出为:

tensor([[1, 2, 3],
[4, 5, 6]], dtype=torch.int32)

tensor([[1, 2, 3],
[4, 5, 6]], dtype=torch.int32)

tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[1, 2, 3],
[4, 5, 6]])

torch.tensor(xnp) 与 torch.from_numpy(xnp) 都能将 Numpy 数组转换为张量,但 torch.from_numpy() 可以自动识别并保留 Numpy 数组的数据类型。


7.2 张量转换为Numpy/列表/元组

函数 torch.numpy() 用于将张量转换为 Numpy 数组,函数 torch.tolist() 用于将张量转换为 list 列表。

# (25) 张量转换为Numpy/列表/元组
X = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 张量转换为 Numpy 数组
xnp = X.numpy()
print(type(xnp))  # <class 'numpy.ndarray'>
print(xnp)  # [[1 2 3], [4 5 6]]
# 张量转换为 list 列表
xlist = X.tolist()
print(type(xlist))  # <class 'list'>
print(xlist)  # [[1, 2, 3], [4, 5, 6]]
# 张量转换为 list 列表
Xlist = list(X)
print(type(Xlist))  # <class 'list'>
print(Xlist)  # [tensor([1, 2, 3]), tensor([4, 5, 6])]

注意:由于 Numpy 数组不支持 GPU 运算,要把 GPU 上创建的张量转换为 Numpy 数组,需要先通过 to('cpu')cpu() 函数把 GPU 上的张量转移到 CPU 上,然后再分离数据结构。


7.3 提取张量的数值

函数 torch.item() 用于将 0 维张量转换为数值,此方法只能用于只有一个元素的 0 维张量,通常用于输出最终的计算结果。

# 0 维张量转换为 数值
tout = torch.tensor(6)  # 仅用于只有一个元素的张量
print("tout =", tout.item())  # tout = 6

8. 序列化张量

序列化张量是指保存某个时序的张量到文件中,在后面可以装载到程序中。

PyTorch 在内部使用 pickle 来序列化张量对象,并为存储添加专用的序列化代码。

通过下述方法可以将张量 points 保存到 ourpoints.t 文件中:

# (26) 序列化张量
# Method 1
torch.save(points, './ourpoints.t')
# Method 2
with open('./ourpoints.t', 'wb') as f:
    torch.save(points, f)

从 ourpoints.t 中读取:

# Method 1
points = torch.load('./ourpoints.t')
# Method 2
with open('./ourpoints.t', 'rb') as f:
    points = torch.load(f)

上述保存张量的方法很简单,但是对于 .t 的文件,只能通过 PyTorch 打开。

我们也可以使用 h5py 来序列化到 HDF5,但需要安装 h5py。

conda install h5py

使用方法如下。

import h5py
f = h5py.File('ourpoints.hdf5', 'w')
dset = f.create_dataset('coords', data=points.numpy())
f.close()

HDF5 的优势是可以在磁盘上索引数据集,并且只访问我们需要的元素。

例如只想要 points 中最后两个点的坐标,例程如下:

f = h5py.File('ourpoints.hdf5', 'r')
dset = f['coords']
last_points = dset[-2:]
last_points
'''
array([[5., 3.],
       [2., 1.]], dtype=float32)
'''
last_points = torch.from_numpy(dset[-2:])
f.close()
last_points
'''
tensor([[5., 3.],
        [2., 1.]])
'''

【本节完】


版权声明:
欢迎关注『youcans的深度学习』系列,转发请注明原文链接:
【youcans的深度学习 06】PyTorch入门教程:张量的基本操作 2 (https://youcans.blog.csdn.net/article/details/130564877)
Copyright 2023 youcans, XUPT
Crated:2023-05-08

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/515259.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面向对象程序设计概述

&#x1f9d1;‍&#x1f4bb;CSDN主页&#xff1a;夏志121的主页 &#x1f4cb;专栏地址&#xff1a;Java核心技术专栏 目录 一、类 二、对象 三、识别类 四、类之间的关系 面向对象程序设计&#xff08;Object-Oriented Programming,OOP)是当今的主流程序设计范型&#x…

线段树详解

目录 线段树的概念 线段树的实现 线段树的存储 需要4n大小的数组 线段树的区间是确定的 线段树的难点在于lazy操作 代码样例 线段树的概念 线段树&#xff08;Segment Tree&#xff09;是一种平衡二叉树&#xff0c;用于解决区间查询问题。它将一个区间划分成若干个子区…

Android 车载值不值得入手学?

前言 随着智能车的不断普及和智能化程度的提高&#xff0c;车载系统也在逐步升级和演进&#xff0c;越来越多的汽车厂商开始推出采用Android系统的车载设备&#xff0c;这为Android车载开发提供了广泛的市场需求。 其次&#xff0c;随着人工智能技术的发展和应用&#xff0c;…

Linux : 安装源码包

安装源码包之前我们要准备好yum环境&#xff0c;或者使用默认上网下载的yum仓库或者查看&#xff1a;Linux&#xff1a;rpm查询安装 && yum安装_鲍海超的博客-CSDN博客 准备离线yum仓库 &#xff0c;默认的需要在有网环境下才能去网上下载 其次就是安装 gcc make 准…

UDP协议 sendto 和 recvfrom 浅析与示例

UDP&#xff08;user datagram protocol&#xff09;用户数据报协议&#xff0c;属于传输层。 UDP是面向非连接的协议&#xff0c;它不与对方建立连接&#xff0c;而是直接把数据报发给对方。UDP无需建立类如三次握手的连接&#xff0c;使得通信效率很高。因此UDP适用于一次传…

Kyligence Zen 一站式指标平台体验——“绝对实力”的指标分析和管理工具——入门体验评测

&#x1f996;欢迎观阅本本篇文章&#xff0c;我是Sam9029 文章目录 前言Kyligence Zen 是什么Kyligence Zen 能做什么Kyligence Zen 优势在何处 正文注册账号平台功能模块介绍指标图表新建指标指标模板 目标仪表盘数据设置 实际业务体验---使用官网数据范例使用流程归因分析指…

MySQL --- 多表设计

关于单表的操作(包括单表的设计、单表的增删改查操作)我们就已经学习完了。接下来我们就要来学习多表的操作&#xff0c;首先来学习多表的设计。 项目开发中&#xff0c;在进行数据库表结构设计时&#xff0c;会根据业务需求及业务模块之间的关系&#xff0c;分析并设计表结构…

ChatGPT-4怎么对接-ChatGPT-4强化升级了哪些功能

ChatGPT-4怎么使用 使用ChatGPT-4&#xff0c;需要通过OpenAI的API接口来对接ChatGPT-4。OpenAI是一个人工智能公司&#xff0c;为开发者提供多个API接口&#xff0c;包括自然语言处理&#xff0c;图像处理等。ChatGPT-4是OpenAI开发的最新版本的聊天式对话模型&#xff0c;可…

React antd Form item「受控组件与非受控组件」子组件 defaultValue 不生效等问题总结

一、为什么 Form.Item 下的子组件 defaultValue 不生效&#xff1f; 当你为 Form.Item 设置 name 属性后&#xff0c;子组件会转为受控模式。因而 defaultValue 不会生效。你需要在 Form 上通过 initialValues 设置默认值。name 字段名&#xff0c;支持数组 类型&#xff1a;N…

2.存储器层次系统

存储器 随机访问存储器 RAM&#xff08;随机存储器&#xff09; SRAM 双稳态触发器&#xff0c;有电就保持不变&#xff0c;干扰消除后时会恢复到稳定值&#xff0c;晶体管多因此密集度低 DRAM 每个位存储为对一个电容的充电&#xff0c;对干扰敏感&#xff0c;漏电所以需要刷…

静态数码管

静态数码管 1、简介工作方式数码管静态显示原理 2、硬件设计3、软件设计4、 1、简介 一般共阳极数码管更为常用 好处&#xff1a;将驱动数码管的工作交到公共端&#xff08;一般接驱动电源&#xff09;&#xff0c;加大驱动电源的功率自然要比加大IC芯片I/O口的驱动电流简单许…

【python 生成器】零基础也能轻松掌握的学习路线与参考资料

一、学习路线 了解生成器的概念和作用 首先&#xff0c;需要明确生成器的概念和作用&#xff0c;生成器是一种特殊的迭代器&#xff0c;它可以在循环中逐个地产生值&#xff0c;而不是一次性将所有的值产生出来。它的作用是使程序更加高效&#xff0c;达到节省内存等的效果。…

Linux 入门

文章目录 一、概述二、安装CentOS下载地址VMware下载地址 三、linux文件与目录结构Linux系统中一切皆文件Linux目录结构 四、VI/VIM 编辑器vi/vim是什么一般模式常用语法键盘图编辑模式指令模式 五、网络配置六、远程登陆七、系统管理Linux 中的进程和服务service 服务管理chkc…

几种常见的电源防反接电路

电源防反接&#xff0c;也即是防止电源的正负极搞反而导致电路损坏&#xff0c;例如你采用的是标准的DC口&#xff0c;那么没什么必要加入此种电路。而如果采用的是非常规的&#xff0c;如自定义的接插件等&#xff0c;那么就很有必要了。 举个例子&#xff1a;小编以前就采用…

企业在线制作帮助中心,选择:语雀、石墨、Baklib哪个好?

在当今互联网时代&#xff0c;越来越多的企业开始将帮助中心建设在线化。在线帮助中心的好处不仅可以提高用户的使用体验&#xff0c;也可以提高企业的工作效率。然而&#xff0c;选择一个合适的在线制作帮助中心工具却并不是一件容易的事情。在众多的在线制作帮助中心工具中&a…

Python3 入门教程||Python3 SMTP发送邮件||Python3 多线程

Python3 SMTP发送邮件 在Python3 中应用的SMTP&#xff08;Simple Mail Transfer Protocol&#xff09;即简单邮件传输协议,它是一组用于由源地址到目的地址传送邮件的规则&#xff0c;由它来控制信件的中转方式。 python的 smtplib 提供了一种很方便的途径发送电子邮件。它对…

[cryptoverse CTF 2023] crypto部分

没打,完事作作题. Warmup 1 Decode the following ciphertext: GmvfHt8Kvq16282R6ej3o4A9Pp6MsN. Remember: CyberChef is your friend. Another great cipher decoding tool is Ciphey. 热身一下就凉,问了别人,用ciphey说是能自动解,但是安装报错 rot13base58 这个没有自动的…

JavaCollection集合:概述、体系特点、常用API、遍历方式

一、集合概述 集合和数组都是容器 数组 特点&#xff1a;数组定义完成并启动后&#xff0c;类型确定、长度固定。 劣势&#xff1a;在进行增删数据操作的时候&#xff0c;数组是不太合适的&#xff0c;增删数据都需要放弃原有数组或者移位。 使用场景&#xff1a;当业务数…

JMeter 常用的几种断言方法,你会了吗?

JMeter是一款常用的负载测试工具&#xff0c;通过模拟多线程并发请求来测试系统的负载能力和性能。在进行性能测试时&#xff0c;断言&#xff08;Assertion&#xff09;是非常重要的一部分&#xff0c;可以帮助我们验证测试结果的正确性。下面介绍JMeter常用的几种断言方法。 …

MySQL 运算符解析

1.算术运算符 算术运算符主要用于数学运算&#xff0c;其可以连接运算符前后的两个数值或表达式&#xff0c;对数值或表达式进行加 &#xff08;&#xff09;、减&#xff08;-&#xff09;、乘&#xff08;*&#xff09;、除&#xff08;/&#xff09;和取模&#xff08;%&…