DCGAN--Keras实现

news2025/1/15 6:31:15

文章目录

  • 一、Keras与tf.keras?
  • 二、keras中Model的使用
  • 三、使用Keras来实现DCGan
    • 1、导入必要的包
    • 2.指定模型输入维度:图像尺寸和噪声向量 的长度
    • 3、构建生成器
    • 4、构造鉴别器
    • 5、构建并编译DCGan
    • 6、对模型进行训练
    • 7、显示生成图像
    • 8、运行模型
  • 总结


一、Keras与tf.keras?

这个博客简单明了
前两篇文章可以看看,会通透很多。
Keras是一个高级API,不过如果我们想要自定义损失函数或者其他,只能使用Tensorflow来定义
tf,keras包含了所有Keras API,所以,我们最好还是使用tf.keras

二、keras中Model的使用

见此篇博客

三、使用Keras来实现DCGan

在这里使用的是Minst 手写数据集

我是在colab实现的,具体的版本如下:

1、导入必要的包

代码如下(示例):

%matplotlib inline
 
import matplotlib.pyplot as plt
import numpy as np
 
from keras.datasets import mnist
from keras.layers import Activation, BatchNormalization, Dense, Dropout, Flatten, Reshape
from keras.layers import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Sequential
from keras.optimizers import Adam

2.指定模型输入维度:图像尺寸和噪声向量 的长度

代码如下(示例):

img_rows = 28
img_cols = 28
channels = 1
 
# 输入图像的维度
img_shape = (img_rows, img_cols, channels)
 
# 噪声向量Z的长度
z_dim = 100

3、构建生成器

解释:生成器其实是一个卷积的逆过程,卷积网络一般被用于图像分类中,在卷积的过程中,我们不断的减少宽高提升通道数,但是在生成器中,我们输入的是一个随机向量,我们需要让它最终输出一个图像,这里输出的是28**28*1的灰度图。
在这里插入图片描述
综合所有步骤如下。
(1)取一个随机噪声向量 ,通过全连接层将其重塑为7×7×256
张量。
(2)使用转置卷积,将7×7×256张量转换为14×14×128张量。
(3)应用批归一化和LeakyReLU激活函数。
(4)使用转置卷积,将14×14×128张量转换为14×14×64张
量。注意:宽度和高度尺寸保持不变。可以通过将
Conv2DTranspose中的stride参数设置为1来实现。
(5)应用批归一化和LeakyReLU激活函数。
(6)使用转置卷积,将14×14×64张量转换为输出图像大小
28×28×1。
(7)应用tanh激活函数。

def build_generator(z_dim):
   #序列化model
    model = Sequential()
 
    # 通过全连接层将输入重新调整大小7*7*256的张量
    model.add(Dense(256 * 7 * 7, input_dim=z_dim))
    model.add(Reshape((7, 7, 256)))
 
    # 通过转置卷积层将7*7*256的张量转换为14*14*128的张量
    model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'))
 
    # 批归一化
    model.add(BatchNormalization())
 
    # Leaky ReLU 激活函数
    model.add(LeakyReLU(alpha=0.01))
 
    # 通过转置卷积层将14*14*128的张量转换为14*14*64的张量
    model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding='same'))
 
    # 批归一化
    model.add(BatchNormalization())
 
    # Leaky ReLU 激活函数
    model.add(LeakyReLU(alpha=0.01))
 
    # 通过转置卷积层将14*14*64的张量转换为28*28*1的张量
    model.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding='same'))
 
    # 带有tanh激活函数的的输出层
    model.add(Activation('tanh'))
 
    return model

4、构造鉴别器

鉴别器与普通的卷积神经网络用来进行图像分类没什么区别,输入图像,输出一个值用来判断图像真伪。在这里插入图片描述

综合所有步骤如下。
(1)使用卷积层将28×28×1的输入图像转换为14×14×32的张
量。
(2)应用LeakyReLU激活函数。
(3)使用卷积层将14×14×32的张量转换为7×7×64的张量。
(4)应用批归一化和LeakyReLU激活函数。
(5)使用卷积层将7×7×64的张量转换为3×3×128的张量。
(6)应用批归一化和LeakyReLU激活函数。
(7)将3×3×128张量展成大小为3×3×128=1152的向量。
(8)使用全连接层,输入sigmoid激活函数计算输入图像是否真
实的概率。

def build_discriminator(img_shape):
 
    model = Sequential()
 
    # 通过卷积层将大小为28*28*1的张量转变为14*14*32的张量
    model.add(Conv2D(32, kernel_size=3,strides=2, input_shape=img_shape,padding='same'))
 
    # Leaky ReLU 激活函数
    model.add(LeakyReLU(alpha=0.01))
 
    # 通过卷积层将大小为14*14*32的张量转变为7*7*64的张量
    model.add(
        Conv2D(64,
               kernel_size=3,
               strides=2,
               input_shape=img_shape,
               padding='same'))
 
    # 批归一化
    model.add(BatchNormalization())
 
    # Leaky ReLU 激活函数
    model.add(LeakyReLU(alpha=0.01))
 
    # 通过卷积层将7*7*64的张量转变为3*3*128的张量
    model.add(
        Conv2D(128,
               kernel_size=3,
               strides=2,
               input_shape=img_shape,
               padding='same'))
 
    # 批归一化
    model.add(BatchNormalization())
 
    # Leaky ReLU 激活函数    
    model.add(LeakyReLU(alpha=0.01))
 
    # 带有sigmoid激活函数的输出层
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
 
    return model

5、构建并编译DCGan

def build_gan(generator, discriminator):
    model = Sequential()
 
    # 将生成器和鉴定器结合到一起
    model.add(generator)
    model.add(discriminator)
 
    return model
# 构建并编译鉴定器(使用了二元交叉熵作为损失函数,Adam的优化算法)
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=Adam(),
                      metrics=['accuracy'])
 
# 构建生成器
generator = build_generator(z_dim)
 
# 在生成器训练的时候,将鉴定器的参数固定
discriminator.trainable = False
 
#构建并编译固定的鉴定器的GAN模型,并训练生成器
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())
# 在生成器训练的时候,将鉴定器的参数固定
discriminator.trainable = False
 
#构建并编译固定的鉴定器的GAN模型,并训练生成器
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())

6、对模型进行训练

losses = []
accuracies = []
iteration_checkpoints = []
 
 
def train(iterations, batch_size, sample_interval):
 
    # 导入mnist数据集
    (X_train, _), (_, _) = mnist.load_data()
 
    # 灰度像素值从[0,255]缩放到[-1, 1]
    X_train = X_train / 127.5 - 1.0
    X_train = np.expand_dims(X_train, axis=3)
 
    # 真实图像的标签都为1
    real = np.ones((batch_size, 1))
 
    # 假图像的标签都为0
    fake = np.zeros((batch_size, 1))
 
    for iteration in range(iterations):
 
        # -------------------------
        #  训练鉴定器
        # -------------------------
 
        # 抽取真实图像的一个批次
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]
 
        # 生成一批次的假图像
        z = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(z)
 
        # 训练鉴定器
        d_loss_real = discriminator.train_on_batch(imgs, real)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss, accuracy = 0.5 * np.add(d_loss_real, d_loss_fake)
 
        # ---------------------
        # 训练生成器
        # ---------------------
 
        # 生成一批次的假照片
        z = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(z)
 
        # 训练生成器
        g_loss = gan.train_on_batch(z, real)
 
        if (iteration + 1) % sample_interval == 0:
 
            # 保存损失和准确率以便训练后绘图
            losses.append((d_loss, g_loss))
            accuracies.append(100.0 * accuracy)
            iteration_checkpoints.append(iteration + 1)
 
            # 输出训练过程
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %
                  (iteration + 1, d_loss, 100.0 * accuracy, g_loss))
 
            # 输出生成图像的采样
            sample_images(generator)

7、显示生成图像

为完整起见,清单4.7包含了sample_images()函数,它在指定的
训练迭代中输出一个4×4的图像网格。

def sample_images(generator, image_grid_rows=4, image_grid_columns=4):
 
    # 样本的随机噪声(4*4张的合成图)
    z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))
 
    # 从随机噪声中生成图像
    gen_imgs = generator.predict(z)
 
    # 将图像像素值重新缩放为[0,1]内
    gen_imgs = 0.5 * gen_imgs + 0.5
 
    # 建立图像网格
    fig, axs = plt.subplots(image_grid_rows,
                            image_grid_columns,
                            figsize=(4, 4),
                            sharey=True,
                            sharex=True)
 
    cnt = 0
    for i in range(image_grid_rows):
        for j in range(image_grid_columns):
            # 输出一个图像网格
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1

8、运行模型

# 设置超参数
iterations = 10000
batch_size = 128
sample_interval = 1000
 
# 训练模型直到指定的迭代次数
train(iterations, batch_size, sample_interval)

总结

遇到的问题:
1、keras的更新
2、TensorFlow中padding卷积的两种方式“SAME”和“VALID”与我之前学习过的深度学习不同。tensorflow见此
3、numpy中expand_dims()函数详解
4、np.random.randint()的用法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/524695.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣sql中等篇练习(二十)

力扣sql中等篇练习(二十) 1 寻找面试候选人 1.1 题目内容 1.1.1 基本题目信息1 1.1.2 基本题目信息2 1.1.3 示例输入输出 a 示例输入 b 示例输出 1.2 示例sql语句 # 分为以下两者情况,分别考虑,然后union进行处理(有可能同时满足,需要去进行去重) # ①该用户在 三场及更多…

软件测试八股文,软件测试常见面试合集【附答案】

PS:加上参考答案有几十万字,答案就没有全部放上来了,高清打印版本超过400多页,评论区留言直接获取 1、你的测试职业发展是什么? 2、你认为测试人员需要具备哪些素质 3、你为什么能够做测试这一行 4、测试的目的是什么? 5、测…

一图看懂 attrs 模块:一个在类定义时可替换 `__init__`, `__eq__`, `__repr__`等方法的样板,资料整理+笔记(大全)

本文由 大侠(AhcaoZhu)原创,转载请声明。 链接: https://blog.csdn.net/Ahcao2008 一图看懂 attrs 模块:一个在类定义时可替换 __init__, __eq__, __repr__等方法的样板,资料整理笔记(大全) 🧊摘要&#x1…

吴恩达|chatgpt 提示词工程师学习笔记。

目录 一、提示指南 写提示词的2大原则: 模型的限制 二、迭代 三、总结 四、推断 五、转换 六、扩展 七、对话机器人 吴恩达和openai团队共同开发了一款免费的课程,课程是教大家如何更有效地使用prompt来调用chatgpt,整个课程时长1个…

ctfshow周末大挑战2023/5/12

本周周末大挑战用到的函数讲解 parse_url() 作用:解析URL,返回其组成部分 语法: parse_url ( string $url [, int $component -1 ] ) 参数: url:要解析的 URL。无效字符将使用 _ 来替换。 component: …

Sentinel———隔离和降级

FeignClient整合Sentinel SpringCloud中,微服务调用都是通过Feign来实现的,因此做客户端保护必须整合Feign和Sentinel。 第一步 修改OrderService的application.yml文件,开启Feign的Sentinel功能(消费者服务) feig…

算法基础第二章

算法基础第二章 第二章:数据结构1、链表1.1、单链表(写邻接表:存储图和树)1.2、双链表(优化某些问题) 2、栈与队列2.1、栈2.1.1、数组模拟栈2.1.2、单调栈 2.2、队列2.2.1、数组模拟队列2.2.2、滑动窗口(单调队列的使用…

操作系统实验二 进程(线程)同步

前言 实验二相比实验一难度有所提升,首先得先掌握好相应的理论知识(读者-写者问题和消费者-生产者问题),才能在实验中得心应手。任务二的代码编写可以借鉴源码,所以我们要先读懂源码。 1.实验目的 掌握Linux环境下&a…

linux系统状态检测命令

1、ifconfig命令 用于获取网卡配置于状态状态的等信息: ens33:网卡名称 inet:ip地址 ether:网卡物理地址(mac地址) RX、TX:接收数据包与发送数据包的个数及累计流量 我们也可以直接通过网卡名称查对应信息: 2、查看系统版本的…

设计模式 - 工厂 Factory Method Pattern

文章参考来源 一、概念 创建简单的对象直接 new 一个就完事,但对于创建时需要各种配置的复杂对象例如手机,没有工厂的情况下,用户需要自己处理屏幕、摄像头、处理器等配置,这样用户和手机就耦合在一起了。 可以使代码结构清晰&a…

【人工智能】— 贝叶斯网络

【人工智能】— 贝叶斯网络 频率学派 vs. 贝叶斯学派贝叶斯学派Probability(概率):独立性/条件独立性:Probability Theory(概率论):Graphical models (概率图模型)什么是图模型(Grap…

【每日一题/哈希表运用题】1054. 距离相等的条形码

⭐️前面的话⭐️ 本篇文章介绍【距离相等的条形码】题解,题目标签【哈希表】, 【贪心】,【优先级队列】,展示语言c/java。 📒博客主页:未见花闻的博客主页 🎉欢迎关注🔎点赞&#…

【计算机网络复习】第四章 网络层 2

源主机网络层的主要工作 路由器网络层的主要工作 目的主机网络层的主要工作 网络层提供的服务 o 屏蔽底层网络的差异,向传输层提供一致的服务 虚电路网络 o 虚电路网络提供面向连接的服务 n 借鉴了电路交换的优点 n 发送数据之前,源主机和目的主机…

MTK耳机识别

MTK耳机检测分为Eint only和EintAccdet 其中主流的是Eint Accdet(multi-key)。 图为MTK 耳机相关电路图的主要部分。 其中,左右声道的33pF主要滤除TDD干扰。串的10R100nf下地电容为低通滤波器。磁珠主要影响的是Fm以及音频THD性能。 Eint:检测耳机是否…

网络基础知识(3)——初识TCP/IP

首先给大家说明的是,TCP/IP 协议它其实是一个协议族,包含了众多的协议,譬如应用层协议 HTTP、 FTP、MQTT…以及传输层协议 TCP、UDP 等这些都属于 TCP/IP 协议。 所以,我们一般说 TCP/IP 协议,它不是指某一个具体的网络…

Casdoor 开始

Casdoor 是一个基于 OAuth 2.0 / OIDC 的中心化的单点登录(SSO)身份验证平台,简单来说,就是 Casdoor 可以帮你解决用户管理的难题,你无需开发用户登录、注册等与用户鉴权相关的一系列功能,只需几个步骤进行…

C++多线程中共享变量同步问题

目录 1、互斥量 (1)std::mutex (2)std::recursive_mutex (3)std::timed_mutex 2、锁管理器 (1)std::lock_guardlk (2)std::unique_locklk &#xff0…

掌控MySQL并发:深度解析锁机制与并发控制

前一篇MySQL读取的记录和我想象的不一致——事物隔离级别和MVCC 讲了事务在并发执行时可能引发的一致性问题的各种现象。一般分为下面3种情况: 读 - 读情况:并发事务相继读取相同的记录。读取操作本身不会对记录有任何影响,不会引起什么问题&…

【C++】C++中的多态

目录 一.多态的概念二.多态的定义及实现2.1虚函数2.2虚函数的重写虚函数重写的两个例外 2.3多态的构成条件2.4C11 override 和final2.5重载、重写、隐藏的对比 三.抽象类3.1概念3.2接口继承和实现继承 四.多态的原理4.1虚函数表4.2多态的原理(1)代码分析(2)清理解决方案 4.3动态…

MySQL高阶语句与连接

目录 高级查询selectDISTINCTWHEREAND ORINBETWEEN通配符与likeORDER BY数学函数聚合函数字符串函数mysql进阶查询GROUP BYHAVING别名子查询EXISTS连接查询inner join(内连接)left join(左连接)right join(右连接)自我连接 高级查询 实验准备: 第一张表&#xff1a…