简要介绍 | 知识蒸馏:轻量级模型的智慧之源

news2025/1/10 11:34:38

注1:本文系“简要介绍”系列之一,仅从概念上对知识蒸馏进行非常简要的介绍,不适合用于深入和详细的了解。

知识蒸馏:轻量级模型的智慧之源

在这里插入图片描述

A Gentle Introduction to Hint Learning & Knowledge Distillation | by LA Tran | Towards AI

在深度学习领域,使用大型神经网络模型通常能够获得更好的性能。然而,这些模型往往具有 高计算复杂度 ,不适合在边缘设备上部署。知识蒸馏(Knowledge Distillation)是一种将大型模型的知识迁移到小型模型中以提高其性能的技术,具有广泛的应用前景。本文将从背景、原理、研究现状、挑战和未来展望等方面介绍知识蒸馏。

背景介绍

随着人工智能的快速发展,深度学习模型在诸如计算机视觉、自然语言处理等领域取得了显著的成果。然而,这些模型的参数量和计算复杂度往往很高,难以直接应用于资源有限的环境,如手机、嵌入式设备等。因此, 模型压缩 成为了一个重要的研究方向。

知识蒸馏 是一种有效的模型压缩方法,通过训练一个较小的学生模型来模拟大型教师模型的行为,从而在保持较低计算复杂度的同时,提高模型性能。

原理介绍和推导

知识蒸馏的基本原理是让学生模型学习教师模型的输出概率分布。以下是详细的推导过程:

x x x 是输入数据, y y y 是真实标签, T T T 是教师模型, S S S 是学生模型。我们用 P T P_T PT P S P_S PS 分别表示教师模型和学生模型预测的概率分布,它们可以通过 softmax 函数计算得到:

P T ( y ∣ x ) = e z T y / T ∑ y ′ e z T y ′ / T , P S ( y ∣ x ) = e z S y / T ∑ y ′ e z S y ′ / T P_T(y|x) = \frac{e^{z_T^y/T}}{\sum_{y'} e^{z_T^{y'}/T}}, \quad P_S(y|x) = \frac{e^{z_S^y/T}}{\sum_{y'} e^{z_S^{y'}/T}} PT(yx)=yezTy/TezTy/T,PS(yx)=yezSy/TezSy/T

其中, z T z_T zT z S z_S zS 分别表示教师模型和学生模型的 logits 输出, T T T 是一个温度参数,用于控制概率分布的平滑程度。

知识蒸馏的目标是最小化学生模型和教师模型的概率分布之间的差异,通常使用 KL 散度 来衡量。定义损失函数为:

L K D = ∑ x K L ( P T ( ⋅ ∣ x ) ∣ ∣ P S ( ⋅ ∣ x ) ) L_{KD} = \sum_x KL(P_T(\cdot|x) || P_S(\cdot|x)) LKD=xKL(PT(x)∣∣PS(x))

为了兼顾真实标签信息,我们还可以将知识蒸馏损失和常规的分类损失相结合:

L = ( 1 − α ) L C E + α T 2 L K D L = (1 - \alpha)L_{CE} + \alpha T^2 L_{KD} L=(1α)LCE+αT2LKD

其中, L C E L_{CE} LCE 是分类损失(如交叉熵损失), α \alpha α 是一个权重参数,用于平衡两者的影响。

研究现状

自 2015 年 Hinton 等人首次提出知识蒸馏以来,该领域的研究取得了丰富的成果。在此,我们将简要介绍一些知识蒸馏的研究方向和关键技术。

  • 蒸馏方法: 除了基本的蒸馏方法,还有许多改进和扩展,例如使用额外的监督信息、动态调整温度参数等。

  • 自蒸馏: 一种将模型自身的知识迁移到同一模型的方法。通过在不同训练阶段或不同网络层之间进行蒸馏,可以提高模型性能,并有助于泛化。

  • 多教师蒸馏: 通过结合多个教师模型的知识,可以进一步提高学生模型的性能。

  • 在线蒸馏: 在训练过程中实时生成教师模型,与学生模型共同学习。

在这里插入图片描述

Knowledge distillation in deep learning and its applications [PeerJ]

挑战

尽管知识蒸馏已经取得了一定的成功,但仍面临许多挑战,包括:

  • 优化方法: 知识蒸馏涉及到两个或多个模型之间的相互作用,如何有效地优化这些模型仍是一个开放性问题。

  • 知识迁移的有效性: 由于教师模型和学生模型的结构差异,部分知识可能难以迁移。如何设计更通用的知识迁移方法仍需进一步研究。

  • 计算效率: 知识蒸馏需要训练多个模型,可能导致较高的计算开销。如何减少蒸馏过程的计算成本是另一个重要问题。

未来展望

随着边缘计算和物联网的发展,轻量级模型在实际应用中的需求将越来越大。因此,知识蒸馏等模型压缩技术将持续受到关注。在未来,我们期待以下几个方向的发展:

  • 新的蒸馏方法: 探索更高效、更通用的知识蒸馏方法,以适应各种应用场景。

  • 跨领域知识迁移: 将知识蒸馏应用于不同领域和任务之间的知识迁移,实现更广泛的泛化能力。

  • 自动化模型设计: 结合自动机器学习(AutoML)技术,在知识蒸馏过程中自动搜索最优的学生模型结构和参数。

  • 与其他模型压缩技术的融合: 将知识蒸馏与其他模型压缩技术(如剪枝、量化等)相结合,实现更高效的模型压缩和性能提升。

总结

知识蒸馏是一种将大型神经网络模型的知识迁移到小型模型的技术。通过训练一个较小的学生模型来模拟大型教师模型的行为,知识蒸馏旨在在保持较低计算复杂度的同时提高模型性能。尽管知识蒸馏在过去几年中取得了显著的进展,但仍存在许多挑战和未来的研究方向。随着边缘计算和物联网的发展,知识蒸馏等模型压缩技术将在未来继续受到关注,为智能设备和应用提供支持。

在这里插入图片描述

On Representation Knowledge Distillation for Graph Neural Networks | Chaitanya K. Joshi

代码示例

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

# 定义教师模型
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义学生模型
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.fc1 = nn.Linear(784, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 计算软目标损失
def soft_target_loss(y_s, y_t, T):
    y_s = torch.nn.functional.log_softmax(y_s / T, dim=1)
    y_t = torch.nn.functional.softmax(y_t / T, dim=1)
    return torch.mean(torch.sum(-y_t * y_s, dim=1))

# 超参数
T = 2.0  # 温度参数
alpha = 0.5  # 软目标损失权重
epochs = 10
learning_rate = 0.001

# 初始化模型、损失函数和优化器
teacher_model = TeacherModel()
student_model = StudentModel()
criterion_hard = nn.CrossEntropyLoss()
optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)

# 模拟训练数据
inputs = Variable(torch.randn(100, 784))
labels = Variable(torch.randint(0, 10, (100,)))

# 开始训练
for epoch in range(epochs):
    optimizer.zero_grad()

    # 教师模型和学生模型的输出
    teacher_output = teacher_model(inputs)
    student_output = student_model(inputs)

    # 计算硬目标损失和软目标损失
    loss_hard = criterion_hard(student_output, labels)
    loss_soft = soft_target_loss(student_output, teacher_output, T)

    # 计算总损失
    loss = (1 - alpha) * loss_hard + alpha * loss_soft

    # 反向传播和优化
    loss.backward()
    optimizer.step()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}")

以上代码展示了一个简单的知识蒸馏示例,其中一个简化的教师模型将知识传递给一个更小的学生模型。在每次训练迭代中,我们计算硬目标损失(基于真实标签)和软目标损失(基于教师模型的输出)。请注意,这个示例仅用于演示知识蒸馏的基本概念,并未涉及数据加载、模型评估等实际应用中的其他关键步骤。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/721217.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Some about RMI

(备份防止忘掉) 一开始编译javac所有文件 这个问题概率遇到 解决方法: 然后java编译impl文件 直到出现bound in registry这一句 然后这个cmd不要关掉 重新在eclipse这个类的工作路径打开一个cmd 看到这个之后回到eclipse里面运行app这个文…

数据结构--树的性质

数据结构–树的性质 树的常考性质 常见考点 1 : 结点数 总度数 1 \color{red}常见考点1:结点数总度数1 常见考点1:结点数总度数1 结点的度 ―― 结点有几个孩子(分支) 树的度 ―― 各结点的度的最大值 m叉树 ―― 每个结点最多只能有m个孩子的树 常见考点 2 : 度为 m 的树、…

MySQL安装以及体系结构

1.简述mysql体系结构 MySQL 最重要、最与众不同的特性是它的存储引擎架构,这种架构的设计将查询处理 (Query Processing)及其他系统任务(Server Task)和数据的存储/提取相分离。这种 处理和存储分离的设计可以在使用时…

【简单认识LVS及LVS-NAT负载均衡群集的搭建】

文章目录 一、LVS群集简介1、群集的含义2、性能扩展方式3、群集的分类4、负载均衡群集架构1、负载均衡的结构 5、三种负载调度工作模式1、NAT模式2、TUN模式3、DR模式 二、LVS虚拟服务器1、Linux Virtual Server简介2、启用LVS虚拟服务3、LVS调度算法(1)…

YoloV5/YoloV7改进---注意力机制:高斯上下文变换器GCT,性能优于ECA、SE等注意力模块 | CVPR2021

目录 1.GCT介绍 实验结果 2.GCT引入到yolov5 2.1 加入common.py中: 2.2 加入yolo.py中: 2.3 yolov5s_GCT.yaml 2.4 yolov5s_GCT1.yaml 1.GCT介绍 论文:https://openaccess.thecvf.com/content/CVPR2021/papers/Ruan_Gaussian_Context_Tra…

Spring源码解析(二):bean容器的创建、默认后置处理器、扫描包路径bean

Spring源码系列文章 Spring源码解析(一):环境搭建 Spring源码解析(二): 目录 一、Spring源码基础组件1、bean定义接口体系2、bean工厂接口体系3、ApplicationContext上下文体系 二、AnnotationConfigApplicationContext注解容器1、创建bean工厂-beanFa…

计算机网络概述(三)

常见的计算机网络体系结构 OSI体系结构: 物理层→数据链路层→网络层→运输层→会话层→表示层→应用层 TCP/IP体系结构: 网络接口层→网际层→运输层→应用层 一般用户的设备都有TCP/IP协议用于连接因特网,TCP/IP的网络接口层并没有规定使用…

Linux基础+html和script一些基本语法

文章目录 linux 基础名字含义指令 html 语法style 样式属性样式标签属性颜色margin 边距ransform 旋转角度重复样式opacity 透明度div 方块元素box-shadow 阴影属性浮动 script获取节点onclick 点击触发setTimeout 定时器利用定时器实现 动画效果 javascript强弱语言区分parseI…

简单详细的MySQL数据库结构及yum和通用二进制安装mysql的方法

目录 mysql体系结构mysql的安装方法一,yum安装1,首先下载一个网络源仓库:2,然后安装 mysql-community-server3,启动mysqld 服务4,然后登录数据库5,初次登录要设置密码,而且不能太简单…

小型电子声光礼花器电子烟花爆竹电路设计

节日和庆典时燃放礼花,其绚丽缤纷的图案,热烈的爆炸声、欢乐的气氛,能给人们留下美好的印象,但有一定的烟尘污染和爆炸危险隐患。本电路可以模拟礼花燃放装置,达到声型兼备的效果,给人们在安全、环保的环境…

redis rehash

dict结构 dictEntry即键值对,每个桶就是dictEntry连接的链表 typedef struct dictEntry {void *key;union {void *val; // 自定义类型uint64_t u64;int64_t s64;double d;} v;struct dictEntry *next; } dictEntry;数据真正指向的地方 typedef struct dictht {di…

京东网站登录二维码显示不出来

环境: 360急速浏览器 Win10专业版 问题描述: 京东网站登录二维码显示不出来 解决方案: 1.打开安全卫士 2.功能大全找到断网急救箱 3.全面诊断一下有问题修复一下,重启浏览器解决

数字迷宫:探秘统计位数为偶数的奇妙世界

本篇博客会讲解力扣“1295. 统计位数为偶数的数字”的解题思路,这是题目链接。 统计位数是偶数的数据个数,关键在于如何统计一个整数的位数。方法是:对于一个整数n,每次/10,都会缩小一位,所以一直进行/10操…

【爬虫】AOI

目前几个大厂,高德百度腾讯,都支持POI爬取,而AOI是需要自己找接口的。 换言之,爬虫需谨慎 1 百度AOI 参考链接是: 这两个链接是选定范围爬取范围内选定类别的AOI 黑科技 | 百度地图抓取地块功能(上&#x…

DeepSpeed-Chat 打造类ChatGPT全流程 笔记二之监督指令微调

文章目录 系列文章0x0. 前言0x1. 🐕 Supervised finetuning (SFT) 教程翻译🏃 如何训练模型🏃 如何对SFT checkpoint进行评测?💁 模型和数据☀️来自OPT-1.3B及其SFT变体(使用不同微调数据)的提示示例☀️…

关于layui实现按钮点击添加行的功能

关于layui实现按钮点击添加行的功能 实现效果 代码实现 <!DOCTYPE html> <html lang"zh" xmlns:th"http://www.thymeleaf.org"> <head><meta charset"UTF-8"><title>Title</title><link rel"styl…

帅气的头像-InsCode Stable Diffusion 美图活动一期

1.运行地址 Stable Diffusion 模型在线使用地址&#xff1a; https://inscode.csdn.net/inscode/Stable-Diffusion 界面截图&#xff1a; 2.模型版本及相关配置 模型&#xff1a;chilloutmix-Ni.safetensor [7234b76e42] 采样迭代步数&#xff08;steps&#xff09;: 30 采样…

QtDesigner的使用

QtDesigner的使用 1、快速入门2、布局管理 1、快速入门 主窗口 菜单栏、工具栏、状态栏 快捷预览方式&#xff0c;工具箱 对象查看器 属性编辑器 英文名作用objectName控件对象名称geometry相对坐标系sizePolicy控件大小策略minnimumSize最小宽度、高度maximumSize最大宽度…

基于jsp+Servlet+mysql学生信息管理系统V2.0

基于jspServletmysql学生信息管理系统V2.0 一、系统介绍二、功能展示1.项目骨架2.数据库表3.项目内容4.登陆界面5.学生-学生信息6、学生-修改密码7、管理员-学生管理8、管理员-添加学生9.管理员-修改学生信息10.管理员-班级信息11.管理员-教师信息 四、其它1.其他系统实现五.获…

旅游卡系统旅行社小程序APP

旅游业的不断发展&#xff0c;旅游卡系统、旅行社小程序APP等数字化工具已经成为了旅行社提升业务效率、提高客户体验的重要手段。下面&#xff0c;我们将为您介绍旅游卡系统旅行社小程序APP的相关内容。 一、旅游卡系统 旅游卡系统是一种将旅游门票、优惠券等资源整合…