【SGDR】《SGDR:Stochastic Gradient Descent with Warm Restarts》

news2024/12/29 9:54:15

在这里插入图片描述

arXiv-2016

code: https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py


文章目录

  • 1 Background and Motivation
  • 2 Related Work
  • 3 Advantages / Contributions
  • 4 Method
  • 5 Experiments
    • 5.1 Datasets and Metric
    • 5.2 Single-Model Results
    • 5.3 Ensemble Results
    • 5.4 Experiments on a Dataset of EEG Recordings
    • 5.5 Preliminary Experiments on a downsampled ImageNet Dataset
  • 6 Conclusion(own) / Future work


1 Background and Motivation

训练深度神经网络的过程可以视为找下面这个方程 min 解的过程

在这里插入图片描述
或者用二阶导的形式

在这里插入图片描述

然而 inverse Hessian 不好求(【矩阵学习】Jacobian矩阵和Hessian矩阵)

虽然有许多改进的优化方法来尽可能的逼近 inverse Hessian,但是,目前在诸多计算机视觉相关任务数据集上表现最好的方法还是 SGD + momentum

有了实践中比较猛的 optimization techniques 后,The main difficulty in training a DNN is then associated with the scheduling of the learning rate and the amount of L2 weight decay regularization employed.

本文,作者从 learning rate schedule 角度出发,提出了 SGDR 学习率策略,periodically simulate warm restarts of SGD

在这里插入图片描述
使得深度学习任务收敛的更快更好

2 Related Work

In applied mathematics, multimodal optimization deals with optimization tasks that involve finding all or most of the multiple (at least locally optimal) solutions of a problem, as opposed to a single best solution.

  • restarts in gradient-free optimization

    based on niching methods(见文末总结部分)

  • restarts in gradient-based optimization
    《Cyclical Learning Rates for Training Neural Networks》(WACV-2017)
    closely-related to our approach in its spirit and formulation but does not focus on restarts
    一个 soft restart 一个 hard restart(SGDR)

3 Advantages / Contributions

SGD + warm restart 技术的结合,或者说 warm restart 在 SGD 上的应用

两者均非原创,在一些小数据集(输入分辨率有限)上有提升,泛化性能还可以

速度上比SGD收敛的要快一些,x2 ~x4

4 Method

periodically simulate warm restarts of SGD

SGD with momentum
在这里插入图片描述
再加 warm start

在这里插入图片描述
蓝色和红色是之前的 step 式 learning rate schedule( A common learning rate schedule is to use a constant learning rate and divide it by a fixed constant in (approximately) regular interval)

其余颜色是作者的 SGDR 伴随不同的参数配置

核心公式

在这里插入图片描述

  • η \eta η 是学习率

  • i − t h i-th ith run

  • T c u r T_{cur} Tcur accounts for how many epochs have been performed since the last restar, T c u r = 0 T_{cur} = 0 Tcur=0 学习率最大,为 η m a x \eta_{max} ηmax T c u r = T i T_{cur}=T_i Tcur=Ti 时学习率最小为 η m i n \eta_{min} ηmin

  • T i T_i Ti cosine 的一个下降周期对应的 epoch 数或 iteration 数

每次 new start 的时候, η m i n \eta_{min} ηmin 或者 η m a x \eta_{max} ηmax 可调整

让每个周期变得越来越长的话,可以设置 T m u l t T_{mult} Tmult > 1(eg =2, it doubles the maximum number of epochs for every new restart. The main purpose of this doubling is to reach good test error as soon as possible)

code,来自 Cosine Annealing Warm Restart

# 导包
from torch import optim
from torch.optim import lr_scheduler

# 定义模型
model, parameters = generate_model(opt)

# 定义优化器
if opt.nesterov:
    dampening = 0
else:
    dampening = 0.9
optimizer = opt.SGD(parameters, lr=0.1, momentum=0.9, dampening=dampending, weight_decay=1e-3, nesterov=opt.nesterov)

# 定义热重启学习率策略
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0, last_epoch=-1)

在这里插入图片描述

5 Experiments

5.1 Datasets and Metric

CIFAR-10:Top1 error
CIFAR-100:Top1 error
a dataset of electroencephalographic (EEG)
downsampled ImageNet(32x32): Top1 and Top5 error

5.2 Single-Model Results

在这里插入图片描述
WRN 网络 with depth d and width k

CIFAR10 上 T m u l t i = 1 T_{multi} = 1 Tmulti=1 比较猛

CIFAR100 上 T m u l t i = 2 T_{multi} = 2 Tmulti=2 比较猛

在这里插入图片描述
收敛速度快的优势

Since SGDR achieves good performance faster, it may allow us to train larger networks

在这里插入图片描述

CIFAR10 上 T m u l t i = 1 T_{multi} = 1 Tmulti=1 比较猛,黑白色

CIFAR100 上 T m u l t i = 2 T_{multi} = 2 Tmulti=2 比较猛

只看收敛效果的话,白色最猛, cosine learning rate

5.3 Ensemble Results

这里是复刻下《SNAPSHOT ENSEMBLES: TRAIN 1, GET M FOR FREE》(ICLR-2017)中的方法
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
来自 优化器: Snapshots Ensembles 快照集成


在这里插入图片描述
具体的实验细节不是很了解,N =16,M=3 表示总共 200 epoch,16个 restart 周期,选出来3个模型平均?还是说跑了 16*200个 epoch,每次run(200epoch)选出来 M 个模型平均?

M = 3,30,70,150
M = 2, 70,150
M = 1 ,150

5.4 Experiments on a Dataset of EEG Recordings

在这里插入图片描述

5.5 Preliminary Experiments on a downsampled ImageNet Dataset

our downsampled ImageNet contains exactly the same images from 1000 classes as the original ImageNet but resized with box downsampling to 32 × 32 pixels.

在这里插入图片描述
在这里插入图片描述

6 Conclusion(own) / Future work

  • 实战中 T i T_i Ti 怎么设计比较好,是设置成最大 epoch 吗?还是多 restart 几次, T m u l T_{mul} Tmul 是不是大于 1 比等于 1 效果好?

  • Restart techniques are common in gradient-free optimization to deal with multimodal functions

  • Stochastic subGradient Descent with restarts can achieve a linear convergence rate for a class of non-smooth and non-strongly convex optimization problems

  • Cyclic Learning rate和SGDR-学习率调整策略论文两篇

    可以将 SGDR 称为hard restart,因为每次循环开始时学习率都是断崖式增加的,相反,CLR应该称为soft restart
    在这里插入图片描述

  • 理解深度学习中的学习率及多种选择策略

  • 什么是ill-conditioning 对SGD有什么影响? - Martin Tan的回答 - 知乎
    https://www.zhihu.com/question/56977045/answer/151137770

在这里插入图片描述
在这里插入图片描述


  • niching

    What is niching scheme?

    Niching methods:
    在这里插入图片描述

    小生境(Niche):来自于生物学的一个概念,是指特定环境下的一种生存环境,生物在其进化过程中,一般总是与自己相同的物种生活在一起,共同繁衍后代。例如,热带鱼不能在较冷的地带生存,而北极熊也不能在热带生存。把这种思想提炼出来,运用到优化上来的关键操作是:当两个个体的海明距离小于预先指定的某个值(称之为小生境距离)时,惩罚其中适应值较小的个体。

    海明距离(Hamming Distance):在信息编码中,两个合法代码对应位上编码不同的位数称为码距,又称海明距离。例如,10101和00110从第一位开始依次有第一位、第四、第五位不同,则海明距离为3。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1607531.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Modality-Aware Contrastive Instance Learning with Self-Distillation ... 论文阅读

Modality-Aware Contrastive Instance Learning with Self-Distillation for Weakly-Supervised Audio-Visual Violence Detection 论文阅读 ABSTRACT1 INTRODUCTION2 RELATEDWORKS2.1 Weakly-Supervised Violence Detection2.2 Contrastive Learning2.3 Cross-Modality Knowle…

基于Java SpringBoot+Vue的校园周边美食探索及分享平台的研究与实现,附源码

博主介绍:✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&#x1f3…

vue+node使用RSA非对称加密,实现登录接口加密密码

背景 登录接口,密码这种重要信息不可以用明文传输,必须加密处理。 这里就可以使用RSA非对称加密,后端生成公钥和私钥。 公钥:给前端,公钥可以暴露出来,没有影响,因为公钥加密的数据只有私钥才…

类和对象(中)(构造函数、析构函数和拷贝构造函数)

1.类的六个默认成员函数 任何类在什么都不写时,编译器会自动生成以下6个默认成员函数。 //空类 class Date{}; 默认成员函数:用户没有显示实现,编译器会自动生成的成员函数称为默认成员函数 2.构造函数 构造函数 是一个 特殊的成员函数&a…

网络分析工具

为了实现业务目标,每天都要在网络上执行大量操作,网络管理员很难了解网络中实际发生的情况、谁消耗的带宽最多,并分析是否正在发生任何可能导致带宽拥塞的活动。对于大型企业和分布式网络来说,这些挑战是多方面的,为了…

[Leetcode]用栈实现队列

用栈实现队列: 请你仅使用两个栈实现先入先出队列。队列应当支持一般队列支持的所有操作(push、pop、peek、empty): 实现 MyQueue 类: void push(int x) 将元素 x 推到队列的末尾int pop() 从队列的开头移除并返回元…

【智能算法】鸡群优化算法(CSO)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献 1.背景 2014年,X Meng等人受到鸡群社会行为启发,提出了鸡群优化算法(Chicken Swarm Optimization, CSO)。 2.算法原理 2.1算法思想 CSO算法的思想是基于对…

(六)PostgreSQL的组织结构(3)-默认角色和schema

PostgreSQL的组织结构(3)-默认角色和schema 基础信息 OS版本:Red Hat Enterprise Linux Server release 7.9 (Maipo) DB版本:16.2 pg软件目录:/home/pg16/soft pg数据目录:/home/pg16/data 端口:57771 默认角色 Post…

软考135-上午题-【软件工程】-软件配置管理

备注: 该部分考题内容在教材中找不到。直接背题目 一、配置数据库 配置数据库可以分为以下三类: (1) 开发库 专供开发人员使用,其中的信息可能做频繁修改,对其控制相当宽松 (2) 受控库 在生存期某一阶段工作结束时发布的阶段产…

手机拍摄视频怎么做二维码?现场录制视频一键生成二维码

随着手机摄像头的像素不断提升,现在经常会通过手机的拍摄视频,然后发送给其他人查看。当我们想要将一个视频分享给多人去查看时,如果一个个去发送会比较的浪费时间,而且对方还需要下载接受视频后才可以查看,时间成本高…

简化PLC图纸绘制流程:利用SOLIDWORKS Electrical提升效率与准确性

效率一向是工程师比较注重的问题,为了提高工作效率,工程师绞尽脑汁。而在SOLIDWORKS Electrical绘制plc原理图时能有效提高PLC图纸的出图效率,并且可以减少数据误差。 在SOLIDWORKS Electrical绘制PLC图纸时,可以先创建PLC输入/输…

域名被污染了只能换域名吗?

域名污染是指域名的解析结果受到恶意干扰或篡改,使得用户在访问相关网站时出现异常。很多域名遭遇过污染的情况,但是并不知道是域名污染,具体来说,域名污染可能表现为以下情况:用户无法通过输入正确的域名访问到目标网…

24华中杯数学建模C题详解速通

本文针对光纤传感领域的曲线重构问题,提出了一套完整的数学建模与求解方法。通过对三个具体问题的分析和求解,揭示了曲率测量、曲线重构、误差分析等环节的内在联系和数学原理。本文综合运用了光纤传感、数值分析、微分几何等学科的知识,构建了波长-曲率转换模型、曲率连续化模…

Oracle——领先的企业级数据库解决方案

一、WHAT IS ORACLWE: ORACLE 数据库系统是美国 ORACLE 公司(甲骨文)提供的以分布式数据库为核心的一组软件产品,是目前最流行的客户/服务器(CLIENT/SERVER)或B/S 体系结构的数据库之一,ORACLE 通常应用于大型系统的数…

护眼台灯哪个牌子好?五大护眼效果好的护眼台灯强力推荐

护眼台灯哪个牌子好?护眼台灯比较好的牌子有书客、雷士、爱德华医生等,这些护眼台灯由于研发实力比较强,除了基本的照明功能,在护眼效果方面的表现也是比较不错的,这样的护眼台灯能够真正地起到对眼睛的保护作用&#…

读《AI营销画布》步骤五 保收获(十)

前言 正如书中所说,做到前四步就已经很了不起了,但是,现如今有很多公司的IT部门正从原来的公司分离,成立了不同的科技公司,以确保从费用成本中心变为利润中心,当然,分离出来不一定是AI促进的&am…

【问题处理】银河麒麟操作系统实例分享,adb读写缓慢问题分析

1.问题环境 处理器: HUAWEI Kunpeng 920 5251K 内存: 512 GiB 整机类型/架构: TaiShan 200K (Model 2280K) BIOS版本: Byosoft Corp. 1.81.K 内核版本 4.19.90-23.15.v2101.ky10.aarch64 第三方应用 数据库 2.问题…

OCP Java17 SE Developers 复习题14

答案 C. Since the question asks about putting data into a structured object, the best class would be one that deserializes the data. Therefore, ObjectInputStream is the best choice, which is option C. ObjectWriter, BufferedStream, and ObjectReader are no…

karpathy build make more --- 2

1 Introduction 用多层神经网络实现更复杂一点名字预测器。 2 方案 采用两层全连接层,中间采用tanh作为激活函数,最后一层用softmax,loss用cross-entropy. 2.1 实施 step1: 生成输入的字符,输入三个字符,输出一个字符. 采用了…

java的spring循环依赖、Bean作用域等深入理解

前言 通过之前的几篇文章将Spring基于XML配置的IOC原理分析完成,但其中还有一些比较重要的细节没有分析总结,比如循环依赖的解决、作用域的实现原理、BeanPostProcessor的执行时机以及SpringBoot零配置实现原理(ComponentScan、Import、Impo…