CGAN原理讲解与源码

news2024/11/26 11:43:04

1.CGAN原理

生成器,输入的是c和z,z是随机噪声,c是条件,对应MNIST数据集,要求规定生成数字是几。
输出是生成的虚假图片。
在这里插入图片描述

判别器的输入是
1.生成器输出的虚假图片x;
2.对应图片的标签c

在这里插入图片描述
来自真实数据集,且标签是对的,就是1
如果是生成器生成的虚假照片就直接是1,都不需要看是否与标签对应

上面第二张图的意思就是,当图片是来自真实数据集,再来看是否与标签对应

2.CGAN损失函数

在这里插入图片描述
上面这个值,生成器越小越好,即判别器认为真实图片是真实图片的概率越低越好,认为虚假图片是真实图片的概率越高越好
判别器越大越好,即判别器认为真实图片是真实图片的概率越大越好,认为虚假图片是真实图片的概率越小越好

criterion(output, label)

在判别器中,
1)output是预测来自真实数据集的图片和标签是否是真实且符合标签的概率,label是1
2)output是预测虚假图片是否是虚假图片的概率,label是0
在生成器中,
output是判别器预测虚假图片是否是真实图片的概率,label是1
以上三种,都是交叉熵越小越好

3.生成器和判别器的源码

class Generator(nn.Module):
    def __init__(self, num_channel=1, nz=100, nc=10, ngf=64):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 输入维度 110 x 1 x 1
            nn.ConvTranspose2d(nz + nc, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 特征维度 (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 特征维度 (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 特征维度 (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 特征维度 (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, num_channel, 4, 2, 1, bias=False),
            nn.Tanh()
            # 特征维度. (num_channel) x 64 x 64
        )
        self.apply(weights_init)

    def forward(self, input_z, onehot_label):
        input_ = torch.cat((input_z, onehot_label), dim=1)
        n, c = input_.size()
        input_ = input_.view(n, c, 1, 1)
        return self.main(input_)


class Discriminator(nn.Module):
    def __init__(self, num_channel=1, nc=10, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 输入维度 (num_c3
            # channel+nc) x 64 x 64  1*64*64的图像和10维的类别   10维类别先转换成10*64*64    然后合并就是11*64*64
            # 输入通道  输出通道   卷积核的大小   步长  填充
            #原始输入张量:b 11 64  64
            nn.Conv2d(num_channel + nc, ndf, 4, 2, 1, bias=False),   #b  64  32  32
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),  #b   64*2   16  16
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),    #b   64*4   8    8
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),    #b   64*8    4    4
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),        #b   1    1    1      其实就是一个数值,区间在正无穷到负无穷之间
            nn.Sigmoid()
        )
        self.apply(weights_init)

    def forward(self, images, onehot_label):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        h, w = images.shape[2:]
        n, nc = onehot_label.shape[:2]
        label = onehot_label.view(n, nc, 1, 1) * torch.ones([n, nc, h, w]).to(device)
        input_ = torch.cat([images, label], 1)
        return self.main(input_)

4.训练过程



MODEL_G_PATH = "./"
LOG_G_PATH = "Log_G.txt"
LOG_D_PATH = "Log_D.txt"
IMAGE_SIZE = 64
BATCH_SIZE = 128
WORKER = 1
LR = 0.0002
NZ = 100
NUM_CLASS = 10
EPOCH = 10

data_loader = loadMNIST(img_size=IMAGE_SIZE, batch_size=BATCH_SIZE)  #原始图片宽高是28*28的,给改变成64*64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG = Generator().to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
real_label = 1.
fake_label = 0.
optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(0.5, 0.999))

g_writer = LossWriter(save_path=LOG_G_PATH)
d_writer = LossWriter(save_path=LOG_D_PATH)

fix_noise = torch.randn(BATCH_SIZE, NZ, device=device)
fix_input_c = (torch.rand(BATCH_SIZE, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)
fix_input_c = onehot(fix_input_c, NUM_CLASS)

img_list = []
G_losses = []
D_losses = []
iters = 0

print("开始训练>>>")
for epoch in range(EPOCH):

    print("正在保存网络并评估...")
    save_network(MODEL_G_PATH, netG, epoch)
    with torch.no_grad():
        fake_imgs = netG(fix_noise, fix_input_c).detach().cpu()
        images = recover_image(fake_imgs)
        full_image = np.full((5 * 64, 5 * 64, 3), 0, dtype="uint8")
        for i in range(25):
            row = i // 5
            col = i % 5
            full_image[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64, :] = images[i]
        plt.imshow(full_image)
        #plt.show()
        plt.imsave("{}.png".format(epoch), full_image)

    for data in data_loader:
        #################################################
        
        #判别器交叉熵越小越好
        # 1. 更新判别器D: 最大化 log(D(x)) + log(1 - D(G(z)))
        # 等同于最小化 - log(D(x)) - log(1 - D(G(z)))
        #################################################
        netD.zero_grad()
        real_imgs, input_c = data   #这里的input_c其实就是数据集每一批中的每个图片对应的标签
        input_c = input_c.to(device)
        input_c = onehot(input_c, NUM_CLASS).to(device)

        # 1.1 来自数据集的样本
        #这里这一步就是想训练判别器,能够识别出是否真实图片,以及图片与对应的标签是否对应
        real_imgs = real_imgs.to(device)
        b_size = real_imgs.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        #上面的torch.full是生成一维的 b_size这么多的,填充值为1.的张量
        # real_label = 1.
        # fake_label = 0.

        # 使用鉴别器对数据集样本做判断
        output = netD(real_imgs, input_c).view(-1)   #view() 方法被用来将模型输出的张量进行扁平化操作,即将张量中的所有元素都展开成一个一维向量
        # 计算交叉熵损失 -log(D(x))
        errD_real = criterion(output, label)
        # 对判别器进行梯度回传
        errD_real.backward()
        D_x = output.mean().item()    #对同一批预测结果的交叉熵取平均值


        #
        # 1.2 生成随机向量   这一步想要训练判别器是否能够识别出是虚假图片
        noise = torch.randn(b_size, NZ, device=device)
        # 生成随机标签
        input_c = (torch.rand(b_size, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)
        input_c = onehot(input_c, NUM_CLASS)

        #fix_noise = torch.randn(BATCH_SIZE, NZ, device=device)
        #fix_input_c = (torch.rand(BATCH_SIZE, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)
        #fix_input_c = onehot(fix_input_c, NUM_CLASS)


        # 来自生成器生成的样本
        fake = netG(noise, input_c)
        label.fill_(fake_label)

        # real_label = 1.
        # fake_label = 0.
        # 使用鉴别器对生成器生成样本做判断
        output = netD(fake.detach(), input_c).view(-1)   #view() 方法被用来将模型输出的张量进行扁平化操作,即将张量中的所有元素都展开成一个一维向量
        # 计算交叉熵损失 -log(1 - D(G(z)))
        errD_fake = criterion(output, label)
        # 对判别器进行梯度回传
        errD_fake.backward()
        D_G_z1 = output.mean().item()

        # 对判别器计算总梯度,-log(D(x))-log(1 - D(G(z)))
        errD = errD_real + errD_fake
        # 更新判别器
        optimizerD.step()


        
        #################################################
        # 2. 更新生成器G: 最小化 log(D(x)) + log(1 - D(G(z))),
        # 等同于最小化log(1 - D(G(z))),即最小化-log(D(G(z)))
        # 也就等同于最小化-(log(D(G(z)))*1+log(1-D(G(z)))*0)
        # 令生成器样本标签值为1,上式就满足了交叉熵的定义
        #################################################
        netG.zero_grad()
        # 对于生成器训练,令生成器生成的样本为真,
        label.fill_(real_label)

        # real_label = 1.
        # fake_label = 0.
        output = netD(fake, input_c).view(-1)
        # 对生成器计算损失
        errG = criterion(output, label)

        # 因为这里判别器的角度label真实应该是0,但是站在生成器的角度,label真实应该是1,即生成器希望生成的虚假图片让判别器识别的时候,会误以为1才比较好,即误以为是真实的图片
        # 所以生成器交叉熵也是越小越好
        # 对生成器进行梯度回传
        errG.backward()
        D_G_z2 = output.mean().item()
        # 更新生成器
        optimizerG.step()

        # 输出损失状态
        if iters % 5 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, EPOCH, iters % len(data_loader), len(data_loader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            d_writer.add(loss=errD.item(), i=iters)
            g_writer.add(loss=errG.item(), i=iters)

        # 保存损失记录
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        iters += 1

5.关于交叉熵

熵代表确定性,熵越小越好,说明确定性越好
在这里,因为参照的是真实标签,它的熵是0
而交叉熵-熵=相对熵
故相对熵在预测情况相对真实情况的时候,相对熵=交叉熵,相对熵越小,说明预测情况越接近真实情况;
同理,交叉熵越小,说明预测情况越接近真实情况。

在二分类0,1任务中,经过卷积、正则化、激活函数ReLU等操作之后,假如生成了一个(B,1,1,1)的张量,每个值在(无穷小,无穷大)之间,经过sigmoid函数,会变成一个(B,1,1,1)的张量,数值h在(0,1)之间,如果这个h>0.5说明模型预测的是1,如果h<0.5说明模型预测的是0,但是这是模型预测的标签值y*,而还有个真实标签值y。假如现在h=0.6,那么说明模型预测的标签y*是1,真实标签却是0,

交叉熵= -y(lgh) -(1-y)(lg(1-h))
即当y=1时,交叉熵是-lgh 这个情况下,h越大越好
当y=0时,交叉熵是-(lg(1-h)) 这个情况下,h越小越好

6.训练过程运行结果

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

7.测试结果

在这里插入图片描述

测试代码


NZ = 100
NUM_CLASS = 10
BATCH_SIZE = 10
DEVICE = "cpu"

# fix_input_c = (torch.rand(BATCH_SIZE, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(DEVICE)

netG = Generator()
netG = restore_network("./", "49", netG)
fix_noise = torch.randn(BATCH_SIZE, NZ, device=DEVICE)
fix_input_c = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
device = "cuda" if torch.cuda.is_available() else "cpu"
fix_input_c = onehot(fix_input_c, NUM_CLASS)
fix_input_c = fix_input_c.to(device)
fix_noise = fix_noise.to(device)
netG = netG.to(device)
#fake_imgs = netG(fix_noise, fix_input_c).detach().cpu()

# images = recover_image(fake_imgs)
# full_image = np.full((1 * 64, 10 * 64, 3), 0, dtype="uint8")
# for i in range(10):
#     row = i // 10
#     col = i % 10
#     full_image[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64, :] = images[i]

#fix_noise = torch.randn(BATCH_SIZE, NZ, device=DEVICE)
full_image = np.full((10 * 64, 10 * 64, 3), 0, dtype="uint8")
for num in range(10):
    input_c = torch.tensor(np.ones(10, dtype="int64") * num)
    input_c = onehot(input_c, NUM_CLASS)
    fix_noise = fix_noise.to(device)
    input_c = input_c.to(device)
    fake_imgs = netG(fix_noise, input_c).detach().cpu()
    images = recover_image(fake_imgs)
    for i in range(10):
        row = num
        col = i % 10
        full_image[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64, :] = images[i]

plt.imshow(full_image)
plt.show()
plt.imsave("hah.png", full_image)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1260011.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大模型下交互式数据挖掘的探索与发现

在这个数据驱动的时代&#xff0c;数据挖掘已成为解锁信息宝库的关键。过去&#xff0c;我们依赖传统的拖拉拽方式来建模&#xff0c;这种方式在早期的数据探索中起到了作用&#xff0c;但随着数据量的激增和需求的多样化&#xff0c;它的局限性逐渐显露。 >>>> 首…

selenium已知一个元素定位同级别的另一个元素

1.需求与实际情况 看下图来举例 &#xff08;1&#xff09;需求 想点击test22&#xff08;即序号-第9行&#xff09;这一行中右边的“复制”这一按钮 &#xff08;2&#xff09;实际情况 只能通过id或者class定位到文件名这一列的元素&#xff0c;而操作这一列的元素是不…

知识变现的未来:解析知识付费系统的核心

随着数字时代的发展&#xff0c;知识付费系统作为一种新兴的学习和知识分享模式&#xff0c;正逐渐引领着知识变现的未来。本文将深入解析知识付费系统的核心技术&#xff0c;揭示其在知识经济时代的重要性和潜力。 1. 知识付费系统的基本架构 知识付费系统的核心在于其灵活…

为什么要用 Redis 而不用 map/guava 做缓存? Redis为什么这么快 Redis有哪些数据类型 Redis的应用场景

文章目录 为什么要用 Redis 而不用 map/guava 做缓存?Redis为什么这么快Redis有哪些数据类型Redis的应用场景总结一计数器缓存会话缓存全页缓存&#xff08;FPC&#xff09;查找表消息队列(发布/订阅功能)分布式锁实现 总结二 简单的聊聊Redis常见的一些疑问点&#xff1a;具体…

预算削减与经济动荡:2024 年明智且经济地创新

如何在经济衰退周期中保持创新&#xff1f;这篇创新研究提供了实用建议。在经济下行压力下领导者往往会试图降低成本和维持生存。然而&#xff0c;这种二元对立的压力往往会导致领导者做出不够理想的决策&#xff0c;更多地关注生存而不是未来投资。本文提供了一系列实用的建议…

蓝桥杯-平方和(599)

【题目】平方和 【通过测试】代码 import java.util.Scanner; import java.util.ArrayList; import java.util.List; // 1:无需package // 2: 类名必须Main, 不可修改public class Main {public static void main(String[] args) {Scanner scan = new Scanner(System.in);//在…

MATLAB中FFT频谱分析使用详解

文章目录 语法说明语法一&#xff1a;Y fft(X)fft(X)返回X长度的傅里叶变换 语法二&#xff1a;Y fft(X,N)如果 X的长度小于 N&#xff0c;则为 X补上尾零以达到长度 N(FFT插值)双边谱转换为单边谱 如果 X 的长度大于 N&#xff0c;则对 X 进行截断以达到长度 N。 语法三&…

根据密码构成规则检测密码字符串

从键盘输入密码字符串&#xff0c;程序根据给定密码构成规则检测并给出对应提示。 (笔记模板由python脚本于2023年11月27日 19:27:47创建&#xff0c;本篇笔记适合熟悉Python字符串str对象的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网&#xff1a;https://www.python.…

ViLT 论文精读【论文精读】

ViLT 论文精读【论文精读】_哔哩哔哩_bilibili 目录 ViLT 论文精读【论文精读】_哔哩哔哩_bilibili 1 地位 2 ViLT做了什么能让它成为这种里程碑式的工作&#xff1f; 3 ViLT到底把模型简化到了什么程度&#xff1f;到底能加速到什么程度&#xff1f; 2.1 过去的方法是怎…

C++之算术生成算法

C之算术生成算法 accumulate #include<iostream> using namespace std; #include<vector> #include<numeric>void test() {vector<int> v;for (int i 0; i < 10; i){v.push_back(i);}int total accumulate(v.begin(), v.end(),0);cout << t…

论文公式工具

论文公式工具 https://www.latexlive.com/home## 论文图片识别公式网页工具&#xff0c;免费的方便但是有限制次数&#xff0c;一天只能用三次公式图片识别。 先注册登录 我们到论文中截取一张图片 在识别得到的一串码中&#xff0c;删掉前面没用的 输出为这个格式&#x…

从零构建属于自己的GPT系列1:预处理模块(逐行代码解读)、文本tokenizer化

1 训练数据 在本任务的训练数据中&#xff0c;我选择了金庸的15本小说&#xff0c;全部都是txt文件 数据打开后的样子 数据预处理需要做的事情就是使用huggingface的transformers包的tokenizer模块&#xff0c;将文本转化为token 最后生成的文件就是train_novel.pkl文件&a…

【MATLAB】LMD分解+FFT+HHT组合算法

有意向获取代码&#xff0c;请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 LMDFFTHHT组合算法是一种基于局部均值分解&#xff08;LMD&#xff09;、快速傅里叶变换&#xff08;FFT&#xff09;和希尔伯特-黄变换&#xff08;HHT&#xff09;的组合算法。 LMD是…

什么是数据增强,为什么会让模型更健壮?

在做一些图像分类训练任务时&#xff0c;我们经常会遇到一个很尴尬的情况&#xff0c;那就是&#xff1a; 明明训练数据集中有很多可爱猫咪的照片&#xff0c;但是当我们给训练好的模型输入一张戴着头盔的猫咪进行测试时&#xff0c;模型就不认识了&#xff0c;或者说识别精度…

87基于matlab的双卡尔曼滤波算法

基于matlab的双卡尔曼滤波算法。第一步使用了卡尔曼滤波算法&#xff0c;用电池电压来修正SOC&#xff0c;然后将修正后的SOC作为第二个卡尔曼滤波算法的输入&#xff0c;对安时积分法得到的SOC进行修正&#xff0c;最终得到双卡尔曼滤波算法SOC估计值。结合EKF算法和安时积分法…

企业联系方式真的那么难获取吗?

企业联系方式的重要性&#xff0c;相信每一个销售人员都是知道的。对于销售来说&#xff0c;获取准确、全面的企业联系方式&#xff0c;无疑是开发客户的基础与保障&#xff0c;任凭能力再高&#xff0c;说服能力多强&#xff0c;没有与客户接触的机会&#xff0c;这些都是无稽…

CAN总线星型连接器及特点

CAN总线星型连接特点 CAN总线是一种广泛应用于汽车、工业自动化、家庭等领域的现场总线技术。它具有高速度、高可靠性、灵活性等特点&#xff0c;被广泛应用于汽车电子、工业自动化、家庭自动化等领域。在CAN总线的实际应用中&#xff0c;其连接方式可以是星型或菊花型。本文将…

Pycharm在debug问题解决方案

Pycharm在debug问题解决方案 前言一、Frames are not available二、查看变量时一直显示collecting data并显示不了任何内容 前言 Pycharm在debug时总是出现一些恼人的问题&#xff0c;以下是博主在训练中遇到的问题及在网上找到的可用解决方案&#xff1a; 一、Frames are not…

自己动手写编译器:golex 和 flex 比较研究 2

上一节我们运行了 gcc 使用的词法解析器&#xff0c;使用它从.l 文件中生成对应的词法解析程序。同时我们用相同的词法规则对 golex 进行测试&#xff0c;发现 golex 同样能实现相同功能&#xff0c;当然这个过程我们也发现了 golex 代码中的不少 bug&#xff0c;本节我们继续对…

基于单片机病房呼叫程序和仿真

如果学弟学妹们在毕设方面有任何问题&#xff0c;随时可以私信我咨询哦&#xff0c;有问必答&#xff01;学长专注于单片机相关的知识&#xff0c;可以解决单片机设计、嵌入式系统、编程和硬件等方面的难题。 愿毕业生有力&#xff0c;陪迷茫着前行&#xff01; 一、系统方案 1…