直观理解torch.gather函数
1. gather的作用
因为深度学习里面,像分类或者分割,有时候去进行loss计算或准确度计算的时候,需要挑选某个维度特定的值,所以有了这么个函数。注意不要高估这个函数的能力,这个函数只是在指定维度上挑出某个指定的值。
2. gather的解释
官方解释: 链接: https://pytorch.org/docs/stable/generated/torch.gather.html
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
第一种理解方式(白话版):是从input矩阵中,获取一个与index矩阵大小相同的output矩阵。规则是,在获取output值时,按照设定的dim和index对应位置的值,从input矩阵中取值出来。
第二种理解方式(程序员版):把index矩阵的元素对应的下标取出来,用这个取出来的元素的值替换下标中指定的维度(dim)的值。
第三种方式(画图):
橙色矩阵是input矩阵,绿色矩阵是index矩阵,也是output矩阵,因为output矩阵和index矩阵是一样大的。红色矩阵是index矩阵和output矩阵中对应的第一个元素,红色这个元素的值是在三个维度中,指定一个维度,然后在这个维度上去取,取index矩阵这个位置对应的那个值为下标的input矩阵中的值。