MLP实现fashion_mnist数据集分类(1)-模型构建、训练、保存与加载(tensorflow)

news2025/1/11 21:53:51

1、查看tensorflow版本

import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

2、fashion_mnist数据集下载与展示

(train_image,train_label),(test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()
print(train_image.shape)
print(train_label.shape)
print(test_image.shape)
print(test_label.shape)

在这里插入图片描述

import matplotlib.pyplot as plt
# plt.imshow(train_image[0])  # 此处为啥是彩色的?

def plot_images_lables(images,labels,start_idx,num=5):
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[start_idx+i],cmap='binary')
        title = 'label=' + str(labels[start_idx+i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()
plot_images_lables(train_image,train_label,0,5)
# plot_images_lables(test_image,test_label,0,5)

在这里插入图片描述

3、数据预处理

X_train,X_test = tf.cast(train_image/255.0,tf.float32),tf.cast(test_image/255.0,tf.float32) # 归一化
y_train,y_test = train_label,test_label # 此处对y没有做onehot处理,需要使用稀疏交叉损失函数

4、模型构建

from keras import Sequential
from keras.layers import Flatten,Dense,Dropout
from keras import Input

model = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=64,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

在这里插入图片描述

5、模型配置

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])

6、模型训练

H = model.fit(x=X_train,
              y=y_train,
              validation_split=0.2,
              # validation_data=(X_test,y_test),
              epochs=10,
              batch_size=128,
              verbose=1)

在这里插入图片描述

plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()

在这里插入图片描述

plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()

在这里插入图片描述

7、模型评估

model.evaluate(X_test,y_test)

在这里插入图片描述

8、模型预测

import numpy as np
import matplotlib.pyplot as plt

def pred_plot_images_lables(images,labels,start_idx,num=5):
    # 预测
    res = model.predict(images[start_idx:start_idx+num])
    res = np.argmax(res,axis=1)

    # 画图
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[start_idx+i],cmap='binary')
        title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()
pred_plot_images_lables(X_test,y_test,0,5)

在这里插入图片描述

9、模型保存与加载

import numpy as np

tf.keras.models.save_model(model,"model.keras")
loaded_model = tf.keras.models.load_model("model.keras")
# assert np.allclose(model.predict(X_test[:5]), loaded_model.predict(X_test[:5]))
print(np.argmax(model.predict(X_test[:5]),axis=1))
print(np.argmax(loaded_model.predict(X_test[:5]),axis=1))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1645362.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Jetson orin 刷机

因为现在的系统各种库已经都乱了,也怪自己太心急了,把cmake给删了,导致很多编译库都出现了问题。记住这个教训! 找到合适的教程 首先是PC系统,看来好几个教程都说用ubuntu,也有的说Windows也可以&#xf…

SGX Memory Organization

文章目录 前言一、Processor Reserved Memory (PRM)二、Enclave Page Cache (EPC)三、Enclave Page Cache Map (EPCM)参考资料 前言 本节内容主要介绍了SGX Memory Organization,来自参考资料里的综述文章,可供初学者了解SGX内存组织对应的知识。 一、…

【C++练级之路】【Lv.20】位图和布隆过滤器(揭开大数据背后的神秘面纱)

快乐的流畅:个人主页 个人专栏:《算法神殿》《数据结构世界》《进击的C》 远方有一堆篝火,在为久候之人燃烧! 文章目录 引言一、位图1.1 位图的概念1.2 位图的优势1.3 位图的模拟实现1.3.1 成员变量与默认成员函数1.3.2 test1.3.3…

LoRa无线通讯入门

本文图片来自于深入浅出讲解LoRa通信技术,LoRa技术介绍,LoRa开发与应用,物联网学习必备知识点!_哔哩哔哩_bilibili LoRa无线通讯 LoRa(Long Range)是一种低功耗广域网(LPWAN)技术&a…

5.06号模拟前端面试8问

5.06号模拟前端面试8问 1.promise如何实现then处理 在JavaScript中,Promise 是一个代表异步操作最终完成或失败的对象。它有三种状态:pending(等待),fulfilled(完成),rejected&…

WhisperCLI-本地部署语音识别系统;Mis开源LLM推理平台;Dokploy-开源版Vercel;Mem-大规模知识图谱

1. Whisper-cli:可本地部署的开源语音识别系统 近日,Ruff的开发团队发布了一款名为Whisper cpp cli的全新语音识别系统,该系统已在GitHub Repo上开源。这是一款完全自主研发的语音转文字系统,基于Whisper技术构建。Ruff团队一直以…

【Linux-点灯烧录-SD卡/USB烧写】

目录 1. 烧写方式2. 烧写之代码编译2.1 led.s->led.o2.2 led.o->led.elf2.3 led.elf->led.bin2.4 反汇编:led.elf->led.dis 3. 烧写之烧录到SD卡上:3.1 开启烧录软件权限:3.2 确定SD卡的格式:FAT323.3 烧录到SD卡上3.…

安卓跑马灯效果

跑马灯效果 当一行文本的内容太多,导致无法全部显示,也不想分行展示时,只能让文字从左向右滚动显示,类 似于跑马灯。电视在播报突发新闻时经常在屏幕下方轮播消息文字,比如“ 快讯:我国选手 *** 在刚刚结束…

(014) java.math.BigInteger cannot be cast to java.lang.Long

文章目录 问题原因 问题 mysql 和 Java 在进行数据类型的映射时,报错: 原因 部分 jdk8 和高版本的 jdk 对 mysql 的 BigInteger 类型转换为 Java的 Long 类型认为是错误的类型转换。 1.解决方法一:更换兼容的 jdk8版本。 2.解决方法二&am…

C++:特殊类的设计 | 单例模式

目录 1、特殊类的设计 2、设计一个类,不能被拷贝 3、设计一个类,只能在堆上创建对象 4、设计一个类,只能在栈上创建对象 5、设计一个类,不能被继承 6、单例模式 1、饿汉模式 2、懒汉模式 1、特殊类的设计 在实际应用场景中…

集合定义和使用方法

一.集合的长度 集合的长度,可以添加和删除,长度也会跟着去发生改变,数组一旦创建完成他的长度就不会发生改变。 二.集合的定义方式 ArrayList<String> list new ArrayList(); 三.集合能存储的数据类型 集合能够存储引用数据类型,存储基本数据类型需要使用包装类: 四…

vs配置cplex12.10

1.创建c空项目 2.修改运行环境 为release以及x64 3.创建cpp文件 4.鼠标右键点击项目中的属性 5.点击c/c&#xff0c;点击第一项常规&#xff0c;配置附加库目录 5.添加文件索引&#xff0c;主要用于把路径导进来 6.这一步要添加的目录与你安装的cplex的目录有关系 F:\program…

vue管理系统导航中添加新的iconfont的图标

1.在官网上将需要的图标&#xff0c;加入项目中&#xff0c;下载 2.下载的压缩包中&#xff0c;可以选择这两个&#xff0c;复制到项目目录中 3.如果和之前的iconfont有重复&#xff0c;那么就重新命名 4.将这里的.ttf文件&#xff0c;也重命名为自己的 5.在main文件中导入 6.在…

九泰智库 | 医械周刊- Vol.24

⚖️ 法规动态 国家药监局&#xff1a;2款创新器械获批上市 4月28日国家药品监督管理局公告&#xff0c;批准心擎医疗&#xff08;苏州&#xff09;股份有限公司“体外心室辅助设备”和“体外心室辅助泵头及管路”创新产品注册申请。 体外心室辅助设备由磁悬浮马达、控制主机…

Python语言在地球科学中地理、气象、气候变化、水文、生态、传感器等数据可视化到常见数据分析方法的使用

Python是功能强大、免费、开源&#xff0c;实现面向对象的编程语言&#xff0c;Python能够运行在Linux、Windows、Macintosh、AIX操作系统上及不同平台&#xff08;x86和arm&#xff09;&#xff0c;Python简洁的语法和对动态输入的支持&#xff0c;再加上解释性语言的本质&…

U盘提示“被写保护”无法操作处理怎么办?

今天在使用U盘复制拷贝文件时&#xff0c;U盘出现“U盘被写保护”提示&#xff0c;导致U盘明明有空闲内存却无法复制的情况。这种情况很常见&#xff0c;很多人在插入U盘到电脑后&#xff0c;会出现"U盘被写保护"的提示&#xff0c;导致无法进行删除、保存、复制等操…

力扣每日一题110:平衡二叉树

题目 简单 给定一个二叉树&#xff0c;判断它是否是 平衡二叉树 示例 1&#xff1a; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;true示例 2&#xff1a; 输入&#xff1a;root [1,2,2,3,3,null,null,4,4] 输出&#xff1a;false示例 3&#xff1a; …

ComfyUI中的图像镜像反转(3种方式)

用下面的节点就可以让图片左右镜像反转&#xff0c;如下 如果想要上下翻转呢&#xff1f;用下面的节点&#xff0c;如下 这个节点不仅可以上下翻转&#xff0c;还可以左右翻转&#xff0c;把方向设置为水平就行&#xff0c;即设置为level&#xff0c;如下 或者用下面这个节点也…

动态规划——斐波那契数列模型:91.解码方法

文章目录 题目描述算法原理1.状态表示2.状态转移方程3.初始化⽅法⼀&#xff08;直接初始化&#xff09;⽅法⼆&#xff08;添加辅助位置初始化&#xff09; 4.填表顺序5.返回值 代码实现C优化Java优化 题目描述 题目链接&#xff1a;91.解码方法 算法原理 类似于斐波那契…

Elasticsearch初步认识

Elasticsearch初步认识 ES概述基本概念正向索引和倒排索引IK分词器ik_smart最少切分ik_max_word为最细粒度划分 ES索引库基本操作对索引库操作对文档操作 ES概述 Elasticsearch&#xff0c;简称为 ES&#xff0c;是一款非常强大的开源的高扩展的分布式全文检索引擎&#xff0c…