KD:Distilling the Knowledge in a Neural Network 原理与代码解析

news2025/1/13 13:35:26

paper:Distilling the Knowledge in a Neural Network

code:https://github.com/megvii-research/mdistiller/blob/master/configs/cifar100/kd.yaml

存在的问题

训练阶段,我们可以不考虑计算成本和训练时间,为了更高的精度训练一个很大的模型,或是训练多个模型,采用模型集成的方法进一步提高精度。但在部署时,往往受计算资源和推理时间的限制,需要采用剪枝、量化等方法对模型进行压缩、加速,或是直接将大模型替换成轻量化的小模型,使其满足实际应用需求。

本文的创新点

本文提出了知识蒸馏的概念,小模型的学习能力有限,将大模型学习到的知识传递给小模型可以帮助小模型的学习并且提高小模型的精度。同时提出了模型“知识”的具体表示方法,以及如何将知识从大模型传递给小模型的具体方法。

方法介绍

分类网络的最后一层通常会采用softmax将logits转化成各类别的最终预测概率,本文将带温度 \(T\) 的softmax输出作为大模型学习到的“知识”,并作为监督信号监督小模型的训练从而将大模型的知识传递给小模型。如下所示

温度 \(T\) 的引入可以使大模型输出的概率分布较为缓和,\(T\) 越大,分布越缓和。比如以MINIST分类为例,对于某张“2”的图像大模型输出3的概率为\(10^{-6}\),输出7的概率为 \(10^{-9}\)。对于另一张“2”的图像,输出的概率可能相反。这是很有用的信息,它表明了哪张2的外观更像3哪张更像7,但在知识传递过程中对交叉熵损失函数的影响很小,因为它的值太小了接近于0。引入 \(T\) 可以使softmax输出的概率分布更加缓和,概率分布曲线更加平滑,从而保留更多有用的信息。

小模型训练阶段,一方面采用不带温度的即 \(T=1\) 的softmax输出并与样本的真实标签即hard targets计算交叉熵损失,另一方面采用带温度的softmax输出并和大模型的softmax输出即soft targets计算KL散度损失,注意这里大小模型的 \(T\) 相等并且大于1最后取两个损失的加权和作为小模型的最终损失。作者发现通常后者的权重取得比较小可以得到更好的结果,这是因为soft targets产生的梯度缩小为 \(1/T^{2}\),因此需要乘以更大的权重来平衡。

代码

import torch
import torch.nn as nn
import torch.nn.functional as F

from ._base import Distiller


def kd_loss(logits_student, logits_teacher, temperature):
    log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()  # (64,100)->(64)->()
    loss_kd *= temperature**2
    return loss_kd


class KD(Distiller):
    """Distilling the Knowledge in a Neural Network"""

    def __init__(self, student, teacher, cfg):
        super(KD, self).__init__(student, teacher)
        self.temperature = cfg.KD.TEMPERATURE  # 4
        self.ce_loss_weight = cfg.KD.LOSS.CE_WEIGHT  # 0.1
        self.kd_loss_weight = cfg.KD.LOSS.KD_WEIGHT  # 0.9

    def forward_train(self, image, target, **kwargs):  # (64,3,32,32),(64)
        logits_student, _ = self.student(image)  # (64,100)
        with torch.no_grad():
            logits_teacher, _ = self.teacher(image)  # (64,100)

        # losses
        loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
        loss_kd = self.kd_loss_weight * kd_loss(
            logits_student, logits_teacher, self.temperature
        )
        losses_dict = {
            "loss_ce": loss_ce,
            "loss_kd": loss_kd,
        }
        return logits_student, losses_dict

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/191626.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小程序提升篇-npm、数据共享、分包、自定义tabBar

npm 包的使用1.1 npm限制小程序支持npm第三方包,提高开发效率,有以下三种限制:不支持依赖node.js内置库包不支持依赖浏览器内置对象的包不支持依赖C插件的包限制较多,因此小程序可以使用的包不多1.2 Vant Weapp是一套开源的小程序…

带你读懂——频率响应与采样频率之间的关系

频响范围 频率响应:不同频率下的输入信号经过系统后响应之后的输出信号增益。大白话就是,输入信号频率是xxx Hz,幅值为yyy mg,观察此时的输出信号幅值为AyAyAy mg,此时升高或降低了AAA倍。 电压增益计算公式&#xff…

浅读人月神话笔记(2)

读书笔记:今日翻书浅读,从《为什么巴比伦塔会失败》开始至《干将莫邪》结束,巴比伦塔的建造对当下项目推进有广泛借鉴意义,今天这几个章节在PMBOK中有一些可以互相对照学习的内容,《为什么巴比伦塔会失败?》…

RPA自动化办公04——软件自动化(excel,word,浏览器)

参考:软件自动化_UiBot开发者指南 虽然我们可以使用前面的鼠标点击等操作打开excel表然后写入什么的,但是直接用Uibot里面的命令会更方便。 Excel 在旁边的命令里面打开excel簿 随便选一个excel表实验一下,然后读取区域,可以选。…

使用字典快速获取唯一值与重复值【单个字典对象】

在以前的博客《使用字典快速获取唯一值与重复值(交集与并集)》使用多个字典对象获取交集与并集,最近有同学提问,是否可以只使用一个字典对象实现相同的功能,对于有“编程洁癖”的同学来说,可能不喜欢使用多…

记录:windows+opencv3.4.16+vs2013+cmake编译

环境:vs2013,x64,opencv3.4.16,cmakeopencv官网:https://opencv.org/releases/1、opencv source下载:因为想用vs2013,现在opencv官网windows版安装包只有vc14和vc15了,只能自己编译了。找一个自…

良心无广的3款软件,每一款都逆天好用,且用且珍惜

闲话少说,直上干货! 1、清浊 清浊是一款强大到离谱的国产手机清理APP,追求简约至上,界面非常清爽,无任何弹弹屏广告,值得关注的是,这款软件完全免费使用,常规清理、应用清理、空文件…

活体识别4:论文笔记之《Face Spoofing Detection Using Colour Texture Analysis》

说明 本文是我对论文《Face Spoofing Detection Using Colour Texture Analysis》做的一个简单笔记。 这个论文是芬兰奥卢大学(Oulu)课题组的一篇很有代表性的论文,写于2016年,使用的是“手工特征SVM分类器”这种比较传统的方案,方案不复杂&…

吾爱2023新年红包题第三题

吾爱论坛2023年春节红包安卓题,随便玩一玩; https://www.52pojie.cn/thread-1738015-1-1.html 第三题:https://www.52pojie.cn/home.php?modtask&doview&id22 首先我们下载后,打开apk是提示要点击 999次即可通关&…

Docker - 4. Docker 帮助启动类命令

目录 1. 启动 docker 2. 停止 docker 3. 重启 docker 4. 查看 docker 状态 5. 保持开机自动启动 6. 显示 docker 版本信息 7. 显示 docker 系统信息 8. 查看 docker 总体帮助文档 9. 查看 docker 命令帮助文档 1. 启动 docker systemctl start docker 2. 停止 dock…

macm1安装tensorflow以及pycharm配置

macm1安装tensorflow以及pycharm配置 本文目录macm1安装tensorflow以及pycharm配置使用MacOS 12安装conda创建一个conda环境安装tensorflowpycharm配置使用MacOS 12 必需条件:macOS 12 安装conda 安装Miniforge(包含conda及一个python环境)…

RabbitMQ消息队列实战(2)—— Java调用RabbitMQ的三种方式

本文主要介绍Java中调用RabbitMQ的三种方式。三种方式实际上对应了三种不同的抽象级别:首先,通过Java原生代码来访问RabbitMQ。在这种方式下,需要手动创建Connection,创建Channel,然后通过Channel对象可以显式的创建Ex…

基于springboot+vue的问卷调查系统(前后端分离)

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战 主要内容:毕业设计(Javaweb项目|小程序等)、简历模板、学习资料、面试题库、技术咨询 文末联系获取 功能分析…

学校机房高效稳定,一招见效

校园安全作为公共安全领域重要的一部分,一直以来都格外受到重视。近年来,各地区陆续发布了多项加强校园安全管理的政策、法规及标准规范,旨在贯彻落实构建“平安校园”的宗旨,不断完善校园的人防、物防、技防建设。 学校机房常见四…

AutoLisp演练(二)

一、自动绘制出多个等半径圆相切 1.输入基准点baspt 2.输入小圆半径rad 3. 输入欲相切的圆的数量num 4.自动绘制出多个等半径圆相切 5. 涉及到相关变量,设定为baspt、rad、num、midpt、cenpt、kk、ang1、ang2 二、程序代码实现 三、测试及效果 测试一 四、…

盘点一些惊艳一时的 CSS 属性

✨ 个人主页:山山而川~xyj ⚶ 作者简介:前端领域新星创作者,专注于前端各领域技术,共同学习共同进步,一起加油! 🎆 系列专栏: web 大前端 🚀 学习格言:与其临…

2023爬虫学习笔记 -- 某狗网站爬取数据

一、爬取某狗网站的首页1、导入需要的库文件import requests2、指定我们要访问的网址网页"https://www.sogou.com"3、获取服务器的返回的所有信息响应requests.get(网页)4、通过text属性,从返回信息中读取字符串内容响应内容响应.text5、查看读取到的内容…

唐宇迪机器学习实战课程笔记(全)

1. 线性回归1.1线性回归理论1.2线性回归实战2.分类模型评估(Mnist实战SGD_Classifier)2.1 K折交叉验证K-fold cross validation2.2 混淆矩阵Confusion Matrix2.3 准确率accuracy、精度precision、召回率recall、F12.4 置信度confidence2.5 ROC曲线3.训练调参基本功(LinearRegre…

1612_PC汇编语言_条件以及控制结构

全部学习汇总: GreyZhang/g_unix: some basic learning about unix operating system. (github.com) 这一次简单看看条件分支以及控制结构,感觉看完这部分之后,汇编的大部分框架已经有个差不多了。我的目的并不是成为汇编高手,因此…

数据处理——增删改

文章目录插入数据方式一:values方式2:将查询结果插入到表中更新数据删除数据MySQL8新特性:计算列综合案例插入数据 用INSERT插入数据 方式一:values 使用这种语法一次只能向表中插入一条数据。 情况1:为表的所有字段…