【深度学习笔记】优化算法——Adam算法

news2024/10/6 18:31:28

Adam算法

🏷sec_adam

本章我们已经学习了许多有效优化的技术。
在本节讨论之前,我们先详细回顾一下这些技术:

  • 在 :numref:sec_sgd中,我们学习了:随机梯度下降在解决优化问题时比梯度下降更有效。
  • 在 :numref:sec_minibatch_sgd中,我们学习了:在一个小批量中使用更大的观测值集,可以通过向量化提供额外效率。这是高效的多机、多GPU和整体并行处理的关键。
  • 在 :numref:sec_momentum中我们添加了一种机制,用于汇总过去梯度的历史以加速收敛。
  • 在 :numref:sec_adagrad中,我们通过对每个坐标缩放来实现高效计算的预处理器。
  • 在 :numref:sec_rmsprop中,我们通过学习率的调整来分离每个坐标的缩放。

Adam算法 :cite:Kingma.Ba.2014将所有这些技术汇总到一个高效的学习算法中。
不出预料,作为深度学习中使用的更强大和有效的优化算法之一,它非常受欢迎。
但是它并非没有问题,尤其是 :cite:Reddi.Kale.Kumar.2019表明,有时Adam算法可能由于方差控制不良而发散。
在完善工作中, :cite:Zaheer.Reddi.Sachan.ea.2018给Adam算法提供了一个称为Yogi的热补丁来解决这些问题。
下面我们了解一下Adam算法。

算法

Adam算法的关键组成部分之一是:它使用指数加权移动平均值来估算梯度的动量和二次矩,即它使用状态变量

v t ← β 1 v t − 1 + ( 1 − β 1 ) g t , s t ← β 2 s t − 1 + ( 1 − β 2 ) g t 2 . \begin{aligned} \mathbf{v}_t & \leftarrow \beta_1 \mathbf{v}_{t-1} + (1 - \beta_1) \mathbf{g}_t, \\ \mathbf{s}_t & \leftarrow \beta_2 \mathbf{s}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2. \end{aligned} vtstβ1vt1+(1β1)gt,β2st1+(1β2)gt2.

这里 β 1 \beta_1 β1 β 2 \beta_2 β2是非负加权参数。
常将它们设置为 β 1 = 0.9 \beta_1 = 0.9 β1=0.9 β 2 = 0.999 \beta_2 = 0.999 β2=0.999
也就是说,方差估计的移动远远慢于动量估计的移动。
注意,如果我们初始化 v 0 = s 0 = 0 \mathbf{v}_0 = \mathbf{s}_0 = 0 v0=s0=0,就会获得一个相当大的初始偏差。
我们可以通过使用 ∑ i = 0 t β i = 1 − β t 1 − β \sum_{i=0}^t \beta^i = \frac{1 - \beta^t}{1 - \beta} i=0tβi=1β1βt来解决这个问题。
相应地,标准化状态变量由下式获得

v ^ t = v t 1 − β 1 t  and  s ^ t = s t 1 − β 2 t . \hat{\mathbf{v}}_t = \frac{\mathbf{v}_t}{1 - \beta_1^t} \text{ and } \hat{\mathbf{s}}_t = \frac{\mathbf{s}_t}{1 - \beta_2^t}. v^t=1β1tvt and s^t=1β2tst.

有了正确的估计,我们现在可以写出更新方程。
首先,我们以非常类似于RMSProp算法的方式重新缩放梯度以获得

g t ′ = η v ^ t s ^ t + ϵ . \mathbf{g}_t' = \frac{\eta \hat{\mathbf{v}}_t}{\sqrt{\hat{\mathbf{s}}_t} + \epsilon}. gt=s^t +ϵηv^t.

与RMSProp不同,我们的更新使用动量 v ^ t \hat{\mathbf{v}}_t v^t而不是梯度本身。
此外,由于使用 1 s ^ t + ϵ \frac{1}{\sqrt{\hat{\mathbf{s}}_t} + \epsilon} s^t +ϵ1而不是 1 s ^ t + ϵ \frac{1}{\sqrt{\hat{\mathbf{s}}_t + \epsilon}} s^t+ϵ 1进行缩放,两者会略有差异。
前者在实践中效果略好一些,因此与RMSProp算法有所区分。
通常,我们选择 ϵ = 1 0 − 6 \epsilon = 10^{-6} ϵ=106,这是为了在数值稳定性和逼真度之间取得良好的平衡。

最后,我们简单更新:

x t ← x t − 1 − g t ′ . \mathbf{x}_t \leftarrow \mathbf{x}_{t-1} - \mathbf{g}_t'. xtxt1gt.

回顾Adam算法,它的设计灵感很清楚:
首先,动量和规模在状态变量中清晰可见,
它们相当独特的定义使我们移除偏项(这可以通过稍微不同的初始化和更新条件来修正)。
其次,RMSProp算法中两项的组合都非常简单。
最后,明确的学习率 η \eta η使我们能够控制步长来解决收敛问题。

实现

从头开始实现Adam算法并不难。
为方便起见,我们将时间步 t t t存储在hyperparams字典中。
除此之外,一切都很简单。

%matplotlib inline
import torch
from d2l import torch as d2l


def init_adam_states(feature_dim):
    v_w, v_b = torch.zeros((feature_dim, 1)), torch.zeros(1)
    s_w, s_b = torch.zeros((feature_dim, 1)), torch.zeros(1)
    return ((v_w, s_w), (v_b, s_b))

def adam(params, states, hyperparams):
    beta1, beta2, eps = 0.9, 0.999, 1e-6
    for p, (v, s) in zip(params, states):
        with torch.no_grad():
            v[:] = beta1 * v + (1 - beta1) * p.grad
            s[:] = beta2 * s + (1 - beta2) * torch.square(p.grad)
            v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
            s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
            p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)
                                                       + eps)
        p.grad.data.zero_()
    hyperparams['t'] += 1

现在,我们用以上Adam算法来训练模型,这里我们使用 η = 0.01 \eta = 0.01 η=0.01的学习率。

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(adam, init_adam_states(feature_dim),
               {'lr': 0.01, 't': 1}, data_iter, feature_dim);
loss: 0.244, 0.015 sec/epoch

在这里插入图片描述

此外,我们可以用深度学习框架自带算法应用Adam算法,这里我们只需要传递配置参数。

trainer = torch.optim.Adam
d2l.train_concise_ch11(trainer, {'lr': 0.01}, data_iter)
loss: 0.254, 0.015 sec/epoch

在这里插入图片描述

Yogi

Adam算法也存在一些问题:
即使在凸环境下,当 s t \mathbf{s}_t st的二次矩估计值爆炸时,它可能无法收敛。
:cite:Zaheer.Reddi.Sachan.ea.2018 s t \mathbf{s}_t st提出了的改进更新和参数初始化。
论文中建议我们重写Adam算法更新如下:

s t ← s t − 1 + ( 1 − β 2 ) ( g t 2 − s t − 1 ) . \mathbf{s}_t \leftarrow \mathbf{s}_{t-1} + (1 - \beta_2) \left(\mathbf{g}_t^2 - \mathbf{s}_{t-1}\right). stst1+(1β2)(gt2st1).

每当 g t 2 \mathbf{g}_t^2 gt2具有值很大的变量或更新很稀疏时, s t \mathbf{s}_t st可能会太快地“忘记”过去的值。
一个有效的解决方法是将 g t 2 − s t − 1 \mathbf{g}_t^2 - \mathbf{s}_{t-1} gt2st1替换为 g t 2 ⊙ s g n ( g t 2 − s t − 1 ) \mathbf{g}_t^2 \odot \mathop{\mathrm{sgn}}(\mathbf{g}_t^2 - \mathbf{s}_{t-1}) gt2sgn(gt2st1)
这就是Yogi更新,现在更新的规模不再取决于偏差的量。

s t ← s t − 1 + ( 1 − β 2 ) g t 2 ⊙ s g n ( g t 2 − s t − 1 ) . \mathbf{s}_t \leftarrow \mathbf{s}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2 \odot \mathop{\mathrm{sgn}}(\mathbf{g}_t^2 - \mathbf{s}_{t-1}). stst1+(1β2)gt2sgn(gt2st1).

论文中,作者还进一步建议用更大的初始批量来初始化动量,而不仅仅是初始的逐点估计。

def yogi(params, states, hyperparams):
    beta1, beta2, eps = 0.9, 0.999, 1e-3
    for p, (v, s) in zip(params, states):
        with torch.no_grad():
            v[:] = beta1 * v + (1 - beta1) * p.grad
            s[:] = s + (1 - beta2) * torch.sign(
                torch.square(p.grad) - s) * torch.square(p.grad)
            v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
            s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
            p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)
                                                       + eps)
        p.grad.data.zero_()
    hyperparams['t'] += 1

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(yogi, init_adam_states(feature_dim),
               {'lr': 0.01, 't': 1}, data_iter, feature_dim);
loss: 0.245, 0.015 sec/epoch

在这里插入图片描述

小结

  • Adam算法将许多优化算法的功能结合到了相当强大的更新规则中。
  • Adam算法在RMSProp算法基础上创建的,还在小批量的随机梯度上使用EWMA。
  • 在估计动量和二次矩时,Adam算法使用偏差校正来调整缓慢的启动速度。
  • 对于具有显著差异的梯度,我们可能会遇到收敛性问题。我们可以通过使用更大的小批量或者切换到改进的估计值 s t \mathbf{s}_t st来修正它们。Yogi提供了这样的替代方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1507763.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LeetCode: 151. 反转字符串中的单词 + 双指针】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

【格与代数系统】格与代数系统汇总

【格与代数系统】格与哈斯图 目录 关系 偏序关系 偏序集 可比性 全序集 最值与上下界 上下确界 格 代数系统 性质 格与代数系统的关系 分配格 有界格 有补格 布尔代数 例1 例2 对偶格 软代数 完备格 稠密性 优软代数 小结 关系 X,Y是两个非空集合, 记若…

C语言编译成库文件的要求

keil编译成库文件 在Keil中,将C语言源文件编译成库文件通常需要进行以下步骤: 创建一个新的Keil项目,并将所需的C语言源文件添加到该项目中。 在项目设置中配置编译选项,确保生成的目标文件符合库文件的标准格式。 编译项目&…

ULBF810-ASEMI新能源整流桥ULBF810

编辑:ll ULBF810-ASEMI新能源整流桥ULBF810 型号:ULBF810 品牌:ASEMI 封装:ULBF-4 最大重复峰值反向电压:1000V 最大正向平均整流电流(Vdss):8A 功率(Pd):中小功率 芯片个数&#xff1a…

无人机手持地面站软件功能详解,无人机手持地面站软件开发人员组成及成本分析

无人机手持地面站软件是专为无人机操控和任务管理设计的移动应用,它通常集成在智能手机、平板电脑或其他便携式设备上,使得用户可以在远离无人机的地方对飞行器进行实时监控与远程控制。 主要功能详解: 1. 飞行控制与姿态显示: …

Android 音频系统

导入 早期Linux版本采用的是OSS框架,它也是Unix及类Unix系统中广泛使用的一种音频体系。 ALSA是Linux社区为了取代OSS而提出的一种框架,是一个源代码完全开放的系统(遵循GNU GPL和GNU LGPL)。ALSA在Kernel 2.5版本中被正式引入后,OSS就逐步…

力扣每日一题 猜数字游戏 阅读理解

Problem: 299. 猜数字游戏 思路 &#x1f468;‍&#x1f3eb; 灵神 复杂度 Code class Solution {public String getHint(String secret, String guess) {int a 0;int[] cntS new int[10];int[] cntG new int[10];for(int i 0; i < secret.length(); i){if(secre…

数据库(SQL sever)

本博客将主要讲述数据库&#xff08;SQL sever&#xff09; 1.数据库解决的数据问题&#xff1a; Data redundancy and inconsistency(数据冗余和不一致) Difficulty in accessing data Data isolation (数据孤立) Integrity problems (完整性问题) Atomicity of updates…

组态软件基础知识

一、组态软件基础知识 1、概述 &#xff08;1&#xff09;、组态软件概念与产生背景 “组态”的概念是伴随着集散型控制系统&#xff08;Distributed Control System简称DCS&#xff09;的出现才开始被广大的生产过程自动化技术人员所熟知的。在工业控制技术的不断发展和应用…

基于FPGA的HyperRam接口设计与实现

一 HyperRAM 针对一些低功耗、低带宽应用&#xff08;物联网、消费产品、汽车和工业应用等&#xff09;&#xff0c;涉及到外部存储&#xff0c;HyperRAM提供了更简洁的内存解决方案。 HyperRAM具有以下特性&#xff1a; 1、超低功耗&#xff1a;200MHz工作频率下读写不到50mW…

基于SpringBoot的农产品特色供销系统(蔬菜商城)

基于SpringBoot的农产品特色供销系统&#xff08;蔬菜商城&#xff09; 系统介绍 该系统使用Java、MySQL、Redis、Spring Boot和HTML等技术作为系统的技术支撑&#xff0c;实现了以下功能模块&#xff1a; &#xff08;1&#xff09;后台管理模块&#xff0c;包括权限、日志、…

学习Java的第七天

目录 一、什么是数组 二、作用 三、如何使用数组 1、声明数组变量 2、创建数组 示例&#xff1a; 3、数组的使用 示例&#xff1a; 4、数组的遍历 for循环示例&#xff08;不知道for循环的可以查看我之前发的文章&#xff09; for-each循环&#xff08;也就是增强for…

CAN一致性测试:物理层测试之终端电阻测试

从本周开始结合工作实践&#xff0c;给大家总结CAN一致性相关的测试 包括&#xff1a;物理层、数据链路层、应用层三大块知识点 CAN一致性测试:物理层测试之终端电阻测试 试验目的&#xff1a; 测试控制器的 CANH 对地、CANL 对地、CANH 对 CANL 的内阻是否符合 ISO11898-2的…

2024年k8s最新版本使用教程

2024年k8s最新版本使用教程 3. YAML语言入门3.1 基本语法规则3.2 支持的数据结构3.3 其他语法 4 资源管理4.1 k8s资源查询4.2 资源操作命令4.3 资源操作方式4.3.1 命令行方式4.3.2 YAML文件方式 5 Namespace5.1 查看命名空间5.2 创建命名空间5.3 删除命名空间5.4 命名空间资源限…

GitOps实践之Argo CD (2)

argocd 【-1】argocd可以解决什么问题? helm 部署是手动的?依赖流水线。而有时候仅仅更新一个小东西,流水线跑好久,CD真的不应该和CI耦合。不同环境的helm配置不同,手动修改问题多,可以用git管理起来,例如分不同环境用目录区分。argocd创建应用可以不通环境部署到不同集…

基于Java+springboot+VUE+redis实现的前后端分类版网上商城项目

基于Java springbootVUEredis实现的前后端分类版网上商城项目 博主介绍&#xff1a;多年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言…

备忘: 踩坑linux环境部署轻量化的Langchain-Chatchat集成通义千问

看了许多材料,为了利用大模型构建以对话方式驱动的本地应用程序需要使用LangChain-chatchat,其基本介绍参考Langchain-Chatchat项目 通过查询资料&#xff0c;查到win上安装流程使用免费的通义千问 api 最轻量化部署Langchain-Chatchat&#xff0c;原以为在信创Linux系统上非常…

MM01/MM02/MM03物料主数据增强

1.屏幕增强 -在主表中附加结构(判断数据的主表,如MARA,MARC) 增强字段数据元素勾选更改文档以后,会记录字段变更历史 -SPRO-->物流-常规-->物料主数据-->配置物料主记录-->创建定制子屏幕的程序 会生成对应的函数组--里面会包含两个屏幕(0001,0002) 这里的0001屏…

2024年华为HCIA-DATACOM新增题库(H12-811)

801、[单选题]178/832、在系统视图下键入什么命令可以切换到用户视图? A quit B souter C system-view D user-view 试题答案&#xff1a;A 试题解析&#xff1a;在系统视图下键入quit命令退出到用户视图。因此答案选A。 802、[单选题]“网络管理员在三层交换机上创建了V…

【联邦学习综述:概念、技术】

出自——联邦学习综述&#xff1a;概念、技术、应用与挑战。梁天恺 1*&#xff0c;曾 碧 2&#xff0c;陈 光 1 从两个方面保护隐私数据 硬件层面 可 信 执 行 环 境 &#xff08;Trusted Execution Environment&#xff0c;TEE&#xff09;边 缘 计 算&#xff08;Edge Com…