神经网络 06(优化方法)

news2025/1/9 2:07:57

一、优化方法

网络搭建好,损失函数设计好之后, 根据损失函数更新参数(权重,偏移)。参数更新过程就是一个神经网络优化过程。

 

二、梯度下降方法

梯度下降法简单来说就是一种寻找使损失函数最小化的方法。从数学上的角度来看,梯度的方向是函数增长速度最快的方向、那么梯度的反方向就是函数减少最快的方向,所以有:

w是权重,E是损失函数

其中,n是学习率,如果学习率太小,那么每次训练之后得到的效果都太小,增大训练的时间成本。如果,学习率太大,那就有可能直接跳过最优解,进入无限的训练中。解决的方法就是,学习率也需要随着训练的进行而变化

 

实际中使用较多的是小批量的梯度下降算法,在tf.keras中通过以下方法实现

tf.keras.optimizers.SGD(
    learning_rate=0.01, momentum=0.0, nesterov=False, name='SGD', **kwargs
)
# 导入相应的工具包
import tensorflow as tf
# 实例化优化方法:SGD 
opt = tf.keras.optimizers.SGD(learning_rate=0.1)
# 定义要调整的参数
var = tf.Variable(1.0)
# 定义损失函数:无参但有返回值
loss = lambda: (var ** 2)/2.0  
# 计算梯度,并对参数进行更新,步长为 `- learning_rate * grad`
opt.minimize(loss, [var]).numpy()
# 展示参数更新结果
var.numpy()

在进行模型训练时,有三个基础的概念:

实际上,梯度下降的几种方式的根本区别就在于 Batch Size不同,,如下表所示:

 

假设数据集有 50000 个训练样本,现在选择 Batch Size = 256 对模型进行训练。

每个 Epoch 要训练的图片数量:50000
训练集具有的 Batch 个数:50000/256+1=196
每个 Epoch 具有的 Iteration 个数:196
10个 Epoch 具有的 Iteration 个数:1960
 

三、反向传播算法(BP算法)

利用反向传播算法对神经网络进行训练。该方法与梯度下降算法相结合,对网络中所有权重计算损失函数的梯度,并利用梯度值来更新权值以最小化损失函数。 

3.1 前向传播和反向传播

前向传播是指数据输入的神经网络中,逐层向前传输,一直运算到输出层为止。

在网络的训练过程中经过前向传播后得到的最终结果跟样本的真实值总是存在误差,误差使用损失函数衡量。想要减小这个误差,就用损失函数Error,从后向前,依次求出各个参数的偏导,这就是反向传播。

3.2 链式法则

反向传播算法是利用链式法则进行梯度求解及权重更新的。对于复杂的复合函数,我们将其拆分为一系列的加减乘除或指数,对数,三角函数等初等函数,通过链式法则完成复合函数的求导。为简单起见,这里以一个神经网络中常见的复合函数的例子来说明 这个过程.令复合函数 (x; w,b)为:

其中x是输入数据,w是权重,b是偏置。我们可以将该复合函数分解为:

 

3.3 反向传播算法

反向传播算法利用链式法则对神经网络中的各个节点的权重进行更新

假设当前前向传播的过程如下图所以:

计算损失函数,进行反向传播:

计算梯度值

3.4 梯度下降优化方法

梯度下降算法在进行网络训练时,会遇到鞍点,局部极小值这些问题,那我们怎么改进SGD呢?在这里我们介绍几个比较常用的

3.4.1 动量算法

动量算法主要针对鞍点问题,介绍动量算法之前,首先看下指数加权平均数的计算方法。

指数加权平均数

其中Yt为t时刻时的真实值,St为t加权平均后的值,β为权重值。红线即是指数加权平均后的结果。
上图中β设为0.9,那么指数加权平均的计算结果为:

动量梯度下降算法

动量梯度下降 (Gradient Descent with Momentum)计算梯度的指数加权平均数,并利用该值来更新参数值。动量梯度下降法的整个过程为,其中β通常设置为0.9:

 

与原始的梯度下降算法相比,它的下降趋势更平滑

3.4.2 AdaGrad

AdaGrad算法会使用一个小批量随机梯度g_t按元素平方的累加变量st。在首次送代时,AdaGrad将s0中每个元素初始化为0。在t次迭代,首先将小批量随机梯度gt按元素平方后累加到变量st:

其中O是按元素相乘。接着,我们将目标函数自变量中每个元素的学习率通过按元素运算重新调整下:

其中α是学习率,e是为了维持数值稳定性而添加的常数,如10^(-6)。这里开方、除法和乘法的运算都是按元素运算的。这些按元素运算使得目标函数自变量中每个元素都分别拥有自己的学习率。

3.4.3  RMSprop

AdaGrad算法在迭代后期由于学习率过小,能较难找到最优解。为了解决这一问题,RMSProp算法对AdaGrad算法做了一点小小的修改。


不同于AdaGrad算法里状态变量st是截至时间步t所有小批量随机梯度gt按元素平方和,RMSProp(RootMean Square Prop) 算法将这些梯度按元素平方做指数加权移动平均

其中e是一样为了维持数值稳定一个常数。最终自变量每个元素的学习率在迭代过程中就不再一直降低。RMSProp 有助于减少抵达最小值路径上的摆动,并允许使用一个更大的学习率α,从而加快算法学习速度。

3.4.4  Adam

Adam 优化算法 (Adaptive Moment Estimation,自适应矩估计)将 Momentum 和RMSProp算法结合在一起。Adam算法在RMSProp算法基础上对小批量随机梯度也做了指数加权移动平均
假设用每一个mini-batch 计算 dw、db,第t次迭代时:

建议的参数设置的值
。学习率a:需要尝试一系列的值,来寻找比较合适的
。β1:常用的缺省值为 0.9
。β2:建议为0.999
e:默认值1e-8

四、学习率退火

 在训练神经网络时,一般情况下学习率都会随着训练而变化,这主要是由于,在神经网络训练的后期,如果学习率过高,会造成loss的振荡,但是如果学习率减小的过快,又会造成收敛变慢的情况。

4.1 分段常数衰减

分段常数衰减是在事先定义好的训练次数区间上,设置不同的学习率常数。刚开始学习率大一些,之后越来越小,区间的设置需要根据样本量调整,一般样本量越大区间间隔应该越小。

4.2 指数衰减

4.3 1/t衰减

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1005522.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

仿东郊到家app系统及功能介绍

类似东郊到家app系统开发,预约sap东郊到家软件定制开发,东郊到家小程序APP开发,东郊到家模式系统定制开发 一、东郊到家软件介绍 1、东郊到家app是一家以推拿为主项,个人定制型的o2o平台,东郊到家app平台提供、正规、安…

计算即时订单比例-首单使用开窗函数row_number()

1 需求 即时订单和计划订单 订单配送中,如果期望配送日期和下单日期相同,称为即时订单,如果期望配送日期和下单日期不同,称为计划订单。 请从配送信息表(delivery_info)中求出每个用户的首单(用…

langchain主要模块(一):模型输入输出

langchain2之模型输入输出 langchain1.概念2.主要模块模型输入/输出 (Model I/O)数据连接 (Data connection)链式组装 (Chains)代理 (Agents)内存 (Memory)回调 (Callbacks) 3.模型输入/输出 (Model I/O)提示提示模板示例选择器 模型LLMsChatModels 输出解释器 langchain 1.概…

计算机竞赛 多目标跟踪算法 实时检测 - opencv 深度学习 机器视觉

文章目录 0 前言2 先上成果3 多目标跟踪的两种方法3.1 方法13.2 方法2 4 Tracking By Detecting的跟踪过程4.1 存在的问题4.2 基于轨迹预测的跟踪方式 5 训练代码6 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 深度学习多目标跟踪 …

opencv(python)视频按帧切片/cv2.VideoCapture()用法

一、介绍 cv2.VideoCapture是OpenCV中一个用于捕捉视频的类。它可以访问计算机的摄像头,或从视频文件中读取图像。通过cv2.VideoCapture,用户可以轻松地捕捉、保存、编辑和传输视频流数据。 使用cv2.VideoCapture可以实现以下功能: 1. 打开…

基于微信小程序的自习室系统设计与实现,可作为毕业设计

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 文章目录 1 简介2 技术栈3 需求分析3.1用户需求分析3.1.1 学生用户3.1.3 管理员用户 4 数据库设计4.4.1 E…

linux————ansible

一、认识自动化运维 自动化运维: 将日常IT运维中大量的重复性工作,小到简单的日常检查、配置变更和软件安装,大到整个变更流程的组织调度,由过去的手工执行转为自动化操作,从而减少乃至消除运维中的延迟,实现“零延时”…

Could not find artifact com.mysql:mysql-connector-j:pom:unknown

在 <dependency><groupId>com.mysql</groupId><artifactId>mysql-connector-j</artifactId><scope>runtime</scope> </dependency> 添加版本号 这里用的是8.0.33版本&#xff0c;输入5.0的版本依然会报错 我自身用的是5.0…

做期权卖方有资金限制吗?

做期权卖方一般是有经济实力的自然人或机构做的&#xff0c;而且必须开立保证金账户&#xff0c;万一买方要行权就会有较高的风险&#xff0c;当然&#xff0c;做期权卖方在交易方面对经验和行情的预判是有一定要求的&#xff0c;下文介绍做期权卖方有资金限制吗&#xff1f; 一…

TCP服务器使用多路复用

启用复用的作用&#xff1f; 解决linux系统中的io阻塞问题&#xff0c;让多个阻塞io接口可以一起执行。无需开启线程&#xff0c;节省系统资源。 linux系统中的阻塞io有哪些&#xff1f; scanf、read管道、eadTcp套接字、acppet接收连接请求 有以下两种方式实现多路复用&am…

广州xx策划公司MongoDB恢复-2023.09.09

2023.09.08用户的MongoDB数据库被勒索病毒攻击&#xff0c;数据全部被清空。 提示&#xff1a; mongoDB的默认端口为27017&#xff0c;黑客通常通过全网段扫描27017是否开放判断是否是MongoDB服务器。一旦发现27017开放&#xff0c;黑客就会用空密码、弱密码尝试连接数据库。黑…

总结987

考研倒计时102天 时间记录&#xff1a; 6:20起床 7:00~7:40早读&#xff0c;13年tex2 7:50~8:20实验室 8:30~8:34列日计划 8:40~11:18进步本回顾&#xff0c;记录 11&#xff1a;20~12:20计算机网络网课 2:10~3:05计网20道选择题 3:07~4:42政治1000题25道选择题纠错 …

idea纯java工程使用gradle指定生成jar的Main-Class,idea生成jar

build.gradle核心代码如下&#xff1a; jar {manifest {attributes "Main-Class": "com.example.sample.Application"}from {configurations.compile.collect { it.isDirectory() ? it : zipTree(it) }} } 完整代码如下: group com.example.sample ver…

蚂蚁金融大模型

9月8日&#xff0c;蚂蚁集团在上海外滩大会发布的蚂蚁金融大模型基于蚂蚁基础大模型&#xff0c;针对金融产业深度定制。蚂蚁基础大模型平台具备万卡异构集群&#xff0c;其中千卡规模训练 MFU 可达到40%&#xff0c;集群有效训练时长占比90&#xff05;以上&#xff0c; RLHF …

前端面试题JS篇(6)

ES6 Module 和 CommonJS 模块的区别&#xff1a; CommonJS 是对模块的浅拷⻉&#xff0c;ES6 Module 是对模块的引⽤&#xff0c;即 ES6 Module 只存只读&#xff0c;不能改变其值&#xff0c;也就是指针指向不能变&#xff0c;类似 const&#xff1b; import 的接⼝是 read-o…

计算机网络第四节 数据链路层

一&#xff0c;引入数据链路层的目的 1.目的意义 数据链路层是体系结构中的第二层&#xff1b; 从发送端来讲&#xff0c;物理层可以将数据链路层交付下来的数据&#xff0c;装换成光&#xff0c;电信号发送到传输介质上了 从接收端来讲&#xff0c;物理层能将传输介质的光&…

ARM接口编程—RTC(exynos 4412平台)

RTC简介 RTC(Real Time Clock)即实时时钟&#xff0c;它是一个可以为系统提供精确的时间基准的元器件&#xff0c;RTC一般采用精度较高的晶振作为时钟源&#xff0c;有些RTC为了在主电源掉电时还可以工作&#xff0c;需要外加电池供电。 RTC内部原理 RTC寄存器 RTC控制寄存器 …

Leetcode 504.七进制数

给定一个整数 num&#xff0c;将其转化为 7 进制&#xff0c;并以字符串形式输出。 示例 1: 输入: num 100 输出: "202"示例 2: 输入: num -7 输出: "-10" 我的答案&#xff1a; 一、信息 1.目的实现十进制向其他进制的转换。 2.原理&#xff1a;公…

openGauss学习笔记-67 openGauss 数据库管理-创建和管理普通表-创建表

文章目录 openGauss学习笔记-67 openGauss 数据库管理-创建和管理普通表-创建表67.1 背景信息67.2 创建表 openGauss学习笔记-67 openGauss 数据库管理-创建和管理普通表-创建表 67.1 背景信息 表是建立在数据库中的&#xff0c;在不同的数据库中可以存放相同的表。甚至可以通…

Redis模块二:缓存分类 + Redis模块三:常见缓存(应用)

缓存大致可以分为两大类&#xff1a;1&#xff09;本地缓存 2&#xff09;分布式缓存 目录 本地缓存 分布式缓存 常见缓存的使用 本地缓存&#xff1a;Spring Cache 分布式缓存&#xff1a;Redis 本地缓存 本地缓存也叫单机缓存&#xff0c;也就是说可以应⽤在单机环…