深度学习(33)——CycleGAN(2)
完整项目在在这里:欢迎造访
文章目录
- 深度学习(33)——CycleGAN(2)
- 1. Generator
- 2. Discriminator
- 3. fake pool
- 4. loss定义
- 5. 模型参数量
- 6. debug 记录
数据格式:
1. Generator
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
特征提取部分使用backbone是resnet(可选择的,可以换其他模型做backbone)
- 上采样一共9个ResNet Block
- 下采样部分
2. Discriminator
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
在上一节说过,discriminator就是一个辨别真假的二分类模型,输入还是一张三通道的图像,最终判断这张图片是真是假。
3. fake pool
用于保存生成的fake image
self.fake_A_pool = ImagePool(opt.pool_size)
4. loss定义
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.MSE
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
- GANLoss 根据gan_mode定义,此处为MSELoss
5. 模型参数量
6. debug 记录
-
set_input(input): 得到real_A,real_B
-
optimize_parameters(): 计算loss做反向传播
-
forward(): 生成fake_A,fake_B,rec_A,rec_B
- generatorA先根据real_A生成fake_B
- generatorB使用fake_B生成rec_A
- generatorB根据real_B生成fake_A
- generatorA使用fake_A生成rec_B
-
backward_G():反向传播
- 计算identity_loss:generatorA是输入real_A得到fake_B的,那现在输入real_B是不是也可以生成和real_B差不多,将这个生成的命名为idt_A,idt_A和real_B之间会存在identity_loss,同理idt_B和real_A之间也存在identity_loss
- 计算generator_loss:generatorA生成的feak_B的loss,我们是希望feak_B是骗过discriminatorA的,所以希望discriminatorA认为是真的A,所以这里将fake_B与True做MSEloss,同理希望discriminatorB认为fake_A是真的B
- 计算cycle_loss:real_A经过generatorA生成fake_B,fake_B经过经过generatorB返回生成rec_A,计算这样循环生成的A和真实A之间的loss,B也同理。
- 最终的generator_loss是上面三者的和,因为有AB之分,所以一共有6项
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
-
backward_D_A(): 计算discriminator_A的loss
-
backward_D_B(): 计算discriminator_B的loss
-
注
- 当optimizer generator的时候discriminator设置为无梯度,不反向传播。
self.set_requires_grad([self.netD_A, self.netD_B], False)
就酱,欢迎提问讨论,886~