手撕深度学习中的优化器

news2025/1/17 22:55:59

深度学习中的优化算法采用的原理是梯度下降法,选取适当的初值params,不断迭代,进行目标函数的极小化,直到收敛。由于负梯度方向时使函数值下降最快的方向,在迭代的每一步,以负梯度方向更新params的值,从而达到减少函数值的目的。

Gradient descent in deep learning

在这里插入图片描述

Optimizer

class Optimizer:
    """
    优化器基类,默认是L2正则化
    """

    def __init__(self, lr, weight_decay):
        self.lr = lr
        self.weight_decay = weight_decay

    def step(self, grads, params):
        # 计算当前时刻下降的步长
        decrement = self.compute_step(grads)
        if self.weight_decay:
            decrement += self.weight_decay * params
        # 更新参数
        params -= decrement

    def compute_step(self, grads):
        raise NotImplementedError

SGD

随机梯度下降
θ t = θ − η ⋅ g t \theta_t = \theta-\eta \cdot g_t θt=θηgt

  • 每次随机抽取一个batch的样本进行梯度下降

  • 对学习率敏感,太小收敛速度很慢,太大会在极小值附近震荡

  • 对于非凸函数,容易陷入局部最小值或鞍点

class SGD(Optimizer):
    """
    stochastic gradient descent
    """

    def __init__(self, lr=0.1, weight_decay=0.0):
        super().__init__(lr, weight_decay)

    def compute_step(self, grads):
        return self.lr * grads

SGDm

SGD中加入动量(momentum)模拟是物体运动时的惯性,即更新的时候在一定程度上保留之前更新的方向,同时利用当前batch的梯度微调最终的更新方向。这样一来,可以在一定程度上增加稳定性,从而学习地更快,并且还有一定摆脱局部最优的能力。
υ t = γ υ t − 1 + g t θ t = θ t − 1 − η υ t \upsilon_t = \gamma \upsilon_{t-1} + g_t \qquad \theta_t=\theta_{t-1} - \eta \upsilon_t υt=γυt1+gtθt=θt1ηυt

  • gt是当前时刻的梯度,vt是当前时刻参数的下降距离
  • 带动量的小球滚下山坡,可能会错过山谷
class SGDm(Optimizer):
    """
    stochastic gradient descent with momentum
    """

    def __init__(self, lr=0.1, momentum=0.9, weight_decay=0.0):
        super().__init__(lr, weight_decay)
        self.momentum = momentum
        self.beta = 0

    def compute_step(self, grads):
        self.beta = self.momentum * self.beta + (1 - self.momentum) * grads
        return self.lr * self.beta

Adagrad

θ t = θ t − 1 − η ∑ i = 0 t − 1 ( g i ) 2 g t − 1 \theta_t=\theta_{t-1} - \frac{\eta}{\sqrt{\sum^{t-1}_{i=0}{(g_i)^2}}}g_{t-1} θt=θt1i=0t1(gi)2 ηgt1

  • 自适应调节学习率
  • 对低频的参数做较大的更新,对高频的做较小的更新,也因此,对于稀疏的数据它的表现很好,很好地提高了 SGD 的鲁棒性
  • 缺点是分母梯度的累积,最后梯度消失
class Adagrad(Optimizer):
    """
    Divide the learning rate of each parameter by the
    root-mean-square of its previous derivatives
    """

    def __init__(self, lr=0.1, eps=1e-8, weight_decay=0.0):
        super().__init__(lr, weight_decay)
        self.eps = eps
        self.state_sum = 0

    def compute_step(self, grads):
        self.state_sum += grads ** 2
        decrement = grads / (self.state_sum ** 0.5 + self.eps) * self.lr
        return decrement

RMSProp

指数滑动平均更新梯度的平方,为解决Adagrad 梯度急剧下降而提出
υ 1 = g 0 2 υ t = α υ t − 1 + ( 1 − α ) ( g t − 1 ) 2 \upsilon_1 = g_0^2 \qquad \upsilon_t = \alpha\upsilon_{t-1} + (1-\alpha)(g_{t-1})^2 υ1=g02υt=αυt1+(1α)(gt1)2

θ t = θ t − 1 − η υ t g t − 1 \theta_t=\theta_{t-1} - \frac{\eta}{\sqrt{\upsilon_t}} g_{t-1} θt=θt1υt ηgt1

class RMSProp(Optimizer):
    """
    Root Mean Square Prop optimizer
    """

    def __init__(self, lr=0.1, alhpa=0.99, eps=1e-8, weight_decay=0.0):
        super().__init__(lr, weight_decay)
        self.eps = eps
        self.alpha = alhpa
        self.state_sum = 0

    def compute_step(self, grads):
        self.state_sum = self.alpha * self.state_sum + (1 - self.alpha) * grads ** 2
        decrement = grads / (self.state_sum ** 0.5 + self.eps) * self.lr
        return decrement

Adam

SGDmRMSProp的结合,Adam 算法通过计算梯度的一阶矩估计和二阶矩估计而为不同的参数设计独立的自适应性学习率。

  • SGDm

θ t = θ t − 1 − m t m t = β 1 m t − 1 + ( 1 − β 1 ) g t − 1 \theta_t=\theta_{t-1} - m_t \qquad m_t = \beta_1 m_{t-1} + (1-\beta_1)g_{t-1} θt=θt1mtmt=β1mt1+(1β1)gt1

  • RMSProp

θ t = θ t − 1 − η υ t g t − 1 \theta_t=\theta_{t-1} - \frac{\eta}{\sqrt{\upsilon_t}} g_{t-1} θt=θt1υt ηgt1

υ 1 = g 0 2 υ t = β 2 υ t − 1 + ( 1 − β 2 ) ( g t − 1 ) 2 \upsilon_1 = g_0^2 \qquad \upsilon_t = \beta_2\upsilon_{t-1} + (1-\beta_2)(g_{t-1})^2 υ1=g02υt=β2υt1+(1β2)(gt1)2

  • Adam

θ t = θ t − 1 − η υ t ′ + ε m t ′ \theta_t = \theta_{t-1} - \frac{\eta}{\sqrt{\upsilon_t'+\varepsilon}} m_t' θt=θt1υt+ε ηmt

m t ′ = m t 1 − β 1 t v t ′ = v t 1 − β 2 t β 1 = 0.9 β 2 = 0.999 m_t' = \frac{m_t}{1-\beta_1^t} \qquad v_t' = \frac{v_t}{1-\beta_2^t} \qquad \beta_1=0.9 \quad \beta_2=0.999 mt=1β1tmtvt=1β2tvtβ1=0.9β2=0.999

class Adam(Optimizer):
    """
    combination of SGDm and RMSProp
    """

    def __init__(self, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
        super().__init__(lr, weight_decay)
        self.eps = eps
        self.beta1, self.beta2 = betas
        self.mt = self.vt = 0
        self._t = 0

    def compute_step(self, grads):
        self._t += 1
        self.mt = self.beta1 * self.mt + (1 - self.beta1) * grads
        self.vt = self.beta2 * self.vt + (1 - self.beta2) * (grads ** 2)
        mt = self.mt / (1 - self.beta1 ** self._t)
        vt = self.vt / (1 - self.beta2 ** self._t)

        decrement = mt / (vt ** 0.5 + self.eps) * self.lr
        return decrement

我平时做视觉任务主要用SGDm和Adam两个优化器,感觉带正则化的SGDm的效果非常好,然后调一下学习率和衰减策略


参考资料:

torch.optim — PyTorch documentation
tinynn: A lightweight deep learning library

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/428194.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

web前端笔记

HTML(安装live server插件) 我们上网时所看到的所有内容其实就是body里面的内容! ! tab 快速生成一个html模板; https://www.runoob.com/tags/html-elementsdoctypes.html html的菜鸟教程! html中的元素分为两种,一种…

广州蓝景分享—13个Web开发人员都知道的基本JavaScript函数

各位编程爱好者,今天由小蓝与大家分享13个基本的JavaScript 函数,如果您是 Web前端开发人员,您应该熟悉这些函数。 您可以将本文所有 JavaScript 函数加入收藏至您的工具箱,以便在您的软件项目中尽可能使用这些片段。 1. 检索任…

恐怖的低代码平台,我 All in 了!

本文目录一、低代码平台是什么?二、目前低代码产品平台是如何分类的?三、低代码平台是怎么互相比较的?一个比喻就明白了!四、iVX平台的恐怖优势!我 All in 了!五、iVX的学习成本?总结&#xff1…

百度CTO王海峰做客《中国经济大讲堂》:文心一言,读书破万亿

当下,大语言模型热度空前,诸如文心一言、ChatGPT 等已经能够与人对话互动、回答问题、协助创作,逐渐应用于人们的工作和生活,也引发了社会热议。近日,百度首席技术官、深度学习技术及应用国家工程研究中心主任王海峰再…

asp.net Core 6 从空建立一个MVC项目,Razor组件使用

Razor组件使用MVC项目创建创建空的Web项目添加MVC框架Razor组件使用准备封装Razor组件MVC项目创建 创建一个空的项目,然后添加MVC。 创建空的Web项目 添加MVC框架 1.添加文件夹 2.添加控制器 3.添加界面 4.修改program.cs文件内容 //原生的 //var builder …

python入门:cl.exe‘ failed with exit status 2错误通用解决方案

文章目录 错误一错误二pypi.org独立安装正确安装错误一 error: Microsoft Visual C++ 14.0 or greater is required. Get it with "Microsoft C++ Build Tools": https://visualstudio.microsoft.com/visual-cpp-build-tools/ 这个错误在windows系统上安装python工…

用64位的plsql developer 连接虚拟机中的64位oracle数据库

背景:为了学习oracle,我在虚拟机上安装了oracle。并在实体机上安装了oracle客户端及plsql developer。 开始之前,先回答两个问题 为什么不在本机安装oracle? 因为oracle比较消耗资源,而我不会一直用,所以放到虚拟机里…

使用VMware虚拟机创建Ubuntu的linux系统,用Xshell连接这个系统,VScode作为编辑器时遇到的问题

使用VMware虚拟机创建Ubuntu的linux系统,用Xshell连接这个系统,VScode作为编辑器时遇到的问题1.软件2.Xshell和Xftp软件的使用3.VScode中安装了Remote Development扩展之后,点击远程资源管理器,下拉框里没有SSH-Targets4.将VScode…

Coremail AI技术发展前生今世

2023年3月15日凌晨,OpenAI发布大型多模态模型GPT-4,正式宣告AI迈入新的“黄金时代”。作为邮件安全厂商,Coremail不禁思索,在当今科技高速发展的节点上,如何将此类大型多模态模型落地至具体的邮件安全防护?…

PostgreSQL 系统表相关技术栈 实现原理(系统表初始化关系模型,SysCache RelCache)

文章目录前言基本介绍OIDpg_classpg_typepg_attribute系统表关系初始化编译阶段Initdb 阶段系统表的访问SysCache初始化 & 基本结构查找 & 插入 & 扩容RelCache初始化pg_filenode.mappg_internal.init初始化完整步骤dynahash 可扩展hash表extendible hashextendibl…

基于国产 FPGA + DSP+1553B总线 的大气数据测量装置的设计与实现

大气数据可供飞行器的控制管理系统使用,为飞行器提供飞行指导,因此实时精准 地获取大气数据在飞行器飞行过程中至关重要。本文设计并实现了一种基于 FPGA 和 DSP 的大气数据测量装置。测量装置包含五个压力传感器及两个温度传感器,可实时获取…

【springcloud 微服务】Spring Cloud Alibaba整合Sentinel详解

目录 一、前言 二、环境准备 2.1 部署sentinel管控台 2.1.1 官网下载sentinel的jar包 2.1.2 启动控制台 2.1.3 访问控制台 2.2 整合springcloud-alibaba 2.2.1 引入相关依赖 2.2.2 修改配置文件 2.2.3 增加一个测试接口 2.2.4 接口测试 三、sentinel 流控规则使用 …

基于HTML5/WebGL智慧楼宇三维可视化云平台

随着“双碳”目标政策的逐步推进,楼宇建筑作为连接人与空间的关键节点,节能潜力愈加凸显,行业热度与日俱增。如今,智慧楼宇已成群雄逐鹿的蓝海,在建筑信息化的浪潮之下,一场跨行业、跨品牌、跨领域的智慧建…

HTML5庆祝生日蛋糕烟花特效

HTML5庆祝生日蛋糕烟花特效 <!DOCTYPE html> <html> <head><meta charset"UTF-8"><title>HTML5 Birthday Cake Fireworks</title><style>canvas {position: absolute;top: 0;left: 0;z-index: -1;}</style> </h…

Kafka的概念|架构|搭建|查看命令

Kafka的概念|架构|搭建|查看命令一 Kafka 概述二 使用消息队列的好处三Kafka 定义3.1Kafka 简介3.2Kafka 的特性3.3 Kafka 系统架构3.4 Partation 数据路由规则四 kafka的架构五 搭建kafka5.1环境准备5.2安装kafka5.3 修改配置文件5.4 编辑其他二台虚拟机的配置文件5.5 编辑三台…

数据结构之第八章、二叉树

目录 一、树型结构&#xff08;了解&#xff09; 1.1概念 1.2专业术语&#xff08;重要&#xff09; 1.3树的表示形式&#xff08;了解&#xff09; ​编辑 1.4树的应用 二、二叉树&#xff08;重点&#xff09; 2.1概念 2.2两种特殊的二叉树 2.3二叉树的性质 2.4…

内、外连接查询-MySQL数据库 (头歌实践平台)

文章目的初衷是希望学习笔记分享给更多的伙伴&#xff0c;并无盈利目的&#xff0c;尊重版权&#xff0c;如有侵犯&#xff0c;请官方工作人员联系博主谢谢。 目录 第1关&#xff1a;内连接查询 任务描述 相关知识 内连接查询 编程要求 测试说明 第2关&#xff1a;外连接…

阿里云计算巢产品负责人何川:计算巢,通过数字化工具加速企业数字原生

让数字原生的中小企业用好云&#xff0c;基于云提高研发效率、构建敏捷组织、快速扩展业务&#xff0c;提高中小企业的发展韧性。在阿里云云峰会 2023 北京站的《数字原生企业创新论坛》中&#xff0c;阿里云智能计算巢产品负责人何川发表了《阿里云计算巢通过数字化工具加速企…

数据结构之第七章、队列(Queue)

目录 一、概念 二、队列 2.1队列的概念、 2.1单链表模拟实现队列 2.2双链表模拟实现队列 2.3队列的使用 2.4循环队列 2.4.1设计环形队列 三、双端队列 四、面试题 4.1用队列实现栈 4.2栈实现队列 一、概念 队列&#xff1a;只允许在一端进行插入数据操作&#xff0…

多功能财务项目管理

使用Zoho Projects的多功能财务项目管理软件改进流程并提供更好的结果。 一、使用Zoho Projects使财务项目管理更加清晰 了解为什么世界各地的财务团队都求助于Zoho Projects以获得强大且透明的财务项目管理软件。 1、跟踪每个数字 Zoho Projects的财务项目管理软件允许团队成…