前言
在进行语义分割的时候,我们的数据集有时候不够用,常常需要进行数据增广。
比较常用的数据增广方法(包括旋转,上下翻转,左右翻转,裁剪,调整对比度,调整饱和度,调整亮度,中心裁剪等)
代码实现
import random
import os
import numpy as np
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as tf
class Augmentation:
def __init__(self):
pass
def rotate(self, image, mask, angle=None):
if angle is None:
angle = transforms.RandomRotation.get_params([-180, 180]) # -180~180随机选一个角度旋转
image = tf.rotate(image, angle)
mask = tf.rotate(mask, angle)
return image, mask
def flip(self, image, mask): # 水平翻转和垂直翻转
if random.random() > 0.5:
image = tf.hflip(image)
mask = tf.hflip(mask)
if random.random() > 0.5:
image = tf.vflip(image)
mask = tf.vflip(mask)
return image, mask
def randomResizeCrop(self, image, mask, scale=(0.3, 1.0), ratio=(1, 1)): # scale表示随机crop出来的图片会在的0.3倍至1倍之间,ratio表示长宽比
img = np.array(image)
h_image, w_image = img.shape[:2]
resize_size = h_image
i, j, h, w = transforms.RandomResizedCrop.get_params(image, scale=scale, ratio=ratio)
image = tf.resized_crop(image, i, j, h, w, resize_size)
mask = tf.resized_crop(mask, i, j, h, w, resize_size)
return image, mask
def adjustContrast(self, image, mask):
factor = random.uniform(0.5, 1.5) # 这里调增广后的数据的对比度因子
image = tf.adjust_contrast(image, factor)
return image, mask
def adjustBrightness(self, image, mask):
factor = random.uniform(0.5, 1.5) # 这里调增广后的数据亮度因子
image = tf.adjust_brightness(image, factor)
return image, mask
def centerCrop(self, image, mask, size=None): # 中心裁剪
if size is None:
size = image.size # 若不设定size,则是原图。
image = tf.center_crop(image, size)
mask = tf.center_crop(mask, size)
return image, mask
def adjustSaturation(self, image, mask): # 调整饱和度
factor = random.uniform(0.5, 1.5) # 这里调增广后的数据饱和度因子
image = tf.adjust_saturation(image, factor)
return image, mask
def augmentationData(image_path, mask_path, option=[1, 2, 3, 4, 5, 6, 7], save_dir=None):
'''
:param image_path: 图片的路径
:param mask_path: mask的路径
:param option: 需要哪种增广方式:1为旋转,2为翻转,3为随机裁剪并恢复原本大小,4为调整对比度,5为中心裁剪(不恢复原本大小),6为调整亮度,7为饱和度
:param save_dir: 增广后的数据存放的路径
'''
aug_image_savedDir = os.path.join(save_dir, 'img')
aug_mask_savedDir = os.path.join(save_dir, 'mask')
if not os.path.exists(aug_image_savedDir):
os.makedirs(aug_image_savedDir)
print('create aug image dir.....')
if not os.path.exists(aug_mask_savedDir):
os.makedirs(aug_mask_savedDir)
print('create aug mask dir.....')
aug = Augmentation()
images = [os.path.join(image_path, f) for f in os.listdir(image_path)]
masks = [os.path.join(mask_path, f) for f in os.listdir(mask_path)]
datas = list(zip(images, masks))
num = len(datas)
for (image_path, mask_path) in datas:
image = Image.open(image_path).convert("RGB")
mask = Image.open(mask_path).convert("RGB")
for opt in option:
num += 1
if opt == 1:
image_tensor, mask_tensor = aug.rotate(image, mask)
aug_type = 'rotate'
elif opt == 2:
image_tensor, mask_tensor = aug.flip(image, mask)
aug_type = 'flip'
elif opt == 3:
image_tensor, mask_tensor = aug.randomResizeCrop(image, mask)
aug_type = 'ResizeCrop'
elif opt == 4:
image_tensor, mask_tensor = aug.adjustContrast(image, mask)
aug_type = 'Contrast'
elif opt == 5:
image_tensor, mask_tensor = aug.centerCrop(image, mask)
aug_type = 'centerCrop'
elif opt == 6:
image_tensor, mask_tensor = aug.adjustBrightness(image, mask)
aug_type = 'Brightness'
elif opt == 7:
image_tensor, mask_tensor = aug.adjustSaturation(image, mask)
aug_type = 'Saturation'
else:
continue
image_tensor = tf.to_tensor(image_tensor)
mask_tensor = tf.to_tensor(mask_tensor)
transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'img', f'{num}_{aug_type}.jpg'))
transforms.ToPILImage()(mask_tensor).save(os.path.join(save_dir, 'mask', f'{num}_{aug_type}_mask.jpg'))
augmentationData(r'D:\wheat\project\tips\jpg', r'D:\wheat\project\tips\PNG',
save_dir=r'D:\wheat\project\tips\finished')