变分自编码器(VAE)PyTorch Lightning 实现

news2024/11/27 19:49:09

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • VAE 简介
      • 基本原理
      • 应用与优点
      • 缺点与挑战
    • 使用 VAE 生成 MNIST 手写数字
      • 忽略警告
      • 导入必要的库
      • 设置随机种子
      • cuDNN 设置
      • 超参数设置
      • 数据加载
      • 定义 VAE 模型
      • 定义损失函数
      • 定义 Lightning 模型
      • 训练模型
      • 绘制训练过程
      • 随机生成新样本
      • 根据潜变量插值生成新样本


VAE 简介

变分自编码器(Variational Autoencoder,VAE)是一种深度学习中的生成模型,它结合了自编码器(Autoencoder, AE)和概率建模的思想,在无监督学习环境中表现出了强大的能力。VAE 在 2013 年由 Diederik P. Kingma 和 Max Welling 首次提出,并迅速成为生成模型领域的重要组成部分。

基本原理

自编码器(AE)基础:
自编码器是一种神经网络结构,通常由两部分组成:编码器(Encoder)和解码器(Decoder)。原始数据通过编码器映射到一个低维的潜在空间(或称为隐空间),这个低维向量被称为潜变量(latent variable)。然后,潜变量再通过解码器重构回原始数据的近似版本。在训练过程中,自编码器的目标是使得输入数据经过编码-解码过程后能够尽可能地恢复原貌,从而学习到数据的有效表示。

VAE的引入与扩展:
VAE 将自编码器的概念推广到了概率框架下。在 VAE 中,潜变量不再是确定性的,而是被赋予了概率分布。具体来说,对于给定的输入数据,编码器不直接输出一个点估计值,而是输出潜变量的均值和方差(假设潜变量服从高斯分布)。这样,每个输入数据可以被视为是从某个潜在的概率分布中采样得到的。

变分推断(Variational Inference):
训练 VA E时,由于真实的后验概率分布难以直接计算,因此采用变分推断来近似后验分布。编码器实际上输出的是一个参数化的概率分布 q ( z ∣ x ) q(z|x) q(zx),即给定输入 x x x 时潜变量 z z z 的概率分布。然后通过最小化 KL 散度(Kullback-Leibler divergence)来优化这个近似分布,使其尽可能接近真实的后验分布 p ( z ∣ x ) p(z|x) p(zx)

目标函数 - Evidence Lower Bound (ELBO):
VAE 的目标函数是证据下界(ELBO),它是原始数据 log-likelihood 的下界。优化该目标函数既鼓励编码器找到数据的高效潜在表示,又促使解码器基于这些表示重建出类似原始数据的新样本。

数学表达上,ELBO 通常分解为两个部分:

  1. 重构损失(Reconstruction Loss):衡量从潜变量重构出来的数据与原始数据之间的差异。
  2. KL散度损失(KL Divergence Loss):衡量编码器产生的潜变量分布与预设的标准正态分布(或其他先验分布)之间的距离。

应用与优点

  • VAE 可以用于生成新数据,例如图像、文本、音频等。
  • 由于其对潜变量进行概率建模,所以它可以提供连续的数据生成,并且能够探索数据的不同模式。
  • 在处理连续和离散数据时具有一定的灵活性。
  • 可以用于特征学习,提取数据的有效低维表示。

缺点与挑战

  • 训练 VAE 可能需要大量的计算资源和时间。
  • 生成的样本有时可能不够清晰或细节模糊,尤其是在复杂数据集上。
  • 对于某些复杂的分布形式,VAE 可能无法完美捕获所有细节。

使用 VAE 生成 MNIST 手写数字

下面我们将使用 PyTorch Lightning 来实现一个简单的 VAE 模型,并使用 MNIST 数据集来进行训练和生成。

在线 Notebook:https://www.kaggle.com/code/marquis03/vae-mnist

忽略警告

import warnings
warnings.filterwarnings("ignore")

导入必要的库

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})

import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

设置随机种子

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

cuDNN 设置

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

超参数设置

batch_size = 64

epochs = 10
KLD_weight = 1
lr = 0.001

input_dim = 784  # 28 * 28
h_dim = 256  # 隐藏层维度  
z_dim = 2  # 潜变量维度

数据加载

train_dataset = datasets.MNIST(root="data", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

定义 VAE 模型

class VAE(nn.Module):
    def __init__(self, input_dim=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()

        self.input_dim = input_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.fc1 = nn.Linear(input_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)  # mu
        self.fc22 = nn.Linear(h_dim, z_dim)  # log_var

        # Decoder
        self.fc3 = nn.Linear(z_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, input_dim)

    def encode(self, x):
        h = torch.relu(self.fc1(x))
        mean = self.fc21(h)
        log_var = self.fc22(h)
        return mean, log_var

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = torch.relu(self.fc3(z))
        out = torch.sigmoid(self.fc4(h))
        return out

    def forward(self, x):
        mean, log_var = self.encode(x)
        z = self.reparameterize(mean, log_var)
        reconstructed_x = self.decode(z)
        return reconstructed_x, mean, log_var

vae = VAE(input_dim, h_dim, z_dim)
x = torch.randn((10, input_dim))
reconstructed_x, mean, log_var = vae(x)
print(reconstructed_x.shape, mean.shape, log_var.shape)
# torch.Size([10, 784]) torch.Size([10, 2]) torch.Size([10, 2])

定义损失函数

def loss_function(x_hat, x, mu, log_var, KLD_weight=1):
    BCE_loss = F.binary_cross_entropy(x_hat, x, reduction="sum") # 重构损失
    KLD_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL 散度损失
    loss = BCE_loss + KLD_loss * KLD_weight
    return loss, BCE_loss, KLD_loss

定义 Lightning 模型

class LitModel(pl.LightningModule):
    def __init__(self, input_dim=784, h_dim=400, z_dim=20):
        super().__init__()
        self.model = VAE(input_dim, h_dim, z_dim)

    def forward(self, x):
        x = self.model(x)
        return x

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(), lr=lr, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        reconstructed_x, mean, log_var = self(x)
        loss, BCE_loss, KLD_loss = loss_function(reconstructed_x, x, mean, log_var, KLD_weight=KLD_weight)
        self.log("loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log_dict(
            {
                "BCE_loss": BCE_loss,
                "KLD_loss": KLD_loss,
            },
            on_step=False,
            on_epoch=True,
            logger=True,
        )
        return loss
    
    def decode(self, z):
        out = self.model.decode(z)
        return out

训练模型

model = LitModel(input_dim, h_dim, z_dim)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(monitor="loss", min_delta=0.00, patience=5, verbose=False, mode="min")
trainer = pl.Trainer(
    max_epochs=epochs,
    enable_progress_bar=True,
    logger=logger,
    callbacks=[early_stop_callback],
)
trainer.fit(model, train_loader)

绘制训练过程

log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"

plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="loss", data=metrics, label="Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="BCE_loss", data=metrics, label="BCE Loss", linewidth=2, marker="^", markersize=12)
sns.lineplot(x=x_name, y="KLD_loss", data=metrics, label="KLD Loss", linewidth=2, marker="s", markersize=10)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()

训练过程

随机生成新样本

row, col = 4, 18
z = torch.randn(row * col, z_dim)
random_res = model.model.decode(z).view(-1, 1, 28, 28).detach().numpy()

plt.figure(figsize=(col, row))
for i in range(row * col):
    plt.subplot(row, col, i + 1)
    plt.imshow(random_res[i].squeeze(), cmap="gray")
    plt.xticks([])
    plt.yticks([])
    plt.axis("off")
plt.show()

随机生成新样本

根据潜变量插值生成新样本

from scipy.stats import norm

n = 15
digit_size = 28

grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

figure = np.zeros((digit_size * n, digit_size * n))
for i, yi in enumerate(grid_y):
    for j, xi in enumerate(grid_x):
        t = [xi, yi]
        z_sampled = torch.FloatTensor(t)
        with torch.no_grad():
            decode = model.decode(z_sampled)
            digit = decode.view((digit_size, digit_size))
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap="gray")
plt.xticks([])
plt.yticks([])
plt.axis("off")
plt.show()

根据潜变量插值生成新样本

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1450755.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Tcl 过程

一个Tcl过程就是Tcl脚本定义的一个命令。可以使用proc命令定义新的过程。Tcl还提供了处理变量作用域的特殊命令,这些命令允许使用引用而非值传递参数,并能把新的Tcl控制结构实现为过程。 一、proc与return 过程由proc命令创建, 其中参数{a b} 中的大括…

【Win10 触摸板】在插入鼠标时禁用触摸板,并在没有鼠标时自动启用触摸板。取消勾选连接鼠标时让触摸板保持打开状态,但拔掉鼠标后触摸板依旧不能使用

出现这种问题我的第一反应就是触摸板坏了,但是无意间我换了一个账户发现触摸板可以用,因此推断触摸板没有坏,是之前的账户问题,跟系统也没有关系,不需要重装系统。 解决办法:与鼠标虚拟设备有关 然后又从知…

【解决(几乎)任何机器学习问题】:超参数优化篇(超详细)

这篇文章相当长,您可以添加至收藏夹,以便在后续有空时候悠闲地阅读。 有了优秀的模型,就有了优化超参数以获得最佳得分模型的难题。那么,什么是超参数优化呢?假设您的机器学习项⽬有⼀个简单的流程。有⼀个数据集&…

Vulnhub靶机:DC5

一、介绍 运行环境:Virtualbox 攻击机:kali(10.0.2.15) 靶机:DC5(10.0.2.58) 目标:获取靶机root权限和flag 靶机下载地址:https://download.vulnhub.com/dc/DC-5.zi…

Rust 数据结构与算法:4栈:用栈实现进制转换

2、进展转换 将十进制数转换为二进制表示形式的最简单方法是“除二法”&#xff0c;可用栈来跟踪二进制结果。 除二法 下面实现一个将十进制数转换为二进制或十六进制的算法&#xff0c;代码如下&#xff1a; #[derive(Debug)] struct Stack<T> {size: usize, // 栈大…

SG7050EEN晶体振荡器SPXO规格书

频率范围:25mhz ~ 500mhz供电电压:2.5 V型。/ 3.3 V类型功能:输出使能(OE)外形尺寸:7.0 5.0 1.5 mm输出:LV-PECL低相位抖动:50fs型。(f0 156.25 MHz)工作温度:-40℃~ 105℃ 规范

二叉树相关OJ题

创作不易&#xff0c;感谢三连&#xff01;&#xff01; 一、选择题 1、某二叉树共有 399 个结点&#xff0c;其中有 199 个度为 2 的结点&#xff0c;则该二叉树中的叶子结点数为&#xff08; &#xff09; A.不存在这样的二叉树 B.200 C.198 D.199解析&#xff1a;选B&…

致创新者:聚焦目标,而非问题

传统的企业创新管理方式常常导致组织内部策略不协调、流程低效、创新失败率高等问题。而创新运营作为企业管理创新的新模式&#xff0c;通过整合文化、实践、人员和工具&#xff0c;提高组织创新能力。已经采用创新运营的公司报告了一系列积极的结果&#xff0c;如市场推出速度…

最长连续手牌 - 华为OD统一考试

OD统一考试&#xff08;C卷&#xff09; 分值&#xff1a; 200分 题解&#xff1a; Java / Python / C 题目描述 有这么一款单人卡牌游戏&#xff0c;牌面由颜色和数字组成&#xff0c;颜色为红、黄、蓝、绿中的一种&#xff0c;数字为 0−9 中的一个。游戏开始时玩家从手牌中…

力扣hot3--并查集+哈希

第一想法是排个序然后遍历一遍&#xff0c;but时间复杂度就超啦 并查集居然与哈希结合了&#xff08;&#xff09; 已经好久没用过并查集了&#xff0c;&#xff0c;&#xff0c;我们用哈希表f_node中来记录原结点的父节点&#xff0c;其中key是原结点&#xff0c;value是父节点…

OpenHarmony—UIAbility组件生命周期

概述 当用户打开、切换和返回到对应应用时&#xff0c;应用中的UIAbility实例会在其生命周期的不同状态之间转换。UIAbility类提供了一系列回调&#xff0c;通过这些回调可以知道当前UIAbility实例的某个状态发生改变&#xff0c;会经过UIAbility实例的创建和销毁&#xff0c;…

代码随想录算法训练营第二十七天|贪心算法理论基础,455.分发饼干,376. 摆动序列,53. 最大子序和

系列文章目录 代码随想录算法训练营第一天|数组理论基础&#xff0c;704. 二分查找&#xff0c;27. 移除元素 代码随想录算法训练营第二天|977.有序数组的平方 &#xff0c;209.长度最小的子数组 &#xff0c;59.螺旋矩阵II 代码随想录算法训练营第三天|链表理论基础&#xff…

抖音私信自动回复工具使用教程!

该工具基于网页版抖音&#xff0c;可以用于抖音个人号等任何权限的账号&#xff01; 获取软件 联系我的v 信&#xff1a;llike620 基本使用 了解GPT的&#xff0c;可以配置FastGPT这种训练知识库的AI进行回复 不了解的&#xff0c;可以配置关键词回复 点击抖音私信按钮&a…

Linux设置jar包开机自启动

步骤 1、新建jar包自启文件 sudo vi /etc/init.d/jarSysInit.sh 按i键进入编辑模式输入以下内容&#xff1a; export JAVA_HOME/home/jdk/jdk-11.0.22 export CLASSPATH.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar:$JAVA_HOME/jre/lib/rt.jar export PATH$PATH:$JAVA_…

H5大气的互联网建站服务公司静态HTML网站模板源码

H5大气的互联网建站服务公司静态HTML网站模板源码 源码介绍&#xff1a;一款大气的互联网建站服务公司/工作室静态HTML网站模板&#xff0c;带有多个单页&#xff0c;可自行二开作为工作室或公司官网。 下载地址&#xff1a; https://www.changyouzuhao.cn/13456.html

人力资源智能化管理项目(day08:云存储)

学习源码可以看我的个人前端学习笔记 (github.com):qdxzw/humanResourceIntelligentManagementProject 存储桶列表 &#xff1a;登录 - 腾讯云 API密钥管理&#xff1a;登录 - 腾讯云 上传图片-创建腾讯云存储桶 存储桶名称&#xff1a;intelligentmanagement-1306913843 地…

【Qt】环境安装与初识

目录 一、Qt背景介绍 二、搭建Qt开发环境 三、新建工程 四、Qt中的命名规范 五、Qt Creator中的快捷键 六、QWidget基础项目文件详解 6.1 .pro文件解析 6.2 widget.h文件解析 6.3 widget.cpp文件解析 6.4 widget.ui文件解析 6.5 main.cpp文件解析 七、对象树 八、…

Rust中不可变变量与const有何区别?

Rust作者认为变量默认应该是immutable&#xff0c;即声明后不能被改变的变量。这一点是让跨语言学习者觉得很别扭&#xff0c;不过这一点小的改变带来了诸多好处&#xff0c;本节我们来学习Rust的变量。 什么是变量&#xff1f; 如果你初次学习编程语言&#xff0c;变量会是一…

AlmaLinux更换鼠标样式为Windows样式

文章目录 前言先看看条件与依赖第一步&#xff1a;测试最终效果第二步&#xff1a;使用CursorXP修改鼠标样式CurosrXP安装CursorXP使用 第三步&#xff1a;Linux端环境搭建与命令执行UbuntuFedora其他系统均失败 第四步&#xff1a;应用主题 前言 只不过是突发奇想&#xff0c…

相机图像质量研究(17)常见问题总结:CMOS期间对成像的影响--靶面尺寸

系列文章目录 相机图像质量研究(1)Camera成像流程介绍 相机图像质量研究(2)ISP专用平台调优介绍 相机图像质量研究(3)图像质量测试介绍 相机图像质量研究(4)常见问题总结&#xff1a;光学结构对成像的影响--焦距 相机图像质量研究(5)常见问题总结&#xff1a;光学结构对成…