使用pytorch构建带梯度惩罚的Wasserstein GAN(WGAN-GP)网络模型

news2025/1/16 18:55:01

本文为此系列的第三篇WGAN-GP,上一篇为DCGAN。文中仍然不会过多详细的讲解之前写过的,只会写WGAN-GP相对于之前版本的改进点,若有不懂的可以重点看第一篇比较详细。

原理

具有梯度惩罚的 Wasserstein GAN (WGAN-GP)可以解决 GAN 的一些稳定性问题。 具体来说,使用W-loss 作为损失函数替代传统的 BCE 等 loss,并使用梯度惩罚来防止 mode collapse。

  • WGAN-GP 使用了 Wasserstein distance(也成为Earth Mover’s distance, EMD)作为训练 GAN 模型的目标函数,Wasserstein distance is a function of amount and distance,体现的是生成的数据的分布移动到真实数据的分布之间所需的距离与量。
    在这里插入图片描述
    随着判别器训练的越来越好,使用 BCE loss 的话会让鉴别器给出接近于 0 或者接近于 1 的极端值,如下为 sigmoid 曲线,极端值的梯度无限接近于 0,这样判别器就没有太多有用的信息反馈给生成器让它学习,导致梯度消失或 model collapse。使用距离的方式可以有效解决,分布距离再远都不再限制。
    在这里插入图片描述
    在这里插入图片描述
  • BCE loss 本质是一个 minimax game, d 即 discriminator 希望尽可能的 minimize,g 即 generator 希望尽可能的 maximize(意味着造出来的假东西对于鉴别器来说看起来很真实),可以进行如下的简化:
    在这里插入图片描述
    基于 Wasserstein distance 的 W-loss 的的式子与其简化版进行对比:
    在这里插入图片描述
    在 Wasserstein GAN 中不再是 discriminator 了,因为输出不再是 0-1 之间来进行分类,既然不分类了就不是 discriminator 了,而是 critic,所以这里使用 c 代表 critic。critic 希望其尽可能的 maximize,因为希望让 real 和 feak 的距离尽可能的大,起到划清界限的目的;generator 希望其尽可能的minimize,减小两者之间的距离,达到以假乱真的目的。
  • mode collapse 即模式崩溃,当生成器学会从单个类生成特征来欺骗鉴别器时,就会发生 mode collapse(陷入一种模式出不来),跟 cnn 的局部最优是一个概念。这会导致输出出现重复,缺乏多样性和细节。

但在使用 W-loss 训练 GAN 时需要对 critic 有一定的条件 —— critic 需要 1-L(1-Lipschitz)连续:
∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ k ∣ x 1 − x 2   ∣ |f(x_1)-f(x_2)|\le k|x_1-x_2\ | f(x1)f(x2)kx1x2 
这里的 k = 1,也就是 critic 的 nn 函数曲线的梯度(斜率)始终在 -1 到 1 之间,即梯度的 L2 范数不超过1:

在这里插入图片描述
如图:
在这里插入图片描述
曲线的每个点的斜率都是在绿色区域内,很显然这个曲线并不符合。像如下这个曲线就是符合的:
在这里插入图片描述
达到 1-L 连续有两种方法:weigh clipping、gradient penalty。

  • weigh clipping 将权重裁剪到固定范围内,从而限制 critic 的学习能力。但是这样有缺点,可能让所有参数走极端,要么取最大值要么取最小值,critic 会非常倾向于学习一个简单的映射函数。
  • gradient penalty 则是添加一个正则项在 loss function 中,相比 weigh clipping 更加柔和对critic参数的限制更加灵活,通常不会导致梯度消失或梯度爆炸问题。
    在这里插入图片描述
    这里的 λ \lambda λ 为超参值,reg 等于 critic 神经网络梯度范数 -1 的平方,即:
    在这里插入图片描述
    当 critic 神经网络梯度范数 >1 时正则化项发挥作用。平方的作用是为了让其偏离越大,惩罚越大。
    这里的 x ^ \hat{x} x^ 为真实数据与生成数据之间随机取样得到的中间数据,随机值 ϵ \epsilon ϵ 作为权重值,假设 ϵ \epsilon ϵ 为0.3,那么 1- ϵ \epsilon ϵ 为0.7。
    在这里插入图片描述

代码

model.py

from torch import nn

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

class Critic(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=64):
        super(Critic, self).__init__()
        self.crit = nn.Sequential(
            self.make_crit_block(im_chan, hidden_dim),
            self.make_crit_block(hidden_dim, hidden_dim * 2),
            self.make_crit_block(hidden_dim * 2, 1, final_layer=True),
        )

    def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        crit_pred = self.crit(image)
        return crit_pred.view(len(crit_pred), -1)

train.py

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
torch.manual_seed(0) # Set for testing purposes, please do not change!

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit = Critic().to(device)
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
crit = crit.apply(weights_init)

def get_gradient(crit, real, fake, epsilon):
    # Mix the images together
    mixed_images = real * epsilon + fake * (1 - epsilon)

    # Calculate the critic's scores on the mixed images
    mixed_scores = crit(mixed_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=mixed_images,
        outputs=mixed_scores,
        # These other parameters have to do with the pytorch autograd engine works
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient

def gradient_penalty(gradient):
    # Flatten the gradients so that each row captures one image
    gradient = gradient.view(len(gradient), -1)

    # Calculate the magnitude of every row
    gradient_norm = gradient.norm(2, dim=1)

    # Penalize the mean squared distance of the gradient norms from 1
    penalty = torch.mean((gradient_norm - 1) ** 2)
    return penalty

def get_gen_loss(crit_fake_pred):
    gen_loss = -1. * torch.mean(crit_fake_pred)
    return gen_loss

def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
    crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gp
    return crit_loss

cur_step = 0
generator_losses = []
critic_losses = []
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        mean_iteration_critic_loss = 0
        for _ in range(crit_repeats):
            ### Update critic ###
            crit_opt.zero_grad()
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            crit_fake_pred = crit(fake.detach())
            crit_real_pred = crit(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            gradient = get_gradient(crit, real, fake.detach(), epsilon)
            gp = gradient_penalty(gradient)
            crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)

            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / crit_repeats
            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer
            crit_opt.step()
        critic_losses += [mean_iteration_critic_loss]

        ### Update generator ###
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        crit_fake_pred = crit(fake_2)

        gen_loss = get_gen_loss(crit_fake_pred)
        gen_loss.backward()

        # Update the weights
        gen_opt.step()

        # Keep track of the average generator loss
        generator_losses += [gen_loss.item()]

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss"
            )
            plt.legend()
            plt.show()

        cur_step += 1

在这里插入图片描述

代码讲解

网络模型与上一篇的DCGAN没有变动。
在这里插入图片描述
这个模块进行梯度计算,即上文原理中正则项公式里面的梯度L2范数里的梯度。首先计算真实数据与生成数据之间随机取样的混合数据,然后输入 critic,最后计算出其梯度。
在这里插入图片描述
梯度惩罚模块,即上文原理中的整个正则项公式,梯度范数 -1 的平方。
在这里插入图片描述
critic 的 loss function 公式如下,generator 因为和真实数据无关,且与正则项也无关,所以只有中间一项。
在这里插入图片描述————————————————————————————————————————————

总之,WGAN-GP 不一定要提高 GAN 的整体性能,但会很好的提高稳定性并避免模式崩溃。

下一篇条件生成GAN。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1557778.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

caffe源码编译安装

一、前置准备 (1)vs2015 目前不要想着2019这些工具了,成功率太低了,就老老实实用vs2015吧 解决“VS2015安装包丢失或损坏“问题_vs2015跳过包会影响使用吗-CSDN博客 注意在安装vs2015过程中老是出现这个问题,其实就是缺少两个证书,安装完后就可以正常安装vs2015了,注意…

大数据面试专题 -- kafka

1、什么是消息队列? 是一个用于存放数据的组件,用于系统之间或者是模块之间的消息传递。 2、消息队列的应用场景? 主要是用于模块之间的解耦合、异步处理、日志处理、流量削峰 3、什么是kafka? kafka是一种基于订阅发布模式的…

AE——重构数字(Pytorch+mnist)

1、简介 AE(自编码器)由编码器和解码器组成,编码器将输入数据映射到潜在空间,解码器将潜在表示映射回原始输入空间。AE的训练目标通常是最小化重构误差,即尽可能地重构输入数据,使得解码器输出与原始输入尽…

什么是nginx正向代理和反向代理?

什么是代理? 代理(Proxy), 简单理解就是自己做不了的事情或实现不了的功能,委托别人去做。 什么是正向代理? 在nginx中,正向代理指委托者是客户端,即被代理的对象是客户端 在这幅图中,由于左边内网中…

如何解决kafka rebalance导致的暂时性不能消费数据问题

文章目录 背景思考答案排它故障转移共享 背景 之前在review同组其它业务的时候,发现竟然把kafka去掉了,问了下原因,有一个单独的服务,我们可以把它称为agent,就是这个服务是动态扩缩容的,会采集一些指标&a…

k8s的pod访问service的方式

背景 在k8s中容器访问某个service服务时有两种方式,一种是把每个要访问的service的ip注入到客户端pod的环境变量中,另一种是客户端pod先通过DNS服务器查找对应service的ip地址,然后在通过这个service ip地址访问对应的service服务 pod客户端…

HarmonyOS 应用开发之FA模型访问Stage模型DataShareExtensionAbility

概述 无论FA模型还是Stage模型,数据读写功能都包含客户端和服务端两部分。 FA模型中,客户端是由DataAbilityHelper提供对外接口,服务端是由DataAbility提供数据库的读写服务。 Stage模型中,客户端是由DataShareHelper提供对外接…

腾讯云2核2G服务器优惠价格,61元一年

腾讯云2核2G服务器多少钱一年?轻量服务器61元一年,CVM 2核2G S5服务器313.2元15个月,轻量2核2G3M带宽、40系统盘,云服务器CVM S5实例是2核2G、50G系统盘。腾讯云2核2G服务器优惠活动 txybk.com/go/txy 链接打开如下图:…

java数组与集合框架(三)--Map,Hashtable,HashMap,LinkedHashMap,TreeMap

Map集合: Map接口: 基于 键(key)/值(value)映射 Map接口概述 Map与Collection并列存在。用于保存具有映射关系的数据:key-value Map 中的key 和value 都可以是任何引用类型的数据Map 中的key 用Set来存放&#xff0…

X进制减法(蓝桥杯)

文章目录 X进制减法题目描述解题思路贪心算法模拟减法(大数相减) X进制减法 题目描述 进制规定了数字在数位上逢几进一。 X 进制是一种很神奇的进制,因为其每一数位的进制并不固定!例如说某种 X 进制数,最低数位为二…

创建Qt Quick Projects

在创建Qt Quick项目之前,我们简单说一下Qml和Qt Quick的关系:它们的关系类似于C和STL标准库的关系,Qml类比C语言,提供了基本语言特性和类型;而Qt Quick则类比STL标准库,Qt Quick在QML的基础上加入了一系列界…

Https【Linux网络编程】

目录 一、为什么需要https 二、常见加密方法 1、对称加密 2、非对称加密 3、数据指纹 三、选择什么加密方案? 方案一:对称加密() 方案二:双方使用非对称加密(效率低) 方案三&#xff1a…

深度学习十大算法之Diffusion扩散模型

1. 引言 扩散模型在近年来成为了热门话题,其火速蹿红主要归功于在图像生成领域的突破应用。尤其是一些从文本到图像的生成技术,它们成功地运用了扩散模型来创建令人惊叹的逼真图像。如果你听说过某个应用能够迅速且高质量地生成图像,那么很可…

【SpringBoot整合系列】SpirngBoot整合EasyExcel

目录 背景需求发展 EasyExcel官网介绍优势常用注解 SpringBoot整合EaxyExcel1.引入依赖2.实体类定义实体类代码示例注解解释 3.自定义转换器转换器代码示例涉及的枚举类型 4.Excel工具类5.简单导出接口SQL 6.简单导入接口SQL 7.复杂的导出(合并行、合并列&#xff0…

docker 共享网络的方式实现容器互联

docker 共享网络的方式实现容器互联 本文以nacos连接mysql为例 前提已经在mysql容器中初始化好nacos数据库,库名nacos 创建一个共享网络 docker network create --driver bridge \ --subnt 192.168.0.0/24 \ --gateway 192.168.0.1 mynet此处可以不指定网络模式、…

Python下载bing每日壁纸并实现win11 壁纸自动切换

前言: 爬虫哪家强,当然是python 我是属于啥语言都用,都懂点,不精通,实际工作中能能够顶上就可以。去年写的抓取bing每日的壁纸,保存到本地,并上传到阿里云oss,如果只是本地壁纸切换,存下来就行,一直想做个壁纸站点&…

Java代码基础算法练习-自定义函数之字符串连接-2024.03.30

任务描述: 写一函数,将两个字符串连接起来,然后在主函数中调用该函数实现字符串连接操作。 任务要求: 代码示例: package M0317_0331;import java.util.Scanner;public class m240330 {public static void main(Stri…

【Java】MyBatis快速入门及详解

文章目录 1. MyBatis概述2. MyBatis快速入门2.1 创建项目2.2 添加依赖2.3 数据准备2.4 编写代码2.4.1 编写核心配置文件2.4.2 编写SQL映射文件2.4.3 编写Java代码 3. Mapper代理开发4. MyBatis核心配置文件5. 案例练习5.1 数据准备5.2 查询数据5.2.1 查询所有数据5.2.2 查询单条…

全国青少年软件编程(Python)等级考试三级考试真题2023年12月——持续更新.....

青少年软件编程(Python)等级考试试卷(三级) 分数:100 题数:38 一、单选题(共25题,共50分) 1.一个非零的二进制正整数,在其末尾添加两个“0”,则该新数将是原数的&#xf…

Redis从入门到精通(二)Redis的数据类型和常见命令介绍

文章目录 前言第2章 Redis数据类型和常见命令2.1 key结构2.2 Redis通用命令2.3 String类型及其常用命令2.4 Hash类型及其常用命令2.5 List类型2.5 Set类型2.6 SortedSet类型2.7 小结 前言 在上一节【Redis从入门到精通(一)Redis安装与启动、Redis客户端的使用】中,…