TensorFlow进行MNIST数据集手写数字识别,保存模型并且进行外部手写图片测试

news2024/7/4 5:53:01

首先,你已经配置好Anaconda3的环境,下载了TensorFlow模块,并且会使用jupyter了,那么接下来就是MNIST实验步骤。

数据集官网下载:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burgesicon-default.png?t=N4P3http://yann.lecun.com/exdb/mnist/

 

将上面四个全部下载,都是数据集,其中前两个是训练集,后两个是测试集

当然上面的数据集使用TensorFlow1.x要方便一点,本文章是使用TensorFlow2.x版本,使用npz格式数据集 

数据集官网下载:https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

导入该导的包

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import datetime

接下来是读取数据集,有以下方法可以读取

第一种非常简单,直接使用tensorflow中的方法从官网下载数据集,这样就不用手动去官网下载数据集了,但是比较耗时间

mnist=tf.keras.datasets.mnist

#获取数据,训练集,测试集 60k训练,10K测试
#网络下载
(x_train,y_train),(x_test,y_test)=mnist.load_data()

第二种方法,使用下载好的本地npz文件,将下面的路径改成你自己的

#本地加载
dataset = np.load('/home/tensor/jupyter/MNIST_data2/mnist.npz')
# 获取训练集和测试集
x_train, y_train = dataset['x_train'], dataset['y_train']
x_test, y_test = dataset['x_test'], dataset['y_test']

下一步,原图片的大小为28x28,将图片转成32x32的大小效果更好

#首先是数据 INPUT 层,输入图像的尺寸统一归一化为32*32。
x_train= np.pad(x_train,((0,0),(2,2),(2,2)),'constant',constant_values=0) #28*28-》32*32
x_test= np.pad(x_test,((0,0),(2,2),(2,2)),'constant',constant_values=0) #28*28-》32*32
#print(x_train.shape,x_test.shape)

#数据集格式转换
# x_train=x_train.astype('float32')
# x_train=x_train.astype('float32')

#归一化,就是为了限定你的输入向量的最大值跟最小值不超过你的隐层跟输出层函数的限定范围。
x_train=x_train/255#归一化
x_test=x_test/255#归一化

x_train=x_train.reshape(x_train.shape[0],32,32,1)
x_test=x_test.reshape(x_test.shape[0],32,32,1)
print(x_train.shape,x_test.shape)

输出结果如下

创建神经网络模型,这里我简单构建一个卷积神经网络模型

def LeNetModel():
#模型实例化,根据LeNet 的结构
    model=tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(filters=6,kernel_size=(5,5),padding='valid',activation=tf.nn.relu,input_shape=(32,32,1)),
        tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=(2,2),padding='same'),
        tf.keras.layers.Conv2D(filters=16,kernel_size=(5,5),padding='valid',activation=tf.nn.relu,input_shape=(32,32,1)),
        tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=(2,2),padding='same'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units=120,activation=tf.nn.relu),
        tf.keras.layers.Dense(units=84,activation=tf.nn.relu),
        tf.keras.layers.Dense(units=10,activation=tf.nn.softmax),
        ])
    return model

加载模型,并且输出摘要

model = LeNetModel()
model.summary() #输出摘要

输出如下:

  

定义超参数

num_epochs=1#训练次数
batch_size=60#每个批次喂多少张图片
lr=0.001#学习率

开始训练模型

#优化器
adam_optimizer=tf.keras.optimizers.Adam(lr)
     
model.compile(
        optimizer=adam_optimizer,
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        metrics=['accuracy']
    )

start_time=datetime.datetime.now() #开始训练时间
     
model.fit(x=x_train,y=y_train,batch_size=batch_size,epochs=num_epochs)
end_time=datetime.datetime.now() #训练结束时间
time_cost=end_time-start_time #训练总时间
print('time_cost: ',time_cost)
model.save('leNet_model.h5') #保存模型
print(model.evaluate(x_test,y_test))
print("Finished!")

输出结果如下:

接下来选择一张测试集图片或多张测试集图片进行测试

def pred_function(image,label):
    image_index=20
     
    # 预测
    pred=model.predict(image.reshape(1,32,32,1))
    print("label:",label,"predict result:",pred.argmax())
     
    # 显示
    plt.imshow(image.reshape(32,32))
    plt.savefig("predict_num.jpg")
    plt.show()
index = 100 #选择第100张图片进行测试
pred_function(x_test[index],y_test[index])

 结果如下:

在结果中,label表示正确值,predict result表示模型测试的结果值

完成以上步骤,那么下面就是使用已保存的模型来预测我们自己手写的数字,将10.jpg改成你的手写图片路径

#模型使用
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib as m
import numpy as np
import cv2
import os

mnist=tf.keras.datasets.mnist
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# np.set_printoptions(threshold=np.inf)
#加载模型
def digit_predict():
    model=tf.keras.models.load_model('leNet_model.h5')
     
    #图片预处理
    img=cv2.imread('10.jpg')
    print(img.shape)
    plt.imshow(img)
    plt.show()
     
    #灰度图
    img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    print(img.shape)
    plt.imshow(img,cmap='Greys')
    plt.show()
     
    #取反
    img=cv2.bitwise_not(img)
    plt.imshow(img,cmap='Greys')
    plt.show()
     
    print('二值化前:',img.shape)
    print('二值化前:',img)
    #纯黑 纯白 二值化
    img[img<=100]=0
    img[img>=140]=255
    plt.imshow(img,cmap='Greys')
    plt.show()
    print('二值化后:',img.shape)
    print('二值化后:',img)
     
    #尺寸
    img=cv2.resize(img,(32,32))
    print('尺寸:',img.shape)
    print('尺寸',img)
     
    #归一化
    img=img/255
    print('归一化:',img.shape)
    print('归一化:',img)#0和1组成

    #预测
    pred=model.predict(img.reshape(1,32,32,1))
    print('prediction Number: ',pred.argmax())
     
    #打印图片信息
    plt.imshow(img)
    plt.show()
     
digit_predict()

 下面是测试我自己手写数字的结果:

我写的数字是8,预测结果也是8,效果还可以 ,那么本篇文章就到此结束啦,感谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/592265.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

apple pencil的替代品买啥比较好?平价电容笔推荐

随着技术的发展&#xff0c;出现了许多种类的电容笔。一款好的电容笔&#xff0c;不但可以极大地提升我们的工作效率&#xff0c;也可以极大地改善我们的学习效果。平替电容笔无论是在技术方面&#xff0c;还是在产品质量方面&#xff0c;都有着非常广泛的应用前景。下面就是我…

Java领域的序列化与反序列化,Java的对象如何传输,常用序列化技术

文章目录 一、引出问题&#xff1a;Java原生的序列化1、基于Socket传输对象案例2、什么是序列化3、Java 原生序列化4、serialVersionUID 的作用5、transient 关键字绕开 transient 机制的办法writeObject 和 readObject 原理 6、Java 序列化的一些简单总结 二、分布式架构下常见…

【智能座舱系列| AR-HUD增强现实】—AR-HUD到底是“鸡肋”还是“真”香?

AR-HUD 概念 HUD,即抬头显示(Head Up Display),又叫平视显示系统。它的作用,就是把时速、导航等重要的行车信息,投影到驾驶员前面的挡风玻璃上,让驾驶员尽量做到不低头、不转头就能看到。 这种显示系统,原是军用战斗机上的显示系统,飞行员不必低头,就能在挡风玻璃上…

ChatGPT学习笔记;Meta发布Megabyte AI模型抗衡Transformer

AI知识 ChatGPT学习笔记 文章包括如下的内容&#xff1a; ChatGPT 介绍科普 背景知识ChatGPT 功能ChatGPT 原理 等等&#xff0c;文章的地址在这里。 AI新闻 &#x1f680; Meta发布Megabyte AI模型抗衡Transformer&#xff1a;解决后者已知问题、速度提升4成 摘要&…

笔试强训5

作者&#xff1a;爱塔居 专栏&#xff1a;笔试强训 作者简介&#xff1a;大三学生&#xff0c;希望和大家一起进步 目录 day6 day7 day6 1.关于抽象类与最终类&#xff0c;下列说法错误的是&#xff1f; A 抽象类能被继承&#xff0c;最终类只能被实例化。 B 抽象类和最终类…

NET HELPMSG 3534 报错

使用了带管理员权限的 PowerShell&#xff08;即在管理员权限下运行CMD&#xff09; 然后进行安装和服务启动操作 1、清空 MySQL 下的 data 文件夹&#xff1b; 2、确保系统环境变量中已经配置了 mysql 的 bin 目录到Path中&#xff1b; 3、执行以下命令&#xff1a; sc delet…

《Opencv3编程入门》学习笔记—第四章

《Opencv3编程入门》学习笔记 记录一下在学习《Opencv3编程入门》这本书时遇到的问题或重要的知识点。 第四章 OpenCV数据结构与基本绘图 四、基础图像容器Mat &#xff08;一&#xff09;数字图像存储概述 图像在数码设备中的表现形式为包含众多强度值的像素点矩阵。 &a…

JAVA键盘录入

文章目录 JAVA键盘录入1.导包2.创建对象3.接受数据接收 b o o l e a n \color{red}{boolean} boolean类型数据接收 b y t e \color{red}{byte} byte类型数据接收 s h o r t \color{red}{short} short类型数据接收 i n t \color{red}{int} int类型数据接收 l o n g \color{red}{…

库的制作与使用

什么是库 库是一种可执行的二进制文件&#xff0c;是编译好的代码。使用库可以提高开发效率。在 Linux 下有静态库和动 态库。因此编译出来的体积就比较大。 静态库在程序编译的时候会被链接到目标代码里面。所以程序在运行的时候不再需要静态库了。因此 编译出来的体积就比较大…

轻松学习白嫖GPT-4,已经标星38K,不再害怕高昂的AI模型费用!

文章目录 白嫖方式GPT-4当前可用站点 白嫖方式GPT-4 计算机专业学生xtekky在GitHub上发布了一个名为gpt4free的开源项目&#xff0c;该项目允许您免费使用GPT4和GPT3.5模型。这个项目目前已经获得了380000颗星。 开源地址&#xff1a;https://github.com/xtekky/gpt4free 简而…

vue ts写法

Vue.js 和 TypeScript 结合使用可以让你的项目更加健壮和易于维护。在 Vue 3 中&#xff0c;你可以使用 Vue.js 的 Composition API 和 TypeScript 一起使用。以下是一个简单的 Vue.js 和 TypeScript 结合使用的例子&#xff1a; 首先&#xff0c;确保你已经安装了 Vue.js 和 T…

如何从电机控制转换为运动控制

随着越来越多的技术广泛应用于工业自动化&#xff0c;我们已经进入了工业4.0时代。新技术不断涌现&#xff0c;赋能人工智能和机器学习、数据分析、工业网络、网络安全和功能安全。然而&#xff0c;大多数工业自动化作为其他所有技术的核心&#xff0c;仍然依靠机器人和运动控制…

【PWN · ret2text | PIE 】[NISACTF 2022]ezpie

简单的PIE绕过 目录 前言 一、题目重述 二、解题思路 1.现有信息 2.思考过程 3.exp 总结 前言 所接触的PIE保护的第一题&#xff0c;也非常简单。 一、题目重述 二、解题思路 1.现有信息 PIE保护——程序可能被加载到任意位置&#xff0c;所以位置是可变的。程序返回…

聚观早报 | 英伟达推「AI」超算;中国2030年前载人登月

今日要闻&#xff1a;英伟达推「AI」超算&#xff1b;中国2030年前载人登月&#xff1b;AI大热&#xff0c;游戏股全线大涨&#xff1b;ofo创始人二次创业项目陷入困境&#xff1b;微信视频号原创标记已对外显示 英伟达推「AI」超算 5 月 29 日&#xff0c;NVIDIA 宣布推出一款…

安捷伦E4440A 26.5G频谱分析仪Agilent e4440a 销售/回收

Agilent E4440A HP E4440A频谱分析仪&#xff0c;3 Hz - 26.5 GHz&#xff08;PSA 系列&#xff09; Agilent / Keysight PSA 系列 E4440A 高性能频谱分析仪提供强大的一键式测量、多功能功能集和前沿技术&#xff0c;可满足您的项目和需求。选项可供您选择&#xff08;详情请…

maven 项目中引入第三方jar,并且打包到项目的运行jar包中

背景说明 项目中遇到了人大金仓数据库的jar连接驱动&#xff0c;需要在maven中引入依赖信息 实践 方案1&#xff1a; 1.在官网下载jar包&#xff0c;https://www.kingbase.com.cn/zxwd/index.htm 下载地址。在项目文件中创建libs目录。 修改pom文件的配置信息 <depende…

如何在 Windows 中检查打开的TCP/IP端口

每当应用程序想要通过网络访问自己时,它都会声明一个TCP/IP端口,这意味着该端口不能被其他任何东西使用。那么,如何检查打开的端口以查看哪个应用程序已经在使用它呢? 检查打开的TCP/IP端口 查看端口使用和进程名称查看端口使用和进程标识符查看端口使用和进程名称 首先,你…

【完全揭秘】Traefik云原生网关——助力你的业务破万QPS

Traefik 是一款开源的反向代理和负载均衡软件&#xff0c;可以自动地为多个微服务实例进行负载均衡&#xff0c;并提供 HTTP/HTTPS/TCP/UDP 等协议支持。 Traefik 具有简单易用、自动发现服务、动态配置、可插拔的中间件等特点&#xff0c;被广泛应用于云原生和容器化场景中&am…

【随时更新】面试所需算法数据结构计算机知识点回顾

操作系统LRU算法 MySQL B树 哈夫曼编码和解码 C 哈夫曼编码 【介绍编码过程】 哈夫曼树编码及其图形化的实现 【使用可视化方式展现最终编码效果】 Python中使用哈夫曼算法实现文件的压缩与解压缩 【Python实现】 哈夫曼树 C语言实现 【图解如何生成】 编码过程 1. 使用二进…

基于SpringBoot+Vue的素材管理系统

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取项目下载方式&#x1f345; 一、项目背景介绍&#xff1a; 随着数字化时代的到来…