导 读
Wasserstein 生成对抗网络 (WGAN) 作为一项关键创新而出现,解决了经常困扰传统生成对抗网络 (GAN) 的稳定性和收敛性的基本挑战。
由 Arjovsky 等人于2017 年提出,WGAN 通过利用 Wasserstein 距离彻底改变了生成模型的训练,提供了一个强大的框架,可以提高生成样本的质量和多样性。
本文深入探讨了 WGAN 的概念基础、优势和实际意义,说明了它们在更广泛的生成建模背景下的重要性。
有需要的朋友关注公众号【小Z的科研日常】,获取更多内容。
01、WGAN的概念框架
WGAN 与其前辈的区别在于用 Wasserstein 距离代替 Jensen-Shannon 散度作为其损失函数。
瓦瑟斯坦距离,直观地理解为推土机距离,量化了将一种概率分布转换为另一种概率分布所需的最小成本。
该指标赋予 WGAN 在训练过程中更平滑、更可靠的梯度信号,即使在真实数据分布和生成数据分布不重叠的情况下,也有助于生成更高质量的样本。
与传统 GAN 的一个重要区别是取代了判别器。与将输入分类为真实或虚假的判别器不同,WGAN 框架中的批评者评估真实样本和生成样本的分布之间的 Wasserstein 距离。
这种从分类到估计的转变标志着生成模型处理学习过程的方式发生了根本性变化,从而实现了更细致、更有效的训练动态。
02、相比于传统GAN的优势与挑战
WGAN 提供了几个引人注目的优势,可以解决传统 GAN 框架的局限性。
首先,它们表现出改进的训练稳定性,降低了对超参数设置和架构选择的敏感性。这种稳定性源于 Wasserstein 距离的特性,即使真实分布和生成分布之间没有重叠,它也能提供有用的梯度信息——这是一个可能阻碍传统 GAN 训练的常见问题。
此外,WGAN 还缓解了模式崩溃问题,即生成器学习产生有限范围的输出,从而无法捕获真实数据分布的多样性的现象。Wasserstein 距离的连续且更有意义的损失景观鼓励生成器探索更广泛的输出,从而增强生成样本的多样性。
WGAN 中损失度量的可解释性也代表了重大进步。与传统 GAN(判别器的准确性不一定与生成样本的质量相关)不同,WGAN 中的批评者损失提供了更直接的收敛性衡量标准,为训练过程和生成数据的质量提供了有价值的见解。
尽管有其优点,WGAN 也带来了新的挑战,主要与计算效率有关。WGAN 的最初实现需要权重裁剪来强制执行 Lipschitz 约束,这对于 Wasserstein 距离的理论属性至关重要。
然而,权重裁剪可能会导致优化困难和容量利用率不足。为了解决这个问题,引入带有梯度惩罚的 WGAN (WGAN-GP) 提出了一种替代方法来强制实施 Lipschitz 约束,而无需进行权重裁剪,从而提高训练稳定性和模型性能。
03、代码
为 Wasserstein 生成对抗网络 (WGAN) 创建完整的代码示例涉及几个步骤,包括定义生成器和批评者的模型架构、准备合成数据集、训练模型以及通过指标和图评估性能。
此示例将说明使用 TensorFlow 和 Keras 的基本实现,并使用简单的合成数据集以便于理解。
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
def build_critic():
model = keras.Sequential([
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
layers.LeakyReLU(alpha=0.2),
layers.GlobalMaxPooling2D(),
layers.Dense(1),
])
return model
def build_generator(latent_dim):
model = keras.Sequential([
keras.Input(shape=(latent_dim,)),
layers.Dense(7 * 7 * 128),
layers.Reshape((7, 7, 128)),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
layers.LeakyReLU(alpha=0.2),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(1, (7, 7), padding='same', activation='sigmoid'),
])
return model
class WGAN(keras.Model):
def __init__(self, critic, generator, latent_dim):
super(WGAN, self).__init__()
self.critic = critic
self.generator = generator
self.latent_dim = latent_dim
self.d_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
self.g_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
self.critic_loss_tracker = keras.metrics.Mean(name="critic_loss")
self.generator_loss_tracker = keras.metrics.Mean(name="generator_loss")
@property
def metrics(self):
return [self.critic_loss_tracker, self.generator_loss_tracker]
def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
super(WGAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.d_loss_fn = d_loss_fn
self.g_loss_fn = g_loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
# 在潜在空间中随机取样
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# 将它们解码为假图像
generated_images = self.generator(random_latent_vectors)
# 将它们与真实图像相结合
combined_images = tf.concat([generated_images, real_images], axis=0)
# 组合标签,辨别真假图像
labels = tf.concat(
[tf.ones((batch_size, 1)), -tf.ones((batch_size, 1))], axis=0
)
# 在标签中添加随机噪音--重要技巧!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# 训练批评家
with tf.GradientTape() as tape:
predictions = self.critic(combined_images)
d_loss = self.d_loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.critic.trainable_variables)
self.d_optimizer.apply_gradients(
zip(grads, self.critic.trainable_variables)
)
# 在潜在空间中随机取样
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# 组装 "所有真实图像 "的标签
misleading_labels = -tf.ones((batch_size, 1))
# 训练生成器(通过评论家模型)
with tf.GradientTape() as tape:
predictions = self.critic(self.generator(random_latent_vectors))
g_loss = self.g_loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_variables)
self.g_optimizer.apply_gradients(
zip(grads, self.generator.trainable_variables)
)
# 更新指标
self.critic_loss_tracker.update_state(d_loss)
self.generator_loss_tracker.update_state(g_loss)
return {
"critic_loss": self.critic_loss_tracker.result(),
"generator_loss": self.generator_loss_tracker.result(),
}
latent_dim = 128
# 准备数据集
(x_train, _), (_, _) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_train = np.expand_dims(x_train, axis=-1)
# 实例化批评者和生成器模型
critic = build_critic()
generator = build_generator(latent_dim)
# 实例化 WGAN 模型
wgan = WGAN(critic=critic, generator=generator, latent_dim=latent_dim)
# 编译 WGAN 模型
wgan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
d_loss_fn=keras.losses.MeanSquaredError(),
g_loss_fn=keras.losses.MeanSquaredError(),
)
wgan.fit(x_train, batch_size=32, epochs=100)
def generate_and_save_images(model, epoch, test_input):
predictions = model.generator(test_input, training=False)
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
# 生成潜在点
random_latent_vectors = tf.random.normal(shape=(16, latent_dim))
generate_and_save_images(wgan, 0, random_latent_vectors)
1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5405 - generator_loss: 2.4530
Epoch 99/100
1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5408 - generator_loss: 2.4463
Epoch 100/100
1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5384 - generator_loss: 2.4411
此代码提供了使用简单数据集通过 TensorFlow 和 Keras 实现 WGAN 的基础框架。对于实际应用程序,您可能需要调整数据集、架构和训练参数以满足您的特定需求。
04、结论
Wasserstein 生成对抗网络代表了生成建模领域的重大飞跃。通过将 Wasserstein 距离集成到 GAN 框架中,WGAN 为训练生成模型提供了更稳定、可靠和可解释的方法。
尽管存在与计算需求和 Lipschitz 约束的执行相关的挑战,但 WGAN 及其后续迭代(如 WGAN-GP)所带来的进步继续影响着生成模型的发展。
随着该领域研究的进展,WGAN 有望进一步释放生成模型在从图像合成到自然语言生成等众多应用中的潜力,预示着人工智能驱动的创造力和创新的新时代。