pytorch-AutoEncoders实战之VAE

news2024/9/20 13:28:56

目录

  • 1. VAE回顾
  • 2. KL的计算公式
  • 3. 构建网络
  • 4. 模型训练

1. VAE回顾

VAE = Variational Auto Encoder,变分自编码器。是一种常见的生成模型,属于无监督学习的范畴。它能够学习一个函数/模型,使得输出数据的分布尽可能的逼近原始数据分布,其基本思路是:把一堆真实样本通过编码器网络变换成一个理想的数据分布,然后这个数据分布再传递给一个解码器网络,得到一堆生成样本,生成样本与真实样本足够接近的话,就训练出了一个VAE模型.

下图中的公式,前半部分计算的是重建误差,可以理解为MSE或者是Cross Entropy,而后半部分KL是散度的公式,主要是计算q分布与p分布的相似度。
那么公式的目标就是重建误差越小越好,q和p的分布越接近越好。
在这里插入图片描述

2. KL的计算公式

reparametrize trick
在这里插入图片描述

在这里插入图片描述
按照上图推导公式实现即可。

3. 构建网络

根据公式可以知道,前半部分计算的是重建误差,后半部分是KL,再根据reparametrize trick分别计算z和epison~N(0, 1)
先将encode的[b,20],切分为两个[b,10]分别作为μ和σ,通过μ和σ计算z值,代码如下:

 mu, sigma = h_.chunk(2, dim=1)
 # reparametrize trick, epison~N(0, 1)
 z = mu + sigma * torch.randn_like(sigma)

计算KL,根据2中的推导公式写代码即可,代码中batchsz2828意思是计算像素级的kld,1e-8是防止log函数变量为0时,趋于无穷大,这里起到限幅的作用

kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

完整代码:

import  torch
from    torch import nn

class VAE(nn.Module):

    def __init__(self):
        super(VAE, self).__init__()

        # [b, 784] => [b, 20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

        self.criteon = nn.MSELoss()

    def forward(self, x):
        """

        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)
        # encoder
        # [b, 20], including mean and sigma
        h_ = self.encoder(x)
        # [b, 20] => [b, 10] and [b, 10]
        mu, sigma = h_.chunk(2, dim=1)
        # reparametrize trick, epison~N(0, 1)
        h = mu + sigma * torch.randn_like(sigma)

        # decoder
        x_hat = self.decoder(h)
        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)

        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

        return x_hat, kld

4. 模型训练

与上一篇的AutoEncoders步骤相近,这里不再详述

import  torch
from    torch.utils.data import DataLoader
from    torch import nn, optim
from    torchvision import transforms, datasets

from    ae import AE
from    vae import VAE

import  visdom

def main():
    mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).next()
    print('x:', x.shape)

    device = torch.device('cuda')
    # model = AE().to(device)
    model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):
        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)

            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = - loss - 1.0 * kld
                loss = - elbo

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        print(epoch, 'loss:', loss.item(), 'kld:', kld.item())

        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2149026.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CCRC-CDO首席数据官:未成年人首次上网年龄持续降低

近日,中国社会科学院新闻与传播研究所联合社会科学文献出版社发布了《青少年蓝皮书:中国未成年人互联网运用报告(2024)》,该报告对中国未成年人的互联网使用情况进行了全面的研究和专项汇报。 调查数据透露,未成年人接触网络的年…

光耦选型 | 充电领域使用光耦型号推荐——晶体管光耦KL3H7

在充电领域,光耦作为一种常见的光电耦合器件,通常用于电气隔离、信号传输、电池保护和充电控制等方面。 电源气隔离:光耦可用于实现电源气隔离,将输入和输出电路进行隔离,提高系统的安全性和稳定性。 信号传输&#…

0基础也可以转行做产品经理吗?

转行成为产品经理,即使没有相关工作经验或技术背景,仍然是一个可行的目标。产品经理的职责多样,但成功的产品经理通常需要具备一系列的技能和素养,包括项目管理、市场分析、用户体验设计等。在没有相关经验的情况下,通…

动手学深度学习PyTorch 第 1 章 引言

在线电子书 深度学习介绍 安装 使用conda环境 conda create -n d2l-zh python3.8 pip安装需要的包 pip install jupyter d2l torch torchvision下载代码并执行 wget https://zh-v2.d2l.ai/d2l-zh.zip unzip d2l-zh.zip jupyter notebookpip install rise如果不想使用jupyt…

建筑企业有闲置资质怎么办?

在建筑行业中,企业可能会因为业务调整、市场变化或战略转型而拥有一些不再使用的资质。这些闲置的资质如果得不到合理处理,不仅会造成资源浪费,还可能影响企业的合规性。因此,建筑企业在面对闲置资质时,需要采取合适的…

《微信小程序实战(3) · 推广海报制作》

📢 大家好,我是 【战神刘玉栋】,有10多年的研发经验,致力于前后端技术栈的知识沉淀和传播。 💗 🌻 CSDN入驻不久,希望大家多多支持,后续会继续提升文章质量,绝不滥竽充数…

C++ 9.19

练习&#xff1a;要求在堆区申请5个double类型的空间&#xff0c;用于存储5名学生的成绩。请自行封装函数完成 1> 空间的申请 2> 学生成绩的录入 3> 学生成绩的输出 4> 学生成绩进行降序排序 5> 释放申请的空间 主程序中用于测试上述函数 #include<ios…

选址模型 | 基于混沌模拟退火粒子群优化算法的电动汽车充电站选址与定容(Matlab)

目录 效果一览基本介绍程序设计参考资料 效果一览 基本介绍 基于混沌模拟退火粒子群优化算法的电动汽车充电站选址与定容&#xff08;Matlab&#xff09; 问题建模&#xff1a;首先&#xff0c;需要将电动汽车充电站选址与定容问题进行数学建模&#xff0c;确定目标函数和约束…

双指针 -- 移动零、复写零、快乐数

目录 移动零 题解&#xff1a; 复写零 题解&#xff1a; 快乐数 题解&#xff1a; 盛最多水的容器 移动零 283. 移动零 - 力扣&#xff08;LeetCode&#xff09;https://leetcode.cn/problems/move-zeroes/description/ 题解&#xff1a; 题目要求我们把数组中的 0 放…

大数据新视界 --大数据大厂之DevOps与大数据:加速数据驱动的业务发展

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

ReentrantLock实现原理

ReentrantLock是基于AQS实现的可重入锁&#xff0c;比synchronized更灵活&#xff0c;可以设置超时时间&#xff0c;可以中断&#xff0c;支持公平和非公平锁两种方式。公平锁获取锁的方式&#xff1a; 主要就是这三步 第一步&#xff1a;tryAcquire 先尝试获得锁 先获取state&…

Flux【真人模型】:高p高糊反向真实质感!网图风格的Lora模型,超逼真的AI美女大模型!

大家好&#xff0c;我是画画的小强 今天和大家分享一款基于Flux训练的网图风格的lora模型&#xff1a;墨幽-F.1-Lora-网图&#xff0c;该Lora模型由墨幽团队出品&#xff0c;旨在生成高p高糊的反向真实质感图片&#xff0c;而非真实摄影图片。不过&#xff0c;在自己出图过程中…

Linux操作系统 进程(3)

接上文 Linux进程优先级之后&#xff0c;我们了解到僵尸进程与孤儿进程的形成原因&#xff0c;既然是因为父进程没有接收子进程的退出状态导致的&#xff0c;那么我们该如何去获取子进程的退出状态呢&#xff1f;那本篇文章将围绕这个问题来解释进程。 环境 &#xff1a; vsco…

更高效的搜索工具,国内免费好用的AI智能搜索引擎工具

搜索引擎是我们获取信息的重要渠道&#xff0c;然而由于搜索引擎搜索结果存在较多的广告以及一些无关内容&#xff0c;这使我们的搜索效率变得更低效。小编就和大家分享几款国内免费好用的AI智能搜索工具&#xff0c;提高搜索效率。 1.开搜AI搜索 开搜AI搜索是一款基于深度学…

【学术会议:中国杭州,机器学习和计算机应用面临的新的挑战问题和研究方向】第五届机器学习与计算机应用国际学术会议(ICMLCA 2024)

您的学术研究值得被更多人看到&#xff01; 在这里&#xff0c;我为您提供精准的会议推荐&#xff0c;包括水利土木工程、计算机科学、地球科学、机械自动化、材料与制造技术、经管金融、人文社科等主流学科相关领域的国际会议。快速的稿件录用和高效的检索服务将确保您的研究…

30个小米集团芯片工程师岗位面试真题

在竞争激烈的半导体行业&#xff0c;小米集团作为全球知名的科技公司&#xff0c;对于芯片工程师的选拔标准自然也是极为严格。本篇分享一份《30个小米集团芯片工程师岗位面试真题》&#xff0c;通过对这30道真题的深入分析&#xff0c;我们可以一窥小米对于芯片设计人才的期待…

缓存数据和数据库数据一致性问题

根据以上的流程没有问题&#xff0c;但是当数据变更的时候&#xff0c;如何把缓存变到最新&#xff0c;使我们下面要讨论的问题 1. 更新数据库再更新缓存 场景&#xff1a;数据库更新成功&#xff0c;但缓存更新失败。 问题&#xff1a; 当缓存失效或过期时&#xff0c;读取…

Web后端服务平台解析漏洞与修复、文件包含漏洞详解

免责申明 本文仅是用于学习检测自己搭建的Web后端服务平台解析漏洞、文件包含漏洞的相关原理,请勿用在非法途径上,若将其用于非法目的,所造成的一切后果由您自行承担,产生的一切风险和后果与笔者无关;本文开始前请认真详细学习《‌中华人民共和国网络安全法》‌及其所在国…

阿贝云评测:免费虚拟主机和免费云服务器体验分享

最近我有幸体验了阿贝云提供的免费虚拟主机和免费云服务器&#xff0c;在这里分享一下我的使用体验。首先我想说的是&#xff0c;阿贝云的服务真的很不错。他们提供的免费虚拟主机性能稳定&#xff0c;速度快&#xff0c;对于刚开始建站的小伙伴来说是一个很好的选择。免费云服…

技术美术百人计划 | 《5.1.1 PBR-基于物理的材质》笔记

1. PBR定义-基于物理的材质 PBR&#xff0c;或者用更通俗一些的称呼是指基于物理的渲染(Physically Based Rendering)&#xff0c;它指的是一些在不同程度上都基于与现实世界的物理原理更相符的基本理论所构成的渲染技术的集合。 正因为基于物理的渲染目的便是为了使用一种更…