PaddleClas:训练技巧

news2024/11/18 7:32:37

训练技巧

1.优化器的选择

带momentum的SGD优化器有两个劣势,其一是收敛速度慢,其二是初始学习率的设置需要依靠大量的经验,然而如果初始学习率设置得当并且迭代轮数充足,该优化器也会在众多的优化器中脱颖而出,使得其在验证集上获得更高的准确率。

一些自适应学习率的优化器如Adam、RMSProp等,收敛速度往往比较快,但是最终的收敛精度会稍差一些。

如果追求更快的收敛速度,我们推荐使用这些自适应学习率的优化器,如果追求更高的收敛精度,我们推荐使用带momentum的SGD优化器。

2.学习率以及学习率下降策略的选择

学习率的选择往往和优化器以及数据和任务有关系。这里主要介绍以momentum+SGD作为优化器训练ImageNet-1k的学习率以及学习率下降的选择。

学习率下降策略:

在训练初始阶段,由于权重处于随机初始化的状态,损失函数相对容易进行梯度下降,所以可以设置一个较大的学习率。

在训练后期,由于权重参数已经接近最优值,较大的学习率无法进一步寻找最优值,所以需要设置一个较小的学习率。

在训练整个过程中,很多研究者使用的学习率下降方式是piecewise_decay,即阶梯式下降学习率,如在ResNet50标准的训练中,我们设置的初始学习率是0.1,每30epoch学习率下降到原来的1/10,一共迭代120epoch。除了piecewise_decay,很多研究者也提出了学习率的其他下降方式,如polynomial_decay(多项式下降)、exponential_decay(指数下降),cosine_decay(余弦下降)等,其中cosine_decay无需调整超参数,鲁棒性也比较高,所以成为现在提高模型精度首选的学习率下降方式。

Cosine_decay和piecewise_decay的学习率变化曲线如下图所示,容易观察到,在整个训练过程中,cosine_decay都保持着较大的学习率,所以其收敛较为缓慢,但是最终的收敛效果较peicewise_decay更好一些。

另外,从图中我们也可以看到,cosine_decay里学习率小的轮数较少,这样会影响到最终的精度,所以为了使得cosine_decay发挥更好的效果,建议迭代更多的轮数,如200轮

warmup策略

如果使用较大的batch_size训练神经网络时,我们建议您使用warmup策略。Warmup策略顾名思义就是让学习率先预热一下,在训练初期我们不直接使用最大的学习率,而是用一个逐渐增大的学习率去训练网络,当学习率增大到最高点时,再使用学习率下降策略中提到的学习率下降方式衰减学习率的值。实验表明,在batch_size较大时,warmup可以稳定提升模型的精度。在训练MobileNetV3等batch_size较大的实验中,我们默认将warmup中的epoch设置为5,即先用5epoch将学习率从0增加到最大值,再去做相应的学习率衰减。

3.batch_size的选择

batch_size是训练神经网络中的一个重要的超参数,该值决定了一次将多少数据送入神经网络参与训练。在论文[1]中,作者通过实验发现,当batch_size的值与学习率的值呈线性关系时,收敛精度几乎不受影响。在训练ImageNet数据时,大部分的神经网络选择的初始学习率为0.1,batch_size是256,所以根据实际的模型大小和显存情况,可以将学习率设置为0.1*k,batch_size设置为256*k。

4.weight_decay的选择

过拟合是机器学习中常见的一个名词,简单理解即为模型在训练数据上表现很好,但在测试数据上表现较差,在卷积神经网络中,同样存在过拟合的问题,为了避免过拟合,很多正则方式被提出,其中,weight_decay是其中一个广泛使用的避免过拟合的方式。Weight_decay等价于在最终的损失函数后添加L2正则化,L2正则化使得网络的权重倾向于选择更小的值,最终整个网络中的参数值更趋向于0,模型的泛化性能相应提高。在各大深度学习框架的实现中,该值表达的含义是L2正则前的系数,在paddle框架中,该值的名称是l2_decay,所以以下都称其为l2_decay。该系数越大,表示加入的正则越强,模型越趋于欠拟合状态。在训练ImageNet的任务中,大多数的网络将该参数值设置为1e-4,在一些小的网络如MobileNet系列网络中,为了避免网络欠拟合,该值设置为1e-5~4e-5之间。当然,该值的设置也和具体的数据集有关系,当任务的数据集较大时,网络本身趋向于欠拟合状态,可以将该值适当减小,当任务的数据集较小时,网络本身趋向于过拟合状态,可以将该值适当增大。下表展示了MobileNetV1_x0_25在ImageNet-1k上使用不同l2_decay的精度情况。由于MobileNetV1_x0_25是一个比较小的网络,所以l2_decay过大会使网络趋向于欠拟合状态,所以在该网络中,相对1e-4,3e-5是更好的选择。

模型L2_decayTrain acc1/acc5Test acc1/acc5
MobileNetV1_x0_251e-443.79%/67.61%50.41%/74.70%
MobileNetV1_x0_253e-547.38%/70.83%51.45%/75.45%

另外,该值的设置也和训练过程中是否使用其他正则化有关系。如果训练过程中的数据预处理比较复杂,相当于训练任务变的更难,可以将该值适当减小,下表展示了在ImageNet-1k上,ResNet50在使用randaugment预处理方式后使用不同l2_decay的精度。容易观察到,在任务变难后,使用更小的l2_decay有助于模型精度的提升。

模型L2_decayTrain acc1/acc5Test acc1/acc5
ResNet501e-475.13%/90.42%77.65%/93.79%
ResNet507e-575.56%/90.55%78.04%/93.74%

综上所述,l2_decay可以根据具体的任务和模型去做相应的调整,通常简单的任务或者较大的模型,推荐使用较大的l2_decay,复杂的任务或者较小的模型,推荐使用较小的l2_decay。

5.label_smoothing的选择

Label_smoothing是深度学习中的一种正则化方法,其全称是 Label Smoothing Regularization(LSR),即标签平滑正则化。在传统的分类任务计算损失函数时,是将真实的one hot标签与神经网络的输出做相应的交叉熵计算,而label_smoothing是将真实的one hot标签做一个标签平滑的处理,使得网络学习的标签不再是一个hard label,而是一个有概率值的soft label,其中在类别对应的位置的概率最大,其他位置概率是一个非常小的数。具体的计算方式参见论文[2]。在label_smoothing里,有一个epsilon的参数值,该值描述了将标签软化的程度,该值越大,经过label smoothing后的标签向量的标签概率值越小,标签越平滑,反之,标签越趋向于hard label,在训练ImageNet-1k的实验里通常将该值设置为0.1。 在训练ImageNet-1k的实验中,我们发现,ResNet50大小级别及其以上的模型在使用label_smooting后,精度有稳定的提升。下表展示了ResNet50_vd在使用label_smoothing前后的精度指标。

模型Use_label_smoothingTest acc1
ResNet50_vd077.9%
ResNet50_vd178.4%

同时,由于label_smoohing相当于一种正则方式,在相对较小的模型上,精度提升不明显甚至会有所下降,下表展示了ResNet18在ImageNet-1k上使用label_smoothing前后的精度指标。可以明显看到,在使用label_smoothing后,精度有所下降。

模型Use_label_smoohingTrain acc1/acc5Test acc1/acc5
ResNet18069.81%/87.70%70.98%/89.92%
ResNet18168.00%/86.56%70.81%/89.89%

综上所述,较大的模型使用label_smoohing可以有效提升模型的精度,较小的模型使用label_smoohing可能会降低模型的精度,所以在决定是否使用label_smoohing前,需要评估模型的大小和任务的难易程度。

6.针对小模型更改图片的crop面积与拉伸变换程度

在ImageNet-1k数据的标准预处理中,random_crop函数中定义了scale和ratio两个值,两个值分别确定了图片crop的大小和图片的拉伸程度,其中scale的默认取值范围是0.08-1(lower_scale-upper_scale),ratio的默认取值范围是3/4-4/3(lower_ratio-upper_ratio)。在非常小的网络训练中,此类数据增强会使得网络欠拟合,导致精度有所下降。为了提升网络的精度,可以使其数据增强变的更弱,即增大图片的crop区域或者减弱图片的拉伸变换程度。我们可以分别通过增大lower_scale的值或缩小lower_ratio与upper_scale的差距来实现更弱的图片变换。下表列出了使用不同lower_scale训练MobileNetV2_x0_25的精度,可以看到,增大图片的crop区域面积后训练精度和验证精度均有提升。

模型Scale取值范围Train_acc1/acc5Test_acc1/acc5
MobileNetV2_x0_25[0.08,1]50.36%/72.98%52.35%/75.65%
MobileNetV2_x0_25[0.2,1]54.39%/77.08%53.18%/76.14%

7.使用数据增广方式提升精度

一般来说,数据集的规模对性能影响至关重要,但是图片的标注往往比较昂贵,所以有标注的图片数量往往比较稀少,在这种情况下,数据的增广尤为重要。在训练ImageNet-1k的标准数据增广中,主要使用了random_crop与random_flip两种数据增广方式,然而,近些年,越来越多的数据增广方式被提出,如cutout、mixup、cutmix、AutoAugment等。实验表明,这些数据的增广方式可以有效提升模型的精度,下表列出了ResNet50在8种不同的数据增广方式的表现,可以看出,相比baseline,所有的数据增广方式均有收益,其中cutmix是目前最有效的数据增广。更多数据增广的介绍请参考数据增广章节。

模型数据增广方式Test top-1
ResNet50标准变换77.31%
ResNet50Auto-Augment77.95%
ResNet50Mixup78.28%
ResNet50Cutmix78.39%
ResNet50Cutout78.01%
ResNet50Gridmask77.85%
ResNet50Random-Augment77.70%
ResNet50Random-Erasing77.91%
ResNet50Hide-and-Seek77.43%

8. 通过train_acc和test_acc确定调优策略

在训练网络的过程中,通常会打印每一个epoch的训练集准确率和验证集准确率,二者刻画了该模型在两个数据集上的表现。通常来说,训练集的准确率比验证集准确率微高或者二者相当是比较不错的状态。如果发现训练集的准确率比验证集高很多,说明在这个任务上已经过拟合,需要在训练过程中加入更多的正则,如增大l2_decay的值,加入更多的数据增广策略,加入label_smoothing策略等;如果发现训练集的准确率比验证集低一些,说明在这个任务上可能欠拟合,需要在训练过程中减弱正则效果,如减小l2_decay的值,减少数据增广方式,增大图片crop区域面积,减弱图片拉伸变换,去除label_smoothing等。

9.通过已有的预训练模型提升自己的数据集的精度

在现阶段计算机视觉领域中,加载预训练模型来训练自己的任务已成为普遍的做法,相比从随机初始化开始训练,加载预训练模型往往可以提升特定任务的精度。一般来说,业界广泛使用的预训练模型是通过训练128万张图片1000类的ImageNet-1k数据集得到的,该预训练模型的fc层权重是是一个k*1000的矩阵,其中k是fc层以前的神经元数,在加载预训练权重时,无需加载fc层的权重。在学习率方面,如果您的任务训练的数据集特别小(如小于1千张),我们建议你使用较小的初始学习率,如0.001(batch_size:256,下同),以免较大的学习率破坏预训练权重。如果您的训练数据集规模相对较大(大于10万),我们建议你尝试更大的初始学习率,如0.01或者更大。

如果您觉得此文档对您有帮助,欢迎star我们的项目:GitHub - PaddlePaddle/PaddleClas: A treasure chest for visual classification and recognition powered by PaddlePaddle

参考文献

[1]P. Goyal, P. Dolla ́r, R. B. Girshick, P. Noordhuis, L. Wesolowski, A. Kyrola, A. Tulloch, Y. Jia, and K. He. Accurate, large minibatch SGD: training imagenet in 1 hour. CoRR, abs/1706.02677, 2017.

[2]C.Szegedy,V.Vanhoucke,S.Ioffe,J.Shlens,andZ.Wojna. Rethinking the inception architecture for computer vision. CoRR, abs/1512.00567, 2015.

 

参考:

1:https://paddleclas.readthedocs.io/zh_CN/latest/models/Tricks.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/727019.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于springboot的教学助手后端管理系统设计与实现(源码+文档+开题+数据库+任务书)

教学助手系统是高校教学、教学管理的重要应用系统。随着教研教改的深入发展,信息技术的运用已经成为打造高效课堂必不可少的重要手段。教学助手软件集教材资源分享、课前导学、课堂互动、在线检测、课后作业等功能为一体,拓展教学时空,增强教…

python爬虫_requests获取bilibili锻刀村系列的字幕并用分词划分可视化词云图展示

文章目录 ⭐前言⭐获取字幕步骤💖 查找heartbeat接口💖 字幕api接口💖 正则提取p标签的内容 ⭐分词⭐结束 ⭐前言 大家好,我是yma16,本文分享python的requests获取哔哩哔哩锻刀村的字幕并用分词划分可视化词云图展示。…

Linux MySQL数据迁移

背景:MySQL安装时如果数据文件存在系统盘,随着业务的增长,必定会占用越来越多的系统盘空间直至爆满。为了给系统盘腾出空间来维持服务器的正常运转,需要将MySQL数据文件转移到其他磁盘。 实现步骤:1.在其他磁盘上创建…

Redis集群环境搭建[CentOS7]

下载 cd /usr/local/src/ wget https://mirrors.huaweicloud.com/redis/redis-7.0.11.tar.gz编译安装 tar -xzvf /usr/local/src/redis-7.0.11.tar.gz -C /usr/local/src/ cd /usr/local/src/redis-7.0.11 make PREFIX/usr/local/redis-7.0.11 install制作集群配置模板 cat …

删除排序链表中的重复元素(保留一个重复元素或不保留重复元素)

题目1:给定一个已排序的链表的头 head , 删除所有重复的元素,使每个元素只出现一次 。返回 已排序的链表 。(重复元素不全部都删除,需要保留一个) 解题思路: 遍历链表,找到重复元素…

Openlayers Draw的用法、属性、方法、事件介绍

Openlayers Draw绘制功能比较常用,如我们需要手动绘制一些点、线、面、多边形,圆等图形,Openlayers为我们提供了相关的API,主要API都在ol/interaction/Draw里面,绘制的API使用起来也比较简单,首先创建一个Draw对象,然后再使用Map的addInteraction方法添加该对象,就可以…

造个CPU玩玩——从硬件到软件的设计

本文使用模拟电路制造CPU——纸上谈兵。 计算机中蕴藏的哲理 最基本的思想是:通过基本电路的接线,确立输入-输出规则,类似函数的入参和返回值,便构成一个功能电路单元。单元套单元组成新单元,如此往复。“一生二&…

vue实现指定div右键显示菜单,并实现复制内容到粘贴板

效果图 实现 全有注释&#xff0c;代码如下&#xff1a; <!--指定的需要右键菜单的div--><div class"content" contextmenu.prevent"showMenu($event, item)"><span class"content_msg">{{item}}</span></div>&…

Python如何批量将图片以超链接的形式插入Excel

【研发背景】 在日常办公中&#xff0c;我们经常需要将图片插入进Excel中&#xff0c;但是如果插入的图片太多的话&#xff0c;就会导致Excel的文件内存越来越大&#xff0c;但是如果我直插入图片的路径&#xff0c;或者只是更改某一列的数据设置为超链接&#xff0c;这样的话&…

拉格朗日乘子法

首先定义一个原始最优化问题&#xff1a; 引入广义拉格朗日函数&#xff0c;将约束问题转换为无约束优化问题&#xff1a; 参数和自变量x求偏导&#xff0c;分别为零&#xff0c;就能解出一个值&#xff08;极大值或者极小值&#xff09;。 直接求解有时候非常困难&#xff0c…

企业和公司扩展WordPress网站的4种方法

Netflix 通过邮递观看 DVD。Apple 是一家计算机公司&#xff0c;而不是电话公司。WordPress 是一个博客平台。 这三个陈述有什么共同点&#xff1f;十年前都是对的&#xff0c;现在都不是了。如今&#xff0c;Netflix 以数字方式提供原创内容而闻名。Apple 正在推出其广受欢迎…

从零开始 Spring Boot 62:过滤实体和关系

从零开始 Spring Boot 62&#xff1a;过滤实体和关系 图源&#xff1a;简书 (jianshu.com) JPA&#xff08;Hibernate&#xff09;中有一些注解可以用于筛选实体和关系&#xff0c;本文将介绍这些注解。 Where 有时候&#xff0c;我们希望对表中的数据进行“软删除”&#x…

Meta为全天候AR眼镜设计了AI系统的八大指导方针

众所周知&#xff0c;Meta不仅局限在Quest这类VR头显上&#xff0c;同时还在打造更轻量化的AR眼镜&#xff0c;目标就是让产品更好的融入到人们的日常生活中去。除了硬件上轻量化以外&#xff0c;在功能和交互体验上也至关重要&#xff0c;例如自然交互方式&#xff0c;比如手势…

什么是人工智能大模型?

目录 1. 人工智能大模型的概述&#xff1a;2. 典型的人工智能大模型&#xff1a;3. 人工智能大模型的应用领域&#xff1a;4. 人工智能大模型的挑战与未来&#xff1a;5. 人工智能大模型的开发和应用&#xff1a;6. 人工智能大模型的学习资源&#xff1a; 人工智能大模型是指具…

MySQL(创建、删除、查询数据库以及依据数据类型建表)

一、 1.创建数据库&#xff0c; mysql> CREATE DATABASE IF NOT EXISTS SECOND_DB; Query OK, 1 row affected (0.01 sec)2.删除数据库&#xff0c; mysql> DROP DATABASE IF EXISTS SECOND_DB; Query OK, 0 rows affected (0.11 sec)3.查询创建数据的语句&#xff0c;…

优化模型案例

案例1 生产决策问题 &#xff08;一个简单的线性规划问题&#xff09; 某工厂在计划期内要安排I、II两种产品生产。生产单位产品所需的设备台时&#xff0c;A&#xff0c;B两种原材料的消耗&#xff0c;资源的限制以及单件产品利润如下表所示 问工厂应分别生产多少单位产品I和…

修改开发板内核启动日志输出级别

1.用超级用户权限输入命令 2.将verbosity 1改成7&#xff0c;将console(控制&#xff09; both 改成 serial&#xff08;串口控制),然后wq保存退出 3.输入命令sudo reboot 查看启动日志输出级别

华为云CodeArts IDE Online:让你随时随地畅享云端编码乐趣

软件开发是把人类智慧以代码方式表达出来的过程&#xff0c;面对不可预知且快速变化的世界&#xff0c;开发者面临着前所未有的巨大挑战。例如&#xff0c;软件交付周期和迭代速度要求更高、开发者需要快速学习各种新技术、开发时间碎片化严重、分散的交付团队协同困难、开发与…

微信小程序接入第三方后,不能及时发送客服消息

微信小程序接入第三方后&#xff0c;不能及时发送客服消息 1、要把这里关了&#xff0c;后台才能及时收到用户发来的消息

机器学习16:使用 TensorFlow 进行神经网络编程练习

在【机器学习15】中&#xff0c;笔者介绍了神经网络的基本原理。在本篇中&#xff0c;我们使用 TensorFlow 来训练、验证神经网络模型&#xff0c;并探索不同 “层数节点数” 对模型预测效果的影响&#xff0c;以便读者对神经网络模型有一个更加直观的认识。 目录 1.导入依赖…