拼接: | 拆分: |
---|---|
Cat、Stack | Split、Chunk |
1、cat(concat)
统计班级学生成绩:
[class1-4, students, scores]
[class5-9, students, scores]
将这九名学生的成绩进行合并
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
torch.cat([a, b], dim=0).shape
# dim=0,在第一个维度上进行合并
a1 = torch.rand(4, 3, 32, 32)
a2 = torch.rand(5, 3, 32, 32)
torch.cat([a1, a2],dim=0).shape
a3 = torch.rand(4, 1, 32, 32)
torch.cat([a1, a3],dim=1).shape
2、stack
与concat相比,stack不同的是会创造一个新的维度
a1 = torch.rand(4, 3, 16, 32)
a2 = torch.rand(4, 3, 16, 32)
torch.cat([a1,a2],dim=2).shape
torch.stack([a1,a2],dim=2).shape
a = torch.rand(32, 8)
b = torch.rand(32, 8)
torch.stack([a,b],dim=0).shape
3、split
可以根据长度和数量进行拆分
a=torch.rand(32,8)
b=torch.rand(32,8)
c=torch.stack([a,b],dim=0)
①根据长度拆分
aa, bb = c.split(1, dim=0)
②根据数量拆分
c=torch.stack([a,b,b],dim=0)
aa, bb = c.split([2,1],dim=0)
4、chunk
根据数量拆分
c=torch.stack([a,b],dim=0)
aa, bb = c.chunk(2,dim=0)