pytorch08:学习率调整策略

news2025/1/16 5:43:57

在这里插入图片描述

目录

  • 一、为什么要调整学习率?
    • 1.1 class _LRScheduler
  • 二、pytorch的六种学习率调整策略
    • 2.1 StepLR
    • 2.2 MultiStepLR
    • 2.3 ExponentialLR
    • 2.4 CosineAnnealingLR
    • 2.5 ReduceLRonPlateau
    • 2.6 LambdaLR
  • 三、学习率调整小结
  • 四、学习率初始化

一、为什么要调整学习率?

学习率(learning rate):控制更新的步伐
一般在模型训练过程中,在开始训练的时候我们会设置学习率大一些,随着模型训练epoch的增加,学习率会逐渐设置小一些。

1.1 class _LRScheduler

学习率调整的父类函数
在这里插入图片描述
主要属性:
• optimizer:关联的优化器
• last_epoch:记录epoch数
• base_lrs:记录初始学习率
主要方法:
• step():更新下一个epoch的学习率,该操作必须放到epoch循环下面
• get_lr():虚函数,计算下一个epoch的学习率

二、pytorch的六种学习率调整策略

2.1 StepLR

在这里插入图片描述

功能:等间隔调整学习率
主要参数:
• step_size:调整间隔数
• gamma:调整系数
调整方式:lr = lr * gamma

代码实现:

import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1)

LR = 0.1
iteration = 10
max_epoch = 200
# ------------------------------ fake data and optimizer  ------------------------------

weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))

optimizer = optim.SGD([weights], lr=LR, momentum=0.9)

# ------------------------------ 1 Step LR ------------------------------
# flag = 0
flag = 1
if flag:

    scheduler_lr = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)  # 设置学习率下降策略,50轮下降一次,每次下降10倍

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):

        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)

        for i in range(iteration):
            loss = torch.pow((weights - target), 2)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

        scheduler_lr.step()  # 学习率更新策略

    plt.plot(epoch_list, lr_list, label="Step LR Scheduler")
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

输出结果:
在这里插入图片描述

因为我们设置每50个epoch降低一次学习率,所以在7774554

2.2 MultiStepLR

在这里插入图片描述

功能:按给定间隔调整学习率
主要参数:
• milestones:设定调整时刻数
• gamma:调整系数
调整方式:lr = lr * gamma

代码实现

flag = 1
if flag:

    milestones = [50, 125, 160]  # 设置学习率下降的位置
    scheduler_lr = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):

        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)

        for i in range(iteration):
            loss = torch.pow((weights - target), 2)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

        scheduler_lr.step()

    plt.plot(epoch_list, lr_list, label="Multi Step LR Scheduler\nmilestones:{}".format(milestones))
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

输出结果
在这里插入图片描述

根据我们设置milestones = [50, 125, 160],发现学习率在这三个地方发生下降。

2.3 ExponentialLR

在这里插入图片描述

功能:按指数衰减调整学习率
主要参数:
• gamma:指数的底
调整方式:lr = lr * gamma^epoch;这里的gamma通常设置为接近1的数值,例如:0.95

代码实现

flag = 1
if flag:

    gamma = 0.95
    scheduler_lr = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):

        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)

        for i in range(iteration):
            loss = torch.pow((weights - target), 2)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

        scheduler_lr.step()

    plt.plot(epoch_list, lr_list, label="Exponential LR Scheduler\ngamma:{}".format(gamma))
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

输出结果
在这里插入图片描述

可以发现学习率是呈指数下降的。

2.4 CosineAnnealingLR

在这里插入图片描述

功能:余弦周期调整学习率
主要参数:
• T_max:下降周期
• eta_min:学习率下限
调整方式:
在这里插入图片描述

代码实现

flag = 1
if flag:
    t_max = 50
    scheduler_lr = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=0.)
    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):
        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)
        for i in range(iteration):
            loss = torch.pow((weights - target), 2)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        scheduler_lr.step()
    plt.plot(epoch_list, lr_list, label="CosineAnnealingLR Scheduler\nT_max:{}".format(t_max))
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

输出结果
在这里插入图片描述

2.5 ReduceLRonPlateau

在这里插入图片描述

功能:监控指标,当指标不再变化则调整,例如:可以监控我们的loss或者准确率,当其不发生变化的时候,调整学习率。
主要参数:
• mode:min/max 两种模式
min模式:当某一个值不下降的时候我们调整学习率,通常用于监控损失
max模型:当某一个值不上升的时候我们调整学习率,通常用于监控精确度
• factor:调整系数
• patience:“耐心”,接受几次不变化
• cooldown:“冷却时间”,停止监控一段时间
• verbose:是否打印日志
• min_lr:学习率下限
• eps:学习率衰减最小值

代码实现

flag = 1
if flag:
    loss_value = 0.5
    accuray = 0.9
    factor = 0.1  # 学习率变换参数
    mode = "min"
    patience = 10  # 能接受多少轮不变化
    cooldown = 10  # 停止监控多少轮
    min_lr = 1e-4  # 设置学习率下限
    verbose = True  # 打印更新日志
    scheduler_lr = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=factor, mode=mode, patience=patience,
                                                        cooldown=cooldown, min_lr=min_lr, verbose=verbose)
    for epoch in range(max_epoch):
        for i in range(iteration):
            # train(...)
            optimizer.step()
            optimizer.zero_grad()
        #if epoch == 5:
           # loss_value = 0.4
        scheduler_lr.step(loss_value) #监控的标量是否下降

输出结果
在这里插入图片描述

2.6 LambdaLR

在这里插入图片描述
功能:自定义调整策略
主要参数:
• lr_lambda:function or list

代码实现

flag = 1
if flag:

    lr_init = 0.1

    weights_1 = torch.randn((6, 3, 5, 5))
    weights_2 = torch.ones((5, 5))

    optimizer = optim.SGD([
        {'params': [weights_1]},
        {'params': [weights_2]}], lr=lr_init)

    # 设置两种不同的学习率调整方法
    lambda1 = lambda epoch: 0.1 ** (epoch // 20)  # 每到20轮的时候学习率变为原来的0.1倍
    lambda2 = lambda epoch: 0.95 ** epoch  # 将学习率进行指数下降

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):
        for i in range(iteration):
            # train(...)

            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

        lr_list.append(scheduler.get_lr())
        epoch_list.append(epoch)

        print('epoch:{:5d}, lr:{}'.format(epoch, scheduler.get_lr()))

    plt.plot(epoch_list, [i[0] for i in lr_list], label="lambda 1")
    plt.plot(epoch_list, [i[1] for i in lr_list], label="lambda 2")
    plt.xlabel("Epoch")
    plt.ylabel("Learning Rate")
    plt.title("LambdaLR")
    plt.legend()
    plt.show()

输出结果
在这里插入图片描述

通过lambda方法定义了两种不同的学习率下降策略。

三、学习率调整小结

  1. 有序调整:Step、MultiStep、Exponential 和 CosineAnnealing
  2. 自适应调整:ReduceLROnPleateau
  3. 自定义调整:Lambda

四、学习率初始化

1、设置较小数:0.01、0.001、0.0001
2、搜索最大学习率: 参考该篇《Cyclical Learning Rates for Training Neural Networks》
方法:我们可以设置学习率逐渐从小变大观察精确度的一个变化,下面这幅图,当学习率为0.055左右的时候模型精确度最高,当学习率大于0.055的时候精确度出现下降情况,所以在模型训练过程中我们可以设置学习率为0.055作为我们的初始学习率。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1358367.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

怎么查询网络出口IP

怎么查询自己的网络的出口IP 背景 一般跟第三方服务进行接口数据交互的时候,对方都会让我们提供调用接口的网络的出口IP,对方会把该IP地址加到对方的白名单中。这样我们才能有权限进行接口的访问。 解决办法 下面介绍三种常用的查询网络出口IP的办法…

DevOps(10)

目录 56.Docker的架构? 57.Docker镜像相关操作有哪些? 58.Docker容器相关操作有哪些? ​编辑59.如何查看Docker容器的日志? 60.如何启动Docker容器?参数含义? 61.如何进入Docker后台模式?有…

基于Java SSM框架实现旅游资源网站系统项目【项目源码+论文说明】

基于java的SSM框架实现旅游资源网站系统演示 摘要 本论文主要论述了如何使用JAVA语言开发一个旅游资源网站 ,本系统将严格按照软件开发流程进行各个阶段的工作,采用B/S架构,面向对象编程思想进行项目开发。在引言中,作者将论述旅…

Redis:原理+项目实战——Redis实战2(Redis实现短信登录(原理剖析+代码优化))

👨‍🎓作者简介:一位大四、研0学生,正在努力准备大四暑假的实习 🌌上期文章:Redis:原理项目实战——Redis实战1(session实现短信登录(并剖析问题)&#xff09…

Ambiq推出语音增强人工智能以消除物联网应用中的噪声

超低功耗半导体解决方案供应商Ambiq推出了其最新产品——神经网络语音增强器 (NNSE),并已将该方案加入到neuralSPOT的(开源模型)Model Zoo中。这一高度优化过的AI模型可以高效实时地将背景噪声从设备对话中去除,从而在嘈杂的环境中…

基于Kettle开发的web版数据集成开源工具(data-integration)-应用篇

目录 📚第一章 基本流程梳理📗页面基本操作📗对应后台服务流程 📚第二章 二开思路📗前端📗后端 🔼上一集:基于Kettle开发的web版数据集成开源工具(data-integration)-介绍篇 *️⃣主…

对话小仙炖副总裁张勇:内容价值将成为直播电商的核心趋势和竞争力

“ 激活中医典籍里的智慧,坚持内容化之路,服务好消费者。” 整理 | 飞族 编辑 | 渔舟 出品|极新&北京电子商务协会 随着直播电商的影响力越来越大,对品牌而言,直播不仅是一种单纯的卖货渠道,…

Open3D聚类算法

按照官网的例子使用聚类,发现结果是全黑的。 经过多次测试发现 eps3.3, min_points1这里是关键 min_points必须等于1否则无效果 import time import open3d as o3d; import numpy as np; import matplotlib.pyplot as plt#坐标 mesh_coord_frame o3d.geometry.Tria…

力扣题:高精度运算-1.3

力扣题-1.3 [力扣刷题攻略] Re:从零开始的力扣刷题生活 力扣题1:43. 字符串相乘 解题思想:类似计算时采用的竖式乘法。首先取得num2的低位,并补齐对应的0,然后与num1进行相乘,然后进行字符串的相加操作。…

使用 pg_stat_statements 优化查询

使用 pg_stat_statements 优化查询 就使用量和社区规模而言,PostgreSQL 是增长最快的数据库之一,得到许多专业开发人员的支持,并得到广泛的工具、连接器、库和可视化应用程序生态系统的支持。 PostgreSQL 也是可扩展的:使用 Postg…

实现HSRP-热备份路由协议

实现HSRP-热备份路由协议 <HSRP多组实现> 网络工程师必会的企业网络常用双机热备协议之HSRP。 实验拓扑: 实验目的: 通过配置多组HSRP实现网关自动切换和链接负载均衡,既当网络正常时PC1,PC3通过R1到达R3,PC2,PC4通过R2到达R3,当R1或R2发生故障时网关能自动切换,以确…

企业微信开发:自建应用:获取企业微信IP段(用于防火墙配置)

概述 在企业微信开发流程中&#xff0c;为了确保与企业微信API的网络通信安全&#xff0c;并适应防火墙配置要求&#xff0c;开发者需要获取企业微信API服务的IP地址范围。这样&#xff0c;仅允许与企业微信官方通信的合法请求通过防火墙&#xff0c;从而保障数据传输的安全性…

教学目标是什么

教学目标&#xff0c;作为教学活动的灵魂之所在&#xff0c;对于教育者和学生都至关重要。然而&#xff0c;你是否曾对此产生过疑问&#xff1a;教学目标究竟是什么&#xff1f;它又如何影响我们的教学活动呢&#xff1f; 教学目标就像一座灯塔&#xff0c;为教学活动指明方向&…

MapInfo Pro和Python基础知识

MapInfo Pro用户长期以来一直使用MapBasic脚本语言来自动化任务、构建自定义应用程序、创建Pro的特定领域自定义、将Pro与其他工具集成等。 MapBasic主要是一种编译语言&#xff0c;这对非程序员来说有点障碍。 我们确实有MapBasic窗口&#xff0c;它允许MapBasic语句和代码直接…

罗德与施瓦茨FSVA40信号和频谱分析仪

罗德与施瓦茨FSVA40是一款功能信号和频谱分析仪&#xff0c;适用于从事射频系统的开发、生产、安装和服务的用户。FSVA40信号和频谱分析仪系列始终提供最佳的价格和性能组合&#xff0c;无论是根据最新通信标准测试生产中的无线设备&#xff0c;还是测量低相位噪声、高灵敏度和…

根本记不住MySQL进阶查询语句

1 MySQL进阶查询 1.1 MySQL进阶查询的语句 全文以数据库location和Store_Info为实例 ---- SELECT ----显示表格中一个或数个字段的所有数据记录 语法&#xff1a;SELECT "字段" FROM "表名"; select 列名 from 表名 ; ---- DISTINCT ----不显示重复的数…

农业银行RPA实践 3大典型案例分析

零接触开放金融服务在疫情之下被越来越多的银行和客户所认同&#xff0c;引起了更广泛的持续关注&#xff0c;各家银行纷纷开展产品服务创新&#xff0c;加速渠道迁移&#xff0c;同时通过远程办公、构建金融生态等方式积极推进零接触开放金融体系建设。 随着商业银行科技力量的…

后端开发——JDBC的学习(三)

本篇继续对JDBC进行总结&#xff1a; ①通过Service层与Dao层实现转账的练习&#xff1b; ②重点&#xff1a;由于每次使用连接就手动创建连接&#xff0c;用完后就销毁&#xff0c;这样会导致资源浪费&#xff0c;因此引入连接池&#xff0c;练习连接池的使用&#xff1b; …

新年福利|这款价值数万的报表工具永久免费了

随着数据资产的价值逐渐凸显&#xff0c;越来越多的企业会希望采用报表工具来处理数据分析&#xff0c;了解业务经营状况&#xff0c;从而辅助经营决策。不过&#xff0c;企业在选型报表工具的时候经常会遇到以下几个问题&#xff1a; 各个报表工具有很多功能和特性&#xff0c…

Python数据处理库之tablib详解

概要 Python 提供了许多库和工具来处理数据&#xff0c;其中之一就是 tablib。tablib 是一个功能强大且易于使用的库&#xff0c;用于处理各种数据格式&#xff0c;包括Excel、CSV、JSON等。它不仅可以用于数据导入和导出&#xff0c;还支持数据转换、过滤、合并等操作。本文将…