【深度学习】学习率预热和学习率衰减 (learning rate warmup decay)

news2024/11/19 21:22:07

背景

在深度学习中学习率这个超参数,在选取和调整都是有一定策略的,俗称炼丹。有时我们遇到 loss 变成 NaN 的情况大多数是由于学习率选择不当引起的。

神经网络在刚开始训练的时候模型的权重(weights)是随机初始化的,选择一个较大的学习率,可能带来模型的不稳定(振荡),因此刚训练时的学习率应当设置一个比较小的值,进而确保网络能够具有良好的收敛性。但较小的学习率会使得训练过程变得非常缓慢,于是learning rate warmup这个概念应运而生。

warmup采用以较低学习率逐渐增大至较高学习率的方式实现网络训练的“热身”阶段,随着训练的进行学习率慢慢变大,到一定程度后就可以设置的预设的学习率进行训练了,随着模型的拟合,需要的学习率也会越来越小,这时也会需要将学习率调小。学习率的warmup和学习率衰减可如下图走势:
在这里插入图片描述

warmup

Resnet论文中使用一个110层的ResNet在cifar10上训练时,先用0.01的学习率训练直到训练误差低于80%(大概训练了400个steps),然后使用0.1的学习率进行训练。其不足之处在于从一个很小的学习率一下变为比较大的学习率可能会导致训练误差突然增大。18年Facebook提出了gradual warmup来解决这个问题,即从最初的小学习率开始,每个step增大一点点,直到达到最初设置的比较大的学习率时,采用最初设置的学习率进行训练。有了这种思想后,也逐渐衍生了多种warmup方式。

在transformers有对应的warmup,如下:

  • get_constant_schedule_with_warmup,
  • get_cosine_schedule_with_warmup,
  • get_cosine_with_hard_restarts_schedule_with_warmup,
  • get_linear_schedule_with_warmup,
  • get_polynomial_decay_schedule_with_warmup

可通过from transformers import get_constant_schedule_with_warmup 方式导入。其中get_cosine_schedule_with_warmup在transformers官网中的学习率变化曲线如下:
在这里插入图片描述

pytorch中adjust learning rate

这部分内容参考了:pytorch之warm-up预热学习策略。pytorch官方文档:https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate也介绍了几种根据训练epoch调整学习率的方法:torch.optim.lr_scheduler,如果想自定义warmup和decay的话,可以使用torch.optim.lr_scheduler.LambdaLR,书写方法可以参考transformers中的写法如下:

def get_cosine_schedule_with_warmup(
    optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
):
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
    initial lr set in the optimizer.
    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        num_cycles (`float`, *optional*, defaults to 0.5):
            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
            following a half-cosine).
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.
    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)

源码链接:https://github.com/huggingface/transformers/blob/v4.25.1/src/transformers/optimization.py#L50

常用的有基于余弦退火的lr_scheduler:lr_scheduler.CosineAnnealingLR,

torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)

以初始学习率为最大学习率,以 2 ∗ T _ m a x 2 ∗ T\_ {m a x} 2T_max 为周期,在一个周期内先下降,后上升。其中:

  • T_max(int):学习率下降到最小值时的epoch数,即当epoch=T_max时,学习率下降到余弦函数最小值,当epoch>T_max时,学习率将增大
  • eta_min(float):学习率的最小值,即在一个周期中,学习率最小会下降到 eta_min,默认值为 0
  • 上一个epoch数,这个变量用来指示学习率是否需要调整。当last_epoch符合设定的间隔时,就会对学习率进行调整;当为-1时,学习率设置为初始值。

更多的lr_schedule 可参考pytorch官方文档。

总结

当然,这种使用warmup和decay的learning rate schedule大多是在bert这种预训练的大模型的微调应用中遇见的。如果是做自然语言处理相关任务的,transformers已经封装了好几个带有warmup 和 decay的lr schedule。如果不是做研究的话,这些已经封装的lr schedule直接拿来用即可。当然也可以使用pytorch中的相关模块自定义。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/77185.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

干货 | 浅谈携程大住宿研发效能提升实践

作者简介Mia &#xff0c;携程高级项目经理&#xff0c;负责酒店Devops实践&#xff0c;关注Devops/敏捷等领域。YY&#xff0c;携程敏捷教练&#xff0c;负责团队敏捷转型&#xff0c;研发效能提升实践&#xff0c;关注Agile、Devops、研发效能等领域。一、前言管理大师彼得德…

[ChatGPT为你支招]如何提高博客的质量,找到写作方向,保持动力,增加粉丝数?

0. 引言 作为一个博主&#xff0c;您可能会面临很多挑战&#xff0c;比如如何提高博客的质量&#xff0c;如何找到自己的写作方向&#xff0c;如何保持持续写作的动力&#xff0c;以及如何增加博客的粉丝数量。在这篇文章中&#xff0c;我们将为您提供一些有用的建议&#xff…

Maven打包报错:找不到符号,类BASE64Encoder,程序包sun.misc

背景 一个基于若依单体架构的多模块 Maven 项目的国产化迁移适配&#xff0c;由于是客户的代码&#xff0c;我们不用关心具体的功能实现&#xff0c;直接来做迁移即可。实施时&#xff0c;按照我们总结的整改建议调整源码&#xff0c;具体迁移适配过程可参考本专栏的其他文章。…

ADI Blackfin DSP处理器-BF533的开发详解19:LAN的网口设计

硬件准备 ADSP-EDU-BF533&#xff1a;BF533开发板 AD-HP530ICE&#xff1a;ADI DSP仿真器 软件准备 Visual DSP软件 硬件链接 功能介绍 BF533说实话用来做LAN的应用有些许勉强&#xff0c;因为他自己不带网口&#xff0c;要做的话&#xff0c;需要在总线上挂&#xff0c;那…

彻底搞懂ESLint与Prettier在vscode中的代码自动格式化

前言 前端代码格式化社区提供了两种比较常用的工具ESLint和Prettier&#xff0c;他们分别提供了对应的vscode的插件&#xff0c;二者在代码格式化方面有重叠的部分&#xff0c;规则不一致时会导致冲突。vscode作为前端开发编辑器已经越来越普遍了&#xff0c;这需要开发者在vs…

ChatGPT与搜索引擎合体,谷歌都不香了,LeCun转发|在线可玩

Alex Pine 发自 凹非寺量子位 | 公众号 QbitAI见惯了列表式搜索引擎&#xff0c;你有没有想过给它换种画风&#xff1f;有人脑洞大开&#xff0c;把艳惊四座的ChatGPT和必应搜索结合起来&#xff0c;搞出了一个智能搜索引擎&#xff1a;既有ChatGPT式的问答&#xff0c;又像普通…

VS——路径说明

$(TargetDir)输出目标的路径 $(TargetPath) 输出文件.exe的路径 $(TargetName) 项目名字 $(TargetFileName) 输出的.exe的名字 $(TargetExt) 文件的扩展名 $(ProjectDir)工程目录 目录根据下面的文件来的 $(IntDir)获取中间文件 $(OutDir)文件输出路径 $(Solu…

神马操作Kafka 竟宣布弃用 Java 8

Kafka 3.0.0 发布了&#xff1a; ​ 编辑切换为居中 添加图片注释&#xff0c;不超过 140 字&#xff08;可选&#xff09; 主要更新如下&#xff1a; The deprecation of support for Java 8 and Scala 2.12 Kafka Raft support for snapshots of the metadata topic and ot…

ADI Blackfin DSP处理器-BF533的开发详解15:RS232串口的实现(含源代码)

硬件准备 ADSP-EDU-BF533&#xff1a;BF533开发板 AD-HP530ICE&#xff1a;ADI DSP仿真器 软件准备 Visual DSP软件 硬件链接 实现原理 ADSP-EDU-BF533 开发板上设计了一个 RS232 接口&#xff0c;该接口通过 ADSP-BF53x 上的 UART 接口&#xff0c;扩展 RS232协议的芯片实…

Java学习之Object类——equals方法

Object 类是类层次结构的根类。每个类都使用 Object 作为超类。所有对象&#xff08;包括数组&#xff09;都实现这个类的方法&#xff0c;学习Object 类的六个方法——equals(Object obj)、finalize、toString、hashCode、getClass、clone 目录 和equals的对比 一、的作用 …

ChatGPT惊人语录大赏

文 | 智商掉了一地这几天ChatGPT实在太火了&#xff0c;笔者的朋友圈已经被ChatGPT的各种金句刷屏了&#xff0c;实在忍不住整理下来&#xff0c;分享给大家。ChatGPT惊人语录1&#xff1a;建议娶奶奶为妻注&#xff1a;贾母是贾宝玉的奶奶ChatGPT惊人语录2&#xff1a;角色扮演…

【allegro 17.4软件操作保姆级教程十】文件输出

目录 1.1添加光绘层叠 1.1.1添加线路层 1.1.2添加表底阻焊层 1.1.3添加表底钢网层 1.1.4添加表底丝印层 1.1.5添加钻孔层 ​1.2输出文件 1.2.1输出光绘文件 1.2.2输出钻孔文件 1.2.3输出坐标文件 1.2.4输出文件打包 1.1添加光绘层叠 在输出文件之前需要先添加光绘层…

PyTorch中学习率调度器可视化介绍

神经网络有许多影响模型性能的超参数。一个最基本的超参数是学习率(LR)&#xff0c;它决定了在训练步骤之间模型权重的变化程度。在最简单的情况下&#xff0c;LR值是0到1之间的固定值。 选择正确的LR值是具有挑战性。一方面较大的学习率有助于算法快速收敛&#xff0c;但它也…

【车载开发系列】UDS诊断---输入输出控制($0x2F)

【车载开发系列】UDS诊断—输入输出控制&#xff08;$0x2F&#xff09; UDS诊断---输入输出控制&#xff08;$0x2F&#xff09;【车载开发系列】UDS诊断---输入输出控制&#xff08;$0x2F&#xff09;一.概念定义1&#xff09;与0x31例程控制服务的关系2&#xff09;与0x22读取…

数据传送类指令(PUSH,POP,LEA)

目录 数据传送类指令 堆栈的概念&#xff1a; 进栈指令 &#xff08;PUSH&#xff09; 出栈指令&#xff08;POP&#xff09; 练习 LEA取偏移地址&#xff08;有效地址EA&#xff09;指令&#xff08;去括号&#xff09; LEA和OFFSET区别&#xff1a; 用法注意 LEA和MO…

微服务框架 SpringCloud微服务架构 微服务保护 31 限流规则 31.6 热点参数限流

微服务框架 【SpringCloudRabbitMQDockerRedis搜索分布式&#xff0c;系统详解springcloud微服务技术栈课程|黑马程序员Java微服务】 微服务保护 文章目录微服务框架微服务保护31 限流规则31.6 热点参数限流31.6.1 热点参数限流31.6.2 案例31 限流规则 31.6 热点参数限流 3…

【代码审计-.NET】基于.NET框架开发的基本特征

目录 一、.NET基本架构 1、基本构成 2、可支持语言 3、封装 4、文件 5、指向解析 6、安全认证 二、工具 1、ILSpyi 2、dnSpy 3、Reflector &#xff08;网上找的一张图谱&#xff09; 本博客只面向讲安全相关内容 一、.NET基本架构 1、基本构成 可支持语言&#xf…

web shell控制目标

文章目录一、封神台五1、为什么提权2、如何寻找exp3、使用exp提权一、封神台五 1、为什么提权 进入目标机器后权限可能不够导致无法执行高权限操作 右键地址进入终端 发现没有操作权限 提权原理&#xff1a;借助高权限的进程执行我们的指令 2、如何寻找exp 什么是exp&a…

【MOSMA】基于粘菌算法求解多目标优化问题附matlab代码

✅作者简介&#xff1a;热爱科研的Matlab仿真开发者&#xff0c;修心和技术同步精进&#xff0c;matlab项目合作可私信。 &#x1f34e;个人主页&#xff1a;Matlab科研工作室 &#x1f34a;个人信条&#xff1a;格物致知。 更多Matlab仿真内容点击&#x1f447; 智能优化算法 …

【Java版oj】day02倒置字符串

目录 一、原题再现 二、问题分析 三、完整代码 一、原题再现 倒置字符串_牛客题霸_牛客网 描述 将一句话的单词进行倒置&#xff0c;标点不倒置。比如 I like beijing. 经过函数后变为&#xff1a;beijing. like I 输入描述&#xff1a; 每个测试输入包含1个测试用例&#x…