目录
1. torch.cat(tensors, dim)
2. torch.stack(tensors, dim)
3. 两者不同
torch.cat() 和 torch.stack()常用来进行张量的拼接,在神经网络里经常用到。且前段时间有一个面试官也问到了这个知识点,虽然内容很小很细,但需要了解。
1. torch.cat(tensors, dim)
- tensors:待拼接的多个张量,可用list, tuple表示
- dim:待拼接的维度,默认是0
- 注意:tensors里不同张量对应的待拼接维度的size可以不一致,但是其他维度的size要保持一致。如代码中待拼接维度是0,x和y对应的维度0上的值不一样,但是其他维度上的值(维度1上的值)要保持一致,即都为4,否则会报错。
示例:新生成的tensor在dim=0这个维度进行了拼接,即 3 + 2 = 5,剩余维度保持不变
x = torch.rand(3, 4)
y = torch.rand(2, 4)
xy = torch.cat([x, y], dim=0)
print(xy.shape) # torch.Size([5, 4])
2. torch.stack(tensors, dim)
- tensors:待拼接的多个张量,可用list, tuple表示
- dim:待拼接的维度,默认是0
- 注意:tensors里所有张量的维度要保持一致,否则会报错
x = torch.rand(7, 4)
y = torch.rand(7, 4)
z = torch.rand(7, 4)
xy = torch.stack([x, y, z])
print(xy.shape) # torch.Size([3, 7, 4])
3. 两者不同
从上面的代码结果可看出两者区别:
- torch.cat会在dim的维度上进行合并,不会扩展出新的维度。
- torch.stack则会在dim的维度上拓展出一个新的维度,然后进行拼接,该维度的大小为tensors的个数