知识蒸馏Matching logits与RocketQAv2

news2025/1/18 20:15:49

知识蒸馏Matching logits

公式推导

刚开始的\frac{\partial L}{\partial z_i}=q_i-p_i怎么来,可以转看下面证明梯度等于输出值-标签y

C是一个交叉熵,我们要求解的是这个交叉熵对z_i的这个梯度。z_i就是你可以理解成第i个类别的得分。z_i就是student model,被蒸馏的模型,它所输出的logits。

p_i是什么?是target probability对吧。q_i是什么?q_i认为就是这个distilled model的输出的那个probability。所以就是说这两个概率相减,再乘以这个T分之一T是什么?T是一个温度。

我们现在假定是说我们是用teacher model输出的这个label,然后去训练student model,或者说去训练distilled model。我们对这个第i个类别的梯度,就等于\frac{1}{T}{(q_i-p_i)},然后呢,q_ip_i可以做一个化简。

q_ip_i进行展开,概率都是用softmax算出来的,就可以得到这个式子。

通过e^x\approx 1+x来进行化简,这个式子在x比较小的时候是成立的。

在这里,当T足够大的时(相比z的logits,即z),\frac{z_i}{T}就足够的小,接近于0,此时e^{\frac{z_i}{T}}\approx 1+\frac{z_i}{T}

\sum_j e^{\frac{z_i}{T}}\approx \sum_j{1+\frac{z_i}{T}}=N+\sum_j{\frac{z_i}{T}} 

z_j的这个累加,它就等于零。这个v_j的这个累加也等于零,即\sum_j z_j=\sum_j v_j=0,所以这两个分母直接就变成了N。

\frac{1}{T}({\frac{1+z_i/T}{N}}-{\frac{1+v_i/T}{N}})=\frac{1}{TN}{\frac{z_i-v_i}{T}}

则所求梯度

想说明的事情

它其实就想说明这样一个事情。我们试图用一个teacher model,或者说我们想用VI对应的那个概率叫p_iz_i对应的概率叫q_i。如果我们想用这个p_i作为label去用交叉商去训练q_i去用这个soft label的交叉商去训练q_i,那么其实我们可能不需要套用交叉商这个东西了,我们也不需要什么softmax的label的交叉商,然后去做这个事了。因为这个东西在我们的这样一通推导下就会发现,其实就等于均方误差,右边这一项其实就是什么均方误差的求导,它就是均方误差求导之后的结果,你可以这样认为。

我们就会发现说,原来对于交叉商对于这个知识蒸馏的这个交叉商,然后我们对他求导求出来的梯度其实是近似等同于我们直接用MSE去训练,然后得到的梯度的。那么既然这样,我们为什么不直接用MSE?

它的推导就告诉我们说我们对于两个模型,两个多分类模型来说,我们要用a模型去交B模型做蒸馏。我们没有必要让这两个模型生成分别生成什么label,然后再生成预测的概率,然后再加上去优化了。

我们直接让这两个多分类模型的这个logic,然后直接做MSE就可以了,就可以做到一种就是一种这种MSE就是一种什么蒸馏的特殊形式。就是蒸馏的一个最早期的雏形,其实在这个时候都还没有考虑用这个什么KL散度来做,就只是提出最简单的一个思想是什么,就是用MSE来做就够了。

我们一直即便到今天,我们做很多知识正溜的实验,我们依然会发现MIC可能有的时候都会比K要好。虽然大家都说自己用什么KL散度用什么JS散度,但是就是否现在就最优,还真不一定有的时候就是MSE效果好。

注:MSE = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2

需要注意的事

公式的推导基于两个假设:

1.T得足够大的(相比z的logits,即z_i)

2.模型输出的logic是零均值的(即均值为0),因为模型输出的logic是零均值的,这个z_j的这个累加,它就等于零。这个v_j的这个累加也等于零,即\sum_j z_j=\sum_j v_j=0

 证明梯度等于输出值-标签y

softmax函数

归一化,使其输出的概率和为1

S_i=\frac{e^{z_i}}{\sum_ke^{z_k}}

S_i代表的是第i个神经元的输出。

神经元的输出,一个神经元如下图:

z_i=\sum_jw_{ij}x_{ij}+b

其中w_{ij}是第i个神经元的第j个权重,b是偏移值。z_i表示该网络的第i个输出。

给这个输出加上一个softmax函数,得a_i=\frac{e^{z_i}}{\sum_ke^{z_k}}

a_i代表softmax的第i个输出值

交叉熵损失函数 loss function

L=-\sum_i{y_i}{lna_i}

其中y_i表示真实的分类结果。

证明梯度等于输出值-标签y

loss对于神经元输出z_i的梯度为\frac{\partial L}{\partial z_i}=\frac{\partial L}{\partial a_j}\frac{\partial a_j}{\partial z_i}

由于softmax公式的特性,它的分母包含了所有神经元的输出,对于不等于i的其他输出里面,也包含着z_i,所有的a都要纳入到计算范围中,并且后面的计算可以看到需要分为i=ji \ne j两种情况求导。

由于\frac{\partial (-\sum_{k\ne j}y_{k}ln a_k)}{\partial a_j}=0

\frac{\partial C}{\partial a_j}=\frac{\partial (-\sum_jy_jln a_j)}{\partial a_j}=-\sum_jy_j\frac{1}{a_J}

如果i=j

\frac{\partial a_i}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_ke^{z_k}})}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_{k\ne i}e^{z_k}+e^{z_i}})}{\partial z_i}=\frac{\sum_ke^{z_k}e^{z_i}-(e^{z_i})^2}{\sum_k(e^{z_k})^2}
=(\frac{e^{z_i}}{\sum_ke^{z_k}})(1-\frac{e^{z_i}}{\sum_ke^{z_k}})=a_i(1-a_i)

这里\sum_ke^{z_k}=\sum_{k\ne i}e^{z_k}+e^{z_i}

如果i \ne j

这里\sum_ke^{z_k}=\sum_{k\ne j}e^{z_k}+e^{z_j}

\frac{\partial a_i}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_ke^{z_k}})}{\partial z_i}=-e^{z_j}(\frac{1}{\sum_ke^{z_k}})e^{z_i}=-a_ia_j

综上

\frac{\partial L}{\partial z_i}=\frac{\partial L}{\partial a_j}\frac{\partial a_j}{\partial z_i}=(-\sum_jy_j\frac{1}{a_j})\frac{\partial a_j}{\partial z_i}=-\frac{y_i}{a_i}a_i(1-a_i)+\sum_{j\neq i}\frac{y_i}{a_j}a_ia_j
=-y_i+y_ia_i+\sum_{j\neq i}{y_ia_i}=-y_i+a_i\sum_{j}y_j

最后,针对分类问题,我们给定的结果y_i最终只会有一个类别是1,其他非标签类别都是0,因此,对于分类问题,这个梯度等于

\frac{\partial L}{\partial z_i}=a_i-y_i

知识蒸馏RocketQAv2

https://arxiv.org/pdf/2110.07367.pdf

这个模型有两部分组成一个retriever和一个ranker。这个做的事情就是说用label去监督re-ranker,然后用ranker去监督retriever。用KL散度去约束它约束,用这个K散路去让这个re-ranker的分布和retriever的分布对齐。

要注意就是说。这里就是他们就没有用MSE,就是说如果用MSE怎么做,就是说对应的这个直接相减,就对应位置直接相减,然后分MSE就行。这里用的是KL散度。

KL散度的定义,你可以认为是这样的,让这两个概率分别相除,除完了之后都要再取对数,然后再乘以这个概率。

DE,这个teacher model的概率乘以teacher model的概率乘以log,teacher model的概率除以student model的概率。然后把这么多概率给它都累加起来。

在这里,假定这里的是retriever给出来的一个概率分布假如说是十个候选,ranker也给了这样一个概率分布,那么就是十个概率分布对应的一项一项的去算这个KL度,即概率除概率,然后再取对数,然后再乘上ranker这个概率。

然后再把这十项给它累加起来,然后就是一个KL散度,这样的话,这个K散度其实是现在就是接受最多的一种损失函数。

因为KL散度就是天生的,可以捕获这个分布和分布之间的距离。像MSE缺点是什么?MSE的缺点是它没有整体的那种距离衡量的能力。MSE其实是对于细节的这种距离的衡量很强。如果MSE来的话,每一个每一项,这十项每一项的重要性对于MIC来说都是一样的。但是这个KL散度可能就会更在乎一个整体的一个分布上的一个区别了,就而不是说就在乎一些细节上的一些差别,因为有可能就是说。你某一些细节差距虽然大一些,但是你整体差距不大,所以KL散度也可以比较小。

实际上一切可以衡量两个分布之间距离的指标都可以用来做知识蒸馏,所以其实wasserstein距离也可以用来作为蒸馏的损失函数:

https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Wasserstein_Contrastive_Representation_Distillation_CVPR_2021_paper.pdf

为什么知识蒸馏会有效?

1. teacher model可以生成soft label,相比于原始数据的hard label,包含了更多信息量。

所以很多时候你与其说直接用一个数据集去训练一个模型,你还不如用这个数据集先训练一个大a模型比a模型要大的模型。再让大a模型去教会a模型去做,有可能效果就更好。就是因为大a模型这个teacher model可以生成soft label相比于原始数据的hard label,可以包含更多的信息量,从而就天然的有一种去燥的一种功能。

2. teacher model可以为大量的无标签数据打上label,然后为student提供一个大规模的训练集。然后从而可以给student提供一个更大尺度的训练集,然后防止student的一个过拟合,然后提高student model的一个泛化能力。也就是说,teacher model可以把自己的泛化能力交给student model

在这个知识蒸馏的过程当中,这也是为什么说很多大公司里边现在线上的模型都是蒸馏出来的小模型就是因为我们与其说直接训练小模型。还不如说就用这个蒸馏去蒸馏一个小模型反而泛化能力会更强一些

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1518709.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大根堆排序

堆是完全二叉树,分为大根堆和小根堆 完全二叉树 从左到右依次变满,高度O(logn) 非完全二叉树: 需要知道的几个点【堆可以看做一段连续的数组来存放】 i是索引位置 i位置的左孩子:2 * i + 1 i位置的右孩子:2 * i + 2 i位置的父亲节点:( i - 1 ) / 2 大根堆【每一颗…

es 分词器详解

基本概念 分词器官方称之为文本分析器,顾名思义,是对文本进行分析处理的一种手段,基本处理逻辑为按照预先制定的分词规则,把原始文档分割成若干更小粒度的词项,粒度大小取决于分词器规则。 分词器发生的时期 1、分词…

兼容性测试策略

📋 个人简介 作者简介:大家好,我是凝小飞,软件测试领域作者支持我:点赞👍收藏⭐️留言📝 一.背景介绍 Android严重的碎片化,主要体现在品牌碎片化、设备碎片化、系统碎片化、分辨率碎…

HDFS的架构优势与基本操作

目录 写在前面一、 HDFS概述1.1 HDFS简介1.2 HDFS优缺点1.2.1 优点1.2.2 缺点 1.3 HDFS组成架构1.4 HDFS文件块大小 二、HDFS的Shell操作(开发重点)2.1 基本语法2.2 命令大全2.3 常用命令实操2.3.1 上传2.3.2 下载2.3.3 HDFS直接操作 三、HDFS的API操作3…

将Linux curl命令转换为windows平台的Python代码

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

怎么采集美团的数据

怎么使用简数采集器批量采集美团的活动、商家和商品相关信息呢? 简数采集器暂时不支持采集美团的相关数据,建议换其他网站采集,谢谢。 简数采集器采集网站文章数据特别高效方便,在简数智能向导模式下,只要填写要采集…

【Python】进阶学习:一文了解NotImplementedError的作用

【Python】进阶学习:一文了解NotImplementedError的作用 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程👈 希望…

ios开发错误积累

1.xcode 下载模拟器报错 Could not download iOS 报错: 解决: 1、去官网下载自己需要 地址(https://developer.apple.com/download/all) 2、下载完成后,执行以下命令添加:xcrun simctl runtime add /路径…

桌面备忘录,电脑桌面备忘录怎么设置

在当今快节奏的生活中,备忘录成为了人们工作和生活中不可或缺的工具。然而,随着科技的发展,纸质备忘录逐渐被电子桌面备忘录所取代。在电脑桌面设置备忘录,可以更加高效地管理任务和提醒事项。 电脑桌面是我们日常工作和娱乐的主…

Dense Distinct Query for End-to-End Object Detection

摘要 对象检测中的一对一标签分配成功地消除了作为后处理的非极大值抑制( NMS )的需要,并使流水线端到端。然而,这引发了一个新的困境,因为广泛使用的稀疏查询无法保证高召回率,而密集查询不可避免地带来更…

论文篇00-【历年论文真题考点汇总】与【历年论文原题2009~2023年文字版记录】(2024年软考高级系统架构设计师冲刺知识点总结-论文篇-先导篇)

专栏系列文章推荐: 案例分析篇00-【历年案例分析真题考点汇总】与【专栏文章案例分析高频考点目录】 综合知识篇00-综合知识考点汇总目录 ...... 历年真题论文题考点汇总 历年软考系统架构设计师论文原题(2009-2022年) 因最新的2023年目前仅能搜索到回忆版,等楼主搜集到…

内容检索(2024.03.15)

随着创作数量的增加,博客文章所涉及的内容越来越庞杂,为了更为方便地阅读,后续更新发布的文章将陆续在此汇总并附上原文链接,感兴趣的小伙伴们可持续关注文章发布动态! 本期更新内容: 1. 信号完整性理论与…

找机厅 洛谷 BFS

P10234 [yLCPC2024] B. 找机厅 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) #include<bits/stdc.h> #define pii pair<int,int> #define fr first #define sc second using namespace std; string maze[2000]; int vis[2000][2000]; char dirs[2005][2005]; st…

数据和类型转换

文章目录 数据类型数字类型数字操作NaNJavaScript算术运算符的执行顺序 字符串类型&#xff08;string&#xff09;字符串拼接模板字符串 未定义类型&#xff08;undefined&#xff09;布尔类型&#xff08;boolean&#xff09;null&#xff08;空类型&#xff09; 类型转换显式…

Ubuntu上搭建TFTP服务

Ubuntu上搭建TFTP服务 TFTP服务简介搭建TFTP服务安装TFTP服务修改配置文件 重启服务 TFTP服务简介 TFTP是一个基于UDP协议实现的用于在客户机和服务器之间进行简单文件传输的协议&#xff0c;适用于开销不大、不复杂的应用场合。TFTP协议专门为小文件传输而设计&#xff0c;只…

Java中 常见的开源树库介绍

阅读本文之前请参阅------Java中 树的基础知识介绍 在 Java 中&#xff0c;有几种流行的开源树库&#xff0c;它们提供了丰富的树算法和高级操作&#xff0c;可以帮助开发者更高效地处理树相关的问题。以下是几种常见的 Java 树库及其特点和区别&#xff1a; JTree 特点…

移动硬盘无法读取怎么修复?教你四招快速解决!

随着科技的发展&#xff0c;移动硬盘已经成为我们日常生活和工作中不可或缺的数据存储设备。然而&#xff0c;有时候我们可能会遇到移动硬盘无法读取的问题&#xff0c;这不仅会给我们带来数据丢失的风险&#xff0c;还可能影响我们的工作进度。下面给大家分享四种针对移动硬盘…

Qt教程 — 3.1 深入了解Qt 控件:Buttons按钮

目录 1 Buttons按钮简介 1.1 Buttons按钮简介 1.2 Buttons按钮如何选择 2 如何使用Buttons按钮 2.1 QPushButton使用-如何自定义皮肤 2.2 QToolButton使用-如何设置帮助文档 2.3 QRadioButton使用-如何设置开关效果 2.4 QRadioButton使用-如何设置三态选择框 2.5 QCom…

【C++初阶】C++入门(上)

C的认识 ①什么是C&#xff1f; ​ C语言是结构化和模块化的语言&#xff0c;适合处理较小规模的程序。对于复杂的问题&#xff0c;规模较大的程序&#xff0c;需要高度的抽象和建模时&#xff0c;C语言则不合适。 ​ 于是1982年&#xff0c;Bjarne Stroustrup&#xff08;本…

虚拟游戏理财 - 华为OD统一考试(C卷)

OD统一考试&#xff08;C卷&#xff09; 分值&#xff1a; 100分 题解&#xff1a; Java / Python / C 题目描述 在一款虚拟游戏中生活&#xff0c;你必须进行投资以增强在虚拟游戏中的资产以免被淘汰出局。 现有一家Bank&#xff0c;它提供有若干理财产品m&#xff0c;风险及…