VAE——生成数字(Pytorch+mnist)

news2025/4/8 9:15:44

1、简介

  • VAE(变分自编码器)同样由编码器和解码器组成,但与AE不同的是,VAE通过引入隐变量并利用概率分布来学习潜在表示。
  • VAE的编码器学习将输入数据映射到潜在空间的概率分布的参数,而不是直接映射到确定性的潜在表示。
  • VAE的解码器则通过从编码器学得的概率分布中采样,从而生成样本。
  • VAE的训练目标既包括最小化重构误差,也包括最大化编码器输出的潜在空间与单位高斯分布之间的KL散度,以促使学得的潜在表示更接近于标准正态分布。
  • VAE可以生成更连续、更具表现力的样本,并且具有更强的概率建模能力。
  • 本文利用VAE,输入数字图像。训练后,生成新的数字图像。
    • (100个epochs的结果)
  • 【注】本文案例输出的是随机的64个数字。

2、代码

  • import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    import torch.nn.functional as F
    from torchvision.utils import save_image
    
    
    # 变分自编码器
    class VAE(nn.Module):
        def __init__(self):
            super(VAE, self).__init__()
    
            # 编码器层
            self.fc1 = nn.Linear(input_size, 512)  # 编码器输入层
            self.fc2 = nn.Linear(512, latent_size)
            self.fc3 = nn.Linear(512, latent_size)
    
            # 解码器层
            self.fc4 = nn.Linear(latent_size, 512)  # 解码器输入层
            self.fc5 = nn.Linear(512, input_size)  # 解码器输出层
    
        # 编码器部分
        def encode(self, x):
            x = F.relu(self.fc1(x))  # 编码器的隐藏表示
            mu = self.fc2(x)  # 潜在空间均值
            log_var = self.fc3(x)  # 潜在空间对数方差
            return mu, log_var
    
        # 重参数化技巧
        def reparameterize(self, mu, log_var):  # 从编码器输出的均值和对数方差中采样得到潜在变量z
            std = torch.exp(0.5 * log_var)  # 计算标准差
            eps = torch.randn_like(std)  # 从标准正态分布中采样得到随机噪声
            return mu + eps * std  # 根据重参数化公式计算潜在变量z
    
        # 解码器部分
        def decode(self, z):
            z = F.relu(self.fc4(z))  # 将潜在变量 z 解码为重构图像
            return torch.sigmoid(self.fc5(z))  # 将隐藏表示映射回输入图像大小,并应用 sigmoid 激活函数,以产生重构图像
    
        # 前向传播
        def forward(self, x):  # 输入图像 x 通过编码器和解码器,得到重构图像和潜在变量的均值和对数方差
            mu, log_var = self.encode(x.view(-1, input_size))
            z = self.reparameterize(mu, log_var)
            return self.decode(z), mu, log_var
    
    
    # 使用重构损失和 KL 散度作为损失函数
    def loss_function(recon_x, x, mu, log_var):  # 参数:重构的图像、原始图像、潜在变量的均值、潜在变量的对数方差
        MSE = F.mse_loss(recon_x, x.view(-1, input_size), reduction='sum')  # 计算重构图像 recon_x 和原始图像 x 之间的均方误差
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())  # 计算潜在变量的KL散度
        return MSE + KLD  # 返回二进制交叉熵损失和 KLD 损失的总和作为最终的损失值
    
    
    def sample_images(epoch):
        with torch.no_grad():  # 上下文管理器,确保在该上下文中不会进行梯度计算。因为在这里只是生成样本而不需要梯度
            number = 64
            sample = torch.randn(number, latent_size).to(device)  # 生成一个形状为 (64, latent_size) 的张量,其中包含从标准正态分布中采样的随机数
            sample = model.decode(sample).cpu()  # 将随机样本输入到解码器中,解码器将其映射为图像
            save_image(sample.view(number, 1, 28, 28), f'sample{epoch}.png')  # 将生成的图像保存为文件
    
    
    if __name__ == '__main__':
        batch_size = 512  # 批次大小
        epochs = 100  # 学习周期
        sample_interval = 10  # 保存结果的周期
        learning_rate = 0.001  # 学习率
        input_size = 784  # 输入大小
        latent_size = 64  # 噪声大小
    
        # 载入 MNIST 数据集中的图片进行训练
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # 将图像转换为张量
    
        train_dataset = torchvision.datasets.MNIST(
            root="~/torch_datasets", train=True, transform=transform, download=True
        )  # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 True
    
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )  # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据
    
        # 在使用定义的 AE 类之前,有以下事情要做:
        # 配置要在哪个设备上运行
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
        # 建立 VAE 模型并载入到 CPU 设备
        model = VAE().to(device)
    
        # Adam 优化器,学习率
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
        # 训练
        for epoch in range(epochs):
            train_loss = 0
            for batch_idx, (data, _) in enumerate(train_loader):
                data = data.to(device)  # 将输入数据移动到设备(GPU 或 CPU)上
    
                optimizer.zero_grad()  # 进行反向传播之前,需要将优化器中的梯度清零,以避免梯度的累积
    
                # 重构图像 recon_batch、潜在变量的均值 mu 和对数方差 log_var
                recon_batch, mu, log_var = model(data)
    
                loss = loss_function(recon_batch, data, mu, log_var)  # 计算损失
                loss.backward()  # 计算损失相对于模型参数的梯度
                train_loss += loss.item()
    
                optimizer.step()  # 更新模型参数
    
            train_loss = train_loss / len(train_loader)  # # 计算每个周期的训练损失
            print('Epoch [{}/{}], Loss: {:.3f}'.format(epoch + 1, epochs, train_loss))
    
            # 每10次保存图像
            if (epoch + 1) % sample_interval == 0:
                sample_images(epoch + 1)
    
            # 每训练10次保存模型
            if (epoch + 1) % sample_interval == 0:
                torch.save(model.state_dict(), f'vae{epoch + 1}.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1557668.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Astro 宣布:将超过 500 多个测试从 Mocha 迁移到了 Node.js

近期,Astro 在其官方博客中宣布,虽然我们对 Mocha 感到满意,但也在寻求让我们的 CI 作业更快的方法。最终将超过 500 多个测试从 Mocha 迁移到了 Node.js。 先了解下 Astro 是什么?Astro 是适合构建像博客、营销网站、电子商务网站…

2_1.Linux中的网络配置

#1.什么是IP ADDRESS# internet protocol ADDRESS ##网络进程地址 ipv4 internet protocol version 4 ip是由32个01组成 11111110.11111110.11111110.11111110 254.254.254.254 #2.子网掩码# 用来划分网络区域 子网掩码非0的位对应的ip上的数字表示这个ip的网络位 子网掩码0位…

TC16-161T+ 音频 信号变压器 RF Transformers 600kHz-160MHz 射频集成电路 Mini-Circuits

Mini-Circuits是一家全球领先的射频、微波和毫米波元器件及子系统制造商。TC16-161T是Mini-Circuits出产的一款射频IC(射频集成电路),具有平衡-不平衡转换器功用。制造商: Mini-Circuits 产品品种: 音频变压器/信号变压器 RoHS…

JWFD流程图转换为矩阵数据库的过程说明

在最开始设计流程图的时候,请务必先把开始节点和结束节点画到流程图上面,就是设计器面板的最开始两个按钮,先画开始点和结束点,再画中间的流程,然后保存,这样提交到矩阵数据库就不会出任何问题,…

实习管理系统的设计与实现|Springboot+ Mysql+Java+ B/S结构(可运行源码+数据库+设计文档)

本项目包含可运行源码数据库LW,文末可获取本项目的所有资料。 推荐阅读100套最新项目持续更新中..... 2024年计算机毕业论文(设计)学生选题参考合集推荐收藏(包含Springboot、jsp、ssmvue等技术项目合集) 1. 前台功能…

zabbix主动发现,注册及分布式监控

主动发现 结果 主动注册 结果 分布式监控 服务机:132 代理机:133 客户端:135 代理机 数据库赋权: 代理机配置 网页上配置代理 客户端配置 网页上配置主机 重启代理机服务 网页效果

开源知识库平台Raneto--使用Docker部署Raneto

文章目录 一、Raneto介绍1.1 Raneto简介1.2 知识库介绍 二、阿里云环境2.1 环境规划2.2 部署介绍 三、本地环境检查3.1 检查Docker服务状态3.2 检查Docker版本3.3 检查docker compose 版本 四、下载Raneto镜像五、部署Raneto知识库平台5.1 创建挂载目录5.2 编辑config.js文件5.…

书生·浦语训练营二期第一次笔记

文章目录 书生浦语大模型全链路开源体系视频笔记Intern2模型体系 训练数据集书生浦语全链条开源开放体系开放高质量语料数据预训练微调中立全面性能榜单大模型评测全栈工具链部署 书生浦语大模型全链路开源体系-Bilibili视频InternLM2技术报告(中文)Inte…

C#基础知识总结

C语言、C和C#的区别 ✔ 面向对象编程(OOP): C 是一种过程化的编程语言,它不直接支持面向对象编程。然而,C 是一种支持 OOP 的 C 的超集,它引入了类、对象、继承、多态等概念。C# 是完全面向对象的&#xff…

酷柚易讯无人空间小程序注册后需开通的部分接口权限

注意:无人共享小程序注册认证与备案后,需要开通以下接口系统才能正常使用! 登录小程序后,找到开发管理->接口设置(申请对应的接口权限)

设计模式10--适配器模式

定义 案例一 案例二 优缺点

用html实现一个手风琴相册设计

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>手风琴相册设计</title><link rel"stylesheet" href"./style.css"> </head> <body> <h1>Accordio…

S7-1500PLC与ABB机器人RobotStudio调试演示

(1)建立空工作站 (2)选择机器人、导入吸盘、托盘、传送带 (3) 将导入的吸盘变为工具 (4)创建机器人系统 布局如下 (5)创建物体 (6)设置物体本地原点 (7)创建传送带Smart组件 (8)创建吸盘Smart组件 将吸盘的传感器拖到吸盘上 (9)示教目标点 (10)同步示教点 (11)添加信号 创建…

汽车租赁(源码+文档)

汽车租赁&#xff08;小程序、ios、安卓都可部署&#xff09; 文件包含内容程序简要说明含有功能项目截图客户端登录界面首页订单个人信息我的界面新手指引注册界面车型选择支付界面修改信息 管理端用户管理订单管理分类管理 文件包含内容 1、搭建视频 2、流程图 3、开题报告 …

Uibot6.0 (RPA财务机器人师资培训第4天 )批量开票机器人案例实战

类似于小北之前发布的一篇博客&#xff08;不能说很像&#xff0c;只能说是一模一样&#xff09; Uibot (RPA设计软件&#xff09;财务会计Web应用自动化(批量开票机器人&#xff09;-CSDN博客https://blog.csdn.net/Zhiyilang/article/details/136782171?spm1001.2014.3001.…

搜索与图论——染色法判定二分图

一个图是二分图当且仅当这个图中不含奇数环 由于图中没有奇数环&#xff0c;所以染色过程中一定没有矛盾 所以一个二分图一定可以成功被二染色&#xff0c;反之在二染色的过程中出现矛盾的图中一定有奇数环&#xff0c;也就一定不是二分图 #include<iostream> #includ…

Unity LineRenderer的基本了解

在Unity中&#xff0c;LineRenderer组件用于在场景中绘制简单的线条。它通常用于绘制轨迹、路径、激光等效果。 下面来了解下它的基本信息。 1、创建 法1&#xff1a;通过代码创建 using UnityEngine;public class CreateLineRenderer : MonoBehaviour {void Start(){// 创…

C# wpf 嵌入winform控件

WPF Hwnd窗口互操作系列 第一章 嵌入Hwnd窗口 第二章 嵌入WinForm控件&#xff08;本章&#xff09; 第三章 嵌入WPF控件 第四章 底部嵌入HwndHost 文章目录 WPF Hwnd窗口互操作系列前言一、导入WinForm1、.Net Framwork&#xff08;1&#xff09;、右键添加引用&#xff08;2…

Linux:对TCP阻塞控制/面向字节流/异常的理解

文章目录 阻塞控制面向字节流TCP链接异常 本篇总结TCP的最后一点小知识 阻塞控制 首先对于阻塞控制的概念是&#xff0c;它是和网络环境息息相关的 如果在发送数据的时候出现问题&#xff0c;不仅仅是由于对方链接出错&#xff0c;其实还和当前的网络情况有关&#xff0c;假…

金三银四面试题(八):JVM常见面试题(2)

今天我们继续探讨常见的JVM面试题。这些问题不比之前的问题庞大&#xff0c;多用于面试中​JVM部分的热身运动&#xff0c;开胃菜&#xff0c;但是大家已经要认真准备。 JRE、JDK、JVM 及JIT 之间有什么不同&#xff1f; JRE 代表Java 运行时&#xff08;Java run-time&#…