AIGC实战——WGAN(Wasserstein GAN)

news2025/1/26 15:49:27

AIGC实战——WGAN

    • 0. 前言
    • 1. WGAN-GP
      • 1.1 Wasserstein 损失
      • 1.2 Lipschitz 约束
      • 1.3 强制 Lipschitz 约束
      • 1.4 梯度惩罚损失
      • 1.5 训练 WGAN-GP
    • 2. GAN 与 WGAN-GP 的关键区别
    • 3. WGAN-GP 模型分析
    • 小结
    • 系列链接

0. 前言

原始的生成对抗网络 (Generative Adversarial Network, GAN) 在训练过程中面临着模式坍塌和梯度消失等问题,为了解决这些问题,研究人员提出了大量的关键技术以提高GAN模型的整体稳定性,并降低了上述问题出现的可能性。例如 WGAN (Wasserstein GAN) 和 WGAN-GP (Wasserstein GAN-Gradient Penalty) 等,通过对原始生成对抗网络 (Generative Adversarial Network, GAN) 框架进行了细微调整,就能够训练复杂GAN。在本节中,我们将学习 WGANWGAN-GP,两者都对原始 GAN 框架进行了细微调整,以改善图像生成过程的稳定性和质量。

1. WGAN-GP

WGAN (Wasserstein GAN) 是提高 GAN 训练稳定性方面的一次巨大进步,在经过一些简单改动后 GAN 就能够实现以下两个特点:

  • 与生成器的收敛度和生成样本质量相关的损失度量
  • 优化过程的稳定性得到提高

具体来说,WGAN 针对判别器和生成器提出了一种新的损失函数 (Wasserstein Loss),用这种损失函数代替二元交叉熵就可以让 GAN 的收敛更加稳定。
在本节中,我们将构建一个 WGAN-GP (Wasserstein GAN-Gradient Penalty),利用 CelebA 数据集训练模型以生成人脸图像。

1.1 Wasserstein 损失

首先我们来回顾一下二元交叉嫡, 在训练 DCGAN 判别器和生成器时采用了这种损失函数:
− 1 n ∑ i = 1 n ( y i l o g ( p i ) + ( 1 − y i ) l o g ( 1 − p i ) ) -\frac 1 n \sum_{i=1}^n(y_ilog(p_i)+(1-y_i)log(1-p_i)) n1i=1n(yilog(pi)+(1yi)log(1pi))
为了训练 GAN 的判别器 D,我们根据以下两者计算损失:真实图像的预测 p i = D ( x i ) p_i=D(x_i) pi=D(xi) 与标签 y i = 1 y_i=1 yi=1 之间的误差,以及生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi))与标签 y i = 0 y_i=0 yi=0 之间的误差。因此,对于 GAN 的判别器来说,损失函数最小化的过程可以表示为:
min ⁡ D − ( E x ∼ p X [ log ⁡ D ( x ) ] + E z ∼ p Z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ) \mathop {\min} \limits_{D}-(\mathbb E_{x\sim p_X}[\log D(x)]+\mathbb E_{z\sim p_Z}[\log (1-D(G(z)))]) Dmin(ExpX[logD(x)]+EzpZ[log(1D(G(z)))])
为了训练 GAN 的生成器 G,我们根据生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = 1 y_i=1 yi=1 的误差计算损失。因此,对于 GAN 的生成器来说,将损失函数最小化的过程可以表示为:
min ⁡ G − ( E z ∼ p Z [ log ⁡ D ( G ( z ) ) ] ) \mathop {\min}\limits_{G}-(\mathbb E_{z\sim p_Z}[\log D(G(z))]) Gmin(EzpZ[logD(G(z))])
接下来,我们比较上述损失函数与 Wasserstein 损失函数。
Wasserstein 损失 (Wasserstein Loss) 是用于 Wasserstein GAN (WGAN) 的一种损失函数。与传统的二元交叉熵损失函数不同,Wasserstein 损失引入了标签 1-1,将判别器的输出从概率值转变为分数 (score),因此,WGAN 的判别器通常也被称为评论家 (critic),并要求判别器是 1-Lipschitz 连续函数。
具体来说,Wasserstein 损失使用标签 y i = 1 y_i=1 yi=1 y i = − 1 y_i=-1 yi=1 代替 y i = 1 y_i=1 yi=1 y i = 0 y_i=0 yi=0,同时还需要移除判别器最后一层的 Sigmoid激活函数,如此一来预测结果 p i p_i pi 就不一定在 [ 0 , 1 ] [0,1] [0,1] 范围内了,它可以是 [ − ∞ , ∞ ] [-∞,∞] [,] 范围内的任何值。Wasserstein 损失的定义如下:
− 1 n ∑ i = 1 n ( y i p i ) -\frac 1 n∑_{i=1}^n(y_ip_i) n1i=1n(yipi)
在训练 WGAN 的判别器 D 时,我们将计算以下损失:判别器对真实图像的预测 p i = D ( x i ) p_i=D(x_i) pi=D(xi) 与标签 y i = 1 y_i=1 yi=1 之间的误差,判别器对生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = − 1 y_i=-1 yi=1 之间的误差。因此,对于 WGAN 判别器,最小化损失函数的过程可以表示为:
min ⁡ D − ( E x ∼ p X [ D ( x ) ] − E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ D - (\mathbb E_{x\sim p_X}[D(x)] - \mathbb E_{z\sim p_Z}[D(G(z))]) Dmin(ExpX[D(x)]EzpZ[D(G(z))])
换句话说,WGAN 判别器试图最大化其对真实图像的预测和生成图像的预测之间的差异,且真实图像的得分更高。
而对于 WGAN 生成器 G 的训练,我们根据判别器对生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = 1 y_i=1 yi=1 计算损失。因此,对于 WGAN 生成器,最小化损失函数可以表示为:
min ⁡ G − ( E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ G - (\mathbb E_{z\sim p_Z}[D(G(z))]) Gmin(EzpZ[D(G(z))])
换句话说,WGAN 生成器试图生成被判别器以极高分数判定为真实图像的图像(即,令判别器认为它们是真实的)。

1.2 Lipschitz 约束

由于我们允许判别器输出 [ − ∞ , ∞ ] [-∞,∞] [,] 范围内的任意值,而不是按照 Sigmoid 函数那样将输出限制在 [ 0 , 1 ] [0,1] [0,1] 范围内,因此 Wasserstein 损失可能会非常大。因此,为了使 Wasserstein 损失函数正常工作,需要对判别器进行额外约束,即 1-Lipschitz 连续性约束。判别器是一个将图像转换为预测的函数 D,如果对于任意两个输人图像 x 1 x_1 x1 x 2 x_2 x2,判别器函数 D 满足以下不等式,则该函数为 1-Lipschitz 连续:
∣ D ( x 1 ) − D ( x 2 ) ∣ ∣ x 1 − x 2 ∣ ≤ 1 \frac {|D(x_1) - D(x_2)|}{|x_1 - x_2|} ≤ 1 x1x2D(x1)D(x2)1
其中, ∣ x 1 − x 2 ∣ |x_1 - x_2| x1x2 表示两个图像的平均像素之差的绝对值, ∣ D ( x 1 ) − D ( x 2 ) ∣ |D(x_1) - D(x_2)| D(x1)D(x2) 表示判别器预测之间的绝对值。这意味着判别器的预测变化速率在任何情况下都是有界的(即梯度的绝对值不能大于 1)。可以在下图中的 Lipschitz 连续的一维函数中看到,无论将圆锥放在任何位置,曲线都不会进入圆锥内部。换句话说,曲线上任何一点的上升或下降速度都是有限的。

Lipschitz 连续

1.3 强制 Lipschitz 约束

在原始的 WGAN 论文中,作者通过在每个训练结束后将判别器的权重裁剪到一个较小范围内 [ − 0.01 , 0.01 ] [-0.01, 0.01] [0.01,0.01] 来强制执行 Lipschitz 约束。
由于我们裁剪了判别器的权重,判别器的学习能力大大降低,因此,事实上,权重裁剪并不是一种理想的强制 Lipschitz 约束的方式。一个强大的判别器对于 WGAN 的成功至关重要,因为如果没有准确的梯度,生成器无法学习如何调整其权重以产生更好的样本。
因此,研究人员提出了许多其他方法来强制执行 Lipschitz 约束,并提高 WGAN 学习复杂特征的能力。其中一种方法是带有梯度惩罚 (Gradient Penalty) 的 Wasserstein GAN
通过在判别器的损失函数中包含一个梯度惩罚项来直接强制执行 Lipschitz 约束,如果梯度范数偏离 1 时,该项会惩罚模型,从而使训练过程更加稳定。
接下来,将这个额外的梯度惩罚项加入到判别器损失函数中。

1.4 梯度惩罚损失

下图展示了 WGAN-GP 判别器的训练过程,与原始判别器的训练过程进行比较,我们可以看到关键的改进是将梯度惩罚损失作为整体损失函数的一部分,并与来自真实图像和生成图像的 Wasserstein 损失一起使用。

WGAN-GP

梯度惩罚损失衡量了预测关于输入图像的梯度范数与 1 之间的平方差。模型倾向于找到能够使梯度惩罚项最小化的权重,从而鼓励模型符合 Lipschitz 约束。
在训练过程中,每一处的计算梯度是非常困难的,因此WGAN-GP 只在少数几个点处评估梯度。为了确保平衡的,我们使用一组插值图像,在真实图像与伪造图像之间的随机位置逐像素进行插值 (Interpolation) 以生成一些图像。

插值图像

使用 Keras 计算梯度惩罚项:

    def gradient_penalty(self, batch_size, real_images, fake_images):
        # 批数据中的每个图像都会得到一个 0~1 之间的随机数字,存储到向量 alpha 中
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        # 计算一组插值图像
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff
        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 使用判别器对每个插值图像进行评分
            pred = self.critic(interpolated, training=True)
        # 计算插值图像 (y_pred) 的预测对于输入 interpolated_samples) 的梯度
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 计算这个向量的 L2 范数(即欧几里得长度)
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        # 函数返回 L2 范数与 1 之差的平方的均值
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

1.5 训练 WGAN-GP

使用 Wasserstein 损失函数的一个优点是,不再需要担心平衡判别器和生成器的训练。事实上,在使用 Wasserstein 损失时,必须在更新生成器之前将判别器训练到收敛,以确保生成器更新的梯度准确无误。这与标准 GAN 相反,标准 GAN 中重要的是不要让判别器变得过强。
因此,使用 Wasserstein GAN,我们可以简单地在生成器更新之间多次训练判别器,以确保它接近收敛。通常每次生成器更新一次,判别器更新三到五次。
了解了 WGAN-GP 的两个关键概念 (Wasserstein 损失和梯度惩罚项)后,使用 Keras 实现 WGAN-GP

    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]
        # 对判别器进行三次更新
        for i in range(self.critic_steps):
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )

            with tf.GradientTape() as tape:
                fake_images = self.generator(
                    random_latent_vectors, training=True
                )
                fake_predictions = self.critic(fake_images, training=True)
                real_predictions = self.critic(real_images, training=True)
                # 计算判别器的 Wasserstein 损失
                c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(real_predictions)
                # 计算梯度惩罚项
                c_gp = self.gradient_penalty(batch_size, real_images, fake_images)
                # 判别器损失函数是 Wasserstein 损失和梯度惩罚的加权和
                c_loss = c_wass_loss + c_gp * self.gp_weight
            c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
            # 更新判别器的权重
            self.c_optimizer.apply_gradients(
                zip(c_gradient, self.critic.trainable_variables)
            )
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_latent_vectors, training=True)
            fake_predictions = self.critic(fake_images, training=True)
            # 计算生成器的 Wasserstein 损失
            g_loss = -tf.reduce_mean(fake_predictions)

        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # 更新生成器的权重
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )

        self.c_loss_metric.update_state(c_loss)
        self.c_wass_loss_metric.update_state(c_wass_loss)
        self.c_gp_metric.update_state(c_gp)
        self.g_loss_metric.update_state(g_loss)
        return {m.name: m.result() for m in self.metrics}

在训练 WGAN-GP 之前,需要注意的最后一点是判别器不应该使用批量归一化。这是因为批归一化会在同一批图像之间创建相关性,从而使梯度惩罚损失的效果降低。实验证明,即使在判别器中没有批归一化, WGAN-GP 仍然可以输出出色的结果。

2. GAN 与 WGAN-GP 的关键区别

总而言之,标准 GANWGAN-GP 之间存在以下:

  • WGAN-GP 使用 Wasserstein 损失
  • WGAN-GP 使用 1 表示真实图像标签,使用 -1 表示伪造图像的标签
  • 判别器的最后一层没有使用 sigmoid 激活
  • 在判别器的损失函数中包含梯度惩罚项
  • 每训练一次生成器更新权重,需要多次训练判别器
  • 判别器中没有批归一化层

3. WGAN-GP 模型分析

训练 25epoch 后,WGAN-GP 模型的生成器能够生成合理图像:

面部生成结果

该模型已经学习到了面部的重要高级特征,且没有出现模式坍塌的迹象。
如果我们将 WGAN-GP 的输出与变分自编码器 (Variational Autoencoder, VAE) 的输出进行比较,可以看到 WGAN-GP 生成的图像通常更清晰。总的来说,VAE 倾向于产生颜色边界模糊的图像,而 GAN 产生的图像更加清晰合理。GAN 通常比 VAE 更难训练,需要更长的时间才能获得满意的数据质量。

小结

在本节中,我们学习了如何使用 Wasserstein 损失函数以解决经典 GAN 训练过程中的模式坍塌和梯度消失等问题,使得 GAN 的训练更加可预测和可靠。WGAN-GP 通过在损失函数中添加一个令梯度范数指向 1 的项,为训练过程施加 1-Lipschitz 约束。

系列链接

AIGC实战——生成模型简介
AIGC实战——深度学习 (Deep Learning, DL)
AIGC实战——卷积神经网络(Convolutional Neural Network, CNN)
AIGC实战——自编码器(Autoencoder)
AIGC实战——变分自编码器(Variational Autoencoder, VAE)
AIGC实战——使用变分自编码器生成面部图像
AIGC实战——生成对抗网络(Generative Adversarial Network, GAN)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1301427.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【js】数字字符串的比较

今天排查一个日历组件的bug,month打印出来是9,month1打印出来为12,比较month和month1大小进入if或者else,奇怪的是每次都是进入的month>month1语句里面 打印typeOf(a)和typeOf(b&#xff09…

深入学习Redis:从入门到实战

Redis快速入门 1.初识Redis1.1.认识NoSQL1.1.1.结构化与非结构化1.1.2.关联和非关联1.1.3.查询方式1.1.4.事务1.1.5.总结 1.2.认识Redis1.3.安装Redis1.3.1.依赖库1.3.2.上传安装包并解压1.3.3.启动1.3.4.默认启动1.3.5.指定配置启动1.3.6.开机自启 1.4.Redis桌面客户端1.4.1.R…

MySQL - InnoDB 和 MyISAM 的索引实现的区别

InnoDB 和 MyISAM 底层都是 B 树的实现,但是二者却完全不同 。 主键索引文件存储不同 MyISAM 引擎的索引文件和数据文件是分离的,而 InnoDB 引擎的索引文件和数据文件是不分离的。 MyISAM 引擎的叶子节点存储的是数据文件的地址,而 InnoDB 的…

VR全景开启智能化酒店获客新模式,打造高人气入住

一般而言,消费者如果想要了解酒店信息,渠道主要是通过第三方平台商家发布的图片,有多次入住经验的可能还会看下以往用户的评价,以及借助酒店的平面宣传效果图看下空间布局等,但是这种用户评价的主观色彩,很…

【数据结构与算法】实现红黑树

文章目录 一、红黑树的五条规则二、红黑树的三种变换2.1.变色2.2.左旋转2.3.右旋转 三、红黑树的插入操作3.1.情况13.2.情况23.3.情况33.4.情况43.5.情况53.6.案例插入10插入9插入8插入7插入6插入5插入4插入3插入2插入1 一、红黑树的五条规则 红黑树除了符合二叉搜索树的基本规…

使用WebyogSQLyog使用数据库

数据库 实现数据持久化到本地: 使用完整的管理系统统一管理, 数据库(DateBase): 为了方便数据存储和管理(增删改查),将数据按照特定的规则存储起来 安装WebyogSQLyog -- 创建数…

探索低代码的潜力、挑战与未来展望

低代码开发作为一种新兴的开发方式,正在逐渐改变着传统的编程模式,低代码使得开发者无需编写大量的代码即可快速构建各种应用程序。然而,低代码也引发了一系列争议,有人称赞其为提升效率的利器,也有人担忧其可能带来的…

【Spring教程22】Spring框架实战:Spring事务角色与 Spring事务属性、事务传播行为代码示例详解

目录 1.Spring事务角色1.1 未开启Spring事务之前:1.2 开启Spring的事务管理后2 Spring事务属性2.1 事务配置2.2 转账业务追加日志案例2.2.1 需求分析2.2.2 环境准备 2.3 事务传播行为2.3.1.修改logService改变事务的传播行为2.3.2 事务传播行为的可选值 欢迎大家回到《 Java教…

【EXCEL】规划求解

题目: s1:设置EXCEL加载项(第一次使用):开发工具–>EXCEL加载项–>勾选“规划求解加载项”–>确定 s2:填入公式(等号左边) s3:数据–>规划求解 s4:得出结果 总结:这玩意…

0基础学java-day15(泛型)

一、泛型 1 泛型的理解和好处 1.1 看一个需求 【不小心加入其它类型,会导致出现类型转换异常】 package com.hspedu.generic;import java.util.ArrayList;/*** author 林然* version 1.0*/ public class Generic01 {SuppressWarnings("all")public st…

NSSCTF web刷题记录7

文章目录 [SDCTF 2022]CURL Up and Read [SDCTF 2022]CURL Up and Read 考点:SSRF 打开题目发现是curl命令,提示填入url 尝试http://www.baidu.com,成功跳转 将url的字符串拿去解码,得到json格式数据 读取下环境变量&#xff0c…

【算法集训】基础数据结构:三、链表

链表就是将所有数据都用一个链子串起来,其中链表也有多种形式,包含单向链表、双向链表等; 现在毕竟还是基础阶段,就先学习单链表吧; 链表用头结点head表示一整个链表,每个链表的节点包含当前节点的值val和下…

社交媒体图像识别与情感分析

社交媒体图像识别与情感分析是当前人工智能领域的一个研究热点。通过对社交媒体上大量的图像和文本数据进行深度学习和情感分析,可以提取出图像中的情感信息,从而为社交媒体用户提供更加个性化和精准的内容推荐和服务。 在社交媒体图像识别方面&#xff…

LabVIEW与Tektronix示波器实现电源测试自动化

LabVIEW与Tektronix示波器实现电源测试自动化 在现代电子测试与测量领域,自动化测试系统的构建是提高效率和精确度的关键。本案例介绍了如何利用LabVIEW软件结合Tektronix MDO MSO DPO2000/3000/4000系列示波器,开发一个自动化测试项目。该项目旨在自动…

C#结合JavaScript实现多文件上传

目录 需求 引入 关键代码 操作界面 ​JavaScript包程序 服务端 ashx 程序 服务端上传后处理程序 小结 需求 在许多应用场景里,多文件上传是一项比较实用的功能。实际应用中,多文件上传可以考虑如下需求: 1、对上传文件的类型、大小…

《微信小程序开发从入门到实战》学习四十五

4.4 云函数 云函数是开发者提前定义好的、保存在云端并且将在云端运行的JS函数。 开发者先定义好云函数,再使用微信开发工具将云函数上传到云空间,在云开发控制台中可看到已经上传的云函数。 云函数运行在云端Node.js环境中。 小程序端通过wx.cloud.…

使用阿里巴巴同步工具DataX实现Mysql与ElasticSearch数据同步

一、Linux环境要求 二、准备工作 2.1 Linux安装jdk 2.2 linux安装python 2.3 下载DataX: 三、DataX压缩包导入,解压缩 四、编写同步Job 五、执行Job 六、定时更新 6.1 创建定时任务 6.2 提交定时任务 6.3 查看定时任务 七、增量更新思路 一、Linux环境要…

内外网文件传输中的4大风险,你都知道吗?

一般来说,企业实施内外网隔离的原因主要就是两个:外因和内因。外因就是因为政策法规要求,这个主要是面向一些特定行业的,比如党政机关、金融、医疗、能源等行业,受这方面监管和要求的会比较多。内因就是为了自身的数据…

C++面试宝典第4题:合并链表

题目 有一个链表,其节点声明如下: struct TNode {int nData;struct TNode *pNext;TNode(int x) : nData(x), pNext(NULL) {} }; 现给定两个按升序排列的单链表pA和pB,请编写一个函数,实现这两个单链表的合并。合并后,…

架构设计系列之基础:初探软件架构设计

11 月开始突发奇想,想把自己在公司内部做的技术培训、平时的技术总结等等的内容分享出来,于是就开通了一个 Wechat 订阅号(灸哥漫谈),开始同步发送内容。 今天(12 月 10 日)也同步在 CSDN 上开通…