2、StarGAN V2
StarGAN 论文链接:StarGAN
StarGAN V2 论文链接:StarGAN V2
在介绍StarGAN V2之前,我们先对StarGAN有一定的了解,StarGAN V2只是在StarGAN的基础上做出了改进,基本的架构是没有变的,只是将风格编码做成了向量的形式,使得风格编码也是可以学习的。
StarGAN
StarGAN的出发点
StarGAN(Star Generative Adversarial Network)是一种生成对抗网络(GAN)的变体,用于图像到图像的多域转换任务。StarGAN 的核心特点是,它可以在单一模型中实现多域图像转换,而不需要为每个领域的转换训练不同的模型。其实就是来解决在CycleGAN中转化一种风格就需要训练一个模型的问题,设计一种编码来实现一个生成器和一个判别器能够生成多种风格,解决了CycleGAN的弊端。
StarGAN架构图
StarGAN为了解决CycleGAN每一个风格需要训练一个模型,并且需要多个生成器和判别器的问题,StarGAN采用了风格编码来实现只需要一个生成器和一个判别器,但是总体思想仍然采用CycleGAN的思想来设计损失函数。
生成器:
- 在StarGAN中,生成器的输入不仅是图像,还包含目标域的域标签(即风格编码)。生成器会根据该标签生成属于目标域的图像。
- 生成器同时使用了循环一致性损失(cycle-consistency loss),这是借鉴了CycleGAN的思想。通过将生成的图像转换回原始域,以确保生成图像保留了输入图像的关键信息。
- 目标是通过风格编码使得生成器能够将一张图片从一个域(如人脸图片)转换为多个目标域(如不同表情、发型或年龄),并在多个域之间进行切换。
判别器:
- StarGAN的判别器不仅需要判断图像的真假(真实图像 vs. 生成图像),还需要判别该图像属于哪个域(风格编码)。
- 判别器会输出多个域的分类信息,并在真假分类的同时,判断生成的图像是否符合指定的域标签。
损失函数:
- 对抗性损失:用于保证生成器生成的图像能够欺骗判别器。
- 域分类损失:用于确保生成的图像与目标域标签匹配。
- 循环一致性损失:用于确保生成图像能够还原回原始域,以保持输入的主要特征。
StarGAN V2
StarGAN V2出发点
StarGAN V2的出发点来自于StarGAN中使用的编码是一些固定的01编码,是不可学习,而StarGAN V2则在风格编码做出来改进,将风格编码初始化成向量,同时也可以通过原始输入图像来生成风格编码,而生成风格编码的网络是可学习的,使的风格更加的差异化,并且生成的图像风格更加准确。模型设计主要流程上并没有做出改动,主要在于损失函数的改动。理解损失函数也是掌握对抗生成网络的关键。
模型架构图
1. 生成器(Generator)
StarGAN V2 的生成器与传统的 GAN 不同,它融合了风格编码和图像转换的思想。生成器的主要目标是将输入图像转换为不同风格的图像。
生成器的核心组成部分:
- 输入:生成器的输入不仅包括要转换的图像,还包括目标风格编码(可以是从风格编码器得到的风格向量,或者是随机采样的向量)。
- 风格编码器:StarGAN V2 引入了一个风格编码器,它可以从目标图像中提取出风格信息,将其表示为风格向量。这样,生成器可以利用不同的风格编码生成对应风格的图像。
- 结构设计:生成器采用了基于卷积的网络架构,但通过风格向量来调控生成过程中的特征图。这使得生成器可以生成具有不同风格特征的图像。
- 多样性建模:生成器能够通过不同的风格编码生成多个同一源图像的多样化风格变化。这依赖于生成器对风格编码的处理,使得输出图像既能够保持输入的语义信息,又能够呈现目标风格。
2. 判别器(Discriminator)
StarGAN V2 的判别器不仅要判断图像的真假,还要判断生成图像是否符合目标风格。它负责区分生成器生成的图像和真实图像,并检测生成图像的风格是否与目标域匹配。
判别器的核心组成部分:
- 输入:判别器接收图像输入,同时附带目标风格标签。它的任务是判断输入的图像是否来自真实的目标域,并判断生成器生成的图像是否匹配目标风格。
- 多域分类:判别器输出的是多分类结果。除了判断图像是真实还是生成的,它还需要对图像的风格域进行分类,确保生成的图像符合目标风格。
- PatchGAN 设计:判别器通常采用 PatchGAN(局部感知)的设计,它对图像的每个局部区域进行真假和风格分类。这种设计有助于判别器更好地捕捉图像的局部特征,尤其是风格特征,从而在视觉上确保生成的图像看起来自然
损失函数的改进:
对抗损失依然是生成对抗网络的核心,用于确保生成图像能欺骗判别器。
风格一致性损失:StarGAN V2通过风格一致性损失来确保生成的图像能保持输入图像的关键信息,并且使风格变化是自然且符合目标域的。
循环一致性损失:与StarGAN类似,StarGAN V2依然采用了循环一致性损失来保证生成图像在转换回原始域时能保持输入图像的主要特征。
多样性损失: StarGAN V2还通过引入多样性损失,确保生成的图像在同一目标域内保持足够的多样性,而不仅仅是简单的风格映射。通过学习不同的风格编码,生成器可以在同一个目标域中生成多个不同风格的图像。
生成器损失
包含下面四种:对抗损失、风格一致性损失、多样性损失、循环一致性损失
对抗损失
生成对抗网络的核心,用于确保生成图像能欺骗判别器。
公式:
风格一致性损失
风格一致性损失,就是保证模型生成的图片的风格和需要生成的风格越接近越好。首先使用x和风格s生成一张图片,然后再用Style encoder进行编码,获得生成后图片的风格编码,计算它和需要生成风格编码之间的差距作为风格一致性损失。
公式:
多样性损失
多样性损失是确保生成的图像在同一目标域内保持足够的多样性,而不仅仅是简单的风格映射。通过学习不同的风格编码,生成器可以在同一个目标域中生成多个不同风格的图像。简单来说,就是两者的标签一样的,同时采用同样的Mapping network进行编码,但是要使编码出来风格编码差异性越大越好,这样采用生成多种不同风格的图像,学习的是Mapping network。
公式:
循环一致性损失
循环一致性损失和CycleGAN的思想是一样的,要求我们生成出来的图片必须经过还原后还是能够与原来的图像越接近越好。从公式中可以看出,先对x和某中风格编码s生成图像,在使用x经过style encoder生成s1,然后将s1和生成的图像输入生成器,得到图片与原来的图片做比较,这样就得到原始图像和还原后图像之间的差异作为循环一致性损失。
公式:
最终Loss值公式:
Ladv 是对抗损失,Lds前面的负号,说明他们之间的差异越大越好。
生气器损失计算源码
def compute_g_loss(nets, args, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None):
# 确保 z_trgs 和 x_refs 其中一个不为空
assert (z_trgs is None) != (x_refs is None)
# 当 z_trgs 不为空时,解包 z_trg 和 z_trg2
if z_trgs is not None:
z_trg, z_trg2 = z_trgs
# 当 x_refs 不为空时,解包 x_ref 和 x_ref2
if x_refs is not None:
x_ref, x_ref2 = x_refs
# 对抗损失(adversarial loss)
if z_trgs is not None:
s_trg = nets.mapping_network(z_trg, y_trg) # 通过映射网络生成目标风格编码
else:
s_trg = nets.style_encoder(x_ref, y_trg) # 通过风格编码器生成目标风格编码
x_fake = nets.generator(x_real, s_trg, masks=masks) # 使用生成器生成假图像
out = nets.discriminator(x_fake, y_trg) # 判别器判断生成的假图像
loss_adv = adv_loss(out, 1) # 对抗损失,目标是真
# 风格重构损失(style reconstruction loss)
s_pred = nets.style_encoder(x_fake, y_trg) # 从生成的假图像中提取风格编码
loss_sty = torch.mean(torch.abs(s_pred - s_trg)) # 风格重构损失,比较生成和目标风格编码的差异
# 多样性敏感损失(diversity sensitive loss)
if z_trgs is not None:
s_trg2 = nets.mapping_network(z_trg2, y_trg) # 生成第二个风格编码
else:
s_trg2 = nets.style_encoder(x_ref2, y_trg) # 从参考图像中提取第二个风格编码
x_fake2 = nets.generator(x_real, s_trg2, masks=masks) # 生成第二个假图像
x_fake2 = x_fake2.detach() # 停止梯度计算
loss_ds = torch.mean(torch.abs(x_fake - x_fake2)) # 计算两个假图像之间的差异,鼓励多样性
# 循环一致性损失(cycle-consistency loss)
masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None # 使用 FAN 模型获取热图(如果 w_hpf > 0)
s_org = nets.style_encoder(x_real, y_org) # 提取输入图像的原始风格编码
x_rec = nets.generator(x_fake, s_org, masks=masks) # 将假图像转换回原始域
loss_cyc = torch.mean(torch.abs(x_rec - x_real)) # 循环一致性损失,确保恢复的图像与原图像相似
# 总损失,由对抗损失、风格重构损失、多样性损失和循环一致性损失组成
loss = loss_adv + args.lambda_sty * loss_sty \
- args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc
# 返回总损失以及每部分的损失值
return loss, Munch(adv=loss_adv.item(),
sty=loss_sty.item(),
ds=loss_ds.item(),
cyc=loss_cyc.item())
判别器损失
它对真实图像和生成的假图像分别进行判别,并计算对应的对抗损失。对真实图像,函数计算其对抗损失(希望判别器将其判别为真)和 R1 正则化损失,以提高训练稳定性。对生成的假图像,生成器根据目标域的风格编码生成假图像,判别器再判断该假图像并计算对抗损失(希望判别器将其判别为假)。最后,将真实损失、假图像损失和正则化损失加和,作为判别器的总损失。
判别器损失计算源码
def compute_d_loss(nets, args, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None):
# 确保 z_trg 和 x_ref 中只有一个不为空
assert (z_trg is None) != (x_ref is None)
# 对真实图像进行操作
x_real.requires_grad_() # 允许对 x_real 进行梯度计算
out = nets.discriminator(x_real, y_org) # 使用判别器判断真实图像
loss_real = adv_loss(out, 1) # 真实图像的对抗损失,目标是 1
loss_reg = r1_reg(out, x_real) # R1 正则化损失,用于提高训练稳定性
# 对生成的假图像进行操作
with torch.no_grad(): # 假图像的生成不需要计算梯度
if z_trg is not None:
s_trg = nets.mapping_network(z_trg, y_trg) # 通过映射网络生成目标风格编码
else: # x_ref 不为空时,通过风格编码器生成风格编码
s_trg = nets.style_encoder(x_ref, y_trg)
x_fake = nets.generator(x_real, s_trg, masks=masks) # 生成假图像
out = nets.discriminator(x_fake, y_trg) # 判别器判断生成的假图像
loss_fake = adv_loss(out, 0) # 假图像的对抗损失,目标是 0
# 总损失,由真实损失、假图像损失和正则化损失组成
loss = loss_real + loss_fake + args.lambda_reg * loss_reg
# 返回总损失以及每部分的损失值
return loss, Munch(real=loss_real.item(),
fake=loss_fake.item(),
reg=loss_reg.item())