【学习笔记】【Pytorch】三、常用的Transforms
- 学习地址
- 主要内容
- 一、Transforms模块介绍
- 二、transforms.ToTensor类的使用
- 1.使用说明
- 2.代码实现
- 三、transforms.Normalize类的使用
- 1.使用说明
- 2.代码实现
- 四、transforms.Resize类的使用
- 1.使用说明
- 2.代码实现
- 五、transforms.Compose类的使用
- 1.使用说明
- 2.代码实现
- 六、transforms.RandomCrop类的使用
- 1.使用说明
- 2.代码实现
- 总结
- 参考
学习地址
PyTorch深度学习快速入门教程【小土堆】.
主要内容
一、Transforms模块介绍
介绍:PyTorch图像处理与数据增强方法。
二、transforms.ToTensor类的使用
作用:将图片转化成Tensor数据类型。
三、transforms.Normalize类的使用
作用:逐 channel 逐像素地对图像进行标准化。
四、transforms.Resize类的使用
作用:图片尺寸缩放。
五、transforms.Compose类的使用
作用:对多个图像变换API进行打包,顺序使用多个API。
六、transforms.RandomCrop类的使用
作用:在随机位置裁剪给定图像。
一、Transforms模块介绍
from torchvision import transforms
介绍:计算机视觉任务中,对图像的变换(Image Transform)往往是必不可少的操作,例如在迁移学习中,需要对图像尺寸进行变换以使用预训练网络的输入层,又如对数据进行增强以丰富训练数据。pytorch中torchvision.transforms提供的丰富多样的图像变换API
transforms文件夹结构:
二、transforms.ToTensor类的使用
作用:将图片转化成Tensor数据类型。
1.使用说明
【__call__重载】ToTensor_object(img)
- 作用:将图片转化成Tensor向量(Tensor类型数据),变为矩阵进行计算,为了之后的卷积做准备;
- img:PIL Image、numpy.ndarray两种格式。
- 返回值:Tensor格式图片。
- 例子:
tensor_trans = transforms.ToTensor() # 创建一个ToTensor实例
tensor_img = tensor_trans(img) # 返回 tensor 数据类型
2.代码实现
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
# tensor数据类型
# 通过 transforms.ToTensor 去解决两个问题
# 1、transforms该如何使用(python)
# 2、为什么我们需要Tensor数据类型
img_path = "data/train/ants_image/0013035.jpg"
img = Image.open(img_path) # PIL类型的图片数据
writer = SummaryWriter("logs")
# 1、transforms该如何使用(python)
tensor_trans = transforms.ToTensor() # 创建一个ToTensor实例
tensor_img = tensor_trans(img) # 返回 tensor 数据类型
# ToTensor类使用 __call__ 魔法方法重载了 () 运算符,
# 使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用。
# 使用“对象名()”形式可以自动调用 __call__ 方法
# 2、为什么我们需要Tensor数据类型:add_image()可以传入tensor格式的图片等
writer.add_image("Tensor_img", tensor_img) # 输入为tensor格式的图片
writer.close()
TensorBoard输出:
Tensor数据类型(print(tensor_img)):
tensor([[[0.3137, 0.3137, 0.3137, ..., 0.3176, 0.3098, 0.2980],
[0.3176, 0.3176, 0.3176, ..., 0.3176, 0.3098, 0.2980],
[0.3216, 0.3216, 0.3216, ..., 0.3137, 0.3098, 0.3020],
...,
[0.3412, 0.3412, 0.3373, ..., 0.1725, 0.3725, 0.3529],
[0.3412, 0.3412, 0.3373, ..., 0.3294, 0.3529, 0.3294],
[0.3412, 0.3412, 0.3373, ..., 0.3098, 0.3059, 0.3294]],
[[0.5922, 0.5922, 0.5922, ..., 0.5961, 0.5882, 0.5765],
[0.5961, 0.5961, 0.5961, ..., 0.5961, 0.5882, 0.5765],
[0.6000, 0.6000, 0.6000, ..., 0.5922, 0.5882, 0.5804],
...,
[0.6275, 0.6275, 0.6235, ..., 0.3608, 0.6196, 0.6157],
[0.6275, 0.6275, 0.6235, ..., 0.5765, 0.6275, 0.5961],
[0.6275, 0.6275, 0.6235, ..., 0.6275, 0.6235, 0.6314]],
[[0.9137, 0.9137, 0.9137, ..., 0.9176, 0.9098, 0.8980],
[0.9176, 0.9176, 0.9176, ..., 0.9176, 0.9098, 0.8980],
[0.9216, 0.9216, 0.9216, ..., 0.9137, 0.9098, 0.9020],
...,
[0.9294, 0.9294, 0.9255, ..., 0.5529, 0.9216, 0.8941],
[0.9294, 0.9294, 0.9255, ..., 0.8863, 1.0000, 0.9137],
[0.9294, 0.9294, 0.9255, ..., 0.9490, 0.9804, 0.9137]]])
三、transforms.Normalize类的使用
作用:逐 channel 逐像素地对图像进行标准化。
1.使用说明
【实例化】transforms.Normalize(mean, std, inplace=False)
- 作用:创建一个RGB图像每个通道对应的均值mean和标准差std的实例。
- mean:列表,3个通道的均值。
- std:列表,3个通道的标准差。
- 例子:
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
【_call_】Normalize_object(img)
- 作用:逐 channel 逐像素地对图像进行标准化。
- img:仅支持Tensor格式的图片。
- 返回值:标准化后的Tensor格式图片
- 计算公式:
output[channel] = (input[channel] - mean[channel]) / std[channel] - 例子:
img_norm = trans_norm(tensor_img) # 仅支持Tensor格式的图片
2.代码实现
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
img_path = "data/train/ants_image/0013035.jpg"
img = Image.open(img_path) # PIL类型的图片数据
writer = SummaryWriter("logs")
# 1.ToTensor
tensor_trans = transforms.ToTensor() # 创建一个ToTensor实例
tensor_img = tensor_trans(img) # 返回 tensor 数据类型
writer.add_image("Tensor_img", tensor_img) # 输入为tensor格式的图片
# 2.Normalize:逐 channel 逐像素地对图像进行标准化(验证一个像素数据的标准化计算)
print(tensor_img[0][0][0])
# RGB图像每个通道对应一个均值和标准差
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
img_norm = trans_norm(tensor_img) # 仅支持Tensor格式的图片
# 0.3137 * 2 - 1 = -0.3726
print(img_norm[0][0][0])
writer.add_image("Tensor_img_norm", img_norm, 1) # 输入为tensor格式的图片
trans_norm = transforms.Normalize([6, 3, 2], [9, 3, 5])
img_norm = trans_norm(tensor_img) # 仅支持Tensor格式的图片
writer.add_image("Tensor_img_norm", img_norm, 2) # 输入为tensor格式的图片
writer.close()
**控制台输出**:
控制台输出:
(768, 512)
torch.Size([3, 512, 512])
TensorBoard输出:
四、transforms.Resize类的使用
作用:图片尺寸缩放。一般输入深度网络的特征图长宽是相等的,就不能采取等比例缩放的方式了,最好需要同时指定长宽。
1.使用说明
【实例化】transforms.Resize(size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None)
- 作用:创建一个缩放图片尺寸的实例。
- size:列表,(h, w),输出就是为该尺寸;int整数,等比例缩放,将图片短边缩放至size,长宽比保持不变,i.e,如果高度>宽度,则图像将被重新缩放为(size*高度/宽度,size)。
- 例子:
trans_resize = transforms.Resize((512, 512))
trans_resize = transforms.Resize(512)
【_call_】Resize_object(img)
- 作用:图片尺寸缩放。
- img:支持PIL、Tensor格式的图片。
- 返回值:缩放后的原格式图片
- 例子:
img_resize = trans_resize(img) # 支持 PTL、tensor 格式的图片
2.代码实现
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
img_path = "data/train/ants_image/0013035.jpg"
img = Image.open(img_path) # PIL类型的图片数据
writer = SummaryWriter("logs")
# ToTensor
tensor_trans = transforms.ToTensor() # 创建一个ToTensor实例
tensor_img = tensor_trans(img) # 返回 tensor 数据类型
writer.add_image("Tensor_img", tensor_img) # 输入为tensor格式的图片
# Resize:变换图片尺寸
print(img.size)
trans_resize = transforms.Resize((512, 512))
# img PIL -> Resize -> img_resize PIL
img_resize = trans_resize(img) # 支持 PTL、tensor 格式的图片
# img_resize PIL -> ToTensor -> img_resize tensor
img_resize = tensor_trans(img_resize)
writer.add_image("Resize", img_resize, 0)
print(img_resize.shape)
writer.close()
控制台输出:
(768, 512)
torch.Size([3, 512, 512])
TensorBoard输出:
五、transforms.Compose类的使用
作用:我们做图像变换时,一般都不会单独使用一个图像变换API,而是顺序使用多个API。对于多个API,transforms模块中提供Compose类,对多个API进行打包。
1.使用说明
【实例化】transforms.Compose(transforms)
- 作用:创建一个由多个图像变换API组成的实例。
- transforms:transforms列表,由多个图像变换API组成的列表。
- 例子:
tensor_trans = transforms.ToTensor() # 创建一个ToTensor实例
trans_resize = transforms.Resize((512, 512)) # 创建一个Resize实例
trans_compose = transforms.Compose([tensor_trans, trans_resize]) # 创建一个Compose实例
【_call_】Compose_object(img)
- 作用:顺序使用多个图像变换API。
- img:支持PIL、Tensor格式的图片。
- 返回值:经过多个操作后的图片结果。
- 例子:
img_compose = trans_compose(img) # Tensor + Resize 的操作结合
2.代码实现
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
img_path = "data/train/ants_image/0013035.jpg"
img = Image.open(img_path) # PIL类型的图片数据
print(img.size)
writer = SummaryWriter("logs")
# Compose:顺序使用多个图像变换API,例如先将PIL图片转换为tensor图片,最后变换图片尺寸
tensor_trans = transforms.ToTensor() # 创建一个ToTensor实例
trans_resize = transforms.Resize((512, 512)) # 创建一个Resize实例
trans_compose = transforms.Compose([tensor_trans, trans_resize]) # 创建一个Compose实例
img_compose = trans_compose(img) # Tensor + Resize 的操作结合
writer.add_image("Compose", img_compose, 0)
print(img_compose.shape)
writer.close()
控制台输出:
(768, 512)
torch.Size([3, 512, 512])
TensorBoard输出:
六、transforms.RandomCrop类的使用
作用:在随机位置裁剪给定图像。
1.使用说明
【实例化】transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode=“constant”)
- 作用:创建一个在随机位置裁剪给定图像的实例。
- size:列表, (h, w);int整形,方形裁剪(size,size)。
- 例子:
# 创建一个RandomCrop实例,指定值不要超过图片大小,方形裁剪
tans_random = transforms.RandomCrop(200)
tans_random = transforms.RandomCrop((200, 300)) # 创建一个RandomCrop实例,指定值不要超过图片大小
【torch.nn.Module父类的__call__】RandomCrop_object(img)
- 作用:调用随机位置裁剪给定图像API。
- img:支持PIL、Tensor格式的图片。
- 返回值:经过随机位置裁剪的原格式图片。
- 例子:
img_crop = tans_random(img)
2.代码实现
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
img_path = "data/train/ants_image/0013035.jpg"
img = Image.open(img_path) # PIL类型的图片数据
writer = SummaryWriter("logs")
# RandomCorp:在随机位置裁剪给定图像。
trans_tensor = transforms.ToTensor() # 创建一个ToTensor实例
tans_random = transforms.RandomCrop(200) # 创建一个RandomCrop实例,指定值不要超过图片大小
trans_compose = transforms.Compose([tans_random, trans_tensor]) # 创建一个Compose实例
for i in range(10):
img_crop = trans_compose(img) # 输入:PIL图片。RandomCrop + Tensor 的操作结合
writer.add_image("RandomCrop", img_crop, i)
writer.close()
TensorBoard输出:
总结
- 关注输入和输出类型多看官方文档
- 关注方法需要什么参数
- 不知道返回值的时候
print()
print(type())
Debug
参考
1.[PyTorch 学习笔记] 2.2 图片预处理 transforms 模块机制
2.transforms模块—PyTorch图像处理与数据增强方法