GAN与DCGAN

news2024/9/26 12:12:38

GAN:生成对抗网络,首先是一个生成模型,区别与之前的辨别模型,对抗体现在生成器与辨别器之间的对抗。
生成器输入的是噪音,通过多层的MLP可以产生图片,将产生的图片和真实图片输入到辨别器,辨别器进行分辨生成的图片是否是真实的图片,如果是输出1,不是输出0。
GAN主要的优化公式:
在这里插入图片描述
1:固定G,训练D,真实数据x希望被D分为1,生成数据z希望被D分为0。根据log函数性质,如果x被错分为0的话,那么logD(x)就会变为负无穷小。如果生成数据G(z)被错分为1的话,那个log(1-1)也是负无穷小。所以要最大化D。
2:固定D,训练G,第一项没有G,跳过,第二项生成数据G(z)目的是,最理想状态是骗过分辨器D,所以他希望D(G(z))为1,则log0为负无穷小,所以要最大化G。
在这里插入图片描述
看一下算法步骤,即如何进行梯度传播。我们可以看到分别对生成器和辨别器进行更新,更新辨别器时同时将真实图片和生成图片输入到辨别器,更新生成器时,将生成的图片输入进去。
1:也就是生成器输入的是噪音,经过多层感知机即linear层后,输入和真实图片大小一样的图片。
2:将真实图片和生成图片共同输入到辨别器进行损失计算。
3:生成器和辨别器各自更新,互不影响。


如何将GAN和CNN结合起来。DCGAN应运而生。
DCGAN将GAN的生成器和辨别器替换为CNN。模型生成器结构:噪声首先经过一个线性层,然后在review为图片,再经过转置卷积进行上采样。
在这里插入图片描述
辨别器结构:因为输出的是一个概率,所以最后大小为(Batchsize,1)。生成器产生的图片输入到辨别器然后经过步长为2的卷积进行下采样,不使用池化是因为卷积可以学习如何进行下采样。最后review为2维,经过一个linear层后紧接一个sigmoid获得最终的概率。
在这里插入图片描述
且包括一些细节部分:
在这里插入图片描述
代码:参考添加链接描述

# -*- coding: utf-8 -*-
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

# 加载数据
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=0.5, std=0.5)])

train_ds = torchvision.datasets.MNIST('/home/Projects/ZQB/a/dataset',
                                      train=True,
                                      transform=transform,
                                      download=False)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)


# 定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(100, 256 * 7 * 7)
        self.bn1 = nn.BatchNorm1d(256 * 7 * 7)
        self.deconv1 = nn.ConvTranspose2d(256, 128,
                                          kernel_size=(3, 3),
                                          stride=1,
                                          padding=1
                                          )  # 得到128*7*7的图像
        self.bn2 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1  # 64*14*14
                                          )
        self.bn3 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64, 1,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1  # 1*28*28
                                          )

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.bn1(x)
        x = x.view(-1, 256, 7, 7)
        x = F.relu(self.deconv1(x))
        x = self.bn2(x)
        x = F.relu(self.deconv2(x))
        x = self.bn3(x)
        x = torch.tanh(self.deconv3(x))
        return x


# 定义判别器
# input:1,28,28
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2)  # 第一层不适用bn  64,13,13
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)  # 128,6,6
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128 * 6 * 6, 1)  # 输出一个概率值

    def forward(self, x):
        x = F.dropout2d(F.leaky_relu(self.conv1(x)))
        x = F.dropout2d(F.leaky_relu(self.conv2(x)))  # (batch, 128,6,6)
        x = self.bn(x)
        x = x.view(-1, 128 * 6 * 6)  # (batch, 128,6,6)--->  (batch, 128*6*6)
        x = torch.sigmoid(self.fc(x))
        return x


# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)

# 损失计算函数
loss_function = torch.nn.BCELoss()

# 定义优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)

test_input = torch.randn(16, 100, device=device)

# 开始训练
D_loss = []
G_loss = []
# 训练循环
for epoch in range(30):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)
    # 对全部的数据集做一次迭代
    for i, (img, _) in enumerate(dataloader):
        img = img.to(device)
        size = img.shape[0]  # 返回img的第一维的大小
        random_noise = torch.randn(size, 100, device=device)

        d_optim.zero_grad()  # 将上述步骤的梯度归零
        real_output = dis(img)  # 对判别器输入真实的图片,real_output是对真实图片的预测结果
        d_real_loss = loss_function(real_output,torch.ones_like(real_output, device=device))
        d_real_loss.backward()  # 求解梯度

        # 得到判别器在生成图像上的损失
        gen_img = gen(random_noise)
        fake_output = dis(gen_img.detach())
        d_fake_loss = loss_function(fake_output,torch.zeros_like(fake_output, device=device))
        d_fake_loss.backward()

        d_loss = d_real_loss + d_fake_loss
        d_optim.step()  # 优化

        # 得到生成器的损失
        g_optim.zero_grad()
        fake_output = dis(gen_img)
        g_loss = loss_function(fake_output,torch.ones_like(fake_output, device=device))
        g_loss.backward()
        g_optim.step()
        torchvision.utils.save_image(gen_img,fp='/home/Projects/ZQB/a/DCGAN/DCGAN/result/result'+f"image_{epoch}.png")
    print('Epoch:', epoch)

1:最主要看一下三个损失计算:
判别器两个:真实输出,希望判别器判为1,用torch.ones_like,生成的输出即fake输出,希望判别器输出为0,用torch.zeros_like

        d_real_loss = loss_function(real_output,torch.ones_like(real_output, device=device))
        d_fake_loss = loss_function(fake_output,torch.zeros_like(fake_output, device=device))

生成器一个:我们希望生成器输出的图片骗过判别器即希望判别器输出为1,torch.ones_like

        g_loss = loss_function(fake_output,torch.ones_like(fake_output, device=device))

2:生成器和判别器采用各自的优化器和各自的反向传播。
3:训练30代后将生成的结果用grid格式保存下来看一下::
epoch1:
在这里插入图片描述
epoch15:
在这里插入图片描述
epoch30:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/495091.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AI绘画5大免费工具

AI现在最火爆的两个方向一个是以ChatGPT为主导的文本生成工具,还有一个就是以Midjourne为主导的文本生成图片工具。 Midjourne 现在基本是都是需要收费的,但确实Midjourne的效果是顶尖的,如果我们只是想试一下 文本生图的过程,这里…

【ArcGIS Pro二次开发】(26):数据筛选器

在使用【OpenItemDialog】打开数据时,其中一个重要的属性【Filter】,可用于筛选要打开的数据。示例代码如下: // 打开文件对话框OpenItemDialog dlg new OpenItemDialog(){Title "选择要打开的文本文件",Filter ItemFilters.Dat…

如何用ChatGPT写专业方向的科普内容?

该场景对应的关键词库(13个): 目标用户、科普内容、生活问题、医疗类型、科普文章、病情症状、通俗性、专业名词、背景资质、权威领域、执业范围、证言人、内容形式。 提问模板(3个): 第一步,…

打包工具--pyinstaller

下载库 pip install pyinstaller 打包命令 Pyinstaller -D setup.py 打包exePyinstaller -F -w run.py 不带控制台的打包Pyinstaller -F -i xx.ico setup.py 打包指定exe图标打包 ❝ -D:打包为一个文件夹,其中exe文件在文件夹内部,这样子单个…

更换外线和智能电表后家里用电频繁跳闸的检修

老家的电路老是跳闸。今天检修了老家的线路,故障就是更换了外线路后,家里烧水或者用电磁炉就频繁跳闸。其实也说不清楚,因为最近又改了智能表嘛。 到电表处观察,是插卡智能表,电表进线有个空开C63A。电表出来有个空开C…

万字长文 - Nature 综述系列 - 给生物学家的机器学习指南 4 (生物应用的挑战)...

万字长文 - Nature 综述系列 - 给生物学家的机器学习指南 1 万字长文 - Nature 综述系列 - 给生物学家的机器学习指南 2 (传统机器学习方法如何选择) 万字长文 - Nature 综述系列 - 给生物学家的机器学习指南 3 (人工神经网络) 生…

C++实践模拟(stack,queue priority_queue,仿函数)

stack和queue的实现,不同于vector和list那般复杂,如果你经历过vector和list的洗礼,那么当你看到stack和queue的大致实现时,你可能会惊叹,怎么能这么简洁。其原因有很多方面的,比如stack和queue不需要实现迭…

第11届蓝桥杯国赛真题剖析-2020年10月31日Scratch编程初中级组

[导读]:超平老师的《Scratch蓝桥杯真题解析100讲》已经全部完成,后续会不定期解读蓝桥杯真题,这是Scratch蓝桥杯真题解析第129讲。 第11届蓝桥杯Scratch国赛真题,这是2020年10月31日举办的全国总决赛,由于疫情影响&am…

【图像】图像格式(3) : BMP

1. 背景 BMP可以说是图像中最简单的格式了,没有图像压缩算法,基本可以看做图像的RGB裸数据加了一些基本的metadata构成。 这也导致了bmp的文件一般都是非常的大,除了windows原生的支持之外(从1990年的windows3.0开始)…

破事精英2◎爬向未来

胡强的2033未免有些过去可怕,海星果然又是反派。 只剩“脑子”的胡强 400百斤只剩“嘴”的庞小白 将自己身体分成一个个“方块”的苏克杰 苍蝇满天飞“衣服堆”的金若愚 “脑子”送到月球打两份工的沙乐乐 有机器人或者分身帮我们干活赚钱,我们去吃喝玩…

FM33A048B 红外调制

TZBRG寄存器保存一个 11 位的分频系数 X ,其值为 0~2047 之间的任一整数。 6 路 UART 共用一个红外调制频率发生器。 红外调制频率计算公式: FIR FAPBCLK/ (TZBRGTZBRG 1) 红外调制的方式为:发送数据0 时调制红外频率,发送数据 1…

JavaScript实现输入两个数比较两个数的大小,输出个人信息的两个程序代码

以下为实现输入两个数比较两个数的大小,输出个人信息的两个程序代码和运行截图 目录 前言 一、实现输入两个数比较两个数的大小 1.1 运行流程及思想 1.2 代码段 1.3 JavaScript语句代码 1.4 运行截图 二、输出个人信息 2.1 运行流程及思想 2.2 代码段 2.3…

Java每日一练(20230506) 全排列II、岛屿数量、有效数独

目录 1. 全排列 II 🌟🌟 2. 岛屿数量 🌟🌟 3. 有效的数独 🌟🌟 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每日一练 专栏 C/C每日一练 专栏 Java每日一练 专栏 …

atbf中imu数据读取逻辑分析仪抓取

一、说明 使用逻辑分析仪抓区imu的spi和中断io的信号,从而侧面描述atbf在imu上的数据读取方式; 二、硬件说明 1、硬件材料 1、mcu at32F437开发板 2、imu icm42688p 3、逻辑分析仪 梦源逻辑分析仪 4、调试器 jlink 2、原理图 3、实物图 4、固…

【git】git lfs

目录 原理 使用方法 报错记录 certificate signed by unknown authority 原理 项目中的大文件会很占空间。 git lfs(large file storage)将大文件替换为小指针, 当真正需要到这些大文件的时候, 才会从本地或者远端的lfs缓存中下载这些大文件. git lfs拥有本地lfs缓存和远端…

ubuntu系统版本查询命令方法

目录 一、使用命令:cat /proc/version 查看 二、 使用命令:uname -a 查看 三、 使用命令:lsb_release -a 查看 四、使用命令:hostnamectl 查看 五、使用命令:cat /etc/issue 查看 一、使用命令:cat /…

LeetCode:21. 合并两个有序链表

21. 合并两个有序链表 1)题目2)思路3)代码4)结果 1)题目 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1: 输入:l1 [1,2,4], l2…

百度地图API介绍

4. 百度地图api 介绍 1. api开发文档 1.2 区别 JavaScript API v3.0 JavaScript API v3.0 链接 ,百度地图JavaScript API是一套由JavaScript语言编写的应用程序接口,可帮助您在网站中构建功能丰富、交互性强的地图应用,支持PC端和移动端基于浏览器的地图应用开发,且支持HT…

2023.03 青少年机器人技术等级考试理论综合试卷(三级)

2023 年 3 月青少年机器人技术等级考试理论综合试卷(三级) 一、单选题(共 20 题,共 80 分) 1. Arduino UNO/Nano 主控板,电位器连接到 A0 引脚,下图程序运行时,变量 potVal 值的范围是?&#xf…

【原创】DELL R750xs 无盘ESXi7安装

一、环境 一台磁盘阵列 多台DELL R750xs 充当esxi主机。 当前端口组 当前虚拟交换机 当前物理网卡 当前VMKernel网卡 当前ISCSI配置 二、问题 虚拟化环境重启时,ESXi主机比磁盘阵列先启动,启动后发现磁盘阵列处于脱机状态。 三、目标 让磁盘阵列启动…