MMdetection框架实现数据增强的N种方法
- 1 为什么要进行数据增强
- 2 数据增强的常见误区
- 3 常见的六种数据增强方式
- 3.1 随机翻转(RandomFlip)
- 3.2 随机裁剪(RandomCrop)
- 3.3 随机比例裁剪并缩放(RandomResizedCrop)
- 3.4 色彩抖动(ColorJitter)
- 3.5 随机灰度化(RandomGrayscale)
- 3.6 随机光照变换(Lighting)
1 为什么要进行数据增强
众所周知,即使是目前最先进的神经网络模型,其本质上也是在利用一系列线性和非线性的函数去拟合目标输出。既然是拟合,当然越多的样本就能获得越准确的结果,这也是为什么现在训练神经网络所使用的数据规模越来越大。
然而,在实际使用中,我们往往可能只有几千甚至几百份数据。面对神经网络数以 M 计的参数,很容易陷入过拟合的陷阱。因为神经网络的收敛需要一个较长的训练过程,而这个过程中网络遇到的反反复复都是训练集的那几张图片,硬背都背下来了,自然很难学到什么能够泛化的特征。一个自然的想法是,能不能用一张图片去生成一系列图片,从而成百上千倍地扩充我们的数据集?而这,也正是数据增强的目的之一。
神经网络的“取巧”——神经网络是没有常识的,因此它永远只会用最“方便”的方式区分两个类别。
假设要训练一个区分苹果和橘子的神经网络,但手上的数据只有红苹果和青橘子,那无论拍摄多少张照片,神经网络也只会简单地认为红色的就是苹果,青色的就是橘子。这在实际使用中经常出现,拍摄的灯光、拍摄的角度等等,任何一个不起眼的区分点,都会被神经网络当做分类的依据。
接下来将列举当前分类方向研究中最常用的一系列数据增强手段和效果。
2 数据增强的常见误区
一些人会认为,既然有这么多数据增强的方法,那么我一口气全堆到一起,是不是就能获得最好的增强效果?答案是否定的,数据增强的目标并不是无脑地堆数据,而是尽可能地去覆盖原始数据无法覆盖不到,但现实生活中会出现的情况。
举例说明:现在要训练一个用以区分道路上汽车种类的神经网络,那么图片的垂直翻转很大程度上就不是一个好的数据增强方法,毕竟现实中不太可能遇到汽车四轮朝上的情况。
3 常见的六种数据增强方式
3.1 随机翻转(RandomFlip)
随机翻转是一个非常常用的数据增强方法,包括水平和垂直翻转。其中,水平翻转是最常用的,但根据实际目标的不同,垂直翻转也可以使用。
在 MMClassificiation 中,大部分数据增强方法都可以通过修改 config 中的 pipeline 配置来实现。这里我们提供了一份 python 代码,用来展示如上图所示的数据增强效果:
import mmcvfrom mmcls.datasets
import PIPELINES
# 数据增强配置,利用 Registry 机制创建数据增强对象
aug_cfg = dict(
type='RandomFlip',
flip_prob=0.5, # 按 50% 的概率随机翻转图像
direction='horizontal', # 翻转方向为水平翻转
)
aug = PIPELINES.build(aug_cfg)
img = mmcv.imread("./kittens.jpg")
# 为了便于信息在预处理函数之间传递,数据增强项的输入和输出都是字典
img_info = {'img': img}
img_aug = aug(img_info)['img']
mmcv.imshow(img_aug)
3.2 随机裁剪(RandomCrop)
在图片的随机位置,按照指定的大小进行裁剪。这种数据增强的方式能够在保留图像比例的基础上,移动图片上各区域在图片上的位置。
在 MMClassification 中,可使用以下配置:
# 此处只提供 cfg 选项,只需替换 RandomFlip 示例中对应部分,即可预览效果
aug_cfg = dict(
type='RandomCrop',
size=(384, 384), # 裁剪大小
padding = None, # 边缘填充宽度(None 为不填充)
pad_if_needed=True, # 如果图片过小,是否自动填充边缘
pad_val=(128, 128, 128), # 边缘填充像素
padding_mode='constant', # 边缘填充模式
)
3.3 随机比例裁剪并缩放(RandomResizedCrop)
这一方法目前几乎是 ImageNet 等通用图像数据集在进行分类网络训练时的标准增强手段。相较于 RandomCrop 死板地裁剪下固定尺寸的图片,RandomResizedCrop 会在一定的范围内,在随机位置按照随机比例裁剪图像,之后再缩放至统一的大小。
因此,图像会在比例上存在一定程度的失真。 但这对分类来说不一定是件坏事,毕竟你并不会把一个稍扁一点的猫认成狗,而网络也能够通过这种增强学到更加接近本质的特征。另外,因为是按比例的裁剪,这种增强手段也就对不同分辨率的图片输入更加友好。
在 MMClassification 中,可使用以下配置:
# 此处只提供 cfg 选项,只需替换 RandomFlip 示例中对应部分,即可预览效果
aug_cfg = dict(
type='RandomResizedCrop',
size=(384, 384), # 目标大小
scale=(0.08, 1.0), # 裁剪图片面积占比限制(不得小于原始面积的 8%)
ratio=(3. / 4., 4. / 3.), # 裁剪图片长宽比例限制,防止过度失真
max_attempts=10, # 当长宽比和面积限制无法同时满足时,最大重试次数
interpolation='bilinear', # 图像缩放算法
backend='cv2', # 缩放后端,有时 'cv2'(OpenCV) 和 'pillow' 有微小差别
)
3.4 色彩抖动(ColorJitter)
对图像的色彩进行数据增强的方法,其中最常用的莫过于 ColorJitter,这种方法会在一定范围内,对图像的亮度(Brightness)、对比度(Contrast)、饱和度(Saturation)和色相(Hue)进行随机变换,从而模拟真实拍摄中不同灯光环境等条件的变化。
在 MMClassification 中,可使用以下配置:
# 此处只提供 cfg 选项,只需替换 RandomFlip 示例中对应部分,即可预览效果
aug_cfg = dict(
type='ColorJitter',
brightness=0.5, # 亮度变化范围(0.5 ~ 1.5)
contrast=0.5, # 对比度变化范围(0.5 ~ 1.5)
saturation=0.5, # 饱和度变化范围(0.5 ~ 1.5)
# 色相变换应用较少,目前 MMClassification 暂不支持 Hue 的增强
)
3.5 随机灰度化(RandomGrayscale)
按照一定概率,将图片转变为灰度图。这种增强方法消除了颜色的影响,在特定场景有所应用。
在 MMClassification 中,可使用以下配置:
# 此处只提供 cfg 选项,只需替换 RandomFlip 示例中对应部分,即可预览效果
aug_cfg = dict(
type='RandomGrayscale',
gray_prob=0.5, # 按 50% 的概率随机灰度化图像
)
3.6 随机光照变换(Lighting)
论文提出的一种针对图片光照的数据增强方法,其来源:https://dl.acm.org/doi/pdf/10.1145/3065386
在这种方法中,首先在训练数据集中对所有图像的像素进行 PCA(主成分分析),从而获得 RGB 空间中的特征值和特征向量。那么这个特征向量代表了什么呢?此处,论文作者认为它代表了光照强度对图片像素的影响。毕竟虽然图像内容各种各样,但不管哪张图片的哪个位置,都不可避免地受到光照条件的影响。
既然特征向量代表了光照强度的影响,那么只要沿着特征向量的方向对图片的像素值做一些随机的加减,就能模拟不同光照的图像。
在 MMClassification 中,可使用以下配置。需要注意的是其中特征值和特征向量的设置,如果你的任务是在通用场景下的分类,那么可以直接沿用 ImageNet 的值;而如果你的任务是在特殊光照环境下的,那么则需要采集不同光照强度下的图像,在自己的数据集上进行 PCA 来替代这里的设置。
import mmcvfrom mmcls.datasets
import PIPELINES
aug_cfg = dict(
type='Lighting',
eigval=[55.4625, 4.7940, 1.1475], # 在 ImageNet 训练集 PCA 获得的特征值
eigvec=[[-0.5675, 0.7192, 0.4009], # 在 ImageNet 训练集 PCA 获得的特征向量
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203]],
alphastd=2.0, # 随机变换幅度,为了展示效果,这里设置较大,通常设置为 0.1
to_rgb=True, # 是否将图像转换为 RGB,mmcv 读取图像为 BGR 格式,为了与特征向量对应,此处转为 RGB
)
aug = PIPELINES.build(aug_cfg)
img = mmcv.imread("./kittens.jpg")
img_info = {'img': img}
img_aug = aug(img_info)['img']
# Lighting 变换得到的图像为 float32 类型,且超出 0~255 范围,为了可视化,此处进行限制
img_aug[img_aug < 0] = 0
img_aug[img_aug > 255] = 255
img_aug = img_aug.astype('uint8')[:, :, ::-1] # 转回 BGR 格式
mmcv.imshow(img_aug)
以上介绍的数据增强方法只是常用方法的一部分,更多的数据增强方法,如多种方法的随机组合(AutoAugment、RandAugment)、多张图片的混合增强(MixUp、CutMix)等。