深度学习实战(11):使用多层感知器分类器对手写数字进行分类

news2024/11/23 12:17:25

使用多层感知器分类器对手写数字进行分类

图 1:多层感知器网络

1.简介

1.1 什么是多层感知器(MLP)?

MLP 是一种监督机器学习 (ML) 算法,属于前馈人工神经网络 [1] 类。该算法本质上是在数据上进行训练以学习函数。给定一组特征和一个目标变量(例如标签),它会学习一个用于分类或回归的非线性函数。在本文中,我们将只关注分类案例。

1.2 MLP和逻辑回归有什么相似之处吗?

有!逻辑回归只有两层,即输入和输出,但是,在 MLP 模型的情况下,唯一的区别是我们可以有额外的中间非线性层。这些被称为隐藏层。除了输入节点(属于输入层的节点)之外,每个节点都是一个使用非线性激活函数的神经元[1]。由于这种非线性性质,MLP 可以学习复杂的非线性函数,从而区分不可线性分离的数据!请参见下面的图 2,了解具有一个隐藏层的 MLP 分类器的可视化表示。

1.3 MLP 是如何训练的?

MLP 使用反向传播进行训练。

1.4 MLP的主要优缺点.

优点:

  • 可以学习非线性函数,从而分离不可线性分离的数据 。
    缺点:
  • 隐藏层的损失函数导致非凸优化问题,因此存在局部最小值。
  • 不同的权重初始化可能会导致不同的输出/权重/结果。
  • MLP 有一些超参数,例如隐藏神经元的数量,需要调整的层数(时间和功耗)。
  • MLP 可能对特征缩放敏感 。
    图 2:具有一个隐藏层和标量输出的 MLP

2.使用scikit-learn的Python动手实例

2.1 数据集

对于这个实践示例,我们将使用 MNIST 数据集。 MNIST 数据库是一个著名的手写数字数据库,用于训练多个 ML 模型 。有 10 个不同数字的手写图像,因此类别数为 10 (参见图 3)。

注意:由于我们处理图像,因此这些由二维数组表示,并且数据的初始维度是每个图像的 28 by 28 ( 28x28 pixels )。然后二维图像被展平,因此在最后由矢量表示。每个 2D 图像都被转换为维度为 [1, 28x28] = [1, 784] 的 1D 向量。最后,我们的数据集有 784 个特征/变量/列。

图 3:数据集中的一些样本

2.2 数据导入与准备

import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.neural_network import MLPClassifier
# Load data
X, y = fetch_openml("mnist_784", version=1, return_X_y=True)
# Normalize intensity of images to make it in the range [0,1] since 255 is the max (white).
X = X / 255.0

请记住,每个 2D 图像现在都转换为维度为 [1, 28x28] = [1, 784] 的 1D 矢量。我们现在来验证一下。

print(X.shape)

这将返回: (70000, 784) 。我们有 70k 个扁平图像(样本),每个图像包含 784 个像素(28*28=784)(变量/特征)。
因此,输入层权重矩阵的形状为

784 x #neurons_in_1st_hidden_layer.

输出层权重矩阵的形状为

#neurons_in_3rd_hidden_layer x #number_of_classes

2.3 模型训练

现在让我们构建模型、训练它并执行分类。我们将分别使用 3 个隐藏层和 50,20 and 10 个神经元。此外,我们将设置最大迭代次数 100 ,并将学习率设置为 0.1 。这些是我在简介中提到的超参数。我们不会在这里微调它们。

# Split the data into train/test sets
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
classifier = MLPClassifier(
    hidden_layer_sizes=(50,20,10),
    max_iter=100,
    alpha=1e-4,
    solver="sgd",
    verbose=10,
    random_state=1,
    learning_rate_init=0.1,
)
# fit the model on the training data
classifier.fit(X_train, y_train)

2.4 模型评估

现在,让我们评估模型。我们将估计训练和测试数据和标签的平均准确度。

print("Training set score: %f" % classifier.score(X_train, y_train))
print("Test set score: %f" % classifier.score(X_test, y_test))

训练集分数:

0.998633

测试集分数:

0.970300

2.5 成本函数演变的可视化

训练期间损失减少的速度有多快?让我们制作一个漂亮的图表看一看!

fig, axes = plt.subplots(1, 1)
axes.plot(classifier.loss_curve_, 'o-')
axes.set_xlabel("number of iteration")
axes.set_ylabel("loss")
plt.show()

图 4:训练迭代中损失的演变
在这里,我们看到损失在训练期间下降得非常快,并且在 40th 迭代后饱和(请记住,我们将最大 100 次迭代定义为超参数)。

2.6 可视化学习到的权重

这里我们首先需要了解权重(每一层的学习模型参数)是如何存储的。

根据文档,属性 classifier.coefs_ 是形状为 (n_layers-1, ) 的权重数组的列表,其中索引 i 处的权重矩阵表示层 i 和层 i+1 之间的权重。在这个例子中,我们定义了 3 个隐藏层,我们还有输入层和输出层。因此,我们希望层间权重有 4 个权重数组(图 5 中的 in-L1, L1-L2, L2-L3 和 L2-out )。

类似地, classifier.intercepts_ 是偏置向量列表,其中索引 i 处的向量表示添加到层 i+1 的偏置值。

在这里插入图片描述

让我们验证一下:

len(classifier.intercepts_) == len(classifier.coefs_) == 4

正确返回 True 。

输入层权重矩阵的形状为

784 x #neurons_in_1st_hidden_layer.

输出层权重矩阵的形状为

#neurons_in_3rd_hidden_layer x #number_of_classes.

2.7 可视化输入层的学习权重

target_layer = 0 #0 is input, 1 is 1st hidden etc
fig, axes = plt.subplots(1, 1, figsize=(15,6))
axes.imshow(np.transpose(classifier.coefs_[target_layer]), cmap=plt.get_cmap("gray"), aspect="auto")
axes.set_xlabel(f"number of neurons in {target_layer}")
axes.set_ylabel("neurons in output layer")
plt.show()

图 6:输入层和第一个隐藏层之间的神经元学习权重的可视化
将它们重新整形并绘制为 2D 图像。

# choose layer to plot
target_layer = 0 #0 is input, 1 is 1st hidden etc
fig, axes = plt.subplots(4, 4)
vmin, vmax = classifier.coefs_[0].min(), classifier.coefs_[target_layer].max()
for coef, ax in zip(classifier.coefs_[0].T, axes.ravel()):
    ax.matshow(coef.reshape(28, 28), cmap=plt.cm.gray, vmin=0.5 * vmin, vmax=0.5 * vmax)
    ax.set_xticks(())
    ax.set_yticks(())
plt.show()

3.总结

MLP 分类器是一种非常强大的神经网络模型,可以学习复杂数据的非线性函数。该方法使用前向传播来构建权重,然后计算损失。接下来,反向传播用于更新权重,从而减少损失。这是以迭代方式完成的,迭代次数是一个输入超参数,正如我在简介中所解释的那样。其他重要的超参数是每个隐藏层中的神经元数量和隐藏层总数。这些都需要微调。
更多Ai资讯:公主号AiCharm
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/337288.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2021年新公开工业控制系统严重漏洞汇总

声明 本文是学习ITOT一体化工业信息安全态势报告(2019). 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 工业互联网安全威胁 2021年新公开工业控制系统严重漏洞 缓冲区溢出漏洞 缓冲区溢出(buffer overflow&…

什么是自动化测试?自动化测试现状怎么样?

什么是自动化测试:其实自动化测试,就是让我们写一段程序去测试另一段程序是否正常的过程,自动化测试可以更加省力的替代一部分的手动操作。 现在自动化测试的现状,也是所有学习者关心的,但现在国内公司主要是以功能测…

算法设计 - 前缀和 差分数列

一维数组前缀和的概念 前缀和的概念很简单,我们用一维数组的前缀和来举例,有如下一维数组: arr [1, 2, 3, 4, 5] 该数组的前缀和数组如下 preSum [1, 3, 6, 10, 15] 关系如下: preSum[0] arr[0]preSum[1] arr[0] arr[1]pre…

Codeforces Round #851 (Div. 2) A — C

Codeforces Round #851 A. One and Two 题目描述 给定一个序列a中的每个元素都是1或2。找出整数k是否存在,以满足以题目所给条件。 题目分析 1对乘积没有贡献,只需要注意2的个数即可,偶数个2即可满足条件,记录第cnt/2个2的位…

集成nanocaptcha库生成登录验证码

背景 需要实现一个验证码登录的功能需求。这个需求挺简单的,主要实现验证码图片生成给前端,然后,在登录接口比对验证码即可。刚拿到这个需求,好久没有搞过登录这一块了,所以,查了一下相关验证码的知识。下…

cv2--特征点特征提取(Sift,Orb,Surf)

cv2–特征点特征提取(Sift,Orb,Surf) 文章目录cv2--特征点特征提取(Sift,Orb,Surf)1. 关键点和关键点描述子2. Sift2.1 检测的步骤2.2 同时计算关键点kp和描述子des3. Surf4. Orb5. …

61 UseSerialGc的新生代回收调试

前言 呵呵 很久之前看到这样的两篇文章 [讨论] HotSpot VM Serial GC的一个问题 新生代回收调试的一些心得 在第一篇帖子中 R大 详细的讲述了 cheney 算法, 以及自己编写的 cheney 算法, 以及 DefNewGeneration 的具体的一些细节, 以及 和现有的例子的对比 另外还有一些…

leaflet 加载topojson数据,显示图形(代码示例047)

第047个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+leaflet中加载topojson文件,将图形显示在地图上。TopoJSON文件格式是geoJSON的一种扩展,它可以对地理空间拓扑进行编码。TopoJSON文件包含数据属性和地理空间的属性。 直接复制下面的 vue+openlayers源代码,操作…

C语言入门(什么是C语言,C语言的编程机制以及一些基础计算机概念)

目录 一.什么是C语言 1.面向对象: 2.面向过程: 二.C语言特点 三.C语言开发时间 四.环境搭建 1.代码编辑器 2.C编译器 五.C语言标准 六.计算机补充知识 1.计算机构成 2.CPU工作 3.编译器 七.编写程序步骤 八.源文件,目标文件&a…

SpringBoot 三大开发工具,你都用过么?

本文已经收录到Github仓库,该仓库包含计算机基础、Java基础、多线程、JVM、数据库、Redis、Spring、Mybatis、SpringMVC、SpringBoot、分布式、微服务、设计模式、架构、校招社招分享等核心知识点,欢迎star~ Github地址:https://github.com/…

Sentinel服务熔断功能(sentinel整合ribbon+openFeign+fallback)

目录 1、Sentinel服务熔断功能 一、Ribbon系列 (一)提供者9003/9004 (二)消费者84 二、OpenFeign系列 三、熔断框架比较 2、规则持久化 1、Sentinel服务熔断功能 一、Ribbon系列 (一)提供者9003/9004 …

DAMA数据管理知识体系指南之元数据管理

第11章 元数据管理 11.1简介 按照通常的说法,元数据的定义是“关于数据的数据”,但是其确切含义是什么?元数据与数据的关系就像数据与自然界的关系。数据反映了真实世界的交易、事件、对象和关系,而元数据则反映了数据的交易、事…

技术分享|终端安全防护|ChatGPT会创造出超级恶意软件吗?

ChatGPT是一个强大的人工智能聊天机器人,它使用大量的数据收集和自然语言处理与用户“交谈”,感觉像是和正常的人类对话。它的易用性和相对较高的准确性让用户可以利用它做任何事情,从解决复杂的数学问题,到写论文,创建…

【Linux】操作系统与进程的概念

目录 冯诺依曼体系 注意 为什么CPU不直接访问输入或输出设备? 跨主机间数据的传递 操作系统 管理 进程 描述进程 进程的查看和终止 bash 通过系统调用创建子进程 fork的辨析 冯诺依曼体系 🥖冯诺依曼结构也称普林斯顿结构,是一种将…

(超详细)Navicat的安装和激活,亲测有效

步骤一:准备安装包 下载Navicat,我用的v15最好一致(私信可以发你安装包和注册码)步骤二:关闭杀毒软件,然后需要断掉网络(一定断网) 步骤三:一路next安装,安装…

nodejs如何实现Digest摘要认证?

文章目录1.前言2. 原理3. 过程4. node实现摘要认证5. 前端如何Digest摘要登录认证(下面是海康的设备代码)1.前言 根据项目需求,海康设备ISAPI协议需要摘要认证,那么什么是摘要认证?估计不少搞到几年的前端连摘要认证都…

每日一个解决问题:事务无法回滚是什么原因?

今天在码代码时发现事务不回滚了,学过MySQL 事务小伙伴们都懂,通过 begin 开启事务,通过 commit 提交事务或者通过 rollback 回滚事务。 正常来说,当我们开启一个事务之后,需要 commit 或者 rollback 来结束一个事务的…

下面这段Python代码执行后的输出结果是?

点击上方“Python爬虫与数据挖掘”,进行关注回复“书籍”即可获赠Python从入门到进阶共10本电子书今日鸡汤几行归塞尽,念尔独何之。大家好,我是皮皮。一、前言前几天在Python青铜交流群【桐霄L】问了一个Python基础的问题,这里拿出…

TortoiseGit 使用教程

一、下载工具 这里给大家准备了所有安装包自取 链接:https://pan.baidu.com/s/1xrxxgyNXNQEGD_RjwKnPMg 提取码:qwer 也可在官网自行下载最新版 1.下载git,直接去官网下载Git - Downloads,根据自己的系统合理下载&#xff0c…

openGauss客户端安装

目录1. 准备两台Linux系统2. 安装openGauss客户端3. 设置客户端主机环境变量4. 修改服务端配置文件5. 测试客户端远程连接客户端环境:openEuler release 22.03 (LTS-SP1) 服务端环境: openEuler release 20.03 (LTS-SP3) openEuler系统官网下载 1. 准备两台Linux系…