机器学习实战 | MNIST手写数字分类项目(深度学习初级)

news2025/1/17 6:07:18

目录

  • 简介
  • 技术流程
    • 1. 载入依赖包和数据集
    • 2. 数据预处理
    • 3. 创建卷积神经网络模型
    • 4. 训练神经网络
    • 5. 评价网络
  • 完整程序
    • train.py 程序
    • gui.py程序

简介

准备写个系列博客介绍机器学习实战中的部分公开项目。首先从初级项目开始。


本文为初级项目第二篇:利用MNIST数据集训练手写数字分类。
项目原网址为:Deep Learning Project – Handwritten Digit Recognition using Python。

第一篇为:机器学习实战 | emojify 使用Python创建自己的表情符号(深度学习初级)

技术流程

项目构想:
MNIST数字分类项目,使机器能够识别手写数字。该Python项目对于计算机视觉可能非常有用。在这里,我们将使用MNIST数据集使用卷积神经网络训练模型

经过训练后,在GUI页面(gui.py程序)显示效果如下:左边是手写数字,通过鼠标手写键入;右边点击recognise会提示训练结果以及识别置信度。
在这里插入图片描述

1. 载入依赖包和数据集

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

print(x_train.shape, y_train.shape)

除了常规包外,同样需要提前配置KerasTensorFlow,安装命令为:

pip install keras==2.10.0
pip install TensorFlow==2.10.0

这里需要注意MNIST手写数据集导入方法,直接从Keras中加载:keras.datasets.mnist

通过mnist.load获取训练数据和测试数据,训练数据集维度为: 60000 × 28 × 28 60000 \times 28\times28 60000×28×28,测试数据集维度为: 10000 × 28 × 28 10000 \times 28\times28 10000×28×28.

2. 数据预处理

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
input_shape = (28, 28, 1)

# convert class vectors to binary class matrices
num_classes = 10
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
  • x_train.reshape:将图像数据转换为神经网络输入,图像大小 60000 × 28 × 28 60000 \times 28\times28 60000×28×28,输出大小为 60000 × 28 × 28 × 1 60000 \times 28\times28\times1 60000×28×28×1
  • keras.utils.to_categorical:将阿拉伯数字的0-9共10个数字(类别)转换为one-shot特征,用二进制表示分类类别,比如数字0用0000表示,数字1用0001表示,数字2用0010表示。
  • x_train /= 255:将图像数据归一化,首先将数据类型转换为float32,接着将数据归一化到0~1范围内。

3. 创建卷积神经网络模型

batch_size = 128
num_classes = 10
epochs = 50

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),activation='relu',input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,optimizer=keras.optimizers.Adadelta(),metrics=['accuracy'])

该项目设计了卷积神经网络(CNN)模型,包括两层卷积层、池化层、全连接层等。

函数分析:

  • Sequential:序贯模型,与函数式模型对立。from keras.models import Sequential, 序贯模型通过一层层神经网络连接构建深度神经网络。
  • add(): 叠加网络层,参数可为conv2D卷积神经网络层,MaxPooling2D二维最大池化层,Dropout随机失活层(防止过拟合),Dense密集层(全连接FC层,在Keras层中FC层被写作Dense层),下面会详细介绍这几个层的含义和参数设置。
  • compile(): 编译神经网络结构,参数包括:loss,字符串结构,指定损失函数(包括MSE等);optimizer,表示优化方式(优化器),用于控制梯度裁剪;metrics,列表,用来衡量模型指标,表示评价指标。

网络结构介绍:

  • conv2D: 卷积神经网络层,参数包括:
  1. filters: 层深度(纵向),一般来说前期数据减少,后期数量逐渐增加,建议选择 2 N 2^N 2N作为深度,比如说:[32,64,128] => [256,512,1024];
  2. kernel_size: 决定了2D卷积窗口的宽度和高度,一般设置为 ( 1 × 1 ) (1\times1) (1×1) ( 3 × 3 ) (3\times3) (3×3) ( 5 × 5 ) (5\times5) (5×5) ( 7 × 7 ) (7\times7) (7×7).
  3. activation:激活函数,可选择为:sigmoid,tanh,relu等
  • MaxPooling2D: 池化层,本质上是采样,对输入的数据进行压缩,一般用在卷积层后,加快神经网络的训练速度。没有需要学习的参数,数据降维,用来防止过拟合现象。
  • Dropout:防过拟合层,在训练时,忽略一定数量的特征检测器,用来增加稀疏性,用伯努利分布(0-1分布)B(1,p)来随机忽略特征数量,输入参数为p的大小
  • Flatten:将多维输入数据一维化,用在卷积层到全连接层的过渡,减少参数的使用量,避免过拟合现象,无参。
  • Dense:全连接层,将特征非线性变化映射到输出空间上。

4. 训练神经网络

hist = model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,verbose=1,validation_data=(x_test, y_test))
print("The model has successfully trained")

model.save('mnist.h5')
print("Saving the model as mnist.h5")

  • model.fit:在搭建完成后,将数据送入模型进行训练。参数包括:
  1. x:训练数据输入;
  2. y:训练数据输出;
  3. batch_size: batch样本数量,即训练一次网络所用的样本数;
  4. epochs:迭代次数,即全部样本数据将被“轮”多少次,轮完训练停止;
  5. verbose:可选训练过程中信息是否输出参数,0表示不输出信息,1表示显示进度条(一般默认为1),2表示每个epoch输出一行记录;
  6. valdation_data:验证数据集。
  • model.save:保存训练模型权重

训练成功后,会在源目录下保存mnist.h5文件,即为权重文件。

5. 评价网络

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
  • model.evaluate:评价网络,返回值是一个浮点数,表示损失值和评估指标值,输入参数为测试数据,verbose表示测试过程中信息是否输出参数,同样verbose=0表示不输出测试信息。

完整程序

train.py : 完整训练代码。

gui.py: GUI窗口,输出可互动的界面。

train.py 程序

"""
Handwrittern digit recognition
"""

"""
1. Import the libraries and load the dataset
"""
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

print(x_train.shape, y_train.shape)

"""
2. Preprocess the data
"""
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
input_shape = (28, 28, 1)

# convert class vectors to binary class matrices
num_classes = 10
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

"""
3. Create the model
"""
batch_size = 128
num_classes = 10
epochs = 50

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),activation='relu',input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,optimizer=keras.optimizers.Adadelta(),metrics=['accuracy'])

"""
4. Train the model
"""
hist = model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,verbose=1,validation_data=(x_test, y_test))
print("The model has successfully trained")

model.save('mnist.h5')
print("Saving the model as mnist.h5")

"""
5. Evaluate the model
"""
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

训练结果会保存在源目录下,生成文件名为:mnist.h5

gui.py程序

from keras.models import load_model
from tkinter import *
import tkinter as tk
import win32gui
from PIL import ImageGrab, Image
import numpy as np

model = load_model('mnist.h5')

def predict_digit(img):
    #resize image to 28x28 pixels
    img = img.resize((28, 28))
    #convert rgb to grayscale
    img = img.convert('L')
    img = np.array(img)
    #reshaping to support our model input and normalizing
    img = img.reshape(1, 28, 28, 1)
    img = img/255.0
    #predicting the class
    res = model.predict([img])[0]
    return np.argmax(res), max(res)

class App(tk.Tk):
    def __init__(self):
        tk.Tk.__init__(self)

        self.x = self.y = 0

        # Creating elements
        self.canvas = tk.Canvas(self, width=300, height=300, bg = "white", cursor="cross")
        self.label = tk.Label(self, text="Thinking..", font=("Helvetica", 48))
        self.classify_btn = tk.Button(self, text = "Recognise", command =self.classify_handwriting)
        self.button_clear = tk.Button(self, text = "Clear", command = self.clear_all)

        # Grid structure
        self.canvas.grid(row=0, column=0, pady=2, sticky=W, )
        self.label.grid(row=0, column=1,pady=2, padx=2)
        self.classify_btn.grid(row=1, column=1, pady=2, padx=2)
        self.button_clear.grid(row=1, column=0, pady=2)

        #self.canvas.bind("<Motion>", self.start_pos)
        self.canvas.bind("<B1-Motion>", self.draw_lines)

    def clear_all(self):
        self.canvas.delete("all")

    def classify_handwriting(self):
        HWND = self.canvas.winfo_id() # get the handle of the canvas
        rect = win32gui.GetWindowRect(HWND) # get the coordinate of the canvas
        im = ImageGrab.grab(rect)

        digit, acc = predict_digit(im)
        self.label.configure(text= str(digit)+', '+ str(int(acc*100))+'%')

    def draw_lines(self, event):
        self.x = event.x
        self.y = event.y
        r=8
        self.canvas.create_oval(self.x-r, self.y-r, self.x + r, self.y + r, fill='black')

app = App()
mainloop()

gui.py程序中用了tkinter包来呈现GUI页面,具体语句这里就不再分析解释了,需要学习的话可以参考以下链接:Python GUI编程(Tkinter)

gui.py运行后,输出页面为:

在这里插入图片描述
通过键盘在左侧手写字符,点击recognise输出识别结果。
在这里插入图片描述


如有问题,欢迎指出和讨论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/738964.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

111、基于51单片机的电磁感应无线充电系统 手机无线充电器设计(程序+原理图+Proteus仿真+程序流程图+论文参考资料等)

方案选择 单片机的选择 方案一&#xff1a;AT89C52是美国ATMEL公司生产的低电压&#xff0c;高性能CMOS型8位单片机&#xff0c;器件采用ATMEL公司的高密度、非易失性存储技术生产&#xff0c;兼容标准MCS-51指令系统&#xff0c;片内置通用8位中央处理器(CPU)和Flash存储单元…

在SpringBoot中对微服务项目的简单使用

准备数据库的数据 create database leq_sc; CREATE TABLE if not exists products(id INT PRIMARY KEY AUTO_INCREMENT,name VARCHAR(50), #商品名称 price DOUBLE,flag VARCHAR(2), #上架状态 goods_desc VARCHAR(100), #商品描述images VARCHAR(400), #商品图?goods_stock I…

[工业互联-21]:常见EtherCAT主站方案:Kithara实时套件

第1章 Kithara实时套件概述 1.1 概述 Kithara Software是一家德国的软件公司&#xff0c;专注于实时技术和嵌入式解决方案。 他们为Windows操作系统提供了Kithara RealTime Suite&#xff0c;这是一套实时扩展模块&#xff0c;使Windows能够实现硬实时任务和控制。 Kithara…

菜比:你还不会接口测试?

很多人会谈论接口测试。到底什么是接口测试&#xff1f;如何进行接口测试&#xff1f;这篇文章会帮到你。 一、前端和后端 在谈论接口测试之前&#xff0c;让我们先明确前端和后端这两个概念。 前端是我们在网页或移动应用程序中看到的页面&#xff0c;它由 HTML 和 CSS 编写…

QT - 20230710

练习&#xff1a;实现一个简易闹钟 widget.h #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QDateTime> #include <QDebug> #include <QTextToSpeech>namespace Ui { class Widget; }class Widget : public QWidget {Q_OBJECTpubl…

Kafka入门, Kafka-Kraft 模式 部署(二十六)

Kafka-Kraft 模式 左图为kafka现有架构&#xff0c;元数据在zookeeper中&#xff0c;运行时动态选举controller,由controller进行kafka集群管理&#xff0c;右图为kraft模式架构&#xff08;实验性&#xff09;&#xff0c;不再依赖zookeeper集群&#xff0c;而是用三台control…

MyBatis将查询的两个字段分别作为Map的key和value

问题背景 首先查出 危险源id 和 危险源报警的个数 alarm 遍历危险源&#xff0c;将报警数填充进去 所以&#xff0c;我需要根据id得到alarm 最方便的就是Map 经过sql查询 -- 危险源下的对象的报警个数select id, ifnull(alarm_count,0) alarm from spang_monitor_danger_…

LongLLaMA:LLaMA的升级版,处理超长上下文的利器!

原文来源&#xff1a;芝士AI吃鱼 有效整合新知识&#xff1a;大模型面临的挑战 大家使用过大型模型产品的时候可能会遇到一个共同的问题&#xff1a;在进行多轮对话时&#xff0c;模型可能会忘记之前的对话内容&#xff0c;导致回答不连贯。这实际上是由于大型模型在处理大量新…

ARM day10 (IIC协议接收温湿传感器数据)

iic.h #ifndef __IIC_H__ #define __IIC_H__ #include "stm32mp1xx_gpio.h" #include "stm32mp1xx_rcc.h" /* 通过程序模拟实现I2C总线的时序和协议* GPIOF ---> AHB4* I2C1_SCL ---> PF14* I2C1_SDA ---> PF15** */#define SET_SDA_OUT do{…

c++实现贝塞尔曲线,生成缓动和回弹动画

贝塞尔曲线于1962年由法国工程师皮埃尔贝塞尔(Pierre Bzier)所广泛发表,他运用贝塞尔曲线来为汽车的主体进行设计。 一般参数公式 贝兹曲线可如下推断。给定点P0、P1、…、Pn,其贝兹曲线即: 几何学的方向上理解贝塞尔曲线: 一阶贝塞尔曲线 二阶贝塞尔曲线 三阶贝塞尔曲…

记录使用注入的方式为Unity编辑器实现扩展能力

使用场景 当前项目编辑器中不方便存放或者提交扩展代码相同的扩展功能需要在多个项目(编辑器)中使用项目开发中&#xff0c;偶尔临时需要使用一个功能&#xff0c;想随时使用随时卸载 设计思路 使用进程注入&#xff0c;将一个c/c dll注入到当前运行的unity编辑器中使用c/c …

分布式搜索 (二)

一、DSL 查询文档 1. DSL Query 的分类 Elasticsearch 提供了基于 JSON 的 DSL (Domain Specific Language) 来定义查询 常见的查询类型包括&#xff1a; ① 查询所有&#xff1a;查询出所有数据&#xff0c;一般测 试用 例如&#xff1a;match_all ② 全文检索 (full text) …

C++数据结构笔记(8)循环链表实现

1.循环链表与单链表的区别在于尾部结点存在指向头结点的指针 2.无论尾部结点指向第一个结点&#xff08;头结点&#xff09;还是第二个结点&#xff08;第一个有效结点&#xff09;&#xff0c;都可以被称为循环链表 3.判断循环结束的两种方式&#xff1a;遍历次数等于size;或…

《深度探索c++对象模型》笔记

非原创&#xff0c;在学习 1 关于对象&#xff08;Object Lessons&#xff09; 这里最开始从C语言的结构体引出C中的”抽象数据类型&#xff08;ADT&#xff09;“。 而加上封装之后&#xff0c;布局成本没有增加&#xff0c;三个data member直接内含在每一个class object之中…

深入选择屏幕

2.3.4.4 屏幕输入报表筛选条件等 &--------------------------------------------------------------------- *& selection-screen /option/parameter:屏幕输入报表赛选条件 *& TABLES . *selection-screen begin of block test select-options: selection-screen…

PHY芯片快速深度理解

摘要&#xff1a; 什么是phy 为什么要熟悉RJ45网口 网络七层协议 两个模块进行通信 什么是MDIO协议 MDIO的作用 MDIO没那么重要 MDIO读写时序 为什么说读取的phy最多32个 什么是phy 物理层芯片称为PHY、数据链路层芯片称为MAC。 可以看到PHY的数据是RJ45网络接口&am…

linux常见指令下

接下来我们就聊聊linux的后面十条指令。 一:echo 作用是往显示器输出内容&#xff0c;和printf类型&#xff0c;但是该指令最核心的是与之相关的一些概念 概念1.输出重定向&#xff1a; echo不仅可以向显示打印内容&#xff0c;还可以向文件输出内容&#xff0c;本应该输出到…

在服务器上启动springboot项目

环境搭建&#xff1a;要在服务器上运行SpringBoot Web项目&#xff0c;需要先在服务器上安装JDK&#xff08;CentOS系统安装JDK参考&#xff1a;http://t.csdn.cn/0zYml&#xff09; 第一步&#xff1a;创建项目 创建一个简单的springboot项目&#xff0c;并通过测试&#xf…

Java Web Servlet (2)23.7.8

1.7 urlPattern配置 Servlet类编写好后&#xff0c;要想被访问到&#xff0c;就需要配置其访问路径&#xff08;urlPattern&#xff09; 一个Servlet,可以配置多个urlPattern package com.itheima.web;import javax.servlet.ServletRequest; import javax.servlet.ServletRes…

嵌入式基础知识-流水线

提到流水线&#xff0c;最先想到的可能是流水线车间中的产品制造过程。 工业上的流水线&#xff0c;又称装配线&#xff0c;指每一个生产单位只专注处理某一个片段的工作&#xff0c;以提高工作效率及产量。 在计算机领域中&#xff0c;也有流水线的概念&#xff0c;其核心原理…