【从零开始学习深度学习】43. 算法优化之Adam算法【RMSProp算法与动量法的结合】介绍及其Pytorch实现

news2025/1/24 5:06:45

Adam算法是在RMSProp算法基础上对小批量随机梯度也做了指数加权移动平均 【可以看做是RMSProp算法与动量法的结合】。

目录

  • 1. Adam算法介绍
  • 2. 从零实现Adam算法
  • 3. Pytorch简洁实现Adam算法--optim.Adam
  • 总结

1. Adam算法介绍

Adam算法使用了动量变量 v t \boldsymbol{v}_t vt和RMSProp算法中小批量随机梯度按元素平方的指数加权移动平均变量 s t \boldsymbol{s}_t st,并在时间步0将它们中每个元素初始化为0。给定超参数 0 ≤ β 1 < 1 0 \leq \beta_1 < 1 0β1<1(算法作者建议设为0.9),时间步 t t t的动量变量 v t \boldsymbol{v}_t vt即小批量随机梯度 g t \boldsymbol{g}_t gt的指数加权移动平均:

v t ← β 1 v t − 1 + ( 1 − β 1 ) g t . \boldsymbol{v}_t \leftarrow \beta_1 \boldsymbol{v}_{t-1} + (1 - \beta_1) \boldsymbol{g}_t. vtβ1vt1+(1β1)gt.

和RMSProp算法中一样,给定超参数 0 ≤ β 2 < 1 0 \leq \beta_2 < 1 0β2<1(算法作者建议设为0.999),
将小批量随机梯度按元素平方后的项 g t ⊙ g t \boldsymbol{g}_t \odot \boldsymbol{g}_t gtgt做指数加权移动平均得到 s t \boldsymbol{s}_t st

s t ← β 2 s t − 1 + ( 1 − β 2 ) g t ⊙ g t . \boldsymbol{s}_t \leftarrow \beta_2 \boldsymbol{s}_{t-1} + (1 - \beta_2) \boldsymbol{g}_t \odot \boldsymbol{g}_t. stβ2st1+(1β2)gtgt.

由于我们将 v 0 \boldsymbol{v}_0 v0 s 0 \boldsymbol{s}_0 s0中的元素都初始化为0,
在时间步 t t t我们得到 v t = ( 1 − β 1 ) ∑ i = 1 t β 1 t − i g i \boldsymbol{v}_t = (1-\beta_1) \sum_{i=1}^t \beta_1^{t-i} \boldsymbol{g}_i vt=(1β1)i=1tβ1tigi。将过去各时间步小批量随机梯度的权值相加,得到 ( 1 − β 1 ) ∑ i = 1 t β 1 t − i = 1 − β 1 t (1-\beta_1) \sum_{i=1}^t \beta_1^{t-i} = 1 - \beta_1^t (1β1)i=1tβ1ti=1β1t。需要注意的是,当 t t t较小时,过去各时间步小批量随机梯度权值之和会较小。例如,当 β 1 = 0.9 \beta_1 = 0.9 β1=0.9时, v 1 = 0.1 g 1 \boldsymbol{v}_1 = 0.1\boldsymbol{g}_1 v1=0.1g1。为了消除这样的影响,对于任意时间步 t t t,我们可以将 v t \boldsymbol{v}_t vt再除以 1 − β 1 t 1 - \beta_1^t 1β1t,从而使过去各时间步小批量随机梯度权值之和为1。这也叫作偏差修正。在Adam算法中,我们对变量 v t \boldsymbol{v}_t vt s t \boldsymbol{s}_t st均作偏差修正:

v ^ t ← v t 1 − β 1 t , \hat{\boldsymbol{v}}_t \leftarrow \frac{\boldsymbol{v}_t}{1 - \beta_1^t}, v^t1β1tvt,

s ^ t ← s t 1 − β 2 t . \hat{\boldsymbol{s}}_t \leftarrow \frac{\boldsymbol{s}_t}{1 - \beta_2^t}. s^t1β2tst.

接下来,Adam算法使用以上偏差修正后的变量 v ^ t \hat{\boldsymbol{v}}_t v^t s ^ t \hat{\boldsymbol{s}}_t s^t,将模型参数中每个元素的学习率通过按元素运算重新调整:

g t ′ ← η v ^ t s ^ t + ϵ , \boldsymbol{g}_t' \leftarrow \frac{\eta \hat{\boldsymbol{v}}_t}{\sqrt{\hat{\boldsymbol{s}}_t} + \epsilon}, gts^t +ϵηv^t,

其中 η \eta η是学习率, ϵ \epsilon ϵ是为了维持数值稳定性而添加的常数,如 1 0 − 8 10^{-8} 108。和AdaGrad算法、RMSProp算法以及AdaDelta算法一样,目标函数自变量中每个元素都分别拥有自己的学习率。最后,使用 g t ′ \boldsymbol{g}_t' gt迭代自变量:

x t ← x t − 1 − g t ′ . \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}_t'. xtxt1gt.

2. 从零实现Adam算法

我们按照Adam算法中的公式实现该算法。其中时间步 t t t通过hyperparams参数传入adam函数。

%matplotlib inline
import torch
import sys
import d2lzh_pytorch as d2l

features, labels = d2l.get_data_ch7()

def init_adam_states():
    v_w, v_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)
    s_w, s_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)
    return ((v_w, s_w), (v_b, s_b))

def adam(params, states, hyperparams):
    beta1, beta2, eps = 0.9, 0.999, 1e-6
    for p, (v, s) in zip(params, states):
        v[:] = beta1 * v + (1 - beta1) * p.grad.data
        s[:] = beta2 * s + (1 - beta2) * p.grad.data**2
        v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
        s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
        p.data -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr) + eps)
    hyperparams['t'] += 1

使用学习率为0.01的Adam算法来训练模型。

d2l.train_ch7(adam, init_adam_states(), {'lr': 0.01, 't': 1}, features, labels)

输出:

loss: 0.245370, 0.065155 sec per epoch

在这里插入图片描述

3. Pytorch简洁实现Adam算法–optim.Adam

通过名称为“Adam”的优化器实例,我们便可使用PyTorch提供的Adam算法。

d2l.train_pytorch_ch7(torch.optim.Adam, {'lr': 0.01}, features, labels)

输出:

loss: 0.242066, 0.056867 sec per epoch

在这里插入图片描述

总结

  • Adam算法在RMSProp算法的基础上对小批量随机梯度也做了指数加权移动平均。
  • Adam算法使用了偏差修正。

如果文章内容对你有帮助,感谢点赞+关注!

欢迎关注下方GZH:阿旭算法与机器学习,共同学习交流~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/150581.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LVGL官方UI设计软件——SquareLine Studio micropython 使用简单测评

经常去LVGL官网逛的人一定都知道这个软件&#xff0c;作为官方的亲儿子&#xff0c;使用体验如何呢&#xff0c;我简单体验了一周左右&#xff0c;简单做个测评&#xff0c;本测评仅代表我个人意见&#xff0c;并且仅限micropython的使用体验&#xff01; 首先是价格&#xff0…

TCP报文段(segment)首部格式

TCP传给IP的数据单元称作TCP报文段或简称为TCP段&#xff08;TCP segment&#xff09;。 IP传给链路层的数据单元称作IP数据报(IP datagram)。 通过以太网传输的比特流称作帧(Frame)。 逐层封装&#xff1a; 源端口号发送端端口号&#xff0c;字段长16位&#xff08;2字节&…

计算机网络第二章

物理层的基本概念物理层的作用&#xff1a;物理层解决如何在连接各种计算机的传输媒体上传输数据比特流&#xff0c;而不是指具体的传输媒体。物理层的主要任务&#xff1a;确定与传输媒体接口有关的一些特性 &#x1f51c;本质&#xff1a;定义一些固定标准物理层的四大特性&a…

Word怎么转PDF?8个Word转PDF工具分析

Word 到 PDF 转换工具是用于将 Microsoft Word&#xff08;DOC 或 DOCX&#xff09;文档转换为 PDF 格式的程序。根据操作模式&#xff0c;它可以是在线或离线软件。当然&#xff0c;考虑到市场上充斥着此类工具&#xff0c;获得最好的 DOCX 到 PDF 转换器可能会让人头疼。正是…

MySQL基础篇第12章(MySQL数据类型)

1. MySQL中的数据类型 常见的数据类型的属性&#xff1a; 2. 整数介绍 2.1 类型介绍 整数类型一共有 5 种&#xff0c;包括 TINYINT、SMALLINT、MEDIUMINT、INT&#xff08;INTEGER&#xff09;和 BIGINT。 它们的区别如下表所示&#xff1a; 2.2 可选属性 整数类型的可选…

javaweb-异步请求AjaxaxiosJSON

1&#xff0c;Ajax 1.1 概述 AJAX (Asynchronous JavaScript And XML)&#xff1a;异步的 JavaScript 和 XML。 我们先来说概念中的 JavaScript 和 XML&#xff0c;JavaScript 表明该技术和前端相关&#xff1b;XML 是指以此进行数据交换。而这两个我们之前都学习过。 ####…

JavaWeb基础——从入门到超神(笔记,持续更新)

day00综述 需要学习SpringBoot&#xff0c;但是JavaWeb是基础&#xff0c;来补一下 JavaWeb就是将数据库中的数据用好看的样式在网页上呈现出来 day01MySQL基础 接下来就是MySQL的安装什么的 mysqld --initialize-insecure mysqld -install net start mysql至此我的电脑上已…

【蓝桥杯-筑基篇】基础入门

&#x1f353;系列专栏:蓝桥杯 &#x1f349;个人主页:个人主页 目录 1.数位翻转 2.三个数求最大值的写法 3.两数交换的几种方法 4.身份证第18位合法性校验 5.黑洞数&#xff08;陷阱数&#xff09; 1.数位翻转 如: 整数 12345 返回结果为整数: 54321 当第一次看到这个题…

【零基础】学python数据结构与算法笔记7

文章目录前言41.查找排序部分习题42.查找排序习题143.查找排序习题244.查找排序习题345.查找排序习题4总结前言 学习python数据结构与算法&#xff0c;学习常用的算法&#xff0c; b站学习链接 41.查找排序部分习题 选题部分来自leetcode 42.查找排序习题1 242. 有效的…

蓝桥杯备赛Day6——链表

目录 数组的缺点 链表 单向链表 双向链表 Python链表的实现 手写链表 数组的缺点 1)需要占用连续的空间 若某个数组很大&#xff0c;可能没有这么大的连续空间给它用。 2〉不方便删除和插入 例如删除数组中间的一个数据&#xff0c;需要把后面所有的数据往前挪填补这个空…

CODESYS开发教程7-字符串及其基本操作

今天继续我们的小白教程&#xff0c;老鸟就不要在这浪费时间了&#x1f60a;。 前面一期我们介绍了CODESYS的关键字及变量。这一期主要介绍CODESYS的字符串类型&#xff0c;以及如何利用字符串操作函数来实现字符串的查找、插入、替换、连接、分割、删除等相关操作。注意本文介…

Realsense相机的RGB与depth图像的对齐

第三部分 将RGB图像和Depth图像对齐 文章目录第三部分 将RGB图像和Depth图像对齐前言一、创建对齐的cpp文件1.用vim创建C文件二、使用CMake构建C工程1.创建并编写CMakeList.txt文件2.编译CMakeLists.txt总结前言 将RGB图像和深度图像对齐有两种方式&#xff0c;一种是将深度图…

音视频开发-第一章-H264编解码

目录参考原文一、概述二、封装格式2.1、视频文件封装格式2.2、音视频编码方式2.2.1、视频编码方式2.2.2、音频编码方式三、H264相关概念3.1、H264基本单元3.2、帧类型3.3、GOP(画面组)3.4、IDR 帧四、H264压缩方式4.1、H264压缩方式4.2、H264压缩方式说明五、H264分层结构5.1、…

【websocket】前端websocket 实时通信

前端websocket 实时通信 文章目录前端websocket 实时通信什么是websocket为什么传统的http协议不能做到websocket实现的功能websocket前后端所用到的事件对比WebSocket.readyState代码什么是websocket websocket是HTML5开始提供的一种网络通信协议&#xff0c;它诞生的目的是在…

60条小妙招帮助你开车更省油

1、把备胎和千斤顶&#xff0c;工具&#xff0c;都放在家里&#xff0c;不跑长途不带这些&#xff0c;省油。2、说明书上说92号或以上标号&#xff0c;那么加95号油省油。如果是95或以上的标注&#xff0c;那就加98省油。3、驾驶中尽量减少急加速 急刹车&#xff0c;省油。4、驾…

Java——多态

好久不见啊&#xff0c;兄弟们&#xff01;&#xff01;这不将近期末考试了吗&#xff0c;阿涛平日里课听的不多&#xff0c;所以最近都在疯狂补课&#xff0c;祖宗之法也可变&#xff0c;阿涛的学校终于不是二十周校历了&#xff01;&#xff01;希望从今往后我们的生活都能够…

[oeasy]python0041_ 转义字符_转义序列_escape_序列_sequence

转义序列 回忆上次内容 上次回顾了5bit-Baudot博多码的来历从 莫尔斯码 到 博多码 原来 人 来 收发电报 现在 机器 来 收发电报 输入方式 从 电键改成 键盘 输出方式 从 纸带变成 打印纸张 后来 电传打字机ASR-33成为 初代 经典终端 除了 \n 和 \r 之外 还有什么 特殊字符 吗…

前端格式化工具使用(eslint、stylelint、prettier、lint-staged和husky搭配格式化代码)

目录 eslint 安装eslint .eslintrc.js env extends parser parserOptions rules globals plugins 屏蔽eslint检测具体规则 官方规则 stylelint 安装stylelint 创建stylelint配置文件 .eslintrc.js extends sass文件的格式检查 prettier 安装使用prettier …

Vue组件-插槽

一、插槽 1. 组件的三大核心&#xff1a;属性&#xff08;data、props&#xff09;、事件、插槽 2. 插槽&#xff08;slot&#xff09; 插槽&#xff08;slot&#xff09;将父组件的内容与子组件的模板相混合&#xff0c;从而弥补了视图的不足。 插槽的目的&#xff1a;使组件…

ConstraintLayout2

ConstraintLayout2ImageFilterView 属性 app:altSrc&#xff1a;altSrc提供的资源将会和src提供的资源通过crossfade属性形成交叉淡化效果。默认情况下,crossfade0&#xff0c;altSrc所引用的资源不可见,取值在0-1app:saturation&#xff1a;float型&#xff0c;默认1&#xf…