Pytorch提供了许多工具来简化和希望数据加载,使代码更具可读性。这里将专门讲述transforms数据预处理方法,即数据增强。
数据增强又称为数据增广、数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。
# 在进行下面代码学习前需要安装torchvision==0.8.2
!pip install torchvision==0.8.2 --user
from PIL import Image
from torchvision import transforms as T
import torch as t
to_tensor = T.ToTensor()
to_pil = T.ToPILImage()
cat = Image.open('./cat.jpeg')
transforms——Crop
# torchvision.transforms.CenterCrop
transforms = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor()]) # Resize:缩放
cat_t = transforms(cat) # 传入transforms中的数据是PIL数据,lena_t为tensor
cat_t.shape # 3*224*224 ; 当T.CenterCrop()的参数大于T.Resize()的参数时,周围用0填充
to_pil(cat_t)
# torchvision.transforms.RandomCrop
transforms = T.Compose([T.Resize(224),T.RandomCrop(224, padding=(16, 64)),T.ToTensor()]) # Resize:缩放
cat_t = transforms(cat) # 传入transforms中的数据是PIL数据,lena_t为tensor
cat_t.shape # 3*224*224 ; 当T.CenterCrop()的参数大于T.Resize()的参数时,周围用0填充
to_pil(cat_t)
transforms——Flip
# torchvision.transforms.RandomHorizontalFlip
transforms = T.Compose([T.Resize(224),T.RandomHorizontalFlip(p=0.5),T.ToTensor()]) # Resize:缩放
cat_t = transforms(cat) # 传入transforms中的数据是PIL数据,lena_t为tensor
cat_t.shape # 3*224*224 ; 当T.CenterCrop()的参数大于T.Resize()的参数时,周围用0填充
to_pil(cat_t)
# torchvision.transforms.RandomRotation
transforms = T.Compose([T.Resize(224),T.RandomRotation(30, center=(0, 0), expand=True),T.ToTensor()]) # Resize:缩放
cat_t = transforms(cat) # 传入transforms中的数据是PIL数据,lena_t为tensor
cat_t.shape # 3*224*224 ; 当T.CenterCrop()的参数大于T.Resize()的参数时,周围用0填充
to_pil(cat_t)
图像变换
transforms的操作
自定义transforms
自定义transforms要素:
- 仅接收一个参数,返回一个参数
- 注意上下游的输出与输入
class Compose(object):
def __call__(self, img):
for t in transforms:
img = t(img)
return img
通过类实现多参数传入:
class YourTransforms(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
椒盐噪声又称为脉冲噪声,是一种随机出现的白点或者黑点,白点称为盐噪声,黑色为椒噪声。
信噪比(Signal-Noise Rate,SNR)是衡量噪声的比例,图像中为图像像素的占比。
class AddPepperNoise(object):
def __init__(self, snr, p):
self.snr = snr
self.p = p
def __call__(self, img):
# 添加椒盐噪声具体实现过程
img = None
return img