前作 [1] 介绍了一种用 pytorch 模仿 MONAI 实现多幅图(如:image 与 label)同用 random seed 保证一致变换的写法,核心是 MultiCompose
类和 to_multi
包装函数。不过 [1] 没考虑不同图用不同 augmentation 的情况,如:
- ColorJitter 只对 image 做,而不对 label 做;
- image 的 resize interpolation 可任选,但 label 只能用
nearest
。
本篇更新写法,支持各图同用、异用 augmentation。
Code
- 对比 [1],主要改变是改写
MultiCompose
类,并将to_multi
吸收入内。 MultiCompose
的用法还是和torchvision.transforms.Compose
几乎一致,不过支持异用 augmentation:只要为各图指定各自的 augmentation 类/函数即可。见下一节例程。
def to_multi():
"""不用单独的 to_multi 打包了,已并入 MultiCompose"""
raise NotImplementedError
class MultiCompose:
"""Extension of torchvision.transforms.Compose that accepts multiple inputs
and ensures the same random seed is applied on each of these inputs at each transforms.
This can be useful when simultaneously transforming images & segmentation masks.
"""
# numpy.random.seed range error:
# ValueError: Seed must be between 0 and 2**32 - 1
MIN_SEED = 0 # - 0x8000_0000_0000_0000
MAX_SEED = min(2**32 - 1, 0xffff_ffff_ffff_ffff)
def __init__(self, transforms):
# self.transforms = [to_multi(t) for t in transforms]
no_op = lambda x: x # i.e. identity function
self.transforms = []
for t in transforms:
if isinstance(t, (tuple, list)):
# convert `None` to `no_op` for convenience
self.transforms.append([no_op if _t is None else _t for _t in t])
else:
self.transforms.append(t)
def __call__(self, *images):
for t in self.transforms:
if isinstance(t, (tuple, list)):
assert len(images) <= len(t) # allow redundant transform
else:
t = [t] * len(images)
_aug_images = []
_seed = random.randint(self.MIN_SEED, self.MAX_SEED)
for _im, _t in zip(images, t):
seed_everything(_seed)
_aug_images.append(_t(_im))
images = _aug_images
if len(images) == 1:
images = images[0]
return images
Usage & Test
例程沿用 [1],但改一下 augmentation:
train_trans = MultiCompose([
# image 用 bilinear,label 用 nearest
(ResizeZoomPad((224, 256), "bilinear"), ResizeZoomPad((224, 256), "nearest")), # 异用
transforms.RandomAffine(30, (0.1, 0.1)), # 同用,传一个就行
transforms.RandomHorizontalFlip(), # 同用
# ColorJitter 只对 image 做,label 不做(None)
[transforms.ColorJitter(0.1, 0.2, 0.3, 0.4), None], # 异用
])
- 效果:
References
- pytorch一致数据增强