使用Pytorch从零开始构建WGAN

news2025/3/1 21:03:14

引言

在考虑生成对抗网络的文献时,Wasserstein GAN 因其与传统 GAN 相比的训练稳定性而成为关键概念之一。在本文中,我将介绍基于梯度惩罚的 WGAN 的概念。文章的结构安排如下:

  1. WGAN 背后的直觉;
  2. GAN 和 WGAN 的比较;
  3. 基于梯度惩罚的WGAN的数学背景;
  4. 使用 PyTorch 从头开始​​在;
  5. CelebA-Face 数据集上实现;
  6. WGAN 结果讨论。

WGAN 背后的直觉

GAN 最初由Ian J. Goodfellow 等人发明。在 GAN 中,有一个由生成器和判别器进行的双玩家最小最大游戏。早期 GAN 的主要问题是模式崩溃和梯度消失问题。为了克服这些问题,长期以来发明了许多技术。WGAN 是试图克服传统 GAN 的这些问题的方法之一。

GAN 与 WGAN

与传统的 GAN 相比,WGAN 有一些改进/变化。

  1. 评论家而非判别器;
  2. W-Loss 代替 BCE Loss;
  3. 使用梯度惩罚/权重剪裁进行权重正则化。

传统GAN的判别器被“Critic”取代。从实现的角度来看,这只不过是最后一层没有 Sigmoid 激活的判别器。

我们稍后将讨论 WGAN 损失函数和权重正则化。

数学背景

损失函数

这是基于梯度惩罚的 WGAN 的完整损失函数。

等式 1. 具有梯度惩罚的完整 WGAN 损失函数 — [3]
在这里插入图片描述
看起来很吓人吧?让我们分解一下这个方程。

第 1 部分:原始批评损失
在这里插入图片描述

该方程产生的值应由生成器正向最大化,同时由批评家负向最大化。请注意,这里的 x_CURL 是生成器 (G(z)) 生成的图像。

这里,D 在最后一层没有 Sigmoid 激活,因此 D(*) 可以是任何实数。这给出了地球移动器的真实分布和生成分布之间的距离的近似值 - [1]。我们在这里想做的是,

  1. 评论家的观点:通过最大化等式 2结果的负值/最小化正值,尽可能地将评论家对真实图像和生成图像的输出分布分开。这反映了评论家的目标,即为真实图像提供更高的分数,为更低的分数到生成的图像。
  2. 生成器的观点:尝试通过以相反的方向分离真实图像和生成图像的输出分布来抵消评论家的努力。这最终使式 2 的结果的正值最大化。这反映了生成器的目标是通过欺骗 Critic 来提高生成图像的 Critic 分数。
  • 在这里你可能已经注意到,Critic over Discriminator这个名字的出现是因为 Critic 不区分真假图像,只是给出一个无界的分数。

为了确保方程有效,我们需要确保 Critic 函数是 1-Lipschitz 连续的 — [1]。

1-Lipschitz连续性

函数 f(x) 是 1-L 连续的,梯度应始终小于或等于 1。

为了确保这种1-Lipschitz连续性,文献中主要提出了2种方法。

  1. Weight Clipping——这是 WGAN 论文 [2] 附带的初始方法;
  2. 梯度惩罚方法——这是在最初的论文之后作为改进提出的[3]。

在本文中,我们将重点关注基于梯度惩罚的 WGAN。

第二部分:梯度惩罚
在这里插入图片描述
这是 Gulrajani 等人提出的梯度惩罚。——[3]。这里我们通过减小 Critic 梯度的 L2 范数与 1 之间的平方距离来强制 Critic 的梯度为 1。注意,我们不能强制 Critic 的梯度为 0,因为这会导致梯度消失问题。

等等!x(^)是什么?

考虑到 1-Lipschitz 连续性的定义,所有 x 的梯度应≤1。但实际上,确保所有可能的图像都满足这种条件是很困难的。因此,我们使用 x(^) 表示使用真实图像和生成图像作为梯度惩罚的数据点的随机插值图像。这确保了 Critic 的梯度将通过查看训练期间遇到的一组公平的数据点/图像进行正则化。

Pytorch实现

在这里,我将介绍大家应该做的必要更改,以便将传统的 GAN 更改为 WGAN。

对于下面的实现,我将使用我在之前有关 DCGAN 的文章中详细解释的模型和训练原理。

数据集

Celeba-face 数据集用于训练。下载、预处理、制作数据加载器脚本如代码1所示。

import zipfile
import os
if not os.path.isfile('celeba.zip'):
  !mkdir data_faces && wget https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip 
  with zipfile.ZipFile("celeba.zip","r") as zip_ref:
    zip_ref.extractall("data_faces/")
    
from torch.utils.data import DataLoader

transform = transforms.Compose([
                               transforms.Resize((img_size,img_size)),
                               transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5, 0.5),(0.5, 0.5, 0.5))])

dataset = datasets.ImageFolder('data_faces', transform=transform)
data_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

生成器和评论家

Critic 与 Discriminator 相同,但不包含最后一层 Sigmoid 激活。

class Generator(nn.Module):
  def __init__(self,noise_channels,img_channels,hidden_G):
    super(Generator,self).__init__()
    self.G=nn.Sequential(
        conv_trans_block(noise_channels,hidden_G*16,kernal_size=4,stride=1,padding=0),
        conv_trans_block(hidden_G*16,hidden_G*8),
        conv_trans_block(hidden_G*8,hidden_G*4),
        conv_trans_block(hidden_G*4,hidden_G*2),
        nn.ConvTranspose2d(hidden_G*2,img_channels,kernel_size=4,stride=2,padding=1),
        nn.Tanh()
    )
  def forward(self,x):
    return self.G(x)

class Critic(nn.Module):
  def __init__(self,img_channels,hidden_D):
    super(Critic,self).__init__()
    self.D=nn.Sequential(
        conv_block(img_channels,hidden_G),
        conv_block(hidden_G,hidden_G*2),
        conv_block(hidden_G*2,hidden_G*4),
        conv_block(hidden_G*4,hidden_G*8),
        nn.Conv2d(hidden_G*8,1,kernel_size=4,stride=2,padding=0))
    
  def forward(self,x):
    return self.D(x)

Generator 和 Critic 的支持块如下面的代码 3 所示。

class conv_trans_block(nn.Module):
  def __init__(self,in_channels,out_channels,kernal_size=4,stride=2,padding=1):
    super(conv_trans_block,self).__init__()
    self.block=nn.Sequential(
        nn.ConvTranspose2d(in_channels,out_channels,kernal_size,stride,padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU())
  def forward(self,x):
    return self.block(x)

class conv_block(nn.Module):
  def __init__(self,in_channels,out_channels,kernal_size=4,stride=2,padding=1):
    super(conv_block,self).__init__()
    self.block=nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernal_size,stride,padding),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2))
  def forward(self,x):
    return self.block(x)

损失函数

与任何其他典型的损失函数不同,损失函数可能有点棘手,因为它包含梯度。在这里,我们将使用梯度惩罚来实现 W-loss,稍后可以将其插入 WGAN 模型中。

def get_gen_loss(crit_fake_pred):
  gen_loss= -torch.mean(crit_fake_pred)
  return gen_loss

def get_crit_loss(crit_fake_pred, crit_real_pred, gradient_penalty, c_lambda):
  crit_loss= torch.mean(crit_fake_pred)- torch.mean(crit_real_pred)+ c_lambda* gradient_penalty
  return crit_loss

让我们分解一下代码 4 中所示的损失函数。

  1. 生成器损失 - 生成器损失不受梯度惩罚的影响。因此,它必须仅最大化 D(x_CURL)/ D(G(z)) 项,这意味着最小化 -D(G(z))。这是在第 2 行中实现的。
  2. 批评者损失 - 批评者损失包含等式 1 中所示损失的 2 个部分。在第 6 行中,前两项给出等式 2 中解释的原始批评者损失,而最后一项给出等式 3 中解释的梯度惩罚。

梯度惩罚可以按照下面的代码 5 来实现 - [1]。

def get_gradient(crit, real_imgs, fake_imgs, epsilon):

  mixed_imgs= real_imgs* epsilon + fake_imgs*(1- epsilon)
  mixed_scores= crit(mixed_imgs)

  gradient= torch.autograd.grad(outputs= mixed_scores,
                                inputs= mixed_imgs,
                                grad_outputs= torch.ones_like(mixed_scores),
                                create_graph=True,
                                retain_graph=True)[0]
  return gradient

def gradient_penalty(gradient):
  gradient= gradient.view(len(gradient), -1)
  gradient_norm= gradient.norm(2, dim=1)
  penalty = torch.nn.MSELoss()(gradient_norm, torch.ones_like(gradient_norm))
  return penalty

在代码 5 中,get_gradient()函数返回从x_hat (混合图像)开始到Critic 输出 (mixed_scores)结束的所有网络梯度。这将在gradient_penalty()函数中使用,它返回Critic梯度的1和L2范数之间的均方距离。

减少 Critic 的损失最终会减少这种梯度惩罚。这确保了 Critic 函数保留了 1-Lipschitz 连续性。

训练

训练将与上一篇文章中的几乎相同。但这里的损失与传统的 GAN 损失不同。我已经使用WANDB记录我的结果。如果您有兴趣记录结果,WANDB 是一个非常好的工具。

C=Critic(img_channels,hidden_C).to(device)
G=Generator(noise_channels,img_channels,hidden_G).to(device)

#C=C.apply(init_weights)
#G=G.apply(init_weights)

wandb.watch(G, log='all', log_freq=10)
wandb.watch(C, log='all', log_freq=10)

opt_C=torch.optim.Adam(C.parameters(),lr=lr, betas=(0.5,0.999))
opt_G=torch.optim.Adam(G.parameters(),lr=lr, betas=(0.5,0.999))

gen_repeats=1
crit_repeats=3

noise_for_generate=torch.randn(batch_size,noise_channels,1,1).to(device)

losses_C=[]
losses_G=[]

for epoch in range(1,epochs+1):
  loss_C_epoch=[]
  loss_G_epoch=[]

  for idx,(x,_) in enumerate(data_loader):
    C.train()
    G.train()

    x=x.to(device)
    x_len=x.shape[0]

    ### Train C

    loss_C_iter=0
    for _ in range(crit_repeats):
      opt_C.zero_grad()
      z=torch.randn(x_len,noise_channels,1,1).to(device)

      real_imgs=x
      fake_imgs=G(z).detach()

      real_C_out=C(real_imgs)
      fake_C_out=C(fake_imgs)

      epsilon= torch.rand(len(x),1,1,1, device= device, requires_grad=True)
      gradient= get_gradient(C, real_imgs, fake_imgs.detach(), epsilon)
      gp= gradient_penalty(gradient)
      loss_C= get_crit_loss(fake_C_out, real_C_out, gp, c_lambda=10)

      loss_C.backward()
      opt_C.step()

      loss_C_iter+=loss_C.item()/crit_repeats

    ### Train G
    loss_G_iter=0
    for _ in range(gen_repeats):
      opt_G.zero_grad()
      z=torch.randn(x_len,noise_channels,1,1).to(device)
      fake_C_out = C(G(z))
      loss_G= get_gen_loss(fake_C_out)
      loss_G.backward()
      opt_G.step()

      loss_G_iter+=loss_G.item()/gen_repeats

结果

这是经过 10 个 epoch 训练后获得的结果。与传统 GAN 一样,生成的图像随着时间的推移变得更加真实。WANDB 项目的所有结果都可以在这里找到。
在这里插入图片描述

结论

生成对抗网络一直是深度学习社区的热门话题。由于 GAN 传统训练方法的缺点,WGAN 随着时间的推移变得越来越流行。这主要是因为它对模式崩溃具有鲁棒性并且不存在梯度消失问题。在本文中,我们实现了一个能够生成人脸的简单 WGAN 模型。

请随意查看 GitHub 代码。如有任何意见、建议和意见,我们将不胜感激。

Reference

[1] GAN specialization on coursera

[2] Arjovsky, Martin et al. “Wasserstein GAN”

[3] Gulrajani, Ishaan et al. “Improved Training of Wasserstein GANs”

[4] Goodfellow, Ian et al. “Generative Adversarial Networks”

[5] Vincent Herrmann, “Wasserstein GAN and the Kantorovich-Rubinstein Duality”

[6] Karras, Tero et al. “A Style-Based Generator Architecture for Generative Adversarial Networks”

本文译自Udith Haputhanthri的博文。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1240267.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

opencv-图像金字塔

图像金字塔是一种图像处理技术,它通过不断降低图像的分辨率,形成一系列图像。金字塔分为两种类型:高斯金字塔和拉普拉斯金字塔。 高斯金字塔(Gaussian Pyramid): 高斯金字塔是通过使用高斯滤波和降采样&a…

Selenium切换窗口、框架和弹出框window、ifame、alert

一、切换窗口 #获取打开的多个窗口句柄 windows driver.window_handles #切换到当前最新打开的窗口 driver.switch_to.window(windows[-1]) #最大化浏览器 driver.maximize_window() #刷新当前页面 driver.refresh() 二、切换框架frame 如存在以下网页&#xff1a; <htm…

基于SSM的济南旅游网站设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…

Dubbo从入门到上天系列第十八篇:Dubbo引入注册中心简介以及DubboAdmin简要介绍,为后续详解Dubbo各种注册中心做铺垫!

一&#xff1a;Dubbo注册中心引言 1&#xff1a;什么是Dubbo的注册中心&#xff1f; Dubbo注册中心是Dubbo服务治理中极其重要的一个概念。它主要是用于对Rpc集群应用实例进行管理。 对于我们的Dubbo服务来讲&#xff0c;至少有两部分构成&#xff0c;一部分是Provider一部分是…

Postgresql源码(116)提升子查询案例分析

0 总结 对于SQL&#xff1a;select * from student, (select * from score where sno > 2) s where student.sno s.sno; pullup在pull_up_subqueries函数内递归完成&#xff0c;分几步&#xff1a; 将内层rte score追加到上层rtbable中&#xff1a;rte1是student、rte2带…

Jenkins 配置节点交换内存

查看交换内存 free -hswapon -s创建swap文件 dd if/dev/zero of/mnt/swap bs1M count1024启用交换文件 设置权限 chmod 600 /mnt/swap设置为交换空间 mkswap /mnt/swap启用交换 swapon /mnt/swap设置用户组 chown root:root /mnt/swap查看 swapon -s重启系统也能生效还需要修…

Redisson分布式锁源码解析

一、使用Redisson步骤 Redisson各个锁基本所用Redisson各个锁基本所用Redisson各个锁基本所用 二、源码解析 lock锁 1&#xff09; 基本思想&#xff1a; lock有两种方法 一种是空参 另一种是带参 * 空参方法&#xff1a;会默认调用看门狗的过期时间30*1000&…

Android进阶知识:ANR的定位与解决

1、前言 ANR对于Android开发者来说一定不会陌生&#xff0c;从刚开始学习Android时的一不注意就ANR&#xff0c;到后来知道主线程不能进行耗时操作注意到这点后&#xff0c;程序出现ANR的情况就大大减少了&#xff0c;甚至于消失了。那么真的是只要在主线程做耗时操作就会产生…

【尚硅谷】第06章:随堂复习与企业真题(面向对象-基础)

第06章&#xff1a;随堂复习与企业真题&#xff08;面向对象-基础&#xff09; 一、随堂复习 1. &#xff08;了解&#xff09;面向过程 vs 面向对象 不管是面向过程、面向对象&#xff0c;都是程序设计的思路。面向过程&#xff1a;以函数为基本单位&#xff0c;适合解决简单…

golang学习笔记——罗马数字转换器

文章目录 罗马数字转换器代码 参考LeetCode 罗马数字转整数代码 罗马数字转换器 编写一个程序来转换罗马数字&#xff08;例如将 MCLX 转换成 1,160&#xff09;。 使用映射加载要用于将字符串字符转换为数字的基本罗马数字。 例如&#xff0c;M 将是映射中的键&#xff0c;其值…

分类预测 | Matlab实现KPCA-IDBO-LSSVM基于核主成分分析-改进蜣螂算法优化最小二乘支持向量机的分类预测

分类预测 | Matlab实现KPCA-IDBO-LSSVM基于核主成分分析-改进蜣螂算法优化最小二乘支持向量机的分类预测 目录 分类预测 | Matlab实现KPCA-IDBO-LSSVM基于核主成分分析-改进蜣螂算法优化最小二乘支持向量机的分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述 1.多特…

OpenAI再次与Altman谈判;ChatGPT Voice正式上线

11月22日&#xff0c;金融时报消息&#xff0c;OpenAI迫于超过700名员工联名信的压力&#xff0c;再次启动了与Sam Altman的谈判&#xff0c;希望他回归董事会。 在Sam确定加入微软后&#xff0c;OpenAI超700名员工签署了一封联名信&#xff0c;要求Sam和Greg Brockman&#x…

【SpringCloud】认识微服务、服务拆分以及远程调用

SpringCloud 1.认识微服务 1.1单体架构 单体架构&#xff1a;将业务的所有功能集中在一个项目中开发&#xff0c;打成一个包部署 单体架构的优缺点&#xff1a; 优点&#xff1a; 架构简单&#xff0c;部署成本低 缺点&#xff1a; 耦合度高&#xff08;维护困难&#x…

一键去水印免费网站快速无痕处理图片、视频水印

水印问题往往是一个大麻烦。即使我们只想将这些照片保留在我们的个人相册中以供怀旧&#xff0c;水印也可能像顽固的符号一样刺激我们的眼睛。为了解决这个问题&#xff0c;我们需要不断探索创新的解决方案&#xff0c;让我们深入研究一款强大的一键去水印免费网站“水印云”。…

IOS+Appium+Python自动化全实战教程

由于公司的产品坐落于不同的平台&#xff0c;如ios、mac、Android、windows、web。因此每次有新需求的时候&#xff0c;开发结束后&#xff0c;留给测试的时间也不多。此外&#xff0c;一些新的功能实现&#xff0c;偶尔会影响其他的模块功能正常的使用。 网上的ios自动化方面的…

npm install安装报错

npm WARN notsup Not compatible with your version of node/npm: v-click-outside-x3.7.1 npm ERR! Error while executing: npm ERR! /usr/bin/git ls-remote -h -t ssh://gitgithub.com/itargaryen/simple-hotkeys.git 解决办法1&#xff1a;&#xff08;没有解决我的问题…

likeshop单商户商城系统 任意文件上传漏洞复现

0x01 产品简介 likeshop单商户标准商城系统适用于B2C、单商户、自营商城场景。完美契合私域流量变现闭环交易使用。 系统拥有丰富的营销玩法&#xff0c;强大的分销能力&#xff0c;支持电子面单和小程序直播等功能。无论运营还是二开都是性价比极高的100%开源商城系统。 0x02…

VUE语法-$refs和ref属性的使用

1、$refs和ref属性的使用 1、$refs:一个包含 DOM 元素和组件实例的对象&#xff0c;通过模板引用注册。 2、ref实际上获取元素的DOM节点 3、如果需要在Vue中操作DOM我们可以通过ref和$refs这两个来实现 总结:$refs可以获取被ref属性修饰的元素的相关信息。 1.1、$refs和re…

杭电oj 2064 汉诺塔III C语言

#include <stdio.h>void main() {int n, i;long long sum[35] { 2,8,26 };for (i 3; i < 35; i)sum[i] 3 * sum[i - 1] 2;while (~scanf_s("%d", &n))printf("%lld\n", sum[n - 1]); }

孟德尔随机化 MR入门基础-简明教程-工具变量-暴露

孟德尔随机化&#xff08;MR&#xff09;入门介绍和分章分享&#xff08;暂时不解读&#xff09; 大家好&#xff0c;孟德尔随机化大火&#xff0c;但是什么是孟德尔随机化&#xff0c;具体怎么实操呢 这没有其他教程的繁冗&#xff0c;我这篇讲最基础的孟德尔随机化的核心步…