深度学习常用优化器总结,具详细(SGD,Momentum,AdaGrad,Rmsprop,Adam,Adamw)

news2024/11/24 14:00:19

学习需要,总结一些常用优化器。

目录

  • 前言
  • SGD:随机梯度下降
  • BGD:批量梯度下降
  • MBGD:小批量梯度下降
  • Momentum
  • AdaGrad
  • RMSprop
  • Adam: Adaptive Moment Estimation
  • AdamW
  • 参考文章

前言

优化器的本质是使用不同的策略进行参数更新。常用的方法就是梯度下降,那梯度下降是指在给定待优化的模型参数 θ ∈ R d \theta \in R^d θRd,和目标函数 J ( θ ) J(\theta) J(θ),算法通过沿梯度 ∇ J ( θ ) \nabla J(\theta) J(θ)的反方向更新权重 θ \theta θ,来最小化目标函数。
学习率 μ \mu μ决定了每一时刻的更新步长。对于每一个时刻 t ,我们可以用下述公式描述梯度下降的流程:
θ t + 1 = θ t − μ ∇ J ( θ ) \theta_{t+1} = \theta_{t} - \mu \nabla J(\theta) θt+1=θtμJ(θ)
梯度下降法目前主要分为三种方法,区别在于每次参数更新时计算的样本数据量不同:批量梯度下降法(BGD, Batch Gradient Descent),随机梯度下降法(SGD, Stochastic Gradient Descent)及小批量梯度下降法(Mini-batch Gradient Descent)。

SGD:随机梯度下降

随机梯度下降是指在一个批次的训练样本中,我随机挑选一个样本计算其关于目标函数的梯度,然后用此梯度进行梯度下降。
设选择的样本为 ( x i , y i ) (x^i,y^i) (xi,yi),首先计算其梯度 ∇ J ( θ , x i , y i ) \nabla J(\theta,x^i,y^i) J(θ,xi,yi),然后进行权值更新:
θ t + 1 = θ t − μ ∇ J ( θ , x i , y i ) \theta_{t+1} = \theta_{t} - \mu \nabla J(\theta,x^i,y^i) θt+1=θtμJ(θ,xi,yi)
SGD的优点是实现简单、效率高,缺点是收敛速度慢、容易陷入局部最小值;迭代次数多

BGD:批量梯度下降

与SGD对应的,BGD是对整个批次的训练样本都进行梯度计算。
设批样本为 { ( x 1 , y 1 ) , . . . , ( x n , y n ) } \{(x^1,y^1),..., (x^n,y^n)\} {(x1,y1),...,(xn,yn)},首先计算所有的样本梯度的平均值 1 n ∑ i = 1 n ∇ J ( θ , x i , y i ) \frac{1}{n} \sum _{i=1} ^{n} \nabla J(\theta,x^i,y^i) n1i=1nJ(θ,xi,yi),然后进行梯度更新:
θ t + 1 = θ t − μ 1 n ∑ i = 1 n ∇ J ( θ , x i , y i ) \theta_{t+1} = \theta_{t} - \mu \frac{1}{n} \sum _{i=1} ^{n} \nabla J(\theta,x^i,y^i) θt+1=θtμn1i=1nJ(θ,xi,yi)
BGD得到的是一个全局最优解,但是每迭代一步,都要用到训练集的所有数据,如果样本数巨大大,那上述公式迭代起来则非常耗时,模型训练速度很慢;迭代次数少

MBGD:小批量梯度下降

是BGD和SGD的折中,从训练样本中选取一小批样本进行梯度计算,然后更新梯度:
θ t + 1 = θ t − μ 1 b ∑ i = 1 b ∇ J ( θ , x i , y i ) \theta_{t+1} = \theta_{t} - \mu \frac{1}{b} \sum _{i=1} ^{b} \nabla J(\theta,x^i,y^i) θt+1=θtμb1i=1bJ(θ,xi,yi)

Momentum

指数加权移动平均是一种常用的序列数据处理方式,用于描述数值的变化趋势,本质上是一种近似求平均的方法。计算公式如下:
v t = β v t − 1 + ( 1 − β ) θ t ​​ v_t=βv _{t−1}+(1−β)θ_t​​ vt=βvt1+(1β)θt​​
v t v_t vt 表示第t个数的估计值, β \beta β为一个可调参数,能表示 v t − 1 v_{t-1} vt1 的权重, θ t \theta_t θt 表示第t个数的实际值

Momentum就是在普通的梯度下降法中引入指数加权移动平均,即定义一个动量,它是梯度的指数加权移动平均值,然后使用该值代替原来的梯度方向来更新。定义的动量为:
v t = β v t − 1 + ( 1 − β ) ∇ θ J ( θ t ) v_t=βv _{t−1}+(1−β)\nabla_{ \theta} J(\theta_t) vt=βvt1+(1β)θJ(θt)
因此梯度下降表达式为:
θ t + 1 = θ t − η v t \theta_{t+1}=\theta_{t} - η v_t θt+1=θtηvt
普通的随机梯度下降法中,由于无法计算损失函数的确切导数,嘈杂的数据会使下降过程并不朝着最佳方向前进,使用加权平均能对嘈杂数据进行一定的屏蔽,使前进方向更接近实际梯度。此外,随机梯度下降法在局部极小值极有可能被困住,但Momentum由于下降方向由最近的一些数共同决定,能在一定程度反应总体的最佳下降方向,所以被困在局部最优解的可能会减小。

AdaGrad

Adagrad是对学习率进行了一个约束,对于经常更新的参数,由于已经积累了大量关于它的知识,不希望被单个样本影响太大,所以希望学习速率慢一些;对于偶尔更新的参数,由于了解的信息太少,希望能从每个偶然出现的样本身上多学一些,即需要学习率大一些。
该方法开始使用二阶动量,才意味着“自适应学习率”优化算法时代的到来。二阶动量是用来度量历史更新频率的,即迄今为止所有梯度值的平方和。二阶动量越大,学习率就越小,这一方法在稀疏数据场景下表现非常好。
v t = ∑ i = 1 n g t 2 v_{t} = \sum _{i=1} ^{n} g^2_t vt=i=1ngt2
θ t + 1 = θ t − η v t + ϵ \theta_{t+1}=\theta_{t} - \frac{η}{\sqrt{v_t+ \epsilon}} θt+1=θtvt+ϵ η
缺点:
仍需要手工设置一个全局学习率 , 如果 设置过大的话,会使regularizer过于敏感,对梯度的调节太大
中后期,分母上梯度累加的平方和会越来越大,使得参数更新量趋近于0,使得训练提前结束,无法学习

RMSprop

RMSProp算法修改了AdaGrad的梯度平方和累加为指数加权的移动平均,还将学习速率除以平方梯度的指数衰减平均值,使得其在非凸设定下效果更好。设定参数:全局初始率η默认设为0.001,decay rate β \beta β,默认设置为0.9,一个极小的常量 ,通常为10e-6。E是取期望的意思。
E [ g 2 ] t = β E [ g 2 ] t + ( 1 − β 1 ) g t 2 E[g^2]_t = \beta E[g^2]_t+(1-\beta _{1})g^2_{t} E[g2]t=βE[g2]t+(1β1)gt2
θ t + 1 = θ t − η E [ g 2 ] t + ϵ g t \theta_{t+1}=\theta_{t} - \frac{η}{\sqrt{E[g^2]_t}+ \epsilon}g_t θt+1=θtE[g2]t +ϵηgt

Adam: Adaptive Moment Estimation

对AdaGrad的优化,一种通过计算模型参数的梯度以及梯度平方的加权平均值(一阶动量和二阶动量),来调整模型的参数。
g t = ∇ θ J ( θ t ) g_t = \nabla_{ \theta} J(\theta_t) gt=θJ(θt)
m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta _{ 1}m_{t-1} + (1-\beta _{1})g_{t} mt=β1mt1+(1β1)gt
v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta _{2}v_{t-1} + (1-\beta _{2})g^2_{t} vt=β2vt1+(1β2)gt2
m t ^ = m t 1 − β 1 t \hat{m_t}=\frac{m_t}{1-\beta^t_{1}} mt^=1β1tmt
v t ^ = v t 1 − β 2 t \hat{v_t}=\frac{v_t}{1-\beta^t_{2}} vt^=1β2tvt
θ t + 1 = θ t − η v t ^ + ϵ m t ^ \theta_{t+1}=\theta_{t} - \frac{η}{\sqrt{\hat{v_t}}+ \epsilon}\hat{m_t} θt+1=θtvt^ +ϵηmt^
其中,各个变量含义如下:
g t g_t gt:模型参数在第t次迭代时的梯度,
m t 和 v t m_t和v_t mtvt:模型参数在第t次迭代时的一阶动量和二阶动量,
β 1 和 β 2 \beta _{1}和\beta _{2} β1β2:超参数(默认是0.9和0.999),
β 1 t 和 β 2 t \beta _{1}^{t}和\beta _{2}^{t} β1tβ2t β 1 \beta _{1} β1 β 2 \beta _{2} β2的t次方。
m t ^ \hat{m_t} mt^ v t ^ \hat{v_t} vt^ t是梯度的偏差纠正后的移动平均值
Adam优化器的主要优点是它能够自适应地调整每个参数的学习率,从而提高模型的收敛速度和泛化能力。

AdamW

Adam 虽然收敛速度快,但没能解决参数过拟合的问题。学术界讨论了诸多方案,其中包括在损失函数中引入参数的 L2 正则项。这样的方法在其他的优化器中或许有效,但会因为 Adam 中自适应学习率的存在而对使用 Adam 优化器的模型失效(因为正则项同时存在于adam的分子和分母,参考adam的公式,这样正则就抵消了)。AdamW就是在Adam+L2正则化的基础上进行改进的算法。
以往的L2正则是直接加在损失函数上,比如加入正则,损失函数变化如下:
L l 2 ( θ ) = L ( θ ) + 1 2 λ ∣ ∣ θ ∣ ∣ 2 L_{l_2}(\theta)=L(\theta) + \frac{1}{2}λ||\theta||^2 Ll2(θ)=L(θ)+21λ∣∣θ2
图片中红色是上述的Adam+L2 regularization的方式,绿色就是adamw即Adam + weight decay的方式。
在这里插入图片描述
为什么这么做?bert给出的解释是

参考文章

[1] 梯度下降优化算法Momentum
[2]多种梯度下降优化算法总结分析

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/785572.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从不同的使用场景认识STag26

当你买下STag26时, 你买到的是什么? 如果你是商超生鲜区的经理, 你买到的是在促销旺季时的高效与安心。 你不用再担心价格没有及时更新, 导致水果蔬菜的滞销。 毕竟,STag26能够一键改价,实时更新&#x…

Linux の shell 流程控制

条件控制 # if then 如果else 没有语句 可以省略 if condition then#语句 fi# if then 。。。 else 。。。 fi if condition then#语句 else#语句 fi# if condition then#语句 elif condition2 then#语句 else#语句 fiif [ $a -gt $b ] thenecho "a > b&quo…

自学网络安全(黑客)一定要注意什么

自学网络安全(黑客)时,你需要注意以下几点: 合法性:确保你的学习和实践活动是合法的。未经授权的入侵、攻击或侵犯他人隐私的行为是违法的,并可能导致严重的法律后果。 遵守道德准则:确保你的学…

【漏洞通知】Apache Shiro又爆认证绕过漏洞CVE-2023-34478,漏洞等级:高危

2023年7月24日,Apache Shiro发布更新版本,修复了一个身份验证绕过漏洞,漏洞编号:CVE-2023-34478,漏洞危害等级:高危。 Apache Shiro版本1.12.0之前和2.0.0-alpha-3 之前容易受到路径遍历攻击,当…

Vue组件自定义事件

v-on:xxx"" &#xff1a;绑定 this.$emit(xxx) : 触发 this.$off() : 解绑 App.vue <template><div class"app"><h1>{{msg}}</h1><!--通过父组件给子组件传递函数类型的props实现&#xff1a;子给父传递参数--><…

观察者模式、中介者模式和发布订阅模式

观察者模式 定义 观察者模式定义了对象间的一种一对多的依赖关系&#xff0c;当一个对象的状态发生改变时&#xff0c;所有依赖于它的对象都将得到通知&#xff0c;并自动更新 观察者模式属于行为型模式&#xff0c;行为型模式关注的是对象之间的通讯&#xff0c;观察者模式…

【洛谷算法题】B2025-输出字符菱形【入门1顺序结构】

&#x1f468;‍&#x1f4bb;博客主页&#xff1a;花无缺 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! 本文由 花无缺 原创 收录于专栏 【洛谷算法题】 文章目录 【洛谷算法题】B2025-输出字符菱形【入门1顺序结构】&#x1f30f;题目描述&#x1f30f;输入格…

1 第一个vue程序

复习次数 &#xff1a;✔ 1.1 vue优势 1.2 vue环境 直接在idea的插件搜vue.js&#xff0c;然后下载。 接着创建一个空项目&#xff0c;并添加模块。然后&#xff0c;创建一个html文件。 1.3 vue例子 完整的html代码如下&#xff1a; <!DOCTYPE html> <html lang&qu…

Linux Day01

目录 一、Linux终端介绍 二、Linux目录介绍 1.目录结构 2.常见目录说明 3.绝对路径与相对路径 4.家目录 一、Linux终端介绍 二、Linux目录介绍 Linux目录&#xff1a;是从根目录"/"开始的 是一棵倒着的树 1.目录结构 2.常见目录说明 目前记住 bin 存放常用命…

网络安全领域关键信息泄露事件引发关注

近日&#xff0c;一家知名网络安全公司发布了一份报告揭露了一起重大信息泄露事件。据称&#xff0c;该事件涉及大量敏感用户数据的泄露引发了全球网络安全领域的广泛关注。 根据报道&#xff0c;该事件发生在全球范围内涉及多个国家和组织。专家指出&#xff0c;此次泄露事件…

ESP-C2模组实现透传示例说明

WIFI-TTL透传模块说明 V 1.0 2022-11-24 1 简介 WiFi-TTL透传模块基于我司DT-ESPC2-12模块研发&#xff0c;引出串口TTL、EN、STATE 等引脚。产品内置我司最新版本的串口透传固件可完成设备TTL 端口到WiFi/云的数据实时透传。本模块可直接取代原有的有线串口&#xff0c;实现…

java注解和自定义注解

目录 一、注解的概念 二、注解的类型 2.1、内置注解 2.2、元注解 2.2.1、各个元注解的作用 2.3、自定义注解 2.4、自定义注解实现及测试 一、注解的概念 1、注解的作用 ①&#xff1a;注解一般用于对程序的说明&#xff0c;就像注释一样&#xff0c;但是区别是注释是给…

出现了HTTPSConnectionPool(host=‘huggingface.co‘, port=443)错误的解决方法

在下载huggingface 模型的时候&#xff0c;经常会出现这个错误&#xff0c;HTTPSConnectionPool(host‘huggingface.co’, port443)&#xff0c;即使你已经有了正确的上网姿势。 如在下载Tokenizer的时候&#xff0c;就会出现&#xff1a; tokenizer AutoTokenizer.from_pre…

权智A133P 安卓10移植SPI转串WK2124驱动

硬件连接示意图 主控CPU通过SPI总线与WK2XXX芯相连接。WK2XXX控制4个UART的数据收发。 其中重要的参数有CS片选线和IRQ中断引脚。 LInux串口驱动框架 当WK2XXX驱动在内核注册成功后&#xff0c;会在/dev目录下面生成ttysWK0,ttysWK1,ttysWK2,ttysWK3节点。上层通过open,read,w…

pytest 第三方插件

目录 前言&#xff1a; 顺序执行&#xff1a;pytest-ordering 失败重试&#xff1a;pytest-rerunfailures 并行执行&#xff1a;pytest-xdist 前言&#xff1a; pytest 是一个广泛使用的 Python 测试框架。它具有强大的测试运行器、测试驱动开发和测试结果可视化等功能。除…

《面试1v1》如何能从Kafka得到准确的信息

&#x1f345; 作者简介&#xff1a;王哥&#xff0c;CSDN2022博客总榜Top100&#x1f3c6;、博客专家&#x1f4aa; &#x1f345; 技术交流&#xff1a;定期更新Java硬核干货&#xff0c;不定期送书活动 &#x1f345; 王哥多年工作总结&#xff1a;Java学习路线总结&#xf…

对高校数字化转型的思考

数字新技术与国民经济各产业的融合深化&#xff0c;使行业产业数字化、网络化、全球化、知识化、智能化趋势愈发显著&#xff0c;深刻改变着人的职业生涯、现代社会对人才的需求和新型就业形式&#xff0c;引发教育资源、形态和范式的深刻变革。数字化转型对于提高学校管理效率…

Redis简介、常用命令

目录 一、​​关系数据库​​与非关系型数据库概述 1.1 关系型数据库 1.2 非关系型数据库 二、关系数据库与非关系型数据库区别 2.1 数据存储方式不同 2.2 扩展方式不同 2.3 对事务性的支持不同 三、非关系型数据库产生背景 四、Redis简介 4.1 Redis的单线程模式 4.…

Linux系列---【Ubuntu 20.04安装KVM】

Ubuntu 20.04安装KVM 一、安装kvm 1.安装kvm sudo apt install qemu-kvm libvirt-daemon-system libvirt-clients bridge-utils 2. 将当前用户添加至libvirt 、 kvm组 sudo adduser $USER libvirt sudo adduser $USER kvm 3.验证安装 virsh list --all 4.启动libvert sudo syst…

Jmeter 压测实战:Jmeter 二次开发之自定义函数

目录 1 前言 2 开发准备 3 自定义函数核心实现 3.1 新建项目 3.2 继承实现 AbstractFunction 类 3.3 最终项目结构 4 Jmeter 加载扩展包 4.1 maven 构建配置 4.2 项目打包 4.3 Jmeter 加载扩展包 5 自定义函数调用调试 5.1 打开 Jmeter 函数助手&#xff0c;选择自…