在 CelebA 数据集上训练的 PyTorch 中的基本变分自动编码器

news2024/11/25 20:22:35
摩西·西珀博士

一、说明

        我最近发现自己需要一种方法将图像编码到潜在嵌入中,调整嵌入,然后生成新图像。有一些强大的方法可以创建嵌入从嵌入生成。如果你想同时做到这两点,一种自然且相当简单的方法是使用变分自动编码器。

        这样的深度网络不仅可以进行编码和解码,而且相当简单,我可以在以后的研究中使用它,而不必过多担心编码解码阶段的各种隐藏复杂性。我也更喜欢对软件内部有尽可能多的控制。

        因此,考虑到所有这些规范,我从 GitHub 收集了一些零碎的东西,施展了一些我自己的魔法,最终得到了一个漂亮、简单的变分自动编码器。我将在下面描述主要部分,完整的包可在以下位置找到:

vae-torch-celeba,PyTorch 中 CelebA 数据集的变分自动编码器,下载vae-torch-celeba的源码_GitHub_帮酷

PyTorch 中用于 CelebA 数据集的变分自动编码器 - GitHub - moshesipper/vae-torch-celeba:变分...

它相当小并且完全独立——这就是我的意图!

二、自动编码器

        为了使本文简短易懂,我将避免提供变分自动编码器的冗长概述。此外,您还可以在 Medium 上找到有关基础知识的优秀文章。我只提供三张快速图片。

        这是基本自动编码器的样子:

来源: https: //commons.wikimedia.org/wiki/File :Autoencoder_schema.png

        简而言之,网络将输入数据压缩为潜在向量(也称为嵌入),然后将其解压缩回来。这两个阶段称为编码解码

        变分自动编码器(VAE)看起来非常相似,除了中间的嵌入部分。对于每个输入,VAE 的编码器输出潜在空间中预定义分布的参数,而不是潜在空间中的向量:

来源:https ://commons.wikimedia.org/wiki/File:Reparameterized_Variational_Autoencoder.png

        最后一张图片:如果我们处理的是图像输入,我们需要一个卷积VAE,如下所示:

来源:https://github.com/arthurmeyer/Saliency_Detection_Convolutional_Autoencoder

        注意#1:观察编码器部分如何在每一层中添加越来越多的滤波器,图像变得越来越小;解码器则相反。

        注意#2:注意符号。如果只有一个通道,则术语“过滤器”和“内核”基本相同。对于多个通道,每个过滤器都是一组内核。查看这篇很棒的 Medium 文章:“直观地理解深度学习的卷积”。

三、CelebA数据集

我将使用的数据集是 CelebA,其中包含 202,599 张名人面孔图像。

CelebA 数据集

CelebFaces Attributes Dataset (CelebA) 是一个大规模人脸属性数据集,包含超过 20 万张名人图像……

可以通过以下方式访问它torchvision:

from torchvision.datasets import CelebA

train_dataset = CelebA(path, split='train')
test_dataset = CelebA(path, split='valid') # or 'test'

四、VAE类

        我的 VAE 基于此PyTorch 示例和存储库的普通 VAE模型(将我使用的普通 VAE 替换为中的任何其他模型PyTorch-VAE应该不会太难)。PyTorch-VAE

        该文件vae.py包含VAE类以及图像大小的定义、两个潜在向量的维度(均值和方差)以及数据集的路径:

CELEB_PATH = './data/'
IMAGE_SIZE = 150
LATENT_DIM = 128
image_dim = 3 * IMAGE_SIZE * IMAGE_SIZE

在课堂上VAE,我使用了以下隐藏过滤器维度:

hidden_dims = [32, 64, 128, 256, 512]

编码器看起来像这样:

in_channels = 3
modules = []
for h_dim in hidden_dims:
    modules.append(
        nn.Sequential(
            nn.Conv2d(in_channels, out_channels=h_dim,
                      kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(h_dim),
            nn.LeakyReLU())
    )
    in_channels = h_dim
self.encoder = nn.Sequential(*modules)

然后是潜在向量:

self.fc_mu = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)
self.fc_var = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)

        最后我们用解码器“倒退”:

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
   modules.append(
      nn.Sequential(
         nn.ConvTranspose2d(hidden_dims[i],
                            hidden_dims[i + 1],
                            kernel_size=3,
                            stride=2,
                            padding=1,
                            output_padding=1),
         nn.BatchNorm2d(hidden_dims[i + 1]),
         nn.LeakyReLU())
   )

self.decoder = nn.Sequential(*modules)

这就是它的要点——还有一些零碎的内容vae.py可以完成这VAE门课。

五、训练

        该文件trainvae.py包含训练我们刚刚编码的 VAE 的代码。老实说,没什么花哨的......有 3 个主要函数:(train随着训练的进行,它也输出损失值),test(它还构建一个重建图像的小样本)和loss_function。训练和测试相当普通,损失函数是标准 VAE,带有重建组件 (MSE) 和 KL 散度组件。

        epoch 上的主循环执行 4 个操作:1) train、2) test、3) 生成随机潜在向量并调用decode以输出相应的输出图像,以及 4) 将 epoch 的模型保存到文件中pth

        以下是示例运行的输出。通过 20 个训练周期,您最终会得到 20 个重建图像文件、20 个潜在采样文件和 20 个 python 模型文件:

        这里reconstruction_20.png,顶行显示 8 张原始图片,底行显示经过训练的 VAE 的相应重建。

        在 epoch 20 时从模型重建(输出)图像。

        这里的sample_20.png,显示了从随机潜在向量生成的 64 张图像:

        只是为了好玩,我添加了一小段代码 — genpics.py— 从数据集中挑选一个随机图像并生成 7 个重建。以下是一些示例(最左边的图像是原始图像):

        最后,我再次放置 GitHub 链接。享受!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1173587.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Markdown语法教程

Markdown:一种轻量级语言,有简洁的编写方式,能够提高大家的工作效率。 一、标题 1.1 标题 标题的编写格式以#号开始,分别表示h1 ~ h6,注意:# 后面有空格! # 一级标题 ## 二级标题 ### 三级标题…

K8s学习笔记——认识理解篇

1. K8s诞生背景 回顾应用的部署,经历了以下几个阶段: 传统部署:物理服务器上运行应用程序。虚拟机部署:物理服务器上安装虚拟机,在虚拟机上运行应用程序。容器部署:物理服务器上安装容器运行时&#xff0…

MINIO minio 安装 报错 问题

minio MINIO 安装 报错 问题 前言问题1问题产生原因分析解决方案 问题2原因分析解决方案 问题3问题产生原因分析解决方案 问题4问题产生 原因分析解决方案 问题5问题产生 原因分析解决方案 关键词: 1: WARNING: MINIO_ACCESS_KEY and MINIO_SECRET_KEY are deprecat…

全能数据分析软件 Tableau Desktop 2019 mac中文版功能亮点

Tableau Desktop 2019 mac是一款专业的全能数据分析工具,可以让用户将海量数据导入并记性汇总,并且支持多种数据类型,比如像是编程常用的键值对、哈希MAP、JSON类型数据等,因此用户可以将很多常用数据库文件直接导入Tableau Deskt…

2019 ICPC 银川题解(A)

赛时没发挥好6题金尾(rank38),剩下很多能写的题,其中四个dp,傻眼ing A Girls Band Party(背包) 有点迷惑的题,当时看只要 5 5 5 张牌一下子想到暴力枚举,结果发现是不…

【漏洞复现】Redis未授权

感谢互联网提供分享知识与智慧,在法治的社会里,请遵守有关法律法规 文章目录 1.1、漏洞描述1.2、影响版本1.3、漏洞复现环境搭建写入计划任务获取反弹shell利用SSH写公钥,免密登陆利用Redis写WebShell 3、检测未授权访问漏洞检测工具参考 1.1…

《 Hello 算法 》 - 免费开源的数据结构与算法入门教程电子书,包含大量动画、图解,通俗易懂

这本学习算法的电子书应该是我看过这方面最好的书了,代码例子有多种编程语言,JavaScript 也支持。 《 Hello 算法 》,英文名称是 Hello algo,是一本关于编程中数据解构和算法入门的电子书,作者是毕业于上海交通大学的…

零代码复现-TCGA联合GEO免疫基因结合代谢基因生信套路(三)

前面的分析中,整理好的关键基因集表达谱矩阵,接下来就准备分子亚型的相关分析。 六、一致性聚类构建分子亚型 在6.TCGA和GEO差异基因获取和预后数据的整理\TCGA文件中获取文件 准备一个生存数据和表达谱矩阵,这里需要注意的是,…

基于Jenkins实现接口自动化持续集成,学完涨薪5k

一、JOB项目配置 1、添加描述 可选选项可填可不填 2、限制项目的运行节点 节点中要有运行环境所需的配置 节点配置教程:https://blog.csdn.net/YZL40514131/article/details/131504280 3、源码管理 需要将脚本推送到远程仓库中 4、构建触发器 可以选择定时构建…

【爬虫实战】用python爬取微博任意关键词搜索结果、exe文件

项目功能简介: 1.交互式配置; 2.两种任意关键词来源(直接输入、本地关键词文件); 3.自动翻页(无限爬取); 4.指定最大翻页页码; 5.数据保存到csv文件; 6.程序支持打包成exe文件…

Go类型嵌入介绍和使用类型嵌入模拟实现“继承”

Go类型嵌入介绍和使用类型嵌入模拟实现“继承” 文章目录 Go类型嵌入介绍和使用类型嵌入模拟实现“继承”一、独立的自定义类型二、继承三、类型嵌入3.1 什么是类型嵌入 四、接口类型的类型嵌入4.1 接口类型的类型嵌入介绍4.2 一个小案例 五、结构体类型的类型嵌入5.1 结构体类…

视频视觉效果制作After Effects 2023 MacOS中文

After Effects 2023是一款业界领先的动态图形和视觉特效软件。它提供了强大的工具集,帮助用户创建引人入胜的视觉效果、动态图形和电影级特效。新的版本带来了更快的渲染速度、增强的图像处理和优化的工作流程,使用户能够更高效地工作。无论您是在电影、…

Kotlin 进阶函数式编程技巧

Kotlin 进阶函数式编程技巧 Kotlin 简介 软件开发环境不断变化,要求开发人员不仅适应,更要进化。Kotlin 以其简洁的语法和强大的功能迅速成为许多人进化过程中的信赖伙伴。虽然 Kotlin 的初始吸引力可能是它的简洁语法和与 Java 的互操作性&#xff0c…

idea文件比对

idea文件比对 1.项目内的文件比对2.项目间的文件比对3. 剪切板对比4. 版本历史(不同分支和不同commit)对比 1.项目内的文件比对 在项目中选择好需要比对的文件(类),然后选择Compare Files Mac下的快捷键是Commandd, 这样的比对像是git冲突解决一样 …

Peter算法小课堂—单调子序列

最长上升子序列 dp解法: f[i]表示以i结尾的最长上升子序列的长度 按照倒数第二个选谁分类: 我们先扫描i号元素前的每个元素(正向),找出第一个比i号元素小的元素k号。①仍然选i号元素,f[i]。②选k号&…

selenium自动化测试入门 —— cookie 处理

driver.get_cookies() # 获得cookie 信息 driver.get_cookies(name) # 获得对应name的cookie信息 add_cookie(cookie_dict) # 向cookie 添加会话信息 delete_cookie(name) # 删除特定(部分)的cookie delete_all_cookies() # 删除所有cookie 示例: from sel…

2022年电工杯数学建模B题5G网络环境下应急物资配送问题求解全过程论文及程序

2022年电工杯数学建模 B题 5G网络环境下应急物资配送问题 原题再现: 一些重特大突发事件往往会造成道路阻断、损坏、封闭等意想不到的情况,对人们的日常生活会造成一定的影响。为了保证人们的正常生活,将应急物资及时准确地配送到位尤为重要…

MapReduce:大数据处理的范式

一、介绍 在当今的数字时代,生成和收集的数据量正以前所未有的速度增长。这种数据的爆炸式增长催生了大数据领域,传统的数据处理方法往往不足。MapReduce是一个编程模型和相关框架,已成为应对大数据处理挑战的强大解决方案。本文探讨了MapRed…

【深度学习】pytorch——Autograd

笔记为自我总结整理的学习笔记,若有错误欢迎指出哟~ 深度学习专栏链接: http://t.csdnimg.cn/dscW7 pytorch——Autograd Autograd简介requires_grad计算图没有梯度追踪的张量ensor.data 、tensor.detach()非叶子节点的梯度计算图特点总结 利用Autograd实…

全网最详细的【shell脚本的入门】

🏅我是默,一个在CSDN分享笔记的博主。📚📚 ​ 🌟在这里,我要推荐给大家我的专栏《Linux》。🎯🎯 🚀无论你是编程小白,还是有一定基础的程序员,这…