彻底改变时尚:使用 GAN 实现 AI 的未来
一、介绍
想象一下,在这个世界里,时装设计师永远不会用完新想法,我们穿的每一件衣服都是一件艺术品。听起来很有趣,对吧?好吧,我们可以在通用对抗网络 (GAN) 的帮助下在现实中实现这一目标。GAN模糊了现实与想象之间的界限。它就像一个瓶子里的精灵,满足了我们所有的创造性愿望。我们甚至可以在GAN的帮助下在地球上创造一个太阳,这在现实生活中是不可能的。
早在 2010 年代,Lan Goodfellow 和他的同事就引入了这个框架。他们实际上旨在解决无监督学习的挑战,即模型从未标记的数据中学习并生成新样本。GAN 凭借其制作引人入胜和栩栩如生的内容的能力彻底改变了许多行业,而时尚行业正在引领这一潜力。现在我们将探索 GAN 的潜力并了解它们是如何神奇地工作的。
学习目标
- 关于生成对抗网络 (GAN) 和 GAN 的工作原理。
- GAN在ML和AI领域的作用
- 我们还将看到使用GAN的一些挑战及其未来潜力
- 了解 GAN 的力量和潜力
- 最后,在MNIST时尚数据集上实现GANs
本文是作为数据科学博客马拉松的一部分发表的。
二、内容提要--生成对抗网络(GAN)
生成对抗网络是一类机器学习模型,用于生成新的现实数据。它可以生成高度逼真的图像、视频等等。它只包含两个神经网络:生成器和鉴别器。
2.1 发电机
生成器是一种卷积神经网络,它生成鉴别器无法区分的数据样本。在这里,generator学习如何从噪声中创建数据。它总是试图愚弄鉴别器。
2.2 鉴别器
鉴别器是一种反卷积神经网络,它试图正确分类生成器生成的真实样本和虚假样本。鉴别器获取生成器生成的真实和虚假数据,并学习将其与真实数据区分开来。鉴别器将给出一个介于 0 和 1 之间的分数作为生成图像的输出。这里 0 表示图像是假的,1 表示图像是真实的。
2.3 对抗性训练
训练过程包括生成虚假数据,鉴别器试图正确识别它。它涉及两个阶段:生成器训练和鉴别器训练。它还涉及优化生成器和鉴别器。生成器的目标是生成无法与真实数据区分开来的数据,而鉴别器的目标是识别真实和虚假数据。如果两个网络都正常工作,那么我们可以说模型是优化的。它们都是使用反向传播进行训练的。因此,每当发生错误时,它都会传播回来,并且它们将更新其权重。
2.4 GAN的训练通常包括以下步骤:
- 定义问题陈述
- 选择架构
- 在真实数据上训练鉴别器
- 为生成器生成虚假输入
- 在虚假数据上训练鉴别器
- 带有鉴别器输出的列车生成器
- 迭代和优化
来源: IBM Developer
2.5 损失函数
GAN中使用的损失函数由两个组件组成,因为我们在其架构中有两个网络。在这种情况下,生成器的损失取决于它生成鉴别器无法区分的真实数据的能力。它总是试图将鉴别器的能力降到最低。另一方面,鉴别器的损失取决于它对真实和虚假样本的分类能力。它试图最大限度地减少错误分类。
在训练期间,生成器和鉴别器交替更新。在这里,两者都试图将损失降到最低。生成器试图通过为鉴别器生成更好的样本来减少其损失,而鉴别器则试图通过准确分类假样本和真实样本来减少其损失。此过程一直持续到GAN达到所需的收敛水平。
三、GAN在机器学习和人工智能中的作用
由于其生成新真实数据的能力,GAN在机器学习和人工智能领域变得越来越重要。这有多种应用,如视频生成、图像生成、文本到图像合成等。这些彻底改变了许多行业。让我们看看 GAN 在这个领域很重要的一些原因。
- 数据生成:我们知道,数据是构建模型最重要的东西。我们需要大量的数据集来训练和构建更好的模型。有时数据稀缺,或者数据昂贵。在这种情况下,GAN可用于使用现有数据生成更多新数据。
- 数据隐私:有时我们需要使用数据来训练模型,但这可能会影响个人的隐私。在这种情况下,我们可以使用 GAN 创建与原始数据类似的数据,并训练模型以保护个人隐私。
- 逼真的模拟:这些可以创建对真实世界情况的准确模拟,并可用于创建机器学习模型。例如,由于在现实世界中测试机器人可能存在风险或昂贵,因此我们可以利用它们来测试机器人。
- 对抗性攻击: GAN 可用于创建对抗性攻击,以测试机器学习模型的鲁棒性。它有助于识别漏洞,有助于开发更好的模型,并提高安全性。
- 创意应用: GAN 可用于为 AI 生成创意应用程序。它们可用于创建游戏、音乐、艺术品、电影、动画、照片等等。此外,它还可以产生原创作品,如故事、诗歌等。
来源:无聊熊猫
随着对GANs的研究仍在继续,我们可以期待这项技术在未来创造更多的奇迹。
四、挑战和局限性
尽管 GAN 已经显示出它们生成真实和多样化数据的能力,但它仍然存在一些需要考虑的挑战和局限性。让我们看看GAN的一些挑战和局限性。
- GAN非常依赖于训练数据。生成的数据基于用于训练的数据。这些将生成类似于训练数据的数据。如果它的多样性有限,那么 GAN 也会生成多样性和质量有限的数据。
- 训练 GAN 很困难,因为它们对网络架构和所用超参数的选择高度敏感。这些容易导致训练不稳定,因为生成器和鉴别器可能会陷入相互欺骗的循环中。这会导致收敛性差,从而产生质量差的样本。
- 如果生成器非常擅长区分真假样本,那么生成器将能够生成可以欺骗鉴别器进行区分的样本。这导致产生彼此高度相似的样本,并且它将能够生成涵盖数据集中所有可能性的样本。
- 训练 GAN 也很昂贵。训练 GAN 的计算成本可能很高,尤其是在处理大型数据集和复杂架构时。
- GANs最令人担忧的挑战之一是创建逼真的虚假数据对社会的影响。这可能会导致隐私问题、偏见或滥用。例如,这些可能会生成虚假图像或视频,从而导致错误信息和欺诈。
五、未来潜力
尽管它有一些挑战和局限性,但 GAN 有一个潜在的光明未来。包括医疗保健、金融和娱乐在内的众多行业预计将因 GAN 而经历一场革命。
- 其潜在发展之一将是生成医学。它可以为他们生成个性化的医学图像和治疗计划。在这些GAN的帮助下,即使是医生也可以通过开发更有效的治疗方法来更好地治疗患者。
- 它可以用来创建虚拟现实环境。这些非常逼真,有很多应用,比如娱乐。
- 使用 GAN,我们可以创建更逼真的模拟环境,用于测试自动驾驶汽车。这样我们就可以开发更安全、更有效的自动驾驶汽车。
- 这些不仅限于与图像相关的任务。它们也可用于自然语言处理(NLP)任务。其中包括文本生成、翻译等等。他们可以生成与上下文相关的文本,这是构建虚拟助手和聊天机器人的必要条件。
- 这对建筑师非常有帮助。它可以为建筑物或任何其他结构生成新的设计。这非常有助于建筑师和设计师创造更具创新性的设计。
- 它也可以用于科学研究,因为它可以生成可以模仿现实世界现象的数据。他们可以创建合成数据,用于科学研究中的测试和验证,帮助药物开发和分子设计,并模拟复杂的物理过程。
- GANs也可用于犯罪调查。例如,我们可以使用嫌疑人的身份创建嫌疑人的图像。这样可以更快、更成功地进行调查。
六、时尚MNIST数据集
它是机器学习中用于各种目的的流行数据集。它是原始 MNIST 数据集的替代品,该数据集包含从 0 到 9 的数字。在我们的时尚MNIST数据集中,我们有各种时尚单品的图像,而不是数字。该数据集包含70000张图像,其中60000张为训练图像,10000张为测试图像。它们中的每一个都是灰度的,像素为 28 x 28 像素。时尚 MNIST 数据集包含 10 类时尚商品。他们是:
- 体恤衫
- 连衣裙
- 外套
- 套衫
- 衬衫
- 裤子
- 袋
- 凉鞋
- 运动鞋
- 踝靴
资料来源:ResearchGate
最初,创建此数据集是为了开发用于分类的机器学习模型。该数据集甚至被用作评估许多机器学习算法的基准。该数据集易于访问,可以从各种来源下载,包括 Tensorflow 和 PyTorch 库。与原始数字MINIST数据集相比,它更具挑战性。模特必须能够区分可能具有相似形状或图案的各种时尚产品。这使得它适用于测试各种算法的鲁棒性。
七、GAN在时尚行业的应用
由于 GAN,时尚行业经历了巨大的转变,它使创造力和变革成为可能。我们设计、生产和体验时尚的方式已经被 GAN 彻底改变了。让我们看看通用对抗网络(GAN)在时尚行业中的一些实际应用。
- 时装设计与生成:GAN能够产生新的设计和新的时尚概念。这有助于设计师创造创新和有吸引力的风格。使用 GAN 可以探索各种组合、图案和颜色。例如,服装店 H&M 使用 GAN 为其产品开发新鲜服装。
- 虚拟试妆:虚拟试妆是一个虚拟试用室。在这种情况下,GAN可以生成更逼真的客户服装图像。因此,客户实际上可以知道他们穿着这些衣服的样子,而无需实际穿着它们。
来源: Augray
- 时尚预测:GAN 也用于预测。他们可以在未来产生时尚趋势。这有助于时尚品牌产生新的风格并紧跟潮流。
- 织物和纹理合成:GANs通过虚拟试验各种材料和图案,帮助设计师生成高分辨率的织物纹理,而无需实际试验它们。这有助于节省大量时间和资源,也有助于创新设计流程。
八、时尚 MNIST 数据集的实现
我们现在将使用生成对抗网络 (GAN) 使用 MNIST 时尚数据集生成时尚样本。首先导入所有必要的库。
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.models import Sequential
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import sys
我们必须加载数据集。在这里,我们使用的是时尚MNIST数据集。这是 tensorflow 中的内置数据集。因此,我们可以使用 tensorflow keras 直接加载它。该数据集基本上用于分类任务。如前所述,它具有 28 x 28 像素的灰度图像。我们只需要一组训练数据。因此,我们将它分为训练数据集和测试数据集,并仅加载训练集。
然后,加载的数据在 -1 和 1 之间归一化。我们通常会进行归一化,以提高深度学习模型在训练过程中的稳定性和收敛性。这是大多数深度学习任务中的常见步骤。最后,我们将向数据数组添加一个额外的维度。因为我们需要匹配生成器的预期输入形状。生成器需要一个 4D 张量。它表示批处理大小、高度、宽度和通道数。
# Load fashion dataset
(X_train, _), (_, _) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
设置发生器和鉴别器的尺寸。这里gen_input_dim是生成器输入的大小,在下一行中,定义生成器生成的图像的形状。这里是 28 x 28 和灰度,因为我们只提供一个通道。
gen_input_dim = 100
img_shape = (28, 28, 1)
8.1 定义生成器模型
现在我们将定义生成器模型。它只需要一个参数,那就是输入维度。它使用 keras 顺序 API 来构建模型。它有三个完全连接的层,具有 LeakyReLU 激活函数和批量归一化。在最后一层中,它使用tanh激活函数来生成最终的输出图像。最后,它返回一个 keras 模型对象,该对象将噪声向量作为输入,并给出生成的图像作为输出。
def build_generator(input_dim):
model = Sequential()
model.add(Dense(256, input_dim=input_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(img_shape), activation='tanh'))
model.add(Reshape(img_shape))
noise = Input(shape=(input_dim,))
img = model(noise)
return Model(noise, img)
8.2 定义判别器模型
下一步是构建鉴别器。它几乎与发生器模型相似,但在这里它只有两个完全连接的层,并且最后一层具有 sigmoid 激活函数。它通过将噪声向量作为输入返回模型对象作为输出,并输出图像为实数的概率。
def build_discriminator(img_shape):
model = Sequential()
model.add(Flatten(input_shape=img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
img = Input(shape=img_shape)
validity = model(img)
return Model(img, validity)
8.3 编译模型
现在我们必须编译它们。我们使用二进制交叉熵损失和 Adam 优化器来编译判别器和生成器。我们将学习率设置为 0.0002,将衰减率设置为 0.5。使用二元交叉熵损失函数构建和编译鉴别器模型,该函数常用于二元分类任务。还定义了准确性指标来评估鉴别器。
同样,构建了一个生成器模型,用于创建生成器的体系结构。在这里,我们不会像为鉴别器那样编译生成器。它将以对抗性的方式对歧视者进行训练。z 是表示发生器随机噪声的输入层。生成器将 z 作为输入并生成 img 作为输出。在组合模型的训练过程中,鉴别器的权重被冻结。生成器的输出将被馈送到鉴别器并生成有效性,从而测量生成图像的质量。然后,使用 z 作为输入,使用 validity 作为输出来创建组合模型。这用于训练生成器。
optimizer = Adam(0.0002, 0.5)
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
generator = build_generator(gen_input_dim)
z = Input(shape=(gen_input_dim,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)
8.4 训练
是时候训练我们的 GAN 了。我们知道它运行周期的迭代次数。在每次迭代中,从训练集中获取一批随机图像,生成器通过传递噪声生成一批假图像。
鉴别器在真实图像和虚假图像上进行训练。并计算平均损失。对发生器进行噪声训练并计算损耗。在这里,我们将sample_interval定义为 1000。因此,每 1000 次迭代,就会打印损失。
# Train GAN
epochs = 5000
batch_size = 32
sample_interval = 1000
d_losses = []
g_losses = []
for epoch in range(epochs):
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_images = X_train[idx]
# Train discriminator
noise = np.random.normal(0, 1, (batch_size, gen_input_dim))
fake_images = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
d_losses.append(d_loss[0])
# Train generator
noise = np.random.normal(0, 1, (batch_size, gen_input_dim))
g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))
g_losses.append(g_loss)
# Print progress
if epoch % sample_interval == 0:
print(f"Epoch {epoch}, Discriminator loss: {d_loss[0]}, Generator loss: {g_loss}")
8.5 生成示例图像
现在让我们看看一些生成的样本。在这里,我们绘制了一个包含这些样本的 5 行和 10 列的网格。这是使用 matplotlib 创建的。这些生成的样本类似于我们用于训练的数据集。我们可以通过训练更多纪元来生成更高质量的样本。
# Generate sample images
r, c = 5,10
noise = np.random.normal(0, 1, (r * c, gen_input_dim))
gen_imgs = generator.predict(noise)
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
# Plot images
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
plt.show()
九、结论
生成对抗网络 (GAN) 因其独特的架构、训练过程和生成数据的能力而成为许多应用程序最受欢迎的选择。与任何技术一样,GAN也有一些挑战和局限性。研究人员正在努力将它们最小化,并渴望更好的GAN。总的来说,我们已经学习并理解了GAN的力量和潜力及其工作。我们还构建了一个 GAN,以使用时尚 MNIST 数据集生成时尚样本。