深度学习--对抗生成网络(GAN)

news2024/9/20 18:53:59

对抗生成网络(Generative Adversarial Network, GAN)是一种深度学习模型,由伊恩·古德费洛(Ian Goodfellow)及其同事在2014年提出。GAN通过两个神经网络的对抗过程来生成数据,这两个网络分别是生成器(Generator)和判别器(Discriminator)。

一、GAN的基本概念与作用

  1. 生成器(Generator):生成器的任务是从随机噪声(通常是从正态分布或均匀分布中采样)中生成伪造数据,目的是让这些数据看起来尽可能像真实数据。

  2. 判别器(Discriminator):判别器的任务是区分生成器生成的伪造数据和真实数据。它通过对输入数据进行分类,输出一个概率值,表示该数据是“真实”还是“伪造”。

  3. 对抗过程:生成器和判别器在训练过程中处于一种博弈状态。生成器尝试生成能够欺骗判别器的数据,而判别器则试图尽可能准确地识别伪造数据和真实数据。这个过程通过交替优化生成器和判别器的损失函数来实现。

  4. 作用:GAN能够生成与训练数据分布相似的新数据,在图像生成、图像超分辨率、风格转换、文本生成等领域有广泛应用。

二、GAN的原理

GAN的训练过程可以看作是一个二人零和博弈:

  • 生成器的目标是最大化判别器分类错误的概率,即最大化判别器预测为真实数据的概率。
  • 判别器的目标是最大化区分真实数据和生成数据的能力,即最大化正确分类的概率。

GAN的优化目标是通过以下损失函数来实现的:

三、GAN的应用

  1. 图像生成:GAN可以生成高质量的图像,如人脸图像、艺术作品等。

  2. 图像修复:GAN可以用于填补图像中的缺失部分或修复损坏的图像。

  3. 图像超分辨率:GAN可以将低分辨率的图像转换为高分辨率的图像。

  4. 风格迁移:GAN可以用于将一种图像的风格迁移到另一种图像上,如将照片转换为油画风格。

  5. 数据增强:在数据集不足的情况下,GAN可以生成更多样的数据,以提高模型的泛化能力。

  6. 文本生成:GAN也被应用于生成与真实文本相似的自然语言文本。

四、GAN的简单代码实现

以下是一个简单的GAN实现示例,使用Python和TensorFlow/Keras来生成简单的手写数字图片。

import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

# 生成器模型
def build_generator():
    model = tf.keras.Sequential([
        layers.Dense(256, input_dim=100),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(28 * 28, activation='tanh'),
        layers.Reshape((28, 28))
    ])
    return model

# 判别器模型
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Flatten(input_shape=(28, 28)),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

# GAN模型
def build_gan(generator, discriminator):
    discriminator.compile(optimizer=tf.keras.optimizers.Adam(), loss='binary_crossentropy')
    discriminator.trainable = False
    gan_input = layers.Input(shape=(100,))
    generated_image = generator(gan_input)
    gan_output = discriminator(generated_image)
    gan = tf.keras.models.Model(gan_input, gan_output)
    gan.compile(optimizer=tf.keras.optimizers.Adam(), loss='binary_crossentropy')
    return gan

# 训练GAN
def train_gan(generator, discriminator, gan, epochs=10000, batch_size=128):
    (x_train, _), _ = tf.keras.datasets.mnist.load_data()
    x_train = x_train / 127.5 - 1.0  # Normalize to [-1, 1]
    
    for epoch in range(epochs):
        # 训练判别器
        noise = np.random.normal(0, 1, (batch_size, 100))
        generated_images = generator.predict(noise)
        real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]
        
        real_labels = np.ones((batch_size, 1))
        fake_labels = np.zeros((batch_size, 1))
        
        discriminator.trainable = True
        d_loss_real = discriminator.train_on_batch(real_images, real_labels)
        d_loss_fake = discriminator.train_on_batch(generated_images, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, 100))
        fake_labels = np.ones((batch_size, 1))
        
        discriminator.trainable = False
        g_loss = gan.train_on_batch(noise, fake_labels)
        
        if epoch % 1000 == 0:
            print(f"Epoch {epoch}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}")
            plot_generated_images(epoch, generator)

# 可视化生成结果
def plot_generated_images(epoch, generator, examples=10, dim=(1, 10), figsize=(10, 1)):
    noise = np.random.normal(0, 1, (examples, 100))
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5
    
    plt.figure(figsize=figsize)
    for i in range(examples):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"gan_generated_image_epoch_{epoch}.png")
    plt.show()

# 构建和训练模型
generator = build_generator()
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)
train_gan(generator, discriminator, gan)

五、总结

GAN是一个强大的生成模型,通过生成器和判别器的对抗训练,能够生成与真实数据分布相似的伪造数据。它在图像生成、修复、风格迁移等领域都有广泛的应用。上面的代码示例展示了如何使用Keras实现一个简单的GAN,用于生成手写数字图片。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2072546.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Chapter 03 Vue指令(下)

欢迎大家订阅【Vue2Vue3】入门到实践 专栏,开启你的 Vue 学习之旅! 文章目录 前言一、v-on指令二、v-for指令三、v-bind指令 前言 在 Vue.js 中,指令是带有 v- 前缀的特殊属性,不同属性对应不同的功能。通过学习不同的指令&#…

临床医生与人工智能识别三级淋巴结成熟状态的研究对比|文献速递·24-08-24

小罗碎碎念 这期推文的主题是三级淋巴结,主要解决一个问题——临床上如何识别三级淋巴结&人工智能如何应用于三级淋巴结的识别。这两篇文献来源于临床和工科两位不同的老师,是在与他们交流的过程中推荐的,在这里向他们表示感谢&#xff…

在VSCode中使用REST Client插件调试HTTP接口

在 VSCode 中安装 REST Client 扩展程序。新建 test.http 文件。编写请求 请求编写格式可以查看 REST Client 扩展程序说明。点击“Send Request”发送请求 5. 等待请求完成查看响应 请求完成会自动打开响应结果。响应结果上面部分是响应头,下面部分是响应…

idea付费插件,SequenceDiagram比较好用

以下idea付费插件你们都用过哪些呢? SequenceDiagram插件是一种用于绘制时序图的工具。时序图是一种图形化的表示对象之间消息传递顺序的方法。 该插件可以在使用各种编程语言编写代码时,方便地绘制时序图,以帮助开发者更好地理解和描述系统…

【数据分享】全球含建筑高度的建筑物数据(shp格式\约15亿栋建筑物)

建筑数据是我们在各项研究中经常使用到的数据。之前我们能获取到的建筑数据大多没有建筑高度信息,而建筑高度是建筑数据最重要的属性。之前我们给大家分享了我国分城市的含建筑高度的建筑物数据(可查看之前的文章获悉详情),本次我…

ST-LINK常见错误总结

伴随着走进STM32 开发 ,烧录部分一直会出现 各种各样的问题 ,写一篇博文记录关于烧录部分的问题,此文会持续更新,可能之后又遇到其他新的问题,会回来再添加的。 目录 STLINK CONNECTION ERROR 问题的解决 固件丢失 …

buuctf [MRCTF2020]hello_world_go

前言 学习笔记 这题签到! 64IDA打开。 查找字符串发现什么都没有。。。 没事 搜索main()【不知道go语言有没有,先搜索再说】 随便点开一个。 有flag格式,提交看看呗。 成了,签到。 flag{hello_world_gogogo} 题外话,…

双系统报错verifiying shim SBAT data falled: Security Pollcy Violation,Ubuntu无法打开

问题 一觉醒来,打开电脑报错无法打开,详细报错如下: verifiying shim SBAT data falled: Security Pollcy Violation Something has gone serlously wrong: seni self-check falled: Security Policy vlolation 这是由于Windows系统自动更新…

x-cmd mod | x btop - 使用 btop 来查看进程的实时信息

目录 介绍使用语法子命令选项FLAGS 介绍 btop 是系统监控工具,能够实时监控 CPU、内存、磁盘、网络和进程使用情况。 使用语法 x btop [FLAGS]子命令 名称描述–cmd直接运行 btop 命令 选项 名称描述–preset,-p 从预设开始,整数范围为 0-9。–upda…

【深度学习】使用Conda虚拟环境安装多个版本的CUDA和CUDNN方便切换

conda虚拟环境安装CUDA和CUDNN 官网教程 https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#conda-installation 1. 背景 深度学习用显卡训练的时候,需要安装与显卡对应的cuda和cudnn。但不同的项目所支持的pytorch版本是不一样的&#x…

考研备考是选择电子学习工具无纸化学习?还是纸质版训练考感?

作为一名成功上岸的考研学子,回顾备考的艰辛历程,深感学习工具的选择至关重要。在当今数字化时代,我们面临着一个关键的抉择:是延续传统的纸质版资料学习,还是投身于电子学习工具的怀抱,开启无纸化学习之旅…

安卓飞机大战设计过程

用户界面 XML布局文件和Activity类 Android布局文件XML是在res/layout文件夹下的xml文件,里面可以放一些组件 启动Activity时, Android 框架会调用 Activity 中的 onCreate() 回调方法,从而加载应 用代码中的布局资源; Overri…

PDF编辑神器!免费版助你轻松搞定文档转换

随着数字化时代的来临,PDF文件因其稳定性和兼容性成为了我们在职场中常用的文档格式。而面对众多的PDF编辑器,免费版的工具选择显得尤为重要。今天分享五款我用过的免费版PDF编辑器的使用感受,帮助大家更好地了解并选择适合自己的办公工具。 …

Flink1.18 同步 MySQL 到 Doris

一、前言 使用Apache Flink实现数据同步的ETL(抽取、转换、加载)过程通常涉及从源系统(如数据库、消息队列或文件)中抽取数据,进行必要的转换,然后将数据加载到目标系统(如另一个数据库…

数据结构之排序(二)

目录 基本思想: 1.1冒泡排序 ​编辑1.1.1代码实现 1.3冒泡排序的特性总结: 2.1 快速排序 2.1.1基本思想 2.2.2代码实现 1. hoare版本 2.挖坑法 3.前后指针版本 2.2.3 快速排序的优化(三数取中) 实现步骤 3.1 快速排序非…

链表--随机链表复制

给你一个长度为 n 的链表,每个节点包含一个额外增加的随机指针 random ,该指针可以指向链表中的任何节点或空节点。 构造这个链表的 深拷贝。 深拷贝应该正好由 n 个 全新 节点组成,其中每个新节点的值都设为其对应的原节点的值。新节点的 n…

太阳方向角/高度角/赤纬角/太阳时角/真平太阳时差/理论计算方法(matlab)

1. 理论学习 方向角,高度角计算公式 如图,直观地描述了方位角(圆盘上红色夹角)与高度角(黄色线与圆盘的夹角) 赤纬角计算公式 地球赤道平面与太阳和地球中心的连线之间的夹角 如图所示,23度那个. 时角计算公式 太阳时角是指日面中心的时角…

博客园OpenApi管理平台

简介 博客园(Cnblogs)提供了OpenAPI服务,允许开发者通过API来获取博客园中的数据。使用这个API,可以实现从博客园抓取文章、评论等信息的功能,这对于想要集成博客园内容到自己网站或应用的开发者来说是非常有用的。 …

【hot100篇-python刷题记录】【二叉树的最大深度】

R6-二叉树篇 最简单的方法: 循环len(root)次,每次循环执行以下操作: 循环pow(2,i)次,每次都root.pop(0) 如果为空,立即退出,返回i1 class Solution:def maxDepth(self, root: Optional[TreeNode]) ->…

C语言基础(十七)

C语言中的结构体&#xff08;Struct&#xff09;是一种用户自定义的数据类型&#xff0c;允许将不同类型的数据项组合成一个单一的类型&#xff1a; 测试代码1&#xff1a; #include "date.h" #include <stdio.h> #include <string.h> // 定义结构…