DCGAN-MNIST——使用TensorFlow 2 / Keras实现深度卷积DCGAN来生成时尚MNIST的灰度图像

news2025/4/16 13:48:21

DCGAN-MNIST——使用TensorFlow 2 / Keras实现深度卷积DCGAN来生成时尚MNIST的灰度图像

    • 1. 效果图
    • 2. 原理
      • 2.1 结构指南
      • 2.2 模型结构及训练过程

这篇博客将介绍如何使用TensorFlow 2 / Keras中实现深度卷积GAN(DCGAN)来生成类似时尚MNIST的灰度图像。将介绍DCGAN架构指南,如何训练稳定的DCGAN。在TensorFlow 2/Keras中使用灰度时尚MNIST图像完成DCGAN代码实现。使用了Keras Model子类化来定制train_step,然后调用Keras Model.fit()进行训练。

下一篇博客将实现用时尚彩色图像训练的DCGAN来展示GAN训练的挑战。

  • DCGAN 架构指南
  • 定制 train_step() 与Keras model.fit()
  • TensorFlow 2 / Keras实现DCGAN

每个GAN至少有一个发生器和一个鉴别器。当生成器和鉴别器相互竞争时,生成器在从鉴别器获得反馈时,能够更好地生成接近训练数据分布的图像。

1. 效果图

生成器结构:
在这里插入图片描述

鉴别器结构:
在这里插入图片描述

训练1 VS 25 VS 50效果图如下:
在这里插入图片描述

(28, 28, 1)


2023-05-25 20:44:36.432534: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/50
1/1 [] - 3s 3s/step - d_loss: 0.7167 - g_loss: 1.2281
Epoch 2/50
1/1 [
] - 0s 160ms/step - d_loss: 0.7766 - g_loss: 0.9896
Epoch 3/50
1/1 [] - 0s 154ms/step - d_loss: 0.8732 - g_loss: 0.6658
Epoch 4/50
1/1 [
] - 0s 166ms/step - d_loss: 0.6789 - g_loss: 0.7311
Epoch 5/50
1/1 [] - 0s 154ms/step - d_loss: 0.4631 - g_loss: 1.0747
Epoch 6/50
1/1 [
] - 0s 165ms/step - d_loss: 0.4101 - g_loss: 1.0436
Epoch 7/50
1/1 [] - 0s 223ms/step - d_loss: 0.3703 - g_loss: 1.1298
Epoch 8/50
1/1 [
] - 0s 197ms/step - d_loss: 0.5815 - g_loss: 1.1503
Epoch 9/50
1/1 [] - 0s 166ms/step - d_loss: 0.4747 - g_loss: 1.7595
Epoch 10/50
1/1 [
] - 0s 156ms/step - d_loss: 0.3227 - g_loss: 2.4748
Epoch 11/50
1/1 [] - 0s 173ms/step - d_loss: 0.2040 - g_loss: 3.1504
Epoch 12/50
1/1 [
] - 0s 166ms/step - d_loss: 0.2089 - g_loss: 2.6114
Epoch 13/50
1/1 [] - 0s 155ms/step - d_loss: 0.3123 - g_loss: 1.8193
Epoch 14/50
1/1 [
] - 0s 150ms/step - d_loss: 0.2877 - g_loss: 2.4994
Epoch 15/50
1/1 [] - 0s 150ms/step - d_loss: 0.1433 - g_loss: 2.4561
Epoch 16/50
1/1 [
] - 0s 149ms/step - d_loss: 0.1219 - g_loss: 1.9404
Epoch 17/50
1/1 [] - 0s 156ms/step - d_loss: 0.1367 - g_loss: 1.8200
Epoch 18/50
1/1 [
] - 0s 152ms/step - d_loss: 0.1019 - g_loss: 2.0167
Epoch 19/50
1/1 [] - 0s 188ms/step - d_loss: 0.0696 - g_loss: 1.5635
Epoch 20/50
1/1 [
] - 0s 173ms/step - d_loss: 0.0684 - g_loss: 1.4290
Epoch 21/50
1/1 [] - 0s 196ms/step - d_loss: 0.0742 - g_loss: 1.5191
Epoch 22/50
1/1 [
] - 0s 192ms/step - d_loss: 0.0975 - g_loss: 1.5210
Epoch 23/50
1/1 [] - 0s 190ms/step - d_loss: 0.0831 - g_loss: 0.6667
Epoch 24/50
1/1 [
] - 0s 173ms/step - d_loss: 0.1027 - g_loss: 0.6803
Epoch 25/50
1/1 [] - 0s 173ms/step - d_loss: 0.0767 - g_loss: 1.1220
Epoch 26/50
1/1 [
] - 0s 181ms/step - d_loss: 0.0440 - g_loss: 1.4847
Epoch 27/50
1/1 [] - 0s 174ms/step - d_loss: 0.0304 - g_loss: 1.4330
Epoch 28/50
1/1 [
] - 0s 182ms/step - d_loss: 0.0268 - g_loss: 1.2883
Epoch 29/50
1/1 [] - 0s 188ms/step - d_loss: 0.0293 - g_loss: 1.3937
Epoch 30/50
1/1 [
] - 0s 173ms/step - d_loss: 0.0136 - g_loss: 1.0047
Epoch 31/50
1/1 [] - 0s 209ms/step - d_loss: 0.0154 - g_loss: 0.8617
Epoch 32/50
1/1 [
] - 0s 199ms/step - d_loss: 0.0114 - g_loss: 0.5661
Epoch 33/50
1/1 [] - 0s 219ms/step - d_loss: 0.0093 - g_loss: 0.6212
Epoch 34/50
1/1 [
] - 0s 193ms/step - d_loss: 0.0084 - g_loss: 0.5213
Epoch 35/50
1/1 [] - 0s 210ms/step - d_loss: 0.0073 - g_loss: 0.4086
Epoch 36/50
1/1 [
] - 0s 195ms/step - d_loss: 0.0059 - g_loss: 0.3696
Epoch 37/50
1/1 [] - 0s 193ms/step - d_loss: 0.0088 - g_loss: 0.3803
Epoch 38/50
1/1 [
] - 0s 177ms/step - d_loss: 0.0084 - g_loss: 0.2576
Epoch 39/50
1/1 [] - 0s 185ms/step - d_loss: 0.0072 - g_loss: 0.3387
Epoch 40/50
1/1 [
] - 0s 182ms/step - d_loss: 0.0056 - g_loss: 0.3223
Epoch 41/50
1/1 [] - 0s 228ms/step - d_loss: 0.0046 - g_loss: 0.2862
Epoch 42/50
1/1 [
] - 0s 226ms/step - d_loss: 0.0059 - g_loss: 0.2288
Epoch 43/50
1/1 [] - 0s 197ms/step - d_loss: 0.0049 - g_loss: 0.2531
Epoch 44/50
1/1 [
] - 0s 200ms/step - d_loss: 0.0056 - g_loss: 0.1869
Epoch 45/50
1/1 [] - 0s 193ms/step - d_loss: 0.0038 - g_loss: 0.2534
Epoch 46/50
1/1 [
] - 0s 192ms/step - d_loss: 0.0050 - g_loss: 0.1715
Epoch 47/50
1/1 [] - 0s 198ms/step - d_loss: 0.0044 - g_loss: 0.1654
Epoch 48/50
1/1 [
] - 0s 181ms/step - d_loss: 0.0056 - g_loss: 0.1122
Epoch 49/50
1/1 [] - 0s 211ms/step - d_loss: 0.0035 - g_loss: 0.1579
Epoch 50/50
1/1 [
] - 0s 188ms/step - d_loss: 0.0043 - g_loss: 0.1457

2. 原理

2.1 结构指南

DCGAN论文介绍了一种GAN架构,其中鉴别器和生成器(discriminator and generator)由卷积神经网络(CNNs)定义。它提供了几个体系结构指南来提高训练稳定性:为了简洁起见,将生成器称为G,鉴别器称为D。
在这里插入图片描述

  1. GD都替换卷积为条纹卷积和分数阶跨步卷积
  • 条纹卷积(Strided convolutions):步长为2的卷积层,用于D中的下采样。
  • 分数阶跨步卷积(Fractional-strided convolutions):Conv2Transpose层的跨步为2,用于G中的上采样。
  1. GD都使用归一化
  • 批量规一化
    本文建议在G和D中使用批量归一化(batchnorm)来帮助稳定GAN训练。Batchnorm将输入层标准化为具有零均值和单位方差。它通常添加在隐藏层之后和激活层之前。随着我们在GAN系列中的进展,您将学习到更好的GAN规范化技术。
  1. 移除深度架构中的全量连接隐藏层
  2. 除使用Tanh的输出层,都生成使用ReLU激活器
  • 激活器
    DCGAN生成器和鉴别器中有四种常用的激活函数:sigmoid、tanh、ReLU和leakyReLU。
  • sigmoid:将数字压缩为0(假)和1(真)。由于DCGAN鉴别器进行二元分类,在D的最后一层使用sigmoid。
  • tanh(Hyperbolic Tangent 双曲正切):也是s形的,类似于s形;事实上,它是一个缩放的s形,但以0为中心,并将输入值压缩为[-1,1]。根据论文的建议在G的最后一层使用tanh。这就是为什么需要将训练图像预处理到[-1,1]的范围内。
  • ReLU(Rectified Linear Activation 整流线性激活):当输入值为负值时,返回0;否则,它将返回输入值。建议对G中的所有层进行ReLU激活,除了使用tanh的输出层。
  • LeakyReLU:与ReLU类似,只是当输入值为负值时,它使用常数alpha来给它一个非常小的斜率。正如论文所建议的那样,将斜率(alpha)设置为0.2。在D中对除最后一层之外的所有层使用LeakyReLU激活。

2.2 模型结构及训练过程

同时训练两个网络:一个生成器和一个鉴别器。为了创建DCGAN模型,首先需要使用Keras Sequential API定义生成器和鉴别器的模型体系结构。然后使用Keras模型子类化来创建DCGAN。

  1. 数据
    第一步是为训练做好数据准备。将使用时尚MNIST数据来训练DCGAN。

  2. 数据加载
    Fashion MNIST数据集具有训练/测试分割。使用训练数据或加载两个训练/测试数据集用于训练目的。对于具有Fashion MNIST的DCGAN,仅使用训练数据集进行训练就足够了
    使用train_images.shape查看Fashion MNIST训练数据形状,并注意到(60000,28,28)的形状,这意味着有60000个28x28大小的训练灰度图像。

  3. 可视化
    将训练数据可视化,以了解图像的外观。看看Fashion MNIST灰度28x28x1图片是什么样子

  4. 数据预处理
    加载的数据是(60000,28,28)的形状是灰度级的。因此需要将通道的第4个维度添加为1,并根据TensorFlow中训练的需要将数据类型(从NumPy数组)转换为float32。
    将输入图像归一化到[-1,1]的范围,因为生成器的最终层激活使用了前面提到的tanh。
    防止电脑内存占到100% 死机,只选择100张照片作为训练数据集

  5. 生成器模型
    生成器的工作是生成看似合理的图像。它的目的是试图欺骗鉴别器,使其认为生成的图像是真实的。
    生成器将随机噪声作为输入,并输出与训练图像相似的图像。由于我们在这里生成的是28x28灰度图像,因此模型架构需要确保得到的形状使得生成器输出应该是28x28x1
    使用Reshape层将1D随机噪声(潜在矢量)转换为3D
    在Fashion MNIST的情况下,用Keras Conv2DTranspose层(论文中提到的分数阶跨步卷积)上采样几次,达到输出图像大小,即28x28x1形状的灰度图像。
    有几层构成了G的构建块:
    密集(完全连接)层:仅用于重塑和平坦噪声矢量
    Conv2DTranspose:上采样
    BatchNormalization:稳定训练;在conv层之后和激活功能之前。
    除了使用tanh的输出之外,所有层都使用G中的ReLU激活。

  6. 鉴别器模型
    鉴别器是一个简单的二元分类器,可以告诉图像是真还是假。它的目的是试图对图像进行正确的分类。鉴别器和常规分类器之间有一些区别:
    使用LeakyReLU作为DCGAN论文中的激活函数。
    鉴别器有两组输入图像:标记为1的训练数据集或真实图像,以及标记为0的生成器创建的伪图像。
    注意:鉴别器网络通常比生成器更小或更简单,因为鉴别器的工作比生成器容易得多。如果鉴别器太强,那么发生器就不会有很好的改善。
    创建方法以构建鉴别器,输入为真实的图像和生成器生成的图像,及图像的宽/高/深,LeakyReLU的值
    Fashion MNIST的图像大小为28x28x1,这些图像作为argos传递到宽度、高度和深度的函数中。alpha表示LeakyReLU用于定义泄漏的斜率。

  7. 损失函数:修改后的极大极小损失
    在创建DCGAN模型之前,先讨论一下损失函数。计算损失是DCGAN(或任何GAN)训练的核心。对于DCGAN将实现修改的极大极小损失,它使用二进制交叉熵(BCE)损失函数。随着在GAN系列中的进展将了解不同GAN变体中的其他损失函数。
    需要计算两个损失:一个用于鉴别器损失,另一个用于生成器损失。
    鉴别器损失: 由于有两组图像被输入鉴别器(真实图像和伪图像),将计算每组图像的损失,并将它们组合作为鉴别器损失。
    生成器损失: 对于生成器损失可以训练G来最大化log D(G(z)),而不是训练G来最小化log(1−D(G))。这就是修正后的极小极大损失。

  8. DCGAN模型:覆盖train_step
    已经定义了生成器和鉴别器架构,并了解了损失函数是如何工作的。准备好将D和G放在一起,通过子类化keras.model并重写train_step()来训练鉴别器和生成器,从而创建DCGAN模型。
    以下是关于如何编写低级别代码以自定义model.fit()的文档。这种方法的优点是仍然可以使用GradientTape进行自定义训练循环,同时仍然可以受益于fit()的方便功能(例如,回调和内置分发支持等)。
    因此对keras.Model进行子类化,以创建DCGAN类–类DCGAN(keras.MModel)
    用真实图像(标记为1)和伪图像(标记为0)来训练鉴别器
    在真实图像上计算鉴别器损失
    在伪图像上计算鉴别器损失
    总的鉴别器损失
    计算鉴别器梯度gradients
    更新鉴别器权重
    不更新鉴别器权重的情况下训练生成器
    计算生成器梯度
    更新生成器权重

  9. 训练期间的监控和可视化:覆盖Keras callback()来监控鉴别器/生成器损失
    例如,对于图像分类,损失可以帮助了解模型的性能。对于GAN,D损失和G损失表明每个模型是如何单独执行的,可能是也可能不是GAN模型总体执行情况的准确衡量标准。我
    们将在“GAN培训挑战”的下一篇文章中对此进行进一步讨论
    对于GAN评估,对训练过程中生成的图像进行视觉检查是很重要的,未来将学习其他评估方法。
    训练50个循环,GPU下每个循环只耗时25s
    训练过程中,可以可视的检查图像以确定生成器的图像质量
    分别查看 第1次,25次,50次训练后生成的Fashion-MNIST图像,可以看到生成器变的越来越好。
    使用generator.summary()查看定义的生成器模型架构,以确保每一层都是想要的形状

3. 源码

# 同时训练两个网络:一个生成器和一个鉴别器。为了创建DCGAN模型,首先需要使用Keras Sequential API定义生成器和鉴别器的模型体系结构。然后使用Keras模型子类化来创建DCGAN。
#
# 启用Colab GPU,要在Colab中启用GPU运行时,请转到编辑→ 笔记本设置或运行时→ 更改运行时类型,然后从硬件加速器下拉菜单中选择“GPU”。
# 导入包,使用TensorFlow 2/Keras编写代码,并使用matplotlib进行可视化。
# USAGE
# python dcgan_minist.py

import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

# 数据
# 第一步是为训练做好数据准备。将使用时尚MNIST数据来训练DCGAN。

# 数据加载
# Fashion MNIST数据集具有训练/测试分割。使用训练数据或加载两个训练/测试数据集用于训练目的。对于具有Fashion MNIST的DCGAN,仅使用训练数据集进行训练就足够了
(train_images, train_labels), (_, _) = tf.keras.datasets.fashion_mnist.load_data()
# 使用train_images.shape查看Fashion MNIST训练数据形状,并注意到(60000,28,28)的形状,这意味着有60000个28x28大小的训练灰度图像。

# 可视化
# 将训练数据可视化,以了解图像的外观。看看Fashion MNIST灰度28x28x1图片是什么样子
plt.figure()
plt.imshow(train_images[0], cmap='gray')
plt.show()

# 数据预处理
# 加载的数据是(60000,28,28)的形状是灰度级的。因此需要将通道的第4个维度添加为1,并根据TensorFlow中训练的需要将数据类型(从NumPy数组)转换为float32。
# 将输入图像归一化到[-1,1]的范围,因为生成器的最终层激活使用了前面提到的tanh。
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5
print(train_images[0].shape)
# train_images = tf.convert_to_tensor(train_images)
# 防止电脑内存占到100% 死机,只选择100张照片作为训练数据集
train_images = tf.convert_to_tensor(train_images[:20])

# 随机噪声的潜在维数
LATENT_DIM = 100

# 生成器模型
# 生成器的工作是生成看似合理的图像。它的目的是试图欺骗鉴别器,使其认为生成的图像是真实的。
# 生成器将随机噪声作为输入,并输出与训练图像相似的图像。由于我们在这里生成的是28x28灰度图像,因此模型架构需要确保得到的形状使得生成器输出应该是28x28x1
# 使用Reshape层将1D随机噪声(潜在矢量)转换为3D
# 在Fashion MNIST的情况下,用Keras Conv2DTranspose层(论文中提到的分数阶跨步卷积)上采样几次,达到输出图像大小,即28x28x1形状的灰度图像。
# 有几层构成了G的构建块:
# 密集(完全连接)层:仅用于重塑和平坦噪声矢量
# Conv2DTranspose:上采样
# BatchNormalization:稳定训练;在conv层之后和激活功能之前。
# 除了使用tanh的输出之外,所有层都使用G中的ReLU激活。
def build_generator():
    # Con2DTranspose层的权重初始化
    WEIGHT_INIT = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    # 图像的颜色通道, 1 for gray scale and 3 for color images
    CHANNELS = 1

    # 使用Keras Sequential API创建模型
    model = Sequential(name='generator')

    # 定义一个密集层 为重塑为3D做准备,并确保在模型架构的第一层中定义输入形状。添加BatchNormalization和ReLU层
    model.add(layers.Dense(7 * 7 * 256, input_dim=LATENT_DIM))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())

    # reshape 1D 为 3D
    model.add(layers.Reshape((7, 7, 256)))

    # 2次使用2步阶的Conv2DTranspose 以获取7x7 to 14x14 to 28x28 在每个Conv2DTranspose层后提娜佳ReLU激活层
    # upsample to 14x14: apply a transposed CONV => BN => RELU
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding="same", kernel_initializer=WEIGHT_INIT))
    model.add(layers.BatchNormalization())
    model.add((layers.ReLU()))

    # upsample to 28x28: apply a transposed CONV => BN => RELU
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding="same", kernel_initializer=WEIGHT_INIT))
    model.add(layers.BatchNormalization())
    model.add((layers.ReLU()))

    # 最后使用tanh为激活函数的Conv2D层
    # 注意:CHANNELS之前被定义为1,这将生成28x28x1的图像,与原始的灰度训练图像相匹配。
    model.add(layers.Conv2D(CHANNELS, (5, 5), padding="same", activation="tanh"))

    return model


# 鉴别器模型
# 鉴别器是一个简单的二元分类器,可以告诉图像是真还是假。它的目的是试图对图像进行正确的分类。鉴别器和常规分类器之间有一些区别:
# 使用LeakyReLU作为DCGAN论文中的激活函数。
# 鉴别器有两组输入图像:标记为1的训练数据集或真实图像,以及标记为0的生成器创建的伪图像。
# 注意:鉴别器网络通常比生成器更小或更简单,因为鉴别器的工作比生成器容易得多。如果鉴别器太强,那么发生器就不会有很好的改善。
# 创建方法以构建鉴别器,输入为真实的图像和生成器生成的图像,及图像的宽/高/深,LeakyReLU的值
# Fashion MNIST的图像大小为28x28x1,这些图像作为argos传递到宽度、高度和深度的函数中。alpha表示LeakyReLU用于定义泄漏的斜率。
def build_discriminator(width, height, depth, alpha=0.2):
    # 使用Keras Sequential API创建模型
    model = Sequential(name='discriminator')
    input_shape = (height, width, depth)

    # We use Conv2D, BatchNormalization, and LeakyReLU twice to downsample.
    # 使用Conv2D, BatchNormalization, and LeakyReLU 2次以进行下采样
    # first set of CONV => BN => leaky ReLU layers
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding="same", input_shape=input_shape))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=alpha))

    # second set of CONV => BN => leacy ReLU layers
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding="same"))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=alpha))

    # Flatten and apply dropout 展平以及应用dropout
    model.add(layers.Flatten())
    model.add(layers.Dropout(0.3))

    # 最后一层使用 sigmoid激活函数一输出二进制分类(binary classification)的结果
    model.add(layers.Dense(1, activation="sigmoid"))

    return model


# 损失函数:修改后的极大极小损失
# 在创建DCGAN模型之前,先讨论一下损失函数。计算损失是DCGAN(或任何GAN)训练的核心。对于DCGAN将实现修改的极大极小损失,它使用二进制交叉熵(BCE)损失函数。随着在GAN系列中的进展将了解不同GAN变体中的其他损失函数。
# 需要计算两个损失:一个用于鉴别器损失,另一个用于生成器损失。
# 鉴别器损失
# 由于有两组图像被输入鉴别器(真实图像和伪图像),将计算每组图像的损失,并将它们组合作为鉴别器损失。
# total_D_loss = loss_from_real_images + loss_from_fake_images
# 生成器损失
# 对于生成器损失可以训练G来最大化log D(G(z)),而不是训练G来最小化log(1−D(G))。这就是修正后的极小极大损失。
class DCGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(DCGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(DCGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    # DCGAN模型:覆盖train_step
    # 已经定义了生成器和鉴别器架构,并了解了损失函数是如何工作的。准备好将D和G放在一起,通过子类化keras.model并重写train_step()来训练鉴别器和生成器,从而创建DCGAN模型。
    # 以下是关于如何编写低级别代码以自定义model.fit()的文档。这种方法的优点是仍然可以使用GradientTape进行自定义训练循环,同时仍然可以受益于fit()的方便功能(例如,回调和内置分发支持等)。
    # 因此对keras.Model进行子类化,以创建DCGAN类–类DCGAN(keras.MModel)
    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]
        noise = tf.random.normal(shape=(batch_size, self.latent_dim))

        # 用真实图像(标记为1)和伪图像(标记为0)来训练鉴别器
        with tf.GradientTape() as tape:
            # 在真实图像上计算鉴别器损失
            pred_real = self.discriminator(real_images, training=True)
            d_loss_real = self.loss_fn(tf.ones((batch_size, 1)), pred_real)

            # 在伪图像上计算鉴别器损失
            fake_images = self.generator(noise)
            pred_fake = self.discriminator(fake_images, training=True)
            d_loss_fake = self.loss_fn(tf.zeros((batch_size, 1)), pred_fake)

            # 总的鉴别器损失
            d_loss = (d_loss_real + d_loss_fake) / 2

        # 计算鉴别器梯度gradients
        grads = tape.gradient(d_loss, self.discriminator.trainable_variables)
        # 更新鉴别器权重
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_variables))

        # We train the generator while not updating the weights of the discriminator.
        # 不更新鉴别器权重的情况下训练生成器
        misleading_labels = tf.ones((batch_size, 1))
        with tf.GradientTape() as tape:
            fake_images = self.generator(noise, training=True)
            pred_fake = self.discriminator(fake_images, training=True)
            g_loss = self.loss_fn(misleading_labels, pred_fake)
        # 计算生成器梯度
        grads = tape.gradient(g_loss, self.generator.trainable_variables)
        # 更新生成器权重
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables))

        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }


# 训练期间的监控和可视化:覆盖Keras callback()来监控鉴别器/生成器损失
# 例如,对于图像分类,损失可以帮助了解模型的性能。对于GAN,D损失和G损失表明每个模型是如何单独执行的,可能是也可能不是GAN模型总体执行情况的准确衡量标准。我
# 们将在“GAN培训挑战”的下一篇文章中对此进行进一步讨论
# 对于GAN评估,对训练过程中生成的图像进行视觉检查是很重要的,未来将学习其他评估方法。
# 训练50个循环,GPU下每个循环只耗时25s
# 训练过程中,可以可视的检查图像以确定生成器的图像质量
# 分别查看 第1次,25次,50次训练后生成的Fashion-MNIST图像,可以看到生成器变的越来越好。
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=3, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images *= 255
        generated_images.numpy()
        for i in range(self.num_img):
            img = keras.preprocessing.image.array_to_img(generated_images[i])
            img.save("images/generated_img_%03d_%d.png" % (epoch, i))


generator = build_generator()
# 使用generator.summary()查看定义的生成器模型架构,以确保每一层都是想要的形状
print(generator.summary())

discriminator = build_discriminator(width=28, height=28, depth=1, alpha=0.2)
print(discriminator.summary())

# 编译和训练模型
dcgan = DCGAN(discriminator=discriminator, generator=generator, latent_dim=LATENT_DIM)

LR = 0.0002  # learning rate

# 如DCGAN论文所建议的,使用Adam优化器,生成器和鉴别器的学习率均为0.0002。对D和G都使用二进制交叉熵损失函数。
dcgan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=LR, beta_1=0.5),
    g_optimizer=keras.optimizers.Adam(learning_rate=LR, beta_1=0.5),
    loss_fn=keras.losses.BinaryCrossentropy(),
)

# 简单的调用model.fit 训练DCGAN模型
NUM_EPOCHS = 50  # number of epochs
dcgan.fit(train_images, epochs=NUM_EPOCHS,
          callbacks=[GANMonitor(num_img=16, latent_dim=LATENT_DIM)])

参考

  • https://pyimagesearch.com/2021/11/11/get-started-dcgan-for-fashion-mnist/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/580466.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

魔法反射--java反射进阶(实战篇)

👳我亲爱的各位大佬们好😘😘😘 ♨️本篇文章记录的为 魔法反射–java反射进阶(实战篇) 相关内容,适合在学Java的小白,帮助新手快速上手,也适合复习中,面试中的大佬🙉🙉🙉…

openpose原理及安装教程(姿态识别)

OpenPose是一个基于深度学习的人体姿态估计框架,可以实时地估计人体的关键点,包括身体和手部姿势。它是由卡内基梅隆大学的研究团队开发的,已经成为了人体姿态估计领域的一个重要项目。 OpenPose的原理是基于卷积神经网络(CNN),通过对图像进行深度学习处理,可以检测出…

如何在华为OD机试中获得满分?Java实现【寻找峰值】一文详解!

✅创作者:陈书予 🎉个人主页:陈书予的个人主页 🍁陈书予的个人社区,欢迎你的加入: 陈书予的社区 🌟专栏地址: Java华为OD机试真题(2022&2023) 文章目录 1. 题目描述2. 输入描述3. 输出描述…

Aerial Vision-and-Dialog Navigation阅读报告

Aerial Vision-and-Dialog Navigation 本次报告,包含以下部分:1摘要,2数据集/模拟器,3AVDN任务,4模型,5实验结果。重点介绍第2/3部分相关主页:Aerial Vision-and-Dialog Navigation (google.com…

【章节2】husky + 自动检测是否有未解决的冲突 + 预检查debugger + 自动检查是否符合commit规范

在章节1中我们学习到了commit的规范、husky的安装和使用、lint-staged怎么安装以及怎么用来格式化代码。那么这篇文章我们来看看commit预处理中我们还能做哪些处理呢? 自然,我们还是要用到husky这个东西的,大致过程其实和章节1异曲同工&#…

不要再来问我小学、初中毕业想出去学习编程找到工作的问题了,你要做就去做,结果自己扛着就行了!

🚀 个人主页 极客小俊 ✍🏻 作者简介:web开发者、设计师、技术分享博主 🐋 希望大家多多支持一下, 我们一起进步!😄 🏅 如果文章对你有帮助的话,欢迎评论 💬点赞&#x1…

探索Java面向对象编程的奇妙世界(五)

⭐ Object 类⭐ toString 方法⭐ 和 equals 方法⭐ super 关键字⭐ 继承树追溯⭐ 封装(encapsulation) ⭐ Object 类 Object 类基本特性 🐟 Object 类是所有类的父类,所有的 Java 对象都拥有 Object 类的属性和方法。 🐟 如果在类的声明中未…

docker-compose方式安装运行Jenkins

docker-compose方式安装运行Jenkins 服务器系统:centos 7.6 以docker-compose 编排容器方式安装,当然需提前安装docker-compose环境(见百度->docker-compose环境安装) docker-compose.yml version: 3.1 services:jenkins:i…

WF攻击(网站指纹攻击)

网站指纹(WF)攻击是被动的本地攻击者通过比较用户发送和接收的数据包序列与先前记录的数据集来确定加密互联网流量的目的地。可以通过网络流量中的模式来识别Tor用户访问过的页面。因此,WF攻击是Tor等隐私增强技术特别关注的题。 攻击过程 该…

分布式网络通信框架(九)——RpcChannel调用过程

介绍 客户端使用RpcChannel对象来构造UserServiceRpc_Stub对象&#xff0c;并利用该对象中RpcChannel::CallMethod来进行rpc调用请求,RpcChannel完成的工作是如下rpc调用流程图的红圈部分&#xff1a; 客户端使用mprpc框架的业务代码 // calluserservice.cc #include <ios…

【算法题解】31. 翻转二叉树的递归解法

这是一道 简单 题 https://leetcode.cn/problems/invert-binary-tree/ 题目 给你一棵二叉树的根节点 r o o t root root &#xff0c;翻转这棵二叉树&#xff0c;并返回其根节点。 示例 1&#xff1a; 输入&#xff1a;root [4,2,7,1,3,6,9] 输出&#xff1a;[4,7,2,9,6…

Vivado综合属性系列之十二 BLACK_BOX

目录 一、前言 二、BLACK_BOX ​2.1 属性说明 ​2.2 工程代码 ​2.3 结果 一、前言 ​在调试中&#xff0c;有时不需要知道一个模块或实例的具体实现&#xff0c;或者需要使其对外属于不可见&#xff0c;只知道它的输入输出&#xff0c;即像一个黑盒&#xff0c;此时可以对模…

Linux内核源码分析 1:Linux内核体系架构和学习路线

好久没有动笔写文章了&#xff0c;这段时间经历了蛮多事情的。这段时间自己写了一两个基于不同指令集的Linux内核&#xff0c;x86和RISC-V。期间也去做了一些嵌入式相关的工作&#xff0c;研究了一下ARM指令集架构。 虽然今年九月份我就要申请了&#xff0c;具体申请AI方向还是…

【使用ChatGPT制作视频】

内容目录 一、利用ChatGPT生成视频文案1. 打开ChatGPT&#xff1a;2. 输入需求&#xff1a;3. 复制&#xff1a; 二、制作生成思维导图1. 打开视频制作网站&#xff1a;2. 网页版下侧 - 一键成片 -粘贴Markdown内容&#xff0c;就会自动生成视频&#xff0c;这里放了其中一段&a…

【刷题之路Ⅱ】百度面试题——迷宫问题

【刷题之路Ⅱ】百度面试题——迷宫问题 一、题目描述二、解题1、方法1——暴力递归1.1、思路分析1.2、先将栈实现一下1.3、代码实现 一、题目描述 原题连接&#xff1a; 迷宫问题 题目描述&#xff1a; 定义一个二维数组 N*M &#xff0c;如 5 5 数组下所示&#xff1a; int …

自学网络安全(黑客),一般人我劝你还是算了吧

一、自学网络安全学习的误区和陷阱 1.不要试图先成为一名程序员&#xff08;以编程为基础的学习&#xff09;再开始学习 我在之前的回答中&#xff0c;我都一再强调不要以编程为基础再开始学习网络安全&#xff0c;一般来说&#xff0c;学习编程不但学习周期长&#xff0c;而…

Fiddler抓包工具之fiddler设置抓HTTPS的请求证书安装

设置抓HTTPS的请求包 基础配置&#xff1a; 路径&#xff1a;启动Fiddler 》Tools》Options》HTTPS 注意&#xff1a;Option更改完配置需重启Fiddler才能生效 选中"Decrpt HTTPS traffic", Fiddler就可以截获HTTPS请求&#xff0c;如果是第一次会弹出证书安装提…

车载软件架构 —— 功能安全与基础软件

我是穿拖鞋的汉子&#xff0c;魔都中坚持长期主义的工程师。 老规矩&#xff0c;分享一段喜欢的文字&#xff0c;避免自己成为高知识低文化的工程师&#xff1a; 在最艰难的时候&#xff0c;自己就别去幻想太远的将来&#xff0c;只要鼓励自己过好今天就行了&#xff01; 这世…

node.js 学习 -- koa

一、搭建项目 1. 安装 Koa 框架 yarn add koa2. 引入 const Koa require("koa"); const app new Koa();3. 配置中间件 // ctx 所有http的上下文 // 配置中间件 app.use((ctx, next) > {ctx.body "hello api"; });4. 监听端口 app.listen(3000, …

TPO69 01|Why Snakes Have Forked Tongues|阅读真题精读|10:40-11:40+15:30-16:57

Why Snakes Have Forked Tongues 5/10 目录 Why Snakes Have Forked Tongues P1 P1生词 P1段落大意 无题目 P2 P2生词 P2段落大意 P2题目 【1】词汇题 secreteproduce ✅ 【2】事实信息题|考频高|难度高|定位错误​ P34​ P34生词 P34段落大意 P34题目 【3】词汇题 simultaneo…