StarGANv2: Diverse Image Synthesis for Multiple Domains论文解读及实现

news2024/10/6 18:11:39

StarGAN v2: Diverse Image Synthesis for Multiple Domainsp
github:https://github.com/clovaai/stargan-v2

0 小结

0.1 模型 4个

模型四个:

  • Generator: G网络
    输入图片x,和风格编码s(可以是F网络或者E网络生成的s),生成x_fake
  • Mapping network: F网络
    输入随机隐变量z和领域风格y (onehot类型数据) ,生成风格编码s
  • Style encoder:E网络
    输入图片x和领域风格y,生成风格编码s
  • Discriminator:D网络
    输入图片x和领域风格y,判断图片X是否是真实的y领域

0.2 损失函数 7项(生成器损失4项,判别器损失3项)

  • 生成器总损失(4项)
    生成器图片判断为真+真实图片和生成图片的风格编码(E网络)s距离最小+不同隐变量z 生成的风格编码s距离大+重构前后的x距离小
  • 判别器总损失(3项)
    真实图片判断为真+生成器图片判别为假+ 正则化损失

1 模型架构

模型主要架构由四部分组成
①Generator、②Mapping network、③Style encoder、④Discriminator

  • Generator:G网络
    生成模型G将输入图片x转换成 输出图片G(x,s),反映了一个领域的独有风格编码s。s是有Maping network F或者风格编码E生成。s被设计表征为领域y的风格。

  • Mapping network: F网络
    给定一个隐变量z和一个领域y,mapping network F生成一个风格 s=F(z),F由一层MLP和多个输出分支组成,分支代表了领域的所有风格。通过随机采样不同的隐变量z,F能高效的学习各领域的风格表征。

  • Style encoder:E网络
    给定图片X和相应的领域y,encoder E挖掘风格编码 s=E(x),E和上面的F类似。使用不同的参考图片,E可以产生不同风格的编码s。

  • Discriminator:D网络
    判别器D是多任务判别器,由多个输出分支组成,每个分支Dy学习一个二分类,判断图片X是否是真实的y领域,或者由G生成的假图 G(x,s).

在这里插入图片描述

2 训练目标

2.1Adversarial objective.

对抗损失:
训练期间,随机采样隐变量z和领域y,通过F函数,生成风格编码s ,
风格编码: s ˉ = F y ˉ ( z ) 风格编码: \bar s=F_{\bar y}(z) 风格编码:sˉ=Fyˉ(z)
生成网络G,将图片X和上面的风格编码S作为输入,生成图片:
生成图片: G ( x , s ˉ ) 生成图片: G(x,\bar s) 生成图片:G(x,sˉ)
对抗损失函数为:
L a d v = E x , y [ l o g D y ( x ) ] + E x , y ˉ , z [ l o g ( 1 − D y ˉ ( G ( x , s ˉ ) ) ) ] ( 1 ) L_{adv}=E_{x,y}[logD_y(x)]+E_{x,\bar y,z}[log(1-D_{\bar y}(G(x,\bar s)))] \qquad (1) Ladv=Ex,y[logDy(x)]+Ex,yˉ,z[log(1Dyˉ(G(x,sˉ)))](1)

D_y:是y领域的判别器
F: 是提供y领域的风格编码s
G:输入图片和风格编码s,生成新图片

2.2 Style reconstruction

风格重构损失
使得前后的风格距离最小
L s t y = E x , y ˉ , z [ ∣ ∣ s ˉ − E y ˉ ( G ( x , s ˉ ) ) ∣ ∣ 1 ] ( 2 ) L_{sty}=E_{x,\bar y ,z}[||\bar s-E_{\bar y}(G(x,\bar s))||_1] \qquad (2) Lsty=Ex,yˉ,z[∣∣sˉEyˉ(G(x,sˉ))1](2)
E网络用来生成风格,上面有提到
(前面的E是求均值,后面的 E y ˉ E_{\bar y} Eyˉ是网络)

2.3 Style diversification

为了使生成器G产生更多风格图片,使得不同风格图片的距离尽可能大

L d s = E x , y ˉ , z 1 , z 2 [ ∣ ∣ G ( x , s ˉ 1 ) − G ( x , s ˉ 2 ) ∣ ∣ 1 ] ( 3 ) L_{ds}=E_{x,\bar y,z_1,z_2}[||G(x,\bar s_1)-G(x,\bar s_2)||_1]\qquad (3) Lds=Ex,yˉ,z1,z2[∣∣G(x,sˉ1)G(x,sˉ2)1](3)

s ˉ 1 和 s ˉ 2 \bar s_1和\bar s_2 sˉ1sˉ2是F在隐变量 z 1 和 z 2 条件下生成的 s ˉ i = F y ˉ ( z i ) f o r i ∈ 1 , 2 z_1和z_2条件下生成的 \bar s_i =F_{\bar y}(z_i) \quad for \quad i \in {1,2} z1z2条件下生成的sˉi=Fyˉ(zi)fori1,2

2.4 cycle consistency loss

循环一致损失
使得经过变换后的X与之前的X距离最小
L c y c = E x , y , y ˉ , z [ ∣ ∣ x − G ( G ( x , s ˉ ) , s ^ ) ∣ ∣ 1 ] ( 4 ) L_{cyc}=E_{x,y,\bar y,z}[||x-G(G(x,\bar s),\hat s)||_1] \qquad (4) Lcyc=Ex,y,yˉ,z[∣∣xG(G(x,sˉ),s^)1](4)

s ^ = E y ( x ) \hat s=E_y(x) s^=Ey(x)是E网络估计的风格code,y是原始的X的领域,使生成器G学会去保留原始的X的特征

2.5 full objective

将上面的损失函数求和,其中DS是最大化距离(所有用减号),其他是最小化
在这里插入图片描述

3 代码实现-网络

3.1 生成器网络:G网络

有多个卷积层组成
输入:图片 和风格编码s,
输入图片x shape [batch_size,3,512,512] ;s shape [batch_size,64]
输出:x_fake shape: [4,3,512,512]

        self.from_rgb = nn.Conv2d(3, dim_in, 3, 1, 1) #(in_channels,out_channels,kernel_size,stride,padding)
        self.encode = nn.ModuleList()
        self.decode = nn.ModuleList()
        self.to_rgb = nn.Sequential(
            nn.InstanceNorm2d(dim_in, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(dim_in, 3, 1, 1, 0))

图片和风格编码会作如下融合,风格编码输出,gamma 和beta,然后和norm后的图片X进行如下计算 (1 + gamma) * self.norm(x) + beta

    def forward(self, x, s):
        h = self.fc(s)
        h = h.view(h.size(0), h.size(1), 1, 1)
        gamma, beta = torch.chunk(h, chunks=2, dim=1)
        return (1 + gamma) * self.norm(x) + beta

  1. 损失函数
    生成器生成的图片,判别器判别为假
 x_fake = nets.generator(x_real, s_trg, masks=masks)
out = nets.discriminator(x_fake, y_trg)
loss_fake = adv_loss(out, 0)

3.2 MappingNetwork:F网络

由多层线性层组成
模型输入是隐变量z(随机生成)shape: [batch_size,dim],和风格领域y shape:[batch_size]
经过共享层后,对每个风格领域分别输出风格编码s, 输出 shpae [batch_size,64]

def forward(self, z, y):
    h = self.shared(z)
    out = []
    for layer in self.unshared:
        out += [layer(h)]
    out = torch.stack(out, dim=1)  # (batch, num_domains, style_dim)
    idx = torch.LongTensor(range(y.size(0))).to(y.device)
    s = out[idx, y]  # (batch, style_dim)
    return s

3.3 StyleEncoder :E网络

输入x,y:
图片x shape [batch_size,3,256,256],领域风格 y shape [batch_size]
输出:
风格编码s shape [batch_size]

    def forward(self, x, y):
        h = self.shared(x)
        h = h.view(h.size(0), -1)
        out = []
        for layer in self.unshared:
            out += [layer(h)]
        out = torch.stack(out, dim=1)  # (batch, num_domains, style_dim)
        idx = torch.LongTensor(range(y.size(0))).to(y.device)
        s = out[idx, y]  # (batch, style_dim)
        return s

3.4 Discriminator:D网络

1) 模型输入输出
输入图片X,及对应的风格y
x shape [4,3,256,256] :[batch_size,channel_num,WH]
y为风格编码[1,1,1,0 ]:batch_size长度大小
输出shap[batch_size]:batch_size长度大小

        out = self.main(x)  # x shape  [4,3,256,256]
        out = out.view(out.size(0), -1)  # (batch, num_domains) # out shape ; [4,2]
        idx = torch.LongTensor(range(y.size(0))).to(y.device)
        out = out[idx, y]  # (batch)
        return out  # out shape [4,]

4 代码实现-损失函数

4.1 判别器损失函数

使所有真实的图片风格预测为1,
使所有G网络生成的图片预测为0

  1. 真实图片判断为真
out = nets.discriminator(x_real, y_org)
loss_real = adv_loss(out, 1)

二分类交叉熵损失

def adv_loss(logits, target):
assert target in [1, 0]
targets = torch.full_like(logits, fill_value=target)
loss = F.binary_cross_entropy_with_logits(logits, targets)
return loss

  1. 生成器生成的图片,判别器判别为假
 x_fake = nets.generator(x_real, s_trg, masks=masks)
out = nets.discriminator(x_fake, y_trg)
loss_fake = adv_loss(out, 0)
  1. 正则化损失
 loss_reg = r1_reg(out, x_real)
 def r1_reg(d_out, x_in):
    # zero-centered gradient penalty for real images
    batch_size = x_in.size(0)
    grad_dout = torch.autograd.grad(
        outputs=d_out.sum(), inputs=x_in,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grad_dout2 = grad_dout.pow(2)
    assert(grad_dout2.size() == x_in.size())
    reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0)
    return reg
  1. 判别器总损失
    真实图片判断为真+生成器图片判别为假+ 正则化损失
loss = loss_real + loss_fake + args.lambda_reg * loss_reg

4.2 生成器损失函数

  1. 使判别器将生成的图片全预测为真
    x_fake = nets.generator(x_real, s_trg, masks=masks)
    out = nets.discriminator(x_fake, y_trg)
    loss_adv = adv_loss(out, 1)
  1. 使真实图片和生成图片的E网络风格编码s,距离最小
    # style reconstruction loss
     s_trg = nets.style_encoder(x_ref, y_trg)
    s_pred = nets.style_encoder(x_fake, y_trg)
    loss_sty = torch.mean(torch.abs(s_pred - s_trg))

3)多样性损失函数
使得不同隐变量z 生成的风格编码s,最终生成的x_fake 距离尽可能大

    x_fake = nets.generator(x_real, s_trg, masks=masks)
    x_fake2 = nets.generator(x_real, s_trg2, masks=masks)
    x_fake2 = x_fake2.detach()
    loss_ds = torch.mean(torch.abs(x_fake - x_fake2))
  1. 循环一致损失
    使重构后的x_rec和原x_real距离小
    s_org = nets.style_encoder(x_real, y_org)
    x_rec = nets.generator(x_fake, s_org, masks=masks)
    loss_cyc = torch.mean(torch.abs(x_rec - x_real))
  1. 生成器总损失
    生成器图片判断为真+真实图片和生成图片的风格编码s距离最小+不同隐变量z 生成的风格编码s距离大+重构前后的x距离小
    loss = loss_adv + args.lambda_sty * loss_sty \
        - args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/870725.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode 778. Swim in Rising Water【最小瓶颈路;二分+BFS或DFS;计数排序+并查集;最小生成树】2096

本文属于「征服LeetCode」系列文章之一,这一系列正式开始于2021/08/12。由于LeetCode上部分题目有锁,本系列将至少持续到刷完所有无锁题之日为止;由于LeetCode还在不断地创建新题,本系列的终止日期可能是永远。在这一系列刷题文章…

城市最短路

题目描述 下图表示的是从城市A到城市H的交通图。从图中可以看出,从城市A到城市H要经过若干个城市。现要找出一条经过城市最少的一条路线。 输入输出格式 输入格式: 无 输出格式: 倒序输出经过城市最少的一条路线 输入输出样例 输入样例…

【LeetCode75】第二十七题(933)最近的请求次数

目录 题目: 示例: 分析: 代码运行结果: 题目: 示例: 分析: 首先这是LeetCode75里第一道设计类的题目,这种类型的题目会比较新颖,就是按照题目要求来设计一个类。然后…

GIL 锁或将在 CPython 中成为可选项

哈喽大家好,我是咸鱼 几天前有媒体报道称,经过多次辩论,Python 指导委员会打算批准通过 PEP 703 提案,让 GIL(全局解释器)锁在 CPython 中成为一个可选项 PEP 703 提案主要目标是使 GIL 变成可选项&#…

二叉树的讲解

💓博主个人主页:不是笨小孩👀 ⏩专栏分类:数据结构与算法👀 刷题专栏👀 C语言👀 🚚代码仓库:笨小孩的代码库👀 ⏩社区:不是笨小孩👀 🌹欢迎大家三连关注&…

定义行业新标准?谷歌:折叠屏手机可承受20万次折叠

根据Patreon账户上的消息,Android专家Mishaal Rahman透露,谷歌计划推出新的硬件质量标准,以满足可折叠手机市场的需求。Android原始设备制造商(OEM)将需要完成谷歌提供的问卷调查,并提交样品设备进行严格审…

读书笔记 |【项目思维与管理】➾ 顺势而动

读书笔记 |【项目思维与管理】➾ 顺势而动 一、企业步入“终结者时代”二、过去成功的经验也许是最可怕的三、做好非重复性的事四、适应客户是出发点五、向知识型企业转变六、速度是决胜条件 💖The Begin💖点点关注,收藏不迷路💖 …

【C++ 学习 ⑬】- 详解 list 容器

目录 一、list 容器的基本介绍 二、list 容器的成员函数 2.1 - 迭代器 2.2 - 修改操作 三、list 的模拟实现 3.1 - list.h 3.2 - 详解 list 容器的迭代器 3.2 - test.cpp 一、list 容器的基本介绍 list 容器以类模板 list<T>&#xff08;T 为存储元素的类型&…

ruoyi-vue-v3.8.6-搭建

一、准备工作 环境&#xff1a; win10、MySQL8、JDKjdk1.8.0_311 redis6.2.6 IDEA 2022.3.3 maven3.9 Node v18.14.2 npm 9.5.0 版本&#xff1a; 若依框架官方文档&#xff1a;http://doc.ruoyi.vip/ 官网导航&#xff1a;http://120.79.202.7/ 若依项目地址&#xff…

将数组(矩阵)旋转根据指定的旋转角度scipy库的rotate方法

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 将数组(矩阵)旋转 根据指定的旋转角度 scipy库的rotate方法 关于下列代码说法正确的是&#xff1f; import numpy as np from scipy.ndimage import rotate a np.array([[1,2,3,4], …

腾讯云服务器竞价实例是什么?适用于什么行业?有啥优惠?

腾讯云服务器CVM计费模式分为包年包月、按量计费和竞价实例&#xff0c;什么是竞价实例&#xff1f;竞价实例和按量付费相类似&#xff0c;优势是价格更划算&#xff0c;缺点是云服务器实例有被自动释放风险&#xff0c;腾讯云服务器网来详细说下什么是竞价实例&#xff1f;以及…

八、解析应用程序——分析应用程序(1)

文章目录 一、确定用户输入入口点1.1 URL文件路径1.2 请求参数1.3 HTTP消息头1.4 带外通道 二、确定服务端技术2.1 提取版本信息2.2 HTTP指纹识别2.3 文件拓展名2.4 目录名称2.5 会话令牌2.6 第三方代码组件 小结 枚举尽可能多的应用程序内容只是解析过程的一个方面。分析应用程…

Leetcode35 搜索插入位置

给定一个排序数组和一个目标值&#xff0c;在数组中找到目标值&#xff0c;并返回其索引。如果目标值不存在于数组中&#xff0c;返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 解析&#xff1a;以示例2来举例&#xff0c;left 0,right 3,mid 1…

【CSS】CSS 布局——弹性盒子

Flexbox 是一种强大的布局系统&#xff0c;旨在更轻松地使用 CSS 创建复杂的布局。 它特别适用于构建响应式设计和在容器内分配空间&#xff0c;即使项目的大小是未知的或动态的。Flexbox 通常用于将元素排列成一行或一列&#xff0c;并提供一组属性来控制 flex 容器内的项目行…

SpringMVC的异常处理机制

1、简介 系统中异常包括两类&#xff1a;预期异常和运行时异常RuntimeException&#xff0c;前者通过捕获异常从而获取异常信息&#xff0c;后 者主要通过规范代码开发、测试等手段减少运行时异常的发生。 系统的Dao、Service、Controller出现都通过throws Exception向上抛出…

BFS 五香豆腐

题目描述 经过谢老师n次的教导&#xff0c;dfc终于觉悟了——过于腐败是不对的。但是dfc自身却无法改变自己&#xff0c;于是他找到了你&#xff0c;请求你的帮助。 dfc的内心可以看成是5*5个分区组成&#xff0c;每个分区都可以决定的的去向&#xff0c;0表示继续爱好腐败&…

【图像分类】CNN + Transformer 结合系列.4

介绍两篇利用Transformer做图像分类的论文&#xff1a;CoAtNet&#xff08;NeurIPS2021&#xff09;&#xff0c;ConvMixer&#xff08;ICLR2022&#xff09;。CoAtNet结合CNN和Transformer的优点进行改进&#xff0c;ConvMixer则patch的角度来说明划分patch有助于分类。 CoAtN…

多目标优化算法之樽海鞘算法(MSSA)

樽海鞘算法的主要灵感是樽海鞘在海洋中航行和觅食时的群聚行为。相关文献表示&#xff0c;多目标优化之樽海鞘算法的结果表明&#xff0c;该算法可以逼近帕雷托最优解&#xff0c;收敛性和覆盖率高。 通过给SSA算法配备一个食物来源库来解决第一个问题。该存储库维护了到目前为…

el-select 动态添加多个下拉框

实现的效果如下: 主要的代码如下: 这是formdata 的结构 主要的逻辑 在这个 methods

ubuntu supervisor 部署 python 项目

ubuntu supervisor 查看系统是否可用 cuda 初环境与设备安装 supervisor 环境创建 Supervisor 配置文件启动 Supervisor 服务管理项目 本篇文章将介绍 ubuntu supervisor 部署 python 项目 Supervisor 是一个用于管理和监控进程的系统工具。它的主要功能是确保系统中的进程持续…