生成对抗网络DCGAN实践笔记

news2024/9/24 11:25:55

在AI内容生成领域,有三种常见的AI模型技术:GAN、VAE、Diffusion。其中,Diffusion是较新的技术,相关资料较为稀缺。VAE通常更多用于压缩任务,而GAN由于其问世较早,相关的开源项目和科普文章也更加全面,适合入门学习。

博主从入门和学习角度用Tensorflow跑通了DCGAN,本文对其进行记录以及分享。

1.简介

GAN(Generative Adversarial Network)是一种用于生成模型的机器学习框架。其原理基于两个主要组件:生成器(Generator)和判别器(Discriminator),二者通过对抗学习的方式相互竞争和提升。

从2014年左右发展至今,GAN目前有很多分支:

  • GAN 朴素GAN,最原始版本
  • DCGAN 卷积神经网络GAN
  • CGAN 条件GAN,训练时传入额外条件,例如通过不同的mask区域生成不同内容,可控制的生成
  • SeqGAN 使用GAN生成某些风格的句子,但不能进行对答
  • Cycle GAN 可实现图像风格迁移,其实现略复杂
  • 省略

2.原理介绍

先来看图

梯度
判别
G
LeakyReLU
tanh
InputNoise
FullConnectLayer123
OutputImage
D
LeakyReLU
Sigmoid
InputImage
FullConnectLayer12
OutputOneValue

生成器(Generator)和判别器(Discriminator)是GAN的两个主要模型,生成器在上图中用缩写G表示,判别器用缩写D表示。
生成器G输入[N]的一维噪声,即InputNoise。输出[W * H * RGB](大致类似)的张量
判别器D输入一张图像,输出[1]的张量,即一个浮点数,通过0-1的值得到图像是真还是假

在训练中,判别器返回一个数值作为判断结果。由于判别器也是一个神经网络模型,因此可以将输出层的梯度一直传递回输入层,然后将输入层的梯度作为生成器的梯度继续反向传播,从而完成一次训练。

判别器需要尽可能的认出造假图片,生成器需要尽可能的骗过判别器,两者会在这2个目标上不断的通过反向传播进行学习,从而达到生成器和判别器的纳什均衡,最终输出质量很高的生成图像。

2.2 重点1

在训练中,判别器返回一个0-1区间的浮点数(如[0]=0.63,[0]=0.21)作为判断结果,值越高也越认为是真实图片。由于判别器也是一个神经网络模型,因此可以将输出层的梯度一直传递回输入层,然后将输入层的梯度作为生成器的梯度继续反向传播,从而完成一次训练。

然而,很多文章并没有提到这一点。如果没有接触过这种多模型梯度传递训练方法,可能会认为使用一个数学方法或者计算机视觉方法来构建判别器也可以让整个模型正常运行。但事实上,这种方法是不可行的(通常情况下)。

2.3 重点2

使用更多的层可以增强模型的推理能力。例如,在训练过程中,如果模型训练出眉毛 A 的特征,则会生成与鼻子 B、C 和 D 相关的备选项;而如果训练出眉毛 E 的特征,则会生成与鼻子 F 和 G 相关的备选项。

这也是为什么生成器需要使用三个隐层的原因(博主的观点)。通过增加隐层的数量,模型可以捕捉到更多的特征和抽象概念,从而提高生成器的表现能力和推理能力。更深层次的网络结构能够帮助模型学习更复杂的模式和关联,使其在生成结果时更加准确和多样化。

上图生成器部分的激活函数用的是LeakyReLU,实际上就单隐层神经网络来说,ReLU要比Sigmoid能多解决很多类型问题,Sigmoid更适合分类问题,遇到一些奇怪的问题不容易收敛,而LeakyReLU激活函数即和ReLU逻辑一样也可以返回负数信息,这是博主觉得采用这个激活函数的原因。
而至于tanH和Sigmoid的比较,它们在某种程度上相似。一般来说,网上普遍认为tanH比Sigmoid更好,主要原因是它具有较窄的数值边界范围。

2.4 重点3

对于2套样本比较损失这类问题,一般使用二分类交叉熵,这不同于分类问题。
而二分类交叉熵又是在只有2种结果(r和1-r),的情况下对公式进行的简化:
https://blog.csdn.net/grayrail/article/details/131619144

3.实践准备

python库下载使用国内镜像源:
https://zhuanlan.zhihu.com/p/477179822

使用方式:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pyspider

github库下载耽误时间,可以缓存到gitee:
在这里插入图片描述

而gitee也有自己缓存好的镜像库,可以先去这里查:
https://gitcode.net/mirrors

python库查找:
https://pypi.org/

在pip中查找python库:
先 pip install pip-search 再使用命令 pip_search 搜索

4.实践

全连接神经网络版本的朴素GAN效果相对较差,而DCGAN(Deep Convolutional GAN)是卷积神经网络版本的GAN,下面以DCGAN为例使用Tensorflow进行实现:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers

# 定义生成器模型
def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # 注意:batch size 没有限制

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

# 定义判别器模型
def build_discriminator():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

# 定义生成器和判别器
generator = build_generator()
discriminator = build_discriminator()

# 定义损失函数
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# 定义生成器和判别器的优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

def generator_loss(fake_output):
    return loss_fn(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = loss_fn(tf.ones_like(real_output), real_output)
    fake_loss = loss_fn(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

# 定义训练循环
@tf.function
def train_step(images):
    # 生成噪声向量
    noise = tf.random.normal([BATCH_SIZE, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # 使用生成器生成假图片
        generated_images = generator(noise, training=True)

        # 使用判别器判断真假图片
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        # 计算损失函数
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    # 计算梯度并更新生成器和判别器的参数
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def generate_and_save_images(model, epoch, test_input):

    predictions = model(test_input, training=False)
    print("predictions.shape:", predictions.shape)
    num_images = predictions.shape[0]
    rows = int(num_images ** 0.5) # 计算行数
    cols = num_images // rows # 计算列数
    
    fig = plt.figure(figsize=(8, 8))
    
    for i in range(num_images):
        plt.subplot(rows, cols, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
    
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    #plt.show()

# 加载MNIST数据集
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

# 标准化数据
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5

# 批量大小与训练次数
BATCH_SIZE = 256
EPOCHS = 50

# 数据集切分为批次并进行训练
dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
    for i,image_batch in enumerate(dataset):
        print("sub i",i)
        train_step(image_batch)

    print("------------------------------------------------------epoch:", epoch)

    # 每个 epoch 结束后生成并保存一组图像
    if (epoch + 1) % 5 == 0:
        seed = tf.random.normal([BATCH_SIZE, 100])
        generate_and_save_images(generator, epoch + 1, seed)

跑一阵子MNIST数据集后,结果如下:
在这里插入图片描述


参考:

论文精读: https://www.bilibili.com/video/BV1rb4y187vD

同济子豪兄精读版本: https://www.bilibili.com/video/BV1oi4y1m7np

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/810624.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华为OD机试真题2022Q4 A + 2023 B卷(Java)

大家好,我是哪吒。 五月份之前,如果你参加华为OD机试,收到的应该是2022Q4或2023Q1,这两个都是A卷题。 5月10日之后,很多小伙伴收到的是B卷,那么恭喜你看到本文了,抓紧刷题吧。B卷新题库正在更…

微服务 - Consul集群化 · 服务注册 · 健康检测 · 服务发现 · 负载均衡

一、Consul 概括 Consul 是由N多个节点(台机/虚机/容器)组成,每个节点中都有 Agent 运行着,各节点间用RPC通信,所有节点内相同的 Datacenter 名称为一个数据中心,节点又分三种角色 Client/Server/Leader: Agent&…

Python算法笔记(3)-树、二叉树、二叉堆、二叉搜索树

树和二叉树 什么是树 树是一种非线性的数据结构,由n个节点构成的有限集合,节点数0的树叫空树,在任意一棵树中,有且仅有一个特点的称为根节点,当N>1时,其余节点可分m为互不相交的有限集。 例如如下&…

子序列,回文串相关题目

class Solution { public:int dp[2510];int lengthOfLIS(vector<int>& nums) {//dp[i]表示以nums[i]为结尾的最长子序列的长度int nnums.size();for(int i0;i<n;i){dp[i]1;}for(int i1;i<n;i){for(int j0;j<i;j){if(nums[i]>nums[j]){dp[i]max(dp[i],dp[…

因子分解机介绍和PyTorch代码实现

因子分解机&#xff08;Factorization Machines&#xff0c;简称FM&#xff09;是一种用于解决推荐系统、回归和分类等机器学习任务的模型。它由Steffen Rendle于2010年提出&#xff0c;是一种基于线性模型的扩展方法&#xff0c;能够有效地处理高维稀疏数据&#xff0c;并且在…

用Blender做一个足球烯C60

文章目录 作图思路先做一个足球球棍模型平滑 Blender初学者入门&#xff1a;做一个魔方 作图思路 C 60 C_{60} C60​是由60个碳原子构成&#xff0c;形似足球&#xff0c;又名足球烯。而足球的顶点&#xff0c;可以通过正二十面体削去顶点得到&#xff0c;原理可参照这篇&…

基于数据驱动的多尺度表示的信号去噪统计方法研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

React组件进阶之children属性,props校验与默认值以及静态属性static

React组件进阶之children属性,props校验与默认值以及静态属性static 一、children属性二、props校验2.1 props说明2.2 prop-types的安装2.3 props校验规则2.4 props默认值 三、静态属性static 一、children属性 children 属性&#xff1a;表示该组件的子节点&#xff0c;只要组…

网站创建004:跟用户交互的标签

input 系列&#xff1a; <body><input type"text" /> <!--文本输入框--><input type"password" /> <!--密码输入框--><input type"checkbox" /> <!--复选框--><input type"checkbox"…

【MySQL】使用C语言连接

​&#x1f320; 作者&#xff1a;阿亮joy. &#x1f386;专栏&#xff1a;《零基础入门MySQL》 &#x1f387; 座右铭&#xff1a;每个优秀的人都有一段沉默的时光&#xff0c;那段时光是付出了很多努力却得不到结果的日子&#xff0c;我们把它叫做扎根 目录 &#x1f449;my…

用CSS和HTML写一个水果库存静态页面

HTML代码&#xff1a; <!DOCTYPE html> <html> <head><link rel"stylesheet" type"text/css" href"styles.css"> </head> <body><header><h1>水果库存</h1></header><table>…

函数指针及其使用

类比 数组的地址 函数的地址 数组指针 函数的指针 函数指针的运用 有趣的代码1

从0到1构建基于自身业务的前端工具库

前言 在实际项目开发中无论 M 端、PC 端&#xff0c;或多或少都有一个 utils 文件目录去管理项目中用到的一些常用的工具方法&#xff0c;比如&#xff1a;时间处理、价格处理、解析url参数、加载脚本等&#xff0c;其中很多是重复、基础、或基于某种业务场景的工具&#xff0…

链表(一) 单链表操作详解

文章目录 一、什么是链表二、链表的分类1、单向或者双向2、带头或不带头3、循环或不循环 三、无头单向不循环链表的实现SList.hSList.c动态申请一个节点单链表打印单链表尾插单链表头插单链表的尾删单链表头删单链表查找在pos位置前插入单链表在pos位置之后插入x删除pos位置单链…

自动驾驶下半场的“入场券”

交流群 | 进“传感器群/滑板底盘群/汽车基础软件群/域控制器群”请扫描文末二维码&#xff0c;添加九章小助手&#xff0c;务必备注交流群名称 真实姓名 公司 职位&#xff08;不备注无法通过好友验证&#xff09; 作者 | 张萌宇 自动驾驶战争的上半场拼的是硬件和算法&…

DTC介绍

DTC 一般由3个字节组成&#xff1a; 字节1&#xff1a;High Byte bit 7-6: 对应DTC属于哪一个系统&#xff0c;P: 00动力系统、C: 01底盘、B: 10车身和U: 11通信系统bit 5-4: 用来区分DTC是标准组织所定义还是制造商自定义 00: ISO/SAE01: 制造商10: ISO/SAE11: ISO/SAE bit 3…

【Rust教程 | 基础系列2 | Cargo工具】Cargo介绍及使用

文章目录 前言一&#xff0c;Cargo介绍1&#xff0c;Cargo安装2&#xff0c;创建Rust项目2&#xff0c;编译项目&#xff1a;3&#xff0c;运行项目&#xff1a;4&#xff0c;测试项目&#xff1a;5&#xff0c;更新项目的依赖&#xff1a;6&#xff0c;生成项目的文档&#xf…

python皮卡丘字符打印代码,用python皮卡丘的代码

大家好&#xff0c;本文将围绕python皮卡丘字符打印代码展开说明&#xff0c;python皮卡丘编程代码教程是一个很多人都想弄明白的事情&#xff0c;想搞清楚python皮卡丘编程代码需要先了解以下几个事情。 1、我用python画皮卡丘&#xff0c;没有错误出现&#xff0c;我也打开才…

内网横向移动—NTLM-Relay重放Responder中继攻击LdapEws

内网横向移动—NTLM-Relay重放&Responder中继攻击&Ldap&Ews 1. 前置了解1.1. MSF与CS切换权限1.1.1. CS会话中切换权限1.1.1.1. 查看进程1.1.1.2. 权限权限 1.1.2. MSF会话中切换权限 2. NTLM中继攻击—Relay重放—SMB上线2.1. 案例测试2.1.1. 同账户密码测试2.1.2…

如何使用CRM系统进行客户关系维护管理?

企业要想持续的发展&#xff0c;就必须管理和维护与客户的关系。但如今客户需求更加复杂和多样化&#xff0c;维护客户关系的难度越来越大。许多企业使用CRM系统来帮助自己管理客户关系。通过本文&#xff0c;让您客户关系维护管理全知道。 1、客户画像 CRM系统可以帮助企业建…