Tensorflow2.0笔记 - AutoEncoder做FashionMnist数据集训练

news2024/10/6 12:34:54

        本笔记记录自编码器做FashionMnist数据集训练,关于autoencoder的原理,请自行百度。

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input,losses
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras.models import Model

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#加载fashion mnist数据集
(x_train, _), (x_test, _) = datasets.fashion_mnist.load_data()
#图片像素数据范围限值到[0,1]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)

h_dim = 64 
class Autoencoder(Model):
  def __init__(self, h_dim):
    super(Autoencoder, self).__init__()
    self.h_dim = h_dim   
    #encoder层,[b, 28, 28] => [b, 784] => [b, h_dim]
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(256, activation='relu'),
      layers.Dense(h_dim, activation='relu'),
    ])
    #decoder层,[b, h_dim] => [b,784] => [b, 28, 28]
    self.decoder = tf.keras.Sequential([
      layers.Dense(784, activation='sigmoid'),
      #恢复成28x28的图片
      layers.Reshape((28, 28))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

model = Autoencoder(h_dim)

model.compile(optimizer='adam', loss=losses.MeanSquaredError())
model.fit(x_train, x_train,
                epochs=10,
                shuffle=True,
                validation_data=(x_test, x_test))


encoded_imgs = model.encoder(x_test).numpy()
decoded_imgs = model.decoder(encoded_imgs).numpy()
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  #绘制原始图像
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  #绘制重建的图像
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1695495.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

17.7K星开源产品分析平台:Posthog

Posthog:开源洞察,产品优化的得力助手 - 精选真开源,释放新价值。 概览 PostHog是一个全面开源的平台,旨在帮助团队构建更好的产品。它提供了从产品分析到会话回放、功能标志和A/B测试等一系列工具,支持自托管&#x…

Keras深度学习框架第二十讲:使用KerasCV中的Stable Diffusion进行高性能图像生成

1、绪论 1.1 概念 为便于后文讨论,首先进行相关概念的陈述。 Stable Diffusion:Stable Diffusion 是一个在图像生成领域广泛使用的技术,尤其是用于文本到图像的转换。它基于扩散模型(Diffusion Models),这…

Springboot017学生读书笔记共享

springboot005学生心理咨询评估系统 亲测完美运行带论文:获取源码,私信评论或者v:niliuapp 运行视频 Springboot017学生读书笔记共享 包含的文件列表(含论文) 数据库脚本:db.sql其他文件:ppt.ppt论文&am…

Java+Swing+Mysql实现飞机订票系统

一、系统介绍 1.开发环境 操作系统:Win10 开发工具 :Eclipse2021 JDK版本:jdk1.8 数据库:Mysql8.0 2.技术选型 JavaSwingMysql 3.功能模块 4.数据库设计 1.用户表(users) 字段名称 类型 记录内容…

【Linux】Linux的安装

文章目录 一、Linux环境的安装虚拟机 镜像文件云服务器(可能需要花钱) 未完待续 一、Linux环境的安装 我们往后的学习用的Linux版本为——CentOs 7 ,使用 Ubuntu 也可以 。这里提供几个安装方法: 电脑安装双系统(不…

OpenWrt 安装Quagga 支持ospf Bgp等动态路由协议 软路由实测 系列四

1 Quagga 是一个路由软件套件, 提供 OSPFv2,OSPFv3,RIP v1 和 v2,RIPng 和 BGP-4 的实现. 2 web 登录安装 #或者ssh登录安装 opkg install quagga quagga-zebra quagga-bgpd quagga-watchquagga quagga-vtysh # reboot 3 ssh 登录 #重启服务 /etc/init.d/quagga restart #…

Linux下Vision Mamba环境配置+多CUDA版本切换

上篇文章大致讲了下Vision Mamba的相关知识,网上关于Vision Mamba的配置博客太多,笔者主要用来整合下。 笔者在Win10和Linux下分别尝试配置相关环境。 Win10下配置 失败 \textcolor{red}{失败} 失败,最后出现的问题如下: https://…

ps进程查看命令详解

1、PS 命令是什么 查看它的man手册可以看到,ps命令能够给出当前系统中进程的快照。它能捕获系统在某一事件的进程状态。如果你想不断更新查看的这个状态,可以使用top命令。 2、ps命令支持三种使用的语法格式 UNIX 风格,选项可以组合在一起…

flutter 实现旋转星球

先看效果 planet_widget.dart import dart:math; import package:flutter/material.dart; import package:vector_math/vector_math_64.dart show Vector3; import package:flutter/gestures.dart; import package:flutter/physics.dart;class PlanetWidget extends StatefulW…

Android14 - 绘制系统 - 概览

从Android 12开始,Android的绘制系统有结构性变化, 在绘制的生产消费者模式中,新增BLASTBufferQueue,客户端进程自行进行queue的生产和消费,随后通过Transation提交到SurfaceFlinger,如此可以使得各进程将缓…

Altium Designer 中键拖动,滚轮缩放,并修改缩放速度

我的版本是AD19,其他版本应该都一样。 滚轮缩放 首先,要用滚轮缩放,先要调整一下AD 设置,打开Preferences,在Mouse Wheel Configuration 里,把Zoom Main Window 后面Ctrl 上的对勾取消掉,再把…

翻译《The Old New Thing》- What‘s the deal with the EM_SETHILITE message?

Whats the deal with the EM_SETHILITE message? - The Old New Thing (microsoft.com)https://devblogs.microsoft.com/oldnewthing/20071025-00/?p24693 Raymond Chen 2007年10月25日 简要 文章讨论了EM_SETHILITE和EM_GETHILITE消息在文档中显示为“未实现”的原因。这些…

C语言 | Leetcode C语言题解之第112题路径总和

题目: 题解: bool hasPathSum(struct TreeNode *root, int sum) {if (root NULL) {return false;}if (root->left NULL && root->right NULL) {return sum root->val;}return hasPathSum(root->left, sum - root->val) ||ha…

与用户沟通获取需求的方法

1 访谈 访谈是最早开始使用的获取用户需求的技术,也是迄今为止仍然广泛使用的需求分析技术。 访谈有两种基本形式,分别是正式的和非正式的访谈。正式访谈时,系统分析员将提出一些事先准备好的具体问题,例如&#xff0…

C++_string简单源码剖析:模拟实现string

文章目录 🚀1.构造与析构函数🚀2.迭代器🚀3.获取🚀 4.内存修改🚀5. 插入🚀6. 删除🚀7. 查找🚀8. 交换swap🚀9. 截取substr🚀10. 比较符号重载🚀11…

pod install 报错 ‘SDK does not contain ‘libarclite‘ at the path...‘

报错内容: SDK does not contain ‘libarclite’ at the path ‘/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/arc/libarclite_iphoneos.a’; 这是报错已经很明确告诉我们,Xcode默认的工具链中缺少一个工具…

Linux Tcpdump抓包入门

Linux Tcpdump抓包入门 一、Tcpdump简介 tcpdump 是一个在Linux系统上用于网络分析和抓包的强大工具。它能够捕获网络数据包并提供详细的分析信息,有助于网络管理员和开发人员诊断网络问题和监控网络流量。 安装部署 # 在Debian/Ubuntu上安装 sudo apt-get install…

爬虫100个Python例子优化

今天看到一个Python 100例的在线资源,感觉每个都需要去点,太费时间了,于是,使用Python将数据爬取下来,方便查看。实际效果如下: 。。。。。。 用了13分钟,当然,这是优化后的效果,如果没有优化,需要的时间更长。 爬取url如下: https://www.runoob.com/python/pytho…

Pytorch深度学习实践笔记1

🎬个人简介:一个全栈工程师的升级之路! 📋个人专栏:pytorch深度学习 🎀CSDN主页 发狂的小花 🌄人生秘诀:学习的本质就是极致重复! 《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibi…

【数据结构与算法 刷题系列】移除链表元素

💓 博客主页:倔强的石头的CSDN主页 📝Gitee主页:倔强的石头的gitee主页 ⏩ 文章专栏:数据结构与算法刷题系列(C语言) 期待您的关注 目录 一、问题描述 二、解题思路 三、源代码实现 一、问题…