AE——重构数字(Pytorch+mnist)

news2024/11/25 0:21:18

1、简介

  • AE(自编码器)由编码器和解码器组成,编码器将输入数据映射到潜在空间,解码器将潜在表示映射回原始输入空间。
  • AE的训练目标通常是最小化重构误差,即尽可能地重构输入数据,使得解码器输出与原始输入尽可能接近。
  • AE通常用于数据压缩、去噪、特征提取等任务。
  • 本文利用AE,输入数字图像。训练后,输入测试数字图像,重构生成新的数字图像。
    • 【注】本文案例需要输入才能生成输出,目标是重构,而不是生成。
  • 可以看出,重构图片和原始图片差别不大。 
  • 【注】输出的10张数字图像是输入的测试图像的第一批次。

2、代码

  • import matplotlib.pyplot as plt
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    
    
    # 在一个类中编写编码器和解码器层。为编码器和解码器层的组件都定义了全连接层
    class AE(nn.Module):
        def __init__(self, **kwargs):
            super().__init__()
            self.encoder_hidden_layer = nn.Linear(
                in_features=kwargs["input_shape"], out_features=128
            )  # 编码器隐藏层
            self.encoder_output_layer = nn.Linear(
                in_features=128, out_features=128
            )  # 编码器输出层
            self.decoder_hidden_layer = nn.Linear(
                in_features=128, out_features=128
            )  # 解码器隐藏层
            self.decoder_output_layer = nn.Linear(
                in_features=128, out_features=kwargs["input_shape"]
            )  # 解码器输出层
    
        # 定义了模型的前向传播过程,包括激活函数的应用和重构图像的生成
        def forward(self, features):
            activation = self.encoder_hidden_layer(features)
            activation = torch.relu(activation)  # ReLU 激活函数,得到编码器的激活值
            code = self.encoder_output_layer(activation)
            code = torch.sigmoid(code)  # Sigmoid 激活函数,以确保编码后的表示在 [0, 1] 范围内
            activation = self.decoder_hidden_layer(code)
            activation = torch.relu(activation)
            activation = self.decoder_output_layer(activation)
            reconstructed = torch.sigmoid(activation)
            return reconstructed
    
    
    if __name__ == '__main__':
        # 设置批大小、学习周期和学习率
        batch_size = 512
        epochs = 30
        learning_rate = 1e-3
    
        # 载入 MNIST 数据集中的图片进行训练
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # 将图像转换为张量
    
        train_dataset = torchvision.datasets.MNIST(
            root="~/torch_datasets", train=True, transform=transform, download=True
        )  # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 True
    
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )  # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据
    
        # 在使用定义的 AE 类之前,有以下事情要做:
        # 配置要在哪个设备上运行
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
        # 建立 AE 模型并载入到 CPU 设备
        model = AE(input_shape=784).to(device)
    
        # Adam 优化器,学习率 10e-3
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
        # 使用均方误差(MSE)损失函数
        criterion = nn.MSELoss()
    
        # 在GPU设备上运行,实例化一个输入大小为784的AE自编码器,并用Adam作为训练优化器用MSELoss作为损失函数
        # 训练:
        for epoch in range(epochs):
            loss = 0
            for batch_features, _ in train_loader:
                # 将小批数据变形为 [N, 784] 矩阵,并加载到 CPU 设备
                batch_features = batch_features.view(-1, 784).to(device)
    
                # 梯度设置为 0,因为 torch 会累加梯度
                optimizer.zero_grad()
    
                # 计算重构
                outputs = model(batch_features)
    
                # 计算训练重建损失
                train_loss = criterion(outputs, batch_features)
    
                # 计算累积梯度
                train_loss.backward()
    
                # 根据当前梯度更新参数
                optimizer.step()
    
                # 将小批量训练损失加到周期损失中
                loss += train_loss.item()
    
            # 计算每个周期的训练损失
            loss = loss / len(train_loader)
    
            # 显示每个周期的训练损失
            print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))
    
        # 用训练过的自编码器提取一些测试用例来重构
        test_dataset = torchvision.datasets.MNIST(
            root="~/torch_datasets", train=False, transform=transform, download=True
        )  # 加载 MNIST 测试数据集
    
        test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=10, shuffle=False
        )  # 创建一个测试数据加载器
    
        test_examples = None
    
        # 通过循环遍历测试数据加载器,获取一个批次的图像数据
        with torch.no_grad():  # 使用 torch.no_grad() 上下文管理器,确保在该上下文中不会进行梯度计算
            for batch_features in test_loader:  # 历测试数据加载器中的每个批次的图像数据
                batch_features = batch_features[0]  # 获取当前批次的图像数据
                test_examples = batch_features.view(-1, 784).to(
                    device)  # 将当前批次的图像数据转换为大小为 (批大小, 784) 的张量,并加载到指定的设备(CPU 或 GPU)上
                reconstruction = model(test_examples)  # 使用训练好的自编码器模型对测试数据进行重构,即生成重构的图像
                break
    
        # 试着用训练过的自编码器重建一些测试图像
        with torch.no_grad():
            number = 10  # 设置要显示的图像数量
            plt.figure(figsize=(20, 4))  # 创建一个新的 Matplotlib 图形,设置图形大小为 (20, 4)
            for index in range(number):  # 遍历要显示的图像数量
                # 显示原始图
                ax = plt.subplot(2, number, index + 1)
                plt.imshow(test_examples[index].cpu().numpy().reshape(28, 28))
                plt.gray()
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
    
                # 显示重构图
                ax = plt.subplot(2, number, index + 1 + number)
                plt.imshow(reconstruction[index].cpu().numpy().reshape(28, 28))
                plt.gray()
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
            plt.savefig('reconstruction_results.png')  # 保存图像
            plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1557774.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是nginx正向代理和反向代理?

什么是代理? 代理(Proxy), 简单理解就是自己做不了的事情或实现不了的功能,委托别人去做。 什么是正向代理? 在nginx中,正向代理指委托者是客户端,即被代理的对象是客户端 在这幅图中,由于左边内网中…

如何解决kafka rebalance导致的暂时性不能消费数据问题

文章目录 背景思考答案排它故障转移共享 背景 之前在review同组其它业务的时候,发现竟然把kafka去掉了,问了下原因,有一个单独的服务,我们可以把它称为agent,就是这个服务是动态扩缩容的,会采集一些指标&a…

k8s的pod访问service的方式

背景 在k8s中容器访问某个service服务时有两种方式,一种是把每个要访问的service的ip注入到客户端pod的环境变量中,另一种是客户端pod先通过DNS服务器查找对应service的ip地址,然后在通过这个service ip地址访问对应的service服务 pod客户端…

HarmonyOS 应用开发之FA模型访问Stage模型DataShareExtensionAbility

概述 无论FA模型还是Stage模型,数据读写功能都包含客户端和服务端两部分。 FA模型中,客户端是由DataAbilityHelper提供对外接口,服务端是由DataAbility提供数据库的读写服务。 Stage模型中,客户端是由DataShareHelper提供对外接…

腾讯云2核2G服务器优惠价格,61元一年

腾讯云2核2G服务器多少钱一年?轻量服务器61元一年,CVM 2核2G S5服务器313.2元15个月,轻量2核2G3M带宽、40系统盘,云服务器CVM S5实例是2核2G、50G系统盘。腾讯云2核2G服务器优惠活动 txybk.com/go/txy 链接打开如下图:…

java数组与集合框架(三)--Map,Hashtable,HashMap,LinkedHashMap,TreeMap

Map集合: Map接口: 基于 键(key)/值(value)映射 Map接口概述 Map与Collection并列存在。用于保存具有映射关系的数据:key-value Map 中的key 和value 都可以是任何引用类型的数据Map 中的key 用Set来存放&#xff0…

X进制减法(蓝桥杯)

文章目录 X进制减法题目描述解题思路贪心算法模拟减法(大数相减) X进制减法 题目描述 进制规定了数字在数位上逢几进一。 X 进制是一种很神奇的进制,因为其每一数位的进制并不固定!例如说某种 X 进制数,最低数位为二…

创建Qt Quick Projects

在创建Qt Quick项目之前,我们简单说一下Qml和Qt Quick的关系:它们的关系类似于C和STL标准库的关系,Qml类比C语言,提供了基本语言特性和类型;而Qt Quick则类比STL标准库,Qt Quick在QML的基础上加入了一系列界…

Https【Linux网络编程】

目录 一、为什么需要https 二、常见加密方法 1、对称加密 2、非对称加密 3、数据指纹 三、选择什么加密方案? 方案一:对称加密() 方案二:双方使用非对称加密(效率低) 方案三&#xff1a…

深度学习十大算法之Diffusion扩散模型

1. 引言 扩散模型在近年来成为了热门话题,其火速蹿红主要归功于在图像生成领域的突破应用。尤其是一些从文本到图像的生成技术,它们成功地运用了扩散模型来创建令人惊叹的逼真图像。如果你听说过某个应用能够迅速且高质量地生成图像,那么很可…

【SpringBoot整合系列】SpirngBoot整合EasyExcel

目录 背景需求发展 EasyExcel官网介绍优势常用注解 SpringBoot整合EaxyExcel1.引入依赖2.实体类定义实体类代码示例注解解释 3.自定义转换器转换器代码示例涉及的枚举类型 4.Excel工具类5.简单导出接口SQL 6.简单导入接口SQL 7.复杂的导出(合并行、合并列&#xff0…

docker 共享网络的方式实现容器互联

docker 共享网络的方式实现容器互联 本文以nacos连接mysql为例 前提已经在mysql容器中初始化好nacos数据库,库名nacos 创建一个共享网络 docker network create --driver bridge \ --subnt 192.168.0.0/24 \ --gateway 192.168.0.1 mynet此处可以不指定网络模式、…

Python下载bing每日壁纸并实现win11 壁纸自动切换

前言: 爬虫哪家强,当然是python 我是属于啥语言都用,都懂点,不精通,实际工作中能能够顶上就可以。去年写的抓取bing每日的壁纸,保存到本地,并上传到阿里云oss,如果只是本地壁纸切换,存下来就行,一直想做个壁纸站点&…

Java代码基础算法练习-自定义函数之字符串连接-2024.03.30

任务描述: 写一函数,将两个字符串连接起来,然后在主函数中调用该函数实现字符串连接操作。 任务要求: 代码示例: package M0317_0331;import java.util.Scanner;public class m240330 {public static void main(Stri…

【Java】MyBatis快速入门及详解

文章目录 1. MyBatis概述2. MyBatis快速入门2.1 创建项目2.2 添加依赖2.3 数据准备2.4 编写代码2.4.1 编写核心配置文件2.4.2 编写SQL映射文件2.4.3 编写Java代码 3. Mapper代理开发4. MyBatis核心配置文件5. 案例练习5.1 数据准备5.2 查询数据5.2.1 查询所有数据5.2.2 查询单条…

全国青少年软件编程(Python)等级考试三级考试真题2023年12月——持续更新.....

青少年软件编程(Python)等级考试试卷(三级) 分数:100 题数:38 一、单选题(共25题,共50分) 1.一个非零的二进制正整数,在其末尾添加两个“0”,则该新数将是原数的&#xf…

Redis从入门到精通(二)Redis的数据类型和常见命令介绍

文章目录 前言第2章 Redis数据类型和常见命令2.1 key结构2.2 Redis通用命令2.3 String类型及其常用命令2.4 Hash类型及其常用命令2.5 List类型2.5 Set类型2.6 SortedSet类型2.7 小结 前言 在上一节【Redis从入门到精通(一)Redis安装与启动、Redis客户端的使用】中,…

【智能算法】猎人猎物算法(HPO)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献 1.背景 2022年,Naruei等人受到自然界动物猎食过程启发,提出了猎人猎物算法(Hunter-Prey Optimization, HPO)。 2.算法原理 2.1算法思想 HPO模拟…

国产AI大模型推荐(一)

文心一言 主要功能: 各种类型的问答、各种文本创作、推理与数学计算、写代码、聊天交流、图片生成等。 链接:文心一言 讯飞星火 特点: 内容生成能力:我可以进行多风格多任务长文本生成,例如邮件、文案、公文、作文、对…

剑指Offer题目笔记23(归并排序)

面试题77: 问题: ​ 输入一个链表的头节点,将该链表排序。 解决方案: ​ 使用归并排序。将链表分为两个子链表,在对两个子链表排序后再将它们合并为一个排序的链表。 源代码: /*** Definition for sin…