AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)

news2025/1/6 17:45:09

AIGC实战——条件生成对抗网络

    • 0. 前言
    • 1. CGAN架构
    • 2. 模型训练
    • 3. CGAN 分析
    • 小结
    • 系列链接

0. 前言

我们已经学习了如何构建生成对抗网络 (Generative Adversarial Net, GAN) 以从给定的训练集中生成逼真图像。但是,我们无法控制想要生成的图像类型,例如控制模型生成男性或女性的面部图像;我们可以从潜空间中随机采样一个点,但是不能预知给定潜变量能够生成什么样的图像。在本节中,我们将构建一个能够控制输出的 GAN,即条件生成对抗网络 (Conditional Generative Adversarial Net, GAN)。该模型最早由 MirzaOsindero2014 年提出,是对 GAN 架构的简单改进。

1. CGAN架构

在节中,我们将使用面部数据集中的头发颜色属性来设置 CGAN 的条件。也就是说,我们将能够明确指定是否要生成带有金发的图像。头发颜色标签作为 CelebA 数据集的一部分已在数据集中提供,CGAN 的架构如下图所示。

CGAN 架构

标准 GANCGAN 之间的关键区别在于:在 CGAN 中,我们需要向生成器和判别器传递与标签相关的额外信息。在生成器中,标签信息转化为独热编码 (one-hot) 向量后附加在潜空间样本之后。在判别器中,通过重复独热编码向量填充得到与输入图像相同形状的通道,将标签信息添加为 RGB 图像的额外通道。
CGAN 之所以能够生成指定类型的图像,是因为其判别器可以获得关于图像内容的额外信息,因此生成器必须确保其输出与提供的标签一致,以继续欺骗判别器。如果生成器生成了与图像标签不一致的图像,即使图像非常逼真,判别器会将它们判定为伪造图像,因为图像和标签并不匹配。
在本节所构建的 CGAN 中,因为有两个类别(金发和非金发),独热编码标签的长度是 2。但是,我们也可以根据需要拥有使用多个标签。例如,在 Fashion-MNIST 数据集上训练 CGAN 时,为了输出 10 种不同类型的 Fashion-MNIST 图像,可以通过将长度为 10 的独热编码标签向量并入生成器的输入,并将 10 个额外的独热编码标签通道并入判别器的输入。
综上,我们需要对标准 GAN 架构所进行的修改是,将标签信息与生成器和判别器的现有输入连接起来:

# 图像通道和标签通道分别传递给判别器,并进行连接
critic_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
label_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CLASSES))
x = layers.Concatenate(axis=-1)([critic_input, label_input])
x = layers.Conv2D(64, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU()(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(1, kernel_size=4, strides=1, padding="valid")(x)
critic_output = layers.Flatten()(x)

critic = models.Model([critic_input, label_input], critic_output)
print(critic.summary())
# 潜向量和标签类别分别传递给生成器,并在调整形状之前进行连接
generator_input = layers.Input(shape=(Z_DIM,))
label_input = layers.Input(shape=(CLASSES,))
x = layers.Concatenate(axis=-1)([generator_input, label_input])
x = layers.Reshape((1, 1, Z_DIM + CLASSES))(x)
x = layers.Conv2DTranspose(
    128, kernel_size=4, strides=1, padding="valid", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    64, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
generator_output = layers.Conv2DTranspose(
    CHANNELS, kernel_size=4, strides=2, padding="same", activation="tanh"
)(x)
generator = models.Model([generator_input, label_input], generator_output)
print(generator.summary())

2. 模型训练

调整 CGANtrain_step 方法,以令生成器和判别器适应新的输入格式:

    def train_step(self, data):
        # 从数据集中提取图像和标签
        real_images, one_hot_labels = data
        # 将独热编码向量扩展为具有与输入图像相同空间尺寸 (64×64) 的独热编码图像
        image_one_hot_labels = one_hot_labels[:, None, None, :]
        image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=IMAGE_SIZE, axis=1)
        image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=IMAGE_SIZE, axis=2)

        batch_size = tf.shape(real_images)[0]

        for i in range(self.critic_steps):
            random_latent_vectors = tf.random.normal( shape=(batch_size, self.latent_dim))

            with tf.GradientTape() as tape:
                # 生成器接受包含两个输入的列表——随机潜向量和独热编码的标签向量
                fake_images = self.generator([random_latent_vectors, one_hot_labels], training=True)
                # 判别器接受包含两个输入的列表——真实/生成图像和独热编码的标签通道
                fake_predictions = self.critic([fake_images, image_one_hot_labels], training=True)
                real_predictions = self.critic([real_images, image_one_hot_labels], training=True)

                c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(real_predictions)
                c_gp = self.gradient_penalty(batch_size, real_images, fake_images, image_one_hot_labels)
                # 梯度惩罚函数还需要通过独热编码的标签通道传递(由于其流经判别器)
                c_loss = c_wass_loss + c_gp * self.gp_weight

            c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
            self.c_optimizer.apply_gradients(zip(c_gradient, self.critic.trainable_variables))

        random_latent_vectors = tf.random.normal(
            shape=(batch_size, self.latent_dim)
        )

        with tf.GradientTape() as tape:
            # 生成器训练过程的修改与判别器训练步骤的修改相同
            fake_images = self.generator([random_latent_vectors, one_hot_labels], training=True)
            fake_predictions = self.critic([fake_images, image_one_hot_labels], training=True)
            g_loss = -tf.reduce_mean(fake_predictions)

        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))

        self.c_loss_metric.update_state(c_loss)
        self.c_wass_loss_metric.update_state(c_wass_loss)
        self.c_gp_metric.update_state(c_gp)
        self.g_loss_metric.update_state(g_loss)
        return {m.name: m.result() for m in self.metrics}

3. CGAN 分析

我们可以通过将特定的独热编码标签传递到生成器的输入中来控制 CGAN 的输出。例如,要生成一张非金发的人脸图像,我们传入向量 [1, 0];要生成一张金发的人脸图像,我们传入向量 [0, 1]
CGAN 的输出如下图所示。可以看到,在保持随机潜向量不变的情况下,只改变条件标签向量,显然 CGAN 已经学会使用标签向量来控制图像的头发颜色属性,且图像的其余部分几乎没有改变。这证明了 GAN 能够以这种方式组织潜空间中的点,使得各个特征可以相互解耦。

生成结果

如果数据集中有标签可用,即使不一定需要将生成的输出与标签相关联,将它们作为 GAN 的输入通常也可以提高生成图像的质量,我们可以把标签看作是像素输入的信息扩展。

小结

在本节中,构建了一个条件生成对抗网络 (Conditional Generative Adversarial Net, CGAN),通过将标签作为输入传递给判别器和生成器,能够生成可控类别的图像,这是由于标签为网络提供了额外的信息,以便使生成的输出与给定的标签相关联。

系列链接

AIGC实战——生成模型简介
AIGC实战——深度学习 (Deep Learning, DL)
AIGC实战——卷积神经网络(Convolutional Neural Network, CNN)
AIGC实战——自编码器(Autoencoder)
AIGC实战——变分自编码器(Variational Autoencoder, VAE)
AIGC实战——使用变分自编码器生成面部图像
AIGC实战——生成对抗网络(Generative Adversarial Network, GAN)
AIGC实战——WGAN(Wasserstein GAN)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1319858.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Java的音乐网站的设计与实现(带论文)

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下,你想解决的问…

如何实现公网访问本地内网搭建的WBO白板远程协作办公【内网穿透】

最近,我发现了一个超级强大的人工智能学习网站。它以通俗易懂的方式呈现复杂的概念,而且内容风趣幽默。我觉得它对大家可能会有所帮助,所以我在此分享。点击这里跳转到网站。 文章目录 前言1. 部署WBO白板2. 本地访问WBO白板3. Linux 安装cp…

数据结构与算法python版本之列表和字典复杂度

前面我们了解了大O表示法以及对不同算法的预估 接下来我们讨论python两种内置数据类型(列表和字典)上各种操作的大O数量级 列表数据类型 list类型各种操作的实现方法很多,如何选择具体哪种实现方法。总的方案就是,让最常用的操作…

微服务实战系列之ZooKeeper(实践篇)

前言 关于ZooKeeper,博主已完整的通过庖丁解牛式的“解法”,完成了概述。我想掌握了这些基础原理和概念后,工作的问题自然迎刃而解,甚至offer也可能手到擒来,真实一举两得,美极了。 为了更有直观的体验&a…

Spark基础入门

spark基础入门 环境搭建 localstandlonespark ha spark code spark corespark sqlspark streaming 环境搭建 准备工作 创建安装目录 mkdir /opt/soft cd /opt/soft下载scala wget https://downloads.lightbend.com/scala/2.13.12/scala-2.13.12.tgz -P /opt/soft解压scala…

基于 Flink 构建实时数据湖的实践

本文整理自火山引擎云原生计算研发工程师王正和闵中元在本次 CommunityOverCode Asia 2023 数据湖专场中的《基于 Flink 构建实时数据湖的实践》主题演讲。 实时数据湖是现代数据架构的核心组成部分,随着数据湖技术的发展,用户对其也有了更高的需求&…

Mysql高可用|索引|事务 | 调优

前言 「作者主页」:雪碧有白泡泡 「个人网站」:雪碧的个人网站 文章目录 前言sql语句的执行顺序关键词连接名字解释sql语句 面试坑点存储引擎MYSQL存储引擎 SQL优化索引索引失效索引的数据结构面试坑点 锁事务四大特性事务的隔离级别MVCC 读写分离面试坑…

以低成本实现高转化:搭建年入百万的知识付费网站的技巧与方法

明理信息科技知识付费平台 一、引言 随着知识经济的崛起,越来越多的知识提供者希望搭建自己的知识付费平台。然而,对于新手来说,如何以低成本、高效率地实现这一目标,同时满足自身需求并提高客户转化率,是一大挑战。…

POST:http://XXX:XXXX/XXXX/XXXX(404 Not found)离谱

很离谱,同样的请求方式,不同的接口会有404的问题。看下边: 上边接口访问正常,下边接口出现404.且本地测试也可以,代码也推到公司git上了。真的很离谱。 我也不知道怎么回事,无语||||||| 哪位兄弟知道啊&a…

4.配置系统时钟思路及方法

前言: 比起之前用过的三星的猎户座4412芯片,STM32F4的系统时钟可以说是小巫见大巫,首先我们需要清晰时钟产生的原理:几乎大多数的芯片都是由晶振产生一个比较低频的频率,然后通过若干个PLL得到单片机能承受的频率&…

2023_Spark_实验二十八:Flume部署及配置

实验目的:熟悉掌握Flume部署及配置 实验方法:通过在集群中部署Flume,掌握Flume配置 实验步骤: 一、Flume简介 Flume是一种分布式的、可靠的和可用的服务,用于有效地收集、聚合和移动大量日志数据。它有一个简单灵活…

LibreNMS:从docker出发

引言 LibreNMS 是一个免费开源的网络监控和自动化工具,用于监视网络设备、服务器和应用程序的性能和状态。它提供了一个集中的管理平台,帮助管理员实时监控和管理整个网络基础设施。 以下是 LibreNMS 的一些主要特点和功能: 自动发现&#…

20、清华、杭州医学院等提出:DA-TransUNet,超越TranUNet,深度医学图像分割框架的[皇帝的新装]

前言: 本文由清华电子工程学院、杭州医学院、大阪大学免疫学前沿研究所、日本科学技术高等研究院信息科学学院、东京法政大学计算机与信息科学专业共同作者,于2023年11月14号发表于arXiv的《Electrical Engineering and Systems Science》期刊。 论文&…

【Python基础】生成器

文章目录 [toc]什么是生成器生成器示例生成器工作流程生成器表达式send()方法和close方法send()方法close()方法 什么是生成器 在Python中,使用生成器可以很方便地支持迭代器协议生成器通过生成器函数产生,通过def定义,但不是通过return返回…

酷雷曼再获“国家高新技术企业”认定

2023年12月8日,《对湖北省认定机构2023年认定报备的第五批高新技术企业拟进行备案的公示》正式发布,酷雷曼武汉同创蓝天科技有限公司成功获评“国家高新技术企业”认定。 屡获权威认定,见证硬核实力 被评定为高新技术企业是我国企业最高荣誉…

武汉小程序开发全攻略:从创意到上线,10个必备步骤详解

在当前数字化时代,小程序已经成为企业营销和服务的重要工具。特别是在武汉这样的创新型城市,小程序开发更是备受青睐。本文将为您详细解读武汉小程序开发的全攻略,从创意到上线的10个必备步骤。 步骤一:确定小程序类型和功能定位…

DSP捕获输入简单笔记

之前使用stm32的大概原理是: 输入引脚输入一个脉冲,捕获1开始极性捕获,捕获的是从启动捕获功能开始计数,捕获的是当前的计数值; 例如一个脉冲,捕获1捕获上升沿,捕获2捕获下降沿;而两…

mysql自动安装脚本(快速部署mysql)

mysql_install - 适用于生产环境单实例快速部署 MySQL8.0 自动安装脚本 mysql8_install.sh(执行前修改一下脚本里的配置参数,改成你自己的)(博客末尾) my_test.cnf(博客末尾)(这个…

Linux性能优化常做的一些事情

Linux性能优化是一个广泛的主题,涉及多个方面。以下是一些常见的Linux性能优化建议: 硬件和系统配置: 使用SSD替代HDD。确保系统有足够的RAM。使用多核CPU。配置合适的网络硬件和带宽。 磁盘I/O性能: 使用RAID来提高I/O性能。使用…

WordCloud—— 词云

【说明】文章内容来自《机器学习入门——基于sklearn》,用于学习记录。若有争议联系删除。 wordcloud 是python的第三方库,称为词云,也成文字云,可以根据文本中的词频以直观和艺术化的形式展示文本中词语的重要性。 依赖于pillow …