【从零开始学习深度学习】40. 算法优化之AdaGrad算法介绍及其Pytorch实现

news2024/11/16 1:49:40

之前介绍的梯度下降算法中,目标函数自变量的每一个元素在相同时间步都使用同一个学习率来自我迭代。例如,假设目标函数为 f f f,自变量为一个二维向量 [ x 1 , x 2 ] ⊤ [x_1, x_2]^\top [x1,x2],该向量中每一个元素在迭代时都使用相同的学习率。例如,在学习率为 η \eta η的梯度下降中,元素 x 1 x_1 x1 x 2 x_2 x2都使用相同的学习率 η \eta η来自我迭代:
x 1 ← x 1 − η ∂ f ∂ x 1 , x 2 ← x 2 − η ∂ f ∂ x 2 . x_1 \leftarrow x_1 - \eta \frac{\partial{f}}{\partial{x_1}}, \quad x_2 \leftarrow x_2 - \eta \frac{\partial{f}}{\partial{x_2}}. x1x1ηx1f,x2x2ηx2f.
通过上一篇文章动量法中我们知道,当 x 1 x_1 x1 x 2 x_2 x2的梯度值有较大差别时,需要选择足够小的学习率使得自变量在梯度值较大的维度上不发散。但这样会导致自变量在梯度值较小的维度上迭代过慢。动量法依赖指数加权移动平均使得自变量的更新方向更加一致,从而降低发散的可能。

本文我们介绍AdaGrad算法,它可以根据自变量在每个维度的梯度值的大小来调整各个维度上的学习率,从而避免统一的学习率难以适应所有维度的问题

目录

  • 1. AdaGrad算法介绍
    • 1.1 AdaGrad算法特点
  • 2. 从零实现AdaGrad算法
  • 3. Pytorch简洁实现AdaGrad算法--使用optim.Adagrad
  • 总结

1. AdaGrad算法介绍

AdaGrad算法会使用一个小批量随机梯度 g t \boldsymbol{g}_t gt按元素平方的累加变量 s t \boldsymbol{s}_t st。在时间步0,AdaGrad将 s 0 \boldsymbol{s}_0 s0中每个元素初始化为0。在时间步 t t t,首先将小批量随机梯度 g t \boldsymbol{g}_t gt按元素平方后累加到变量 s t \boldsymbol{s}_t st

s t ← s t − 1 + g t ⊙ g t , \boldsymbol{s}_t \leftarrow \boldsymbol{s}_{t-1} + \boldsymbol{g}_t \odot \boldsymbol{g}_t, stst1+gtgt,

其中 ⊙ \odot 是按元素相乘。接着,我们将目标函数自变量中每个元素的学习率通过按元素运算重新调整一下:

x t ← x t − 1 − η s t + ϵ ⊙ g t , \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \frac{\eta}{\sqrt{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, xtxt1st+ϵ ηgt,

其中 η \eta η是学习率, ϵ \epsilon ϵ是为了维持数值稳定性而添加的常数,如 1 0 − 6 10^{-6} 106。这里开方、除法和乘法的运算都是按元素运算的。这些按元素运算使得目标函数自变量中每个元素都分别拥有自己的学习率。

1.1 AdaGrad算法特点

需要强调的是,小批量随机梯度按元素平方的累加变量 s t \boldsymbol{s}_t st出现在学习率的分母项中。因此,如果目标函数有关自变量中某个元素的偏导数一直都较大,那么该元素的学习率将下降较快;反之,如果目标函数有关自变量中某个元素的偏导数一直都较小,那么该元素的学习率将下降较慢。然而,由于 s t \boldsymbol{s}_t st一直在累加按元素平方的梯度,自变量中每个元素的学习率在迭代过程中一直在降低(或不变)。所以,当学习率在迭代早期降得较快且当前解依然不佳时,AdaGrad算法在迭代后期由于学习率过小,可能较难找到一个有用的解

下面我们仍然以目标函数 f ( x ) = 0.1 x 1 2 + 2 x 2 2 f(\boldsymbol{x})=0.1x_1^2+2x_2^2 f(x)=0.1x12+2x22为例观察AdaGrad算法对自变量的迭代轨迹。我们实现AdaGrad算法并使用和之前动量法相同的学习率0.4。可以看到,自变量的迭代轨迹较平滑。但由于 s t \boldsymbol{s}_t st的累加效果使学习率不断衰减,自变量在迭代后期的移动幅度较小。

%matplotlib inline
import math
import torch
import sys
import d2lzh_pytorch as d2l

def adagrad_2d(x1, x2, s1, s2):
    g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6  # 前两项为自变量梯度
    s1 += g1 ** 2
    s2 += g2 ** 2
    x1 -= eta / math.sqrt(s1 + eps) * g1
    x2 -= eta / math.sqrt(s2 + eps) * g2
    return x1, x2, s1, s2

def f_2d(x1, x2):
    return 0.1 * x1 ** 2 + 2 * x2 ** 2

eta = 0.4
d2l.show_trace_2d(f_2d, d2l.train_2d(adagrad_2d))

输出:

epoch 20, x1 -2.382563, x2 -0.158591

在这里插入图片描述

下面将学习率增大到2。可以看到自变量更为迅速地逼近了最优解。

eta = 2
d2l.show_trace_2d(f_2d, d2l.train_2d(adagrad_2d))

输出:

epoch 20, x1 -0.002295, x2 -0.000000

在这里插入图片描述

2. 从零实现AdaGrad算法

同动量法一样,AdaGrad算法需要对每个自变量维护同它一样形状的状态变量。我们根据AdaGrad算法中的公式实现该算法。

features, labels = d2l.get_data_ch7()

def init_adagrad_states():
    s_w = torch.zeros((features.shape[1], 1), dtype=torch.float32)
    s_b = torch.zeros(1, dtype=torch.float32)
    return (s_w, s_b)

def adagrad(params, states, hyperparams):
    eps = 1e-6
    for p, s in zip(params, states):
        s.data += (p.grad.data**2)
        p.data -= hyperparams['lr'] * p.grad.data / torch.sqrt(s + eps)

与之前小批量随机梯度下降相比,这里使用更大的学习率来训练模型。

d2l.train_ch7(adagrad, init_adagrad_states(), {'lr': 0.1}, features, labels)

输出:

loss: 0.243675, 0.049749 sec per epoch

在这里插入图片描述

3. Pytorch简洁实现AdaGrad算法–使用optim.Adagrad

通过名称为Adagrad的优化器方法,我们便可使用PyTorch提供的AdaGrad算法来训练模型。

d2l.train_pytorch_ch7(torch.optim.Adagrad, {'lr': 0.1}, features, labels)

输出:

loss: 0.243147, 0.040675 sec per epoch

在这里插入图片描述

总结

  • AdaGrad算法在迭代过程中不断调整学习率,并让目标函数自变量中每个元素都分别拥有自己的学习率。
  • 使用AdaGrad算法时,自变量中每个元素的学习率在迭代过程中一直在降低(或不变)。

如果文章内容对你有帮助,感谢点赞+关注!

欢迎关注下方GZH:阿旭算法与机器学习,共同学习交流~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/150359.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

线程安全详解

线程安全 线程安全是多线程编程时的计算机程序代码中的一个概念。在拥有共享数据的多条线程并行执行的程序中,线程安全的代码会通过同步机制保证各个线程都可以正常且正确的执行,不会出现数据污染等意外情况。 多个线程访问同一个对象时,如…

【docker11】docker安装常用软件

docker安装常用软件 1.安装软件说明 框架图 总体步骤: 搜索镜像拉去镜像查看镜像启动镜像 - 服务端口映射停止容器移除容器 1.安装tomcat docker hub上查找tomcat镜像 从docker hub上拉取tomcat镜像到本地 docker images查看是否有拉取到的tomcat 使用to…

Spring Security 导致 Spring Boot 跨域失效问题

1、CORS 是什么首先我们要明确,CORS 是什么,以及规范是如何要求的。这里只是梳理一下流程。CORS 全称是 Cross-Origin Resource Sharing,直译过来就是跨域资源共享。要理解这个概念就需要知道域、资源和同源策略这三个概念。域,指…

sahrding-jdbc的雪花算法取模为0或1的问题

工作时无意间发现sahrding-jdbc使用雪花算法生成的id 在某一业务分库分表 永远在那两个库表里面,排查后这里做下分享 环境、配置、问题介绍 16库16表使用的是org.apache.shardingsphere.core.strategy.keygen下面generateKey生成id分库表算法是对16取模生成数据永远在0库0表 0…

如何有效的防护暴力破解和撞库攻击

在网络威胁领域,暴力破解攻击仍然是网络犯罪分子非常喜爱且有利可图的攻击方法。,黑客通过收集互联网已泄露的用户和密码信息,生成对应的字典表,由于许多用户重复使用相同的用户名和密码,攻击者可以使用撞库攻击获得对…

前缀和算法

目录1.概述2.代码实现2.1.一维前缀和2.2.二维前缀和3.应用本文参考: LABULADONG 的算法网站 1.概述 前缀和算法分为一维和二维,一维前缀和可以快速求序列中某一段的和,而二维前缀和可以快速求一个矩阵中某个子矩阵的和。 2.代码实现 2.1.一…

约拍小程序开发,优化约拍产业路径

随着网红、直播经济的发展,年轻群体对于拍摄服务的需求正在急速增长,与此同时该群体用户又极具个性,很多传统影楼拍摄价格高、拍摄风格固定化、等待时间久,个人摄影师宣传推广难度又大,导致拍摄需求被抑制,…

_Linux多线程-基础篇

文章目录1. 什么是线程Linux中的线程叫做轻量级进程(LWP)2. 线程的优点3. 线程的缺点4. 线程异常5. 线程用途6. Linux进程VS线程单进程7. 总结线程在进程内部执行是OS调度的基本单位。不同视角看待进程轻量级进程1. 什么是线程 在一个程序里的一个执行路…

【Cfeng Work】 Open API的intro和 梳理

OpenAPI 开放平台 内容管理Open API intro腾讯云OpenAPIOpenAPI请求API密钥管理云API签名过程拼接规范请求串 CanonicalRequest拼接待签名字符串计算签名拼接到Authorization中云API签名失败规范Token 和 长期密钥公共参数OpenApi设计AppId、AppSecretsign签名timestamp 时间戳…

制造服务行业需要项目管理软件吗?

伴随着时代的快速发展,电子制造服务行业也正在飞速发展,在日新月异得时代步伐下,科学得管理方法和现代化得管理工具相结合,保证企业得进步与发展。项目管理工具通过构造高效统一的项目管理平台,优化企业管理中存在的问…

win10 conda安装labme安装和使用

1、安装conda 在WIN10中配置conda 1.1miniconda下载 https://docs.conda.io/en/latest/miniconda.html 可以任意选择对应的版本,安装时选择配置路径,以免后期重复配置环境。 1.2.2 在环境变量中添加路径 Win R,打开运行,输入…

在ubuntu系统上用pyinstaller加密打包yolov5项目代码的详细步骤

目录0. 背景1. 创建虚拟环境2. pyinstaller打包2.1. 生成并修改spec文件2.2. 重新生成二进制文件3. 测试4. 加密打包4.1. 创建入口函数main.py4.2. 修改detect.py4.3. 加密生成main.spec文件4.4. 修改main.spec文件4.5. 生成二进制文件4.6. 测试0. 背景 最近需要在ubuntu 18.0…

Payso×OceanBase:云上拓新,开启云数据库的智能托管

日前,聚合支付厂商 Payso( 独角鲨北京科技有限公司)的平台、交易系统引入 OceanBase 原生分布式数据库,实现显著降本增效—— 硬件成本降低 20~30%,查询效率提升 80%,执行效率提升了 30%&#x…

学习IB生物,我们需要知道什么知识点?

学习IB课程的很多同学应该都听说过一个说法:IB生物算是理科中的文科,没有公式推导,只有大量需要记忆的内容,不需要用学习理科的思维去学习,其实这种观点是有误区的。实际上,学习生物这门课将会面对的是一个…

01背包问题详解

目录 1.1二维dp数组 1.2一维dp数组改进 1.3相关例题 1.3.1分割等和子集 1.3.2一和零 1.1二维dp数组 概述:背包的最大重量是固定的,物品的数量,重量也是固定的,并且物品只能放一次,可以选择放或者不放&#xff0c…

Redis 核心原理串讲(中),架构演进之高可用

文章目录Redis 核心原理总览(全局篇)前言一、持久化1、RDB2、AOF3、AOF 重写4、混合持久化5、对比二、副本1、同步模式2、部分重同步三、哨兵1、核心能力2、节点通信总结Redis 核心原理总览(全局篇) 正文开始之前,我们…

【FPGA】Verilog:基本实验步骤演示 | 功能电路创建 | 添加仿真激励 | 观察记录仿真波形

前言: 本章内容主要是演示Vivado下利用Verilog语言进行电路设计、仿真、综合和下载的完整过程、Verilog语言基本运用,电路设计和Test Bench程序的编写、以及实验开发板的使用,通过观察和数据记录理解仿真和FGPA实现的差异。 目录 Ⅰ. 基础知…

考研政治 马原 易混淆知识点

马哲 1. 哲学基本问题 从何者为第一性,分为唯物主义和唯心主义 从是否具有同一性,分为可知论(有同一性)和不可知论(无同一性) 辩证法:联系,发展的观点看世界,认为发展的…

python对接API二次开发高级实战案例解析:百度地图Web服务API封装函数(行政区划区域检索、地理编码、国内天气查询、IP定位、坐标转换)

文章目录前言一、IP定位1.请求URL2.获取IP定位封装函数3.输出结果二、国内天气查询1.请求url2.天气查询封装函数3.输出结果三、行政区划区域检索1.请求url2.区域检索封装函数3.输出结果四、地理编码1.请求url2.地理编码封装函数3.输出结果五、坐标转换1.请求url2.坐标转换封装函…

一文细说Linux虚拟文件系统原理

在 Unix 的世界里,有句很经典的话:一切对象皆是文件。这句话的意思是说,可以将 Unix 操作系统中所有的对象都当成文件,然后使用操作文件的接口来操作它们。Linux 作为一个类 Unix 操作系统,也努力实现这个目标。 虚拟…