变分自动编码器 (VAE)02/2 PyTorch 教程

news2025/1/17 21:47:47

一、说明

        在自动编码器中,来自输入数据的信息被映射到固定的潜在表示中。当我们旨在训练模型以生成确定性预测时,这特别有用。相比之下,变分自动编码器(VAE)将输入数据转换为变分表示向量(顾名思义),其中该向量的元素表示有关输入数据分布的不同属性。VAE的这种概率特性使其成为一个生成模型。VAE中的潜在表示由最能定义输入数据的概率分布(μ,σ)组成。

二、前提知识

        要了解有关VAE直觉的更多信息,我建议您阅读了解变分自动编码器(VAE)和什么是变分自动编码器?

三、VAE的PyTorch实现

        在本文中,我们只关注 PyTorch 中的简单 VAE,并在 MNIST 数据集上训练后可视化其潜在表示。让我们从导入一些库开始:

import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.utils import save_image, make_grid

并下载 MNIST 数据集并制作数据加载器:

# create a transofrm to apply to each datapoint
transform = transforms.Compose([transforms.ToTensor()])

# download the MNIST datasets
path = '~/datasets'
train_dataset = MNIST(path, transform=transform, download=True)
test_dataset  = MNIST(path, transform=transform, download=True)

# create train and test dataloaders
batch_size = 100
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

让我们可视化一些示例训练数据:

# get 25 sample training images for visualization
dataiter = iter(train_loader)
image = dataiter.next()

num_samples = 25
sample_images = [image[0][i,0] for i in range(num_samples)] 

fig = plt.figure(figsize=(5, 5))
grid = ImageGrid(fig, 111, nrows_ncols=(5, 5), axes_pad=0.1)

for ax, im in zip(grid, sample_images):
    ax.imshow(im, cmap='gray')
    ax.axis('off')

plt.show()
25 张样本训练图像

        现在,我们创建一个简单的VAE,它具有完全连接的编码器和解码器。输入维度为 784,这是 MNIST 图像 (28×28) 的扁平维度。在编码器中,均值 (μ) 和方差 (σ²) 向量是我们的变分表示向量 (size=200)。请注意,我们将潜在方差与 epsilon (ε) 参数相乘,以便在解码之前重新参数化。这使我们能够执行反向传播并解决节点随机性。在此处阅读有关重新参数化的更多信息。

        此外,我们的最终编码器维度为 2,即μ和σ向量。这些连续向量定义了我们的潜在空间分布,使我们能够在VAE中对图像进行采样。

class VAE(nn.Module):

    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=200, device=device):
        super(VAE, self).__init__()

        # encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, latent_dim),
            nn.LeakyReLU(0.2)
            )
        
        # latent mean and variance 
        self.mean_layer = nn.Linear(latent_dim, 2)
        self.logvar_layer = nn.Linear(latent_dim, 2)
        
        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(2, latent_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
            )
     
    def encode(self, x):
        x = self.encoder(x)
        mean, logvar = self.mean_layer(x), self.logvar_layer(x)
        return mean, logvar

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(device)      
        z = mean + var*epsilon
        return z

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterization(mean, logvar)
        x_hat = self.decode(z)
        return x_hat, mean, log_var
        
    def forward(self, x):
        mean, log_var = self.encode(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) 
        x_hat = self.decode(z)  
        return x_hat, mean, log_var

        现在,我们可以定义我们的模型和优化器:

model = VAE().to(device)
optimizer = Adam(model.parameters(), lr=1e-3)

        VAE中的损失函数由再现损失和库尔巴克-莱布勒(KL)散度组成。KL 散度是用于衡量两个概率分布之间距离的指标。KL 散度是生成建模中的一个重要概念,但在本教程中,我们不会更详细地介绍。:了解KL背离的直观指南。

def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD

        最后,我们可以训练我们的模型:

def train(model, optimizer, epochs, device):
    model.train()
    for epoch in range(epochs):
        overall_loss = 0
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.view(batch_size, x_dim).to(device)

            optimizer.zero_grad()

            x_hat, mean, log_var = model(x)
            loss = loss_function(x, x_hat, mean, log_var)
            
            overall_loss += loss.item()
            
            loss.backward()
            optimizer.step()

        print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss/(batch_idx*batch_size))
    return overall_loss

train(model, optimizer, epochs=50, device=device)

        我们现在知道,从潜在空间生成图像所需要的只是两个浮点值(平均值和方差)。让我们从潜在空间生成一些图像:

def generate_digit(mean, var):
    z_sample = torch.tensor([[mean, var]], dtype=torch.float).to(device)
    x_decoded = model.decode(z_sample)
    digit = x_decoded.detach().cpu().reshape(28, 28) # reshape vector to 2d array
    plt.imshow(digit, cmap='gray')
    plt.axis('off')
    plt.show()

generate_digit(0.0, 1.0), generate_digit(1.0, 0.0)

                有趣!更令人印象深刻的潜在空间视图:

def plot_latent_space(model, scale=1.0, n=25, digit_size=28, figsize=15):
    # display a n*n 2D manifold of digits
    figure = np.zeros((digit_size * n, digit_size * n))

    # construct a grid 
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = torch.tensor([[xi, yi]], dtype=torch.float).to(device)
            x_decoded = model.decode(z_sample)
            digit = x_decoded[0].detach().cpu().reshape(digit_size, digit_size)
            figure[i * digit_size : (i + 1) * digit_size, j * digit_size : (j + 1) * digit_size,] = digit

    plt.figure(figsize=(figsize, figsize))
    plt.title('VAE Latent Space Visualization')
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("mean, z [0]")
    plt.ylabel("var, z [1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent_space(model)

        潜在空间可视化,范围:[-1.0、1.0]

        这就是介于 -1.0 和 1.0 之间的均值和方差值的潜在空间的外观。如果我们将此比例更改为 -5.0 和 5.0 会发生什么?

潜在空间可视化,范围:[-5.0、5.0]

        再次,有趣!现在,我们可以看到大多数数字表示形式所在的均值和方差值的范围。现在,我们知道如何从头开始构建一个简单的VAE,对图像进行采样并可视化潜在空间。但VAE并不止于此,还有更先进的技术使表征学习更加迷人。我将在以后的文章中探讨它们。在Github,LinkedIn和 Google Scholar上找到我。

四、在Google Colab中尝试代码:

import torch
import numpy as np
import torch.nn as nn
from torch.optim import Adam
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.utils import save_image, make_grid

# create a transofrm to apply to each datapoint
transform = transforms.Compose([transforms.ToTensor()])

# download the MNIST datasets
path = '~/datasets'
train_dataset = MNIST(path, transform=transform, download=True)
test_dataset  = MNIST(path, transform=transform, download=True)

# create train and test dataloaders
batch_size = 100
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# get 25 sample training images for visualization
dataiter = iter(train_loader)
image = dataiter.next()

num_samples = 25
sample_images = [image[0][i,0] for i in range(num_samples)] 

fig = plt.figure(figsize=(5, 5))
grid = ImageGrid(fig, 111, nrows_ncols=(5, 5), axes_pad=0.1)

for ax, im in zip(grid, sample_images):
    ax.imshow(im, cmap='gray')
    ax.axis('off')

plt.show()

class Encoder(nn.Module):
    
    def __init__(self, input_dim=784, hidden_dim=512, latent_dim=256):
        super(Encoder, self).__init__()

        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, latent_dim)
        self.var = nn.Linear (hidden_dim, latent_dim)
        self.LeakyReLU = nn.LeakyReLU(0.2)
        self.training = True
        
    def forward(self, x):
        x = self.LeakyReLU(self.linear1(x))
        x = self.LeakyReLU(self.linear2(x))

        mean = self.mean(x)
        log_var = self.var(x)                     
        return mean, log_var

class Decoder(nn.Module):
    
    def __init__(self, output_dim=784, hidden_dim=512, latent_dim=256):
        super(Decoder, self).__init__()

        self.linear2 = nn.Linear(latent_dim, hidden_dim)
        self.linear1 = nn.Linear(hidden_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        x = self.LeakyReLU(self.linear2(x))
        x = self.LeakyReLU(self.linear1(x))
        
        x_hat = torch.sigmoid(self.output(x))
        return x_hat
class VAE(nn.Module):

    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=200, device=device):
        super(VAE, self).__init__()

        # encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, latent_dim),
            nn.LeakyReLU(0.2)
            )
        
        # latent mean and variance 
        self.mean_layer = nn.Linear(latent_dim, 2)
        self.logvar_layer = nn.Linear(latent_dim, 2)
        
        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(2, latent_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
            )
     
    def encode(self, x):
        x = self.encoder(x)
        mean, logvar = self.mean_layer(x), self.logvar_layer(x)
        return mean, logvar

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(device)      
        z = mean + var*epsilon
        return z

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterization(mean, logvar)
        x_hat = self.decode(z)
        return x_hat, mean, log_var
        
    def forward(self, x):
        mean, log_var = self.encode(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) 
        x_hat = self.decode(z)  
        return x_hat, mean, log_var
model = VAE().to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD

def train(model, optimizer, epochs, device, x_dim=784):
    model.train()
    for epoch in range(epochs):
        overall_loss = 0
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.view(batch_size, x_dim).to(device)

            optimizer.zero_grad()

            x_hat, mean, log_var = model(x)
            loss = loss_function(x, x_hat, mean, log_var)
            
            overall_loss += loss.item()
            
            loss.backward()
            optimizer.step()

        print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss/(batch_idx*batch_size))
    return overall_loss
train(model, optimizer, epochs=50, device=device)
Epoch 1 	Average Loss:  180.83752029750104
	Epoch 2 	Average Loss:  163.2696611703099
	Epoch 3 	Average Loss:  158.45957443721306
	Epoch 4 	Average Loss:  155.22525566699707
	Epoch 5 	Average Loss:  153.2932642294449
	Epoch 6 	Average Loss:  151.85221148202734
	Epoch 7 	Average Loss:  150.7171567847454
	Epoch 8 	Average Loss:  149.69312638577316
	Epoch 9 	Average Loss:  148.78454667284015
	Epoch 10 	Average Loss:  148.15143693264815
def generate_digit(mean, var):
    z_sample = torch.tensor([[mean, var]], dtype=torch.float).to(device)
    x_decoded = model.decode(z_sample)
    digit = x_decoded.detach().cpu().reshape(28, 28) # reshape vector to 2d array
    plt.title(f'[{mean},{var}]')
    plt.imshow(digit, cmap='gray')
    plt.axis('off')
    plt.show()

#img1: mean0, var1 / img2: mean1, var0
generate_digit(0.0, 1.0), generate_digit(1.0, 0.0)

def plot_latent_space(model, scale=5.0, n=25, digit_size=28, figsize=15):
    # display a n*n 2D manifold of digits
    figure = np.zeros((digit_size * n, digit_size * n))

    # construct a grid 
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = torch.tensor([[xi, yi]], dtype=torch.float).to(device)
            x_decoded = model.decode(z_sample)
            digit = x_decoded[0].detach().cpu().reshape(digit_size, digit_size)
            figure[i * digit_size : (i + 1) * digit_size, j * digit_size : (j + 1) * digit_size,] = digit

    plt.figure(figsize=(figsize, figsize))
    plt.title('VAE Latent Space Visualization')
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("mean, z [0]")
    plt.ylabel("var, z [1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()
     
plot_latent_space(model, scale=1.0)
plot_latent_space(model, scale=5.0)

礼萨·卡兰塔尔

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1081106.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

气象台卫星监测vr交互教学增强学生的学习兴趣和动力

对地观测是以地球为研究对象,依托卫星、飞船等光电仪器,进行各种探测活动,其核心是遥感技术,因此为了让遥感专业学员能提前熟悉对地观测规则、流程、方法及注意事项,借助VR虚拟现实制作的三维仿真场景,能让…

全新彩虹商城时光模板知识付费系统源码+内有5000多商品+易支付源码

源码简介: 全新彩虹商城时光模板知识付费系统源码,这是最新的彩虹知识付费商城系统,具备众多强大且实用的功能。首先,它支持二级分类和多级分销,使得商品分类更为清晰,销售网络更具扩展性。 其次&#xf…

机器人轨迹规划算法的研究现状

近年来,随着机器人技术的迅速发展,机器人在工业、医疗、军事等领域的应用越来越广泛。机器人轨迹规划是机器人控制的重要环节之一,它决定了机器人在执行任务时的运动轨迹,直接影响机器人的精度、速度和稳定性。因此,机…

【PCIE720】基于PCIe总线架构的高性能计算(HPC)硬件加速卡

PCIE720是一款基于PCI Express总线架构的高性能计算(HPC)硬件加速卡,板卡采用Xilinx的高性能28nm 7系列FPGA作为运算节点,在资源、接口以及时钟的优化,为高性能计算提供卓越的硬件加速性能。板卡一共具有5个FPGA处理节…

代码混淆界面介绍

代码混淆界面介绍 代码混淆功能包括oc,swift,类和函数设置区域。其他flutter,混合开发的最终都会转未oc活着swift的的二进制,所以没有其他语言的设置。 代码混淆功能分顶部的显示控制区域:显示方式,风险等…

python 深度学习 解决遇到的报错问题6

目录 一、解决报错HTTPSConnectionPool(hosthuggingface.co, port443): Max retries exceeded with url: /bert-base-uncased/resolve/main/vocab.txt (Caused by ConnectTimeoutError(, Connection to huggingface.co 如何从huggingface官网下载模型 二、nx.draw if cf._ax…

jupyter 切换虚拟环境

当前只有两个环kernel 我已经创建了很多虚拟环境,如何在notebook中使用这些虚拟环境呢?请看下面 比如说我要添加nlp 这个虚拟环境到notebook中 1. 切换到nlp环境 2. 安装如下模块 pip install ipykernel 3. 执行如下命令 python -m ipykernel install …

VS2019如何显示和去除控制台页面

这是控制台页面: 方法: 选中目标项目,右键--->属性--->配置属性--->链接器--->系统--->子系统--->(窗口/控制台)

地级市HVV | 未授权访问合集

在网站前后端分离盛行下,将大部分权限控制交给前端,导致js中隐藏未授权或者可绕过的前端鉴权。前后端分离的好处是提高开发效率,同时防止黑客更直接的对服务器造成危害,但权限控制的工作量全部交给前端会导致大量页面未授权或者后…

面试经典 150 题 1 —(双指针)— 125. 验证回文串

125. 验证回文串 方法一 class Solution { public:bool isPalindrome(string s) {string newStr "";for(int fast 0; fast < s.size(); fast){if(isalnum(s[fast]))){newStr tolower(s[fast]);}}string tmp newStr;reverse(tmp.begin(), tmp.end());if(strcm…

【计算机网络】TCP协议与UDP协议详解

文章目录 一、传输层 1、1 再次理解传输层 1、2 再次理解端口号 1、2、1 端口号范围划分 1、2、2 认识知名端口号 1、3 网络常用指令netstat 与 pidof 二、UDP协议 2、1 UDP协议的报文 2、2 UDP的特点 2、3 UDP的缓冲区 三、TCP协议 3、1 TCP协议的报文 3、2 确认应答 3、3 按…

计算机毕业设计选什么题目好?springboot 个人健康信息管理系统

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

基于SpringBoot的抗疫物资管理系统

目录 前言 一、技术栈 二、系统功能介绍 用户管理 公告信息管理 轮播图管理 物质分类管理 物质信息管理 物质入库管理 物质出库管理 个人信息 前台首页功能实现 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 随着现在网络的快速发展&#xff0c;网…

简单制作RT-Thread Studio的CH32V303的BSP支持包

简单制作RT-Thread Studio的CH32V303的BSP支持包 开原仓库链接在此&#xff1a;RTT_Studio_BSP_CH32V303 参考 CH32V307V-R1&#xff08;V1.0.8&#xff09;的 BSP&#xff0c;更新了外设驱动库之类的。 可以在 RT-Thread SDK 管理器中导入离线资源包&#xff0c;可以新建 RT…

CSS学习基础知识

CSS学习笔记 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-width,…

Linux Centos7 下使用yum安装的nginx平滑升级

1. 查看当前nginx版本 1nginx -v2. 查看centos版本 1cat /etc/redhat-release3. 创建一个新的文件nginx.repo&#xff0c;其中第三行的7是因为我的centos版本是7点多的&#xff0c;你看自己是多少就改多少 1vim /etc/yum.repos.d/nginx.repo23[nginx]4namenginx repo 5baseu…

基于 ACK Fluid 的混合云优化数据访问(三):加速第三方存储的读访问,降本增效并行

作者&#xff1a;车漾 前文回顾&#xff1a; 本系列将介绍如何基于 ACK Fluid 支持和优化混合云的数据访问场景&#xff0c;相关文章请参考&#xff1a; 基于 ACK Fluid 的混合云优化数据访问&#xff08;一&#xff09;&#xff1a;场景与架构 基于 ACK Fluid 的混合云优化…

react antd table表格点击一行选中数据的方法

一、前言 antd的table&#xff0c;默认是点击左边的单选/复选按钮&#xff0c;才能选中一行数据&#xff1b; 现在想实现点击右边的部分&#xff0c;也可以触发操作选中这行数据。 可以使用onRow实现&#xff0c;样例如下。 二、代码 1.表格样式部分 //表格table样式部分{…

如何在手机上设置节日提醒和倒计时天数?

在平淡的生活和工作中&#xff0c;时不时有各种各样节日的点缀&#xff0c;为我们的日常增添了一些仪式感&#xff0c;例如春节、元宵节、情人节、端午节、七夕节等。此外还有一些特殊的日子也值得纪念&#xff0c;例如恋爱纪念日、结婚纪念日、亲朋好友生日等。面对这些节日&a…

provide,inject

通过provide和inject函数可以简便的实现跨级组件通讯 父组件提供 const count ref(0); provide(count, count);const updateCount (num) > {count.value num; }; provide(updateCount, updateCount);子组件或者孙子组件接受 const count inject(count); const upda…