生成对抗网络DCGAN学习

news2025/1/11 12:36:13

在AI内容生成领域,有三种常见的AI模型技术:GAN、VAE、Diffusion。其中,Diffusion是较新的技术,相关资料较为稀缺。VAE通常更多用于压缩任务,而GAN由于其问世较早,相关的开源项目和科普文章也更加全面,适合入门学习。

博主从入门和学习角度用Tensorflow跑通了DCGAN,本文对其进行记录以及分享。

1.简介

GAN(Generative Adversarial Network)是一种用于生成模型的机器学习框架。其原理基于两个主要组件:生成器(Generator)和判别器(Discriminator),二者通过对抗学习的方式相互竞争和提升。

从2014年左右发展至今,GAN目前有很多分支:

  • GAN 朴素GAN,最原始版本
  • DCGAN 卷积神经网络GAN
  • CGAN 条件GAN,训练时传入额外条件,例如通过不同的mask区域生成不同内容,可控制的生成
  • SeqGAN 使用GAN生成某些风格的句子,但不能进行对答
  • Cycle GAN 可实现图像风格迁移,其实现略复杂
  • 省略

2.原理介绍

先来看图

梯度
判别
G
LeakyReLU
tanh
InputNoise
FullConnectLayer123
OutputImage
D
LeakyReLU
Sigmoid
InputImage
FullConnectLayer12
OutputOneValue

生成器(Generator)和判别器(Discriminator)是GAN的两个主要模型,生成器在上图中用缩写G表示,判别器用缩写D表示。
生成器G输入[N]的一维噪声,即InputNoise。输出[W * H * RGB](大致类似)的张量
判别器D输入一张图像,输出[1]的张量,即一个浮点数,通过0-1的值得到图像是真还是假

判别器需要尽可能的认出造假图片,生成器需要尽可能的骗过判别器,两者会在这2个目标上不断的通过反向传播进行学习,从而达到生成器和判别器的纳什均衡,最终输出质量很高的生成图像。

2.2 重点1

在训练中,判别器返回一个0-1区间的浮点数(如[0]=0.63,[0]=0.21)作为判断结果,值越高也越认为是真实图片。由于判别器也是一个神经网络模型,因此可以将输出层的梯度一直传递回输入层,然后将输入层的梯度作为生成器的梯度继续反向传播,从而完成一次训练。

然而,很多文章并没有提到这一点。如果没有接触过这种多模型梯度传递训练方法,可能会认为使用一个数学方法或者计算机视觉方法来构建判别器也可以让整个模型正常运行。但事实上,这种方法是不可行的(通常情况下)。

2.3 重点2

使用更多的层可以增强模型的推理能力。例如,在训练过程中,如果模型生成出眉毛 A 的特征,则有鼻子 B、C 和 D 相关的备选项;而如果生成出眉毛 E 的特征,则有鼻子 F 和 G 相关的备选项。

这也是为什么生成器需要使用三个隐层的原因(博主的观点)。通过增加隐层的数量,模型可以捕捉到更多的特征和抽象概念,从而提高生成器的表现能力和推理能力。更深层次的网络结构能够帮助模型学习更复杂的模式和关联,使其在生成结果时更加准确和多样化。

上图生成器部分的激活函数用的是LeakyReLU,实际上就单隐层神经网络来说,ReLU要比Sigmoid能多解决很多类型问题,Sigmoid更适合分类问题,遇到一些奇怪的问题不容易收敛,而LeakyReLU激活函数即和ReLU逻辑一样也可以返回负数信息,这是博主觉得采用这个激活函数的原因。
而至于tanH和Sigmoid的比较,它们在某种程度上相似。一般来说,网上普遍认为tanH比Sigmoid更好,主要原因是它具有较窄的数值边界范围。

2.4 重点3

对于2套样本比较损失这类问题,一般使用二分类交叉熵,这不同于分类问题。
而二分类交叉熵又是在只有2种结果(r和1-r),的情况下对公式进行的简化:
https://blog.csdn.net/grayrail/article/details/131619144

2.5 模式崩溃

训练时还会出现一种情况,即生成器始终卡在一个生成结果上,比如生成0-9数字,结果训练几轮后始终在生成数字3。
这种情况称为模式崩溃,一般增加训练样本数量并调节参数,没有比较好的办法。

3.实践准备

python库下载使用国内镜像源:
https://zhuanlan.zhihu.com/p/477179822

使用方式:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pyspider

github库下载耽误时间,可以缓存到gitee:
在这里插入图片描述

而gitee也有自己缓存好的镜像库,可以先去这里查:
https://gitcode.net/mirrors

python库查找:
https://pypi.org/

在pip中查找python库:
先 pip install pip-search 再使用命令 pip_search 搜索

4.实践

全连接神经网络版本的朴素GAN效果相对较差,而DCGAN(Deep Convolutional GAN)是卷积神经网络版本的GAN,下面以DCGAN为例使用Tensorflow进行实现:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers

# 定义生成器模型
def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # 注意:batch size 没有限制

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

# 定义判别器模型
def build_discriminator():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

# 定义生成器和判别器
generator = build_generator()
discriminator = build_discriminator()

# 定义损失函数
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# 定义生成器和判别器的优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

def generator_loss(fake_output):
    return loss_fn(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = loss_fn(tf.ones_like(real_output), real_output)
    fake_loss = loss_fn(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

# 定义训练循环
@tf.function  #这个是tensorflow的装饰器,标记后可提升性能,不加此标记也可
def train_step(images):
    # 生成噪声向量
    noise = tf.random.normal([BATCH_SIZE, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # 使用生成器生成假图片
        generated_images = generator(noise, training=True)

        # 使用判别器判断真假图片
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        # 计算损失函数
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    # 计算梯度并更新生成器和判别器的参数
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def generate_and_save_images(model, epoch, test_input):

    predictions = model(test_input, training=False)
    print("predictions.shape:", predictions.shape)
    num_images = predictions.shape[0]
    rows = int(num_images ** 0.5) # 计算行数
    cols = num_images // rows # 计算列数
    
    fig = plt.figure(figsize=(8, 8))
    
    for i in range(num_images):
        plt.subplot(rows, cols, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
    
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    #plt.show()

# 加载MNIST数据集
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

# 标准化数据
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5

# 批量大小与训练次数
BATCH_SIZE = 256
EPOCHS = 50

# 数据集切分为批次并进行训练
dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
    for i,image_batch in enumerate(dataset):
        print("sub i",i)
        train_step(image_batch)

    print("------------------------------------------------------epoch:", epoch)

    # 每个 epoch 结束后生成并保存一组图像
    if (epoch + 1) % 5 == 0:
        seed = tf.random.normal([BATCH_SIZE, 100])
        generate_and_save_images(generator, epoch + 1, seed)

跑一阵子MNIST数据集后,结果如下:
在这里插入图片描述


参考:

论文精读: https://www.bilibili.com/video/BV1rb4y187vD

同济子豪兄精读版本: https://www.bilibili.com/video/BV1oi4y1m7np

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/826649.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何在 Ubuntu 上部署 ONLYOFFICE 协作空间社区版?

ONLYOFFICE 协作空间是一个在线协作平台,帮助您更好地与客户、业务合作伙伴、承包商及第三方进行文档协作。今天我们来介绍一下,如何在 Ubuntu 上安装协作空间的自托管版。 ONLYOFFICE 协作空间主要功能 使用 ONLYOFFICE 协作空间,您可以&am…

Golang 函数参数的传递方式 值传递,引用传递

基本介绍 我们在讲解函数注意事项和使用细节时,已经讲过值类型和引用类型了,这里我们再系统总结一下,因为这是重难点,值类型参数默认就是值传递,而引用类型参数默认就是引用传递。 两种传递方式(函数默认都…

MySQL数据库备份

目录 一、概述 二、数据备份的重要性 三、造成数据丢失的原因 四、备份类型 五、常见的备份方法 六、备份 一、概述 数据库备份是指将数据库中的数据、表格、视图、存储过程、触发器等信息备份到另一个地方,以便在数据库丢失或损坏时进行恢复。数据库备份是数…

以CS32F031为例浅说国产32位MCU的内核处理器

芯片内核又称CPU内核,它是CPU中间的核心芯片,是CPU最重要的组成部分。由单晶硅制成,CPU所有的计算、接受/存储命令、处理数据都由核心执行。各种CPU核心都具有固定的逻辑结构,一级缓存、二级缓存、执行单元、指令级单元和总线接口…

【C++】string类

目录 🌞专栏导读 🌛为什么学习string类? ⭐C语言中的字符串 🌛标准库中的string类 ⭐基本使用string ⭐string类的常用接口 ⭐总结: 🌛范围for的使用 🌞专栏导读 🌟作者简介…

docker更换数据存储路径

1. 先停掉docker服务 sudo systemctl stop docker 可能会出现的问题: 这样会导致docker关闭失败,解决办法:systemctl stop docker.socket 确保docker关闭: 2.备份现在的 Docker 数据存储目录 /var/lib/docker(默认路径) mv /var/lib/docker /var/lib/…

高斯函数的傅里叶变换与离散化频谱分析

1 高斯函数的傅里叶变换 主要参考自 http://www.cse.yorku.ca/~kosta/CompVis_Notes/fourier_transform_Gaussian.pdf 对于中心化的高斯函数,即 g ( x ) 1 2 π σ e − x 2 2 σ 2 , (1.1) g\left( x \right) \frac{1}{{\sqrt {2\pi } \sigma }}{e^{ - \frac{{{x…

【有趣的】关于Map的一些小测试

Map在代码中用到得非常多,它是无序的、key-value结构的,其读取会非常快。 今天看了个小文章Map判空 、空字符串、空key值等各种判断方法,你都掌握了吗?便自己也玩一下。 一、判空 因为对象已经new出来了,所以map指向的…

【洁洁送书第三期】人性的光辉,python之光

这里写目录标题 python学习现状python之光亮点python学习配套视频python之光目录强力推荐 python学习现状 作为生产力工具,Python是当今极为流行的编程语言。Python编程逐渐成为一项通用能力,从小学生到各个行业的从业人员都在学Python。Python确实能够…

高忆管理:多重利好共振 外资加码布局A股

资本商场活泼信号正在继续开释,内外资决心取得有力提振。以北向资金为代表的外资近来表现活泼,近六个买卖日已连续净买入超500亿元。多家外资组织近期表态称,伴跟着方针力度加强,我国经济有望继续复苏,活泼看好我国权益…

优思学院|质量工程师应具备什么能力?

质量工程师是一个需要耐心、细心、坚持态度、沟通能力、协调能力的工作,更需要持续学习强化自身的专业知识。 质量工程师负责审核、客户投诉的调查、过程的改进以达到质量之提升,他們也必须要预警生产线风险、质量异常,并且协调不同的部門一…

【雕爷学编程】Arduino动手做(181)---Maixduino AI开发板2

37款传感器与执行器的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止这37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&am…

[语义分割] ASPP不同版本对比(DeepLab、DeepLab v1、DeepLab v2、DeepLab v3、DeepLab v3+、LR-ASPP)

1. 引言 1.1 本文目的 本文主要对前段时间学习的 ASPP 模块进行对比,涉及到的 ASPP 有: ASPP in DeepLab v2,简称 ASPP v2ASPP in DeepLab v3,简称 ASPP v3ASPP in DeepLab v3,简称 ASPP v3ASPP in MobileNet v3&am…

开发提测?

前言 开发提测是正式开始测试的重要关卡,提测质量的好坏会直接影响测试阶段的效率,进而影响项目进度。较好的提测质量,对提高测试效率和优化项目进度有着事半功倍的作用。如何更好的推进开发提高提测质量呢?下面小编结合自己项目…

攻防世界-web-lottery

题目描述:里面有个附件,是网站的源代码,还有一个链接,是线上的网站 主页告诉了我们规则: 1. 每个人的初始金额为20美元 2. 一支彩票2美元,挑选7个数字,根据匹配上的数字有不同的奖励 我们先体…

基于dynamorio自制反汇编小工具 instr_trace安装

目录 概述一、下载源码二、安装dynamorio1、安装依赖2、编译3、测试安装是否成功参考截图 三、安装instr_trace工具1、文件说明2、编译3、运行 四、生成的文件格式说明(1)mov指令(寄存器->寄存器)(2)mov…

用 Yara 对红队工具 “打标“

​前言: YARA 通常是帮助恶意软件研究人员识别和分类恶意软件样本的工具,它基于文本或二进制模式创建恶意样本的描述规则,每个规则由一组字符串和一个布尔表达式组成,这些表达式决定了它的逻辑。 但是这次我们尝试使用 YARA 作为一种扫描工…

ELK日志分析系统介绍及搭建(超详细)

目录 一、ELK日志分析系统简介 二、Elasticsearch介绍 2.1Elasticsearch概述 三、Logstash介绍 四、Kibana介绍 五、ELK工作原理 六、部署ELK日志分析系统 6.1ELK Elasticsearch 集群部署(在Node1、Node2节点上操作) 6.2部署 Elasticsearch 软件 …

Chrome 调试技巧

有时候,qa测试会忽然出问题,然后需要你刷新界面,按照他的操作再来一次。 下面介绍一个更好的办法。 可以让qa打开chrome里的这个选项 Pause on uncaught 是在遇到未try的错误时暂停 下面那个是在try的时候出错时暂停 chrome自动断点 是不是特…

latex subfloat出现双括号的问题

使用latex的subfloat插入子图,编译完以后出现双括号: 搞了很长时间没搞出来,在网上查阅资料得到,在加载包的部分把subfigure去掉(不知道为什么,我并没有使用subfigure包啊,这是在头部引用中引入…