深度学习_Learning Rate Scheduling

news2025/3/1 4:02:39

我们在训练模型时学习率的设置非常重要。

  • 学习率的大小很重要。如果它太大,优化就会发散,如果它太小,训练时间太长,否则我们最终会得到次优的结果。
  • 其次,衰变率同样重要。如果学习率仍然很大,我们可能会简单地在最小值附近反弹,从而无法达到最优

我们可以通过学习率时间表(Learning Rate Scheduling)有效地管理准确性

一、基于FashionMNIST任务的学习率时间表实践准备

构建简单网络

def net_fn():
    model = nn.Sequential(
        nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
        nn.Linear(120, 84), nn.ReLU(),
        nn.Linear(84, 10))
    return model

模型结构如下(左-netron
在这里插入图片描述
简单的训练框架
全部脚本可以查看笔者的github: LearningRateScheduling.ipynb

def train(model, train_iter, test_iter, config, scheduler=None):
    device = config.device
    loss = config.loss
    opt = config.opt
    num_epochs = config.num_epochs
    model.to(device)
    animator = Animator(xlabel='epoch', xlim=[0, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    
    ep_total_steps = len(train_iter)
    for ep in range(num_epochs):
        tq_bar = tqdm(enumerate(train_iter))
        tq_bar.set_description(f'[ Epoch {ep+1}/{num_epochs} ]')
        # train_loss, train_acc, num_examples
        metric = Accumulator(3)
        for idx, (X, y) in tq_bar:
            final_flag = (ep_total_steps == idx + 1) & (num_epochs == ep + 1)
            model.train()
            opt.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = model(X)
            l = loss(y_hat, y)
            l.backward()
            opt.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])
            train_loss = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            tq_bar.set_postfix({
                "loss" : f"{train_loss:.3f}",
                "acc" : f"{train_acc:.3f}",
            })
            if (idx + 1) % 50 == 0:
                animator.add(ep + idx / len(train_iter), (train_loss, train_acc, None), clear_flag=not final_flag)

        test_acc = evaluate_accuracy_gpu(model, test_iter)
        animator.add(ep+1, (None, None, test_acc), clear_flag=not final_flag)
        if scheduler:
            if scheduler.__module__ == lr_scheduler.__name__:
                # 使用 PyTorch In-Built scheduler
                scheduler.step()
            else:
                # 使用自定义 scheduler
                for param_group in opt.param_groups:
                    param_group['lr'] = scheduler(ep) 

    print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    plt.show()

二、基于FashionMNIST任务的学习率时间表实践

2.1 无learning rate Scheduler 训练

def test(train_iter, test_iter, scheduler=None):
    net = net_fn()
    cfg = Namespace(
        device=try_gpu(),
        loss=nn.CrossEntropyLoss(),
        lr=0.3, 
        num_epochs=10,
        opt=torch.optim.SGD(net.parameters(), lr=0.3)
    )
    train(net, train_iter, test_iter, cfg, scheduler)

batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)
test(train_iter, test_iter)

在这里插入图片描述

2.2 Square Root Scheduler训练

更新方式为
η = η ∗ n u m _ u p d a t e + 1 \eta =\eta *\sqrt{num\_update + 1} η=ηnum_update+1
本次试验是每一个epoch更新一次

def get_lr(scheduler):
    lr = scheduler.get_last_lr()[0]
    scheduler.optimizer.step()
    scheduler.step()
    return lr

def plot_scheduler(scheduler, num_epochs=10):
    s = scheduler.__class__.__name__
    if scheduler.__module__ == lr_scheduler.__name__:
        print('pytorch build lr_scheduler')
        plot_y = [get_lr(scheduler) for _ in range(num_epochs)]
    else:
        plot_y = [scheduler(t) for t in range(num_epochs)]

    plt.title(f'train with learning rate scheduler: {s}')
    plt.plot(torch.arange(num_epochs), plot_y)
    plt.xlabel('num_epochs')
    plt.ylabel('learning_rate')
    plt.show()


class SquareRootScheduler:
    """
    使用均方根scheduler
    每一个epoch更新一次
    """
    def __init__(self, lr=0.1):
        self.lr = lr

    def __call__(self, num_update):
        return self.lr * pow(num_update + 1.0, -0.5)

scheduler = SquareRootScheduler(lr=0.1)
plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

从下图中可以看出:曲线比以前更平滑了。其次,过度拟合较少。
在这里插入图片描述

2.3 FactorScheduler训练

学习率更新方式: η t + 1 ← m a x ( η m i n , η t ⋅ α ) \eta_{t+1} \leftarrow \mathop{\mathrm{max}}(\eta_{\mathrm{min}}, \eta_t \cdot \alpha) ηt+1max(ηmin,ηtα)

class FactorScheduler:
    def __init__(self, factor=1, stop_factor_lr=1e-7, base_lr=0.1):
        self.factor = factor
        self.stop_factor_lr = stop_factor_lr
        self.base_lr = base_lr

    def __call__(self, num_update):
        self.base_lr = max(self.stop_factor_lr, self.base_lr * self.factor)
        return self.base_lr

scheduler = FactorScheduler(factor=0.8, stop_factor_lr=1e-2, base_lr=0.6)
plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

2.4 Multi Factor Scheduler训练

保持学习率分段恒定,并每隔一段时间将其降低一个给定的量。也就是说,给定一组何时降低速率的时间比如$ (s = {3, 8} )$
d e c r e a s e ( η t + 1 ← η t ⋅ α )    t ∈ s decrease (\eta_{t+1} \leftarrow \eta_t \cdot \alpha) \ \ t \in s decrease(ηt+1ηtα)  ts

net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
scheduler = lr_scheduler.MultiStepLR(trainer, milestones=[3, 8], gamma=0.5)

plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

2.5 Cosine Scheduler训练

Loshchilov和Hutter提出了一个相当令人困惑的启发式方法。它依赖于这样一种观察,即我们可能不想在一开始就大幅降低学习率,此外,我们可能希望在最后使用非常小的学习率来“完善”解决方案。这导致了一个类似余弦的时间表,具有以下函数形式,用于范围内的学习率 t ∈ [ 0 , T ] t \in [0, T] t[0,T]

η t = η T + η 0 − η T 2 ( 1 + cos ⁡ ( π t T ) ) \eta_t = \eta_T + \frac{\eta_0 - \eta_T}{2} \left(1 + \cos(\frac{\pi t}{T})\right) ηt=ηT+2η0ηT(1+cos(Tπt))

注:

  • η T \eta_T ηT: 为最终的学习率
  • η 0 \eta_0 η0: 为最开始的学习率
class CosineScheduler:
    def __init__(self, max_update, base_lr=0.01, final_lr=0,
               warmup_steps=0, warmup_begin_lr=0):
        self.base_lr_orig = base_lr
        self.max_update = max_update
        self.final_lr = final_lr
        self.warmup_steps = warmup_steps
        self.warmup_begin_lr = warmup_begin_lr
        self.max_steps = self.max_update - self.warmup_steps

    def get_warmup_lr(self, step):
        increase = (self.base_lr_orig - self.warmup_begin_lr) \
                       * float(step) / float(self.warmup_steps)
        return self.warmup_begin_lr + increase

    def __call__(self, step):
        if step < self.warmup_steps:
            return self.get_warmup_lr(step)
        if step <= self.max_update:
            self.base_lr = self.final_lr + (
                self.base_lr_orig - self.final_lr) * (1 + math.cos(
                math.pi * (step - self.warmup_steps) / self.max_steps)) / 2
        return self.base_lr

scheduler = CosineScheduler(max_update=10, base_lr=0.2, final_lr=0.02)
plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

2.6 Warmup

在某些情况下,初始化参数不足以保证良好的解决方案。对于一些先进的网络设计来说,这尤其是一个问题(Transformer的训练常用该方法),可能会导致不稳定的优化问题。
我们可以通过选择一个足够小的学习率来解决这个问题,以防止一开始就出现分歧。不幸的是,这意味着进展缓慢。相反,学习率高最初会导致差异。

对于这种困境,一个相当简单的解决方案是使用一个预热期,在此期间学习速率增加到其初始最大值,并冷却速率直到优化过程结束。为了简单起见,通常使用线性增加来实现这一目的。

scheduler = CosineScheduler(max_update=10, warmup_steps=3, base_lr=0.2, final_lr=0.02)
plot_scheduler(scheduler, 15)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

小结

从上述的5个策略上来看,一般情况我们用 Cosine Scheduler 或者线性衰减就能得到较好的结果。不过对于较大的模型,需要用warmup 并且需要特意去设计,比如NoamOpt等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/412523.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL NDB Cluster使用docker compose一键部署

本文主要用来学习MySQL NDB Cluster 解决学习过程中的痛点&#xff1a;需要开启N台VMware虚拟机&#xff0c;电脑不堪重负 使用docker部署&#xff0c;完美解决 本文使用的docker image: mysql/mysql-cluster:8.0 创建mysql_cluster目录&#xff0c;后续操作都在这个目录下 …

基于html+css的盒子展示7

准备项目 项目开发工具 Visual Studio Code 1.44.2 版本: 1.44.2 提交: ff915844119ce9485abfe8aa9076ec76b5300ddd 日期: 2020-04-16T16:36:23.138Z Electron: 7.1.11 Chrome: 78.0.3904.130 Node.js: 12.8.1 V8: 7.8.279.23-electron.0 OS: Windows_NT x64 10.0.19044 项目…

系统集成项目管理工程师软考第三章习题(每天更新)

第一章指路&#xff1a;系统集成项目管理工程师软考第一章习题&#xff08;已完结&#xff09;_程序猿幼苗的博客-CSDN博客 第二章指路&#xff1a;系统集成项目管理工程师软考第二章习题&#xff08;已完结&#xff09;_程序猿幼苗的博客-CSDN博客 第3章信息系统集成专业技术…

基于密集学习的半监督目标检测

文章目录Dense Learning based Semi-Supervised Object Detection摘要本文方法实验结果Dense Learning based Semi-Supervised Object Detection 摘要 提出了一种基于密集学习(DSL)的无锚点的半监督目标检测算法用于分配多层级和精确的密集像素伪标签的自适应过滤器用于生成稳…

C++13:搜索二叉树

目录 搜索二叉树概念 模拟实现搜索二叉树 插入函数实现 插入函数实现&#xff08;递归&#xff09; 查找函数实现 删除函数实现 删除函数实现&#xff08;递归&#xff09; 中序遍历实现 拷贝构造函数实现 析构函数实现 赋值重载 我们在最开始学习二叉树的时候&#xff0c;…

【网络】Internet 协议版本 6 (IPv6)

文章目录IPv6 寻址文本表示形式地址类型IPv6 路由邻居发现IPv6 自动配置自动配置类型IPv6 移动性禁用或启用 IPv6配置步骤代码启用操作系统启用来源Internet 协议版本 6 (IPv6) 是 Internet 的网络层的标准协议套件。 IPv6 旨在解决当前版本的 Internet 协议套件&#xff08;称…

二叉树前中后层遍历(递归/非递归)(简单易懂(*^ー^))

文章目录二叉树的遍历1 先序遍历1.1 递归1.2 非递归2 中序遍历2.1 递归2.2 非递归3 后序遍历3.1 递归3.2 非递归4 层序遍历5 前中后层序完整可运行代码&#xff08;C&#xff09;二叉树的遍历 1 先序遍历 1.1 递归 先序遍历(Preorder Traversal)&#xff0c;即根左右的顺序遍…

Anaconda3安装配置/创建删除虚拟环境/在特定虚拟环境下安装库

1. Anaconda3彻底卸载 先说Anaconda3的卸载&#xff0c;在Anaconda3安装路径下有一个Uninstall-Anaconda3.exe&#xff0c;右键以“管理员身份运行”&#xff0c;可执行完全卸载 2. 下载与安装Anaconda3 官网地址https://repo.anaconda.com/ 点击Anaconda Distribution&…

自学大数据第14天NoSQL~MongoDB及其命令

这几天主要是看了一下mongodb的一些知识,网上也有一些教程,今天主要是复习一下mongodb 启动mongodb 在连接mongodb前首先要创建数据存放目录与日志存放目录,还得保证当前用户对这两个目录有相应的读写操作 mongod --dbpath/usr/local/mongodb/data/db/ --logpath/usr/lcoal/mon…

(四)【软件设计师】计算机系统—基础单位进制

文章目录一、计算机基础单位二、进制1.进制表示符号2.进制之间的转换&#xff1a;(1)十进制转换为二进制&#xff08;例子&#xff1a;173&#xff09;(2)十进制转换为八进制&#xff08;3&#xff09;十进制转换为十六进制&#xff08;4&#xff09;二进制转换为十进制&#x…

Linux入门 - 最常用基础指令汇总

目录 ls指令 pwd指令 cd指令 touch指令 mkdir指令 rmdir指令 && rm 指令 man指令&#xff08;重要&#xff09; cp指令&#xff08;重要&#xff09; mv指令&#xff08;重要&#xff09; cat指令 more指令 less指令&#xff08;重要&#xff09; head指令…

交换机PCB板布局布线注意事项

由于板卡在工作中会受到各种各样的干扰&#xff0c;这些干扰不仅影响系统运行的稳定性&#xff0c;同时也有可能带来误差&#xff0c;因此考虑如何抑制干扰&#xff0c;提高电磁兼容性是PCB布局布线时的一项重要任务。海翎光电的小编现将PCB布局布线中需要主要考虑的因素列在下…

银行数字化转型导师坚鹏:深度解读《中华人民共和国数据安全法》

深度解读《中华人民共和国数据安全法》 ——中国数据安全立法 助力企业稳健发展课程背景&#xff1a; 很多金融机构存在以下问题&#xff1a; 不清楚数据安全法立法背景&#xff1f; 不知道如何理解数据安全法相关政策&#xff1f; 不清楚如何数据安全进行合规建设&#xf…

【前端之旅】Vue入门笔记

一名软件工程专业学生的前端之旅,记录自己对三件套(HTML、CSS、JavaScript)、Jquery、Ajax、Axios、Bootstrap、Node.js、Vue、小程序开发(Uniapp)以及各种UI组件库、前端框架的学习。 【前端之旅】Web基础与开发工具 【前端之旅】手把手教你安装VS Code并附上超实用插件…

计算机组成原理第二章数据的表示与运算(中)

提示&#xff1a;且行且忘且随风&#xff0c;且行且看且从容 文章目录前言2.2.0 奇偶校验码(大纲已删)2.2.1 电路的基本原理 加法器设计2.2.2 并行进位加法器2.2.3 补码加减运算器2.2.4 标志位的生成2.2.5 定点数的移位运算2.2.62.2.6.1 原码的乘法运算2.2.6.2 补码的乘法运算2…

Linux下异步socket客户端

文章目录socket 客户端1. 创建socketsocket()函数返回值2. 设置socket的属性connect函数sockaddr_in结构体inet_pton函数3. fcntl设置非阻塞4. recv函数socket 客户端 1. 创建socket socket()函数 #include <sys/socket.h> int socket(int domain, int type, int proto…

四、线程安全,synchronized,volatile(JMM)【4/12 5/12 6/12】【多线程】

4. 多线程带来的的风险-线程安全 (重点) 4.1 观察线程不安全 static class Counter {public int count 0;void increase() {count;} } public static void main(String[] args) throws InterruptedException {final Counter counter new Counter();Thread t1 new Thread(()…

数据结构——排序(4)

作者&#xff1a;几冬雪来 时间&#xff1a;2023年4月12日 内容&#xff1a;数据结构排序内容讲解 目录 前言&#xff1a; 1.快速排序中的递归&#xff1a; 2.小区间优化&#xff1a; 3.递归改非递归&#xff1a; 4.归并排序&#xff1a; 5.归并排序的非递归形式&…

Revit中如何绘制倾斜的屋顶及一键成板?

Revit中如何绘制倾斜的屋顶&#xff1f;如下图所示&#xff0c;像这种坡屋顶有两种方法进行绘制。 第一种&#xff1a;定义坡度。 1、点击建筑选项卡中的屋顶按钮。选择使用矩形工具。 2、在选项栏中&#xff0c;偏离值修改为500&#xff0c;把屋顶迹线绘制出来。 3、取消这三…

软件测试今天你被内卷了吗?

认识一个人&#xff0c;大专学历非计算机专业的&#xff0c;是前几年环境好的时候入的行&#xff0c;那时候软件测试的要求真的很低&#xff0c;他那时好像是报了个班&#xff0c;然后入门的&#xff0c;但学的都是些基础&#xff0c;当时的他想的也简单&#xff0c;反正也能拿…