人工智能基础部分20-生成对抗网络(GAN)的实现应用

news2024/12/24 2:45:43

大家好,我是微学AI,今天给大家介绍一下人工智能基础部分20-生成对抗网络(GAN)的实现应用。生成对抗网络是一种由深度学习模型构成的神经网络系统,由一个生成器和一个判别器相互博弈来提升模型的能力。本文将从以下几个方面进行阐述:生成对抗网络的概念、GAN的原理、GAN的实验设计。

一、前言

随着近年来人工智能发展的不断加速,尤其是深度学习的出现,使得计算机视觉领域取得了许多重要突破。生成对抗网络(Generative Adversarial Networks, GAN)是其中一种具有广泛应用前景的技术。GAN是一种生成式模型,它的主要原理是通过博弈论的方式,将生成模型与判别模型进行对抗训练,从而实现生成图像、音频等数据的任务。本文将对GAN 的工作原理进行详细解释,并通过一个图像生成示例项目,展示如何使用 PyTorch 框架实现 GAN,并给出实验结果与完整代码。

二、生成对抗网络(GAN)原理

GAN的核心思想是让两个网络(生成器和判别器)进行博弈,最终迭代得到一个高质量的生成器。生成器的任务是生成与真实数据分布相近的伪数据,而判别器的任务则是判断输入数据是来源于真实数据还是伪数据。通过优化生成器与判别器的博弈过程,使得生成器逐渐改进,能够生成越来越接近真实数据的伪数据。

2.1 生成器

生成器的主要作用是以随机噪声为输入,输出生成的伪数据。随机噪声是一个高斯分布的向量,我们可以通过一个深度神经网络模型(如卷积神经网络、前馈神经网络等)将这个高斯分布的向量映射成我们想要输出的伪数据。

2.2 判别器

判别器是一个二分类神经网络模型,输入可能来自生成器也可能来自真实数据。其任务是对输入数据进行分类,输出一个概率值以判断输入数据是来自真实数据集还是生成器生成的伪数据。

2.3 博弈过程

生成器与判别器博弈的过程即是各自的训练过程。生成器训练的目标是使得判别器对其生成的数据预测为真实数据的概率最大;判别器训练的目标是使得自身对真实数据与生成的数据的分类准确率最高。通过反复迭代这个过程,最终生成器能够生成越来越接近真实数据的伪数据。

2.4 数学原理

生成对抗网络(Generative Adversarial Networks,简称 GAN)是一种基于博弈论的生成模型,其数学原理可以用以下公式表示:

假设p_{data}(x)表示真实数据的分布,p_z(z) 表示生成器输入随机噪声z 的分布,G(z;\theta_g)表示生成器的输出,其中 \theta_g是生成器的参数,D(x;\theta_d) 表示判别器的输出,其中\theta_d是判别器的参数。

GAN 的目标是最小化以下损失函数:

\min_G\max_D V(D,G) = \mathbb{E}{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]

其中 \mathbb{E} 表示期望值,\log表示自然对数。

这个损失函数的含义是:最小化生成器生成的数据与真实数据之间的差距,同时最大化判别器对生成器生成的数据和真实数据的区分度。具体来说,第一项\mathbb{E}{x \sim p{data}(x)}[\log D(x)]表示真实数据被判别为真实数据的概率,第二项 \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] 表示生成器生成的虚构数据被判别为虚构数据的概率。

在训练过程中,GAN 会交替训练生成器和判别器,通过最小化损失函数 V(D,G)来优化模型参数。具体来说,对于每个训练迭代,我们首先固定生成器的参数,通过最大化损失函数V(D,G) 来优化判别器的参数。然后,我们固定判别器的参数,通过最小化损失函数V(D,G) 来优化生成器的参数。这个过程会一直迭代下去,直到达到预定的迭代次数或者损失函数收敛。

三、实验设计

本文使用 tensorflow  框架实现 GAN,并在图像生成任务上进行训练。实验workflow 分为以下五个步骤:数据准备\构建生成器与判别器\设置损失函数与优化器、训练过程,让我们先从数据准备开始。

四、代码实现

下面我们将使用MNIST(手写数字化)这一经典的数据集来展示GANs的实际应用效果。

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# 设置随机种子以获得可重现的结果
np.random.seed(42)
tf.random.set_seed(42)

# 加载MNIST数据集
(x_train, y_train), (_, _) = keras.datasets.mnist.load_data()

# 将数据规范化到[-1, 1]范围内
x_train = x_train.astype(np.float32) / 127.5 - 1

# 将数据集重塑为(-1, 28, 28, 1)
x_train = np.expand_dims(x_train, axis=-1)


# 创建生成器模型
def create_generator():
    generator = keras.Sequential()
    generator.add(layers.Dense(7 * 7 * 256, use_bias=False, input_shape=(100,)))
    generator.add(layers.BatchNormalization())
    generator.add(layers.LeakyReLU(alpha=0.2))

    generator.add(layers.Reshape((7, 7, 256)))

    generator.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias = False))
    generator.add(layers.BatchNormalization())
    generator.add(layers.LeakyReLU(alpha=0.2))

    generator.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias = False))
    generator.add(layers.BatchNormalization())
    generator.add(layers.LeakyReLU(alpha=0.2))

    generator.add(
        layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias = False, activation ='tanh'))
    return generator


generator = create_generator()


# 创建鉴别器模型
def create_discriminator():
    discriminator = keras.Sequential()
    discriminator.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape = (28, 28, 1)))
    discriminator.add(layers.LeakyReLU(alpha=0.2))
    discriminator.add(layers.BatchNormalization())

    discriminator.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    discriminator.add(layers.LeakyReLU(alpha=0.2))
    discriminator.add(layers.BatchNormalization())

    discriminator.add(layers.Flatten())
    discriminator.add(layers.Dropout(0.2))
    discriminator.add(layers.Dense(1, activation='sigmoid'))
    return discriminator


discriminator = create_discriminator()

# 编译鉴别器
discriminator_optimizer = keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy', metrics = ['accuracy'])

# 创建和编译整体GAN结构
discriminator.trainable = False
gan_input = keras.Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = keras.Model(gan_input, gan_output)

gan_optimizer = keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')

# 模型训练函数
def train_gan(epochs=100, batch_size=128):
    num_examples = x_train.shape[0]
    num_batches = num_examples // batch_size
    for epoch in range(epochs):
        for batch_idx in range(num_batches):
            noise = np.random.normal(size=(batch_size, 100))
            generated_images = generator.predict(noise)

            real_images = x_train[(batch_idx * batch_size):((batch_idx + 1) * batch_size)]
            all_images = np.concatenate([generated_images, real_images])

            labels = np.zeros(2 * batch_size)
            labels[batch_size:] = 1

            # 在噪声上加一点随机数,提高生成器的鲁棒性
            labels += 0.05 * np.random.rand(2 * batch_size)

            discriminator_loss = discriminator.train_on_batch(all_images, labels)

            noise = np.random.randn(batch_size, 100)
            misleading_targets = np.ones(batch_size)

            generator_loss = gan.train_on_batch(noise, misleading_targets)

            if (batch_idx + 1) % 50 == 0:
                print(
                    f"Epoch:{epoch + 1}/{epochs} Batch:{batch_idx + 1}/{num_batches} Discriminator Loss: {discriminator_loss[0]} Generator Loss:{generator_loss}")


train_gan()

以上实现了生成对抗网络是训练过程,实际中我们可以替换数据训练自己的数据模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/579756.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

flutter_学习记录_03_通过事件打开侧边栏

实现类似这样的侧边栏的效果&#xff1a; 可以用Drawer来实现。 1. 在Scaffold组件下设置endDrawer属性 代码如下&#xff1a; import package:flutter/material.dart;class ProductListPage extends StatefulWidget {ProductListPage( {super.key}) ;overrideState<Pro…

首发Yolov8优化:Adam该换了!斯坦福最新Sophia优化器,比Adam快2倍 | 2023.5月斯坦福最新成果

1.Sophia优化器介绍 斯坦福2023.5月发表的最新研究成果,他们提出了「一种叫Sophia的优化器,相比Adam,它在LLM上能够快2倍,可以大幅降低训练成本」。 论文:https://arxiv.org/pdf/2305.14342.pdf 本文介绍了一种新的模型预训练优化器:Sophia(Second-order Clippe…

低资源方面级情感分析研究综述

文章目录 前言1. 引言2. 问题定义、数据集和评价指标2.1 问题定义2.2 任务定义2.3 常用数据集 3. 方面级情感分析的方法3.1 **方面词抽取**3.1.1 基于无监督学习的方法3.1.1.1 基于规则的方面词抽取3.1.1.2 基于统计的方面词抽取 3.1.2 基于有监督浅层模型的方法3.1.3 基于有监…

【C++初阶】类和对象(下)之友元 + 内部类 + 匿名对象

&#x1f466;个人主页&#xff1a;Weraphael ✍&#x1f3fb;作者简介&#xff1a;目前学习C和算法 ✈️专栏&#xff1a;C航路 &#x1f40b; 希望大家多多支持&#xff0c;咱一起进步&#xff01;&#x1f601; 如果文章对你有帮助的话 欢迎 评论&#x1f4ac; 点赞&#x1…

一台服务器通过nginx安装多个web应用

1.首先安装nginx网站服务器 yum install nginx 2.nginx 的主配置文件&#xff1a;/etc/nginx/nginx.conf (一台服务器有两个域名部署) 我们在/etc/nginx/nginx.d/下创建一个conf文件&#xff0c;这个文件会被嵌套到主配置文件当中 server { listen 80; …

《数据库应用系统实践》------ 个人作品管理系统

系列文章 《数据库应用系统实践》------ 个人作品管理系统 文章目录 系列文章一、需求分析1、系统背景2、 系统功能结构&#xff08;需包含功能结构框图和模块说明&#xff09;3&#xff0e;系统功能简介 二、概念模型设计1&#xff0e;基本要素&#xff08;符号介绍说明&…

Netty客户端与服务器端闲暇检测与心跳检测(三)

网络应用程序中普遍存在一个问题&#xff1a;连接假死&#xff0c;连接假死现象是:在某一端(服务器端|客户端)看来,底层的TCP连接已经断开,但是应用程序没有捕获到,因此会认为这个连接还存在。从TCP层面来说,只有收到四次握手数据包,或者一个RST数据包,才表示连接状态已断开; 连…

Spring练习二ssm框架整合应用

导入教程的项目&#xff0c;通过查看源码对aop面向切面编程进行理解分析 aop面向编程就像是我们给程序某些位置丢下锚点&#xff08;切入点&#xff09;以及当走到锚点时需要调用的方法&#xff08;切面&#xff09;。在程序运行的过程中&#xff0c; 一旦到达锚点&#xff0c;…

f-stack的源码编译安装

DPDK虽然能提供高性能的报文转发&#xff08;安装使用方法见DPDK的源码编译安装&#xff09;&#xff0c;但是它并没有提供对应的IP/TCP协议栈&#xff0c;所以在网络产品的某些功能场景下&#xff08;特别是涉及到需要使用TCP协议栈的情况&#xff09;&#xff0c;比如BGP邻居…

Ansible原理简介与安装篇

工作原理 1、在Ansible管理体系中&#xff0c;存在“管理节点”和“被管理节点” 2、被管理节点通常被称为”资产“ 3、在管理节点上&#xff0c;Ansible将AdHoc或PlayBook转换为python脚本。并通过SSH将这些python脚本传递到被管理服务器上。在被管理服务器上依次执行&#xf…

遥感云大数据在灾害、水体与湿地领域及GPT模型应用

近年来遥感技术得到了突飞猛进的发展&#xff0c;航天、航空、临近空间等多遥感平台不断增加&#xff0c;数据的空间、时间、光谱分辨率不断提高&#xff0c;数据量猛增&#xff0c;遥感数据已经越来越具有大数据特征。遥感大数据的出现为相关研究提供了前所未有的机遇&#xf…

基础篇010.2 STM32驱动RC522 RFID模块之二:STM32硬件SPI驱动RC522

目录 基础篇010.1 STM32驱动RC522 RFID模块之一&#xff1a;基础知识 1. 实验硬件及原理图 1.1 RFID硬件 1.2 硬件原理图 2. 单片机与RFID硬件模块分析 3. 利用STM32CubeMX创建MDK工程 3.1 STM32CubeMX工程创建 3.2 配置调试方式 3.3 配置时钟电路 3.4 配置时钟 3.5 配…

【C++】Map、Set 模拟实现

文章目录 &#x1f4d5; 概念&#x1f4d5; 实现框架Find()★ 迭代器 ★反向迭代器map 的 operator[ ] &#x1f4d5; 源代码rb_tree.hset.hmap.h &#x1f4d5; 概念 map、set 是 C 中的关联式容器&#xff0c;由于 map 和set所开放的各种操作接口&#xff0c;RB-tree 也都提…

2023.05.28 学习周报

文章目录 摘要文献阅读1.题目2.现有方法存在的局限性3.SR-GNN模型4.模型的组成部分4.1 构图4.2 item向量表示4.3 session向量表示4.4 预测模块 5.实验与分析5.1 数据集5.2 比较方法5.3 评估指标5.4 实验结果 6.结论 有限元法1.一个例子2.进一步 深度学习1.张量场2.对流-扩散方程…

Linux(基础IO详解)

在基础IO这篇博客中&#xff0c;我们将了解到文件系统的构成&#xff0c;以及缓冲区究竟是个什么东东&#xff0c;我们都知道缓冲区&#xff0c;有时也谈论缓冲区&#xff0c;但不一定真的去深入了解过缓冲区。为什么内存和磁盘交互速度如此之慢&#xff1f;为什么都说Linux中一…

Dom解析与Sax解析的区别

1.Dom解析&#xff1a; Dom解析的时候&#xff0c;首先要把整个文件读取完毕&#xff0c;装载到内存中。然后进行解析&#xff0c;在解析的过程中&#xff0c;你可以直接获取某个节点&#xff0c;进行操作&#xff0c;也可以获取根节点然后进行遍历操作&#xff0c;得到所有的…

一台服务器通过apache安装多个web应用

当我们只有一台linux服务器资源但有创建多个网站的需求时&#xff0c;我们可以通过安装一个网站服务器Apache进行搭建&#xff0c;此次服务器使用Centos 7 下面分别介绍一个域名多个端口和多个域名用Apache来搭建多个网站的操作过程。 一、使用apache 服务器 &#xff08;一…

HCIA-MSTP替代技术之链路捆绑(LACP模式)

目录 手工链路聚合的不足&#xff1a; LACP链路聚合的原理 LACP模式&#xff1a; LACPDU&#xff1a; 1&#xff0c;设备优先级&#xff1a; 设备优先级的比较是&#xff1a;先比较优先级大小&#xff0c;0到32768&#xff0c;越小优先级越高&#xff0c;如果优先级相同&a…

华为FinalMLP

FinalMLP:An Enhanced Two-Stream MLP model for CTR Prediction 摘要 Two-Stream model&#xff1a;因为一个普通的MLP网络不足以学到丰富的特征交叉信息&#xff0c;因此大家提出了很多实用MLP和其他专用网络结合来学习。 MLP是隐式地学习特征交叉&#xff0c;当前很多工作…

分布式网络通信框架(二)——RPC通信原理和技术选型

项目实现功能 技术选型 黄色部分&#xff1a;设计rpc方法参数的打包和解析&#xff0c;也就是数据的序列化和反序列化&#xff0c;用protobuf做RPC方法调用的序列化和反序列化。 使用protobuf的好处: protobuf是二进制存储&#xff0c;xml和json是文本存储&#xff1b; pro…