生成对抗:DCGAN

news2025/1/11 14:09:05

DCGAN简介

  Generative Adversarial Networks(GANs),GANs有两个模型组成,一个是生成器,用于训练生成假的数据,另一个是判别器,用于预测生成器的输出结果。其中生成器提供训练数据给判别器,提高判别器的准确率。判别器提供生成样本的预测结果,给生成器提供优化的方向。其实在1990年前后,对抗的思想就已经应用于无监督人工神经网络领域,通过最小化另一个程序最大化的目标函数来求解问题。

  生成器的输入通常是一些随机向量,然后去生成接近真实的训练数据。为了生成逼真的数据,生成器的输出会作为判别器的输入,生成器会收到来自判别器的反馈(判别器的识别结果)。生成器可以通过判别器的反馈,知道自己生成结果跟真实数据的差距,从而不断的提高自己,同时判别器也从生成器源源不断获得训练数据,不断提高自己的鉴别能力。

生成器的目标函数

m i n i m i z e E x ∼ P G ( x ∗ ) [ l o g ( 1 − D ( x ) ) ] ,      D ( x ) → 1 minimize E_{x \sim P_G(x^*)} [ log(1 - D(x))] , ~~~~D(x) \to 1 minimizeExPG(x)[log(1D(x))],    D(x)1

  • 假设真实样本 标签为 1, 生成样本标签为 0;
  • z ∼ P z z \sim P_z zPz : 随机噪声
  • P G ( G ( z ) ) = P G ( x ∗ ) P_G(G(z)) = P_G(x^*) PG(G(z))=PG(x): 生成数据的分布
  • D ( x ) D(x) D(x) : 判别器输出结果

判别器的目标函数

  • 准确分辨出真实样本 : x ∼ P r ( x ) x \sim P_r(x) xPr(x)

m a x i m i z e E x ∼ P r ( x ) [ l o g ( D ( x ) ] ,      D ( x ) → 1 maximize E_{x \sim P_r(x)} [log(D(x)], ~~~~ D(x) \to 1 maximizeExPr(x)[log(D(x)],    D(x)1

  • 准确分辨出生成样本 :

m a x i m i z e E x ∼ P G ( x ∗ ) [ l o g ( 1 − D ( x ) ) ] ,      D ( x ) → 0 maximize E_{x \sim P_G(x^*)} [log(1 - D(x))] , ~~~~D(x) \to 0 maximizeExPG(x)[log(1D(x))],    D(x)0

综合目标函数

G m i n D m a x L ( D , G ) = E x ∼ P r ( x ) [ l o g ( D ( x ) ] + E x ∼ P G ( x ∗ ) [ l o g ( 1 − D ( x ) ) ] G_{min} D_{max} L(D, G) = E_{x \sim P_r(x)} [log(D(x)] + E_{x \sim P_G(x^*)} [log(1 - D(x))] GminDmaxL(D,G)=ExPr(x)[log(D(x)]+ExPG(x)[log(1D(x))]

DCGAN实现

import os
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.layers import Conv2D,LeakyReLU,Input,BatchNormalization,Flatten,Conv2DTranspose,Activation,Dense,Reshape,Dropout
from tqdm import tqdm

生成模型

def create_generator(alpha=0.2):
    inputs = Input(shape=(128,))
    x = Dense(units=28 * 28 * 128, use_bias=False)(inputs)
    x = LeakyReLU(alpha=alpha)(x)
    x = BatchNormalization()(x)
    x = Reshape((28, 28, 128))(x)
    x = Conv2DTranspose(filters=128, strides=(1, 1),kernel_size=(5, 5), padding='same',use_bias=False)(x)
    x = LeakyReLU(alpha=alpha)(x)
    x = BatchNormalization()(x)
    x = Conv2DTranspose(filters=128, strides=(2, 2), kernel_size=(5, 5), padding='same', use_bias=False)(x)
    x = LeakyReLU(alpha=alpha)(x)
    x = BatchNormalization()(x)
    x = Conv2DTranspose(filters=128, strides=(2, 2), kernel_size=(5, 5), padding='same', use_bias=False)(x)
    x = LeakyReLU(alpha=alpha)(x)
    x = BatchNormalization()(x)
    x = Conv2D(filters=3, kernel_size=(3, 3),strides=(1, 1), padding='same')(x)
    outputs = Activation('tanh')(x)
    return Model(inputs, outputs)
generator = create_generator()
generator_opt = Adam(learning_rate=1e-4)

判别模型

def create_discriminator(alpha=0.2, dropout=0.2):
    inputs = Input(shape=(112, 112, 3))
    x = Conv2D(filters=128, kernel_size=(5, 5), strides=(2, 2), padding='same')(inputs)
    x = LeakyReLU(alpha=alpha)(x)
    x = Dropout(rate=dropout)(x)
    x = Conv2D(filters=64, kernel_size=(5, 5),strides=(2, 2), padding='same')(x)
    x = LeakyReLU(alpha=alpha)(x)
    x = Dropout(rate=dropout)(x)
    x = Conv2D(filters=64, kernel_size=(5, 5),strides=(2, 2), padding='same')(x)
    x = LeakyReLU(alpha=alpha)(x)
    x = Dropout(rate=dropout)(x)
    x = Conv2D(filters=64, kernel_size=(5, 5),strides=(2, 2), padding='same')(x)
    x = LeakyReLU(alpha=alpha)(x)
    x = Dropout(rate=dropout)(x)
    x = Flatten()(x)
    outputs = Dense(units=1)(x)
    return Model(inputs, outputs)
discriminator = create_discriminator()
discriminator_opt = Adam(learning_rate=1e-4)

损失函数

loss = BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, fake):
    real_loss = loss(tf.ones_like(real), real)
    fake_loss = loss(tf.zeros_like(fake), fake)
    return real_loss + fake_loss
def generator_loss(fake):
    return loss(tf.ones_like(fake), fake)

Training Step

单个训练步骤:

  • 生成随机噪声向量
  • 根据随机向量生成图像
  • 判断真实图像和生成图像的真伪
  • 计算生成损失和判别损失
  • 计算生成模型的损失函数对于模型参数的梯度
  • 更新生成模型的参数
  • 计算判别模型的损失函数对于模型参数的梯度
  • 更新判别模型的参数
@tf.function
def train_step(images, batch_size, noise_dim = 100):
    noise = tf.random.normal((batch_size, noise_dim))
    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
        gen_images = generator(noise, training=True)
        real_pred = discriminator(images, training=True)
        fake_pred = discriminator(gen_images, training=True)
        gen_loss = generator_loss(fake_pred)
        dis_loss = discriminator_loss(real_pred, fake_pred)
    gen_gradient = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gen_opt_args = zip(gen_gradient, generator.trainable_variables)
    generator_opt.apply_gradients(gen_opt_args)
    dis_gradient = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
    dis_opt_args = zip(dis_gradient, discriminator.trainable_variables)
    discriminator_opt.apply_gradients(dis_opt_args)
    return gen_loss, dis_loss

训练数据

来自 32 国家的 211 种不同的硬币数据集。

import pathlib
import numpy as np
from glob import glob
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
file_patten = str(pathlib.Path.home()/'DeepVision/SeData/coins/data/*/*/*.jpg')

DataSetPaths = np.array([*glob(file_patten)])
def process_image(image):
    image = (image - 127.5) / 127.5
    return image
def resize_image(original_image, size=(112, 112)):
    new_size = tuple(size)
    resized = original_image.resize(new_size)
    resized = np.array(resized)
    resized = resized.astype(np.uint8)
    return resized
train_data = []
for image_path in tqdm(DataSetPaths, ncols=60):
    image = Image.open(image_path)
    image = resize_image(image)
    if image.shape[-1] != 3:
        continue
    train_data.append(process_image(image))
train_data = np.array(train_data)
100%|██████████████████| 8101/8101 [00:30<00:00, 264.58it/s]

在这里插入图片描述
在这里插入图片描述

训练模型

def generate_batch_image(train_data, batch_size):
    indices = np.random.choice(range(0, len(train_data)), batch_size, replace=False)
    batch_images = train_data[indices]
    return batch_images
BS = 128
Gloss = []
Dloss = []
for epoch in tqdm(range(200),ncols=60):
    loss1 = []
    loss2 = []
    for step in range(150):
        batch_images = generate_batch_image(train_data, batch_size=BS)
        gloss, dloss = train_step(batch_images, noise_dim=128, batch_size=BS)
        loss1.append(gloss)
        loss2.append(dloss)
    #print("Gen_loss : %.3f, Dis_loss : %.3f"%(np.mean(loss1), np.mean(loss2)))
    Gloss.append(np.mean(loss1))
    Dloss.append(np.mean(loss2))
100%|███████████████████| 200/200 [1:15:27<00:00, 22.64s/it]

测试生成图像

def generate_and_save_images(model=generator, epoc=0, test_input=None):
    predictions = model(test_input, training=False)
    plt.figure(figsize=(12, 12))
    for i in range(predictions.shape[0]):
        plt.subplot(6, 6, i + 1)
        image = predictions[i, :, :, :] *127.5 + 127.5
        image = tf.cast(image, tf.uint8)
        plt.imshow(image, cmap='Greys_r')
        plt.axis('off')
    #plt.savefig(f'{epoch}.png')
    plt.show()
test_seed = tf.random.normal((36, 128)) 
generate_and_save_images(test_input=test_seed)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/118277.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java系列】小小练习——带你回顾Java基本运算符

返回主篇章         &#x1f447; 【Java】才疏学浅小石Java问道之路 Java小练习1. 练习一1.1 题目1.2 题解(附解析)2. 练习二2.1 题目2.2 题解(附解析)3. 练习三3.1 题目3.2 题解(附解析)小结1. 练习一 1.1 题目 一个三位数&#xff0c;将其拆分为个位、十位、百位后…

mac安装cocoapods完整步骤

一、概念理解 首先不要急着搜索终端命令&#xff0c;你需要明白安装 cocoapods 都需要什么环境&#xff0c;这对于安装途中如果遇到问题该如何解决很重要&#xff0c;很重要&#xff0c;很重要&#xff01; 1、安装pods需要依赖 ruby 环境&#xff0c;而安装 ruby 你需要借助工…

[网鼎杯 2020 白虎组]PicDown(任意文件读取)

打开界面发现有一个get传参然后&#xff0c;尝试任意文件读取漏洞&#xff0c;/etc/passwd看一下,提示下载了一个jpg图片然后 打不开只能用 010查看一下信息 看来是猜对了&#xff0c;然后 如果日记没删掉可以查看历史记录 .bash_history呃呃呃差不到&#xff0c;那就看一下现…

Python 现代控制理论 —— 梯度下降法实现的线性回归系统

线性回归是有监督学习中的经典问题&#xff0c;其核心在于找到样本的多个特征与标签值之间的线性关系。样本集中的第j个样本可被表示为&#xff1a; 特征向量&#xff1a;标签值&#xff1a; 而线性回归系统给出权重向量&#xff1a; 使得该样本的预测值为&#xff1a; 当所有…

Python采集某网站m3u8内容,美女我来了~

前言 嗨喽&#xff0c;大家好呀~这里是爱看美女的茜茜呐 又到了学Python时刻~ 环境使用: Python 3.8 Pycharm 模块使用: import requests >>> pip install requests import re 正则表达式 解析数据 import json 安装python第三方模块: win R 输入 cmd 点击确…

不止一面的百变 ACE

这个时代&#xff0c;可谓是云原生的黄金时代。 站在这个云原生的风口&#xff0c;年轻一代的开发者如何看待自己所处的环境&#xff1f;他们眼中的云原生未来是什么样&#xff1f; 今天我们就将走近一位年轻的“云原生原住民”&#xff0c;听听他作为开发者的成长经历。 War…

【python3】9.python高阶内容(上)_基础

9.python高阶内容&#xff08;上&#xff09;_基础 2022.12.27 python高阶内容&#xff08;上&#xff09;_基础9.1 字符串的高阶玩法 9.1.1 %百分号模式 %d:整数%i:整数%s:字符%f:小数 【方式1】&#xff1a;前面用格式占位&#xff0c;后面用具体的内容 name "莫烦…

Android设计模式详解之访问者模式

前言 访问者模式是一种将数据操作与数据结构分离的设计模式&#xff1b; 定义&#xff1a;封装一些作用于某种数据结构中的各元素的操作&#xff0c;它可以在不改变这个数据结构的前提下定义作用于这些元素的新的操作&#xff1b; 使用场景&#xff1a; 对象结构比较稳定&a…

大厂与小厂招人的区别,看完多少有点不敢相信

前两天在头条发了一条招人的感慨&#xff0c;关于大厂招人和小公司招人的区别。 大厂&#xff1a;有影响力&#xff0c;有钱&#xff0c;能够吸引了大量的应聘者。因此&#xff0c;也就有了筛选的资格&#xff0c;比如必须985名校毕业&#xff0c;必须35岁以下&#xff0c;不能…

基于DoIP使用CANoe对ECU进行诊断测试

伴随以太网引入到车载网络中,本文分享通过常用工具CANoe怎么样对ECU进行通信以及测试。 相比在车载CAN总线,以太网又有什么与众不同之处? 1、硬件接口卡(收发器) 以往车载CAN网络较常使用的是VN 16XX 系列,在连接ECU进行通信时,除了配置波特率也要进行通道分配: 而…

7个学习UI、UX设计一定要经历的步骤

我们不是一些有才华的设计师。我们天生就有艺术天赋。后天我们学会了设计技巧。设计的根本目的是解决问题。设计是不断发现和解决问题。 有许多设计领域&#xff1a;UI、UX.产品设计师.平面设计师.交互设计师.信息架构师等&#xff0c;所以要找出你最感兴趣的设计专业。 现在让…

美颜sdk动态贴纸技术、代码分析

目前&#xff0c;美颜sdk动态贴纸已经成了各大直播平台主播的必备“直播伴侣”&#xff0c;在其他的视频拍摄场景动态贴纸的热度同样很高&#xff0c;本篇文章小编将为大家深度盘点一下美颜sdk动态贴纸的技术实现以及代码。 一、多终端适配 对于如今的直播平台终端来说&#x…

CAPL学习之路-测试功能集函数(测试结构化)

用户可以使用如下函数在测试报告中对每一条测试用例设置结构化的输出内容 TestCaseDescription 添加测试用例的描述文本 此函数用于测试用例中,描述文本会添加在固定区域(测试用例title的下方)。多次调用该函数,描述文本会合并显示在固定区域。如果想让描述文本换行,可以…

爆火的Web3.0背后,百度营销如何抓住流量密码?

出品| 大力财经 文 | 魏力 AI、元宇宙、Web3.0、AIGC等新技术、新概念的加持&#xff0c;给传统的流量营销平台带来了前所未有的挑战。尤其是短视频时代的崛起&#xff0c;用户的使用习惯开始改变&#xff0c;完全改变了流量的逻辑和习惯。 从搜索引擎业务起家的百度&#x…

DoIP---车载以太网诊断方面边缘节点的路由策略分析

假期后开工第一天&#xff0c;规划好自己一天需要做的事情&#xff0c;按部就班完成每日任务&#xff0c;做好每日总结。 自己一天一个脚印&#xff0c;这不是鸡血&#xff0c;这是工作态度&#xff01;&#xff01;&#xff01; 惯例分享一段喜欢的文字&#xff1a; 每个人…

目标检测之FCOS算法分析

网络结构 (图片来自原论文&#xff1a;FCOS: Fully Convolutional One-Stage Object Detection) 在ResNet50 Backbone中&#xff0c;C3,C4,C5C3,C4,C5C3,C4,C5是卷积特征图&#xff1b; 在FPN结构中&#xff0c;P3,P4,P5,P6,P7P3,P4,P5,P6,P7P3,P4,P5,P6,P7是最后用于预测的特…

2023跨境出海指南:泰国网红营销白皮书

作为东南亚第二大经济体&#xff0c;泰国一直是旅游和企业出海的热门之地。随着电商经济和互联网的发展&#xff0c;加上疫情的催化&#xff0c;泰国的社交媒体行业也得到了飞速发展&#xff0c;已经成为了主流营销方式之一。本文Nox聚星就从网红营销的角度&#xff0c;和大家探…

代码随想录-46-226.翻转二叉树

目录前言题目1.使用队列思路&#xff08;定义变量&#xff09;2. 本题思路分析&#xff1a;3. 算法实现4. pop函数的算法复杂度5. 算法坑点前言 在本科毕设结束后&#xff0c;我开始刷卡哥的“代码随想录”&#xff0c;每天一节。自己的总结笔记均会放在“算法刷题-代码随想录…

浅谈一下个人基于IRIS后端业务开发框架的理解

文章目录浅谈一下个人基于IRIS后端业务开发框架的理解现状方案具体实现BaseBizDataFilterSqlImp、RefApiUtil总结浅谈一下个人基于IRIS后端业务开发框架的理解现状由于国内使用基于M语言IRIS平台几乎都在医疗行业。医疗系统又非常的庞大和复杂。前期由于快速占领市场&#xff0…

珠城科技在创业板上市:IPO首日跌破发行价,市值相对蒸发约7亿元

12月26日&#xff0c;浙江珠城科技股份有限公司&#xff08;下称“珠城科技”&#xff0c;SZ:301280&#xff09;在深圳证券交易所创业板上市。本次上市&#xff0c;珠城科技的发行价格为67.40元/股&#xff0c;发行数量为1628.34万股&#xff0c;募资总额约为10.98亿元&#x…