Pytorch实现warm up和consine decay

news2024/11/24 4:05:59

在深度学习领域,模型训练过程中的不稳定性是一个常见的问题。为了解决这个问题,在Resnet这篇论文也提及了Warm Up的方法,通过逐渐增加学习率,引导模型在训练初期更稳定地收敛。同时在warm up之后结合consine decay的方法让训练变得更有效。

warm up和consine decay的意义

  • warm up来自于这篇文章:https://arxiv.org/pdf/1706.02677.pdf
  • consine decay来自于这篇文章:https://arxiv.org/pdf/1812.01187.pdf

在第一轮训练的时候,每个数据点对模型来说都是新的,模型会很快地进行数据分布修正,如果这时候学习率就很大,极有可能导致开始的时候就对该数据“过拟合”,后面要通过多轮训练才能拉回来,浪费时间。
当训练了一段时间(比如两轮、三轮)后,模型已经对每个数据点看过几遍了,或者说对当前的batch而言有了一些正确的了解,较大的学习率就不那么容易会使模型学偏,所以可以适当调大学习率。这个过程就可以看做是warmup。
那么为什么之后还要decay呢?当模型训到一定阶段后(比如10个epoch),模型的分布就已经比较固定了,或者说能学到的新东西就比较少了。如果还沿用较大的学习率,就会破坏这种稳定性,用我们通常的话说,就是已经接近loss的local optimal了,为了靠近这个最低点,我们就要慢慢来。

代码实现

1. 非常简单的前期准备工作

import torch,math
import matplotlib.pyplot as plt

# 定义一个简单的网络
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(1, 1, bias=False)

    def forward(self, x):
        return self.linear(x)

# 声明一些超参数
epochs = 100000
warm_up_epochs = epochs*0.3
milestones = [epochs*0.3, epochs*0.7]

# 优化器
optimizer = torch.optim.SGD(model.parameters(), 0.1, momentum=0.9, weight_decay=5e-4)

2. 设置scheduler调度器


# MultiStepLR without warm up
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

# warm_up_with_multistep_lr
warm_up_with_multistep_lr = lambda epoch: epoch / warm_up_epochs if epoch <= warm_up_epochs else 0.1**len([m for m in milestones if m <= epoch])
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_multistep_lr)

# warm_up_with_cosine_lr
warm_up_with_cosine_lr = lambda epoch: epoch / warm_up_epochs if epoch <= warm_up_epochs else 0.5 * ( math.cos((epoch - warm_up_epochs) /(epochs - warm_up_epochs) * math.pi) + 1)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_cosine_lr)

上面的三段代码分别是:

  1. 不使用warm up + multistep learning rate 衰减
  2. 使用warm up + multistep learning rate 衰减
  3. 使用warm up + consine learning rate衰减

代码均使用pytorch中的torch.optim.lr_scheduler.LambdaLR自定义学习率衰减器

3. 模拟训练过程

lrs = []
for item in range(epochs):
    lr = optimizer.param_groups[0]["lr"]
    optimizer.zero_grad()
    optimizer.step()
    scheduler.step()
    lrs.append(scheduler.get_last_lr())
   
plt.plot(range(epochs),lrs)
plt.show()

结果

1. MultiStepLR without warm up

在这里插入图片描述

2. warm_up_with_multistep_lr

warm_up_with_multistep_lr

3. warm_up_with_cosine_lr

warm_up_with_cosine_lr

参考

  1. https://pytorch.org/docs/stable/nn.html
  2. 知乎 https://zhuanlan.zhihu.com/p/148487894
  3. 知乎 https://zhuanlan.zhihu.com/p/424373231

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/744266.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计讯物联网关型水利遥测终端机TY910确保闸站自动化监测长效运行

闸站是我国水利建设工程的重要组成部分&#xff0c;具备调度水源、防洪排涝、灌溉等能力&#xff0c;在农业、水路运输、养殖业等行业领域起着关键作用&#xff0c;进而解决区域水资源不均衡的问题&#xff0c;促进水资源多方面的利用。当前&#xff0c;我国闸站存在数量多、分…

SonarQube安装、汉化及使用

一、环境准备 SonarQube下载&#xff1a;https://www.sonarqube.org/downloads/Java 11 或以上版本安装 多环境共存&#xff0c;指定 SonarQube 的java运行版本&#xff08;注意地址改为自己的java路径&#xff0c;最后面必须带java&#xff09;&#xff1a; 解压目录下 &#…

Atlas 200I DK A2视频保存

Atlas 200I DK A2开发者套件内置案例第一个目标检测&#xff0c;视频保存下来无法打开&#xff0c;修改为cv2保存可正常在本地展示。 原代码部分 def infer_video(video_path, model, labels_dict, cfg, output_pathoutput.mp4):"""视频推理"""…

YOLOv5解析 | 第四篇:common.py文件详解

前言 文件位置:**./models/commonpy** 该文件是实现YOLO算法中各个模块的地方,如果我们需要修改某一模块(例如C3),那么就需要修改这个文件中对应模块的的定义。这里我先围绕代码,带大家过一遍各个模块的定义,详细介绍我将在后续的教案中逐步展开。由于YOLOv5版本问题,同…

SpringBoot 如何处理 CORS 跨域?

Springboot跨域问题&#xff0c;是当前主流web开发人员都绕不开的难题。但我们首先要明确以下几点 跨域只存在于浏览器端&#xff0c;不存在于安卓/ios/Node.js/python/ java等其它环境跨域请求能发出去&#xff0c;服务端能收到请求并正常返回结果&#xff0c;只是结果被浏览器…

SpringMVC 中的数据绑定如何使用 @InitBinder 注解

SpringMVC 是一款基于 Java 的 Web 开发框架&#xff0c;它提供了许多方便开发的功能&#xff0c;其中包括数据绑定。在 SpringMVC 中&#xff0c;数据绑定的工作是由 DataBinder 类完成的。DataBinder 可以将 HTTP 请求中的数据绑定到 Java 对象中&#xff0c;并且还可以将 Ja…

《安富莱嵌入式周报》第316期:垂直降落火箭模型,超低噪声测量,开源电流探头,吸尘器BLDC,绕过TrustZone,提高频率计精度,CMSIS V6.0文档

周报汇总地址&#xff1a;嵌入式周报 - uCOS & uCGUI & emWin & embOS & TouchGFX & ThreadX - 硬汉嵌入式论坛 - Powered by Discuz! 视频版&#xff1a; https://www.bilibili.com/video/BV1rz4y1H71w/ 《安富莱嵌入式周报》第316期&#xff1a;垂直降落…

实验室服务器 环境配置记录

前言 本篇文章为本人自己(Toniht)在实验室服务器上配置环境的一些记录&#xff0c;我也是个半吊子&#xff0c;很多步骤都不知其所以然&#xff0c;主打一个能用就行。主要目的是方便后续遇见问题及时定位&#xff0c;或者后续再次需要时不用上网到处查找。次要目的是希望能帮…

从0到1学习Yalmip工具箱(1)-入门学习

博客中所有内容均来源于自己学习过程中积累的经验以及对yalmip官方文档的翻译&#xff1a;YALMIP 1.Yalmip工具箱的下载与安装 1.1下载 Yalmip的作者是Johan Lfberg&#xff0c;是由Matlab平台编程实现的一个免费开源数学优化工具箱&#xff0c;在官网上就可以下载。官方下载…

8-1-1、kuberbetes学习-service、deployment、ReplicaSet、pod

Kubernetes资源对象Pod、ReplicaSet、Deployment、Service之间的关系_CodingSoldier的博客-CSDN博客 Pod、ReplicaSet、Deployment、Service之间的关系如下图: deployment根据pod的标签关联到pod,是为了管理pod的生命…

Unity 事件函数的执行顺序

脚本生命周期流程图 Awake&#xff1a;在所有 Start 函数之前&#xff0c;以及 prefab 实例化之后调用。&#xff08;如果一个 GameObject 在启动期间处于非活动状态&#xff0c;则在激活之前不会调用它。&#xff09;OnEnable&#xff08;仅在对象处于活动状态时调用&#xff…

为你揭开ai绘画女生软件的神秘面纱

黄琳&#xff1a;嘿&#xff0c;我最近听说了一种叫做ai绘画的东西&#xff0c;你知道它是什么吗&#xff1f; 罗娜&#xff1a;听说这是一种通过人工智能技术来生成艺术作品的过程和方法。 黄琳&#xff1a;哦&#xff0c;那它生成的效果如何呢&#xff1f;有什么软件可以实…

2023IKCEST “一带一路” 国际大数据竞赛重磅启动!

2023IKCEST第五届“一带一路”国际大数据竞赛暨第九届百度&西安交大大数据竞赛&#xff0c;由联合国教科文组织国际工程科技知识中心&#xff08;IKCEST&#xff09;、中国工程科技知识中心&#xff08;CKCEST&#xff09;、百度及西安交通大学共同主办&#xff0c;旨在放眼…

LINUX安装nginx详细步骤,部署web前端项目

1. 安装依赖包 //一键安装上面四个依赖 yum -y install gcc zlib zlib-devel pcre-devel openssl openssl-devel 2. 下载并解压安装包 可以去https://nginx.org/download里面找最新的包&#xff0c;nginx-1.25.1.tar.gz及以后的&#xff0c;里面资源比较多&#xff0c;耐心寻…

飞行动力学 - 第7节-起飞性能 之 基础点摘要

飞行动力学 - 第7节-起飞性能 之 基础点摘要 1. 气动特性2. 起飞性能3. 性能指标3.1 地面滑跑阶段3.2 起飞滑跑距离估算 4. 跑道4.1 编号4.2 等级 5. 参考资料 1. 气动特性 起飞不仅需要考虑升力&#xff0c;还需要在有限跑道长度上加速&#xff0c;因此襟翼放出的角度不能太大…

Claude 2正式上线;Prompt在手,天下我有

&#x1f989; AI新闻 &#x1f680; Claude 2正式上线&#xff0c;AI能力全面提升 摘要&#xff1a;Claude 2正式上线&#xff01;作为ChatGPT的强力挑战者&#xff0c;Claude 2的到来绝对是一个重磅事件。Claude 2性能更强&#xff0c;响应更快&#xff0c;并正式推出了网页…

MPI转以太网模块西门子200以太网通讯设置

你有没有想过&#xff0c;微生物发酵行业的生产控制可以如此先进&#xff1f;今天我们要介绍的是一项关于MPI转以太网模块在发酵集散控制系统中的应用。 这个系统由上位机和下位机组成&#xff0c;可以实现工程师站和操作员站之间的无缝连接&#xff0c;同时还可以实现远程工作…

微信小程序之网络数据请求 wx:request的简单使用

网络数据请求 1. 网络数据请求 wx:request2. 请求格式3. 关闭request的合法检验 1. 网络数据请求 wx:request 出于安全性方面的考虑&#xff0c;小程序官方对数据接口的请求做出了两个限制&#xff1a;只能请求 HTTPS 类型的接口必须将接口的域名添加到信任列表中. 在自己的微…

数智化转型下,财务共享各类RPA建设如何避坑?

企业数智化转型时代的热词——RPA是业务流程优化的利器之一。但对于部分非IT人士对RPA在企业管理领域的运用优势及实施注意点还不太了解&#xff0c;今天与大家快速科普一下。 RPA全称为Robotic Process Automation, 即机器人流程自动化&#xff0c;是一种能够在计算机/手机等…

【C++】 Qt-线程挂起、恢复和退出

文章目录 线程挂起和恢复内核对象线程退出 线程挂起和恢复 我们给设置线程的函数创建一个线程句柄用来接收返回值&#xff0c;并且将状态改为挂起状态 然后在循环中当第五秒时恢复线程&#xff0c;第八秒时连续挂起两次线程&#xff0c;并且返回输出挂起计数器的值&#xff08…