Optimizer神经网络中各种优化器介绍

news2025/1/16 19:08:53

1. SGD

1.1 batch-GD

每次更新使用全部的样本,注意会对所有的样本取均值,这样每次更新的速度慢。计算量大。

1.2 SGD

每次随机取一个样本。这样更新速度更快。SGD算法在于每次只去拟合一个训练样本,这使得在梯度下降过程中不需去用所有训练样本来更新Theta。BGD每次迭代都会朝着最优解逼近,而SGD由于噪音比BGD多,多以SGD并不是每次迭代都朝着最优解逼近,但大体方向是朝着最优解,SGD大约要遍历1-10次数据次来获取最优解。

但是 SGD 因为更新比较频繁,会造成 cost function 有严重的震荡。

1.3. MBGD(Mini-batch Gradient Descent)

MBGD有时候甚至比SGD更高效。MBGD不像BGD每次用m(所有训练样本数)个examples去训练,也不像SGD每次用一个example。MBGD使用中间值b个examples
经典的b取值大约在2-100。例如 b=10,m=1000。

2. Momentum

在这里插入图片描述

SGD存在的一个主要问题是:在沟壑处无法正常收敛的问题。如果初始化不好不幸陷入沟壑区,则会出现下面左图的震荡问题:即在一个方向上梯度很大,且正负交替出现。而momentum会加上前面的一次迭代更新时的梯度。让与上一次同方向的值更大,反方向的更小,如下面右图所示。momentum公式为:
v t = γ v t − 1 + η Δ θ J ( θ ) θ = θ − v t \begin{align} v_t&=\gamma v_{t-1}+\eta\Delta_\theta J(\theta)\\ \theta &= \theta-v_t \end{align} vtθ=γvt1+ηΔθJ(θ)=θvt

  • 下降初期时,使用上一次参数更新,下降方向一致,乘上较大的$mu$能够进行很好的加速。
  • 下降中后期时,在局部最小值来回震荡的时候, g r a d i e n t → 0 gradient\to 0 gradient0 μ \mu μ使得更新幅度增大,跳出陷阱。
  • 在梯度改变方向的时候, μ \mu μ能够减少更新 总而言之,momentum项能够在相关方向加速SGD,抑制振荡,从而加快收敛。
    正确的方向上让他更快,错误的方向上让他更慢。如果上次的momentum(v)与这次的负梯度方向是相同的,那这次下降的幅度就会加大,从而加速收敛。
    momentum的更新方式为:
    在这里插入图片描述

momentum设置太小动量效果不明显,设置太大容器使得本来收敛很好的地方震动太大,特别是训练的后期,一般取0.9。

3. NAG(Nesterov accelerated gradient)

动量法每下降一步都是由前面下降方向的一个累积和当前点的梯度方向组合而成。于是一位大神(Nesterov)就开始思考,既然每一步都要将两个梯度方向(历史梯度、当前梯度)做一个合并再下降,那为什么不先按照历史梯度往前走那么一小步,按照前面一小步位置的“超前梯度”来做梯度合并呢?如此一来,小球就可以先不管三七二十一先往前走一步,在靠前一点的位置看到梯度,然后按照那个位置再来修正这一步的梯度方向。如此一来,有了超前的眼光,小球就会更加”聪明“, 这种方法被命名为Nesterov accelerated gradient 简称 NAG。

NAG的更新方式为:
在这里插入图片描述

与momentum不同的是,NAG是先往前走一步,谈谈路,用超前的梯度来进行修正。
更新公式为:
在这里插入图片描述

实现证明,比momentum更快。

4. AdaGrad

SGD+Momentum的问题是:

  • 设置初始的学习率比较难
  • 所有的参数都使用相同的学习率
    Adam采用累加前面梯度的平方和的方式。能够对每个参数自适应不同的学习速率。因此对于稀疏特征,学习率会大一点。对非稀疏特征,学习率会小一点。因此次方法适合处理稀疏特征。公式为:
    θ t + 1 , i = θ t , i − η G t , i + ϵ g t , i \theta_{t+1, i}=\theta_{t, i}-\frac {\eta}{\sqrt{G_{t,i}+\epsilon}}g_{t,i} θt+1,i=θt,iGt,i+ϵ ηgt,i

其中 同样是当前的梯度,连加和开根号都是元素级别的运算。 是初始学习率,由于之后会自动调整学习率,所以初始值就不像之前的算法那样重要了。而是一个比较小的数,用来保证分母非0。

其含义是,对于每个参数,随着其更新的总距离增多,其学习速率也随之变慢。

g t g_t gt从1到t进行一个递推形成一个约束项, ϵ \epsilon ϵ保证分母非0。
G t , i = ∑ r = 1 t ( g r , i 2 ) G_{t, i}=\sum_{r=1}^t(g_{r,i}^2) Gt,i=r=1t(gr,i2)
为前面的参数的梯度平方和。特点为:

  • 前期梯度较小的时候,叠加的梯度平方和也比较小,能够加快梯度。
  • 后期梯度叠加项比较大uo,梯度也会变小,能够以小步幅更新。
  • 对于不同的变量可以用不同的学习率。
  • 适合处理稀疏的数据。

缺点:

  • 依赖一个全局学习率
  • 中后期,梯度的平方和累加会越来越大,会使得 g r a d i e n t → 0 gradient\to 0 gradient0,使得后期训练很慢,甚至接近0。

5. AdaDelta

Adadelta是对于Adagrad的扩展。最初方案依然是对学习率进行自适应约束,但是进行了计算上的简化。 Adagrad会累加之前所有的梯度平方,而Adadelta只累加固定大小的项(Adagrad需要存储),并且也不直接存储这些项,仅仅是近似计算对应的平均值。即:
E [ g 2 ] t = γ E [ g 2 ] t − 1 + ( 1 − γ ) g t 2 Δ θ t = − η E [ g 2 ] t + ϵ g t \begin{align} E[g^2]_t&=\gamma E[g^2]_{t-1}+(1-\gamma)g_t^2\\ \Delta\theta_t&=-\frac{\eta}{\sqrt{E[g^2]_t+\epsilon}}g_t \end{align} E[g2]tΔθt=γE[g2]t1+(1γ)gt2=E[g2]t+ϵ ηgt

因为AdaDelta需要计算 R [ g t − w : t ] R[g_t-w:t] R[gtw:t],需要存储前面 w w w个状态,比较麻烦。因此AdaDelta采用了类似momtemum的平均话方法,如果 γ = 0.5 \gamma=0.5 γ=0.5,则相当于前面的均方根RMS。其中Inception-V3的初始化建议为1。

此处AdaDelta还是依赖于全局学习率,因此作者做了一定的处理来近似:
经过二阶海森矩阵近似之后,得到 Δ x ∼ x \Delta x\sim x Δxx
Δ x t = − ∑ r = 1 t − 1 Δ x r 2 E [ g 2 ] t + ϵ \Delta_{x_t}=-\frac{\sqrt{\sum_{r=1}^{t-1}\Delta x_r^2}}{\sqrt{E[g^2]_t+\epsilon}} Δxt=E[g2]t+ϵ r=1t1Δxr2
这样的话,AdaDelta已经不依赖于全局学习率了。

  • 训练初中期,加速效果不错,很快
  • 训练后期,反复在局部最小值附近抖动

6. RMSProp

RMSProp是AdaDelta的一种扩展。当 γ = 0.5 \gamma=0.5 γ=0.5的时候就变成了RMSProp。但是RMSProp仍然依赖于全局学习率。效果介于AdaGrad和AdaDelta之间。

7. Adam

Adam(Adaptive Moment Estimation)本质上是带有动量项的RMSprop,它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。Adam的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。公式如下:
g t = Δ θ J ( θ t − 1 ) m t = β 1 m t − 1 + ( 1 − β 1 ) g t v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 m ^ t = m t 1 − β 1 t v ^ t = v t 1 − β 2 t θ = θ − α m ^ t v ^ t + ϵ \begin{align} g_t&=\Delta_\theta J(\theta_{t-1})\\ m_t&=\beta_1m_{t-1}+(1-\beta_1)g_t\\ v_t&=\beta_2 v_{t-1}+(1-\beta_2)g_t^2\\ \hat m_t&=\frac{m_t}{1-\beta_1^t}\\ \hat v_t&=\frac{v_t}{1-\beta_2^t}\\ \theta&=\theta-\alpha \frac{\hat m_t}{\sqrt{\hat v^t}+\epsilon} \end{align} gtmtvtm^tv^tθ=ΔθJ(θt1)=β1mt1+(1β1)gt=β2vt1+(1β2)gt2=1β1tmt=1β2tvt=θαv^t +ϵm^t
然后对 m t m_t mt n t n_t nt进行无偏估计。因为 m 0 m_0 m0 n 0 n_0 n0初始化都是0,我们希望能够快点从0中调出来。因为如果 β \beta β比较大的话,原来的 m t m_t mt可能会调不出来。因此进行无偏估计后能够放大。 β 1 \beta_1 β1 β 2 \beta_2 β2两个超参数一般设置为0.9和0.999。:
m ^ t = m t 1 − β 1 t \hat m_t=\frac{m_t}{1-\beta_1^t} m^t=1β1tmt
接下来更新参数,初始的学习率 α \alpha α(默认0.001)乘以梯度均值与梯度方差的平方根之比。由表达式可以看出,对更新的步长计算,能够从梯度均值及梯度平方两个角度进行自适应地调节,而不是直接由当前梯度决定。

直接对梯度的矩进行估计对内存没有额外的要求,而且可以根据梯度进行动态调整。而且后面的一项比值可以对学习率形成一个动态约束,因为它是有范围的。

目前来讲,效果最好的是Adam。但是经典的论文搞上去的方式都是先用Adam,然后再用SGD+momentum死磕上去。

Adam看作是Momentum+RMSProp的结合体。

形成一个动态约束,因为它是有范围的。

目前来讲,效果最好的是Adam。但是经典的论文搞上去的方式都是先用Adam,然后再用SGD+momentum死磕上去。

Adam看作是Momentum+RMSProp的结合体。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1557792.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot 整合Redis第1篇

SpringBoot是一个开发框架,Redis是一个高性能的键值存储数据库, 常用于缓存、会话管理、消息队列等应用场景。 定义 Redis是什么? 它是一个存储层级, 在实际项目中,位于关系数据库之上, 类似Android分为5…

(C语言)fread与fwrite详解

1. fwrite函数详解 头文件&#xff1a;stdio.h 函数有4个参数&#xff0c;只适用于文件输出流 作用&#xff1b;将从ptr中拿count个大小为size字节的数据以二进制的方式写到文件流中。返回写入成功的数目。 演示 #include <stdio.h> int main() {FILE* pf fopen(&qu…

相册清理大师-手机重复照片整理、垃圾清理软件

相册清理大师是一款超级简单实用的照片视频整理工具。通过便捷的操作手势&#xff0c;帮助你极速整理相册中的照片和视频、释放手机存储空间。 【功能简介】 向上滑动&#xff1a;删除不要的照片 向左滑动&#xff1a;切换下一张照片 向右滑动&#xff1a;返回上一张照片 整理分…

【shell】select in实现终端交互场景

文章目录 序言1. select in语句及其语法2. select in和case语句相结合3. 执行界面示例 序言 shell脚本实现简单的终端交互功能&#xff0c;根据用户不同输入执行不同功能脚本 1. select in语句及其语法 select in是shell独有的一种循环&#xff0c;非常适合终端交互场景 该语…

链表的极致——带头双向循环链表

​ 文章目录 双向带头循环链表简介&#xff1a;双向&#xff1a;带头&#xff1a;特点&#xff1a;链表带头节点的好处&#xff1a; 循环&#xff1a;特点&#xff1a;循环的好处&#xff1a; 双向带头循环链表的接口函数实现准备工作&#xff1a; 初始化链表&#xff08;头结…

C++:数据类型—布尔(12)

布尔类型代表就是真和假&#xff08;bool&#xff09; 真就是1&#xff08;true&#xff09; 假就是0&#xff08;false&#xff09; 也可以任务非0即为真 bool 直占用1个字节大小 语法&#xff1a;bool 变量名 (true | false&#xff09; 提示&#xff1a;bool在后期判断也是…

深度学习pytorch——经典卷积网络之ResNet(持续更新)

错误率前五的神经网络&#xff08;图-1&#xff09;&#xff1a; 图-1 可以很直观的看到&#xff0c;随着层数的增加Error也在逐渐降低&#xff0c;因此深度是非常重要的&#xff0c;但是学习更好的网络模型和堆叠层数一样简单吗&#xff1f;通过实现表明&#xff08;图-2&…

《自动机理论、语言和计算导论》阅读笔记:p49-p67

《自动机理论、语言和计算导论》学习第4天&#xff0c;p49-p67总结&#xff0c;总计19页。 一、技术总结 1.Deterministic Finite Automata(DFA) vs Nondeterministic Finite Automata(NFA) (1)DFA定义 (2)NFA定义 A “nonedeterministic” finite automata has the power t…

python之绘制曲线

以同一种型号的钻头&#xff0c;钻21种类型的板材&#xff0c;每种板材使用3根钻头&#xff0c;分别在钻第一个孔、2001孔、4001孔和6001孔前测量钻头外径&#xff0c;收集数据。 1、测试方法 采用激光钻径分选机测量微钻钻径以评估微钻外径磨损&#xff0c;测量从钻尖起始间…

C语言-文件

目录 1.什么是文件&#xff1f;1.1 程序文件1.2 数据文件 2.二进制文件和文本文件&#xff1f;3.文件的打开和关闭4.文件的顺序读写5.文件的随机读写5.1 fseek5.2 ftell5.3 rewind 6.文件读取结束的判定7.文件缓冲区 1.什么是文件&#xff1f; 磁盘上的文件就是文件 一般包含两…

使用pytorch构建带梯度惩罚的Wasserstein GAN(WGAN-GP)网络模型

本文为此系列的第三篇WGAN-GP&#xff0c;上一篇为DCGAN。文中仍然不会过多详细的讲解之前写过的&#xff0c;只会写WGAN-GP相对于之前版本的改进点&#xff0c;若有不懂的可以重点看第一篇比较详细。 原理 具有梯度惩罚的 Wasserstein GAN (WGAN-GP)可以解决 GAN 的一些稳定性…

caffe源码编译安装

一、前置准备 (1)vs2015 目前不要想着2019这些工具了,成功率太低了,就老老实实用vs2015吧 解决“VS2015安装包丢失或损坏“问题_vs2015跳过包会影响使用吗-CSDN博客 注意在安装vs2015过程中老是出现这个问题,其实就是缺少两个证书,安装完后就可以正常安装vs2015了,注意…

大数据面试专题 -- kafka

1、什么是消息队列&#xff1f; 是一个用于存放数据的组件&#xff0c;用于系统之间或者是模块之间的消息传递。 2、消息队列的应用场景&#xff1f; 主要是用于模块之间的解耦合、异步处理、日志处理、流量削峰 3、什么是kafka&#xff1f; kafka是一种基于订阅发布模式的…

AE——重构数字(Pytorch+mnist)

1、简介 AE&#xff08;自编码器&#xff09;由编码器和解码器组成&#xff0c;编码器将输入数据映射到潜在空间&#xff0c;解码器将潜在表示映射回原始输入空间。AE的训练目标通常是最小化重构误差&#xff0c;即尽可能地重构输入数据&#xff0c;使得解码器输出与原始输入尽…

什么是nginx正向代理和反向代理?

什么是代理&#xff1f; 代理(Proxy), 简单理解就是自己做不了的事情或实现不了的功能&#xff0c;委托别人去做。 什么是正向代理&#xff1f; 在nginx中&#xff0c;正向代理指委托者是客户端&#xff0c;即被代理的对象是客户端 在这幅图中&#xff0c;由于左边内网中…

如何解决kafka rebalance导致的暂时性不能消费数据问题

文章目录 背景思考答案排它故障转移共享 背景 之前在review同组其它业务的时候&#xff0c;发现竟然把kafka去掉了&#xff0c;问了下原因&#xff0c;有一个单独的服务&#xff0c;我们可以把它称为agent&#xff0c;就是这个服务是动态扩缩容的&#xff0c;会采集一些指标&a…

k8s的pod访问service的方式

背景 在k8s中容器访问某个service服务时有两种方式&#xff0c;一种是把每个要访问的service的ip注入到客户端pod的环境变量中&#xff0c;另一种是客户端pod先通过DNS服务器查找对应service的ip地址&#xff0c;然后在通过这个service ip地址访问对应的service服务 pod客户端…

HarmonyOS 应用开发之FA模型访问Stage模型DataShareExtensionAbility

概述 无论FA模型还是Stage模型&#xff0c;数据读写功能都包含客户端和服务端两部分。 FA模型中&#xff0c;客户端是由DataAbilityHelper提供对外接口&#xff0c;服务端是由DataAbility提供数据库的读写服务。 Stage模型中&#xff0c;客户端是由DataShareHelper提供对外接…

腾讯云2核2G服务器优惠价格,61元一年

腾讯云2核2G服务器多少钱一年&#xff1f;轻量服务器61元一年&#xff0c;CVM 2核2G S5服务器313.2元15个月&#xff0c;轻量2核2G3M带宽、40系统盘&#xff0c;云服务器CVM S5实例是2核2G、50G系统盘。腾讯云2核2G服务器优惠活动 txybk.com/go/txy 链接打开如下图&#xff1a;…

java数组与集合框架(三)--Map,Hashtable,HashMap,LinkedHashMap,TreeMap

Map集合&#xff1a; Map接口: 基于 键&#xff08;key&#xff09;/值&#xff08;value&#xff09;映射 Map接口概述 Map与Collection并列存在。用于保存具有映射关系的数据:key-value Map 中的key 和value 都可以是任何引用类型的数据Map 中的key 用Set来存放&#xff0…