CGAN笔记总结第二弹~

news2024/11/23 19:55:05

CGAN原理与源码分析

  • 一、复习GAN
    • 1.1损失函数
    • 1.2判别器源码
    • 1.3 生成器源码
  • 二、什么是CGAN?
    • 2.1 CGAN原理图
    • 2.2条件GAN的损失函数
    • 2.3 生成器源码
    • 2.4 判别器源码
    • 2.5 训练过程
      • 1)这里的训练顺序
      • 2)为什么先训练判别器后训练生成器呢?
    • 2.6 训练过程运行结果
    • 2.7测试结果
      • 1)测试代码

一、复习GAN

生成式对抗网络(Generative Adversarial Networks)是让两个神经网络进行博弈进行学习。基础结构包含生成器和判别器。生成器的目标是生成与真实图片相似的图片,以假乱真,尽可能地让判别器判断生成的图片是真实的。判别器的目标是能够区分真实图片和生成图片。生成器和判别器通过巧妙地设计损失函数,而结合在一起,在相互对抗中不断调整各自的参数,使得判别器难以判断生成器生成的图片是否真实,从而达到欺骗人眼的效果。
在    插入图片描述

1.1损失函数

在这里插入图片描述

在这里插入图片描述

1.2判别器源码

class Discriminator(nn.Module):
	def __init__(self):
		super().__init__()
		self.model = nn.Sequential(
						nn.Linear(784,1024),
						nn.LeakyReLU(0.2),
						nn.Dropout(0.3),
						nn.Linear(1024,512),
						nn.LeakyReLU(0.2),
						nn.Dropout(0.3),
						nn.Linear(512,256),
						nn.LeakyReLU(0.2),
						nn.Dropout(0.3),
						nn.Linear(256,1),
						nn.Sigmoid()
 		)
 	def forward(self, x):
 		return self.model(x)

在这里插入图片描述

1.3 生成器源码

class Generator(nn.Module):
	def __init__(self):
		super().__init__()
		self.model = nn.Sequential(
						nn.Linear(100,256),
						nn.LeakyReLU(0.2),
						
						nn.Linear(256,512),
						nn.LeakyReLU(0.2),
						
						nn.Linear(512,1024),
						nn.LeakyReLU(0.2),
					
						nn.Linear(1024,784),
						nn.Tahn()
 		)
 	def forward(self, x):
 		return self.model(x)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

二、什么是CGAN?

CGAN,全称Conditional Generative Aderversarial Networks.与GAN相比,条件GAN加入了额外信息c,从而能够生成指定的手写数字。

2.1 CGAN原理图

在这里插入图片描述

2.2条件GAN的损失函数

在这里插入图片描述
nn.BCELoss()是一个PyTorch中的损失函数,它被用于二分类问题。BCE代表二元交叉熵(Binary Cross Entropy)
这里用到的是二元交叉熵损失函数
D(x)代表的是判别器判别图片是真的概率;

2.3 生成器源码

class Generator(nn.Module):
    def __init__(self, num_channel=1, nz=100, nc=10, ngf=64):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 输入维度 110 x 1 x 1
            nn.ConvTranspose2d(nz + nc, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 特征维度 (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 特征维度 (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 特征维度 (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 特征维度 (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, num_channel, 4, 2, 1, bias=False),
            nn.Tanh()
            # 特征维度. (num_channel) x 64 x 64
        )
        self.apply(weights_init)

    def forward(self, input_z, onehot_label):
        input_ = torch.cat((input_z, onehot_label), dim=1)
        n, c = input_.size()
        input_ = input_.view(n, c, 1, 1)
        return self.main(input_)

在生成器,
随机向量z是100维的,
额外信息c是10维的,(因为手写数字包含0-9,一共10类)
在这里,采用直接拼接的方式,最终形成了110维的输入

2.4 判别器源码

class Discriminator(nn.Module):
    def __init__(self, num_channel=1, nc=10, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 输入维度 (num_c3
            # channel+nc) x 64 x 64  1*64*64的图像和10维的类别   10维类别先转换成10*64*64    然后合并就是11*64*64
            # 输入通道  输出通道   卷积核的大小   步长  填充
            #原始输入张量:b 11 64  64
            nn.Conv2d(num_channel + nc, ndf, 4, 2, 1, bias=False),   #b  64  32  32
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),  #b   64*2   16  16
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),    #b   64*4   8    8
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),    #b   64*8    4    4
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),        #b   1    1    1      其实就是一个数值,区间在正无穷到负无穷之间
            nn.Sigmoid()
        )
        self.apply(weights_init)

    def forward(self, images, onehot_label):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        h, w = images.shape[2:]
        n, nc = onehot_label.shape[:2]
        label = onehot_label.view(n, nc, 1, 1) * torch.ones([n, nc, h, w]).to(device)
        input_ = torch.cat([images, label], 1)
        return self.main(input_)

在判别器中,输入的数据有
图片x,(可能是来自真实数据集的样本,也可能是来自生成器生成的虚假样本) 维度是1 * H * W
额外信息c,维度是10维,变换到10 * 1 * 1,将后两维进行复制 变换为10 * H * W的张量;
最终拼接在一起,构成11 * H * W的输入。

2.5 训练过程


MODEL_G_PATH = "./"
LOG_G_PATH = "Log_G.txt"
LOG_D_PATH = "Log_D.txt"
IMAGE_SIZE = 64
BATCH_SIZE = 128
WORKER = 1
LR = 0.0002
NZ = 100
NUM_CLASS = 10
EPOCH = 50

data_loader = loadMNIST(img_size=IMAGE_SIZE, batch_size=BATCH_SIZE)  #原始图片宽高是28*28的,给改变成64*64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG = Generator().to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
real_label = 1.
fake_label = 0.
optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(0.5, 0.999))

g_writer = LossWriter(save_path=LOG_G_PATH)
d_writer = LossWriter(save_path=LOG_D_PATH)

fix_noise = torch.randn(BATCH_SIZE, NZ, device=device)
fix_input_c = (torch.rand(BATCH_SIZE, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)
fix_input_c = onehot(fix_input_c, NUM_CLASS)

img_list = []
G_losses = []
D_losses = []
iters = 0

print("开始训练>>>")
for epoch in range(EPOCH):

    print("正在保存网络并评估...")
    save_network(MODEL_G_PATH, netG, epoch)
    with torch.no_grad():
        fake_imgs = netG(fix_noise, fix_input_c).detach().cpu()
        images = recover_image(fake_imgs)
        full_image = np.full((5 * 64, 5 * 64, 3), 0, dtype="uint8")
        for i in range(25):
            row = i // 5
            col = i % 5
            full_image[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64, :] = images[i]
            # !!!!!!!!!!!!!!
            #每一轮次结束后,这里只展示了一批图片的前25张。
        plt.imshow(full_image)
        #plt.show()
        plt.imsave("{}.png".format(epoch), full_image)

    for data in data_loader:
       
        netD.zero_grad()
        real_imgs, input_c = data   #这里的input_c其实就是数据集每一批中的每个图片对应的标签
        input_c = input_c.to(device)
        input_c = onehot(input_c, NUM_CLASS).to(device)

        # 1.1 来自数据集的样本
        real_imgs = real_imgs.to(device)
        b_size = real_imgs.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        #上面的torch.full是生成一维的 b_size这么多的,填充值为1.的张量
        # real_label = 1.
        # fake_label = 0.

        # 使用判别器对真实数据集样本做判断
        #!!!!!!!!!!!!!
        #output应该是判别器判别一批真图片真实的概率
        output = netD(real_imgs, input_c).view(-1)   
        errD_real = criterion(output, label)
        #!!!!!!
        #errD_real是判别器识别真图片的误差,为了训练判别器判别真图片为真
        errD_real.backward()
        D_x = output.mean().item()   
        #!!!!!!!
        #D_x就是判别器判别一批真图片为真的概率的平均值

        
        # 1.2 生成随机向量   这一步想要训练判别器是否能够识别出是虚假图片
        noise = torch.randn(b_size, NZ, device=device)
        # 生成随机标签
        input_c = (torch.rand(b_size, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)
        input_c = onehot(input_c, NUM_CLASS)

        # 来自生成器生成的样本
        fake = netG(noise, input_c)
        label.fill_(fake_label)

        # real_label = 1.
        # fake_label = 0.
        # 使用判别器对生成器生成样本做判断
        #!!!!!!!!!!!
        #output应该是判别器判别一批假图片真实的概率
        output = netD(fake.detach(), input_c).view(-1)  
        errD_fake = criterion(output, label)
        # 对判别器进行梯度回传
        errD_fake.backward()
        #!!!!!!
        #errD_fake是判别器识别假图片的误差,为了训练判别器判别假图片为假
        D_G_z1 = output.mean().item()
        #!!!!!!!!!!!!
        #D_G_z1就是判别器判别一批假图片为真的概率的平均值
        errD = errD_real + errD_fake
        #!!!!!!
        #errD是判别器识别真实图片和假图片的误差和
        # 更新判别器
        optimizerD.step()


        
       
        netG.zero_grad()
        # 对于生成器训练,令生成器生成的样本为真,
        label.fill_(real_label)

        # real_label = 1.
        # fake_label = 0.
        #!!!!!!!!!!!
        #output应该是判别器判别一批假图片真实的概率
        output = netD(fake, input_c).view(-1)
        # 对生成器计算损失
        errG = criterion(output, label)
         #!!!!!!
        #errG是判别器识别假图片的误差,但是是为了训练生成器生成假图片,以假乱真
        # 因为这里判别器的角度label真实应该是0,但是站在生成器的角度,label真实应该是1,即生成器希望生成的虚假图片让判别器识别的时候,会误以为1才比较好,即误以为是真实的图片
        # 所以生成器交叉熵也是越小越好
        # 对生成器进行梯度回传
        errG.backward()
        D_G_z2 = output.mean().item()
        #!!!!!!!!!!!!
        #D_G_z2就是判别器判别一批假图片为真的概率的平均值
        # 更新生成器
        optimizerG.step()

        # 输出损失状态
        if iters % 5 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, EPOCH, iters % len(data_loader), len(data_loader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            d_writer.add(loss=errD.item(), i=iters)
            g_writer.add(loss=errG.item(), i=iters)

        # 保存损失记录
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        iters += 1

1)这里的训练顺序

这里训练的顺序是
先拿真实图片训练判别器,
再拿假图片训练判别器,
最后,拿假图片让判别器判断,来训练生成器。

2)为什么先训练判别器后训练生成器呢?

试想,假如先训练生成器,但是刚开始判别器还没有判别能力,所以达不到训练生成器,帮助生成器能越来越生成逼真的假图片。
所以,需要先训练判别器,让判别器先具有初步的判别能力,才能训练生成器,帮助生成器能够生成逼真的假图片。

2.6 训练过程运行结果

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
#errD是判别器识别真实图片和假图片的误差和,是为了训练判别器能够判别真假图片
#errG是判别器识别假图片的误差,但是是为了训练生成器生成假图片,以假乱真
#D_x就是判别器判别一批真图片为真的概率的平均值,训练判别器识别真图片
#D_G_z1就是判别器判别一批假图片为真的概率的平均值,训练判别器识别假图片
#D_G_z2就是判别器判别一批假图片为真的概率的平均值,训练生成器生成逼真的假图片

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

2.7测试结果

在这里插入图片描述

1)测试代码


NZ = 100
NUM_CLASS = 10
BATCH_SIZE = 10
DEVICE = "cpu"

netG = Generator()
netG = restore_network("./", "49", netG)
fix_noise = torch.randn(BATCH_SIZE, NZ, device=DEVICE)
fix_input_c = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
device = "cuda" if torch.cuda.is_available() else "cpu"
fix_input_c = onehot(fix_input_c, NUM_CLASS)
fix_input_c = fix_input_c.to(device)
fix_noise = fix_noise.to(device)
netG = netG.to(device)
#fake_imgs = netG(fix_noise, fix_input_c).detach().cpu()



#fix_noise = torch.randn(BATCH_SIZE, NZ, device=DEVICE)
full_image = np.full((10 * 64, 10 * 64, 3), 0, dtype="uint8")
for num in range(10):
    input_c = torch.tensor(np.ones(10, dtype="int64") * num)
    input_c = onehot(input_c, NUM_CLASS)
    fix_noise = fix_noise.to(device)
    input_c = input_c.to(device)
    fake_imgs = netG(fix_noise, input_c).detach().cpu()
    images = recover_image(fake_imgs)
    for i in range(10):
        row = num
        col = i % 10
        full_image[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64, :] = images[i]

plt.imshow(full_image)
plt.show()
plt.imsave("hah.png", full_image)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1304946.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Dijkstra求最短路 II(堆优化Dijkstra算法)

给定一个 n 个点 m 条边的有向图,图中可能存在重边和自环,所有边权均为非负值。 请你求出 11 号点到 n 号点的最短距离,如果无法从 11 号点走到 n 号点,则输出 −1−1。 输入格式 第一行包含整数 n 和 m。 接下来 m 行每行包含…

Vue指令之v-else与v-else-if

在上一篇博客中介绍了v-if,而在各式各样的程序语句中 if 和 else 通常是伴生的,在Vue中也不例外,Vue同样提供了v-else和v-else-if指令,其功能就是补充v-if的逻辑判断。 例如,当我们要根据一个分数输出对应的等级&…

路由器的转换原理--ENSP实验

目录 一、路由器的工作原理 二、路由表的形成 1、直连路由 2、非直连路由 2.1静态路由 2.2动态路由 三、静态路由和默认路由 1、静态路由 1.1静态路由的缺点 1.2路由的配置--结合ensp实验 2、默认路由--特殊的静态路由 2.1概念 2.2格式 2.3默认路由的配置--ens…

3GPP标准查看、下载和几个UE相关系列标准

由于一直做终端侧协议。最近以UE为核心重新下载了一系列文档。 总结并举例一下分类标准。 如何查看3GPP标准列表 实际上在3GPP网站如下链接:Specifications by Series,每个系列以及分类都说的很清楚。 几个系列分类举例 和终端协议层工作比较关系密切…

【Java】构建表达式二叉树和表达式二叉树求值

问题背景 1. 实现一个简单的计算器。通过键盘输入一个包含圆括号、加减乘除等符号组成的算术表达式字符串,输出该算术表达式的值。要求: (1)系统至少能实现加、减、乘、除等运算; (2)利用二叉…

LeetCode(55)环形链表【链表】【简单】

目录 1.题目2.答案3.提交结果截图 链接: 环形链表 1.题目 给你一个链表的头节点 head ,判断链表中是否有环。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,评…

啊?150水冷踏板卷到7千多,巧格的钱购买150了?

力帆的车一般我是不太想写的,但是顶不住它这个价格,实在是....,标准版售价干到了7980元,和巧格一个价了,比福喜还便宜点,属实是离离原上谱,不过这个车不太影响的了豪爵大哥的UHR,两台…

Cypress安装与使用教程(2)—— 软测大玩家

😏作者简介:博主是一位测试管理者,同时也是一名对外企业兼职讲师。 📡主页地址:【Austin_zhai】 🙆目的与景愿:旨在于能帮助更多的测试行业人员提升软硬技能,分享行业相关最新信息。…

pip list 报错 ImportError: cannot import name ‘main‘ from ‘pip._int

文章目录 报错信息问题原因解决方案 关注公众号:『AI学习星球』 算法学习、4对1辅导、论文辅导或核心期刊可以通过公众号或CSDN滴滴我 报错信息 最近在操作服务器的时候,发现pip list这个命令不好使了,报错如下 外链图片转存失败,源站可能…

[Linux] yum安装分布式LNMP架构

1. 在一台主机安装nginx(192.168.136.120) 1.1 搭建nginx相关的yum源 cd /yum.repos.d mkdir bak mv *.repo bak vim /etc/yum.repos.d/nginx.repo [nginx-stable] namenginx stable repo baseurlhttp://nginx.org/packages/centos/7/$basearch/ gpgche…

Self-Distillation from the Last Mini-Batch for Consistency Regularization中文版

Self-Distillation from the Last Mini-Batch for Consistency Regularization 从上一个小批量自发蒸馏,实现一致性正则化 摘要 知识蒸馏(Knowledge distillation,KD)展示了强大的潜力,作为一种强有力的正则化策略&a…

CETN01 - How to Use Cloud Classroom

文章目录 I. Introduction to Cloud ClassroomII. How to Use Cloud Classroom1. Publish Resources2. Conduct Activities3. Class Teaching Reports4. View Experience Values5. Performance in Cloud Classroom I. 云课堂介绍II. 如何使用云课堂1. 发布资源2. 进行活动3. 班…

C++STL之List的实现

首先我们要实现List的STL,我们首先要学会双向带头链表的数据结构。那么第一步肯定是要构建我们的节点的数据结构。 首先要有数据域,前后指针域即可。 再通过模板类进行模板化。 然后再写List的构造函数,这个地方用T&,通过引用就可以减少一次形参拷…

Android 蓝牙BluetoothAdapter 相关(一)

Android 蓝牙相关 本文主要讲述android 蓝牙的简单使用. 1: 是否支持蓝牙 /*** 是否支持蓝牙** return*/ private boolean isSupportBluetooth() {BluetoothAdapter bluetoothAdapter BluetoothAdapter.getDefaultAdapter();return bluetoothAdapter ! null; }2: 开启蓝牙 …

强大的音频编辑器 Metadatics直装 for mac

Metadatics是一款Mac上的音频元数据编辑器,功能强大且高级。它支持批量编辑最常见的音频文件类型,包括MP3、M4A、AIFF、WAV、FLAC、APE、OGG、WMA等。它可以从在线资源中查找元数据,根据元数据重命名文件,或使用众多内置函数之一来…

Mysql、Oracle安全项检查表及操作脚本

软件开发全资料获取:点我获取 Mysql检查表 Oracle检查表

【Canvas】记录一次从0到1绘制风场空间分布图的过程

前言 📫 大家好,我是南木元元,热衷分享有趣实用的文章,希望大家多多支持,一起进步! 🍅 个人主页:南木元元 目录 背景 前置知识 风场数据 绘制风场 准备工作 生成二维网格 获取…

ppt转换成pdf文件

最近用到了,记一下; ppt转pdf分为两种情况: 小于2007版本的 .ppt格式(2003) 与大于2007版本的 .pptx格式(2007) .ppt格式为 二进制文件 .pptx格式为xml格式,在java中有不同的jar包需要使用 引入…

MacOS 12 开放指定端口 指定ip访问

MacOS 12 开放指定端口 指定ip访问 在 macOS 上开放一个端口,并指定只能特定的 IP 访问,你可以使用 macOS 内置的 pfctl(Packet Filter)工具来实现。 以下是一些基本的步骤: 1、 编辑 pf 配置文件: 打开 /…

Dockerfile创建镜像--LNMP+wordpress

实验准备: nginx:172.111.0.10 docker-nginx mysql:172.111.0.20 docker-mysql php:172.111.0.30 docker-php 自定义网段:172.111.0.0/16mkdir nginx mysql php mv nginx-1.22.0.tar.gz wordpress-6.4.2-zh_CN.ta…