在AI领域的模型训练中通常会遇到模型过拟合问题,通常采取的办法就是数据增强处理,例如在图像处理中,数据增强是指对原始图像进行旋转、缩放、剪切、翻转等操作,以扩大训练数据集的规模,提高模型泛化能力,降低过拟合风险。
笔者在这里以深度学习框架Pytorch中的数据增强工具(transforms模块)为例介绍数据增强处理。torchvision.transforms是PyTorch中用于图像处理和数据增强的模块。它提供了许多函数,可以在图像上应用各种转换,例如裁剪、旋转、翻转、缩放、归一化等操作,从而生成更多变化的图像数据。
transforms模块中的函数可以分为两类:一类是针对PIL图像对象的操作函数,例如:Resize、RandomCrop、RandomHorizontalFlip等;另一类是对Tensor对象的操作函数,例如:Normalize、ToTensor等。下面笔者将分别介绍这些函数的原理和代码示例。
针对PIL图像对象的操作函数
Resize
Resize函数可以将图像缩放到指定大小。常用的缩放方法等比例缩放和非等比例缩放两种。其中如果预先目标大小,Resize函数会按照目标大小等比例缩放图像,如果预先仅指定了图片宽度或高度,则会进行非等比例缩放。
from torchvision import transforms
from PIL import Image
# 等比例缩放
transform = transforms.Compose([transforms.Resize((224, 224))])
img = Image.open('1.png')
img.show(img)
img = transform(img)
img.show(img)
# # 非等比例缩放
transform = transforms.Compose([transforms.Resize((224, 300))])
img = Image.open('1.png')
img.show(img)
img = transform(img)
img.show(img)
-
等比例缩放
-
非等比例缩放
CenterCrop
CenterCrop函数可以从图像中心裁剪指定大小的区域。
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([transforms.CenterCrop(224)])
img = Image.open('1.png')
img.show(img)
img = transform(img)
img.show(img)
RandomCrop
RandomCrop函数可以随机裁剪指定大小的区域。
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([transforms.RandomCrop(224)])
img = Image.open('1.png')
img.show(img)
img = transform(img)
img.show(img)
RandomHorizontalFlip
RandomHorizontalFlip函数可以随机水平翻转图像。
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([transforms.RandomHorizontalFlip()])
img = Image.open('1.png')
img.show(img)
img = transform(img)
img.show(img)
针对Tensor对象的操作函数
ToTensor
ToTensor函数可以将PIL图像对象转换为Tensor对象。
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([transforms.ToTensor()])
img = Image.open('1.png')
img = transform(img)
print(img.size()) #torch.Size([3, 800, 1000])
Normalize
Normalize函数可以对Tensor对象进行归一化,以减少模型训练的时间,提高模型性能和稳定性。
from torchvision import transforms
from PIL import Image
import torch
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img = Image.open('1.png')
img = transform(img)
print(torch.min(img), torch.max(img)) # tensor(-1.) tensor(1.)
上述代码将图像像素值归一化到[-1, 1]之间。
Compose函数
Compose函数则用于将多个transforms函数组合在一起,形成一个transforms的列表。在数据加载时,会按照列表中的顺序,依次对图像进行变换。
from torchvision import transforms
transform = transforms.Compose([
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = Image.open('1.png')
img = transform(img)
print(torch.min(img), torch.max(img))
上述代码中,我们通过Compose函数将CenterCrop、RandomHorizontalFlip、ToTensor和Normalize四个函数组合在一起,形成了一个transform对象。