pytorch一致数据增强

news2024/11/26 18:40:18

分割任务对 image 做(某些)transform 时,要对 label(segmentation mask)也做对应的 transform,如 Resize、RandomRotation 等。如果对 image、label 分别用 transform 处理一遍,则涉及随机操作的可能不一致,如 RandomRotation 将 image 转了 a 度、却将 label 转了 b 度。

MONAI 有个 ArrayDataset 实现了这功能,思路是每次 transform 前都重置一次 random seed 先。对 monai 订制 transform 的方法不熟,torchvision.transforms 的订制接口比较简单,考虑基于 pytorch 实现。要改两个东西:

  • 扩展 torchvison.transforms.Compose,使之支持多个输入(image、label);
  • 一个 wrapper,扩展 transform,使之支持多输入。

思路也是重置 random seed,参考 [1-4]。

Code

  • to_multi:将处理单幅图的 transform 扩展成可处理多幅;
  • MultiCompose:扩展 torchvision.transforms.Compose,可输入多幅图。内部调用 to_multi 扩展传入的 transforms。
import random, os
import numpy as np
import torch

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def to_multi(trfm):
    """wrap a transform to extend to multiple input with synchronised random seed
    Input:
        trfm: transformation function/object (custom or from torchvision.transforms)
    Output:
        _multi_transform: function
    """
    # numpy.random.seed range error:
    #   ValueError: Seed must be between 0 and 2**32 - 1
    min_seed = 0 # - 0x8000_0000_0000_0000
    max_seed = min(2**32 - 1, 0xffff_ffff_ffff_ffff)
    def _multi_transform(*images):
        """images: [C, H, W]"""
        if len(images) == 1:
            return trfm(images[0])
        _seed = random.randint(min_seed, max_seed)
        res = []
        for img in images:
            seed_everything(_seed)
            res.append(trfm(img))
        return tuple(res)

    return _multi_transform


class MultiCompose:
    """Extension of torchvision.transforms.Compose that accepts multiple input.
    Usage is the same as torchvision.transforms.Compose. This class will wrap input
    transforms with `to_multi` to support simultaneous multiple transformation.
    This can be useful when simultaneously transforming images & segmentation masks.
    """
    def __init__(self, transforms):
        """transforms should be wrapped by `to_multi`"""
        self.transforms = [to_multi(t) for t in transforms]

    def __call__(self, *images):
        for t in self.transforms:
            images = t(*images)
        return images

test

测试一致性,用到预处理过的 verse’19 数据集、一些工具函数、一个订制 transform:

  • verse’19 数据集及预处理见 iTomxy/data/verse;
  • digit_sort_key:数据文件排序用;
  • get_palettecolor_segblend_seg:可视化用;
  • MyDataset:看其中 __getitem__ 的 transform 用法,即同时传入 image 和 label;
  • ResizeZoomPad:一个订制的 transform;
import os, os.path as osp, random
from glob import glob
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F


def digit_sort_key(s, num_pattern=re.compile('([0-9]+)')):
    """natural sort,数据排序用"""
    return [int(text) for text in num_pattern.split(s) if text.isdigit()]


def get_palette(n_classes, pil_format=True):
    """创建调色盘,可视化用"""
    n = n_classes
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
            palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
            palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
            i += 1
            lab >>= 3

    if pil_format:
        return palette

    res = []
    for i in range(0, len(palette), 3):
        res.append(tuple(palette[i: i+3]))
    return res


def color_seg(label, n_classes=0):
    """segmentation mask 上色,可视化用"""
    if n_classes < 1:
        n_classes = math.ceil(np.max(label)) + 1
    label_rgb = Image.fromarray(label.astype(np.int32)).convert("L")
    label_rgb.putpalette(get_palette(n_classes))
    return label_rgb.convert("RGB")


def blend_seg(image, label, n_classes=0, alpha=0.7, rescale=False, transparent_bg=True, save_file=""):
    """融合 image 和其 segmentation mask,可视化用"""
    if rescale:
        denom = image.max() - image.min()
        if 0 != denom:
            image = (image - image.min()) / denom * 255
        image = np.clip(image, 0, 255).astype(np.uint8)
    img_pil = Image.fromarray(image).convert("RGB")
    lab_pil = color_seg(label, n_classes)
    blended_image = Image.blend(img_pil, lab_pil, alpha)
    if transparent_bg:
        blended_image = Image.fromarray(np.where(
            (0 == label)[:, :, np.newaxis],
            np.asarray(img_pil),
            np.asarray(blended_image)
        ))
    if save_file:
        blended_image.save(save_file)
    return blended_image


class MyDataset(torch.utils.data.Dataset):
    """订制 dataset,看 __getitem__ 处 transform 的调法"""
    def __init__(self, image_list, label_list, transform=None):
        assert len(image_list) == len(label_list)
        self.image_list = image_list
        self.label_list = label_list
        self.transform = transform
    def __len__(self):
        return len(self.image_list)
    def __getitem__(self, index):
        img = np.load(self.image_list[index]) # [h, w]
        lab = np.load(self.label_list[index])
        img = torch.from_numpy(img).unsqueeze(0).float() # -> [c=1, h, w]
        lab = torch.from_numpy(lab).unsqueeze(0).int()
        if self.transform is not None:
            img, lab = self.transform(img, lab) # 同时传入 image、label
        return img, lab


class ResizeZoomPad:
    """订制 resize"""
    def __init__(self, size, interpolation="bilinear"):
        if isinstance(size, int):
            assert size > 0
            self.size = [size, size]
        elif isinstance(size, (tuple, list)):
            assert len(size) == 2 and size[0] > 0 and size[1] > 0
            self.size = size

        if isinstance(interpolation, str):
            assert interpolation.lower() in {"nearest", "bilinear", "bicubic", "box", "hamming", "lanczos"}
            interpolation = {
                "nearest": F.InterpolationMode.NEAREST,
                "bilinear": F.InterpolationMode.BILINEAR,
                "bicubic": F.InterpolationMode.BICUBIC,
                "box": F.InterpolationMode.BOX,
                "hamming": F.InterpolationMode.HAMMING,
                "lanczos": F.InterpolationMode.LANCZOS
            }[interpolation.lower()]
        self.interpolation = interpolation

    def __call__(self, image):
        """image: [C, H, W]"""
        scale_h, scale_w = float(self.size[0]) / image.size(1), float(self.size[1]) / image.size(2)
        scale = min(scale_h, scale_w)
        tmp_size = [ # clipping to ensure size
            min(int(image.size(1) * scale), self.size[0]),
            min(int(image.size(2) * scale), self.size[1])
        ]
        image = F.resize(image, tmp_size, self.interpolation)
        assert image.size(1) <= self.size[0] and image.size(2) <= self.size[1]
        pad_h, pad_w = self.size[0] - image.size(1), self.size[1] - image.size(2)
        if pad_h > 0 or pad_w > 0:
            pad_left, pad_right = pad_w // 2, (pad_w + 1) // 2
            pad_top, pad_bottom = pad_h // 2, (pad_h + 1) // 2
            image = F.pad(image, (pad_left, pad_top, pad_right, pad_bottom))
        return image


# 读数据文件
data_path = os.path.expanduser("~/data/verse/processed-verse19-npy-horizontal")
train_images, train_labels, val_images, val_labels = [], [], [], []
for d in os.listdir(osp.join(data_path, "training")):
    if d.endswith("_ct"):
        img_p = osp.join(data_path, "training", d)
        lab_p = osp.join(data_path, "training", d[:-3]+"_seg-vert_msk")
        assert osp.isdir(lab_p)
        train_labels.extend(glob(os.path.join(lab_p, "*.npy")))
        train_images.extend(glob(os.path.join(img_p, "*.npy")))
for d in os.listdir(osp.join(data_path, "validation")):
    if d.endswith("_ct"):
        img_p = osp.join(data_path, "validation", d)
        lab_p = osp.join(data_path, "validation", d[:-3]+"_seg-vert_msk")
        assert osp.isdir(lab_p)
        val_labels.extend(glob(os.path.join(lab_p, "*.npy")))
        val_images.extend(glob(os.path.join(img_p, "*.npy")))

# 数据文件名排序
train_images = sorted(train_images, key=lambda f: digit_sort_key(os.path.basename(f)))
train_labels = sorted(train_labels, key=lambda f: digit_sort_key(os.path.basename(f)))
val_images = sorted(val_images, key=lambda f: digit_sort_key(os.path.basename(f)))
val_labels = sorted(val_labels, key=lambda f: digit_sort_key(os.path.basename(f)))

# transform
# 用 MultiCompose,其内部调用 to_multi 将 transforms wrap 成支持多输入的
train_trans = MultiCompose([
    ResizeZoomPad((224, 256)),
    transforms.RandomRotation(30),
])

# 测试:读数据,可试化 image 和 label
check_ds = MyDataset(train_images, train_labels, train_trans)
check_loader = torch.utils.data.DataLoader(check_ds, batch_size=10, shuffle=True)
for images, labels in check_loader:
    print(images.size(), labels.size())
    for i in range(images.size(0)):
        # print(i, end='\r')
        img = images[i][0].numpy()
        lab = labels[i][0].numpy()
        print(np.unique(lab))
        seg_img = blend_seg(img, lab)
        img = (255 * (img - img.min()) / (img.max() - img.min())).astype(np.uint8)
        img = np.asarray(Image.fromarray(img).convert("RGB"))
        lab = np.asarray(color_seg(lab))
        comb = np.concatenate([img, lab, seg_img], axis=1)
        Image.fromarray(comb).save(f"test-dataset-{i}.png")
    break

效果:
test-dataset-7.png
可见,image 和 label 转了同一个随机角度。

Limits

有些 augmentations 是只对 image 做而不对 label 做的,如 ColorJitter,这里没有考虑怎么处理。

References

  1. How to Set Random Seeds in PyTorch and Tensorflow
  2. ihoromi4/seed_everything.py
  3. Reproducibility
  4. What is the max seed you can set up?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1300975.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于深度学习的yolov7植物病虫害识别及防治系统

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介简介YOLOv7 系统特性工作流程 二、功能三、系统四. 总结 一项目简介 # YOLOv7植物病虫害识别及防治系统介绍 简介 该系统基于深度学习技术&#xff0c;采…

【KCC@南京】KCC南京“数字经济-开源行”活动回顾录

11月26日&#xff0c;由KCC南京、中科南京软件研究所、傲空间、PowerData联合主办的 KCC南京“数字经济-开源行” 的活动已圆满结束。此次活动&#xff0c;3 场主题研讨&#xff0c;11 场分享&#xff0c;现场参会人数 60&#xff0c;线上直播观看 3000&#xff0c;各地小伙伴从…

抓取真实浏览器设备指纹fingerprint写入cookie方案

一个关于抓取真实浏览器设备指纹写入cookie方案&#xff0c;用户访问页面获取到用户设备生成指纹id&#xff0c;通过js把指纹存入cookie&#xff0c;然后用php进行获取cookie存的指纹值到后台。 用途&#xff1a;追踪用户设备&#xff0c;防恶意注册&#xff0c;防恶意采集 浏…

1827_ChibiOS中OSLIB的邮箱机制

全部学习汇总&#xff1a; GreyZhang/g_ChibiOS: I found a new RTOS called ChibiOS and it seems interesting! (github.com) 1. 邮箱其实是一个环形队列&#xff1b; 2. 使用场景上&#xff0c;邮箱主要是用来实现异步单向的一些消息或者数据处理的。在处理机制上&#xff…

C语言 预处理 + 条件编译宏 + 井号运算符

预处理阶段任务 预处理指令 条件编译宏 条件编译宏的作用在于根据编译时的条件进行代码的选择性编译&#xff0c;从而实现不同环境、不同配置或不同功能的编译版本。 这可以用于实现调试模式和发布模式的切换&#xff0c;平台适配&#xff0c;以及选择性地编译不同的功能模块等…

【Spring 基础】00 入门指南

【Spring 基础】00 入门指南 文章目录 【Spring 基础】00 入门指南1.简介2.概念1&#xff09;控制反转&#xff08;IoC&#xff09;2&#xff09;依赖注入&#xff08;DI&#xff09; 3.核心模块1&#xff09;Spring Core2&#xff09;Spring AOP3&#xff09;Spring MVC4&…

组件之间传值

目录 1&#xff1a;组件中的关系 2&#xff1a;父向子传值 3&#xff1a;子组件向父组件共享数据 4&#xff1a;兄弟组件数据共享 1&#xff1a;组件中的关系 在项目中使用到的组件关系最常用两种是&#xff0c;父子关系&#xff0c;兄弟关系 例如A组件使用B组件或者C组件…

大师学SwiftUI第18章Part2 - 存储图片和自定义相机

存储图片 在前面的示例中&#xff0c;我们在屏幕上展示了图片&#xff0c;但也可以将其存储到文件或数据库中。另外有时使用相机将照片存储到设备的相册薄里会很有用&#xff0c;这样可供其它应用访问。UIKit框架提供了如下两个保存图片和视频的函数。 UIImageWriteToSavedPh…

CCF刷题记录 -- 202305-2:矩阵运算 --python解法

2023.12.7 主要算法 矩阵置换矩阵相乘 满分注意点 运算顺序&#xff0c;利用了矩阵运算法则中的&#xff08;A*B&#xff09;*c A*(B*C) # 矩阵置换 def zhihuan(a):b[]for i in range(d):c []for j in range(n):c.append(a[j][i])b.append(c)return b# 矩阵相乘 def ju_zh…

C# WPF上位机开发(通讯协议的编写)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 作为上位机&#xff0c;它很重要的一个部分就是需要和外面的设备进行数据沟通的。很多时候&#xff0c;也就是在这个沟通的过程当中&#xff0c;上…

SuperMap iObject.NET三维场景拖拽框选实现详解及完整源代码(一)——环境准备及项目配置

作者&#xff1a;超图研究院技术支持中心-于丁1 SuperMap iObject.NET三维场景拖拽框选实现详解及完整源代码&#xff08;一&#xff09;——环境准备及项目配置   三维场景框选是一种在三维空间中进行选择和操作的功能&#xff0c;它可以让使用者通过鼠标拖动来创建一个矩形…

stu06-VSCode里的常用快捷键

Alt Z&#xff1a;文字自动换行。当一行的文字太长时&#xff0c;可以使用。或者查看→自动换行Alt Shift ↓ &#xff1a;快速复制当前行到下一行Alt Shift ↑ &#xff1a;快速复制当前行到上一行Alt B&#xff1a;在默认浏览器中打开当前.html文件Ctrl Enter&#xf…

前端 Web Workers 简介

简介 以前我们总说&#xff0c;JS 是单线程没有多线程&#xff0c;当 JS 在页面中运行长耗时同步任务的时候就会导致页面假死影响用户体验&#xff0c;从而需要设置把任务放在任务队列中&#xff1b;执行任务队列中的任务也并非多线程进行的&#xff0c;然而现在 HTML5 提供了…

如何一个例子玩明白GIT

一个例子玩明白GIT GIT的介绍和教程五花八门&#xff0c;但实际需要用的就是建仓、推送、拉取等操作&#xff0c;这儿咱可以通过一个例子熟悉这些操作&#xff0c;一次性搞定GIT的使用方法学习。下面这个例子的内容是内容是建立初始版本库&#xff0c;然后将数据复制到 "远…

05-详解Nacos配置管理中心,配置拉取的方式,热更新,配置共享(优先级)的步骤

Nacos配置管理 新建配置文件 当微服务部署的实例越来越多时,如果需要修改微服务的配置就需要逐个修改配置文件并且还要重启关联的微服务十分繁琐还易出错 项目中的配置文件分为每个项目特有的配置,项目所公用的配置 每个项目特有的配置: 有些项目中需要但有些项目中又不需要…

初学者如何入门 Generative AI 之 Stable Diffusion 与 CLIP :看两篇综述,玩几个应用感受一下先!超多高清大图,沉浸式体验

文章大纲 4种 图片生成 的算法扩散模型的起源Stable DiffusionCLIP参考文献与学习路径A synthography of an astronaut riding a horse created in NightCafe Studio with Stable Diffusion XL (SDXL). Prompt is a photograph of an astronaut riding a horse with weight of …

lenovo联想拯救者Legion R7000P 2020H(82GR)笔记本原厂Windows10系统包

拯救者笔记本电脑原装出厂WIN10系统ISO镜像 链接&#xff1a;https://pan.baidu.com/s/1iPNXELRipKaAIR-yaq5HNg?pwdm27n 提取码&#xff1a;m27n 自带有所有驱动、出厂主题壁纸、系统属性专属LOGO标志、Office办公软件、联想电脑管家等预装程序 所需要工具&#xff1a;1…

【网络奇缘系列】计算机网络|数据通信方式|数据传输方式

&#x1f308;个人主页: Aileen_0v0&#x1f525;系列专栏: 一见倾心,再见倾城 --- 计算机网络~&#x1f4ab;个人格言:"没有罗马,那就自己创造罗马~" 这篇文章是关于计算机网络中数据通信的基础知识点&#xff0c; 从模型&#xff0c;术语再到数据通信方式&#…

Jmeter 请求签名api接口-BeanShell

Jmeter 请求签名api接口-BeanShell 项目签名说明编译扩展jar包jmeter 使用 BeanShell 调用jar包中的签名方法 项目签名说明 有签名算法的api接口本地不好测试&#xff0c;使用BeanShell 扩展jar 包对参数进行签名&#xff0c;接口签名算法使用 sha512Hex 算法。签名的说明如下…

java实现网络聊天

网络聊天实现步骤&#xff08;从功能谈论方法&#xff09;&#xff1a; 客户端&#xff1a; 1.登录面板&#xff1a;注册提醒用户注册格式&#xff0c;登录账号密码不为空&#xff0c;点击登录的时候需要连接服务器端&#xff0c;启动聊天面板。&#xff08;监听用户点击登录…