【深度学习】如何找到最优学习率

news2024/11/27 7:32:25

经过了大量炼丹的同学都知道,超参数是一个非常玄乎的东西,比如batch size,学习率等,这些东西的设定并没有什么规律和原因,论文中设定的超参数一般都是靠经验决定的。但是超参数往往又特别重要,比如学习率,如果设置了一个太大的学习率,那么loss就爆了,设置的学习率太小,需要等待的时间就特别长,那么我们是否有一个科学的办法来决定我们的初始学习率呢?

在这篇文章中,我会讲一种非常简单却有效的方法来确定合理的初始学习率。

学习率的重要性

目前深度学习使用的都是非常简单的一阶收敛算法,梯度下降法,不管有多少自适应的优化算法,本质上都是对梯度下降法的各种变形,所以初始学习率对深层网络的收敛起着决定性的作用,下面就是梯度下降法的公式

深度学习:如何找到最优学习率

这里 α 就是学习率,如果学习率太小,会导致网络loss下降非常慢,如果学习率太大,那么参数更新的幅度就非常大,就会导致网络收敛到局部最优点,或者loss直接开始增加,如下图所示。

深度学习:如何找到最优学习率

学习率的选择策略在网络的训练过程中是不断在变化的,在刚开始的时候,参数比较随机,所以我们应该选择相对较大的学习率,这样loss下降更快;当训练一段时间之后,参数的更新就应该有更小的幅度,所以学习率一般会做衰减,衰减的方式也非常多,比如到一定的步数将学习率乘上0.1,也有指数衰减等。

这里我们关心的一个问题是初始学习率如何确定,当然有很多办法,一个比较笨的方法就是从0.0001开始尝试,然后用0.001,每个量级的学习率都去跑一下网络,然后观察一下loss的情况,选择一个相对合理的学习率,但是这种方法太耗时间了,能不能有一个更简单有效的办法呢?

一个简单的办法

Leslie N. Smith 在2015年的一篇论文“Cyclical Learning Rates for Training Neural Networks”中的3.3节描述了一个非常棒的方法来找初始学习率,同时推荐大家去看看这篇论文,有一些非常启发性的学习率设置想法。

这个方法在论文中是用来估计网络允许的最小学习率和最大学习率,我们也可以用来找我们的最优初始学习率,方法非常简单。首先我们设置一个非常小的初始学习率,比如1e-5,然后在每个batch之后都更新网络,同时增加学习率,统计每个batch计算出的loss。最后我们可以描绘出学习的变化曲线和loss的变化曲线,从中就能够发现最好的学习率。

下面就是随着迭代次数的增加,学习率不断增加的曲线,以及不同的学习率对应的loss的曲线。

深度学习:如何找到最优学习率
深度学习:如何找到最优学习率

从上面的图片可以看到,随着学习率由小不断变大的过程,网络的loss也会从一个相对大的位置变到一个较小的位置,同时又会增大,这也就对应于我们说的学习率太小,loss下降太慢,学习率太大,loss有可能反而增大的情况。从上面的图中我们就能够找到一个相对合理的初始学习率,0.1。

之所以上面的方法可以work,因为小的学习率对参数更新的影响相对于大的学习率来讲是非常小的,比如第一次迭代的时候学习率是1e-5,参数进行了更新,然后进入第二次迭代,学习率变成了5e-5,参数又进行了更新,那么这一次参数的更新可以看作是在最原始的参数上进行的,而之后的学习率更大,参数的更新幅度相对于前面来讲会更大,所以都可以看作是在原始的参数上进行更新的。正是因为这个原因,学习率设置要从小变到大,而如果学习率设置反过来,从大变到小,那么loss曲线就完全没有意义了。

实现

上面已经说明了算法的思想,说白了其实是非常简单的,就是不断地迭代,每次迭代学习率都不同,同时记录下来所有的loss,绘制成曲线就可以了。下面就是使用PyTorch实现的代码,因为在网络的迭代过程中学习率会不断地变化,而PyTorch的optim里面并没有把learning rate的接口暴露出来,导致显示修改学习率非常麻烦,所以我重新写了一个更加高层的包mxtorch,借鉴了gluon的一些优点,在定义层的时候暴露初始化方法,支持tensorboard,同时增加了大量的model zoo,包括inceptionresnetv2,resnext等等,提供预训练权重,model zoo参考于Cadene的repo。目前这个repo刚刚开始,欢迎有兴趣的小伙伴加入我。

下面就是部分代码,近期会把找学习率的代码合并到mxtorch中。这里使用的数据集是kaggle上的dog breed,使用预训练的resnet50,ScheduledOptim的源码点这里。

   
   
  1. criterion = torch.nn.CrossEntropyLoss()
  2. net = model_zoo.resnet50(pretrained=True)
  3. net.fc = nn.Linear(2048, 120)
  4.  
  5. with torch.cuda.device(0):
  6. net = net.cuda()
  7.  
  8. basic_optim = torch.optim.SGD(net.parameters(), lr=1e-5)
  9. optimizer = ScheduledOptim(basic_optim)
  10.  
  11.  
  12. lr_mult = (1 / 1e-5) ** (1 / 100)
  13. lr = []
  14. losses = []
  15. best_loss = 1e9
  16. for data, label in train_data:
  17. with torch.cuda.device(0):
  18. data = Variable(data.cuda())
  19. label = Variable(label.cuda())
  20. # forward
  21. out = net(data)
  22. loss = criterion(out, label)
  23. # backward
  24. optimizer.zero_grad()
  25. loss.backward()
  26. optimizer.step()
  27. lr.append(optimizer.learning_rate)
  28. losses.append(loss.data[0])
  29. optimizer.set_learning_rate(optimizer.learning_rate lr_mult)
  30. if loss.data[0] < best_loss:
  31. best_loss = loss.data[0]
  32. if loss.data[0] > 4 best_loss or optimizer.learning_rate > 1.:
  33. break
  34.  
  35. plt.figure()
  36. plt.xticks(np.log([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]), (1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1))
  37. plt.xlabel(‘learning rate’)
  38. plt.ylabel(‘loss’)
  39. plt.plot(np.log(lr), losses)
  40. plt.show()
  41. plt.figure()
  42. plt.xlabel(‘num iterations’)
  43. plt.ylabel(‘learning rate’)
  44. plt.plot(lr)

one more thing

通过上面的例子我们能够有一个非常有效的方法寻找初始学习率,同时在我们的认知中,学习率的策略都是不断地做decay,而上面的论文别出心裁,提出了一种循环变化学习率的思想,能够更快的达到最优解,非常具有启发性,推荐大家去阅读阅读。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1254775.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

扩散模型实战(十二):使用调度器DDIM反转来优化图像编辑

推荐阅读列表&#xff1a; 扩散模型实战&#xff08;一&#xff09;&#xff1a;基本原理介绍 扩散模型实战&#xff08;二&#xff09;&#xff1a;扩散模型的发展 扩散模型实战&#xff08;三&#xff09;&#xff1a;扩散模型的应用 扩散模型实战&#xff08;四&#xff…

python之pyqt专栏4-代码控制部件

通过前面的学习&#xff0c;我们已经回创建新的pyqt项目、对项目结构有了了解、也了解Qt Designer设计UI界面并 把"xx.ui"转换为“xxx.py”。 pyqt模块与类 pyqt6 由模块组成&#xff0c;而模块里面又有很多的类 在pyqt官网Modules — PyQt Documentation v6.6.0页面…

函数的防抖与节流

一、函数防抖 &#xff08;一&#xff09;防抖的理解 防抖就是将所有的触发都取消&#xff0c;在规定的时间结束过后才会执行最后一次&#xff0c;也就是说连续快速的触发只会执行最后一次结果。 也可以理解为游戏里的回城按钮&#xff0c;每点一下就会重新刷新回城进度&…

SSM 框架整合

1 整合配置 1.1 流程 1.2 Spring 整合 MyBatis 1.3 Spring 整合 SpringMVC 1.4 配置代码 JdbcConfig.java public class JdbcConfig {Value("${jdbc.driver}")private String driver;Value("${jdbc.url}")private String url;Value("${jdbc.usern…

【挑战业余一周拿证】CSDN官方课程目录

一、亚马逊云科技简介 二、在云中计算 三、全球基础设施和可靠性 四、联网 五、存储和数据库 六、安全性 七、监控和分析 八、定价和支持 九、迁移和创新 十、云之旅 关注订阅号 CSDN 官方中文视频&#xff08;免费&#xff09;&#xff1a;点击进入 一、亚马逊云科…

Apache POI(处理Miscrosoft Office各种文件格式)

文章目录 一、Apache POI介绍二、应用场景三、使用步骤1.导入maven坐标2.写入代码讲解3.读取代码讲解 总结 一、Apache POI介绍 Apache POI 是一个处理Miscrosoft Office各种文件格式的开源项目。简单来说就是&#xff0c;我们可以使用 POI 在 Java 程序中对Miscrosoft Office…

N7 LUP.2.3 DRC如何解决?

这个问题在Design Rule中的介绍如下图&#xff1a; 解决办法是od 15 um的范围要加LUP_GR* cell&#xff0c;需要提高密度(加的位置需要符合tcic)去fix。

ubuntu 安装 jetbrains-toolbox

ubuntu 安装 jetbrains-toolbox 官网下载 jetbrains-toolbox jetbrains 官网 jetbrains 官网&#xff1a;https://www.jetbrains.com/ jetbrains-toolbox 官网下载页面 在下载页面点击 Download 安装 jetbrains-toolbox 解压 jetbrains-toolbox 安装包 到指定目录 本案例将…

【Dockerfile】将自己的项目构建成镜像部署运行

目录 1.Dockerfile 2.镜像结构 3.Dockerfile语法 4.构建Java项目 5.基于Java8构建项目 1.Dockerfile 常见的镜像在DockerHub就能找到&#xff0c;但是我们自己写的项目就必须自己构建镜像了。 而要自定义镜像&#xff0c;就必须先了解镜像的结构才行。 2.镜像结构 镜…

Python——常见内置模块

Python 模块&#xff08;Modules&#xff09;1、概念模块函数类变量2、分类3、模块导入的方法&#xff1a;五种4、使用import 导入模块5、使用from……import部分导入6、使用as关键字为导入模块或功能命名别名7、模块的搜索目录8、自定义模块 常见内置模块一、math模块二、rand…

[pyqt5]PyQt5之如何设置QWidget窗口背景图片问题

目录 PyQt5设置QWidget窗口背景图片 QWidget 添加背景图片问题QSS 背景图样式区别PyQt设置窗口背景图像&#xff0c;以及图像自适应窗口大小变化 总结 PyQt5设置QWidget窗口背景图片 QWidget 添加背景图片问题 QWidget 创建的窗口有时并不能直接用 setStyleSheet 设置窗口部分…

手机技巧:安卓微信8.0.44测试版功能介绍

目录 一、更新介绍 二、功能更新介绍 拍一拍撤回功能 聊天设置界面文案优化 关怀模式新增了非常实用的安静模式 微信设置中新增翻译设置选项 近期腾讯官方终于发布了安卓微信8.0.44测试版&#xff0c;今天小编继续给大家介绍一个本次安卓微信8.0.44测试版本更新的内容&am…

网络运维与网络安全 学习笔记2023.11.26

网络运维与网络安全 学习笔记 第二十七天 今日目标 NAT场景与原理、静态NAT、动态NAT PAT原理与配置、动态PAT之EasyIP、静态PAT之NAT Server NAT场景与原理 项目背景 为节省IP地址和费用&#xff0c;企业内网使用的都是“私有IP地址” Internet网络的组成设备&#xff0c…

3 时间序列预测入门:TCN

0 引言 TCN&#xff08;全称Temporal Convolutional Network&#xff09;&#xff0c;时序卷积网络&#xff0c;是在2018年提出的一个卷积模型&#xff0c;但是可以用来处理时间序列。 论文&#xff1a;https://arxiv.org/pdf/1803.01271.pdf 一维卷积&#xff1a;在时间步长方…

轻量应用服务器推荐,入门首选

轻量应用服务器: 配置高&#xff0c;价格低&#xff0c;高性价比&#xff0c;已经成为开发者和中小企业首选服务器。 轻量-爆款 腾讯云轻量应用服务器 2核2G 3M带宽 40G SSD盘 月流量200G 一月45元。 推荐理由&#xff1a; 腾讯云这次活动预备了多款轻量应用服务器。如果做小…

【Java基础系列】文件上传功能

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

BUUCTF刷题之路-web-[GXYCTF2019]Ping Ping Ping1

启动环境后&#xff0c;是一个简简单单的页面&#xff1a; 看样子是能够触发远程执行漏洞的。尝试下ping 127.0.0.1&#xff0c;如果有回显说明我们的想法是对的。 最近才学习的nc反弹shell。想着是否能用nc反弹shell的办法。控制服务器然后输出flag呢&#xff1f;于是我测试下…

女生儿童房装修:原木上下铺搭配粉色调。福州中宅装饰,福州装修

你是否正在为女生儿童房的装修而发愁呢&#xff1f;该如何让房间既适合孩子生活&#xff0c;又能够满足日常学习的需要呢&#xff1f;这里有一个精美的装修案例&#xff0c;或许能够为你提供一些灵感。 1️⃣ 原木上下铺 房间的上下铺采用了原木色调&#xff0c;带来了自然、温…

RT-DETR 更换损失函数之 SIoU / EIoU / WIoU / Focal_xIoU

文章目录 更换方式CIoUDIoUEIoUGIoUSIoUWIoUFocal_CIoUFocal_DIoUFocal_EIoUFocal_GIoUFocal_SIoU提示更换方式 第一步:将ultralytics/ultralytics/utils/metrics.py文件中的bbox_iou替换为如下的代码:class

聊一聊索引覆盖的好处

问&#xff1a;索引覆盖啥意思&#xff1f; 答&#xff1a;若查询的字段在二级索引的叶子节点中&#xff0c;则可直接返回结果&#xff0c;无需回表。这种通过组合索引避免回表的优化技术也称为索引覆盖&#xff08;Covering Index&#xff09;。在叶子节点中的包括索引字段和主…