【深度学习】pytorch——Tensor(张量)详解

news2024/10/6 8:26:24

笔记为自我总结整理的学习笔记,若有错误欢迎指出哟~

pytorch——Tensor

  • 简介
  • 创建Tensor
    • torch.Tensor( )和torch.tensor( )的区别
      • torch.Tensor( )
      • torch.tensor( )
    • tensor可以是一个数(标量)、一维数组(向量)、二维数组(矩阵)和更高维的数组(高阶数据)。
      • 标量(scalar )
      • 向量(vector)
      • 矩阵(matrix)
  • 常用Tensor操作
    • 调整tensor的形状
      • tensor.view
      • tensor.squeeze与tensor.unsqueeze
        • tensor.squeeze(dim)
        • tensor.unsqueeze(dim)
      • None 可以为张量添加一个新的轴(维度)
    • 索引操作
      • 切片索引
      • gather( )
  • 高级索引
  • Tensor数据类型
  • Tensor逐元素
  • Tensor归并操作
  • Tensor比较操作
  • Tensor线性代数
  • Tensor和Numpy
  • Tensor的数据结构

简介

Tensor,又名张量。它可以是一个数(标量)、一维数组(向量)、二维数组(矩阵)和更高维的数组(高阶数据)。Tensor和Numpy的ndarrays类似,但PyTorch的tensor支持GPU加速。
在这里插入图片描述

官方文档
https://pytorch.org/docs/stable/tensors.html

import torch  as t
t.__version__		# '2.1.0+cpu'

创建Tensor

创建方法示例输出
通过给定数据创建张量torch.Tensor([1, 2, 3])tensor([1., 2., 3.])
通过指定tensor的形状torch.Tensor(2, 3)tensor([[1.1395e+23, 1.6844e-42, 0.0000e+00],[0.0000e+00, 0.0000e+00, 0.0000e+00]])
使用torch.arange()创建连续的张量torch.arange(0, 10, 2)tensor([0, 2, 4, 6, 8])
使用torch.zeros()创建全零张量torch.zeros((3, 4))tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]])
使用torch.ones()创建全一张量torch.ones((2, 2))tensor([[1., 1.], [1., 1.]])
使用torch.randn()创建随机张量torch.randn((3, 3))tensor([[ 1.0553, -0.4815, 0.6344], [-0.7507, 1.3891, 1.0460], [-0.5625, 1.9531, -0.5468]])
使用torch.rand()创建在0到1之间均匀分布的随机张量torch.rand((3, 3))tensor([[1, 6, 5], [2, 0, 4], [8, 5, 7]])
使用torch.randint()创建在给定范围内的整数随机张量torch.randint(low=0, high=10, size=(3, 3))tensor([[0, 8, 9], [1, 8, 7], [4, 4, 4]])
使用torch.eye()创建单位矩阵torch.eye(5)tensor([[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]])
从Python列表或Numpy数组创建张量torch.tensor([1, 2, 3])torch.tensor(np.array([1, 2, 3]))tensor([1, 2, 3])或tensor([1, 2, 3], dtype=torch.int32)
将整个张量填充为常数值torch.full((3, 3), 3.14)tensor([[3.1400, 3.1400, 3.1400], [3.1400, 3.1400, 3.1400], [3.1400, 3.1400, 3.1400]])
创建指定大小的空张量torch.empty((3, 3))tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])
创建长度为5的随机排列张量torch.randperm(5)tensor([1, 2, 0, 3, 4])

torch.Tensor( )和torch.tensor( )的区别

torch.Tensor( )

torch.Tensor([1, 2, 3]) 的创建方式会根据输入的数据类型来确定张量的数据类型。
例如,如果输入的是整数列表,那么创建的张量将使用默认的数据类型 torch.float32。这意味着即使输入的数据是整数,张的数据类型也会被转换为浮点数类型。

a = t.Tensor([1, 2, 3])
a
# tensor([1., 2., 3.])

torch.Tensor(1,2) 通过指定tensor的形状创建张量

a= t.Tensor(1,2) # 注意和t.tensor([1, 2])的区别
a.shape
# torch.Size([1, 2])

torch.tensor( )

torch.tensor([1, 2, 3]) 的创建方式会根据输入的数据类型灵活地选择张量的数据类型。它可以接受各种数据类型的输入,包括整数、浮点数、布尔值等,并根据输入的数据类型自动确定创建张量使用的数据类型。

a = t.tensor([1, 2, 3])
a
# tensor([1, 2, 3])

tensor可以是一个数(标量)、一维数组(向量)、二维数组(矩阵)和更高维的数组(高阶数据)。

标量(scalar )

scalar = t.tensor(3.14) 
print('scalar: %s, shape of sclar: %s' %(scalar, scalar.shape))

输出为:

scalar: tensor(3.1400), shape of sclar: torch.Size([])

向量(vector)

vector = t.tensor([1, 2, 3])
print('vector: %s, shape of vector: %s' %(vector, vector.shape))

输出为:

vector: tensor([1, 2, 3]), shape of vector: torch.Size([3])

矩阵(matrix)

matrix = t.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])
matrix,matrix.shape

输出为:

(tensor([[0.1000, 1.2000],
         [2.2000, 3.1000],
         [4.9000, 5.2000]]), torch.Size([3, 2]))

常用Tensor操作

方法描述
tensor.view(*args)改变张量形状
tensor.reshape(*args)改变张量形状
tensor.size()返回张量形状
tensor.dim()返回张量维度
tensor.unsqueeze(dim)在指定维度上添加一个新的维度
tensor.squeeze(dim)压缩指定维度的大小为1的维度
tensor.transpose(dim0, dim1)交换两个维度
tensor.permute(*dims)重新排列张量的维度
tensor.flatten()展平所有维度
tensor.mean(dim)沿指定维度计算张量的平均值
tensor.sum(dim)沿指定维度计算张量的和
tensor.max(dim)沿指定维度返回张量的最大值
tensor.min(dim)沿指定维度返回张量的最小值
tensor.argmax(dim)沿指定维度返回张量最大元素的索引
tensor.argmin(dim)沿指定维度返回张量最小元素的索引
tensor.add(value)将标量加到张量中的每个元素
tensor.add(tensor)将另一个张量加到该张量
tensor.sub(value)将标量从张量中的每个元素减去
tensor.sub(tensor)从该张量中减去另一个张量
tensor.mul(value)将张量中的每个元素乘以标量
tensor.mul(tensor)将该张量与另一个张量相乘
tensor.div(value)将张量中的每个元素除以标量
tensor.div(tensor)将该张量除以另一个张量

调整tensor的形状

tensor.view

通过tensor.view方法可以调整tensor的形状,但必须保证调整前后元素总数一致。view不会修改自身的数据,返回的新tensor与原tensor共享内存,也即更改其中的一个,另外一个也会跟着改变。

a = t.arange(0, 6)
a.view(2, 3)

输出结果为:

tensor([[0, 1, 2],
        [3, 4, 5]])
  • 案例1
b = a.view(-1, 2) # 当某一维为-1的时候,会自动计算它的大小
b.shape		# torch.Size([3, 2])
  • 案例2
b = a.view(-1, 3) # 当某一维为-1的时候,会自动计算它的大小
b.shape		# torch.Size([2,3])

tensor.squeeze与tensor.unsqueeze

tensor.squeeze(dim)

tensor.squeeze(dim) 方法用于压缩张量中指定维度大小为1的维度,即将大小为1的维度去除。如果未指定 dim 参数,则会去除所有大小为1的维度。

# 创建一个形状为 (1, 3, 1, 4) 的张量
x = torch.arange(12).reshape(1, 3, 1, 4)
print(x.shape)  # 输出: torch.Size([1, 3, 1, 4])

# 使用 squeeze 去除大小为1的维度
y = x.squeeze()
print(y.shape)  # 输出: torch.Size([3, 4])

# 指定 dim 参数去除指定维度大小为1的维度
z = x.squeeze(0)
print(z.shape)  # 输出: torch.Size([3, 1, 4])
tensor.unsqueeze(dim)

tensor.unsqueeze(dim) 方法用于在指定维度上添加一个新的维度,新的维度大小为1。

# 创建一个形状为 (3, 4) 的张量
x = t.randn(3, 4)
print(x.shape)  # 输出: torch.Size([3, 4])

# 使用 unsqueeze 在维度0上添加新维度
y = x.unsqueeze(0)
print(y.shape)  # 输出: torch.Size([1, 3, 4])

# 使用 unsqueeze 在维度2上添加新维度
z = x.unsqueeze(2)
print(z.shape)  # 输出: torch.Size([3, 4, 1])

None 可以为张量添加一个新的轴(维度)

在 PyTorch 中,使用 None 可以为张量添加一个新的轴(维度)。这个新的轴可以在任何位置添加,从而改变张量的形状。以下是一个示例:

# 创建一个形状为 (3, 4) 的二维张量
a = t.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# 使用 None 在第一维度上新增一个轴
b = a[None, :, :]

print(b.shape)  # 输出: torch.Size([1, 3, 4])

在上面的例子中,使用 None 将张量 a 在第一维度上扩展,结果得到了一个形状为 [1, 3, 4] 的三维张量 b。通过为 a 添加新的轴,我们可以改变张量的维度和形状,从而为其提供更多的灵活性。

索引操作

索引出来的结果与原tensor共享内存,也即修改一个,另一个会跟着修改。

切片索引

# 创建一个形状为 (3, 4) 的二维张量
a = t.tensor([[1, 2, 3, 4],
            [5, 6, 7, 8],
            [9, 10, 11, 12]])

# 使用切片操作访问其中的元素
b = a[:, 1:3]
print(b)
# tensor([[ 2,  3],
#         [ 6,  7],
#         [10, 11]])


# 可以使用 step 参数控制步长
c = a[::2, ::2]
print(c)
# tensor([[ 1,  3],
#         [ 9, 11]])


# 可以使用负数索引从后往前访问元素
d = a[:, -2:]
print(d)
# tensor([[ 3,  4],
#         [ 7,  8],
#         [11, 12]])

gather( )

gather() 是 PyTorch 中的一个张量索引函数,可以用于按照给定的索引从输入张量中检索数据。它的语法如下:

torch.gather(input, dim, index, out=None, sparse_grad=False) -> Tensor

其中,参数含义如下:

  • input:输入张量,形状为 (N*,*C) 或 (N,C,d1,d2,…,dk)。
  • dim:要检索的维度。
  • index:用于检索的索引张量,形状为 (M,) 或(M,d1,d2,…,dk)。
  • out:输出张量,形状与 index 相同。
  • sparse_grad:是否在反向传播时启用稀疏梯度计算。

gather() 函数主要用于按照给定的索引从输入张量中检索数据。具体来说,对于二维输入张量 input 和一维索引张量 indexgather() 函数会返回一个一维张量,其中每个元素是 input 中相应行和 index 中相应列的交点处的数值。对于更高维度的输入,索引张量 index 可以选择任何维度的元素。

示例1

# 创建一个形状为 (3, 4) 的二维张量
input = t.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# 创建一个索引张量,用于按列检索元素
index = t.tensor([[0, 2, 3],
                  [1, 3, 2]])

# 使用 gather 函数按列检索元素,返回一个二维张量
output = t.gather(input, dim=1, index=index)

print(output)
# 输出:
# tensor([[ 1,  3,  4],
#         [ 6,  8,  7]])

在上面的示例中:

  • 创建了一个形状为 (3, 4) 的二维输入张量 input
  • 创建了一个形状为 (2, 3) 的索引张量 index,用于检索元素。
  • 使用 gather() 函数按列检索元素,并将结果存储到输出张量 output 中。

示例2

# 创建一个形状为 (2, 3, 4) 的三维张量
input = t.tensor([[[1, 2, 3, 4],
                       [5, 6, 7, 8],
                       [9, 10, 11, 12]],
                      
                      [[13, 14, 15, 16],
                       [17, 18, 19, 20],
                       [21, 22, 23, 24]]])

# 创建一个形状为 (2, 3) 的索引张量
index = t.tensor([[0, 2, 1],
                      [2, 1, 0]])

# 添加一个维度到索引张量
index = index.unsqueeze(2)

# 使用 gather 函数按第二个维度检索元素
output_dim_1 = t.gather(input, dim=1, index=index)

# 使用 gather 函数按第三个维度检索元素
output_dim_2 = t.gather(input, dim=2, index=index)

print(output_dim_1)
print(output_dim_2)
'''
输出:
tensor([[[ 1],
         [ 9],
         [ 5]],

        [[21],
         [17],
         [13]]])
tensor([[[ 1],
         [ 7],
         [10]],

        [[15],
         [18],
         [21]]])
'''

高级索引

高级索引可以看成是普通索引操作的扩展,但是高级索引操作的结果一般不和原始的Tensor共享内存。

x = t.arange(0,27).view(3,3,3)
print(x)

a = x[[1, 2], [1, 2], [2, 0]] # x[1,1,2]和x[2,2,0]
print(a)

b = x[[2, 1, 0], [0], [1]] # x[2,0,1],x[1,0,1],x[0,0,1]
print(b)

c = x[[0, 2], ...] # x[0] 和 x[2]
print(c)

输出结果为:

tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]]) 

tensor([14, 24]) 

tensor([19, 10,  1]) 

tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]])

Tensor数据类型

以下是常见的 Tensor 数据类型及其相应的字符串表示:

数据类型字符串表示
32 位浮点数‘torch.float32’ 或 ‘torch.float’
64 位浮点数‘torch.float64’ 或 ‘torch.double’
16 位浮点数(半精度)‘torch.float16’ 或 ‘torch.half’
8 位整数(无符号)‘torch.uint8’
8 位整数(有符号)‘torch.int8’
16 位整数‘torch.int16’ 或 ‘torch.short’
32 位整数‘torch.int32’ 或 ‘torch.int’
64 位整数‘torch.int64’ 或 ‘torch.long’
布尔型‘torch.bool’

使用 PyTorch 中的 dtype 属性可以获取 Tensor 的数据类型。例如:

x = t.randn(3, 4) # 创建一个随机的 FloatTensor

print(x.dtype) # 输出 torch.float32

Tensor逐元素

以下是 PyTorch 支持的逐元素操作及其相应的函数名:

操作函数名
加法torch.add()torch.add_()
减法torch.sub()torch.sub_()
乘法torch.mul()torch.mul_()
除法torch.div()torch.div_()
幂运算torch.pow()torch.pow_()
取整torch.floor()torch.floor_()
取整(向上)torch.ceil()torch.ceil_()
取整(四舍五入)torch.round()torch.round_()
指数函数torch.exp()torch.exp_()
对数函数torch.log()torch.log_()
平方根函数torch.sqrt()torch.sqrt_()
绝对值torch.abs()torch.abs_()
正弦函数torch.sin()torch.sin_()
余弦函数torch.cos()torch.cos_()
正切函数torch.tan()torch.tan_()
反正弦函数torch.asin()torch.asin_()
反余弦函数torch.acos()torch.acos_()
反正切函数torch.atan()torch.atan_()

下面是三个逐元素操作的示例:

  1. 加法操作:
x = t.tensor([1, 2, 3])
y = t.tensor([4, 5, 6])

result = t.add(x, y)

print(result)  # 输出 tensor([5, 7, 9])
  1. 平方根函数操作:
x = t.tensor([4.0, 9.0, 16.0])

result = t.sqrt(x)

print(result)  # 输出 tensor([2., 3., 4.])
  1. 绝对值操作:
x = t.tensor([-1, -2, 3, -4])

result = t.abs(x)

print(result)  # 输出 tensor([1, 2, 3, 4])

Tensor归并操作

以下是 PyTorch 支持的归并操作及其相应的函数名:

操作函数名
求和torch.sum()
平均值torch.mean()
方差torch.var()
标准差torch.std()
最小值torch.min()
最大值torch.max()
中位数torch.median()
排序torch.sort()

下面是三个归并操作的示例:

  1. 求和操作:
x = t.tensor([[1, 2], [3, 4]])

result = t.sum(x)

print(result)  # 输出 tensor(10)
  1. 平均值操作:
x = t.tensor([[1, 2], [3, 4]], dtype=t.float)

result = t.mean(x)

print(result)  # 输出 tensor(2.5000)
  1. 最小值操作:
x = t.tensor([[1, 2], [3, 4]])

result = t.min(x)

print(result)  # 输出 tensor(1)

Tensor比较操作

以下是 PyTorch 支持的比较、排序和取最大/最小值的操作及其相应的函数名:

操作函数名功能
大于/小于/大于等于/小于等于/等于/不等于torch.gt()/torch.lt()/torch.ge()/ torch.le()/torch.eq()/torch.ne()对两个张量进行比较,返回一个布尔型张量。
最大的k个数torch.topk()返回输入张量中最大的 k 个元素及其对应的索引。
排序torch.sort()对输入张量进行排序。
比较两个 tensor 最大/最小值torch.max()/torch.min()比较两个张量之间的最大值或最小值,返回一个张量。

下面是三个操作的示例:

  1. topk 操作:
x = t.tensor([1, 3, 2, 4, 5])
result = t.topk(x, k=3)

print(result)  
'''输出:
torch.return_types.topk(
values=tensor([5, 4, 3]),
indices=tensor([4, 3, 1]))
'''

上述代码中,使用 topk() 函数获取张量 x 中的前三个最大值及其索引。

  1. sort 操作:
x = t.tensor([1, 3, 2, 4, 5])
result = t.sort(x)

print(result)  
'''输出:
torch.return_types.sort(
values=tensor([1, 2, 3, 4, 5]),
indices=tensor([0, 2, 1, 3, 4]))
'''

上述代码中,使用 sort() 函数对张量 x 进行排序,并返回排好序的张量及其索引。

  1. max 操作:
x = t.tensor([1, 3, 2, 4, 5])
y = t.tensor([2, 3, 1, 5, 4])

result = t.max(x, y)

print(result)  # 输出 tensor([2, 3, 2, 5, 5])

上述代码中,使用 max() 函数比较张量 xy 中的最大值,并返回一个新的张量。

Tensor线性代数

函数名功能
torch.trace()计算矩阵的迹
torch.diag()提取矩阵的对角线元素
torch.triu()提取矩阵的上三角部分,可指定偏移量
torch.tril()提取矩阵的下三角部分,可指定偏移量
torch.mm()计算两个2维张量的矩阵乘法
torch.bmm()计算两个3维张量的批量矩阵乘法
torch.addmm()将两个矩阵相乘并加上一个矩阵
torch.addbmm()批量矩阵相乘并加上一个矩阵
torch.addmv()矩阵和向量相乘并加上一个向量
torch.addr()计算两个向量的外积
torch.badbmm()批量进行矩阵乘法操作的累积和
torch.t()转置张量或矩阵
torch.dot()计算两个1维张量的点积
torch.cross()计算两个3维张量的叉积
torch.inverse()计算方阵的逆矩阵
torch.svd()计算矩阵的奇异值分解

以下是几个线性代数操作的示例:

  1. torch.trace() 计算矩阵的迹:
x = t.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

result = t.trace(x)

print(result)  # 输出 tensor(15)

上述代码中,我们使用 trace() 函数计算矩阵 x 的迹。

  1. torch.diag() 提取矩阵的对角线元素:
x = t.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

result = t.diag(x)

print(result)  # 输出 tensor([1, 5, 9])

上述代码中,我们使用 diag() 函数提取矩阵 x 的对角线元素。

  1. torch.mm() 计算两个2维张量的矩阵乘法:
x = t.tensor([[1, 2], [3, 4]])
y = t.tensor([[5, 6], [7, 8]])

result = t.mm(x, y)

print(result)  # 输出 tensor([[19, 22], [43, 50]])

上述代码中,我们使用 mm() 函数计算张量 xy 的矩阵乘法。

Tensor和Numpy

Tensor和Numpy数组之间具有很高的相似性,彼此之间的互操作也非常简单高效。需要注意的是,Numpy和Tensor共享内存。

当遇到Tensor不支持的操作时,可先转成Numpy数组,处理后再转回tensor,其转换开销很小。

Numpy和Tensor共享内存

import numpy as np
a = np.ones([2, 3],dtype=np.float32)
print("\na:\n",a)

# Tensor——>Numpy
b = t.from_numpy(a)
print("\nb:\n",b)

a[0, 1]=100
print("\n改变后的b:\n",b)

# Numpy——>Tensor
c = b.numpy() # a, b, c三个对象共享内存
print("\nc:\n",c)

输出结果为:

a:
 [[1. 1. 1.]
 [1. 1. 1.]]

b:
 tensor([[1., 1., 1.],
        [1., 1., 1.]])

改变后的b:
 tensor([[  1., 100.,   1.],
        [  1.,   1.,   1.]])

c:
 [[  1. 100.   1.]
 [  1.   1.   1.]]

注意:numpy的数据类型和Tensor的类型不一样的时候,数据会被复制,不会共享内存

import numpy as np
a = np.ones([2, 3])
print("\na:\n",a)

# Tensor——>Numpy
b = t.Tensor(a)
print("\nb:\n",b)

# Tensor——>Numpy
c = t.from_numpy(a)
print("\nc:\n",c)

a[0, 1]=100
print("\n改变后的b:\n",b,"\n\n改变后的c:\n",c)

输出结果为:

a:
 [[1. 1. 1.]
 [1. 1. 1.]]

b:
 tensor([[1., 1., 1.],
        [1., 1., 1.]])

c:
 tensor([[1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)

改变后的b:
 tensor([[1., 1., 1.],
        [1., 1., 1.]]) 

改变后的c:
 tensor([[  1., 100.,   1.],
        [  1.,   1.,   1.]], dtype=torch.float64)

Tensor的数据结构

tensor分为头信息区(Tensor)和存储区(Storage),信息区主要保存着tensor的形状(size)、步长(stride)、数据类型(type)等信息,而真正的数据则保存成连续数组。由于数据动辄成千上万,因此信息区元素占用内存较少,主要内存占用则取决于tensor中元素的数目,也即存储区的大小。
在这里插入图片描述

一个tensor有着与之相对应的storage, storage是在data之上封装的接口,不同tensor的头信息一般不同,但却可能使用相同的数据。

a = t.arange(0, 6)
print(a.storage())

b = a.view(2, 3)
print(b.storage())

# 一个对象的id值可以看作它在内存中的地址
# storage的内存地址一样,即是同一个storage
id(b.storage()) == id(a.storage())

输出结果为:

0
 1
 2
 3
 4
 5
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]

 0
 1
 2
 3
 4
 5
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]

True

绝大多数操作并不修改tensor的数据,而只是修改了tensor的头信息。这种做法更节省内存,同时提升了处理速度。在使用中需要注意。 此外有些操作会导致tensor不连续,这时需调用tensor.contiguous方法将它们变成连续的数据,该方法会使数据复制一份,不再与原来的数据共享storage。

思考:高级索引一般不共享stroage,而普通索引共享storage,为什么?

在 PyTorch 中,高级索引(advanced indexing)和普通索引(basic indexing)的行为是不同的,这导致了对存储(storage)共享的处理方式也不同。

普通索引是指使用整数、切片或布尔掩码进行索引,例如 tensor[0]tensor[1:3]tensor[mask]。在这种情况下,返回的索引结果与原来的 Tensor 共享相同的存储空间。这意味着对返回的索引结果进行修改会影响到原来的 Tensor,因为它们实际上指向相同的内存位置。

示例代码:

import torch

x = torch.tensor([1, 2, 3, 4, 5])
y = x[1:3]

y[0] = 10

print(x)  # 输出 tensor([ 1, 10,  3,  4,  5])

在上述代码中,对索引结果 y 进行修改后,原始 Tensor x 也被修改了,这是因为它们共享了相同的存储空间。

而对于高级索引,情况不同。高级索引是指使用整数数组或布尔数组进行索引,例如 tensor[[0, 2]]tensor[mask]。在这种情况下,返回的索引结果与原来的 Tensor 不再共享相同的存储空间。返回的索引结果将会创建一个新的 Tensor,其存储空间是独立的。

示例代码:

import torch

x = torch.tensor([1, 2, 3, 4, 5])
indices = torch.tensor([0, 2])
y = x[indices]

y[0] = 10

print(x)  # 输出 tensor([1, 2, 3, 4, 5])

在上述代码中,对索引结果 y 进行修改后,原始 Tensor x 并没有被修改,因为它们不再共享相同的存储空间。

这种差异是由于普通索引和高级索引的底层机制不同所导致的。普通索引可以通过在存储中使用偏移量和步长来定位对应的元素,因此共享存储;而高级索引需要创建一个新的 Tensor 来存储索引结果,因此不共享存储。

了解这种差异很重要,因为它会影响到对原始 Tensor 和索引结果进行操作时是否会相互影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1164774.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

EasyExcel动态复杂表头导出方法

目录 需求分析解决方案数据问题数据导入 需求分析 公司数据比较特殊有一部分数据需要动态修改导致信息导入时表头是不确定的,但其中又有一部分表头是固定的,如下图所示,如果表头全部是固定的话可以通过EasyExcel实体类的注解很轻松的解决&am…

Linux CentOS7.9安装OpenJDK17

Linux CentOS7.9安装OpenJDK17 一、OpenJDK下载 清华大学开源软件镜像站 国内的站点,下载速度贼快 二、上传解压 文件上传到服务器后,解压命令: tar -zxvf jdk-xxxx-linux-x64.tar.gz三、配置环境 export JAVA_HOME/home/local/java/j…

ffmpeg mp3截取命令,视频与mp3合成带音频视频命令

从00:00:03.500开始截取往后长度到结尾的mp3音频(这个更有用,测试好用) ffmpeg -i d:/c.mp3 -ss 00:00:03.500 d:/output.mp3 将两个音频合并成一个音频(测试好用) ffmpeg -i "concat:d:/c.mp3|d:/output.mp3&…

Linux C语言进阶-D10指针数组

指针变量构成的数组 理解下面printf中的a和p的表示,其中p[0]、p[1]、p[2]表示存储的a,a1和a2这几个地址,而再加个*,相当于对地址解引用,从而得到数组中的值。 如下图,要想得到a[0][1]的值可以直接打印a[0][1]&#xf…

使用pytorch处理自己的数据集

目录 1 返回本地文件中的数据集 2 根据当前已有的数据集创建每一个样本数据对应的标签 3 tensorboard的使用 4 transforms处理数据 tranfroms.Totensor的使用 transforms.Normalize的使用 transforms.Resize的使用 transforms.Compose使用 5 dataset_transforms使用 1 返回本地…

YOLOv5:修改backbone为SPPCSPC

YOLOv5:修改backbone为SPPCSPC 前言前提条件相关介绍SPPCSPCYOLOv5修改backbone为SPPCSPC修改common.py修改yolo.py修改yolov5.yaml配置 参考 前言 记录在YOLOv5修改backbone操作,方便自己查阅。由于本人水平有限,难免出现错漏,敬…

一文读懂最小相位滤波器和线性相位滤波器

一文读懂最小相位滤波器和线性相位滤波器 1. 举例说明2. 最小相位定义2.1 最小相位多项式2.2 最大相位滤波器2.3 最小相位意味最快的衰减2.4 最小相位/全通分解 3. 建立最小相位系统 前一篇博客 《一文读懂滤波器的线性相位,全通滤波器,群延迟》 详细解…

curl(二)HTTP协议和头

一 HTTP协议相关 ① 强制发出请求的http1.0 7.29 版本默认是http1.1 ② 查看当前curl版本是否支持http2 方式2: curl --version看Features 补充: 7.33.0 版本才引入 http2,才能使用curl发出http2.0版本的请求 ③ 强制发送http3 说明: 了解即可 二…

DevChat:超越编码的未来 - 优势、安装、使用、以及未来前景

目录 前言1 DevChat的优势1.1 精确的上下文控制1.2 灵活的提示管理1.3 上下文构建1.4 提前准备好的提示模板1.5 高级命令控制 2 安装DevChat插件3 使用DevChat插件3.1 代码生成3.2 文档撰写3.3 解释代码3.4 解决问题3.5 版本控制 4 DevChat的未来前景结语 前言 在软件开发领域…

Docker 多阶段构建的原理及构建过程展示

Docker多阶段构建是一个优秀的技术,可以显著减少 Docker 镜像的大小,从而加快镜像的构建速度,并减少镜像的传输时间和存储空间。本文将详细介绍 Docker 多阶段构建的原理、用途以及示例。 Docker 多阶段构建的原理 在传统的 Docker 镜像构建…

阿里面试:让代码不腐烂,DDD是怎么做的?

说在前面 在40岁老架构师 尼恩的读者交流群(50)中,最近有小伙伴拿到了一线互联网企业如阿里、滴滴、极兔、有赞、希音、百度、网易、美团的面试资格,遇到很多很重要的面试题: 谈谈你的高并发落地经验? 谈谈你对DDD的理解&#xf…

CIM与MES

CIM系统,全称计算机集成制造系统(Computer-Integrated Manufacturing),是一种集成了计算机技术、网络通讯技术和软件系统的制造自动化框架。CIM的主要目标是整合制造过程中的所有活动,包括生产管理、设备管理和品质管理…

物流小程序制作教程:从零到有,详细解析

随着互联网的快速发展,物流行业也逐渐实现了数字化转型。为了满足消费者对更加便捷、高效的服务需求,许多物流企业选择制作自己的小程序。本文将通过乔拓云网后台,带你轻松搭建物流小程序,主要分为以下几个部分: 一、进…

设置echarts折线图虚线

itemStyle:{normal: {lineStyle: { type: solid}}}itemStyle:{normal: {lineStyle: { type: dashed}}}放到每个红框里面

梯度消失和梯度爆炸的原因

梯度消失和梯度爆炸 梯度爆炸和梯度消失本质上是因为梯度反向传播中的连乘效应。 梯度下降算法 举一个简单的例子,函数表达式为loss 2w^2 4w,如下图 ​​​​​​​ ​​​​​​​ 为了求得w的最优值,使得loss最小,从上图很容易看出来当w -1时,loss最小…

ER图设计神器,帮你省时省力,高效完成工作!

ER图(Entity-Relationship Diagram)工具用于设计数据库模型,通常用于表示数据实体、关系和属性之间的关系。以下是10个好用的ER图工具。 一、Lucidchart Lucidchart 是一款基于云的协作式图表设计工具,它允许用户创建、编辑和共享…

SAP发票及复制控制

一、概述 众所周知,SAP为业财一体化的ERP管理系统,因此财务发票必不可少,很多外贸企业还会用到形式发票。 发票相关的配置主要包含:发票类型、复制控制、发票定价以及自动过账的配置。 二、系统配置 1. 发票类型 1.1 概述 发…

Latex安装使用教程

在论文投稿时有些期刊要求使用Latex格式,比如博主现在就遇到了这个问题,木有办法,老老实实的学呗。大家可以去官网下载,但官网的界面设计属实有些一言难尽,因此我们可以使用国内的镜像。 LaTeX 基于 TeX,主…

2023年双十一第2波红包活动淘宝天猫京东双11红包领取优惠券跨店满多少减多少规则?

本文为大家提供众多福利: 2023年淘宝/天猫双十一红包第2波活动时间与领取入口,最高23888超级红包 2023年京东双十一红包第2波活动时间与领取入口,最高11111京享红包 草柴APP领淘宝/天猫、京东大额内部隐藏优惠券,拿购物返利 美…

Content-Type 值有哪些?

1、application/x-www-form-urlencoded 最常见 POST 提交数据的方式。 浏览器的原生 form 表单&#xff0c;如果不设置 enctype 属性&#xff0c;那么最终就会以 application/x-www-form-urlencoded 方式提交数据。 <form action"http://www.haha/ads/sds?name小草莓…