【Python机器学习】神经网络中常用激活函数、损失函数、优化方法(图文解释 附源码)

news2025/3/6 11:22:50

下面以经典的分类任务:MNIST手写数字识别,采用全连接层神经网络

MNIST数据集是一个手写体的数字图片集,它包含有训练集和测试集,由250个人手写的数字构成。训练集包含60000个样本,测试集包含10000个样本。每个样本包括一张图片和一个标签。每张图片由28×28个像素点构成,每个像素点用1个灰度值表示。标签是与图片对应的0到9的数字。

随着训练损失值逐渐降低 精确度上升 

 

 部分代码如下

import numpy as np
import tensorflow.keras as ka
import datetime
 
np.random.seed(0)
 
(X_train, y_train), (X_val, y_val) = ka.datasets.mnist.load_data("D:\datasets\MNIST_Data\mnist.npz") # 加载数据集,并分成训练集和验证集
 
num_pixels = X_train.shape[1] * X_train.shape[2] # 每幅图片的像素数为784

# 将二维的数组拉成一维的向量
X_train = X_train.reshape(X_train.shape[0], num_pixels).astype('float32')
X_val = X_val.reshape(X_val.shape[0], num_pixels).astype('float32')

# 归一化
X_train = X_train / 255
X_val = X_val / 255
 
y_train = ka.utils.to_categorical(y_train) # 转化为独热编码
y_val = ka.utils.to_categorical(y_val)
num_classes = y_val.shape[1] # 10

# 多层全连接神经网络模型
model = ka.Sequential([
    ka.layers.Dense(num_pixels, input_shape=(num_pixels,), kernel_initializer='normal', activation='sigmoid'),
    ka.layers.Dense(784, kernel_initializer='normal', activation='sigmoid'),    
    ka.layers.Dense(num_classes, kernel_initializer='normal', activation='sigmoid')
])
model.summary()

#model.compile(loss='mse', optimizer='sgd', metrics=['accuracy'])
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

startdate = datetime.datetime.now() # 获取当前时间
model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=20, batch_size=200, verbose=2)
enddate = datetime.datetime.now()

print("训练用时:" + str(enddate - startdate))

下面以上述代码来讨论神经网络中常用的激活函数、损失函数和优化方法

激活函数

常用的激活函数还是ReLU函数 Softplus函数 tanh函数和Softmax函数等等

ReLU函数的定义为:

Softplus函数的定义为:

 

tanh函数的图像类似于Sigmoid函数,作用也类似于Sigmoid函数。它的定义为:

 

假设有一组实数y_1,y_2,…,y_K(可看作多分类的结果),Softmax函数将它们转化为一组对应的概率值: 

 

假如有一组数1、2、5、3,容易计算出它们的Softmax函数值分别约为0.01、0.04、0.83、0.11,将它们的原数值和Softmax函数值、max函数值等比例画出

修改示例代码,使模型分别采用不同激活函数组合进行比较,其他参数不变,仍为MSE损失函数、SGD优化方法,并训练20轮,运行结果:

 

采用什么样的激活函数,要根据理论研究、工程经验和试验综合分析。过拟合示例中,如果采用softplus激活函数,训练轮数仍为5000,网络结构仍然是四层(1,5,5,1)结构,分别对样本特征进行归一化处理和不归一化处理时拟合多项式的结果 

损失函数

MSE损失函数时基于欧式距离的损失函数,还有KL散度损失函数,交叉熵损失函数等等

交叉熵可以用来衡量两个分布之间的差距

信息熵的定义:H(X)=−∑_i=1^n▒p_ilogp_i。用p_i表示第i个输出的标签值,即真实值,用q_i表示第i个输出值,即预测值。将p_i与q_i之间的对数差在p_i上的期望值称为相对熵:

计算a和d两项输出的相对熵:

 

将相对熵的定义式进一步展开:

 

将相对熵的定义式进一步展开:前一项保持不变,因此一般用后一项作为两个分布之间差异的度量,称为交叉熵:H(p,q)=−∑_i=1^n▒p_ilogq_i

如果只有正负两个分类,记标签为正类的概率为y,记预测为正类的概率为p,那么上式为:

 

交叉熵损失函数在梯度下降法中可以避免MSE学习速率降低的问题,得到了广泛的应用。 

 优化算法

下面讨论常用于多层神经网络中的优化算法,它们都是梯度下降法的主要改进方法,主要从增加动量和调整优化步长两方面着手

### MindSpore框架下
class mindspore.nn.SGD(params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False, loss_scale=1.0)
 
### TensorFlow框架下
tf.keras.optimizers.SGD(
    learning_rate=0.01, momentum=0.0, nesterov=False, name='SGD', **kwargs
)

为了克服固定步长的弊端,MindSpore深度学习框架和TensorFlow2深度学习框架都提供了动态调整步长的方法。

learning_rate超参数即为梯度下降法中的步长,也称为学习率,它们的默认初始值都是固定的0.1,可以设置成动态的步长。

MindSpore提供了函数和类两种预定义的动态调整步长方法,两种方法的具体功能相近,它们分别按余弦函数、指数函数、与时间成反比、多项式函数等方式衰减步长。

learning_rate = 0.1
decay_rate = 0.9
total_step = 6
step_per_epoch = 2
decay_epoch = 1
output = exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
print(output)
>>> [0.1, 0.1, 0.09000000000000001, 0.09000000000000001, 0.08100000000000002, 0.08100000000000002]

设当前为第i步,其步长的计算方法为:

其中,current_epocℎ=floor(i/step_per_epocℎ),floor为向下取整运算。

在TensorFlow2框架中也提供了类似的动态调整步长方法,它们都在tensorflow.keras.optimezers.schedules模块内。

这些动态调整步长的方法,实际上并没有结合优化的具体进展情况来设定步长,仍然可以看成是一组预先设定的步长,只不过它们的大小按一定方式逐步衰减了。

结合优化具体进展的自适应步长调整方法Adagrad(Adaptive Gradient)算法记录下所有历史梯度的平方和,并用它的平方根来除步长,这样就使得当前的实际步长越来越小。

MindSpore中实现该算法的类为:mindspore.nn.Adagrad。TensorFlow2中实现该算法的类是:tf.keras.optimizers.Adagrad。

在经典力学中,动量(Momentum)表示为物体的质量和速度的乘积,体现为物体运动的惯性。在梯度下降法中,如果使梯度下降的过程具有一定的“动量”,保持原方向运动的一定的 “惯性”,则有可能在下降的过程中“冲过”小的“洼地”,避免陷入极小值点。

 

在SGD算法中,通过配置momentum 参数,就可以使梯度下降法利用这种“惯性”。momentum 参数设置的是“惯性”的大小

加入动量的梯度下降的迭代关系式还有一种改进方法,称为NAG(Nesterov accelerated gradient)。该方法中,计算梯度的点发生了变化,它可以理解为先按“惯性”前进一小步,再计算梯度。这种方法在每一步都往前多走了一小步,有时可以加快收敛速度。设置SGD的nesterov为True,即可使用该算法。

结合动量和步长进行优化的算法有RMSProp(Root Mean Square Prop)和Adam(Adaptive moment estimation)算法等。

RMSProp(Root Mean Square Prop)算法通过对Adagrad算法逐步增加控制历史信息与当前梯度的比例系数、增加动量因子和中心化操作形成了三个版本。在MindSpore中,实现该算法的类是:mindspore.nn.RMSProp,在TensorFlow2中实现该算法的是:tensorflow.keras.optimizers.RMSprop。

 Adam(Adaptive moment estimation)算法是一种结合了AdaGrad算法和RMSProp算法优点的算法。Adam算法综合效果较好,应用广泛。

神经网络三隐层分别采用relu、relu和softmax激活函数组合,采用交叉熵损失函数,训练20轮,采用不同的优化方法:

局部收敛与梯度消散 

BP神经网络不一定收敛,也就是说,网络的训练不一定成功。误差的平方是非凸函数,BP神经网络是否收敛或者能否收敛到全局最优,与初始值有关。

在校对误差反向传播的过程中,如果偏导数较小,在多次连乘之后,校对误差会趋近于0,导致梯度也趋近于0,前面层的参数无法得到有效更新,称之为梯度消散。相反,如果偏导数较大,则会在反向传播的过程中呈指数级增长,导致溢出,无法计算,网络不稳定,称之为梯度爆炸。

常用的解决方法包括尽量使用合适的激活函数(如Relu函数,它在正数部分导数为1);预训练;合适的网络模型(有些网络模型具有防消散和爆炸能力);梯度截断等等

创作不易 觉得有帮助请点赞关注收藏~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/104037.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构基础--散列表

一、散列简介 散列表,又叫哈希表(Hash Table),是能够通过给定的关键字的值直接访问到具体对应的值的一个数据结构。也就是说,把关键字映射到一个表中的位置来直接访问记录,以加快访问速度。 通常&#xff0…

Android设计模式详解之建造者模式

前言 Builder模式是一种创建型设计模式。 定义:将一个复杂对象的创建与它的表示分离,使得同样的构造过程可以创建不同的表示。 使用场景: 相同的方法,不同的执行顺序,产生不同的事件结果时;多个部件或零…

java学习day63(乐友商城)商品新增后台、商品编辑后台、搭建前台系统页面

1.商品新增 当我们点击新增商品按钮: 就会出现一个弹窗: 里面把商品的数据分为了4部分来填写: 基本信息:主要是一些简单的文本数据,包含了SPU和SpuDetail的部分数据,如 商品分类:是SPU中的cid1&…

大型项目都会使用到的Makefile

一、vi编辑器之神 1.vi编辑器的三种模式: 插入模式:可以编辑文档 编辑模式:可以敲一些命令,执行例如复制n行 剪切n行 ,粘贴等功能 命令模式:(最后一行模式) 在此模式下可以保存文件,退出vi…

第03讲:Redis的持久化方案

前言 redis是一个内存数据库,当redis服务器重启,获取电脑重启,数据会丢失,我们可以将redis内存中的数据持久化保存到硬盘的文件中。 redis提供两种持久化方式: RDB:快照,通过从服务器保存和持久化AOF&…

Codeforces Round #839 (Div. 3) A~G all answer

Dashboard - Codeforces Round #839 (Div. 3) - Codeforces 最近状态奇差无比,还有点生病,低烧反复横跳,应该没阳?(虽然家人都阳了,就剩我一个了wuwuwu~(A B C就不作解释了&#xff…

【小5聊】Python3 使用selenium模块实现简单爬虫系列一

第一次听说Python还是在工作的时候,还是一位女生在用,当时她说可以用来处理excel文档,特别是一些统计分析。第二次让我真正进入python世界,还是在一次C站举办的大赛上。聊聊你是因为什么机缘巧合进入到python圈的呢?不…

金盾杯2022-AGCTFS战队 wp

文章目录Web图书馆EzPHPeZphp2SQLSkip有来无回反败为胜Crypto小菜一碟RRSSAAsimpleRrandMISC盗梦空间qianda0_Sudoku数据泄露01-账号泄露追踪数据泄露02-泄露的密码数据泄露03-泄露的密钥ReverseTeaPwnLoginWtfWeb 图书馆 根据提示找到 干货|最全的Tomcat漏洞复现…

Qt5 网页标题、关键词提取工具Findyou

Qt5 网页标题、关键词提取工具Findyou 一、程序运行 运行界面 辅助功能,可用于将扫描器的扫描结果转换为url 二、所涉及的重要知识点 1、Qt爬取https的网页 来自宇龍_ https://blog.csdn.net/qq_45809384/article/details/122049295?spm1001.2014.3001.5506 打…

Foxmail客户端添加163账号和邮件备份163邮箱

文章目录一、Foxmail添加163账号1. 点击图标2. 账号管理3. 新建4. 手动设置5. 填写信息6. 创建二、邮件转移备份2.1. 邮件折叠2.2. 选择目标邮箱2.3. 同步服务端Foxmail客户端添加163账号的具体步骤如下:一、Foxmail添加163账号 1. 点击图标 首先打开Foxmail客户端…

51寻找数组中出现次数超一半的数

51寻找数组中出现次数超一半的数 一看题目就想用hash表,但是要求空间复杂度为1,说明不可以用哈希表去存。一直在原地数组上思考,类似桶排序,可是这取决于数值的大小,最后还是看了题解,学到了。 思想是&…

外汇天眼:一笔赚了12600美元 你羡慕吗?

在外汇投资中,黑平台一直是外汇投资圈的一枚毒瘤,不能顺利出金也是外汇投资面临的最大风险之一。 对于外汇投资者而言,外汇交易平台的选择至关重要。 选择好的外汇交易平台,最重要的是:选择安全可靠的平台&#xff0…

Blackmagic黑魔法摄像机braw视频帧损坏文件修复方法

Blackmagic是全球知名的影视级产品供应商,其高清摄像机是国内外各种剧组的最爱。Blackmagic的新产品目前使用braw格式,其编码采用自定义的raw编码,视频的效果和阿莱不相上下。之前我们已经多次介绍过这种braw文件的修复,近期我们处…

grpc的使用

GRPC学习 本文包括grpc的入门使用和四种实现方式 文章目录一、GRPC 安装和hello world1、什么是GRPC2、安装grpc和代码3、服务端3.1、取出 server3.2、挂载方法3.3、注册服务3.4、创建监听4、客户端二、protobuf语法三、GRPC server 的使用1、普通服务2、流式传入(客…

通达信破解接口怎么委托下单?

通达信破解接口主要是利用数学公式建立模型,通过大量数据判断未来价格走势,通过程序选股。虽然选股也比较广泛,但也能覆盖A股市场的四千多只股票,能排除强行涨跌等人为因素,执行的纪律性强。所以对于通达信破解接口对股…

【笔记】git 修改之前的提交记录信息(git commit -m ‘...‘)

文章目录一、修改最后一条提交记录信息二、修改前面某条或某几条提交记录信息一、修改最后一条提交记录信息 git commit --amend进入vi编辑器后: 按i下方出现’- - 插入 - -‘的提示时,便可编辑提交记录信息按ESC,输入:wq保存退出&#xff0…

ICG衍生物ICG-Sulfo-OSu的产品描述及保存建议

中文名称 ICG-Sulfo-OSu 英文名字 ICG-Sulfo-OSu 凯新生物描述: (ICG)是一种用于医学诊断的菁染料它用于测定心输出量、肝功能和肝血流,以及眼科血管造影它的峰值光谱吸收接近800 nm这些红外频率穿透视网膜层,使ICG血管造影能够比…

【STA】(2)概念

目录 1.CMOS逻辑设计 1.1 基本MOS结构 1.2 COMS逻辑门 1.3 标准单元 2.CMOS单元建模 3.电平翻转波形 4.传播延迟 5.波形的转换率 6.信号之间的偏移 7. 时序弧和单调性 8.最小和最大时序路径 9.时钟域 10.工作条件 1.CMOS逻辑设计 1.1 基本MOS结构 MOS(Metal Oxide…

2022年Python笔试选择题及答案(秋招)

2022年Python笔试选择题及答案(秋招) 🏠个人主页:编程ID 🧑个人简介:大家好,我是编程ID,一个想要与大家共同进步的程序员儿 🧑如果各位哥哥姐姐在准备面试,找…

【Redis-11】Redis事务实现原理

Redis通过MULTI、EXEC、WATCH等命令来实现事务的功能,事务提供了一种将多个命令请求打包,然后一次性,顺序性的执行多个命令的机制。在事务执行期间,服务器不会中断事务去执行其他客户端的命令,他会讲事务中所有命令执行…