目录
1. scatter() 定义和参数说明
2. 示例和详细解释
3. scatter() 常见用途
1. scatter() 定义和参数说明
scatter() 或 scatter_() 常用来返回根据index映射关系映射后的新的tensor。其中,scatter() 不会直接修改原来的 Tensor,而 scatter_() 直接在原tensor上修改。
官方文档:torch.Tensor.scatter_ — PyTorch 2.0 documentation
参数定义:
- dim:沿着哪个维度进行索引
- index:索引值
- src:数据源,可以是张量,也可以是标量
简言之 scatter() 是通过 src 来修改另一个张量,修改的元素值和位置由 dim 和 index 决定
2. 示例和详细解释
在官方文档中,给出了3维tensor的具体操作说明,看起来很蒙,没关系继续往下看
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
接下来的示例,我们以2维为例,那上面的公式简化为如下,
self[index[i][j]][j] = src[i][j] # if dim == 0
self[i][index[i][j]] = src[i][j] # if dim == 1
示例:将全零的张量,根据index和scr进行值的变化
src = torch.arange(1, 11).reshape((2, 5))
# src: tensor([[0.8351, 0.2974, 0.9028, 0.4250, 0.0370],
# [0.4564, 0.6832, 0.6854, 0.6056, 0.7118]])
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 4, 2, 3]])
dist = torch.zeros(2, 5, dtype=src.dtype).scatter(1, index, src)
# dist: tensor([[0.0370, 0.2974, 0.9028, 0.0000, 0.0000],
# [0.4564, 0.6832, 0.6056, 0.7118, 0.6854]])
将上述张量使用表格表示:
当 dim = 1时,dist[i] [index[i][j]] = src[i][j],所以具体的计算如下
- 当i=0, j=0时,dist[0][index[0][0]] = src[0][0], 即 dist[0][1] = 0.8351
- 当i=0, j=1时,dist[0][index[0][1]] = src[0][1], 即 dist[0][0] = 0.2974
- 当i=0, j=2时,dist[0][index[0][2]] = src[0][2], 即 dist[0][2] = 0.9028
- 当i=0, j=3时,dist[0][index[0][3]] = src[0][3], 即 dist[0][4] = 0.4250
- 当i=0, j=4时,dist[0][index[0][4]] = src[0][4], 即 dist[0][3] = 0.0370
- 当i=1, j=0时,dist[1][index[1][0]] = src[1][0], 即 dist[1][0] = 0.4564
- 当i=1, j=1时,dist[1][index[1][1]] = src[1][1], 即 dist[1][1] = 0.6832
- 当i=1, j=2时,dist[1][index[1][2]] = src[1][2], 即 dist[1][4] = 0.6854
- 当i=1, j=3时,dist[1][index[1][3]] = src[1][3], 即 dist[1][2] = 0.6056
- 当i=1, j=4时,dist[1][index[1][4]] = src[1][4], 即 dist[1][3] = 0.7118
当 dim = 0时,赋值方式不同,即dist[index[i][j]][j] = src[i][j],计算的过程与上述同理,不再赘述
当 src 为标量时,更为简单。因为全部都用该标量值进行赋值即可
3. scatter() 常见用途
scatter() 常用来对标签进行 one-hot 编码。
注意:label的类型是torch.LongTensor
class_num, batch_size = 10, 4
label = torch.tensor([[6], [0], [3], [2]], dtype=torch.long)
one_hot = torch.zeros(batch_size, class_num).scatter(1, label, 1)
print(one_hot)
# tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
# [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
# [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])
这个程序比较好理解,dim =1, label = index, src = 1 ,所以
- dist[0][label[0][0]] = dist[0][6] = 1
- dist[1][label[1][0]] = dist[1][0] = 1
- dist[2][label[2][0]] = dist[2][3] = 1
- dist[3][label[3][0]] = dist[3][2] = 1