【深度学习中常见的优化器总结】SGD+Adagrad+RMSprop+Adam优化算法总结及代码实现

news2024/11/17 2:51:01

文章目录

  • 一、SGD,随机梯度下降
    • 1.1、算法详解
      • 1)MBSGD(Mini-batch Stochastic Gradient Descent)
      • 2)动量法:momentum
      • 3)NAG(Nesterov accelerated gradient)
      • 4)权重衰减项(weight_decay)
      • 5)总结
    • 1.2、Pytorch实现:torch.optim.SGD
    • 1.3、示例
  • 二、Adagrad:自适应梯度
    • 2.1、算法详解
    • 2.2、Pytorch的实现:torch.optim.Adagrad
  • 三、RMSprop
    • 3.1、算法详解
    • 3.2、Pytorch的实现:torch.optim.RMSprop
  • 四、Adam
    • 4.1、算法详解
    • 4.2、Pytorch的实现:torch.optim.Adam

  • 这个博客讲的非常清晰:https://blog.csdn.net/xian0710830114/article/details/126551268

一、SGD,随机梯度下降

1.1、算法详解

1)MBSGD(Mini-batch Stochastic Gradient Descent)

  • 随机梯度下降其实可以有三种实现方式,最为常用,而且在pytorch中实现的也是小批量随机梯度下降。
  • 有以下三种:

1)BGD(批量梯度下降法):每次迭代使用全部训练样本来计算梯度,并根据梯度的平均值来更新模型的参数。尽管 BGD 对参数更新的方向更稳定,但由于计算梯度需要考虑所有样本,因此在大规模数据集上会导致较高的计算开销。
2)SGD(随机梯度下降法):在每次迭代中,随机选择一个样本来计算梯度并更新模型的参数。与 BGD 不同,SGD 每次只使用一个样本,因此计算效率更高。然而,由于单个样本的梯度估计可能存在噪声,SGD 的参数更新方向更加不稳定,收敛速度也相对较慢。
3)MBSGD(小批量随机梯度下降法):MBGD 是 BGD 和 SGD 的折中方法。在每次迭代中,随机选择一个小批量的样本来计算梯度,并根据梯度的平均值来更新模型的参数。这样可以减少计算开销,并且相对于 SGD 而言,参数更新方向更加稳定。

  • 对于含有 n个训练样本的数据集,每次参数更新,选择一个大小为 m(m<n) 的mini-batch数据样本计算其梯度,其参数更新公式如下,其中 j 是一个batch的开始:
    在这里插入图片描述
  • 小批量随机梯度下降可以加速收敛,一定程度上有摆脱局部最优的能力(起码比SGD好),但是又可能会存在噪声。

2)动量法:momentum

  • 动量(Momentum)是一种优化梯度下降算法的技术,用于加速模型参数的更新,并帮助模型跳出局部最优解。
  • 它在训练过程中考虑了之前参数更新的方向和速度。通过将当前梯度与过去梯度加权平均,来获取即将更新的梯度。
  • 如图b,可以看出能够加速收敛
    在这里插入图片描述
  • 动量项通常设置为0.9或类似值。
  • 参数更新公式如下,其中ρ 是动量衰减率,m是速率(即一阶动量):
    在这里插入图片描述

3)NAG(Nesterov accelerated gradient)

  • 暂时略过,其实它也是加速收敛的方法

4)权重衰减项(weight_decay)

  • weight_decay通过对模型的权重进行惩罚来减小权重的大小,用于防止模型过拟合。(简单来说就是控制了模型复杂度,即强制的使权重不会特别大,因为进行了权重衰减,大权重衰减的就多)
  • 其实就相当于在梯度后面增加了一个wieght_decay × \times × θ t − 1 \theta_{t-1} θt1
    g t = g t + λ θ t − 1 g_t = g_t + \lambda\theta_{t-1} gt=gt+λθt1
  • 其实就是在梯度中,增加了权重衰减。weight_decay 用于控制模型权重衰减(weight decay)的程度。
  • 较小的 weight_decay 值会使权重衰减的影响较小,而较大的值会使权重衰减的影响更显著。
  • 这与岭回归类似,岭回归是在损失函数中增加了L2范数的约束,用于防止过拟合(尤其是当特征数大于样本数时,导致多重非线性)

5)总结

  • 优点:收敛速度变快,有一定摆脱局部最优的能力
  • 缺点:需要手动调参,例如学习率等

1.2、Pytorch实现:torch.optim.SGD

CLASS torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, maximize=False)
""
params(iterable)- 参数组,优化器要优化的那部分参数。
lr(float)- 初始学习率,可按需随着训练过程不断调整学习率。
momentum(float)- 动量,通常设置为 0.90.8
weight_decay(float)- 权重衰减系数,也就是 L2 正则项的系数
nesterov(bool)- bool 选项,是否使用 NAG(Nesterov accelerated gradient)
maximize(bool)- 最大化还是最小化损失函数,默认是最小化,即False
""

在这里插入图片描述

1.3、示例

SGD优化器计算过程(以线性回归为例)
建立模型为:y = w^Tx = w1x1+w2x2+w3x3
初始化:y=1*x1+1*x2+1*x3,三个参数w为[1, 1, 1]
损失函数:
l = (pred-gt)**2 = (w1x1+w2x2+w3x3) ** 2
求导(链式法则,先对pred求导,再对w求导):
l'(w1) = 2(pred-gt)*x1
l'(w2) = 2(pred-gt)*x2
l'(w3) = 2(pred-gt)*x3
 
输入数据:
x = tensor([ 1.0943,  1.3479, -1.6927])
预测结果:
p = 1*1.0943+1*1.3479+1*-1.6927=0.7495
 
1)当weight_decay = 0
输出梯度:grad: tensor([[ 2.8188,  3.4719, -4.3600]])
手动计算验证:
l'(w1) = 2*(0.7495- -0.5384)*1.0943=2.81869794
l'(w2) = 2*(0.7495- -0.5384)*1.3479=3.47192082
l'(w3) = 2*(0.7495- -0.5384)*-1.6927=-4.36005666

权重更新:lr = 0.01
w = tensor([[0.9718, 0.9653, 1.0436]], requires_grad=True)
w1 = 1-0.01*2.81869794=0.9718130206
w2 = 1-0.01*3.47192082=0.9652807918
w3 = 1-0.01*-4.36005666=1.0436005666
 
2)当weight_decay = 0.1,lr = 0.01
输出梯度:grad: tensor([[ 2.8188,  3.4719, -4.3600]])

l'(w1) = l`(w1) + 0.1*1=2.9188
w1:= 1-0.01*2.9188 = 0.9708

参考链接:https://blog.csdn.net/qq_39707285/article/details/124257377

二、Adagrad:自适应梯度

2.1、算法详解

  • Adagrad优化算法可以自适应调整不同参数的学习率大小,用于解决这样一个问题:常见特征(频繁特征)的参数更新较快,而不常见特征(稀疏特征)的更新较慢

  • Adagrad优化算法是引入了二阶动量,即 v t v_t vt,表示之前所有时间步长(iteration/epoch)的历史梯度的平方和。再将学习率变为 η v t + ε \frac{\eta }{\sqrt{v_t+\varepsilon } } vt+ε η,那么学习率就可以自适应更新:如果梯度大(更新较快),学习率就会降低;如果梯度小(更新较慢),学习率就会升高。
    在这里插入图片描述

  • 通过这种自适应调整学习率的方式,每个参数都分别拥有自己的学习率。使得对稀疏特征和频繁特征都能得到较好的更新效果。

  • 总结:

优点:Adagrad可以自适应调整学习率,使得对稀疏特征和频繁特征都能得到较好的更新效果。
缺点:仍需要手工设置一个全局学习率;在分母中累积平方梯度,因此在训练过程中累积和不断增长。这会导致学习率不断变小并最终变得无限小,使模型不能继续更新。

2.2、Pytorch的实现:torch.optim.Adagrad

CLASS torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
''

params (iterable) – 待优化参数的iterable或者是定义了参数组的dict
lr (float, 可选) – 学习率(默认: 1e-2)
lr_decay (float, 可选) – 学习率衰减(默认: 0)
weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认: 0)
initial_accumulator_value - 累加器的起始值,必须为正。
''

在这里插入图片描述

三、RMSprop

3.1、算法详解

  • RMSprop是对 Adagrad 的一种改进,将AdaGrad的梯度平方和累加 改为 指数加权的移动平均,参数更新公式:
    在这里插入图片描述
  • RMSprop 通过对梯度平方进行移动平均来计算参数的自适应学习率。具体来说,它引入了一个衰减系数(decay rate,即 ρ \rho ρ,一般设为0.99),用于控制历史梯度平方的权重。
  • 可以使学习率的调整更加平稳

3.2、Pytorch的实现:torch.optim.RMSprop

CLASS torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False''

params (iterable) – 待优化参数的iterable或者是定义了参数组的dict
lr (float, 可选) – 学习率(默认:1e-2)
momentum (float, 可选) – 动量因子(默认:0)
alpha (float, 可选) – 平滑常数(默认:0.99)
eps (float, 可选) – 为了增加数值计算的稳定性而加到分母里的项(默认:1e-8)
weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认: 0)
centered (bool, 可选) – 如果为True,计算中心化的RMSProp,并且用它的方差预测值对梯度进行归一化
''

在这里插入图片描述

四、Adam

4.1、算法详解

  • Adam算法结合了Momentum 和 RMSprop,并进行了偏差修正。
  • 也可以从数学理论上解释:Adam 利用梯度的一阶矩估计(momentum)结合过去梯度的更新方向以确定当前梯度的方向,以及二阶矩估计(梯度平方的移动平均)动态的调整学习率。

1)梯度一阶矩估计(通常称为动量):它表示先前梯度的指数加权移动平均,类似于动量优化算法中的动量项。它考虑了过去梯度的方向,并在更新时产生相关影响,有助于加速收敛
2)梯度二阶矩估计(称为自适应学习率):它表示先前梯度的平方的指数加权移动平均。它衡量了过去梯度大小的变化情况,用于自适应地调整学习率,使得在梯度变化较大时减小学习率,在梯度变化较小时增加学习率。

  • Adam的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。
    在这里插入图片描述
    在这里插入图片描述
  • 总结:

1)自适应学习率:根据梯度的二阶矩估计自动调整学习率大小,在梯度变化较大时减小学习率,在梯度变化较小时增加学习率。这种自适应性使得Adam算法对于不同参数和数据集具有较好的适应性,可以更快地收敛到最优解。
2)动量:利用梯度的一阶矩估计(动量)来考虑过去梯度的方向信息,从而加速模型训练的收敛过程。动量的引入有助于跳出局部最优解。

4.2、Pytorch的实现:torch.optim.Adam

CLASS torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
''
params (iterable) – 待优化参数的iterable或者是定义了参数组的dict
lr (float, 可选) – 学习率(默认:1e-3)
betas (Tuple[float,float], 可选) – 用于计算梯度以及梯度平方的移动平均值的系数(默认:0.90.999)
eps (float, 可选) – 为了增加数值计算的稳定性而加到分母里的项(默认:1e-8)
weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认: 0''

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/800464.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

c++网络编程

网络编程模型 c/s 模型&#xff1a;客户端服务器模型b/s 模型&#xff1a;浏览器服务器模型1.tcp网络流程 服务器流程&#xff1a; 1.创建套接字2.完善服务器网络信息结构体3.绑定服务器网络信息结构体4.让服务器处于监听状态5.accept阻塞等待客户端连接信号6.收发数据7.关闭套…

C++那些事之template disambiguator

template disambiguator 1.背景 最近看到一段代码&#xff1a; auto chunk_left first_sort_key.template GetChunk<ArrayType>(left); 请问&#xff0c;这里的.template代表什么意义&#xff1f; 本节将从实际例子出发&#xff0c;探讨这个意义。 2.template disambigu…

mac不识别移动硬盘导致无法拷贝资源

背景 硬盘插入到Mac电脑上之后&#xff0c;mac不识别移动硬盘导致无法拷贝资源。 移动硬盘在Mac上无法被识别的原因可能有很多&#xff0c;多数情况下&#xff0c;是硬盘的格式与Mac电脑不兼容。 文件系统格式不兼容 macOS使用的文件系统是HFS或APFS&#xff0c;如果移动硬盘是…

【java】【面对对象高级4】内部类、枚举、泛型

目录 1、内部类 1.1 成员内部类【了解】 1.1.1 定义 1.1.2 扩展变量 1.2 静态内部类【了解】 1.2.1 定义 1.2.2 扩展变量 1.3 局部内部类【了解】 1.4 匿名内部类【重点】 1.4.1 定义 1.4.1.1 常规写法 1.4.1.2 匿名内部类改造 1.4.2 匿名内部类的常见使用场景 1.4.2…

超卖等高并发秒杀场景的问题及解决方案

超卖等高并发秒杀场景的问题及解决方案 1. 超卖问题&#xff08;多人秒杀&#xff09;1.1 原因1.2 解决方案1.3 总结 2. 锁失效问题&#xff08;单人重复抢&#xff09;2.1 原因2.2 解决方案 3. 事务边界问题&#xff08;单人重复抢&#xff09;3.1 原因3.2 解决方案3.3 总结 4…

【踩坑】三种方式解决 Homebrew failing to install - fatal: not in a git directory

问题描述 解决方法一 添加安全目录&#xff0c;没有测试。 git config --global --add safe.directory /opt/homebrew/Library/Taps/homebrew/homebrew- git config --global --add safe.directory /opt/homebrew/Library/Taps/homebrew/homebrew-cask 解决方法二 取消挂载这…

Redis 主从同步原理

一、什么是主从同步&#xff1f; 主从同步&#xff0c;就是将数据冗余备份&#xff0c;主库&#xff08;Master&#xff09;将自己库中的数据&#xff0c;同步给从库&#xff08;Slave&#xff09;。 从库可以一个&#xff0c;也可以多个&#xff0c;如图所示&#xff1a; 二…

Acwing.291 蒙德里安的梦想

题目 求把NM的棋盘分割成若干个12的的长方形&#xff0c;有多少种方案。 例如当N2&#xff0c;M4时&#xff0c;共有5种方案。当N2&#xff0c;M3时&#xff0c;共有3种方案。如下图所示: 输入格式 输入包含多组测试用例。 每组测试用例占一行&#xff0c;包含两个整数N和M…

STM32 CAN通讯实验程序

目录 STM32 CAN通讯实验 CAN硬件原理图 CAN外设原理图 TJA1050T硬件描述 实验线路图 回环实验 CAN头文件配置 CAN_GPIO_Config初始化 CAN初始化结构体 CAN筛选器结构体 接收中断优先级配置 接收中断函数 main文件 实验现象 补充 STM32 CAN通讯实验 CAN硬件原理图…

JavaScript的函数中this的指向

JavaScript的函数中this的指向 JavaScript 语言之所以有 this 的设计&#xff0c;跟内存里面的数据结构有关系。 以下例子来简单描述this在不同情况下所指向的对象。 var obj {aa: function(){console.log(this.num)},num: 5 };var aa obj.aa; var num 10;obj.aa(); // …

简要介绍 | 走向自然的身份认证:步态识别技术简介

注1&#xff1a;本文系“简要介绍”系列之一&#xff0c;仅从概念上对步态识别进行非常简要的介绍&#xff0c;不适合用于深入和详细的了解。 走向自然的身份认证&#xff1a;步态识别技术简介 Gait Recognition Based on Deep Learning: A Survey | ACM Computing Surveys 背景…

一文谈谈Git

"And if forever lasts till now Alright" 为什么要有git&#xff1f; 想象一下&#xff0c;现如今你的老师同时叫你和张三&#xff0c;各自写一份下半年的学习计划交给他。 可是你的老师是一个极其"较真"的人&#xff0c;发现你俩写的学习计划太"水&…

【弹力设计篇】聊聊异步通讯设计

为什么需要异步设计 刚开始参加工作&#xff0c;发现有一些API设计中回落数据之后&#xff0c;然后将数据写入到消息队列中&#xff0c;当时很是不理解为什么要这么做&#xff0c;直到后边系统学习消息队列之后才发现原来这其实就是异步处理&#xff0c;当流量很多的时候&…

一张表中几列字段以不同的条件规则去统计计数展示实现思路设计

今天在写一个业务的时候&#xff0c;遇到这样一个需求 一、需求描述 一张表中其中几列字段需要以不同的条件规则去统计计数&#xff0c;求实现方式 因为项目业务涉及隐私&#xff0c;我就想了一个类似的情景 二、情景描述 有一张月考成绩表&#xff0c;包含学生和他的各科…

区间预测 | MATLAB实现QRBiGRU双向门控循环单元分位数回归多输入单输出区间预测

区间预测 | MATLAB实现QRBiGRU双向门控循环单元分位数回归多输入单输出区间预测 目录 区间预测 | MATLAB实现QRBiGRU双向门控循环单元分位数回归多输入单输出区间预测效果一览基本介绍模型描述程序设计参考资料 效果一览 基本介绍 MATLAB实现QRBiGRU双向门控循环单元分位数回归…

EXCEL,如何比较2个表里的数据差异(使用数据透视表)

目录 1 问题: 需要比较如下2个表的内容差异 1.1 原始数据喝问题 1.2 提前总结 2 使用EXCEL公式方法 2.1 新增辅助列&#xff1a; 辅助index 2.2 具体公式 配合条件格式 使用 3 数据透视表方法 3.1 新增辅助列&#xff1a; 辅助index 3.2 需要先打开 数据透视表向导 …

基于CNN卷积神经网络的调制信号识别算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 1. 卷积神经网络&#xff08;CNN&#xff09; 2. 调制信号识别 3.实现过程 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 MATLAB2022A 3.部分核心程序 % 构建调制类型…

支付宝短视频平台创作分成激励项目

没想到支付宝也开通了中视频计划&#xff0c;这波羊毛算是蒿定了&#xff0c;最近啊&#xff0c;马爸爸火速上线了支付宝创作分成计划&#xff0c;明显就是抄的抖音中视频计划&#xff0c;目前还在内测阶段&#xff0c;补贴的力度非常大&#xff0c;错过的话就只能拍大腿了&…

Prometheus 的应用服务发现及黑河部署等

目录 promtool检查语法 部署Prometheus Server 检查语法是否规范 部署node-exporter 部署Consul 直接请求API进行服务注册 使用register命令注册服务&#xff08;建议使用&#xff09; 单个和多个注册&#xff0c;多个后面多加了s 在Prometheus上做consul的服务发现 部署…

windows安装linux

https://www.cnblogs.com/liuqingzheng/p/16271895.html 咱们安装linux系统是centos7 准备工作&#xff1a; 安装软件&#xff1a;vmware -------虚拟机 官网下载地址&#xff1a;下载 VMware Workstation Pro | CN 也可以从这里面下载 链接&#xff1a;https://pan.bai…