揭示 Wasserstein 生成对抗网络的潜力:生成建模的新范式

news2024/11/16 19:40:36

导 读

Wasserstein 生成对抗网络 (WGAN) 作为一项关键创新而出现,解决了经常困扰传统生成对抗网络 (GAN) 的稳定性和收敛性的基本挑战。

由 Arjovsky 等人于2017 年提出,WGAN 通过利用 Wasserstein 距离彻底改变了生成模型的训练,提供了一个强大的框架,可以提高生成样本的质量和多样性。

本文深入探讨了 WGAN 的概念基础、优势和实际意义,说明了它们在更广泛的生成建模背景下的重要性。

有需要的朋友关注公众号【小Z的科研日常】,获取更多内容

01、WGAN的概念框架

WGAN 与其前辈的区别在于用 Wasserstein 距离代替 Jensen-Shannon 散度作为其损失函数。

瓦瑟斯坦距离,直观地理解为推土机距离,量化了将一种概率分布转换为另一种概率分布所需的最小成本。

该指标赋予 WGAN 在训练过程中更平滑、更可靠的梯度信号,即使在真实数据分布和生成数据分布不重叠的情况下,也有助于生成更高质量的样本。

与传统 GAN 的一个重要区别是取代了判别器。与将输入分类为真实或虚假的判别器不同,WGAN 框架中的批评者评估真实样本和生成样本的分布之间的 Wasserstein 距离。

这种从分类到估计的转变标志着生成模型处理学习过程的方式发生了根本性变化,从而实现了更细致、更有效的训练动态。

02、相比于传统GAN的优势与挑战

WGAN 提供了几个引人注目的优势,可以解决传统 GAN 框架的局限性。

首先,它们表现出改进的训练稳定性,降低了对超参数设置和架构选择的敏感性。这种稳定性源于 Wasserstein 距离的特性,即使真实分布和生成分布之间没有重叠,它也能提供有用的梯度信息——这是一个可能阻碍传统 GAN 训练的常见问题。

此外,WGAN 还缓解了模式崩溃问题,即生成器学习产生有限范围的输出,从而无法捕获真实数据分布的多样性的现象。Wasserstein 距离的连续且更有意义的损失景观鼓励生成器探索更广泛的输出,从而增强生成样本的多样性。

WGAN 中损失度量的可解释性也代表了重大进步。与传统 GAN(判别器的准确性不一定与生成样本的质量相关)不同,WGAN 中的批评者损失提供了更直接的收敛性衡量标准,为训练过程和生成数据的质量提供了有价值的见解。

尽管有其优点,WGAN 也带来了新的挑战,主要与计算效率有关。WGAN 的最初实现需要权重裁剪来强制执行 Lipschitz 约束,这对于 Wasserstein 距离的理论属性至关重要。

然而,权重裁剪可能会导致优化困难和容量利用率不足。为了解决这个问题,引入带有梯度惩罚的 WGAN (WGAN-GP) 提出了一种替代方法来强制实施 Lipschitz 约束,而无需进行权重裁剪,从而提高训练稳定性和模型性能。

03、代码

为 Wasserstein 生成对抗网络 (WGAN) 创建完整的代码示例涉及几个步骤,包括定义生成器和批评者的模型架构、准备合成数据集、训练模型以及通过指标和图评估性能。

此示例将说明使用 TensorFlow 和 Keras 的基本实现,并使用简单的合成数据集以便于理解。

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

def build_critic():
    model = keras.Sequential([
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ])
    return model

def build_generator(latent_dim):
    model = keras.Sequential([
        keras.Input(shape=(latent_dim,)),
        layers.Dense(7 * 7 * 128),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding='same', activation='sigmoid'),
    ])
    return model

class WGAN(keras.Model):
    def __init__(self, critic, generator, latent_dim):
        super(WGAN, self).__init__()
        self.critic = critic
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
        self.g_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
        self.critic_loss_tracker = keras.metrics.Mean(name="critic_loss")
        self.generator_loss_tracker = keras.metrics.Mean(name="generator_loss")

    @property
    def metrics(self):
        return [self.critic_loss_tracker, self.generator_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

       # 在潜在空间中随机取样
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # 将它们解码为假图像
        generated_images = self.generator(random_latent_vectors)

        # 将它们与真实图像相结合
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # 组合标签,辨别真假图像
        labels = tf.concat(
            [tf.ones((batch_size, 1)), -tf.ones((batch_size, 1))], axis=0
        )
        # 在标签中添加随机噪音--重要技巧!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # 训练批评家
        with tf.GradientTape() as tape:
            predictions = self.critic(combined_images)
            d_loss = self.d_loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.critic.trainable_variables)
        self.d_optimizer.apply_gradients(
            zip(grads, self.critic.trainable_variables)
        )

       # 在潜在空间中随机取样
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # 组装 "所有真实图像 "的标签
        misleading_labels = -tf.ones((batch_size, 1))

        # 训练生成器(通过评论家模型)
        with tf.GradientTape() as tape:
            predictions = self.critic(self.generator(random_latent_vectors))
            g_loss = self.g_loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(
            zip(grads, self.generator.trainable_variables)
        )

       # 更新指标
        self.critic_loss_tracker.update_state(d_loss)
        self.generator_loss_tracker.update_state(g_loss)

        return {
            "critic_loss": self.critic_loss_tracker.result(),
            "generator_loss": self.generator_loss_tracker.result(),
        }

latent_dim = 128

# 准备数据集
(x_train, _), (_, _) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_train = np.expand_dims(x_train, axis=-1)

# 实例化批评者和生成器模型
critic = build_critic()
generator = build_generator(latent_dim)

# 实例化 WGAN 模型
wgan = WGAN(critic=critic, generator=generator, latent_dim=latent_dim)

# 编译 WGAN 模型
wgan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
    d_loss_fn=keras.losses.MeanSquaredError(),
    g_loss_fn=keras.losses.MeanSquaredError(),
)

wgan.fit(x_train, batch_size=32, epochs=100)

def generate_and_save_images(model, epoch, test_input):
    predictions = model.generator(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

# 生成潜在点
random_latent_vectors = tf.random.normal(shape=(16, latent_dim))
generate_and_save_images(wgan, 0, random_latent_vectors)
1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5405 - generator_loss: 2.4530
Epoch 99/100
1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5408 - generator_loss: 2.4463
Epoch 100/100
1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5384 - generator_loss: 2.4411

此代码提供了使用简单数据集通过 TensorFlow 和 Keras 实现 WGAN 的基础框架。对于实际应用程序,您可能需要调整数据集、架构和训练参数以满足您的特定需求。

04、结论

Wasserstein 生成对抗网络代表了生成建模领域的重大飞跃。通过将 Wasserstein 距离集成到 GAN 框架中,WGAN 为训练生成模型提供了更稳定、可靠和可解释的方法。

尽管存在与计算需求和 Lipschitz 约束的执行相关的挑战,但 WGAN 及其后续迭代(如 WGAN-GP)所带来的进步继续影响着生成模型的发展。

随着该领域研究的进展,WGAN 有望进一步释放生成模型在从图像合成到自然语言生成等众多应用中的潜力,预示着人工智能驱动的创造力和创新的新时代。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1482883.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端监控与埋点

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 📃个人状态: 研发工程师,现效力于中国工业软件事业 🚀人生格言: 积跬步…

SINAMICS V90 PN 指导手册 第6章 BOP面板 LED灯、基本操作、辅助功能

概述 使用BOP可进行以下操作: 独立调试诊断参数查看参数设置SD卡驱动重启 SINAMICS V90 PN 基本操作面板 LED灯 共有两个LED状态指示灯,(RDY和COM)可用来显示驱动状态,两个LED灯都为三色(绿色/红色/黄色) LED灯状态 状态指示灯的颜色、状…

什么是VR虚拟现实|虚拟科技博物馆|VR设备购买

虚拟现实(Virtual Reality,简称VR)是一种通过计算机技术模拟出的一种全新的人机交互方式。它可以通过专门的设备(如头戴式显示器)将用户带入一个计算机生成的虚拟环境之中,使用户能够与这个虚拟环境进行交互…

1小时网络安全事件报告要求,持安零信任如何帮助用户应急响应?

12月8日,国家网信办起草发布了《网络安全事件报告管理办法(征求意见稿)》(以下简称“办法”)。拟规定运营者在发生网络安全事件时应当及时启动应急预案进行处置。 1小时报告 按照《网络安全事件分级指南》&#xff0c…

centos7单节点部署ceph(mon/mgr/osd/mgr/rgw)

使用ceph建议采用多节点多磁盘方式部署,本文章仅作为单节点部署参考,请勿用于生产环境 使用ceph建议采用多节点多磁盘方式部署,本文章仅作为单节点部署参考,请勿用于生产环境 使用ceph建议采用多节点多磁盘方式部署,…

2024年2月文章一览

2024年2月编程人总共更新了5篇文章: 1.2024年1月文章一览 2.Programming Abstractions in C阅读笔记:p283-p292 3.Programming Abstractions in C阅读笔记:p293-p302 4.Programming Abstractions in C阅读笔记:p303-p305 5.P…

如何在jupyter notebook 中下载第三方库

在anconda 中找到: Anaconda Prompt 进入页面后的样式: 在黑色框中输入: 下载第三方库的命令 第三方库: 三种输入方式 标准保证正确 pip instsall 包名 -i 镜像源地址 pip install pip 是 Python 包管理工具,…

2024年腾讯云优惠券领取页面_代金券使用方法_新老用户均可

腾讯云代金券领取渠道有哪些?腾讯云官网可以领取、官方媒体账号可以领取代金券、完成任务可以领取代金券,大家也可以在腾讯云百科蹲守代金券,因为腾讯云代金券领取渠道比较分散,腾讯云百科txybk.com专注汇总优惠代金券领取页面&am…

【IC前端虚拟项目】inst_buffer子模块DS与RTL编码

【IC前端虚拟项目】数据搬运指令处理模块前端实现虚拟项目说明-CSDN博客 需要说明一下的是,在我所提供的文档体系里,并没有模块的DS文档哈,因为实际项目里我也不怎么写DS毕竟不是每个公司都和HISI一样对文档要求这么严格的。不过作为一个培训的虚拟项目,还是建议在时间充裕…

PaddleOCR的部署教程(实操环境安装、数据集制作、实际应用案例)

文章目录 前言 PaddleOCR简介 一、PaddleOCR环境搭建 因为我之前安装过cuda和cudnn,查看cuda的版本根据你版本安装合适的paddlepaddle版本(之前没有安装过cuda的可以看我这篇文章Ubuntu20.04配置深度学习环境yolov5最简流程) 1.创建一个…

ESP32 partitions分区表的配置

由于在使用ESP32会遇到编译出来的bin文件大于分区表的时候,因此需要我们修改分区表或者使用自定义分区表的方式来解决。(项目是使用VScode来搭建和调试的,VScode YYDS) 具体分区标的含义这里就不讲了,网上有很多文档介…

【MySQL】:高效利用MySQL函数实用指南

🎥 屿小夏 : 个人主页 🔥个人专栏 : MySQL从入门到进阶 🌄 莫道桑榆晚,为霞尚满天! 文章目录 📑前言一. MySQL函数概论二. 字符串函数三. 数值函数四. 日期函数五. 流程函数&#x1…

C. Bitwise Operation Wizard

解题思路 可以相同先通过,找出最大值再通过每个数与取或,记录得值最大时的每个数其中的最小值与的最大 import java.io.*; import java.math.BigInteger; import java.util.Arrays; import java.util.BitSet; import java.util.HashMap; import java.ut…

StarRocks实战——携程酒店实时数仓

目录 一、实时数仓 二、实时数仓架构介绍 2.1 Lambda架构 2.2 Kappa架构 三、携程酒店实时数仓架构 3.1 架构选型 3.2 实时计算引擎选型 3.3 OLAP选型 四、携程酒店实时订单 4.1 数据源 4.2 ETL数据处理 4.3 应用效果 4.4 总结 原文大佬的这篇实时数仓建设案例有借…

代码随想录-回溯算法

组合 //未剪枝 class Solution {List<List<Integer>> ans new ArrayList<>();Deque<Integer> path new LinkedList<>();public List<List<Integer>> combine(int n, int k) {backtracking(n, k, 1);return ans;}public void back…

js中浏览器渲染原理

JavaScript&#xff08;JS&#xff09;是一种广泛使用的编程语言&#xff0c;特别是在Web开发中。在浏览器中&#xff0c;JS被用于实现动态网页效果、交互性和用户体验的提升。然而&#xff0c;要理解JS在浏览器中的工作原理&#xff0c;我们首先需要了解浏览器的渲染过程。 浏…

C++之函数,指针

函数 1&#xff0c;函数概述 作用&#xff1a;将一段经常使用的代码封装起来&#xff0c;减少重复代码 一个较大的程序&#xff0c;一般分为若干份程序块&#xff0c;每个模块实现特定的功能 2&#xff0c;函数的定义 函数的定义一般有五个步骤&#xff1a; 1&#xff0c…

【前端素材】推荐优质数医院办公后台管理系统网页Stisla平台模板(附源码)

一、需求分析 在线后台管理系统是指供管理员或运营人员使用的Web应用程序&#xff0c;用于管理和监控网站、应用程序或系统的运行和数据。它通常包括一系列工具和功能&#xff0c;用于管理用户、内容、权限、数据等。下面是关于在线后台管理系统的详细分析&#xff1a; 1、功…

​用细节去解释,如何打造一款行政旗舰车型

高山行政加长版应该是这个级别里最大的几款 MPV 之一了&#xff0c;对于一款较大的车型&#xff0c;其最重要的是解决行驶的便利性。 这次我们就试试魏牌高山行政加长版&#xff0c;从产品本身出发看几个纬度的细节&#xff1a; 行政该如何定义加长后产品的功能变化加长之后到…

ssm172旅行社管理系统的设计与实现

** &#x1f345;点赞收藏关注 → 私信领取本源代码、数据库&#x1f345; 本人在Java毕业设计领域有多年的经验&#xff0c;陆续会更新更多优质的Java实战项目希望你能有所收获&#xff0c;少走一些弯路。&#x1f345;关注我不迷路&#x1f345;** 一 、设计说明 1.1 研究…