官方文档在这里,说的比较清楚,但是举的例子不是很直观。我们再详细解释一下:
torch.roll(input, shifts, dims=None) → Tensor
- input:输入的tensor
- shifts:滚动的方向和长度,若为正,则向索引大的方向滚动;若为负,则向索引减小的方向滚动。可是一个整数也可以是一个元组
- dims:tensor滚动的维度;要和shifts设置的数量对齐。
这里要特别指出,如果移动的位置已经超出本身维度的大小,就补到反方向去。即不会丢弃数值,也不会凭空补齐数值。会循环滚动。
# 二维tensor举例
x = torch.tensor([[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 0, 0, 0]])
# 第0维向索引大的方向滚动1个位置,即整体向下移动1个像素,第1行的元素由第4行元素补齐
y = torch.roll(x, shifts=(1), dims=(0))
tensor([[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0]])
# 第0维向索引小的方向滚动1个位置,即整体向上移动1个像素,第4行的元素由第1行元素补齐
x = torch.roll(x, shifts=(-1), dims=(0))
tensor([[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]])
# 第1维向索引大的方向滚动2个位置,即整体向右移动2个像素,第1、2列的元素由第3、4列元素补齐
x = torch.roll(x, shifts=(0, 2), dims=(0, 1))
tensor([[0, 0, 0, 0],
[1, 0, 0, 1],
[1, 0, 0, 1],
[0, 0, 0, 0]])
# 第0、1维向索引小的方向滚动2像素,即整体向下、右移动2个像素
x = torch.roll(x, shifts=(2, 2), dims=(0, 1))
tensor([[1, 0, 0, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 1]])
torch.roll()的使用比较简单,在实际中的应用也比较多,比如在swin-transformer中,利用torch.roll()进行MSA的计算,具体原理就不讲解了: