多层感知机(MLP)算法原理和代码实现

news2024/9/20 9:00:50

文章目录

  • 多层感知机入门
  • 算法优化原理
  • sklearn代码实现
  • 核心优缺点分析

多层感知机入门

神经网络在最近几年,是个很火的名词了。常听到的卷积神经网络(CNN)或者循环神经网络(RNN),都可以看做是神经网络在特定场景下的具体应用方式。

本文我们尝试从神经网络的基础:多层感知机(Multilayer perceptron, MLP)入手,以此了解其解决预测问题的基本算法原理。

要入门MLP,个人认为比较简单的方式是将其理解为一种广义的线性模型。

看上图,对于线性模型(左上图)
y = w [ 1 ] ∗ x [ 1 ] + w [ 2 ] ∗ x [ 2 ] + w [ 3 ] ∗ x [ 3 ] + b y=w[1]*x[1]+w[2]*x[2]+w[3]*x[3]+b y=w[1]x[1]+w[2]x[2]+w[3]x[3]+b
对于MLP(右上图), x x x y y y之间增加了一组黄色的 h h h模块,叫隐藏层。

y y y h h h x x x的关系为
h [ 1 ] = f ( w [ 1 , 1 ] x [ 1 ] + w [ 2 , 1 ] x [ 2 ] + w [ 3 , 1 ] x [ 3 ] + b [ 0 ] ) h[1]=f(w[1,1]x[1]+w[2,1]x[2]+w[3,1]x[3]+b[0]) h[1]=f(w[1,1]x[1]+w[2,1]x[2]+w[3,1]x[3]+b[0])
h [ 2 ] = f ( w [ 1 , 2 ] x [ 1 ] + w [ 2 , 2 ] x [ 2 ] + w [ 3 , 2 ] x [ 3 ] + b [ 1 ] ) h[2]=f(w[1,2]x[1]+w[2,2]x[2]+w[3,2]x[3]+b[1]) h[2]=f(w[1,2]x[1]+w[2,2]x[2]+w[3,2]x[3]+b[1])
y = v [ 1 ] h [ 1 ] + v [ 2 ] h [ 2 ] + s y=v[1]h[1]+v[2]h[2]+s y=v[1]h[1]+v[2]h[2]+s
其中, f ( ⋅ ) f(·) f()是某个非线性函数,叫激活函数。

从数学上来看,如果 f ( ⋅ ) f(·) f()是线性的,那么增加了隐藏层的MLP本质上还是一个线性模型。为了让MLP的学习能力更强大, f ( ⋅ ) f(·) f()需要使用非线性函数,目前应用比较多的有校正非线性(relu)、正切双曲线(tanh)或sigmoid函数。

三个函数的表达式分别为
relu : y = m a x ( 0 , x ) \text{relu}:y=max(0, x) relu:y=max(0,x)
tanh : y = e x − e − x e x + e − x \text{tanh}:y=\frac{e^x-e^{-x}}{e^x+e^{-x}} tanh:y=ex+exexex
sigmoid : y = 1 1 + e − x \text{sigmoid}:y=\frac{1}{1+e^{-x}} sigmoid:y=1+ex1
以下是它们各自对应的图形

此处最直接的一个疑问是,为什么有了这些激活函数后,学习能力就能变强大了?

我们可以这样理解:任意非线性函数都可以用多个分段线性函数来表示,即分段线性函数可以拟合任意非线性函数。而relu显然可以拟合分段线性函数,所以理论上便可以拟合任意非线性函数;从图上看,tanh和sigmoid可以认为是软化后的relu,所以对拟合效果影响不大,而且这两个函数计算梯度会更方便。

MLP的设计其实还有一种理解,就是看成是基于生物神经元的抽象模型,这也是神经网络模型的最初由来。神经元的结构如下,MLP和它的形状确实很像了。

算法优化原理

在上一节中,隐藏层只有1层,该层中也只有2个隐藏单元,变量的总个数为11个。但实际上,隐藏层的数量和每个隐藏层中的隐藏单元数量都可以为更多个,此时变量的个数就更多了。那么,该如何找到这些变量的最优值呢?

回到MLP设计的初衷,我们是要让MLP针对样本的预测误差最小。针对任意一个样本,误差可以定义为
L o s s = ( y − y ^ ) 2 Loss=(y-\hat y)^2 Loss=(yy^)2
还是以之前的实例为例,上式变为
L o s s = [ y − ( v 1 f ( w 11 x 1 + w 21 x 2 + w 31 x 3 + b 0 ) + v 2 f ( w 12 x 1 + w 22 x 2 + w 32 x 3 + b 1 ) + s ) ] 2 Loss=[y-(v_1f(w_{11}x_1+w_{21}x_2+w_{31}x_3+b_0)+v_2f(w_{12}x_1+w_{22}x_2+w_{32}x_3+b_1)+s)]^2 Loss=[y(v1f(w11x1+w21x2+w31x3+b0)+v2f(w12x1+w22x2+w32x3+b1)+s)]2
如果训练集中共有 N N N个样本,那么需要将每个样本的误差加和,并使其最小化
m i n ∑ i = 1 N L o s s i min \quad \sum_{i=1}^NLoss_i mini=1NLossi
上式中,待优化变量为 v , w , b \pmb v,\pmb w,\pmb b v,w,b s s s,而且为无约束问题,因此可以通过梯度类算法求解。

经典的梯度类算法,可以参考梯度类算法原理:最速下降法、牛顿法和拟牛顿法。不过由于MLP中的优化变量数量通常非常大,优化过程较为耗时,所以像牛顿法这类需要计算海森矩阵的算法自然是不合适的;即使是拟牛顿法这种已经使用其他方案替代海森矩阵的算法也逐渐被大家舍弃;最终,在提升训练效率的驱使下,人们普遍选择了最速下降法这类只需要一阶梯度的算法作为基本算法。

既然需要使用一阶梯度的信息,那么该如何计算梯度呢?在计算之前,我们要先理清楚的是,这里的梯度,指的是损失函数Loss针对优化变量为 v , w , b \pmb v,\pmb w,\pmb b v,w,b s s s的梯度。

我们以Loss针对 w w w的梯度为例,说明一下求解过程。还是使用之前的MLP实例,同时把表达式写成矩阵形式
h = f ( w T x + b ) ⇒ ∂ h ∂ w T = f ′ ⋅ x h=f(w^Tx+b) \Rightarrow \frac{\partial h}{\partial w^T}=f'·x h=f(wTx+b)wTh=fx
y ^ = v T h + s ⇒ ∂ y ^ ∂ h = v T \hat y = v^Th+s \Rightarrow \frac{\partial \hat y}{\partial h}=v^T y^=vTh+shy^=vT
L o s s = ( y − y ^ ) 2 ⇒ ∂ L o s s ∂ y ^ = y − y ^ Loss=(y-\hat y)^2 \Rightarrow \frac{\partial Loss}{\partial \hat y}=y-\hat y Loss=(yy^)2y^Loss=yy^
根据求导的链式法则
∂ L o s s ∂ w T = ∂ L o s s ∂ y ^ ⋅ ∂ y ^ ∂ h ⋅ ∂ h ∂ w T = ( y − y ^ ) ⋅ v T ⋅ f ′ ⋅ x \frac{\partial Loss}{\partial w^T}=\frac{\partial Loss}{\partial \hat y}·\frac{\partial \hat y}{\partial h}·\frac{\partial h}{\partial w^T}=(y-\hat y)·v^T·f'·x wTLoss=y^Losshy^wTh=(yy^)vTfx
这就是我们常说的:误差反向传播,此处可以理解为梯度计算的过程是从后向前逐步推导过来的。

以上实例给的是,针对一个样本的梯度求解过程。为了得到准确的梯度值,就需要针对每个样本都走一遍如上的过程。如果我们的训练样本数量比较大,就会导致梯度计算的过程非常慢。

如果换个思路来处理,我们每次只计算其中一个样本的梯度值,然后将该值作为最终的梯度值,那么结果就是梯度的计算过程快但值不那么准确。

所以当下流行的折中方案是,每次挑选一部分样本(batch)出来,计算出梯度值作为整体的梯度值。通过batch数量的调整,便可以兼顾计算的效率和值的准确性。

有了梯度后,我们回顾一下梯度类算法的基本迭代公式:
θ k + 1 = θ k − η k ⋅ d k \theta_{k+1}=\theta_k-\eta_k·d_k θk+1=θkηkdk
其中, θ \theta θ表示优化变量, η k \eta_k ηk d k d_k dk分别为第 k k k次的迭代步长(MLP中更多称之为“学习率”)和迭代方向。

在最速下降法中
d k = ∂ L o s s ∂ θ = ∇ ( θ ) d_k=\frac{\partial Loss}{\partial \theta}=\nabla(\theta) dk=θLoss=(θ)

d k d_k dk的计算过程刚刚已经描述了。如果算法逻辑到此为止,我们一般会称之为随机梯度下降法(Stochastic gradient descent,SGD)。

但是,人们一般很难满足于这种比较基础的算法,所以在此之上,相继出现了很多改进算法。从迭代公式可以看出,改进点主要包含3类:第一类是针对迭代方向,如Momentum;第二类是针对学习率,如Adagrad;第三类是同时改进迭代方向和学习率,如Adam。具体改进的策略可以参考学术文献:An overview of gradient descent optimization algorithms。

sklearn代码实现

sklearn中已经集成了MLP的工具包,下例中使用分类功能:MLPClassifier。该函数参数的含义可以参考MLPClassifier参数详解。数据集我们使用two_moons,该数据集的详细介绍可以参考之前的文章:决策树入门、sklearn实现、原理解读和算法分析。

import mglearn.plots
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier

from data.two_moons import two_moons
import matplotlib.pyplot as plt


def train_two_moons():
    # 导入两个月亮数据集
    features, labels = two_moons()
    
    # 将数据集分为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(features, labels, stratify=labels, random_state=42)
    # 设置 MLP 分类器属性:激活函数为tanh,2个隐藏层,每层均包含10个隐藏单元,最大迭代次数为1000
    mlp = MLPClassifier(solver='lbfgs', activation='tanh', random_state=0, hidden_layer_sizes=[10, 10], max_iter=1000)

    # 使用训练集对 MLP 分类器进行训练
    mlp.fit(X_train, y_train)
    # 打印模型在训练集上的准确率
    print('Accuracy on train set: {:.2f}'.format(mlp.score(X_train, y_train)))
    # 打印模型在测试集上的准确率
    print('Accuracy on test set: {:.2f}'.format(mlp.score(X_test, y_test)))
    
    # 绘制 2D 分类边界
    mglearn.plots.plot_2d_separator(mlp, X_train, fill=True, alpha=.3)
    # 绘制训练集的散点图
    mglearn.discrete_scatter(X_train[:, 0], X_train[:, 1], y_train)
    # 设置 x 轴标签
    plt.xlabel("Feature 0")
    # 设置 y 轴标签
    plt.ylabel("Feature 1")
    

if __name__ == '__main__':

    train_two_moons()

运行以上代码后,首先可以得到MLP的性能指标:

Accuracy on train set: 1.00
Accuracy on test set: 0.84

然后还可以得到该模型的决策边界。显然,该决策边界是非线性的,但是看起来还算平滑。

核心优缺点分析

MLP的主要优点是,可以通过调整隐藏层的层数和每个隐藏层内的隐藏单元数,构建出复杂的模型,从而学习到数据中包含的各类信息。如果给予足够的计算时间和数据,并且仔细调节参数,MLP通常可以打败其他机器学习算法(无论是分类任务还是回归任务)。

其主要缺点有两个:第一个是复杂模型往往需要花费较长的时间去训练模型中的参数,第二节中SGD->Momentum、Adagrad->Adam的算法进化,可以理解为人们尝试通过算法改进的方式,降低模型训练的总时间;第二个是可解释差,即我们无法解释清楚模型给出某个预测结果的原因,所以这个模型很多时候也被称之为黑箱模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/736783.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

干货 | 智能网联汽车大数据基础平台构建研究

以下内容整理自大数据能力提升项目必修课《大数据系统基础》同学们的期末答辩汇报。 各位老师大家上午好,我们组的题目是智能网联汽车大数据基础平台的构建。我们的指导企业是西部智联。我们的汇报将从这五个方面进行展开,第一个方面是项目背景与需求分析…

uni-app基础知识

发展 DCloud于2012年开始研发小程序技术,优化webview的功能和性能,并加入W3C和HTML5中国产业联盟。 2015年,DCloud正式商用了自己的小程序,产品名为“流应用”,它不是B/S模式的轻应用,而是能接近原生功能…

Debezium系列之:基于debezium将mysql数据库数据更改流式传输到 Elasticsearch和PostgreSQL数据库

Debezium系列之:基于debezium将mysql数据库数据更改流式传输到 Elasticsearch和PostgreSQL数据库 一、背景二、技术路线三、配置四、从mysql同步数据到Elasticsearch和PostgreSQL数据库五、总结 一、背景 基于 Debezium 的端到端数据流用例,将数据流式传…

I/O 多路复用小结

Socket 模型 Socket 编程是一种使用 Socket 模型进行网络通信的编程技术。它是一种基于网络套接字的编程模型,用于实现不同计算机之间的数据传输。 事实上,在进行网络通信前,通信双方都要创建一个 Socket,双方的数据读写都要依赖于…

【Python】执行SQL报错

可以再数据库查询界面执行的SQL,一直报错 unsupported format character Y (0x59) at index 61 SQL如下: datapd.read_sql_query(sql"""selectdate_format(create_time,%Y-%m) as mon,count(distinct order_id) as ord_cntfrom prod.o…

HTTP与HTTPS

HTTP与HTTPS介绍 超文本传输协议HTTP协议被用于在Web浏览器和网站服务器之间传递信息,HTTP协议以明文方式发送内容,不提供任何方式的数据加密,如果攻击者截取了Web浏览器和网站服务器之间的传输报文,就可以直接读懂其中的信息&…

qt源码--事件系统

qt的事件传播主要依赖于QCoreApplication、QAbstractEventDispatcher(会根据不同的平台生成各自的处理对象)、QEvent(各种事件类型)等。 首先看下QCoreApplication的实现: 2、了解QCoreApplication的构造函数 其构造函…

在最新ICP备案域名的基础上,结合其他网络营销手段,打造强大的品牌推广效果

API接口是一种软件系统之间进行交互的方式,通过API接口,可以在不同的系统之间传递数据、命令等信息。在网络营销中,API接口可以帮助我们更加高效地进行品牌推广。本文将以在最新ICP备案域名的基础上,结合其他网络营销手段&#xf…

JVM回收算法(标记-清除算法, 复制算法, 标记-整理算法)

1.标记-清除算法 最基础的算法,分为两个阶段,“标记”和“清除” 原理: - 标记阶段:collector从mutator根对象开始进行遍历,对从mutator根对象可以访问到的对象都打上一个标识,一般是在对象的header中&am…

vue-router 4.0 动态路由会跳转到 404 页面的问题

引子 开发过前端单页面应用的小伙伴们,应该对前端路由都不陌生吧。 无论是用 vue 或者 react,都有官方提供的 router 方案。 但是有些场景下,处于安全性和友好性考虑,我们需要用到动态路由。 如果你不知道什么叫动态路由&…

翻遍整个牛客网,整理出了全网最全的Java面试八股文大合集,整整4000多页

大家从 Boss 直聘上或者其他招聘网站上都可以看到 Java 岗位众多,Java 岗位的招聘薪酬天差地别,人才要求也是五花八门。而很多 Java 工程师求职过程中,也是冷暖自知。很多时候技术有,但是面试的时候就是过不了! 为了帮…

4.7 x64dbg 应用层的钩子扫描

所谓的应用层钩子(Application-level hooks)是一种编程技术,它允许应用程序通过在特定事件发生时执行特定代码来自定义或扩展其行为。这些事件可以是用户交互,系统事件,或者其他应用程序内部的事件。应用层钩子是在应用…

【Zabbix 监控 Windows 系统,Java应用,SNMP】

目录 一、Zabbix 监控 Windows 系统1、下载 Windows 客户端 Zabbix agent 22、安装客户端,配置3、在服务端 Web 页面添加主机,关联模板 二、Zabbix 监控 java 应用1、客户端开启 java jmxremote 远程监控功能1、配置 java jmxremote 远程监控功能2、启动…

【ARM Coresight 系列文章 3.1 - ARM Coresight DP 对 AP 的访问 2】

文章目录 图 1-1 如上图1-1 所示,DAP上可以集成多个MEM-AP,上图是集成了3个MEM-AP,它们可能是AXI-AP, AHB-AP, APB-AP。 那么AP的类型是如何区分的呢? 不同的组件会使用不同MEM-AP接口,如Cortex-A/Coretex-R 系列的c…

ai绘画工具有哪些?这几款好用的ai绘画工具免费分享给你

上周,我去了一家现代艺术画廊,墙上挂满了令人叹为观止的绘画作品。我被其中一幅细腻而充满情感的油画所深深吸引,想要了解背后的创作过程。这时,一位热情的艺术导师走到我身边。她微笑着说:“这幅作品实际上是由ai绘画…

解决“_mkdir无法识别空格目录“问题

在C编程里,有时候需要创建一个文件夹,通常使用库函数_mkdir(const char* dirname)来新建一个文件夹,该库函数每次只能创建一个文件夹,不能级联创建。若要级联创建文件,则请用递归方式或者for循环方式调用_mkdir()。 #…

7月6日华为云盘古气象大模型登上《Nature》杂志:相比传统数值预报快10000倍

7月6日,国际顶级学术期刊《自然》(Nature)杂志正刊发表了华为云盘古大模型研发团队的最新研究成果——《三维神经网络用于精准中期全球天气预报》(《Accurate medium-range global weather forecasting with 3D neural networks》…

CrossKD 原理与代码解析

paper:CrossKD: Cross-Head Knowledge Distillation for Dense Object Detection official implementation: https://github.com/jbwang1997/CrossKD 前言 蒸馏可以分为预测蒸馏prediction mimicking和特征蒸馏feature imitation两种,201…

【LeetCode】HOT 100(26)

题单介绍: 精选 100 道力扣(LeetCode)上最热门的题目,适合初识算法与数据结构的新手和想要在短时间内高效提升的人,熟练掌握这 100 道题,你就已经具备了在代码世界通行的基本能力。 目录 题单介绍&#…

什么是 AOP?对于 Spring IoC 和 AOP 的理解?Spring AOP 和 AspectJ AOP 有什么区别?

什么是 AOP? AOP(Aspect-Oriented Programming),即 面向切面编程, 它与OOP( ObjectOriented Programming, 面向对象编程) 相辅相成,提供了与OOP 不同的抽象软件结构的视角 在 OOP 中, 我们以类(class)作为我们的基本单元 而 A…