程序员学长 | 快速学习一个算法,GAN

news2024/9/17 3:40:27

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:快速学习一个算法,GAN

GAN 如何工作?

GAN 由两个部分组成:生成器(Generator)和判别器(Discriminator)。这两个部分通过一种对抗的过程来相互改进和优化。

c03159492fff41b799068bb5cf88bd59.png

生成器(Generator)

生成器的任务是接收一个随机噪声向量,并将其转换为看起来尽可能真实的数据。它通常使用一个深度神经网络来实现,从随机噪声中生成类似于训练数据的样本。

生成器的目标是生成的样本能够骗过判别器,使判别器认为生成的数据是真实的。

判别器(Discriminator)

判别器是一个二分类模型,它的任务是区分真实数据和生成器生成的假数据。

判别器接收一个数据样本,并输出一个概率值,表示该样本是真实数据的概率。

判别器的目标是尽可能准确地将真实数据与生成数据区分开来。

对抗过程

GAN 的训练过程可以被看作是生成器和判别器之间的博弈。具体来说,生成器试图生成逼真的数据以欺骗判别器,而判别器则试图更好地识别生成的假数据

这个过程可以描述为一个最小最大化的优化问题。

4c8b68b8b9184d9092ffda2498ba9173.png

鉴别器 D 想要最大化目标函数,使得 D(x) 接近于 1,D(G(z)) 接近于 0。这意味着鉴别器应该将训练集中的所有图像识别为真实 (1),将所有生成的图像识别为假 (0)。

生成器 (G) 想要最小化目标函数,使得 D(G(z)) 为 1。这意味着生成器试图生成被鉴别器网络分类为 1 的图像。

训练步骤

  1. 初始化生成器和判别器的参数

  2. 判别器训练

    • 从真实数据集中采样一个 mini-batch 的真实数据。

    • 从生成器的噪声分布中采样一个 mini-batch 的噪声,并生成假数据。

    • 更新判别器的参数,使其能够更好地区分真实数据和生成数据。

      f6487b7759c24cf9aa3570065e6f957b.png

  3. 生成器训练

    861dc76c582746c281d4d47c3e1156c0.png
    • 从噪声分布中采样一个 mini-batch 的噪声,并生成假数据。

    • 更新生成器的参数,使其生成的数据更能够欺骗判别器。

  4. 重复上述过程,直到生成器生成的数据足够逼真,判别器无法准确区分真实数据和生成数据

GAN 在图像生成、图像修复、超分辨率、风格迁移、数据增强等领域有广泛应用。例如,通过 GAN,可以生成高分辨率的图像,将低分辨率图像转换为高分辨率图像,或者将某种风格的图像转换为另一种风格的图像。

案例分享

下面是使用 GAN 来生成图像的案例。这里我们以手写数字识别数据集为例进行说明。

1.读取数据集

from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import LeakyReLU
from keras.layers import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD
import matplotlib.pyplot as plt
import sys
import numpy as np

num_rows = 28
num_cols = 28
num_channels = 1
input_shape = (num_rows, num_cols, num_channels)
z_size = 100
batch_size = 32
(train_ims, _), (_, _) = mnist.load_data()
train_ims = train_ims / 127.5 - 1.
train_ims = np.expand_dims(train_ims, axis=3)

valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

2.定义生成器

生成器 (D) 在 GAN 中扮演着至关重要的角色,因为它负责生成能够欺骗鉴别器的真实图像。

它是 GAN 中图像形成的主要组件。

在本文中,我们为生成器使用了一种特定的架构,该架构包含一个完全连接 (FC) 层并采用 Leaky ReLU 激活。然而,值得注意的是,生成器的最后一层使用 TanH 激活而不是 LeakyReLU。

def build_generator():
    gen_model = Sequential()
    gen_model.add(Dense(256, input_dim=z_size))
    gen_model.add(LeakyReLU(alpha=0.2))
    gen_model.add(BatchNormalization(momentum=0.8))
    gen_model.add(Dense(512))
    gen_model.add(LeakyReLU(alpha=0.2))
    gen_model.add(BatchNormalization(momentum=0.8))
    gen_model.add(Dense(1024))
    gen_model.add(LeakyReLU(alpha=0.2))
    gen_model.add(BatchNormalization(momentum=0.8))
    gen_model.add(Dense(np.prod(input_shape), activation='tanh'))
    gen_model.add(Reshape(input_shape))

    gen_noise = Input(shape=(z_size,))
    gen_img = gen_model(gen_noise)
    return Model(gen_noise, gen_img)

3.定义鉴别器

在生成对抗网络 (GAN) 中,鉴别器 (D) 通过评估真实性和可能性来执行区分真实图像和生成图像的关键任务。

此组件可以看作是一个二元分类问题。

为了解决此任务,我们可以采用一个简化的网络架构,该架构由全连接层 (FC)、Leaky ReLU 激活和 Dropout 层组成。值得一提的是,鉴别器的最后一层包括 FC 层,后跟 Sigmoid 激活。Sigmoid 激活函数产生所需的分类概率。

def build_discriminator():
    disc_model = Sequential()
    disc_model.add(Flatten(input_shape=input_shape))
    disc_model.add(Dense(512))
    disc_model.add(LeakyReLU(alpha=0.2))
    disc_model.add(Dense(256))
    disc_model.add(LeakyReLU(alpha=0.2))
    disc_model.add(Dense(1, activation='sigmoid'))

    disc_img = Input(shape=input_shape)
    validity = disc_model(disc_img)
    return Model(disc_img, validity)

4.计算损失函数

我们可以使用二元交叉熵损失来实现生成器和鉴别器。

# discriminator
disc= build_discriminator()
disc.compile(loss='binary_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])

z = Input(shape=(z_size,))

# generator
img = generator(z)

disc.trainable = False

validity = disc(img)

# combined model
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer='sgd')

5.优化损失

def intialize_model():
    disc= build_discriminator()
    disc.compile(loss='binary_crossentropy',
        optimizer='sgd',
        metrics=['accuracy'])

    generator = build_generator()

    z = Input(shape=(z_size,))
    img = generator(z)

    disc.trainable = False

    validity = disc(img)

    combined = Model(z, validity)
    combined.compile(loss='binary_crossentropy', optimizer='sgd')
    return disc, generator, combined

6. 模型训练

def train(epochs, batch_size=128, sample_interval=50):
    # load images
    (train_ims, _), (_, _) = mnist.load_data()
    # preprocess
    train_ims = train_ims / 127.5 - 1.
    train_ims = np.expand_dims(train_ims, axis=3)

    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    # training loop
    for epoch in range(epochs):

        batch_index = np.random.randint(0, train_ims.shape[0], batch_size)
        imgs = train_ims[batch_index]
    # create noise
        noise = np.random.normal(0, 1, (batch_size, z_size))
    # predict using a Generator
        gen_imgs = gen.predict(noise)
    # calculate loss functions
        real_disc_loss = disc.train_on_batch(imgs, valid)
        fake_disc_loss = disc.train_on_batch(gen_imgs, fake)
        disc_loss_total = 0.5 * np.add(real_disc_loss, fake_disc_loss)

        noise = np.random.normal(0, 1, (batch_size, z_size))

        g_loss = full_model.train_on_batch(noise, valid)
   
    # save outputs every few epochs
        if epoch % sample_interval == 0:
            one_batch(epoch)

7.生成手写数字

使用 MNIST 数据集,我们可以创建一个实用函数,使生成器为一组图像生成预测。

此函数生成随机声音,将其提供给生成器,运行它以显示生成的图像并将其保存在特殊文件夹中。建议定期运行此实用函数,例如每 200 个周期运行一次,以监控网络进度。

def one_batch(epoch):
    r, c = 5, 5
    noise_model = np.random.normal(0, 1, (r * c, z_size))
    gen_images = gen.predict(noise_model)

    # Rescale images 0 - 1
    gen_images = gen_images*(0.5) + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_images[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%d.png" % epoch)
    plt.close()

在我们的实验中,我们使用 32 的批次大小对 GAN 进行了大约 10,000 个时期的训练。

为了跟踪训练的进度,我们每 200 个时期保存一次生成的图像,并将它们存储在名为 “images” 的指定文件夹中。

disc, gen, full_model = intialize_model()
train(epochs=10000, batch_size=32, sample_interval=200)

现在,我们来检查一下不同阶段的 GAN 模拟结果。

初始化、5000 个 epoch 以及 10000 个 epoch 的最终结果。

最初,我们以随机噪声作为生成器的输入。

2e954b9c4c274f3e98d6ff8c9ba2be87.png

经过 5000 个时期的训练后,我们可以观察到生成的图形开始类似于 MNIST 数据集。

1a0dc57d0de14beebed11deee9879f68.png

经过 10,000 个时期的训练后,我们获得以下输出。

7c9ae539548c463db3d7b6a41fe4db86.png

可以看到,这些生成的图像与手写数字数据已经非常相似了。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1919097.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C标准库读写文件

函数介绍 库变量 变量描述size_t无符号整数类型,是sizeof关键字的结果,表示对象大小FILE文件流类型,适合存储文件流信息的对象类型 库宏 宏描述NULL空指针常量EOF表示已经到达文件结束的负整数stderr、stdin、stdout指向FILE类型的指针&a…

【AIGC】二、mac本地采用GPU启动keras运算

mac本地采用GPU启动keras运算 一、问题背景二、技术背景三、实验验证本机配置安装PlaidML安装plaidml-keras配置默认显卡 运行采用 CPU运算的代码step1 先导入keras包,导入数据cifar10,这里可能涉及外网下载,有问题可以参考[keras使用基础问题…

starccm+软件许可优化解决方案

starccm软件介绍 Simcenter Star CCM专注于CFD的多物理场仿真,支持流体动力学模拟、电池模拟、协同仿真、设计探索、电机、电化学、引擎模拟、移动物体、流变学、固体力学等多个方面,无论是真实的多物理场仿真,捕捉产品的完整几何形状&#x…

LVS实验

LVS实验 nginx1 RS1 192.168.11.137 nginx2 RS2 192.168.11.138 test4 调度器 ens33 192.168.11.135 ens36 12.0.0.1 test2 客户端 12.0.0.10 一、test4 配置两张网卡地址信息 [roottest4 network-scripts]# cat ifcfg-ens33 TYPEEthernet BOOTPROTOstatic DEFROUTEyes DEVIC…

利用 Plotly.js 创建交互式条形图

本文由ScriptEcho平台提供技术支持 项目地址:传送门 利用 Plotly.js 创建交互式条形图 应用场景介绍 交互式条形图广泛应用于数据可视化和分析领域。它可以直观地展示不同类别或分组之间的数值差异,并允许用户通过交互操作探索数据。 代码基本功能介…

【经典面试题】环形链表

1.环形链表oj 2. oj解法 利用快慢指针: /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/typedef struct ListNode ListNode; bool hasCycle(struct ListNode *head) {ListNode* slow head, *fast…

51单片机(STC8051U34K64)_RA8889_SPI4参考代码(v1.3)

硬件:STC8051U34K64 RA8889开发板(硬件跳线变更为SPI-4模式,PS101,R143,R141短接,R142不接) STC8051U34K64是STC最新推出来的单片机,主要用于替换传统的8051单片机,与标…

大佬,简单解释下“嵌入式软件开发”和“嵌入式硬件开发”的区别

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「嵌入式的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!!首先,嵌入式硬…

【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第一篇 嵌入式Linux入门篇-第十九章 Linux 工具之make 工具和 makefile 文件

i.MX8MM处理器采用了先进的14LPCFinFET工艺,提供更快的速度和更高的电源效率;四核Cortex-A53,单核Cortex-M4,多达五个内核 ,主频高达1.8GHz,2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…

【car】深入浅出学习机械燃油车知识、结构、原理、维修、保养、改装、编程

汽车的五大总成通常是指发动机、变速器、前后桥、车架和悬挂系统。 发动机:是汽车的动力来源,负责将燃料的化学能转化为机械能,驱动汽车行驶。常见的发动机类型有内燃机(如汽油发动机、柴油发动机)和电动机&#xff0…

hypermill软件许可优化解决方案

Hypermill软件介绍 hyperMILL的最大优势表现在五轴联动方面 五轴联动被广泛应于汽车、工具、模具、机械、航空航天等领域,比如航空叶轮、叶片、结构件的铣削。现在很多机床和控制器都可以适应五轴铣削要求,然而在软件方面多采取定位加工方式(…

案例|LabVIEW连接S7-1200PLC

附带: 写了好的参考文章: 通讯测试工具和博图仿真机的连接教程【内含图文完整过程软件使用】 解决博图V15 V16 V17 V18等高版本和低版本在同款PLC上不兼容的问题 目录 前言一、准备条件二、步骤1. HslCommunicationDemo问题1:连接失败?问题…

..质数..

先弄清楚我们在上小学时 学的概念。 1、什么是质因数? -质因数是指能够整除给定正整数的质数。每个正整数都可以被表示为几个质数的乘积,这些质数就是该数的质因数。质因数分解是将一个正整数分解成若干个质数相乘的过程。例如,数字 12…

[激光原理与应用-109]:南京科耐激光-激光焊接-焊中检测-智能制程监测系统IPM介绍 - 12 - 焊接工艺之影响焊接效果的因素

目录 一、影响激光焊接效果的因素 1.1、光束特征 1.2、焊接特征 1.3、保护气体 二、材料对焊接的影响 2.1 材料特征 2.2 不同材料对激光的吸收率 (一)、不同金属材料对不同激光的吸收率 1. 金属材料对激光的普遍反应 2. 不同波长激光的吸收率差…

ant design pro多页签功能

效果: 原理: 1、所有需要页签页面,都需要一个共同父组件 2、如何缓存,用的是ant的Tabs组件,在共同父组件中,实际是展示的Tabs组件 3、右键,用的是ant的Dropdown组件,当点击时&…

SpringBoot新手快速入门系列教程十:基于docker容器,部署一个简单的项目

前述: 本篇教程将略过很多docker下载环境配置的基础步骤,如果您对docker不太熟悉请参考我的上一个教程:SpringBoot新手快速入门系列教程九:基于docker容器,部署一个简单的项目 使用 Docker Compose 支持部署 Docker 项…

MySQL某个字段按指定值排序,其他值按创建时间排序

项目场景: MySQL某个字段按指定值排序,其他值按创建时间排序,我们需要用到FIELD() 函数,它是一种对查询结果排序的方法,可以根据指定的字段值顺序进行排序。 order by FIELD() 函数的语法如下: ORDER BY …

[GHCTF 2024 新生赛]ezzz_unserialize

源码&#xff1a; <?php /*** Author: hey* message: Patience is the key in life,I think youll be able to find vulnerabilities in code audits.* Have fun and Good luck!!!*/ error_reporting(0); class Sakura{public $apple;public $strawberry;public function …

LiteOS增加执行自定义源码

开发过程注意事项&#xff1a; 源码工程路径不能太长 源码工程路径不能有中文 一定要关闭360等杀毒软件&#xff0c;否则编译的打包阶段会出错 增加自定义源码的步骤: 1.创建源码目录 2. 创建源文件 新建myhello目录后&#xff0c;再此目录下再新建源文件myhello_demo.c 3. 编…