人工智能|深度学习——知识蒸馏

news2024/12/23 4:38:10

一、引言

1.1 深度学习的优点

特征学习代替特征工程:深度学习通过从数据中自己学习出有效的特征表示,代替以往机器学习中繁琐的人工特征工程过程,举例来说,对于图片的猫狗识别问题,机器学习需要人工的设计、提取出猫的特征、狗的特征输入到机器学习模型中才能进行进一步的分类,这个过程非常依赖人的经验和领域知识,而深度学习模型会自己直接从猫狗图片中学习出猫和狗的有效特征表示。

端到端学习代替多模块学习:在一些任务中,传统机器学习方法需要将一个任务的输入和输出之间,人为的分割成多个子模块,也就是分割成多个阶段,每个子模块分开进行训练学习,比如对于一个自然语言理解问题,一般需要切分成分词、词性标注、句法分析、语法分析等多个模块,而端到端学习不进行模块和阶段的划分,直接优化任务的总体目标,中间过程不需要人为干预,训练数据呈现 输入-输出 对的形式,不再需要额外的信息。

1.2 深度学习的缺点

依赖数据量规模:深度学习要想发挥出理想的效果,需要大规模的数据,当数据量偏少时可能还不如传统的机器学习方法。

模型体积过大:深度学习要想从数据中学习出更有效的特征表示,一般会通过加深模型层数的方法,随着残差连接和多种正则化方法的提出,训练更深层的模型变为可能,这也导致了深度学习模型的体积变的越来越大,无法部署在那些资源受限的设备上,往往只是理论上能达到最优,但是无法真正进行落地使用。

可解释性差:在深度学习的眼中,万事万物都是向量(更准确的说叫张量),外界对象需要被表示为向量才能输入到模型中进行进一步的处理,在深度学习中把将外界对象表示为向量这个过程叫做嵌入,比如将一个词语表示为向量叫做词嵌入,但是表示成向量之后,它的解释性就很差,比如用 [0.3,0.4,9.2] 这个向量表示‘我’这个词,你就不知道这几个数字究竟表示什么意义。

二、什么是知识蒸馏

2.1 模型压缩

模型压缩在不降低或者只是轻微降低原模型准确率的同时,大幅缩小原模型的体积,使其可以真正进行线上部署,常用的模型压缩方法包括

参数裁剪:删除掉原模型中一些无用的参数,缩小模型的体积

精度转换:降低原模型中参数的存储精度

神经网络结构搜索:寻找原模型中真正对最终结果起作用的网络层,删除掉影响不大的网络层,降低模型的体积。

2.2 什么是学习

赫尔伯特.西蒙曾经给学习下过定义:“如果一个系统能够通过执行某个过程改进它的性能,这就是学习。”

具体到深度学习的过程,也就是训练的过程,就是神经网络根据损失函数的约束,从输入的数据中发掘信息,从信息中再获取到对于最终任务起关键性作用的知识。

这些学习到的知识以参数的形式固化在神经网络中,当我们将数据输入到训练完毕的神经网络中,可以获取到神经网络关于数据形成的知识。

2.3 什么是知识蒸馏

知识蒸馏也是一种模型压缩方法,参数裁剪、精度转换、神经网络结构搜索这些模型压缩方法会破坏原模型的结构,也就是会损坏原模型从数据中学习到的知识,而知识蒸馏通过蒸馏的手段保护原模型中学习到的知识,然后将这些知识迁移到压缩模型中,使压缩模型虽然相比原模型体积要小的多,但是也可以学习到相同的知识。

2.4 知识蒸馏的一般流程

类比人类的学习过程,在知识蒸馏中称要进行压缩的模型为教师神经网络(Teacher Model),压缩之后的模型为学生神经网络(Student Model),一般情况下,教师神经网络的体积要远大于学生神经网络。

一般的知识蒸馏过程为

首先利用数据集训练教师神经网络,让教师神经网络充分学习数据中包含的知识

然后在利用数据集训练学生神经网络时,通过蒸馏方法将教师神经网络中已经学习到的知识提取出来,指导学生神经网络的训练,这样学生神经网络相当于从教师神经网络那里获取到了关于数据集的先验信息。

也就是在知识蒸馏中,教师神经网络是预先在数据集上进行过训练的,然后在学生神经网络的训练过程中利用自身学习到的知识对其进行指导,帮助提高学生神经网络的准确率。

使用知识蒸馏要解决的关键问题是

  • 如何发掘教师神经网络中包含的知识
  • 如何将教师神经网络中的知识通过蒸馏无损的迁移到学生神经网络中,也就是蒸馏方法的设计
  • 如何设计学生神经网络的结构

三 知识蒸馏的分类

3.1 模型结构的种类

深度学习中虽然模型众多,但是其结构可以归为四种

前馈神经网络:也叫多层感知机,MLP,前馈神经网络由 线性变换+非线性激活 组成,通过线性变换将输入空间中的数据变换到特征空间,利用非线性激活函数无限逼近真实的判别函数。

卷积神经网络:CNN,卷积神经网络是连接受限的前馈神经网络,适合处理具有局部相关性的数据,比如图像

循环神经网络:RNN,循环神经网络会携带网络处理过程中产生的历史信息进行接下来的处理,适合处理那些具有时序性特征的数据

Transformer:带有注意力机制的前馈神经网络,利用注意力机制获取数据中的关键信息,可以利用有限的计算资源处理更重要的信息。

综上,多种网络结构其实可以统一看成前馈神经网络。

3.2 知识的分类

在知识蒸馏中,将教师神经网络中的知识分为三种

输出层知识:图中的Response-based Knowledge,是教师神经网络最后一层的输出,这个输出未经过Softmax层转换为概率,一般称为Logits,关于Logits的具体介绍可见Logits

中间层知识:图中的Feature-Based Knowledge,指的是教师神经网络中间网络层的输出、包含的参数

关系型知识:图中的Relation-Based Knowledge,将教师神经网络不同层知识之间的关系作为一种知识,也叫结构型知识。

3.3 如何蒸馏

四、输出层知识蒸馏

《Distilling the Knowledge in a Neural Network 》 2015
Hinton2015年在这篇文章中首次提出知识蒸馏的概念和方法,并在MNIST手写体数字识别数据集上验证了方法的有效性。

假设我们现在的任务是利用神经网络识别 1~5 的手写体数字图片,也就是将一张手写体数字图片输入到神经网络中,神经网络要判断出这张图片中的数字究竟是几。

 但是这些数字的大小相差太大,类似于归一化,先想办法在不改变它们原有分布的情况下,改变这些数值的大小,使其具有可比性,Hinton在这里引入了一个称为'温度'的参数,对Logtis进行平滑处理,知识蒸馏这个词语也是来自于这个过程,具体平滑公式是

 蒸馏过程为

 


五、中间层知识蒸馏

《Learning Metrics from Teachers: Compact Networks for Image Embedding》2019 CVPR
可以将神经网络看作是一个解决问题的过程,最后神经网络的输出结果就是神经网络对问题的解,而中间的网络层就是解决问题的步骤,既然可以让学生神经网络直接学习教师神经网络输出的问题结果,也可以利用蒸馏损失函数让学生神经网络学习教师神经网络的解题过程,也就是学习教师神经网络中间层的知识。

这篇论文中的中间层知识蒸馏过程如下图

 首先利用数据集训练教师神经网络,然后在学生神经网络训练的过程中,将数据同样输入到教师神经网络中,获取教师神经网络每个中间网络层输出的特征图,同样获取学生神经网络的特征图,然后定义中间层蒸馏损失函数为

 中间层知识蒸馏的一般流程为

 首先训练教师神经网络,然后获取教师神经网络中间层的知识,在训练学生神经网络时获取教师神经网络、学生神经网络中间层的知识,利用蒸馏损失函数进行中间层知识蒸馏。

六、关系层知识蒸馏

《Relational Knowledge Distillation》2019 CVPR


普通的知识蒸馏(上图中左侧的Conventional KD)中学生神经网络学习到的是一对一的教师神经网络产生的知识,而在关系层知识蒸馏(上图中右侧Relational KD)中学生神经网络学习的是知识之间的结构关系(Structure Knowledge),增强知识蒸馏的泛化性。

在这篇文章中,作者提出的关系层知识蒸馏过程如下

在这篇文章中,作者每次选择两个样本进行关系型知识蒸馏,定义关系抽取函数为

其中的 u 是一个正则化因子定义为

在本文中使用的衡量相似性函数为

最终关系型蒸馏损失函数定义为

七、其它知识蒸馏方法

7.1 多教师知识蒸馏


相当于是一种模型集成的方法,利用知识蒸馏将多个教师神经网络中的知识迁移到学生神经网络中,让学生神经网络在多个不同的特征空间中进行学习,可以大幅提高学生神经网络的准确率。

7.2 融合图神经网络的知识蒸馏


将教师神经网络中的知识构建成图,然后利用图神经网络中的方法获取其中的关系型知识,进行关系层的知识蒸馏。

7.3 结合多模态的知识蒸馏


数据集是一个多模态的数据集,比如包含声音、图像、文本,在A模态上训练教师神经网络,然后在利用B模态训练学生神经网络的时候利用知识蒸馏,相当于一种多模态融合方法。

八、知识蒸馏代码的编写

知识蒸馏代码编写中,教师神经网络按照普通的深度学习流程在数据集上进行训练即可,重点在于神经网络中知识的获取和蒸馏损失函数的编写,拿PyTorch举例

8.1 如何获取到神经网络中的知识

利用hook机制获取神经网络中的知识,hook机制能够使我们获取神经网络中指定层在前向传递过程中的输出,也就是相应的知识

import torch
 
 
# 获取模型中的知识
class GetFeatures:
    
    # 指定想要获取知识的模型和相应的网络层
    def __init__(self, model, layer_num):
        # 获取到的知识
        self.features = None
        # 注入hook
        self.hook = model[layer_num].register_forward_hook(self.hook_fn)
 
    # hook函数
    def hook_fn(self, module, input, output):
        # 获取模型相应网络层的输出
        self.features = output.cuda()
 
    # 移除hook
    def remove(self):
        self.hook.remove()
 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1304013.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

产品<Axure的安装以及组件介绍

Axure介绍: Axure是一款用户体验设计工具,可以用于创建交互式原型、线框图和设计文档。它支持快速原型开发、界面设计、信息架构、流程图和注释等功能,可以帮助设计师快速地创建和共享交互式原型,从而更好地与客户和团队协作。 …

从 MQTT、InfluxDB 将数据无缝接入 TDengine,接入功能与 Logstash 类似

利用 TDengine Enterprise 和 TDengine Cloud 的数据接入功能,我们现在能够将 MQTT、InfluxDB 中的数据通过规则无缝转换至 TDengine 中,在降低成本的同时,也为用户的数据转换工作提供了极大的便捷性。由于该功能在实现及使用上与 Logstash 类…

「差生文具多系列」推荐两个好看的 Redis 客户端

📢 声明: 🍄 大家好,我是风筝 🌍 作者主页:【古时的风筝CSDN主页】。 ⚠️ 本文目的为个人学习记录及知识分享。如果有什么不正确、不严谨的地方请及时指正,不胜感激。 直达博主:「…

总结6种@Transactional注解的失效场景

作者简介:大家好,我是smart哥,前中兴通讯、美团架构师,现某互联网公司CTO 联系qq:184480602,加我进群,大家一起学习,一起进步,一起对抗互联网寒冬 引言 昨天有粉丝咨询了…

【漏洞修复】Cisco IOS XE软件Web UI权限提升漏洞及修复方法

关于Cisco IOS XE软件Web UI权限提升漏洞及修复方法 文章目录 漏洞基本信息漏洞影响范围确认设备是否受影响漏洞修复方法推荐阅读 漏洞基本信息 Cisco IOS XE Unauthenticatd Remote Command Execution (CVE-2023-20198) (Direct Check) Severity:Critical Vulnerability Pri…

【Jeecg Boot 3 - 第二天】2.1、nginx 部署 JEECGBOOT VUE3

一、场景 二、实战 ▶ 2.1、打包(build 前端) > Stage 1:修改配置文件 .env.production(作用:指向后端接口地址) > Stage 2:点击build(作用&#xff1…

动态规划——数塔问题(三维数组的应用)

一、例题要求及理论分析 声明:理论指导《算法设计与分析 第四版》 因为这个地方用到了三维数组,感觉很有意思就故意挑出来分享给大家(三维数组可以看成很多页二维数组) 4.5.1认识动态规划数塔问题: 如图4-12所示的一…

小型洗衣机哪个牌子质量好?迷你洗衣机排名前十名

随着内衣洗衣机的流行,很多小伙伴在纠结该不该入手一款内衣洗衣机,专门来洗一些贴身衣物,答案是非常有必要的,因为我们现在市面上的大型洗衣机只能做清洁,无法对我们的贴身衣物进行一个高强度的清洁,而小小…

2023年最新prometheus + grafana搭建和使用+gmail邮箱告警配置

一、安装prometheus 1.1 安装 prometheus官网下载地址 sudo -i mkdir -p /opt/prometheus #移动解压后的文件名到/opt/,并改名prometheus mv prometheus-2.45 /opt/prometheus/ #创建一个专门的prometheus用户: -M 不创建家目录, -s 不让登录 useradd…

web服务器之——搭建两个基于不同端口访问的网站

要求如下: 建立一个使用web服务器默认端口的网站,设置DocumentRoot为/www/port/80,网页内容为:the port is 80。建立一个使用10000端口的网站,设置DocumentRoot为/www/port/10000,网页内容为:t…

Centos7 首次 安装Mysql8.0

随笔记录 背景介绍:重装Centos7 系统,没有安装mysql 目录 1. 查看否有MariaDB与MySQL 2. MySQL官网下载适用于centos7的mysql安装包 2.1 查询服务器是x86_64架构还是arm架构 2.2 查系统版本 2.3 下载适用于系统版本安装包 2.3.1 国内镜像源下载…

@Transactional失效问题

作者简介:大家好,我是smart哥,前中兴通讯、美团架构师,现某互联网公司CTO 联系qq:184480602,加我进群,大家一起学习,一起进步,一起对抗互联网寒冬 关于Transactional 日…

应用在LED灯光控制触摸屏中的触摸芯片

LED灯光控制触摸屏方法,包括:建立触摸屏的触摸轨迹信息与LED灯光驱动程序的映射关系;检测用户施加在触摸屏上的触摸轨迹,生成触摸轨迹信息;根据生成的触摸轨迹信息,调用对应的LED灯光驱动程序,控…

算法-05-二分查找

二分查找(Binary Search)算法,也叫折半查找算法,是一种针对有序数据集合的查找算法。 1-二分查找的思想 我们生活中猜数字的游戏,告诉你一个数据范围,比如0-100,然后你说出一个数字&#xff0c…

​pathlib --- 面向对象的文件系统路径​

3.4 新版功能. 源代码 Lib/pathlib.py 该模块提供表示文件系统路径的类,其语义适用于不同的操作系统。路径类被分为提供纯计算操作而没有 I/O 的 纯路径,以及从纯路径继承而来但提供 I/O 操作的 具体路径。 如果以前从未用过此模块,或不确定…

1、springboot项目运行报错

问题1:获取不到配置文件的参数 我的配置文件获取的参数如下: public class Configures{Value("${configmdm.apk.apkName}")private static String apkName;private void setApkName(String apkName) {Configures.apkName apkName;}private …

k8s详细教程(一)

—————————————————————————————————————————————— 博主介绍:Java领域优质创作者,博客之星城市赛道TOP20、专注于前端流行技术框架、Java后端技术领域、项目实战运维以及GIS地理信息领域。 🍅文末获取源码…

OpenSSL 编程指南

目录 前言初始化SSL库创建SSL 上下文接口(SSL_CTX)安装证书和私钥加载证书(客户端/服务端证书)加载私钥/公钥加载CA证书设置对端证书验证例1 SSL服务端安装证书例2 客户端安装证书创建和安装SSL结构建立TCP/IP连接客户端创建socket服务端创建连接创建SSL结构中的BIOSSL握手服务…

数据结构基础介绍

一.起源及重要性 1968 年,美国的高德纳 Donakl E . Kn uth 教授在其所写的《 计算机程序艺术》第一卷《基本算法 》 中,较系统地阐述了数据的逻辑结构和存储结构及其操作, 开创了数据结构的课程体系 ,数据结构作为一门独立的…

基于单片机的定时插座在智能家居中的应用

近年来,随着科学技术的发展迅速,人们对智能化的要求越来越高。越来越多的智能化产品进入千家万户,如电脑电视、扫地机器人、智能空气净化器等。这些家居电器和电子产品大都需要连接电源,为满足多种用电器的正常使用,延…