深度学习训练营之DCGAN网络学习
- 原文链接
- 环境介绍
- DCGAN简单介绍
- 生成器(Generator)
- 判别器(Discriminator)
- 对抗训练
- 前置工作
- 导入第三方库
- 导入数据
- 数据查看
- 定义模型
- 初始化权重
- 定义生成器generator
- 定义判别器
- 模型训练
- 定义参数
- 模型训练
- 结果可视化
原文链接
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍦 参考文章:365天深度学习训练营-第G2周:深度学习训练营之DCGAN网络学习
- 🍖 原作者:K同学啊|接辅导、项目定制
环境介绍
- 语言环境:Python3.11.4
- 编译器:jupyter notebook
- 深度学习环境:TensorFlow2
DCGAN简单介绍
DCGAN(Deep Convolutional Generative Adversarial Network)是一种基于生成对抗网络(GAN)的深度学习模型,用于生成逼真的图像。它通过将生成器和判别器两个网络相互对抗地训练,以实现生成高质量图像的目标。
DCGAN 的核心思想是使用卷积神经网络(CNN)作为生成器和判别器的网络结构。下面是 DCGAN 的一般工作原理:
生成器(Generator)
生成器接受一个随机噪声向量作为输入,并使用反卷积层(或称为转置卷积层)将其逐渐放大和转换为图像。
通过层层上采样处理和卷积操作,生成器逐渐学习到将低分辨率噪声向量转化为高分辨率逼真图像的映射。
生成器的目标是尽可能接近真实图像的分布,从而生成看起来真实的图像。
判别器(Discriminator)
判别器是一个二分类的CNN网络,用于区分真实图像和生成器生成的假图像。
判别器接受输入图像并输出一个概率,表示输入图像是真实图像的概率。
判别器通过对真实图像分配较高的概率值,并对生成器生成的假图像分配较低的概率值,来辨别真实和假的图像。
对抗训练
DCGAN 的核心是通过对抗训练生成器和判别器来提升它们的性能(属于是无监督的学习)。
在训练过程中,生成器试图生成逼真的图像以欺骗判别器,而判别器则努力区分真实和生成的图像。
这里就可以理解为生成器通过尽可能地生成逼近于真实图片的图像来尝试骗过判别器,而判别器就是通过尽可能地将假图片和真图片进行区分,当两种之间发生冲突的时候,就会进行进一步的优化,直到达到平衡,在后续的代码当中我们也可以看到生成器和判别器之间的网络价格正好是相反的
生成器和判别器相互对抗地进行训练,通过最小化生成器生成图像被判别为假的概率(对抗损失)和最大化真实图像被判别为真的概率(真实损失)来优化网络。
通过反复训练生成器和判别器,并使它们相互对抗地提升,最终可以得到一个生成器能够生成高质量逼真图像的模型。
前置工作
导入第三方库
import torch,random,os
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
manualSeed=999#随机数种子
print("Random Seed:",manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True)
999
导入数据
导入数据并设置超参数
dataroot="./DCGAN/"
# 数据集和上一周的一样,所以就放在一起了
batch_size=128
image_size=64
nz=100 #z潜在的向量大小(生成器generator的尺寸)
ngf=64 #生成器中的特征图大小
ndf=64
num_epochs=50
lr=0.00002
beta1=0.5
print(dataroot)
数据查看
进行数据的导入,
-
用
ImageFolder
类来创建数据集对象, -
Transforms.Compose
组合成一系列的图像变换操作来对图像进行预处理 -
DataLoder
类来创建一个数据加载器的对象 -
Matplotlib
库来绘制这些图像
dataset=dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),#转换成张量
transforms.Normalize((0.5,0.5,0.5),
(0.5,0.5,0.5)),
]))
dataloader=torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=5#使用多个线程加载数据的工作进程数
)
device=torch.device=("cuda:0"if (torch.cuda.is_available())else "cpu")
print("使用的设备为 " +device)
real_batch=next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:24],
padding=2,
normalize=True).cpu(),(1,2,0)))
定义模型
初始化权重
def weights_init(m):
#获取当前层类名
classname=m.__class__.__name__
#包含conv,表示当前层是卷积层
if classname.find('Conv')!=-1:
#j均值设为0.0,标准差为0.02
nn.init.normal_(m.weight.data,0.0,0.02)#直接在张量上进行参数初始化
elif classname.find('BatchNorm')!=-1:
nn.init.normal_(m.weight.data,1.0,0.02)
nn.init.constant_(m.bias.data,0)
定义生成器generator
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
## 模型中间块儿
self.main=nn.Sequential(
nn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),
nn.BatchNorm2d(ngf*8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),
nn.BatchNorm2d(ngf*4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),
nn.BatchNorm2d(ngf*2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf,3,4,2,1,bias=False),
nn.Tanh()#Tanh激活函数
)
def forward(self, input):
return self.main(input)
#创建生成器
netG=Generator().to(device)
netG.apply(weights_init)
print(netG)
大家可以注意一下这个网络的架构,会和后面的判别器是相反的
定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.main=nn.Sequential(
nn.Conv2d(3,ndf,4,2,1,bias=False),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),
nn.BatchNorm2d(ndf*4),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),
nn.BatchNorm2d(ndf*8),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(ndf*8,1,4,1,0,bias=False),
nn.Sigmoid()#Sigmoid激活函数
)
def forward(self, input):
return self.main(input)
#创建判别器
netD=Discriminator().to(device)
netD.apply(weights_init)#weights_init初始化所有权重
print(netD)
模型训练
定义参数
criterion=nn.BCELoss()
fixed_noise=torch.randn(64,nz,1,1,device=device)
real_label=1.#1表示真实
fake_label=0.#0表示虚假生成
#设置优化器
optimizerD=optim.Adam(netD.parameters(),lr=lr,betas=(beta1,0.999))
optimizerG=optim.Adam(netG.parameters(),lr=lr,betas=(beta1,0.999))
模型训练
img_list=[]#用存储生成的图像列表
G_losses=[]
D_losses=[]
iters=0#迭代次数
print("开始训练Starting Training Loop..")
for epoch in range(num_epochs):
#dataloader中的每个batch
for i,data in enumerate(dataloader,0):
####
#最大化log(D(x))+log(1-D(G(z)))
####
netD.zero_grad()#清除判别器网络的梯度
real_cpu=data[0].to(device)
b_size=real_cpu.size(0)
label=torch.full((b_size,),real_label,dtype=torch.float,device=device)
#输入判别器进行前向传播
output=netD(real_cpu).view(-1)
errD_real=criterion(output,label)
errD_real.backward()
D_x=output.mean().item()#计算批判别器对真实图像样本的输出平均值
'''使用生成图像样本进行训练'''
noise=torch.randn(b_size,nz,1,1,device=device)
fake=netG(noise)
label.fill_(fake_label)
output=netD(fake.detach()).view(-1)
errD_fake=criterion(output,label)
errD_fake.backward()
D_G_z1=output.mean().item()
errD=errD_fake+errD_real
optimizerD.step()
'''更新生成网络'''
netG.zero_grad()
label.fill_(real_label)
output=netD(fake).view(-1)
errG=criterion(output,label)
errG.backward()
D_G_z2=output.mean().item()
optimizerG.step()
if i % 400 == 0:
print('[%d/%d][%d/%d]\tLoss_D:%.4f\tLoss_G:%.4f\tD(x):%.4f\tD(G(z)):%.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
#保存损失值
G_losses.append(errG.item())
D_losses.append(errD.item())
#保存固定噪声上的输出来检查生成器的性能
if(iters%500==0)or((epoch==num_epochs-1)and(i==len(dataloader)-1)):
with torch.no_grad():
fake=netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake,padding=2,normalize=True))
iters+=1
开始训练Starting Training Loop..
[0/50][0/36] Loss_D:1.3728 Loss_G:1.0315 D(x):0.6877 D(G(z)):0.5443 / 0.4221
[1/50][0/36] Loss_D:0.3502 Loss_G:2.3366 D(x):0.9120 D(G(z)):0.1921 / 0.1283
[2/50][0/36] Loss_D:0.1925 Loss_G:3.2138 D(x):0.9384 D(G(z)):0.0957 / 0.0582
[3/50][0/36] Loss_D:0.1281 Loss_G:3.6822 D(x):0.9570 D(G(z)):0.0674 / 0.0370
[4/50][0/36] Loss_D:0.1669 Loss_G:4.0574 D(x):0.9308 D(G(z)):0.0563 / 0.0262
[5/50][0/36] Loss_D:0.1337 Loss_G:4.2146 D(x):0.9428 D(G(z)):0.0551 / 0.0209
[6/50][0/36] Loss_D:0.0729 Loss_G:4.5967 D(x):0.9696 D(G(z)):0.0344 / 0.0138
[7/50][0/36] Loss_D:0.0770 Loss_G:4.6592 D(x):0.9747 D(G(z)):0.0344 / 0.0133
[8/50][0/36] Loss_D:0.0932 Loss_G:4.8994 D(x):0.9742 D(G(z)):0.0303 / 0.0105
[9/50][0/36] Loss_D:0.0790 Loss_G:5.0675 D(x):0.9819 D(G(z)):0.0269 / 0.0083
[10/50][0/36] Loss_D:0.0496 Loss_G:5.0618 D(x):0.9807 D(G(z)):0.0278 / 0.0085
[11/50][0/36] Loss_D:0.0452 Loss_G:5.2256 D(x):0.9800 D(G(z)):0.0221 / 0.0069
[12/50][0/36] Loss_D:0.0332 Loss_G:5.4038 D(x):0.9833 D(G(z)):0.0148 / 0.0058
[13/50][0/36] Loss_D:0.0370 Loss_G:5.2032 D(x):0.9815 D(G(z)):0.0171 / 0.0064
[14/50][0/36] Loss_D:0.0326 Loss_G:5.5015 D(x):0.9838 D(G(z)):0.0149 / 0.0053
[15/50][0/36] Loss_D:0.0368 Loss_G:5.4651 D(x):0.9872 D(G(z)):0.0162 / 0.0055
[16/50][0/36] Loss_D:0.0349 Loss_G:5.6891 D(x):0.9849 D(G(z)):0.0186 / 0.0047
[17/50][0/36] Loss_D:0.0214 Loss_G:5.5402 D(x):0.9925 D(G(z)):0.0133 / 0.0048
[18/50][0/36] Loss_D:0.0216 Loss_G:5.6668 D(x):0.9912 D(G(z)):0.0123 / 0.0041
[19/50][0/36] Loss_D:0.0219 Loss_G:5.6475 D(x):0.9919 D(G(z)):0.0132 / 0.0046
[20/50][0/36] Loss_D:0.0165 Loss_G:5.7313 D(x):0.9956 D(G(z)):0.0118 / 0.0040
[21/50][0/36] Loss_D:0.0203 Loss_G:5.7859 D(x):0.9939 D(G(z)):0.0138 / 0.0040
[22/50][0/36] Loss_D:0.0266 Loss_G:5.7094 D(x):0.9850 D(G(z)):0.0104 / 0.0040
[23/50][0/36] Loss_D:0.0207 Loss_G:5.7429 D(x):0.9899 D(G(z)):0.0101 / 0.0038
...
[46/50][0/36] Loss_D:0.0100 Loss_G:6.6160 D(x):0.9945 D(G(z)):0.0044 / 0.0024
[47/50][0/36] Loss_D:0.0114 Loss_G:7.1434 D(x):0.9927 D(G(z)):0.0025 / 0.0017
[48/50][0/36] Loss_D:0.0039 Loss_G:7.2856 D(x):0.9980 D(G(z)):0.0019 / 0.0012
[49/50][0/36] Loss_D:0.0198 Loss_G:6.2926 D(x):0.9882 D(G(z)):0.0048 / 0.0029
结果可视化
plt.figure(figsize=(10, 5))
plt.title('Generator and Discriminator Loss During Training')
plt.plot(G_losses, label='G')
plt.plot(D_losses, label='D')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.show()
阿哲,训练效果好差,不知道是不是硬件的问题
fig = plt.figure(figsize=(8, 8))
plt.axis('off')
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())