【Python】学习率调整策略详解和示例

news2025/1/12 17:47:01

学习率调整得当将有助于算法快速收敛和获取全局最优,以获得更好的性能。本文对学习率调度器进行示例介绍。

  • 学习率调整的意义
  • 基础示例
    • 无学习率调整方法
    • 学习率调整方法一
    • 多因子调度器
    • 余弦调度器
  • 结论

学习率调整的意义

首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果(陷入局部最优)。我们之前看到问题的条件数很重要。直观地说,这是最不敏感与最敏感方向的变化量的比率。

其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。 简而言之,我们希望速率衰减,但要比慢,这样能成为解决凸问题的不错选择

另一个同样重要的方面是初始化。这既涉及参数最初的设置方式,又关系到它们最初的演变方式。这被戏称为预热(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的

鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。本文将梳理不同的调度策略对准确性的影响,并展示如何通过学习率调度器(learning rate scheduler)来有效管理。

基础示例

我们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。 为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用relu而不是sigmoid,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。 此外,我们混合网络以提高性能。

无学习率调整方法

import math
import torch
from torch import nn
from torch.optim import lr_scheduler, SGD
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

def load_data_fashion_mnist(batch_size):
    # 定义数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # 加载训练集和测试集
    train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader
def net_fn():
    model = nn.Sequential(
        nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
        nn.Linear(120, 84), nn.ReLU(),
        nn.Linear(84, 10))
    return model


def train(net, train_loader, test_loader, num_epochs, loss, optimizer, device, scheduler=None):
    net.to(device)
    running_loss = 0.0
    train_losses = []
    test_losses = []
    test_accuracies = []
    for epoch in range(num_epochs):
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = net(inputs)
            loss_value = loss(outputs, labels)

            # Backward and optimize
            loss_value.backward()
            optimizer.step()

            # Print statistics
            running_loss += loss_value.item()

            # if i % 200 == 199:  # print every 200 mini-batches
            #     print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 200}')
            #     running_loss = 0.0
        train_losses.append(running_loss / len(train_loader))
        # Evaluate the model on the test dataset
        test_loss, test_acc = evaluate(net, test_loader, device)
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)
        print(f'Epoch {epoch+1}, Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Test Acc: {test_accuracies[-1]:.2f}')

        if scheduler:
            if scheduler.__module__ == lr_scheduler.__name__:

                scheduler.step()
            else:

                for param_group in  optimizer.param_groups:
                    param_group['lr'] = scheduler(epoch)

    plt.figure(figsize=(10, 6))
    plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
    plt.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
    plt.plot(range(1, num_epochs + 1), test_accuracies, label='Test Accuracy')
    plt.title('Training, Test Losses and Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Loss / Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig("1.jpg")
    plt.show()





def evaluate(model, data_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            test_loss += nn.CrossEntropyLoss(reduction='sum')(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
    test_loss /= len(data_loader.dataset)
    accuracy = correct / len(data_loader.dataset)
    #accuracy = 100. * correct / len(data_loader.dataset)
    return test_loss, accuracy


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the model
model = net_fn()

# Define the loss function
loss = nn.CrossEntropyLoss()

# Define the optimizer
lr=0.3
optimizer = SGD(model.parameters(), lr=lr)



# Load the dataset
batch_size=128
train_loader, test_loader=load_data_fashion_mnist(batch_size)
num_epochs=30
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device)

这里没有使用学习率调整策略。训练过程和结果如下图所示:

.
.
.
.
Epoch 23, Train Loss: 0.1247, Test Loss: 0.3939, Test Acc: 0.90
Epoch 24, Train Loss: 0.1236, Test Loss: 0.4370, Test Acc: 0.89
Epoch 25, Train Loss: 0.1167, Test Loss: 0.4117, Test Acc: 0.89
Epoch 26, Train Loss: 0.1169, Test Loss: 0.4440, Test Acc: 0.89
Epoch 27, Train Loss: 0.1163, Test Loss: 0.4336, Test Acc: 0.89
Epoch 28, Train Loss: 0.1055, Test Loss: 0.4312, Test Acc: 0.90
Epoch 29, Train Loss: 0.1065, Test Loss: 0.4942, Test Acc: 0.89
Epoch 30, Train Loss: 0.1051, Test Loss: 0.4763, Test Acc: 0.89

在这里插入图片描述

学习率调整方法一

设置在每个迭代轮数(甚至在每个小批量)之后向下调整学习率。 例如,以动态的方式来响应优化的进展情况。

在代码最后添加SquareRootScheduler类,并更新train()函数参数,其它内容不变。

class SquareRootScheduler:
    def __init__(self, lr=0.1):
        self.lr = lr

    def __call__(self, num_update):
        return self.lr * pow(num_update + 1.0, -0.5)

scheduler = SquareRootScheduler(lr=0.1)
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device,scheduler)

运行代码,可得相应参数值和变化过程,如下所示。

Epoch 23, Train Loss: 0.1823, Test Loss: 0.2811, Test Acc: 0.90
Epoch 24, Train Loss: 0.1801, Test Loss: 0.2800, Test Acc: 0.90
Epoch 25, Train Loss: 0.1767, Test Loss: 0.2819, Test Acc: 0.90
Epoch 26, Train Loss: 0.1747, Test Loss: 0.2800, Test Acc: 0.91
Epoch 27, Train Loss: 0.1720, Test Loss: 0.2818, Test Acc: 0.90
Epoch 28, Train Loss: 0.1689, Test Loss: 0.2856, Test Acc: 0.90
Epoch 29, Train Loss: 0.1669, Test Loss: 0.2907, Test Acc: 0.90
Epoch 30, Train Loss: 0.1641, Test Loss: 0.2813, Test Acc: 0.90

在这里插入图片描述
我们可以看出曲线比没有策略时平滑了很多,效果有所提升。

多因子调度器

多因子调度器。
在这里插入图片描述
在这里插入图片描述
代码部分修改:

scheduler =lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30], gamma=0.5)

运行结果为:
在这里插入图片描述
可见效果不理想,出现过拟合现象。

余弦调度器

余弦调度器是 (Loshchilov and Hutter, 2016)提出的一种启发式算法。 它所依据的观点是:我们可能不想在一开始就太大地降低学习率,而且可能希望最终能用非常小的学习率来“改进”解决方案。 这产生了一个类似于余弦的调度,函数形式如下所示,学习率的值在
之间。
在这里插入图片描述
代码中添加CosineScheduler类和修改scheduler。

class CosineScheduler:
    def __init__(self, max_update, base_lr=0.01, final_lr=0,
               warmup_steps=0, warmup_begin_lr=0):
        self.base_lr_orig = base_lr
        self.max_update = max_update
        self.final_lr = final_lr
        self.warmup_steps = warmup_steps
        self.warmup_begin_lr = warmup_begin_lr
        self.max_steps = self.max_update - self.warmup_steps

    def get_warmup_lr(self, epoch):
        increase = (self.base_lr_orig - self.warmup_begin_lr) \
                       * float(epoch) / float(self.warmup_steps)
        return self.warmup_begin_lr + increase

    def __call__(self, epoch):
        if epoch < self.warmup_steps:
            return self.get_warmup_lr(epoch)
        if epoch <= self.max_update:
            self.base_lr = self.final_lr + (
                self.base_lr_orig - self.final_lr) * (1 + math.cos(
                math.pi * (epoch - self.warmup_steps) / self.max_steps)) / 2
        return self.base_lr


#scheduler = SquareRootScheduler(lr=0.1)
#scheduler =lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30], gamma=0.5)
scheduler = CosineScheduler(max_update=20, base_lr=0.3, final_lr=0.01)
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device,scheduler)

运行结果如下:

在这里插入图片描述
过拟合现象消失,效果提升。

结论

在开发时应根据自己需要,选择合适的学习率调整策略。优化在深度学习中有多种用途。对于同样的训练误差而言,选择不同的优化算法和学习率调度,除了最大限度地减少训练时间,可以导致测试集上不同的泛化和过拟合量。

注:部分内容摘选子书籍《动手学深度学习》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1548273.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件测试技术之登录页面测试用例的设计方法

相信大家都有过写登录测试用例的经验&#xff0c;相较于开发人员编写代码而言&#xff0c;测试人员编写用例同样重要。本文作者总结了一些关于登录用例的经验。 一、功能测试用例设计&#xff1a; 1、正常登录场景 测试用例1&#xff1a;输入正确的用户名和密码&#xff0c;验证…

对于提高Web安全,WAF能有什么作用

数字化时代&#xff0c;网络安全已经成为了一个不可忽视的重要议题。网络攻击事件频发&#xff0c;各种安全隐患层出不穷&#xff0c;如何有效地保护我们的网络空间&#xff0c;确保信息安全&#xff0c;已成为一项迫切的任务。而Web应用防火墙&#xff0c;正是守护网络安全的一…

GitHub学生认证

文件、证明之类的一定要用英文。 我先是用有道网页翻译把学信网的报告翻译成了英文&#xff0c;然后截图传上去&#xff0c; 给我这个答复 所以要先2FA认证、支付信息填好。 2FA认证&#xff1a;Github开启2FA双重验证 - 知乎 (zhihu.com) 支付信息&#xff1a;点击Setting…

找茬游戏小程序源码系统:封面广告+插屏广告 自带流量主低成本 带完整的安装代码包以及搭建教程

近年来&#xff0c;小程序市场持续火爆&#xff0c;各类小程序层出不穷。找茬游戏小程序作为其中的一种&#xff0c;以其独特的游戏形式和良好的用户体验&#xff0c;吸引了大量用户。然而&#xff0c;对于许多开发者和商家来说&#xff0c;开发一款高质量的找茬游戏小程序并非…

无人车+工厂车间集成无缝,这款网关产品了解一下

​诸位朋友们,大家好!今天给大家介绍一款引领工业无人化发展的黑科技 —— 星创易联科技的SV900-5G车载网关。 相信大家对无人驾驶技术都很感兴趣,它代表着未来出行和生产的全新方式。而要实现真正的"无人化",离不开无人车网关这个智能大脑的作用。SV900就是一款专为…

Openlayers 入门教程(一):应该如何学习 Openlayers

还是大剑师兰特&#xff1a;曾是美国某知名大学计算机专业研究生&#xff0c;现为航空航海领域高级前端工程师&#xff1b;CSDN知名博主&#xff0c;GIS领域优质创作者&#xff0c;深耕openlayers、leaflet、mapbox、cesium&#xff0c;canvas&#xff0c;webgl&#xff0c;ech…

联机分析处理技术

目录 一、OLAP概述&#xff08;一&#xff09;OLAP的定义&#xff08;二&#xff09;OLAP的12条准则&#xff08;三&#xff09;OLAP的简要准则&#xff08;四&#xff09;OLAP系统的基本结构 二、OLAP的多维分析操作&#xff08;一&#xff09;切片&#xff08;二&#xff09;…

电脑访问网页获取路由器WAN口内网IP

因为运维过程中容易出现路由器配置了固定IP但是没人知道后台密码&#xff0c;不确定这个办公室的IP地址&#xff0c;且使用tracert路由追踪也只会出现路由器的LAN口网关并不会出现WAN口IP。 今日正好遇到了个好方法&#xff0c;经过测试可以正常使用。 方法如下&#xff1a; 内…

O2OA(翱途)开发平台-快速入门开发一个门户实例

O2OA(翱途)开发平台[下称O2OA开发平台或者O2OA]拥有门户页面定制与集成的能力&#xff0c;平台通过门户定制&#xff0c;可以根据企业的文化&#xff0c;业务需要设计符合企业需要的统一信息门户&#xff0c;系统首页等UI界面。本篇主要介绍通过门户管理系统如何快速的进行一个…

宝宝洗衣机哪个牌子质量好?四大高热度婴儿洗衣机不容错过

相信大部分的用户家里都会备有一台传统的大型洗衣机&#xff0c;不过&#xff0c;如果家里有了初生的婴儿的话&#xff0c;细心的宝爸宝妈还是会为了宝宝的衣物的卫生&#xff0c;而选择分开单独清洗宝宝的衣物&#xff0c;并且很多宝爸宝妈都会自己手工洗。由于刚出生的宝宝的…

Java 基础学习(二十)Maven、XML与WebServer

1 Maven 1.1 什么是Maven 1.1.1 Maven概述 Maven是一种流行的构建工具&#xff0c;用于管理Java项目的构建&#xff0c;依赖管理和项目信息管理。它使用XML文件来定义项目结构和构建步骤&#xff0c;并使用插件来执行各种构建任务。Maven可以自动下载项目依赖项并管理它们的…

I/O(输入/输出流的概述)

文章目录 前言一、流的概述二、输入/输出流 1.字节/字符输入流2.字节/字符输出流总结 前言 在变量、数组和对象中储存的数据是暂时的&#xff0c;程序结束后它们就会丢失。如果想要永久地储存程序创建的数据&#xff0c;需要将其保存在磁盘文件中&#xff0c;这样就可以在程序中…

Pillow教程07:调整图片的亮度+对比度+色彩+锐度

---------------Pillow教程集合--------------- Python项目18&#xff1a;使用Pillow模块&#xff0c;随机生成4位数的图片验证码 Python教程93&#xff1a;初识Pillow模块&#xff08;创建Image对象查看属性图片的保存与缩放&#xff09; Pillow教程02&#xff1a;图片的裁…

《探索移动开发的未来之路》

移动开发作为当今科技领域中最为炙手可热的领域之一&#xff0c;正以惊人的速度不断迭代和发展。从技术进展到应用案例&#xff0c;再到面临的挑战与机遇以及未来的趋势&#xff0c;移动开发都呈现出了令人瞩目的发展前景。本文将围绕移动开发的技术进展、行业应用案例、面临的…

定义类强化——移动的圆

1.构造一个Location类&#xff1a; 1)该类有两个double型私有成员变量x和y&#xff0c;分别表示横坐标和纵坐标&#xff1b; 2)该类有一个有参构造方法&#xff0c;能初始化成员变量x和y&#xff1b; 3)该类具有成员变量的x和y的访问方法和赋值方法。 2.构造一个Circle类&a…

30---SDRAM电路设计

视频链接 SDRAM电路设计01_哔哩哔哩_bilibili SDRAM电路设计 1、SDRAM简介 SDRAM&#xff1a;Synchronous Dynamic Random Access Memory&#xff0c;同步动态随机存储器。 同步是指其时钟频率和CPU前端总线的系统时钟相同&#xff0c;并且内部命令的发送与数据的传输都以…

【保姆级讲解如何Stable Diffusion本地部署】

&#x1f308;个人主页:程序员不想敲代码啊&#x1f308; &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家&#x1f3c6; &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提…

基于Java的校园疫情防控管理系统(Vue.js+SpringBoot)

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 学生2.2 老师2.3 学校管理部门 三、系统展示四、核心代码4.1 新增健康情况上报4.2 查询健康咨询4.3 新增离返校申请4.4 查询防疫物资4.5 查询防控宣传数据 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBoot…

基于nodejs+vue饮食分享平台python-flask-django-php

本系统采用了nodejs语言的express框架&#xff0c;数据采用MySQL数据库进行存储。进行开发设计&#xff0c;功能强大&#xff0c;界面化操作便于上手。本系统具有良好的易用性和安全性&#xff0c;系统功能齐全&#xff0c;可以满足饮食分享管理的相关工作。 前端技术&#xff…

设计模式学习笔记 - 设计模式与范式 -结构型:3.装饰器模式

概述 上篇文章《设计模式与范式 -结构型&#xff1a;2.桥接模式》&#xff0c;我们介绍了桥接模式&#xff0c;桥接模式的理解方式有两种。第一种理解方式是 “将抽象与实现解耦&#xff0c;让它们能独立开发”。这种理解方式比较特别&#xff0c;应用场景也不多。另一种理解方…