用VAE生成图像

news2025/1/9 16:41:54

用VAE生成图像

  • 自编码器AE,auto-encoder
  • VAE
    • 讲讲为什么是log_var
    • 为什么要用重参数化技巧
  • 用VAE生成图像

变分自编码器是自编码器的改进版本,自编码器AE是一种无监督学习,但它无法产生新的内容,变分自编码器对其潜在空间进行拓展,使其满足正态分布,情况就大不一样了。

自编码器AE,auto-encoder

自编码器是通过对输入X进行编码后得到一个低维的向量z,然后根据这个向量还原出输入X。通过对比X与 X ∼ \overset{\sim}{X} X的误差,再利用神经网络去训练使得误差逐渐减小,从而达到非监督学习的目的。
下图为AE的架构图:
在这里插入图片描述
自编码器不能随意产生合理的潜在变量,从而导致它无法产生新的内容。因为潜在变量Z都是编码器从原始图片中产生的。为了解决这一问题,研究人员对潜在空间Z(潜在变量对应的空间) 增加一些约束,使 Z 满足正态分布,由此就出现了VAE模型, VAE对编码器添加约束,就是强迫它产生服从单位正态分布的潜在变量。正是这种约束,把VAE和 AE 区分开来。

VAE

变分自编码器关键一点就是增加一个对潜在空间Z的正态分布约束,如何确定这个正态分布就成为主要目标,我们知道要确定正态分布,只要确定其两个参数: 均值 μ \mu μ和标准差 σ \sigma σ。那如何确定 μ , σ \mu, \sigma μ,σ呢?用一般的方法或估计比较麻烦效果也不好,研究人员发现**用神经网络去拟合,简单效果也不错。**下图为VAE的架构图:
在这里插入图片描述
上图中,模块①的功能把输入样本X通过编码器输出两个m维向量( μ , l o g _ v a r \mu, \mathrm{log\_var} μ,log_var), 这两个向量是潜在空间(假设满足正态分布)的两个参数(相当于均值和方差)。那么如何从这个潜在空间采样一个点 Z ?

这里假设潜在正态分布能生成输入图像,从标准正态分布 N(0, 1)中采样一个 ϵ \epsilon ϵ(模块②的功能), 然后使
Z = μ + e l o g _ v a r ∗ ϵ Z=\mu + e^{log\_var}*\epsilon Z=μ+elog_varϵ
这也是模块③的主要功能。
Z是从潜在空间抽取的一个向量,Z通过解码器生成一个样本 X ∼ \overset{\sim}{X} X, 这是模块④的功能。这里的 ϵ \epsilon ϵ 是随机采样的,这就可以保证潜在空间的连续性,良好的结构性。而这些特性使得潜在空间的每个方向都表示数据中有意义的变化方向。

以上的这些步骤构成整个网络的前向传播过程,那反向传播应如何进行?要确定反向传播就会设计损失函数,损失函数是衡量模型优劣的主要指标。这里我们需要从以下两个方面进行衡量。
1)生成的新图像与原图像的相似度;
2)隐含空间的分布与正态分布的相似度。

度量图像的相似度一般采用交叉熵(如nn.BCELoss) , 度量两个分布的相似度一般采用KL散度(Kullback-Leibler divergence)。这两个度量的和构成了整个模型的损失函数。

以下是损失函数的具体代码,VAE损失函数的推导过程可以参考原论文

# 定义重构损失函数及KL散度
reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
kl_div = -0.5*torch.sum(1+log_var-mu.pow(2)-log_var.exp())
# 两者相加得总损失
loss = reconst_loss + kl_div

讲讲为什么是log_var

这里可以看成 log_var = log ⁡ σ \log \sigma logσ,所以 Z = μ + e l o g _ v a r ∗ ϵ Z=\mu + e^{log\_var}*\epsilon Z=μ+elog_varϵ也就是 Z = μ + σ ∗ ϵ Z=\mu + \sigma*\epsilon Z=μ+σϵ
其中 ϵ ∼ N ( 0 , 1 ) \epsilon\sim N(0,1) ϵN(0,1), 这里涉及到重参数化reparameterization。

为什么要用重参数化技巧

如果想从高斯分布 N ( μ , σ 2 ) N(\mu,\sigma^{2}) N(μ,σ2)中采样,可以先从标准分布 N ( 0 , 1 ) N(0,1) N(0,1)采样出 ϵ \epsilon ϵ , 再得到 Z = σ ∗ ϵ + μ Z = \sigma*\epsilon+\mu Z=σϵ+μ.
这样做的好处是

  1. 如果直接对 N ( μ , σ 2 ) N(\mu,\sigma^{2}) N(μ,σ2)进行采样得到Z,则Z无法对 μ , σ \mu,\sigma μ,σ进行求偏导

  2. 将随机性转移到了 ϵ \epsilon ϵ 这个常量上,而 σ \sigma σ μ \mu μ则当做仿射变换网络的一部分,这样得到的 Z = σ ∗ ϵ + μ Z = \sigma*\epsilon+\mu Z=σϵ+μ,则Z就可以对 μ , σ \mu,\sigma μ,σ进行求偏导来计算损失函数,进行求梯度,进行BP。

用VAE生成图像

下面将结合代码,用pytorch实现,为便于说明起见,数据集采用MNIST,整个网络架构如下图所示。
VAE网络架构图

# 1. 导入必要的包
import os 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision 
from torchvision import transforms 
from torchvision.utils import save_image

# 2. 定义一些超参数
image_size = 784 # 28*28 
h_dim = 400 
z_dim = 20 
num_epochs = 30 
batch_size = 128 
learning_rate = 0.001 

# 如果没有文件夹就创建一个文件夹
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
                                     
  1. 对数据集进行预处理,如转换为Tensor, 把数据集转换为循环,可批量加载的数据集
# 只下载训练数据集即可
# 下载MNIST训练集
dataset = torchvision.datasets.MNIST(root='./data', train=True, 
                                    transform=transforms.ToTensor(),
                                    download=True)

# 数据加载
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                    batch_size=batch_size,
                                    shuffle=True)
  1. 构建VAE模型,主要由Encoder和Decoder两部分组成

# 定义AVE模型
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)

    def encoder(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    # 用mu, log_var生成一个潜在空间点z, mu, log_var为两个统计参数,我们假设
    # 这个假设分布能生成图像
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std 

    def decoder(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decoder(z)
        return x_reconst, mu, log_var 
    
  1. 选择GPU和优化器
# 选择GPU和优化器
torch.cuda.set_device(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  1. 训练模型,同时保存原图像与随机生成的图像
# 训练模型,同时保存原图像与随机生成的图像
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # 获取样本,并前向传播
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)

        # 计算重构损失和KL散度(KL散度用于衡量两种分布的相似程度)
        # KL散度的计算可以参考https://shenxiaohai.me/2018/10/20/pytorch-tutorial-advanced-02/
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        # 反向传播和优化
        loss = reconst_loss + kl_div 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1)%100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div:{:.4f}'.
                  format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))

    # 利用训练的模型进行测试
    with torch.no_grad():
        # 保存采样图像,即潜在向量z通过解码器生成的新图像
        # 随机生成的图像
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decoder(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # 保存重构图像,即原图像通过解码器生成的图像
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))


  1. 展示原图像及重构图像
#显示图片
import matplotlib.pyplot as plt 
import matplotlib.image as mpimg 
import numpy as np 

recons_path = './samples/reconst-30.png'
Image = mpimg.imread(recons_path)
plt.imshow(Image)
plt.axis('off')
plt.show()
# reconst
# 奇数列为原图像,欧数列为原图像重构的图像,可以看出重构效果还不错。

在这里插入图片描述
8. 由潜在空间通过解码器生成的新图像,这个图像效果也不错


# sampled
# 为由潜在空间通过解码器生成的新图像,这个图像效果也不错
genPath = './samples/sampled-30.png'
Image = mpimg.imread(genPath)
plt.imshow(Image)
plt.axis('off')
plt.show()

在这里插入图片描述
总结:这里构建网络主要用全连接层,有兴趣的读者,可以把卷积层,如果编码层使用卷积层(如nn.Conv2d), 则解码器就需要使用反卷积层(如nn.ConvTranspose2d)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/385341.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

二、Redis安装配置(云服务器、vmware本地虚拟机)

一、自己购买服务器 自己购买阿里云、青牛云、腾讯云或华为云服务器, 自带CentoOS或者Ubuntu环境,直接开干 二、Vmware本地虚拟机安装 1、VMWare虚拟机的安装,不讲解,默认懂 2、如何查看自己的linux是32位还是64位 getconf L…

云HIS医院管理系统源码 云HIS系统源码 SaaS模式 springboot开发

▶ SaaS运维平台多医院入驻强大的电子病历模板 ,有源码,有演示! ▶ 云HIS系统技术框架: 总体框架: SaaS应用,全浏览器访问 前后端分离,多服务协同 服务可拆分,功能易扩展 ▶ 云HI…

初阶C语言——实用调试技巧【详解】

文章目录1. 什么是bug?2. 调试是什么?有多重要?2.1 调试是什么?2.2 调试的基本步骤2.3 Debug和Release的介绍3.学会使用快捷键4.调试的时候查看程序当前信息4.1 查看临时变量的值4.2 查看内存信息4.3 查看调用堆栈4.4 查看汇编信息…

混凝土搅拌站远程监控解决方案

一、项目背景 随着大规模的基础设施建设,对混凝土搅拌设备的需求量日益增加,对其技术指标的要求也日益提高,其技术性能将直接关系到工程的质量和使用寿命。而混凝土生产的质量是在生产过程中形成的,而非最终强度的检测。混凝土生…

10 面向接口编程(上):一切皆服务,服务基于协议

按照面向接口编程的理念,将每个模块看成是一个服务,服务的具体实现我们其实并不关心,我们关心的是服务提供的能力,即接口协议。那么框架主体真正要做的事情是什么呢?其实是:定义好每个模块服务的接口协议&a…

That引导的宾语从句

That引导的宾语从句指的是that为宾语从句的引导词。宾语从句:置于动词、介词等词性后面,在句子中起宾语作用的从句叫做宾语从句。宾语从句分为三类:动词的宾语从句,介词的宾语从句和形容词的宾语从句。 一、that引导的宾语从句(在…

《数据万象带你玩转视图场景》第一期:avif图片压缩详解

前言随着硬件的发展,不管是手机还是专业摄像设备拍出的图片随便可能就有几M,甚至几十M,并且现在我们处于随处可及的信息海洋里,海量的图片带来了存储问题、带宽问题、加载时延问题等等。对图片信息进行有效的压缩处理无疑会极大的…

ARM架构Ubuntu下使用Docker安装MySQL

大家好,我是中国码农摘星人。 欢迎分享/收藏/赞/在看! 由于ARM架构的限制,许多软件还没有做到完全适配,CentOS、MySQL等软件安装频繁出错。于是决定做一栏相关软件环境安装的文章。 基础信息 Apple M1 ProUbuntu 22.04 运行 使…

Python 如何安装 MySQLdb ?

人生苦短 我用python Python 标准数据库接口为 Python DB-API, Python DB-API为开发人员提供了数据库应用编程接口。 Python 数据库接口支持非常多的数据库, 你可以选择适合你项目的数据库: GadFlymSQLMySQLPostgreSQLMicrosoft SQL Serve…

来 CSDN 三年,我写了一本Python书

大家好,我是朱小五。转眼间已经来 CSDN 3年了,其中给大家一共分享了252篇Python文章。 但这三年,最大的收获还是写了一本Python书! 在这个自动化时代,我们有很多重复无聊的工作要做。想想这些你不再需要一次又一次地做…

站内信箱系统的设计与实现

技术:Java、JSP等摘要:在经济全球化和信息技术成为发展迅速的今时今日,人们通过电子邮件收发进行信息传递已经成为主流。随着互联网和网络办公的发展,电子邮件正在被广泛应用在人们的日常生活中。跟据调查研究统计,在全…

文件系统-

文件系统 是一个面向用户的可视化管理类型的操作系统 其实就是管理硬盘的基本单位扇区,然后将存储数据可视化管理给用户 文件系统包含两个部分 文件的集合和目录结构 对于用户和系统来说文件系统时不一样的 操作系统只解释可执行文件 文件内部结构 文件就是基本…

【JVM】详解Java内存区域和分配

这里写目录标题一、前言二、运行时数据分区2.1程序计数器(PC)2.2 Java虚拟机栈2.3 本地方法栈2.4 Java堆2.5 方法区2.5.1 运行时常量池2.6 直接内存三、HotSpot虚拟机对象探秘3.1 对象的创建3.2 对象的内存布局3.3 对象的访问定位一、前言 C/C需要自行回收和释放已经没用的对象…

2023年 Java 发展趋势

GitHub 语言统计表明,Java在编程语言中排名第二,而在2022年的TIOBE指数中,Java排在第四。 抛开排名,Java是自诞生以来企业使用率最高的编程语言,作为一种编程语言,它比许多竞争对手都有更多的优点&#xf…

【C/C++】逗号表达式、算术运算符优先级

一、逗号表达式 1、如下图中代码,为变量d赋值,d的值为逗号表达式中的哪一个呢? 运行结果:d的值为6 2、再举个例子 运行结果:d的结果还是6 3、再举个例子 运行结果 以上面三种不同的逗号表达式为例,…

IronOCR 2023.3.2 Crack by Xacker

适用于 .NET 2023.3.2 的 IronOCR 提高了从 PDF 阅读文本时的可靠性。2023 年 3 月 2 日 - 10:22 新版本特征 添加了与 Amazon AWS (Amazon Linux) 的兼容性。添加了对各种旧版 Linux 发行版的兼容性。提高了从 PDF 阅读文本时的可靠性。创建可搜索的 PDF 时提高了速度和保真度…

JVM调优面试题——基础知识

文章目录1、JDK,JRE以及JVM的关系2、编译器到底干了什么事?3、类加载机制是什么?3.1、装载(Load)3.2、链接(Link)3.3、初始化(Initialize)4、类加载器有哪些?5、什么是双亲委派机制?6、介绍一下JVM内存划分&#xff08…

[蓝桥杯] 数学与简单DP问题

文章目录 一、简单数学问题习题练习 1、1 买不到的数目 1、1、1 题目描述 1、1、2 题解关键思路与解答 1、2 饮料换购 1、2、1 题目描述 1、2、2 题解关键思路与解答 二、DP问题习题练习 2、1 背包问题 2、1、1 题目描述 2、1、2 题解关键思路与解答 2、2 摘花生 2、2、1 题目…

收个滴滴Offer:从小伙三面经历,看看需要学点啥?

说在前面 在尼恩的(50)读者社群中,经常有小伙伴,需要面试大厂。 后续结合一些大厂的面试真题,给大家梳理一下学习路径,看看大家需要学点啥? 这里也一并把题目以及参考答案,收入咱…

Spring 容器创建初始化,获取bean流程分析

Spring 容器创建初始化,获取bean流程分析 Spring 容器创建初始化 流程分析 1、首先读取bean.xml 文件 2、扫描指定的包 com.hspedu.spring.component 2.1、扫描包,得到bean的class对象,排除包下不是bean的 2.2、扫描将bean信息封装BeanDef…