WGAN-GP 原理及实现
- 一、WGAN-GP 原理
- 1.1 WGAN-GP 核心原理
- 1.2 WGAN-GP 实现步骤
- 1.3 总结
- 二、WGAN-GP 实现
- 2.1 导包
- 2.2 数据加载和处理
- 2.3 构建生成器
- 2.4 构建判别器
- 2.5 训练和保存模型
- 2.6 图片转GIF
一、WGAN-GP 原理
Wasserstein GAN with Gradient Penalty (WGAN-GP) 是对原始 WGAN 的改进,通过梯度惩罚(Gradient Penalty)
替代权重裁剪(Weight Clipping),解决了 WGAN 训练不稳定、权重裁剪导致梯度消失或爆炸的问题。
1.1 WGAN-GP 核心原理
(1) Wasserstein 距离(Earth-Mover 距离)
- 原始 GAN 的 JS 散度在分布不重叠时梯度消失,而 WGAN 使用 Wasserstein 距离衡量生成分布
P
g
P_g
Pg 和真实分布
P
r
P_r
Pr 的距离:
W ( P r , P g ) = inf γ ∼ Π ( P r , P g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(P_r, P_g) = \inf_{\gamma \sim \Pi(P_r, P_g)} \mathbb{E}_{(x,y)\sim \gamma} [\|x-y\|] W(Pr,Pg)=infγ∼Π(Pr,Pg)E(x,y)∼γ[∥x−y∥] - 通过 Kantorovich-Rubinstein 对偶形式,转化为:
W ( P r , P g ) = sup ∥ D ∥ L ≤ 1 E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] W(P_r, P_g) = \sup_{\|D\|_L \leq 1} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))] W(Pr,Pg)=sup∥D∥L≤1Ex∼Pr[D(x)]−Ez∼Pz[D(G(z))],其中 D D D 是 1-Lipschitz 函数(梯度范数不超过 1)
(2) 梯度惩罚(Gradient Penalty)
- 原始 WGAN 的问题:通过权重裁剪强制判别器(Critic)满足 Lipschitz 约束,但会导致梯度不稳定或容量下降
- WGAN-GP 的改进:直接对判别器的梯度施加惩罚项,强制其梯度范数接近 1:
λ
⋅
E
x
^
∼
P
x
^
\lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}
λ⋅Ex^∼Px^
[
(
∥
∇
x
^
D
(
x
^
)
∥
2
−
1
)
2
]
\left [(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right]
[(∥∇x^D(x^)∥2−1)2]
- x ^ \hat{x} x^ 是真实数据和生成数据的随机插值点: x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon) G(z) x^=ϵx+(1−ϵ)G(z), ϵ ∼ U [ 0 , 1 ] \epsilon \sim U[0,1] ϵ∼U[0,1]
- λ \lambda λ 是惩罚系数(通常设为 10)
1.2 WGAN-GP 实现步骤
(1) 判别器(Critic)的损失函数
判别器的目标是最大化 Wasserstein 距离,同时满足梯度约束:
L
D
=
E
x
∼
P
r
[
D
(
x
)
]
−
E
z
∼
P
z
[
D
(
G
(
z
)
)
]
⏟
Wasserstein 距离
+
λ
⋅
E
x
^
∼
P
x
^
[
(
∥
∇
x
^
D
(
x
^
)
∥
2
−
1
)
2
]
⏟
梯度惩罚
L_D = \underbrace{\mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))]}_{\text{Wasserstein 距离}} + \underbrace{\lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right]}_{\text{梯度惩罚}}
LD=Wasserstein 距离
Ex∼Pr[D(x)]−Ez∼Pz[D(G(z))]+梯度惩罚
λ⋅Ex^∼Px^[(∥∇x^D(x^)∥2−1)2]
(2) 生成器(Generator)的损失函数
生成器的目标是最小化 Wasserstein 距离:
L
G
=
−
E
z
∼
P
z
[
D
(
G
(
z
)
)
]
L_G = -\mathbb{E}_{z \sim P_z}[D(G(z))]
LG=−Ez∼Pz[D(G(z))]
(3) 训练流程
- 输入:真实数据 x x x,噪声 z ∼ N ( 0 , 1 ) z \sim \mathcal{N}(0,1) z∼N(0,1)
- 生成数据: G ( z ) G(z) G(z)
- 插值采样: x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon) G(z) x^=ϵx+(1−ϵ)G(z), ϵ ∼ U [ 0 , 1 ] \epsilon \sim U[0,1] ϵ∼U[0,1]
- 计算梯度惩罚:
- 对插值样本 x ^ \hat{x} x^ 计算判别器输出 D ( x ^ ) D(\hat{x}) D(x^)
- 求梯度 ∇ x ^ D ( x ^ ) \nabla_{\hat{x}} D(\hat{x}) ∇x^D(x^) 并计算惩罚项
- 更新判别器:最小化 L D L_D LD
- 更新生成器:最小化 L G L_G LG(每 n critic n_{\text{critic}} ncritic 次判别器更新后更新 1 次生成器)
1.3 总结
WGAN-GP 通过梯度惩罚替代权重裁剪,显著提升了 WGAN 的训练稳定性,是生成对抗网络的重要改进之一。实际应用中需注意:
- 判别器架构设计
- 梯度惩罚的正确实现
- 学习率和训练次数的调优
二、WGAN-GP 实现
2.1 导包
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torchsummary import summary
# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 指定存放日志路径
writer=SummaryWriter(log_dir="./runs/wgan_gp")
os.makedirs("./img/wgan_gp_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录
2.2 数据加载和处理
# 加载 MNIST 数据集
def load_data(batch_size=64,img_shape=(1,28,28)):
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize(mean=[0.5], std=[0.5]) # 归一化到[-1,1]
])
# 下载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=2,shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2,shuffle=False)
return train_loader, test_loader
2.3 构建生成器
class Generator(nn.Module):
"""生成器"""
def __init__(self, latent_dim=100,img_shape=(1,28,28)):
super(Generator,self).__init__()
# 网络块
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat))
layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh() # 输出归一化到[-1,1]
)
def forward(self,z): # 噪声z,2维[batch_size,latent_dim]
gen_img=self.model(z)
gen_img=gen_img.view(gen_img.shape[0],*img_shape)
return gen_img # 4维[batch_size,1,H,W]
2.4 构建判别器
class Discriminator(nn.Module):
"""判别器"""
def __init__(self,img_shape=(1,28,28)):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(256, 1)
)
def forward(self,img): # 输入图片,4维[batc_size,1,H,W]
img=img.view(img.shape[0], -1)
pred = self.model(img)
return pred # 2维[batch_size,1]
2.5 训练和保存模型
-
WGAN-GP 算法流程
-
定义梯度惩罚函数
def compute_gradient_penalty(critic, real, fake, device):
batch_size = real.shape[0]
epsilon = torch.rand(batch_size, 1, 1, 1).to(device) # 随机插值系数
interpolates = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)
critic_interpolates = critic(interpolates)
# 计算梯度
gradients = torch.autograd.grad(
outputs=critic_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(critic_interpolates),
create_graph=True,
retain_graph=True,
)[0]
gradients = gradients.view(gradients.shape[0], -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
- 训练和保存
# 设置超参数
batch_size = 64
epochs = 200
lr= 0.0002
latent_dim=100 # 生成器输入噪声向量的长度(维数)
sample_interval=400 #每400次迭代保存生成样本
# WGAN的特别设置
num_iter_critic = 5
lambda_gp = 10
# 设置图片形状1*28*28
img_shape = (1,28,28)
# 加载数据
train_loader,_= load_data(batch_size=batch_size,img_shape=img_shape)
# 实例化生成器G、判别器D
G=Generator().to(device)
D=Discriminator().to(device)
# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
# 开始训练
batches_done=0
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(epochs):
# 进入训练模式
G.train()
D.train()
loop = tqdm(train_loader, desc=f"第{epoch+1}轮")
for i, (real_imgs, _) in enumerate(loop):
real_imgs=real_imgs.to(device) # [B,C,H,W]
# -----------------
# 训练判别器
# -----------------
# 获取噪声样本[B,latent_dim)
z=torch.normal(0,1,size=(real_imgs.shape[0],latent_dim),device=device) #从正态分布中抽样
# Step-1 计算判断器损失=判断真实图片损失+判断生成图片损失+惩罚项
fake_imgs=G(z).detach()
gradient_penalty=compute_gradient_penalty(D, real_imgs, fake_imgs, device)
dis_loss=-torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs))+lambda_gp*gradient_penalty
# Step-2 更新判别器参数
optimizer_D.zero_grad() # 梯度清零
dis_loss.backward() #反向传播,计算梯度
optimizer_D.step() #更新判别器
# -----------------
# 训练生成器
# -----------------
# 判别器每迭代 num_iter_critic 次,生成器迭代一次
if i % num_iter_critic ==0 :
gen_imgs=G(z).detach()
# 更新生成器参数
optimizer_G.zero_grad() #梯度清零
gen_loss=-torch.mean(D(gen_imgs))
gen_loss.backward() #反向传播,计算梯度
optimizer_G.step() #更新生成器
# 更新进度条
loop.set_postfix(
gen_loss=f"{gen_loss:.8f}",
dis_loss=f"{dis_loss:.8f}"
)
# 每 sample_interval 次迭代保存生成样本
if batches_done % sample_interval == 0:
save_image(gen_imgs.data[:25], f"./img/wgan_gp_mnist/{epoch}_{i}.png", nrow=5, normalize=True)
batches_done += 1
print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))
#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/WGAN-GP_G.pth")
torch.save(D.state_dict(), "./model/WGAN-GP_D.pth")
2.6 图片转GIF
from PIL import Image
def create_gif(img_dir="./img/wgan_gp_mnist", output_file="./img/wgan_gp_mnist/wgan_gp_figure.gif", duration=100):
images = []
img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]
# 自定义排序:按 "x_y.png" 的 x 和 y 排序
img_paths_sorted = sorted(
img_paths,
key=lambda x: (
int(x.split('_')[0]), # 第一个数字(如 0_400.png 的 0)
int(x.split('_')[1].split('.')[0]) # 第二个数字(如 0_400.png 的 400)
)
)
for img_file in img_paths_sorted:
img = Image.open(os.path.join(img_dir, img_file))
images.append(img)
images[0].save(output_file, save_all=True, append_images=images[1:],
duration=duration, loop=0)
print(f"GIF已保存至 {output_file}")
create_gif()
