CNN神经网络——手写体识别

news2024/12/23 7:37:23

目录

Load The Datesets

Defining,Training,Measuring CNN Algorithm

Datasets

GRAET HONOR TO SHARE MY KNOWLEDGE WITH YOU

This paper is going to show how to use keras to relize a CNN model for digits classfication

Load The Datesets

The datasets files are shown in the follwing chart:

we see, there are all .gz file, which is a common zipped file format. And differ from some datasets needed to be splitted as training sest and testting set, this datasets shown above is already be splitted as traning and testing set, where file t10k-images-idx3-ubyte.gz and t10k-labels-idx1-ubyte.gz is the features and labels of testing sets respectively, while the other two is features and labels of traning sets. When upzipping those files,  one could find there is a binary file:

As we see, this file is not CSV , XLSX, DATA or other image file like PNG, JPG or JPEG. So How to read it to Python is another problem we have to deal with. Since is not our focus, there I just show the code:

def data_generate():
    import gzip
    datadir = r'..\数据集\MNIST_data'   #datadir为解压路径
    sources = ['t10k-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz',
               'train-images-idx3-ubyte.gz','train-labels-idx1-ubyte.gz']
    
    """函数说明:def extract_tar函数用于解压某个tar.gz的压缩文件。"""
    def extract_tar(datafile,extractdir):   #定义一个解压缩函数
        file = datafile.replace(".gz","")
        g_file = gzip.GzipFile(datafile)
        #读取解压后的文件,并写入去掉后缀名的同名文件(即得到解压后的文件)
        open(file, "wb+").write(g_file.read())    #将文件加压缩到压缩文件所在的文件夹中
        g_file.close()
        print("%s 解压完成."%datafile)
        return file   #返回解压缩后文件的名称字符串
    data_file = []    
    for source in sources:    #通过遍历解压文件夹datasets_q8中所有文件
        datafile = r'%s\%s' %(datadir,source)   #指定待解压缩文件
        file = extract_tar(datafile,datadir)    #将文件压缩到路径datadir
        data_file.append(file)
    
    
    """很多类型的文件,其起始的几个字节的内容是固定的(或是有意填充,
    或是本就如此)。根据这几个字节的内容就可以确定文件类型,
    因此这几个字节的内容被称为魔数 (magic number)。"""
    import struct
    import numpy as np
    def decode_idx1_ubyte(idx1_ubyte_file):
        """
        解析idx1文件的通用函数
        :param idx1_ubyte_file: idx1文件路径
        :return: 数据集
        """
        # 读取二进制数据
        bin_data = open(idx1_ubyte_file, 'rb').read()
    
        # 解析文件头信息,依次为魔数和标签数
        offset = 0
        fmt_header = '>ii'
        magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
        print('魔数:%d, 图片数量: %d张' % (magic_number, num_images))
    
        # 解析数据集
        offset += struct.calcsize(fmt_header)
        fmt_image = '>B'
        labels = np.empty(num_images)
        for i in range(num_images):
            if (i + 1) % 10000 == 0:
                print ('已解析 %d' % (i + 1) + '张')
            labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
            offset += struct.calcsize(fmt_image)
        return labels
    y_train = decode_idx1_ubyte(data_file[3])
    y_test = decode_idx1_ubyte(data_file[1])
    y_train = y_train.astype(np.int)
    y_test = y_test.astype(np.int)
    def decode_idx3_ubyte(idx3_ubyte_file):
        """
        解析idx3文件的通用函数
        :param idx3_ubyte_file: idx3文件路径
        :return: 数据集
        """
        # 读取二进制数据
        bin_data = open(idx3_ubyte_file, 'rb').read()
    
        # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
        offset = 0
        fmt_header = '>iiii' #因为数据结构中前4行的数据类型都是32位整型,所以采用i格式,但我们需要读取前4行数据,所以需要4个i。我们后面会看到标签集中,只使用2个ii。
        magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
        print('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))
    
        # 解析数据集
        image_size = num_rows * num_cols
        offset += struct.calcsize(fmt_header)  #获得数据在缓存中的指针位置,从前面介绍的数据结构可以看出,读取了前4行之后,指针位置(即偏移位置offset)指向0016。
        print(offset)
        fmt_image = '>' + str(image_size) + 'B'  #图像数据像素值的类型为unsigned char型,对应的format格式为B。这里还有加上图像大小784,是为了读取784个B格式数据,如果没有则只会读取一个值(即一副图像中的一个像素值)
        print(fmt_image,offset,struct.calcsize(fmt_image))
        images = np.empty((num_images, num_rows, num_cols))
        #plt.figure()
        for i in range(num_images):
            if (i + 1) % 10000 == 0:
                print('已解析 %d' % (i + 1) + '张')
                print(offset)
            images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols))
            offset += struct.calcsize(fmt_image)
        return images
    X_train = decode_idx3_ubyte(data_file[2])
    X_test = decode_idx3_ubyte(data_file[0])
    
    """画图代码:读者可以运行该代码直观地查看数据集"""
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(5,5, figsize=(8, 8),subplot_kw={'xticks':[], 'yticks':[]},
                            gridspec_kw=dict(hspace=0.1, wspace=0.1))
    for i, ax in enumerate(axes.flat):
        ax.imshow(X_train[i], cmap='gray', interpolation='nearest')
        ax.text(0.07, 0.07, str(y_train[i]),transform=ax.transAxes, color='white')
    return X_train,y_train,X_test,y_test

The above code define a function that generates the features of Trainingset denoted as X_train, labels of Trainingset denoted as y_train, and X_test, y_test of Testingset. Meanwhile, the above code plotting some of the datasets, shown below:

Defining,Training,Measuring CNN Algorithm

This section we will construct, traning a CNN model ,and measuring it in the testing sets using accuracy score

We run the function to gain dataset:

X_train,y_train,X_test,y_test = data_generate()

Then we import some useful package:

from keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout, Conv2D,AveragePooling2D,Flatten

Then we define the CNN model :

model = Sequential()   #创建神经网络类
model.add(Conv2D(16,kernel_size=(3,3),
                 input_shape=(width,height,1),activation='relu'))    #创建一个卷积层,核为3X3
model.add(Dropout(0.2))    #Dropout正则化
model.add(Conv2D(32,kernel_size=(3,3),activation='relu'))
model.add(Dropout(0.2))
model.add(AveragePooling2D(pool_size=(2,2)))    #添加池化层,并根据平均值将矩阵缩放成2X2
model.add(Dropout(0.2))
model.add(Conv2D(64,kernel_size=(3,3),activation='relu'))
model.add(AveragePooling2D(pool_size=(2,2)))    #添加池化层,缩放成2X2
model.add(Flatten())    #添加Flatten层
model.add(Dense(units=512,activation='relu'))    #BP神经网络的隐藏层,512个节点
model.add(Dense(units=10,activation='sigmoid'))    #输出层

For the principle of Convolution layer or Pooling layer, please refer to 
https://blog.csdn.net/weixin_42141390/article/details/105004900

Then we select Adam as the algorithm to finding the parameters of the CNN model with cross-entropy as cost function,then we set the max iteration is 100, the code is shown as below:

model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])    #使用Adam算法训练模型
history = model.fit(X_train,y_train,epochs=100,batch_size=256,validation_data=(X_test,y_test))    #定义随机搜索算法的mini-batch=256

Running the code, the output windows will show the following information:

Epoch 1/100

60000/60000 [==============================] - 219s 4ms/step - loss: 0.0869 - acc: 0.9712 - val_loss: 0.0199 - val_acc: 0.9934

.......... 

Epoch 99/100

60000/60000 [==============================] - 186s 3ms/step - loss: 2.6969e-05 - acc: 1.0000 - val_loss: 0.0109 - val_acc: 0.9986

Epoch 100/100

60000/60000 [==============================] - 186s 3ms/step - loss: 2.6969e-05 - acc: 1.0000 - val_loss: 0.0110 - val_acc: 0.9986

Then we plot the training curve by varible "history",using the following code:

import matplotlib.pyplot as plt
font1 = {'family' : 'Times New Roman',
'weight' : 'normal',
'size'   : 20,
}
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.plot(history.history['acc'],linewidth=3,label='Train')
plt.plot(history.history['val_acc'],linewidth=3,linestyle='dashed',label='Test')
plt.xlabel('Epoch',fontsize=20)
plt.ylabel('精确度',fontsize=20)
plt.legend(prop=font1)

Datasets

to gain the datasets, you can refer to :https://github.com/1259975740/Machine_Learning/tree/master/chapter12/%E6%95%B0%E6%8D%AE%E9%9B%86/MNIST_data

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/404265.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

便携式井用自动采样器主要有哪些功能特点?

如图此款井用采样器整机小巧,非常适合狭小领域使用,携带方便,采样精准,可以延伸放到井下进行工作。尤其适合:窨井、下水道、沟渠 等现场条件恶劣的工作场合。可帮助采样人员采取到具有代表性的水样从而进行检测 参数特…

笔记本固态盘数据丢失怎么办?笔记本固态盘怎么恢复数据

如果笔记本固态盘数据丢失怎么办?笔记本固态盘怎么恢复数据?下面将为大家详细地介绍一下笔记本固态硬盘数据恢复的三种实用方法,希望对大家有所帮助。一、简单恢复方法笔记本固态硬盘数据删除以后,较为简单直接的恢复方法就是从回…

《C++代码分析》第三回:类成员函数覆盖父类函数的调用(分析this指针的变化)

一、前言 在C的学习中我们知道&#xff0c;子类是可以覆盖父类的方法&#xff0c;这里我们探讨一下子类方法调用父类方法时this指针时如何变化的。 二、示例代码 #include "windows.h" #include "windef.h" #include <iostream> #include <tch…

自学大数据第六天~HDFS命令

HDFS常用命令 查看hadoop版本 version hadoop version注意,没有 ‘-’ [hadoopmaster ~]$ hadoop version Hadoop 3.3.4 Source code repository https://github.com/apache/hadoop.git -r a585a73c3e02ac62350c136643a5e7f6095a3dbb Compiled by stevel on 2022-07-29T12:3…

蓝桥杯嵌入式(G4系列):定时器输出可调PWM

前言&#xff1a; 蓝桥杯定时器输出PWM的考点在历届真题中的出现次数较多&#xff0c;而且之前关于STM32的学习&#xff0c;我对于使用STM32Cubemx配置PWM的方式确实不是很熟悉&#xff0c;这里简单记录一下自己的学习过程。 STM32Cubemx配置部分&#xff1a; 这里我们是改编真…

yocto编译烧录和脚本解析

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录前言一、初始化构建目录二、imx-setup-release.sh脚本解析三、编译单独编译内核四、烧录总结前言 本篇文章主要讲解如何在下载好源码之后进行编译和yocto的脚本解析…

剑指 Offer II 027. 回文链表

题目链接 剑指 Offer II 027. 回文链表 easy 题目描述 给定一个链表的 头节点 head&#xff0c;请判断其是否为回文链表。 如果一个链表是回文&#xff0c;那么链表节点序列从前往后看和从后往前看是相同的。 示例 1&#xff1a; 输入: head [1,2,3,3,2,1] 输出: true 示例…

多目标遗传算法NSGA-II原理详解及算法实现

在接触学习多目标优化的问题上&#xff0c;经常会被提及到多目标遗传算法NSGA-II&#xff0c;网上也看到了很多人对该算法的总结&#xff0c;但真正讲解明白的以及配套用算法实现的文章很少&#xff0c;这里也对该算法进行一次详解与总结。会有侧重点的阐述&#xff0c;不会针对…

理解并解决【跨域】问题--通过代理或【CROS】

文章目录一.理解跨域问题引起跨域问题的原因浏览器的同源策略二.跨域问题的解决办法解决方法1-------代理&#xff08;前端配置&#xff09;解决方法2-----开启跨域资源共享CORS&#xff08;后端&#xff09;知识小贴士一.理解跨域问题 主要出现在前后端分离项目 引起跨域问题…

磨金石教育摄影技能干货分享|春之旅拍

春天来一次短暂的旅行&#xff0c;你会选择哪里呢&#xff1f;春天的照片又该如何拍呢&#xff1f;看看下面的照片&#xff0c;或许能给你答案。照片的构图很巧妙&#xff0c;画面被分成两部分&#xff0c;一半湖泊&#xff0c;一半绿色树林。分开这些的是一条斜向的公路&#…

合并两个链表(自定义位置合并与有序合并)LeetCode--OJ题详解

图片: csdn 自定义位置合并 问题&#xff1a; 给两个链表 list1 和 list2 &#xff0c;它们包含的元素分别为 n 个和 m 个。 请你将 list1 中 下标从 a 到 b 的全部节点都删除&#xff0c;并将list2 接在被删除节点 的位置。 比如&#xff1a; 输入&#xff1a;list1 [1…

【STL】list剖析及模拟实现

✍作者&#xff1a;阿润菜菜 &#x1f4d6;专栏&#xff1a;C 初识list 1. list基本概况 list是可以在常数范围内在任意位置进行插入和删除的序列式容器&#xff0c;并且该容器可以前后双向迭代。list的底层是双向链表结构&#xff0c;双向链表中每个元素存储在互不相关的独立…

前端前沿web 3d可视化技术 ThreeJS学习全记录

前端前沿web 3d可视化技术 随着浏览器性能和网络带宽的提升 使得3D技术不再是桌面的专利 打破传统平面展示模式 前端方向主要流向的3D图形库包括Three.js和WebGL WebGL灵活高性能&#xff0c;但代码量大&#xff0c;难度大&#xff0c;需要掌握很多底层知识和数学知识 Threej…

卷积神经网络模型之——LeNet

目录LeNet模型参数介绍该网络特点关于C3与S2之间的连接关于最后的输出层子采样参考LeNet LeNet是一个用来识别手写数字的最经典的卷积神经网络&#xff0c;是Yann LeCun在1998年设计并提出的。Lenet的网络结构规模较小&#xff0c;但包含了卷积层、池化层、全连接层&#xff0…

Mr. Cappuccino的第49杯咖啡——冒泡APP(升级版)之基于Docker部署Gitlab

冒泡APP&#xff08;升级版&#xff09;之基于Docker部署Gitlab基于Docker安装Gitlab登录Gitlab创建Git项目上传代码使用Git命令切换Git地址使用IDE更换Git地址基于Docker安装Gitlab 查看beginor/gitlab-ce镜像版本 下载指定版本的镜像 docker pull beginor/gitlab-ce:11.3.0…

c# 源生成器

本文概述了 .NET Compiler Platform&#xff08;“Roslyn”&#xff09;SDK 附带的源生成器。 通过源生成器&#xff0c;C# 开发人员可以在编译用户代码时检查用户代码。 生成器可以动态创建新的 C# 源文件&#xff0c;这些文件将添加到用户的编译中。 这样&#xff0c;代码可以…

线程(一)

线程 1. 线程 定义&#xff1a;线程是进程的组成部分&#xff0c;不同的线程执行不同的任务&#xff0c;不同的功能模块&#xff0c;同时线程使用的资源师由进程管理&#xff0c;主要分配CPU和内存。 ​ 在进程中&#xff0c;线程执行的方式是抢占式执行操作&#xff0c;需要考…

33--Vue-前端开发-使用Vue脚手架快速搭建项目

一、vue脚手架搭建项目 node的安装: 官方下载,一路下一步 node命令类似于python npm命令类似于pip 使用npm安装第三方模块,速度慢一些,需换成淘宝镜像 以后用cmpm代替npm的使用 npm install -g cnpm --registry=https://registry.npm.taobao.org安装脚手架: cnpm inst…

汉诺塔--课后程序(Python程序开发案例教程-黑马程序员编著-第6章-课后作业)

实例3&#xff1a;汉诺塔 汉诺塔是一个可以使用递归解决的经典问题&#xff0c;它源于印度一个古老传说&#xff1a;大梵天创造世界的时候做了三根金刚石柱子&#xff0c;其中一根柱子从下往上按照从大到小的顺序摞着64片黄金圆盘&#xff0c;大梵天命令婆罗门把圆盘从下面开始…

C++回顾(二十)—— vector容器 和 deque容器

20.1 vector容器 20.1.1 vector容器简介 vector是将元素置于一个动态数组中加以管理的容器。vector可以随机存取元素&#xff08;支持索引值直接存取&#xff0c; 用[]操作符或at()方法&#xff09;。vector尾部添加或移除元素非常快速。但是在中部或头部插入元素或移除元素比…