文章目录
- 前言小笔记
- 关键特性
- 技术栈
- 使用场景
- 贡献者:
- 完整代码
- 代码解析
- 1. 导入必要的库
- 2. 设备配置
- 3. 超参数设置
- 4. 创建样本目录
- 5. 图像处理
- 6. 加载MNIST数据集
- 7. 创建数据加载器
- 8. 定义判别器(Discriminator)D
- 9. 定义生成器(Generator)G
- 10. 设备设置
- 11. 损失函数和优化器
- 12. 辅助函数
- 13. 训练循环
- 14. 打印训练进度
- 15. 保存图像
- 16. 保存模型检查点
- 运行效果
前言小笔记
这份代码是利用深度学习技术,通过生成对抗网络(GAN)模型,实现了对手写数字图像的生成。MNIST数据集是一个广泛使用的数据库,包含了大量的手写数字灰度图像,是机器学习和计算机视觉领域的标准测试集。
关键特性
- 设备适配性:自动检测并使用可用的GPU资源,提高计算效率。
- 超参数配置:提供了灵活的超参数设置,包括潜在空间大小、隐藏层大小、图像尺寸等,以适应不同的训练需求。
- 图像预处理:实现了图像的归一化处理,将像素值标准化到[-1, 1]区间,以利于神经网络的训练。
- 数据加载:使用
DataLoader
高效地加载和批处理数据,同时支持数据打乱,提高模型泛化能力。 - 判别器与生成器网络:定义了两个神经网络模型,判别器用于区分真实图像与生成图像,生成器用于生成逼真的数字图像。
- 损失函数与优化器:选用了二元交叉熵损失函数和Adam优化器,确保了模型的有效训练。
- 训练循环:实现了完整的训练逻辑,包括判别器和生成器的交替训练,以及梯度的更新。
- 进度监控:在训练过程中提供了详细的进度输出,方便用户监控训练状态。
- 图像保存:训练过程中会生成并保存真实图像和假图像的样本,用于可视化训练效果。
- 模型保存:训练完成后,模型参数会被保存,方便后续的模型加载和使用。
技术栈
- PyTorch:主要的深度学习框架,用于构建和训练神经网络。
- torchvision:PyTorch的扩展包,提供图像处理和数据加载工具。
使用场景
本项目适用于深度学习研究、教育、数据科学竞赛等场景,特别是在需要生成图像数据或理解GAN工作原理的场合。
贡献者:
本代码来源于
https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/generative_adversarial_network/main.py#L41-L57适用于希望快速入门GAN或在MNIST数据集上实践GAN模型的用户。
完整代码
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
# Create a directory if not exists
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
# Image processing
# transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels
# std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], # 1 for greyscale channels
std=[0.5])])
# MNIST dataset
mnist = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transform,
download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)
# Discriminator
D = nn.Sequential(
nn.Linear(image_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, 1),
nn.Sigmoid())
# Generator
G = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, image_size),
nn.Tanh())
# Device setting
D = D.to(device)
G = G.to(device)
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1)
def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()
# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):
for i, (images, _) in enumerate(data_loader):
images = images.reshape(batch_size, -1).to(device)
# Create the labels which are later used as input for the BCE loss
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ================================================================== #
# Train the discriminator #
# ================================================================== #
# Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
# Second term of the loss is always zero since real_labels == 1
outputs = D(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
# Compute BCELoss using fake images
# First term of the loss is always zero since fake_labels == 0
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
# Backprop and optimize
d_loss = d_loss_real + d_loss_fake
reset_grad()
d_loss.backward()
d_optimizer.step()
# ================================================================== #
# Train the generator #
# ================================================================== #
# Compute loss with fake images
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
g_loss = criterion(outputs, real_labels)
# Backprop and optimize
reset_grad()
g_loss.backward()
g_optimizer.step()
if (i+1) % 200 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))
# Save real images
if (epoch+1) == 1:
images = images.reshape(images.size(0), 1, 28, 28)
save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
# Save sampled images
fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
# Save the model checkpoints
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')
代码解析
1. 导入必要的库
代码开始处导入了多个Python库,这些库提供了后续操作所需的功能。
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
2. 设备配置
设置设备为GPU(如果可用),否则使用CPU。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
3. 超参数设置
定义了网络训练过程中将使用的超参数,如潜在空间大小、隐藏层大小等。
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
4. 创建样本目录
如果不存在,创建一个目录来保存生成的样本图像。
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
5. 图像处理
定义图像预处理步骤,包括转换为张量和归一化。
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
6. 加载MNIST数据集
加载MNIST数据集,并应用上面定义的转换。
mnist = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transform,
download=True)
7. 创建数据加载器
创建一个DataLoader
对象,用于批量加载数据,并在训练时打乱数据顺序。
data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)
8. 定义判别器(Discriminator)D
使用nn.Sequential
定义了一个简单的神经网络结构作为判别器。
D = nn.Sequential(
nn.Linear(image_size, hidden_size),
nn.LeakyReLU(0.2),
# ... 其他层 ...
nn.Sigmoid()
)
9. 定义生成器(Generator)G
同样使用nn.Sequential
定义生成器网络结构。
G = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.ReLU(),
# ... 其他层 ...
nn.Tanh()
)
10. 设备设置
将判别器和生成器移动到之前设置的设备上。
D = D.to(device)
G = G.to(device)
11. 损失函数和优化器
定义了二元交叉熵损失函数和两个Adam优化器,分别用于判别器和生成器。
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
12. 辅助函数
定义了两个辅助函数,denorm
用于将归一化的图像反归一化,reset_grad
用于清除梯度。
def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1)
def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()
13. 训练循环
实现了GAN的训练过程,包括训练判别器和生成器的逻辑。
for epoch in range(num_epochs):
# ... 训练逻辑 ...
14. 打印训练进度
在训练过程中,每隔一定步数打印当前的训练状态。
15. 保存图像
在训练过程中,保存真实图像和生成的假图像。
# Save real images
# Save sampled images
16. 保存模型检查点
训练结束后,保存生成器和判别器的模型参数。
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')
这段代码实现了一个典型的GAN训练流程,包括数据预处理、模型定义、训练循环、图像保存和模型保存等步骤。
运行效果
源代码直接运行: