相关阅读
Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482
在Pytorch中,flatten是Tensor的一个重要方法,同时它也是一个torch模块中的一个函数,它们的语法如下所示。
Tensor.flatten(start_dim=0, end_dim=-1) → Tensor
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
input (Tensor) – the input tensor
start_dim (int) – the first dim to flatten
end_dim (int) – the last dim to flatten
flatten函数(或方法)用于将一个张量以特定方法展平, 如果传递了一个参数,则会将从start_dim到end_dim之间的维度展开。默认情况下,flatten将从第0维展平至最后1维。
可以看几个例子以更好的理解:
import torch
# 创建一个张量
x = torch.rand(3, 3, 3)
# 使用flatten函数,展平x张量
y=x.flatten()
print(x)
tensor([[[0.2581, 0.8408, 0.0216],
[0.6353, 0.9141, 0.4098],
[0.6391, 0.9829, 0.3967]],
[[0.2167, 0.8983, 0.6492],
[0.1947, 0.4953, 0.3281],
[0.1740, 0.2092, 0.2048]],
[[0.3972, 0.6290, 0.3010],
[0.6107, 0.5429, 0.7515],
[0.7950, 0.0538, 0.8963]]])
print(y)
tensor([0.2581, 0.8408, 0.0216, 0.6353, 0.9141, 0.4098, 0.6391, 0.9829, 0.3967,
0.2167, 0.8983, 0.6492, 0.1947, 0.4953, 0.3281, 0.1740, 0.2092, 0.2048,
0.3972, 0.6290, 0.3010, 0.6107, 0.5429, 0.7515, 0.7950, 0.0538, 0.8963])
print(id(x),id(y))
1185516393792 1185516395312 # 说明两个张量对象不同
print(x.storage().data_ptr(), y.storage().data_ptr())
1185641974912 1185641974912 # 说明两个张量对象里面保存的数据存储是共享的
print(id(x[0,0,0]), id(y[0]))
1186163118464 1186163118464 # 进一步说明两个张量对象里面保存的数据存储是共享的