深度学习--对抗生成网络(GAN, Generative Adversarial Network)

news2024/12/23 14:04:28

对抗生成网络(GAN, Generative Adversarial Network)是一种深度学习模型,由Ian Goodfellow等人在2014年提出。GAN主要用于生成数据,通过两个神经网络相互对抗,来生成以假乱真的新数据。以下是对GAN的详细阐述,包括其概念、作用、核心要点、实现过程、代码实现和适用场景。

1. 概念

GAN由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。

  • 生成器负责生成伪造的样本数据,它的目标是生成足够真实的数据,使判别器难以区分。
  • 判别器负责区分数据是真实的(来自训练数据集)还是生成的(来自生成器)。

这两个网络通过博弈的方式相互对抗:

  • 生成器尝试欺骗判别器,生成与真实数据无差别的虚假数据;
  • 判别器试图提高辨别能力,正确区分真假数据。

最终的目标是使生成器生成的数据越来越接近于真实数据,直至判别器无法区分两者。

2. 作用

GAN的主要作用是生成新数据,常用于图像生成、数据增强、艺术创作等领域。它的优势在于无需明确的监督信号,仅通过数据分布的隐含特征进行学习和生成。

具体应用包括:

  • 图像生成:例如生成逼真的人脸、风景等图像。
  • 数据增强:扩充小样本数据集,改善模型训练效果。
  • 超分辨率重建:将低分辨率图像生成高分辨率图像。
  • 风格转换:将一种图像风格转换为另一种,例如将照片转化为绘画风格。
  • 生成虚拟数据:例如医学影像、合成声音、文本等。

3. 核心要点

GAN的核心在于生成器和判别器的相互博弈,这种机制使模型能够自我优化,但同时也存在一些关键挑战和要点:

  • 损失函数:GAN的损失函数是基于极小极大博弈的。生成器的目标是最大化判别器的损失,即让判别器判断出错;而判别器的目标是最小化这个损失,使其能够更好地区分真假数据。

    通常使用交叉熵损失(Binary Cross-Entropy)来优化生成器和判别器:

  • 模式崩溃:生成器有时会陷入生成某些特定模式的数据(称为模式崩溃),即生成器输出的多样性不足,难以生成多样的真实数据。为了解决这一问题,改进的GAN模型(如WGAN)引入了不同的损失函数和训练策略。

  • 平衡训练:生成器和判别器的训练需要保持平衡,过强的判别器会导致生成器无法学习,而过强的生成器又会让判别器失效。训练GAN时,需要小心调节它们的训练速率。

  • 网络架构:生成器和判别器的网络结构设计非常重要,通常使用深度卷积神经网络(DCNN)进行构建,尤其在图像生成任务中,DCGAN(Deep Convolutional GAN)表现优异。

4. 实现过程

GAN的实现过程包括以下几个步骤:

  1. 数据准备:选择训练数据集,例如图像或其他类型的数据集,通常需要大量真实样本。

  2. 生成噪声:生成器的输入是随机噪声,一般从高维的均匀分布或正态分布中采样。

  3. 构建生成器网络:生成器将噪声数据映射为真实数据的空间,通过深度神经网络进行逐层生成,最终输出一个逼真的样本。

  4. 构建判别器网络:判别器是一个二分类网络,输入为真实数据或生成器生成的数据,输出为其判断的概率值(0-1之间,表示真假)。

  5. 训练:采用交替训练方式,先固定生成器,训练判别器;再固定判别器,训练生成器。这个过程不断循环,生成器和判别器相互竞争,直至生成器的生成能力足以欺骗判别器。

  6. 模型评估:训练过程中,使用对抗损失或其他指标来评估生成器和判别器的效果。视觉上,生成的图像逐渐从粗糙变得逼真。

5.GAN的代码实现

下面是一个简单的GAN实现,用于生成与MNIST数据集类似的手写数字图像。

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist

# 设置随机种子,便于复现
np.random.seed(1000)
tf.random.set_seed(1000)

# 超参数设置
latent_dim = 100  # 生成器输入的噪声维度
batch_size = 128
epochs = 10000
save_interval = 1000

# 1. 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = (x_train - 127.5) / 127.5  # 将图像归一化到[-1, 1]
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 重塑为28x28x1的图像

# 2. 创建生成器模型
def build_generator():
    model = Sequential()
    model.add(Dense(256, input_dim=latent_dim))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(28 * 28 * 1, activation='tanh'))
    model.add(Reshape((28, 28, 1)))
    return model

# 3. 创建判别器模型
def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(512))
    model.add(LeakyReLU(0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(0.2))
    model.add(Dense(1, activation='sigmoid'))  # 输出0或1,判断真伪
    return model

# 4. 编译生成器和判别器
generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

# 5. 创建并编译GAN模型
discriminator.trainable = False  # 固定判别器,训练时只训练生成器
gan_input = tf.keras.Input(shape=(latent_dim,))
generated_image = generator(gan_input)
validity = discriminator(generated_image)

gan = tf.keras.Model(gan_input, validity)
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

# 6. 训练GAN
def train(epochs, batch_size=128, save_interval=100):
    half_batch = int(batch_size / 2)

    for epoch in range(epochs):
        # 训练判别器
        idx = np.random.randint(0, x_train.shape[0], half_batch)
        real_images = x_train[idx]

        noise = np.random.normal(0, 1, (half_batch, latent_dim))
        generated_images = generator.predict(noise)

        real_labels = np.ones((half_batch, 1))
        fake_labels = np.zeros((half_batch, 1))

        d_loss_real = discriminator.train_on_batch(real_images, real_labels)
        d_loss_fake = discriminator.train_on_batch(generated_images, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        valid_labels = np.ones((batch_size, 1))

        g_loss = gan.train_on_batch(noise, valid_labels)

        # 每隔save_interval保存并展示一次结果
        if epoch % save_interval == 0:
            print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")
            save_images(epoch)

# 7. 生成并保存图像
def save_images(epoch):
    noise = np.random.normal(0, 1, (25, latent_dim))
    gen_images = generator.predict(noise)
    gen_images = 0.5 * gen_images + 0.5  # 缩放回[0, 1]区间

    fig, axs = plt.subplots(5, 5)
    cnt = 0
    for i in range(5):
        for j in range(5):
            axs[i, j].imshow(gen_images[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    fig.savefig(f"gan_images/mnist_{epoch}.png")
    plt.close()

# 开始训练
train(epochs=epochs, batch_size=batch_size, save_interval=save_interval)

6. 适用场景

GAN适用于许多生成任务,特别是那些需要从数据中提取复杂模式的任务:

  • 图像生成与修复:GAN可用于生成逼真的图像,修复图像中的缺失部分。
  • 数据增强:在数据稀缺的场景下,GAN可以生成类似于训练数据的样本,帮助改进模型的泛化能力。
  • 超分辨率图像重建:通过生成细节清晰的高分辨率图像,应用于图像处理、视频质量提升等场景。
  • 风格迁移:通过GAN实现不同风格的图像、视频转换,例如将照片转为艺术风格画。
  • 医学影像生成:GAN可以生成医学图像,例如CT扫描、MRI数据等,辅助疾病检测与诊断。
  • 文本到图像生成:通过输入文本描述,GAN可以生成与描述相匹配的图像,应用于自动图像生成等场景。

总结

对抗生成网络(GAN)是近年来在生成式模型领域的重要突破,通过生成器与判别器的对抗博弈,GAN能够生成高度逼真的数据。其应用范围广泛,涵盖了图像生成、数据增强、超分辨率重建、风格迁移等多个领域。然而,GAN的训练过程具有挑战性,特别是在平衡两者的对抗关系上仍然存在技术难题。随着技术的不断发展,GAN在生成数据、创造内容等方面的应用前景将更加广阔。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2113494.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【原创】java+swing+mysql简易员工管理系统设计与实现

个人主页:程序员杨工 个人简介:从事软件开发多年,前后端均有涉猎,具有丰富的开发经验 博客内容:全栈开发,分享Java、Python、Php、小程序、前后端、数据库经验和实战 文末有本人名片,希望和大家…

web登录校验

基础登录功能 LoginController PostMapping("/login")Result login(RequestBody Emp emp) {log.info("前端,发送了一个登录请求");Emp e empService.login(emp);return e!null?Result.success():Result.error("用户" "名或密…

isxdigit函数讲解 <ctype.h>头文件函数

目录 1.头文件 2.isxdigit函数使用 方源一把抓住VS2022&#xff0c;顷刻 炼化&#xff01; ​​​​​​​ 1.头文件 以上函数都需要包括头文件<ctype.h> &#xff0c;其中包括 isxdigit 函数 #include<ctype.h> 2.isxdigit函数使用 isxdigit 函数是判断字符…

Leetcode Hot 100刷题记录 -Day10(合并区间)

合并区间 问题描述&#xff1a; 以数组 intervals 表示若干个区间的集合&#xff0c;其中单个区间为 intervals[i] [starti,endi] 。请你合并所有重叠的区间&#xff0c;并返回 一个不重叠的区间数组&#xff0c;该数组需恰好覆盖输入中的所有区间 。 示例 1&#xff1a; 输入&…

vscode从本地安装插件

1. 打开VSCode。 2. 点击左侧菜单中的“扩展”&#xff08;或按CtrlShiftX&#xff09;。 3. 点击“更多操作”&#xff08;三个点&#xff09;> “从VSIX安装”。 4. 选择下载的.vsix文件。 5. 点击“安装”即可安装插件。

IstoreOS安装的1Panel无法安装应用

IstoreOS安装的1Panel无法安装应用&#xff0c;无法安装OpenResty&#xff0c;创建Docker提示文件不存在 这个路径&#xff1a; /root/Configs/1Panel/1panel/apps/openresty/openresty/www /root/Configs/1Panel/1panel/apps/openresty/openresty/1pwaf/data /root/Configs/…

请求响应-02.请求-postman工具

一.前后端分离开发 当前主流的开发模式是前后端分离开发&#xff0c;每开发一个功能&#xff0c;就需要对该功能接口进行测试&#xff0c;当前我们的测试方法是直接将url地址输入到浏览器中&#xff0c;查看web页面是否满足我们的要求。但是浏览器发起的请求全部都是GET请求&am…

【笔记】408刷题笔记

文章目录 三对角三叉树求最小带权路径UDP报文首部和TCP报文首部IP报文首部TCP报文首部UDP报文首部 刷新和再生的区别地址译码 为了区分队空队满&#xff0c;可以使用三种处理方式 1&#xff09;牺牲一个单元 队头指针在队尾指针的下一位置作为队满的标志 队满条件&#xff1a;(…

每日一题,力扣leetcode Hot100之238.除自身以外数组的乘积

乍一看这个题很简单&#xff0c;但是不能用除法&#xff0c;并且在O(N)时间复杂度完成或许有点难度。 考虑到不能用除法&#xff0c;如果我们要计算输出结果位置i的值&#xff0c;我们就要获取这个位置左边的乘积和右边的乘积&#xff0c;那么我新设立两个数组L和R。 对于L来…

Hive 本地启动时报错 Persistence Manager has been closed

Hive 本地启动时报错 Persistence Manager has been closed 2024-09-07 17:21:45 ERROR RetryingHMSHandler:215 - Retrying HMSHandler after 2000 ms (attempt 2 of 10) with error: javax.jdo.JDOFatalUserException: Persistence Manager has been closedat org.datanucle…

使用亚马逊Bedrock的Stable Diffusion XL模型实现文本到图像生成:探索AI的无限创意

引言 什么是Amazon Bedrock&#xff1f; Amazon Bedrock是亚马逊云服务&#xff08;AWS&#xff09;推出的一项旗舰服务&#xff0c;旨在推动生成式人工智能&#xff08;AI&#xff09;在各行业的广泛应用。它的核心功能是提供由顶尖AI公司&#xff08;如AI21 Labs、Anthropic…

基于 RocketMQ 的云原生 MQTT 消息引擎设计

作者&#xff1a;沁君 概述 随着智能家居、工业互联网和车联网的迅猛发展&#xff0c;面向 IoT&#xff08;物联网&#xff09;设备类的消息通讯需求正在经历前所未有的增长。在这样的背景下&#xff0c;高效和可靠的消息传输标准成为了枢纽。MQTT 协议作为新一代物联网场景中…

Windows 11安装nvm教程

1、nvm是什么 nvm 全名 node.js version management&#xff0c;是一个 nodejs 的版本管理工具。通过它可以安装和切换不同版本的 nodejs&#xff0c;主要解决 node 各种版本存在不兼容现象。   在工作中&#xff0c;我们可能同时在进行2个或者多个不同的项目开发&#xff0…

一、Maven工程的GAVP属性及项目结构说明

1、GAVP Maven 中的 GAVP 是指 GroupId、ArtifactId、Version、Packaging 等四个属性的缩写&#xff0c;其中前三个是必要的&#xff0c;而 Packaging 属性为可选项。这四个属性主要为每个项目在maven仓库总做一个标识&#xff0c;类似人的《姓-名》。有了具体标识&#xff0c…

高清4K短视频素材网站有哪些?推荐8个高清4K短视频素材网站

是不是还在为找不到合适的4K高清素材而苦恼&#xff1f;别急&#xff01;今天我为大家精心挑选了8个超级优秀的4K高清短视频素材网站&#xff0c;不仅能让你的视频质量爆表&#xff0c;还能大大提高账号的互动率和曝光度&#xff01;每一个推荐都是精心筛选过的&#xff0c;每一…

【leetcode详解】爬楼梯:DP入门典例(附DP通用思路 同类进阶练习)

实战总结&#xff1a; vector常用方法&#xff1a; 创建一个长为n的vector&#xff0c;并将所有元素初始化为某一定值x vector<int> vec(len, x) 代码执行过程中将所有元素更新为某一值x fill(vec.begin(), vec.end(), x) // 更多实战方法欢迎参考文章&#xff1a;…

SpringBoot教程(十五) | SpringBoot集成RabbitMq(消息丢失、消息重复、消息顺序、消息顺序)

SpringBoot教程&#xff08;十五&#xff09; | SpringBoot集成RabbitMq&#xff08;消息丢失、消息重复、消息顺序、消息顺序&#xff09; RabbitMQ常见问题解决方案问题一&#xff1a;消息丢失的解决方案&#xff08;1&#xff09;生成者丢失消息丢失的情景解决方案1&#xf…

TensorRT-LLM高级用法

--multi_block_mode decoding phase, 推理1个新token&#xff0c; 平时&#xff1a;按照batch样本&#xff0c;按照head&#xff0c;将计算平均分给所有SM&#xff1b; batch_size*num_heads和SM数目相比较小时&#xff1a;有些SM会空闲&#xff1b;加了--multi_block_mode&…

JavaScript 知识点(从基础到进阶)

&#x1f30f;个人博客主页&#xff1a;心.c ​ 前言&#xff1a;JavaScript已经学完了&#xff0c;和大家分享一下我的笔记&#xff0c;希望大家可以有所收获&#xff0c;花不多说&#xff0c;开干&#xff01;&#xff01;&#xff01; &#x1f525;&#x1f525;&#x1f5…

urllib与requests爬虫简介

urllib与requests爬虫简介 – 潘登同学的爬虫笔记 文章目录 urllib与requests爬虫简介 -- 潘登同学的爬虫笔记第一个爬虫程序 urllib的基本使用Request对象的使用urllib发送get请求实战-喜马拉雅网站 urllib发送post请求 动态页面获取数据请求 SSL证书验证伪装自己的爬虫-请求头…