torch.argmax()
返回输入中所有元素的最大值的索引,与torch.max()
中返回(values, indices)
中的indices
类似,它也常被用于深度学习中的分类问题。
在下面程序中,使用torch.argmax()
import torch
a = torch.tensor([[1, 2, 3, 4],
[4, 1, 2, 3],
[6, 2, 3, 4],
[3, 4, 5, 9]])
print(torch.argmax(a))
最大值在tensor(15)
的位置
接下来引入dim属性,dim=0
代表消去维数dim=0
(行),即求每列最大值的索引。
print(torch.argmax(a, dim=0))
而dim=1
代表消去维数dim=1
(列),即求每行最大值的索引。
print(torch.argmax(a, dim=1))
再接下来引入keepdim
属性,默认为False
。
它表示是否保留要消去的维数,用上面的程序来示范keepdim=True
的情况,它保留了要消去的列。
print(torch.argmax(a, dim=1, keepdim=True))
在深度学习中,我们常用argmax来预测分类的标签,例如:
import torch
outputs = torch.tensor([[0.1, 0.2],
[0.3, 0.4]])
preds = outputs.argmax(1)
targets = torch.tensor([0, 1])
print((preds == targets).sum().item())
- 假设上面outputs是深度学习模型预测的概率值分布
- argmax(1)代表预测最大概率所在的标签
- 通过预测标签与真实标签相比,如果相等代表预测正确,否则相反,用来表示模型预测的正确率从而评估模型
下面是某深度学习模型在刚开始训练时所预测的标签与真实标签的差异,随着训练的进行,准确率也会不断上升。
官方文档:
torch.argmax(input,dim,keepdim=False)
主要参数:
- input(Tensor)-输入张量。
- dim(int)-要减少的维度。如果为 None ,则返回展平输入的argmax。
- keepdim(bool)-输出张量是否保留了 dim 。如果 dim=None ,则忽略。