学习目标:
在学习PyG时,遇到了 scatter 这个函数,经过学习加上自身的理解,记录如下以备复习
学习内容:
- src:表示输入的tensor,接下来被处理;
- index:表示tensor对应的索引;
- dim:该值取0或者1(-1),默认是1;当
dim=0
时,表示从行
进行分割成元素;当dim=1
时,表示从列
进行分割成元素。 - reduce:表示对应的操作
具体操作如下:
例子1
from torch_scatter import scatter
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = torch.tensor([0, 0, 1], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)
1.首先是dim=0
表示对输入的tensor进行行
分割:[1,2,3],[4,5,6],[7,8,9]。
2.索引index=[0,0,1]表示处理的顺序:第一行元素和第二行元素进行处理,再是第三行的元素进行进行。对第一行元素[1,2,3]和第二行元素[4,5,6]进行reduce='mean'
得到[2.5,3.5,4.5],对第三行元素[7,8,9]进行reduce='mean'得到[7,8,9]
.
例子2
from torch_scatter import scatter
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = torch.tensor([0, 0, 1], dtype=torch.int64)
out = scatter(src, index, dim=1, reduce='mean')
print(out)
1.首先dim=1
表示对输入的tensor进行列向
分割元素[1,4,7]、[2,5,8]和[3,6,9]。
2.索引index=[0,0,1]
表示将[1,4,7]和[2,5,8]首先进行reduce='mean'
操作得到[1.5,4.5,7.5];[3,6,9]进行reduce=mean
操作后仍为[3,6,9],接着将其进行列向
拼接。
例子3–维度问题
from torch_scatter import scatter
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
index = torch.tensor([1, 1, 0,2], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)
1.dim=0
表示从行向
进行分割
[1,2,3]
[4,5,6]
[7,8,9]
[10,11,12]
2.索引index=[1,1,0,2]
,从索引可以看出顺序为[7,8,9]——[1,2,3]和[4,5,6]——[10,11,12],分别进行reduce='mean'
操作得到[7,8,9]——[2.5,3.5,4.5]——[10,11,12]三个tensor,然后进行行向
拼接。