paper: https://arxiv.org/pdf/1703.10593.pdf
github: https://github.com/aitorzip/PyTorch-CycleGAN
1 cycleGAN 小结
- 网络:
生成器2个:G_A,G_B
判别器两个: D_A,D_B - 损失函数8个
6个生成器损失函数
2个判别器损失函数
1.1 数据
- fake_B
原始A,经过生成器G_A,生成fake_B
A->G_A = fake_B - rec_A
原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
A->G_A ->G_B = rec_A - fake_A
B->G_B = fake_A
原始B,经过生成器G_B,生成fake_A - rec_B
B->G_B ->G_A = rec_B
原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B
1.2 损失函数
6个生成器损失,2个判别器损失
1)6个生成器损失:
- 生成器一致损失 2个
数据B经过生成器G_A,后生成的B,与原始B距离最小。A同理
① B-> G_A->B’ : 使得B 与B’距离最小
② A-> F_B->A’ : 使得A 与A’距离最小 - 生成器损失 2个
生成器生成的数据,让判别器都判别为真
③ MSELoss(D_A(fake_B), True)
④ MSELoss(D_B(fake_A), True) - 循环一致损失 2个
⑤ 原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
A->G_A ->G_B = rec_A
⑥原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B
B->G_B ->G_A = rec_B
2) 2个判别器损失
- 判别器损失 2个
使真实图片为判别为真,假图片判别为假
① D_A
pred_real = D_A(real); pred_fake= D_A(fake)
MSELoss(pred_real, True)+MSELoss(pred_fake, False)
②D_B
pred_real = D_B(real); pred_fake= D_B(fake)
MSELoss(pred_real, True)+MSELoss(pred_fake, False)
2 模型架构
- 两个生成网络:
G: X——> Y ,输入X生成Y
F: Y——> X :输入Y生成X - 两个判别网络:
D_A: 用于区分真实A和 F(B)生成的假A.
D_B:用于区分真实B和 G(A)生成的假B.
3 损失函数
3.1 Adversarial Loss
对抗损失:
- 对于生成器 G: X——> Y
生成器G_X: 最小化以下目标函数
对于判别器D_Y:最大化以下目标函数
- 对于生成器 F: Y——> X,损失函数同上
生成器F_Y,使判别器D_X判断为真
对于判别器D_X:是真实X判断为真,F_Y生成的X,判断为假。
L G A N ( F , D X , Y , X ) L_{GAN}(F,D_X,Y,X) LGAN(F,DX,Y,X)
3.2 Cycle Consistency Loss
循环一致损失,即 X 经过生成器G_x后 得到Y,Y再过F_Y生成X,使得前后生成的X距离最小。
1) 前向一致损失
即从x 经过网络后还原为x的过程
X
−
>
G
(
x
)
−
>
F
(
G
(
x
)
)
=
X
X -> G(x) -> F(G(x)) =X
X−>G(x)−>F(G(x))=X
2)反向一致损失
即y从经过网络后还原为y的过程
Y
−
>
F
(
y
)
−
>
G
(
F
(
y
)
)
=
Y
Y -> F(y) -> G(F(y)) =Y
Y−>F(y)−>G(F(y))=Y
3.3 Full Objective
4 代码实现
4.1网络结构
- 1 生成器A :
netG_A:可以选用resnet,或者unet网络
输入数据A,生成数据B
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
- 2 生成器B:
netG_B: 与netG_A网络一样
输入数据B,生成数据A
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
- 3 判别器A
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
- 4 判别器B
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
5 损失
5.1 前向传播数据
- self.fake_B
原始A,经过生成器G_A,生成fake_B
A->G_A = fake_B - self.rec_A
原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
A->G_A ->G_B = rec_A - self.fake_A
B->G_B = fake_A
原始B,经过生成器G_B,生成fake_A - self.rec_B
B->G_B ->G_A = rec_B
原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
5.2 生成器一致损失
数据B经过生成器G_A,后生成的B,与原始B距离最小。
B-> G_A->B’ : 使得B 与B’距离最小
A-> F_B->A’ : 使得A 与A’距离最小
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A= self.L1Loss(self.idt_A, self.real_B)
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.L1Loss(self.idt_B, self.real_A)
5.2 生成器损失
生成器生成的数据,让判别器都判别为真
(备注:判别器输出不是一个值,而是一个矩阵,需要使判别器输出矩阵每一个值都接近1)
# GAN loss D_A(G_A(A))
self.loss_G_A = self.MSELoss(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterioMSELossGAN(self.netD_B(self.fake_A), True)
5.3 循环一致损失
使得重构的A与原始A距离最近,使用L1Loss
- self.rec_A
原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
A->G_A ->G_B = rec_A - self.rec_B
B->G_B ->G_A = rec_B
原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.L1Loss(self.rec_A, self.real_A)
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.L1Loss(self.rec_B, self.real_B)
5.4 生成器总loss
上面6个生成器损失求和即为总的生成损失函数
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
5.5 判别器损失
判别器:使真实图片为判别为真,假图片判别为假
pred_real = netD(real)
loss_D_real = self.MSELoss(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.MSELoss(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5