GAN的原理分析与实例

news2024/9/23 15:24:24

为了便于理解,可以先玩一玩这个网站:GAN Lab: Play with Generative Adversarial Networks in Your Browser!

GAN的本质:枯叶蝶和鸟。生成器的目标:让枯叶蝶进化,变得像枯叶,不被鸟准确识别。判别器的目标:准确判别是枯叶还是鸟

伪代码: 

案例:

原始数据:

案例结果: 

 案例完整代码:

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False

# dir = '... your path/faces/'
dir = './data/train_data'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223


noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 50
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generator


class NetGenerator(nn.Module):
    def __init__(self):
        super(NetGenerator,self).__init__()
        self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行
            nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),
            #转置卷积层:输入特征映射的尺寸会放大,通道数可能会减小,普通卷积层:输入特征映射的尺寸会缩小,但通道数可能会增加
            nn.BatchNorm2d(n_generator_feature * 8),
            nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4
            nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature * 4),
            nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8
            nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature * 2),
            nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16
            nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature),
            nn.ReLU(True),      # (n_generator_feature) × 32 × 32
            nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),
            nn.Tanh()       # 3 * 96 * 96
        )

    def forward(self, input):
        return self.main(input)


class NetDiscriminator(nn.Module):
    def __init__(self):
        super(NetDiscriminator,self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32
            nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 2),
            nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16
            nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 4),
            nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8
            nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 8),
            nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4
            nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()        # 输出一个概率
        )

    def forward(self, input):
        return self.main(input).view(-1)


def train():
    for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96
        real_image = Variable(image)
        #real_image = real_image.cuda()

        if (i + 1) % d_every == 0:  #d_every = 1,每一个batch训练一次discriminator
            optimizer_d.zero_grad()
            output = Discriminator(real_image)      # 尽可能把真图片判为True
            error_d_real = criterion(output, true_labels)
            error_d_real.backward()

            noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))
            fake_img = Generator(noises).detach()       # 根据噪声生成假图
            fake_output = Discriminator(fake_img)       # 尽可能把假图片判为False
            error_d_fake = criterion(fake_output, fake_labels)
            error_d_fake.backward()
            optimizer_d.step()

        if (i + 1) % g_every == 0:
            optimizer_g.zero_grad()
            noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))
            fake_img = Generator(noises)        # 这里没有detach
            fake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为True
            error_g = criterion(fake_output, true_labels)
            error_g.backward()
            optimizer_g.step()


def show(num):
    fix_fake_imags = Generator(fix_noises)
    fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5

    # x = torch.rand(64, 3, 96, 96)
    fig = plt.figure(1)

    i = 1
    for image in fix_fake_imags:
        ax = fig.add_subplot(8, 8, eval('%d' % i)) #将Figure划分为8行8列的子图网格,并将当前的子图添加到第i个位置。
        # plt.xticks([]), plt.yticks([])  # 去除坐标轴
        plt.axis('off')
        plt.imshow(image.permute(1, 2, 0)) #permute()函数可以对维度进行重排,Matplotlib期望的图像格式是(H, W, C),即高度、宽度、通道
        i += 1
    plt.subplots_adjust(left=None,  # the left side of the subplots of the figure
                        right=None,  # the right side of the subplots of the figure
                        bottom=None,  # the bottom of the subplots of the figure
                        top=None,  # the top of the subplots of the figure
                        wspace=0.05,  # the amount of width reserved for blank space between subplots
                        hspace=0.05)  # the amount of height reserved for white space between subplots)
    plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)
    plt.savefig("images/%dcgan.png" % num)


if __name__ == '__main__':
    transform = tv.transforms.Compose([
        tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecated
        tv.transforms.CenterCrop(96),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数
    ])

    dataset = tv.datasets.ImageFolder(dir, transform=transform)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'

    print('数据加载完毕!')
    Generator = NetGenerator()
    Discriminator = NetDiscriminator()

    optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    criterion = torch.nn.BCELoss()

    true_labels = Variable(torch.ones(batch_size))     # batch_size
    fake_labels = Variable(torch.zeros(batch_size))
    fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))
    noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布

    # if torch.cuda.is_available() == True:
    #     print('Cuda is available!')
    #     Generator.cuda()
    #     Discriminator.cuda()
    #     criterion.cuda()
    #     true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
    #     fix_noises, noises = fix_noises.cuda(), noises.cuda()


    #plot_epoch = [1,5,10,50,100,200,500,800,1000,1500,2000,2500,3000]
    plot_epoch = [1,5,10,50,100,200,500,800,1000,1200,1500]

    for i in range(1500):        # 最大迭代次数
        train()
        print('迭代次数:{}'.format(i))
        if i in plot_epoch:
            show(i)


http://t.csdnimg.cn/FTSriicon-default.png?t=N7T8http://t.csdnimg.cn/FTSri

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1311265.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Boot之自定义starter

&#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 接下来看看由辉辉所写的关于Spring Boot的相关操作吧 目录 &#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 一. starter是什么 二.为什么要使…

什么是 API 代理?

API 代理就像您的计算机和互联网上的特殊服务之间的有用中间人。这有点像将翻译、保安和信使合而为一。 什么是 API 代理&#xff1f; API 代理就像您和在线服务之间的有用中间人。当您的计算机需要从特殊在线服务&#xff08;API&#xff09;获取某些内容时&#xff0c;API 代…

计算机网络:物理层(奈氏准则和香农定理,含例题)

带你速通计算机网络期末 文章目录 一、码元和带宽 1、什么是码元 2、数字通信系统数据传输速率的两种表示方法 2.1、码元传输速率 2.2、信息传输速率 3、例题 3.1、例题1 3.2、例题2 4、带宽 二、奈氏准则&#xff08;奈奎斯特定理&#xff09; 1、奈氏准则简介 2、…

leetcode做题笔记2132. 用邮票贴满网格图

给你一个 m x n 的二进制矩阵 grid &#xff0c;每个格子要么为 0 &#xff08;空&#xff09;要么为 1 &#xff08;被占据&#xff09;。 给你邮票的尺寸为 stampHeight x stampWidth 。我们想将邮票贴进二进制矩阵中&#xff0c;且满足以下 限制 和 要求 &#xff1a; 覆盖…

VUE中如果让全局组件在某一页面不显示

目录 前言 方法一 1.在全局组件中添加一个变量用于控制显示与隐藏。 2.在全局组件的模板中使用 v-if 条件来决定是否显示该组件 3.在不需要显示全局组件的页面中&#xff0c;修改 showGlobalComponent 变量的值为 false&#xff0c;以隐藏全局组件。 4.在需要隐藏全局组…

dockerfile---创建镜像

dockerfile创建镜像&#xff1a;创建自定义镜像。 包扩配置文件的创建&#xff0c;挂载点&#xff0c;对外暴露的端口。设置环境变量。 docker镜像的方式: 1、基于官方源进行创建 根据官方提供的镜像源&#xff0c;创建镜像&#xff0c;然后拉起容器。是一个白板&#xff0c…

国产猫粮推荐排行榜有哪些牌子?国产主食冻干猫粮品牌十大排行

近年来&#xff0c;冻干猫粮作为热门的高品质猫粮&#xff0c;受到了许多追求纯天然、健康食品的铲屎官的关注。萌新铲屎官就很疑惑了冻干猫粮可以代替猫粮作为主食吗&#xff1f;冻干猫粮真就那么好吗&#xff1f; 作为一个猫咖店长&#xff0c;这几年我至少给猫挑选了20几款…

智能故障诊断期刊推荐【中文期刊】

控制与决策 http://kzyjc.alljournals.cn/kzyjc/home 兵工学报 http://www.co-journal.com/CN/1000-1093/home.shtml 计算机集成制造系统 http://jsjjc.soripan.net/ 机械工程学报 http://www.cjmenet.com.cn/CN/0577-6686/home.shtml 太阳能学报 https://www.tynxb.org.c…

Windows ❀ 关闭Google的自动更新功能

文章目录 1. 故障问题2. 解决方法 1. 故障问题 如何关闭掉Google的自动更新功能&#xff1f; 2. 解决方法 修改更新域名本地hosts为环回地址即可。 # 禁止google自动更新 127.0.0.1 update.googleapis.com备注&#xff1a; mac路径&#xff1a;/etc/hostswindows路径&…

SpringBoot之数组,集合,日期参数的详细解析

1.4 数组集合参数 数组集合参数的使用场景&#xff1a;在HTML的表单中&#xff0c;有一个表单项是支持多选的(复选框)&#xff0c;可以提交选择的多个值。 多个值是怎么提交的呢&#xff1f;其实多个值也是一个一个的提交。 后端程序接收上述多个值的方式有两种&#xff1a; 数…

EM的理论基础

1 EM定义​ 电迁移(Electro-Migration)是指在外加电场下,电子和金属原子之间的动量转移导致材料的运动。这种动量传递导致金属原子(比如Cu原子)从其原始位置移位,如图7-1。这种效应随着导线中电流密度的增加而增加,并且在更高的温度下,动量传递变得更加严重。因此,在先…

2023全球开发者生态调研:84%的开发者表示他们在工作中正积极使用生成式AI工具

今年JetBrains首次在一年一度的开发者生态调研中&#xff0c;增加了人工智能方向的问题。在全球26348名开发者参与的调研中&#xff0c;总体对人工智能的发展持乐观态度。特别是生成式AI在软件开发和编程环节中的应用&#xff0c;84%的开发者表示他们在工作中正在积极使用生成式…

【STM32CubeMX】F103 BxCAN

F103&BxCAN bxCAN总体描述 有一个增强的过滤机制来处理各种类型的报文此外&#xff0c;应用层任务需要更多CPU时间&#xff0c;因此报文接收所需的实时响应程度需要减轻。 接收FIFO的方案允许&#xff0c;CPU花很长时间处理应用层任务而不会丢失报文。 构筑在底层CAN驱动程…

MySQL增量备份与恢复

实验环境 某学校近期在进行期中考试&#xff0c;要求数据库管理员负责一班&#xff0c;二班学生的考试成绩录入&#xff0c;为保证数据的可靠性&#xff0c;数据库管理员在录入学生成绩后均要做数据库备份&#xff0c;并且为了测试备份数据是否可 用&#xff0c;模拟数据丢失故…

柯桥日常英语口语,外贸英语商务英语|英文打电话的常用语

日常生活中&#xff0c;我们常常需要打电话交流。在打电话时说话清楚&#xff0c;使用适当的礼节是很重要的。 如果你太正式&#xff0c;人们在和你说话时&#xff0c;可能会很难感到舒适。如果你太随便&#xff0c;他们可能会认为你很粗鲁&#xff01; 所以&#xff0c;说话的…

Jmeter,提取响应体中的数据:正则表达式、Json提取器

一、正则表达式 1、线程组--创建线程组&#xff1b; 2、线程组--添加--取样器--HTTP请求&#xff1b; 3、Http请求--添加--后置处理器--正则表达式提取器&#xff1b; 4、线程组--添加--监听器--查看结果树&#xff1b; 5、线程组--添加--取样器--调试取样器。 响应体数据…

Disruptor详解,Java高性能内存队列最优解

文章目录 一、Disruptor介绍1、为什么要有Disruptor2、Disruptor介绍3、Disruptor的高性能设计4、RingBuffer数据结构5、等待策略6、Disruptor在日志框架中的应用7、术语 二、Disruptor实战1、引入依赖2、Disruptor构造器3、入门实例&#xff08;1&#xff09;Hello World&…

MATLAB 绘制伯德图之将幅频特性和相频特性分开绘制方法

幅频和相频特性分别在两个图窗&#xff0c;不在一起方便保存&#xff0c;无需再裁剪 clear; close all; k 1; numH 1; denH [1,k]; sysH tf(numH,denH); w logspace(-2,2);[mag, phase] bode(sysH,w);% 幅频特性 loglog(w,squeeze(mag));grid on; % 相频特性 semilogx(…

在springboot中引入参数校验

一、概要 一般我们判断前端传过来的参数&#xff0c;需要对某些值进行判断&#xff0c;是否满足条件。 而springboot相关的参数校验注解&#xff0c;可以解决我们这个问题。 二、快速开始 首先&#xff0c;我用的springboot版本是 3.1.5 引入参数校验相关依赖 <!--1…

数据之美:零售业的变革之道

数据可视化能够为零售业带来令人瞩目的变化。随着零售业务的发展&#xff0c;数据可视化成为了洞察市场、优化运营并提升客户体验的强大工具。下面我就以可视化从业者的视角出发&#xff0c;简单分析一下数据可视化为零售业可能带来的改变。 数据可视化让零售商深入了解消费者行…