AIGC实战——CycleGAN详解与实现
- 0. 前言
- 1. CycleGAN 基本原理
- 2. CycleGAN 模型分析
- 3. 实现 CycleGAN
- 小结
- 系列链接
0. 前言
CycleGAN
是一种用于图像转换的生成对抗网络(Generative Adversarial Network, GAN),可以在不需要配对数据的情况下将一种风格的图像转换成另一种风格,而无需为每一对输入-输出图像配对训练数据。CycleGAN
的核心思想是利用两个生成器和两个判别器,它们共同学习两个域之间的映射关系。例如,将马的图像转换成斑马的图像,或者将苹果图像转换为橙子图像。在本节中,我们将学习 CycleGAN
的基本原理,并实现该模型用于将夏天的风景图像转换成冬天的风景图像,或反之将冬天的风景图像转换为夏天的风景图像。
1. CycleGAN 基本原理
CycleGAN
是一种无需配对的图像转换技术,它可以将一个图像域中的图像转换为另一个图像域中的图像,而不需要匹配这两个域中的图像。它使用两个生成器和两个判别器,其中一个生成器将一个域中的图像转换为另一个域中的图像,而第二个生成器将其转换回来。这个过程被称为循环一致性,转换过程是可逆的。
CycleGAN
可以用于执行从一个类别到另一个类别的图像转换,而无需提供相匹配的输入-输出图像对来训练模型,只需要在两个不同的文件夹中提供这两个类别的图像。在本节中,我们将学习如何训练 CycleGAN
将夏天的风景图像转换成冬天的风景图像,或反之将冬天的风景图像转换为夏天的风景图像,CycleGAN
中的 Cycle
是指将图像从一个类别转换到另一个类别,然后再转换回原始类别的过程。
为了实现图像转换,使用两个 GAN
,每个 GAN
的生成器执行从一个域到另一个域的图像转换。具体来说,假设输入是
X
X
X,那么第一个 GAN
的生成器执行映射
G
:
X
→
Y
G:X\rightarrow Y
G:X→Y,其输出为
Y
=
G
(
X
)
Y = G(X)
Y=G(X);第二个 GAN
的生成器执行逆映射
F
:
Y
→
X
F:Y\rightarrow X
F:Y→X,结果为
X
=
F
(
Y
)
X = F(Y)
X=F(Y)。每个判别器都训练用于区分真实图像和生成图像:
为了训练 CycleGAN
,除了传统的对抗损失外,还添加了循环一致性损失,用于确保给定图像
X
X
X 作为输入,那么经过两次转换
F
(
G
(
X
)
)
∼
X
F(G(X)) \sim X
F(G(X))∼X 后得到的图像与
X
X
X 相同,类似地,需要损失确保
G
(
F
(
Y
)
)
∼
Y
)
G(F(Y)) \sim Y)
G(F(Y))∼Y)。
总体而言,在 CycleGAN
中,需要使用三种不同的损失值:
- 鉴别器损失:用于区分真实图像和伪造图像
- 循环一致性损失:由于
CycleGAN
使用了两个生成器,因此需要确保转换是可逆的,循环一致性损失通过将转换过的图像再次传递到原始的生成器中,并将生成的图像与原始图像进行比较来实现 - 恒等损失 (
Identity loss
):确保生成器在不进行转换的情况下仍然能够生成与原始图像相似的图像,通过将原始图像传递到生成器中,并计算生成图像与原始图像之间的差异
2. CycleGAN 模型分析
CycleGAN
模型构建策略如下:
- 导入数据集并进行预处理
- 定义
UNet
架构用于构建生成器和判别器网络 - 定义两个生成器:
G_AB
:将类别A
图像转换为类别B
图像的生成器G_BA
:将类别B
图像转换为类别A
图像的生成器
- 定义恒等损失:
- 如果将一张橘子的图像输入到橙子生成器,理想情况下,如果生成器完全理解橙子的所有信息,它不应该改变图像,而应该“生成”完全相同的图像,据此,我们可以创建一个恒等变换
- 当类别
A
(real_A
) 的图像通过G_BA
并与real_A
进行比较时,恒等损失应该是最小的 - 当类别
B
(real_B
) 的图像通过G_AB
并与real_B
进行比较时,恒等损失应该是最小的
- 定义GAN损失:
real_A
和fake_A
的判别器和生成器损失(当real_B
图像通过G_BA
时得到fake_A
)real_B
和fake_B
的判别器和生成器损失(当real_A
图像通过G_AB
时得到fake_B
)
- 定义循环一致性损失:
- 一张苹果图像需要通过橙子生成网络进行转换,生成伪造的橘子图像,然后再通过苹果生成网络将伪造的橙子图像转换回苹果图像
fake_B
是real_A
通过G_AB
时的输出,当fake_B
通过G_BA
时应该重新生成real_A
fake_A
是real_B
通过G_BA
时的输出,当fake_A
通过G_AB
时应该重新生成real_B
- 优化三个损失函数的加权和
3. 实现 CycleGAN
在本节中,我们使用 TensorFlow
实现 CycleGAN
模型。
(1) 导入所需模块,并使用 tensorflow_datasets
加载数据集,并使用tensorflow_examples 库中预定义的 pix2pix 模型的生成器和鉴别器:
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
tensorflow_examples
包含一组适用于 CycleGAN
的数据集,如马-斑马、苹果-橘子等。在本节中,我们将使用 summer2winter_yosemite
数据集,包含了夏季图像和冬季图像,训练 CycleGAN
将输入的夏季图像转换为冬季图像,或反之将冬季图像转换为夏季图像。
(2) 加载数据,并获取训练和测试图像:
import os
import time
import matplotlib.pyplot as plt
import tensorflow as tf
from glob import glob
AUTOTUNE = tf.data.AUTOTUNE
train_summer = tf.data.Dataset.list_files('summer2winter_yosemite/trainA/*.jpg')
train_winter = tf.data.Dataset.list_files('summer2winter_yosemite/trainB/*.jpg')
test_summer = tf.data.Dataset.list_files('summer2winter_yosemite/testA/*.jpg')
test_winter = tf.data.Dataset.list_files('summer2winter_yosemite/testB/*.jpg')
(3) 设置超参数:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
(4) 训练网络之前,对图像进行预处理,为了获得更好的性能,对训练图像添加随机抖动。执行归一化后,将图像调整为 286 x 286
,然后随机裁剪为 256 x 256
,最后应用随机抖动:
def random_crop(image):
cropped_image = tf.image.random_crop(
image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
image = tf.cast(image, tf.float32)
image = (image / 127.5) - 1
return image
def random_jitter(image):
# resizing to 286 x 286 x 3
image = tf.image.resize(image, [286, 286],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
# randomly cropping to 256 x 256 x 3
image = random_crop(image)
# random mirroring
image = tf.image.random_flip_left_right(image)
return image
def load(image):
image = tf.io.read_file(image)
image = tf.image.decode_jpeg(image)
input_image = tf.cast(image, tf.float32)
return input_image
(5) 数据增强(随机裁剪和抖动)仅对训练图像进行,因此需要分别定义训练数据和测试数据的图像预处理函数:
def preprocess_image_train(image):
image = load(image)
image = random_jitter(image)
image = normalize(image)
return image
def preprocess_image_test(image):
image = load(image)
image = normalize(image)
return image
(6) 将以上函数应用于图像时,会将其归一化到范围 [-1,1]
之间,并对训练图像进行数据增强。在训练和测试数据集上应用以上函数,并创建一个数据加载器,用于批量提供训练图像:
train_summer = train_summer.map(
preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
train_winter = train_winter.map(
preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
test_summer = test_summer.map(
preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
test_winter = test_winter.map(
preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
在以上代码中,参数 num_parallel_calls
用于指定需要利用系统中的 CPU
核心数量,可以将其值设置为系统中的 CPU
全部核心数。可以使用 AUTOTUNE = tf.data.AUTOTUNE
值,以便 TensorFlow
动态确定合适的 CPU
核心数量。
(7) 使用在 tensorflow_examples
模块中定义的 pix2pix
模型的生成器和鉴别器,定义两个生成器和两个鉴别器:
sample_summer = next(iter(train_summer))
sample_winter = next(iter(train_winter))
OUTPUT_CHANNELS = 3
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
(8) 查看示例图像,每张图像在绘制之前都会执行归一化处理:
to_winter = generator_g(sample_summer)
to_summer = generator_f(sample_winter)
plt.figure(figsize=(8, 8))
contrast = 8
imgs = [sample_summer, to_winter, sample_winter, to_summer]
title = ['Summer', 'To Winter', 'Winter', 'To Summer']
for i in range(len(imgs)):
plt.subplot(2, 2, i+1)
plt.title(title[i])
if i % 2 == 0:
plt.imshow(imgs[i][0] * 0.5 + 0.5)
else:
plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
(9) 定义损失函数和优化器,使用与 DCGAN
相同的生成器和鉴别器的损失函数:
LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
real_loss = loss_obj(tf.ones_like(real), real)
generated_loss = loss_obj(tf.zeros_like(generated), generated)
total_disc_loss = real_loss + generated_loss
return total_disc_loss * 0.5
def generator_loss(generated):
return loss_obj(tf.ones_like(generated), generated)
(10) 由于 CycleGAN
包含四个模型,两个生成器和两个鉴别器,因此需要定义四个优化器:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
(11) 此外,在 CycleGAN
中,还需要定义两个额外的损失函数。首先是循环一致性损失,用于确保生成结果接近原始输入:
def calc_cycle_loss(real_image, cycled_image):
loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
return LAMBDA * loss1
另外还需要定义一个恒等损失,用于确保如果将图像 Y Y Y 输入生成器 G : X → Y G:X\rightarrow Y G:X→Y,它会输出类似于 Y Y Y 的图像。因此,如果给夏季图像生成器一个夏季的图像作为输入,它不应该对其进行过多修改:
def identity_loss(real_image, same_image):
loss = tf.reduce_mean(tf.abs(real_image - same_image))
return LAMBDA * 0.5 * loss
(12) 定义函数训练生成器和鉴别器。两个鉴别器和两个生成器将通过 tape
梯度进行训练。训练步骤可以分为 4
步:
- 从两个生成器中获取输出图像
- 计算损失
- 计算梯度
- 最后,应用梯度
@tf.function
def train_step(real_x, real_y):
# persistent is set to True because the tape is used more than
# once to calculate the gradients.
with tf.GradientTape(persistent=True) as tape:
# Generator G translates X -> Y
# Generator F translates Y -> X.
fake_y = generator_g(real_x, training=True)
cycled_x = generator_f(fake_y, training=True)
fake_x = generator_f(real_y, training=True)
cycled_y = generator_g(fake_x, training=True)
# same_x and same_y are used for identity loss.
same_x = generator_f(real_x, training=True)
same_y = generator_g(real_y, training=True)
disc_real_x = discriminator_x(real_x, training=True)
disc_real_y = discriminator_y(real_y, training=True)
disc_fake_x = discriminator_x(fake_x, training=True)
disc_fake_y = discriminator_y(fake_y, training=True)
# calculate the loss
gen_g_loss = generator_loss(disc_fake_y)
gen_f_loss = generator_loss(disc_fake_x)
total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
# Total generator loss = adversarial loss + cycle loss
total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
# Calculate the gradients for generator and discriminator
generator_g_gradients = tape.gradient(total_gen_g_loss,
generator_g.trainable_variables)
generator_f_gradients = tape.gradient(total_gen_f_loss,
generator_f.trainable_variables)
discriminator_x_gradients = tape.gradient(disc_x_loss,
discriminator_x.trainable_variables)
discriminator_y_gradients = tape.gradient(disc_y_loss,
discriminator_y.trainable_variables)
# Apply the gradients to the optimizer
generator_g_optimizer.apply_gradients(zip(generator_g_gradients,
generator_g.trainable_variables))
generator_f_optimizer.apply_gradients(zip(generator_f_gradients,
generator_f.trainable_variables))
discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
discriminator_x.trainable_variables))
discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
discriminator_y.trainable_variables))
(13) 训练网络 200
个epoch:
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
generator_f=generator_f,
discriminator_x=discriminator_x,
discriminator_y=discriminator_y,
generator_g_optimizer=generator_g_optimizer,
generator_f_optimizer=generator_f_optimizer,
discriminator_x_optimizer=discriminator_x_optimizer,
discriminator_y_optimizer=discriminator_y_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
定义检查点保存模型权重。由于训练一个优秀的 CycleGAN
可能需要大量时间,保存检查点能够用于确保模型从上次中断的地方继续学习,只需在下次开始时加载现有的检查点。
(14) 查看 CycleGAN
生成的图像。生成器 A
以夏季照片作为输入,将它们转换为冬季照片,而生成器 B
以冬季照片作为输入,将它们转换为夏季照片:
EPOCHS = 100
def generate_images(model, test_input):
prediction = model(test_input)
plt.figure(figsize=(12, 12))
display_list = [test_input[0], prediction[0]]
title = ['Input Image', 'Predicted Image']
for i in range(2):
plt.subplot(1, 2, i+1)
plt.title(title[i])
# getting the pixel values between [0, 1] to plot it.
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
plt.savefig('image.png')
plt.show()
for epoch in range(EPOCHS):
start = time.time()
n = 0
for image_x, image_y in tf.data.Dataset.zip((train_summer, train_winter)):
train_step(image_x, image_y)
if n % 10 == 0:
print ('.', end='')
n += 1
# Using a consistent image (sample_horse) so that the progress of the model
# is clearly visible.
generate_images(generator_g, sample_summer)
if (epoch + 1) % 5 == 0:
ckpt_save_path = ckpt_manager.save()
print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
ckpt_save_path))
print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
time.time()-start))
to_winter = generator_g(sample_summer)
to_summer = generator_f(sample_winter)
plt.figure(figsize=(8, 8))
contrast = 8
imgs = [sample_summer, to_winter, sample_winter, to_summer]
title = ['Summer', 'To Winter', 'Winter', 'To Summer']
for i in range(len(imgs)):
plt.subplot(2, 2, i+1)
plt.title(title[i])
if i % 2 == 0:
plt.imshow(imgs[i][0] * 0.5 + 0.5)
else:
plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
for inp in test_summer.take(5):
generate_images(generator_g, inp)
可以尝试使用 TensorFlow CycleGAN
数据集中其他的数据集,如 apple2orange
数据集。
小结
CycleGAN
是一种用于无监督图像转换的深度学习模型,它通过两个生成器和两个判别器的组合来学习两个不同域之间的映射关系。生成器负责将一个域的图像转换成另一个域的图像,而判别器则用于区分生成的图像和真实的图像。CycleGAN
引入循环一致性损失,确保图像转换是可逆的,从而提高生成图像的质量。通过对抗训练和循环一致性损失,CycleGAN
可以实现在没有配对标签的情况下进行图像域转换。
系列链接
AIGC实战——生成模型简介
AIGC实战——深度学习 (Deep Learning, DL)
AIGC实战——卷积神经网络(Convolutional Neural Network, CNN)
AIGC实战——自编码器(Autoencoder)
AIGC实战——变分自编码器(Variational Autoencoder, VAE)
AIGC实战——使用变分自编码器生成面部图像
AIGC实战——生成对抗网络(Generative Adversarial Network, GAN)
AIGC实战——WGAN(Wasserstein GAN)
AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)
AIGC实战——自回归模型(Autoregressive Model)
AIGC实战——改进循环神经网络
AIGC实战——像素卷积神经网络(PixelCNN)
AIGC实战——归一化流模型(Normalizing Flow Model)
AIGC实战——能量模型(Energy-Based Model)
AIGC实战——扩散模型(Diffusion Model)
AIGC实战——GPT(Generative Pre-trained Transformer)
AIGC实战——Transformer模型
AIGC实战——ProGAN(Progressive Growing Generative Adversarial Network)
AIGC实战——StyleGAN(Style-Based Generative Adversarial Network)
AIGC实战——VQ-GAN(Vector Quantized Generative Adversarial Network)
AIGC实战——基于Transformer实现音乐生成
AIGC实战——MuseGAN详解与实现
AIGC实战——多模态模型DALL.E 2
AIGC实战——多模态模型Flamingo
AIGC实战——世界模型(World Model)
AIGC实战——生成式人工智能总结与展望