python与深度学习(四):ANN和fashion_mnist二

news2025/1/16 7:54:06

目录

  • 1. 说明
  • 2. fashion_mnist的ANN模型测试
    • 2.1 导入相关库
    • 2.2 加载数据和模型
    • 2.3 设置保存图片的路径
    • 2.4 加载图片
    • 2.5 图片预处理
    • 2.6 对图片进行预测
    • 2.7 显示图片
  • 3. 完整代码和显示结果
  • 4. 多张图片进行测试的完整代码以及结果

1. 说明

本篇文章是对上篇文章训练的模型进行测试。首先是将训练好的模型进行重新加载,然后采用opencv对图片进行加载,最后将加载好的图片输送给模型并且显示结果。

2. fashion_mnist的ANN模型测试

2.1 导入相关库

在这里导入需要的第三方库如cv2,如果没有,则需要自行下载。

from tensorflow import keras
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np
# 导入tensorflow
import tensorflow as tf
# 导入keras
from tensorflow import keras
from keras.datasets import fashion_mnist

2.2 加载数据和模型

把fashion_mnist数据集进行加载,并且把训练好的模型也加载进来。

# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载fashion数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载ann_mnist.h5文件,重新生成模型对象
recons_model = keras.models.load_model('ann_fashion.h5')

2.3 设置保存图片的路径

将数据集的某个数据以图片的形式进行保存,便于测试的可视化。
在这里设置图片存储的位置。

# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test100.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[100]).save(test_file_path)

在书写完上述代码后,需要在代码的当前路径下新建一个imgs的文件夹用于存储图片,如下。
在这里插入图片描述
执行完上述代码后就会在imgs的文件中可以发现多了一张图片,如下(下面测试了很多次)。
在这里插入图片描述

2.4 加载图片

采用cv2对图片进行加载,下面最后一行代码取一个通道的原因是用opencv库也就是cv2读取图片的时候,图片是三通道的,而训练的模型是单通道的,因此取单通道。

# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]

2.5 图片预处理

对图片进行预处理,即进行归一化处理和改变形状处理,这是为了便于将图片输入给训练好的模型进行预测。

# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 784)

2.6 对图片进行预测

将图片输入给训练好我的模型并且进行预测。
预测的结果是10个概率值,所以需要进行处理, np.argmax()是得到概率值最大值的序号,也就是预测的数字。

# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
y_pre_pro = tf.nn.softmax(y_pre_pro)  # 因为模型搭建的时候,输出层没有激活函数softmax,因此这里需要使用softmax
# 哪一类数字
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别:', class_names[class_id])
text = str(class_names[class_id])

2.7 显示图片

对预测的图片进行显示,把预测的数字显示在图片上。
下面5行代码分别是创建窗口,设定窗口大小,显示图片,停留图片,清除内存。

# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()

3. 完整代码和显示结果

以下是完整的代码和图片显示结果。

from tensorflow import keras
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np
# 导入tensorflow
import tensorflow as tf
# 导入keras
from tensorflow import keras
from keras.datasets import fashion_mnist
# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载fashion数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载ann_mnist.h5文件,重新生成模型对象
recons_model = keras.models.load_model('ann_fashion.h5')
# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test100.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[100]).save(test_file_path)
# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]
# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 784)
# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
y_pre_pro = tf.nn.softmax(y_pre_pro)  # 因为模型搭建的时候,输出层没有激活函数softmax,因此这里需要使用softmax
# 哪一类数字
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别:', class_names[class_id])
text = str(class_names[class_id])
# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()

To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
1/1 [==============================] - 0s 224ms/step
test.png的预测概率: tf.Tensor(
[[1.3604405e-02 4.0369225e-04 4.1837210e-04 9.8322290e-01 1.0399223e-04
  1.5519042e-07 2.2383246e-03 8.2791501e-10 8.1014296e-06 1.0210023e-08]], shape=(1, 10), dtype=float32)
test.png的预测概率: tf.Tensor(0.9832229, shape=(), dtype=float32)
test.png的所属类别: Dress

在这里插入图片描述

4. 多张图片进行测试的完整代码以及结果

为了测试更多的图片,引入循环进行多次测试,效果更好。

from tensorflow import keras
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np
# 导入tensorflow
import tensorflow as tf
# 导入keras
from tensorflow import keras
from keras.datasets import fashion_mnist
# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载fashion_mnist数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载ann_fashion.h5文件,重新生成模型对象
recons_model = keras.models.load_model('ann_fashion.h5')

prepicture = int(input("input the number of test picture :"))
for i in range(prepicture):
    path1 = input("input the test picture path:")
    # 创建图片保存路径
    test_file_path = os.path.join(sys.path[0], 'imgs', path1)
    # 存储测试数据的任意一个
    num = int(input("input the test picture num:"))
    Image.fromarray(x_test[num]).save(test_file_path)
    # 加载本地test.png图像
    image = cv2.imread(test_file_path)
    # 复制图片
    test_img = image.copy()
    # 将图片大小转换成(28,28)
    test_img = cv2.resize(test_img, (28, 28))
    # 取单通道值
    test_img = test_img[:, :, 0]
    # 预处理: 归一化 + reshape
    new_test_img = (test_img/255.0).reshape(1, 784)
    # 预测
    y_pre_pro = recons_model.predict(new_test_img, verbose=1)
    y_pre_pro = tf.nn.softmax(y_pre_pro)  # 因为模型搭建的时候,输出层没有激活函数softmax,因此这里需要使用softmax
    # 哪一类数字
    class_id = np.argmax(y_pre_pro, axis=1)[0]
    print('test.png的预测概率:', y_pre_pro)
    print('test.png的预测概率:', y_pre_pro[0, class_id])
    print('test.png的所属类别:', class_names[class_id])
    text = str(class_names[class_id])
    # # 显示
    cv2.namedWindow('img', 0)
    cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小
    cv2.imshow('img', image)
    cv2.waitKey()
    cv2.destroyAllWindows()

下面的test picture num指的是数据集中该数据的序号(0-59999),并不是值实际的数字。

```python
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
input the number of test picture :2
input the test picture path:71.jpg
input the test picture num:2
1/1 [==============================] - 0s 176ms/step
test.png的预测概率: tf.Tensor(
[[5.1171044e-20 1.0000000e+00 1.0369056e-20 3.2896547e-11 1.0296870e-15
  3.9021204e-35 1.4200611e-16 0.0000000e+00 3.1272054e-24 3.7608996e-36]], shape=(1, 10), dtype=float32)
test.png的预测概率: tf.Tensor(1.0, shape=(), dtype=float32)
test.png的所属类别: Trouser

在这里插入图片描述

input the test picture path:72.jpg
input the test picture num:3
1/1 [==============================] - 0s 32ms/step
test.png的预测概率: tf.Tensor(
[[2.8700330e-18 1.0000000e+00 5.2245058e-19 7.1288919e-10 3.2442375e-14
  2.9884308e-32 3.5976920e-15 2.7016272e-38 3.2129908e-22 3.0112130e-33]], shape=(1, 10), dtype=float32)
test.png的预测概率: tf.Tensor(1.0, shape=(), dtype=float32)
test.png的所属类别: Trouser

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/794169.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Guitar Pro有没有免费的?

Guitar Pro是一款功能强大的吉他谱编辑软件,它为音乐创作者和乐手提供了极大的便利。但是,许多人对于Guitar Pro是否有免费版本,问题存在疑虑。下面我们就来详细介绍Guitar Pro有没有免费的内容吧! 一、Guitar Pro有没有免费的 …

音视频——视频流H264编码格式

1 H264介绍 我们了解了什么是宏快,宏快作为压缩视频的最小的一部分,需要被组织,然后在网络之间做相互传输。 H264更深层次 —》宏块 太浅了 ​ 如果单纯的用宏快来发送数据是杂乱无章的,就好像在没有集装箱 出现之前,…

MultipartFile类型接收上传文件报出的UncheckedIOException以及删除tomcat临时文件失败源码探索

1、描述异常背景: 因为需要分析数据,待处理excel文件的数据行数太大,手动太累,花半小时写了一个定制的数据入库工具,改成了通用的,整个项目中的万级别数据都在工具上分析,写SQL进行分析&#x…

Rust vs Go:常用语法对比(十)

题图来自 Rust vs. Golang: Which One is Better?[1] 182. Quine program Output the source of the program. 输出程序的源代码 package mainimport "fmt"func main() { fmt.Printf("%s%c%s%c\n", s, 0x60, s, 0x60)}var s package mainimport "fm…

【C语言】嵌入式C语言项目管理利器:深入理解Makefile的应用与实践

目录 一、makedile的概述 1、案例引入 2、makefile 3、Makefile优点 二、makefile的语法规则 1、语法规则 2、简单实战 三、makefile的变量 1、自定义变量 2、系统环境变量 3、预定义变量 4、高级makefile 一、makedile的概述 1、案例引入 gcc a.c b.c c.c ‐o …

Error: Please select Android SDK解决方案(仅供参考)

一、问题描述 今天开始正式接触项目的工作内容,然后从组里的代码仓库里git clone了一份Android Studio项目下来。下好了以后我使用Android Studio打开这个项目,但在尝试编译运行的时候遇到了很多错误。例如,开发环境界面上面用于编译的小锤子…

【MySQL】基本查询(插入查询结果、聚合函数、分组查询)

目录 一、插入查询结果二、聚合函数三、分组查询(group by & having)四、SQL查询的执行顺序五、OJ练习 一、插入查询结果 语法: INSERT INTO table_name [(column [, column ...])] SELECT ...案例:删除表中重复数据 --创建…

ARM汇编中类似c语言中宏定义的使用

—# 一、是什么? .equ xxx,xxx 类似c语言中#define xxx xxxx ## 1.操作例子 代码如下(示例): .equ bss_start,0x2000 .equ bss_end,0x20000x100 .global _start _start:ldr r0,bss_startldr r1,bss_end

用ChatGPT的AI配音插件,1秒获取想要的语音样本

Hi! 大家好,我是专注于AI项目实战的赤辰,今天继续跟大家介绍另外一款GPT4.0插件Speechki。 一、什么是Speechki? Speechki是一个专门的文本转语音(TTS)工具,它的主要功能是将文本转换为高质量的音频文件。…

【C++初阶】 priority_queue(优先级队列)

⭐博客主页:️CS semi主页 ⭐欢迎关注:点赞收藏留言 ⭐系列专栏:C初阶 ⭐代码仓库:C初阶 家人们更新不易,你们的点赞和关注对我而言十分重要,友友们麻烦多多点赞+关注,你们的支持是我…

wordpress我的个人网站搭建

WordPress介绍 WordPress是一个功能强大且易于使用的网站管理平台。它是基于PHP和MySQL构建的,可以在各种不同的主机上运行。 wordpress对服务器的要求 需求最低版本要求PHP7.4 或更高版本MySQL5.6 或更高版本Web服务器任意(如:Apache、Ng…

【Nodejs】登录鉴权-Cookie

1.什么是认证(Authentication) 通俗地讲就是验证当前用户的身份,证明“你是你自己”(比如:你每天上下班打卡,都需要通过指纹打卡,当你的指纹和系统里录入的指纹相匹配时,就打卡成功&…

read_table()函数--Pandas

1. 函数功能 将通用定界符文件读取到DataFrame中 2. 函数语法 pandas.read_table(filepath_or_buffer, *, sep_NoDefault.no_default, delimiterNone, headerinfer, names_NoDefault.no_default, index_colNone, usecolsNone, dtypeNone, engineNone, convertersNone, true…

java并发编程 13:JUC之Semaphore、CountdownLatch、 CyclicBarrier

目录 Semaphore使用常见应用原理源码流程 CountdownLatch使用原理 CyclicBarrier使用 Semaphore 使用 Semaphore是一种计数信号量,它用于控制对共享资源的访问。它维护了一个许可计数器,表示可用的许可数量。线程在访问共享资源前必须先获得许可&#…

小学期笔记——天天酷跑2

程序的錯誤一共是3個維度 1:編譯錯誤 (1)点击后边的红色点点 (2)鼠标移动红线上 (3)复制错误提示内容 (4)粘贴到CHATGPT (5) 复制爆红正行(全文&#xff0…

JavaScript function默认参数赋值前后顺序差异

1、(num1,num2num1) 当传值仅传一个参数时,先给到第一个参数即num1,num1再赋值给num2, function sum(num1, num2 num1) {console.log(num1 num2) } sum(10)//20 sum(10,3)//13 2、(t2t1,t1) 当传值仅有一个参数时,先给到第一个…

Spring Boot OAuth2 认证服务器搭建及授权码认证演示

本篇使用JDK版本是1.8,需要搭建一个OAuth 2.0的认证服务器,用于实现各个系统的单点登录。 框架构思 这里选择Spring Boot+Spring Security + Spring Authorization Server 实现,具体的版本选择如下: Spirng Boot 2.7.14 , Spring Boot 目前的最新版本是 3.1.2,在官方的介…

华为eNSP:isis配置跨区域路由

一、拓扑图 二、路由器的配置 1、配置接口IP AR1: <Huawei>system-view [Huawei]int g0/0/0 [Huawei-GigabitEthernet0/0/0]ip add 1.1.1.1 24 [Huawei-GigabitEthernet0/0/0]q AR2: [Huawei]int g0/0/0 [Huawei-GigabitEthernet0/0/0]ip add 1.1.1.2 24 [Huawe…

【弹力设计篇】聊聊限流设计

为什么需要限流 对于一个后端系统来说&#xff0c;其处理能力是有限的&#xff0c;会受到业务流程&#xff0c;系统配置等的限制&#xff0c;QPS和TPS有一个上限值&#xff0c;比如一个订单系统1分钟可以处理100个请求。当1分钟超过100个请求的时候&#xff0c;我们为了保证系…

13.2.2 【Linux】使用者功能

不论是 useradd/usermod/userdel &#xff0c;那都是系统管理员所能够使用的指令&#xff0c; 如果我是一般身份使用者&#xff0c;那么我是否除了密码之外&#xff0c;就无法更改其他的数据呢&#xff1f; 当然不是啦&#xff01;这里我们介绍几个一般身份使用者常用的帐号数据…