【深度学习】在 MNIST实现自动编码器实践教程

news2025/1/15 20:09:44

一、说明

        自动编码器是一种无监督学习的神经网络模型,主要用于降维或特征提取。常见的自动编码器包括基本的单层自动编码器、深度自动编码器、卷积自动编码器和变分自动编码器等。

        其中,基本的单层自动编码器由一个编码器和一个解码器组成,编码器将输入数据压缩成低维数据,解码器将低维数据还原成原始数据。深度自动编码器是在单层自动编码器的基础上增加了多个隐藏层,可以实现更复杂的特征提取。卷积自动编码器则是针对图像等数据特征提取的一种自动编码器,它使用卷积神经网络进行特征提取和重建。变分自动编码器则是一种生成式模型,可以用于生成新的数据样本。

        总的来说,不同类型的自动编码器适用于不同类型的数据和问题,选择合适的自动编码器可以提高模型的性能。

二、在Minist数据集实现自动编码器

2.1 概述

        本文中的代码用于在 MNIST 数据集上训练自动编码器。自动编码器是一种旨在重建其输入的神经网络。在此脚本中,自动编码器由两个较小的网络组成:编码器和解码器。编码器获取输入图像,将其压缩为 64 个特征,并将编码表示传递给解码器,然后解码器重建输入图像。自动编码器通过最小化重建图像和原始图像之间的均方误差来训练。该脚本首先加载 MNIST 数据集并规范化像素值。然后,它将图像重塑为一维表示,以便可以将其输入神经网络。之后,使用tensorflow.keras库中的输入层和密集层创建编码器和解码器模型。自动编码器模型是通过链接编码器和解码器模型创建的。然后使用亚当优化器和均方误差损失函数编译自动编码器。最后,自动编码器在归一化和重塑的MNIST图像上训练25个epoch。通过绘制训练集和测试集在 epoch 上的损失来监控训练进度。训练后,脚本绘制一些测试图像及其相应的重建。此外,还计算了原始图像和重建图像之间的均方误差和结构相似性指数(SSIM)。

        下图显示了模型的良好拟合,可以看到模型的良好拟合。

训练和测试数据的模型丢失

        该代码比较两个图像,一个来自测试集的原始图像和一个由自动编码器生成的预测图像。它使用该函数计算两个图像之间的均方误差 (MSE),并使用 scikit-image 库中的函数计算两个图像之间的结构相似性指数 (SSIM)。根据 mse 和 ssim 代码检索test_labels以打印测试图像的值。msessim

2.2 代码实现

import numpy as np
import tensorflow
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.layers import Layer 
from skimage import metrics
## import os can be skipped if there is nocompatibility issue 
## with the OpenMP library and TensorFlow 
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


# Load the MNIST dataset
(x_train, train_labels), (x_test, test_labels) = mnist.load_data()

# Normalize the data
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.


# Flatten the images
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

 
# Randomize both the training and test
permutation = np.random.permutation(len(x_train))
x_train, train_labels = x_train[permutation], train_labels[permutation]
permutation = np.random.permutation(len(x_test))
x_test, test_labels = x_test[permutation], test_labels[permutation]
# Create the encoder


list_xtest = [ [x_test[i], test_labels[i]] for i in test_labels] 
print(len(list_xtest)) 

encoder_input = Input(shape=(784,))
encoded = Dense(64, activation='relu')(encoder_input)
encoder = Model(encoder_input, encoded)

# Create the decoder
decoder_input = Input(shape=(64,))
decoded = Dense(784, activation='sigmoid')(decoder_input)
decoder = Model(decoder_input, decoded)

# Create the autoencoder
autoencoder = Model(encoder_input, decoder(encoder(encoder_input)))

lr_schedule = tensorflow.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate = 5e-01, decay_steps = 2500, decay_rate = 0.75,staircase=True) 
tensorflow.keras.optimizers.Adam(learning_rate = lr_schedule,beta_1=0.95,beta_2=0.99,epsilon=1e-01)
autoencoder.compile(optimizer='adam', loss='mean_squared_error')


# Train the autoencoder
history = autoencoder.fit(x_train, x_train,
                epochs=25,
                batch_size=512,
                shuffle=True,
                validation_data=(x_test, x_test))

# Plot the training history
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper right')
plt.show()

# Plot the test figures vs. predicted figures
decoded_imgs = autoencoder.predict(x_test)


def mse(imageA, imageB):
    err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
    err /= float(imageA.shape[0])
    return err

def ssim(imageA, imageB):
    return metrics.structural_similarity(imageA, imageB,channel_axis=None)

decomser = [] 
decossimr = [] 
n = 10
list_xtestn = [ [x_test[i], test_labels[i]] for i in range(10)] 
print([list_xtestn[i][1] for i in range(n)]) 
plt.figure(figsize=(20, 4))
for i in range(n):
    # Display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # Display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
    if mse(list_xtestn[i][0],decoded_imgs[i]) <= 0.01: 
        msel = mse(list_xtestn[i][0],decoded_imgs[i])
        decomser.append(list_xtestn[i][1])  
    if ssim(list_xtestn[i][0],decoded_imgs[i]) > 0.85:
        ssiml = ssim(list_xtestn[i][0],decoded_imgs[i])
        decossimr.append(list_xtestn[i][1])   
    print("mse and ssim for image %s are %s and %s" %(i,msel,ssiml)) 
plt.show() 

print(decomser)
print(decossimr)

三、实验的部分结果示例 

        该模型可以预测手写数据,如下所示。

原始数据和预测数据

        此外,使用MSE和ssim方法将预测图像与测试图像进行比较,可以访问test_labels并打印预测数据。

预测和测试图像的 MSE 和 SSM 值,以及 SSE 和 SSIM 方法test_labels返回的数字列表

        此代码演示如何使用自动编码器通过图像比较教程来训练和建立手写识别网络。一开始,训练和测试图像是随机的,因此每次运行的图像集都不同。

        在另一篇文章中,我们将展示如何使用 Padé 近似值作为自动编码器 (link.medium.com/cqiP5bd9ixb) 的激活函数。

引用:

  1. 原始的MNIST数据集:LeCun,Y.,Cortes,C.和Burges,C.J.(2010)。MNIST手写数字数据库。AT&T 实验室 [在线]。可用: http://yann。莱昆。com/exdb/mnist/
  2. 自动编码器概念和应用:Hinton,G.E.和Salakhutdinov,R.R.(2006)。使用神经网络降低数据的维数。科学, 313(5786), 504–507.
  3. 使用自动编码器进行图像重建:Masci,J.,Meier,U.,Cireşan,D.和Schmidhuber,J.(2011年52月)。用于分层特征提取的堆叠卷积自动编码器。在人工神经网络国际会议(第 59-<> 页)中。施普林格,柏林,海德堡。
  4. The tensorflow.keras library: Chollet, F. (2018).使用 Python 进行深度学习。纽约州谢尔特岛:曼宁出版公司
  5. 均方误差损失函数和亚当优化器:Kingma,D.P.和Ba,J.(2014)。Adam:一种随机优化的方法。arXiv预印本arXiv:1412.6980。
  6. 结构相似性指数(SSIM):Wang,Z.,Bovik,A.C.,Sheikh,H.R.和Simoncelli,E.P.(2004)。图像质量评估:从错误可见性到结构相似性。IEEE图像处理事务,13(4),600-612。
  7. 弗朗西斯·贝尼斯坦特

    ·

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/830112.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OLED透明屏安装指南:准备工作、步骤和注意事项

随着科技的不断发展&#xff0c;OLED透明屏作为一种新型的显示技术&#xff0c;逐渐得到了广泛的应用。 OLED透明屏具有高透明度、高亮度和广视角等优势&#xff0c;可以实现透明显示效果&#xff0c;为商业展示、户外广告等领域提供了更广阔的空间。 然而&#xff0c;正确的…

Qt实现可伸缩的侧边工具栏(鼠标悬浮控制伸缩栏)

Qt实现可伸缩的侧边工具栏 一直在网上找&#xff0c;发现大多的实现方案都是用一个按钮&#xff0c;按下控制侧边栏的伸缩&#xff0c;但是我想要实现鼠标悬浮在侧边栏的时候就伸出&#xff0c;移开就收缩的功能&#xff0c;也没找到好的参考&#xff0c;所以决定自己实现一个…

Apache Kafka Learning

一、Kafka Kafka是由Apache软件基金会开发的一个开源流处理平台&#xff0c;由Scala和Java编写。Kafka是一种高吞吐量的分布式发布订阅消息系统&#xff0c;它可以收集并处理用户在网站中的所有动作流数据以及物联网设备的采样信息。 Apache Kafka是Apache软件基金会的开源的流…

Quartz使用文档,使用Quartz实现动态任务,Spring集成Quartz,Quartz集群部署,Quartz源码分析

文章目录 一、Quartz 基本介绍二、Quartz Java 编程1、文档2、引入依赖3、入门案例4、默认配置文件 三、Quartz 重要组件1、Quartz架构体系2、JobDetail3、Trigger&#xff08;1&#xff09;代码实例&#xff08;2&#xff09;SimpleTrigger&#xff08;3&#xff09;CalendarI…

低代码开发工具到底是给“谁”用的?

不同的工具&#xff0c;受众也不一样。 你不要认为“低代码开发工具”只有一种&#xff0c;实际上它分 3 种。 第一种&#xff1a;企业级低代码开发平台 这种通常是给专业开发人员使用的&#xff0c;但也没有限制得很死&#xff0c;只要你懂编程逻辑&#xff0c;能写sql语句&…

[数据分析与可视化] Python绘制数据地图4-MovingPandas入门指北

MovingPandas是一个基于Python和GeoPandas的开源地理时空数据处理库&#xff0c;用于处理移动物体的轨迹数据。它提供了一组强大的工具&#xff0c;可以轻松地加载、分析和可视化移动物体的轨迹。通过使用MovingPandas&#xff0c;用户可以轻松地处理和分析移动对象数据&#x…

微信云开发-数据库操作

文章目录 前提初始化数据库插入数据查询数据获取一条数据获取多条数据查询指令 更新数据更新指令 删除数据总结 前提 首先有1个集合(名称:todos). 其中集合中的数据为: {// 计划描述"description": "learn mini-program cloud service",// 截止日期"…

阿里云OSS的开通+配置及其使用

云存储解决方案-阿里云OSS 文章目录 云存储解决方案-阿里云OSS1. 阿里云OSS简介2. OSS开通&#xff08;1&#xff09;打开https://www.aliyun.com/ &#xff0c;申请阿里云账号并完成实名认证。&#xff08;2&#xff09;充值 (可以不用做)&#xff08;3&#xff09;开通OSS&am…

小程序云开发快速入门(2/4)

前言 我们对《微信小程序云开发快速入门&#xff08;1/4&#xff09;》的知识进行回顾一下。在上章节我们知道了云开发的优势以及能力&#xff0c;并且我们还完成了码仔备忘录的本地版到网络版的改造&#xff0c;主要学习了云数据库同时还通过在小程序使用云API直接操作了云数…

选读SQL经典实例笔记16_逻辑否定

1. 示例数据 1.1. student insert into student values (1,AARON,20) insert into student values (2,CHUCK,21) insert into student values (3,DOUG,20) insert into student values (4,MAGGIE,19) insert into student values (5,STEVE,22) insert into student values (6…

Java内存溢出的排查工具和方法

JVM内存溢出事故回顾 JVM内存溢出的排查方法个工具介绍 事故回顾 • 9:58收到报警&#xff0c;资讯延时1小时。 • 10:10排查出接口全部超时&#xff0c;超时时间2s。 • 去运维那边执行jstat发现元空间沾满了&#xff0c;疯狂fgc。 • 执行jmap -dump 并下载。 • 使用MAT分…

VLAN原理+配置

目录 一&#xff0c; 以太网二层交换机 二&#xff0c;三层架构&#xff1a; 三&#xff0c;VLAN配置思路 1.创建vlan 2.接口划入vlan 3.trunk干道 4.vlan间路由器 5.DHCP池塘配置 四&#xff0c;华为VLAN部分的接口模式讲解&#xff1a; 五&#xff0c;华为VLAN部分的…

【雕爷学编程】MicroPython动手做(30)——物联网之Blynk 2

知识点&#xff1a;什么是掌控板&#xff1f; 掌控板是一块普及STEAM创客教育、人工智能教育、机器人编程教育的开源智能硬件。它集成ESP-32高性能双核芯片&#xff0c;支持WiFi和蓝牙双模通信&#xff0c;可作为物联网节点&#xff0c;实现物联网应用。同时掌控板上集成了OLED…

阿里云出品—高分计算机好书推荐榜

1、云原生架构白皮书 云原生是一种构建和运行应用程序的方法&#xff0c;它能实现构建应用简便快捷&#xff0c;部署应用轻松自如&#xff0c;越来越多公司和个人选择使用云原生技术。《云原生架构白皮书》作为业界首本全方位构建云原生架构规划与实践全景图的白皮书&#xff…

【牛客】统计字符

⭐️ 题目描述 &#x1f31f; OJ链接&#xff1a;HJ40 统计字符 ps&#xff1a; 判断字符可以直接使用头文件自带的函数。 函数作用iscntrl判断是否为控制字符isspace判断是否为空白字符&#xff08;空格、换页’\f’、换行’\n’、回车’\r’、制表符’\t&#xff09;isdigi…

「应用实时监控 ARMS 」斩获「根因分析技术」先进级认证

阿里云云原生可观测 ARMS 率先斩获「根因分析技术」先进级认证 7 月 25 日&#xff0c;由中国信通院发起的“2023 可信云-系统稳定性”首批评估结果在可信云大会现场公布&#xff0c;应用实时监控服务 ARMS 斩获《可观测性标准体系要求 - 根因分析技术分级能力要求》“先进级”…

Pytorch深度学习之余弦退火学习率设置

1. 什么是余弦退火学习率&#xff1f; 余弦退火学习速率调度是改进深度神经网络学习过程的常用方法。当深度神经网络在大型数据集上训练时&#xff0c;它尤其有用&#xff0c;因为在大型数据集中&#xff0c;学习过程可能会陷入局部极小值。在训练过程中&#xff0c;学习率以不…

OpenMMLab【超级视客营】——把类别信息加入可视化结果中(MMSegmentation的第二个PR)

文章目录 1. 任务说明1.0 新手指引1.1 任务目标1.2 提交格式 2. 实施2.1 可视化的形式2.2 拉分支和提交PR2.2.1 拉分支2.2.2 提交PR 2.3 MMSegmentation中关于可视化的内容2.3.1 文档说明2.3.2 相关PR&#xff08;确定要修改的文件&#xff09;2.3.3 提交时的代码测试 2.4 发现…

java实现5种不同的验证码图片,包括中文、算式等,并返回前端

导入以下依赖 <!--图片验证码--><dependency><groupId>com.github.whvcse</groupId><artifactId>easy-captcha</artifactId><version>1.6.2</version></dependency> 编写controller package com.anXin.user.controlle…

Tessy 4.3.18

Tessy 4.3.18 windows 2692407267qq.com&#xff0c;更多内容请见http://user.qzone.qq.com/2692407267/