【从零开始学习深度学习】41. 算法优化之RMSProp算法【基于AdaGrad算法的改进】介绍及其Pytorch实现

news2024/11/16 10:41:26

上一篇文章AdaGrad算法中提到,因为调整学习率时分母上的变量 s t \boldsymbol{s}_t st一直在累加按元素平方的小批量随机梯度,所以目标函数自变量每个元素的学习率在迭代过程中一直在降低(或不变)。因此,当学习率在迭代早期降得较快且当前解依然不佳时,AdaGrad算法在迭代后期由于学习率过小,可能较难找到一个有用的解。为了解决这一问题,RMSProp算法对AdaGrad算法做了一点修改。

目录

  • 1. RMSProp算法介绍
  • 2. 从零实现RMSProp算法
  • 3. Pytorch简洁实现RMSProp算法--optim.RMSprop
  • 总结

1. RMSProp算法介绍

不同于AdaGrad算法里状态变量 s t \boldsymbol{s}_t st是截至时间步 t t t所有小批量随机梯度 g t \boldsymbol{g}_t gt按元素平方和,RMSProp算法将这些梯度按元素平方做指数加权移动平均[在之前动量法里介绍过指数加权移动平均]。具体来说,给定超参数 0 ≤ γ < 1 0 \leq \gamma < 1 0γ<1,RMSProp算法在时间步 t > 0 t>0 t>0计算

s t ← γ s t − 1 + ( 1 − γ ) g t ⊙ g t . \boldsymbol{s}_t \leftarrow \gamma \boldsymbol{s}_{t-1} + (1 - \gamma) \boldsymbol{g}_t \odot \boldsymbol{g}_t. stγst1+(1γ)gtgt.

和AdaGrad算法一样,RMSProp算法将目标函数自变量中每个元素的学习率通过按元素运算重新调整,然后更新自变量

x t ← x t − 1 − η s t + ϵ ⊙ g t , \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \frac{\eta}{\sqrt{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, xtxt1st+ϵ ηgt,

其中 η \eta η是学习率, ϵ \epsilon ϵ是为了维持数值稳定性而添加的常数,如 1 0 − 6 10^{-6} 106。因为RMSProp算法的状态变量 s t \boldsymbol{s}_t st是对平方项 g t ⊙ g t \boldsymbol{g}_t \odot \boldsymbol{g}_t gtgt的指数加权移动平均,所以可以看作是最近 1 / ( 1 − γ ) 1/(1-\gamma) 1/(1γ)个时间步的小批量随机梯度平方项的加权平均。如此一来,自变量每个元素的学习率在迭代过程中就不再一直降低(或不变)。

让我们先观察RMSProp算法对目标函数 f ( x ) = 0.1 x 1 2 + 2 x 2 2 f(\boldsymbol{x})=0.1x_1^2+2x_2^2 f(x)=0.1x12+2x22中自变量的迭代轨迹。依然使用的学习率为0.4的AdaGrad算法,自变量在迭代后期的移动幅度较小。但在同样的学习率下,RMSProp算法可以更快逼近最优解。

%matplotlib inline
import math
import torch
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l

def rmsprop_2d(x1, x2, s1, s2):
    g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6
    s1 = gamma * s1 + (1 - gamma) * g1 ** 2
    s2 = gamma * s2 + (1 - gamma) * g2 ** 2
    x1 -= eta / math.sqrt(s1 + eps) * g1
    x2 -= eta / math.sqrt(s2 + eps) * g2
    return x1, x2, s1, s2

def f_2d(x1, x2):
    return 0.1 * x1 ** 2 + 2 * x2 ** 2

eta, gamma = 0.4, 0.9
d2l.show_trace_2d(f_2d, d2l.train_2d(rmsprop_2d))

输出:

epoch 20, x1 -0.010599, x2 0.000000

在这里插入图片描述

2. 从零实现RMSProp算法

接下来按照RMSProp算法中的公式实现该算法。

features, labels = d2l.get_data_ch7()

def init_rmsprop_states():
    s_w = torch.zeros((features.shape[1], 1), dtype=torch.float32)
    s_b = torch.zeros(1, dtype=torch.float32)
    return (s_w, s_b)

def rmsprop(params, states, hyperparams):
    gamma, eps = hyperparams['gamma'], 1e-6
    for p, s in zip(params, states):
        s.data = gamma * s.data + (1 - gamma) * (p.grad.data)**2
        p.data -= hyperparams['lr'] * p.grad.data / torch.sqrt(s + eps)

我们将初始学习率设为0.01,并将超参数 γ \gamma γ设为0.9。此时,变量 s t \boldsymbol{s}_t st可看作是最近 1 / ( 1 − 0.9 ) = 10 1/(1-0.9) = 10 1/(10.9)=10个时间步的平方项 g t ⊙ g t \boldsymbol{g}_t \odot \boldsymbol{g}_t gtgt的加权平均。

d2l.train_ch7(rmsprop, init_rmsprop_states(), {'lr': 0.01, 'gamma': 0.9},
              features, labels)

输出:

loss: 0.243452, 0.049984 sec per epoch

在这里插入图片描述

3. Pytorch简洁实现RMSProp算法–optim.RMSprop

通过名称为RMSprop的优化器方法,我们便可使用PyTorch提供的RMSProp算法来训练模型。注意,超参数 γ \gamma γ通过alpha指定。

d2l.train_pytorch_ch7(torch.optim.RMSprop, {'lr': 0.01, 'alpha': 0.9},
                    features, labels)

输出:

loss: 0.243676, 0.043637 sec per epoch

在这里插入图片描述

总结

  • RMSProp算法和AdaGrad算法的不同在于,RMSProp算法使用了小批量随机梯度按元素平方的指数加权移动平均来调整学习率。

如果文章内容对你有帮助,感谢点赞+关注!

欢迎关注下方GZH:阿旭算法与机器学习,共同学习交流~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/150229.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode 45. 跳跃游戏 II

45. 跳跃游戏 II - 力扣&#xff08;LeetCode&#xff09; 解法1&#xff1a;&#xff08;动态规划 贪心&#xff09; 果然代码越短&#xff0c;思路越难。这题用的是动态规划贪心的思想。首先分析题意我们可以知道&#xff0c;从索引0这个点开始&#xff0c;我们走一步可以…

redis命令第二弹

1、redis命令-hash类型练习2、redis命令-list类型练习3、redis命令-set类型练习

YOLOV5环境搭建以及训练COCO128数据集

前言记录了自己训练coco128的全过程手把手教你YOLOV5环境搭建以及训练COCO128数据集。相关配置文件在百度网盘中。如果懒得话可以直接全部用我的数据一、准备工作1.1创建环境打开anaconda power shell&#xff08;最好以管理员身份运行&#xff0c;免得到后面相关文件权限进不去…

sentinel-介绍(一)

Sentinel Website&#xff08;Sentinel 官网网站&#xff09; Sentinel: 分布式系统的流量防卫兵 Sentinel 是什么&#xff1f; 随着微服务的流行&#xff0c;服务和服务之间的稳定性变得越来越重要。Sentinel 以流量为切入点&#xff0c;从流量控制、流量路由、熔断降级、系…

ansible配置yum源仓库

1.挂载本地光盘到/mnt 2.配置yum源仓库文件通过多种方式实现 仓库1 &#xff1a; Name: RH294_Base Description&#xff1a; RH294 base software Base urt: file:///mnt/BaseOS 不需要验证钦件包 GPG 签名 启用此软件仓库 仓库 2: Name: RH294_S…

LeetCode刷题模版:41 - 50

目录 简介41. 缺失的第一个正数42. 接雨水43. 字符串相乘44. 通配符匹配45. 跳跃游戏 II46. 全排列47. 全排列 II48. 旋转图像49. 字母异位词分组50. Pow(x, n)结语简介 Hello! 非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出~ ଘ(੭ˊᵕˋ)੭ 昵称:海轰 标…

axios系列之取消请求

文章の目录写在最后使用 cancel token 取消请求 Axios 的 cancel token API 基于cancelable promises proposal&#xff0c;它还处于第一阶段。 可以使用 CancelToken.source 工厂方法创建 cancel token&#xff0c;像这样&#xff1a; const CancelToken axios.CancelToken;…

Revit二次开发小技巧(十七)实时监控模型线的生成

前言&#xff1a;项目中需要一个需求&#xff0c;用户想调用出Revit中自带的绘制模型线方法&#xff0c;然后再绘制结束时&#xff0c;可以拿到绘制的模型线&#xff0c;然后实现后面的算法。这里记录一种方法&#xff0c;通过DocumentChange事件修改Tag的PropertyChanged事件来…

【Python】pandas获取全省人口数据并作可视化分析

前言 今天我们看看自己所在的省份的人口人数&#xff0c;使用pandas并作可视化分析。 环境使用 python 3.9pycharm 模块使用 pandasPandas 是基于NumPy的一种工具&#xff0c;该工具是为解决数据分析任务而创建的。Pandas 纳入了大量库和一些标准的数据模型&#xff0c;提供…

java和vue募捐网水滴筹项目捐款爱心系统筹款系统

简介 募捐网&#xff0c;注册用户实名认证通过后可以发布募捐&#xff0c;管理员审核募捐通过后&#xff0c;前台用户可以看到该募捐信息&#xff0c;进行募捐或者举报&#xff08;管理审核举报成功后&#xff0c;会拉黑该募捐发起人&#xff09;&#xff0c;前台展示公告、爱…

83. 删除排序链表中的重复元素(链表)

文章目录题目描述方法一 暴力法方法二 递归法参考文献题目描述 给定一个已排序的链表的头 head &#xff0c; 删除所有重复的元素&#xff0c;使每个元素只出现一次 。返回 已排序的链表 。 示例 1&#xff1a; 输入&#xff1a;head [1,1,2] 输出&#xff1a;[1,2] 示例 2…

酷开系统——家庭场景下的智能营销系统!

随着人们生活方式的改变&#xff0c;以往传统的营销资源和渠道正在慢慢陷入一个“无用”的尴尬境地&#xff0c;而作为家庭娱乐中心的智能大屏&#xff0c;近两年所表现出来的数据和效果却逐渐备受企业和品牌方关注&#xff0c;有数据显示&#xff0c;智能大屏的家庭覆盖规模正…

bug 站在一个测试的角度看bug

如何描述一个bug ?如何定义bug的级别 ?bug的生命周期 ?如何开始第一次测试 ?测试的执行和bug管理 ?产生争执怎么办 ?如何描述一个bug?作为一名测试人员&#xff0c;提bug是最基础的工作&#xff0c;那我们如何才能把bug提的清晰易懂呢?发现问题的版本 开发人员获取对应…

k8s之基于kubeadm搭建k8s集群

写在前面 你可能知道搭建k8s集群的kind&#xff0c;minikube工具&#xff0c;但是他们都太简单了&#xff0c;不能满足生产级的要求&#xff0c;想要真正的部署生产级别的k8s集群&#xff0c;我们还需要另外一个集群管理工具kubeadm ,本文就一起看下如何使用该工具来搭建k8s集…

STM32MP157驱动开发——Linux下的单总线驱动

STM32MP157驱动开发——Linux下的单总线驱动0.前言一、DS18B20 及工作时序简介1.DS18B20 简介2.DS18B20 时序简介4.DS18B20温度读取流程二、DHT11 及工作时序简介1.DHT11 简介2.DHT11 工作时序简介三、驱动开发1.DS18B20驱动1&#xff09;修改设备树2&#xff09;驱动编写2.测试…

IB生物:干细胞与生命的各种功能

国际学校生物老师解读IB生物&#xff0c;感兴趣的同学记得收藏哦~IB生物分为SL(standard level)和HL(higher level)SL有6个topic∶细胞生物&#xff0c;分子生物&#xff0c;遗传学&#xff0c;生态学&#xff0c;物种进化以及多样性和人体生理。HL除了上述6个topic外还要加上∶…

Python入门基础实例讲解——两个数字比大小,并输出最大值

嗨害大家好鸭&#xff01; 我是小熊猫~ 今天也是给大家带来干货的一天~ pycharm永久激活码可以从这里找到我&#xff1a; 输出&#xff1a;print&#xff08;&#xff09; print() 方法用于打印输出&#xff0c;最常见的一个函数。 比较运算符 >&#xff1a; 大于&#…

关于校园网的各种连接问题

校园网网络使用异常&#xff0c;掉线、卡顿以及无法连接网络&#xff0c;经网络上收据的信息&#xff0c;大致分为五类&#xff1a;1.能获取到校园网地址如10.*.*.*&#xff0c;但无法跳出认证界面。2.物理链路故障&#xff1b;3.IP配置故障&#xff1b;4. 网络正常&#xff0c…

SpringCloud高级应用-1(SpringCloud技术栈概览)

1、SpringCloud技术栈 开发分布式系统可能具有挑战性&#xff0c;复杂性已从应用程序层转移到网络层&#xff0c;并要求服务之间进行更多的交互。将代码设为“cloud-native”就需要解决12-factor&#xff0c;例如外部配置&#xff0c;服务无状态&#xff0c;日志记录以及连接…

【矩阵论】8. 常用矩阵总结——单阵,正规阵,幂0阵,幂等阵,循环阵

矩阵论 1. 准备知识——复数域上矩阵,Hermite变换) 1.准备知识——复数域上的内积域正交阵 1.准备知识——Hermite阵&#xff0c;二次型&#xff0c;矩阵合同&#xff0c;正定阵&#xff0c;幂0阵&#xff0c;幂等阵&#xff0c;矩阵的秩 2. 矩阵分解——SVD准备知识——奇异值…