知识蒸馏介绍

news2024/12/26 21:01:27

在这里插入图片描述

一、知识蒸馏介绍

1.1 概念介绍

知识蒸馏(knowledge distillation)是模型压缩的一种常用的方法,不同于模型压缩中的剪枝和量化,知识蒸馏是通过构建一个轻量化的小模型,利用性能更好的大模型的监督信息,来训练这个小模型,以期达到更好的性能和精度。
最早是由Hinton在2015年首次在文章《Distilling the Knowledge in a Neural Network》中提出并应用在分类任务上面,这个复杂模型我们称之为teacher(教师模型),小模型我们称之为Student(学生模型)。来自Teacher模型输出的监督信息称之为knowledge(知识),而student学习迁移来自teacher的监督信息的过程称之为Distillation(蒸馏)。

1.2 为什么要有知识蒸馏?

深度学习在计算机视觉、语音识别、自然语言处理等内的众多领域中均取得了令人难以置信的性能。但是,大多数模型在计算上过于昂贵,无法在移动端或嵌入式设备上运行。因此需要对模型进行压缩,且知识蒸馏是模型压缩中重要的技术之一。

提升模型精度
如果对目前的网络模型A的精度不是很满意,那么可以先训练一个更高精度的teacher模型B(通常参数量更多,时延更大),然后用这个训练好的teacher模型B对student模型A进行知识蒸馏,得到一个更高精度的A模型。

降低模型时延,压缩网络参数
如果对目前的网络模型A的时延不满意,可以先找到一个时延更低,参数量更小的模型B,通常来讲,这种模型精度也会比较低,然后通过训练一个更高精度的teacher模型C来对这个参数量小的模型B进行知识蒸馏,使得该模型B的精度接近最原始的模型A,从而达到降低时延的目的。

标签之间的域迁移
假如使用狗和猫的数据集训练了一个teacher模型A,使用香蕉和苹果训练了一个teacher模型B,那么就可以用这两个模型同时蒸馏出一个可以识别狗、猫、香蕉以及苹果的模型,将两个不同域的数据集进行集成和迁移。

降低标注量
该功能可以通过半监督的蒸馏方式来实现,用户利用训练好的teacher网络模型来对未标注的数据集进行蒸馏,达到降低标注量的目的。

因此,在工业界中对知识蒸馏和迁移学习也有着非常强烈的需求。

补充模型压缩的知识::模型压缩大体上可以分为 5 种:

模型剪枝:即移除对结果作用较小的组件,如减少 head 的数量和去除作用较少的层,共享参数等,ALBERT属于这种;
量化:比如将 float32 降到 float8;
知识蒸馏:将 teacher 的能力蒸馏到 student上,一般 student 会比 teacher 小。我们可以把一个大而深的网络蒸馏到一个小的网络,也可以把集成的网络蒸馏到一个小的网络上。
参数共享:通过共享参数,达到减少网络参数的目的,如 ALBERT 共享了 Transformer 层;
参数矩阵近似:通过矩阵的低秩分解或其他方法达到降低矩阵参数的目的;

1.3 这与从头开始训练模型有何不同?

显然,对于更复杂的模型,理论搜索空间要大于较小网络的搜索空间。但是,如果我们假设使用较小的网络可以实现相同(甚至相似)的收敛,则教师网络的收敛空间应与学生网络的解空间重叠。

不幸的是,仅此一项并不能保证学生网络在同一位置收敛。学生网络的收敛可能与教师网络的收敛大不相同。但是,如果指导学生网络复制教师网络的行为(教师网络已经在更大的解空间中进行了搜索),则可以预期其收敛空间与原始教师网络收敛空间重叠。
在这里插入图片描述

二. 知识蒸馏方式

2.1 知识蒸馏基本框架

知识蒸馏采取Teacher-Student模式:将复杂且大的模型作为Teacher,Student模型结构较为简单,用Teacher来辅助Student模型的训练,Teacher学习能力强,可以将它学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力。复杂笨重但是效果好的Teacher模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的Student小模型。

训练方式
训练学生和教师模型的方法主要有三种,即离线、在线和自我蒸馏。
蒸馏训练方法的分类取决于教师模型是否与学生模型同时修改

迁移方法
知识蒸馏是对模型的能力进行迁移,迁移的方法不同可以简单分为基于目标蒸馏(也称为Soft-target蒸馏或Logits方法蒸馏)和基于特征蒸馏的算法两个大的方向

2.2 目标蒸馏-Logits方法

一个很直白且高效的迁移泛化能力的方法就是:
使用softmax层输出的类别的概率来作为“Soft-target” 。用温度系数平滑结果,在整个知识蒸馏过程中,我们先让温度 升高,然后在测试阶段恢复“低温“( ),从而将原模型中的知识提取出来,因此将其称为是蒸馏,实在是妙啊。

什么是logits?

p = softmax (loggits)
logits = [0.3,1.2, 0.7, 0.6, -1.2], i代表第i个类别,Zi代表属于第i类的可能性。
因为Logits并非概率值,所以一般在Logits数值上会用Softmax函数进行变换,得出的概率值作为最终分类结果概率。

用softmax将logits转变成Soft-target

Softmax一方面把Logits数值在各类别之间进行概率归一,使得各个类别归属数值满足概率分布;
另外一方面,它会放大Logits数值之间的差异,使得Logits得分两极分化,Logits得分高的得到的概率值更偏大一些,而较低的Logits数值,得到的概率值则更小。

  • Hard-target:原始数据集标注的 one-shot 标签,除了正标签为 1,其他负标签都是 0。
  • Soft-target:Teacher模型softmax层输出的类别概率,每个类别都分配了概率,正标签的概率最高。

softmax层的输出,除了正例之外,负标签也带有Teacher模型归纳推理的大量信息

比如某些负标签对应的概率远远大于其他负标签,则代表 Teacher模型在推理时认为该样本与该负标签有一定的相似性。由此我们可见Soft-target蕴含着比Hard-target更多的信息。
而传统的 Hard-target 的训练方式,所有的负标签都会被平等对待。因此,Soft-target 给 Student模型带来的信息量要大于 Hard-target,并且Soft-target分布的熵相对高时,其Soft-target蕴含的知识就更丰富。
同时,使用 Soft-target 训练时,梯度的方差会更小,训练时可以使用更大的学习率,所需要的样本也更少。这也解释了为什么通过蒸馏的方法训练出的Student模型相比使用完全相同的模型结构和训练数据只使用Hard-target的训练方法得到的模型,拥有更好的泛化能力。

用温度系数平滑Soft-target 结果

直接使用softmax层的输出值作为soft target,这又会带来一个问题: 当softmax输出的概率分布熵相对较小时,负标签的值都很接近0,对损失函数的贡献非常小,小到可以忽略不计。因此"温度"这个变量就派上了用场。下面的公式是加了温度这个变量之后的softmax函数:
T越高,softmax的output probability distribution越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。

温度影响
当想从负标签中学到一些信息量的时候,温度 应调高一些;
当想减少负标签的干扰的时候,温度 应调低一些;
总的来说, 的选择和Student模型的大小有关,Student模型参数量比较小的时候,相对比较低的温度就可以了。因为参数量小的模型不能学到所有Teacher模型的知识,所以可以适当忽略掉一些负标签的信息。
最后,在整个知识蒸馏过程中,我们先让温度 升高,然后在测试阶段恢复“低温“( ),从而将原模型中的知识提取出来,因此将其称为是蒸馏,实在是妙啊

2.2.1 蒸馏的一种特殊形式:直接Matching Logits

直接Matching Logits指的是,直接使用softmax层的输入logits(而不再是输出)作为Soft- target,需要最小化的目标函数是Teacher模型和Student模型的logits之间的平方差,

2.2.2 知识蒸馏训练的具体方法

训练好Teacher模型;利用高温Thigh产生 Soft-target;
使用{soft label, Thigh}和 {hard label, T=1}同时训练 Student模型;

在这里插入图片描述

2.3 特征蒸馏

另外一种知识蒸馏思路是特征蒸馏方法,它不像Logits方法那样,Student只学习Teacher的Logits这种结果知识,而是学习Teacher网络结构中的中间层特征。

最早采用这种模式的工作来自于论文《FITNETS:Hints for Thin Deep Nets》,它强迫Student某些中间层的网络响应,要去逼近Teacher对应的中间层的网络响应。这种情况下,Teacher中间特征层的响应,就是传递给Student的知识。在此之后,出了各种新方法,但是大致思路还是这个思路,本质是Teacher将特征级知识迁移给Student。因此,接下来我们以这篇论文为主,详细介绍特征蒸馏方法的原理。

2.4 训练方式

2.4.1. 对抗蒸馏

对抗性学习最近在生成对抗网络的背景下被概念化,用于训练一个生成器模型,该模型学习生成尽可能接近真实数据分布的合成数据样本,以及一个鉴别器模型,该模型学习区分真实数据和合成数据样品。这一概念已应用于知识蒸馏,使学生和教师模型能够更好地表示真实数据分布。

为了达到学习真实数据分布的目的,可以:
● 老师作为鉴别器模型,学生通过学习老师作为鉴别器来区分生成器生成的数据
● 学生作为生成器模型,通过学习老师作为生成器来生成接近老师生成的数据
● 学生作为生成器模型,通过学习老师作为生成器来生成接近原始数据的数据,使用在线蒸馏技术,同时优化学生和老师模型
在这里插入图片描述

2.4.2. Multi-Teacher蒸馏

在多教师蒸馏中,学生模型从几个不同的教师模型中获取知识,使用教师模型的集合可以为学生模型提供不同种类的知识,这比从单个教师那里获取的知识更有益模型。

来自多位教师的知识可以合并为所有模型的平均响应。通常从教师那里传授的知识类型是基于响应和特征。多位教师可以传递不同种类的知识。
在这里插入图片描述

2.4.3. 跨模态蒸馏

图 8 显示了跨模态蒸馏训练方案。在这里,教师以一种方式接受培训,其知识被提炼到需要不同方式知识的学生中。当数据或标签在训练或测试期间不可用于特定模态时,就会出现这种情况,因此需要跨模态转移知识。

跨模态蒸馏最常用于视觉领域。例如,来自受过标记图像数据训练的教师的知识可用于对具有未标记输入域(如光流、文本或音频)的学生模型进行蒸馏。在这种情况下,从教师模型的图像中学习到的特征用于学生模型的监督训练。跨模态蒸馏对于视觉问答、图像字幕等应用非常有用。
在这里插入图片描述

2.4.4 基于图的蒸馏

使用图而不是从教师到学生的单个实例知识来捕获数据内关系。
图有两种使用方式——作为知识转移的手段,以及控制教师知识的转移。
在基于图的蒸馏中,图的每个顶点代表一个自监督的教师,它可能分别基于响应或基于特征的知识,如逻辑和特征图。
在这里插入图片描述

2.4.5 基于注意力的蒸馏

● 基于使用注意力图从特征嵌入中转移知识。
在这里插入图片描述

2.4.6 无数据蒸馏

● 由于隐私、安全或保密原因,无数据蒸馏是在没有训练数据集的情况下基于合成数据。合成数据通常是从预训练教师模型的特征表示中生成的。在其他应用中,GAN 也用于生成合成训练数据。

2.4.7 量化蒸馏

● 量化蒸馏用于将知识从高精度教师模型(例如 32 位浮点)转移到低精度学生网络(例如 8 位)。
在这里插入图片描述

2.4.8 终身蒸馏

● 终身蒸馏基于持续学习、终身学习和元学习的学习机制,其中积累以前学到的知识并将其转移到未来的学习中。

2.4.8 基于神经架构搜索的蒸馏

● 基于神经架构搜索的蒸馏用于识别合适的学生模型架构,以优化从教师模型中的学习。

三. 支持知识蒸馏的平台

TensorFlow Model Optimization Toolkit: Google发布的一款工具包,可以压缩和加速TensorFlow模型,并提供支持量化和蒸馏等压缩方法。
Hugging Face Distiller: hugging face公司的一款用于训练和蒸馏自然语言处理模型的Python库。它支持多种蒸馏策略和结构,并可以在多个平台上使用。
Nvidia TensorRT: Nvidia推出的一款用于加速深度学习推理的高性能推理引擎。TensorRT可以用于优化、蒸馏和量化深度学习模型,以提高模型的推理速度和准确性。
Keras Distiller: 一款基于Keras的模型蒸馏工具,可以对Keras模型进行蒸馏和压缩,并支持对数据集进行分布式训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2202692.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

项目经理是怎么慢慢废掉的?这些无意识行为可能会毁了你!

工作久了,每个人都或多或少会有一些无力感和疲惫感。如果没有调整过来,久而久之,会感觉自己好像废掉了,做什么事情都打不起精神。 如果你是项目经理,工作中有这样一些迹象,比如总是拖延时间、丧失自己的判…

【进程间通信(三)】【system V共享内存】

目录 1. 原理2. 编码通信2.1 创建共享内存2.2 shmat && shmdt && shmctl2.3 通信 3. 共享内存的特性3.1 共享内存的属性3.2 加入管道实现同步机制 前面的文章介绍了管道通信,其中包括匿名管道、命名管道。这篇文章介绍另一种进程间通信的方式 -----…

NVP的含义?如何理解其在AEM|FLUKE线缆认证测试中的意义?不同的NVP会出现怎样的结果?

在AEM|FLUKE铜缆认证测试中,有很多朋友对NVP设置有疑问,不知道应该怎么去设置它,并很好的应用它,那我们基于此,做一个简单的分析。 什么是NVP? NVP是Nominal Velocity of Propagation的缩写?简单直接译过…

Java基础-泛型机制

文章目录 为什么引入泛型泛型的基本使用泛型类泛型接口泛型方法泛型数组正确的数组声明使用场景如何理解Java中的泛型是伪泛型?泛型中类型擦除 泛型数组:如何正确的初始化泛型数组实例? 为什么引入泛型 引入泛型的意义在于: 适用…

KEYSIGHT B1500A 半导体器件参数分析仪

新利通 B1500A 半导体器件参数分析仪 ——一体化器件表征分析仪—— 简述 Keysight B1500A 半导体参数分析仪是一款一体化器件表征分析仪,能够测量 IV、CV、脉冲/动态 IV 等参数。 主机和插入式模块能够表征大多数电子器件、材料、半导体和有源/无源元器件。 B…

关于相机的一些零碎知识点

热成像,英文为Thermal Imaging,例如型号500T,其实指的就是热成像500分辨率。 相机的CMOS,英文为Complementary Metal Oxide Semiconductor,是数码相机的核心成像部件,是一种互补金属氧化物导体器件。 DPI…

PVC刻字膜高精度模切应用

PVC刻字膜是一种由聚氯乙烯(PVC)为主要成分制成的薄膜材料,具有耐磨、耐刮、耐水、耐油以及良好的化学稳定性等特点。这种薄膜在多个行业中得到广泛应用,特别是在服装、鞋业、箱包、汽车内饰等领域,用于制作各种标识、…

NDC美国药品编码目录数据库查询方法

NDC(National Drug Code)翻译为“国家药品代码”,是美国食品药品监督管理局(FDA)制定的一种药品标识系统,用于唯一标识药品。这个编码系统主要目的是为精准识别和追踪不同药品而建设,行业人员和…

2024最新【Pycharm】史上最全PyCharm安装教程,图文教程(超详细)

1. PyCharm下载安装 完整安装包下载(包含Python和Pycharm专业版注册码):点击这里 1)访问官网 https://www.jetbrains.com/pycharm/download/#sectionwindows 下载「社区版 Community」 安装包。 2)下载完成后&#…

【斯坦福CS144】Lab7

一、实验目的 在本课程中,你已经实现了互联网基础设施的重要部分。这个检查点不是关于实现,而是关于测量实际的互联网并报告特定路径的长期统计数据。 二、实验内容 1.收集数据 选择一个远程主机,其往返时间(RTT)从…

Unity3D相关知识点总结

Unity3D使用的是笛卡尔三维坐标系,并且是以左手坐标系进行展示的。 1.全局坐标系(global) 全局坐标系描述的是游戏对象在整个世界(场景)中的相对于坐标原点(0,0,0)的位置…

处理 Vue3 中隐藏元素刷新闪烁问题

一、问题说明 页面刷新,原本隐藏的元素会一闪而过。 效果展示: 页面的导航栏通过路由跳转中携带的 meta 参数控制导航栏的 显示/隐藏,但在实践过程中发现,虽然元素隐藏了,但是刷新页面会出现闪烁的问题。 项目源码&…

MLP优化KAN

一:spline概念介绍 在数学学科数值分析中,样条(spline)是一种特殊的函数,由多项式分段定义。样条的英语单词spline来源于可变形的样条工具,那是一种在造船和工程制图时用来画出光滑形状的工具 样条有两个特…

Adversarial and Adaptive Tone Mapping Operatorfor High Dynamic Range Images

Abstract 这项工作涉及色调映射,这是一种将高动态范围 (HDR) 图像转换为低动态范围 (LDR) 图像的常用方法。 我们通过使用自适应色调映射来解决这个问题。 我们建议部署条件生成对抗网络来构建对抗性和自适应色调映射算子(adTMO)&#xff0c…

游戏盾是如何解决游戏行业攻击问题

随着游戏行业的迅猛发展,其高额的利润和激烈的市场竞争吸引了众多企业和创业者的目光。然而,这一行业也面临着前所未有的业务和安全挑战,尤其是DDoS(分布式拒绝服务)攻击,已经成为游戏行业的一大威胁。今天…

Metasploit渗透测试之MSFvenom

简介 到目前为止,你应该已经对MSFvenom不陌生了,因为在之前的文章中已经介绍多次了。MSFvenom是用于生成有效攻击载荷和编码的工具。它由msfpayload和msfencode演变而来。并于2015年6月8日取代了这两者。 在本文中,我们将更深入地研究可用的…

MySQL进阶 - 索引

01 索引概述 【1】概念:索引就是一种有序的数据结构,可用于高效查询数据。在数据库表中除了要保存原始数据外,数据库还需要去维护索引这种数据结构,通过这种数据结构来指向原始数据,这样就可以根据这些数据结构实现高…

如何高效开发一套医院绩效核算系统

医院绩效核算系统是一种专为医疗机构设计的软件系统,旨在通过科学、系统的方法评估和核算医院内各科室及员工的绩效。该系统与医院的信息化系统紧密集成,特别是与医院信息系统(HIS)对接,以确保数据的准确性和实时性。 …

nginx配置多域名共用服务器80端口

nginx配置多域名共用服务器80端口 多个域名,比如两个域名,这两个域名其实共用一台服务器(意味着域名解析到同一个IP),一个域名为abc.com(可以是http://abc.com或者www.abc.com),另外一个域名为x…

腾讯地图接口报错此key每日调用量已达到上限

需要在 配额管理 的 账户额度 中进行配额的分配