深度学习训练营之DCGAN网络学习

news2024/9/30 7:30:11

深度学习训练营之DCGAN网络学习

  • 原文链接
  • 环境介绍
  • DCGAN简单介绍
    • 生成器(Generator)
    • 判别器(Discriminator)
    • 对抗训练
  • 前置工作
    • 导入第三方库
    • 导入数据
    • 数据查看
  • 定义模型
    • 初始化权重
    • 定义生成器generator
    • 定义判别器
  • 模型训练
    • 定义参数
    • 模型训练
  • 结果可视化

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第G2周:深度学习训练营之DCGAN网络学习
  • 🍖 原作者:K同学啊|接辅导、项目定制

环境介绍

  • 语言环境:Python3.11.4
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

DCGAN简单介绍

DCGAN(Deep Convolutional Generative Adversarial Network)是一种基于生成对抗网络(GAN)的深度学习模型,用于生成逼真的图像。它通过将生成器和判别器两个网络相互对抗地训练,以实现生成高质量图像的目标。

DCGAN 的核心思想是使用卷积神经网络(CNN)作为生成器和判别器的网络结构。下面是 DCGAN 的一般工作原理:

生成器(Generator)

生成器接受一个随机噪声向量作为输入,并使用反卷积层(或称为转置卷积层)将其逐渐放大和转换为图像。
通过层层上采样处理和卷积操作,生成器逐渐学习到将低分辨率噪声向量转化为高分辨率逼真图像的映射。
生成器的目标是尽可能接近真实图像的分布,从而生成看起来真实的图像。

判别器(Discriminator)

判别器是一个二分类的CNN网络,用于区分真实图像和生成器生成的假图像。
判别器接受输入图像并输出一个概率,表示输入图像是真实图像的概率。
判别器通过对真实图像分配较高的概率值,并对生成器生成的假图像分配较低的概率值,来辨别真实和假的图像。

对抗训练

DCGAN 的核心是通过对抗训练生成器和判别器来提升它们的性能(属于是无监督的学习)。
在训练过程中,生成器试图生成逼真的图像以欺骗判别器,而判别器则努力区分真实和生成的图像。
这里就可以理解为生成器通过尽可能地生成逼近于真实图片的图像来尝试骗过判别器,而判别器就是通过尽可能地将假图片和真图片进行区分,当两种之间发生冲突的时候,就会进行进一步的优化,直到达到平衡,在后续的代码当中我们也可以看到生成器和判别器之间的网络价格正好是相反的
生成器和判别器相互对抗地进行训练,通过最小化生成器生成图像被判别为假的概率(对抗损失)和最大化真实图像被判别为真的概率(真实损失)来优化网络。
通过反复训练生成器和判别器,并使它们相互对抗地提升,最终可以得到一个生成器能够生成高质量逼真图像的模型
在这里插入图片描述

前置工作

导入第三方库

import torch,random,os
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML 
manualSeed=999#随机数种子
print("Random Seed:",manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True)

999

导入数据

导入数据并设置超参数

dataroot="./DCGAN/"
# 数据集和上一周的一样,所以就放在一起了
batch_size=128
image_size=64
nz=100 #z潜在的向量大小(生成器generator的尺寸)
ngf=64 #生成器中的特征图大小
ndf=64
num_epochs=50
lr=0.00002
beta1=0.5
print(dataroot)

数据查看

进行数据的导入,

  • ImageFolder类来创建数据集对象,

  • Transforms.Compose组合成一系列的图像变换操作来对图像进行预处理

  • DataLoder类来创建一个数据加载器的对象

  • Matplotlib库来绘制这些图像

dataset=dset.ImageFolder(root=dataroot,
                         transform=transforms.Compose([
                            transforms.Resize(image_size),
                            transforms.CenterCrop(image_size),
                            transforms.ToTensor(),#转换成张量
                            transforms.Normalize((0.5,0.5,0.5),
                                                 (0.5,0.5,0.5)),
                         ]))
dataloader=torch.utils.data.DataLoader(dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=5#使用多个线程加载数据的工作进程数
                                      )
device=torch.device=("cuda:0"if (torch.cuda.is_available())else "cpu")
print("使用的设备为 " +device)

real_batch=next(iter(dataloader)) 
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:24],
                                         padding=2,
                                         normalize=True).cpu(),(1,2,0)))

在这里插入图片描述

定义模型

初始化权重

def weights_init(m):
    #获取当前层类名
    classname=m.__class__.__name__
    #包含conv,表示当前层是卷积层
    if classname.find('Conv')!=-1:
        #j均值设为0.0,标准差为0.02
        nn.init.normal_(m.weight.data,0.0,0.02)#直接在张量上进行参数初始化
    elif classname.find('BatchNorm')!=-1:
        nn.init.normal_(m.weight.data,1.0,0.02)
        nn.init.constant_(m.bias.data,0)

定义生成器generator

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        ## 模型中间块儿
        self.main=nn.Sequential(
            nn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf,3,4,2,1,bias=False),
            nn.Tanh()#Tanh激活函数
        )
       

    def forward(self, input):                           
        return self.main(input)                      
#创建生成器 
netG=Generator().to(device)
netG.apply(weights_init)
print(netG)

大家可以注意一下这个网络的架构,会和后面的判别器是相反的
在这里插入图片描述

定义判别器

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.main=nn.Sequential(
            nn.Conv2d(3,ndf,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2,inplace=True),
             nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf*8,1,4,1,0,bias=False),
            nn.Sigmoid()#Sigmoid激活函数
        )

    def forward(self, input):
        return self.main(input)   
#创建判别器
netD=Discriminator().to(device)
netD.apply(weights_init)#weights_init初始化所有权重
print(netD)

在这里插入图片描述

模型训练

定义参数

criterion=nn.BCELoss()
fixed_noise=torch.randn(64,nz,1,1,device=device)
real_label=1.#1表示真实
fake_label=0.#0表示虚假生成

#设置优化器
optimizerD=optim.Adam(netD.parameters(),lr=lr,betas=(beta1,0.999))
optimizerG=optim.Adam(netG.parameters(),lr=lr,betas=(beta1,0.999))

模型训练

img_list=[]#用存储生成的图像列表
G_losses=[]
D_losses=[]
iters=0#迭代次数
print("开始训练Starting Training Loop..")
for epoch in range(num_epochs):
    #dataloader中的每个batch
    for i,data in enumerate(dataloader,0):
        ####
        #最大化log(D(x))+log(1-D(G(z)))
        ####
        netD.zero_grad()#清除判别器网络的梯度 
        real_cpu=data[0].to(device)
        b_size=real_cpu.size(0)
        label=torch.full((b_size,),real_label,dtype=torch.float,device=device) 
        #输入判别器进行前向传播
        output=netD(real_cpu).view(-1)
        errD_real=criterion(output,label)
        errD_real.backward()
        D_x=output.mean().item()#计算批判别器对真实图像样本的输出平均值
        
        
        '''使用生成图像样本进行训练'''
        noise=torch.randn(b_size,nz,1,1,device=device)
        fake=netG(noise)
        label.fill_(fake_label)
        output=netD(fake.detach()).view(-1)
        errD_fake=criterion(output,label)
        errD_fake.backward()
        D_G_z1=output.mean().item()
        errD=errD_fake+errD_real
        optimizerD.step()
        
        '''更新生成网络'''
        netG.zero_grad()
        label.fill_(real_label)
        output=netD(fake).view(-1)
        errG=criterion(output,label)
        errG.backward()
        D_G_z2=output.mean().item()
        optimizerG.step()
        
        if i % 400 == 0:
                print('[%d/%d][%d/%d]\tLoss_D:%.4f\tLoss_G:%.4f\tD(x):%.4f\tD(G(z)):%.4f / %.4f'
                 % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        #保存损失值
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        #保存固定噪声上的输出来检查生成器的性能 
        if(iters%500==0)or((epoch==num_epochs-1)and(i==len(dataloader)-1)):
            with torch.no_grad():
                fake=netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake,padding=2,normalize=True))
        
        iters+=1
开始训练Starting Training Loop..
[0/50][0/36]	Loss_D:1.3728	Loss_G:1.0315	D(x):0.6877	D(G(z)):0.5443 / 0.4221
[1/50][0/36]	Loss_D:0.3502	Loss_G:2.3366	D(x):0.9120	D(G(z)):0.1921 / 0.1283
[2/50][0/36]	Loss_D:0.1925	Loss_G:3.2138	D(x):0.9384	D(G(z)):0.0957 / 0.0582
[3/50][0/36]	Loss_D:0.1281	Loss_G:3.6822	D(x):0.9570	D(G(z)):0.0674 / 0.0370
[4/50][0/36]	Loss_D:0.1669	Loss_G:4.0574	D(x):0.9308	D(G(z)):0.0563 / 0.0262
[5/50][0/36]	Loss_D:0.1337	Loss_G:4.2146	D(x):0.9428	D(G(z)):0.0551 / 0.0209
[6/50][0/36]	Loss_D:0.0729	Loss_G:4.5967	D(x):0.9696	D(G(z)):0.0344 / 0.0138
[7/50][0/36]	Loss_D:0.0770	Loss_G:4.6592	D(x):0.9747	D(G(z)):0.0344 / 0.0133
[8/50][0/36]	Loss_D:0.0932	Loss_G:4.8994	D(x):0.9742	D(G(z)):0.0303 / 0.0105
[9/50][0/36]	Loss_D:0.0790	Loss_G:5.0675	D(x):0.9819	D(G(z)):0.0269 / 0.0083
[10/50][0/36]	Loss_D:0.0496	Loss_G:5.0618	D(x):0.9807	D(G(z)):0.0278 / 0.0085
[11/50][0/36]	Loss_D:0.0452	Loss_G:5.2256	D(x):0.9800	D(G(z)):0.0221 / 0.0069
[12/50][0/36]	Loss_D:0.0332	Loss_G:5.4038	D(x):0.9833	D(G(z)):0.0148 / 0.0058
[13/50][0/36]	Loss_D:0.0370	Loss_G:5.2032	D(x):0.9815	D(G(z)):0.0171 / 0.0064
[14/50][0/36]	Loss_D:0.0326	Loss_G:5.5015	D(x):0.9838	D(G(z)):0.0149 / 0.0053
[15/50][0/36]	Loss_D:0.0368	Loss_G:5.4651	D(x):0.9872	D(G(z)):0.0162 / 0.0055
[16/50][0/36]	Loss_D:0.0349	Loss_G:5.6891	D(x):0.9849	D(G(z)):0.0186 / 0.0047
[17/50][0/36]	Loss_D:0.0214	Loss_G:5.5402	D(x):0.9925	D(G(z)):0.0133 / 0.0048
[18/50][0/36]	Loss_D:0.0216	Loss_G:5.6668	D(x):0.9912	D(G(z)):0.0123 / 0.0041
[19/50][0/36]	Loss_D:0.0219	Loss_G:5.6475	D(x):0.9919	D(G(z)):0.0132 / 0.0046
[20/50][0/36]	Loss_D:0.0165	Loss_G:5.7313	D(x):0.9956	D(G(z)):0.0118 / 0.0040
[21/50][0/36]	Loss_D:0.0203	Loss_G:5.7859	D(x):0.9939	D(G(z)):0.0138 / 0.0040
[22/50][0/36]	Loss_D:0.0266	Loss_G:5.7094	D(x):0.9850	D(G(z)):0.0104 / 0.0040
[23/50][0/36]	Loss_D:0.0207	Loss_G:5.7429	D(x):0.9899	D(G(z)):0.0101 / 0.0038
...
[46/50][0/36]	Loss_D:0.0100	Loss_G:6.6160	D(x):0.9945	D(G(z)):0.0044 / 0.0024
[47/50][0/36]	Loss_D:0.0114	Loss_G:7.1434	D(x):0.9927	D(G(z)):0.0025 / 0.0017
[48/50][0/36]	Loss_D:0.0039	Loss_G:7.2856	D(x):0.9980	D(G(z)):0.0019 / 0.0012
[49/50][0/36]	Loss_D:0.0198	Loss_G:6.2926	D(x):0.9882	D(G(z)):0.0048 / 0.0029

结果可视化

plt.figure(figsize=(10, 5))
plt.title('Generator and Discriminator Loss During Training')
plt.plot(G_losses, label='G')
plt.plot(D_losses, label='D')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.show()

阿哲,训练效果好差,不知道是不是硬件的问题
请添加图片描述

fig = plt.figure(figsize=(8, 8))
plt.axis('off')

ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/804013.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

7.28

1.思维导图 2.qt的sever #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include<QTcpServer> //服务器类 #include<QTcpSocket> //客户端类 #include<QMessageBox> //对话框类 #include<QList> …

计组 [指令系统] 预习题目

PPT第5章 第2部分预习题目 预习内容及相关问题 什么是R,I,J型指令&#xff0c;它们的特点&#xff1f; &#xff08;二&#xff09;R型指令的数据通路 &#xff08;指令功能与其对应的逻辑结构&#xff09; 功能&#xff1a;R[rd] ← R[rs] op R[rt]&#xff0c;如&#xff1a…

React的UmiJS搭建的项目集成海康威视h5player播放插件H5视频播放器开发包 V2.1.2

最近前端的一个项目&#xff0c;大屏需要摄像头播放&#xff0c;摄像头厂家是海康威视的&#xff0c;网上找了一圈都没有React集成的&#xff0c;特别是没有使用UmiJS搭脚手架搭建的&#xff0c;所以记录一下。 海康威视的开放平台的API地址&#xff0c;相关插件和文档都可以下…

行列转换.

表abc&#xff1a; &#xff08;建表语句在文章末尾&#xff09; 想要得到&#xff1a; 方法一 with a as(select 年,产 from abc where 季1), b as(select 年,产 from abc where 季2), c as(select 年,产 from abc where 季3), d as(select 年,产 from abc where 季4) selec…

图像识别概述

图像识别的过程 图像识别技术的过程分以下几步&#xff1a; 1. 信息的获取&#xff1a; 是指通过传感器&#xff0c;将光或声音等信息转化为电信息。也就是获取研究对象的基本信息并通过某种方法将其转变为机器能够认识的信息。 2. 预处理&#xff1a; 主要是指图像处理中的…

行业追踪,2023-07-28

自动复盘 2023-07-28 凡所有相&#xff0c;皆是虚妄。若见诸相非相&#xff0c;即见如来。 k 线图是最好的老师&#xff0c;每天持续发布板块的rps排名&#xff0c;追踪板块&#xff0c;板块来开仓&#xff0c;板块去清仓&#xff0c;丢弃自以为是的想法&#xff0c;板块去留让…

Android中绘制的两个天气相关的View

文章目录 一、前言二、降雨的代码三、风向代码 一、前言 开发天气相关软件时候&#xff0c;做了两个自定义View&#xff0c;这里进行记录&#xff0c;由于涉及类较多&#xff0c;这里仅包含核心代码&#xff0c;需要调整后才可以运行&#xff0c;自定义View范围仅包含网格相关…

机器学习伦理:探讨隐私保护、公平性和透明度

文章目录 &#x1f340;引言&#x1f340;隐私保护&#x1f340;公平性&#x1f340;透明度&#x1f340;结论 随着机器学习技术的不断发展和应用&#xff0c;我们必须面对伦理问题&#xff0c;以确保这些智能系统的发展和使用是符合道德和法律规范的。本文将就机器学习伦理的关…

Revit二次开发 插件加密、打包、发布、授权全套教程

目录 代码加密及授权 添加授权工具引用 添加授权验证代码段 使用VMProtect进行代码保护 代码加密标记 代码加密 发布产品 软件打包 软件发布 相关文件的获取地址 本教程基于mxbim.com所提供的服务。 Revit二次开发 插件加密、打包、发布、授权全套教程 本网站(www.…

实锤研究,ChatGPT能力掉线!

早在一个多月前&#xff0c;ChatGPT性能下降的传闻便开始在网上流行&#xff0c;不少订阅了Plus版的用户纷纷表示&#xff0c;感觉ChatGPT在经历了几轮更新后开始降智&#xff0c;甚至有时反应速度也会出现问题。而如今&#xff0c;这一传闻终于得到了证实。 就在本周&#xf…

如何学好Java并调整学习过程中的心态:学习之路的秘诀

文章目录 第一步&#xff1a;建立坚实的基础实例分析&#xff1a;选择合适的学习路径 第二步&#xff1a;选择合适的学习资源实例分析&#xff1a;参与编程社区 第三步&#xff1a;动手实践实例分析&#xff1a;开发个人项目 调整学习过程中的心态1. 不怕失败2. 持续学习3. 寻求…

ORA-38760: This database instance failed to turn on flashback database

早晨接一个任务&#xff0c;使用rman备份在虚拟化单机上恢复实例&#xff0c;恢复参数文件、控制文件和数据文件都正常&#xff0c;recover归档时报错如下&#xff1a; Starting recover at 2023-07-28 10:25:01 using channel ORA_DISK_1 starting media recovery media reco…

实时云渲染技术:VR虚拟现实应用的关键节点

近年来&#xff0c;虚拟现实&#xff08;Virtual Reality, VR&#xff09;技术在市场上的应用越来越广泛&#xff0c;虚拟现实已成为一个热门的科技话题。相关数据显示&#xff0c;2019年至2021年&#xff0c;我国虚拟现实市场规模不断扩大&#xff0c;从2019年的282.8亿元增长…

攻防世界-Reverse-simple-unpack

题目描述&#xff1a;菜鸡拿到了一个被加壳的二进制文件 1. 思路分析 提示很清楚了&#xff0c;加壳的二进制文件&#xff0c;正好对这一块知识点是残缺的&#xff0c;先了解下加壳到底是什么 通过这段描述&#xff0c;其实加壳的目的是使得逆向起来更难了&#xff0c;因此这里…

基于SSM实现个人随笔分享平台:创作心灵,分享自我

项目简介 本文将对项目的功能及部分细节的实现进行介绍。个人随笔分享平台基于 SpringBoot SpringMVC MyBatis 实现。实现了用户的注册与登录、随笔主页、文章查询、个人随笔展示、个人随笔查询、写随笔、草稿箱、随笔修改、随笔删除、访问量及阅读量统计等功能。该项目登录模…

十六章:可靠性确实重要:一种端到端的弱监督语义分割方法

0.摘要 弱监督语义分割是一项具有挑战性的任务&#xff0c;因为它只利用图像级别的信息作为训练的监督&#xff0c;但在测试时需要产生像素级别的预测。为了应对这样一个具有挑战性的任务&#xff0c;最近最先进的方法提出了采用两步解决方案&#xff0c;即&#xff1a;1&#…

自动上传git

自动上传git 执行脚本 保存为.bat文件 echo off title bat 交互执行git命令 D: cd D:/git/test git add . git commit -m %date:~0,4%年%date:~5,2%月%date:~8,2%日 git push教程如下 1、搜索任务计划程序&#xff08;最好管理员身份运行&#xff0c;普通用户可能无权限&am…

下载JMeter的历史版本——个人推荐5.2.1版本

官网地址&#xff1a;https://archive.apache.org/dist/jmeter/binaries/

【Git|项目管理】Git的常用命令以及使用场景

文章目录 1.前言2.工作区,暂存区,版本库简介3.Git的常用命令4.版本回退5.撤销修改6.删除文件7.总结 1.前言 在学习Git命令之前,需要先了解工作区,暂存区和版本库这三个概念 2.工作区,暂存区,版本库简介 在使用Git进行版本控制时&#xff0c;有三个重要的概念&#xff1a;工作…

机器学习——异常检测

异常点检测(Outlier detection)&#xff0c;⼜称为离群点检测&#xff0c;是找出与预期对象的⾏为差异较⼤的对象的⼀个检测过程。这些被检测出的对象被称为异常点或者离群点。异常点&#xff08;outlier&#xff09;是⼀个数据对象&#xff0c;它明显不同于其他的数据对象。异…