PyTorch自定义学习率调度器实现指南

news2024/11/20 19:31:16

在深度学习训练过程中,学习率调度器扮演着至关重要的角色。这主要是因为在训练的不同阶段,模型的学习动态会发生显著变化。

在训练初期,损失函数通常呈现剧烈波动,梯度值较大且不稳定。此阶段的主要目标是在优化空间中快速接近某个局部最小值。然而,过高的学习率可能导致模型跳过潜在的优质局部最小值,从而限制了模型性能的充分发挥。

尽管PyTorch提供了多种预定义的学习率调度器,但在特定研究场景或需要更精细控制时,这些标准实现可能无法完全满足需求。在这种情况下,实现自定义学习率调度器成为了一个可行的解决方案。

本文将详细介绍如何通过扩展PyTorch的

LRScheduler

类来实现一个具有预热阶段的余弦衰减调度器。我们将分五个关键步骤来完成这个过程。

1、继承LRScheduler类

在PyTorch中实现自定义学习率调度器时,首先需要继承

torch.optim.lr_scheduler.LRScheduler

类。这个基类提供了管理学习率调度所需的核心功能。

通过继承

LRScheduler

,我们可以利用以下关键特性:

  1. self.optimizer:对优化器的引用,用于调整其学习率。
  2. self.base_lrs:存储优化器中所有参数组的初始学习率,可在自定义调度器中进行访问和修改。
  3. self.last_epoch:跟踪当前训练轮次,用于根据轮次数调整学习率。
  4. step()方法:在每个训练轮次后调用,用于自动更新学习率。
  5. 参数组处理:LRScheduler设计支持优化器中的多个参数组,允许对模型的不同部分应用不同的学习率调整策略。

以下是继承

LRScheduler

的基本代码结构:

 fromtorch.optim.lr_schedulerimportLRScheduler
 
 classCosineWarmupScheduler(LRScheduler):
     pass

通过继承

LRScheduler

,我们获得了上述所有功能,只需要通过实现

get_lr()

方法来定义学习率的具体变化逻辑。

2、实现构造函数

在自定义学习率调度器中,构造函数(

__init__

方法)用于初始化调度器的关键参数。这些参数定义了学习率调整的具体策略,包括预热期的长度、总训练轮次和最小学习率等。

以下是构造函数的实现示例:

 classCosineWarmupScheduler(LRScheduler):
     def__init__(self, optimizer, warmup_epochs, total_epochs, min_lr=0.0, last_epoch=-1):
         self.warmup_epochs=warmup_epochs    # 学习率线性增加的预热轮次
         self.total_epochs=total_epochs      # 总训练轮次
         self.min_lr=min_lr                  # 学习率下限
         super(CosineWarmupScheduler, self).__init__(optimizer, last_epoch)

参数说明:

  • optimizer:PyTorch优化器实例,其学习率将被调整。
  • warmup_epochs:预热阶段的轮次数,在此期间学习率线性增加。
  • total_epochs:训练的总轮次,包括预热阶段和衰减阶段。
  • min_lr:学习率的下限,衰减阶段的最终学习率不会低于此值。
  • last_epoch:上一轮的索引,用于恢复训练。默认为-1,表示从头开始训练。

3、调用父类构造函数

在自定义调度器的构造函数中,通过

super()

调用父类(

LRScheduler

)的构造函数是非常重要的。这确保了基类被正确初始化,使我们能够访问诸如

self.optimizer

self.base_lrs

self.last_epoch

等关键属性。

 super(CosineWarmupScheduler, self).__init__(optimizer, last_epoch)

这行代码不仅初始化了基类,还使得自定义调度器能够继承

LRScheduler

的其他有用方法,如

step()

get_last_lr()

4、实现get_lr()方法

get_lr()

方法是自定义调度器的核心,它定义了学习率如何随训练轮次变化的具体逻辑。在本例中,我们实现了一个包含预热阶段的余弦衰减调度策略:

预热阶段:在前

warmup_epochs

轮中,学习率从0线性增加到初始学习率。

余弦衰减阶段:预热结束后,学习率按余弦函数从初始值衰减到最小值。

以下是

get_lr()

方法的实现:

 importmath
 
 classCosineWarmupScheduler(LRScheduler):
     def__init__(self, optimizer, warmup_epochs, total_epochs, min_lr=0.0, last_epoch=-1):
         self.warmup_epochs=warmup_epochs
         self.total_epochs=total_epochs
         self.min_lr=min_lr
         super(CosineWarmupScheduler, self).__init__(optimizer, last_epoch)
 
     defget_lr(self):
         epoch=self.last_epoch+1
         
         ifepoch<=self.warmup_epochs:
             # 预热阶段:线性增加学习率
             return [base_lr*epoch/self.warmup_epochsforbase_lrinself.base_lrs]
         else:
             # 余弦衰减阶段
             decay_epochs=self.total_epochs-self.warmup_epochs
             cosine_decay=0.5* (1+math.cos(math.pi* (epoch-self.warmup_epochs) /decay_epochs))
             return [self.min_lr+ (base_lr-self.min_lr) *cosine_decayforbase_lrinself.base_lrs]

这个实现确保了学习率在预热阶段平滑增加,然后在剩余的训练过程中逐渐衰减,最终达到指定的最小值。

5、在训练流程中应用自定义调度器

实现自定义学习率调度器后,下一步是将其集成到训练流程中。以下示例展示了如何在PyTorch训练循环中初始化和使用自定义调度器:

 importtorch
 importtorch.optimasoptim
 
 # 定义模型(此处使用简单的线性模型作为示例)
 model=torch.nn.Linear(10, 1)
 
 # 初始化优化器
 optimizer=optim.SGD(model.parameters(), lr=0.1)
 
 # 初始化自定义学习率调度器
 scheduler=CosineWarmupScheduler(optimizer, warmup_epochs=5, total_epochs=50, min_lr=0.001)
 
 # 训练循环
 num_epochs=50
 
 forepochinrange(num_epochs):
     model.train()
     fordata, targetindataloader:
         optimizer.zero_grad()
         output=model(data)
         loss=criterion(output, target)
         loss.backward()
         optimizer.step()
 
     # 在每个epoch结束时更新学习率
     scheduler.step()
 
     # 记录当前学习率(用于监控)
     current_lr=scheduler.get_last_lr()[0]
     print(f"Epoch {epoch+1}/{num_epochs}, Learning Rate: {current_lr:.6f}")

在这个示例中,我们执行以下关键步骤:

  1. 定义模型和优化器。
  2. 使用之前实现的CosineWarmupScheduler初始化学习率调度器。
  3. 在每个训练epoch中:- 执行标准的前向传播、损失计算和反向传播步骤。- 调用optimizer.step()更新模型参数。- 在epoch结束时调用scheduler.step()更新学习率。
  4. 使用scheduler.get_last_lr()获取并记录当前学习率,用于监控训练过程。

关键组件说明

  • scheduler.step():这个方法在每个epoch结束时调用,根据当前epoch更新学习率。它是动态调整学习率的核心机制。
  • scheduler.get_last_lr():返回当前的学习率。在多参数组的情况下,它返回一个列表,每个元素对应一个参数组的学习率。

总结

通过继承PyTorch的

LRScheduler

类并实现自定义的

get_lr()

方法,我们可以创建灵活的学习率调度策略,以满足特定的训练需求。本指南展示的带预热的余弦衰减调度器只是众多可能实现的一个例子。

自定义学习率调度器的关键优势在于:

  1. 灵活性:可以实现任何所需的学习率调整策略。
  2. 精确控制:能够根据训练动态和模型特性精细调整学习过程。
  3. 适应性:可以轻松适应不同的模型架构和数据集特性。

在实际应用中,可能需要进行大量实验来确定最适合特定问题的学习率调度策略。通过掌握自定义调度器的实现技巧,研究人员和工程师可以更灵活地优化深度学习模型的训练过程,从而潜在地提高模型性能和训练效率。

https://avoid.overfit.cn/post/aa1e90e02eb24d9f982e2c933bdd97a7

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2165893.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ResNet残差网络:深度学习的里程碑

引言 在深度学习领域&#xff0c;卷积神经网络&#xff08;CNN&#xff09;的发展一直推动着图像识别、目标检测等任务的进步。然而&#xff0c;随着网络层数的增加&#xff0c;传统的CNN面临着梯度消失和梯度爆炸等难题&#xff0c;限制了深层网络的训练效果。为了克服这些挑…

oracle direct path read处理过程

文章目录 缘起处理过程1.AWR Report 分析2.调查direct path read发生的table3.获取sql text4.解释sql并输出执行计划&#xff1a; 结论&#xff1a;补充direct path read等待事件说明 缘起 记录direct path read处理过程 处理过程 1.AWR Report 分析 问题发生时间段awr如下…

FortiGate OSPF动态路由协议配置

1.目的 本文档针对 FortiGate 的 OSPF 动态路由协议说明。OSPF 路由协议是一种 典型的链路状态(Link-state)的路由协议,一般用于同一个路由域内。在这里,路由 域是指一个自治系统,即 AS,它是指一组通过统一的路由政策或路由协议互相交 换路由信息的网络。在这个 AS 中,所有的 …

基于JSP+Servlet+Layui实现的博客系统

> 这是一个使用 Java 和 JSP 开发的博客系统&#xff0c;并使用 Layui 作为前端框架。 > 它包含多种功能&#xff0c;比如文章发布、评论管理、用户管理等。 > 它非常适合作为 Java 初学者的练习项目。 一、项目演示 - 博客首页 - 加载动画 - 右侧搜索框可以输入…

开源服务器管理软件Nexterm

什么是 Nexterm &#xff1f; Nexterm 是一款用于 SSH、VNC 和 RDP 的开源服务器管理软件。 安装 在群晖上以 Docker 方式安装。 在注册表中搜索 nexterm &#xff0c;选择第一个 germannewsmaker/nexterm&#xff0c;版本选择 latest。 本文写作时&#xff0c; latest 版本对…

【STM32】RTT-Studio中HAL库开发教程七:IIC通信--EEPROM存储器FM24C04

文章目录 一、简介二、模拟IIC时序三、读写流程四、完整代码五、测试验证 一、简介 FM24C04D&#xff0c;4K串行EEPROM&#xff1a;内部32页&#xff0c;每个16字节&#xff0c;4K需要一个11位的数据字地址进行随机字寻址。FM24C04D提供4096位串行电可擦除和可编程只读存储器&a…

Excel 设置自动换行

背景 版本&#xff1a;office 专业版 11.0 表格内输入长信息&#xff0c;发现默认状态时未自动换行的&#xff0c;找了很久设置按钮&#xff0c;遂总结成经验帖。 操作 1&#xff09;选中需设置的单元格/区域/行/列。 2&#xff09;点击【开始】下【对齐方式】中的【自动换…

HAproxy,nginx实现七层负载均衡

环境准备&#xff1a; 192.168.88.25 &#xff08;client&#xff09; 192.168.88.26 &#xff08;HAproxy&#xff09; 192.168.88.27 &#xff08;web1&#xff09; 192.168.88.28 (web2) 192.168.88.29 &#xff08;php1&#xff09; 192.168.88.30…

基于微信小程序的教学质量评价系统ssm(lw+演示+源码+运行)

摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了基于微信小程序的教学质量评价系统的开发全过程。通过分析基于微信小程序的教学质量评价系统管理的不足&#xff0c;创建了一个计算机管理基于微信小程序的教学…

【Anti-UAV410】论文阅读

摘要 无人机在红外视频中的感知&#xff0c;对于有效反无人机是很重要的。现有的跟踪数据集存在目标大小和环境问题&#xff0c;不能完全表示复杂的逼真场景。因此作者就提出了Anti-UAV410数据集&#xff0c;该数据集总共410个视频和超过438K个标注框。为了应对复杂环境无人机跟…

丹摩智算(damodel)部署stable diffusion实验

名词解释&#xff1a; 丹摩智算&#xff08;damodel&#xff09;&#xff1a;是一款带有RTX4090&#xff0c;Tesla-P40等显卡的公有云服务器。 stable diffusion&#xff1a;是一个大模型&#xff0c;可支持文生图&#xff0c;图生图&#xff0c;文生视频等功能 一.实验目标 …

Linux-TCP重传

问题描述&#xff1a; 应用系统进行切换&#xff0c;包含业务流量切换&#xff08;即TongWeb主备切换&#xff09;和MYSQL数据库主备切换。首先进行流量切换&#xff0c;然后进行数据库主备切换。切换后发现备机TongWeb上有两批次慢请求&#xff0c;第一批慢请求响应时间在133…

【HarmonyOS】应用引用media中的字符串资源如何拼接字符串

【HarmonyOS】应用引用media中的字符串资源如何拼接字符串 一、问题背景&#xff1a; 鸿蒙应用中使用字符串资源加载&#xff0c;一般文本放置在resoutces-base-element-string.json字符串配置文件中。便于国际化的处理。当然小项目一般直接引用字符串&#xff0c;不需要加载s…

计算机毕业设计 基于Python国潮男装微博评论数据分析系统的设计与实现 Django+Vue 前后端分离 附源码 讲解 文档

&#x1f34a;作者&#xff1a;计算机编程-吉哥 &#x1f34a;简介&#xff1a;专业从事JavaWeb程序开发&#xff0c;微信小程序开发&#xff0c;定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事&#xff0c;生活就是快乐的。 &#x1f34a;心愿&#xff1a;点…

LeetCode 149. 直线上最多的点数

LeetCode 149. 直线上最多的点数 给你一个数组 points &#xff0c;其中 points[i] [xi, yi] 表示 X-Y 平面上的一个点。求最多有多少个点在同一条直线上。 示例 1&#xff1a; 输入&#xff1a;points [[1,1],[2,2],[3,3]] 输出&#xff1a;3 示例 2&#xff1a; 输入&…

【数据结构之线性表】有序表的合并(链表篇)

链表有序表的合并 思路图 将链表L1和L2按照顺序合并到L3中&#xff08;注&#xff1a;三个链表都是带头结点的&#xff09; A、要实现有序合并&#xff0c;必须先比较L1,L2两表中结点的大小&#xff0c;这里我们暂时先不讨论&#xff0c;直接根据图中来进行思路整理&#xff…

pve主要架构和重要服务介绍

Proxmox VE (PVE) 是一款开源的虚拟化平台&#xff0c;它基于 KVM (Kernel-based Virtual Machine) 和 LXC (Linux Containers) 技术&#xff0c;支持虚拟机和容器的运行。PVE 还提供高可用集群管理、软件定义存储、备份和恢复以及网络管理等企业级功能。下面介绍 PVE 的主要架…

jenkins中多个vue项目共用一个node_modules减少服务器内存的占用,对空间造成资源浪费

多个vue项目使用的node_modules一致&#xff0c;每个项目都安装一遍依赖&#xff0c;对空间造成资源浪费。 通过服务器上的软连接mklink(windows服务器&#xff0c;如果是linux服务器用ln)来共用一套node_modules windows mklink /d [链接文件或目录] [原始文件或目录] 进入…

二叉树的基本概念(下)

文章目录 &#x1f34a;自我介绍&#x1f34a;二叉树的分类满二叉树完全二叉树 &#x1f34a;二叉树的存储顺序存储[完全二叉树]链式存储 你的点赞评论就是对博主最大的鼓励 当然喜欢的小伙伴可以&#xff1a;点赞关注评论收藏&#xff08;一键四连&#xff09;哦~ &#x1f34…

无人机避障——4D 毫米波雷达 SLAM篇(一)

做无人机避障相关工作&#xff0c;3D毫米波避障测试顺利后&#xff0c;开始做4D毫米波雷达无人机避障遇到4D雷达点云需要进行处理的问题&#xff0c;查阅文献&#xff0c;发现以下这篇文章中的建图方法应该为后续思考的方向&#xff0c;特此将这个开源项目进行复现和学习&#…