文章目录
- 前言
- 一、torch.squeeze()函数
- 二、torch.unsqueeze()函数
前言
这两个函数在pytorch框架下的深度学习经常用到,这次把它们记录一下。
一、torch.squeeze()函数
torch.squeeze()用来“挤”掉某一个维度为1的维度,或者所有维度为1的维度。(只挤掉维度为1的维度)
例子如下:
import torch
A=torch.rand(1,3,224,224)
B=torch.unsqueeze(A,dim=0)
print(B.shape)
结果:
一般来说,这个函数多用于最后网络输出图片的可视化。
如果对维度不为1的维度进行去除:
import torch
A=torch.rand(1,3,224,224)
B=torch.squeeze(A,dim=1)
print(B.shape)
A=torch.rand(1,3,224,224)
B=torch.squeeze(A,dim=2)
print(B.shape)
A=torch.rand(1,3,224,224)
B=torch.squeeze(A,dim=3)
print(B.shape)
不会发生变化
二、torch.unsqueeze()函数
torch.unsqueeze()函数用来插入新的维度扩充张量。例子如下:
在第0维度增加一个维度大小为1的维度(也就是在最前面加一个1)
import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=0)
print(B.shape)
结果为:(这个一般用的最多,比如输入的VGG的照片是1,3,224,224.一般的三通道照片是3,224,224,这时就需要用unsqueeze函数)
在第1,2,3维度增加一个维度大小为1的维度,只需要把dim改改就行
import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=1)
print(B.shape)
import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=2)
print(B.shape)
import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=3)
print(B.shape)