keras深度学习框架通过简单神经网络实现手写数字识别

news2025/1/13 21:15:48

   背景

     keras深度学习框架,并不是一个独立的深度学习框架,它后台依赖tensorflow或者theano。大部分开发者应该使用的是tensorflow。keras可以很方便的像搭积木一样根据模型搭出我们需要的神经网络,然后进行编译,训练,测试,预测。

    今天介绍的手写数字识别实验,主要是熟悉keras搭建神经网络的流程,以及大体的思路。现如今,手写数字识别实验的代码各种各样,对于初学者而言,我们需要的是类似helloworld那样简单的示例。通过示例,我们可以了解神经网络的搭建过程。

    这里使用的手写数字识别,通过搭建网络,构建模型,最后保存模型,然后我们加载模型,通过真实的图片来预测,也检验一下神经网络的能力。

     这里手写数字识别数据来源于官方自带mnist数据集,这个数据集包含60000个训练集和10000个测试集。每个数据是由28 * 28 = 784个矩阵元素组成。所以我们自己用来测试的图片最后应该也要按照这个28*28的尺寸来制作,并且最后进行预测predict的时候,也要像训练集或者测试集一样,把图片转为一个784元素的数组。

    准备代码

import keras
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Activation
from tensorflow.keras import datasets, utils
import matplotlib.pyplot as plt


(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train.reshape((-1, 28*28))
x_train = x_train.astype('float32')/255
x_test = x_test.reshape((-1, 28*28))
x_test = x_test.astype('float32')/255

y_train = utils.to_categorical(y_train, num_classes=10)
y_test = utils.to_categorical(y_test, num_classes=10)

print('x_train.shape', x_train.shape)
print('x_test.shape', x_test.shape)
print('y_train.shape', y_train.shape)
print('y_test.shape', y_test.shape)
"""
layer = [Dense(32, input_shape=(784,)),
         Activation('relu'),
         Dense(10),
         Activation('softmax')]

model = Sequential(layer)
"""
model = Sequential()
# model.add(Dense(units=784, activation="relu", input_dim=784))
model.add(Dense(512, activation="relu", input_shape=(28*28, )))
model.add(Dense(10, activation="softmax"))

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()

history = model.fit(x_train, y_train, epochs=5, batch_size=128, validation_data=(x_test, y_test))

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'bo', label="Training accuracy")
plt.plot(epochs, val_acc, 'b', label="Validation accuracy")
plt.title('Training and Validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
model.save("mnist.h5")
prediction = model.predict(x_test[:1], batch_size=32)
print(x_test[:1])
print(y_test[:1])
print(prediction)
print(np.argmax(prediction, axis=1))

    这个代码在引入了相关库之后,进行的第一件事就是数据处理:

(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train.reshape((-1, 28*28))
x_train = x_train.astype('float32')/255
x_test = x_test.reshape((-1, 28*28))
x_test = x_test.astype('float32')/255
y_train = utils.to_categorical(y_train, num_classes=10)
y_test = utils.to_categorical(y_test, num_classes=10)

print('x_train.shape', x_train.shape)
print('x_test.shape', x_test.shape)
print('y_train.shape', y_train.shape)
print('y_test.shape', y_test.shape)

     我们的数据集x_train,x_test就是我们的图片数据,这个数据是784个元素组成的数组,我们先进行转矩阵,然后对像素点取模,得到0-1之间的值。我们代码最后打印了x_test[:1],可以看看它的样子:

    这里我们还使用了utils.to_categorical(y_test,num_classes=10) 对我们的目标进行了one-hot转码。通过这个图我们也看到了,数字 7 转了one-hot编码之后,变为了[0,0,0,0,0,0,0,1,0,0]。

    这个代码构建了一个简单的神经网络,也就两层,

     第一层输入层 Dense(512,activation="relu",input_shape=(28*28, ))  #512个节点,relu激活函数,输入形状或者维度 28*28=784。代码中也给出了另一种通过input_dim来指定维度的方法,意思是一样的,但是那种写法model.add(Dense(units=784, activation="relu", input_dim=784))指定的网络节点units=784。这个数字可以随便定义。手写数字识别里面,设置512,784都可以。

    第二层输出层 Dense(10, activation="softmax") #这里指定对应十个分类,也就是数字0,1,2,3,4,5,6,7,8,9的个数。手写数字识别是一个多分类问题。

     没有隐藏层,也没有其他的Dropout。就是简单神经网络。

     另外,代码中还给出了一种构建神经网络的办法:

layer = [Dense(32, input_shape=(784,)),
         Activation('relu'),
         Dense(10),
         Activation('softmax')]

model = Sequential(layer)

    意思是一样的,只不过,这里units=32,也就是输入层由32个神经网络节点组成。 

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()

    这是编译神经网络和打印神经网络概要。

    编译神经网络传入loss="categorical_cressentropy" 表示损失函数求的是交叉熵。optimizer="adam",表示优化器是adam,表示自适应算法,另外,也有可能会看到sgd,随机梯度下降算法,或者rmsprop也是一种自适应算法。metrics=["accuracy"]统计指标,这里指定成功率。 

   通过model.summary()我们可以看到神经网络节点信息: 

   

history = model.fit(x_train, y_train, epochs=5, batch_size=128, validation_data=(x_test, y_test))

    这里是把训练和测试神经网络放在一起了,我们传入的validation_data指定了测试数据集。如果不指定validation_data,那么后面,我们通过model.evaluate(x_test,y_test) 也可以得到loss,acc等数据。

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'bo', label="Training accuracy")
plt.plot(epochs, val_acc, 'b', label="Validation accuracy")
plt.title('Training and Validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

    我们通过matplot来展示acc,val_acc等信息,结果如下图所示:

       我们还通过model.save("mnist.h5")保存模型,后面我们会加载这个模型来进行预测。

prediction = model.predict(x_test[:1], batch_size=32)
print(x_test[:1])
print(y_test[:1])
print(prediction)
print(np.argmax(prediction, axis=1))

    我们简单通过测试集的第一个数字7来进行了一个验证,这个验证,主要是要知道我们将来传入图片需要什么类型的数据,以及得到预测结果之后,怎么取值。这里prediction是一个按照概率来进行组装的数组,哪个概率大,最终的结果就是谁。我们通过np.argmax(prediction, axis=1)指定获取一个数组中按行(axis=1)来统计最大的那个数。

***************************************************************

    预测

     很多代码示例里面,基本上到了model.evaluate()对算法进行评估之后,就没有了,对于刚入门的人来说,神经网络创建了,测试了,好不好用也不知道。因为这个训练集和测试机都是官网给出的例子,对于程序员来说,通过实践来验证一个猜测,那才是最重要的,至于这是什么不重要。

     上面的代码最后,我们通过测试集x_test[:1]也就是第一个测试数字简单做了一个预测,大概知道了要预测,需要的数据是一个[28*28=784]的数组。而我们准备的测试图片应该也要和官方给出的测试数据对应上,也即是前面提到的图片是28*28像素的数字图片,如下所示:

    同样的给出代码:

import keras
import numpy as np
import cv2
from keras.models import load_model

model = load_model("mnist.h5")


def predict(img_path):
    img = cv2.imread(img_path, 0)
    img = img.reshape(28, 28).astype("float32") / 255  # 0 1
    img = img.reshape(-1, 784)  # 28 * 28 -> 784
    label = model.predict(img)
    label = np.argmax(label, axis=1)
    print('{} -> {}'.format(img_path, label[0]))


if __name__ == '__main__':
    for _ in range(10):
        predict("number_images/b_{}.png".format(_))

    这些图片我们放在number_images目录下,命名规则是b_0.png,b_1.png这样子。

    最后,我们加载模型,并通过opencv库加载图片,并转换图片矩阵为784个元素的数组。然后交给模型预测,预测结果是一个概率数组,取概率最大的那个数组元素。

     预测结果如下:

    结果很感人,并没有达到很高的概率,准确率60%,而且这个概率对于手写图片识别来说,还有点偏高,因为实际上很多数字图片识别错误。 

    这篇文章,主要就keras构建简单神经网络,并进行训练,测试,最后还通过我们自己手写的数字图片来进行预测验证,也过了一把深度学习的瘾。

    本文keras和tensorflow版本是2.8.0,可能有几个api与其他地方有区别,比如datasets,这里使用的是tensorflow.keras.datasets。另外在计算成功率acc的时候,使用的是history['accuracy'],有的地方可能直接是history['acc'],应该是版本的问题,根据自己的版本找到合适的方法就行。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/935709.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

4.22 TCP 四次挥手,可以变成三次吗?

目录 为什么 TCP 挥手需要四次呢? 粗暴关闭 vs 优雅关闭 close函数 shotdown函数 什么情况会出现三次挥手? 什么是 TCP 延迟确认机制? TCP 序列号和确认号是如何变化的? 在一些情况下, TCP 四次挥手是可以变成 T…

如何识别PCI/PCIE设备需要申请多大的地址空间?

1、PCI/PCIE设备的配置空间 (1)PCI/PCIE设备需要的资源都在配置空间里进行指定,其中需要的地址空间资源在配置空间的基地址寄存器里指定; (2)参考博客:《PCI设备和PCI桥的配置空间(header_type0、header_type1)和配置命令(type0、type1)详解》…

交叉编译 libzdb

参考博客:移植libzdb3.2.2到arm_configure: error: no available database found or s_酣楼驻海的博客-CSDN博客 编译时间 2023-08-23 libzdb 下载: 源码访问如下: https://bitbucket.org/tildeslash/libzdb/src/master/ git 下载链接 …

JavaScript设计模式(一)——构造器模式、原型模式、类模式

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 📃个人状态: 研发工程师,现效力于中国工业软件事业 🚀人生格言: 积跬步…

Lottery抽奖项目学习第二章第一节:环境、配置、规范

Lottery抽奖项目学习第二章第一节:环境、配置、规范 环境、配置、规范 下面以DDD架构和设计模式落地实战的方式,进行讲解和实现分布式抽奖系统的代码开发,那么这里会涉及到很多DDD的设计思路和设计模式应用,以及互联网大厂开发中…

【GoLang】go入门:go语言执行过程分析 常见数据类型(基本数据类型)

1、go语言执行过程分析 【1】执行流程分析 通过 go build 进行编译 运行上一步生成的可执行文件 通过 go run 命令直接运行 【2】上述两种执行流程的区别 在编译时,编译器会将程序运行时依赖的库文件包含在可执行文件中,所以可执行文件会变大很多通过g…

c++中的基本类型

专栏简介:为什么我要重新介绍c的相关知识,在此之前,我对于c的了解也仅仅是在表面。而在后来与c慢慢的接触中,c编程语言越来越让我觉得深奥,所以还是想要重新开创一个专栏来介绍c。对于c的介绍,本专栏会先介…

学会shell 基本语法,玩转linux

01、获取当前时间,年月日时分秒 now$(date %Y%m%d%H%M%S) echo "$now" 输出为:20181202222727 02、date 在脚本中的几种用法 date %Y 以 4 位数字格式打印年份 date %y 以 2 位数字格式打印年份 date %m 月份 date %d 日期 date %H 小时 d…

IO模型和NGINX安装升级

IO模型和NGINX安装升级 IO模型 IO概念 I/O在计算机中指Input/Output, IOPS (Input/Output Per Second)即每秒的输入输出量(或读写次数),是衡量磁盘性能的主要指标之一。 Linux的IO类型 磁盘I/O 磁盘I/O是进程向内核发起系统调用,请求磁…

抖店出单后怎么操作?谈厂家话术与发货注意事项,抖店最新教程

我是王路飞。 当你的抖店出单后,你是怎么操作的? 还是像之前那样去拼多多代拍发货?这样做的商家,不知道你的店铺被封了几个了? 记住,现在抖店出单后,一定不要再去多多拍单发货了! 关于抖店…

0基础入门C++之类和对象下篇

目录 1.再谈构造函数1.1构造函数赋值1.2初始化列表1.3explicit关键字 2.static成员2.1概念2.1静态成员变量2.2静态成员函数2.3特性 3.匿名对象4.友元函数4.1友元函数4.2友元类 5.内部类6.再次理解类和对象 1.再谈构造函数 首先我们先来回忆一下构造函数: 构造函数是…

【springboot】springboot定时任务:

文章目录 一、文档:二、案例: 一、文档: 【cron表达式在线生成器】https://cron.qqe2.com/ 二、案例: EnableScheduling //开启任务调度package com.sky.task;import com.sky.entity.Orders; import com.sky.mapper.OrderMapper; …

CAN总线学习——物理层、数据链路层、CANopen协议

1、CAN总线介绍 1.1、CAN总线描述 (1)CAN总线支持多节点通信,但是节点不分区主从,也就是不存在一个节点来负责维护总线的通信;这点可以和I2C总线对对比,I2C是一主多从模式; (2)是差分、异步、串行总线,采用…

分布式—雪花算法生成ID

一、简介 1、雪花算法的组成: 由64个Bit(比特)位组成的long类型的数字 0 | 0000000000 0000000000 0000000000 000000000 | 00000 | 00000 | 000000000000 1个bit:符号位,始终为0。 41个bit:时间戳,精确到毫秒级别&a…

Python slice(切片)

在Python中,切片(slice)是对序列型对象(如list, string, tuple)的一种高级索引方法。普通索引只取出序列中一个下标对应的元素,而切片取出序列中一个范围对应的元素,这里的范围不是狭义上的连续片段。 切片的基本语法为: object…

ChromeOS 的 Linux 操作系统和 Chrome 浏览器分离

导读科技媒体 Ars Technica 报道称,谷歌正在将 ChromeOS 的浏览器从操作系统中分离出来 —— 让它变得更像 Linux。虽然目前还没有任何官方消息,但这项变化可能会在本月的版本更新中推出。 据介绍,谷歌将该项目命名为 "Lacros"——…

防溺水预警识别系统算法

防溺水预警识别系统旨在通过opencvpython网络模型深度学习算法,防溺水预警识别系统算法实时监测河道环境,对学生等违规下水游泳等危险行为进行预警和提醒。Python是一种由Guido van Rossum开发的通用编程语言,它很快就变得非常流行&#xff0…

strcat函数

目录 函数介绍: 函数声明: 具体使用: 注意事项: 字符串⾃⼰给⾃⼰追加,如何? 模拟实现strcat函数: 函数介绍: 被称为字符串的追加/链接函数,它的功能就是在一个字符…

gcc/linux下的c++异常实现

概述 本文不一定具有很好的说教性,仅作为自我学习的笔记。不妨可参阅国外大神博文C exceptions under the hood链接中包含了大量的例子。 偶有在对ELF做分析的时候看到如下图一些注释,部分关键字看不懂,比如什么FDE, unwind , __gxx_perso…

【技巧分享】如何获取子窗体选择了多少记录数?一招搞定!

Hi,大家好久不见。 我这个更新速度是不是太慢了呀,因为,最近又又又在忙,请大家谅解啦。 现在更新文章、视频都要花好久去考虑,好不容易有个灵感了,一搜索,结果发现之前都已经分享过了(委屈脸&…