[ACL 2024] Revisiting Knowledge Distillation for Autoregressive Language Models

news2024/11/28 20:52:25

Contents

  • Introduction
  • Method
    • Rethinking Knowledge Distillation for Autoregressive LMs
    • Improving Knowledge Distillation with Adaptive Teaching Modes
  • Experiments
  • References

Introduction

  • 作者提出 Autoregressive KD with Adaptive Teaching Modes (ATKD),通过对难易样本采用不同的学习策略来解决 larger teachers might dramatically
    result in a poorer student
    , especially when the model capability gap is large 的问题,可以作为一种通用的学习策略提升不同的已有 KD 算法的精度
    在这里插入图片描述

Method

Rethinking Knowledge Distillation for Autoregressive LMs

  • Reformulation of L K L \mathcal L_{\mathbf {KL}} LKL. KL 散度可以被分解为 ground truth 类别上的 binary KL loss K L ( p b t ∣ ∣ q b t ) \mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t) KL(pbt∣∣qbt) 和非 ground truth 类别上的 KL loss K L ( p ^ t ∣ ∣ q ^ t ) \mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t}) KL(p^t∣∣q^t),前者可以帮助 student 学习 target 相关的信息,被称为 target-oriented knowledge distillation (TKD),后者可以帮助 student 学习 non-target 中蕴含的知识,被称为 diversity-oriented knowledge distillation (DKD);此外,这两部分的蒸馏损失被加上了一个权值 p \ g t t p_{\backslash g_t}^t p\gtt,该项反映了 teacher 的 uncertainty,被称为 uncertainty coefficient (UNC)
    L K L = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t ) + ∑ j = 1 , j ≠ g t C p j t log ⁡ ( p j t q j t ) ) = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t )       + p \ g t t ∑ j = 1 , j ≠ g t C p ^ j t ( log ⁡ ( p ^ j t q ^ j t ) + log ⁡ ( p \ g t t q \ g t t ) ) = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t ) + p ∖ g t t log ⁡ ( p ∖ g t t q ∖ g t t )       + p ∖ g t t ∑ j = 1 , j ≠ g t C p ^ j t log ⁡ ( p ^ j t q ^ j t ) = ∑ t = 1 T ( K L ( p b t ∣ ∣ q b t ) + p \ g t t K L ( p ^ t ∣ ∣ q ^ t ) ) \begin{aligned} \mathcal{L}_{\mathrm{KL}}& =\sum_{t=1}^{T}(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+\sum_{j=1,j\neq g_{t}}^{C}p_{j}^{t}\log(\frac{p_{j}^{t}}{q_{j}^{t}})) \\&=\sum_{t=1}^T\left(p_{g_t}^t\log(\frac{p_{g_t}^t}{q_{g_t}^t})\right. \\ &\ \ \ \ \ +p_{\backslash g_{t}}^{t}\sum_{j=1,j\neq g_{t}}^{C}\hat{p}_{j}^{t}\left(\log(\frac{\hat{p}_{j}^{t}}{\hat{q}_{j}^{t}})+\log(\frac{p_{\backslash g_{t}}^{t}}{q_{\backslash g_{t}}^{t}})\right) \\ &=\sum_{t=1}^{T}\left(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+p_{\setminus g_{t}}^{t}\log(\frac{p_{\setminus g_{t}}^{t}}{q_{\setminus g_{t}}^{t}})\right. \\ &\ \ \ \ \ +p_{\setminus g_t}^t\sum_{j=1,j\neq g_t}^C\hat{p}_j^t\log(\frac{\hat{p}_j^t}{\hat{q}_j^t}) \\ &=\sum_{t=1}^T\left(\mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t)+p_{\backslash g_t}^t\mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t})\right) \end{aligned} LKL=t=1T(pgttlog(qgttpgtt)+j=1,j=gtCpjtlog(qjtpjt))=t=1T(pgttlog(qgttpgtt)     +p\gttj=1,j=gtCp^jt(log(q^jtp^jt)+log(q\gttp\gtt))=t=1T(pgttlog(qgttpgtt)+pgttlog(qgttpgtt)     +pgttj=1,j=gtCp^jtlog(q^jtp^jt)=t=1T(KL(pbt∣∣qbt)+p\gttKL(p^t∣∣q^t))其中, T T T 为序列长度, p , q p,q p,q 分别为 teacher 和 student 的概率分布, g t gt gt 为 teacher 预测的 ground-truth 类别, p g t t = exp ⁡ ( z g t t ) ∑ j = 1 C exp ⁡ ( z j t ) , p ∖ g t t = ∑ k = 1 , k ≠ g t C exp ⁡ ( z k t ) ∑ j = 1 C exp ⁡ ( z j t ) , p ^ i t = exp ⁡ ( z i t ) ∑ j = 1 , j ≠ g t C exp ⁡ ( z j t ) p_{g_t}^t=\frac{\exp(z_{g_t}^t)}{\sum_{j=1}^C\exp(z_j^t)},p_{\setminus g_t}^t=\frac{\sum_{k=1,k\neq g_t}^C\exp(z_k^t)}{\sum_{j=1}^C\exp(z_j^t)},\hat{p}_i^t=\frac{\exp(z_i^t)}{\sum_{j=1,j\neq g_t}^C\exp(z_j^t)} pgtt=j=1Cexp(zjt)exp(zgtt),pgtt=j=1Cexp(zjt)k=1,k=gtCexp(zkt),p^it=j=1,j=gtCexp(zjt)exp(zit) p i t = p ∖ g t t ⋅ p ^ i t p_i^t=p_{\setminus g_t}^t\cdot \hat{p}_i^t pit=pgttp^it p b t = [ p g t t , p ∖ g t t ] \mathrm{p}_{\mathrm{b}}^t=[p_{g_t}^t,p_{\setminus g_t}^t] pbt=[pgtt,pgtt]
  • Empirical Analyses. (1) UNC measures the learning difficulties of tokens, where the hard-to-learn ones are more important for KD. 根据 p \ g t t p_{\backslash g_t}^t p\gtt 的大小可以把 tokens 分为难样本 (top-50% uncertainty) 和简单样本,实验发现难样本对 student 的学习更重要,尤其是 student 和 teacher 差距比较大的时候,这可能是因为难样本能让 student 学到丰富的类间信息,同时避免过拟合
    在这里插入图片描述(2) DKD contributes more (than TKD) but is greatly suppressed, especially for the larger teachers. 作者对 TKD 和 DKD 做了解耦,去除了权重 p \ g t t p_{\backslash g_t}^t p\gtt 来考察它们各自的作用,作者发现 DKD 显著优于 TKD,但在 KL loss 中,由于 p \ g t t p_{\backslash g_t}^t p\gtt 的存在,DKD 的权值被降低了,并且这一现象在更大规模的模型中尤为显著,这也是作者认为的导致 larger teachers might dramatically result in a poorer student 的原因在这里插入图片描述在这里插入图片描述(3) TKD plays different roles in tokens with different learning difficulties. TKD 在简单样本上可能会导致 student 过拟合,从而影响泛化性;在难样本上能降低难样本的学习难度,从而提升 student 精度
    在这里插入图片描述

Improving Knowledge Distillation with Adaptive Teaching Modes

  • Autoregressive KD with Adaptive Teaching Modes (ATKD). 基于上述观察很容易想到,不同的 tokens 根据其难易程度,应该有不同的学习策略;简单样本仅使用 DKD,难样本 (top-50% uncertainty) 使用 DKD + TKD
    L K L e = − ∑ t ∈ D e K L ( p ^ t ∣ ∣ q ^ t ) , L K L h = − ∑ t ∈ D h K L ( p b t ∣ ∣ q b t ) + K L ( p ^ t ∣ ∣ q ^ t ) \begin{aligned} &\mathcal{L}_\mathrm{KL}^{e} =-\sum_{t\in\mathcal{D}_e}\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}), \\ &\mathcal{L}_{\mathrm{KL}}^h =-\sum_{t\in\mathcal{D}_h}\mathrm{KL}(\mathbf{p_b^t}||\mathbf{q_b^t})+\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}) \end{aligned} LKLe=tDeKL(p^t∣∣q^t),LKLh=tDhKL(pbt∣∣qbt)+KL(p^t∣∣q^t)最终的损失函数为简单样本和难样本上损失的加权和 L K L a l l = λ ∗ L K L e + ( 1 − λ ) ∗ L K L h \mathcal{L}_{\mathrm{KL}}^{all}=\lambda*\mathcal{L}_{\mathrm{KL}}^e+(1-\lambda)*\mathcal{L}_{\mathrm{KL}}^h LKLall=λLKLe+(1λ)LKLh其中, λ = 0.2 \lambda=0.2 λ=0.2

Experiments

  • Compared Results. S NLG \mathcal S_{\textrm{NLG}} SNLG 为语言生成任务,由 GPT-4 打分; S NLU \mathcal S_{\textrm{NLU}} SNLU 为语言理解任务,为 benchmark 得分
    在这里插入图片描述在这里插入图片描述
  • Ablation Study. (1) Impact of ratio k k k. k k k 用于确定 top- k k k uncertainty 的 tokens 为难样本;(2) Impact of coefficient λ λ λ. 用于确定难易样本损失的权重
    在这里插入图片描述

References

  • Zhong, Qihuang, et al. “Revisiting knowledge distillation for autoregressive language models.” arXiv preprint arXiv:2402.11890 (2024).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2062362.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Go语言开发通过本地数据xdb文件​查询获取IP地址的归属地区及运营商名称

说明: 用本地数据,离线识别ip属地,用于显示用户ip属地,不依赖第三方的api接口,本地数据包解析,解析速度快10微秒级别的查询效率。返回数据固定格式:国家|区域|省份|城市|ISP,例如&a…

c++11(三)

一、可变参数 1、可变参数模板 c语言中的 scanf 和 printf 可以支持我们传入任意个数的参数&#xff0c;原理就是用了参数包。 //可变参数包 template<class ...Args> void Print(Args... args) {} Args&#xff1a;模板参数包 args&#xff1a;函数形参参数包 声明…

检查linux系统中异常进程

1、查看非root运行的进程 [rootbastion-IDC ~]# ps -U root -u root -N 2、查看root运行的进程 [rootbastion-IDC ~]# ps -u root 注意&#xff1a;UID为0的进程&#xff0c;查看该进程所打开的端口和文件 [rootbastion-IDC ~]#ps -ef 查看进程 [rootbastion-IDC ~]# l…

Lesson 77 Terrible toothache

Lesson 77 Terrible toothache 词汇 appointment n. 预约 构成&#xff1a;point v. 指&#xff0c;指向 用法&#xff1a;point to 人 / 物    指着&#xff0c;指向……    point out 指出&#xff08;问题&#xff09; 相关&#xff1a;game point 局点    matc…

statsmodels学习笔记

statsmodels学习笔记 统计模型、假设检验和数据探索。statsmodels是一个python模块&#xff0c;提供了用于估计许多不同统计模型的类和函数&#xff0c;以及用于统计测试和统计数据探索。每个估计器都有一个广泛的结果统计列表。根据现有的统计软件包对结果进行测试&#xff0c…

【C++】深入解析C/C++内存管理:new与delete的使用及原理

C语法相关知识点可以通过点击以下链接进行学习一起加油&#xff01;命名空间缺省参数与函数重载C相关特性类和对象-上篇类和对象-中篇类和对象-下篇日期类 本章将分享C为何放弃malloc/free系列&#xff0c;选择新系列new/delete去管理内存。深度探索new/delete的使用及其原理,m…

VBA注释 (<*> + <*>)

在VBA&#xff08;Visual Basic for Applications&#xff09;中&#xff0c;注释是一种用于向代码中添加说明或解释文本的方法&#xff0c;这些文本不会被执行。注释对于理解代码的目的、逻辑或特定部分的代码功能非常有帮助&#xff0c;尤其是在处理复杂或长的代码时。 一、…

当《黑神话:悟空》遇上openKylin,国产力量的极致碰撞!

万众瞩目的国产3A游戏巨作《黑神话&#xff1a;悟空》终于上线啦&#xff01;&#xff01;&#xff01; 在正式发售后不到24小时&#xff0c;Steam在线玩家峰值突破222万&#xff0c;在Steam所有游戏在线玩家历史峰值中排名第二。第一拨玩家纷纷晒出好评&#xff0c;称这款现象…

Python安装Crypto库报错:ModuleNotFoundError: No module named ‘Crypto‘

目录 from Crypto.Cipher import AES 1.解决方法 1、卸载Crypto和pycrypto库 2、安装pycryptodome库 二、另一种解决方法&#xff08;看的别人遇到的情况&#xff0c;我没有遇到这种情况&#xff09; from Crypto.Cipher import AES 在网上搜的教程使用第三方库实现AES算法…

消息中心业务系统集成方案:提升企业信息流动性与协作效率

在信息化时代&#xff0c;企业的业务系统之间需要实现高效的信息流动与协作&#xff0c;以支持动态的业务需求和快速的决策过程。消息中心作为企业信息管理的重要组成部分&#xff0c;通过整合各类消息和通知&#xff0c;能够提升信息传递的效率和准确性。本文将详细探讨消息中…

Nginx 配置指南

一、Nginx 简介 1.1 概述 Nginx 是一款高性能、轻量级的开源 Web 服务器和反向代理服务器&#xff0c;以其可靠性、丰富的功能和简单的配置而闻名。由 Igor Sysoev 开发&#xff0c;最初用于解决 C10K 问题&#xff0c;与传统的 Web 服务器相比&#xff0c;Nginx 采用异步事件…

使用stream()流合并两个列表

List<Author>结构如下&#xff1a; List<Reader>结构如下&#xff1a; 需求&#xff1a;将Author列表和Reader列表根据相同id合并到一个列表中 private static void mergeList() {List<Author> authors Author.getAuthors();List<Reader> readers …

阅读、分析和维护高质量开源软件有感——小计一笔

目录 一、问题分析 软件开发问题分析 动机 学什么 目的 二、要求 阅读 理解 运用 分析 评估 认知 三、案例选择 MiNotes”开源软件 方式 实践支撑软件工具 操作流程 应该学到的知识 学习过程 四、任务与输出 1.阅读开源软件 2.标注开源软件 3.分析开源…

iLogtail 开源两周年:感恩遇见,畅想未来

早在上世纪 60 年代&#xff0c;早期的计算机&#xff08;例如 ENIAC 和 IBM 的大型机&#xff09;在操作过程中会输出一些基本的状态信息和错误报告&#xff0c;这些记录通常通过打印机输出到纸带或纸卡上&#xff0c;用于跟踪操作流程和调试&#xff0c;最早期的日志系统借此…

前端必备:高效处理树形数据与数组的实用函数

​&#x1f308;个人主页&#xff1a;前端青山 &#x1f525;系列专栏&#xff1a;Vue篇 &#x1f516;人终将被年少不可得之物困其一生 依旧青山,本期给大家带来Vuet篇专栏内容:Vue-树形数据处理|数组:实用函数封装 大家好&#xff0c;依旧青山&#xff0c;在开发项目过程中&a…

3、springboot时代背景

一、微服务 二、分布式 三、云原生 原生应用如何上云。 Cloud Native 上云的困难 服务自愈弹性伸缩服务隔离自动化部署灰度发布流量治理...... 上云的解决

人工智能算法工程师(中级)课程21-深度学习中各种优化器算法的应用与实践、代码详解

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下人工智能算法工程师(中级)课程21-深度学习中各种优化器算法的应用与实践、代码详解。本文将介绍PyTorch框架下的几种优化器&#xff0c;展示如何使用PyTorch中的优化器&#xff0c;我们将使用MNIST数据集和一个简单…

云游戏畅玩黑神话悟空:使用 NVIDIA 4090 体验极致画质

​ 黑神话悟空 爽啦&#xff01;没有好配置又想玩《黑神话&#xff1a;悟空》的朋友们都爽啦&#xff01;自己没有好的 GPU&#xff0c;体验《黑神话&#xff1a;悟空》时画质不好玩的不舒心&#xff1f;厚德云来帮你解决问题&#xff01;厚德云上线了《黑神话&#xff1a;悟空…

机器学习第五十三周周报 MAG

文章目录 week53 MAG摘要Abstract1. 题目2. Abstract3. 预测标准3.1 问题提出3.2 数据预处理3.3 模型架构MAG3.4 时域故障模式识别3.5 故障检测器 4. 文献解读4.1 Introduction4.2 创新点4.3 实验过程4.4 实验结果4.5 结果分析 5. 结论小结参考文献 week53 MAG 摘要 本周阅读…

【ASPLOS2024】RECom:通过编译器技术加速推荐模型推理,论文中选并获得荣誉奖项!

2024年5月&#xff0c;关于推荐模型自动编译优化的论文《RECom: A Compiler Approach to Accelerate Recommendation Model Inference with Massive Embedding Columns》在系统领域顶会ASPLOS 2024上中选并进行了展示&#xff0c;并被授予了Distinguished Artifact Award 荣誉&…