生成对抗网络DCGAN学习实践

news2025/1/18 9:59:59

在AI内容生成领域,有三种常见的AI模型技术:GAN、VAE、Diffusion。其中,Diffusion是较新的技术,相关资料较为稀缺。VAE通常更多用于压缩任务,而GAN由于其问世较早,相关的开源项目和科普文章也更加全面,适合入门学习。

博主从入门和学习角度用Tensorflow跑通了DCGAN,本文对其进行记录以及分享。

1.简介

GAN(Generative Adversarial Network)是一种用于生成模型的机器学习框架。其原理基于两个主要组件:生成器(Generator)和判别器(Discriminator),二者通过对抗学习的方式相互竞争和提升。

从2014年左右发展至今,GAN目前有很多分支:

  • GAN 朴素GAN,最原始版本
  • DCGAN 卷积神经网络GAN
  • CGAN 条件GAN,训练时传入额外条件,例如通过不同的mask区域生成不同内容,可控制的生成
  • SeqGAN 使用GAN生成某些风格的句子,但不能进行对答
  • Cycle GAN 可实现图像风格迁移,其实现略复杂
  • 省略

2.原理介绍

先来看图

梯度
判别
G
LeakyReLU
tanh
InputNoise
FullConnectLayer123
OutputImage
D
LeakyReLU
Sigmoid
InputImage
FullConnectLayer12
OutputOneValue

生成器(Generator)和判别器(Discriminator)是GAN的两个主要模型,生成器在上图中用缩写G表示,判别器用缩写D表示。
生成器G输入[N]的一维噪声,即InputNoise。输出[W * H * RGB](大致类似)的张量
判别器D输入一张图像,输出[1]的张量,即一个浮点数,通过0-1的值得到图像是真还是假

判别器需要尽可能的认出造假图片,生成器需要尽可能的骗过判别器,两者会在这2个目标上不断的通过反向传播进行学习,从而达到生成器和判别器的纳什均衡,最终输出质量很高的生成图像。

2.2 重点1

在训练中,判别器返回一个0-1区间的浮点数(如[0]=0.63,[0]=0.21)作为判断结果,值越高也越认为是真实图片。由于判别器也是一个神经网络模型,因此可以将输出层的梯度一直传递回输入层,然后将输入层的梯度作为生成器的梯度继续反向传播,从而完成一次训练。

然而,很多文章并没有提到这一点。如果没有接触过这种多模型梯度传递训练方法,可能会认为使用一个数学方法或者计算机视觉方法来构建判别器也可以让整个模型正常运行。但事实上,这种方法是不可行的(通常情况下)。

2.3 重点2

使用更多的层可以增强模型的推理能力。例如,在训练过程中,如果模型生成出眉毛 A 的特征,则有鼻子 B、C 和 D 相关的备选项;而如果生成出眉毛 E 的特征,则有鼻子 F 和 G 相关的备选项。

这也是为什么生成器需要使用三个隐层的原因(博主的观点)。通过增加隐层的数量,模型可以捕捉到更多的特征和抽象概念,从而提高生成器的表现能力和推理能力。更深层次的网络结构能够帮助模型学习更复杂的模式和关联,使其在生成结果时更加准确和多样化。

上图生成器部分的激活函数用的是LeakyReLU,实际上就单隐层神经网络来说,ReLU要比Sigmoid能多解决很多类型问题,Sigmoid更适合分类问题,遇到一些奇怪的问题不容易收敛,而LeakyReLU激活函数即和ReLU逻辑一样也可以返回负数信息,这是博主觉得采用这个激活函数的原因。
而至于tanH和Sigmoid的比较,它们在某种程度上相似。一般来说,网上普遍认为tanH比Sigmoid更好,主要原因是它具有较窄的数值边界范围。

2.4 重点3

对于2套样本比较损失这类问题,一般使用二分类交叉熵,这不同于分类问题。
而二分类交叉熵又是在只有2种结果(r和1-r),的情况下对公式进行的简化:
https://blog.csdn.net/grayrail/article/details/131619144

2.5 模式崩溃

训练时还会出现一种情况,即生成器始终卡在一个生成结果上,比如生成0-9数字,结果训练几轮后始终在生成数字3。
这种情况称为模式崩溃,一般增加训练样本数量并调节参数,没有比较好的办法。

3.实践准备

python库下载使用国内镜像源:
https://zhuanlan.zhihu.com/p/477179822

使用方式:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pyspider

github库下载耽误时间,可以缓存到gitee:
在这里插入图片描述

而gitee也有自己缓存好的镜像库,可以先去这里查:
https://gitcode.net/mirrors

python库查找:
https://pypi.org/

在pip中查找python库:
先 pip install pip-search 再使用命令 pip_search 搜索

4.实践

全连接神经网络版本的朴素GAN效果相对较差,而DCGAN(Deep Convolutional GAN)是卷积神经网络版本的GAN,下面以DCGAN为例使用Tensorflow进行实现:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers

# 定义生成器模型
def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # 注意:batch size 没有限制

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

# 定义判别器模型
def build_discriminator():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

# 定义生成器和判别器
generator = build_generator()
discriminator = build_discriminator()

# 定义损失函数
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# 定义生成器和判别器的优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

def generator_loss(fake_output):
    return loss_fn(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = loss_fn(tf.ones_like(real_output), real_output)
    fake_loss = loss_fn(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

# 定义训练循环
@tf.function
def train_step(images):
    # 生成噪声向量
    noise = tf.random.normal([BATCH_SIZE, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # 使用生成器生成假图片
        generated_images = generator(noise, training=True)

        # 使用判别器判断真假图片
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        # 计算损失函数
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    # 计算梯度并更新生成器和判别器的参数
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def generate_and_save_images(model, epoch, test_input):

    predictions = model(test_input, training=False)
    print("predictions.shape:", predictions.shape)
    num_images = predictions.shape[0]
    rows = int(num_images ** 0.5) # 计算行数
    cols = num_images // rows # 计算列数
    
    fig = plt.figure(figsize=(8, 8))
    
    for i in range(num_images):
        plt.subplot(rows, cols, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
    
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    #plt.show()

# 加载MNIST数据集
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

# 标准化数据
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5

# 批量大小与训练次数
BATCH_SIZE = 256
EPOCHS = 50

# 数据集切分为批次并进行训练
dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
    for i,image_batch in enumerate(dataset):
        print("sub i",i)
        train_step(image_batch)

    print("------------------------------------------------------epoch:", epoch)

    # 每个 epoch 结束后生成并保存一组图像
    if (epoch + 1) % 5 == 0:
        seed = tf.random.normal([BATCH_SIZE, 100])
        generate_and_save_images(generator, epoch + 1, seed)

跑一阵子MNIST数据集后,结果如下:
在这里插入图片描述


参考:

论文精读: https://www.bilibili.com/video/BV1rb4y187vD

同济子豪兄精读版本: https://www.bilibili.com/video/BV1oi4y1m7np

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/814021.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

809协议服务端程序解码程序

809协议服务端程序解码程序 目录概述需求: 设计思路实现思路分析1.服务端2.code: 拓展实现性能参数测试:1.功能测试 参考资料和推荐阅读 Survive by day and develop by night. talk for import biz , show your perfect code,full busy,skip…

【LeetCode】探索杨辉三角模型

一、题目描述 力扣原题 首先我们要来了解一下题目本身在说些什么,通过下方的动图我们可以更加清楚地看到杨辉三角是怎样一步步生成的。给到的示例中我们通过输入杨辉三角的行数,然后通过计算得到这个杨辉三角的每一行是什么具体的数值 二、模型选择 首先…

大数据技术之ClickHouse---入门篇---介绍

星光下的赶路人star的个人主页 一棵树长到它想长到的高度之后,它才知道怎样的空气适合它 文章目录 1、Clickhouse入门1.1 什么是Clickhouse1.1.1 Clickhouse的特点1.1.1.1 列示储存1.1.1.2 DBMS的功能1.1.1.3 多样化引擎1.1.1.4 高吞吐写入能力1.1.1.5 数据分区与线…

JAVA SE -- 第十三天

(全部来自“韩顺平教育”) 集合 一、集合框架体系 集合主要是两组(单列集合、双列集合) Collection接口有两个重要的子接口List 、Set,它们的实现子类都是单列集合 Map接口的实现子类是双列集合,存放的…

进阶C语言——再识结构体

1 结构的基础知识 结构是一些值的集合,这些值称为成员变量。结构的每个成员可以是不同类型的变量。 2 结构的声明 struct tag {member-list; }variable-list;如果下面来描述一个学生的话,我们会想到学生的姓名,成绩,性别等&#…

【序列化工具JdkSerialize和Protostuff】

序列化工具对比 JdkSerialize:java内置的序列化能将实现了Serilazable接口的对象进行序列化和反序列化, ObjectOutputStream的writeObject()方法可序列化对象生成字节数组 Protostuff:google开源的protostuff采用更为紧凑的二进制数组&#…

1-linux下mysql8.0.33安装

在互联网企业的日常工作/运维中,我们会经常用到mysql数据库,而linux下mysql的安装方式有三种: 1.mysql rpm安装 2.mysql二进制安装 3.mysql源码安装 今天就为大家讲讲linux下mysql8.0.33版本rpm方式的安装。 1.前提 1.1.系统版本 Cent…

目标识别数据集互相转换——xml、txt、json数据格式互转

VOC数据格式与YOLO数据格式互转 1.VOC数据格式 VOC(Visual Object Classes)是一个常用的计算机视觉数据集,它主要用于对象检测、分类和分割任务。VOC的标注格式,也被许多其他的数据集采用,因此理解这个数据格式是很重…

Pytest+Allure+Excel接口自动化测试框架实战

1. Allure 简介 简介 Allure 框架是一个灵活的、轻量级的、支持多语言的测试报告工具,它不仅以 Web 的方式展示了简介的测试结果,而且允许参与开发过程的每个人可以从日常执行的测试中,最大限度地提取有用信息。 Allure 是由 Java 语言开发…

C#文件操作从入门到精通(1)——INI文件操作

点击这里:微软官方文档查看writePrivateProfileString函数定义 常见错误: 1、中文路径写入失败,为啥? 2、文件不是全路径,只有文件名也会写入失败: 3、GetLastError怎么使用? GetLastError错误代码含义: (0)-操作成功完成。 (1)-功能错误。 (2)- 系统找不到指定的文件…

62 # 借用 promise 写成类的方法

新建 62 文件夹&#xff0c;里面添加三个文件 index.html <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><tit…

在周末,找回属于自己的时间~

&#x1f3c6;作者简介&#xff0c;黑夜开发者&#xff0c;全栈领域新星创作者✌&#xff0c;2023年6月csdn上海赛道top4。 &#x1f466;&#x1f3fb;个人主页 &#xff1a; 点击这里 &#x1f4bb;推荐专栏1&#xff1a;PHP面试题专区&#xff08;2023&#xff09; PHP入门基…

【深度学习】Inst-Inpaint: Instructing to Remove Objects with Diffusion Models,指令式图像修复

论文&#xff1a;https://arxiv.org/abs/2304.03246 code:http://instinpaint.abyildirim.com/ 文章目录 AbstractIntroductionRelated WorkDataset GenerationMethodPS Abstract 图像修复任务是指从图像中擦除不需要的像素&#xff0c;并以语义一致且逼真的方式填充它们。传统…

C++基础知识 (引用)

⭐️ 往期相关文章 ✨ 链接&#xff1a;C基础知识&#xff08;命名空间、输入输出、函数的缺省参数、函数重载&#xff09; ⭐️ 引用 引用从语法的层面上讲&#xff1a;引用不是定义一个新的变量&#xff0c;而是给已存在的变量取了一个别名&#xff0c;编译器不会为引用变…

黑马大数据学习笔记3-MapReduce配置和YARN部署以及基本命令

目录 部署说明MapReduce配置文件YARN配置文件分发配置文件集群启动命令开始启动YARN集群 查看YARN的WEB UI页面保存快照YARN集群的启停命令一键启动脚本单进程启停 提交MapReduce任务到YARN执行提交wordcount示例程序查看运行日志提交求圆周率示例程序 p41~43 https://www.bili…

无线温湿度信息中继器模块的组成和工作状态及编程与组网建议

在无线温湿度信息收集系统中&#xff0c;信息中继器模块是连接终端信息点与互联网的重要节点。本文将详细介绍该模块的组成和工作状态&#xff0c;并给出编程和组网的建议。 一、组成 该无线温湿度信息中继器模块由以下几个核心组成部分构成&#xff1a; STM32F103ZET6主控芯片…

17- C++ const和异常-5 (C++)

第六章 C对C的拓展2 6.1 const详解 6.1.1 const 修饰普通变量 被修饰的对象是只读的 const int a; //a的值是只读的 int const a; const int * p; 该语句表示指向整形常量的 指针&#xff0c;它指向的值不能修改。 int const * p; 该语句与b的含义相同&#xff0c;表…

adobe ps beta的使用方法

1、人物换发型。 1&#xff09;套索套选出来相关的头发。 2&#xff09;点击生成&#xff0c;输入“red hair” 按“生成”键。 2、人物换眼睛。 1&#xff09;套索套选出来相关的眼睛区域&#xff0c;大一点范围。 2&#xff09;点击生成&#xff0c;输入“blue eyes"…

【Golang 接口自动化03】 解析接口返回XML

目录 解析接口返回数据 定义结构体 解析函数&#xff1a; 测试 优化 资料获取方法 上一篇我们学习了怎么发送各种数据类型的http请求&#xff0c;这一篇我们来介绍怎么来解析接口返回的XML的数据。 解析接口返回数据 定义结构体 假设我们现在有一个接口返回的数据resp如…

分布式软件架构——内容分发网络

内容分发网络&#xff08;CDN&#xff0c;Content Distribution Network或Content Delivery Network&#xff09; 其基本思路是尽可能避开互联网上有可能影响数据传输速度和稳定性的瓶颈和环节&#xff0c;使内容传输得更快、更稳定。通过在网络各处放置节点服务器所构成的在现…