torch.optim.lr_scheduler.OneCycleLR 学习与理解

news2024/11/18 3:29:07

一、功能和参数

1.1、通过图像直观地理解 OneCycleLR 的过程:

补充:

生成该图像的代码:

来自:torch.optim.lr_scheduler.OneCycleLR用法_dxz_tust的博客-CSDN博客

import cv2
import torch.nn as nn
import torch
from torchvision.models import AlexNet
import matplotlib.pyplot as plt
#定义2分类网络
steps = []
lrs = []
# ## !!!!下面这一行如果感觉太慢,可以使用:model = torch.nn.Linear(2, 1) !!!!
model = AlexNet(num_classes=2)
# ------------------------------------------
lr = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

#total_steps:总的batch数,这个参数设置后就不用设置epochs和steps_per_epoch,anneal_strategy 默认是"cos"方式,当然也可以选择"linear"
#注意,这里的max_lr和你优化器中的lr并不是同一个
#注意,无论你optim中的lr设置是啥,最后起作用的还是max_lr
scheduler =torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr=0.9,total_steps=100, verbose=True)
# ------------------------------------------
for epoch in range(10):
    for batch in range(10):
        scheduler.step()
        lrs.append(scheduler.get_lr()[0])
        steps.append(epoch*10+batch)
 
 
plt.figure()
plt.legend()
plt.plot(steps, lrs, label='OneCycle')
plt.savefig("dd.png")

1.2、参数介绍

上图中的值,就是参数的默认值。

详细介绍:(英文挺容易理解,就不再翻译了。)

optimizer (Optimizer): Wrapped optimizer.
max_lr (float or list): Upper learning rate boundaries in the cycle
    for each parameter group.
total_steps (int): The total number of steps in the cycle. Note that
    if a value is not provided here, then it must be inferred by providing
    a value for epochs and steps_per_epoch.
    Default: None
epochs (int): The number of epochs to train for. This is used along
    with steps_per_epoch in order to infer the total number of steps in the cycle
    if a value for total_steps is not provided.
    Default: None
steps_per_epoch (int): The number of steps per epoch to train for. This is
    used along with epochs in order to infer the total number of steps in the
    cycle if a value for total_steps is not provided.
    Default: None
pct_start (float): The percentage of the cycle (in number of steps) spent
    increasing the learning rate.
    Default: 0.3
anneal_strategy (str): {'cos', 'linear'}
    Specifies the annealing strategy: "cos" for cosine annealing(退火), "linear" for
    linear annealing(退火).
    Default: 'cos'
cycle_momentum (bool): If ``True``, momentum is cycled inversely(相反地)
    to learning rate between 'base_momentum' and 'max_momentum'.
    Default: True
base_momentum (float or list): Lower momentum boundaries in the cycle
    for each parameter group. Note that momentum is cycled inversely
    to learning rate; at the peak of a cycle, momentum is
    'base_momentum' and learning rate is 'max_lr'.
    Default: 0.85
max_momentum (float or list): Upper momentum boundaries in the cycle
    for each parameter group. Functionally,
    it defines the cycle amplitude(振幅) (max_momentum - base_momentum).
    Note that momentum is cycled inversely
    to learning rate; at the start of a cycle, momentum is 'max_momentum'
    and learning rate is 'base_lr'

    Default: 0.95
div_factor (float): Determines the initial learning rate via
    initial_lr = max_lr/div_factor
    Default: 25
final_div_factor (float): Determines the minimum learning rate via
    min_lr = initial_lr/final_div_factor (可以看出,最终的lr是非常非常小的
    Default: 1e4
three_phase (bool): If ``True``, use a third phase of the schedule to annihilate(消灭) the
    learning rate according to 'final_div_factor' instead of modifying the second
    phase (the first two phases will be symmetrical about the step indicated by
    'pct_start').

    Default: False
last_epoch (int): The index of the last batch. This parameter is used when
    resuming a training job. Since `step()` should be invoked after each
    batch instead of after each epoch, this number represents the total
    number of *batches* computed, not the total number of epochs computed.(这句话的意思是:last_epoch表示已经训练了多少个batches,而不是训练了多少个epochs)
    When last_epoch=-1, the schedule is started from the beginning.
    Default: -1
verbose (bool): If ``True``, prints a message to stdout for
    each update. Default: ``False``.

关于参数我的一些理解:

  1. optimizer 学习率是要在优化器中使用的。这个参数就是用于指定这个学习率用于哪个优化器。根据我的观察,其实就是给optimizer加了一个调整学习率的HOOK
  2. max_lr 就是【1.1、】图的上界;
  3. total_steps 或者 (epochs + steps_per_epoch)二者必须设置其中一个(官方文档:You must either provide a value for total_steps or provide a value for both epochs and steps_per_epoch.)原因是计算每一步的学习率需要用到 total_steps 。通过(epochs + steps_per_epoch)可以计算出total_steps(total_steps = epochs * steps_per_epoch)
  4. pct_start 表示【1.1、】图的上升阶段占 total_steps 的比例;
  5. anneal_strategy 表示【1.1、】途中下降阶段的策略:cos、linear;
  6. cycle_momentum 这个不怎么理解,直接使用默认值就可以了
  7. max_momentum 这个也不怎么理解,使用默认值就可以了;关于单词“inversely”的含义可以参考其中标红的部分
  8. div_factor 确定初始学习率,即【1.1、】图最左侧的起点,使用默认值就可以
  9. final_div_factor 确定最终的学习率,即【1.1、】图最右侧的终点,使用默认值就可以
  10. three_phase 这个后面单独细讲
  11. last_epoch 看英语解释就可以,比较容易理解
  12. verbose 看英语解释就可以,比较容易理解

1.2.1、关于参数 three_phase 的介绍

【1.1、】中图只有上升和下降两个阶段。将【1.1、】图的代码增加 three_phase=True 参数后,

scheduler =torch.optim.lr_scheduler.OneCycleLR(optimizer,total_steps=100,max_lr=0.9,three_phase=True)

其图像会变成下面这样:

 整个图将分为3个部分:①上升阶段,②第一个下降阶段,③第二个下降阶段

并且:① 和 ② 是对称的

①和②在total_steps中的占比都是pct_start。所以pct_start如果大于等于0.5,那么即使设置了tree_phase参数,也不会出现③这个阶段。

补充:如果上升和下降都采用线性(linear)的方法,图像会类似于下图:

(来自:侵权立删)

1.2.2、three_phase 对训练效果影响的注意点

官方有这么一句话:The default behaviour of this scheduler follows the fastai implementation of 1cycle, which claims that "unpublished work has shown even better results by using only two phases". To mimic the behaviour of the original paper instead, set ``three_phase=True``.

即,论文中的 OneCycleLR() 是三阶段的,但是有人验证过了,三阶段的训练效果没有二阶段的训练效果好。所以直接使用默认值 False 更好一些。

1.3、官方的使用例子

data_loader = torch.utils.data.DataLoader(...)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
for epoch in range(10):
    for batch in data_loader:
        train_batch(...)
        scheduler.step()

二、一个可以直接使用的OneCycleLR配置

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, total_steps=100)

three_phase 使用默认值(False)效果更好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/654396.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nodejs二、内置模块

零、文章目录 Nodejs二、内置模块 1、fs 文件系统模块 (1)fs 文件系统模块是什么 fs 模块是 Node.js 官方提供的、用来操作文件的模块。它提供了一系列的方法和属性,用来满足用户对文件的操作需求。 fs.readFile() :用来读取指…

【深度学习-第2篇】CNN卷积神经网络30分钟入门!足够通俗易懂了吧(图解)

网络上有着很多关于CNN入门的教程,但是总还是觉得缺少足够简易、直观、全面的文章,能让人通读下来酣畅淋漓,将CNN概念尽收囊中。本篇文章就想尝试一下,真正地带小白同学们轻松入门。 这篇文章包含很多图片,为了花这些…

k8s-containerd容器运行时默认50G存储位置更换

containerd作为k8s主要的cri,它默认存储位置是使用的/根目录挂载的资源。当容器运行的越来越多,默认的50G不够使用了。有2种方法可以进行解决。 方式1、增加/根分区的磁盘空间。 方式2、修改containerd配置文件,修改默认配置为/home 这里我…

【汤4操作系统】深入理解信号量的使用-三大问题的变体

主要从生产者消费者、读写者、哲学家问题中的经典变体进行讲述,均使用伪代码实现 生产者消费者变体 顾客看作是生产出的产品,理发师看作是消费者,沙发有空位,顾客就进去,沙发有顾客,理发师就去理发 和生产者…

Redis客户端 - Jdies快速入门

原文首更地址,阅读效果更佳! Redis客户端 - Jdies快速入门 | CoderMast编程桅杆Redis客户端 - Jdies快速入门 简介 Jedis is a Java client for Redis designed for performance and ease of use. Jedis是Redis 的 Java 客户端,专为性能和易…

Python中使用matplotlib绘制各类图表示例

折线图 折线图是一种用于表示数据随时间、变量或其他连续性变化的趋势的图表。通过在横轴上放置时间或如此类似的连续变量,可以在纵轴上放置数据点的值,从而捕捉到数据随时间发生的变化。折线图可以用于比较不同变量的趋势,轻松地发现不同的…

不写代码如果解决Jmeter跨线程组取参数值问题?

目录 前言 定义属性法 文件转接法 总结: 前言 如果你工作中已经在用jmeter做接口测试,或性能测试了,你可能会遇到一个麻烦。 那就是jmeter的变量值不能跨线程组传递。 看,官方就已经给出了解释: 这个不是jmeter的…

机器学习——识别足球和橄榄球

一、选题的背景 橄榄球起源于足球,二者即相似又有所区别。计算机技术发展至今,AI技术也有了极大的进步,通过机器学习不断的训练,AI对于足球和橄榄球的识别能力可以帮助人们对足球和橄榄球的分辨。机器学习是一种智能技术&#xff…

虚拟机使用docker安装MySql出现的问题,Navicat连不上MySql

文章目录 一、问题引入 二、问题分析 三、问题解决 ​四、总结 一、问题引入 今天是学习谷粒商城的第一天,既然是第一天,肯定就是先对项目先有个基本的了解,比如是项目所用到的技术栈,项目整体的架构等,还对分布…

操作系统闲谈09——内存管理算法

操作系统闲谈09——内存管理算法 Buddy伙伴系统 假设存在一段连续的页框,阴影部分表示已经被使用的页框,现在需要申请一个连续的5个页框。这个时候,在这段内存上不能找到连续的5个空闲的页框,就会去另一段内存上去寻找5个连续的页…

华为OD机试真题B卷 JavaScript 实现【乱序整数序列两数之和绝对值最小】,附详细解题思路

一、题目描述 给定一个随机的整数(可能存在正整数和负整数)数组 nums,请你在该数组中找出两个数,其和的绝对值(|nums[x]nums[y]|)为最小值,并返回这个两个数(按从小到大返回)以及绝对值。 每种…

Android 行业就业难! 我是否该负重前行~

不知从何时开始,互联网市场岗位开始以收缩趋势进行发展,使得不少互联网行业的从业者面临者工作难找的难题,对于我们开发人群来说很不友好。 以前可以靠着跳槽实现涨薪梦,而如今是能不动就不动,能稳住是最好。 为什么这…

Docker——安装MySQL

一、安装并拉取MySQL镜像 先把docker启动起来 systemctl restart docker systemctl status docker 安装MySQL docker search mysql拉取镜像, 如果拉取不成功或者显示超时,可以去配置加速镜像源。 二、查看本地镜像并启动MySQL 但是光有镜像没有把镜像…

Redis面试之数据类型及底层原理

废话不多说直接上类型 string(字符串) hash(哈希) list(列表) set(集合) zset(有序集合) stream(流) geospatial(地…

CRM软件有哪些?这9款值得推荐

业内有一句流传已久的话:你的左手不知道你的右手在做什么。同一个企业内部,不同部门之间往往存在信息不同步,数据不对称的情况,比如销售和营销部门关于某个市场活动所带来的效果产生分歧。CRM软件的存在就可以解决这类问题。 在正…

实验4 Cache性能分析【计算机系统结构】

实验4 Cache性能分析【计算机系统结构】 前言推荐实验四 Cache性能分析1 实验目的2 实验平台3 实验内容和步骤3.1 Cache容量对不命中率的影响3.2 相联度对不命中率的影响3.3 Cache块大小对不命中率的影响3.4 替换算法对不命中率的影响 4 实验总结与心得5 请思考 最后 前言 202…

8年测试工程师分享,我是怎么开展性能测试的(基础篇)

第一节 测试的一般步骤 性能测试的工作是基于系统功能已经完备或者已经趋于完备之上的,在功能还不够完备的情况下没有多大的意义(后期功能完善上会对系统的性能有影响,过早进入性能测试会出现测试结果不准确、浪费测试资源)&…

足不出户怎么在家赚钱,暑假在家别闲着,给自己赚点生活费吧

在当今快节奏的现代生活中,人们面临着越来越大的竞争压力。为了过上舒适的生活、提前退休、创业或增加收入,许多人都希望能够在家中赚钱。那么,在家里如何可以找到赚钱的项目呢?本文将为您详细介绍一些方法。 一、在家工作有很多好…

《计算之魂》读书笔记——第2章,从递推到递归

我们人类的固有思维方式常常是出于直观的,由近及远、从少到多,这样的思维方式让我们很容易理解具体的事物,却也限制了我们的抽象思维,所以当我们理解远离我们生活经验的事物时,就容易出现障碍。我们人类这种自底向上、…

调用万维易源实现天气预测

作者介绍 房庚晨,男,西安工程大学电子信息学院,22级研究生 研究方向:机器视觉与人工智能 电子邮件:1292475736qq.com 王泽宇,男,西安工程大学电子信息学院,2022级研究生&#xff0…