使用合成数据训练语义分割模型

news2024/11/27 2:41:33

计算机视觉应用现在在各个科技领域无处不在。 这些模型高效且有效,研究人员每年都会尝试新想法。 新的前沿是试图消除深度学习的最大负担:需要大量的标记数据。 正如本文所述,此问题的解决方案是使用合成数据。

从这些研究中获益最多的计算机视觉领域当然是语义分割领域,即预测图像每个像素的标签的任务,以便从图像中检索感兴趣的对象。 正如人们所预料的那样,手动标记训练集是一个昂贵、耗时且容易出错的过程,因此有多种利用合成数据的新方法。

在本文中,我们将看到其中一种方法,它利用生成对抗网络来解决使用合成数据的域适应问题。另一种常用的合成数据生成方法是利用逼真渲染的游戏引擎,例如基于UE5开发的UnrealSynth合成数据生成器:
在这里插入图片描述

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D场景编辑器

1、合成数据生成

为了生成语义分割任务的数据,最常见的解决方案是使用与渲染引擎关联的模拟器。 通过这种方式,可以随意生成图像,改变闪电条件、物体的数量和姿势以及它们之间的交互,并始终关联像素完美的语义标签。 例如,一个非常流行的数据集,几乎所有研究都用作基准,是 GTAV [1],其中使用的模拟引擎是同名视频游戏。 该数据集包含从汽车驾驶员的角度拍摄的图像,非常适合自动驾驶等应用。 另一个著名的数据集是 SINTHIA [2],它也包含城市环境的图像。

在这里插入图片描述

图 1.1 — 来自 GTAV 数据集 [2] 的带有标签的图像示例

2、领域适应的生成方法

直接使用合成数据训练模型是不够的,神经网络可能会学习模拟环境中存在的一些不真实的模式,无法很好地概括现实世界的数据。 这称为域适应问题(Domain Adaption Problem)。

为了克服这个问题,模型必须在训练过程中学习重新调整源域 S(合成域)和目标域 T(真实域)之间特征分布的最佳方法。 这可以通过对抗性训练、知识蒸馏和自我监督学习等多种方法来实现。

特别是,对抗性训练的特点是采用生成方法,将源域数据转换为更类似于目标域的分布。 它可以表述如下:

给定源域数据集 Dₛ= {(xᵢˢ, yᵢˢ), i=1…nₛ} 和目标域数据集 Dₜ = {xᵢᵗ, i=1…nₜ},其中 xᵢˢ 和 xᵢᵗ 是输入样本, yᵢˢ 是对应的样本 xᵢˢ 的标签,目标是学习一个映射函数 𝓍ᵢˢ = G(xᵢˢ),称为生成器,它将源域特征映射到目标域特征,以便在转换后的源域图像上训练的深度学习模型可以表现良好 在目标域上。 它是通过判别器来完成的,判别器是一种神经网络,它接收真实图像和变换后的合成图像的输入,并尝试预测输入是否来自真实分布。

网络在对抗性环境中进行训练,只有当鉴别器失败时,生成器才会获胜。 当变换后的图像与真实图像非常相似以至于鉴别器无法区分它们时,该过程会收敛,从而使预测不比随机猜测更好(准确度为 50%)。

3、几何引导的输入输出自适应

各种算法都利用生成方法。 其中之一被称为 GIO-Ada [3],代表几何引导输入输出适应。 该算法相对于简单方法引入了 2 项改进。

它使用可以从模拟引擎轻松检索的另一条信息:深度图。 直觉是,对象的几何信息更好地编码在其深度信息中,而不是其像素的语义标签中。 因此,模型被训练来估计输入图像的深度图,并且这个额外的信息仅在训练期间用作辅助损失。
它在输出级别使用第二个对抗阶段,第二个鉴别器对任务网络的输出(语义标签图和几何深度图)进行操作,经过训练以预测预测的输出来自真实的还是合成的 图像。

在这里插入图片描述

图 1.2 — GIO-Ada 架构概述。 源数据的流向以橙线显示,目标数据的流向以黑线显示

完整的架构由 4 个神经网络组成:生成器(用于转换合成图像)、任务网络(预测真实图像和转换图像的标签和深度图)以及 2 个判别器。 所有网络都经过端到端训练,并采用遵循对抗训练规则的通用优化步骤。

4、Pytorch Lightening实现

为了轻松实现和训练这种复杂的算法,pytorch_lightning 是一个可以提供帮助的库。 这是 pytorch 的包装器,有助于避免重新实现一些与 torch 配合使用所需的样板代码,例如实现训练循环、处理超参数和权重的记录和保存、管理 GPU(或多个 GPU)并执行优化器步骤。 在我们的例子中,最后一个功能不是必需的,因为对抗训练的特殊性恰恰在于生成器和判别器之间优化步骤的交替,并且需要定制。

让我们首先导入库并定义一个实用函数,该函数将用于为鉴别器创建标签。

import itertools
from typing import Iterator

import pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics.classification.jaccard import MulticlassJaccardIndex


def _labels(inputs: torch.Tensor, fill_value: int) -> torch.Tensor:
    return torch.full((inputs.size(0), 1), fill_value).to(inputs)

神经网络被实现为torch模块。 给定 B = 批量大小、C = 图像通道、K = 类数、W、H= 图像的宽度和高度:

  • 任务网络必须处理形状为 B × C × W × H 的批量图像,并返回形状为 B × K × W × H 的标签预测和形状为 B × 1× W × H 的深度预测。一种可能的架构选择是 使用 DeepLabV3+ [4] 作为任务网络,具有两个不同的头,一个用于类别预测,一个用于深度预测。
  • 图像变换网络必须输入所有合成数据,即形状为 B × C × W × H 的图像、形状为 B × K × W × H 的标签和形状为 B × 1× W × H 的深度图,连接起来 它们,并在输出中生成形状为 B × C × W × H 的变换图像。
  • 鉴别器必须采用形状 B × (C 或 C + K + 1) × W × H 的输入,并产生形状 B × 1 的输出,表示样本为真实样本的概率。
class TaskNetwork(nn.Module):
    def __init__(
        self,
        input_channels: int,
        num_classes: int,
        pretrained_backbone: bool = False,
    ) -> None:
        ...

    def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        ...


class ImageTransformNetwork(nn.Module):
    def __init__(
        self,
        input_channels: int,
        output_channels: int,
    ) -> None:
        ...

    def forward(
        self,
        fake_images: torch.Tensor,
        labels: torch.Tensor,
        depths: torch.Tensor,
    ) -> torch.Tensor:
        ...


class Discriminator(nn.Module):
    def __init__(self, input_channels: int) -> None:
        ...

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        ...

其余代码将在 LightningModule 内实现。 在这里,我们在 __init__方法中传递所有超参数,在该方法中我们实例化了 4 个神经网络,以及损失和指标。 卷积层的权重从正态分布初始化,任务网络的权重除外,其中权重可以预先训练,例如使用 ImageNet 数据集。

class GIOAda(pl.LightningModule):
    REAL_LABEL = 1
    FAKE_LABEL = 0

    def __init__(
        self,
        num_classes: int,
        pretrained_backbone: bool,
        init_lr: float,
        betas: tuple[float, float],
        num_epochs: int,
        num_steps_per_epoch: int,
        lam_input: float,
        lam_output: float,
        lam_depth: float,
    ) -> None:
        super().__init__()

        self.save_hyperparameters() # saved in the dictionary self.hparams
        # disabling automatic optimization, as it willl be made manually
        self.automatic_optimization = False

        self.task_network = TaskNetwork(
            input_channels=3,  # RGB Channels
            num_classes=num_classes,  # Classes
            pretrained_backbone=pretrained_backbone,
        )
        self.fake_transformation = ImageTransformNetwork(
            input_channels=num_classes + 4,  # RGB Channels + Classes + Depth
            output_channels=3,  # RGB Channels
        )
        self.input_discriminator = Discriminator(
            input_channels=3,  # RGB Channels
        )
        self.output_discriminator = Discriminator(
            input_channels=num_classes + 1,  # Classes + Depth
        )

        self.depths_loss = nn.L1Loss()
        self.labels_loss = nn.CrossEntropyLoss()
        self.discriminator_loss = nn.BCELoss()

        self.miou_index = MulticlassJaccardIndex(num_classes)

        self.weight_init(pretrained_backbone=pretrained_backbone)

    def weight_init(self, pretrained_backbone: bool = False):
        for name, module in self.named_modules():
            if "task" in name and pretrained_backbone:
                continue
            if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
                module.weight.data.normal_(0, 0.001)
                if module.bias is not None:
                    module.bias.data.zero_()

然后我们定义优化器和学习率调度器。 我们需要一个优化器来处理“生成器”的权重,即生成器和任务网络,以及另一个优化器来处理鉴别器的权重。 作为学习率调度程序,我们将使用 OneCycle 策略,该策略在训练的第一部分通过提高学习率和降低动量来“预热”网络,从而允许早期探索权重空间并找到更好的起点 观点。 然后,在最后部分,通过余弦退火策略降低学习率。

    def configure_optimizers(
        self,
    ) -> tuple[
        list[torch.optim.Adam], list[torch.optim.lr_scheduler.OneCycleLR]
    ]:
        params_g = itertools.chain(
            self.fake_transformation.parameters(),
            self.task_network.parameters(),
        )
        params_d = itertools.chain(
            self.input_discriminator.parameters(),
            self.output_discriminator.parameters(),
        )
        optimizer_g, lr_sched_g = self._optimizer_lr_scheduler(params_g)
        optimizer_d, lr_sched_d = self._optimizer_lr_scheduler(params_d)
        return [optimizer_g, optimizer_d], [lr_sched_g, lr_sched_d]

    def _optimizer_lr_scheduler(
        self,
        parameters: Iterator[torch.nn.Parameter],
    ) -> tuple[torch.optim.Adam, torch.optim.lr_scheduler.OneCycleLR]:
        optimizer = torch.optim.Adam(
            parameters,
            lr=self.hparams["init_lr"],
            betas=self.hparams["betas"],
        )
        lr_sched = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.hparams["init_lr"],
            epochs=self.hparams["num_epochs"],
            steps_per_epoch=self.hparams["num_steps_per_epoch"],
            base_momentum=self.hparams["betas"][0],
        )
        return optimizer, lr_sched

训练步骤接收输入:

  • 从真实数据集中采样的一批真实图像
  • 一批合成图像,以及从合成数据集中采样的相应标签和深度图。

然后它执行 2 个操作:

  • 鉴别器的优化步骤,需要所有输入
  • 生成器的优化步骤,仅需要合成数据

步骤的顺序对于确保模型的收敛至关重要。 由于生成器更容易崩溃,我们应该让判别器“引导”训练路径。 这样,在生成器步骤中,鉴别器的工作会更好一点,为生成器留下“更好”的梯度。

    def training_step(self, batch: tuple[torch.Tensor, ...]) -> None:
        optimizer_g, optimizer_d = self.optimizers()  
        real_images, fake_images, labels, depths = batch

        # Update D network: minimize log(D(x)) + log(1 - D(G(z)))
        self.toggle_optimizer(optimizer_d)
        optimizer_d.zero_grad()
        self._discriminator_step(real_images, fake_images, labels, depths)
        optimizer_d.step()
        self.untoggle_optimizer(optimizer_d)

        # Update G network: maximize log(D(G(z))) and minimize task loss
        self.toggle_optimizer(optimizer_g)
        optimizer_g.zero_grad()
        self._generator_step(fake_images, labels, depths)
        optimizer_g.step()
        self.untoggle_optimizer(optimizer_g)

鉴别器步骤只是最小化鉴别器输出的二元交叉熵损失。 首先在真实批次上完成此操作,其中预期标签全部为 1,然后在合成批次上完成,其中预期标签全部为零。

    def _discriminator_step(
        self,
        real_images: torch.Tensor,
        fake_images: torch.Tensor,
        labels: torch.Tensor,
        depths: torch.Tensor,
    ) -> None:
        disc_lab = _labels(real_images, self.REAL_LABEL)
        disc_input = self.input_discriminator(real_images)
        disc_output = self.output_discriminator(
            torch.concat(self.task_network(real_images), dim=1)
        )
        loss_input = (
            self.discriminator_loss(disc_input, disc_lab)
            * self.hparams["lam_input"]
        )
        loss_output = (
            self.discriminator_loss(disc_output, disc_lab)
            * self.hparams["lam_output"]
        )
        self.manual_backward(loss_input + loss_output)

        transformed = self.fake_transformation(fake_images, labels, depths)
        disc_lab = _labels(transformed, self.FAKE_LABEL)
        disc_input = self.input_discriminator(transformed)
        disc_output = self.output_discriminator(
            torch.concat(self.task_network(transformed), dim=1)
        )
        loss_input = (
            self.discriminator_loss(disc_input, disc_lab)
            * self.hparams["lam_input"]
        )
        loss_output = (
            self.discriminator_loss(disc_output, disc_lab)
            * self.hparams["lam_output"]
        )
        self.manual_backward(loss_input + loss_output)
        
        # Log losses and metrics
        # ...

相反,生成器步骤最小化标签的交叉熵损失和深度估计的 L1Loss,并且还最大化鉴别器的二元交叉熵损失。 这是通过使用与之前相反的标签计算损失来完成的,因此所有标签都用于合成输入。 没有必要计算实际输入的损失,因为生成器的权重对此输出没有影响。

    def _generator_step(
        self,
        fake_images: torch.Tensor,
        labels: torch.Tensor,
        depths: torch.Tensor,
    ) -> None:
        # Set disc_lab = REAL in order to maximize the loss for the 
        # discriminator when inputs are all fakes
        disc_lab = _labels(fake_images, self.REAL_LABEL)

        # Forward pass on all the networks to collect gradients for G
        transformed = self.fake_transformation(fake_images, labels, depths)
        fake_mask, fake_depth = self.task_network(transformed)
        disc_input = self.input_discriminator(transformed)
        disc_output = self.output_discriminator(
            torch.concat((fake_mask, fake_depth), dim=1)
        )

        # Calculate losses
        loss_input = (
            self.discriminator_loss(disc_input, disc_lab)
            * self.hparams["lam_input"]
        )
        loss_output = (
            self.discriminator_loss(disc_output, disc_lab)
            * self.hparams["lam_output"]
        )
        loss_depths = (
            self.depths_loss(fake_depth, depths) * self.hparams["lam_depth"]
        )
        loss_labels = self.labels_loss(fake_mask, labels)

        # Calculate Gradients
        self.manual_backward(
            loss_input + loss_output + loss_depths + loss_labels
        )

        # Log losses and metrics
        # ...

5、结束语

事实证明,这里解释的方法在各种数据集上都非常有效。 在下图中,我们可以看到,利用 sintetic 数据训练的模型优于仅在小型 KITTI 数据集上训练的模型。 从大量合成数据中获取的知识使模型能够从真实图像中提取更细粒度的细节。

在这里插入图片描述

图 1.3 — KITTI 数据集上的语义分割定性结果。 从左到右:左:输入图像,中:非自适应结果,右:GIO-Ada 方法的结果。

该算法也有一些缺点。 首先,对抗性训练可能非常不稳定,这可以从之前看到的不寻常的训练步骤中猜测出来。 因此,详尽的超参数搜索对于获得良好结果至关重要。 另一个主要问题是训练生成网络是一项内存非常密集的工作,尤其是对于高分辨率图像。

最新的研究集中在其他方法(例如自学习)上,利用变压器层中注意力机制的强泛化特性以及特定领域的数据增强技术。

尽管如此,生成方法(例如本文中讨论的生成方法)由于易于适应新领域以及生成学习研究的不断发展,继续在该领域占据一席之地。


原文链接:用合成数据进行语义分割 — BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1173089.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

公众号留言功能在哪?教你开通

为什么公众号没有留言功能?从2018年2月开始,新注册的微信公众号取消了留言功能,原因是为了规避一些营销号通过虚假留言骗取读者信任。不过大部分公众号运营者对TX此举感到失望,一方面大片的留言就像店前排队的顾客,能体…

Camtasia2024破解版电脑屏幕录制剪辑软件

屏幕录制剪辑 TechSmith Camtasia for Mac v2021是 TechSmith 公司所开发出一款专业屏幕录像和编辑, Camtasia Studio2024版是由TechSmith公司官方进行汉化推出的最新版本,除2023版以下版本均没有官方汉化。 同时TechSmith公司打击第三方贩卖Camtasia Studio汉化的…

不同VLAN间的通信原理

不同VLAN间的通信原理 VLANaccess口trunk口 不同VLAN间通信原理 首先我们来看看什么是VLAN VLAN VLAN(Virtual Local Area Network)虚拟局域网,是将一个物理的局域网在逻辑上划分成多个广播域的技术。VLAN技术部署在数据链路层。 VLAN能够隔…

Redis高级数据类型-HyperLogLogBitmap以及使用两种数据类型完成网站数据统计

网站数据统计 定义相关的Redis Key /*** 单日UV*/public static String getUVKey(String date) {return PREFIX_UVSPLITdate;}/*** 记录区间UV* param startData 开始日期* param endDate 结束日期* return*/public static String getUVkey(String startData,String endDate){r…

05 行列式

行列式 面积变化行列式空间定向改变三维空间行列式的计算 这是关于3Blue1Brown "线性代数的本质"的学习笔记。 面积变化 线性变换会使得基向量 i ⃗ \vec{i} i 和 j ⃗ \vec{j} j ​围城的区域面积被缩放。 图1 线性变换可能会使得基向量 i ⃗ \vec{i} i 和 j ⃗ …

webJS基础-----制作一个时间倒计时

1,可以使用以下两个方式制作 方式1:setTimeout ()定时器是在指定的时间后执行某些代码,代码执行一次就会自动停止; 方式2:setInterval ()定时器是按照指定的周期来重复执行某些代码,该定时器不会自动停止…

从 LLM 大模型到 AI Agent 技术演进

▼最近直播超级多,预约保你有收获 近期直播:《基于 LLM 大模型的微调构建AI Agents 案例实践》 —1— LLM 大模型有哪些局限性? 给定一些字或者词(称为 token),预测下一个字或者词模型,就是语言…

基于Magma构建灵活、低成本无线接入网

传统蜂窝网络一般基于特定接入技术并针对大规模公共网络设计,无法灵活适配小规模网络以及异构无线技术。本文介绍了Magma在构建低成本异构无线接入方面的探索。原文: Building Flexible, Low-Cost Wireless Access Networks With Magma 摘要 当今仍然有数十亿人受限…

想学计算机编程从什么学起?零基础如何自学计算机编程?中文编程开发语言工具箱之渐变标签组构件

想学计算机编程从什么学起?零基础如何自学计算机编程? 给大家分享一款中文编程工具,零基础轻松学编程,不需英语基础,编程工具可下载。 这款工具不但可以连接部分硬件,而且可以开发大型的软件,…

跨平台联调代码:Windows下VS2022远程连接Linux-protobuf为例

文章目录 Linux上头文件的位置Linux上共享库的位置Linux上配置好环境变量Windows上VS上的设置添加包含目录与库目录设置链接参数-库依赖项设置编译参数更新远程标头管理器代码的书写输出 Linux上头文件的位置 Linux上我的protobuf头文件的位置为: /usr/local/prot…

2023年中国制糖行业研究报告

第一章 行业概况 1.1 定义 制糖行业是指以甘蔗、甜菜等为主要原料,通过一系列的工艺流程,生产糖以及相关副产品的产业。它是食品工业的重要组成部分,为人们日常生活中的甜蜜体验提供了必不可少的物质基础。 主要原料: 制糖行业…

深入剖析:正则表达式的奥秘

简介 正则表达式(Regular Expressions)是一种强大的文本处理工具,一种用于匹配文本模式的字符串。它由特定的字符和操作符组成,用于定义一个搜索模式。这些搜索模式可以用于文本搜索、替换、验证和提取数据等多种用途。 以下是一…

在搜索引擎中屏蔽csdn

csdn是一个很好的技术博客,里面信息很丰富,我也喜欢在csdn上做技术笔记。 但是CSDN体量太大,文章质量良莠不齐。当在搜索引擎搜索技术问题时,搜索结果中CSDN的内容占比太多,导致难以从其他优秀的博客平台中获取信息。因…

Mac安装VMware

去官网下载一下VMware Download VMware Fusion | VMware | SG 下载完成之后,打开直接闪退,参考这篇文章解决 解决macOS13安装Fusion13闪退的问题-CSDN博客 然后即可成功顺行

linux入门到地狱

linux—001入门 IT圈必备(前端工作者用的比较少) 老旧电脑跑linux不容易卡 我代码没保存windows闪退,僵停(vs2019卡掉线),重启更新,占用cpu内存服务报错pip各种bug 出来生态环境友好其他的全是bug(bug时间成本超过了windows快捷友好生态) 那就说明wind…

行业安卓主板-基于RK3568/3288/3588的AI视觉秤/云相框/点餐机/明厨亮灶行业解决方案(一)

AI视觉秤 单屏Al秤集成独立NPU,可达0.8Tops算力,令AI运算效率大幅提升,以实现生鲜商品快速准确识别,快速称重打印标签,降低生鲜门店运营成本,缓解高峰期称重排队拥堵的现象,提高称重效率&#…

1、Sentinel基本应用限流规则(1)

Sentinel基本应用&限流规则 1.1 概述与作用 随着微服务的流行,服务和服务之间的稳定性变得越来越重要。缓存、降级和限流是保护微服务系统运行稳定性的三大利器。 缓存:提升系统访问速度和增大系统能处理的容量 降级:当服务出问题或者影…

如何写复盘报告

复盘报告在it公司中是为了在出现事情后,我们更好的回顾事情的前因后果,定位问题,指定解决措施,并且宣导,让这类事情减少发生的概率。那复盘报告一般怎样写合适呢?下来我们就看看, 一、一般会先…

Elasticsearch:RAG vs Fine-tunning (大语言模型微调)

如果你对 RAG 还不是很熟悉的话,请阅读之前的文章 “Elasticsearch:什么是检索增强生成 - RAG?”。你可以阅读文章 “Elasticsearch:在你的数据上训练大型语言模型 (LLM)” 来了解更多关于如何训练你的模型。在今天的文章中&#…

Linux项目自动化构建工具-make/Makefile使用

make/Makefile使用介绍 make是一个命令makefile是一个在当前目录下存在的一个具有特定格式的文本文件 ​ 下面我们设计一个场景&#xff0c;实现make命令对我们code.c文件进行编译和删除。 1 #include<stdio.h> 2 3 int main() 4 { 5 printf("hello,world!…