torch.clip函数
torch.clip(input, min=None, max=None, out=None)
input:输入张量,即要进行裁剪的张量。
min(可选):裁剪的下限。如果未指定,则不进行下限裁剪。
max(可选):裁剪的上限。如果未指定,则不进行上限裁剪。
out(可选):输出张量,如果提供,结果将存储在这个张量中。
使用
# 语义分割
# 假如有15个类别
result = seg_softmax(x).argmax(dim=1)
# 现在只需要前3个类别的结果
top_3_class_result = torch.clip(result, min=0, max=3)