【深度学习】超详细的 PyTorch 学习笔记(上)

news2024/10/5 18:33:04

文章目录

  • 一、PyTorch环境检查
  • 二、查看张量类型
  • 三、查看张量尺寸和所占内存大小
  • 四、创建张量
    • 4.1 创建值全为1的张量
    • 4.2 创建值全为0的张量
    • 4.3 创建值全为指定值的张量
    • 4.4 通过 list 创建张量
    • 4.5 通过 ndarray 创建张量
    • 4.6 创建指定范围和间距的有序张量
    • 4.7 创建单位矩阵(对角线为1)
  • 五、生成随机张量
    • 5.1 按均匀分布生成
    • 5.2 按标准正态分布生成
    • 5.3 生成指定区间的整型随机张量
    • 5.4 获取随机序列
  • 六、张量的索引与切片
    • 6.1 索引
    • 6.2 切片
      • 6.2.1 获取张量的前/后N个元素
      • 6.2.2 根据指定步长获取张量的前/后N个元素
      • 6.2.3 根据特殊索引获取张量值
      • 6.2.4 根据 mask 选取张量值
      • 6.2.5 根据展平的索引获取张量值
  • 七、张量的维度变换
    • 7.1 view 和 reshape 尺寸变换
    • 7.2 unsqueeze 升维
    • 7.3 squeeze 降维
    • 7.4 expand 扩展
    • 7.5 repeat 复制
    • 7.6 .t() 转置
    • 7.7 transpose 维度变换
    • 7.8 permute 维度变换
  • 八、张量的拼接和拆分
    • 8.1 cat
    • 8.2 stack
    • 8.3 split
    • 8.4 chunk
  • 九、基本运算
    • 9.1 广播机制
    • 9.2 matmul 矩阵/张量乘法
    • 9.3 pow 次方运算
    • 9.4 sqrt 平方根运算
    • 9.5 exp 指数幂运算
    • 9.6 log 对数运算(相当于 ln)
    • 9.7 取整
    • 9.8 clamp 控制张量的取值范围
  • 十、统计属性
    • 10.1 norm 求范数
    • 10.2 mean、median、sum、min、max、prod、argmax、argmin
    • 10.3 topk 获取最大的k个值
    • 10.4 kthvalue 获取第k大的值
    • 10.5 比较运算函数
  • 十一、高级操作
    • 11.1 where
    • 11.2 gather


一、PyTorch环境检查

import torch
# 输出PyTorch版本
print(torch.__version__)
# 检查PyTorch是否支持GPU加速
print("cuda:", torch.cuda.is_available())

输出:

1.8.0+cu101
cuda: True

二、查看张量类型

import torch

a = torch.randn(2, 3)
b = torch.randint(0, 1, (2, 3))
print(a.type())
print(b.type())
print(type(a))
print(type(b))
print(isinstance(a, torch.FloatTensor))
print(isinstance(b, torch.FloatTensor))

输出:

torch.FloatTensor
torch.LongTensor
<class 'torch.Tensor'>
<class 'torch.Tensor'>
True
False

三、查看张量尺寸和所占内存大小

import torch

a = torch.randn(2, 3)
print(a.size(), type(a.size()))
print(a.shape, type(a.shape))
print("维度数:", a.dim())
print("所占内存大小:", a.numel())

输出:

torch.Size([2, 3]) <class 'torch.Size'>
torch.Size([2, 3]) <class 'torch.Size'>
维度数: 2
所占内存大小: 6

四、创建张量

4.1 创建值全为1的张量

import torch

a = torch.ones(2, 3)
print(a)

输出:

tensor([[1., 1., 1.],
        [1., 1., 1.]])

4.2 创建值全为0的张量

import torch

a = torch.zeros(2, 3)
print(a)

输出:

tensor([[0., 0., 0.],
        [0., 0., 0.]])

4.3 创建值全为指定值的张量

import torch

a = torch.full([2, 3], 6.6)
print(a)
print(a.shape)

a = torch.full([], 6.6)
print(a)
print(a.shape)

输出:

tensor([[6.6000, 6.6000, 6.6000],
        [6.6000, 6.6000, 6.6000]])
torch.Size([2, 3])
tensor(6.6000)
torch.Size([])

4.4 通过 list 创建张量

import torch

print(torch.LongTensor([[1, 2], [3, 4]]))
print(torch.Tensor([[1, 2], [3, 4]]))
print(torch.FloatTensor([[1, 2], [3, 4]]))

输出:

tensor([[1, 2],
        [3, 4]])
tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])

4.5 通过 ndarray 创建张量

import torch
import numpy as np

a = np.array([2, 3.3])
print(type(a))

print(torch.from_numpy(a))

输出:

<class 'numpy.ndarray'>
tensor([2.0000, 3.3000], dtype=torch.float64)

4.6 创建指定范围和间距的有序张量

import torch

print("torch.arange(0,10):", torch.arange(0, 10))
print("torch.arange(0,10,2):", torch.arange(0, 10, 2))
print("torch.linspace(0,10,steps = 4):", torch.linspace(0, 10, steps=4))
print("torch.linspace(0,10,steps = 10):", torch.linspace(0, 10, steps=10))
print("torch.linspace(0,10,steps = 11):", torch.linspace(0, 10, steps=11))
print("torch.logspace(0,-1,steps = 10):", torch.logspace(0, -1, steps=10))
print("torch.logspace(0,1,steps = 10):", torch.logspace(0, 1, steps=10))

输出:

torch.arange(0,10): tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
torch.arange(0,10,2): tensor([0, 2, 4, 6, 8])
torch.linspace(0,10,steps = 4): tensor([ 0.0000,  3.3333,  6.6667, 10.0000])
torch.linspace(0,10,steps = 10): tensor([ 0.0000,  1.1111,  2.2222,  3.3333,  4.4444,  5.5556,  6.6667,  7.7778,
         8.8889, 10.0000])
torch.linspace(0,10,steps = 11): tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
torch.logspace(0,-1,steps = 10): tensor([1.0000, 0.7743, 0.5995, 0.4642, 0.3594, 0.2783, 0.2154, 0.1668, 0.1292,
        0.1000])
torch.logspace(0,1,steps = 10): tensor([ 1.0000,  1.2915,  1.6681,  2.1544,  2.7826,  3.5938,  4.6416,  5.9948,
         7.7426, 10.0000])

4.7 创建单位矩阵(对角线为1)

import torch

# n * n
print(torch.eye(3))
print(torch.eye(4, 4))

# 非 n * n
print(torch.eye(2, 3))

输出:

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
tensor([[1., 0., 0.],
        [0., 1., 0.]])

五、生成随机张量

5.1 按均匀分布生成

均匀分布:0-1之间

import torch

# 生成shape为(2,3,2)的Tensor
random_tensor = torch.rand(2, 3, 2)
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)

输出:

tensor([[[0.3321, 0.7077],
         [0.8372, 0.2545],
         [0.5849, 0.0312]],

        [[0.6792, 0.8339],
         [0.9689, 0.5579],
         [0.2843, 0.6578]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])

5.2 按标准正态分布生成

标准正态分布:均值为0,方差为1

import torch

# 生成shape为(2,3,2)的Tensor
random_tensor = torch.randn(2, 3, 2)
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)

输出:

tensor([[[ 1.6622,  1.4002],
         [ 1.5145, -0.0427],
         [ 0.4082, -0.3527]],

        [[ 1.2381, -0.2409],
         [-1.0770, -1.1289],
         [-1.5798,  0.3093]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])

5.3 生成指定区间的整型随机张量

import torch

# 生成shape为(2,3,2)的Tensor
# 整数范围[1,4)
random_tensor = torch.randint(1, 4, (2, 3, 2))
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)

输出:

tensor([[[3, 2],
         [2, 2],
         [3, 3]],

        [[2, 1],
         [2, 1],
         [1, 1]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])

5.4 获取随机序列

import torch

# torch中没有random.shuffle
# y = torch.randperm(n) y是把0到n-1这些数随机打乱得到的一个数字序列
# randperm(n, out=None, dtype=torch.int64)-> LongTensor
idx = torch.randperm(3)
a = torch.Tensor(4, 2)
print(a)
print(idx, idx.type())
print(a[idx])

输出:

tensor([[0.0000e+00, 5.5491e-43],
        [1.8754e+28, 8.0439e+20],
        [4.2767e-05, 1.0413e-11],
        [4.2002e-08, 6.5558e-10]])
tensor([0, 1, 2]) torch.LongTensor
tensor([[0.0000e+00, 5.5491e-43],
        [1.8754e+28, 8.0439e+20],
        [4.2767e-05, 1.0413e-11]])

六、张量的索引与切片

6.1 索引

import torch

a = torch.rand(4, 3, 28, 28)
print("a[0].shape:", a[0].shape)
print("a[0,0].shape:", a[0, 0].shape)
print("a[0,0,2,4]:", a[0, 0, 2, 4])

输出:

a[0].shape: torch.Size([3, 28, 28])
a[0,0].shape: torch.Size([28, 28])
a[0,0,2,4]: tensor(0.2935)

6.2 切片

6.2.1 获取张量的前/后N个元素

import torch

a = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)

print("a[:2].shape:", a[:2].shape)
print("a[:2,:1,:,:].shape:", a[:2, :1, :, :].shape)
print("a[:2,1:,:,:].shape:", a[:2, 1:, :, :].shape)
print("a[:2,-1:,:,:].shape:", a[:2, -1:, :, :].shape)

输出:

a.shape: torch.Size([4, 3, 28, 28])
a[:2].shape: torch.Size([2, 3, 28, 28])
a[:2,:1,:,:].shape: torch.Size([2, 1, 28, 28])
a[:2,1:,:,:].shape: torch.Size([2, 2, 28, 28])
a[:2,-1:,:,:].shape: torch.Size([2, 1, 28, 28])

6.2.2 根据指定步长获取张量的前/后N个元素

import torch

a = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)

print("a[:,:,0:28:2,0:28:2].shape:", a[:, :, 0:28:2, 0:28:2].shape)
print("a[:,:,::2,::2].shape:", a[:, :, ::2, ::2].shape)

输出:

a.shape: torch.Size([4, 3, 28, 28])
a[:,:,0:28:2,0:28:2].shape: torch.Size([4, 3, 14, 14])
a[:,:,::2,::2].shape: torch.Size([4, 3, 14, 14])

6.2.3 根据特殊索引获取张量值

import torch

a = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)

print("a.index_select(0,torch.tensor([0,2])).shape:", a.index_select(0, torch.tensor([0, 2])).shape)
print("a.index_select(1,torch.tensor([1,2])).shape:", a.index_select(1, torch.tensor([1, 2])).shape)
print("a.index_select(2,torch.arange(28)).shape:", a.index_select(2, torch.arange(28)).shape)
print("a.index_select(2,torch.arange(8)).shape:", a.index_select(2, torch.arange(8)).shape)

print("a[...].shape:", a[...].shape)
print("a[0,...].shape:", a[0, ...].shape)
print("a[:,1,...].shape:", a[:, 1, ...].shape)
print("a[...,:2].shape:", a[..., :2].shape)

输出:

a.shape: torch.Size([4, 3, 28, 28])
a.index_select(0,torch.tensor([0,2])).shape: torch.Size([2, 3, 28, 28])
a.index_select(1,torch.tensor([1,2])).shape: torch.Size([4, 2, 28, 28])
a.index_select(2,torch.arange(28)).shape: torch.Size([4, 3, 28, 28])
a.index_select(2,torch.arange(8)).shape: torch.Size([4, 3, 8, 28])
a[...].shape: torch.Size([4, 3, 28, 28])
a[0,...].shape: torch.Size([3, 28, 28])
a[:,1,...].shape: torch.Size([4, 28, 28])
a[...,:2].shape: torch.Size([4, 3, 28, 2])

6.2.4 根据 mask 选取张量值

import torch

a = torch.rand(3,4)
print(a)

mask = a.ge(0.5)
print(mask)

b = torch.masked_select(a, mask)
print(b)
print(b.shape)

输出:

tensor([[0.6119, 0.3231, 0.8763, 0.6680],
        [0.5421, 0.0359, 0.2040, 0.2894],
        [0.5961, 0.7953, 0.2759, 0.7808]])
tensor([[ True, False,  True,  True],
        [ True, False, False, False],
        [ True,  True, False,  True]])
tensor([0.6119, 0.8763, 0.6680, 0.5421, 0.5961, 0.7953, 0.7808])
torch.Size([7])

6.2.5 根据展平的索引获取张量值

import torch

a = torch.Tensor([[4, 3, 5], [6, 7, 8]])
print(a)
print(torch.take(a, torch.tensor([0, 2, -1])))

输出:

tensor([[4., 3., 5.],
        [6., 7., 8.]])
tensor([4., 5., 8.])

七、张量的维度变换

7.1 view 和 reshape 尺寸变换

view 和 reshape 的用法一致

import torch

a = torch.rand(4, 1, 28, 28)
print(a.shape)

print(a.view(4, 28 * 28).shape)
print(a.view(4 * 28, 28).shape)
print(a.view(4, 28, 28, 1).shape)

输出:

torch.Size([4, 1, 28, 28])
torch.Size([4, 784])
torch.Size([112, 28])
torch.Size([4, 28, 28, 1])

7.2 unsqueeze 升维

import torch

a = torch.rand(4, 1, 28, 28)
print("a.shape:", a.shape)

print("a.unsqueeze(0).shape:", a.unsqueeze(0).shape)
print("a.unsqueeze(-1).shape:", a.unsqueeze(-1).shape)

b = torch.rand(32)
print("b.shape:", b.shape)
print("b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape:", b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape)

输出:

a.shape: torch.Size([4, 1, 28, 28])
a.unsqueeze(0).shape: torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(-1).shape: torch.Size([4, 1, 28, 28, 1])
b.shape: torch.Size([32])
b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape: torch.Size([1, 32, 1, 1])

7.3 squeeze 降维

import torch

b = torch.rand(4, 1, 28, 28)
print("b.shape:", b.shape)

print("b.squeeze().shape:", b.squeeze().shape)
print("b.squeeze(0).shape:", b.squeeze(0).shape)
print("b.squeeze(-1).shape:", b.squeeze(-1).shape)

输出:

b.shape: torch.Size([4, 1, 28, 28])
b.squeeze().shape: torch.Size([4, 28, 28])
b.squeeze(0).shape: torch.Size([4, 1, 28, 28])
b.squeeze(-1).shape: torch.Size([4, 1, 28, 28])

7.4 expand 扩展

import torch

b = torch.rand(1, 32, 1, 1)
print("b.shape:", b.shape)

print("b.expand(4,32,14,14).shape:", b.expand(4, 32, 14, 14).shape)
print("b.expand(-1,32,-1,-1).shape:", b.expand(-1, 32, -1, -1).shape)
print("b.expand(-1,32,-1,4).shape:", b.expand(-1, 32, -1, 4).shape)

输出:

b.shape: torch.Size([1, 32, 1, 1])
b.expand(4,32,14,14).shape: torch.Size([4, 32, 14, 14])
b.expand(-1,32,-1,-1).shape: torch.Size([1, 32, 1, 1])
b.expand(-1,32,-1,4).shape: torch.Size([1, 32, 1, 4])

7.5 repeat 复制

import torch

b = torch.rand(1, 32, 1, 1)
print("b.shape:", b.shape)

print("b.repeat(4,32,1,1).shape:", b.repeat(4, 32, 1, 1).shape)
print("b.repeat(4,1,1,1).shape:", b.repeat(4, 1, 1, 1).shape)
print("b.repeat(4,1,32,32).shape:", b.repeat(4, 1, 32, 32).shape)

输出:

b.shape: torch.Size([1, 32, 1, 1])
b.repeat(4,32,1,1).shape: torch.Size([4, 1024, 1, 1])
b.repeat(4,1,1,1).shape: torch.Size([4, 32, 1, 1])
b.repeat(4,1,32,32).shape: torch.Size([4, 32, 32, 32])

7.6 .t() 转置

import torch

b = torch.rand(3, 4)
print(b)
print(b.t())

输出:

tensor([[0.3598, 0.3820, 0.9488, 0.2987],
        [0.7339, 0.2339, 0.5251, 0.2017],
        [0.8442, 0.6528, 0.2914, 0.5034]])
tensor([[0.3598, 0.7339, 0.8442],
        [0.3820, 0.2339, 0.6528],
        [0.9488, 0.5251, 0.2914],
        [0.2987, 0.2017, 0.5034]])

7.7 transpose 维度变换

import torch

a = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a.transpose(1, 3).shape)

输出:

torch.Size([4, 3, 28, 28])
torch.Size([4, 28, 28, 3])

7.8 permute 维度变换

import torch

a = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a.permute(0,2,3,1).shape)

输出:

torch.Size([4, 3, 28, 28])
torch.Size([4, 28, 28, 3])

八、张量的拼接和拆分

8.1 cat

import torch

a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
c = torch.cat([a, b], dim=0)
print(c.shape) # torch.Size([9, 32, 8])

8.2 stack

import torch

a1 = torch.rand(4,3,16,32)
a2 = torch.rand(4,3,16,32)
c = torch.stack([a1,a2],dim = 2)
print(c.shape) # torch.Size([4, 3, 2, 16, 32])

8.3 split

import torch

a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b], dim=0)
print(c.shape)  # torch.Size([2, 32, 8])

aa, bb = c.split([1, 1], dim=0)
print(aa.shape, bb.shape)  # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])

aa, bb = c.split([20, 12], dim=1)
print(aa.shape, bb.shape)  # torch.Size([2, 20, 8]) torch.Size([2, 12, 8])

8.4 chunk

import torch

a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b], dim=0)
print(c.shape)  # torch.Size([2, 32, 8])

aa, bb = c.chunk(2, dim=0)
print(aa.shape, bb.shape)  # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])

aa, bb = c.chunk(2, dim=1)
print(aa.shape, bb.shape)  # torch.Size([2, 16, 8]) torch.Size([2, 16, 8])

aa, bb, cc, dd = c.chunk(4, dim=1)
print(aa.shape, bb.shape, cc.shape,
      dd.shape)  #torch.Size([2, 8, 8]) torch.Size([2, 8, 8]) torch.Size([2, 8, 8]) torch.Size([2, 8, 8])

九、基本运算

9.1 广播机制

import torch

a = torch.rand(2,2)
print(a)
b = torch.rand(2)
print(b)
print(a+b)

输出:

tensor([[0.4668, 0.6053],
        [0.5321, 0.8734]])
tensor([0.7595, 0.6517])
tensor([[1.2263, 1.2570],
        [1.2916, 1.5251]])

9.2 matmul 矩阵/张量乘法

import torch

a = torch.ones(2, 2) * 3
b = torch.ones(2, 2)
print(a)
print(b)
print(torch.matmul(a, b))

输出:

tensor([[3., 3.],
        [3., 3.]])
tensor([[1., 1.],
        [1., 1.]])
tensor([[6., 6.],
        [6., 6.]])

9.3 pow 次方运算

import torch

a = torch.ones(2, 2) * 3
print(a)
print(torch.pow(a, 3))

输出:

tensor([[3., 3.],
        [3., 3.]])
tensor([[27., 27.],
        [27., 27.]])

9.4 sqrt 平方根运算

import torch

a = torch.ones(2, 2) * 9
print(a)
print(torch.pow(a, 0.5))
print(torch.sqrt(a))

输出:

tensor([[9., 9.],
        [9., 9.]])
tensor([[3., 3.],
        [3., 3.]])
tensor([[3., 3.],
        [3., 3.]])

9.5 exp 指数幂运算

import torch

a = torch.ones(2, 2)
print(a)
print(torch.exp(a))

输出:

tensor([[1., 1.],
        [1., 1.]])
tensor([[2.7183, 2.7183],
        [2.7183, 2.7183]])

9.6 log 对数运算(相当于 ln)

import torch

a = torch.ones(2, 2) * 3
print(a)
print(torch.log(a))

输出:

tensor([[3., 3.],
        [3., 3.]])
tensor([[1.0986, 1.0986],
        [1.0986, 1.0986]])

9.7 取整

  • floor():向下取整
  • ceil():向上取整
  • round():四舍五入
  • trunc():截取整数部分
  • frac():截取小数部分
import torch

a = torch.tensor(3.14)
print(a)  # tensor(3.1400)
print(torch.floor(a))  #tensor(3.)
print(torch.ceil(a))  #tensor(4.)
print(torch.round(a))  #tensor(3.)
print(torch.trunc(a))  #tensor(3.)
print(torch.frac(a))  #tensor(0.1400)

9.8 clamp 控制张量的取值范围

import torch

a = torch.rand(2, 3) * 15

print(a)
# 将大于8的值设置为8;小于4的值设置为4
print(torch.clamp(a, 4, 8))

输出:

tensor([[ 8.8872,  5.6534, 14.3027],
        [ 0.8305, 12.6266, 13.9683]])
tensor([[8.0000, 5.6534, 8.0000],
        [4.0000, 8.0000, 8.0000]])

十、统计属性

10.1 norm 求范数

import torch

a = torch.ones(2, 3)
b = torch.norm(a)  # 默认求2范数
c = torch.norm(a, p=1)  # 指定求1范数
print(a)
print(b)
print(c)

输出:

tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor(2.4495)
tensor(6.)

10.2 mean、median、sum、min、max、prod、argmax、argmin

  • prod():返回张量里所有元素的乘积
  • armax():返回张量中最大元素的展平索引
  • argmin():返回张量中最小元素的展平索引
import torch

a = torch.arange(8).view(2, 4).float()
print(a)
'''
tensor([[0., 1., 2., 3.],
        [4., 5., 6., 7.]])
'''
print(a.mean())  #tensor(3.5000)
print(a.median())  #tensor(3.)
print(a.sum())  #tensor(28.)
print(a.min())  #tensor(0.)
print(a.max())  #tensor(7.)
print(a.prod())  #tensor(0.)
print(a.argmax())  #tensor(7)
print(a.argmin())  #tensor(0)
import torch

a = torch.rand(2,4)
print(a)

print(a.max(dim = 1))
print(a.max(dim = 1,keepdim = True))

输出:

tensor([[0.7239, 0.9412, 0.7602, 0.2131],
        [0.6277, 0.1033, 0.8300, 0.9909]])
torch.return_types.max(
values=tensor([0.9412, 0.9909]),
indices=tensor([1, 3]))
torch.return_types.max(
values=tensor([[0.9412],
        [0.9909]]),
indices=tensor([[1],
        [3]]))

10.3 topk 获取最大的k个值

import torch

a = torch.rand(2,4)
print(a)
print(a.topk(2,dim=1))
'''
tensor([[0.3247, 0.9220, 0.4314, 0.8123],
        [0.7133, 0.2471, 0.0281, 0.3595]])
torch.return_types.topk(
values=tensor([[0.9220, 0.8123],
        [0.7133, 0.3595]]),
indices=tensor([[1, 3],
        [0, 3]]))
'''

10.4 kthvalue 获取第k大的值

import torch

a = torch.rand(2, 4)
print(a)
print(a.kthvalue(3,dim=1))
'''
tensor([[0.0980, 0.0479, 0.9298, 0.5638],
        [0.9095, 0.9071, 0.4913, 0.6144]])
torch.return_types.kthvalue(
values=tensor([0.5638, 0.9071]),
indices=tensor([3, 1]))
'''

10.5 比较运算函数

import torch

a = torch.rand(2, 3)
print(a)
'''
tensor([[0.1196, 0.5068, 0.9272],
        [0.6395, 0.2433, 0.9702]])
'''
# a >= 0.5
print(a.ge(0.5))
'''
tensor([[False,  True,  True],
        [ True, False,  True]])
'''
# a > 0.5
print(a.gt(0.5))
'''
tensor([[False,  True,  True],
        [ True, False,  True]])
'''
# a <= 0.5
print(a.le(0.5))
'''
tensor([[ True, False, False],
        [False,  True, False]])
'''
# a < 0.5
print(a.lt(0.5))
'''
tensor([[ True, False, False],
        [False,  True, False]])
'''
# a = 0.5
print(a.eq(0.5))
'''
tensor([[False, False, False],
        [False, False, False]])
'''

十一、高级操作

11.1 where

import torch

cond = torch.rand(2, 2)
a = torch.zeros(2, 2)
b = torch.ones(2, 2)
print(cond)
'''
tensor([[0.3622, 0.9658],
        [0.1774, 0.6670]])
'''
print(a)
'''
tensor([[0., 0.],
        [0., 0.]])
'''
print(b)
'''
tensor([[1., 1.],
        [1., 1.]])
'''
# 满足条件cond.ge(0.5)的按照a的对应元素赋值,否则按照b的对应元素赋值
print(torch.where(cond.ge(0.5), a, b))
'''
tensor([[1., 0.],
        [1., 0.]])
'''

11.2 gather

帮助我们从批量tensor中取出指定乱序索引下的数据

import torch

a = torch.arange(3, 12).view(3, 3)
print(a)
'''
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
'''

index = torch.tensor([[2, 1, 0]])
print(a.gather(1, index)) # tensor([[5, 4, 3]])

index = torch.tensor([[2, 1, 0]]).t()
print(a.gather(1, index))
'''
tensor([[5],
        [7],
        [9]])
'''

index = torch.tensor([[0, 2],
                      [1, 2]])
print(a.gather(1, index))
'''
tensor([[3, 5],
        [7, 8]])
'''

参考链接:图解PyTorch中的torch.gather函数

在强化学习DQN中的使用 gather() 函数

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/78357.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【力扣算法简单五十题】23.环形链表

给你一个链表的头节点 head &#xff0c;判断链表中是否有环。 如果链表中有某个节点&#xff0c;可以通过连续跟踪 next 指针再次到达&#xff0c;则链表中存在环。 为了表示给定链表中的环&#xff0c;评测系统内部使用整数 pos 来表示链表尾连接到链表中的位置&#xff08;索…

基于多种优化算法及神经网络的光伏系统控制(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️❤️&#x1f4a5;&#x1f4a5;&#x1f4a5; &#x1f389;作者研究&#xff1a;&#x1f3c5;&#x1f3c5;&#x1f3c5;本科计算机专业&#xff0c;研究生电气学硕…

NNDL 实验八 网络优化与正则化(3)不同优化算法比较

文章目录7.3 不同优化算法的比较分析7.3.1 优化算法的实验设定7.3.1.1 2D可视化实验7.3.1.2 简单拟合实验7.3.1.3 与Torch API对比&#xff0c;验证正确性7.3.2 学习率调整7.3.2.1 AdaGrad算法7.3.2.2 RMSprop算法7.3.3 梯度估计修正7.3.3.1 动量法7.3.3.2 Adam算法7.3.4 不同优…

【并发】深度解析CAS原理与底层源码

【并发】深度解析CAS原理与底层源码 什么是 CAS&#xff1f; CAS全称是&#xff08;Compare And Swap&#xff0c;比较并交换&#xff09;&#xff0c;通常指的是这样一种原子操作&#xff08;针对一个变量&#xff0c;首先比较它的内存值与某个期望值是否相同&#xff0c;如…

不就是Redis吗?竟让我一个月拿了8个offer,其中两家都是一线大厂

在高并发的场景Redis是必须的&#xff0c;而 Redis非关系型内存存储不可谓不彪悍。 支持异步持久化达到容灾&#xff1a;速度快、并发高。官方号称支持并发11万读操作&#xff0c;并发8万写操作。惊了吗&#xff1f; 支持数据结构丰富&#xff1a;string&#xff08;字符串&a…

盘点5种最频繁使用的检测异常值的方法(附Python代码)

本文介绍了数据科学家必备的五种检测异常值的方法。 无论是通过识别错误还是主动预防&#xff0c;检测异常值对任何业务都是重要的。本文将讨论五种检测异常值的方法。 文章目录什么是异常值&#xff1f;为什么我们要关注异常值&#xff1f;技术提升方法1——标准差方法2——箱…

【OpenEnergyMonitor】开源的能源监控系统--项目介绍

OpenEnergyMonitor1. 系统框架2.项目组成2.1 emonPi模块:2.1.1 emonpi的安装&#xff1a;2.1.2 emonTx & emonBase 安装2.1.3 emonTx Wifi 安装&#xff1a;2.1.4 添加额外的 emonTx 节点&#xff1a;2.1.5 添加额外的emonTx-节点监控三项电源2.1.6 添加 emonTH 温度节点2.…

【Vue核心】8.计算属性

1. 定义: 要用的属性不存在,要通过已有属性计算得来。 2. 原理 底层借助了objcet.defineproperty方法提供的getter fllsetter. 3. get两数什么时候执行? (1),初次读取时会执行一次。 (2),当依赖的数据发生改变时会被再次调用。 4. 优势 与methods实现相比,内部有缓存机…

进厂手册:Git 学习笔记(详解命令)

文章目录git 对象通过git对象进行文件的保存git对象的缺点树对象构建树对象提交对象高层命令工作区的文件状态git reset hard 咋用以及用错了怎么恢复git checkout vs git resetGit存储后悔药工作区暂存区版本库reset三部曲checkout深入理解tag远程上的相关操作ssh登入一些个人…

[附源码]计算机毕业设计家庭整理服务管理系统Springboot程序

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

LeetCode刷题复盘笔记—一文搞懂动态规划之213. 打家劫舍 II问题(动态规划系列第十八篇)

今日主要总结一下动态规划完全背包的一道题目&#xff0c;213. 打家劫舍 II 题目&#xff1a;213. 打家劫舍 II Leetcode题目地址 题目描述&#xff1a; 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋&#xff0c;每间房内都藏有一定的现金。这个地方所有的房屋都 围成一…

快速排序详解

快速排序&#xff0c;简称快排。其实看快速排序的名字就知道它肯定是一个很牛的排序&#xff0c;C语言中的qsort和C中的sort底层都是快排。 快速排序由于排序效率在同为O(N*logN)的几种排序方法中效率较高&#xff0c;因此经常被采用&#xff0c;再加上快速排序思想----分治法…

Opencv 基本操作五 各种连通域处理方法

在深度学习中&#xff0c;尤其是语义分割模型部署的结果后处理中&#xff0c;离不开各类形态学处理方法&#xff0c;其中以连通域处理为主&#xff1b;同时在一些传统的图像处理算法中&#xff0c;也需要一些形态学、连通域处理方法。为此&#xff0c;整理了一些常用的连通域处…

leetcode每日一题寒假版:1691. 堆叠长方体的最大高度 (hard)( 换了皮的最长递增子序列)

2022-12-10 1691. 堆叠长方体的最大高度 (hard) &#x1f6a9; 学如逆水行舟&#xff0c;不进则退。 —— 《增广贤文》 题目描述&#xff1a; 给你 n 个长方体 cuboids &#xff0c;其中第 i 个长方体的长宽高表示为 cuboids[i] [width(i), length(i), height(i)]&#xf…

Docker补充知识点--自定义网络实现直连容器

前面介绍docker镜像的秘密这篇知识点的时候&#xff0c;https://blog.csdn.net/dudadudadd/article/details/128200522&#xff0c;提到了docker容器也有属于自己的IP的概念&#xff0c;默认的Docker容器是采用的是bridge网络模式。并且提到了一嘴自定义网卡配置&#xff0c;本…

java基于Springboot的健身房课程预约平台-计算机毕业设计

项目介绍 开发语言&#xff1a;Java 框架&#xff1a;springboot JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 数据库工具&#xff1a;Navicat11 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven 本健身网站系统是针…

Unity纹理优化:缩小包体

Android打包apk大小约&#xff1a;475M 查看打包日志&#xff1a;Console→Open Editor Log; 或者依赖第三方插件&#xff1a;build reports tool&#xff08;在unity store里可以下载&#xff09;&#xff1b; 定位问题 经过排查后&#xff0c;发现项目中纹理占比很高&#…

分布式能源的不确定性——风速测试(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️❤️&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清…

(6)Pytorch数据处理

Pytorch 数据处理 要点总结 1、功能 Dataset&#xff1a;准备数据集&#xff0c;一般会针对自己的数据集格式重写Dataset&#xff0c;定义数据输入输出格式 Dataloader&#xff1a;用于加载数据&#xff0c;通常不用改这部分内容 2、看代码时请关注 Dataloader中collate_fn 传入…

【云原生】K8s Ingress rewrite与TCP四层转发讲解与实战操作

文章目录一、背景二、K8s Ingress安装三、K8s Ingress rewrite 讲解与使用1&#xff09;配置说明2&#xff09;示例演示1、部署应用2、配置ingress rewrite转发&#xff08;http&#xff09;3、配置ingress rewrite转发&#xff08;https&#xff09;【1】创建证书&#xff08;…