图像生成--对抗生成模型

news2024/11/23 18:44:38

生成模型概述

对抗生成模型

机器学习中的两大主要问题:

  1. 判别
  2. 生成

判别模型的典型代表即为图像分类任务,即给定一个数据,判定他是哪一类。

判别模型学习到的是一个概率(贝叶斯过程)

而生成模型的区别在于,给定一个数据,将其生成为预期数据。

 

在数学上,生成模型与判别模型的区别在于:

给定观测值x:

  • 判别模型旨在判别得到y的概率

  • 生成模型旨在根据指定的y得到x的概率

生成模型的应用

超分辨率

图像生成(风格迁移)

生成模型原理简要说明

在GoodFellow的论文中,以最大似然估计进行举例。

首先需要说明的问题是:

生成模型的本质,在于从训练数据中学习到数据的分布

学习到了分布之后,给定一个随机的噪声

过程可以简单地理解为:这个噪声中,符合指定分布的内容得到加强,不符合指定分布的内容会被削弱

当在迭代过程中,数据逐渐贴合预期的输入,从而看上去更逼真。

方法分类

基于最大似然估计的数据生成,是生成模型的理论基础。

按照不同的形式和流派,大致可以分成下面的类别。

此处不对研究脉络的具体细节进行探究,只是对原理进行比喻式介绍。

  1. Explicit density: 显性密度。也就是说,我们在这类方法中,需要给出分布模型的具体形式(密度函数),通过各种迭代运算,来得到模型的真实参数。

  2. Implicit density:隐性密度。在这类方法中,不指定数据分布密度函数,而是通过数据分布所满足的条件,用拟合能力比较强的模型来寻找合适的模型和分布参数。

GAN则属于隐式密度方法,不需要指定模型的具体分布密度函数,来得到较好的分布拟合。

拓展:生成模型可以视为一种损失函数

该部分内容会在后续进行进一步展开,此处只做简单介绍。

首先,我们通常会采用显式的函数作为损失函数。

这种方式带来很多便利,但并不一定精确(对特定任务来说)。

我们用对抗生成式的模型对网络进行约束,从而能够不使用显性的函数来约束模型。

对于用于约束的网络,我们将一些必须要满足的条件作为约束目标,从而令约束模型进一步地摸索出更好的约束边界。

生成对抗模型GAN

Generative Adversarial Network,GAN是一种深度学习模型,属于一种无监督学习的方法。

其目的在于,从数据中学习分布,来得到足以以假乱成真的数据。

为了达到这个目的,通常包含两个基本模型:生成器和判别器。(generative model, G)和(discriminative model, D)

判别模型学习“分界面(分解曲线)”

在训练过程中,利用合理的结构和设定,令二者满足纳什均衡,来得到最优解。

GAN原理

GAN的过程,离不开两个关键内容:生成与对抗。

Goodfellow的例子如下:

一个城市中,有一群小偷(生成器)和一群警察(判别器)。

小偷的目的在于,想方设法地欺骗警察;

而警察的目的在于,想方设法地不受欺骗。

这样一来,小偷在不断的欺骗和被识破的过程中不断精进技能,从而掌握了更加不易被识破的欺骗技能;

警察则在被欺骗的过程中,不断提高辨识功能,从而对欺骗的细节做出判断,更加接近本质。

生成器 生成器采用随机输入,尝试输出样本数据。根据输入的样本随机产生一个数据,将其送入鉴别器

鉴别器 鉴别器的任务在于,接受两个输入,分别是生成器的输入和真实数据,判别器的目的在于判断生成器的输入是不是真的。

数学表达

上述过程中,希望判别器能够最大程度地判别出真实数据为真,生成数据为假

而生成器则是能够最大程度地令判别器产生误判

训练过程

两阶段训练:

  1. 固定生成器参数,训练判别器

  2. 固定判别器,训练生成器

GAN模型的训练过程是一个非常复杂的训练过程,早期的GAN训练也非常麻烦。

训练难度之所以大,一个重要的原因在于,难以掌控生成器和判别器的能力。

理解:

如果小偷很厉害,则警察无法从中提升判别能力;

如果警察很厉害,小偷则会被一网打尽,无法提升其“造假能力”

理论上,如果判别器过于强大,生成器则会由于步长太大无法找到全局最优解。

一个简单的例子在于,人类现代科技无法从外星人科技中吸收影响,从而无法引发科技进步。

因此,通常是训练多轮生成器,再训练少轮判别器

通俗来说,GAN训练的过程应当是一个循序渐进,相辅相成的过程。如果一开始,通过载入与训练模型令判别器具有很高的能力,往往会令GAN难以有效收敛。

代码实践

参考

In [1]:

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from tqdm import tqdm

In [2]:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(.5, .5)
])
train_data = torchvision.datasets.MNIST('data', 
                                        train=True,
                                        transform=transform,
                                        download=True)

dataloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/9912422 [00:00<?, ?it/s]
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/28881 [00:00<?, ?it/s]
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/1648877 [00:00<?, ?it/s]
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
  0%|          | 0/4542 [00:00<?, ?it/s]
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

In [3]:

# generator
class Gen(nn.Module):
    def __init__(self):
        super(Gen, self).__init__()
        self.gen = nn.Sequential(nn.Linear(100, 256), 
                                 nn.ReLU(),
                                 nn.Linear(256, 512), 
                                 nn.ReLU(), 
                                 nn.Linear(512, 28*28), 
                                 nn.Tanh())
    
    def forward(self, x):
        img = self.gen(x)
        img = img.view(-1, 28, 28)
        return img

In [4]:

# discriminator
class Dis(nn.Module):
    def __init__(self):
        super(Dis, self).__init__()
        self.dis = nn.Sequential(nn.Linear(28*28, 512), 
                                 nn.LeakyReLU(),
                                 nn.Linear(512, 256), 
                                 nn.LeakyReLU(), 
                                 nn.Linear(256, 1),
                                 nn.Sigmoid())
    
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.dis(x)
        return x

In [5]:

gen = Gen().to('cpu')
dis = Dis().to('cpu')

d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)

loss_func = torch.nn.BCELoss()

In [6]:

# train
loss_d = []
loss_g = []

for epoch in range(50):
    d_epoch_loss = 0
    g_epoch_loss = 0
    batch_count = len(dataloader)
    
    for i, (img, _) in enumerate(tqdm(dataloader)):
        img = img.to('cpu')
        size = img.size(0)
        random_noise = torch.randn(size, 100, device='cpu')
        
        d_opt.zero_grad()
        real_output = dis(img)
        d_real_loss = loss_func(real_output, 
                                torch.ones_like(real_output))
        d_real_loss.backward()
        
        gen_img = gen(random_noise)
        fake_output = dis(gen_img.detach())
        d_fake_loss = loss_func(fake_output, 
                                torch.zeros_like(fake_output))
        d_fake_loss.backward()
        
        d_loss = d_real_loss + d_fake_loss
        d_optim.step()
        
        g_optim.zero_grad()
        fake_output = dis(gen_img)
        g_loss = loss_func(fake_output, 
                           torch.ones_like(fake_output))
        
        g_loss.backward()
        g_optim.step()
    
        torch.save(gen.state_dict(), str(epoch).zfill(2) + ".pth")
100%|█████████████████████████████████████████| 938/938 [00:22<00:00, 42.29it/s]
100%|█████████████████████████████████████████| 938/938 [00:22<00:00, 42.55it/s]
100%|█████████████████████████████████████████| 938/938 [00:22<00:00, 42.48it/s]
100%|█████████████████████████████████████████| 938/938 [00:22<00:00, 42.15it/s]
 57%|███████████████████████▍                 | 537/938 [00:12<00:09, 41.52it/s]
---------------------------------------------------------------------------

In [7]:

# show result
def result_show(weight, test_input):
    gen = Gen().to('cpu')
    gen.load_state_dict(torch.load(weight))
    gen.eval()
    plot_img(gen, test_input)

In [8]:

# plot image
import matplotlib.pyplot as plt

def plot_img(model, _input):
    prediction = model(_input).detach().cpu().numpy()
    print(prediction.shape)
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow((prediction[i]+1)/2)
        plt.axis('off')
    plt.show()

In [9]:

random_noise = torch.randn(size, 100)
import numpy as np

result_show('./00.pth', random_noise)
(64, 28, 28)

GAN模型进阶

GAN模型的本质

学习训练数据的分布,符合训练数据分布的数据,具有较好的可视化效果;

在分布之外的数据,可视化效果较差。

那么GAN模型的根本问题是:

找一个生成模型G,该模型定义了概率分布

给定一个分布z,找到一个G,可以使分布比较相似。

 具体上,从符合z分布中采样多个点,得到了多个x。

进而,从创造一个D,用于引导采样。

需要说明的是,D的loss值与生成数据和真实数据的内容息息相关。

如果说损失越大,则越说明生成的数据和真实数据越接近。

一个直观的例子

李宏毅推荐的例子

GAN的本质:散度

散度定义(divergence):p(x)和q(x)到底有多不一样

性质1: 散度取值在0-1之间,越接近于0,分布越相似。否则分布区别越大。

那么GAN的本质,在于如何度量散度,即如何设定一个合适的函数f,来得到一个良好的分布拟合。

KL散度:描述数据分布之间的相似性

卡方散度:判断两个样本是否符合相同的分布

关于散度和GAN的关系

散度用于评价分布的相似程度。

常用的KL散度,公式为

但KL散度存在不对称性,在basic gan里,用的是JS散度

使用JS散度存在一个比较大的问题,即如果分布相差较远,则会等于一个恒定的值。不利于模型收敛。

因此,可以灵活地调整散度,来适应不同类型的数据。

如何把散度作为优化目标?

散度可以衡量两个分布,那么如何将散度作为他的优化函数呢?

凸共轭

 

红线部分即共轭函数的曲线,可以看出他也是凸函数。

如何求解一个函数的凸函数?

采用极值求导的方式求解。

例如f(x)=xlog⁡x

一般形式的GAN

回到GAN中,有

那么我们的目的就在于:

直观上的感受:

另一种思路 WGAN

有颜色的色块表示把第i行的分布,修改到第j行。(推土机)

运送路径越多,运送的货物越多,则做的功越大。

那么首先定义运送的功

进而,只需要找到运送功最小的那个方案就可以了

注意,这里需要定义D的函数需要满足1-Lipschitz,即

其中,k=1

这样的作用在于,令y的增长不超过x。也就是限制模型不要更新的太快。

否则,如果取消限制,那么就会令D直接爆炸。

求解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/650736.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Elacticsearch】 原理/数据结构/面试经典问题整理

对Elacticsearch 原理/数据结构/面试经典问题整理的文章&#xff1b; 映射 | Elasticsearch: 权威指南 | Elastic Elacticsearch介绍 Elasticsearch,这里简称ES。ES是一个开源的高可用高扩展的分布式全文搜索与分析引擎&#xff0c;可以提供PB级近实时的数据存储和检索能力&am…

SpringBoot自定义starter入门

一、创建一个普通的Maven项目 pom.xml文件修改 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation&quo…

你知道如何将音频转文字吗?

我跟你们说&#xff0c;我认识一名盲人音乐人&#xff0c;他很热爱音乐创作&#xff0c;但是因为听力的限制&#xff0c;无法像其他人那样从录音中获取音乐素材。然而&#xff0c;有一天他听说了一个神奇的功能——音频转文字&#xff0c;这个功能可以将音频文件转换成文字文本…

从技术谈到管理,把系统优化的技术用到企业管理

很多技术人员在职业上对自己要求高&#xff0c;工作勤奋&#xff0c;承担越来越大的责任&#xff0c;最终得到信任&#xff0c;被提拔到管理岗位。但是往往缺乏专业的管理知识&#xff0c;在工作中不能从整体范围优化工作流程&#xff0c;仍然是“个人贡献者”的工作方式&#…

低功耗晶振电路设计

晶振电路设计 晶振中负性阻抗的原理 晶振的回路主要由两部分组成&#xff0c; 一部分是激活分支&#xff0c; 用于提供能量给晶振启动直至达到稳定的相位&#xff0c;另一部分是被动分支&#xff0c; 主要由电阻&#xff0c; 两个外部负载电容以及所有的寄生电容&#xff0c;…

手写RPC总结篇

协议制定&#xff1a;client到server做交互的通信协议&#xff0c;比如request response 网络端点peer 难点1 : Jetty嵌入 ◆jetty Server ◆ServletContextHandler ◆ServletHolder jetty server 起到网络监听的作用ServletContextHandler注册到jetty server中ServletHolde…

测试开发之Python自动化 Pytest 之 fixture

Pytest 之 fixture unittest 和 nose 都支持 fixture 的,但是 fixture 在 pytest 里使用更灵活。也算是 pytest 的一个闪光点吧可以理解为一个跟 setup 和 teardown 这种前后置类似的东西。但是比它们要强大、灵活很多 fixtur 当做参数传入 # -*- coding: utf-8 -*-import p…

图像处理实战02-yolov5目标检测

yolov5 YOLOv5 是一种目标检测算法&#xff0c;它是 YOLO (You Only Look Once) 系列算法的最新版本。YOLOv5 采用了一种新的架构&#xff0c;它包括一个基于 CSPNet (Cross Stage Partial Network) 的主干网络以及一系列改进的技巧&#xff0c;如多尺度训练、数据增强、网络混…

互联网行业-镭速文件传输系统方案

互联网行业是一个快速变化和高度竞争的行业&#xff0c;这一行业需要传输大量的数据、代码和文件。在互联网企业的生产和运营过程中&#xff0c;需要传输各种敏感和大型的文件&#xff0c;例如业务报告、数据分析、软件代码等。这些文件需要在不同的部门、不同的地点之间高效地…

微服务springcloud 07 hystrix + turbine 集群聚合监控

01.hystrix dashboard 一次只能监控一个服务实例&#xff0c;使用 turbine 可以汇集监控信息&#xff0c;将聚合后的信息提供给 hystrix dashboard 来集中展示和监控 02.新建 sp10-turbine 项目 03.pom.xml <?xml version"1.0" encoding"UTF-8"?&…

C语言---malloc(0)会产生什么结果,真的是空指针吗?

前言 &#xff08;1&#xff09;几天前在一个交流群中看到有人说&#xff0c;面试问malloc(0)会怎么样是真的恶心。 &#xff08;2&#xff09;这个突然激起了我的好奇心。居然还可以malloc(0)&#xff1f;&#xff01; &#xff08;3&#xff09;经过测试最后&#xff0c;发现…

基于Java学生课外知识学习网站设计实现(源码+lw+部署文档+讲解等)

博主介绍&#xff1a; ✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战 ✌ &#x1f345; 文末获取源码联系 &#x1f345; &#x1f447;&#x1f3fb; 精…

PulsarMQ系列入门篇

文章目录 介绍&#xff1a;部署安装讲解:安装单机版本测试&#xff08;Linux下&#xff09;&#xff1a; 介绍&#xff1a; PulsarMQ 现托管于apache Apache 软件基金会顶级项目&#xff0c;2016年由雅虎公司开源的分布式多租户消息中间件 &#xff0c;是下一代云原生分布式消息…

PaddleOCR #hello paddle: 从普通程序走向机器学习程序 - 初识机器学习

这篇示例向你介绍普通程序跟机器学习程序的区别&#xff0c;并带着你用百度飞桨框架&#xff0c;实现第一个机器学习程序&#xff0c;并初步认识机器学习。 作为一名开发者&#xff0c;你最熟悉的开始学习一门编程语言&#xff0c;或者一个深度学习框架的方式&#xff0c;可能是…

万字长文:大模型训练避坑指南

自 2022 年 11 月底 ChatGPT 发布以来&#xff0c;大模型的热度持续发酵&#xff0c;相信高屋建瓴的讨论大家已经看了很多了。今天我们选择从实用角度&#xff0c;分别就算力、算法、工程、数据和团队等方向讨论了训练一个千亿参数量级的大语言模型和 ChatGPT 需要些什么&#…

4.17 TCP三次握手 4.18滑动窗口 4.19TCP四次挥手

4.17 TCP三次握手 TCP 是一种面向连接的单播协议&#xff0c;在发送数据前&#xff0c;通信双方必须在彼此间建立一条连接。所谓的“连接”&#xff0c;其实是客户端和服务器的内存里保存的一份关于对方的信息&#xff0c;如 IP 地址、端口号等。 TCP 可以看成是一种字节流&a…

i.MX 91x推出,飞凌嵌入式携手NXP打造更强大、更经济、更安全的解决方案

NXP在COMPUTEX 2023上发布了i.MX 91应用处理器系列&#xff0c;作为i.MX 9系列的入门级产品&#xff0c;i.MX 91x简化了高性价比边缘设备的开发过程&#xff0c;助力构建需要安全性、高性能表现以及Linux支持的可扩展、高可靠性的平台&#xff0c;可满足下一代基于Linux的物联网…

【数据库三】MySQL索引

MySQL索引、事务与存储引擎 1.MySQL索引1.1 索引的概念1.2 索引的作用​1.3 索引的副作用​1.4 创建索引的原则依据​ 2.索引的分类和创建2.1 普通索引2.2 唯一索引2.3 主键索引2.4 组合索引2.5 全文索引 3. 查看索引4.删除索引5. 知识点总结 1.MySQL索引 1.1 索引的概念 索引…

基于html+css的图展示127

准备项目 项目开发工具 Visual Studio Code 1.44.2 版本: 1.44.2 提交: ff915844119ce9485abfe8aa9076ec76b5300ddd 日期: 2020-04-16T16:36:23.138Z Electron: 7.1.11 Chrome: 78.0.3904.130 Node.js: 12.8.1 V8: 7.8.279.23-electron.0 OS: Windows_NT x64 10.0.19044 项目…

Unity光照贴图的切换,实现黑夜和白天效果

有这么一个需求&#xff0c;不能使用实时光来进行动态控制光照开关&#xff0c;但是又要实现白天和黑夜的效果&#xff0c;我的场景中有大概十几个点光源和平行光 实现步骤&#xff1a; 一、模型原模原样复制到另一个场景中&#xff08;因为贴图只能存在于当前的场景文件夹&am…