Pytorch Advanced(二) Variational Auto-Encoder

news2025/1/11 9:50:39

自编码说白了就是一个特征提取器,也可以看作是一个降维器。下面找了一张很丑的图来说明自编码的过程。

自编码分为压缩和解码两个过程。从图中可以看出来,压缩过程就是将一组数据特征进行提取, 得到更深层次的特征。解码的过程就是利用之前的深层次特征再还原成为原来的数据特征。那么如何保证从压缩到解码两部分,原数据和解码数据保持一致呢?这就是要训练的过程。

如何理解降维?如果压缩的过程是卷积,维度可以根据核的个数变化,特征维度因此而改变。


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

dataset = torchvision.datasets.MNIST(root='../../data',
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

模型搭建:这里搭建的是一个变分自编码,Variational Autoencoder

那么变分自编码是为了解决什么问题呢? ——- 其主要思想还是希望学习隐层变量,并将其用来表示原始数据,但是它加另一个条件, 即隐层变量能学习原始数据的分布, 并反过来生产一些和原始数据相似的数据(这有啥用?—-可用于图片修复,让图片按训练集的数据分布变化)。

变分自编码 (Variational Autoencoder) 为了让隐层抓住输入数据特性, 而不是简单的输出数据=输入数据,他在隐层中加入随机噪声(单位高斯噪声)(这个过程也叫reparametrize),以确保隐层能较好抽象输入数据特点。

代码中怎么做的呢?

1、编码过程中我们保存了第二层线性层的输出。其中第二层包含有fc2与fc3两部分,他们是并联的。

2、给隐藏层加入随机噪声,作为解码的输入

class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

训练:由于训练中加入了噪声,所以损失值的结构也因此改变。一部分来源于解码内容核原内容的相似度,另一部分是kl_div,具体是什么意义需查看论文。

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        # For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    
    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

模型训练完成了之后该如何使用这个模型呢?

model.decode()是一个解码的过程,我们给他一个随机的中间特征z就可以输出一个数字图片了。

z = torch.randn(1,z_dim).to(device)
out = model.decode(z)
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

有了随机的一张图片之后,我们把他完整的放入模型中,生成了和输入相似的一张图片,也没看出来是修复了图像......

out,_,_ = model(out) 
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1003894.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python-爬虫-urllib

网络爬虫(Web Crawler),又叫网络蜘蛛、网络机器人,是一种自动化数据采集程序 数据采集 → 数据处理 → 数据存储 常见的工作流程如下: 1.定义采集的目标(网站、APP、公众号、小程序)&#xff…

RP9学习-2

1.基本元素2 1.1树 可以收起 添加子菜单 选中树的节点即可添加 移动层级 编辑树属性 选中某行文字,点击Edit Tree Properties 可以把箭头变成加减,另外也可以导入自己的图标 注意要使用自己的图标,需要勾选Show Icon 也可以给某个节点单…

开放式耳机也会有巅峰音质体验-南卡NANK OE PRO

前言 这两年,开放式耳机市场发展迅猛,新品层出不穷,各大耳机厂商也都相继推出了自家的产品。而在众多的厂家中,作为国内开发式耳机的TOP1,南卡通过多年来在业内领域的经验和专业的技术能力,为广大音乐爱好…

fastadmin在前端调用 /api/common/upload 返回未上传文件或超出服务器上传限制

第一步:在api目录直接调用 域名/api/common/upload 上传图片的时候要在Common.php文件里面把验证登录的 protected $noNeedLogin [init]; 方法注释掉。 // protected $noNeedLogin [init];protected $noNeedLogin *;protected $noNeedRight *; 第二步&#…

计算机竞赛 大数据分析:基于时间序列的股票预测于分析

1 简介 Hi,大家好,这里是丹成学长,今天向大家介绍一个大数据项目 大数据分析:基于时间序列的股票预测于分析 2 时间序列的由来 提到时间序列分析技术,就不得不说到其中的AR/MA/ARMA/ARIMA分析模型。这四种分析方法…

dp(1) - 数字三角形模型

898.数字三角形 题目链接 : 活动 - AcWing 题目 : 给定一个如下图所示的数字三角形,从顶部出发,在每一结点可以选择移动至其左下方的结点或移动至其右下方的结点,一直走到底层,要求找出一条路径,使路径上的数字的和…

利用LinuxPTP进行时间同步(软/硬件时间戳) - 研一

转自:https://blog.csdn.net/BUPTOctopus/article/details/86246335 官方文档:https://access.redhat.com/documentation/en-us/red_hat_enterprise_linux/6/html/deployment_guide/s1-using_ptp 查看网卡是否支持软硬件时间戳: sudo ethtoo…

为什么做期权卖方是比较高胜率的?

期权三大因素:行情方向、时间价值流失、波动率。波动率下降、时间价值流失,震荡行情,这几项对期权卖方交易有利,一般做期权卖方胜率基本可以达70%左右,下文揭秘为什么做期权卖方是比较高胜率的? 一、期权卖方交易如何…

MQTT网关对接水务二次供水管理平台案例

一、客户介绍 随着城市发展和人口增长,对水务行业的监测和管理要求也越来越高。然而,传统的水务行业监测方式存在很多不足,如数据传输的缓慢和不可靠,数据安全风险大等,为了更有效地监测和管理这些信息,供…

电子科大软件系统架构设计——系统规划

文章目录 系统规划定义意义目标任务路径规划规划步骤规划方法业务系统规划法业务流程重组法价值链分析法战略目标集转移法关键成功因素法 项目计划定义要素工作分解活动排序工期预算三点估计法德尔菲法 成本估算与计算进度安排甘特图法PERT图方法 可行性分析技术可行性分析进度…

固定资产电脑怎么编号管理

科技的发展已经深入到了我们的生活中的每一个角落,尤其是在办公室环境中,电脑已经成为了必不可少的工具。然而,随着电脑数量的增加和管理复杂性的提升,如何有效地管理和追踪这些固定资产变得越来越重要。本文将探讨一种创新的方式…

Scrum敏捷开发如何实施

​在当今高度变化的时代,软件开发的环境和要求也在不断变化。传统的开发方法往往难以适应这种快速变化,因此,一种新的软件开发方法——敏捷开发逐渐得到了广泛的关注和应用。 敏捷开发的实施可以按照以下步骤进行: 1、明确产品愿…

WebDAV之π-Disk派盘 + 飞傲音乐

飞傲音乐是一款专为手机解码耳放设计的本地播放器,旨在提供更符合发烧友使用习惯的音乐播放体验。它具备以下功能和特性: 1. DSD源码输出:支持DSD音频格式的输出,即使是普通手机也能够进行DSD硬解码播放。 2. Hi-Res高清音乐格式源码输出:支持高清音乐格式,可以播放高达…

led护眼灯真的能护眼吗?Led护眼灯的好处

随着人们对家庭环境艺术的重视,台灯因其摆设在桌案台几上的特殊地位,也要进求特有的装饰效果。家居用台灯开始逐新分流为工艺台灯和书写台灯两类。前者追求外观效果,将发展思路放在材质的创新、造型的求异上,以配合风格多样的家居…

字符串类型

目录 一、字符与字符串 二、字符串对象与自变量 三、正则表达式 1.普通字符 2.特殊字符 3.非打印字符 4.限定符 5.定位符 四、正则表达式的处理 1.Pattern.compile(String regex) 2.Matcher.matches() 3.Matcher.find() 4.Matcher.replaceAll(String replacement)…

CRM客户管理系统是什么?

CRM的含义我们都知道,是客户关系管理的缩写,更多地用来代表CRM系统。所以CRM管理又可以理解为通过CRM系统进行管理。那么下面我们就来详细说说,什么是crm管理? CRM管理功能主要包括: 营销管理: CRM系统可…

【规范】Apifox就应该这么玩

前言 🍊缘由 好的工具就要配好的玩法 起因是最近在回顾项目时,看到了年事已高并且长时间不用的Postman,发现之前自己整理的接口文档十分混乱且没有规律。遂打开现在使用的Apifox,将本狗目前项目中使用Apifox的整理规范和使用方…

【数据结构前置知识】初识集合框架和时间,空间复杂度

文章目录 1. 什么是集合框架2. 集合框架的重要性 3. 背后所涉及的数据结构以及算法3.1 什么是数据结构3.2 容器背后对应的数据结构3.3 相关java知识3.4 什么是算法 4.时间复杂度1. 如何衡量一个算法的好坏2. 算法效率3. 时间复杂度3.1 时间复杂度的概念3.2 大O的渐进表示法3.3 …

【ESP32】以蓝牙网关为例,记录队列的使用

📋 个人简介 💖 作者简介:大家好,我是喜欢记录零碎知识点的菜鸟打工人。😎📝 个人主页:欢迎访问我的 Ethernet_Comm 博客主页🔥🎉 支持我:点赞👍收…

geek完全卸载sqlserver2012

前言 有时候sqlserver2012 出现问题,需要卸载安装 会出现卸载不干净的问题 需要用到geek去卸载 卸载 双击exe打开软件 输入sql查询相关的软件 依次一个一个的去删除