【PyTorch实战演练】自调整学习率实例应用(附代码)

news2025/1/11 22:35:49

目录

0. 前言

1. 自调整学习率的常用方法

1.1  ExponentialLR 指数衰减方法

1.2 CosineAnnealingLR 余弦退火方法

1.3 ChainedScheduler 链式方法

2. 实例说明

3. 结果说明

3.1 余弦退火法训练过程

3.2 指数衰减法训练过程

3.3 恒定学习率训练过程

3.4 结果解读

4. 完整代码


0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文介绍深度学习训练中经常能使用到的实用技巧——自调整学习率,并基于PyTorch框架通过实例进行使用,最后对比同样条件下以自调整学习率和固定学习率在模型训练上的表现。

在深度学习模型训练过程中,经常会出现损失值不收敛的头疼情况,令人怀疑自己设计的网络模型存在不合理或者错误之处,但是导致这种情况的真正原因往往非常简单——就是学习率的设定不合理。

这就导致在深度学习模型训练的初期,可能需要“摸索”一个比较合适的学习率,在后期逐渐细化的过程可能还需要尝试其他的学习率,进行“分段训练”。学习率的这种“摸索”费时费力,因此自调整学习率的方法应运而生。

1. 自调整学习率的常用方法

自调整学习率是一种通用的方法,通过在训练期间动态地更新学习率以使模型更好地收敛,提高模型的精确度和稳定性。在PyTorch框架 torch.optim.lr_scheduler 中集成了大量的自调整学习率的方法:

lr_scheduler.LambdaLR

lr_scheduler.MultiplicativeLR

lr_scheduler.StepLR

lr_scheduler.MultiStepLR

lr_scheduler.ConstantLR

lr_scheduler.LinearLR

lr_scheduler.ExponentialLR

lr_scheduler.PolynomialLR

lr_scheduler.CosineAnnealingLR

lr_scheduler.ChainedScheduler

lr_scheduler.SequentialLR

lr_scheduler.ReduceLROnPlateau

lr_scheduler.CyclicLR

lr_scheduler.OneCycleLR

lr_scheduler.CosineAnnealingWarmRestarts

这里具体方法虽然很多,但是每种自动调整学习率的策略都比较简单,本文仅介绍以下两种常见的方法,其余可以查看PyTorch官网说明。

1.1  ExponentialLR 指数衰减方法

这种方法的学习率为:

lr = lr_{initial} \cdot \gamma^{epoch}

这里有两个参数:

  • lr_{initial}:初始学习率
  • \gamma:衰减指数,取值范围(0,1),一般来说取值要非常接近1(>0.95,甚至要>0.99),否则在动辄几百几千的迭代次数epoch条件下学习率很快会衰减到0

PyTorch调用代码为:

torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch)
  • optimizer:训练该模型的优化算法
  • gamma:上面的衰减指数\gamma
  • last_epoch:不用理会,默认-1即可
1.2 CosineAnnealingLR 余弦退火方法

首先我要吐槽下不知道是哪个大聪明给起的这个名字,我学过《机械加工原理与工艺》,但我不明白这个算法和退火有什么关系?名字非常唬人,方法并不复杂。

这个方法就是让学习率按余弦曲线进行震荡,其震荡范围为[0, lr_initial],震荡周期为T_max:

PyTorch调用代码为:

torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, last_epoch)

其参数同上说明不再赘述。

1.3 ChainedScheduler 链式方法

这个方法说白了就是多种自调节学习率方法进行组合使用,例如以下:

scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
scheduler2 = ExponentialLR(self.opt, gamma=0.9)
scheduler = ChainedScheduler([scheduler1, scheduler2])

最后再强调一下,无论使用哪种方法,在coding时别忘了每个epoch学习后加上 scheduler.step() ,这样才能更新学习率。

2. 实例说明

本文目的是验证自调节学习率的作用,选用一个简单的“平方网络”实例,即期望输出为输入值的平方。

训练输入数据x_train,输出数据y_train分别为:

x_train = torch.tensor([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],dtype=torch.float32).unsqueeze(-1)
y_train = torch.tensor([0.01,0.04,0.09,0.16,0.25,0.36,0.49,0.64,0.81,1],dtype=torch.float32).unsqueeze(-1)

验证输入数据x_val为:

x_val = torch.tensor([0.15,0.25,0.35,0.45,0.55,0.65,0.75,0.85,0.95],dtype=torch.float32).unsqueeze(-1)

网络模型使用5层全连接网络,对应这种简单问题足够用了:

class Linear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(in_features=1, out_features=3),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=3,out_features=5),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=10),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=10,out_features=5),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=1),
            torch.nn.ReLU(),
        )

优化器选用Adam,训练组及验证组的损失函数都选用MSE均方差。

3. 结果说明

对于本文对比的三种方法,其训练参数设定如下表:

学习率调整方法训练参数设定
指数衰减迭代次数epoch=2000,初始学习率initial_lr=0.0005,衰减指数gamma=0.99995
余弦退火迭代次数epoch=2000,初始学习率initial_lr=0.001(因为学习率均值为初始值的一半,公平起见给余弦退火方法初始学习率*2),震荡周期T_max=200
恒定学习率学习率lr=0.0005

各种方法的训练过程如下图:

3.1 余弦退火法训练过程

3.2 指数衰减法训练过程

3.3 恒定学习率训练过程

3.4 结果解读
  1. 指数衰减方法的gamma真的要选择非常接近1才行!理由我上面也解释了,下面是gamma=0.999时的loss下降情况,可见其下降速度之慢;

  2. 本实例中,使用余弦退火方法训练的模型在验证组表现最好,验证组损失值val=0.0027,其次是指数衰减法(val=0.0037),最差是恒定学习率(val=0.0053),但这个结果仅限本实例中,并不是说哪种方法就比哪种方法好,在不同模型上可能要因地制宜选择合适的方法
  3. 目前这些所谓的“自调节学习率”我认为仅能算是“半自动调节”,因为仍有超参数需要事前设定(例如衰减指数gamma)。而选择哪种方法最好,以及这种方法的参数如何设定呢?可能还是需要做一定的“摸索”,但是相比恒定学习率肯定会节省很多时间!

4. 完整代码

import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse


torch.manual_seed(25)

x_train = torch.tensor([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],dtype=torch.float32).unsqueeze(-1)
y_train = torch.tensor([0.01,0.04,0.09,0.16,0.25,0.36,0.49,0.64,0.81,1],dtype=torch.float32).unsqueeze(-1)
x_val = torch.tensor([0.15,0.25,0.35,0.45,0.55,0.65,0.75,0.85,0.95],dtype=torch.float32).unsqueeze(-1)

class Linear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(in_features=1, out_features=3),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=3,out_features=5),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=10),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=10,out_features=5),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=1),
            torch.nn.ReLU(),
        )

    def forward(self,x):
        return self.layers(x)


criterion = torch.nn.MSELoss()

def train_with_CosinAnneal(epoch, initial_lr, T_max):
    linear1 = Linear()
    opt = torch.optim.Adam(linear1.parameters(), lr=initial_lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=opt, T_max=T_max, last_epoch=-1)
    last_epoch_loss = 0

    for i in tqdm(range(epoch)):
        opt.zero_grad()
        total_loss = 0
        for j in range(len(x_train)):
            output = linear1(x_train[j])
            loss = criterion(output,y_train[j])

            total_loss = total_loss + loss.detach()
            loss.backward()
            opt.step()
            if i == epoch-1:
                last_epoch_loss = total_loss
        scheduler.step()

        cur_lr = scheduler.get_lr()  #查看学习率变化情况
        # plt.scatter(i, cur_lr, s=2, c='g')  

        plt.scatter(i, total_loss,s=2,c='r')

    val = 0
    for x in x_val:
        val = val + ((linear1(x) - x * x) / x * x)**2

    plt.title('CosinAnneal---val=%f---last_epoch_loss=%f'%(val,last_epoch_loss))
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()


def train_with_ExponentialLR(epoch, initial_lr, gamma):
    linear2 = Linear()
    opt = torch.optim.Adam(linear2.parameters(), lr=initial_lr)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=gamma, last_epoch=-1)
    last_epoch_loss = 0

    for i in tqdm(range(epoch)):
        opt.zero_grad()
        total_loss = 0
        for j in range(len(x_train)):
            output = linear2(x_train[j])
            loss = criterion(output,y_train[j])

            total_loss = total_loss + loss.detach()
            loss.backward()
            opt.step()
            if i == epoch-1:
                last_epoch_loss = total_loss

        scheduler.step()

        cur_lr = scheduler.get_lr()

        # plt.scatter(i, cur_lr, s=2, c='g')
        plt.scatter(i, total_loss,s=2,c='g')


    val = 0
    for x in x_val:
        val = val + ((linear2(x) - x * x) / x * x)**2

    plt.title('ExponentialLR---val=%f---last_epoch_loss=%f'%(val,last_epoch_loss))
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()

def train_without_lr_scheduler(epoch, initial_lr):
    linear3 = Linear()
    opt = torch.optim.Adam(linear3.parameters(), lr=initial_lr)
    last_epoch_loss = 0

    for i in tqdm(range(epoch)):
        opt.zero_grad()
        total_loss = 0
        for j in range(len(x_train)):
            output = linear3(x_train[j])
            loss = criterion(output, y_train[j])
            total_loss = total_loss + loss.detach()
            loss.backward()
            opt.step()
            if i == epoch-1:
                last_epoch_loss = total_loss

        plt.scatter(i, total_loss, s=2, c='b')

    val = 0
    for x in x_val:
        val = val + ((linear3(x) - x * x) / x * x)**2


    plt.title('without_lr_scheduler---val=%f---last_epoch_loss=%f'%(val,last_epoch_loss))
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()


if __name__ == '__main__' :
    epoch = 2000
    initial_lr = 0.0005
    gamma = 0.99995
    T_max = 200



    Cosine_Anneal_args = [epoch,initial_lr*2,T_max]
    train_with_CosinAnneal(*Cosine_Anneal_args)

    # ExponentialLR_args = [epoch, initial_lr, gamma]
    # train_with_ExponentialLR(*ExponentialLR_args)
    #
    # without_scheduler_args = [epoch, initial_lr]
    # train_without_lr_scheduler(*without_scheduler_args)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1122525.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Jenkins 相关内容

Jenkins 相关内容 什么是 Jenkins,它是如何工作的?Jenkins 中自由式项目和管道之间的区别什么是Jenkins管道,它们如何工作?第一次如何安装Jenkins并进行设置?什么是 Jenkins 插件,如何安装它们?…

中间件安全-CVE复现WeblogicJenkinsGlassFish漏洞复现

目录 服务攻防-中间件安全&CVE复现&Weblogic&Jenkins&GlassFish漏洞复现中间件-Weblogic安全问题漏洞复现CVE_2017_3506漏洞复现 中间件-JBoos安全问题漏洞复现CVE-2017-12149漏洞复现CVE-2017-7504漏洞复现 中间件-Jenkins安全问题漏洞复现CVE-2017-1000353漏…

【问题记录】解决Qt连接MySQL报“QMYSQL driver not loaded”以及不支持MySQL事务操作的问题!

环境 Windows 11 家庭中文版,64 位操作系统, 基于 x64 的处理器Qt 5.15.2 MinGW 32-bitmysql Ver 14.14 Distrib 5.7.42, for Win32 (AMD64) 问题情况 在Qt 5.15.2 中编写连接MySQL数据库代码后,使用 MinGW 32-bit 构建套件进行编译运行后,报…

Python基础入门例程5-NP5 格式化输出(一)

描述 牛牛、牛妹和牛可乐正在Nowcoder学习Python语言,现在给定他们三个当中的某一个名字name, 假设输入的name为Niuniu,则输出 I am Niuniu and I am studying Python in Nowcoder! 请按以上句式输出相应的英文句子。 输入描述&#xff1…

驱动开发1 概念、内核模块编程、内核消息打印函数printk函数的使用、内核模块传参、内核导出符号

1 驱动相关概念 2 内核模块编程 内核模块编写实例代码注释 #include <linux/init.h> #include <linux/module.h>//入口函数&#xff0c;安装内核模块时执行 static int __init mycdev_init(void) {//static 修饰当前函数只能在本文件使用//int 函数的返回值类型&a…

063:mapboxGL常见错误:Style is not done loading(原因及解决办法)

第063个 点击查看专栏目录 作者在做vue+mapbox的项目,将geojson的数据加载到地图上来,形成的效果图如下 但是在处理的时候,遇到过这个一个错误,提示信息如下: vue.runtime.esm.js:3049 Error: Style is not done loadingat Qt._checkLoaded (mapbox-gl.js:36:1)at Qt.…

《计算机视觉中的多视图几何》笔记(14)

14 Affine Epipolar Geometry 本章主要是在仿射摄像机的情况下重新考虑对极几何&#xff0c;也就是仿射对极几何。 仿射摄像机的优点是它是线性的&#xff0c;所以很多最优化算法可以用线性代数的知识解决。如果是一般的投影摄像机&#xff0c;很多算法就不是线性的了&#x…

[架构之路-241]:目标系统 - 纵向分层 - 企业信息化与企业信息系统(多台企业应用单机组成的企业信息网络)

目录 前言&#xff1a; 一、什么是信息系统&#xff1a;计算机软件硬件系统 1.1 什么是信息 1.2 什么是信息系统 1.3 什么是信息技术 1.4 什么是信息化与信息化转型 1.5 什么是数字化与数字化转型&#xff08;信息化的前提&#xff09; 1.6 数字化与信息化的比较 1.7 …

Android 10.0 Launcher3定制化之动态时钟图标功能实现

1.概述 在10.0的系统产品rom定制化开发中,在Launcher3中的定制化的一些功能中,对于一些产品要求需要实现动态时钟图标功能,这就需要先绘制时分秒时针表盘,然后 每秒刷新一次时钟图标,时钟需要做到实时更新,做到动态时钟的效果,接下来就来分析这个功能的实现 如图: 2.动…

基于nodejs+vue市民健身中心网上平台mysql

市民健身中心网上平台分为用户界面和管理员界面&#xff0c; 用户信息模块&#xff1a;管理员可在后台添加、删除普通用户&#xff0c;查看、编辑普通用户的信息。 课程表管理模块&#xff1a;管理员可对课程表进行修改任课教师、新增某一堂课、删除某一堂课、查找课程、修改…

论文导读|9月MSOM文章精选:智慧城市运筹

推文作者&#xff1a;郭浩然 编者按 本期论文导读围绕“智慧城市运筹”这一话题&#xff0c;简要介绍并分析了近期的三篇MSOM文章&#xff0c;分别涉及了最后一公里配送中的新模式&#xff1a;“司机辅助”&#xff0c;易腐库存管理的新策略&#xff1a;“截断平衡”&#xff0…

Openssl数据安全传输平台004:套接字C语言API封装为C++类 / 客户端及服务端代码框架和代码实现

文章目录 0. 代码仓库1. 客户端C API2. 客户端C API的封装分析2.1 sckClient_init()和sckClient_destroy()2.2 sckClient_connect2.3 sckClient_closeconn()2.4 sckClient_send()2.5 sckClient_rev()2.6 sck_FreeMem 3. 客户端C API4. 服务端C API5. 服务端C6. 客户端和服务端代…

性能测试LoadRunner02

本篇主要讲&#xff1a;通过Controller设计简单的测试场景&#xff0c;可以简单的分析性能测试报告。 Controller 设计场景 Controller打开方式 1&#xff09;通过VUG打开 2&#xff09;之间双击Controller 不演示了&#xff0c;双击打开&#xff0c;选择Manual Scenario自…

buuctf[HCTF 2018]WarmUp 1

题目环境&#xff1a; 发现除了表情包&#xff0c;再无其他F12试试发现source.php文件访问这个文件&#xff0c;格式如下&#xff1a;url/source.php回显如下&#xff1a;PHP代码审计&#xff1a; <?php highlight_file(__FILE__); class emmm {public static function ch…

Linux笔记之diff工具软件P4merge的使用

Linux笔记之diff工具软件P4merge的使用 code review! 文章目录 Linux笔记之diff工具软件P4merge的使用1.安装和配置2.使用&#xff1a;p4merge a.cc b.cc3.配置git 参考博文: Ubuntu Git可视化比较工具 P4Merge 的安装/配置及使用 1.安装和配置 $ wget https://cdist2.per…

分享一下我家网络机柜,家庭网络设备推荐

家里网络机柜搞了几天终于搞好了&#xff0c;非专业的&#xff0c;走线有点乱&#xff0c;勿喷。 从上到下的设备分别是&#xff1a; 无线路由器&#xff08;当ap用&#xff09;:TL-XDR6088 插排&#xff1a;德木pdu机柜插排 硬盘录像机&#xff1a;TL-NVR6108-L8P 第二排左边…

OpenGL 环境搭建和 hello world 程序(LearnOpenGL P1)

文章目录 OpenGLGLFW & CMake链接到 Hello OpenGL&#xff01;GLAD运行测试 OpenGL 什么是 OpenGL&#xff0c;OpenGL 能做什么在此不再赘述 运行 OpenGL 需要准备的有&#xff1a; CMake&#xff1a;用于执行编译VS&#xff1a;我使用的是 Visual Studio 17 2022 版本G…

系统架构师备考倒计时13天(每日知识点)

1. 数据仓库四大特点 面向主题的。操作型数据库的数据组织面向事务处理任务&#xff0c;各个业务系统之间各自分离&#xff0c;而数据仓库中的数据是按照一定的主题域进行组织的。集成的。数据仓库中的数据是在对原有分散的数据库数据抽取、清理的基础上经过系统加工、汇总和整…

[ 云计算 | AWS 实践 ] Java 如何重命名 Amazon S3 中的文件和文件夹

本文收录于【#云计算入门与实践 - AWS】专栏中&#xff0c;收录 AWS 入门与实践相关博文。 本文同步于个人公众号&#xff1a;【云计算洞察】 更多关于云计算技术内容敬请关注&#xff1a;CSDN【#云计算入门与实践 - AWS】专栏。 本系列已更新博文&#xff1a; [ 云计算 | …

100 # mongoose 的使用

mongoose elegant mongodb object modeling for node.js https://mongoosejs.com/ 安装 mongoose npm i mongoose基本示例 const mongoose require("mongoose");// 1、连接 mongodb let conn mongoose.createConnection("mongodb://kaimo313:kaimo313loc…