【从零开始学习深度学习】42. 算法优化之AdaDelta算法【基于AdaGrad算法的改进】介绍及其Pytorch实现

news2024/9/27 7:24:15

除了上一篇文章介绍的RMSProp算法以外,另一个常用优化算法AdaDelta算法也针对AdaGrad算法在迭代后期可能较难找到有用解的问题做了改进 。比较有意思的是,AdaDelta算法没有学习率这一超参数

目录

  • 1. AdaDelta算法介绍
  • 2. 从零实现AdaDelta算法
  • 3. Pytorch简洁实现AdaDelta算法---optim.Adadelta
  • 总结

1. AdaDelta算法介绍

AdaDelta算法也像RMSProp算法一样,使用了小批量随机梯度 g t \boldsymbol{g}_t gt按元素平方的指数加权移动平均变量 s t \boldsymbol{s}_t st。在时间步0,它的所有元素被初始化为0。给定超参数 0 ≤ ρ < 1 0 \leq \rho < 1 0ρ<1(对应RMSProp算法中的 γ \gamma γ),在时间步 t > 0 t>0 t>0,同RMSProp算法一样计算

s t ← ρ s t − 1 + ( 1 − ρ ) g t ⊙ g t . \boldsymbol{s}_t \leftarrow \rho \boldsymbol{s}_{t-1} + (1 - \rho) \boldsymbol{g}_t \odot \boldsymbol{g}_t. stρst1+(1ρ)gtgt.

与RMSProp算法不同的是,AdaDelta算法还维护一个额外的状态变量 Δ x t \Delta\boldsymbol{x}_t Δxt,其元素同样在时间步0时被初始化为0。我们使用 Δ x t − 1 \Delta\boldsymbol{x}_{t-1} Δxt1来计算自变量的变化量:

g t ′ ← Δ x t − 1 + ϵ s t + ϵ ⊙ g t , \boldsymbol{g}_t' \leftarrow \sqrt{\frac{\Delta\boldsymbol{x}_{t-1} + \epsilon}{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, gtst+ϵΔxt1+ϵ gt,

其中 ϵ \epsilon ϵ是为了维持数值稳定性而添加的常数,如 1 0 − 5 10^{-5} 105。接着更新自变量:

x t ← x t − 1 − g t ′ . \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}'_t. xtxt1gt.

最后,我们使用 Δ x t \Delta\boldsymbol{x}_t Δxt来记录自变量变化量 g t ′ \boldsymbol{g}'_t gt按元素平方的指数加权移动平均:

Δ x t ← ρ Δ x t − 1 + ( 1 − ρ ) g t ′ ⊙ g t ′ . \Delta\boldsymbol{x}_t \leftarrow \rho \Delta\boldsymbol{x}_{t-1} + (1 - \rho) \boldsymbol{g}'_t \odot \boldsymbol{g}'_t. ΔxtρΔxt1+(1ρ)gtgt.

可以看到,如不考虑 ϵ \epsilon ϵ的影响,AdaDelta算法跟RMSProp算法的不同之处在于使用 Δ x t − 1 \sqrt{\Delta\boldsymbol{x}_{t-1}} Δxt1 来替代学习率 η \eta η

2. 从零实现AdaDelta算法

AdaDelta算法需要对每个自变量维护两个状态变量,即 s t \boldsymbol{s}_t st Δ x t \Delta\boldsymbol{x}_t Δxt。我们按AdaDelta算法中的公式实现该算法。

%matplotlib inline
import torch
import sys 
import d2lzh_pytorch as d2l

features, labels = d2l.get_data_ch7()

def init_adadelta_states():
    s_w, s_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)
    delta_w, delta_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)
    return ((s_w, delta_w), (s_b, delta_b))

def adadelta(params, states, hyperparams):
    rho, eps = hyperparams['rho'], 1e-5
    for p, (s, delta) in zip(params, states):
        s[:] = rho * s + (1 - rho) * (p.grad.data**2)
        g =  p.grad.data * torch.sqrt((delta + eps) / (s + eps))
        p.data -= g
        delta[:] = rho * delta + (1 - rho) * g * g

使用超参数 ρ = 0.9 \rho=0.9 ρ=0.9来训练模型。

d2l.train_ch7(adadelta, init_adadelta_states(), {'rho': 0.9}, features, labels)

输出:

loss: 0.243728, 0.062991 sec per epoch

在这里插入图片描述

3. Pytorch简洁实现AdaDelta算法—optim.Adadelta

通过名称为Adadelta的优化器方法,我们便可使用PyTorch提供的AdaDelta算法。它的超参数可以通过rho来指定。

d2l.train_pytorch_ch7(torch.optim.Adadelta, {'rho': 0.9}, features, labels)

输出:

loss: 0.242104, 0.047702 sec per epoch

在这里插入图片描述

总结

  • AdaDelta算法没有学习率超参数,它通过使用有关自变量更新量平方的指数加权移动平均的项来替代RMSProp算法中的学习率。

如果文章内容对你有帮助,感谢点赞+关注!

欢迎关注下方GZH:阿旭算法与机器学习,共同学习交流~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/150945.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UDS诊断系列介绍04-10会话服务

本文框架1. 系列介绍10服务概述2. 10服务请求与应答2.1 10服务请求2.2 肯定应答2.3 否定应答1. 系列介绍 UDS&#xff08;Unified Diagnostic Services&#xff09;协议&#xff0c;即统一的诊断服务&#xff0c;是面向整车所有ECU的一种诊断通信方式&#xff0c;是基于ISO 14…

Linux学习笔记——集群化环境前置准备

5.7、集群化环境前置准备 5.7.1、介绍 在前面&#xff0c;我们所学习安装的软件&#xff0c;都是以单机模式运行的。 后续&#xff0c;我们将要学习大数据相关的软件部署&#xff0c;所以后续我们所安装的软件服务&#xff0c;大多数都是以集群化&#xff08;多台服务器共同…

使用OpenCV读取视频、图片并做简单处理

1.OpenCV的安装与卸载 在conda中安装opencv&#xff0c;打开Anaconda Prompt 使用国内镜像源安装opencv&#xff0c;命令如下&#xff1a; pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple 也可以安装opencv的另一个扩展包opencv-contrib-python&am…

Centos下使用yum安装Mysql8(Mysql5.7)以及常见的配置和使用

记录一下在centos7.x下面使用yum方式安装mysql8(Mysql5.7)关系型数据库安装之前一般需要先确定centos7.x服务器里是否已经安装&#xff0c;未安装或者刚初始化的centos7.x服务器最好安装&#xff0c;原来已经有的要升级的话一定要对系统原有mysql 或mariadb卸载干净&#xff0c…

系统测试的具体测试类型

系统测试&#xff1a;是为判断系统是否符合要求而对集成的软、硬件系统进行的测试活动、它是将已经集成好的软件系统&#xff0c;作为基于整个计算机系统的一个元素&#xff0c;与计算机硬件、外设、某些支持软件、人员、数据等其他系统元素结合在一起&#xff0c;在实际运行环…

Charles - 夜神模拟器证书安装App抓包

Charles - 夜神模拟器证书安装App抓包 文章目录Charles - 夜神模拟器证书安装App抓包前言一、软件安装1.Openssl安装1.1下载安装1.2配置环境变量1.3查看openssl版本&#xff0c;输入命令&#xff1a;openssl version2.夜神模拟器安装1.1 下载安装1.2工具准备&#xff0c;MT管理…

【Lilishop商城】No4-6.业务逻辑的代码开发,涉及到:接口入参、出参开发逻辑,及POJO的各种总结

仅涉及后端&#xff0c;全部目录看顶部专栏&#xff0c;代码、文档、接口路径在&#xff1a; 【Lilishop商城】记录一下B2B2C商城系统学习笔记~_清晨敲代码的博客-CSDN博客 全篇会结合业务介绍重点设计逻辑&#xff0c;其中重点包括接口类、业务类&#xff0c;具体的结合源代…

完整iOS APP发布App Store上架流程指南

本文章的目的在于教会你如何创建ios的打包证书和如何上架假如你没有任何的打包或上架经验&#xff0c;参考本文有很大的收益。通常创建ios证书和上架&#xff0c;是需要MAC电脑的&#xff0c;本文重点介绍如何在没有mac电脑的情况下&#xff0c;创建mac证书和上架。假如你还没有…

STM32CUBEIDE-简单案例生成

STM32CUBEIDE-简单案例生成 京东链接&#xff1a;https://i-item.jd.com/66584659856.html 生成工程 使用STM32CUBEMX生成例程&#xff0c;这里使用STM32F103C8T6系统板。 新建一个工程&#xff0c;这里有3种新建工程方式。 ● 基于MCU/MPU新建工程 ● 基于ST模块新建工程 ●…

PCB板缺陷检测机器视觉识别算法 yolo

PCB板缺陷检测机器视觉识别算法通过pythonyolo系列网络深度学习模型对PCB电路板外观实时监测&#xff0c;当模型算法监测到有缺陷的PCB板时立即抓拍存档。Python是一种由Guido van Rossum开发的通用编程语言&#xff0c;它很快就变得非常流行&#xff0c;主要是因为它的简单性和…

Vue2进阶笔记

Vue2进阶笔记一、基础知识1.1 computed计算属性1.2 watch监视属性1.3 动态绑定样式1.4 列表循环渲染 key的探讨1.5 列表过滤1.6 数据监视1.7 表单收集1.8 过滤器1.9 生命周期函数1.10 nextTick1.11 动画与过渡1.12 脚手架配置跨域代理二、组件化开发2.1 演替与定义2.2 使用与注…

多线程进阶(一)锁策略,CAS及Synchronized原理

目录 前言&#xff1a; 常见锁策略 CAS CAS应用场景 标准库中基于CAS实现的原子类介绍 代码实现 ABA问题 Synchronized原理 锁升级 锁消除 锁粗化 小结&#xff1a; 前言&#xff1a; 通过这篇文章可以更加深入理解锁内部的一些实现原理&#xff0c;以及怎样描述一…

Qt 使用 Matlab函数

背景&#xff1a;个人的Qt项目中&#xff0c;需要一个图片分割算法。该算法之前在Matlab上实现过&#xff0c;同时转成C版本有点麻烦&#xff0c;因此尝试通过Qt与Matlab编程相结合的方式&#xff0c;实现该功能。 注意&#xff1a;以下所有功能及配置过程&#xff0c;默认已经…

CSDN竞赛21期题解

总结 &#xff08;PS&#xff1a;这次竞赛的奖励对我诱惑力感觉没多大&#xff0c;因为高级背包不久前才收到一个&#xff0c;邹老师的两本签名书也早就拿到了&#xff0c;程序员杂志、帆布包也都有了&#xff0c;扑克牌都拿了几副了&#xff0c;所以还是换点其他的书比较好&a…

c语言tips-【c语言内存模型】

0.摘要 C语言是比较接近底层的语言&#xff0c;因此它的很多知识点是和操作系统挂钩的&#xff0c;例如它的内存模型&#xff0c;其实也是操作系统进程的内存模型&#xff0c;本文章就是解释进程&#xff0c;虚拟内存空间&#xff0c;内存模型的相关知识和它们之间的联系 1. 虚…

热交换器及一维平行流换热器分析(Matlab代码实现)

目录 &#x1f4a5;1 概述 &#x1f4da;2 运行结果 &#x1f389;3 参考文献 &#x1f468;‍&#x1f4bb;4 Matlab代码 &#x1f4a5;1 概述 首先试图对热交换器的设置进行建模&#xff0c;并获得该过程的控制方程。使用相应的控制方程并设置边界条件并获得适当的边界值…

RHCE第五天之NFS服务器详解

文章目录一、NFS服务器简介二、NFS的使用三、客户端使用autofs自动挂载四、实验练习一、NFS服务器简介 NFS&#xff08;Network File System&#xff0c;网络文件系统&#xff09;&#xff1a; 是FreeBSD支持的文件系统中的一种&#xff0c;它允许网络中的计算机&#xff08;不…

Qt / Qml 视频硬解码(CUDA)中如何实现无上传硬渲染(一)

【写在前面】 很多时候&#xff0c;我们在对视频的解码和渲染的处理都要经过以下步骤&#xff1a; 软解码&#xff0c;视频帧位于内存。 软渲染&#xff0c;需要拷贝到图像然后渲染&#xff1b;硬渲染则需要上传纹理&#xff0c;然后渲染。硬解码&#xff0c;视频帧位于显存。…

OPengl学习(四)——顶点数组

文章目录1、 问题2、步骤2.1 激活数组2.2 指定数组的数据2.3 解引用和渲染3、例子1、 问题 1、在前面我们实现一个多彩三角形&#xff0c;调用三次glvertext&#xff08;&#xff09;函数&#xff0c;如果在多边形&#xff0c;如20条边的&#xff0c;那么就要使用22次函数&…

【大数据之路】数据管理篇 《三》存储和成本管理 【搬运小结】

文章目录【大数据之路】数据管理篇 《三》存储和成本管理1.1数据压缩1.2存储治理项优化1.3生命周期管理1.3.1 生命周期管理策略1.3.2 生命周期管理策略1.4数据成本计量【大数据之路】数据管理篇 《三》存储和成本管理 1.1数据压缩 在分布式文件系统中&#xff0c;为了提高数据…