LeNet对MNIST 数据集中的图像进行分类--keras实现

news2024/12/28 21:51:28

我们将训练一个卷积神经网络来对 MNIST 数据库中的图像进行分类,可以与前面所提到的CNN实现对比CNN对 MNIST 数据库中的图像进行分类-CSDN博客

加载 MNIST 数据库

MNIST 是机器学习领域最著名的数据集之一。

  • 它有 70,000 张手写数字图像 - 下载非常简单 - 图像尺寸为 28x28 - 灰度图像
from keras.datasets import mnist

# 使用 Keras 导入预洗牌 MNIST 数据库
(X_train, y_train), (X_test, y_test) = mnist.load_data()

print("The MNIST database has a training set of %d examples." % len(X_train))
print("The MNIST database has a test set of %d examples." % len(X_test))

将前六个训练图像可视化 

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.cm as cm
import numpy as np

# 绘制前六幅训练图像
fig = plt.figure(figsize=(20,20))
for i in range(6):
    ax = fig.add_subplot(1, 6, i+1, xticks=[], yticks=[])
    ax.imshow(X_train[i], cmap='gray')
    ax.set_title(str(y_train[i]))

 查看图像的更多细节

def visualize_input(img, ax):
    ax.imshow(img, cmap='gray')
    width, height = img.shape
    thresh = img.max()/2.5
    for x in range(width):
        for y in range(height):
            ax.annotate(str(round(img[x][y],2)), xy=(y,x),
                        horizontalalignment='center',
                        verticalalignment='center',
                        color='white' if img[x][y]<thresh else 'black')

fig = plt.figure(figsize = (12,12)) 
ax = fig.add_subplot(111)
visualize_input(X_train[0], ax)

预处理输入图像:通过将每幅图像中的每个像素除以 255 来调整图像比例

# normalize the data to accelerate learning
mean = np.mean(X_train)
std = np.std(X_train)
X_train = (X_train-mean)/(std+1e-7)
X_test = (X_test-mean)/(std+1e-7)

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

对标签进行预处理:使用单热方案对分类整数标签进行编码

from keras.utils import np_utils

num_classes = 10 
# print first ten (integer-valued) training labels
print('Integer-valued labels:')
print(y_train[:10])

# one-hot encode the labels
# convert class vectors to binary class matrices
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)

# print first ten (one-hot) training labels
print('One-hot labels:')
print(y_train[:10])

重塑数据以适应我们的 CNN(和 input_shape)

# input image dimensions 28x28 pixel images. 
img_rows, img_cols = 28, 28

X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

print('image input shape: ', input_shape)
print('x_train shape:', X_train.shape)

定义模型架构

论文地址:lecun-01a.pdf

要在 Keras 中实现 LeNet-5,请阅读原始论文并从第 6、7 和 8 页中提取架构信息。以下是构建 LeNet-5 网络的主要启示:

  • 每个卷积层的滤波器数量:从图中(以及论文中的定义)可以看出,每个卷积层的深度(滤波器数量)如下:C1 = 6、C3 = 16、C5 = 120 层。
  • 每个 CONV 层的内核大小:根据论文,内核大小 = 5 x 5
  • 每个卷积层之后都会添加一个子采样层(POOL)。每个单元的感受野是一个 2 x 2 的区域(即 pool_size = 2)。请注意,LeNet-5 创建者使用的是平均池化,它计算的是输入的平均值,而不是我们在早期项目中使用的最大池化层,后者传递的是输入的最大值。如果您有兴趣了解两者的区别,可以同时尝试。在本实验中,我们将采用论文架构。
  • 激活函数:LeNet-5 的创建者为隐藏层使用了 tanh 激活函数,因为对称函数被认为比 sigmoid 函数收敛更快。一般来说,我们强烈建议您为网络中的每个卷积层添加 ReLU 激活函数。

需要记住的事项

  • 始终为 CNN 中的 Conv2D 层添加 ReLU 激活函数。除了网络中的最后一层,密集层也应具有 ReLU 激活函数。
  • 在构建分类网络时,网络的最后一层应该是具有软最大激活函数的密集(FC)层。最终层的节点数应等于数据集中的类别总数。
from keras.models import Sequential
from keras.layers import Conv2D, AveragePooling2D, Flatten, Dense
#Instantiate an empty model
model = Sequential()

# C1 Convolutional Layer
model.add(Conv2D(6, kernel_size=(5, 5), strides=(1, 1), activation='tanh', input_shape=input_shape, padding='same'))

# S2 Pooling Layer
model.add(AveragePooling2D(pool_size=(2, 2), strides=2, padding='valid'))

# C3 Convolutional Layer
model.add(Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))

# S4 Pooling Layer
model.add(AveragePooling2D(pool_size=(2, 2), strides=2, padding='valid'))

# C5 Fully Connected Convolutional Layer
model.add(Conv2D(120, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))

#Flatten the CNN output so that we can connect it with fully connected layers
model.add(Flatten())

# FC6 Fully Connected Layer
model.add(Dense(84, activation='tanh'))

# Output Layer with softmax activation
model.add(Dense(10, activation='softmax'))

# print the model summary
model.summary()

编译模型

我们将使用亚当优化器

# the loss function is categorical cross entropy since we have multiple classes (10) 


# compile the model by defining the loss function, optimizer, and performance metric
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

训练模型

LeCun 和他的团队采用了计划衰减学习法,学习率的值按照以下时间表递减:前两个历元为 0.0005,接下来的三个历元为 0.0002,接下来的四个历元为 0.00005,之后为 0.00001。在论文中,作者对其网络进行了 20 个历元的训练。

from keras.callbacks import ModelCheckpoint, LearningRateScheduler

# set the learning rate schedule as created in the original paper
def lr_schedule(epoch):
    if epoch <= 2:     
        lr = 5e-4
    elif epoch > 2 and epoch <= 5:
        lr = 2e-4
    elif epoch > 5 and epoch <= 9:
        lr = 5e-5
    else: 
        lr = 1e-5
    return lr

lr_scheduler = LearningRateScheduler(lr_schedule)

# set the checkpointer
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, 
                               save_best_only=True)

# train the model
hist = model.fit(X_train, y_train, batch_size=32, epochs=20,
          validation_data=(X_test, y_test), callbacks=[checkpointer, lr_scheduler], 
          verbose=2, shuffle=True)

在验证集上加载分类准确率最高的模型

# load the weights that yielded the best validation accuracy
model.load_weights('model.weights.best.hdf5')

计算测试集的分类准确率

# evaluate test accuracy
score = model.evaluate(X_test, y_test, verbose=0)
accuracy = 100*score[1]

# print test accuracy
print('Test accuracy: %.4f%%' % accuracy)

评估模型

import matplotlib.pyplot as plt

f, ax = plt.subplots()
ax.plot([None] + hist.history['accuracy'], 'o-')
ax.plot([None] + hist.history['val_accuracy'], 'x-')
# 绘制图例并自动使用最佳位置: loc = 0。
ax.legend(['Train acc', 'Validation acc'], loc = 0)
ax.set_title('Training/Validation acc per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('acc')
plt.show()

import matplotlib.pyplot as plt

f, ax = plt.subplots()
ax.plot([None] + hist.history['loss'], 'o-')
ax.plot([None] + hist.history['val_loss'], 'x-')

# Plot legend and use the best location automatically: loc = 0.
ax.legend(['Train loss', "Val loss"], loc = 0)
ax.set_title('Training/Validation Loss per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1276941.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

规则引擎专题---2、开源规则引擎对比

开源规则引擎 开源的规则引擎整体分为下面几类&#xff1a; 通过界面配置的成熟规则引擎&#xff0c;这种规则引擎相对来说就比较重&#xff0c;但功能全&#xff0c;比较出名的有:drools, urule。 基于jvm脚本语言&#xff0c;互联网公司会觉得drools太重了&#xff0c;然后…

用100ask 6ull配合 飞凌 elf1的教程进行学习的记录

启动方式 百问网 elf1: 固件 emmc-otg 串口 网络 改eth0, 网线接在右边的网口eth2上

spring boot mybatis TypeHandler 看源码如何初始化及调用

目录 概述使用TypeHandler使用方式在 select | update | insert 中加入 配置文件中指定 源码分析配置文件指定Mapper 执行query如何转换 结束 概述 阅读此文 可以达到 spring boot mybatis TypeHandler 源码如何初始化及如何调用的。 spring boot 版本为 2.7.17&#xff0c;my…

触控板绘画工具Inklet mac功能介绍

Inklet mac是一款触控板绘画工具&#xff0c;把你的触控板变成画画的板子&#xff0c;意思是&#xff0c;你点在触控板的哪里&#xff0c;鼠标就会出现载相应的地方。例如&#xff0c;但你把手指移动到触控盘左下角&#xff0c;那么鼠标也会出现在左下角&#xff0c;对于用户而…

【已解决】Cannot find project Scala library 2.11.8 for module XXX

问题描述 在 flink 示例程序调试过程中&#xff0c;reload project 报错 Cannot find project Scala library 2.11.8 for module HbasePrint 报错如下图所示&#xff1a; 问题解决 经过搜索&#xff0c;初步判定是 pom 文件中 Scala 版本配置和项目中实际使用的版本不一致导…

11.29 知识回顾(视图层、模板层)

一、视图层 1.1 响应对象 响应---》本质都是 HttpResponse -HttpResponse---》字符串 -render----》放个模板---》模板渲染是在后端完成 -js代码是在客户端浏览器里执行的 -模板语法是在后端执行的 -redirect----》重定向 -字符串参数不是…

RabbitMq整合Springboot超全实战案例+图文演示+源码自取

目录 介绍 简单整合 简单模式 定义 代码示例 work模式 定义 代码示例 pubsub模式 定义 代码示例 routing模式 定义 代码示例 top模式 定义 代码 下单付款加积分示例 介绍 代码 可靠性投递示例 介绍 代码 交换机投递确认回调 队列投递确认回调 ​延迟消…

前缀和 LeetCode1094 拼车

1094. 拼车 车上最初有 capacity 个空座位。车 只能 向一个方向行驶&#xff08;也就是说&#xff0c;不允许掉头或改变方向&#xff09; 给定整数 capacity 和一个数组 trips , trip[i] [numPassengersi, fromi, toi] 表示第 i 次旅行有 numPassengersi 乘客&#xff0c;接…

抖音怎么一次性隐藏全部视频

很多朋友不知道抖音怎么一次性隐藏全部视频&#xff0c;其实只需要在设置菜单中将账号设置为【私密账号】即可&#xff0c;在抖音中依次点击【设置】-【我】-【隐私设置】-【私密账号】&#xff0c;在弹出的窗口中将账号设为私密即可。也可以依次打开抖音作品&#xff0c;点击底…

golang Pool实战与底层实现

使用的go版本为 go1.21.2 首先我们写一个简单的Pool的使用代码 package mainimport "sync"var bytePool sync.Pool{New: func() interface{} {b : make([]byte, 1024)return &b}, }func main() {for j : 0; j < 10; j {obj : bytePool.Get().(*[]byte) // …

解决element ui tree组件不产生横向滚动条

结果是这样的 需要在tree的外层&#xff0c;包一个父组件 <div class"tree"><el-tree :data"treeData" show-checkbox default-expand-all></el-tree></div> 在css里面这样写,样式穿透按自己使用的css编译器以及框架要求就好 &l…

SQL Server 2016(创建数据库)

1、实验环境。 某公司有一台已经安装了SQL Server 2016的服务器&#xff0c;现在需要新建数据库。 2、需求描述。 创建一个名为"db_class"的数据库&#xff0c;数据文件和日志文件初始大小设置为10MB&#xff0c;启用自动增长&#xff0c;数据库文件存放路径为C:\db…

文献速递:人工智能在健康和医学中

人工智能在健康和医学中 01 文献速递介绍 这篇文章详细探讨了人工智能&#xff08;AI&#xff09;在医学领域的最新进展、挑战和未来发展的机遇。 1.医学AI算法的最新进展&#xff1a; **AI在医疗实践中的应用&#xff1a;**虽然AI系统在多项回顾性医学研究中表现出色&…

docker 搭建开发环境,解决deepin依赖问题

本机环境&#xff1a; deepin v23b2 删除docker旧包 sudo apt-get remove docker docker-engine docker.io containerd runc注意卸载docker旧包的时候Images, containers, volumes, 和networks 都保存在 /var/lib/docker 卸载的时候不会自动删除这块数据&#xff0c;如果你先…

Beautiful Soup4爬虫速成

做毕业论文需要收集数据集&#xff0c;我的数据集就是文本的格式&#xff0c;而且是静态页面的形式&#xff0c;所以只是一个简单的入门。动态页面的爬虫提取这些比较进阶的内容&#xff0c;我暂时没有这样的需求&#xff0c;所以有这类问题的朋友们请移步。 如果只是简单的静态…

目标检测——Faster R-CNN算法解读

论文&#xff1a;Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks 作者&#xff1a;Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun 链接&#xff1a;https://arxiv.org/abs/1506.01497 代码&#xff1a;https://github.com/rbgirsh…

vue使用elementui的el-menu的折叠菜单collapse

由于我的是在el-menu所在组件外面的兄弟组件设置是否折叠的控制&#xff0c;我用事件总线bus进行是否折叠传递 参数说明类型可选值默认值collapse是否水平折叠收起菜单&#xff08;仅在 mode 为 vertical 时可用&#xff09;boolean—falsebackground-color菜单的背景色&#…

深入理解Servlet(上)

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 为什么要了解Servlet …

用JavaScript的管道方法简化代码复杂性

用JavaScript的管道方法简化代码复杂性 在现代 web 开发中&#xff0c;维护干净有效的代码是必不可少的。随着项目的增加&#xff0c;我们功能的复杂性也在增加。然而&#xff0c;javaScript为我们提供了一个强大的工具&#xff0c;可以将这些复杂的函数分解为更小的、可管理的…

什么是Anaconda

Anaconda的安装也很方便。打开这个网站Anaconda下载&#xff0c;然后安装即可。 Anaconda可以帮助我们解决团队之间合作的包依赖管理问题。在没有使用Anaconda之前&#xff0c;如果你的Python程序想让你的同事运行&#xff0c;那么你的同事可能会遇到很多包依赖问题&#xff0…