【深度学习】学习率及多种选择策略

news2024/11/15 9:43:20

学习率是最影响性能的超参数之一,如果我们只能调整一个超参数,那么最好的选择就是它。相比于其它超参数学习率以一种更加复杂的方式控制着模型的有效容量,当学习率最优时,模型的有效容量最大。本文从手动选择学习率到使用预热机制介绍了很多学习率的选择策略。

这篇文章记录了我对以下问题的理解:

  • 学习速率是什么?学习速率有什么意义?
  • 如何系统地获得良好的学习速率?
  • 我们为什么要在训练过程中改变学习速率?
  • 当使用预训练模型时,我们该如何解决学习速率的问题?

本文的大部分内容都是以 fast.ai 研究员写的内容 [1], [2], [5] 和 [3] 为基础的。本文是一个更为简洁的版本,通过本文可以快速获取这些文章的主要内容。如果您想了解更多详情,请参阅参考资料。

首先,什么是学习速率?

学习速率是指导我们该如何通过损失函数的梯度调整网络权重的超参数。学习率越低,损失函数的变化速度就越慢。虽然使用低学习率可以确保我们不会错过任何局部极小值,但也意味着我们将花费更长的时间来进行收敛,特别是在被困在高原区域的情况下。

下述公式表示了上面所说的这种关系。

  1. new_weight = existing_weight — learning_rate * gradient

理解深度学习中的学习率及多种选择策略

采用小学习速率(顶部)和大学习速率(底部)的梯度下降。来源:Coursera 上吴恩达(Andrew Ng)的机器学习课程。

一般而言,用户可以利用过去的经验(或其他类型的学习资料)直观地设定学习率的最佳值。

因此,想得到最佳学习速率是很难做到的。下图演示了配置学习速率时可能遇到的不同情况。

理解深度学习中的学习率及多种选择策略

不同学习速率对收敛的影响(图片来源:cs231n)

此外,学习速率对模型收敛到局部极小值(也就是达到最好的精度)的速度也是有影响的。因此,从正确的方向做出正确的选择意味着我们可以用更短的时间来训练模型。

  
  
  1. Less training time, lesser money spent on GPU cloud compute. 😃

有更好的方法选择学习速率吗?

在「训练神经网络的周期性学习速率」[4] 的 3.3 节中,Leslie N. Smith 认为,用户可以以非常低的学习率开始训练模型,在每一次迭代过程中逐渐提高学习率(线性提高或是指数提高都可以),用户可以用这种方法估计出最佳学习率。

理解深度学习中的学习率及多种选择策略

在每一个 mini-batch 后提升学习率

如果我们对每次迭代的学习进行记录,并绘制学习率(对数尺度)与损失,我们会看到,随着学习率的提高,从某个点开始损失会停止下降并开始提高。在实践中,学习速率的理想情况应该是从图的左边到最低点(如下图所示)。在本例中,是从 0.001 到 0.01。

理解深度学习中的学习率及多种选择策略

上述方法看似有用,但该如何应用呢?

目前,上述方法在 fast.ai 包中作为一个函数进行使用。fast.ai 包是由 Jeremy Howard 开发的一种高级 pytorch 包(就像 Keras 之于 Tensorflow)。

在训练神经网络之前,只需输入以下命令即可开始找到最佳学习速率。

  
  
  1. # learn is an instance of Learner class or one of derived classes like ConvLearner
  2. learn.lr_find()
  3. learn.sched.plot_lr()

使之更好

现在我们已经知道了什么是学习速率,那么当我们开始训练模型时,怎样才能系统地得到最理想的值呢。接下来,我们将介绍如何利用学习率来改善模型的性能。

传统的方法

一般而言,当已经设定好学习速率并训练模型时,只有等学习速率随着时间的推移而下降,模型才能最终收敛。

然而,随着梯度达到高原,训练损失会更难得到改善。在 [3] 中,Dauphin 等人认为,减少损失的难度来自鞍点,而不是局部最低点。

理解深度学习中的学习率及多种选择策略

误差曲面中的鞍点。鞍点是函数上的导数为零但不是轴上局部极值的点。(图片来源:safaribooksonline)

所以我们该如何解决这个问题?

我们可以采取几种办法。[1] 中是这么说的:

…无需使用固定的学习速率,并随着时间的推移而令它下降。如果训练不会改善损失,我们可根据一些周期函数 f 来改变每次迭代的学习速率。每个 Epoch 的迭代次数都是固定的。这种方法让学习速率在合理的边界值之间周期变化。这是有益的,因为如果我们卡在鞍点上,提高学习速率可以更快地穿越鞍点。

在 [2] 中,Leslie 提出了一种「三角」方法,这种方法可以在每次迭代之后重新开始调整学习速率。

理解深度学习中的学习率及多种选择策略

Leslie N. Smith 提出的「Triangular」和「Triangular2」学习率周期变化的方法。左图中,LR 的最小值和最大值保持不变。右图中,每个周期之后 LR 最小值和最大值之间的差减半。

另一种常用的方法是由 Loshchilov&Hutter [6] 提出的预热重启(Warm Restarts)随机梯度下降。这种方法使用余弦函数作为周期函数,并在每个周期最大值时重新开始学习速率。「预热」是因为学习率重新开始时并不是从头开始的,而是由模型在最后一步收敛的参数决定的 [7]。

下图展示了伴随这种变化的过程,该过程将每个周期设置为相同的时间段。

理解深度学习中的学习率及多种选择策略

SGDR 图,学习率 vs 迭代次数。

因此,我们现在可以通过周期性跳过「山脉」的办法缩短训练时间(下图)。

理解深度学习中的学习率及多种选择策略

比较固定 LR 和周期 LR(图片来自 ruder.io)

研究表明,使用这些方法除了可以节省时间外,还可以在不调整的情况下提高分类准确性,而且可以减少迭代次数。

迁移学习中的学习速率

在 fast.ai 课程中,非常重视利用预训练模型解决 AI 问题。例如,在解决图像分类问题时,会教授学生如何使用 VGG 或 Resnet50 等预训练模型,并将其连接到想要预测的图像数据集。

我们采取下面的几个步骤,总结了 fast.ai 是如何完成模型构建(该程序不要与 fast.ai 包混淆)的:

1. 启用数据增强,precompute = True

2. 使用 lr_find() 找到损失仍在降低的最高学习速率

3. 从预计算激活值到最后一层训练 1~2 个 Epoch

4. 在 cycle_len = 1 的情况下使用数据增强(precompute=False)训练最后一层 2~3 次

5. 修改所有层为可训练状态

6. 将前面层的学习率设置得比下一个较高层低 3~10 倍

7. 再次使用 lr_find()

8. 在 cycle_mult=2 的情况下训练整个网络,直到过度拟合

从上面的步骤中,我们注意到步骤 2、5 和 7 提到了学习速率。这篇文章的前半部分已经基本涵盖了上述步骤中的第 2 项——如何在训练模型之前得出最佳学习率。

在下文中,我们会通过 SGDR 来了解如何通过重启学习速率来减少训练时间和提高准确性,以避免梯度接近零。

在最后一节中,我们将重点介绍差异学习(differential learning),以及如何在训练带有预训练模型中应用差异学习确定学习速率。

什么是差异学习

差异学习(different learning)在训练期间为网络中的不同层设置不同的学习速率。这种方法与人们常用的学习速率配置方法相反,常用的方法是训练时在整个网络中使用相同的学习速率。

理解深度学习中的学习率及多种选择策略

在写这篇文章的时候,Jeremy 和 Sebastian Ruder 发表的一篇论文深入探讨了这个问题。所以我估计差异学习速率现在有一个新的名字——差别性的精调。😃

为了更清楚地说明这个概念,我们可以参考下面的图。在下图中将一个预训练模型分成 3 组,每个组的学习速率都是逐渐增加的。

理解深度学习中的学习率及多种选择策略

具有差异学习速率的简单 CNN 模型。图片来自 [3]

这种方法的意义在于,前几个层通常会包含非常细微的数据细节,比如线和边,我们一般不希望改变这些细节并想保留它的信息。因此,无需大量改变权重。

相比之下,在后面的层,以绿色以上的层为例,我们可以从中获得眼球、嘴巴或鼻子等数据的细节特征,但我们可能不需要保留它们。

这种方法与其他微调方法相比如何?

在 [9] 中提出,微调整个模型太过昂贵,因为有些模型可能超过了 100 层。因此人们通常一次一层地对模型进行微调。

然而,这样的调整对顺序有要求,不具并行性,且因为需要通过数据集进行微调,导致模型会在小数据集上过拟合。

下表证明 [9] 中引入的方法能够在各种 NLP 分类任务中提高准确度且降低错误率。

理解深度学习中的学习率及多种选择策略

参考文献:

[1] Improving the way we work with learning rate.

[2] The Cyclical Learning Rate technique.

[3] Transfer Learning using differential learning rates.

[4] Leslie N. Smith. Cyclical Learning Rates for Training Neural Networks.

[5] Estimating an Optimal Learning Rate for a Deep Neural Network

[6] Stochastic Gradient Descent with Warm Restarts

[7] Optimization for Deep Learning Highlights in 2017

[8] Lesson 1 Notebook, fast.ai Part 1 V2

[9] Fine-tuned Language Models for Text Classification

原文链接:https://towardsdatascience.com/understanding-learning-rates-and-how-it-improves-performance-in-deep-learning-d0d4059c1c10

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1245405.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

插入排序(形象类比)

最近在看riscv手册的时候,里面有一段代码是插入排序,但是单看代码的时候有点迷,没看懂咋操作的,后来又查资料复习了一下,最终才把代码看明白,所以写篇博客记录一下。 插入排序像打扑克牌 这是我听到过比较形…

RubyMine 2023:提升Rails/Ruby开发效率的强大利器

在Rails/Ruby开发领域,JetBrains RubyMine一直以其强大的功能和优秀的性能而备受开发者的青睐。现如今,我们迎来了全新的RubyMine 2023版本,它将为开发者们带来更高效的开发体验和无可比拟的工具支持。 首先,RubyMine 2023提供了…

IDEA安装教程

文章目录 1 下载IntelliJ IDEA2 安装3 IDEA配置4 创建项目 1 下载IntelliJ IDEA ​ 官方网站上下载最新版本的IntelliJ IDEA。官方网站提供了两个版本:Community版和Ultimate版。 Community版是免费的,适用于个人和非商业用途。Ultimate版则需要付费购…

ESP32之避障

ESP32之避障 图片 程序 int Led27;//定义LED 接口 int buttonpin4; //定义光遮断传感器接口 int val;//定义数字变量val void setup() { pinMode(Led,OUTPUT);//定义LED 为输出接口 pinMode(buttonpin,INPUT);//定义避障传感器为输出接口 } void loop() {Serial.begin(9600);…

JVMj之console Java监视与管理控制台

jconsole Java监视与管理控制台 1、jconsole介绍 jconsole (java monitoring and management console)是一款基于JMX (Java Management Extensions) 的可视化监视和管理工具。 2、启动jconsole 1、在linux和windwos下通过jconsole启动即可。 2、然后会自动搜索本机运行的…

【栈】不同字符的最小子序列

题目: /*** 思路:栈,使用数组记录每个字母出现的次数,再用一个数组标记字符是否在栈中* 遍历栈,存储字符时比较栈顶字符,若小于栈顶字符并且后面有重复的字符则* 栈顶元素出栈,否则入栈。** au…

超级利器!Postman自动化接口测试让你提升测试效率,节省宝贵时间!

Postman自动化接口测试 该篇文章针对已经掌握 Postman 基本用法的读者,即对接口相关概念有一定了解、已经会使用 Postman 进行模拟请求的操作。 当前环境: Window 7 - 64 Postman 版本(免费版):Chrome App v5.5.3 …

数字乡村:科技赋能农村产业升级

数字乡村:科技赋能农村产业升级 数字乡村是指通过信息技术和数字化手段,推动农业现代化、农村经济发展和农民增收的一种新模式。近年来,随着互联网技术的飞速发展,数字乡村开始在全国范围内迅速兴起,为乡村经济注入了新…

CVE-2022-0543(Redis 沙盒逃逸漏洞)

简介 CVE-2022-0543是一个与Redis相关的安全漏洞。在Redis中,用户连接后可以通过eval命令执行Lua脚本,但在沙箱环境中脚本无法执行命令或读取文件。然而,攻击者可以利用Lua沙箱中遗留的变量package的loadlib函数来加载动态链接库liblua5.1.s…

tcp/ip协议2实现的插图,数据结构2 (19 - 章)

(68) 68 十九1 选路请求与消息 函rtalloc,rtalloc1,rtfree (69)

解决mv3版本浏览器插件,不能注入js脚本问题

文章目录 背景引入ifream解决ifream和父页面完全跨域问题参考链接 背景 浏览器插件升级mv3版本后,不能再使用content_script内容脚本向原浏览器(top)注入script标签达到注入脚本的目的。浏览器认为插入未经审核的脚本是不安全的行为。 引入…

从0开始学习JavaScript--JavaScript元编程

JavaScript作为一门灵活的动态语言,具备强大的元编程能力。元编程是一种通过操作程序自身结构的编程方式,使得程序能够在运行时动态地创建、修改、查询自身的结构和行为。本文将深入探讨JavaScript中元编程的各个方面,包括原型、反射、代理等…

揭秘周杰伦《最伟大的作品》MV,绝美UI配色方案竟然藏在这里

色彩在UI设计的基本框架中占据着举足轻重的位置。实际上,精心挑选和组合的色彩配色,往往就是UI设计成功的不二法门。在打造出一个实用的UI配色方案过程中,我们需要有坚实的色彩理论知识,同时还需要擅于从生活中观察和提取灵感。以…

MySQL索引事务基础

目录 1. 索引 1.1索引的概念 1.2索引的特点 1.3 索引的使用场景 1.4索引的使用 1.4.1查看索引 1.4.2创建索引 1.4.3删除索引 1.5索引保存的数据结构 2.事务 2.1经典例子 2.2事务的概念 2.3事务的使用 2.4事务的4个核心特性 2.5事务的并发问题 2.5.1脏读 2.5.2不可…

Python Opencv实践 - 全景图片拼接stitcher

做一个全景图片切片的程序Spliter 由于手里没有切割好的全景图片资源,因此首先写了一个切片的程序spliter。 如果有现成的切割好的待拼接的切片文件,则不需要使用spliter。 对于全景图片的拼接,需要注意一点,各个切片图片之间要有…

Linux之实现简易的shell

1.打印提示符并获取命令行 我们在使用shell的时候&#xff0c;发现我们在输入命令是&#xff0c;前面会有&#xff1a;有用户名&#xff0c;版本&#xff0c;当前路径等信息&#xff0c;这里我们可以用环境变量去获取: 1 #include <stdio.h>2 #include <stdlib.h>…

【论文解读】在上下文中学习创建任务向量

一、简要介绍 大型语言模型&#xff08;LLMs&#xff09;中的上下文学习&#xff08;ICL&#xff09;已经成为一种强大的新的学习范式。然而&#xff0c;其潜在的机制仍未被很好地了解。特别是&#xff0c;将其映射到“标准”机器学习框架是具有挑战性的&#xff0c;在该框架中…

Python BDD 框架比较之 pytest-bdd vs behave

pytest-bdd和behave是 Python 的两个流行的 BDD 测试框架&#xff0c;两者都可以用来编写用户故事和可执行的测试用例&#xff0c; 具体选择哪一个则需要根据实际的项目状况来看。 先简单看一下两者的功能&#xff1a; pytest-bdd 基于pytest测试框架&#xff0c;可以与pytest…

美团技术博客即将十周岁啦 | 欢迎分享你跟它的故事

种一棵树最好的时间是十年前&#xff0c;其次是现在。 2013年12月04日&#xff0c; 美团技术博客发布了第一篇技术文章。 时光荏苒&#xff0c;岁月如歌。 美团技术博客即将迎来自己十周岁的生日。 感谢大家的一路相伴。 十年来&#xff0c;美团技术博客累计发布了570多篇技术文…

STM32_6(TIM)

TIM定时器&#xff08;第一部分&#xff09; TIM&#xff08;Timer&#xff09;定时器定时器可以对输入的时钟进行计数&#xff0c;并在计数值达到设定值时触发中断16位计数器、预分频器、自动重装寄存器的时基单元&#xff0c;在72MHz计数时钟下可以实现最大59.65s的定时不仅…