1 介绍
torch.nn.Flatten(start_dim=1, end_dim=-1)
将一个连续的维度范围扁平化为一个张量
start_dim (int) | 要开始扁平化的第一个维度(默认值 = 1) |
end_dim (int) | 要结束扁平化的最后一个维度(默认值 = -1) |
2 举例
input = torch.randn(32, 1, 5, 5)
m = nn.Flatten()
output = m(input)
output.size()
#torch.Size([32, 25])
m = nn.Flatten(0, 2)
output = m(input)
output.size()
#torch.Size([160, 5])