paper:Distilling the Knowledge in a Neural Network
code:https://github.com/megvii-research/mdistiller/blob/master/configs/cifar100/kd.yaml
存在的问题
训练阶段,我们可以不考虑计算成本和训练时间,为了更高的精度训练一个很大的模型,或是训练多个模型,采用模型集成的方法进一步提高精度。但在部署时,往往受计算资源和推理时间的限制,需要采用剪枝、量化等方法对模型进行压缩、加速,或是直接将大模型替换成轻量化的小模型,使其满足实际应用需求。
本文的创新点
本文提出了知识蒸馏的概念,小模型的学习能力有限,将大模型学习到的知识传递给小模型可以帮助小模型的学习并且提高小模型的精度。同时提出了模型“知识”的具体表示方法,以及如何将知识从大模型传递给小模型的具体方法。
方法介绍
分类网络的最后一层通常会采用softmax将logits转化成各类别的最终预测概率,本文将带温度 \(T\) 的softmax输出作为大模型学习到的“知识”,并作为监督信号监督小模型的训练从而将大模型的知识传递给小模型。如下所示
温度 \(T\) 的引入可以使大模型输出的概率分布较为缓和,\(T\) 越大,分布越缓和。比如以MINIST分类为例,对于某张“2”的图像大模型输出3的概率为\(10^{-6}\),输出7的概率为 \(10^{-9}\)。对于另一张“2”的图像,输出的概率可能相反。这是很有用的信息,它表明了哪张2的外观更像3哪张更像7,但在知识传递过程中对交叉熵损失函数的影响很小,因为它的值太小了接近于0。引入 \(T\) 可以使softmax输出的概率分布更加缓和,概率分布曲线更加平滑,从而保留更多有用的信息。
小模型训练阶段,一方面采用不带温度的即 \(T=1\) 的softmax输出并与样本的真实标签即hard targets计算交叉熵损失,另一方面采用带温度的softmax输出并和大模型的softmax输出即soft targets计算KL散度损失,注意这里大小模型的 \(T\) 相等并且大于1。最后取两个损失的加权和作为小模型的最终损失。作者发现通常后者的权重取得比较小可以得到更好的结果,这是因为soft targets产生的梯度缩小为 \(1/T^{2}\),因此需要乘以更大的权重来平衡。
代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from ._base import Distiller
def kd_loss(logits_student, logits_teacher, temperature):
log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean() # (64,100)->(64)->()
loss_kd *= temperature**2
return loss_kd
class KD(Distiller):
"""Distilling the Knowledge in a Neural Network"""
def __init__(self, student, teacher, cfg):
super(KD, self).__init__(student, teacher)
self.temperature = cfg.KD.TEMPERATURE # 4
self.ce_loss_weight = cfg.KD.LOSS.CE_WEIGHT # 0.1
self.kd_loss_weight = cfg.KD.LOSS.KD_WEIGHT # 0.9
def forward_train(self, image, target, **kwargs): # (64,3,32,32),(64)
logits_student, _ = self.student(image) # (64,100)
with torch.no_grad():
logits_teacher, _ = self.teacher(image) # (64,100)
# losses
loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
loss_kd = self.kd_loss_weight * kd_loss(
logits_student, logits_teacher, self.temperature
)
losses_dict = {
"loss_ce": loss_ce,
"loss_kd": loss_kd,
}
return logits_student, losses_dict