在torch中,常用view()函数来改变tensor的形状
查询官方文档:
torch.Tensor.view — PyTorch 2.2 documentationhttps://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view示例
1.创建一个4x4的二维数组进行测试
x = torch.randn(4, 4)
print(x)
print(x.size())
(1).将二维数组变换为一维数组
y = x.view(16)
print(y)
print(y.size())
(2).将二维数组变换为其他形式的二维数组
z = x.view(2, 8)
print(z)
print(z.size())
(3).可以将其中一个参数设置为-1,view()会根据已设置的维度自动推断出另外一个维度的大小
# the size -1 is inferred from other dimensions
yy = x.view(-1, 8)
print(yy)
print(yy.size())
zz = x.view(8, -1)
print(zz)
print(zz.size())
可以看到分别得到了2x8的yy和8x2的zz,符合实际的情况。
2.创建一个1x2x3x4的四维矩阵进行测试
x = torch.rand(1, 2, 3, 4)
print(x)
print(x.size())
(1).将四维数组变换为一维数组
y = x.view(-1)
print(y)
print(y.size())
(2).将四维数组变换为二维数组
z = x.view(2,-1)
print(z)
print(z.size())
(3).将四维数组变换为三维数组
a = x.view(2, -1, 4)
print(a)
print(a.size())
(4).将四维数组转换为其他形式的四维数组
b = x.view(1, 3, 2, 4)
print(b)
print(b.size())
值得注意的是view()函数并不改变tensor数据在内存中的层次
利用tranpose函数进行验证,transpose函数可以交换数据指定的维度:
c = x.transpose(1, 2)
print(c)
print(c.size())
transpose(1,2)将第二个维度和第三个维度互换(四维对应的索引是0,1,2,3)
利用equal()函数判断b和c是否相同:
print("b和c是否相等:")
print(torch.equal(b, c))
由如上结果可知,view()函数并不改变数据在内存中的层次。