[深度学习] 生成对抗网络GAN

news2024/11/17 17:48:51

生成对抗网络(Generative Adversarial Networks,GANs)是一种由 Ian Goodfellow 等人在2014年提出的深度学习模型Generative Adversarial Networks。GANs的基本思想是通过两个神经网络(生成器和判别器)的对抗过程,生成与真实数据分布相似的新数据。以下是对GANs的详细介绍。

1. 基本结构

GANs由两个主要组件构成:

1.1 生成器(Generator)

生成器的任务是从一个随机噪声(通常是高斯噪声)中生成逼真的数据样本。生成器是一个神经网络,它接受一个随机向量作为输入,并输出一个与真实数据分布相似的样本。目标是欺骗判别器,使其认为生成的数据是真实的。

1.2 判别器(Discriminator)

判别器是另一个神经网络,用于区分真实数据和生成器生成的假数据。它的输入是一个数据样本,输出是一个标量值,表示输入数据是真实的概率。判别器的目标是最大化区分真实数据和生成数据的准确性。

2. 工作原理

GANs通过生成器和判别器的对抗训练来实现目标。训练过程中,两者的目标是相反的:

  • 生成器试图最大限度地欺骗判别器,使其无法区分生成数据和真实数据。
  • 判别器则尽量提高区分真实数据和生成数据的能力。

这种对抗训练可以形式化为一个极小极大(minimax)问题:

在这里插入图片描述
其中,D(x) 是判别器对真实数据 x 的输出, G(z) 是生成器对随机噪声 z 的输出。

3. 训练过程

  1. 初始化:随机初始化生成器和判别器的参数。
  2. 训练判别器:在固定生成器的情况下,用一批真实数据和生成数据训练判别器,更新判别器的参数。
  3. 训练生成器:在固定判别器的情况下,用生成器生成的假数据训练生成器,更新生成器的参数。
  4. 循环:重复上述步骤,直到生成器生成的数据足够逼真。

下面是一个使用TensorFlow和Keras实现简单生成对抗网络(GAN)的示例代码。该代码使用MNIST手写数字数据集来训练生成器生成类似手写数字的图像。

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt

# 加载MNIST数据集
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # 将图像标准化到 [-1, 1] 区间
BUFFER_SIZE = 60000
BATCH_SIZE = 256

# 创建数据集
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# 生成器模型
def make_generator_model():
    model = models.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # 注意:None 是批量大小

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

# 判别器模型
def make_discriminator_model():
    model = models.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

# 定义损失函数和优化器
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

generator = make_generator_model()
discriminator = make_discriminator_model()

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

# 训练步骤
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

# 训练过程
def train(dataset, epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)

        # 每个epoch结束时生成并保存一些图像
        generate_and_save_images(generator, epoch + 1, seed)

    # 最后一个epoch生成图像
    generate_and_save_images(generator, epochs, seed)

# 生成并保存图像的辅助函数
def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

# 训练模型
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

seed = tf.random.normal([num_examples_to_generate, noise_dim])

train(train_dataset, EPOCHS)
说明
  1. 数据预处理: 加载MNIST数据集并进行归一化处理。
  2. 模型构建:
    • 生成器: 从随机噪声生成图像。
    • 判别器: 区分真实图像和生成图像。
  3. 损失函数: 使用二元交叉熵损失函数。
  4. 优化器: 使用Adam优化器。
  5. 训练过程: 每个训练步骤包含生成图像、判别图像,并计算和应用梯度更新。
  6. 结果可视化: 在训练过程中生成并保存图像。

运行上述代码后,会看到生成的手写数字图像,这些图像与真实的MNIST数据集中的手写数字非常相似。

4. 常见问题与改进

生成对抗网络(GANs)虽然在许多领域取得了显著成果,但也存在一些显著的缺点和挑战。这些缺点主要源于其对抗训练的复杂性和网络设计的固有特性。

1. 训练不稳定

原因: GANs的训练是一个极小极大(minimax)优化问题,生成器和判别器的目标是相反的,这导致了训练过程的不稳定性。

  • 对抗失衡: 如果判别器太强,它会很容易区分生成的数据和真实数据,使生成器无法学习。如果生成器太强,它会欺骗判别器,使其失去鉴别能力。
  • 梯度消失或爆炸: 在训练过程中,梯度可能会消失或爆炸,尤其是在深层网络中,这使得训练过程更加困难。
2. 模式崩溃(Mode Collapse)

原因: 生成器可能会找到一些特定模式的生成样本,这些样本能很好地欺骗判别器,但缺乏多样性。

  • 单一输出: 生成器可能会倾向于生成几种特定的输出模式,而忽略其他可能的输出,导致生成样本缺乏多样性。
  • 判别器反馈不足: 如果判别器未能有效反馈生成样本的多样性问题,生成器将持续生成相同或相似的样本。

这些缺点和挑战使得GANs在实际应用中需要进行细致的调整和改进,但也促使研究者不断提出新的方法和变体,以改善GANs的性能和稳定性:

1 条件GAN(Conditional GANs,cGANs)

通过在生成器和判别器中加入条件变量,使得生成的样本能够受控于特定的条件。例如,可以生成特定类别的图像。

2 Wasserstein GANs(WGANs)

通过使用Wasserstein距离替代原始GAN中的Jensen-Shannon散度,解决了训练不稳定的问题,使得训练过程更加平滑。

3 深度卷积GAN(Deep Convolutional GANs,DCGANs)

利用卷积神经网络(CNN)来构建生成器和判别器,提高了生成图像的质量。

5. 应用场景

GANs在许多领域都有广泛的应用,包括但不限于:

  • 图像生成:生成高质量的图像,例如人脸、风景等。
  • 图像修复:修复受损图像或填补缺失部分。
  • 图像超分辨率:提高低分辨率图像的分辨率。
  • 数据增强:在数据不足的情况下生成更多训练样本。
  • 视频生成:生成逼真的视频序列。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1864407.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nodejs使用mqtt库连接阿里云服务器

建项目 命令行输入: npm init 输入项目名,自动化生成项目列表。 6.3 编写代码 新建mqtt_demo_aliyun.js,代码如下: // mqtt_demo_aliyun.jsconst mqtt require("mqtt"); const connectUrl "ws://post-cn-nw**…

展厅设计中需要人性化的地方

1、预留参观空间 展厅空间的布局设计必须尽可能的宽敞,以避免参观人数较多时可能会发生的拥堵,重点展品需要预留较大的展示空间或四面通畅的中心位置,更方便观众从不同角度与方位参观。因为是展厅,不仅代表着企业形象,…

安科瑞光伏并网电表ADL400N-CT双向计量防逆流自带互感器电表-安科瑞 蒋静

1 概述 ADL 系列导轨式多功能电能表,是主要针对于光伏并网系统、微逆系统、储能系统、交流耦合系统等新能源发电系统而设计的一款智能仪表,产品具有精度高、体积小、响应速度快、安装方便等特点。具有对电力参数进行采样计量和监测,逆变器或…

flask与vue实现通过websocket通信

在一些情况下,我们需要实现前后端之间的时刻监听,本文是一篇工具文档,用于解决前后端之间使用websocket交互。 一. Flask的相关配置 1. 下载相关依赖库 如果还没有配置flask的话,需要先安装flask,同时为解决跨域问题&#xff0…

Topaz Gigapixel AI图片无损放大软件下载安装,Topaz Gigapixel AI 高精度的图片无损放大

Topaz Gigapixel AI无疑是一款革命性的图片无损放大软件,它在图像处理领域开创了一种全新的可能性。 Topaz Gigapixel AI的核心功能在于能够将图片进行高精度的无损放大。虽然经过软件处理的图片严格意义上并不能算是完全无损,但相较于传统方法&#xf…

AI实战案例!如何运用SD完成运营设计海报?玩转Stable Diffusion必知的3大绝技

大家好我是安琪! Satble Diffusion 给视觉设计带来了前所未有的可能性和机会,它为设计师提供了更多选择和工具的同时,也改变了设计师的角色和设计流程。然而,设计师与人工智能软件的协作和创新能力仍然是不可或缺的。接下来我将从…

【语言模型】探索AI模型、AI大模型、大模型、大语言模型与大数据模型的关系与协同

一、引言 随着人工智能(AI)技术的飞速发展,各种AI模型如雨后春笋般涌现,其中AI模型、AI大模型、大模型、大语言模型以及大数据模型等概念在学术界和工业界引起了广泛关注。这些模型不仅各自具有独特的特点和应用场景,…

告别臭脚尴尬!安全鞋除臭秘籍大公开

你是否有过这样的烦恼,忙碌一天回到家,脱鞋的瞬间,那令人窒息的气味让人瞬间清醒?别担心,今天百华小编就与大家一起探讨下安全鞋除臭的秘籍,让你从此告别臭脚尴尬! 首先,我们要了解…

PHP 面向对象编程(OOP)入门指南

面向对象编程(Object-Oriented Programming,简称OOP)是一种编程范式,通过使用对象来设计和组织代码。PHP作为一种广泛使用的服务器端脚本语言,支持面向对象编程。本文将介绍PHP面向对象编程的基本概念和用法&#xff0…

SpringCloud Alibaba Seata2.0分布式事务AT模式实践总结

这里我们划分订单、库存与支付三个module来实践Seata的分布式事务。 依赖版本(jdk17)&#xff1a; <spring.boot.version>3.1.7</spring.boot.version> <spring.cloud.version>2022.0.4</spring.cloud.version> <spring.cloud.alibaba.version>…

美多多商城定义用户模型类遇见的问题

from django.db import models from django.contrib.auth.models import AbstractUser # Create your models here. class User(AbstractUser):mobile models.CharField(max_length11, uniqueTrue,verbose_name手机号)class Meta:db_tabletb_users #自定义表名verbose_name用户…

【动态内存】详解

Hi~&#xff01;这里是奋斗的小羊&#xff0c;很荣幸您能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎 ~~ &#x1f4a5;&#x1f4a5;个人主页&#xff1a;奋斗的小羊 &#x1f4a5;&#x1f4a5;所属专栏&#xff1a;C语言 &#x1f680;本系列文章为个人学习…

深入浅出 langchain 1. Prompt 与 Model

示例 从代码入手来看原理 from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI prompt ChatPromptTemplate.from_template("tell me a short joke about…

B端页面:日志管理页面,简洁实用的设计法门

B端日志管理是指在企业级后台系统中对系统操作日志进行记录、查看和管理的功能。 它的作用主要有以下几点&#xff1a; 1. 安全审计&#xff1a;通过记录用户的操作日志&#xff0c;可以对系统的安全性进行审计和监控&#xff0c;及时发现异常操作和安全漏洞。 2. 故障排查&a…

Program LLMs,不只是Prompt LLMs

前言 随着大模型的使用和应用越来越频繁&#xff0c;也越来越广泛&#xff0c;大家有没有陷入到无限制的研究、调优自己的prompt。 随之&#xff0c;市面上也出现了提示词工程师&#xff0c;更有专门的提示工程一说。 现在网上搜一搜&#xff0c;有各种各样的写提示词的技巧…

Python多线程技巧心得详解

概要 多线程是一种能够并发执行代码的方法,可以提高程序的执行效率和响应速度。本文将详细介绍 Python 中多线程的概念、使用场景、基本用法以及实际应用,可以更好地掌握多线程编程。 什么是多线程? 多线程是一种在单个进程内并发执行多个线程的技术。每个线程共享相同的内…

电脑CPU速度很快,为什么3dMax还会出现卡顿的情况?

我们在使用3dmax时会经常遇到电脑变得很缓慢甚至卡顿的情况&#xff08;多发生于新手群体&#xff09;&#xff0c;即使我们的电脑CPU已经足够快&#xff0c; 也会出现滞后或性能延迟。包括但不限于 Intel i9 和 AMD“Ryzen Threadrippers”。 例如单击用户界面的任何区域或移…

红酒舞动,运动风采,品味力与美

当夜幕降临&#xff0c;城市的灯火渐次亮起&#xff0c;忙碌了一天的人们开始寻找那份属于自己的宁静与愉悦。在这个时刻&#xff0c;红酒与运动&#xff0c;这两个看似截然不同的元素&#xff0c;却能以它们不同的魅力&#xff0c;为我们带来一场视觉与感官的盛宴。 红酒&…

如何轻松获取 GitLab 指定分支特定路径下的文件夹内容

第一步&#xff1a; 获取 accessToken 及你的 项目 id &#xff1a; 获取 accessToken ,点击用户头像进入setting 按图示操作&#xff0c;第 3 步 填写你发起请求的域名。 获取项目 id , 简单粗暴方案 进入 你项目仓库页面后 直接 源码搜索 project_id&#xff0c; value 就…

ApolloClient GraphQL 与 ReactNative

要在 React Native 应用程序中设置使用 GraphQL 的简单示例&#xff0c;您需要遵循以下步骤&#xff1a; 设置一个 React Native 项目。安装 GraphQL 必要的依赖项。创建一个基本的 GraphQL 服务器&#xff08;或使用公共 GraphQL 端点&#xff09;。从 React Native 应用中的…