一、基本概念
生成对抗网络(Generative Adversarial Network,GAN)是一种由两个神经网络共同组成深度学习模型:生成器(Generator)和判别器(Discriminator)。这两个网络通过对抗的方式进行训练,生成器尝试伪造逼真的样本数据,而判别器则负责判断输入的数据是真实数据还是生成器伪造出来的数据。理想情况下,判别器对真实样本和生成样本的判断概率都是1/2,意味着判别器已经无法判断生成器生成的数据真假。
二、模型原理
GAN的模型原理并不复杂。首先,GAN由以下两个子模型组成:
- 生成器(Generator):从随机噪声中生成数据,目标是欺骗判别器,使其认为生成的数据是真实的。
- 判别器(Discriminator):判断输入数据是来自真实数据分布还是生成器,目标是正确区分真实数据和生成数据。
然后,GAN的损失函数是训练的核心,我们需要构建一个合适的损失函数用于衡量生成器和判别器的表现:
- 生成器损失(G_loss):通常表示为最大化判别器对其生成样本的错误分类概率,也就是判别器判定所有生成数据均为真。
- 判别器损失(D_loss):由两部分组成,一部分是真实样本的损失(标签为1),另一部分是生成样本的损失(标签为0)。
最后,我们通过算法设计来交替训练生成器和判别器,例如生成器每训练5个Epoch,我们就训练一次判别器:
- 训练判别器:提高其区分真实样本和生成样本的能力。
- 训练生成器:提高其生成真实样本的能力,目标是最大化判别器将其生成样本识别为真实样本的概率。
三、python实现
1、导库
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA
2、数据处理
这里我们的目标是训练一个生成对抗网络来生成iris数据,使用sklearn的iris数据集训练。这意味着,我们输入给生成器的信息中需要包含类别信息,这样生成器才能生成对应类别的数据样本。当然,这一步不是必要的,在类别不敏感的任务中,只需要生成符合要求的数据即可。
# 加载Iris数据集
iris = load_iris()
data = iris.data
labels = iris.target
# 标准化数据
scaler = StandardScaler()
data = scaler.fit_transform(data)
# One-hot编码标签
encoder = OneHotEncoder(sparse=False)
# torch.Size([100, 3])
labels = encoder.fit_transform(labels.reshape(-1, 1))
# 转换为PyTorch张量
data = torch.FloatTensor(data)
labels = torch.FloatTensor(labels)
# 创建数据加载器
batch_size = 32
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
3、构建生成器
这里,我们构建一个全连接神经网络。生成器的输入包括随机初始化的x,以及x对应的期望类别,期望类别是可以真实标签,表示生成对应类别下的数据样本。
# 生成器网络
class Generator(nn.Module):
def __init__(self, input_dim, label_dim, output_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim + label_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, output_dim),
)
def forward(self, x, labels):
x = torch.cat([x, labels], 1)
return self.model(x)
4、构建判别器
这里,我们的判别器实际上是一个二分类模型。判别器的输入维度跟生成器一直,都需要考虑类别信息。
# 判别器网络
class Discriminator(nn.Module):
def __init__(self, input_dim, label_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim + label_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x, labels):
x = torch.cat([x, labels], 1)
return self.model(x)
5、超参数设置
值得注意的是,我们分别为生成器和判别器构造一个优化器,从而便于分开训练两个子模型。
# 设置超参数
latent_dim = 100
data_dim = data.shape[1]
label_dim = labels.shape[1]
lr = 0.0002
num_epochs = 200
# 初始化生成器和判别器
generator = Generator(latent_dim, label_dim, data_dim)
discriminator = Discriminator(data_dim, label_dim)
# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
# 损失函数
criterion = nn.BCELoss()
6、模型训练
这里,我们选择了分开训练生成器和判别器,在一个epoch中,先训练3次生成器,再训练一次判别器。这样的目的是增加生成器的学习时间,从而使得生成的样本更为真实。
# 训练GAN
for epoch in range(num_epochs):
for i, (real_data, real_labels) in enumerate(dataloader):
batch_size = real_data.size(0)
# 当前仅训练生成器
generator.train()
discriminator.eval()
# 迭代训练生成器,这里是每个epoch训练3次
for _ in range(3):
z = torch.randn(batch_size, latent_dim)
# 直接使用真实标签即可,这里的标签代表的是样本类别,目的是让模型学习到类别差异
# 生成器生成的是各对应类别的数据
fake_data = generator(z, real_labels)
# 使用判别器对生成的假数据进行分类
outputs = discriminator(fake_data, real_labels)
# 基于判别器的结果计算生成器的损失,目标是让判别器认为生成的数据是真实的(标签为1)
# 如果这里使用的是torch.zeros则生成器的结果将会非常差,几乎无法生成真实数据
# 这是由于我们的目标是让outputs逼近全1向量,也就是让判别器认为所有生成的数据都是真实的,这样才能让生成样本越来越真实
g_loss = criterion(outputs, torch.ones(batch_size, 1))
# 反向传播生成器的梯度
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
# 当前仅训练判别器
generator.eval()
discriminator.train()
# 训练判别器,真实样本标签为1,生成样本标签为0
real_targets = torch.ones(batch_size, 1)
fake_targets = torch.zeros(batch_size, 1)
# 真实数据损失
outputs = discriminator(real_data, real_labels)
d_loss_real = criterion(outputs, real_targets)
real_score = outputs
# 生成假数据,计算损失
z = torch.randn(batch_size, latent_dim)
fake_data = generator(z, real_labels)
outputs = discriminator(fake_data.detach(), real_labels)
# 这里的目标与上面生成器部分相反,我们是要让outputs逼近全0向量,也就是全部预测出假数据
# 所以fake_targets是一个全0向量
d_loss_fake = criterion(outputs, fake_targets)
fake_score = outputs
# 总的判别器损失
d_loss = d_loss_real + d_loss_fake
# 反向传播判别器的梯度
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
if epoch%10==0:
# 打印损失
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, '
f'D(x): {real_score.mean().item():.4f}, D(G(z)): {fake_score.mean().item():.4f}')
7、生成新数据
最后,我们使用训练好的GAN中的生成器来生成一批新数据。可以看到,效果不错。
# 生成新数据
num_samples = 100
z = torch.randn(num_samples, latent_dim)
labels = np.array([0, 1, 2] * (num_samples // 3) + [0] * (num_samples % 3))
labels = encoder.transform(labels.reshape(-1, 1))
labels = torch.FloatTensor(labels)
generated_data = generator(z, labels).detach().numpy()
# 降维
pca = PCA(n_components=2)
data_2d = pca.fit_transform(data)
generated_data_2d = pca.transform(generated_data)
# 可视化生成的数据
plt.figure(figsize=(10, 5))
for i in range(3):
real_class_data = data_2d[iris.target == i]
generated_class_data = generated_data_2d[np.argmax(labels.numpy(), axis=1) == i]
plt.scatter(real_class_data[:, 0], real_class_data[:, 1], label=f'Real Class {i}')
plt.scatter(generated_class_data[:, 0], generated_class_data[:, 1], label=f'Generated Class {i}')
plt.legend()
plt.show()
四、完整代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA
# 加载Iris数据集
iris = load_iris()
data = iris.data
labels = iris.target
# 标准化数据
scaler = StandardScaler()
data = scaler.fit_transform(data)
# One-hot编码标签
encoder = OneHotEncoder(sparse=False)
labels = encoder.fit_transform(labels.reshape(-1, 1))
# 转换为PyTorch张量
data = torch.FloatTensor(data)
labels = torch.FloatTensor(labels)
# 创建数据加载器
batch_size = 32
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 生成器网络
class Generator(nn.Module):
def __init__(self, input_dim, label_dim, output_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim + label_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, output_dim),
)
def forward(self, x, labels):
x = torch.cat([x, labels], 1)
return self.model(x)
# 判别器网络
class Discriminator(nn.Module):
def __init__(self, input_dim, label_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim + label_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x, labels):
x = torch.cat([x, labels], 1)
return self.model(x)
# 设置超参数
latent_dim = 100
data_dim = data.shape[1]
label_dim = labels.shape[1]
lr = 0.0002
num_epochs = 200
# 初始化生成器和判别器
generator = Generator(latent_dim, label_dim, data_dim)
discriminator = Discriminator(data_dim, label_dim)
# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
# 损失函数
criterion = nn.BCELoss()
# 训练GAN
for epoch in range(num_epochs):
for i, (real_data, real_labels) in enumerate(dataloader):
batch_size = real_data.size(0)
generator.train()
discriminator.eval()
# 迭代训练生成器,这里是每个epoch训练3次
for _ in range(3):
z = torch.randn(batch_size, latent_dim)
# 直接使用真实标签即可,这里的标签代表的是样本类别,目的是让模型学习到类别差异
# 生成器生成的是各对应类别的数据
fake_data = generator(z, real_labels)
# 使用判别器对生成的假数据进行分类
outputs = discriminator(fake_data, real_labels)
# 基于判别器的结果计算生成器的损失,目标是让判别器认为生成的数据是真实的(标签为1)
# 如果这里使用的是torch.zeros则生成器的结果将会非常差,几乎无法生成真实数据
# 这是由于我们的目标是让outputs逼近全1向量,也就是让判别器认为所有生成的数据都是真实的,这样才能让生成样本越来越真实
g_loss = criterion(outputs, torch.ones(batch_size, 1))
# 反向传播生成器的梯度
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
generator.eval()
discriminator.train()
# 训练判别器,真实样本标签为1,生成样本标签为0
real_targets = torch.ones(batch_size, 1)
fake_targets = torch.zeros(batch_size, 1)
# 真实数据损失
outputs = discriminator(real_data, real_labels)
d_loss_real = criterion(outputs, real_targets)
real_score = outputs
# 生成假数据,计算损失
z = torch.randn(batch_size, latent_dim)
fake_data = generator(z, real_labels)
outputs = discriminator(fake_data.detach(), real_labels)
# 这里的目标与上面生成器部分相反,我们是要让outputs逼近全0向量,也就是全部预测出假数据
# 所以fake_targets是一个全0向量
d_loss_fake = criterion(outputs, fake_targets)
fake_score = outputs
# 总的判别器损失
d_loss = d_loss_real + d_loss_fake
# 反向传播判别器的梯度
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
if epoch%10==0:
# 打印损失
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, '
f'D(x): {real_score.mean().item():.4f}, D(G(z)): {fake_score.mean().item():.4f}')
# 生成新数据
num_samples = 100
z = torch.randn(num_samples, latent_dim)
labels = np.array([0, 1, 2] * (num_samples // 3) + [0] * (num_samples % 3))
labels = encoder.transform(labels.reshape(-1, 1))
labels = torch.FloatTensor(labels)
generated_data = generator(z, labels).detach().numpy()
# 降维
pca = PCA(n_components=2)
data_2d = pca.fit_transform(data)
generated_data_2d = pca.transform(generated_data)
# 可视化生成的数据
plt.figure(figsize=(10, 5))
for i in range(3):
real_class_data = data_2d[iris.target == i]
generated_class_data = generated_data_2d[np.argmax(labels.numpy(), axis=1) == i]
plt.scatter(real_class_data[:, 0], real_class_data[:, 1], label=f'Real Class {i}')
plt.scatter(generated_class_data[:, 0], generated_class_data[:, 1], label=f'Generated Class {i}')
plt.legend()
plt.show()
五、总结
生成对抗网络是一个很经典的深度学习模型,它在诸多领域中发挥着重要作用。除了超参数调整之外,训练GAN的另一个关键步骤是构造一个合适的训练策略。例如,可以同时训练生成器和判别器,也可以交替训练二者,或者先训练生成器再训练判别器等等。但是,这两个网络是相互博弈的,由于生成器参数是随机初始化的,一开始生成的数据质量往往较差。我们的策略一般是先让生成器变强(通过构造更复杂的网络结构或者更多的训练次数),让生成的数据质量先提升。这样随着训练的迭代,生成的样本越来越逼真,判别器也不得不为了最小化D_loss而提升自身的能力。