CVAE——生成0-9数字图像(Pytorch+mnist)

news2024/12/24 21:16:56

1、简介

  • CVAE(Conditional Variational Autoencoder,条件变分自编码器)是一种变分自编码器(VAE)的变体,用于生成有条件的数据。在传统的变分自编码器中,生成的数据是完全由潜在变量决定的,而CVAE允许在生成过程中引入外部条件信息。
  • 具体来说,CVAE在生成数据时,除了使用随机采样的潜在变量外,还会接收一个额外的条件信息。这个条件信息可以是类别标签、属性信息、或者其他形式的上下文信息,取决于具体的任务。通过将条件信息作为输入提供给编码器和解码器,CVAE能够生成与条件信息相关的数据。
  • CVAE在许多任务中都很有用,例如图像生成中的类别条件生成、属性编辑、生成特定风格的图像等。通过引入条件信息,CVAE使得生成的数据更具有控制性和可解释性。
  • 本文利用CVAE,输入数字图像和对应的标签。训练后,生成0-9数字图像。
    • (epoch=10)
    • (epoch=20)
    • (epoch=30)

2、代码

  • import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    import torch.nn.functional as F
    from torchvision.utils import save_image
    
    
    # 变分自编码器
    class CVAE(nn.Module):
        def __init__(self):
            super(CVAE, self).__init__()
            self.labels = 10  # 标签数量
    
            # 编码器层
            self.fc1 = nn.Linear(input_size + self.labels, 512)  # 编码器输入层
            self.fc2 = nn.Linear(512, latent_size)
            self.fc3 = nn.Linear(512, latent_size)
    
            # 解码器层
            self.fc4 = nn.Linear(latent_size + self.labels, 512)  # 解码器输入层
            self.fc5 = nn.Linear(512, input_size)  # 解码器输出层
    
        # 编码器部分
        def encode(self, x):
            x = F.relu(self.fc1(x))  # 编码器的隐藏表示
            mu = self.fc2(x)  # 潜在空间均值
            log_var = self.fc3(x)  # 潜在空间对数方差
            return mu, log_var
    
        # 重参数化技巧
        def reparameterize(self, mu, log_var):  # 从编码器输出的均值和对数方差中采样得到潜在变量z
            std = torch.exp(0.5 * log_var)  # 计算标准差
            eps = torch.randn_like(std)  # 从标准正态分布中采样得到随机噪声
            return mu + eps * std  # 根据重参数化公式计算潜在变量z
    
        # 解码器部分
        def decode(self, z):
            z = F.relu(self.fc4(z))  # 将潜在变量 z 解码为重构图像
            return torch.sigmoid(self.fc5(z))  # 将隐藏表示映射回输入图像大小,并应用 sigmoid 激活函数,以产生重构图像
    
        # 前向传播
        def forward(self, x, y):  # 输入图像 x,标签 y 通过编码器和解码器,得到重构图像和潜在变量的均值和对数方差
            x = torch.cat([x, y], dim=1)
            mu, log_var = self.encode(x)
            z = self.reparameterize(mu, log_var)
            z = torch.cat([z, y], dim=1)
            return self.decode(z), mu, log_var
    
    
    # 使用重构损失和 KL 散度作为损失函数
    def loss_function(recon_x, x, mu, log_var):  # 参数:重构的图像、原始图像、潜在变量的均值、潜在变量的对数方差
        MSE = F.mse_loss(recon_x, x.view(-1, input_size), reduction='sum')  # 计算重构图像 recon_x 和原始图像 x 之间的均方误差
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())  # 计算潜在变量的KL散度
        return MSE + KLD  # 返回二进制交叉熵损失和 KLD 损失的总和作为最终的损失值
    
    
    def sample_images(epoch):
        with torch.no_grad():  # 上下文管理器,确保在该上下文中不会进行梯度计算。因为在这里只是生成样本而不需要梯度
            number = 10
            # 生成标签
            sample_labels = torch.arange(10).long().to(device)  # 0-9的标签
            sample_labels_onehot = F.one_hot(sample_labels, num_classes=10).float()
            # 生成随机噪声
            sample = torch.randn(number, latent_size).to(device)  # 生成一个形状为 (64, latent_size) 的张量,其中包含从标准正态分布中采样的随机数
            sample = torch.cat([sample, sample_labels_onehot], dim=1)  # 连接图片和标签
    
            sample = model.decode(sample).cpu()  # 将随机样本输入到解码器中,解码器将其映射为图像
            save_image(sample.view(number, 1, 28, 28), f'sample{epoch}.png', nrow=int(number / 2))  # 将生成的图像保存为文件
    
    
    if __name__ == '__main__':
        batch_size = 512  # 批次大小
        epochs = 30  # 学习周期
        sample_interval = 10  # 保存结果的周期
        learning_rate = 0.001  # 学习率
        input_size = 784  # 输入大小
        latent_size = 64  # 潜在变量大小
    
        # 载入 MNIST 数据集中的图片进行训练
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # 将图像转换为张量
    
        train_dataset = torchvision.datasets.MNIST(
            root="~/torch_datasets", train=True, transform=transform, download=True
        )  # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 True
    
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )  # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据
    
        # 在使用定义的 AE 类之前,有以下事情要做:
        # 配置要在哪个设备上运行
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
        # 建立 CVAE 模型并载入到 CPU 设备
        model = CVAE().to(device)
    
        # Adam 优化器,学习率
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
        # 训练
        for epoch in range(epochs):
            train_loss = 0
            for batch_idx, (data, labels) in enumerate(train_loader):
                data = data.to(device)  # 将输入数据移动到设备(GPU 或 CPU)上
                data = data.view(-1, input_size)  # 重塑维度
    
                labels = F.one_hot(labels, num_classes=10).float().to(device)  # 转换为独热编码
                # print(labels[1])
    
                optimizer.zero_grad()  # 进行反向传播之前,需要将优化器中的梯度清零,以避免梯度的累积
    
                # 重构图像 recon_batch、潜在变量的均值 mu 和对数方差 log_var
                recon_batch, mu, log_var = model(data, labels)
    
                loss = loss_function(recon_batch, data, mu, log_var)  # 计算损失
                loss.backward()  # 计算损失相对于模型参数的梯度
                train_loss += loss.item()
    
                optimizer.step()  # 更新模型参数
    
            train_loss = train_loss / len(train_loader)  # # 计算每个周期的训练损失
            print('Epoch [{}/{}], Loss: {:.3f}'.format(epoch + 1, epochs, train_loss))
    
            # 每10次保存图像
            if (epoch + 1) % sample_interval == 0:
                sample_images(epoch + 1)
    
            # 每训练10次保存模型
            if (epoch + 1) % sample_interval == 0:
                torch.save(model.state_dict(), f'vae{epoch + 1}.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1559735.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

9.图像中值腐蚀膨胀滤波的实现

1 简介 在第七章介绍了基于三种卷积前的图像填充方式,并生成了3X3的图像卷积模板,第八章运用这种卷积模板进行了均值滤波的FPGA实现与MATLAB实现,验证了卷积模板生成的正确性和均值滤波算法的MATLAB算法实现。   由于均值滤波、中值滤波、腐…

【QT+QGIS跨平台编译】054:【exiv2lib+Qt跨平台编译】(一套代码、一套框架,跨平台编译)

点击查看专栏目录 文章目录 一、exiv2lib介绍二、文件下载三、文件分析四、pro文件五、编译实践一、exiv2lib介绍 exiv2lib 是一个用于处理图像元数据的开源 C++ 库。它可用于读取、编辑和写入图像文件中的 Exif 元数据(Exchangeable Image File Format,可交换图像文件格式)…

怎么打包出release.aar包

第一种 选择build variant 更改成release 第二钟 在gradle中选择相应任务来编译 选择assemble release如果没有这个选项,可能是你没有开启那个Task 收集的选项

机器学习——降维算法-奇异值分解(SVD)

机器学习——降维算法-奇异值分解(SVD) 在机器学习中,降维是一种常见的数据预处理技术,用于减少数据集中特征的数量,同时保留数据集的主要信息。奇异值分解(Singular Value Decomposition,简称…

为 Linux 中的 Docker 配置阿里云和网易云国内镜像加速下载中心

由于默认情况下,Docker 的镜像下载中心默认为国外的镜像中心,使用该镜像中心拉去镜像会十分缓慢,所以我们需要配置国内的 Docker 镜像下载中心,加速 Docker 镜像的拉取。Docker 的国内镜像下载中心常用的有:阿里云、网…

微信小程序(黑马优购:购物车页面)

1.渲染商品页面 <template><view><!-- 商品列表的标题区域 --><view class"cart-title"><!-- 左侧的图标 --><uni-icons type"shop" size"18"></uni-icons><!-- 右侧的文本 --><text class…

力扣 1143. 最长公共子序列

题目来源&#xff1a;https://leetcode.cn/problems/longest-common-subsequence/description/ C题解&#xff08;思路来源代码随想录&#xff09;&#xff1a;动态规划。 1. 确定dp数组&#xff08;dp table&#xff09;以及下标的含义 dp[i][j]&#xff1a;长度为[0, i - 1]…

Python之Opencv进阶教程(1):图片模糊

1、Opencv提供了多种模糊图片的方法 加载原始未经模糊处理的图片 import cv2 as cvimg cv.imread(../Resources/Photos/girl.jpg) cv.imshow(girl, img)1.1 平均值 关键代码 # Averaging 平均值 average cv.blur(img, (3, 3)) cv.imshow(Average Blur, average)实现效果 1.2…

备战蓝桥杯---贪心刷题1

话不多说&#xff0c;直接看题&#xff1a; 本质是一个数学题&#xff1a; 我们令xi<0表示反方向传递&#xff0c;易得我们就是求每一个xi的绝对值之和min,我们令平均值为a爸。 易得约束条件&#xff1a; x1-x2a1-a,x2-x3a2-a..... 解得x1x1-0,x2x1-((n-1)*a-a2-...an)。…

通过搜索引擎让大模型获取实时数据-实现类似 perplexity 的效果

文章目录 一、前言二、初衷三、实现方式四、总结 一、前言 汇报一下这周末的工作&#xff0c;主要是开发了一门课程&#xff1a;通过搜索引擎让大模型获取实时数据&#xff0c;第一次开发一门课程&#xff0c;难免会有很多不熟悉和做的不好的地方。 已经训练好的大模型有气数…

今天起,Windows可以一键召唤GPT-4了

ChatGPT狂飙160天&#xff0c;世界已经不是之前的样子。 新建了人工智能中文站https://ai.weoknow.com 每天给大家更新可用的国内可用chatGPT资源 发布在https://it.weoknow.com 更多资源欢迎关注 微软 AI 大计的最后一块拼图完成了&#xff1f; 把 Copilot 按钮放在 Window…

【Linux】权限的基本概念

在本篇博客中&#xff0c;作者将会讲解在linux系统中&#xff0c;权限的基本概念。 一.什么是权限 通俗的讲&#xff0c;权限是用来约束人的。比如说&#xff1a;你买了某软件的vip会员&#xff0c;那么你就可以执行相对操作&#xff0c;如果你没买&#xff0c;则就会有权限约束…

Linux的中间件

我们先补充点关于awk的内容 awk的用法其实很广。 $0 表示整条记录 变量&#xff1a; NF 一行中有多少个字段&#xff08;表示字段数&#xff09; NR &#xff1a; 代表当前记录的序号&#xff0c;从1开始计数。每读取一条记录&#xff0c;NR的值就会自动增加1。&#xff08;…

基于ssm旅游资源网站(java项目+文档+源码)

风定落花生&#xff0c;歌声逐流水&#xff0c;大家好我是风歌&#xff0c;混迹在java圈的辛苦码农。今天要和大家聊的是一款基于ssm的旅游资源网站。项目源码以及部署相关请联系风歌&#xff0c;文末附上联系信息 。 项目简介&#xff1a; 旅游资源网站的主要使用者分为管理…

稀碎从零算法笔记Day35-LeetCode:字典序的第K小数字

要考虑完结《稀碎从零》系列了哈哈哈 这道题和【LC.42 接雨水】&#xff0c;我愿称之为【笔试界的颜良&文丑】 题型&#xff1a;字典树、前缀获取、数组、树的先序遍历 链接&#xff1a;440. 字典序的第K小数字 - 力扣&#xff08;LeetCode&#xff09; 来源&#xff1…

Pytorch 下载失败原因

错误信息&#xff1a; ERROR: Could not find a version that satisfies the requirement torch (from versions: none) ERROR: No matching distribution found for torch 解决方案&#xff1a; 在官网看到&#xff0c;它需要python3.8-3.11的环境。过高和过低的版本都不…

番外篇 | 手把手教你如何用YOLOv8实现行人/车辆等过线统计

前言:Hello大家好,我是小哥谈。目标检测行人/车辆等过线统计是一种常见的视频分析任务,用于统计行人/车辆等在指定区域内过线的次数。这个任务通常需要使用目标检测算法来识别行人/车辆等,并使用计数器算法来统计过线的次数。🌈 目录 🚀1.本文介绍 🚀2.实现

LeetCode刷题【链表,图论,回溯】

目录 链表138. 随机链表的复制148. 排序链表146. LRU 缓存 图论200. 岛屿数量994. 腐烂的橘子207. 课程表 回溯 链表 138. 随机链表的复制 给你一个长度为 n 的链表&#xff0c;每个节点包含一个额外增加的随机指针 random &#xff0c;该指针可以指向链表中的任何节点或空节…

2024年泰迪杯数据挖掘B题详细思路代码文章教程

目前b题已全部更新包含详细的代码模型和文章&#xff0c;本文也给出了结果展示和使用模型说明。 同时文章最下方包含详细的视频教学获取方式&#xff0c;手把手保姆级&#xff0c;模型高精度&#xff0c;结果有保障&#xff01; 分析&#xff1a; 本题待解决问题 目标&#…

K8S之Secret的介绍和使用

Secret Secret的介绍Secret的使用通过环境变量引入Secret通过volume挂载Secret Secret的介绍 Secret是一种保护敏感数据的资源对象。例如&#xff1a;密码、token、秘钥等&#xff0c;而不需要把这些敏感数据暴露到镜像或者Pod Spec中。Secret可以以Volume或者环境变量的方式使…