【大模型】通俗解读变分自编码器VAE

news2025/1/11 11:52:48

目录

写在前面

一、VAE结构

二、损失函数

三、代码实现

1.训练代码

2.推理生成图片

3.插值编辑图片

四、总结


写在前面

        论文地址:https://arxiv.org/abs/1312.6114

        大模型已经有了突破性的进展,图文的生成质量都越来越高,可控性也越来越强。很多阅读大模型源码的小伙伴会发现,大部分大模型,尤其是CV模型都会用到一个子模型:变分自编码器(VAE),这篇文章就以图像生成为例介绍一下VAE,并且解释它问什么天生适用于图像生成。配合代码尽量做到通俗易懂。

        变分自编码器(VAE)是一种生成模型,旨在通过学习数据的潜在表示(Latent)来生成新样本。VAE 的训练目标是最大化变分下界,这意味着在学习潜在空间时,保持生成样本与真实数据的相似性,并尽量让潜在变量的分布接近标准正态分布。这样一来,模型就能有效地生成多样化的新图像。

        上面那段话似乎不容易理解,我用白话解释一遍。VAE 的最大作用是尽量简单的生成“能看的”图片。现在达到的效果是输入一段标准高斯分布的Latent,就能生成自然连贯的图像。而且生成的图像有如下三个特点:        

1.这个图像是全新的(也许跟某些训练数据相似);

2.通过编辑Latent可以一定程度上控制生成图像中的内容;

3.Latent空间中的结构化使得生成的图像自然且连贯,也就是说输入虽然是随机的,但输出是“能看的”,不是无意义的图像。

一、VAE结构

        VAE由如下三块组成:

        1.编码器(Encoder):输入数据通过编码器转换为潜在空间的分布。编码器通常由几层神经网络组成,输出潜在变量的均值和方差(其实是对数方差)。

        2.重参数化层(Reparameterize):从编码器输出的均值和方差中进行重参数化采样,生成潜在变量。这一过程使得模型能够在训练时进行反向传播。

        3.解码器(Decoder):解码器接收潜在变量并将其转换回原始数据的分布。解码器同样由神经网络组成,目的是重构输入数据。

        可以看到和AE相比,VAE的结构差别主要集中在编码器和潜在空间的处理。编码器有两个输出均值和方差(其实是对数方差);中间的重参数化层根据均值和方差重采样得到Latent,我们一般管他叫做z。

        下面我们使用MNIST数据集模拟一个VAE的结构,编码器和解码器使用最简单的全连接,Hidden维度400,Latent维度20,batch_size=128。

        可以看到,编码器的输出是两个128x20的特征图,用于重参数化;重参数化的输出是128x20,也就是每一个点都根据对应的均值和方差采样得来。

二、损失函数

        (VAE)的损失函数主要由两部分组成:

        1.重构损失(Reconstruction Loss):衡量模型生成的样本与原始输入之间的差异,通常使用均方误差(MSE)或二元交叉熵(Binary Cross-Entropy)作为度量。这部分确保生成的样本尽量忠实于输入数据。

        2.KL散度(Kullback-Leibler Divergence):衡量编码器输出的潜在分布与先验分布(通常是标准正态分布)之间的差异。目标是使得 q(z|x)逼近标准正态分布N(0,1),使得采样变得更加合理。

        重构损失没什么可说的,下面给出KL散度的公式:

D_{KL}(q(z|x)||p(z))=-0.5\cdot (1+log(\sigma ^2)-\mu ^2-\sigma ^2)

        KL散度代码实现:在代码实现的时候编码器的输出其实是均值mu和对数方差log_var,这一点在上图也能看出来:

KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        其中log_var 对应对数方差log(\sigma^2),使用对数方差的形式可以保证数值稳定性、避免负值以及计算便利性,这种做法在许多深度学习模型中都得到了广泛应用,尤其是在处理概率分布时。;mu 是均值\mu\sigma^2=exp(log(\sigma^2)),在代码中就是log_var.exp()。

        KL散度会在下一篇文章详细介绍,这里到此为止。

三、代码实现

1.训练代码

        下面是训练的全部代码,很简单,没什么可说的,重点是重参数化层和损失函数中的KL散度。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# 定义 VAE 模型
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)  # 输出均值和对数方差
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # 输出为 [0, 1]
        )

    def encode(self, x):
        """
        编码器
        :param x:
        :return:
        """
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)
        return mu, log_var

    @staticmethod
    def reparameterize(mu, log_var):
        """
        重参数化
        :param mu:
        :param log_var:
        :return:
        """
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        """
        解码器
        :param z:
        :return:
        """
        return self.decoder(z)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var


def loss_function(recon_x, x, mu, log_var):
    """
    重构损失和 KL 散度
    :param recon_x:
    :param x:
    :param mu:
    :param log_var:
    :return:
    """
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD


def train(model, train_loader, optimizer, epoch):
    """
    训练模型
    :param model:
    :param train_loader:
    :param optimizer:
    :param epoch:
    :return:
    """
    model.train()
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(data.size(0), -1)  # 展平输入
        optimizer.zero_grad()
        recon_batch, mu, log_var = model(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}]: Loss: {loss.item()}')


# 超参数
input_dim = 28 * 28  # MNIST
hidden_dim = 400
latent_dim = 20
batch_size = 128
learning_rate = 1e-3
num_epochs = 200

# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # 展平
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 初始化模型和优化器
model = VAE(input_dim, hidden_dim, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(1, num_epochs + 1):
    train(model, train_loader, optimizer, epoch)
    if epoch % 20 == 0:
        # 保存模型
        torch.save(model.state_dict(), 'model_data/vae_mnist_{}.pth'.format(epoch))

2.推理生成图片

        下面是推理代码,理论上一个训练好的解码器,只需要标准高斯分布的随机噪声作为输入即可。我们来试一下,只使用解码器,输入是标准高斯分布的采样数据,输出是数字图片。

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


# 定义 VAE 模型
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # 输出为 [0, 1]
        )

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        pass


def ran_demo():
    with torch.no_grad():
        z = torch.randn(64, latent_dim).to(device)  # 随机采样
        sample = model.decode(z).cpu()

    # 绘制生成的样本
    fig, axes = plt.subplots(8, 8, figsize=(8, 8))
    for i in range(64):
        axes[i // 8, i % 8].imshow(sample[i].view(28, 28), cmap='gray')
        axes[i // 8, i % 8].axis('off')
    plt.show()



if __name__ == '__main__':
    # 超参数
    input_dim = 28 * 28  # MNIST
    hidden_dim = 400
    latent_dim = 20
    # hidden_dim = 1024
    # latent_dim = 128
    batch_size = 128
    learning_rate = 1e-3
    num_epochs = 500
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 初始化模型和优化器
    model = VAE()

    # 加载模型并生成图像
    model.load_state_dict(torch.load('model_data/vae_mnist_1000.pth', map_location=torch.device('cpu')))
    # model.load_state_dict(torch.load('model_data/vae_mnist_200.pth', map_location=torch.device('cpu')))
    model.eval()

    # 随机输入
    ran_demo()

         输出结果如下:大部分是能看出来的数字的。毕竟只是一个简单的demo,就不要在意细节了。(#^.^#)

3.插值编辑图片

        下面玩一个有意思的,既然不同的Latent分布控制着不同的图像特征,那么我们试试把一个数字的Latent通过插值慢慢混入另一个数字的Latent,看看会发生什么。我们在数字6的Latent中慢慢混入7.

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


# 定义 VAE 模型
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)  # 输出均值和对数方差
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # 输出为 [0, 1]
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        print(eps)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var


def interpolate_demo(from_num, to_num):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(-1))  # 展平
    ])
    dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

    def interpolate(z1, z2, num_steps=10):
        return [(1 - alpha) * z1 + alpha * z2 for alpha in np.linspace(0, 1, num_steps)]

    # 找到数字“1”和“7”的潜在向量
    def get_latent_vector(digit):
        model.eval()
        with torch.no_grad():
            for data, labels in data_loader:
                if labels[0] == digit:
                    data = data.to(device)
                    mu, log_var = model.encode(data.view(-1, input_dim))
                    return mu.mean(0).cpu().numpy()  # 返回均值作为潜在向量
    # 获取两个数字的向量
    latent_1 = get_latent_vector(from_num)
    latent_7 = get_latent_vector(to_num)
    # 计算插值向量
    interpolated_latents = interpolate(latent_1, latent_7)
    # 使用解码器生成图像
    with torch.no_grad():
        generated_images = [model.decode(torch.tensor(latent).float().to(device)).view(28, 28).cpu().numpy() for latent
                            in interpolated_latents]

    # 可视化生成的图像
    fig, axs = plt.subplots(1, len(generated_images), figsize=(15, 3))
    for i, img in enumerate(generated_images):
        axs[i].imshow(img, cmap='gray')
        axs[i].axis('off')
    plt.show()


if __name__ == '__main__':
    # 超参数
    input_dim = 28 * 28  # MNIST
    hidden_dim = 400
    latent_dim = 20
    # hidden_dim = 1024
    # latent_dim = 128
    batch_size = 128
    learning_rate = 1e-3
    num_epochs = 500
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 初始化模型和优化器
    model = VAE()

    # 加载模型并生成图像
    model.load_state_dict(torch.load('model_data/vae_mnist_200.pth', map_location=torch.device('cpu')))
    model.eval()

    # 插值demo
    interpolate_demo(6, 7)

         可以看到数字6慢慢变成了数字7,中间的几张图既有6的特征又有7的特征。通过控制Latent确实可以控制输出图像的特征。那是不是也可以把一个人的脸慢慢变成另一个人的脸呢,我感觉可以试试。

四、总结

        1.与AE模型相比,VAE主要有两处修改:

        (1)编码器输出均值和方差(对数方差),经过重参数化层重采样后得到Latent,再进行解码;

        (2)损失函数加入了KL散度,衡量编码器输出的Latent分布与先验分布(通常是标准正态分布)之间的差异,同时起到正则化的目的,使码器输出的Latent分布尽量符合标准高斯分布。

        2.为什么VAE适合用在生成任务?

        (1)容易生成的“能看的”图像:解码器只需接受标准高斯分布的采样数据就能生成自然连贯的图像,这意味着我们不再为生成的图像过于抽象而烦恼;

        (2)生成图像的属性可以编辑:图像的各种属性特征都蕴含在Latent里,只要找到方法对齐并组合这些特征,我们就能控制输出图像的内容,比如:长着牛头的企鹅。这就是为什么当今很多生成模型吧VAE作为一个模块来使用,同时还需要配合其它模型来完成特定的生成任务,这点今天不做过多讨论。

        总之VAE极大推动了生成任务,是很有研究价值的,小伙伴们快玩起来吧。

        VAE就介绍到这,关注不迷路(*^__^*) 

  关注订阅号了解更多精品文章

交流探讨请加微信

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2177465.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

cesium渲染的3Dtiles的模型下载!glb模型文件下载!

前端开发测试或者学习cesium的时候最难最麻烦就是找到一个合适的模型,现在就直接给各位放几个可以满足我们测试使用的模型文件。 模型文件下载—香港3DTiles模型文件 某盘 通过百度网盘分享的文件:hk-效果图.png,hk.zip等2个文件 链接&…

react中的ref三种形式

1&#xff0c;字符串形式 <!-- 创建盒子 --><div id"test"></div> <script type"text/babel">class Demo extends React.Component{render(){return(<div><input type"text" refinput1 /><button onCl…

奔驰EQS450suv升级增强AR抬头显示HUD案例分享

以下是奔驰 EQS450 SUV 升级增强版 AR 抬头显示的一般改装案例步骤及相关信息&#xff1a; 配件&#xff1a;通常包括显示屏、仪表模块、饰板等。 安装步骤&#xff1a; 1. 拆下中控的仪表。 2. 在仪表上预留位置切割出合适的孔位&#xff0c;用于安装显示器。 3. 将显示器…

宝塔部署若依前端出现502解决方法

一、前言 ‌若依系统是一个基于Java语言的开源项目&#xff0c;旨在帮助开发者减少开发时间&#xff0c;特别适用于需要快速开发出一套具有用户管理、菜单管理、权限管理、定时任务、日志管理等功能的简单系统。‌ 系统分为前后端分离、分布式等架构 部署教程如下&#xff1a…

单体到微服务架构服务演化过程

架构服务化 聊聊从单体到微服务架构服务演化过程 单体分层架构 在 Web 应用程序发展的早期&#xff0c;大部分工程是将所有的服务端功能模块打包到单个巨石型&#xff08;Monolith&#xff09;应用中&#xff0c;譬如很多企业的 Java 应用程序打包为 war 包&#xff0c;最终会形…

软文代发高效率推广方式解析-华媒舍

在这个时代&#xff0c;软文代发成为了一种非常实用的推广方法。如何有效地开展软文代发营销推广&#xff0c;并不是每个人都知道的。下面我们就以高效软文代发推广方式大曝光为主线&#xff0c;为书友详细介绍科谱有关的内容。 一、什么叫软文代发 软文代发是指由企业或个人必…

引入 LangChain4j 来简化 LLM 与 Java 应用程序的集成

作者&#xff1a;来自 Elastic David Pilato LangChain4j 框架于 2023 年创建&#xff0c;其目标如下&#xff1a; LangChain4j 的目标是简化将 LLM 集成到 Java 应用程序的过程。 LangChain4j 提供了一种标准方法&#xff1a; 根据给定内容&#xff08;例如文本&#xff09;创…

VSCode编程配置再次总结

VScode 中C++编程再次总结 0.简介 1.配置总结 1.1 launch jsion文件 launch.json文件主要用于运行和调试的配置,具有程序启动调试功能。launch.json文件会启用tasks.json的任务,并能实现调试功能。 左侧任务栏的第四个选项运行和调试,点击创建launch.json {"conf…

AI变现N种方式,新手小白必看!【保姆级教程】

风口&#xff01;风口&#xff01;风口&#xff01; 终于不用再抱怨 “我们这代人啊&#xff0c;什么也没赶上” 因为我们现在正处于风口之上&#xff01; 在当今数字化的时代 AI 绘画正以惊人的速度崛起 并向各行各业渗透 既然阻止不了时代的变化 那就让它为我们所用 …

打造高业绩朋友圈:策略与实践

在数字化时代&#xff0c;朋友圈不仅是个人生活的展示窗口&#xff0c;更是商业变现的有力平台。许多人通过精心经营朋友圈&#xff0c;实现了财富的增长&#xff0c;甚至达到了年入百万的惊人业绩。朋友圈已成为普通人实现逆袭的重要战场。 要打造一个业绩过万的朋友圈&#…

微积分入门(真的很入门)

前置知识 前置知识&#xff1a;极限 我们要求 lim ⁡ x → 1 x 2 − 1 x − 1 \lim\limits_{x \to 1}\dfrac{x^2-1}{x-1} x→1lim​x−1x2−1​。 右边我们都知道是什么意思&#xff0c;那左边是什么呢&#xff1f; 意思就是&#xff0c;当 x x x 无限接近 1 1 1 时&…

Java IO 和 NIO

在 Java 编程中&#xff0c;输入输出&#xff08;IO&#xff09;是不可或缺的部分&#xff0c;随着技术的发展&#xff0c;Java 的 IO 系统也经历了显著的变化。本文将深入探讨 Java IO 和 NIO 的历史、优缺点以及适用场景。 1. Java IO 的历史 Java IO 包&#xff08;java.i…

JVM和GC监控技术

一、监控技术简介 JVM是什么&#xff1f;项目里面有JVM吗&#xff1f;JVM跟Tomcat有什么关系&#xff1f;为什么需要去分析JVM&#xff1f; 1. JVM(全称&#xff1a;Java Virtual Machine)&#xff0c;Java虚拟机 是Java程序运行的环境&#xff0c;它是一个虚构的计算机&…

盛世欢歌,共庆华诞!祝大家国庆节快乐!

举国同庆 盛世中华 盛世欢歌&#xff0c;共庆华诞&#xff01;在这美好的时光里&#xff0c;让我们一起欢庆国庆&#xff0c;感受祖国的强大和美好。数图祝大家国庆快乐&#xff01; 国庆来临之际&#xff0c;根据国家有关规定&#xff0c;现将2024年国庆放假安排通知如下&…

JVM(HotSpot):虚拟机栈(JVM Stacks)与本地方法栈(Native Method Stacks)

文章目录 一、内存结构图二、数据结构-栈三、JVM栈四、本地方法栈五、问题辨析1、垃圾回收是否涉及栈内存&#xff1f;2、栈内存越大越好吗&#xff1f;3、方法内的局部变量是否线程安全&#xff1f;4、栈内存溢出问题 一、内存结构图 二、数据结构-栈 数据结构中&#xff0c;…

windows 系统服务在注册表中的位置

计算机\HKEY_LOCAL_MACHINE\SYSTEM\ControlSet001\Services 此注册表项下是系统服务安装信息 利用此注册表项可以获取服务详细信息

新版Android Studio Koala 导入github第三方依赖 maven仓库的处理方法 (java版)

以下是依赖的处理 这是由于Android Studio 构建项目模式发生改变了。 旧版项目构造 创建新的项目采用build.gradle.kts配置。 先看旧版同样的配置是什么样的。 再来查看新版带.kts后缀文件官方自带的库是怎么配置&#xff0c;模拟配置就OK。 先看libs文件这个库的写法。 …

隐藏SpringBoot自动生成的文件

第一种方法——删除 第二种方法——Settings——Editor——fail types

题库系统平台开发功能解析

题库系统开发功能介绍可以从多个方面进行阐述&#xff0c;以下是一些核心功能及其详细解释 1. 题库管理系统 题目录入与编辑&#xff1a;提供灵活的题目录入方式&#xff0c;支持手动输入、批量导入&#xff08;如从Excel、Word等文件中导入&#xff09;以及从其他题库中复制试…

HuggingChat macOS版正式发布!文章内附体验地址!我国打造糖尿病专用AI模型|AI日报

文章推荐 全新豆包AI视频模型发布&#xff01;实测下的可灵与豆包&#xff01;原来它们的差距不止一点点... 今日热点 我国团队打造糖尿病专用AI模型 上海交通大学清源研究院MIFA实验室携手复旦大学附属中山医院内分泌科&#xff0c;组建专家团队&#xff0c;联手开发一款名…