深度学习中的学习率设置技巧与实现详解

news2024/9/30 5:34:24

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

深度学习中的学习率设置技巧与实现详解

(封面图由文心一格生成)

深度学习中的学习率设置技巧与实现详解

深度学习中的学习率是一个非常重要的超参数,对模型的训练和结果影响极大。在深度学习模型中,学习率决定了参数更新的步长,因此合理设置学习率对于优化算法的收敛速度、模型的训练效果以及泛化性能都有很大的影响。本文将介绍深度学习中的学习率设置技巧,包括常用的学习率衰减方法、自适应学习率方法以及学习率预热等。

1. 常用的学习率衰减方法

1.1 学习率衰减

学习率衰减是一种常见的优化算法,它可以随着训练的进行,逐渐减小学习率,从而使得模型在训练初期能够快速地收敛,而在训练后期能够更加稳定地更新参数。学习率衰减的方法有很多种,包括Step Decay、Exponential Decay、Polynomial Decay等。

Step Decay是一种常见的学习率衰减方法,它是在训练的过程中,根据固定的步数对学习率进行逐步地降低。例如,假设初始学习率为0.1,每隔10个epoch将学习率降低10倍,那么当训练到第11个epoch时,学习率将变为0.01,当训练到第21个epoch时,学习率将变为0.001,以此类推。这种方法简单易行,但是需要手动设置衰减的步数和衰减的幅度,不太灵活。

Exponential Decay是一种常见的指数衰减方法,它可以根据训练的epoch数来逐渐减小学习率。具体而言,Exponential Decay方法的学习率衰减规则如下:

α = α 0 ⋅ e − k t \alpha=\alpha_0 · e^{-kt} α=α0ekt

其中, α 0 \alpha_0 α0表示初始学习率, k k k为衰减系数, t t t表示训练的epoch数。随着训练的进行, t t t会不断增大,因此学习率会不断减小。Exponential Decay方法可以通过设置不同的 k k k值来控制学习率的衰减速度,从而达到更好的训练效果。

Polynomial Decay是一种常见的多项式衰减方法,它可以通过多项式函数来逐渐减小学习率。具体而言,Polynomial Decay方法的学习率衰减规则如下:
α = α ⋅ ( 1 − t T ) p \alpha=\alpha\cdot (1 - \frac{t}{T})^p α=α(1Tt)p
其中, α 0 \alpha_0 α0表示初始学习率, p p p为多项式的幂次数, t t t表示训练的epoch数, T T T为总的训练epoch数。随着训练的进行, t t t会不断增大,因此学习率会不断减小,同时随着 p p p的增大,学习率的衰减速度也会加快。

1.2 余弦退火

余弦退火(Cosine Annealing)是一种新兴的学习率衰减方法,它通过余弦函数来逐渐减小学习率,从而达到更好的训练效果。具体而言,余弦退火方法的学习率衰减规则如下:
α = α 0 ⋅ 1 + cos ⁡ ( π ⋅ t T ) 2 \alpha = \alpha_0 \cdot \frac{1 + \cos(\frac{\pi \cdot t}{T})}{2} α=α021+cos(Tπt)

其中, α 0 \alpha_0 α0表示初始学习率, t t t表示训练的epoch数, T T T为总的训练epoch数。随着训练的进行, t t t会不断增大,因此学习率会不断减小,同时余弦函数的周期也会不断缩小,从而使得学习率在训练过程中逐渐降低。

1.3 One Cycle Learning Rate

One Cycle Learning Rate是一种比较新的学习率衰减方法,它通过在训练初期使用一个较大的学习率,从而快速地收敛到一个局部最优解,然后在训练后期使用一个较小的学习率,从而逐步地优化模型。具体而言,One Cycle Learning Rate方法的学习率变化规则如下:

  • 在训练初期,使用较大的学习率(如初始学习率的10倍),从而快速地收敛到一个局部最优解;
  • 然后在训练中期,使用较小的学习率,从而逐步地优化模型;
  • 最后在训练后期,再次使用较小的学习率,从而让模型更加稳定。
    One Cycle Learning Rate方法可以有效地提高模型的训练速度和泛化性能,但是需要仔细调整超参数,否则容易导致模型的过拟合。

2. 自适应学习率方法

除了学习率衰减方法之外,深度学习中还有很多自适应学习率方法,包括Adagrad、Adadelta、Adam等。这些方法都是基于梯度信息来自适应地调整学习率,从而在训练过程中更加稳定和高效。

2.1 Adagrad

Adagrad是一种自适应学习率方法,它可以根据参数梯度的大小来动态地调整学习率。具体而言,Adagrad方法的学习率更新规则如下:

其中, α 0 \alpha_0 α0表示初始学习率, g i g_i gi表示参数的梯度, ϵ \epsilon ϵ是一个非常小的常数,用于防止分母为0。Adagrad方法的优点在于它可以根据参数的梯度大小自适应地调整学习率,从而更好地适应不同的数据分布和参数更新。但是Adagrad方法也有一些缺点,比如需要存储梯度平方和的累积值,导致内存占用较大;另外,由于学习率是逐渐减小的,因此可能会导致模型在后期训练时收敛速度变慢。

2.2 Adadelta

Adadelta是一种自适应学习率方法,它可以根据参数梯度的大小和历史梯度信息来动态地调整学习率Adadelta方法的优点在于它可以动态地调整学习率,并且不需要存储梯度平方和的累积值,因此内存占用较小。但是Adadelta方法也有一些缺点,比如需要手动设置一些超参数,例如平均梯度的衰减率和初始的平均梯度值等。

2.3 Adam

Adam是一种自适应学习率方法,它可以根据参数梯度的大小和历史梯度信息来动态地调整学习率,并且还可以适应不同的数据分布和参数更新。具体而言,Adam方法的学习率更新规则如下:
m t = β 1 ⋅ m t − 1 + ( 1 − β 1 ) ⋅ g t v t = β 2 ⋅ v t − 1 + ( 1 − β 2 ) ⋅ g t 2 m ^ t = m t 1 − β 1 t v ^ t = v t 1 − β 2 t Δ x t = − α v ^ t + ϵ ⋅ m ^ t \begin{aligned}m_t &= \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t \\ v_t &= \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\ \hat{m}_t &= \frac{m_t}{1 - \beta_1^t} \\ \hat{v}_t &= \frac{v_t}{1 - \beta_2^t} \\ \Delta x_t &= -\frac{\alpha}{\sqrt{\hat{v}_t}+\epsilon} \cdot \hat{m}_t\end{aligned} mtvtm^tv^tΔxt=β1mt1+(1β1)gt=β2vt1+(1β2)gt2=1β1tmt=1β2tvt=v^t +ϵαm^t

其中, g t g_t gt表示参数的梯度, m t m_t mt v t v_t vt分别表示梯度的一阶和二阶矩, β 1 \beta_1 β1 β 2 \beta_2 β2是衰减率, m ^ t \hat{m}_t m^t v ^ t \hat{v}_t v^t分别表示修正后的一阶和二阶矩, α \alpha α表示初始学习率, ϵ \epsilon ϵ是一个非常小的常数,用于防止分母为0。
Adam方法的优点在于它不仅可以动态地调整学习率,还可以适应不同的数据分布和参数更新,从而在训练过程中更加稳定和高效。但是Adam方法也有一些缺点,比如需要手动设置一些超参数,例如衰减率和初始学习率等。
三、学习率预热
学习率预热是一种常见的训练技巧,它可以在训练初期使用一个较小的学习率,从而避免模型在训练初期过度更新参数,导致模型不稳定。具体而言,学习率预热的方法是在训练前先使用一个较小的学习率进行一些预热操作,例如在训练初期进行一些预热的epoch,然后再逐步地增加学习率,从而使得模型更加稳定和高效。
学习率预热的方法可以有效地避免模型在训练初期过度更新参数,导致模型不稳定,同时也可以加速模型的收敛速度,提高训练效率和泛化性能。

3. 代码实现

下面是使用PyTorch实现常见的学习率衰减方法和自适应学习率方法的代码示例:

3.1 Step Decay

import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
for epoch in range(num_epochs):
	# train
	train_loss, train_acc = train(...)
	# update learning rate
	scheduler.step()
	# validation
	val_loss, val_acc = validate(...)
	# print results
	print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}, Val Loss: {:.4f}, Val Acc: {:.4f}'.format(epoch+1, num_epochs, train_loss, train_acc, val_loss, val_acc))

3.2 Exponential Decay

import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

for epoch in range(num_epochs):
	# train
	train_loss, train_acc = train(...)
	
	# update learning rate
	scheduler.step()
	
	# validation
	val_loss, val_acc = validate(...)
	
	# print results
	print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}, Val Loss: {:.4f}, Val Acc: {:.4f}'.format(epoch+1, num_epochs, train_loss, train_acc, val_loss, val_acc))

3.3 Cosine Annealing

import torch.optim as optim

optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

for epoch in range(num_epochs):
    # train
    train_loss, train_acc = train(...)
    
    # update learning rate
    scheduler.step()
    
    # validation
    val_loss, val_acc = validate(...)
    
    # print results
    print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}, Val Loss: {:.4f}, Val Acc: {:.4f}'.format(
        epoch+1, num_epochs, train_loss, train_acc, val_loss, val_acc))

3.4 Adagrad

import torch.optim as optim

optimizer = optim.Adagrad(net.parameters(), lr=0.1)

for epoch in range(num_epochs):
    # train
    train_loss, train_acc = train(...)
    
    # update learning rate
    optimizer.step()
    optimizer.zero_grad()
    
    # validation
    val_loss, val_acc = validate(...)
    
    # print results
    print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}, Val Loss: {:.4f}, Val Acc: {:.4f}'.format(
        epoch+1, num_epochs, train_loss, train_acc, val_loss, val_acc))

3.5 Adadelta

import torch.optim as optim

optimizer = optim.Adadelta(net.parameters(), lr=0.1, rho=0.9, eps=1e-6)

for epoch in range(num_epochs):
    # train
    train_loss, train_acc = train(...)
    
    # update learning rate
    optimizer.step()
    optimizer.zero_grad()
    
    # validation
    val_loss, val_acc = validate(...)
    
    # print results
    print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}, Val Loss: {:.4f}, Val Acc: {:.4f}'.format(
        epoch+1, num_epochs, train_loss, train_acc, val_loss, val_acc))

3.6 Adam

import torch.optim as optim

optimizer = optim.Adam(net.parameters(), lr=0.1, betas=(0.9, 0.99), eps=1e-8)

for epoch in range(num_epochs):
    # train
    train_loss, train_acc = train(...)
    
    # update learning rate
    optimizer.step()
    optimizer.zero_grad()
    
    # validation
    val_loss, val_acc = validate(...)
    
    # print results
    print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}, Val Loss: {:.4f}, Val Acc: {:.4f}'.format(epoch+1, num_epochs, train_loss, train_acc, val_loss, val_acc))

4. 总结

本文介绍了深度学习中常见的学习率设置技巧,包括学习率衰减方法和自适应学习率方法。学习率衰减方法可以根据训练的进展情况动态地调整学习率,从而提高模型的训练效率和泛化性能;自适应学习率方法可以根据参数梯度的大小和历史梯度信息来动态地调整学习率,从而在训练过程中更加稳定和高效。此外,学习率预热也是一种常见的训练技巧,它可以在训练初期使用一个较小的学习率,从而避免模型在训练初期过度更新参数,导致模型不稳定。

在代码实现方面,PyTorch提供了许多内置的学习率调度器和自适应学习率优化器,可以方便地实现各种学习率设置技巧。通过合理地选择和使用这些工具,可以帮助我们更加高效地训练深度学习模型,并获得更好的训练效果和泛化性能。


❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/429267.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【计算机网络】第二章 应用层 3

Email应用的特性 异步应用,方便用户 提供一对多通信 价格低廉 主要包含: o 用户代理(user agents,UA) o 邮件服务器(mail servers) o 邮件传输协议:SMTP o 邮件访问协议:POP3或IMAP 用户代理 o 客户端程序 o 提供编辑、发…

【Linux】认识协议

🎇Linux: 博客主页:一起去看日落吗分享博主的在Linux中学习到的知识和遇到的问题博主的能力有限,出现错误希望大家不吝赐教分享给大家一句我很喜欢的话: 看似不起波澜的日复一日,一定会在某一天让你看见坚持…

【Java 编程语言】——JDK 安装

JDK 安装 文章目录JDK 安装一、JDK的选择与下载1.JDK的选择2.JDK的下载二、Java环境变量的配置一、JDK的选择与下载 1.JDK的选择 目前的JDK的版本更新很快,已经到了JDK20了。但是对于普通的开发或者学习人员来说,选择较为稳定的JDK是更为合适的选择。当…

干货丨AI常见问题及处理方法

AI软件在运行时经常会容易报错或者操作不成功,问题及处理方法分享给大家 01 当AI中色板里面没有颜色可选 原因:将图片素材直接以新窗口打开,所以显示的是位图文件。 解决办法:重新新建文件,然后将图片拖入新建文件中…

让Ai来告诉你Linux应该怎么学

今天在slack上添加了Claude,他属于ChatGPT的最强竞品,支持中文,体验非常舒适,也并不像国内某些自建AI那样弱智。 至于Linux要怎么学,就让Claude来回答吧。 你能告诉我Liunx应该怎么学吗? 学习Linux,我有…

Elasticsearch:使用 Elastic APM 监控 Android 应用程序

作者:Alexander Wert, Cesar Munoz 人们通过私人和专业的移动应用程序在智能手机上处理越来越多的事情。 拥有成千上万甚至数百万的用户,确保出色的性能和可靠性是移动应用程序和相关后端服务的提供商和运营商面临的主要挑战。 了解移动应用程序的行为、…

【Mysql系列】——详细剖析数据库中的存储引擎

【Mysql系列】——详细剖析数据库中的存储引擎😎前言🙌存储引擎什么是存储引擎?Mysql的体系结构:Mysql的体系结构分为四层:连接层服务层引擎层存储层存储引擎的查看存储引擎的指定存储引擎的特点InnoDB介绍InnoDB特点I…

论文浅尝 | 大语言模型在in-context learning中的不同表现

笔记整理:毕祯,浙江大学博士,研究方向为知识图谱、自然语言处理链接:https://arxiv.org/pdf/2303.03846.pd本文是谷歌等机构最新发表的论文,旨在研究大模型上下文学习的能力。这篇论文研究了语言模型中的上下文学习是如…

数影周报:现代汽车发生数据泄露事件;淘宝天猫集团完成组织调整

本周看点:现代汽车发生数据泄露事件;微软会议应用Teams 新功能可禁用/启用脏话过滤器;欧洲隐私监管机构创建ChatGPT工作组;淘宝天猫集团完成组织调整;阿里巴巴再向Lazada投资3.529亿美元...... 数据安全那些事 现代汽车…

C语言数据结构-队列的知识总结归纳

队列的知识总结归纳一.队列的基本概念二.循环队列的顺序存储常见的基本操作以及详细图解1.队列的顺序存储结构类型定义2.初始化队列初始化队列示意图3.判断队空4.判断队列是否满的三种方法图示5.入队或进队入队的示意图6出队或退队出队的图示三. 队列的链式存储结构四. 链式队列…

AutoGPT自主人工智能用法和使用案例

介绍 AutoGPT是什么:自主人工智能,不需要人为的干预,自己完成思考和决策【比如最近比较热门的用AutoGPT创业,做项目–>就是比较消耗token】 AI 自己上网、自己使用第三方工具、自己思考、自己操作你的电脑【就是操作你的电脑…

缺省函数,函数重载,引用简单介绍的补充说明

TIPS 命名空间域的作用实际上相当于把部分变量的名称给他隔离起来,这样的话就可以减少变量名的冲突。命名空间是对全局域当中的这些变量啊,函数啊,类型啊进行一个封装与隔离,可以防止你和我之间的冲突,也可以防止与库…

leetcode:各位相加(数学办法详解)

前言:内容包括:题目,代码实现,大致思路 目录 题目: 代码实现: 大致思路: 题目: 给定一个非负整数 num,反复将各个位上的数字相加,直到结果为一位数。返回…

【云原生Docker】11-Docker镜像仓库

【云原生|Docker】11-Docker Registry(官方仓库) 文章目录【云原生|Docker】11-Docker Registry(官方仓库)前言docker registry简介操作示例hyper/docker-registry-web前言 ​ 前面我们所有的docker操作,使用的镜像都是在docker官方的镜像仓库下载,当然这…

总结825

学习目标: 4月(复习完高数18讲内容,背诵21篇短文,熟词僻义300词基础词) 今日复习: 手绘高数第11讲思维导图,回顾线性代数第一讲 学习内容: 第12讲二重积分视频,纠正11讲…

手势控制的机器人手臂

将向你展示如何构建机械手臂并使用手势和计算机视觉来控制它。下面有一个在开发阶段的机械手臂的演示视频。展示开发中的手臂的演示视频:https://youtu.be/KwiwetZGv0s如图所示,该过程首先用摄像头捕捉我的手及其标志。通过跟踪特定的界标,例…

300到400的蓝牙耳机有哪些推荐?2023年值得入手的性价比蓝牙耳机

今年依旧是真无线蓝牙耳机快速发展的一年,市面上都有着各式各样的蓝牙耳机,一时间难以辨认哪些款式更适合自己,今天给大家介绍的是300元左右的蓝牙耳机,那这个价位的耳机到底怎么样呢?其实,300左右的蓝牙耳…

Qt 窗口置顶

文章目录一、前言二、示例代码三、补充说明四、窗口透明五、参考一、前言 我们使用QT进行界面开发时,可能会遇到需要将窗口置顶的情况。最常见的就是,需要制作一个悬浮工具栏,悬浮菜单,甚至是悬浮的画板。这就意味这我们需要将这个…

Javascript40行代码实现基础MVC原理。

参考文章 M数据层 V视图 C控制器 先来一个dom结构&#xff0c;一个p标签&#xff0c;用来展示输入的内容&#xff0c;一个input标签&#xff0c;用来输入内容⬇️ <p id"mvcp"></p> <input id"mvc"></input>创建Model类&#x…

第二部分——长难句——第一章——并列句

conjunction(and,but,if,when(while)) 想把两个句子&#xff08;多件事&#xff09;连在一块&#xff0c;就必须加上连词。 所以长难句到底是啥&#xff1f; 所以长难句&#xff08;直白表达&#xff0c;并不是语法表述&#xff09;就是几个简单句多家上几个连接词就齐活了&am…