VAE-pytorch代码

news2024/11/16 3:42:30

 

 

 

import os
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
 
from torchvision import transforms, datasets
from torchvision.utils import save_image
 
from tqdm import tqdm
 
 
class VAE(nn.Module):  # 定义VAE模型
    def __init__(self, img_size, latent_dim):  # 初始化方法
        super(VAE, self).__init__()  # 继承初始化方法
        self.in_channel, self.img_h, self.img_w = img_size  # 由输入图片形状得到图片通道数C、图片高度H、图片宽度W
        self.h = self.img_h // 32  # 经过5次卷积后,最终特征层高度变为原图片高度的1/32
        self.w = self.img_w // 32  # 经过5次卷积后,最终特征层宽度变为原图片高度的1/32
        hw = self.h * self.w  # 最终特征层的尺寸hxw
        self.latent_dim = latent_dim  # 采样变量Z的长度
        self.hidden_dims = [32, 64, 128, 256, 512]  # 特征层通道数列表
        # 开始构建编码器Encoder
        layers = []  # 用于存放模型结构
        for hidden_dim in self.hidden_dims:  # 循环特征层通道数列表
            layers += [nn.Conv2d(self.in_channel, hidden_dim, 3, 2, 1),  # 添加conv
                       nn.BatchNorm2d(hidden_dim),  # 添加bn
                       nn.LeakyReLU()]  # 添加leakyrelu
            self.in_channel = hidden_dim  # 将下次循环的输入通道数设为本次循环的输出通道数
 
        self.encoder = nn.Sequential(*layers)  # 解码器Encoder模型结构
 
        self.fc_mu = nn.Linear(self.hidden_dims[-1] * hw, self.latent_dim)  # linaer,将特征向量转化为分布均值mu
        self.fc_var = nn.Linear(self.hidden_dims[-1] * hw, self.latent_dim)  # linear,将特征向量转化为分布方差的对数log(var)
        # 开始构建解码器Decoder
        layers = []  # 用于存放模型结构
        self.decoder_input = nn.Linear(self.latent_dim, self.hidden_dims[-1] * hw)  # linaer,将采样变量Z转化为特征向量
        self.hidden_dims.reverse()  # 倒序特征层通道数列表
        for i in range(len(self.hidden_dims) - 1):  # 循环特征层通道数列表
            layers += [nn.ConvTranspose2d(self.hidden_dims[i], self.hidden_dims[i + 1], 3, 2, 1, 1),  # 添加transconv
                       nn.BatchNorm2d(self.hidden_dims[i + 1]),  # 添加bn
                       nn.LeakyReLU()]  # 添加leakyrelu
        layers += [nn.ConvTranspose2d(self.hidden_dims[-1], self.hidden_dims[-1], 3, 2, 1, 1),  # 添加transconv
                   nn.BatchNorm2d(self.hidden_dims[-1]),  # 添加bn
                   nn.LeakyReLU(),  # 添加leakyrelu
                   nn.Conv2d(self.hidden_dims[-1], img_size[0], 3, 1, 1),  # 添加conv
                   nn.Tanh()]  # 添加tanh
        self.decoder = nn.Sequential(*layers)  # 编码器Decoder模型结构
 
    def encode(self, x):  # 定义编码过程
        result = self.encoder(x)  # Encoder结构,(n,1,32,32)-->(n,512,1,1)
        result = torch.flatten(result, 1)  # 将特征层转化为特征向量,(n,512,1,1)-->(n,512)
        mu = self.fc_mu(result)  # 计算分布均值mu,(n,512)-->(n,128)
        log_var = self.fc_var(result)  # 计算分布方差的对数log(var),(n,512)-->(n,128)
        return [mu, log_var]  # 返回分布的均值和方差对数
 
    def decode(self, z):  # 定义解码过程
        y = self.decoder_input(z).view(-1, self.hidden_dims[0], self.h,
                                       self.w)  # 将采样变量Z转化为特征向量,再转化为特征层,(n,128)-->(n,512)-->(n,512,1,1)
        y = self.decoder(y)  # decoder结构,(n,512,1,1)-->(n,1,32,32)
        return y  # 返回生成样本Y
 
    def reparameterize(self, mu, log_var):  # 重参数技巧
        std = torch.exp(0.5 * log_var)  # 分布标准差std
        eps = torch.randn_like(std)  # 从标准正态分布中采样,(n,128)
        return mu + eps * std  # 返回对应正态分布中的采样值
 
    def forward(self, x):  # 前传函数
        mu, log_var = self.encode(x)  # 经过编码过程,得到分布的均值mu和方差对数log_var
        z = self.reparameterize(mu, log_var)  # 经过重参数技巧,得到分布采样变量Z
        y = self.decode(z)  # 经过解码过程,得到生成样本Y
        return [y, x, mu, log_var]  # 返回生成样本Y,输入样本X,分布均值mu,分布方差对数log_var
 
    def sample(self, n, cuda):  # 定义生成过程
        z = torch.randn(n, self.latent_dim)  # 从标准正态分布中采样得到n个采样变量Z,长度为latent_dim
        if cuda:  # 如果使用cuda
            z = z.cuda()  # 将采样变量Z加载到GPU
        images = self.decode(z)  # 经过解码过程,得到生成样本Y
        return images  # 返回生成样本Y
 
 
def loss_fn(y, x, mu, log_var):  # 定义损失函数
    recons_loss = F.mse_loss(y, x)  # 重建损失,MSE
    kld_loss = torch.mean(0.5 * torch.sum(mu ** 2 + torch.exp(log_var) - log_var - 1, 1), 0)  # 分布损失,正态分布与标准正态分布的KL散度
    return recons_loss + w * kld_loss  # 最终损失由两部分组成,其中分布损失需要乘上一个系数w
 
 
if __name__ == "__main__":
    total_epochs = 100  # epochs
    batch_size = 64  # batch size
    lr = 5e-4  # lr
    w = 0.00025  # kld_loss的系数w
    num_workers = 8  # 数据加载线程数
    image_size = 32  # 图片尺寸
    image_channel = 1  # 图片通道
    latent_dim = 128  # 采样变量Z长度
    sample_images_dir = "sample_images"  # 生成样本示例存放路径
    train_dataset_dir = "../dataset/mnist"  # 训练样本存放路径
 
    os.makedirs(sample_images_dir, exist_ok=True)  # 创建生成样本示例存放路径
    os.makedirs(train_dataset_dir, exist_ok=True)  # 创建训练样本存放路径
    cuda = True if torch.cuda.is_available() else False  # 如果cuda可用,则使用cuda
    img_size = (image_channel, image_size, image_size)  # 输入样本形状(1,32,32)
 
    vae = VAE(img_size, latent_dim)  # 实例化VAE模型,传入输入样本形状与采样变量长度
    if cuda:  # 如果使用cuda
        vae = vae.cuda()  # 将模型加载到GPU
    # dataset and dataloader
    transform = transforms.Compose(  # 图片预处理方法
        [transforms.Resize(image_size),  # 图片resize,(28x28)-->(32,32)
         transforms.ToTensor(),  # 转化为tensor
         transforms.Normalize([0.5], [0.5])]  # 标准化
    )
    dataloader = DataLoader(  # 定义dataloader
        dataset=datasets.MNIST(root=train_dataset_dir,  # 使用mnist数据集,选择数据路径
                               train=True,  # 使用训练集
                               transform=transform,  # 图片预处理
                               download=True),  # 自动下载
        batch_size=batch_size,  # batch size
        num_workers=num_workers,  # 数据加载线程数
        shuffle=True  # 打乱数据
    )
    # optimizer
    optimizer = torch.optim.Adam(vae.parameters(), lr=lr)  # 使用Adam优化器
    # train loop
    for epoch in range(total_epochs):  # 循环epoch
        total_loss = 0  # 记录总损失
        pbar = tqdm(total=len(dataloader), desc=f"Epoch {epoch + 1}/{total_epochs}", postfix=dict,
                    miniters=0.3)  # 设置当前epoch显示进度
        for i, (img, _) in enumerate(dataloader):  # 循环iter
            if cuda:  # 如果使用cuda
                img = img.cuda()  # 将训练数据加载到GPU
            vae.train()  # 模型开始训练
            optimizer.zero_grad()  # 模型清零梯度
            y, x, mu, log_var = vae(img)  # 输入训练样本X,得到生成样本Y,输入样本X,分布均值mu,分布方差对数log_var
            loss = loss_fn(y, x, mu, log_var)  # 计算loss
            loss.backward()  # 反向传播,计算当前梯度
            optimizer.step()  # 根据梯度,更新网络参数
            total_loss += loss.item()  # 累计loss
            pbar.set_postfix(**{"Loss": loss.item()})  # 显示当前iter的loss
            pbar.update(1)  # 步进长度
        pbar.close()  # 关闭当前epoch显示进度
        print("total_loss:%.4f" %
              (total_loss / len(dataloader)))  # 显示当前epoch训练完成后,模型的总损失
        vae.eval()  # 模型开始验证
        sample_images = vae.sample(25, cuda)  # 获得25个生成样本
        save_image(sample_images.data, "%s/ep%d.png" % (sample_images_dir, (epoch + 1)), nrow=5,
                   normalize=True)  # 保存生成样本示例(5x5)

其中计算KLloss的代码的解释如下:

代码的目标是计算变分自编码器(VAE)中近似后验分布q(z∣x) 和标准正态分布 p(z) 之间的KL散度。KL散度公式的具体计算步骤如下:

1. mu ** 2

计算均值的平方项: μ2 这个项是为了衡量均值偏离零的程度。

2. torch.exp(log_var)

对数方差取指数,以获得实际的方差: exp⁡(log⁡(σ2))=σ2 这个项衡量方差的大小。

3. - log_var

减去对数方差: −log⁡(σ2) 这个项衡量分布的扩展程度。

4. - 1

减去 1,是KL散度公式中的常数项,用于归一化。

将这些项加在一起:

μ2+exp⁡(log⁡(σ2))−log⁡(σ2)−1

5. torch.sum(..., 1)

对所有维度求和,计算单个样本的KL散度: ∑(μ2+σ2−log⁡(σ2)−1) 这一步是将每个样本的所有维度的KL散度加起来。

6. 0.5 * ...

乘以 0.5,因KL散度公式中有系数 0.5: 0.5×∑(μ2+σ2−log⁡(σ2)−1)

7. torch.mean(..., 0)

对所有样本取平均,得到最终的KL散度损失: mean(0.5×∑(μ2+σ2−log⁡(σ2)−1))

整个公式的作用是计算出近似后验分布 q(z∣x) 和标准正态分布 p(z) 之间的KL散度,该散度表示了两个分布之间的差异。这种损失通常用于变分自编码器(VAE)训练中,确保生成的潜在变量分布接近标准正态分布。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1868558.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AI大模型】Transformers大模型库(十四):Datasets Viewer

目录 一、引言 二、Datasets Viewer数据查看器 2.1 概述 2.2 示例 三、总结 一、引言 这里的Transformers指的是huggingface开发的大模型库,为huggingface上数以万计的预训练大模型提供预测、训练等服务。 🤗 Transformers 提供了数以千计的预训练…

python-立方和不等式

[题目描述] 试求满足下述立方和不等式的 m 的整数解。 1^32^3...m^3≤n。本题算法如下: 对指定的 n,设置求和循环,从 i1 开始,i 递增1取值,把 i^3 (或 i∗i∗i)累加到 s,直至 s>n,脱离循环作…

VMware Workstation环境下,邮件(E-Mail)服务的安装配置,并用Windows7来验证测试

需求说明: 某企业信息中心计划使用IP地址17216.11.0用于虚拟网络测试,注册域名为xyz.net.cn.并将172.16.11.2作为主域名的服务器(DNS服务器)的IP地址,将172.16.11.3分配给虚拟网络测试的DHCP服务器,将172.16.11.4分配给虚拟网络测试的web服务器,将172.16.11.5分配给FTP服务器…

深度學習筆記14-CIFAR10彩色圖片識別(Pytorch)

🍨 本文為🔗365天深度學習訓練營 中的學習紀錄博客🍖 原作者:K同学啊 | 接輔導、項目定制 一、我的環境 電腦系統:Windows 10 顯卡:NVIDIA GeForce GTX 1060 6GB 語言環境:Python 3.7.0 開發…

ubuntu 18.04 server源码编译安装freeswitch 1.10.7支持音视频通话、收发短信——筑梦之路

软件版本说明 ubuntu版本18.04:https://releases.ubuntu.com/18.04.6/ubuntu-18.04.6-live-server-amd64.iso freeswitch 版本1.10.7:https://files.freeswitch.org/freeswitch-releases/freeswitch-1.10.7.-release.tar.gz spandsp包:https:…

家人们谁懂啊?手机信息删除找不回,原来3个技巧就能恢复

你是否也曾经遇到过这样的情况:手机里的重要信息不小心删除了,翻遍所有地方都找不到,心情烦躁到了极点。其实,信息删除后找回它们并不像你想象的那么复杂。不要担心,因为今天我将与你分享3个技巧,帮助你轻松…

1 哈希应用

O(1) 的哈希 Python中的哈希表主要通过内置的字典(dict)类型实现。对于字典的操作,包括插入(insert)、删除(delete)和查找(lookup)的时间复杂度,在理想情况下…

上位机图像处理和嵌入式模块部署(mcu之iap升级)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 mcu种类很多,如果是开发的时候需要对固件升级,整体还是比较容易的。不管是dap,还是st-link v2、j-link&#xf…

idea中使用springboot进行开发时遇到的工程结构问题汇总

idea中的工程结构和eclipse中不同,但是配置的内容都是一样的。 IDEA中也就是这个页面,快捷键ctrlaltshifts 如果在eclipse中,经常会遇到jre和jdk不正确的情况,但IDEA中这个问题很少,但是IDEA中会经常由于未正常配置根…

ESXI存储设备已经分区,无法创建数据存储。

问题:ESXi 存储设备已经分区完成,并且有 VMFS 文件系统,无法创建数据存储,选项是灰色。 解决办法:通过命令行工具 在现有的 VMFS 分区上创建数据存储。 在现有的 VMFS 分区上创建数据存储 1.ESXI开SSH2.windows自带CMD登入ESXI。&…

使用宝塔安装ModstartCMS (非一键安装)

操作系统 Linux Windows 推荐 Linux 操作系统,性能比较好 软件环境 稳定版 PHP 5.6 PHP 7.0 MySQL >5.0 PHP Extension:Fileinfo Apache/Nginx Laravel 9.0 版本 PHP 8.1 MySQL >5.0 PHP Extension:Fileinfo Apache/Ngin…

Windows 注册表是什么?如何备份注册表?

Windows注册表(Windows Registry)是微软Windows操作系统中的一个重要组件,用于存储系统和应用程序的配置信息和选项。下面就给大家详细讲解一下什么是注册表。 注册表的概念 Windows 注册表是一个集中管理的数据库,存储了系统、…

安霸CVFlow推理开发笔记

一、安霸环境搭建: 1.远程172.20.62.13 2. 打开Virtualbox,所在目录:E:\Program Files\Oracle\VirtualBox 3. 配置好ubuntu18.04环境,Ubuntu密码:amba 4. 安装toolchain,解压Ambarella_Toolchain_CNNGe…

成都工业学院2022级数据库原理及应用专周课程学生选课系统(基础篇)

运行环境 操作系统:Windows 11 家庭版 运行软件:Navicat Premium 16 项目内容 需求分析 学生:选课、退课、查看课程信息、查看选课情况等操作 教师:查看选课名单等操作 管理员:课程管理等操作 实体关系模式图 关…

js计算某个时间距离当前时间多少天,少于7天红色展示

效果图 后端返回数据格式 info:{vip_validity:"2027-09-07" }<div>到期时间&#xff1a;{{ info.vip_validity }}, 剩余<span :class"countdownDays(info.vip_validity) < 7 ? surplus : ">{{ !!info.vip_validity ? countdownDays(inf…

MySQL进阶-索引-使用规则-索引失效情况一(索引列运算,字符串不加引号,头部模糊匹配)

文章目录 1、索引列运算1.1、查询表tb_user1.2、查看tb_user的索引1.3、查询 phone177999900151.4、执行计划 phone177999900151.5、查询 substring(phone,10,2) 151.6、执行计划 substring(phone,10,2) 15 2、字符串不加引号2.1、查询 phone177999900152.2、执行计划 phone177…

强化学习详解:理论基础与核心算法解析

本文详细介绍了强化学习的基础知识和基本算法&#xff0c;包括动态规划、蒙特卡洛方法和时序差分学习&#xff0c;解析了其核心概念、算法步骤及实现细节。 关注TechLead&#xff0c;复旦AI博士&#xff0c;分享AI领域全维度知识与研究。拥有10年AI领域研究经验、复旦机器人智能…

【漏洞复现】用友GRP-U8——SQL注入

声明&#xff1a;本文档或演示材料仅供教育和教学目的使用&#xff0c;任何个人或组织使用本文档中的信息进行非法活动&#xff0c;均与本文档的作者或发布者无关。 文章目录 漏洞描述漏洞复现测试工具 漏洞描述 用友GRP-U8是一款企业管理软件&#xff0c;其系统dialog_moreUs…

客户案例|某 SaaS 企业租户敏感数据保护实践

近年来&#xff0c;随着云计算技术的快速发展&#xff0c;软件即服务&#xff08;SaaS&#xff09;在各行业的应用逐渐增多&#xff0c;SaaS 应用给企业数字化发展带来了便捷性、成本效益与可访问性&#xff0c;同时也带来了一系列数据安全风险。作为 SaaS 产品运营服务商&…

活动|华院计算受邀参加2024全球人工智能技术大会(GAITC),探讨法律大模型如何赋能社会治理

6月22至23日&#xff0c;备受瞩目的2024全球人工智能技术大会&#xff08;GAITC&#xff09;在杭州市余杭区未来科技城隆重举行。本届大会以“交叉、融合、相生、共赢”为主题&#xff0c;集“会、展、赛”为一体&#xff0c;聚“产、学、研”于一堂。值得一提的是&#xff0c;…