python与深度学习(八):CNN和fashion_mnist二

news2024/11/17 8:44:16

目录

  • 1. 说明
  • 2. fashion_mnist的CNN模型测试
    • 2.1 导入相关库
    • 2.2 加载数据和模型
    • 2.3 设置保存图片的路径
    • 2.4 加载图片
    • 2.5 图片预处理
    • 2.6 对图片进行预测
    • 2.7 显示图片
  • 3. 完整代码和显示结果
  • 4. 多张图片进行测试的完整代码以及结果

1. 说明

本篇文章是对上篇文章训练的模型进行测试。首先是将训练好的模型进行重新加载,然后采用opencv对图片进行加载,最后将加载好的图片输送给模型并且显示结果。

2. fashion_mnist的CNN模型测试

2.1 导入相关库

在这里导入需要的第三方库如cv2,如果没有,则需要自行下载。

from tensorflow import keras
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np
# 导入tensorflow
import tensorflow as tf
# 导入keras
from tensorflow import keras
from keras.datasets import fashion_mnist

2.2 加载数据和模型

把fashion_mnist数据集进行加载,并且把训练好的模型也加载进来。

# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载fashion数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载cnn_fashion.h5文件,重新生成模型对象
recons_model = keras.models.load_model('cnn_fashion.h5')

2.3 设置保存图片的路径

将数据集的某个数据以图片的形式进行保存,便于测试的可视化。
在这里设置图片存储的位置。

# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test100.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[100]).save(test_file_path)

在书写完上述代码后,需要在代码的当前路径下新建一个imgs的文件夹用于存储图片,如下。
在这里插入图片描述

执行完上述代码后就会在imgs的文件中可以发现多了一张图片,如下(下面测试了很多次)。
在这里插入图片描述

2.4 加载图片

采用cv2对图片进行加载,下面最后一行代码取一个通道的原因是用opencv库也就是cv2读取图片的时候,图片是三通道的,而训练的模型是单通道的,因此取单通道。

# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]

2.5 图片预处理

对图片进行预处理,即进行归一化处理和改变形状处理,这是为了便于将图片输入给训练好的模型进行预测。

# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 28, 28, 1)

2.6 对图片进行预测

将图片输入给训练好我的模型并且进行预测。
预测的结果是10个概率值,所以需要进行处理, np.argmax()是得到概率值最大值的序号,也就是预测的数字。

# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
# 哪一类
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别:', class_names[class_id])
text = str(class_names[class_id])

2.7 显示图片

对预测的图片进行显示,把预测的数字显示在图片上。
下面5行代码分别是创建窗口,设定窗口大小,显示图片,停留图片,清除内存。

# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()

3. 完整代码和显示结果

以下是完整的代码和图片显示结果。

from tensorflow import keras
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np
# 导入tensorflow
import tensorflow as tf
# 导入keras
from tensorflow import keras
from keras.datasets import fashion_mnist
# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载fashion数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载cnn_fashion.h5文件,重新生成模型对象
recons_model = keras.models.load_model('cnn_fashion.h5')
# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test100.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[100]).save(test_file_path)
# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]
# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 28, 28, 1)
# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
# 哪一类
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别:', class_names[class_id])
text = str(class_names[class_id])
# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()

To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
1/1 [==============================] - 0s 168ms/step
test.png的预测概率: [[2.9672831e-04 7.3040414e-05 1.4721525e-04 9.9842703e-01 4.7597905e-06
  8.9959512e-06 1.0416918e-03 8.6147125e-09 4.2549357e-07 1.2974965e-07]]
test.png的预测概率: 0.99842703
test.png的所属类别: Dress

在这里插入图片描述

4. 多张图片进行测试的完整代码以及结果

为了测试更多的图片,引入循环进行多次测试,效果更好。

from tensorflow import keras
from keras.datasets import fashion_mnist
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np

# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载mnist数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载cnn_fashion.h5文件,重新生成模型对象
recons_model = keras.models.load_model('cnn_fashion.h5')

prepicture = int(input("input the number of test picture :"))
for i in range(prepicture):
    path1 = input("input the test picture path:")
    # 创建图片保存路径
    test_file_path = os.path.join(sys.path[0], 'imgs', path1)
    # 存储测试数据的任意一个
    num = int(input("input the test picture num:"))
    Image.fromarray(x_test[num]).save(test_file_path)
    # 加载本地test.png图像
    image = cv2.imread(test_file_path)
    # 复制图片
    test_img = image.copy()
    # 将图片大小转换成(28,28)
    test_img = cv2.resize(test_img, (28, 28))
    # 取单通道值
    test_img = test_img[:, :, 0]
    # 预处理: 归一化 + reshape
    new_test_img = (test_img/255.0).reshape(1, 28, 28, 1)
    # 预测
    y_pre_pro = recons_model.predict(new_test_img, verbose=1)
    # 哪一类数字
    class_id = np.argmax(y_pre_pro, axis=1)[0]
    print('test.png的预测概率:', y_pre_pro)
    print('test.png的预测概率:', y_pre_pro[0, class_id])
    print('test.png的所属类别:', class_names[class_id])
    text = str(class_names[class_id])
    # # 显示
    cv2.namedWindow('img', 0)
    cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小
    cv2.imshow('img', image)
    cv2.waitKey()
    cv2.destroyAllWindows()

To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
input the number of test picture :2
input the test picture path:101.jpg
input the test picture num:1
1/1 [==============================] - 0s 145ms/step
test.png的预测概率: [[5.1000708e-05 2.9449904e-13 9.9993873e-01 5.5402721e-11 4.8696438e-06
  1.2649738e-12 5.3379590e-06 6.5959898e-17 7.1223938e-10 4.0113624e-12]]
test.png的预测概率: 0.9999387
test.png的所属类别: Pullover

在这里插入图片描述

input the test picture path:102.jpg
input the test picture num:2
1/1 [==============================] - 0s 21ms/step
test.png的预测概率: [[3.01315001e-10 1.00000000e+00 1.03142118e-14 8.63922683e-11
  4.10812981e-11 6.07313693e-22 2.31636132e-09 5.08595438e-25
  1.02018335e-13 8.82350167e-28]]
test.png的预测概率: 1.0
test.png的所属类别: Trouser

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/799032.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于新浪微博海量用户行为数据、博文数据数据分析:包括综合指数、移动指数、PC指数三个指数

基于新浪微博海量用户行为数据、博文数据数据分析:包括综合指数、移动指数、PC指数三个指数 项目介绍 微指数是基于海量用户行为数据、博文数据,采用科学计算方法统计得出的反映不同事件领域发展状况的指数产品。微指数对于收录的关键词,在指…

javaweb如何格式化日期时间

解决方式: 方式一:在属性上加入注解,对日期进行格式化 方式二:在 WebMvcConfiguration 中扩展Spring MVC的消息转换器,统一对日期类型进行格式化处理 package com.sky.json;import com.fasterxml.jackson.databind.Des…

[OnWork.Tools]系列 01-简介

说明 OnWork.Tools 是基于 Net6 的桌面程序。支持Windows7SP1及以上系统,主要是日常办公或者是开发工作过程中常用的工具集合。界面使用WPF Mvvm模式开发,目的是将开源项目中,好用的项目集成到一起,方便大家使用和学习。 功能 …

牛客周赛 Round 4

A 游游的字符串构造 思路分析 构造字符串&#xff0c;注意k>1 时间复杂度 O(n) AC代码 #include<bits/stdc.h> using namespace std; int main() {int n,k;cin>>n>>k;if(n>3*k){for(int i0;i<k;i)cout<<"you";for(int i0;i<…

工作纪实34-emoji表情包存储异常,修改db的字段类型

线上问题&#xff0c;发现emojo表情写入数据库出现异常 修改mysql字段的字符集 ALTER TABLE customer_cycle_info MODIFY COLUMN customer_sales_remark varchar(500) CHARACTER SET

已实现商业化却仍陷亏损泥潭,瑕瑜错陈的觅瑞集团求上市

撰稿|行星 来源|贝多财经 7月25日&#xff0c;Mirxes Holding Company Limited-B&#xff08;以下简称“觅瑞集团”&#xff09;向港交所递交上市申请材料&#xff0c;计划在港交所主板上市&#xff0c;中金公司和建银国际为其联席保荐人。 据招股书介绍&#xff0c;成立于2…

0基础系列C++教程 从0开始 第二课

0基础系列C教程 从0开始 第二课来了&#xff01; 复习第一课内容 1 怎么输出数字“1919810”&#xff1f; 答案&#xff08;关键语句&#xff09;: cout<<"1919810"; 2 怎么输出字符串“Hello World”&#xff1f; 答案&#xff08;关键语句&#xff09;&a…

MySQL案例——多表查询以及嵌套查询

系列文章目录 MySQL笔记——表的修改查询相关的命令操作 MySQL笔记——MySQL数据库介绍以及在Linux里面安装MySQL数据库&#xff0c;对MySQL数据库的简单操作&#xff0c;MySQL的外接应用程序使用说明 文章目录 系列文章目录 前言 一 创建数据库 1.1 创建一个部门表 1.…

7.24 C++

封装Vector #include <iostream> #include <cstring> using namespace std;template <typename T> class Vector { private:T *start; //指向容器的起始地址T *last; //指向容器存放数据的下一个位置T *end; //指向容器的最后的下一个位置 public:Vec…

MyBatis-入门-快速入门程序

本次使用MyBatis框架是基于SpringBoot框架进行的&#xff0c;在IDEA中创建一个SpringBBot工程&#xff0c;根据自己的需求选择对应的依赖即可 快速入门 需求&#xff1a;使用MyBatis查询所有用户数据步骤&#xff1a; 准备工作&#xff08;创建Spring Boot工程、数据库user表…

【MATLAB】ILOSpsi制导率的代码解析

ILOSpsi制导率的代码解析 这里记录一下关于fossen的MMS工具箱中&#xff0c;关于ILOSpsi制导率的代码解析内容&#xff0c;结合fossen的marine carft hydrodynamics and motion control这本书来参考看 文章目录 ILOSpsi制导率的代码解析前言一、代码全文二、内容解析1.persist…

静态代理、jdk、cglib动态代理 搞不清? 看这个文章就懂了

一、代理模式 代理模式是一种比较好的理解的设计模式。简单来说就是 &#xff1a; 我们使用代理对象来增强目标对象(target obiect)&#xff0c;这样就可以在不修改原目标对象的前提下&#xff0c;提供额外的功能操作&#xff0c;扩展目标对象的功能。将核心业务代码和非核心…

linux下i2c调试神器i2c-tools安装及使用

i2c-tools介绍 在嵌入式linux开发中&#xff0c;有时候需要确认i2c硬件是否正常连接&#xff0c;设备是否正常工作&#xff0c;设备的地址是多少等等&#xff0c;这里我们就需要使用一个用于测试I2C总线的工具——i2c-tools。 i2c-tools是一个专门调试i2c的开源工具&#xff…

Linux查看内存的几种方法

PS的拼接方法 ps aux|head -1;ps aux|grep -v PID|sort -rn -k 4|head 进程的 status 比如说你要查看的进程pid是33123 cat /proc/33123/status VmRSS: 表示占用的物理内存 top PID&#xff1a;进程的ID USER&#xff1a;进程所有者 PR&#xff1a;进程的优先级别&#x…

503、415、403、404报错

1、503 报错 Service Unvailable 解决&#xff1a;如果你用的是网关gateway&#xff0c;观察下你的项目是否导入了nacos依赖。 2、415 报错 Unsupported Media Type post传对象的时候&#xff0c;这样就会报错 解决&#xff1a;axios传对象的时候&#xff0c;用这种就可以了。…

阿里内部都在疯传!企业级 Spring Boot 项目开发实战教程,先肝为敬

前言 本书结合大量的实际开发经验&#xff0c;由浅入深地讲解 Spring Boot 的技术原理和企业级应用开发涉及的的技术及其完整流程。无论是对 Java 企业级开发人员&#xff0c;还是 对其他相关技术爱好者&#xff0c;本书都极具参考价值。 本书特点 理论知识结合实践代码&#…

专项练习-04编程语言-03JAVA-05

1. 设有下面两个类的定义&#xff1a; class Person {} class Student extends Person { public int id; //学号 public int score; //总分 public String name; // 姓名 public int getScore(){return score;} } 类Person和类Student的关系是&#xff08;&#x…

vue2中开发时通过template中的div等标签自动输出对应的less形式带层级的class,只显示带class的

1.写完静态不是要写less吗&#xff0c;自动生成一下实现 this.getLevelClass(domId); domId是自定义的class名称&#xff0c;跟根据自己的需要设置 //vue2中开发时通过template中的div等标签自动输出对应的less形式带层级的class,只显示带class的getLevelClass(name) {let dom…

Python基础语法第七章之文件

目录 一、文件 1.1文件是什么 1.2文件路径 1.3文件操作 1.3.1 打开文件 1.3.2关闭文件 1.3.3写文件 1.3.4读文件 二、使用上下文管理器 2.1上下文管理器 一、文件 1.1文件是什么 变量是把数据保存到内存中. 如果程序重启/主机重启, 内存中的数据就会丢失. 要想能让…

Excel 两列数据中相同的数据进行同行显示

一、要求 假设您有两个列&#xff0c;分别是A列和B列&#xff0c;需要在C列中找出A列对应的B列的值。 二、方案 方法1&#xff1a;寻常思路 凸显重复项对A列单独进行筛选–按颜色进行排序&#xff0c;然后升序对B列重复上述操作即可 方法2&#xff1a;两个公式 VLOOKUP 纵向查找…