【PyTorch][chapter 18][李宏毅深度学习]【无监督学习][ VAE]

news2024/11/16 3:16:33

前言:

          VAE——Variational Auto-Encoder,变分自编码器,是由 Kingma 等人于 2014 年提出的基于变分贝叶斯(Variational Bayes,VB)推断的生成式网络结构。与传统的自编码器通过数值的方式描述潜在空间不同,它以概率的方式描述对潜在空间的观察,在数据生成方面表现出了巨大的应用价值。VAE一经提出就迅速获得了深度生成模型领域广泛的关注,并和生成对抗网络(Generative Adversarial Networks,GAN)被视为无监督式学习领域最具研究价值的方法之一,在深度生成模型领域得到越来越多的应用。

           Durk Kingma 目前也是 OpenAI 的研究科学家

   VAE 是我深度学习过程中偏难的一部分,涉及到的理论基础:

          极大似然估计, KL 散度 ,Bayes定理,蒙特卡洛重采样思想,VI变分思想,ELBO


目录:

  1.    AE 编码器缺陷
  2.    VAE 编码器 跟AE 编码器差异
  3.    VAE 编码器
  4.     VAE 思想
  5.      Python 代码例子

一 AE 编码器缺陷

   1.1 AE 简介

   输入一张图片 x

   编码器Encoder:

                 z=f(x)  通过神经网络得到低维度的特征空间Z

   解码器Decoder:

                 \hat{x}=g(z)  通过特征空间 重构输入的图像

   损失函数:

               J=mse(x,\hat{x})

   1.2 特征空间z

           单独使用解码器Decoder

           特征空间z 维度为10,固定其它维度参数. 取其中两维参数,产生不同的

值(如下图星座图),然后通过Decoder 生成不同的图片.就会发现该维度

跟图像的某些特征有关联.

1.3 通过特征空间z重构缺陷:泛化能力差

     

       如上图:

            假设通过AE 模型训练动物的图像,特征空间Z为一维度。

      两种狗分别对应特征向量z_1,z_3, 我们取一个特征向量z_2,期望通过

     解码器输出介于两种狗中间的一个样子的一种狗。

          实际输出: ,随机输出一些乱七八糟的图像。

     原因:

          因为训练的时候,模型对训练的图像和特征空间Z的映射是离散的,对特征空间z

中没有训练过的空间没有约束,所以通过解码器输出的图像也是随机的.


二  VAE 编码器 跟AE 编码器差异

        2.1  AE 编码器特征空间

      假设特征空间Z 为一维,

  通过编码器生成的特征空间为一维空间的一个离散点c,然后通过解码器重构输入x

2.2 VAE 编码器

      通过编码器产生一个均值为u,方差为\sigma的高斯分布,然后在该分布上采样得到

特征空间的一个点c, 通过解码器重构输入. 现在特征空间Z是一个高斯分布,

泛化能力更强


三 VAE 编码器

3.1 模型简介

 输入 :x

 经过编码器 生成一个服从高斯分布的特征空间 z \sim N(u,\sigma^2) ,

通过重参数采样技巧 采样出特征点 C=\begin{bmatrix} c_1,c_2,c_3 \end{bmatrix}

 把特征点 输入解码器,重构出输入x

3.2 标准差\sigma(黄色模块)设计原理

           方差 \sigma^2   标准差 \sigma 

           因为标准差是非负的,但是经过编码器输出的可能是负的值,所以

    认为其输出值为 a=log (\sigma) ,再经过 exp 操作,得到一个非负的标准差

        \sigma=e^{a}=\sigma

      很多博主用的\sigma^2,我理解是错误的,为什么直接用 标准差

       参考3.3  苏剑林的 重参数采样 原理画出来的。

3.3 为什么要重参数采样 reparameterization trick

        我们要从p(Z|X)中采样一个Z出来,尽管我们知道了p(Z|X)是正态分布,但是均值方差都是靠模型算出来的,我们要靠这个过程反过来优化均值方差的模型。
但是“采样”这个操作是不可导的,而采样的结果是可导的

p(Z|X) 的概率可以写成如下形式

   说明

    服从 N(0,1)的标准正态分布

   从N(u,\sigma^2)中采样一个Z,相当于从N(0,I)标准正态分布中采样一个e,然后让

    Z=u+e*\sigma

   我们将从采样N(u,\sigma^2)变成了从N(0,I)中采样,然后通过参数变换得服从N(u,\sigma^2)分布。这样一来,“采样”这个操作就不用参与梯度下降了,改为采样的结果参与,使得整个模型可训练了。其中 u,\sigma是求导参数,e 为已知道参数

3.4 损失函数

         J=J_1+J_2

         该模型有两个约束条件

         1   一个输入图像x和重构的图像\hat{x},mse 误差最小

                    J_1= ||x-\hat{x}||_2

         2   特征空间Z 要服从高斯分布(使用KL 散度)

                     J_2=KL(N(u,\sigma^2)||N(0,1))

                  该值越小越好

     KL 散度简化

3.5 伪代码


四  VAE 思想

        4.1 高斯混合模型

             我们重构出m张图片 X=\begin{Bmatrix} x_1 &x_2 & ... & x_m \end{Bmatrix}

              P(X)=\prod_i^{m} P(x_i),P(X) 很复杂无法求解.

            常用的思路是通过引入隐藏变量(latent variable) Z。

           寻找 Z空间到 X空间的映射,这样我们通过在Z空间采样映射到 X  空间就可以生成新的图片。

          P(X)=\int _z P(x|z)P(z)dz   

          我们使用多个高斯分布的P(z) 去拟合P(X)的分布,这里面P(z)为已知道

            在强化学习里面,蒙特卡罗重采样也是用了该方案.

例:

如上图 P(X=红色)=2/5  ,P(X=绿色)=3/5 

 我们可以通过高斯混合模型原理的方法求解

P(X=红色)=P(X=红色|Z=正方形)*P(Z=正方形)+ P(X=红色|Z=圆形)*P(Z=圆形)

                    

 P(X=绿色)也是一样

   4.2 极大似然估计

      目标:极大似然函数

            L= logP(x) 

      已知:

            编码器的概率分布\int_z q(z|x)dz=1

       则:

          L=L*\int_z q(z|x)dz(相当于乘以1)

              =\int_z q(z|x) log P(x)dz (因为P(x)跟z 无关,可以直接拿到积分里面)

             =\int_z q(z|x)log \frac{P(z,x)}{p(z|x)}

           贝叶斯定理:

         P(z,x)=p(x)p(z|x)

          =\int q(z|x)log \frac{p(z,x)}{p(z|x)}\frac{q(z|x)}{q(z|x)}

         =\int_z q(z|x)log \frac{q(z|x)}{p(z|x)}+\int_z q(z|x)log \frac{p(z,x)}{q(z|x)}

         =KL(q(z|x)||q(z|x))+\int_z q(z|x)log \frac{p(z,x)}{q(z|x)}

   1:  VAE叫做“变分自编码器”,它跟变分法有什么联系

固定概率分布p(x)(或q(x)的情况下,对于任意的概率分布q(x)(或p(x))),都有KL(p(x)||q(x))≥0,而且只有当p(x)=q(x)时才等于零。

因为KL(p(x)∥∥q(x))实际上是一个泛函,要对泛函求极值就要用到变分法

  \geq L_b=\int_z q(z|x)log\frac{p(z,x)}{q(z|x)}

ELBO:全称为 Evidence Lower Bound,即证据下界。

上面KL(q(z|x)||q(z|x)) 我们取了下界0

        =\int_z q(z|x)log \frac{p(z)p(x|z)}{q(z|x)}

  贝叶斯定理

   p(z,x)=p(x|z)p(z)

   注意: 这里面P(Z)在4.1 高斯混合模型 是已知道的概率分布,符合高斯分布

 =\int_z q(z|x)log p(z|x)+\int_z q(z|x)log \frac{p(z)}{p(z|x)}

=-KL(q(z|x)||p(z))+H(q(z|x)||p(x||z))

我们目标值是求L 的最大值

第一项:

因为KL 散度的非负性

-KL(q(z|x)||p(z))极大值点为 p(z)=q(z|x),因为p(z)是符合高斯分布的

所以通过编码器生成的q(z|x)也要跟它概率一致,符合高斯分布。

第二项:

H(q(z|x)||p(x||z)) 

 这部分代表重构误差,我们用mse(x,\hat{x}) 来训练该部分的误差


五 Python 代码

# -*- coding: utf-8 -*-
"""
Created on Mon Feb 26 15:47:20 2024

@author: chengxf2
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms   # transforms用于数据预处理


# 定义变分自编码器(VAE)模型
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(in_features=784, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=latent_dim*2),  # 输出均值和方差
            nn.ReLU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(in_features =latent_dim , out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=784),
            nn.Sigmoid()
        )
        
    def reparameterize(self, mu, logvar):
        
        std = torch.exp(logvar/2.0)  # 计算标准差,Encoder 出来的可能有负的值,标准差为非负值,所以要乘以exp
        eps = torch.randn_like(std)  # 从标准正态分布中采样噪声
        z = mu + eps * std  # 重参数化技巧
        return z
    
    def forward(self, x):
        # 编码[batch, latent_dim*2]
        encoded = self.encoder(x)
        #[ z = mu|logvar]
        mu, logvar = torch.chunk(encoded, 2, dim=1)  # 将输出分割为均值和方差
      
        
        z = self.reparameterize(mu, logvar)  # 重参数化
        
        # 解码
        decoded = self.decoder(z)
        return decoded, mu, logvar

# 定义训练函数
def train_vae(model, train_loader, num_epochs, learning_rate):
    criterion = nn.BCELoss()  # 二元交叉熵损失函数
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # Adam优化器
    
    model.train()  # 设置模型为训练模式
    
    for epoch in range(num_epochs):
        total_loss = 0.0
        
        for data in train_loader:
            images, _ = data
            images = images.view(images.size(0), -1)  # 展平输入图像
            
            optimizer.zero_grad()
            
            # 前向传播
            outputs, mu, logvar = model(images)
            
            # 计算重构损失和KL散度
            reconstruction_loss = criterion(outputs, images)
            kl_divergence = 0.5 * torch.sum( -logvar +mu.pow(2) +logvar.exp()-1)
            
            # 计算总损失
            loss = reconstruction_loss + kl_divergence
            
            # 反向传播和优化
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # 输出当前训练轮次的损失
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, total_loss / len(train_loader)))
    
    print('Training finished.')

# 示例用法
if __name__ == '__main__':
    # 设置超参数
  
    latent_dim = 32  # 潜在空间维度
    num_epochs = 1  # 训练轮次
    learning_rate = 1e-4  # 学习率
    
    # 加载MNIST数据集
    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
    
    # 创建VAE模型
    model = VAE(latent_dim)
    
    # 训练VAE模型
    train_vae(model, train_loader, num_epochs, learning_rate)


VAE到底在做什么?VAE原理讲解系列#1_哔哩哔哩_bilibili

VAE里面的概率知识。VAE原理讲解系列#2_哔哩哔哩_bilibili

vae损失函数怎么理解? - 知乎

如何搭建VQ-VAE模型(Pytorch代码)_哔哩哔哩_bilibili

变分自编码器(一):原来是这么一回事 - 科 学空间|Scientific Spaces

16: Unsupervised Learning - Auto-encoder_哔哩哔哩_bilibili

【生成模型VAE】十分钟带你了解变分自编码器及搭建VQ-VAE模型(Pytorch代码)!简单易懂!—GAN/机器学习/监督学习_哔哩哔哩_bilibili

[diffusion] 生成模型基础 VAE 原理及实现_哔哩哔哩_bilibili

[论文简析]VAE: Auto-encoding Variational Bayes[1312.6114]_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1479908.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用JavaScript动态提取视频中的文字

现阶段整个社会短视频,中视频为王,文字传播虽然被弱化,但在业务中还是有一定的传播价值,今天就来讲一讲如何使用js动态提取视频中的字幕。 先来看看效果: 屏幕录制2024-02-29 15.40.18 一,tesseract.js介…

springcloud alibaba组件简介

一、Nacos 服务注册中心/统一配置中心 1、介绍 Nacos是一个配置中心,也是一个服务注册与发现中心。 1.1、配置中心的好处: (1)配置数据脱敏 (2)防止出错,方便管理 (3&#xff…

Web漏扫工具OWASP ZAP安装与使用(非常详细)从零基础入门到精通,看完这一篇就够了。

本文仅用于安全学习使用!切勿非法用途。 一、OWASP ZAP简介 开放式Web应用程序安全项目(OWASP,Open Web Application Security Project)是一个组织,它提供有关计算机和互联网应用程序的公正、实际、有成本效益的信息。…

javascript作用域编译浅析

作用域思维导图 1:编译原理 分词/词法分析 如果词法单元生成器在判断a是一个独立的词法单元还是其他词法单元的一部分时,调用的是有状态的解析规则,那么这个过程就被称为词法分析。 解析/语法分析 由词法单元流转换成一个由元素逐级嵌套所组…

java: 错误: 不支持发行版本 5

目录 一、问题描述 二、解决办法 方法一:修改idea设置中的jdk版本 方法二:配置pom.xml文件 方法三:配置maven的xml文件(推荐) 三、结果 一、问题描述 问题描述:今天创建了一个maven项目,…

第六课:NIO简介

一、传统BIO的缺点 BIO属于同步阻塞行IO,在服务器的实现模型为,每一个连接都要对应一个线程。当客户端有连接请求的时候,服务器端需要启动一个新的线程与之对应处理,这个模型有很多缺陷。当客户端不做出进一步IO请求的时候,服务器…

Gitlab: 私有化部署

目录 1. 说明 2. 资源要求 3. 安装 4. 配置实践 4.1 服务器 4.2 人员与项目 4.2 部署准备 4.2.1 访问变量及用户账号设置 4.2.2 Runner设置 4.2.3 要点 5. 应用项目 CI/CD 6. 参考 1. 说明 gitlab是一个强大且免费的代码管理/部署工具,能统一集成代码仓…

springboot233大学生就业需求分析系统

大学生就业需求分析系统设计与实现 摘 要 信息数据从传统到当代,是一直在变革当中,突如其来的互联网让传统的信息管理看到了革命性的曙光,因为传统信息管理从时效性,还是安全性,还是可操作性等各个方面来讲&#xff…

SpringBoot接收参数的几种形式

SpringBoot接收参数的几种形式 在SpringBoot中获取参数基本方式有5种,需要都掌握. 这里需要记住一个技术术语或概念 API接口: 你写好的那个URL地址,就被称为API接口 1. 接收常规参数 给/param/demo1这个URL接口发送id, name两个参数 以上是以GET请求类型进行发送,实际发送…

一封来自 DatenLord 关于GSoC 2024的挑战书

Google Summer of Code 是一项全球性的在线计划,致力于将新的contributor引入开源软件开发领域。GSoC 参与者在导师的指导下,与开源组织合作开展为期 12 周以上的编程项目。今年,达坦科技入选作为开源社区组织,携CNCF Sandbox项目…

深入探讨Java中的OutputStreamWriter类

咦咦咦,各位小可爱,我是你们的好伙伴——bug菌,今天又来给大家普及Java SE相关知识点了,别躲起来啊,听我讲干货还不快点赞,赞多了我就有动力讲得更嗨啦!所以呀,养成先点赞后阅读的好…

动态规划(算法竞赛、蓝桥杯)--分组背包DP

1、B站视频链接&#xff1a;E16 背包DP 分组背包_哔哩哔哩_bilibili #include <bits/stdc.h> using namespace std; const int N110; int v[N][N],w[N][N],s[N]; // v[i,j]:第i组第j个物品的体积 s[i]:第i组物品的个数 int f[N][N]; // f[i,j]:前i组物品&#xff0c;能放…

Power Apps 学习笔记 -- Plugin

文章目录 1. Plugin 简介2. Plugin 配置2.1 步骤Step核心分析 3. Plugin 代码 1. Plugin 简介 Plugin基础教程 : Plugin基础教程 插件Plugin: 1. 插件Plugin通常用于默认数据处理操作区间&#xff0c;增加数据默认行为的方法。(无重用性)2. Plugin 配置 .NET环境&#xff1a;.…

图像分割 - 轮廓拟合(最小外接矩形和圆形)

1、前言 拟合:用一条光滑的曲线将平面上的点连接起来 轮廓拟合:将凹凸不平的轮廓用平整的几何图形体现出来 本章将介绍如何用最小外接矩形或者最小外接圆形将下面的图像轮廓拟合 几何图形的轮廓绘制,参考前面的文章:图像分割 - 查找图像的轮廓(cv2.findContours函数) 2、…

rk3568-一种基于wifi的网络环境搭建方案

前言&#xff1a; PC--Ubuntu--开发板 三者之间的网络互相ping通很重要&#xff0c;尤其是ubuntu和开发板互ping成功最关键&#xff0c;关系到nfs&#xff0c;tftp等常用的开发手段。现在大多数开发板都带有wifi芯片&#xff0c;现在提供一种方案可以三个设备无线地搭建网络环境…

这4款一键生成的AI写作软件值得一试

自今年初以来&#xff0c;各类AI工具如潮水般涌现&#xff0c;包括AI写作、AI绘画、AI音频处理和AI抠图等等。这些工具层出不穷&#xff0c;为我们的工作和生活带来了极大的便利。学会充分利用这些AI工具可以显著提升我们的生产效率。 软件一&#xff1a;爱制作AI 推荐指数&am…

在实训云平台上配置云主机

文章目录 零、学习目标一、实训云升级二、实训云登录&#xff08;一&#xff09;登录实训云&#xff08;二&#xff09;切换界面语言&#xff08;三&#xff09;规划云主机实例 三、创建网络三、创建路由器2024-2-29更新到此四、添加接口五、创建端口六、添加安全组规则七、创建…

一文详解CRM系统是什么?让你轻松了解CRM的全貌!

互联网上关于CRM管理系统的介绍文章各式各样&#xff0c;但是很多都是为了做品牌推广&#xff0c;并不能真正帮助读者理解CRM这一系统。这篇文章有别于您读到的其他文章&#xff0c;将从CRM系统的概念理解、常见分类、基础功能、应用阶段、发展趋势、系统定价和选型技巧这七个方…

Tomcat 下部署若依单体应用可观测最佳实践

实现目标 采集指标信息采集链路信息采集日志信息采集 RUM 信息会话重放 即用户访问前端的一系列过程的会话录制信息&#xff0c;包括点击某个按钮、操作界面、停留时间等&#xff0c;有助于客户真是意图、操作复现 版本信息 Tomcat (9.0.81)Springboot(2.6.2)JDK (>8)DDT…

mount命令最新详细教程

背景 需要在设备上面&#xff0c;自动化运行u盘里面的脚本&#xff0c;并且进入一个产测模式。因此实际使用了这个mount命令&#xff0c;所以&#xff0c;写了这么一篇供大家参考。 一. 定义 mount命令在Linux和类Unix系统中用于挂载文件系统&#xff0c;即将存储设备…