Paddle:SSLD 知识蒸馏实战

news2024/11/25 10:45:33

SSLD 知识蒸馏实战

论文:https://arxiv.org/abs/2103.05959

1. 算法介绍

1.1 简介

PaddleClas 融合已有的知识蒸馏方法 [2,3],提供了一种简单的半监督标签知识蒸馏方案(SSLD,Simple Semi-supervised Label Distillation),基于 ImageNet1k 分类数据集,在 ResNet_vd 以及 MobileNet 系列上的精度均有超过 3% 的绝对精度提升。

1.2 SSLD蒸馏策略

SSLD蒸馏原理实现如下图所示,其核心策略为:大数据集训练+ImageNet1k蒸馏Finetune。选择合适的教师模型,首先在挑选得到的500万数据集上进行训练,然后在ImageNet1k训练集上进行finetune,最终得到蒸馏后的学生模型。

 

首先,我们从 ImageNet22k 中挖掘出了近 400 万张图片,同时与 ImageNet-1k 训练集整合在一起,得到了一个新的包含 500 万张图片的数据集。然后,我们将学生模型与教师模型组合成一个新的网络,该网络分别输出学生模型和教师模型的预测分布,与此同时,固定教师模型整个网络的梯度,而学生模型可以做正常的反向传播。最后,我们将两个模型的 logits 经过 softmax 激活函数转换为 soft label,并将二者的 soft label 做 JS 散度作为损失函数,用于蒸馏模型训练。

以 MobileNetV3(该模型直接训练,精度为 75.3%)的知识蒸馏为例,该方案的核心策略优化点如下所示。

实验ID策略Top-1 acc
1baseline75.60%
2更换教师模型精度为82.4%的权重76.00%
3使用改进的JS散度损失函数76.20%
4迭代轮数增加至360epoch77.10%
5添加400W挖掘得到的无标注数据78.50%
6基于ImageNet1k数据微调78.90%
  • 注:其中baseline的训练条件为
    • 训练数据:ImageNet1k数据集
    • 损失函数:Cross Entropy Loss
    • 迭代轮数:120epoch

SSLD 蒸馏方案的一大特色就是无需使用图像的真值标签,因此可以任意扩展数据集的大小,考虑到计算资源的限制,我们在这里仅基于 ImageNet22k 数据集对蒸馏任务的训练集进行扩充。在 SSLD 蒸馏任务中,我们使用了 Top-k per class 的数据采样方案 [3] 。

具体步骤如下。

  • (1)训练集去重。我们首先基于 SIFT 特征相似度匹配的方式对 ImageNet22k 数据集与 ImageNet1k 验证集进行去重,防止添加的 ImageNet22k 训练集中包含 ImageNet1k 验证集图像,最终去除了 4511 张相似图片。

  • (2)大数据集 soft label 获取,对于去重后的 ImageNet22k 数据集,我们使用 ResNeXt101_32x16d_wsl 模型进行预测,得到每张图片的 soft label 。
  • (3)Top-k 数据选择,ImageNet1k 数据共有 1000 类,对于每一类,找出属于该类并且得分最高的 k 张图片,最终得到一个数据量不超过 1000*k 的数据集(某些类上得到的图片数量可能少于 k 张)。
  • (4)将该数据集与 ImageNet1k 的训练集融合组成最终蒸馏模型所使用的数据集,数据量为 500 万。

1.3 SKL-UGI蒸馏策略

此外,在无标注数据选择的过程中,我们发现使用更加通用的数据,即使不需要严格的数据筛选过程,也可以帮助知识蒸馏任务获得稳定的精度提升,因而提出了SKL-UGI (Symmetrical-KL Unlabeled General Images distillation)知识蒸馏方案。

通用数据可以使用ImageNet数据或者与场景相似的数据集。更多关于SKL-UGI的应用,请参考:超轻量图像分类方案PULC使用教程。

2. 实验

  • PaddleClas的蒸馏策略为大数据集训练+ImageNet1k蒸馏finetune的策略。选择合适的教师模型,首先在挑选得到的500万数据集上进行训练,然后在ImageNet1k训练集上进行finetune,最终得到蒸馏后的学生模型。

2.1 教师模型的选择

为了验证教师模型和学生模型的模型大小差异和教师模型的模型精度对蒸馏结果的影响,我们做了几组实验验证。训练策略统一为:cosine_decay_warmup,lr=1.3, epoch=120, bs=2048,学生模型均为从头训练。

Teacher ModelTeacher Top1Student ModelStudent Top1
ResNeXt101_32x16d_wsl84.2%MobileNetV3_large_x1_075.78%
ResNet50_vd79.12%MobileNetV3_large_x1_075.60%
ResNet50_vd82.35%MobileNetV3_large_x1_076.00%

从表中可以看出

教师模型结构相同时,其精度越高,最终的蒸馏效果也会更好一些。

教师模型与学生模型的模型大小差异不宜过大,否则反而会影响蒸馏结果的精度。

因此最终在蒸馏实验中,对于ResNet系列学生模型,我们使用ResNeXt101_32x16d_wsl作为教师模型;对于MobileNet系列学生模型,我们使用蒸馏得到的ResNet50_vd作为教师模型

2.2 大数据蒸馏

基于PaddleClas的蒸馏策略为大数据集训练+imagenet1k finetune的策略。

针对从ImageNet22k挑选出的400万数据,融合imagenet1k训练集,组成共500万的训练集进行训练,具体地,在不同模型上的训练超参及效果如下。

Student Modelnum_epochl2_ecaybatch size/gpu cardsbase lrlearning rate decaytop1 acc
MobileNetV13603e-54096/81.6cosine_decay_warmup77.65%
MobileNetV23601e-53072/80.54cosine_decay_warmup76.34%
MobileNetV3_large_x1_03601e-55760/243.65625cosine_decay_warmup78.54%
MobileNetV3_small_x1_03601e-55760/243.65625cosine_decay_warmup70.11%
ResNet50_vd3607e-51024/320.4cosine_decay_warmup82.07%
ResNet101_vd3607e-51024/320.4cosine_decay_warmup83.41%
Res2Net200_vd_26w_4s3604e-51024/320.4cosine_decay_warmup84.82%

2.3 ImageNet1k训练集finetune

对于在大数据集上训练的模型,其学习到的特征可能与ImageNet1k数据特征有偏,因此在这里使用ImageNet1k数据集对模型进行finetune。finetune的超参和finetune的精度收益如下。

Student Modelnum_epochl2_ecaybatch size/gpu cardsbase lrlearning rate decaytop1 acc
MobileNetV1303e-54096/80.016cosine_decay_warmup77.89%
MobileNetV2301e-53072/80.0054cosine_decay_warmup76.73%
MobileNetV3_large_x1_0301e-52048/80.008cosine_decay_warmup78.96%
MobileNetV3_small_x1_0301e-56400/320.025cosine_decay_warmup71.28%
ResNet50_vd607e-51024/320.004cosine_decay_warmup82.39%
ResNet101_vd307e-51024/320.004cosine_decay_warmup83.73%
Res2Net200_vd_26w_4s3604e-51024/320.004cosine_decay_warmup85.13%

2.4 数据增广以及基于Fix策略的微调

  • 基于前文所述的实验结论,我们在训练的过程中加入自动增广(AutoAugment)[4],同时进一步减小了l2_decay(4e-5->2e-5),最终ResNet50_vd经过SSLD蒸馏策略,在ImageNet1k上的精度可以达到82.99%,相比之前不加数据增广的蒸馏策略再次增加了0.6%。
  • 对于图像分类任务,在测试的时候,测试尺度为训练尺度的1.15倍左右时,往往在不需要重新训练模型的情况下,模型的精度指标就可以进一步提升[5],对于82.99%的ResNet50_vd在320x320的尺度下测试,精度可达83.7%,我们进一步使用Fix策略,即在320x320的尺度下进行训练,使用与预测时相同的数据预处理方法,同时固定除FC层以外的所有参数,最终在320x320的预测尺度下,精度可以达到84.0%

2.5 实验过程中的一些问题

 bn的计算方法

  • 在预测过程中,batch norm的平均值与方差是通过加载预训练模型得到(设其模式为test mode)。在训练过程中,batch norm是通过统计当前batch的信息(设其模式为train mode),与历史保存信息进行滑动平均计算得到,在蒸馏任务中,我们发现通过train mode,即教师模型的bn实时变化的模式,去指导学生模型,比通过test mode蒸馏,得到的学生模型性能更好一些,下面是一组实验结果。因此我们在该蒸馏方案中,均使用train mode去得到教师模型的soft label。
Teacher ModelTeacher Top1Student ModelStudent Top1
ResNet50_vd82.35%MobileNetV3_large_x1_076.00%
ResNet50_vd82.35%MobileNetV3_large_x1_075.84%

三. 蒸馏模型的应用

3.1 使用方法

  • 中间层学习率调整。蒸馏得到的模型的中间层特征图更加精细化,因此将蒸馏模型预训练应用到其他任务中时,如果采取和之前相同的学习率,容易破坏中间层特征。而如果降低整体模型训练的学习率,则会带来训练收敛速度慢的问题。因此我们使用了中间层学习率调整的策略。具体地:
    • 针对ResNet50_vd,我们设置一个学习率倍数列表,res block之前的3个conv2d卷积参数具有统一的学习率倍数,4个res block的conv2d分别有一个学习率参数,共需设置5个学习率倍数的超参。在实验中发现。用于迁移学习finetune分类模型时,[0.1,0.1,0.2,0.2,0.3]的中间层学习率倍数设置在绝大多数的任务中都性能更好;而在目标检测任务中,[0.05,0.05,0.05,0.1,0.15]的中间层学习率倍数设置能够带来更大的精度收益。
    • 对于MoblileNetV3_large_1x0,由于其包含15个block,我们设置每3个block共享一个学习率倍数参数,因此需要共5个学习率倍数的参数,最终发现在分类和检测任务中,[0.25,0.25,0.5,0.5,0.75]的中间层学习率倍数能够带来更大的精度收益。
  • 适当的l2 decay。不同分类模型在训练的时候一般都会根据模型设置不同的l2 decay,大模型为了防止过拟合,往往会设置更大的l2 decay,如ResNet50等模型,一般设置为1e-4;而如MobileNet系列模型,在训练时往往都会设置为1e-5~4e-5,防止模型过度欠拟合,在蒸馏时亦是如此。在将蒸馏模型应用到目标检测任务中时,我们发现也需要调节backbone甚至特定任务模型模型的l2 decay,和预训练蒸馏时的l2 decay尽可能保持一致。以Faster RCNN MobiletNetV3 FPN为例,我们发现仅修改该参数,在COCO2017数据集上就可以带来最多0.5%左右的精度(mAP)提升(默认Faster RCNN l2 decay为1e-4,我们修改为1e-5~4e-5均有0.3%~0.5%的提升)。

3.2 迁移学习finetune

  • 为验证迁移学习的效果,我们在10个小的数据集上验证其效果。在这里为了保证实验的可对比性,我们均使用ImageNet1k数据集训练的标准预处理过程,对于蒸馏模型我们也添加了蒸馏模型中间层学习率的搜索。
  • 对于ResNet50_vd,baseline为Top1 Acc 79.12%的预训练模型基于grid search搜索得到的最佳精度,对比实验则为基于该精度对预训练和中间层学习率进一步搜索得到的最佳精度。下面给出10个数据集上所有baseline和蒸馏模型的精度对比。
DatasetModelBaseline Top1 AccDistillation Model Finetune
Oxford102 flowersResNete50_vd97.18%97.41%
caltech-101ResNete50_vd92.57%93.21%
Oxford-IIIT-PetsResNete50_vd94.30%94.76%
DTDResNete50_vd76.48%77.71%
fgvc-aircraft-2013bResNete50_vd88.98%90.00%
Stanford-CarsResNete50_vd92.65%92.76%
SUN397ResNete50_vd64.02%68.36%
cifar100ResNete50_vd86.50%87.58%
cifar10ResNete50_vd97.72%97.94%
Food-101ResNete50_vd89.58%89.99%
  • 可以看出在上面10个数据集上,结合适当的中间层学习率倍数设置,蒸馏模型平均能够带来1%以上的精度提升。

3.3 目标检测

我们基于两阶段目标检测Faster/Cascade RCNN模型验证蒸馏得到的预训练模型的效果。

  • ResNet50_vd

设置训练与评测的尺度均为640x640,最终COCO上检测指标如下。

Modeltrain/test scalepretrain top1 accfeature map lrcoco mAP
Faster RCNN R50_vd FPN640/64079.12%[1.0,1.0,1.0,1.0,1.0]34.8%
Faster RCNN R50_vd FPN640/64079.12%[0.05,0.05,0.1,0.1,0.15]34.3%
Faster RCNN R50_vd FPN640/64082.18%[0.05,0.05,0.1,0.1,0.15]36.3%

在这里可以看出,对于未蒸馏模型,过度调整中间层学习率反而降低最终检测模型的性能指标。基于该蒸馏模型,我们也提供了领先的服务端实用目标检测方案,详细的配置与训练代码均已开源,可以参考PaddleDetection。

3. 注意事项

  • 用户在使用SSLD蒸馏之前,首先需要在目标数据集上训练一个教师模型,该教师模型用于指导学生模型在该数据集上的训练。
  • 如果学生模型没有加载预训练模型,训练的其他超参数可以参考该学生模型在ImageNet-1k上训练的超参数,如果学生模型加载了预训练模型,学习率可以调整到原来的1/10或者1/100。
  • 在SSLD蒸馏的过程中,学生模型只学习soft-label导致训练目标变的更加复杂,建议可以适当的调小l2_decay的值来获得更高的验证集准确率。
  • 若用户准备添加无标签的训练数据,只需要将新的训练数据放置在原本训练数据的路径下,生成新的数据list即可,另外,新生成的数据list需要将无标签的数据添加伪标签(只是为了统一读数据)。
  • 教师模型的选择。在进行知识蒸馏时,如果教师模型与学生模型的结构差异太大,蒸馏得到的结果反而不会有太大收益。相同结构下,精度更高的教师模型对结果也有很大影响。相比于79.12%的ResNet50_vd教师模型,使用82.4%的ResNet50_vd教师模型可以带来0.4%的绝对精度收益(75.6%->76.0%)。
  • 更多的迭代轮数。蒸馏的baseline实验只迭代了120个epoch。实验发现,迭代轮数越多,蒸馏效果越好,最终迭代了360epoch,精度指标可以达到77.1%(76.2%->77.1%)

四. 实操

训练配置

# model architecture
Arch:
  name: "DistillationModel"
  class_num: &class_num 1000
  # if not null, its lengths should be same as models
  pretrained_list:
  # if not null, its lengths should be same as models
  freeze_params_list:
  - True
  - False
  infer_model_name: "Student"
  models:
    - Teacher:
        name: ResNet50_vd
        class_num: *class_num
        pretrained: True
        use_ssld: True
    - Student:
        name: PPLCNet_x2_5
        class_num: *class_num
        pretrained: False
 
# loss function config for traing/eval process
Loss:
  Train:
    - DistillationDMLLoss:
        weight: 1.0
        model_name_pairs:
        - ["Student", "Teacher"]
  Eval:
    - CELoss:
        weight: 1.0

 训练baseline

  • teacher模型:res50_vd,精度0.71
  • student模型:mbv3,精度 0.52

训练蒸馏:

  • SSLD蒸馏:精度0.53

整体看:

  • 有提升,但不是特别多。因为extra数据也是起到很大作用的。在实际的业务数据中,应该会有不错的表现。

五. 参考文献

[1] Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network[J]. arXiv preprint arXiv:1503.02531, 2015.

[2] Bagherinezhad H, Horton M, Rastegari M, et al. Label refinery: Improving imagenet classification through label progression[J]. arXiv preprint arXiv:1805.02641, 2018.

[3] Yalniz I Z, Jégou H, Chen K, et al. Billion-scale semi-supervised learning for image classification[J]. arXiv preprint arXiv:1905.00546, 2019.

[4] Cubuk E D, Zoph B, Mane D, et al. Autoaugment: Learning augmentation strategies from data[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2019: 113-123.

[5] Touvron H, Vedaldi A, Douze M, et al. Fixing the train-test resolution discrepancy[C]//Advances in Neural Information Processing Systems. 2019: 8250-8260.

参考链接:

https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5/docs/zh_CN/training/advanced/ssld.md

一、模型压缩方法简介 — PaddleClas 文档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/716289.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何解释商业智能BI?商业智能BI未来的发展趋势?

商业智能BI能够成为当前商业世界中备受企业欢迎的数据类技术解决方案其实是有原因的,早在1958年,IBM研究员就将商业智能BI的早期形态定义为:“对事物相互关系的一种理解能力,并依靠这种能力去指导决策,以达到预期的目标…

【Python】正则表达式匹配大部分Url

正则表达式: r’(\w://)?(\w)(.\w).(\w)(/\w)(.\w)?(?(\w\w)(&\w\w))?’ 解释:

World macheine和Houdini这两个软件在游戏地形制作中如何选择?

本文仅针对“World macheine和Houdini这两个软件在游戏地形制作中如何选择?”做出回答。 简单介绍 World Machine: World Machine是一款专业的地形生成软件。它提供了一套强大的工具和节点系统,用于创建高度图和地形。World Machine可以帮助…

opencv使用applyColorMap()函数,可以将灰度图或彩色图转换成自定义的彩色图,或opencv提供的20多种色彩值

文章目录 1、applyColorMap()函数的使用:(1)函数原型:void applyColorMap(InputArray src, OutputArray dst, int colormap)void applyColorMap(InputArray src, OutputArray dst, InputArray userColor) (2&#xff0…

PMP证书为什么可以不用考试就可以获得CSPM二级证书?

一、PMP证书介绍 PMP是指项目管理专业人士(Project Management Professional),是国际上公认的项目管理领域的权威认证。PMP认证由美国项目管理协会(PMI)颁发,是全球范围内最具权威性、最受认可的项目管理专…

Jmeter使用之:怎么编写扩展函数(一)

目录 前言: 1、首先编写我们的java类,如具体代码如下: 2、使用eclipse把java导出成jar包,如timetool.jar。 3、Jmeter 测试计划底部的library找到timetool.jar,导入进去。 4、在Jmeter测试计划下新建线程组&#…

咨询第三方软件测试机构报价时,软件企业应该准备什么?

随着软件行业的快速发展,软件企业也面临着越来越大的市场竞争压力。为了确保软件产品的质量和稳定性,许多企业开始选择外包软件测试服务。然而,在咨询第三方软件测试机构报价之前,软件企业需要做好一些准备工作,以获得…

AI很渴:chatGPT交流一次=喝掉一瓶水,GPT3训练=填满核反应堆

流行的大型语言模型(LLM),如OpenAI的ChatGPT和Google的Bard,耗能巨大,需要庞大的服务器农场提供足够的数据来训练这些强大的程序。对这些数据中心进行冷却也使得AI聊天机器人对水的需求量极大。新的研究表明&#xff0…

小白到运维工程师自学之路 第四十五集 (生产级Redis Cluster部署)

一、概述 Redis Cluster是Redis数据库的一种分布式解决方案,用于在多个节点上分布和 管理数据。它通过将数据分片存储在不同的节点上,实现数据的分布式存储和处理。 Redis Cluster采用主从复制的方式来保证数据的高可用性和容错性,每个主节…

ai绘画二次元软件免费的哪个好?这些二次元ai绘画软件比较好

小伙伴好呀,今天我要和你们分享一个超酷的话题——ai绘画二次元作品!是不是感觉很时髦?没错,现在我们不再局限于传统的绘画方式,而是可以通过ai技术来创造出令人赞叹的二次元世界。你不需要成为一名艺术大师&#xff0…

从文档智能开始洞察一切

文档智能 Document Intelligence 即使在当今数字至上的时代,许多交易仍依赖于发票、合同、法律文件、员工记录、财务报表等纸质文件。当企业希望对纸质记录进行数字化处理,以便搜索、保存和提取有价值的数据以用于决策和市场开拓,AI支持的文…

【教学类-36-05】动物头饰制作2.0(midjounery动物简笔画四图)一页两种动物

作品展示 背景需求: 头饰1.0的教学实践发现,完全可以利用裁剪的边缘纸条作为头饰的套环。因此重新设计word模板,合理布局图案位置,设计了一页2份的头饰。 原来样式:一页一份动物(4个) 现在样式…

16-Linux背景知识

目录 1.Linux是什么? 2.Unix & Linux 发展历程图 3.Linux 发行版 PS:CentOS 和 RedHat 的关系 4.关于 Linux 学习什么? 4.1.基础命令(重点) PS:使用命令相比于使用图形界面的主要好处 4.2.系统编程 &…

IP地址定位在电商行业中的应用

最新数据显示,随着电商行业的快速发展越来越多的企业开始将IP地址定位技术应用于其业务中。IP地址定位是一种利用互联网上的IP地址来确定用户地理位置的技术它通过识别用户的IP地址,从而可以实时追踪和定位他们的位置。 在电商行业中,IP地址定…

Docker 搭建sonarqube,并集成阿里P3C规则

简介 本文安装的sonarqube是7.6-community版本,未安装最新版是因为7.9之后不再支持mysql。如果你安装的是其他版本的sonarqube,那么不要使用插件包中的插件,会有版本兼容性问题。 插件 插件包 插件包中包含java语音插件,汉化插…

linux下postgresql的安装和部署

1.官网下载安装包 PostgreSQL: File Browser 2. 下载成功后上传到Linux服务器 3.解压文件 tar -zxvf postgresql-14.5.tar.gz 4.编译(后边的地址指定的就是安装数据库目录) ./configure --prefix/usr/local/postgresql 出现异常:configure: error: readline lib…

Static Timing Analysis for Nanometer Designs A Practical Approach

分享电子书籍:静态时序分析圣经 Static Timing Analysis for Nanometer Designs A Practical Approach 1 setup time Setup time (建立时间)是数据信号(D)在时钟事件(这里以时钟上升沿为例)发生之前保持稳定的最小时间。以便时钟可靠地对数据进行采样。适用于同步电路,如触…

如何制作3D虚拟人物?这篇文章告诉你

3D虚拟人物制作是一种利用计算机技术来创建并模拟逼真的虚拟角色的过程。随着科技的不断发展和创新,3D虚拟人物制作在影视特效、游戏开发、虚拟主播、辅助医疗等领域得到了广泛应用和重视。 3D虚拟人物制作是一项复杂而精细的工作。它需要具备扎实的绘画基础和美学…

如何组织一次有价值的业务巡检

1.背景 随着业务的快速迭代,开发自测需求与QA测试的需求比例相当,对于开发自测的需求,需求质量我们无法把控,并且随着自测需求的增多,QA对业务的熟悉程度也会出现断层; 部分业务整体已趋于稳定&#xff0c…

如何在Microsoft Excel中使用RANK函数快速计算排名

Excel 中的 RANK 函数是一个内置的统计函数,它返回给定数字数组中数值的秩。根据特定数据点相对于列表中其他值的大小,将等级分配给该数据点。 RANK 的公式是:=RANK(number,ref,[order]),该函数接受两个强制参数 number 和 ref,第三个参数 order 是可选的,其中: number…