条件生成对抗网络
1.生成对抗网络
生成对网络由两个“对抗性”模型组成:一个生成模型 G,用于捕获数据分布,另一个判别模型 D,用于估计样本来自训练数据而不是 G 的概率。G 和 D 都可以是非线性映射函数。
为了学习数据 x 上的生成器分布 Pg,生成器构建从先验噪声分布 pz(z) 到数据空间的映射函数 G(z; θg)。判别器 D(x; θd) 输出一个标量,表示 x 来自训练数据而不是 pg 的概率。
G 和 D 都是同时训练的:我们调整 G 的参数以最小化 log(1 − D(G(z)) 并调整 D 的参数以最小化 logD(X),就好像它们遵循两人的最小-最大一样价值函数 V (G, D) 的博弈:
G(Generator) -> 生成模块
D (Discriminator) -> 鉴别模块(输出就结果可以是二进制也可以是一维的置信度)
2.条件生成对抗网络
如果生成器和判别器都以一些额外的信息 y 为条件,则生成对抗网络可以扩展到条件模型。y 可以是任何类型的辅助信息,例如类标签或来自其他模态的数据。我们可以通过将 y 作为额外的输入层输入到判别器和生成器中来执行调节。
在生成器中,先验输入噪声 pz(z) 和 y 被组合在联合隐藏表示中,并且对抗性训练框架允许在如何组成该隐藏表示方面具有相当大的灵活性。
在判别器中,x 和 y 作为输入呈现给判别函数(在本例中再次由 MLP 体现)。两人迷你最大游戏的目标函数为:
3.判别器损失函数
判别器(Discriminator)
判别器的目标是区分生成器生成的假数据和真实数据。它接受来自生成器的输出或真实数据集的样本作为输入,并输出一个概率值,表示输入样本是真实数据的概率。
生成器(Generator)
生成器(Generator)的损失函数是它在对抗过程中试图最小化的目标。生成器的目标是产生尽可能接近真实数据分布的假数据,以便判别器(Discriminator)难以区分真假数据。
训练过程
- 初始化:生成器和判别器的参数随机初始化。
- 对抗训练:
生成器生成假数据。
判别器尝试区分真假数据。
判别器的损失函数是它对真实数据和生成数据的预测误差的总和。
生成器的损失函数是它欺骗判别器的成功率,即判别器错误地将生成数据识别为真实数据的概率。
- 参数更新:
判别器根据损失函数更新参数,以更好地区分真假数据。
生成器根据损失函数更新参数,以生成更逼真的数据,以欺骗判别器。
代码实现
#以fashionMNist
# 损失函数
def d_loss_fn(r_logit, f_logit):
r_loss = torch.nn.functional.binary_cross_entropy_with_logits(r_logit, torch.ones_like(r_logit))
f_loss = torch.nn.functional.binary_cross_entropy_with_logits(f_logit, torch.zeros_like(f_logit))
return r_loss, f_loss
def g_loss_fn(f_logit):
f_loss = torch.nn.functional.binary_cross_entropy_with_logits(f_logit, torch.ones_like(f_logit))
return f_loss
# 生成模型
class GeneratorCGAN(nn.Module):
def __init__(self, z_dim, c_dim, dim=128):
super(GeneratorCGAN, self).__init__()
def dconv_bn_relu(in_dim, out_dim, kernel_size=4, stride=2, padding=1, output_padding=0):
return nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, padding, output_padding),
nn.BatchNorm2d(out_dim),
nn.ReLU()
)
self.ls = nn.Sequential(
dconv_bn_relu(z_dim + c_dim, dim * 4, 4, 1, 0, 0), # (N, dim * 4, 4, 4)
dconv_bn_relu(dim * 4, dim * 2), # (N, dim * 2, 8, 8)
dconv_bn_relu(dim * 2, dim), # (N, dim, 16, 16)
nn.ConvTranspose2d(dim, 3, 4, 2, padding=1), nn.Tanh() # (N, 3, 32, 32)
)
def forward(self, z, c):
# z: (N, z_dim), c: (N, c_dim) ->[64, 110]
x = torch.cat([z, c], 1)
# [64, 110] -> [64, 3, 32, 32]
x = self.ls(x.view(x.size(0), x.size(1), 1, 1))
# print(x.shape)
# 输出生成的图像结果
return x
class DiscriminatorCGAN(nn.Module):
def __init__(self, x_dim, c_dim, dim=96, norm='none', weight_norm='spectral_norm'):
super(DiscriminatorCGAN, self).__init__()
norm_fn = _get_norm_fn_2d(norm)
weight_norm_fn = _get_weight_norm_fn(weight_norm)
def conv_norm_lrelu(in_dim, out_dim, kernel_size=3, stride=1, padding=1):
return nn.Sequential(
weight_norm_fn(nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding)),
norm_fn(out_dim),
nn.LeakyReLU(0.2)
)
self.ls = nn.Sequential( # (N, x_dim+c_dim, 32, 32)
conv_norm_lrelu(x_dim + c_dim, dim),
conv_norm_lrelu(dim, dim),
conv_norm_lrelu(dim, dim, stride=2), # (N, dim , 16, 16)
conv_norm_lrelu(dim, dim * 2),
conv_norm_lrelu(dim * 2, dim * 2),
conv_norm_lrelu(dim * 2, dim * 2, stride=2), # (N, dim*2, 8, 8)
conv_norm_lrelu(dim * 2, dim * 2, kernel_size=3, stride=1, padding=0),
conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),
conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0), # (N, dim*2, 6, 6)
nn.AvgPool2d(kernel_size=6), # (N, dim*2, 1, 1)
torchlib.Reshape(-1, dim * 2), # (N, dim*2)
weight_norm_fn(nn.Linear(dim * 2, 1)) # (N, 1)
)
def forward(self, x, c):
# x: (N, x_dim, 32, 32), c: (N, c_dim)
# [64, 10] -> [64, 10, 32, 32]
c = c.view(c.size(0), c.size(1), 1, 1) * torch.ones([c.size(0), c.size(1), x.size(2), x.size(3)], dtype=c.dtype, device=c.device)
# 常规损失函数 [64, 10, 32, 32] ->[64, 1]
logit = self.ls(torch.cat([x, c], 1))
# 输出置信度
return logit
# model:鉴别器输入维度3:三通道图像,输出维度10:对应类别
D = DiscriminatorCGAN(x_dim=3, c_dim=c_dim)
# 生成器模型:编码维度,输出维度10:对应类别
G = GeneratorCGAN(z_dim=z_dim, c_dim=c_dim)
训练架构
# 训练鉴别器模型输入与输出
# 图像
x = x.to(device)
# 对应类别
c_dense = c_dense.to(device)
# 随机图像
z = torch.randn(batch_size, z_dim).to(device)
# 条件标签
c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
# 随机数与条件输入生成器生成伪图像
x_f = G(z, c).detach()
# 原始图像与条件输入鉴别器计算标签图像分数
x_gan_logit = D(x, c) # [batchsize,1]
# 输入伪图像与条件计算伪图像分数
x_f_gan_logit = D(x_f, c) # [batchsize,1]
_x_gan_loss, d_x_f_gan_loss = d_loss_fn(x_gan_logit, x_f_gan_logit)
# 训练生成器模型输入与输出
z = torch.randn(batch_size, z_dim).to(device)
# 生成器中计算损失函数
x_f = G(z, c)
x_f_gan_logit = D(x_f, c)
g_gan_loss = g_loss_fn(x_f_gan_logit)