本文将对【模型剪枝】基于DepGraph(依赖图)完成复杂模型的一键剪枝 文章中剪枝的模型进行蒸馏训练
一、逻辑蒸馏步骤
- 加载教师模型
- 定义蒸馏loss
- 计算蒸馏loss
- 正常训练
二、代码
1、加载教师模型
教师模型使用未进行剪枝,并且已经训练好的原始模型。
teacher_model = torch.load('./logs/before_prune.pth', map_location=device)
2、定义蒸馏loss
分割和分类的loss,都是用的softmax。
import torch.nn.functional as F
import torch.nn as nn
# 蒸馏温度
Tempature = 2
def KD_loss(teacher_pred, student_pred):
t_p = F.softmax(teacher_pred / Tempature, dim=1)
s_p = F.log_softmax(student_pred / Tempature, dim=1)
return nn.KLDivLoss(reduction='mean')(s_p, t_p) * (Tempature ** 2)
3、 计算蒸馏loss
teacher_outputs = t_model(imgs)
# 蒸馏loss
soft_loss = KD_loss(teacher_outputs, outputs)
# 总loss = 蒸馏loss*alpha + 原学生模型loss*(1-alpha)
alpha = 0.9
all_loss = loss * (1 - alpha) + soft_loss * alpha
4、正常训练
all_loss.backward()
用剪枝前训练好的模型对剪枝后模型进行蒸馏训练,训练后测试效果如下: