深度学习中的学习率调度器(scheduler)分析并作图查看各方法差异

news2025/1/16 0:46:47

文章目录

    • 1. 指数衰减调度器(Exponential Decay Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 2. 余弦退火调度器(Cosine Annealing Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 3. 步长衰减调度器(Step Decay Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 4. 多项式衰减与预热调度器(Polynomial Decay with Warm-up)
      • 工作原理
      • 适用场景
      • 实现示例
    • 5. 多步衰减调度器(MultiStep Decay Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 总结
    • 参考资料

在深度学习模型训练过程中, 学习率调度器(Learning Rate Scheduler)是优化过程中不可或缺的重要组成部分。它们能够在训练的不同阶段自动调整学习率,从而提高模型的收敛速度和最终性能。选择合适的学习率调度器对于优化训练过程至关重要,不同的调度器适用于不同的训练需求和模型架构。本文将介绍几种常用的学习率调度器,并通过 PyTorch 提供的 torch.optim.lr_schedulertransformers 库中的调度器,展示具体的实现示例及其适用场景。可以通过 运行示例代码来作图查看学习率变化情况,能帮助大家更好的了解不同方法的区别。

1. 指数衰减调度器(Exponential Decay Scheduler)

请添加图片描述

工作原理

指数衰减调度器通过在每个训练步骤中以固定的速率减小学习率,从而逐步降低学习率。这种调度器适用于需要平稳且持续减小学习率的训练过程,有助于模型在训练后期稳定收敛。

适用场景

  • 稳定收敛:适用于希望学习率在整个训练过程中持续且缓慢降低,以避免训练后期的震荡。
  • 简单调整:当训练过程相对稳定,不需要复杂的学习率调整策略时,指数衰减是一个简单有效的选择。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR

# 定义优化器和参数
initial_lr = 5e-5
num_training_steps = 3000
decay_rate = 0.99
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)

# 定义指数衰减调度器
scheduler = ExponentialLR(optimizer, gamma=decay_rate)

# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):
    optimizer.step()
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Exponential Decay Scheduler')
plt.legend()
plt.grid(True)
plt.show()

2. 余弦退火调度器(Cosine Annealing Scheduler)

请添加图片描述

工作原理

余弦退火调度器通过余弦函数调整学习率,使其在训练过程中呈现周期性变化。这种调度器特别适用于处理模型训练中的振荡现象,能够在训练末期提供较低的学习率以帮助模型更好地收敛。

适用场景

  • 避免局部最优:通过周期性调整学习率,可以帮助模型跳出局部最优解。
  • 动态调整:适用于需要在训练过程中动态调整学习率以应对不同训练阶段需求的场景。
  • 模型复杂度较高:对于复杂模型,如深层神经网络,余弦退火有助于更好地探索参数空间。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# 优化器和参数定义同上
initial_lr = 5e-5
num_training_steps = 3000
T_max = 1000  # 一个周期内的步数
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)

# 定义余弦退火调度器
scheduler = CosineAnnealingLR(optimizer, T_max=T_max)

# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):
    optimizer.step()
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Cosine Annealing Scheduler')
plt.legend()
plt.grid(True)
plt.show()

3. 步长衰减调度器(Step Decay Scheduler)

在这里插入图片描述

工作原理

步长衰减调度器在训练过程中每隔一定的步数(step_size)后按指定的因子(gamma)降低学习率。这种调度器适用于需要在训练过程中分阶段减小学习率的场景,有助于模型在不同训练阶段进行有效的学习。

适用场景

  • 分阶段训练:适用于需要在训练的特定阶段进行学习率调整的任务,如先快速学习再细致调整。
  • 明确的训练阶段:当训练过程可以划分为多个明确的阶段,每个阶段需要不同学习率时,步长衰减是理想选择。
  • 资源受限的训练:在有限的训练资源下,通过分阶段调整学习率可以更有效地利用计算资源。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR

# 调度器参数
initial_lr = 5e-5
num_training_steps = 3000
step_size = 500  # 每隔 step_size 个 step,学习率衰减一次
gamma = 0.1      # 衰减因子
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)

# 定义步长衰减调度器
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):
    optimizer.step()
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Step Decay Scheduler')
plt.legend()
plt.grid(True)
plt.show()

4. 多项式衰减与预热调度器(Polynomial Decay with Warm-up)

在这里插入图片描述

工作原理

多项式衰减与预热调度器结合了学习率预热和多项式衰减的优势。训练初期通过预热阶段逐步增加学习率,随后按照多项式函数逐步降低学习率。这种调度器适用于如 BERT 等复杂模型的训练,有助于在训练初期稳定模型参数并在后期促进收敛。

适用场景

  • 复杂模型训练:适用于需要在训练初期进行稳定性的复杂模型,如 Transformer、BERT 等。
  • 防止初期震荡:通过预热阶段逐步增加学习率,可以防止训练初期由于学习率过高导致的梯度震荡。
  • 需要精细控制:适用于需要对学习率进行精细控制,以实现最佳收敛效果的任务。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from transformers import get_polynomial_decay_schedule_with_warmup

# 调度器参数
initial_lr = 5e-5
warmup_steps = 100
num_training_steps = 3000
lr_end = 1e-7  # 最低学习率
power = 2.0    # 多项式衰减的幂次
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)

# 定义多项式衰减与预热调度器
scheduler = get_polynomial_decay_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=warmup_steps, 
    num_training_steps=num_training_steps, 
    lr_end=lr_end, 
    power=power
)  # 二次衰减

# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):
    optimizer.step()
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.axvline(x=warmup_steps, color='r', linestyle='--', label='End of Warm-up')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Polynomial Decay Scheduler with Warm-up')
plt.legend()
plt.grid(True)
plt.show()

5. 多步衰减调度器(MultiStep Decay Scheduler)

在这里插入图片描述

工作原理

多步衰减调度器在预设的多个步数(milestones)时刻按指定的因子(gamma)降低学习率。这种调度器允许在训练过程中在多个关键点调整学习率,适用于需要在多个阶段显著改变学习率的训练任务。

适用场景

  • 多阶段训练:适用于训练过程中有多个关键阶段,每个阶段需要不同学习率的任务。
  • 灵活调整:当训练过程不规则或需要根据训练进展手动调整学习率时,多步衰减提供了灵活性。
  • 特定任务需求:适用于一些特定任务或模型架构,需要在特定步数后调整学习率以优化性能。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import MultiStepLR

# 调度器参数
initial_lr = 5e-5
num_training_steps = 3000
milestones = [1000, 2000]  # 指定的步数
gamma = 0.1  # 衰减因子
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)

# 定义多步衰减调度器
scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):
    optimizer.step()
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
for i, milestone in enumerate(milestones):
    if i == 0:
        plt.axvline(x=milestone, color='r', linestyle='--', label=f'Milestone at Step {milestone}')
    else:
        plt.axvline(x=milestone, color='r', linestyle='--')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('MultiStep Decay Scheduler')
plt.legend()
plt.grid(True)
plt.show()

注意:在多步衰减调度器的绘图代码中,plt.axvline 函数仅在第一个里程碑处添加标签,后续的里程碑标签设置为 None'_nolegend_',以避免图例中出现重复的标签。

总结

以上示例代码展示了不同学习率调度器的实现方式以及学习率随训练步骤变化的过程。选择合适的调度器可以根据具体任务和模型的需求来优化训练效果。以下是各类调度器的快速参考:

  • 指数衰减调度器(Exponential Decay Scheduler):适用于希望学习率持续且缓慢降低,稳定收敛的训练过程。
  • 余弦退火调度器(Cosine Annealing Scheduler):适用于需要动态调整学习率以避免局部最优,尤其适合复杂模型。
  • 步长衰减调度器(Step Decay Scheduler):适用于分阶段训练,明确划分训练阶段的任务。
  • 多项式衰减与预热调度器(Polynomial Decay with Warm-up):适用于复杂模型训练,防止初期震荡并促进后期收敛。
  • 多步衰减调度器(MultiStep Decay Scheduler):适用于多阶段训练,需要在多个关键点调整学习率的任务。

在实际应用中,可以根据模型的复杂度、数据集的特性以及训练的阶段性需求,灵活选择和调整学习率调度策略,以实现最佳的训练效果。

参考资料

  • PyTorch 官方文档 - Learning Rate Scheduler
  • Transformers 库 - 调度器

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2277259.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于springboot+vue+微信小程序的宠物领养系统

基于springbootvue微信小程序的宠物领养系统 一、介绍 本项目利用SpringBoot、Vue和微信小程序技术,构建了一个宠物领养系统。 本系统的设计分为两个层面,分别为管理层面与用户层面,也就是管理者与用户,管理权限与用户权限是不…

Nginx安全加固系列:Referrer-Policy

假设页面有一个链接,点击这个链接,会向服务器发送Http请求,加载这个链接指向的页面,在这个Http请求头里,会包含一个Referrer的标头,用于向服务器说明这个Http请求是从哪个页面跳转过来的,那么这…

SQL面试题1:连续登陆问题

引言 场景介绍: 许多互联网平台为了提高用户的参与度和忠诚度,会推出各种连续登录奖励机制。例如,游戏平台会给连续登录的玩家发放游戏道具、金币等奖励;学习类 APP 会为连续登录学习的用户提供积分,积分可兑换课程或…

LeetCode_5. 最长回文子串

最长回文子串https://leetcode.cn/problems/longest-palindromic-substring?envTypeproblem-list-v2&envId2cktkvj 给你一个字符串 s,找到 s 中最长的 回文子串 示例 1: 输入:s "babad" 输出:"bab" …

3D目标检测数据集——Waymo数据集

Waymo数据集簡介 发布首页:https://waymo.com/open/ 论文:https://openaccess.thecvf.com/content_CVPR_2020/papers/Sun_Scalability_in_Perception_for_Autonomous_Driving_Waymo_Open_Dataset_CVPR_2020_paper.pdf github:https://github.…

如何在 Linux、MacOS 以及 Windows 中打开控制面板

控制面板不仅仅是一系列图标和菜单的集合;它是通往优化个人计算体验的大门。通过它,用户可以轻松调整从外观到性能的各种参数,确保他们的电脑能够完美地适应自己的需求。无论是想要提升系统安全性、管理硬件设备,还是简单地改变桌…

Mycat读写分离搭建及配置超详细!!!

目录 一、Mycat产生背景二、Mycat介绍三、Mycat安装四、Mycat搭建读写分离1、 搭建MySQL数据库主从复制2、 基于mysql主从复制搭建MyCat读写分离 五、Mycat启动常见错误处理1、Caused by: io.mycat.config.util.ConfigException: SelfCheck### schema TESTDB refered by user u…

空指针:HttpSession异常,SpringBoot集成WebSocket

异常可能性: 404 : 请检查拦截器是否将请求拦截WebSocket握手期间HttpSession为空 HttpSession为空 方法一 : 网上参考大量的文档,有说跟前端请求域名有关系的。 反正对我来说,没啥用无法连接。 需使用 localhost&a…

【大数据】机器学习------决策树

一、基本流程 决策树是一种基于树结构的分类和回归方法,它通过对特征空间进行划分,每个内部节点表示一个特征测试,每个分支代表一个测试输出,每个叶节点代表一个类别或回归值。 特征选择:根据某种准则(如信…

服务器数据恢复—raid5故障导致上层ORACLE无法启动的数据恢复案例

服务器数据恢复环境&故障: 一台服务器上的8块硬盘组建了一组raid5磁盘阵列。上层安装windows server操作系统,部署了oracle数据库。 raid5阵列中有2块硬盘的硬盘指示灯显示异常报警。服务器操作系统无法启动,ORACLE数据库也无法启动。 服…

Day05-后端Web基础——TomcatServletHTTP协议SpringBootWeb入门

目录 Web基础知识课程内容1. Tomcat1.1 简介1.2 基本使用1.2.1 下载1.2.2 安装与卸载1.2.3 启动与关闭1.2.4 常见问题 2. Servlet2.1 快速入门2.1.1 什么是Servlet2.1.2 入门程序2.1.3 注意事项 2.2 执行流程 3. HTTP协议3.1 HTTP-概述3.1.1 介绍3.1.2 特点 3.2 HTTP-请求协议3…

【已解决】【记录】2AI大模型web UI使用tips 本地

docker desktop使用 互动 如果需要发送网页链接,就在链接上加上【#】号 如果要上传文件就点击这个➕号 中文回复 命令它只用中文回复,在右上角打开【对话高级设置】 输入提示词(提示词使用英文会更好) Must reply to the us…

Deep4SNet: deep learning for fake speech classification

Deep4SNet:用于虚假语音分类的深度学习 摘要: 虚假语音是指即使通过人工智能或信号处理技术产生的语音记录。生成虚假录音的方法有"深度语音"和"模仿"。在《深沉的声音》中,录音听起来有点合成,而在《模仿》中…

Docker save load 镜像 tag 为 <none>

一、场景分析 我从 docker hub 上拉了这么一个镜像。 docker pull tomcat:8.5-jre8-alpine 我用 docker save 命令想把它导出成 tar 文件以便拷贝到内网机器上使用。 docker save -o tomcat-8.5-jre8-alpine.tar.gz 镜像ID 当我把这个镜像传到别的机器,并用 dock…

备战蓝桥杯 队列和queue详解

目录 队列的概念 队列的静态实现 总代码 stl的queue 队列算法题 1.队列模板题 2.机器翻译 3.海港 双端队列 队列的概念 和栈一样,队列也是一种访问受限的线性表,它只能在表头位置删除,在表尾位置插入,队列是先进先出&…

工厂物流管理系统方案(二):危险品车辆专用导航系统架构设计深度剖析

本文专为IT架构师、物流技术专家、软件开发工程师及对危险品运输导航技术有深入探索需求的读者撰写,旨在全面解析危险品车辆专用导航系统的架构设计,展现其技术深度与复杂性,为行业同仁提供权威的技术参考与实践指导。如需获取危险品车辆专用…

用 Python 从零开始创建神经网络(十九):真实数据集

真实数据集 引言数据准备数据加载数据预处理数据洗牌批次(Batches)训练(Training)到目前为止的全部代码: 引言 在实践中,深度学习通常涉及庞大的数据集(通常以TB甚至更多为单位)&am…

No.1|Godot|俄罗斯方块复刻|棋盘和初始方块的设置

删掉基础图标新建assets、scenes、scripts文件夹 俄罗斯方块的每种方块都是由四个小方块组成的,很适合放在网格地图中 比如网格地图是宽10列,高20行 要实现网格的对齐和下落 Node2D节点 新建一个Node2D 添加2个TileMapLayer 一个命名为Board&…

蓝桥云客第 5 场 算法季度赛

题目: 2.开赛主题曲【算法赛】 - 蓝桥云课 问题描述 蓝桥杯组委会创作了一首气势磅礴的开赛主题曲,其歌词可用一个仅包含小写字母的字符串 S 表示。S 中的每个字符对应一个音高,音高由字母表顺序决定:a1,b2,...,z26。字母越靠后…

刀客doc:快手的商业化架构为什么又调了?

一、 1月10日,快手商业化及电商事业部进行新一轮的架构调整。作为2025年快手的第一次大调整,变动最大的是负责广告业务的商业化事业部。快手商业化将原来的8个业务中心,现在统合成了5个,行业归拢看上去更加明晰了。 根据自媒体《…