文章目录
- torch.max()
- dim
- keepdim
- dim=0
- dim=1
- out:返回命名元组 (values, indices)
torch.max()
torch.max(input)
→ Tensor:返回 input 张量中所有元素的最大值。
注意输入的必须是张量形式,输出的也为张量形式
当输入为tuple类型时,会报错,需要将输入改为tensor类型,输出也为tensor类型
torch.max():官方文档
torch.max(input,dim,keepdim=False,*,out=None)
主要参数:
- input(Tensor)-输入张量。
- dim(int)-要减少的维度。
- keepdim(bool)-输出张量是否保留了 dim 。默认值: False 。
关键字参数:- out(tuple,optional)-两个输出张量的结果元组(max,max_indices)
dim
对于二维数组来说,dim=0为行,dim=1为列
在torch.max()中代表要减少的维度(dimension)
import torch
a = torch.tensor([1, 2, 3, 4])
max = torch.max(a, dim=0)
print(max)
对于以上程序,由于只存在行,所以torch.max(a, dim=0)
只能减少的维度为行向量,即dim=0
如果 max = max = torch.max(a, dim=1)
,则会报错:维度错误
注:如果在减少的行中存在多个最大值,则返回第一个最大值的索引。
import torch
a = torch.tensor([4, 2, 3, 4])
max = torch.max(a, dim=0)
print(max)
keepdim
输出张量是否保留了 dim,即设置是否保留torch.max(input, dim=0, keepdim=True)
中需要消去的dim。
如果 keepdim 是 True ,则输出张量的大小与 input 相同,除了在维度 dim 中它们的大小为1。
dim=0
二维数组中dim=0代表行,torch.max(a, dim=0)代表消去行,求每列的最大值,keepdim=True则代表保留行
import torch
a = torch.tensor([[1, 2, 3, 4],
[4, 1, 2, 3],
[6, 2, 3, 4],
[3, 4, 5, 9]])
# dim = 0
max1_1 = torch.max(a, dim=0, keepdim=False)
max1_2 = torch.max(a, dim=0, keepdim=True)
print(max1_1)
print(max1_2)
dim=0,消去的维数为行,即求每列的最大值
keepdim=False,vlaues=tensor([6, 4, 5, 9])有一个中括号
keepdim=True,vlaues=tensor([[6, 4, 5, 9]])有两个中括号
indices代表最大值所处的位置(第一列第三个:2,第一列第四个:3,第三列第四个:3,第四列第四个:3)
dim=1
二维数组中dim=1代表列,torch.max(a, dim=0)代表消去列,求每行的最大值,keepdim=True则代表保留列
import torch
a = torch.tensor([[1, 2, 3, 4],
[4, 1, 2, 3],
[6, 2, 3, 4],
[3, 4, 5, 9]])
# dim = 1
max2_1 = torch.max(a, dim=1, keepdim=False)
max2_2 = torch.max(a, dim=1, keepdim=True)
print(max2_1)
print(max2_2)
dim=1,消去的维数为列,即求每行的最大值
keepdim=False,vlaues=tensor([4, 4, 6, 9])有一个中括号
keepdim=True,vlaues=tensor([[4], [4], [6], [9]])有两个中括号
indices代表最大值所处的位置(第一行第四个:3,第二行第一个:0,第三行第一个:0,第四行第四个:3)
out:返回命名元组 (values, indices)
values
是给定维度 dim 中 input 张量的每行的最大值。
indices
是找到的每个最大值(argmax)的索引位置。