【深度学习笔记】7_6 RMSProp算法

news2024/9/23 5:20:29

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

7.6 RMSProp算法

我们在7.5节(AdaGrad算法)中提到,因为调整学习率时分母上的变量 s t \boldsymbol{s}_t st一直在累加按元素平方的小批量随机梯度,所以目标函数自变量每个元素的学习率在迭代过程中一直在降低(或不变)。因此,当学习率在迭代早期降得较快且当前解依然不佳时,AdaGrad算法在迭代后期由于学习率过小,可能较难找到一个有用的解。为了解决这一问题,RMSProp算法对AdaGrad算法做了一点小小的修改。该算法源自Coursera上的一门课程,即“机器学习的神经网络” [1]。

7.6.1 算法

我们在7.4节(动量法)里介绍过指数加权移动平均。不同于AdaGrad算法里状态变量 s t \boldsymbol{s}_t st是截至时间步 t t t所有小批量随机梯度 g t \boldsymbol{g}_t gt按元素平方和,RMSProp算法将这些梯度按元素平方做指数加权移动平均。具体来说,给定超参数 0 ≤ γ < 1 0 \leq \gamma < 1 0γ<1,RMSProp算法在时间步 t > 0 t>0 t>0计算

s t ← γ s t − 1 + ( 1 − γ ) g t ⊙ g t . \boldsymbol{s}_t \leftarrow \gamma \boldsymbol{s}_{t-1} + (1 - \gamma) \boldsymbol{g}_t \odot \boldsymbol{g}_t. stγst1+(1γ)gtgt.

和AdaGrad算法一样,RMSProp算法将目标函数自变量中每个元素的学习率通过按元素运算重新调整,然后更新自变量

x t ← x t − 1 − η s t + ϵ ⊙ g t , \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \frac{\eta}{\sqrt{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, xtxt1st+ϵ ηgt,

其中 η \eta η是学习率, ϵ \epsilon ϵ是为了维持数值稳定性而添加的常数,如 1 0 − 6 10^{-6} 106。因为RMSProp算法的状态变量 s t \boldsymbol{s}_t st是对平方项 g t ⊙ g t \boldsymbol{g}_t \odot \boldsymbol{g}_t gtgt的指数加权移动平均,所以可以看作是最近 1 / ( 1 − γ ) 1/(1-\gamma) 1/(1γ)个时间步的小批量随机梯度平方项的加权平均。如此一来,自变量每个元素的学习率在迭代过程中就不再一直降低(或不变)。

照例,让我们先观察RMSProp算法对目标函数 f ( x ) = 0.1 x 1 2 + 2 x 2 2 f(\boldsymbol{x})=0.1x_1^2+2x_2^2 f(x)=0.1x12+2x22中自变量的迭代轨迹。回忆在7.5节(AdaGrad算法)使用的学习率为0.4的AdaGrad算法,自变量在迭代后期的移动幅度较小。但在同样的学习率下,RMSProp算法可以更快逼近最优解。

%matplotlib inline
import math
import torch
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l

def rmsprop_2d(x1, x2, s1, s2):
    g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6
    s1 = gamma * s1 + (1 - gamma) * g1 ** 2
    s2 = gamma * s2 + (1 - gamma) * g2 ** 2
    x1 -= eta / math.sqrt(s1 + eps) * g1
    x2 -= eta / math.sqrt(s2 + eps) * g2
    return x1, x2, s1, s2

def f_2d(x1, x2):
    return 0.1 * x1 ** 2 + 2 * x2 ** 2

eta, gamma = 0.4, 0.9
d2l.show_trace_2d(f_2d, d2l.train_2d(rmsprop_2d))

输出:

epoch 20, x1 -0.010599, x2 0.000000

在这里插入图片描述

7.6.2 从零开始实现

接下来按照RMSProp算法中的公式实现该算法。

features, labels = d2l.get_data_ch7()

def init_rmsprop_states():
    s_w = torch.zeros((features.shape[1], 1), dtype=torch.float32)
    s_b = torch.zeros(1, dtype=torch.float32)
    return (s_w, s_b)

def rmsprop(params, states, hyperparams):
    gamma, eps = hyperparams['gamma'], 1e-6
    for p, s in zip(params, states):
        s.data = gamma * s.data + (1 - gamma) * (p.grad.data)**2
        p.data -= hyperparams['lr'] * p.grad.data / torch.sqrt(s + eps)

我们将初始学习率设为0.01,并将超参数 γ \gamma γ设为0.9。此时,变量 s t \boldsymbol{s}_t st可看作是最近 1 / ( 1 − 0.9 ) = 10 1/(1-0.9) = 10 1/(10.9)=10个时间步的平方项 g t ⊙ g t \boldsymbol{g}_t \odot \boldsymbol{g}_t gtgt的加权平均。

d2l.train_ch7(rmsprop, init_rmsprop_states(), {'lr': 0.01, 'gamma': 0.9},
              features, labels)

输出:

loss: 0.243452, 0.049984 sec per epoch

在这里插入图片描述

7.6.3 简洁实现

通过名称为RMSprop的优化器方法,我们便可使用PyTorch提供的RMSProp算法来训练模型。注意,超参数 γ \gamma γ通过alpha指定。

d2l.train_pytorch_ch7(torch.optim.RMSprop, {'lr': 0.01, 'alpha': 0.9},
                    features, labels)

输出:

loss: 0.243676, 0.043637 sec per epoch

在这里插入图片描述

小结

  • RMSProp算法和AdaGrad算法的不同在于,RMSProp算法使用了小批量随机梯度按元素平方的指数加权移动平均来调整学习率。

参考文献

[1] Tieleman, T., & Hinton, G. (2012). Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude. COURSERA: Neural networks for machine learning, 4(2), 26-31.


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1510750.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

得帆助力大族激光主数据平台建设,用数据为企业生产力赋能

本期客户 大族激光科技产业集团股份有限公司&#xff08;以下简称“大族激光”&#xff09;是一家从事工业激光加工设备与自动化等配套设备及其关键器件的研发、生产、销售&#xff0c;激光、机器人及自动化技术在智能制造领域的系统解决方案的优质提供商&#xff0c;是国内激光…

RPC通信原理

RPC通信原理 RPC的概念 如果现在我有一个电商项目&#xff0c;用户要查询订单&#xff0c;自然而然是通过Service接口来调用订单的实现类。 我们把用户模块和订单模块都放在一起&#xff0c;打包成一个war包&#xff0c;然后再tomcat上运行&#xff0c;tomcat占有一个进程&am…

智能革新:思通数科开源AI平台在保险合同管理中的应用与优化

思通数科开源的多模态AI能力引擎平台是一个强大的工具&#xff0c;它结合了自然语言处理&#xff08;NLP&#xff09;、图像识别和语音识别技术&#xff0c;为企业提供自动化处理和分析文本、音视频和图像数据的能力。这个平台的开源性质意味着它可以被广泛地应用于各种业务场景…

JSP中间件漏洞

jsp的注入最难挖 另外3个好挖 struts2 url有action 就代表是struts2 用漏洞利用工具 下面之这两个一般都可以用工具扫一下 、 有些网站看起来没有 action实际上我们提交了 我们的账号和密码 之后就有了 工具包是下面 这些 github上面也有 用法就是如下图 把url放进去就…

3d视觉笔记 | 神经辐射场NeRF(Neural Radiance Fields)

NeRF概念 NeRF&#xff08;Neural Radiance Fields&#xff0c;神经辐射场&#xff09;是一种用于3D场景重建和图像渲染的深度学习方法。它由Ben Mildenhall等人在2020年的论文《NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis》中首次提出。NeRF通过…

8-100V转5V 2A 12V 2A 降压芯片 外置MOS 恒压输出

SC9103 一款宽电压范围降压型 DC-DC 电源管理芯片&#xff0c;内部集成使能开关控制、基准电源、误差放大器、 过热保护、限流保护、短路保护等功能&#xff0c;非常适合宽电压输入降压使用。 SC9103 零功耗使能控制&#xff0c;可以大大节省外围器件&#xff0c;更加适合电池场…

20240309web前端_第一周作业_完成用户注册界面

作业一&#xff1a;完成用户注册界面 成果展示&#xff1a; 完整代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-…

c++初阶------类和对象(下)

作者前言 &#x1f382; ✨✨✨✨✨✨&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f382; ​&#x1f382; 作者介绍&#xff1a; &#x1f382;&#x1f382; &#x1f382; &#x1f389;&#x1f389;&#x1f389…

避抗指南:如何寻找OLED透明屏供应商

寻找OLED透明屏供应商&#xff0c;你可以按照以下步骤进行&#xff1a; 明确需求&#xff1a;首先&#xff0c;你需要明确自己的需求&#xff0c;包括所需OLED透明屏的尺寸、分辨率、亮度、色彩饱和度等具体参数&#xff0c;以及预算和采购量。这有助于你更精准地找到符合需求的…

Django入门 整体流程跑通

Django学习笔记 一、Django整体流程跑通 1.1安装 pip install django //安装 import django //在python环境中导入django django.get_version() //获取版本号&#xff0c;如果能获取到&#xff0c;说明安装成功Django目录结构 Python310-Scripts\django-admi…

滑动窗口的概念,糊涂窗口综合征,nagle算法

目录 1.流量控制 2.滑动窗口 3.思考问题 1.流量控制 一般来说,我们总是希望数据传输得更快一些,但如果发送方把数据发送得过快,接收方就可能来不及接收,这就会造成数据的丢失.所谓流量控制(flow control)就是发送方的发送速率不要太快,要让接收方来得及接收. 2.滑动窗口 T…

【VS Code插件开发】自定义指令实现 git 命令 (九)

&#x1f431; 个人主页&#xff1a;不叫猫先生&#xff0c;公众号&#xff1a;前端舵手 &#x1f64b;‍♂️ 作者简介&#xff1a;前端领域优质作者、阿里云专家博主&#xff0c;共同学习共同进步&#xff0c;一起加油呀&#xff01; ✨优质专栏&#xff1a;VS Code插件开发极…

WebServer -- 架构图 面试题(上)

目录 &#x1f382;前言 &#x1f33c;流程图 && 架构图 1&#xff09;什么是 WebServer 2&#xff09;服务器基本框架 3&#xff09;Reactor && Proactor 模式 4&#xff09;同步 I/O 模拟Proactor模式&#xff08;Linux&#xff09; 5&#xff09;主从…

小白刷题CTF show web方向

web01 右键查看源代码&#xff0c;再使用在线解密&#xff0c;就可以得出答案了 web02 sql注入 admin or 11 或者 1 or 11可以登录查询几个字段&#xff1a;1 or 11 order by 3 # 使用此语句&#xff0c;判断列数。 order by 3不会出错&#xff0c;但是order by 4就没有显示…

上传文件携带参数总是deubg不进去

const { data } await createVerificationMaterialApi({file: info.file,name: file,filename: info.file.name,data: { ids },})//这样传参数&#xff0c;网络里看发的请求会是如下样子&#xff0c;这样debug不到代码里正确方法 const { data } await createVerificationMat…

这下爽了,全是特殊版实用软件,功能强大还免费

闲话少说&#xff0c;直接上狠货。 1、我的ABC软件工具箱 简洁而不失强大&#xff0c;我的ABC软件工具箱是您批量处理办公任务的得力小助手。完全免费&#xff0c;界面清新无广告&#xff0c;让您轻松开启高效办公之旅。 面对日常办公中繁多的文件处理需求&#xff0c;如内容…

1.Datax数据同步之Windows下,mysql数据同步至另一个mysql数据库

目录 前言步骤操作大纲步骤明细其他问题 前言 Datax是什么&#xff1f; DataX 是阿里巴巴集团内被广泛使用的离线数据同步工具/平台&#xff0c;实现包括 MySQL、SQL Server、Oracle、PostgreSQL、HDFS、Hive、HBase、OTS、ODPS 等各种异构数据源之间高效的数据同步功能。准备…

IO学习--02

标准IO由ANSI C库说明&#xff0c;在很多系统都实现了标准IO库。标准IO库处理很多细节&#xff0c;如缓冲的分配、优化长度执行IO等&#xff0c;使得用户不需要考虑选择合适的长度。标准IO是在系统调用函数构建的&#xff0c;便于用户使用。 标准IO的所有操作都是围绕流&#x…

c语言经典测试题12

1.题1 float f[10]; // 假设这里有对f进行初始化的代码 for(int i 0; i < 10;) { if(f[i] 0) break; } 上述代码有那些缺陷&#xff08;&#xff09; A: for(int i 0; i < 10;)这一行写错了 B: f是float型数据直接做相等判断有风险 C: f[i]应该是f[i] D: 没有缺…

LLM PreTraining from scratch -- 大模型从头开始预训练指北

最近做了一些大模型训练相关的训练相关的技术储备,在内部平台上完成了多机多卡的llm 预训练的尝试,具体的过程大致如下: 数据准备: 大语言模型的训练依赖于与之匹配的语料数据,在开源社区有一群人在自发的整理高质量的语料数据,可以通过 以下的一些链接获取 liwu/MNBVC…