深入解析与实现:变分自编码器(VAE)完整代码详解

news2024/10/5 14:35:41

VAE理论上一篇已经详细讲完了,虽然VAE已经是过去的东西了,但是它对后面强大的生成模型是很有指导意义的。接下来,我们简单实现一下其代码吧。

1 VAE在minist数据集上的实现

完整的代码如下,没有什么特别好讲的。

import cv2
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary

""" 就用线性层构造最简单的vae吧"""


class VAE(nn.Module):
    def __init__(self, image_size=28*28, hidden1=400, hidden2=100, latent_dims=40):
        super().__init__()

        # encoder
        self.encoder = nn.Sequential(
            nn.Linear(image_size, hidden1),
            nn.ReLU(),
            nn.Linear(hidden1, hidden2),
            nn.ReLU(),
        )
        self.mu = nn.Sequential(
            nn.Linear(hidden2, latent_dims),
        )

        self.logvar = nn.Sequential(
            nn.Linear(hidden2, latent_dims),
        )   # 由于方差是非负的,因此预测方差对数

        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dims, hidden2),
            nn.ReLU(),
            nn.Linear(hidden2, hidden1),
            nn.ReLU(),
            nn.Linear(hidden1, image_size),
            nn.Tanh()
        )

    # 重参数,为了可以反向传播
    def reparametrization(self, mu, logvar):
        # sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))
        std = 0.5 * torch.exp(logvar)
        # N(mu, std^2) = N(0, 1) * std + mu
        z = torch.randn(std.size(), device=mu.device) * std + mu
        return z

    def forward(self, x):
        en = self.encoder(x)
        mu = self.mu(en)
        logvar = self.logvar(en)
        z = self.reparametrization(mu, logvar)

        return self.decoder(z), mu, logvar


def loss_function(fake_imgs, real_imgs, mu, logvar):

    kl = -0.5 * torch.sum(1 + logvar - torch.exp(logvar) - mu ** 2)
    reconstruction = ((real_imgs - fake_imgs)**2).sum()

    return kl, reconstruction


def train(num_epoch):

    write_fake = SummaryWriter(f'logs/fake')

    device = torch.device("cuda:0")

    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)

    vae = VAE().to(device)
    optimizer = torch.optim.Adam(vae.parameters(), lr=0.0003)

    vae.train()
    step = 0
    for epoch in range(num_epoch):
        for batch_indx, (inputs, _) in enumerate(trainloader):

            inputs = inputs.to(device)

            real_imgs = torch.flatten(inputs, start_dim=1)

            fake_imgs, mu, logvar = vae(real_imgs)

            loss_kl, loss_re = loss_function(fake_imgs, real_imgs, mu, logvar)

            loss_all = loss_kl + loss_re

            optimizer.zero_grad()
            loss_all.backward()
            optimizer.step()

            print(f"epoch:{epoch}, loss kl:{loss_kl.item()}, loss re:{loss_re.item()}, loss all:{loss_all.item()}")
            if batch_indx == 0:
                with torch.no_grad():
                    x = torch.randn((32, 40)).to(device)
                    fake = vae.decoder(x).reshape(-1, 1, 28, 28)
                    img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)

                    write_fake.add_image(
                        "Mnist Fake Image", img_grid_fake, global_step=step
                    )
                    step += 1


if __name__ == "__main__":
    summary(VAE(), input_size=(1, 784))
    train(1000)

模型结构打印如下:
VAE [1, 784] –
├─Sequential: 1-1 [1, 100] –
│ └─Linear: 2-1 [1, 400] 314,000
│ └─ReLU: 2-2 [1, 400] –
│ └─Linear: 2-3 [1, 100] 40,100
│ └─ReLU: 2-4 [1, 100] –
├─Sequential: 1-2 [1, 40] –
│ └─Linear: 2-5 [1, 40] 4,040
├─Sequential: 1-3 [1, 40] –
│ └─Linear: 2-6 [1, 40] 4,040
├─Sequential: 1-4 [1, 784] –
│ └─Linear: 2-7 [1, 100] 4,100
│ └─ReLU: 2-8 [1, 100] –
│ └─Linear: 2-9 [1, 400] 40,400
│ └─ReLU: 2-10 [1, 400] –
│ └─Linear: 2-11 [1, 784] 314,384
│ └─Tanh: 2-12 [1, 784] –

训练结果,从结果上来看,是不如GAN的,主要原因在于其在KL散度和重建损失之间很难做到平衡,所以很难训练得好,当然原因是多方面的。
在这里插入图片描述

2 VAE的缺陷

变分自编码器(VAE, Variational Autoencoder)作为一种强大的深度学习模型,在生成建模领域有着广泛的应用,但它也存在一些缺陷,主要包括:

  • 生成样本质量:与生成对抗网络(GANs)相比,VAE生成的样本可能显得较为模糊或缺乏清晰度。尽管VAE能够生成连续且有结构的潜在空间,其生成的样本在某些情况下可能不够真实或细节不够丰富。

  • 潜在空间的连续性问题:虽然VAE设计用于学习连续的潜在空间,以允许插值和生成流畅的变化序列,但在实践中,这种连续性可能不如理论中那样完美。潜在空间中可能会出现空洞或不连贯区域,影响样本生成的质量和连续性变换的效果。

  • KL散度的平衡问题:VAE通过在其损失函数中加入KL散度项来约束潜在变量的分布,以确保它接近先验分布(通常是标准正态分布)。然而,KL散度的权重难以选择,如果设置不当,可能导致模型过分关注重构损失而忽视了潜在空间的平滑性和多样性,或者相反。

  • 训练难度与稳定性:VAE的训练过程比一些其他模型更为复杂,涉及到优化 Evidence Lower Bound (ELBO),这可能导致训练过程较为不稳定,需要更多的计算资源和更长的训练时间。特别是优化过程中对似然的近似以及对数似然的下界处理增加了训练的复杂度。

  • 表达能力与模型容量:由于VAE的编码器和解码器结构相对简单(通常为全连接层或简单的卷积层),在处理高度复杂的高维数据时,其表达能力可能受限,影响生成样本的质量和多样性。

这些缺陷提示研究者和实践者在使用VAE时需要仔细调整模型架构、损失函数的平衡以及训练策略,以最大化其生成能力和实用性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1700775.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【计算机毕业设计】安卓054基于Android校园助手

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。🌹赠送计算机毕业设计600个选题excel文件,帮助大学选题。赠送开题报告模板&#xff…

渗透测试的测试流程与注意事项

软件测试流程 渗透测试是一种重要的软件测试技术,通过对系统进行模拟攻击和漏洞评估,帮助组织发现和修复潜在的安全风险,提高系统的安全性和稳定性。在进行渗透测试时,需要注意合法授权、技术能力、安全意识和报告质量等方面的问…

基于springboot实现华府便利店信息管理系统项目【项目源码+论文说明】

基于springboot实现华府便利店信息管理系统演示 摘要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本华府便利店信息管理系统就是在这样的大环境下诞生&#xff…

VMware安装保姆教程、Docker安装/依赖安装缓慢等问题

常见问题前置: 1、docker依赖安装缓慢,没有走设置的资源库:解决安装docker-ce过慢 Operation too slow. Less than 1000 bytes/sec transferred the last 30 seconds‘) 在添加阿里云镜像后安装依旧慢: yum-config-manager --add-repo http://mirrors.aliyun.com/docker…

移动云:连接未来的智慧之旅

随着数字化转型的加速,云服务在各行各业中的应用越来越广泛。移动云不仅提供了灵活的计算和存储资源,还通过创新的技术手段,为企业和开发者解决了许多实际问题。在这个变革的大背景下,移动云服务作为中国移动倾力打造的云业务品牌…

计算机网络6——应用层5 DHCP/SNMP

文章目录 一、动态主机配置协议 DHCP二、简单网络管理协议 SNMP1、网络管理的基本概念2、管理信息结构SMI1)被管对象的命名2)被管对象的数据类型3)编码方法 3、管理信息库 MIB4、SNMP的协议数据单元和报文 一、动态主机配置协议 DHCP 为了把…

产品经理-交互说明撰写(八)

1. 交互说明 交互说明可以看做是交互设计师或者产品经理输出的最核心的”产品“交互说明面向的”用户“是下游的同事 ⇒ UI设计师、开发工程师、测试工程师 2. 基本交互形式 2.1 页面交互 2.2 元素控件交互 3. 交互说明主要包括以下3个维度 3.1 页面流程(页面之…

unity制作app(10)--统一字体

1.载入字体,微软雅黑,需要3分钟左右 加载进来3个 2.font文件夹下创建一个txt,内部的内容如下: 啊阿埃挨哎唉哀皑癌蔼矮艾碍爱隘鞍氨安俺按暗岸胺案肮昂盎凹敖熬翱袄傲奥懊澳芭捌扒叭吧笆八疤巴拔跋靶把耙坝霸罢爸白柏…

linux系统更改SSH端口号配置

1.编写sshd.config cd /etc/ssh sudo cp sshd_config sshd_config.bak vim sshd_config2.重启服务 systemctl restart sshd 结束!!

第22讲:RBD块存储COW克隆解除父子镜像的依赖关系

RBD块存储COW克隆解除父子镜像的依赖关系 1.COW镜像克隆存在的依赖关系 在前面使用copy-on-write机制基于快照做出来的链接克隆,与快照依赖性很强,如果快照损坏或者丢失,那么克隆的镜像将无法使用,使用这个镜像创建的虚拟机也会…

【详细介绍WebKit的结构】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出…

记住这三个神仙代码,时刻为你的电脑保驾护航

在这个数字化飞速发展的时代,我们的电脑不仅存储着重要的个人信息,还承载着繁重的工作任务。如何确保电脑的安全与稳定运行,成为了一个至关重要的问题。今天小编给大家分享这三个神仙代码,记好了这三个代码,时刻为你的…

如何修改 Miyoo Mini + 中的键位 Onion OS

如何修改 Miyoo Mini 中的键位 Onion OS MiyooMini 键位跟 XBox 键位不同 MiyooMini 买来之后就发现键位跟 XBox 手柄的键位不同。 玩 FC 游戏的时候也非常别扭,跟我以前玩 FC 游戏时的键位非常不同,正好AB XY 调换过来了。 实物如下: …

阿里云产品DTU评测报告(二)

阿里云产品DTU评测报告(二) 问题回顾问题处理继续执行 问题回顾 基于上一次DTU评测,在评测过程中遇到了windows系统情况下执行amp命令失败的情况,失败情况如图 导致后续命令无法执行,一时之间不知如何处理&#xff0…

人力资源管理信息化系统如何支持企业开展管理诊断?

华恒智信人力资源顾问有限公司致力于帮助企业开展人力资源管理方面的各项提升改进工作,在长期的咨询工作中,最常听到企业提到的问题莫过于管理诊断方面的问题,事实上,很多企业在日常工作中,都意识到企业内部存在管理方…

AC/DC电源模块:提供高质量的电力转换解决方案

BOSHIDA AC/DC电源模块:提供高质量的电力转换解决方案 AC/DC电源模块是一种电力转换器件,可以将交流电转换为直流电。它通常用于各种电子设备和系统中,提供高质量的电力转换解决方案。 AC/DC电源模块具有许多优点。首先,它能够提…

5.4 Go 匿名函数与闭包

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

2024/5/27 英语每日一段

Rejecting diet culture became something of a feminist cause. “A growing number of women are joining in an anti-diet movement,” The New York Times reported in 1992. “They are forming support groups and ceasing to diet with a resolve similar to that of se…

一篇文章带你快速搞定Kafka术语no.2

在Kafka的世界中有很多概念和术语是需要你提前理解并熟练掌握的,这对于后面你深入学习Kafka各种功能和特性将大有裨益。下面我来盘点一下Kafka的各种术语。 在专栏的第一期我说过Kafka属于分布式的消息引擎系统,它的主要功能是提供一套完备的消息发布与…

Spring Boot 系统学习第三天:Spring依赖注入原理分析

1.概述 Spring中关于依赖注入的代码实现非常丰富,涉及大量类和组件之间的协作与交互。从原理上讲,任何一个框架都存在一条核心执行流程,只要抓住这条主流程,就能把握框架的整体代码结构,Spring也不例外。无论采用何种依…