人脸图像生成(DCGAN)

news2024/11/26 10:43:36

一、理论基础

1.什么是深度卷积对抗网络(Deep Convolutional Generative Adversarial Network,)
深度卷积对抗网络(Deep Convolutional Generative Adversarial Network,DCGAN)是一种生成对抗网络(GAN)的变体,它结合了深度卷积神经网络(CNN)的特性和生成对抗网络的架构。

生成对抗网络是由生成器(Generator)和判别器(Discriminator)组成的模型。生成器尝试生成与真实数据相似的样本,而判别器则尝试区分生成的样本和真实样本。两者通过博弈的方式不断优化,使生成器生成更逼真的样本。

DCGAN引入了卷积神经网络的结构,以处理图像生成任务。它的主要特点包括:

卷积层替代全连接层: DCGAN中的生成器和判别器都使用卷积层,这有助于模型学习图像中的空间层次特征,从而更好地捕捉图像的结构信息。

批归一化(Batch Normalization): 在生成器和判别器中广泛使用批归一化,有助于加速训练过程,同时提高模型的稳定性和生成效果。

去除全连接层: DCGAN中移除了全连接层,这有助于减少模型参数数量,降低过拟合的风险。

使用Leaky ReLU激活函数: 生成器和判别器中使用Leaky ReLU激活函数,以避免梯度消失的问题,同时引入一定的负斜率,促使模型更容易学习。

DCGAN的目标是通过训练生成器生成逼真的图像,同时训练判别器以有效地区分真实和生成的图像。这种架构的成功应用包括图像生成、图像编辑、图像超分辨率等领域。

2.DCGAN原理
DCGAN(Deep Convolutional Generative Adversarial Network)由GAN进行改进得到,它由两个子网络组成:生成器和判别器。

生成器网络接受一个随机噪声向量作为输入,并尝试生成看起来像真实数据的输出。具体来说,生成器网络通常由多个卷积层和反卷积层组成,这些层将随机噪声转换为具有现实特征的图像。

判别器网络则接受输入并尝试将其分类为“真实”或“生成”。判别器网络通常由多个卷积层组成,这些层将输入转换为具有现实特征的表示形式,并输出一个二进制数字,表示输入是否是真实数据。

在训练过程中,生成器和判别器交替进行预测和生成,以逐渐提高生成器输出的质量。生成器试图生成看起来像真实数据的输出,而判别器则试图将其与真实数据区分开来。通过不断地调整生成器和判别器,最终生成器可以生成非常逼真的数据。

3.DCGAN与GAN相同点与不同点
GAN(Generative Adversarial Network)和DCGAN(Deep Convolutional Generative Adversarial Network)都是生成对抗网络,它们的基本原理是相同的,即通过两个相互对抗的网络(生成器和判别器)来进行无监督的学习,以生成高质量的新数据。但是,DCGAN相对于GAN有一些改进和不同点,主要表现在以下方面:

相同点:
基本原理相同:DCGAN和GAN的基本原理都是生成对抗网络(Generative Adversarial Networks),其中两个子网络(生成器和判别器)在对抗中优化彼此。
判别器结构相同:在DCGAN和GAN中,判别器结构都是相同的,都是使用卷积神经网络(CNN)进行特征提取和分类。
不同点:
网络结构不同:DCGAN将卷积运算的思想引入到生成式模型当中,生成器和判别器模型都使用了卷积层,而GAN使用了全连接层。
训练方法不同:DCGAN使用正交标注(Orthogonal Annotation)来生成伪标签(Pseudo Labels),以用于半监督学习。而GAN通常使用真实标签(True Labels)进行监督学习。
输入输出不同:DCGAN输入的数据通常是3D体积,而GAN输入的数据通常是2D图像。此外,DCGAN输出的数据也是3D体积,而GAN输出的数据是2D图像。
判别器模型不同:在DCGAN中,判别器模型使用卷积步长取代了空间池化,以更好地提取图像特征。而在GAN中,判别器模型通常使用空间池化。
生成器模型不同:在DCGAN中,生成器模型中使用反卷积操作扩大数据维度,以更好地处理高分辨率的图像。而在GAN中,生成器模型通常使用上采样操作。
模型优化不同:在DCGAN中,整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。此外,DCGAN还使用了一些其他的技术,如Batch Normalization、Leaky ReLU激活函数和Tanh输出层,以帮助控制输出范围和提高生成质量。而在GAN中,这些技术并不常见。
4.训练原理
DCGAN的训练原理是基于半监督学习,利用已有的少量标注数据和大量无标注数据来生成新的数据。以下是DCGAN训练原理的主要步骤:

正交标注:首先,对一个标记的3D体积进行两个正交切片的标注,即正交标注。这样可以减少标注的负担,并且利用了不同方向的切片提供的互补信息。
注册模块:通过注册模块,将正交标注传播到整个体积,生成了伪标签。注册模块利用了正交标注的信息,将其传播到整个体积,从而生成了伪标签,用于半监督学习的训练过程中。
生成器训练:使用已经生成的伪标签和未标注的数据来训练生成器模型。生成器模型的目标是最小化判别器模型的输出,即让判别器无法区分生成器和真实数据之间的差异。
判别器训练:使用已经生成的伪标签和真实数据来训练判别器模型。判别器模型的目标是最小化生成器模型的输出,即让生成器无法生成看起来像真实数据的输出。
迭代训练:不断地迭代训练生成器和判别器模型,直到生成器能够生成看起来像真实数据的输出,并且判别器能够准确地区分生成数据和真实数据。
DCGAN的训练原理是基于半监督学习,通过生成伪标签和使用注册模块来减少标注负担,从而利用未标注数据来生成新的数据。在整个训练过程中,生成器和判别器相互对抗,以提高生成器的性能和生成数据的质量。

二、前期准备 

import os
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets as dset
import torchvision.utils as vutils
from torchvision.utils import save_image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
 
 
manualSeed = 999  # 随机种子
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results
 
# 超参数配置
dataroot   = "D:/GAN-Data"  # 数据路径
batch_size = 128                   # 训练过程中的批次大小
n_epochs   = 5                     # 训练的总轮数
img_size   = 64                    # 图像的尺寸(宽度和高度)
nz         = 100                   # z潜在向量的大小(生成器输入的尺寸)
ngf        = 64                    # 生成器中的特征图大小
ndf        = 64                    # 判别器中的特征图大小
beta1      = 0.5                   # Adam优化器的Beta1超参数
beta2      = 0.2                   # Adam优化器的Beta1超参数
lr         = 0.0002                # 学习率
 
# 创建数据集
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),        # 调整图像大小
                               transforms.CenterCrop(img_size),    # 中心裁剪图像
                               transforms.ToTensor(),                # 将图像转换为张量
                               transforms.Normalize((0.5, 0.5, 0.5), # 标准化图像张量
                                                    (0.5, 0.5, 0.5)),]))
# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,  # 批量大小
                                         shuffle=True)           # 是否打乱数据集
# 选择要在哪个设备上运行代码
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print("使用的设备是:",device)
# 绘制一些训练图像
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:24],
           padding=2,
           normalize=True).cpu(),(1,2,0)))
 
 

三、定义模型 

criterion = nn.BCELoss()

fixed_noise = torch.randn(64, nz, 1, 1,device=device)
real_label = 1.
fake_label = 0.

# 为生成器(G)和判别器( D)设置Adam优化器
optimizerD = optim.Adam(netD.parameters(),lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(),lr=lr, betas=(beta1, 0.999))
img_list = []  # 用于存储生成的图像列表
G_losses = []  # 用于存储生成器的损失列表
D_losses = []  # 用于存储判别器的损失列表
iters = 0  # 迭代次数
print("Starting Training Loop...")
for epoch in range(num_epochs ):
    for i, data in enumerate(dataloader, 0):
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()
        
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach( )).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()
        
        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        
        if i % 400 == 0:
            print('[%d/%d][%d/%d]\tLoss_D:%.4f\tLoss_G:%.4f\tD(x):%.4f\tD(G(z)):%.4f / %.4f'
                 % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        if(iters % 500 == 0) or ((epoch == num_epochs - 1) and(i == len(dataloader) - 1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1
plt.figure(figsize=(10, 5))
plt.title('Generator and Discriminator Loss During Training')
plt.plot(G_losses, label='G')
plt.plot(D_losses, label='D')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1659973.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

跨域问题(服务器和浏览器之间)待补充

一、为什么产生: 同源策略(域名,协议,端口),安全问题 二、怎么解决: 1、cros:修改响应头 2、jp:采用js标签 3、代理(创建服务器,定义规则,服…

十二届蓝桥杯Python组1月中/高级试题 第五题

** 十二届蓝桥杯Python组1月中/高级试题 第五题 ** 第五题(难度系数 5,35 个计分点) 提示信息: 平均数:是指在一组数据中所有数据之和再除以这组数据的个数。 如:“1,2,3&#xf…

安防监控/视频汇聚系统EasyCVR+AI智能分析助力解决校园霸凌事件

一、方案背景 校园霸凌这一校园中不应存在的现象,却屡见不鲜,它像一把锋利的刀,深深地刺入那些无辜的心灵,让受害者承受着无尽的痛苦。随着科技的进步与发展,我们应该追求有效、进步的手段来阻止校园霸凌事件的发生&a…

达坦科技@了你,并邀请你参加2024开源之夏!

开源之夏(英文简称“OSPP”)是中科院软件所“开源软件供应链点亮计划”指导下的系列暑期活动。达坦科技自开源之夏创办首期起每年参与,积极鼓励在校学生积极参与开源软件的开发维护,培养和发掘更多优秀的开发者。今年,…

超越机械抓手:看多指机器人如何灵活运用触觉?

论文标题: Learning Visuotactile Skills with Two Multifingered Hands 论文作者: Toru Lin, Yu Zhang, Qiyang Li, Haozhi Qi, Brent Yi, Sergey Levine, and Jitendra Malik 1. 机器人新挑战:多指手指操作 在自动化和智能化日益普及的…

【Vulhub靶场】Nginx 中间件漏洞复现

【Vulhub靶场】Nginx 中间件漏洞复现 一、Nginx 文件名逻辑漏洞(CVE-2013-4547)1. 影响版本2. 漏洞原理3. 漏洞复现 二、Nginx越界读取缓存漏洞(CVE-2017-7529)1. 漏洞详情2. 影响版本3. 漏洞复现 三、Nginx 配置错误导致漏洞&…

预告 | 飞凌嵌入式邀您共聚2024上海充换电展

第三届上海国际充电桩及换电站展览会(CPSE),即将于5月22日~24日在上海汽车会展中心举行。届时,飞凌嵌入式将带来多款嵌入式核心板、开发板、充电桩TCU以及储能EMS网关产品,与来自全国的客户朋友及行业伙伴一同交流分享…

基于R语言绘图 | 转录代谢趋势图绘制教程

原文链接:基于R语言绘图 | 转录代谢趋势图绘制教程 本期教程 小杜的生信笔记,自2021年11月开始做的知识分享,主要内容是R语言绘图教程、转录组上游分析、转录组下游分析等内容。凡事在社群同学,可免费获得自2021年11月份至今全部…

【ArcGIS Pro微课1000例】0058:玩转NetCDF多维数据集

一、NetCDF介绍 NetCDF(network Common Data Form)网络通用数据格式是由美国大学大气研究协会(University Corporation for Atmospheric Research,UCAR)的Unidata项目科学家针对科学数据的特点开发的,是一种面向数组型并适于网络共享的数据的描述和编码标准。NetCDF广泛应…

羊大师:当代年轻人如何应对压力

羊大师:当代年轻人如何应对压力 当代年轻人面临各种压力,包括工作、学习、人际关系、经济等方面的压力。以下是一些建议,帮助年轻人应对这些压力: 认识并接受压力: 首先要认识到压力是生活中不可避免的一部分。 尝试…

WPF之DataGird应用

1,DataGrid相关属性 GridLinesVisibility:DataGrid网格线是否显示或者显示的方式。HorizontalGridLinesBrush:水平网格线画刷。VerticalGridLinesBrush:垂直网格线画刷。HorizontalScrollBarVisibility:水平滚动条可见…

卷积通用模型的剪枝、蒸馏---蒸馏篇--RKD关系蒸馏(以deeplabv3+为例)

本文使用RKD实现对deeplabv3+模型的蒸馏;与上一篇KD蒸馏的方法有所不同,RKD是对展平层的特征做蒸馏,蒸馏的loss分为二阶的距离损失Distance-wise Loss和三阶的角度损失Angle-wise Loss。 一、RKD简介 RKD算法的核心是以教师模型的多个输出为结构单元,取代传统蒸馏学习中以教…

【经验总结】 常用的模型优化器

优化器是一种用于优化模型权重和偏差的算法,它根据训练数据更新模型参数,以模型的预测结果更加准确。 1. 常见的优化器 SGD(Stochastic Gradient Descent):SGD是一种基本的优化算法,它在每次迭代中随机选择…

借势吃货节趣味小游戏的效果是什么

吃货节对食品、餐饮等行业厂家/商家来说非常利好,借势节日气氛能更快达成预期营销效果,除了传统方式外,线上趣味互动游戏营销也是重要形式。 搜索【雨科】平台拥有多款吃货节趣味抽奖h5小游戏形式,不同玩法和内容承载、渠道传播用…

简单的Python HTML 输出

1、问题背景 一名初学者在尝试将 Python 脚本输出到网页上时遇到了一些问题。他当前使用 Python 和 HTML 进行开发,并且遇到了以下问题: 担心自己的代码过于复杂,尤其是 WebOutput() 函数。希望通过 JavaScript 使用 HTML 模板文件更新数据。…

48. UE5 RPG 实现攻击伤害数字显示

在前面的文章中,我们实现了对敌人的攻击的受击效果,并且能够降低目标的血量,实现死亡效果。相对于正常的游戏,我们还需要实现技能或者攻击对敌人造成的伤害数值,并直观的显示出来。 所以,接下来&#xff0c…

【JAVA】JAVA的垃圾回收机制详解

对于Java的垃圾回收机制,它是Java虚拟机(JVM)提供的一种自动内存管理机制,主要负责回收不再使用的对象以释放内存空间。垃圾回收机制主要包括以下几个方面的内容: 垃圾对象的识别:Java虚拟机通过一些算法&…

MySQL索引优化(超详细)篇章2--索引调优

目录 1.索引失效状况2.性能分析3.表的索引信息--调整索引顺序4.删除冗余索引5.最佳左前缀法则5.1下面是一个实际的例子来说明这个概念: 6.数据长度和索引长度占用空间比较 1.索引失效状况 MySQL索引失效通常指的是查询语句无法有效地利用索引,而导致全表…

matlab打开文件对话框

在使用matlab GUI制作时,为了便于用户交互使用,经常设置文件打开对话框,让用户根据实际需要选择打开的文件。下面以打开一张图片为例,matlab代码如下: [temp_filepath,temp_filename]uigetfile(*.jpg,请选择要打开的图…

设计模式(2)创造型设计模式

创建型模式 创建型模式1.工厂模式1.1 抽象工厂模式(Abstract factory)1.2 工厂方法模式(Factory Method)1.3 简单工厂模式(Simple Factory) 2. 建造者模式(Builder)3. 原型模式&…