2、StarGAN V2

news2024/11/15 4:17:17

2、StarGAN V2

StarGAN 论文链接:StarGAN

StarGAN V2 论文链接:StarGAN V2

在介绍StarGAN V2之前,我们先对StarGAN有一定的了解,StarGAN V2只是在StarGAN的基础上做出了改进,基本的架构是没有变的,只是将风格编码做成了向量的形式,使得风格编码也是可以学习的。

StarGAN
StarGAN的出发点

StarGAN(Star Generative Adversarial Network)是一种生成对抗网络(GAN)的变体,用于图像到图像的多域转换任务。StarGAN 的核心特点是,它可以在单一模型中实现多域图像转换,而不需要为每个领域的转换训练不同的模型。其实就是来解决在CycleGAN中转化一种风格就需要训练一个模型的问题,设计一种编码来实现一个生成器和一个判别器能够生成多种风格,解决了CycleGAN的弊端。

在这里插入图片描述

StarGAN架构图
  1. StarGAN为了解决CycleGAN每一个风格需要训练一个模型,并且需要多个生成器和判别器的问题,StarGAN采用了风格编码来实现只需要一个生成器和一个判别器,但是总体思想仍然采用CycleGAN的思想来设计损失函数。

    • 生成器

      • 在StarGAN中,生成器的输入不仅是图像,还包含目标域的域标签(即风格编码)。生成器会根据该标签生成属于目标域的图像。
      • 生成器同时使用了循环一致性损失(cycle-consistency loss),这是借鉴了CycleGAN的思想。通过将生成的图像转换回原始域,以确保生成图像保留了输入图像的关键信息。
      • 目标是通过风格编码使得生成器能够将一张图片从一个域(如人脸图片)转换为多个目标域(如不同表情、发型或年龄),并在多个域之间进行切换。

      判别器

      • StarGAN的判别器不仅需要判断图像的真假(真实图像 vs. 生成图像),还需要判别该图像属于哪个域(风格编码)。
      • 判别器会输出多个域的分类信息,并在真假分类的同时,判断生成的图像是否符合指定的域标签。

      损失函数

      • 对抗性损失:用于保证生成器生成的图像能够欺骗判别器。
      • 域分类损失:用于确保生成的图像与目标域标签匹配。
      • 循环一致性损失:用于确保生成图像能够还原回原始域,以保持输入的主要特征。

在这里插入图片描述

StarGAN V2
StarGAN V2出发点

StarGAN V2的出发点来自于StarGAN中使用的编码是一些固定的01编码,是不可学习,而StarGAN V2则在风格编码做出来改进,将风格编码初始化成向量,同时也可以通过原始输入图像来生成风格编码,而生成风格编码的网络是可学习的,使的风格更加的差异化,并且生成的图像风格更加准确。模型设计主要流程上并没有做出改动,主要在于损失函数的改动。理解损失函数也是掌握对抗生成网络的关键。

模型架构图

1. 生成器(Generator)

StarGAN V2 的生成器与传统的 GAN 不同,它融合了风格编码和图像转换的思想。生成器的主要目标是将输入图像转换为不同风格的图像。

生成器的核心组成部分:

  • 输入:生成器的输入不仅包括要转换的图像,还包括目标风格编码(可以是从风格编码器得到的风格向量,或者是随机采样的向量)。
  • 风格编码器:StarGAN V2 引入了一个风格编码器,它可以从目标图像中提取出风格信息,将其表示为风格向量。这样,生成器可以利用不同的风格编码生成对应风格的图像。
  • 结构设计:生成器采用了基于卷积的网络架构,但通过风格向量来调控生成过程中的特征图。这使得生成器可以生成具有不同风格特征的图像。
  • 多样性建模:生成器能够通过不同的风格编码生成多个同一源图像的多样化风格变化。这依赖于生成器对风格编码的处理,使得输出图像既能够保持输入的语义信息,又能够呈现目标风格。

2. 判别器(Discriminator)

StarGAN V2 的判别器不仅要判断图像的真假,还要判断生成图像是否符合目标风格。它负责区分生成器生成的图像和真实图像,并检测生成图像的风格是否与目标域匹配。

判别器的核心组成部分:

  • 输入:判别器接收图像输入,同时附带目标风格标签。它的任务是判断输入的图像是否来自真实的目标域,并判断生成器生成的图像是否匹配目标风格。
  • 多域分类:判别器输出的是多分类结果。除了判断图像是真实还是生成的,它还需要对图像的风格域进行分类,确保生成的图像符合目标风格。
  • PatchGAN 设计:判别器通常采用 PatchGAN(局部感知)的设计,它对图像的每个局部区域进行真假和风格分类。这种设计有助于判别器更好地捕捉图像的局部特征,尤其是风格特征,从而在视觉上确保生成的图像看起来自然

损失函数的改进

  • 对抗损失依然是生成对抗网络的核心,用于确保生成图像能欺骗判别器。

  • 风格一致性损失:StarGAN V2通过风格一致性损失来确保生成的图像能保持输入图像的关键信息,并且使风格变化是自然且符合目标域的。

  • 循环一致性损失:与StarGAN类似,StarGAN V2依然采用了循环一致性损失来保证生成图像在转换回原始域时能保持输入图像的主要特征。

  • 多样性损失: StarGAN V2还通过引入多样性损失,确保生成的图像在同一目标域内保持足够的多样性,而不仅仅是简单的风格映射。通过学习不同的风格编码,生成器可以在同一个目标域中生成多个不同风格的图像。

在这里插入图片描述

生成器损失

包含下面四种:对抗损失风格一致性损失多样性损失循环一致性损失

对抗损失

生成对抗网络的核心,用于确保生成图像能欺骗判别器。

公式:在这里插入图片描述

风格一致性损失

风格一致性损失,就是保证模型生成的图片的风格和需要生成的风格越接近越好。首先使用x和风格s生成一张图片,然后再用Style encoder进行编码,获得生成后图片的风格编码,计算它和需要生成风格编码之间的差距作为风格一致性损失。

公式:在这里插入图片描述

多样性损失

多样性损失是确保生成的图像在同一目标域内保持足够的多样性,而不仅仅是简单的风格映射。通过学习不同的风格编码,生成器可以在同一个目标域中生成多个不同风格的图像。简单来说,就是两者的标签一样的,同时采用同样的Mapping network进行编码,但是要使编码出来风格编码差异性越大越好,这样采用生成多种不同风格的图像,学习的是Mapping network。

公式:在这里插入图片描述

循环一致性损失

循环一致性损失和CycleGAN的思想是一样的,要求我们生成出来的图片必须经过还原后还是能够与原来的图像越接近越好。从公式中可以看出,先对x和某中风格编码s生成图像,在使用x经过style encoder生成s1,然后将s1和生成的图像输入生成器,得到图片与原来的图片做比较,这样就得到原始图像和还原后图像之间的差异作为循环一致性损失。

公式:在这里插入图片描述

最终Loss值公式:

Ladv 是对抗损失,Lds前面的负号,说明他们之间的差异越大越好。

在这里插入图片描述

生气器损失计算源码
def compute_g_loss(nets, args, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None):
    # 确保 z_trgs 和 x_refs 其中一个不为空
    assert (z_trgs is None) != (x_refs is None)
    
    # 当 z_trgs 不为空时,解包 z_trg 和 z_trg2
    if z_trgs is not None:
        z_trg, z_trg2 = z_trgs
    # 当 x_refs 不为空时,解包 x_ref 和 x_ref2
    if x_refs is not None:
        x_ref, x_ref2 = x_refs

    # 对抗损失(adversarial loss)
    if z_trgs is not None:
        s_trg = nets.mapping_network(z_trg, y_trg)  # 通过映射网络生成目标风格编码
    else:
        s_trg = nets.style_encoder(x_ref, y_trg)  # 通过风格编码器生成目标风格编码

    x_fake = nets.generator(x_real, s_trg, masks=masks)  # 使用生成器生成假图像
    out = nets.discriminator(x_fake, y_trg)  # 判别器判断生成的假图像
    loss_adv = adv_loss(out, 1)  # 对抗损失,目标是真

    # 风格重构损失(style reconstruction loss)
    s_pred = nets.style_encoder(x_fake, y_trg)  # 从生成的假图像中提取风格编码
    loss_sty = torch.mean(torch.abs(s_pred - s_trg))  # 风格重构损失,比较生成和目标风格编码的差异

    # 多样性敏感损失(diversity sensitive loss)
    if z_trgs is not None:
        s_trg2 = nets.mapping_network(z_trg2, y_trg)  # 生成第二个风格编码
    else:
        s_trg2 = nets.style_encoder(x_ref2, y_trg)  # 从参考图像中提取第二个风格编码
    x_fake2 = nets.generator(x_real, s_trg2, masks=masks)  # 生成第二个假图像
    x_fake2 = x_fake2.detach()  # 停止梯度计算
    loss_ds = torch.mean(torch.abs(x_fake - x_fake2))  # 计算两个假图像之间的差异,鼓励多样性

    # 循环一致性损失(cycle-consistency loss)
    masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None  # 使用 FAN 模型获取热图(如果 w_hpf > 0)
    s_org = nets.style_encoder(x_real, y_org)  # 提取输入图像的原始风格编码
    x_rec = nets.generator(x_fake, s_org, masks=masks)  # 将假图像转换回原始域
    loss_cyc = torch.mean(torch.abs(x_rec - x_real))  # 循环一致性损失,确保恢复的图像与原图像相似

    # 总损失,由对抗损失、风格重构损失、多样性损失和循环一致性损失组成
    loss = loss_adv + args.lambda_sty * loss_sty \
        - args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc

    # 返回总损失以及每部分的损失值
    return loss, Munch(adv=loss_adv.item(),
                       sty=loss_sty.item(),
                       ds=loss_ds.item(),
                       cyc=loss_cyc.item())

判别器损失

它对真实图像和生成的假图像分别进行判别,并计算对应的对抗损失。对真实图像,函数计算其对抗损失(希望判别器将其判别为真)和 R1 正则化损失,以提高训练稳定性。对生成的假图像,生成器根据目标域的风格编码生成假图像,判别器再判断该假图像并计算对抗损失(希望判别器将其判别为假)。最后,将真实损失、假图像损失和正则化损失加和,作为判别器的总损失。

判别器损失计算源码
def compute_d_loss(nets, args, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None):
    # 确保 z_trg 和 x_ref 中只有一个不为空
    assert (z_trg is None) != (x_ref is None)
    
    # 对真实图像进行操作
    x_real.requires_grad_()  # 允许对 x_real 进行梯度计算
    out = nets.discriminator(x_real, y_org)  # 使用判别器判断真实图像
    loss_real = adv_loss(out, 1)  # 真实图像的对抗损失,目标是 1
    loss_reg = r1_reg(out, x_real)  # R1 正则化损失,用于提高训练稳定性

    # 对生成的假图像进行操作
    with torch.no_grad():  # 假图像的生成不需要计算梯度
        if z_trg is not None:
            s_trg = nets.mapping_network(z_trg, y_trg)  # 通过映射网络生成目标风格编码
        else:  # x_ref 不为空时,通过风格编码器生成风格编码
            s_trg = nets.style_encoder(x_ref, y_trg)

        x_fake = nets.generator(x_real, s_trg, masks=masks)  # 生成假图像
    out = nets.discriminator(x_fake, y_trg)  # 判别器判断生成的假图像
    loss_fake = adv_loss(out, 0)  # 假图像的对抗损失,目标是 0

    # 总损失,由真实损失、假图像损失和正则化损失组成
    loss = loss_real + loss_fake + args.lambda_reg * loss_reg

    # 返回总损失以及每部分的损失值
    return loss, Munch(real=loss_real.item(),
                       fake=loss_fake.item(),
                       reg=loss_reg.item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2155398.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

探索AI编程新境界:aider库揭秘

文章目录 **探索AI编程新境界:aider库揭秘**背景:为何选择aider?简介:aider是什么?安装指南:如何安装aider?功能演示:aider的简单用法实战应用:aider在不同场景下的使用常…

【RabbitMQ】应用问题

RabbitMQ 应用问题 1. 幂等性保障1.1 幂等性介绍1.2 解决⽅案全局唯⼀ID业务逻辑判断 2. 顺序性保障2.1 顺序性保障介绍2.2 顺序性保障⽅案 3. 消息积压问题3.1 原因分析3.2 解决⽅案 1. 幂等性保障 1.1 幂等性介绍 幂等性是数学和计算机科学中某些运算的性质, 它们可以被多次…

华为静态路由(route-static)

静态路由的组成 在华为路由器中,使用ip route-static命令配置静态路由。 一条静态路由主要包含以下要素: 目的地址:数据包要到达的目标IP地址 子网掩码:用于指定目的地址的网络部分和主机部分 下一跳地址(可选&#…

【yolo破损纸板-包装盒-快递袋缺陷检测】

yolo破损纸板-包装盒-快递袋缺陷检测 破损纸质包装盒检测方盒型快递包裹检测 破损纸质包装盒检测 数据集合模型 可视化 方盒型快递包裹检测 数据集和模型 train: ../train/images val: ../valid/images test: ../test/images nc: 1 names: - box_packet可视化

理解JVM中的死锁:原因及解决方案

死锁是并发应用程序中的常见问题。在此类应用程序中,我们使用锁定机制来确保线程安全。此外,我们使用线程池和信号量来管理资源消耗。然而,在某些情况下,这些技术可能会导致死锁。 在本文中,我们将探讨死锁、死锁出现…

蓝桥杯模块一:LED指示灯的基本控制

模块训练一:LED指示灯的基本控制 模块1到模块13都是通过I\O模式进行设计 一、电路图 二、电路分析 1.74HC573锁存器介绍 OE端接地,上电即工作,控制LE端,当LE端接高电平时,锁存器开始工作,接通D和Q 2.电路工作原理分析…

C语言 | Leetcode C语言题解之第415题字符串相加

题目: 题解: char* addStrings(char* num1, char* num2) {int i strlen(num1) - 1, j strlen(num2) - 1, add 0;char* ans (char*)malloc(sizeof(char) * (fmax(i, j) 3));int len 0;while (i > 0 || j > 0 || add ! 0) {int x i > 0 ?…

SpringCloud入门(五)Nacos注册中心(上)

国内公司一般都推崇阿里巴巴的技术,比如注册中心,SpringCloudAlibaba也推出了一个名为Nacos的注册中心。Dynami Naming and Configuration Service。是阿里巴巴2018年7月开源的项目。 Nacos是阿里巴巴的产品,现在是SpringCloud中的一个组件。…

nuget包管理

1、下载 下载nuget 下载nuget.exe,配置系统环境变量,打开电脑属性一高级系统设置一环境变量一系统变量,选择Path,添加nuget.exe目录 2、常用命令 nuget install System.Data.SQLITE -SolutionDirectory D:\NugetPackages\ -Packa…

基于BiGRU+Attention实现风力涡轮机发电量多变量时序预测(PyTorch版)

前言 系列专栏:【深度学习:算法项目实战】✨︎ 涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对…

37. Vector3与模型位置、缩放属性

本文章给通过组对象Group (opens new window)给大家讲解一下threejs层级模型或树结构的概念。 Group层级模型(树结构)案例 下面代码创建了两个网格模型mesh1、mesh2,通过THREE.Group类创建一个组对象group,然后通过add方法把网格模型mesh1、mesh2作为设置为组对象g…

【Godot4.3】GraphEdit全解析(1) - 基础介绍

概述 最早系统性的讲述Godot的GraphEdit和GraphNode的教程应该是Hi小胡的了,也有小伙伴已经设计出一些插件或小应用用于辅助自己的项目。或者更直观的你可以去看看B站的Godot的Visual Shader教程。 我是学了好几次,学完就忘了用,本篇是基于…

Java只有国人在搞了?

从Java诞生到现在,在全球一直属于最大的开发平台,拥有着世界上最多的开发者和最活跃的社区。你说Java只有国人在搞就有点过分了,Java中常用的主流框架全是外国人写的,虽说阿里也为Java做了很多贡献,但你还真没有资格说…

代码随想录Day 52|题目:101.孤岛的面积、102.沉没孤岛、103.水流问题、104.建造最大岛屿

提示:DDU,供自己复习使用。欢迎大家前来讨论~ 文章目录 图论part03题目一:101.孤岛的总面积解题思路DFS**BFS** 题目二:102. 沉没孤岛解题思路 题目三:103. 水流问题解题思路优化 题目四:104.建造最大岛屿…

Windows11+Microsoft MPI v10.1.3 安装配置记录

WindowsMicrosoft MPI v10.1.3 安装配置记录 MS-MPI 安装VS中进行配置属性管理器-添加新项目属性表VC目录-包含目录链接器-常规-附加库目录链接器-输入-附加依赖项 测试 某个项目需要MPI支持,在此记录MS MPI的安装配置过程。 MS-MPI 安装 在微软官网下载 两个都下…

去中心化的力量:探索Web3的分布式网络

Web3作为一种新兴的网络架构,代表了对互联网发展的一种探索。与传统的中心化互联网模式相比,Web3致力于通过去中心化的方式构建更加开放和透明的数字世界。本文将探讨Web3的核心理念、技术实现及其潜在应用。 一、去中心化的核心理念 Web3的去中心化理…

深度学习02-pytorch-06-张量的形状操作

在 PyTorch 中,张量的形状操作是非常重要的,可以让你灵活地调整和处理张量的维度和数据结构。以下是一些常用的张量形状函数及其用法,带有详细解释和举例说明: 1. reshape() 功能: 改变张量的形状,但不改变数据的顺序…

Stable Diffusion 使用详解(12)--- 设计师风格变换

目录 背景 seg模型(语义分割) 描述 原理 实战-装修风格变换 现代风格 欧式风格转换 提示词及相关参数设置 模型选择 seg cn 加持 效果 还能做点啥 问题 解决方法 出图效果 二次优化调整 二次出图效果 地中海风格转换 参数修改 效果 …

软硬件项目运维方案(Doc原件完整版套用)

1 系统的服务内容 1.1 服务目标 1.2 信息资产统计服务 1.3 网络、安全系统运维服务 1.4 主机、存储系统运维服务 1.5 数据库系统运维服务 1.6 中间件运维服务 2 运维服务流程 3 服务管理制度规范 3.1 服务时间 3.2 行为规范 3.3 现场服务支持规范 3.4 问题记录规范…

C++容器list底层迭代器的实现逻辑~list相关函数模拟实现

目录 1.两个基本的结构体搭建 2.实现push_back函数 3.关于list现状的分析(对于我们如何实现这个迭代器很重要) 3.1和string,vector的比较 3.2对于list的分析 3.3总结 4.迭代器类的封装 5.list容器里面其他函数的实现 6.个人总结 7.代码附录 1.两…