Rethinking Knowledge Distillation for Autoregressive LMs
Improving Knowledge Distillation with Adaptive Teaching Modes
Experiments
References
Introduction
作者提出 Autoregressive KD with Adaptive Teaching Modes (ATKD),通过对难易样本采用不同的学习策略来解决 larger teachers might dramatically result in a poorer student, especially when the model capability gap is large 的问题,可以作为一种通用的学习策略提升不同的已有 KD 算法的精度
Method
Rethinking Knowledge Distillation for Autoregressive LMs
Reformulation of
L
K
L
\mathcal L_{\mathbf {KL}}
LKL. KL 散度可以被分解为 ground truth 类别上的 binary KL loss
K
L
(
p
b
t
∣
∣
q
b
t
)
\mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t)
KL(pbt∣∣qbt) 和非 ground truth 类别上的 KL loss
K
L
(
p
^
t
∣
∣
q
^
t
)
\mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t})
KL(p^t∣∣q^t),前者可以帮助 student 学习 target 相关的信息,被称为 target-oriented knowledge distillation (TKD),后者可以帮助 student 学习 non-target 中蕴含的知识,被称为 diversity-oriented knowledge distillation (DKD);此外,这两部分的蒸馏损失被加上了一个权值
p
\
g
t
t
p_{\backslash g_t}^t
p\gtt,该项反映了 teacher 的 uncertainty,被称为 uncertainty coefficient (UNC)
L
K
L
=
∑
t
=
1
T
(
p
g
t
t
log
(
p
g
t
t
q
g
t
t
)
+
∑
j
=
1
,
j
≠
g
t
C
p
j
t
log
(
p
j
t
q
j
t
)
)
=
∑
t
=
1
T
(
p
g
t
t
log
(
p
g
t
t
q
g
t
t
)
+
p
\
g
t
t
∑
j
=
1
,
j
≠
g
t
C
p
^
j
t
(
log
(
p
^
j
t
q
^
j
t
)
+
log
(
p
\
g
t
t
q
\
g
t
t
)
)
=
∑
t
=
1
T
(
p
g
t
t
log
(
p
g
t
t
q
g
t
t
)
+
p
∖
g
t
t
log
(
p
∖
g
t
t
q
∖
g
t
t
)
+
p
∖
g
t
t
∑
j
=
1
,
j
≠
g
t
C
p
^
j
t
log
(
p
^
j
t
q
^
j
t
)
=
∑
t
=
1
T
(
K
L
(
p
b
t
∣
∣
q
b
t
)
+
p
\
g
t
t
K
L
(
p
^
t
∣
∣
q
^
t
)
)
\begin{aligned} \mathcal{L}_{\mathrm{KL}}& =\sum_{t=1}^{T}(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+\sum_{j=1,j\neq g_{t}}^{C}p_{j}^{t}\log(\frac{p_{j}^{t}}{q_{j}^{t}})) \\&=\sum_{t=1}^T\left(p_{g_t}^t\log(\frac{p_{g_t}^t}{q_{g_t}^t})\right. \\ &\ \ \ \ \ +p_{\backslash g_{t}}^{t}\sum_{j=1,j\neq g_{t}}^{C}\hat{p}_{j}^{t}\left(\log(\frac{\hat{p}_{j}^{t}}{\hat{q}_{j}^{t}})+\log(\frac{p_{\backslash g_{t}}^{t}}{q_{\backslash g_{t}}^{t}})\right) \\ &=\sum_{t=1}^{T}\left(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+p_{\setminus g_{t}}^{t}\log(\frac{p_{\setminus g_{t}}^{t}}{q_{\setminus g_{t}}^{t}})\right. \\ &\ \ \ \ \ +p_{\setminus g_t}^t\sum_{j=1,j\neq g_t}^C\hat{p}_j^t\log(\frac{\hat{p}_j^t}{\hat{q}_j^t}) \\ &=\sum_{t=1}^T\left(\mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t)+p_{\backslash g_t}^t\mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t})\right) \end{aligned}
LKL=t=1∑T(pgttlog(qgttpgtt)+j=1,j=gt∑Cpjtlog(qjtpjt))=t=1∑T(pgttlog(qgttpgtt)+p\gttj=1,j=gt∑Cp^jt(log(q^jtp^jt)+log(q\gttp\gtt))=t=1∑T(pgttlog(qgttpgtt)+p∖gttlog(q∖gttp∖gtt)+p∖gttj=1,j=gt∑Cp^jtlog(q^jtp^jt)=t=1∑T(KL(pbt∣∣qbt)+p\gttKL(p^t∣∣q^t))其中,
T
T
T 为序列长度,
p
,
q
p,q
p,q 分别为 teacher 和 student 的概率分布,
g
t
gt
gt 为 teacher 预测的 ground-truth 类别,
p
g
t
t
=
exp
(
z
g
t
t
)
∑
j
=
1
C
exp
(
z
j
t
)
,
p
∖
g
t
t
=
∑
k
=
1
,
k
≠
g
t
C
exp
(
z
k
t
)
∑
j
=
1
C
exp
(
z
j
t
)
,
p
^
i
t
=
exp
(
z
i
t
)
∑
j
=
1
,
j
≠
g
t
C
exp
(
z
j
t
)
p_{g_t}^t=\frac{\exp(z_{g_t}^t)}{\sum_{j=1}^C\exp(z_j^t)},p_{\setminus g_t}^t=\frac{\sum_{k=1,k\neq g_t}^C\exp(z_k^t)}{\sum_{j=1}^C\exp(z_j^t)},\hat{p}_i^t=\frac{\exp(z_i^t)}{\sum_{j=1,j\neq g_t}^C\exp(z_j^t)}
pgtt=∑j=1Cexp(zjt)exp(zgtt),p∖gtt=∑j=1Cexp(zjt)∑k=1,k=gtCexp(zkt),p^it=∑j=1,j=gtCexp(zjt)exp(zit),
p
i
t
=
p
∖
g
t
t
⋅
p
^
i
t
p_i^t=p_{\setminus g_t}^t\cdot \hat{p}_i^t
pit=p∖gtt⋅p^it,
p
b
t
=
[
p
g
t
t
,
p
∖
g
t
t
]
\mathrm{p}_{\mathrm{b}}^t=[p_{g_t}^t,p_{\setminus g_t}^t]
pbt=[pgtt,p∖gtt]
Empirical Analyses. (1) UNC measures the learning difficulties of tokens, where the hard-to-learn ones are more important for KD. 根据
p
\
g
t
t
p_{\backslash g_t}^t
p\gtt 的大小可以把 tokens 分为难样本 (top-50% uncertainty) 和简单样本,实验发现难样本对 student 的学习更重要,尤其是 student 和 teacher 差距比较大的时候,这可能是因为难样本能让 student 学到丰富的类间信息,同时避免过拟合 (2) DKD contributes more (than TKD) but is greatly suppressed, especially for the larger teachers. 作者对 TKD 和 DKD 做了解耦,去除了权重
p
\
g
t
t
p_{\backslash g_t}^t
p\gtt 来考察它们各自的作用,作者发现 DKD 显著优于 TKD,但在 KL loss 中,由于
p
\
g
t
t
p_{\backslash g_t}^t
p\gtt 的存在,DKD 的权值被降低了,并且这一现象在更大规模的模型中尤为显著,这也是作者认为的导致 larger teachers might dramatically result in a poorer student 的原因(3) TKD plays different roles in tokens with different learning difficulties. TKD 在简单样本上可能会导致 student 过拟合,从而影响泛化性;在难样本上能降低难样本的学习难度,从而提升 student 精度
Improving Knowledge Distillation with Adaptive Teaching Modes
Autoregressive KD with Adaptive Teaching Modes (ATKD). 基于上述观察很容易想到,不同的 tokens 根据其难易程度,应该有不同的学习策略;简单样本仅使用 DKD,难样本 (top-50% uncertainty) 使用 DKD + TKD
L
K
L
e
=
−
∑
t
∈
D
e
K
L
(
p
^
t
∣
∣
q
^
t
)
,
L
K
L
h
=
−
∑
t
∈
D
h
K
L
(
p
b
t
∣
∣
q
b
t
)
+
K
L
(
p
^
t
∣
∣
q
^
t
)
\begin{aligned} &\mathcal{L}_\mathrm{KL}^{e} =-\sum_{t\in\mathcal{D}_e}\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}), \\ &\mathcal{L}_{\mathrm{KL}}^h =-\sum_{t\in\mathcal{D}_h}\mathrm{KL}(\mathbf{p_b^t}||\mathbf{q_b^t})+\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}) \end{aligned}
LKLe=−t∈De∑KL(p^t∣∣q^t),LKLh=−t∈Dh∑KL(pbt∣∣qbt)+KL(p^t∣∣q^t)最终的损失函数为简单样本和难样本上损失的加权和
L
K
L
a
l
l
=
λ
∗
L
K
L
e
+
(
1
−
λ
)
∗
L
K
L
h
\mathcal{L}_{\mathrm{KL}}^{all}=\lambda*\mathcal{L}_{\mathrm{KL}}^e+(1-\lambda)*\mathcal{L}_{\mathrm{KL}}^h
LKLall=λ∗LKLe+(1−λ)∗LKLh其中,
λ
=
0.2
\lambda=0.2
λ=0.2
Experiments
Compared Results.
S
NLG
\mathcal S_{\textrm{NLG}}
SNLG 为语言生成任务,由 GPT-4 打分;
S
NLU
\mathcal S_{\textrm{NLU}}
SNLU 为语言理解任务,为 benchmark 得分
Ablation Study. (1) Impact of ratio
k
k
k.
k
k
k 用于确定 top-
k
k
k uncertainty 的 tokens 为难样本;(2) Impact of coefficient
λ
λ
λ. 用于确定难易样本损失的权重
References
Zhong, Qihuang, et al. “Revisiting knowledge distillation for autoregressive language models.” arXiv preprint arXiv:2402.11890 (2024).
Lesson 77 Terrible toothache
词汇
appointment n. 预约 构成:point v. 指,指向 用法:point to 人 / 物 指着,指向…… point out 指出(问题) 相关:game point 局点 matc…
在VBA(Visual Basic for Applications)中,注释是一种用于向代码中添加说明或解释文本的方法,这些文本不会被执行。注释对于理解代码的目的、逻辑或特定部分的代码功能非常有帮助,尤其是在处理复杂或长的代码时。
一、…
2024年5月,关于推荐模型自动编译优化的论文《RECom: A Compiler Approach to Accelerate Recommendation Model Inference with Massive Embedding Columns》在系统领域顶会ASPLOS 2024上中选并进行了展示,并被授予了Distinguished Artifact Award 荣誉&…