tensorflow fashion_mnist数据集模型训练及预测

news2024/11/14 2:49:58

✨ 博客主页:小小马车夫的主页
✨ 所属专栏:Tensorflow

文章目录

  • 前言
  • 一、环境
  • 二、fashion_mnist数据集介绍
  • 三、fashion_mnist数据集下载和展示
  • 四、数据预处理
  • 五、构建模型和训练模型
  • 六、模型预测
  • 总结


前言

前面介绍mnist手写数字集训练,本文对fashion_mnist数据集训练和预测进行简要介绍。


一、环境

MacOS: 13.0
Python: 3.9.13
Tensorflow: 2.11.0

二、fashion_mnist数据集介绍

fashion_mnist数据集和mnist数据集类似,都是28x28的灰度图片,区分是fashion_mnist数据集是服装图片,具体分类如下图:

分类英文描述中文描述
0t-shirtT恤
1trouser牛仔裤
2pullover套衫
3dress裙子
4coat外套
5sandal凉鞋
6shirt衬衫
7sneaker运动鞋
8bag
9ankle boot短靴

三、fashion_mnist数据集下载和展示

运用tensorflow下载fashion_mnist数据集与mnist类似,代码如下:

import tensorflow as tf
from tensorflow import keras
import numpy as np

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print(train_images.shape, train_labels.shape)
print(test_images.shape, test_labels.shape)

输出:

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)

可以看到训练集是60000张28x28的灰度图,测试集是10000张28x28的灰度图。
一些样例展示如下:
fashion_mnist

四、数据预处理

数据预处理主要是对图片归一化处理,如下:

train_images=train_images / 255.
test_images = test_images / 255.

五、构建模型和训练模型

模型构建

model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))
model.add(keras.layers.Dense(128, activation=tf.nn.relu))
model.add(keras.layers.Dense(10, activation=tf.nn.softmax))
model.summary()

模型结构如下:

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 128)               100480    
                                                                 
 dense_1 (Dense)             (None, 10)                1290      
                                                                 
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________

模型训练

class MyCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
    	#loss小于0.25就停止训练
        if logs.get('loss') < 0.25:
            self.model.stop_training = True
callbacks = MyCallback()
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.sparse_categorical_crossentropy, metrics=['acc'])
h = model.fit(train_images, train_labels, batch_size=32, epochs=15, validation_data=(test_images_scaled, test_labels), callbacks=[callbacks])

查看结果

Epoch 1/15
1875/1875 [==============================] - 11s 5ms/step - loss: 0.5031 - acc: 0.8239 - val_loss: 0.4201 - val_acc: 0.8499
Epoch 2/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3774 - acc: 0.8648 - val_loss: 0.4333 - val_acc: 0.8482
Epoch 3/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3371 - acc: 0.8773 - val_loss: 0.3662 - val_acc: 0.8667
Epoch 4/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3145 - acc: 0.8845 - val_loss: 0.3697 - val_acc: 0.8667
Epoch 5/15
1875/1875 [==============================] - 10s 5ms/step - loss: 0.2929 - acc: 0.8921 - val_loss: 0.3404 - val_acc: 0.8794
Epoch 6/15
1875/1875 [==============================] - 10s 5ms/step - loss: 0.2805 - acc: 0.8958 - val_loss: 0.3453 - val_acc: 0.8793
Epoch 7/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2683 - acc: 0.9009 - val_loss: 0.3452 - val_acc: 0.8778
Epoch 8/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2566 - acc: 0.9032 - val_loss: 0.3370 - val_acc: 0.8820
Epoch 9/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2480 - acc: 0.9065 - val_loss: 0.3482 - val_acc: 0.8789

用图标显示损失曲线和准确率曲线

loss_list = h.history['loss']
acc_list = h.history['acc']
test_loss_list = h.history['val_loss']
test_acc_list = h.history['val_acc']

plt.rcParams['font.sans-serif'] = ['Songti SC']
plt.rcParams['axes.unicode_minus'] = False

plt.figure(figsize=(20, 10))

plt.subplot(221)
plt.ylabel('loss')
plt.plot(loss_list, color='blue', marker='.', label='train_loss')
plt.plot(test_loss_list, color='red', marker='.', label='val_loss')
plt.legend(loc='upper left')
plt.title('损失曲线', fontsize=16)

plt.subplot(222)
plt.ylabel('acc')
plt.plot(acc_list, color='blue', marker='.', label='train_acc')
plt.plot(test_acc_list, color='red', marker='.', label='val_acc')
plt.legend(loc='upper left')
plt.title('准确率曲线', fontsize=16)
plt.show()

输出:
在这里插入图片描述

六、模型预测

选一个图像进行预测:

image = tf.cast(test_images[1], tf.float32)
image = tf.reshape(image, [1, 28, 28])
np.argmax(model.predict(image))
print(test_labels[1])
plt.imshow(test_images[1])
plt.show()

输出:

1/1 [==============================] - 0s 408ms/step
2

predict

总结

本文主要介绍了tensorflow fashion_mnist的下载、训练、预测,模型用的全连接网络。

如果觉得有些帮助或觉得文章还不错,请关注一下博主,你的关注是我持续写作的动力。另外,如果有什么问题,可以在评论区留言,或者私信博主,博主看到后会第一时间进行回复。
【间歇性的努力和蒙混过日子,都是对之前努力的清零】
欢迎转载,转载请注明出处:https://blog.csdn.net/xxm524/article/details/128160073

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/59547.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

自制肥鲨HDO2电源降压延长线,支持3S~6S动力电池

自制肥鲨HDO2电源降压延长线&#xff0c;支持3S~6S动力电池1. 问题源由2. 破题思路2.1 10元大钞搞定2.2 两个毛爷爷搞定3. 解决方案4. 最终延长线产出4.1 裸照4.2 成品5. 花絮1. 问题源由 源由&#xff1a; 电池盒电源线接触不良。 肥鲨眼镜的电源盒问题由来已久&#xff0c;…

SecureCRT隧道,跳板机+端口转发,内网穿透

背景 ServerA(Linux系统)&#xff1a; 内网&#xff1a;192.168.111.201 公网&#xff1a;10.121.8.88&#xff08;虚构的ip方便理解&#xff09; ServerB&#xff1a; 内网&#xff1a;192.168.111.202 本机&#xff1a; 安装有SecureCRT软件 注意上图中的箭头。箭头指向可…

Android动画——使用动画启动Activity

1、使用动画启动Activity概述 我们在Android开发应用时&#xff0c;会遇到一个页面跳转到另一个页面的情况&#xff0c;这时候我们如果使用动画过渡会使得页面更加的流畅。这是一个滑动式的进入和退出的动画可以看到Android的过渡动画可以在不同状态之间建立视觉联系。您可以为…

find 命令这 7 种高级用法

可以很肯定地说&#xff0c;find 命令是 Linux 后台开发人员必须熟知的操作之一&#xff0c;除非您使用的是 Windows Server。 对于技术面试&#xff0c;它也是一个热门话题。让我们看一道真题&#xff1a; 如果你的 Linux 服务器上有一个名为 logs 的目录&#xff0c;如何删…

MySQL性能调优——索引篇

MySQL为什么会选错索引 使用explain命令可以查看查询语句使用了具体使用了哪个索引&#xff0c;比如 explain select * from t where a between 10000 and 20000;查询结果如图所示。 选择索引是优化器的工作 优化器选择索引的目的是想找到一个最优的执行方案&#xff0c;并…

08_线程池

08_线程池前言Callable接口ThreadPoolExecutor**为什么用线程池****线程池的好处**架构说明创建线程池底层实现线程池的重要参数拒绝策略线程池底层工作原理问题二: 线程池使用过吗?谈谈在生产上如何设置的参数?线程池的拒绝策略你谈谈?工作中单一的/固定数的/可变数的三种创…

设计模式 之 行为型模式

设计模式 之 行为型模式 模式 & 描述包括行为型模式 这些设计模式特别关注对象之间的通信。责任链模式&#xff08;Chain of Responsibility Pattern&#xff09; 命令模式&#xff08;Command Pattern&#xff09;解释器模式&#xff08;Interpreter Pattern&#xff09;…

Web 性能指标

Web 性能指标 对于 Web 开发人员来说&#xff0c;如何衡量一个 Web 页面的性能一直是一个难题。 最初&#xff0c;我们使用 Time to First Byte、DomContentLoaded 和 Load 这些衡量文档加载进度的指标&#xff0c;但它们不能直接反应用户视觉体验。 为了能衡量用户视觉体验…

[附源码]计算机毕业设计springboot志愿者服务平台

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

[数据结构]八大排序算法总结

作者&#xff1a; 华丞臧专栏&#xff1a;【数据结构】 各位读者老爷如果觉得博主写的不错&#xff0c;请诸位多多支持(点赞收藏关注)。如果有错误的地方&#xff0c;欢迎在评论区指出。推荐一款刷题网站 &#x1f449; LeetCode刷题网站 目录 一、排序的概念及其运用 1.1排…

【目的:windows下VS2017/2022配置使用opengl - 初探-创建一个空窗口】

目的&#xff1a;windows下VS2017/2022配置使用opengl - 初探-创建一个空窗口 环境&#xff1a; 系统&#xff1a;Win10 环境&#xff1a;VS2017 64bit步骤&#xff1a; windows下visualstudio下使用opengl&#xff0c;搭建配置环境并测试窗口 1、opengl库&#xff0c;vs下自…

Crack:Open Inventor 10.12.1 Fixed Bugs List 10.12

10.12.0 - 10.12.1 Open Inventor 10.12.1 Core #OIV-4245 Shapes not rendered with MultipleInstancing  #OIV-4258 Transparency issue with SoPackedColor – CAS-41256-F0S4 OivSuite.Java #OIV-4273 Memory leak with RemoteViz Java and JVM VolumeViz #OI…

CMake中add_library的使用

CMake中的add_library命令用于使用指定的源文件向项目(project)中添加库&#xff0c;其格式如下&#xff1a; add_library(<name> [STATIC | SHARED | MODULE][EXCLUDE_FROM_ALL][<source>...]) # Normal Libraries add_library(<name> OBJECT [<source&…

【Java 快速复习】垃圾回收算法 垃圾回收器

快速理解 Java 垃圾回收算法 & 垃圾回收器 先说个关系概念&#xff0c;垃圾回收的算法是逻辑概念的定义&#xff0c;用于规范垃圾回收器实现方的一些行为&#xff0c;而垃圾回收器就是实现这些算法的工具&#xff0c;这些工具大概是一系列的 C 的类以及其实现的一些对应回…

Linux服务器上跑深度学习实验

原文地址&#xff1a;Linux上跑深度学习实验 目录远程连接环境搭建与服务器断开连接后代码停止之前一直使用Google Colab跑实验&#xff0c;因为实验的规模不大&#xff0c;配合Google Drive用起来就很舒服&#xff0c;但是最近要系统地进行实验&#xff0c;规模一下子上来了&a…

【Spring】一文带你搞懂Spring容器配置

前言 本文为大家介绍的是Spring容器配置相关知识&#xff0c;包含Bean和Configuration的使用&#xff0c;使用 AnnotationConfigApplicationContext实例化Spring容器&#xff0c;Bean注解的使用&#xff0c;Configuration的使用&#xff0c;Import 注解的使用&#xff0c;结合J…

C++中STL-set详解

目录 set/ multiset容器 1. set基本概念 2.set构造和赋值 3.set大小和交换 4.set插入和删除 5.set容器-查找和统计 6.set和multiset的区别 7.pair对组创建 8.set容器排序 9.set存放自定义数据类型 set/ multiset容器 1. set基本概念 简介: 所有元素都会在插入时自动…

使用Apisix打造家庭NAS网关,免公网IP访问

使用Apisix打造家庭NAS网关 本文使用apisix打造家庭NAS网关&#xff0c;并通过cloudflare进行穿透&#xff0c;可免公网IP访问。首先你的NAS支持Docker&#xff0c;没有NAS也没有关系&#xff0c;只要你的电脑支持Docker同样可以参照该教程。 1 依赖资源准备 准备域名: 免费…

HTML+CSS+JS做一个好看的个人网页—web网页设计作业

个人网页设计个人网页&#xff08;htmlcssjs&#xff09;——网页设计作业带背景音乐&#xff08;The way I still Love you&#xff09;、樱花飘落效果、粒子飘落效果页面美观&#xff0c;样式精美涉及&#xff08;htmlcssjs&#xff09;&#xff0c;下载后可以根据自己需求进…

8086,8088CPU管脚,奇偶地址体, 时钟信号发生器8284 ,ready信号,reset复位信号。规则字和非规则字

8086/8088均为40条引线&#xff0c;双列直插式封装&#xff0c;某些引线有多重功能&#xff0c;其功能转换有两种情况&#xff1a;一种是分时复用&#xff0c;一种是按组态定义。 用8088微处理器构成系统时&#xff0c;有两种不同的组态&#xff1a; 最小组态&#xff1a;808…