目录
1).max(1)的使用:
2).max(0)的使用:
1).max(1)的使用:
假设有一个形状为 ( m , n ) 的 Tensor x ,其中m表示行数,n表示列数。
x.max(1) ,相当于x.max(dim=1) 。作用:沿着第 1 维(即列数)的方向,返回 x 的每行最大值和对应的列索引。
import torch
# 构造一个 3x3 的 Tensor x
x = torch.tensor([[0.2, 0.9, 0.1], [0.8, 0.4, 0.3], [0.2, 0.7, 0.6]])
###### .max(1)的用法
values1, indices1 = x.max(1)
print(values1) # 每一行的最大值
print(indices1) # 每一行最大值所对应的列索引
####### 以下是输出结果:
# tensor([0.9000, 0.8000, 0.7000])
# tensor([1, 0, 1])
2).max(0)的使用:
假设有一个形状为 ( m , n ) 的 Tensor x ,其中m表示行数,n表示列数。
x.max(0) ,相当于x.max(dim=0) 。作用:沿着第 0 维(即行数)的方向,返回 x 的每列最大值和对应的行索引。
import torch
# 构造一个 3x3 的Tensor x
x = torch.tensor([[0.2, 0.9, 0.1], [0.8, 0.4, 0.3], [0.2, 0.7, 0.6]])
###### .max(0)的用法
values2, indices2 = x.max(0)
print(values2) # 每一列的最大值
print(indices2) # 每一列最大值所对应的行索引
####### 以下是输出结果:
# tensor([0.8000, 0.9000, 0.6000])
# tensor([1, 0, 2])