生成对抗网络(DCGAN)手写数字生成

news2025/4/8 19:09:56

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
  • 二、什么是生成对抗网络
    • 1. 简单介绍
    • 2. 应用领域
  • 三、创建模型
      • 1. 生成器
      • 2. 判别器
  • 四、定义损失函数和优化器
      • 1. 判别器损失
      • 2. 生成器损失
  • 五、定义训练循环
  • 六、训练模型
  • 七、创建 GIF

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码
  • 卷积神经网络(Inception-ResNet-v2)交通标志识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
    
# 打印显卡信息,确认GPU可用
print(gpus)
from tensorflow.keras  import layers
from IPython           import display
import matplotlib.pyplot as plt
import numpy             as np
import glob,imageio,os,PIL,time
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')

# 将图片标准化到 [-1, 1] 区间内
train_images = train_images / 127.5 - 1  
BUFFER_SIZE = 60000
BATCH_SIZE  = 256

# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

二、什么是生成对抗网络

1. 简单介绍

生成对抗网络(GAN) 包含生成器和判别器,两个模型通过对抗训练不断学习、进化。

  • 生成器(Generator):生成数据(大部分情况下是图像),目的是“骗过”判别器。
  • 鉴别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器生成的“假数据”。

2. 应用领域

GAN 的应用十分广泛,它的应用包括图像合成、风格迁移、照片修复以及照片编辑,数据增强等等。

1)风格迁移

图像风格迁移是将图像A的风格转换到图像B中去,得到新的图像。

2)图像生成

GAN 不但能生成人脸,还能生成其他类型的图片,比如漫画人物。

三、创建模型

1. 生成器

生成器使用 tf.keras.layers.Conv2DTranspose (上采样)层来从种子(随机噪声)中产生图片。以一个使用该种子作为输入的 Dense 层开始,然后多次上采样直到达到所期望的 28x28x1 的图片尺寸。注意除了输出层使用 tanh 之外,其他每层均使用 tf.keras.layers.LeakyReLU 作为激活函数。

def make_generator_model():
    model = tf.keras.Sequential([
        layers.Dense(7*7*256, use_bias=False, input_shape=(100,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        
        layers.Reshape((7, 7, 256)),
        
        layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        
        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        
        layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
    ])

    return model

generator = make_generator_model()
generator.summary()

2. 判别器

判别器是一个基于 CNN 的图片分类器。

def make_discriminator_model():
    model = tf.keras.Sequential([
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        
        layers.Flatten(),
        layers.Dense(1)
    ])

    return model

discriminator = make_discriminator_model()
discriminator.summary()

四、定义损失函数和优化器

为两个模型定义损失函数和优化器。

# 该方法返回计算交叉熵损失的辅助函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

1. 判别器损失

该方法量化判断真伪图片的能力。它将判别器对真实图片的预测值与值全为 1 的数组进行对比,将判别器对伪造(生成的)图片的预测值与值全为 0 的数组进行对比。

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

2. 生成器损失

生成器损失量化其欺骗判别器的能力。直观来讲,如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或 1)。这里我们将把判别器在生成图片上的判断结果与一个值全为 1 的数组进行对比。

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

由于我们需要分别训练两个网络,判别器和生成器的优化器是不同的。

generator_optimizer     = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

五、定义训练循环

EPOCHS = 60
noise_dim = 100
num_examples_to_generate = 16

# 我们将重复使用该种子(在 GIF 中更容易可视化进度)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

训练循环在生成器接收到一个随机种子作为输入时开始。该种子用于生产一张图片。判别器随后被用于区分真实图片(选自训练集)和伪造图片(由生成器生成)。针对这里的每一个模型都计算损失函数,并且计算梯度用于更新生成器与判别器。

# 注意 `tf.function` 的使用
# 该注解使函数被“编译”
@tf.function
def train_step(images):
    # 生成噪音
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)
        
        # 计算loss
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
    
    #计算梯度
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    #更新模型
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            train_step(image_batch)

        # 实时更新生成的图片
        display.clear_output(wait=True)
        generate_and_save_images(generator, epoch + 1, seed)

        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

    # 最后一个 epoch 结束后生成图片
    display.clear_output(wait=True)
    generate_and_save_images(generator, epochs, seed)
def generate_and_save_images(model, epoch, test_input):
    # 注意 training` 设定为 False
    # 因此,所有层都在推理模式下运行(batchnorm)。
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4,4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig('./images/19/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

六、训练模型

调用上面定义的 train() 方法来同时训练生成器和判别器。在训练之初,生成的图片看起来像是随机噪声。随着训练过程的进行,生成的数字将越来越真实。在大概 50 个 epoch 之后,这些图片看起来像是 MNIST 数字。

%%time:将会给出cell的代码运行一次所花费的时间。

%%time
train(train_dataset, EPOCHS)

在这里插入图片描述

七、创建 GIF

import imageio,pathlib

def compose_gif():
    # 图片地址
    data_dir = "./images/19"
    data_dir = pathlib.Path(data_dir)
    paths    = list(data_dir.glob('*'))
    
    gif_images = []
    for path in paths:
        gif_images.append(imageio.imread(path))
    imageio.mimsave("./pic_gif/MINST_DCGAN_19.gif",gif_images,fps=8)
    
compose_gif()
print("GIF动图生成完成!")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1274748.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Amazon CodeWhisperer 使用体验

文章作者:STRIVE Amazon CodeWhisperer 是最新的代码生成工具,支持多种编程语言,如 java,js,Python 等,能减少开发人员手敲代码时间,提升工作效率。PS:本人是一名 CodeWhisperer 业余爱好者 亚马逊云科技开发者社区为开…

Spring Cloud 配置 Nacos

一,下载Nacos 下载地址:https://github.com/alibaba/nacos/releases 二,启动Nacos 安装Nacos的bin目录下, 执行:startup.cmd -m standalone 然后打开上图红框的地址 三,配置服务 1 配置Nacos 创建命名…

【C++】异常抛出变量的生命周期

欢迎关注博主 Mindtechnist 或加入【智能科技社区】一起学习和分享Linux、C、C、Python、Matlab,机器人运动控制、多机器人协作,智能优化算法,滤波估计、多传感器信息融合,机器学习,人工智能等相关领域的知识和技术。搜…

[Android] c++ 通过 JNI 调用 JAVA函数

如何使用: Calling Java from C with JNI - CodeProject c里的 JNI 类型 和 JAVA 类型的映射关系: JNI Types and Data Structures Primitive Types and Native Equivalents Java TypeNative TypeDescriptionbooleanjbooleanunsigned 8 bitsbytejbyt…

高级java工程师手把手教你解决内存不足引起JVM奔溃真实生产事故案例实战

高级java工程师手把手教你解决内存不足引起JVM奔溃案例实战 一、真实事故描述: 生产环境的Java程序进程,直接宕掉,进程都没有了,JVM奔溃了。生产事故,生产直接停止了,甲方爸爸客户着急了,公司…

使用yolov7进行多图像视频识别

1.yolov7你可以让你简单的部署,比起前几代来说特别简单 #下面是我转换老友记的测试视频,可以看到几乎可以准确预测 2.步骤 1.在github官网下载代码 https://github.com/WongKinYiu/yolov7 2.点击下载权重文件放到项目中 3.安装依赖,我的python版本是3.6的 pip install -r requ…

SQL中left join、right join、inner join等的区别

一张图可以简洁明了的理解出left join、right join、join、inner join的区别: 1、left join 就是“左连接”,表1左连接表2,以左为主,表示以表1为主,关联上表2的数据,查出来的结果显示左边的所有数据&#…

如何从初级进阶中级测试工程师?测试人该具备哪些素养?

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 1、如何成为一枚中…

JAVA全栈开发 day14_集合(Collection\List接口、数据结构、泛型)

一、数组 数组是一个容器,可以存入相同类型的多个数据元素。 数组局限性: ​ 长度固定:(添加–扩容, 删除-缩容) ​ 类型是一致的 对象数组 : int[] arr new int[5]; … Student[] arr …

分享88个清新唯美PPT,总有一款适合您

分享88个清新唯美PPT,总有一款适合您 88个清新唯美PPT下载链接:https://pan.baidu.com/s/1XUUjxjmWFw2fJKENjk6_Yg?pwd8888 提取码:8888 Python采集代码下载链接:采集代码.zip - 蓝奏云 学习知识费力气,收集整…

【亚马逊云科技】re:Invent 2023 | Amazon Q王炸产品震撼来袭

re:Invent 2023前沿资讯快速入口➡️:2023亚马逊云科技reinvent大会,与开发者一起构建未来! 文章目录 一、2023 亚马逊云科技 re:Invent 精彩内容速递🎨二、Amazon Q 震撼来袭2.1 什么是Amazon Q?2.2 Amazon Q功能介绍…

OpenHarmony 关闭息屏方式总结

前言 OpenHarmony源码版本:4.0release 开发板:DAYU / rk3568 一、通过修改系统源码实现不息屏 修改目录:base/powermgr/power_manager/services/native/profile/power_mode_config.xml 通过文件中的提示可以知道DisplayOffTime表示息屏的…

wordpress安装之Linux ftp传输

工欲善其事,必先利其器。 最近准备在自己的服务器上搭建一个个人技术分享的平台。 因为我发现现在网络上的工具呀,还有一些问题的解答总是模棱两可,所以我打算自己做一个。 首先呢,我们需要有一个linxu的系统当服务器,然后呢&a…

d3dcompiler_47.dll缺失怎么修复?一招搞定电脑弹窗问题

在计算机使用过程中,我们常常会遇到一些错误提示,其中之一就是“d3dcompiler_47.dll缺失”。这个错误通常出现在游戏或应用程序运行时,它会导致程序无法正常启动或运行。为了解决这个问题,我们需要采取一些措施来修复缺失的文件。…

带米勒钳位的隔离驱动SiLM5350系列 工作原理、特性参数、封装形式

带米勒钳位的隔离驱动SiLM5350系列 单通道 30V,10A 带米勒钳位的隔离驱动 具有驱动电流更大、传输延时更低、抗干扰能力更强、封装体积更小等优势, 为提高电源转换效率、安全性和可靠性提供理想之选。 描述: SiLM5350系列是单通道隔离驱动器&#xff0…

2023年中国数据要素市场研究报告

第一章 概况 1.1 定义 中国数据要素交易市场是一个多层次、多维度的复杂体系,涵盖了不同的交易方式、市场类型和行业应用。数据要素作为一种新兴的生产要素,涉及社会经营活动中所有可以电子化记录、为使用者或所有者带来经济效益的数据资源。 在狭义上…

图片点击放大

在列表中添加插槽 <template slot-scope"scope">&#xff0c;获取当前点击的数据 在图片中添加点击事件的方法&#xff0c;用来弹出窗口 <vxe-columnfield"icon"title"等级图标"><template slot-scope"scope"><…

基于若依的ruoyi-nbcio流程管理系统仿钉钉流程初步完成转bpmn设计(还有bug,以后再修改)

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码&#xff1a; https://gitee.com/nbacheng/ruoyi-nbcio 演示地址&#xff1a;RuoYi-Nbcio后台管理系统 今天初步完成仿钉钉流程转bpmn设计的工作&#xff0c;当然还有不少bug&#xff0c;以后有需要或者网友也帮…

【android开发-01】android中toast的用法介绍

1&#xff0c;android中toast的作用 在Android开发中&#xff0c;Toast是一种用于向用户显示简短消息的轻量级对话框。它通常用于向用户提供一些即时的反馈信息&#xff0c;例如操作结果、提示或警告。 Toast的主要作用如下&#xff1a; 提供反馈&#xff1a;Toast可以在用户…

索尼PMW580视频帧EC碎片重组开启方法

索尼PMW580视频帧EC碎片重组开启方法 索尼PMW-580摄像机生成的MXF文件存在严重的碎片化&#xff0c;目前CHS零壹视频恢复程序MXF版、专业版、高级版已经支持重组结构体正常的碎片&#xff0c;同时也支持对于结构体破坏或者覆盖后仅存在音视频帧EC数据的重组&#xff0c;需要注…