BERT知识蒸馏TinyBERT

news2024/11/25 14:10:07

1. 概述

诸如BERT等预训练模型的提出显著的提升了自然语言处理任务的效果,但是随着模型的越来越复杂,同样带来了很多的问题,如参数过多,模型过大,推理事件过长,计算资源需求大等。近年来,通过模型压缩的方式来减小模型的大小也是一个重要的研究方向,其中,知识蒸馏也是常用的一种模型压缩方法。TinyBERT[1]是一种针对transformer-based模型的知识蒸馏方法,以BERT为Teacher模型蒸馏得到一个较小的模型TinyBERT。四层结构的TinyBERT在GLUE benchmark上可以达到BERT的96.8%及以上的性能表现,同时模型缩小7.5倍,推理速度提升9.4倍。六层结构的TinyBERT可以达到和BERT同样的性能表现。

2. 算法原理

为了能够将原始的BERT模型蒸馏到TinyBERT,因此,在[1]中提出了一种新的针对Transformer网络特殊设计的蒸馏方法,同时,因为BERT模型的训练分成了两个部分,分别为预训练和针对特定任务的Fine-tuning,因此在TinyBERT模型的蒸馏训练过程中也设计了两阶段的学习框架,在预训练和Fine-tuning阶段都进行蒸馏,以确保TinyBERT模型能够从BERT模型中学习到一般的语义知识和特定任务知识。

2.1. 知识蒸馏

知识蒸馏(knowledge distillation)[2]是模型压缩的一种常用的方法,对于一个完整的知识蒸馏过程,有两个模型,分别为Teacher模型和Student模型,通过学习将已经训练好的Teacher模型中的知识迁移到小的Student模型中。其具体过程如下图所示:

在这里插入图片描述
对于Student模型,其目标函数有两个,分别为蒸馏的loss(distillation loss)和自身的loss(student loss),Student模型最终的损失函数为:

L = α L s o f t + β L h a r d L=\alpha L_{soft}+\beta L_{hard} L=αLsoft+βLhard

其中, L s o f t L_{soft} Lsoft表示的是蒸馏的loss, L h a r d L_{hard} Lhard表示的是自身的loss。

2.2. Transformer Distillation

BERT模型是由多个Transformer模块(Self-Attention+FFN)组成,单个Self-Attention+FFN模块如下图所示:

在这里插入图片描述

假设BERT模型中有 N N N层的Transformer Layer,在蒸馏的过程中,BERT模型作为Teacher模型,而需要蒸馏的模型TinyBERT模型作为Student模型,其Transformer Layer的层数假设为 M M M,则有 M < N M<N M<N,此时需要找到一个对应关系: n = g ( m ) n = g\left ( m \right ) n=g(m),表示的是在Student模型中的第 m m m层对应于Teacher模型中的第 n n n层,即 g ( m ) g\left ( m \right ) g(m)层。TinyBERT的Embedding层和预测层也是从BERT的相应层学习知识的,其中Embedding层对应的层数为 0 0 0,预测层对应的层数为 M + 1 M+1 M+1,对应到BERT中的层数分别为 0 = g ( 0 ) 0=g\left (0 \right ) 0=g(0) N + 1 = g ( M + 1 ) N + 1 = g\left ( M+1 \right ) N+1=g(M+1)。在形式上,学生模型可以通过最小化以下的目标函数来获取教师模型的知识:

L m o d e l = ∑ x ∈ χ ∑ m = 0 M + 1 λ m L l a y e r ( f m S ( x ) , f g ( m ) T ( x ) ) L_{model}=\sum _{x\in \chi }\sum_{m=0}^{M+1}\lambda _mL_{layer}\left ( f_m^S\left ( x \right ),f_{g\left ( m \right )}^T\left ( x \right ) \right ) Lmodel=xχm=0M+1λmLlayer(fmS(x),fg(m)T(x))

其中, L l a y e r L_{layer} Llayer是给定的模型层的损失函数, f m f_m fm表示的是由第 m m m层得到的结果, λ m \lambda_{m} λm表示第 m m m层蒸馏的重要程度。在TinyBERT的蒸馏过程中,又可以分为以下三个部分:

  • transformer-layer distillation
  • embedding-layer distillation
  • prediction-layer distillation。

2.2.1. Transformer-layer Distillation

Transformer-layer的蒸馏由Attention Based蒸馏和Hidden States Based蒸馏两部分组成,具体如下图所示:

在这里插入图片描述

其中,在BERT中多头注意力层能够捕获到丰富的语义信息,因此,在蒸馏到TinyBERT中,提出了Attention Based蒸馏,其目的是希望使得蒸馏后的Student模型能够从Teacher模型中学习到这些语义上的信息。具体到模型中,就是让TinyBERT网络学习拟合BERT网络中的多头注意力矩阵,目标函数定义如下:

L a t t n = 1 h ∑ i = 1 h M S E ( A i S , A i T ) L_{attn}=\frac{1}{h}\sum_{i=1}^{h}MSE\left ( A_i^S,A_i^T \right ) Lattn=h1i=1hMSE(AiS,AiT)

其中, h h h代表注意力头数, A i ∈ R l × l A_i \in \mathbb{R}^{l\times l} AiRl×l代表Student或者Teacher模型中的第 i i i个注意力头对应的注意力矩阵, l l l代表输入文本的长度。在[1]中使用注意力矩阵 A A A而不是 s o f t m a x ( A ) softmax\left ( A \right ) softmax(A)是因为实验结果显示这样可以得到更快的收敛速度和更好的性能表现。

Hidden States Based的蒸馏是对Transformer层进行了知识蒸馏处理,目标函数定义为:

L h i d n = M S E ( H S W h , H T ) L_{hidn}=MSE\left ( H^SW_h,H^T \right ) Lhidn=MSE(HSWh,HT)

其中,矩阵 H S ∈ R l × d ′ H^S\in \mathbb{R}^{l\times {d}'} HSRl×d H T ∈ R l × d H^T\in \mathbb{R}^{l\times d} HTRl×d分别代表Student网络和Teacher网络的隐状态,且都是FFN的输出。 d d d d ′ {d}' d代表Teacher网络和Student网络的隐藏状态大小,且 d ′ < d {d}' < d d<d,因为Student网络总是小于Teacher网络。 W h ∈ R d ′ × d W_h\in \mathbb{R}^{{d}'\times d} WhRd×d是一个参数矩阵,将Student网络的隐藏状态投影到Teacher网络隐藏状态所在的空间。

2.2.2. Embedding-layer Distillation

Embedding层的蒸馏与Hidden States Based蒸馏一致,其目标函数为:

L e m b d = M S E ( E S W e , E T ) L_{embd}=MSE\left ( E^SW_e,E^T \right ) Lembd=MSE(ESWe,ET)

其中 E S E^S ES E T E^T ET分别代表Student网络和Teacher网络的Embedding, W e W_e We的作用与 W h W_h Wh的作用一致。

2.2.3. Prediction-layer Distillation

除了对中间层做蒸馏,同样对于最终的预测层也要进行蒸馏,其目标函数为:

L p r e d = C E ( z T t , z S t ) L_{pred}=CE\left ( \frac{z^T}{t},\frac{z^S}{t} \right ) Lpred=CE(tzT,tzS)

其中, z S z^S zS z T z^T zT分别是Student网络和Teacher网络预测的logits向量, C E CE CE表示的是交叉熵损失, t t t是温度值,在实验中得知,当 t = 1 t = 1 t=1时效果最好。

综合上述三个部分的Loss函数,则可以得到Teacher网络和Student网络之间对应层的蒸馏损失如下:

L l a y e r = { L e m b d , m = 0 L h i d n + L a t t n , M ≥ m > 0 L p r e d , m = M + 1 L_{layer}=\begin{cases} L_{embd}, & m=0 \\ L_{hidn} + L_{attn}, & M \geq m > 0 \\ L_{pred}, & m=M+1 \end{cases} Llayer= Lembd,Lhidn+Lattn,Lpred,m=0Mm>0m=M+1

2.3. 两阶段的训练

对于BERT的训练来说分为两个阶段,分别为预训练和fine-tunning,预训练阶段可以使得BERT模型能够学习到更强的语义信息,fine-tunning阶段是为了使模型更适配具体的任务。因此在蒸馏的过程中也需要针对两个阶段都蒸馏,即general distillation和task-specific distillation,具体如下图所示:

在这里插入图片描述

在general distillation阶段,通过蒸馏使得TinyBERT能够学习到BERT中的语义知识,能够提升TinyBERT的泛化能力,而task-specific distillation可以进一步获取到fine-tuned BERT中的知识。

3. 总结

在TinyBERT中,精简了BERT模型的大小,设计了三种层的蒸馏,分别为transformer-layer,embedding-layer以及prediction-layer。同时,为了能够对以上三层的蒸馏,文中设计了两阶段的训练过程,分别与BERT的训练过程对应,即预训练和fine-tunning。

参考文献

[1] Jiao X, Yin Y, Shang L, et al. Tinybert: Distilling bert for natural language understanding[J]. arXiv preprint arXiv:1909.10351, 2019.

[2] 知识蒸馏基本原理

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/4273.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PX4基本配置

目录 下载固件 下载原生稳定版固件 安装PX4 Master, Beta或自定义固件 FMUv2 Bootloader 更新 机架设置 飞行控制器/传感器方向 计算朝向 设置朝向 罗盘校准 执行校准 陀螺仪校准 # 执行校准 加速度计 执行校准 空速计校准 执行校准 水平平面校准 执行校准 …

Spring Cloud Zookeeper 升级为Spring Cloud Kubernetes

这里是weihubeats,觉得文章不错可以关注公众号小奏技术&#xff0c;文章首发。拒绝营销号&#xff0c;拒绝标题党 背景 现有的微服务是使用的Spring Cloud Zookeeper这一套&#xff0c;实际应用在Kubernetes中部署并不需要额外的注册中心&#xff0c;本身Kubernetes自己就支持…

10道不得不会的 Java容器 面试题

博主介绍&#xff1a; &#x1f680;自媒体 JavaPub 独立维护人&#xff0c;全网粉丝15w&#xff0c;csdn博客专家、java领域优质创作者&#xff0c;51ctoTOP10博主&#xff0c;知乎/掘金/华为云/阿里云/InfoQ等平台优质作者、专注于 Java、Go 技术领域和副业。&#x1f680; 最…

FFmpeg内存IO模式

ffmpeg 支持从网络流 或者本地文件读取数据&#xff0c;然后拿去丢给解码器解码&#xff0c;但是有一种特殊情况&#xff0c;就是数据不是从网络来的&#xff0c;也不在本地文件里面&#xff0c;而是在某块内存里面的。 这时候 av_read_frame() 函数怎样才能从内存把 AVPacket…

TensorFlow的GPU使用相关设置整理

前言 TensorFlow是一个在机器学习和深度学习领域被广泛使用的开源软件库&#xff0c;用于各种感知和语言理解任务的机器学习。 默认情况下&#xff0c;TensorFlow 会映射进程可见的所有 GPU&#xff08;取决于 CUDA_VISIBLE_DEVICES&#xff09;的几乎全部内存。这是为了减少内…

国考省考行测:问题型材料主旨分析,有问题有对策,主旨是对策,有问题无对策,要合理引申对策

国考省考行测&#xff1a;问题型材料主旨分析&#xff0c;有问题有对策&#xff0c;主旨是对策&#xff0c;有问题无对策&#xff0c;要合理引申对策 2022找工作是学历、能力和运气的超强结合体! 公务员特招重点就是专业技能&#xff0c;附带行测和申论&#xff0c;而常规国考…

【Linux】Linux背景、环境的搭建以及用XShell实现远程登陆

目录Linux 背景Linux环境搭建Linux远程登陆Linux 背景 肯尼斯蓝汤普森最早用汇编语言创建了UNIX系统&#xff0c;后来与他的好“基友”丹尼斯里奇&#xff08;C语言之父&#xff09;&#xff0c;他们两个一同用C语言重新写了UNIX系统&#xff0c;但是操作系统的使用是需要收费…

ActiveState Platform - November 2022

ActiveState Platform - November 2022 ActiveState平台定期更新新的、修补的和版本化的软件包和语言。 Python 3.10.7、3.9.14、3.8.14-解决了许多安全问题的点发布。 Python C库-ibxml 2.10.3、libxslt 1.1.37、libexpat 2.4.9、zlib 1.2.13、curl 7.85.0和sqlite3 3.39.4&am…

Python添加水印简简单单,三行代码教你批量添加

环境使用: Python 3.8Pycharm 如何配置pycharm里面的python解释器? 选择file(文件) >>> setting(设置) >>> Project(项目) >>> python interpreter(python解释器)点击齿轮, 选择add添加python安装路径 pycharm如何安装插件? 选择file(文件) …

使用Python PyQt5完成残缺棋盘覆盖仿真作业

摘要&#xff1a;本文内容是关于如何实现残缺棋盘覆盖仿真软件&#xff0c;算法课作业要求设计开发一个残缺棋盘覆盖仿真软件。使用”分治算法“求解问题&#xff0c;Python编程语言实现功能&#xff1b;使用PyQt5和Python热力图实现界面和仿真效果展示。 1 残缺棋盘覆盖仿真作…

[Linux打怪升级之路]-yun安装和gcc的使用

前言 作者&#xff1a;小蜗牛向前冲 名言&#xff1a;我可以接受失败&#xff0c;但我不能接受放弃 如果觉的博主的文章还不错的话&#xff0c;还请点赞&#xff0c;收藏&#xff0c;关注&#x1f440;支持博主。如果发现有问题的地方欢迎❀大家在评论区指正。 本期学习目标&am…

Java:外包Java项目有什么好处?

Java已经成为众多解决方案的通用开发语言&#xff0c;包括web应用、游戏、软件开发等等。超过710万全球的Java程序员都在忙着为业界下一个最好的应用程序编码。 随着企业努力在当今的全球市场中保持竞争力&#xff0c;对Java项目外包的需求不断增加。 以下是你的企业通过外包Ja…

python基于PHP+MySQL的论坛管理系统

互联网给了我们一个互通互信的途径,但是如何能够更加高效的进行各种问题的分享和交流是很多人关心的问题,市面上比较知名的一些分享交流平台也很多,比如百度的贴吧,知乎等高质量内容分享平台,本系统是一个类似这样的论坛分享系统 随着互联网的发展人们分享和交流的分享也变的越…

leetcode刷题(128)——1575. 统计所有可行路径,动态规划解法

leetcode刷题&#xff08;127&#xff09;——1575. 统计所有可行路径&#xff0c;DFS解法 给你一个 互不相同 的整数数组&#xff0c;其中 locations[i] 表示第 i 个城市的位置。同时给你 start&#xff0c;finish 和 fuel 分别表示出发城市、目的地城市和你初始拥有的汽油总…

【CSS】CSS字体样式【CSS基础知识详解】

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;花无缺 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! 本文由 花无缺 原创 本文章收录于专栏 【CSS】 【CSS专栏】已发布文章 &#x1f4c1;【CSS基础认知】 &#x1f4c1;【CSS选择器全解指南】 本文目录【CS…

物联网感知-光纤光栅传感器技术

一、光纤光栅传感技术 光纤光栅是利用光纤材料的光敏性&#xff0c;通过紫外光曝光的方法将入射光相干场图样写入纤芯&#xff0c;将周期性微扰作用于光纤纤芯&#xff0c;在纤芯内产生沿纤芯轴向的折射率周期性变化&#xff0c;从而形成永久性空间的相位光栅&#xff0c;其作用…

MySQL数据库的基本操作及存储引擎的使用

大家好呀&#xff01;我是猿童学&#x1f435;&#xff0c;最近在学习Mysql数据库&#xff0c;给初学者分享一些知识&#xff0c;也是学习的总结&#xff0c;关注我将会不断地更新数据库知识&#xff0c;也欢迎大家指点一二&#x1f339;。 目录 一、常用的MySQL语句 二、创建…

使用ThinkMusic网站源码配合cpolar,发布本地音乐网站

1、前言 在我们的日常生活中&#xff0c;音乐已经成为不可或缺的要素之一&#xff0c;听几首喜欢的音乐&#xff0c;能让原本糟糕的心情变得好起来。虽然现在使用电脑或移动电子设备听歌都很方便&#xff0c;但难免受到诸多会员或VIP限制&#xff0c;难免让我们回想起音乐网站…

【JavaScript】常用内置对象——数组(Array)对象

文章目录什么是数组创建数组访问数组数组常用方法和属性投票传送门什么是数组 数组&#xff08;Array&#xff09;是最基本的集合类型&#xff0c;由于JavaScript是弱类型语言&#xff0c;因此JavaScript的数组和大多数语言的数组有所区别。在大多数语言中&#xff0c;当声明一…

ubuntu 20.04 qemu u-boot-2022.10 开发环境搭建

开发环境 ubuntu 20.04 VMware Workstation Pro 16 基于qemu&#xff08;模拟器&#xff09;&#xff0c;vexpress-a9 平台 搭建 u-boot-2022.10 (当前最新版本&#xff09; 准备工作 u-boot下载&#xff0c;下载最新稳定版本&#xff0c;当前为 u-boot-2022.10&#xff0…