学习率范围测试(LR Finder)脚本

news2024/12/24 9:29:06

简介

深度学习中的学习率是模型训练中至关重要的超参数之一。合适的学习率可以加速模型的收敛,提高训练效率,而不恰当的学习率可能导致训练过慢或者无法收敛。为了找到合适的学习率,LR Finder成为了一种强大的工具。

学习率范围测试(LR Finder)是一种通过逐渐增加学习率来观察模型在不同学习率下的性能变化的方法。这个过程可以帮助我们找到一个合适的初始学习率,有助于训练过程的稳定和加速。在本文中,我们将深入探讨 LR Finder 的原理、实现和应用,以及如何在实际的分类项目中充分利用这个强大的工具。

学习率在深度学习中的关键

深度学习当中,学习率是一个至关重要的超参数,直接影响了模型的训练性能和收敛速度。因为它决定了每一次参数更新的步长,过大的学习率可能会导致模型在训练过程中发散,而过小的学习率则可能导致模型训练缓慢甚至停滞。

学习率调整方法

下面是一些常见的学习率调整方法,它们在不同的场景和问题上表现出色。

1.常数学习率

常数学习率是最简单的学习率调整方法,即在整个训练过程中,学习率始终保持不变,但是它只是比较适用于数据集较小、模型已经比较稳定的情况。

用法很简单,就是在训练开始时选择一个合适的学习率,比如0.001或0.01,后续的训练过程中就不再变化了。这个值的设定就比较依赖于的经验了。

2.学习率衰减

学习率衰减是随着训练的进行逐渐减小学习率。它可以分为定期调整和根据模型性能调整两种方式,它适用于需要在训练初期更加收敛,而在后期更加稳定的情况。

  • 定期调整:每隔一段训练轮次或步骤,学习率按照一定的衰减率进行调整。
  • 性能调整:根据模型在验证集上的性能进行调整,性能停滞时降低学习率,性能提高时保持学习率不变或轻微提高。

3.自适应学习率

自适应学习率方法根据模型参数的梯度信息和历史信息来自动调整学习率。主要应用于不同参数具有不同特性或数据分布不均匀的情况。

  • Adam: 结合了动量法和自适应学习率的方法,对不同参数应用不同的学习率。
  • Adagrad: 根据参数的历史梯度信息自适应地调整学习率。
  • RMSProp: 在 Adagrad 的基础上加入了一个衰减系数,以防止学习率过早地降低。

学习率范围测试(LR Finder)

学习率范围测试是一种通过逐渐增加学习率来观察模型性能的方法,从而找到一个合适的初始学习率,适合在训练初期选择合适的学习率范围。

实现的方法:

逐渐增加学习率,从一个较小的学习率开始,逐渐增加学习率直到性能开始下降,在学习率范围内观察模型的性能曲线,找到性能峰值对应的学习率。

class FindLR(_LRScheduler):
    """
    exponentially increasing learning rate

    Args:
        optimizer: optimzier(e.g. SGD)
        num_iter: totoal_iters
        max_lr: maximum  learning rate
    """
    def __init__(self, optimizer, max_lr=10, num_iter=100, last_epoch=-1):

        self.total_iters = num_iter
        self.max_lr = max_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [base_lr * (self.max_lr / base_lr) ** (self.last_epoch / (self.total_iters + 1e-32)) for base_lr in self.base_lrs]

在 get_lr 方法中,通过指数增长的方式计算当前迭代次数下每个参数的学习率。这个方法返回一个学习率列表,其中每个元素对应网络中的一个参数。这个学习率列表将被用于更新优化器中的学习率。

import torch
import torch.optim as optim
from pyzjr.core.lr_scheduler import _LRScheduler
import matplotlib.pyplot as plt

class FindLR(_LRScheduler):
    def __init__(self, optimizer, max_lr=10, num_iter=100, last_epoch=-1):
        self.total_iters = num_iter
        self.max_lr = max_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [base_lr * (self.max_lr / base_lr) ** (self.last_epoch / (self.total_iters + 1e-32)) for base_lr in self.base_lrs]

x = torch.arange(0, 100, 1)
y = torch.sin(0.1 * x) + 0.1 * torch.randn(100)

model = torch.nn.Linear(1, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)

lr_finder = FindLR(optimizer, max_lr=10, num_iter=100)
lr_values = []

for epoch in range(100):
    outputs = model(x.unsqueeze(1).float())
    loss = torch.nn.functional.mse_loss(outputs.squeeze(), y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    lr_finder.step()

    current_lr = optimizer.param_groups[0]['lr']
    lr_values.append(current_lr)

    print(f"Iteration: {epoch}, Loss: {loss.item():.4f}, LR: {current_lr:.8f}")

plt.switch_backend('TkAgg')
plt.plot(lr_values)
plt.xlabel('Iteration')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Range Test')
plt.show()


迭代与学习率的曲线图:

下面我们将采用pyzjr中定义的 lr_finder 对CIFAR-100进行学习率范围测试:

import torch.nn as nn

import matplotlib
matplotlib.use('Agg')
from utils import get_network
from dataset import get_train_loader
from pyzjr.dlearn.learnrate import lr_finder

CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

if __name__ == '__main__':
    class parser_args():
        def __init__(self):
            self.net = "vgg16"
            self.batch_size = 64
            self.base_lr = 1e-7
            self.max_lr = 10
            self.num_iter = 100
            self.Cuda = True

    args = parser_args()

    mean = CIFAR100_TRAIN_MEAN
    std = CIFAR100_TRAIN_STD
    train_loader = get_train_loader(mean, std, batch_size=4)

    net = get_network(args)  # 网络模型定义

    loss_function = nn.CrossEntropyLoss()
    lrfinder = lr_finder(net, train_loader, loss_function)
    lrfinder.update()   # 不断迭代
    lrfinder.plotshow()   # 绘制图像并显示

    lrfinder.save(path="result.jpg")   # 保存图像到指定路径

输出图像result.jpg:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1199703.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Django的ORM操作

文章目录 1.ORM操作1.1 表结构1.1.1 常见字段和参数1.1.2 表关系 2.ORM2.1 基本操作2.2 连接数据库2.3 基础增删改查2.3.1 增加2.3.2 查找2.3.4 删除2.3.4 修改 1.ORM操作 orm,关系对象映射,本质翻译的。 1.1 表结构 实现:创建表、修改表、…

思维模型 暗示效应

本系列文章 主要是 分享 思维模型,涉及各个领域,重在提升认知。无形中引导他人的思想和行为。 1 暗示效应的应用 1.1 暗示效应在商业品牌树立中的应用 可口可乐的品牌形象:可口可乐通过广告、包装和营销活动,向消费者传递了一种…

【递归】求根节点到叶节点数字之和(Java版)

目录 1.题目解析 2.讲解算法原理 3.代码 1.题目解析 LCR 049. 求根节点到叶节点数字之和 给定一个二叉树的根节点 root ,树中每个节点都存放有一个 0 到 9 之间的数字。 每条从根节点到叶节点的路径都代表一个数字: 例如,从根节点到叶节点…

伙伴(buddy)系统原理

一、伙伴算法的由来 在实际情况中,操作系统必须能够在任意时刻申请和释放任意大小的内存,该函数的实现需要考虑延时问题和碎片问题。 延时问题指的是系统查找到可分配单元的时间变长,例如程序请求分配一个64KB的内存空间,系统查看…

Technology Strategy Patterns 学习笔记9 - bringing it all together

1 Patterns Map 2 Creating the Strategy 2.1 Ansoff Growth Matrix 和owth-share Matrix 区别参见https://fourweekmba.com/bcg-matrix-vs-ansoff-matrix/ 3 Communicating

STM32F407: CMSIS-DSP库的移植(基于库文件)

目录 1. 源码下载 2. DSP库源码简介 3.基于库的移植(DSP库的使用) 3.1 实验1 3.2 实验2 4. 使用V6版本的编译器进行编译 上一篇:STM32F407-Discovery的硬件FPU-CSDN博客 1. 源码下载 Github地址:GitHub - ARM-software/CMSIS_5: CMSIS Version 5…

开发者测试2023省赛--Square测试用例

测试结果 官方提交结果 EclEmma PITest 被测文件 [1/7] Square.java /*** This class implements the Square block cipher.** <P>* <b>References</b>** <P>* The Square algorithm was developed by <a href="mailto:Daemen.J@banksys.co…

AWS云服务器EC2实例进行操作系统迁移

AWS云服务器EC2实例进行操作系统迁移 文章目录 AWS云服务器EC2实例进行操作系统迁移1. 亚马逊EC2云服务器简介1.2 亚马逊EC2云务器与弹性云服务器区别 2. 亚马逊EC2云服务器配置流程2.1 亚马逊EC2云服务器实例配置2.1.1 EC2实例购买教程2.1.1 EC2实例初始化配置2.1.2 远程登录E…

Gold-YOLO:基于收集-分配机制的高效目标检测器

文章目录 摘要1、简介2、相关工作2.1、实时目标检测器2.2、基于Transformer的目标检测2.3、用于目标检测的多尺度特征 3、方法3.1、预备知识3.2、低级收集和分发分支3.3、高阶段收集和分发分支3.4、增强的跨层信息流3.5、遮罩图像建模预训练 4、实验4.1、设置4.2、比较4.3.2、 …

recycleView(三)动态修改背景色

效果图 1.关键代码 1. // 定义一个变量来记录滑动的距离var scrollDistance 0// 在RecycleView的滑动监听器中更新滑动的距离binding.recyclerView.addOnScrollListener(object : RecyclerView.OnScrollListener() {override fun onScrolled(recyclerView: RecyclerView, …

Pinia 状态管理器 菠萝

Pinia介绍&#xff1a; Pinia 是 Vue 的专属状态管理库&#xff0c;它允许你跨组件或页面共享状态。 Pinia 大小只有 1kb 左右&#xff0c;超轻量级&#xff0c;你甚至可能忘记它的存在&#xff01; 相比 Vuex,Pinia 的优点&#xff1a; 更贴合 Vue 3 的 Composition API 风…

Leetcode—191.位1的个数【简单】

2023每日刷题&#xff08;二十七&#xff09; Leetcode—191.位1的个数 实现代码 int hammingWeight(uint32_t n) {int ans 0;for(int i 0; i < 32; i) {if(n & ((long long)1 << i)) {ans;}}return ans; }运行结果 翻转比特1思路 就解法一的代码实现来说&am…

基于Transformer架构的ChatGPT:三步带你了解它的工作原理

作者&#xff1a;Insist-- 个人主页&#xff1a;insist--个人主页 梦想从未散场&#xff0c;传奇永不落幕&#xff0c;博主会持续更新优质网络知识、Python知识、Linux知识以及各种小技巧&#xff0c;愿你我共同在CSDN进步 目录 一、Transformer架构 1. 自注意力层 2. 前馈神…

Django中简单的增删改查

用户列表展示 建立列表 views.py def userlist(request):return render(request,userlist.html) urls.py urlpatterns [path(admin/, admin.site.urls),path(userlist/, views.userlist), ]templates----userlist.html <!DOCTYPE html> <html lang"en">…

PCB知识补充

系列文章目录 文章目录 系列文章目录参考文献PCB知识互连线电阻过孔/铜箔电流能力铜箔载流能力过孔载流能力 热设计电磁兼容及部分要求 参考文献 [1]牛森,张敏娟,银子燕.高速PCB多板互联的电源完整性分析[J].单片机与嵌入式系统应用,2023,23(09). [2]陈之秀,刘洋,张涵舒等.高…

【Python】KDtree的调用

前言 查询点集中与目标点最近的k个点 from scipy.spatial import cKDTree import numpy as npdata np.array([[1,2],[1,3],[4,5]]) # 生成 100 个三维数据 tree cKDTree(data) # 创建 K-D Tree result tree.query(np.array([5, 5]), k2) # 查询与 [0.5, 0.5, 0.5] 最近的三…

(离散数学)命题及命题的真值

答案&#xff1a; &#xff08;5&#xff09;不是命题&#xff0c;因为真值不止一个 &#xff08;6&#xff09;不是命题&#xff0c;因为不是陈述句 &#xff08;7&#xff09;不是命题&#xff0c;因为不是陈述句 &#xff08;8&#xff09;不是命题&#xff0c;真值不唯一

NodeJs - 实现当前线程唯一的单例对象

NodeJs - 实现当前线程唯一的单例对象 一. 实现当前线程唯一的单例对象 一. 实现当前线程唯一的单例对象 Java 里面&#xff0c;一般都把这种和当前线程绑定的单例对象存储到ThreadLocal里面&#xff0c;但是Node里面没有这种存储&#xff0c;那咋办呢&#xff1f;直接上代码&…

Android---动态权限适配问题

在 Android6.0&#xff0c;即 API 23 之前&#xff0c;App 需要的权限都会在安装阶段向用户展示&#xff0c;而在 App 运行期间不需要动态判断权限是否已申请。从 6.0 之后的版本开始&#xff0c;Android 系统做了一次大的改动。对于部分权限&#xff0c;App 需要在代码中动态申…

No177.精选前端面试题,享受每天的挑战和学习

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云课上架的前后端实战课程《Vue.js 和 Egg.js 开发企业级健康管理项目》、《带你从入…