人工智能-优化算法之学习率调度器

news2025/1/12 6:04:12

学习率调度器

到目前为止,我们主要关注如何更新权重向量的优化算法,而不是它们的更新速率。 然而,调整学习率通常与实际算法同样重要,有如下几方面需要考虑:

  • 首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果。我们之前看到问题的条件数很重要。直观地说,这是最不敏感与最敏感方向的变化量的比率。

  • 其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。

  • 另一个同样重要的方面是初始化。这既涉及参数最初的设置方式,又关系到它们最初的演变方式。这被戏称为预热(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的。

  • 最后,还有许多优化变体可以执行周期性学习率调整。这超出了本章的范围,我们建议读者阅读 (Izmailov et al, 2018)来了解个中细节。例如,如何通过对整个路径参数求平均值来获得更好的解。

鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。 在本章中,我们将梳理不同的调度策略对准确性的影响,并展示如何通过学习率调度器(learning rate scheduler)来有效管理。

们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。 为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用relu而不是sigmoid,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。 此外,我们混合网络以提高性能。 由于大多数代码都是标准的,我们只介绍基础知识,而不做进一步的详细讨论

%matplotlib inline
import math
import torch
from torch import nn
from torch.optim import lr_scheduler
from d2l import torch as d2l


def net_fn():
    model = nn.Sequential(
        nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
        nn.Linear(120, 84), nn.ReLU(),
        nn.Linear(84, 10))

    return model

loss = nn.CrossEntropyLoss()
device = d2l.try_gpu()

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

# 代码几乎与d2l.train_ch6定义在卷积神经网络一章LeNet一节中的相同
def train(net, train_iter, test_iter, num_epochs, loss, trainer, device,
          scheduler=None):
    net.to(device)
    animator = d2l.Animator(xlabel='epoch', xlim=[0, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])

    for epoch in range(num_epochs):
        metric = d2l.Accumulator(3)  # train_loss,train_acc,num_examples
        for i, (X, y) in enumerate(train_iter):
            net.train()
            trainer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            trainer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            train_loss = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % 50 == 0:
                animator.add(epoch + i / len(train_iter),
                             (train_loss, train_acc, None))

        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch+1, (None, None, test_acc))

        if scheduler:
            if scheduler.__module__ == lr_scheduler.__name__:
                # UsingPyTorchIn-Builtscheduler
                scheduler.step()
            else:
                # Usingcustomdefinedscheduler
                for param_group in trainer.param_groups:
                    param_group['lr'] = scheduler(epoch)

    print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')

让我们来看看如果使用默认设置,调用此算法会发生什么。 例如设学习率为0.3并训练30次迭代。 留意在超过了某点、测试准确度方面的进展停滞时,训练准确度将如何继续提高。 两条曲线之间的间隙表示过拟合。

lr, num_epochs = 0.3, 30
net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=lr)
train(net, train_iter, test_iter, num_epochs, loss, trainer, device)

train loss 0.128, train acc 0.951, test acc 0.885

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1276240.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

知识管理平台Confluence:win10安装confluence

文章目录 介绍主要功能 安装教程安装java运行平台JRE安装数据库Postgresql在Postgresql创建confluence使用的数据库创建数据库用户创建数据库 安装confluence注册confluence启动confluence 参考链接 介绍 Confluence 是由澳大利亚软件公司 Atlassian 开发的企业协作平台。它提…

flutter开发实战-ValueListenableBuilder实现局部刷新功能

flutter开发实战-ValueListenableBuilder实现局部刷新功能 在创建的新工程中,点击按钮更新counter后,通过setState可以出发本类的build方法进行更新。当我们只需要更新一小部分控件的时候,通过setState就不太合适了,这就需要进行…

canvas基础:渲染文本

canvas实例应用100 专栏提供canvas的基础知识,高级动画,相关应用扩展等信息。 canvas作为html的一部分,是图像图标地图可视化的一个重要的基础,学好了canvas,在其他的一些应用上将会起到非常重要的帮助。 文章目录 示例…

java设计模式学习之【桥接模式】

文章目录 引言桥接模式简介定义与用途:实现方式 使用场景优势与劣势桥接模式在Spring中的应用绘图示例代码地址 引言 想象你正在开发一个图形界面应用程序,需要支持多种不同的窗口操作系统。如果每个系统都需要写一套代码,那将是多么繁琐&am…

scrapy爬虫中间件和下载中间件的使用

一、关于中间件 之前文章说过,scrapy有两种中间件:爬虫中间件和下载中间件,他们的作用时间和位置都不一样,具体区别如下: 爬虫中间件(Spider Middleware) 作用: 爬虫中间件主要负…

SQL Server 2016(基本概念和命令)

1、文件类型。 【1】主数据文件:数据库的启动信息。扩展名为".mdf"。 【2】次要(辅助)数据文件:主数据之外的数据都是次要数据文件。扩展名为".ndf"。 【3】事务日志文件:包含恢复数据库的所有事务…

深入理解前端路由:构建现代 Web 应用的基石(下)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

2024年天津天狮学院专升本专业课报名缴费流程

天津天狮学院高职升本缴费流程 一、登录缴费系统 二、填写个人信息,进行缴费 1.在姓名处填写“姓名”,学号处填写“身份证号”,如下图所示: 此处填写身份证号 2.单击查询按钮,显示报考专业及缴费列表,…

JPA数据源Oracle异常记录

代码执行异常 ObjectOptimisticLockingFailureException org.springframework.orm.ObjectOptimisticLockingFailureException: Batch update returned unexpected row count from update [0]; actual row count: 0; expected: 1; nested exception is org.hibernate.StaleSta…

从0开始学习JavaScript--JavaScript ES6 模块系统

JavaScript ES6(ECMAScript 2015)引入了官方支持的模块系统,使得前端开发更加现代化和模块化。本文将深入探讨 ES6 模块系统的各个方面,通过丰富的示例代码详细展示其核心概念和实际应用。 ES6 模块的基本概念 1 模块的导出 ES…

java原子类型

AtomicBoolean AtomicInteger AtomicLong AtomicReference<V> StringBuilder - 不是原子类型。StringBuilder 是 java.lang 包下的类 用法&#xff1a;无需回调改变数值

基于springboot + vue框架的网上商城系统

qq&#xff08;2829419543&#xff09;获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;springboot 前端&#xff1a;采用vue技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xf…

Linux:vim的简单使用

个人主页 &#xff1a; 个人主页 个人专栏 &#xff1a; 《数据结构》 《C语言》《C》《Linux》 文章目录 前言一、vim的基本概念二、vim的基本操作三、vim正常模式命令集四、vim底行模式命令集五、.xxx.swp的解决总结 前言 本文是对Linux中vim使用的总结 一、vim的基本概念 …

C语言:求十个数中的平均数

分析&#xff1a; 程序中定义了一个average函数&#xff0c;用于计算分数的平均值。该函数接受一个包含10个分数的数组作为参数&#xff0c;并返回平均值。在主函数main中&#xff0c;首先提示输入10个分数&#xff0c;然后使用循环读取输入的分数&#xff0c;并将它们存储在名…

iris+vue上传到本地存储【go/iris】

iris部分 //main.go package mainimport ("fmt""io""net/http""os" )//上传视频文件部分 func uploadHandler_video(w http.ResponseWriter, r *http.Request) {// 解析上传的文件err : r.ParseMultipartForm(10 << 20) // 设置…

Nacos 架构原理

基本架构及概念​ 服务 (Service)​ 服务是指一个或一组软件功能&#xff08;例如特定信息的检索或一组操作的执行&#xff09;&#xff0c;其目的是不同的客户端可以为不同的目的重用&#xff08;例如通过跨进程的网络调用&#xff09;。Nacos 支持主流的服务生态&#xff0c…

基于springboot + vue在线考试系统

qq&#xff08;2829419543&#xff09;获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;springboot 前端&#xff1a;采用vue技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xf…

更改Jupyter Notebook 默认存储路径

import osprint(os.path.abspath(.)) 然后打开cmd,输入&#xff1a; jupyter notebook --generate-config 按照路径在本地文件夹中找到那个文件。 然后找到"c.NotebookApp.notebook_dir"这条语句&#xff1a;&#xff08;直接通过"crtlf"输入关键字找阿 …

2661. 找出叠涂元素 : 常规哈希表运用题

题目描述 这是 LeetCode 上的 「2661. 找出叠涂元素」 &#xff0c;难度为 「中等」。 Tag : 「模拟」、「哈希表」、「计数」 给你一个下标从 开始的整数数组 arr 和一个 的整数矩阵 mat。 arr 和 mat 都包含范围 &#xff0c; 内的所有整数。 从下标 开始遍历 arr 中的每…

经典神经网络——VGGNet模型论文详解及代码复现

论文地址&#xff1a;1409.1556.pdf。 (arxiv.org)&#xff1b;1409.1556.pdf (arxiv.org) 项目地址&#xff1a;Kaggle Code 一、背景 ImageNet Large Scale Visual Recognition Challenge 是李飞飞等人于2010年创办的图像识别挑战赛&#xff0c;自2010起连续举办8年&#xf…