【图书推荐】《PyTorch深度学习与企业级项目实战》-CSDN博客
《PyTorch深度学习与企业级项目实战(人工智能技术丛书)》(宋立桓,宋立林)【摘要 书评 试读】- 京东图书 (jd.com)
如今AI艺术创作能力越来越强大,Google发布的ImageGen项目基于文本提示作画的结果和真实艺术家的成品难辨真假。本项目将使用PyTorch实现生成式对抗网络生成式对抗网络来完成AI生成动漫人物图像。
本项目中使用的数据集是一个由63 632个高质量动画人脸组成的数据集,从www.getchu.com中抓取,然后使用https://github.com/nagadomi/lbpcascade_animeface中的动画人脸检测算法进行裁剪。图像大小从90×90到120×120不等。该数据集包含高质量的动漫角色图像,具有干净的背景和丰富的颜色。数据集下载链接:https://github.com/bchao1/Anime-Face-Dataset。
我们知道在生成式对抗网络中有两个模型——生成模型(Generative Model,G)和判别模型(Discriminative Model,D)。G就是一个生成图片的网络,它接收一个随机的噪声z,然后通过这个噪声生成图片,生成的数据记作G(z)。D是一个判别网络,判别一幅图片是不是“真实的”(是不是捏造的)。它的输入参数是x,x代表一幅图片,输出D(x)代表x为真实图片的概率,如果为1,就代表是真实的图片,而输出为0,就代表不可能是真实的图片。
- 定义生成器Generator:生成器的输入为100维的高斯噪声,生成器会利用这个噪声生成指定大小的图片,关于最初的噪声,可以看成10011的特征图,然后利用转置卷积来进行尺寸还原操作,标准的卷积操作是不断缩小尺寸,转置卷积就可以理解为它的逆操作,这样就可以不断放大图像。
- 定义判别器Discriminator:判别器就是一个典型的二分类网络,首先它的输入是我们输入的图片,我们会利用一系列卷积操作来形成一维特征图进行分类操作,这里可以发现判别器的网络和生成器的相关操作是可逆的,唯独不一样的是激活函数。
模型训练的步骤如下:
步骤1:首先固定生成器,训练判别器,提高真实样本被判别为真的概率,同时降低生成器生成的假图像被判别为真的概率,目标是判别器能准确进行分类。
步骤2:固定判别器,训练生成器,生成器生成图像,尽可能提高该图像被判别器判别为真的概率,目标是生成器的结果能够骗过判别器。
步骤3:重复,循环交替训练,最终生成器生成的样本足够逼真,使得鉴别器只有大约50%的判断正确率(相当于乱猜)。
完整代码如下:
#####################GANDEMO.py####################
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
import torchvision
from torchvision import transforms, datasets
from tqdm import tqdm
class Config(object):
data_path = './gandata/data/'
image_size = 96
batch_size = 32
epochs = 200
lr1 = 2e-3
lr2 = 2e-4
beta1 = 0.5
gpu = False
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
nz = 100
ngf = 64
ndf = 64
save_path = './gandata/images'
generator_path = './gandata/generator.pkl' #模型保存路径
discriminator_path = './gandata/discriminator.pkl' #模型保存路径
gen_img = './gandata/result.png'
gen_num = 64
gen_search_num = 5000
gen_mean = 0
gen_std = 1
config = Config()
# 1.数据转换
data_transform = transforms.Compose([
transforms.Resize(config.image_size),
transforms.CenterCrop(config.image_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# 2.形成训练集
train_dataset = datasets.ImageFolder(root=os.path.join(config.data_path),
transform=data_transform)
# 3.形成迭代器
train_loader = torch.utils.data.DataLoader(train_dataset,
config.batch_size,
True,
drop_last=True)
print('using {} images for training.'.format(len(train_dataset)))
class Generator(nn.Module):
def __init__(self, config):
super().__init__()
ngf = config.ngf
self.model = nn.Sequential(
nn.ConvTranspose2d(config.nz, ngf * 8, 4, 1, 0),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, 3, 5, 3, 1),
nn.Tanh()
)
def forward(self, x):
output = self.model(x)
return output
class Discriminator(nn.Module):
def __init__(self, config):
super().__init__()
ndf = config.ndf
self.model = nn.Sequential(
nn.Conv2d(3, ndf, 5, 3, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 8, 1, 4, 1, 0)
)
def forward(self, x):
output = self.model(x)
return output.view(-1)
generator = Generator(config)
discriminator = Discriminator(config)
optimizer_generator = torch.optim.Adam(generator.parameters(),
config.lr1,
betas=(config.beta1, 0.999))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
config.lr2,
betas=(config.beta1, 0.999))
true_labels = torch.ones(config.batch_size)
fake_labels = torch.zeros(config.batch_size)
fix_noises = torch.randn(config.batch_size, config.nz, 1, 1)
noises = torch.randn(config.batch_size, config.nz, 1, 1)
for epoch in range(config.epochs):
for ii, (img, _) in tqdm(enumerate(train_loader)):
real_img = img.to(config.device)
if ii % 2 == 0:
optimizer_discriminator.zero_grad()
r_preds = discriminator(real_img)
noises.data.copy_(torch.randn(config.batch_size, config.nz, 1, 1))
fake_img = generator(noises).detach()
f_preds = discriminator(fake_img)
r_f_diff = (r_preds - f_preds.mean()).clamp(max=1)
f_r_diff = (f_preds - r_preds.mean()).clamp(min=-1)
loss_d_real = (1 - r_f_diff).mean()
loss_d_fake = (1 + f_r_diff).mean()
loss_d = loss_d_real + loss_d_fake
loss_d.backward()
optimizer_discriminator.step()
else:
optimizer_generator.zero_grad()
noises.data.copy_(torch.randn(config.batch_size, config.nz, 1, 1))
fake_img = generator(noises)
f_preds = discriminator(fake_img)
r_preds = discriminator(real_img)
r_f_diff = r_preds - torch.mean(f_preds)
f_r_diff = f_preds - torch.mean(r_preds)
loss_g = torch.mean(F.relu(1 + r_f_diff)) \
+ torch.mean(F.relu(1 - f_r_diff))
loss_g.backward()
optimizer_generator.step()
if epoch == config.epochs - 1:
# 保存模型
torch.save(discriminator.state_dict(), config.discriminator_path)
torch.save(generator.state_dict(), config.generator_path)
print('Finished Training')
generator = Generator(config)
discriminator = Discriminator(config)
noises = torch.randn(config.gen_search_num,
config.nz, 1, 1).normal_(config.gen_mean,
config.gen_std)
noises = noises.to(config.device)
generator.load_state_dict(torch.load(config.generator_path,
map_location='cpu'))
discriminator.load_state_dict(torch.load(config.discriminator_path,
map_location='cpu'))
generator.to(config.device)
discriminator.to(config.device)
fake_img = generator(noises)
scores = discriminator(fake_img).detach()
indexs = scores.topk(config.gen_num)[1]
result = []
for ii in indexs:
result.append(fake_img.data[ii])
torchvision.utils.save_image(torch.stack(result), config.gen_img,
normalize=True, value_range=(-1, 1))
代码运行结果如下:
using 900 images for training.
28it [00:20, 1.40it/s]
28it [00:20, 1.33it/s]
28it [00:21, 1.29it/s]
…
28it [00:26, 1.06it/s]
Finished Training
效果图如图13-9所示,由于只训练了100个Epoch,因此图像生成的纹理还不算太清楚,大家计算资源允许的话,可以多训练一些Epoch来生成更多的图像细节。
图13-9