拆解DKD loss (建议读完论文哈)

news2025/3/1 13:47:03
  • 论文链接:https://arxiv.org/abs/2203.08679

 

def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):
    gt_mask = _get_gt_mask(logits_student, target)     # 获取掩码
    other_mask = _get_other_mask(logits_student, target)
    pred_student = F.softmax(logits_student / temperature, dim=1)  # 然后将学生和教师模型的输出通过softmax函数和温度参数进行缩放。
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    pred_student = cat_mask(pred_student, gt_mask, other_mask)    # 接着,函数将通过之前获取到的两个掩码对学生和教师模型的输出进行切片,
    pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)    # 来获取属于真实标签以及不属于真实标签的模型预测结果。
    log_pred_student = torch.log(pred_student)                    # 然后对学生模型中真实标签部分的输出取对数,
    tckd_loss = (                                                 # 并将其与教师模型的输出通过KL散度计算一种损失 tckd_loss。
        F.kl_div(log_pred_student, pred_teacher, size_average=False)
        * (temperature**2)
        / target.shape[0]
    )
    pred_teacher_part2 = F.softmax(                               # 接着将学生模型中不属于真实标签部分的输出取对数
        logits_teacher / temperature - 1000.0 * gt_mask, dim=1      # ,并将其与教师模型获取的剩余输出通过KL散度计算另一种损失nckd_loss。
    )
    log_pred_student_part2 = F.log_softmax(
        logits_student / temperature - 1000.0 * gt_mask, dim=1
    )
    nckd_loss = (                # 接着将学生模型中不属于真实标签部分的输出取对数,并将其与教师模型获取的剩余输出通过KL散度计算另一种损失nckd_loss。
        F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False)
        * (temperature**2)
        / target.shape[0]
    )
    return alpha * tckd_loss + beta * nckd_loss                    # 最终将这两种损失按照权重加权求和作为总的DKD损失返回。
def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):

定义这个函数来计算DKD的损失 

需要传入 学生模型的 logits单元和老师的logits单元

知识蒸馏综述笔记_:)�东东要拼命的博客-CSDN博客

target 表示真实标签GT

alpha, beta 两个权重 表示两个知识正交 无关 互不影响

temperature 表示蒸馏温度

gt_mask = _get_gt_mask(logits_student, target)
获取标签为真实标签的掩码和标签不是真实标签的掩码
other_mask = _get_other_mask(logits_student, target)

pred_student = F.softmax(logits_student / temperature, dim=1)
pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
然后将学生和教师模型的输出通过softmax函数和温度参数进行缩放
pred_student = cat_mask(pred_student, gt_mask, other_mask)
pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
接着,函数将通过之前获取到的两个掩码对学生和教师模型的输出进行切片

这里应该理解起来有些困难 我去借个原文的图回来

 

 这个图很清晰了

函数是将pred_studentpred_teacher分别进行切片来获取属于真实标签以及不属于真实标签的部分。这里的切片是通过两个掩码来实现的,具体包含以下步骤:

  1. 首先,利用_get_gt_mask_get_other_mask两个帮助函数分别获取真实标签和非真实标签部分的掩码,掩码中元素的取值为0或1,1代表该类别属于真实标签;0代表该类别不属于真实标签。
  2. 然后,对于学生模型输出的概率分布pred_student和教师模型输出的概率分布pred_teacher,将其按照对应的掩码进行切片,对于属于真实标签的部分,保留对应的概率,对于不属于真实标签的部分,以0填充。
  3. 最终得到的是两个经过切片处理的概率分布pred_studentpred_teacher,其中分别包含了属于真实标签和不属于真实标签的部分。该方法可以减少真实标签以外的噪声对知识蒸馏效果的影响。

log_pred_student = torch.log(pred_student)
然后对学生模型中真实标签部分的输出取对数



 # 并将其与教师模型的输出通过KL散度计算一种损失 tckd_loss。
tckd_loss = (                                                

    F.kl_div(log_pred_student, pred_teacher, size_average=False)
    * (temperature**2)
    / target.shape[0]
)

其中F.kl_div计算的是log_pred_student(学生模型在真实标签上的预测分布取对数后得到的张量)和pred_teacher(教师模型在真实标签上的预测分布)之间的KL散度。

由于KL散度是没有单位的,所以为了方便理解和比较,一般会将其除以样本数目target.shape[0],这其实相当于计算平均KL散度。

为了进一步加强知识蒸馏的作用,我们还会乘以一个温度的平方temperature**2,这样做可以使预测结果更加平滑,并可以减轻分类器对于某些输出的过度自信。

其中 size_average=False 意味意味着 KL散度 函数不会对结果进行批次规范化,也就是不会除以批次大小。因此,输出结果是未经过规范化的,每个样本都有自己的损失值。在进行批次训练时,这些值可以被相加然后除以批次大小,以得到整个批次的平均损失。

 

pred_teacher_part2 = F.softmax(                              
    logits_teacher / temperature - 1000.0 * gt_mask, dim=1      
)
log_pred_student_part2 = F.log_softmax(
    logits_student / temperature - 1000.0 * gt_mask, dim=1
)

这两个是为nckd服务的

nckd_loss = (                
    F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False)
    * (temperature**2)
    / target.shape[0]
)

这是一个对 logits_teacher 进行 softmax 函数的应用,除以一个温度常量并且减去一个大的负数值(通过与一个 ground truth mask 相乘得到)。

softmax 函数将 logits(未规范化的对数概率)转化为概率值,使它们相加等于1。

将 logits 除以温度常量可以控制结果分布的“软度”。

而 ground truth mask 是一个二进制掩码,对于目标 ground truth 标记值为1,对于所有其他标记值为0。

将其与一个大的负数值相乘可以将该标记的概率变为0,从而避免模型过度依赖真实标记,鼓励其探索其他可能性。该函数沿着第二个维度(通常是标记维度)应用。

def _get_gt_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
    return mask

此函数用于根据模型输出的逻辑值和目标标签生成一个掩码,

以获取与目标标签对应的类别的掩码。

这个函数首先将目标标签reshape为一维张量

(-1表示PyTorch将根据原始张量形状推断该维度的大小)。

接下来,它创建了一个与目标映射形状相同的零填充张量,

然后在沿第二维(即列)的位置上填充目标张量指示的位置,

并将1填充在这些位置上。

这是通过调用scatter_()方法来完成的,

其输入是要scatter的维度(在这种情况下为维度1),

位置索引(即目标值)和要scatter的值(即1)。

最后,它将生成的张量转换为布尔掩码以返回。

生成的掩码可用于各种目的,

例如仅选择与目标标签对应的逻辑值,以计算损失函数或计算给定批次输入的准确度等。

def _get_other_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
    return mask
def cat_mask(t, mask1, mask2):
    t1 = (t * mask1).sum(dim=1, keepdims=True)
    t2 = (t * mask2).sum(1, keepdims=True)
    rt = torch.cat([t1, t2], dim=1)
    return rt

这是一个函数,需要输入三个张量,分别是 mask1、mask2 和 t。

其中,mask1 和 mask2 分别表示掩码,t 是待处理的张量。

这个函数实现了 t 和 mask1、mask2 之间的逐个元素相乘操作,

然后在第二个维度上对结果进行求和。

使用 keepdims=True 在结果张量中保留了该维度。

函数输出将两个求和结果沿着第二个维度拼接起来得到一个新的张量 rt,并返回。

这个函数可用于根据提供的掩码从 t 中提取某些特征。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/404819.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

自动化测试——数据驱动测试

数据驱动测试 在实际的测试过程中,我们会发现好几组用例都是相同的操作步骤,只是测试数据的不同,而我们往往需要编写多次用例来进行测试,此时我们可以利用数据驱动测试来简化该种操作。 参数化: 输入数据的不同从而产…

一篇文章教你彻底理解ThreadLocal

文章目录ThreadLocal是什么?ThreadLocal如何使用?特别注意ThreadLocal数据存储存取ThreadLocal原理解析Thread.threadLocals原理Thread.inheritableThreadLocals原理ThreadLocal内存泄漏内存泄漏原因对内存泄漏的补救用完就要删除(最终解决&a…

统一流程平台----执行流应用

在flowable平台中,执行流(Execution)完成了流程实例/执行分支/任务/子流程之间关系的建立。flowable整个体系以执行流为基础,完成上下游数据的关联,让bpmn图纸能按照约定进行流转,形成了第一层概念。1.执行…

vue3实战项目安装各种爆时候报错问题和解决

文章目录1.安装:npm install -g sass报错问题1.npm install失败,报错如下引入使用echarts 相关问题1. vue3中npm install echarts --save报错但是这个地方有报提示,问题待解决...........1.安装:npm install -g sass 注释: 多种安装方法 2.vue中局部引用,也可以设置全局css文件…

C++:vector和list的迭代器区别和常见迭代器失效问题

迭代器常见问题的汇总vector迭代器和list迭代器的使用vector迭代器list迭代器vector迭代器失效问题list迭代器失效问题vector和list的区别vector迭代器和list迭代器的使用 学习C,使用迭代器和了解迭代器失效的原因是每个初学者都需要掌握的,接下来我们就…

C++代码格式化-clang-format

文章目录前言c|vscode|clang-formatc|vs|clang-formatc|.clang-format其他附录Visual Studio格式在vs和vscode中不同无法从繁体切换到简体我的vs code配置前言 一个项目中的代码,可能来自不同的地方。不管是多人合作,还是ctrl-c/ctrl-v,都有…

剑指offer JZ6 从尾到头打印链表

Java 剑指offer JZ6 从尾到头打印链表 文章目录Java 剑指offer JZ6 从尾到头打印链表一、题目描述二、递归写法三、栈方法使用Java的递归和栈解决从尾到头打印链表的问题 一、题目描述 输入一个链表的头节点,按链表从尾到头的顺序返回每个节点的值(用数组…

spring cloud @RefreshScope 刷新机制

在学习 nacos 的配置修改发现用到了 RefreshScope 注解,将 spring boot 的日志调整如下logging:level:com:alibaba:cloud: debugorg:springframework:context: debugcloud: debug调用 nacos 的配置修改,看到如下信息2023-03-10 15:48:15.332 INFO [com.a…

三天吃透MySQL面试八股文

本文已经收录到Github仓库,该仓库包含计算机基础、Java基础、多线程、JVM、数据库、Redis、Spring、Mybatis、SpringMVC、SpringBoot、分布式、微服务、设计模式、架构、校招社招分享等核心知识点,欢迎star~ Github地址:https://github.com/…

MGRE综合实验

实验拓扑及相关要求: IP地址配置: ip规划如该拓扑上可视 缺省路由: [r1]ip route-static 0.0.0.0 0 15.0.0.2 [r2]ip route-static 0.0.0.0 0 25.0.0.2 [r3]ip route-static 0.0.0.0 0 35.0.0.2 [r4]ip route-static 0.0.0.0 0 45.0.0.2 公…

Java的二叉树、红黑树、B+树

数组和链表是常用的数据结构,数组虽然查找快(有序数组可以通过二分法查找),但是插入和删除是比较慢的;而链表,插入和删除很快(只需要改变一些引用值),但是查找就很慢&…

游戏引擎开发总结:序列化系统

序列化需要准备什么?首先,我们需要一个被序列化类实现序列化函数,以及序列化库。准备我的序列化库是Yaml-cpp,序列话函数就命名为Serialize,另外我们不需要关心组件类型是具体是什么,所以我这边使用多态&am…

Spring和MaBatis整合

Spring和MyBatis整合: 先瞅一眼各种文件路径: 将之前mybatis中的测试类中的SqlSessionFactory(通过其openSession()来获得对象SqlSession),和Mybatis配置文件中的数据源(url,username等&#…

【Java爬虫】HttpClient+Jsoup实现爬取校内新闻

警告网络爬虫作为一门技术,在使用过程中,应该遵守Robots协议。采集数据时应注意礼貌,不允许爬的网站尽量不要短时间大频率爬取,涉及hdd的网站更是不要去满足自己的好奇心,不然说不准哪天就和吴签一起吃大碗宽面了...介…

[洛谷-P2585][ZJOI2006]三色二叉树(树形DP+状态机DP)

[洛谷-P2585][ZJOI2006]三色二叉树(树形DP状态机DP)一、题目题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1提示数据规模与约定二、分析1、递归建树2、树形DP 状态机DP(1)状态表示(2)状态转移三、…

C++11异步编程

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录前言1、std::future和std::shared_future1.1 std:future1.2 std::shared_future2、std::async3、std::promise4、std::packaged_task前言 C11提供了异步操作相关的类…

Vue3电商项目实战-结算支付 2【03-结算-对话框组件封装、04-结算-收货地址-切换】

文章目录03-结算-对话框组件封装04-结算-收货地址-切换03-结算-对话框组件封装 目的:实现一个对话框组件可设置标题,动态插入内容,动态插入底部操作按钮,打开关闭功能。 大致步骤: 参照xtx-confirm定义一个基础布局实…

MFC常用控件使用(文本框、编辑框、下拉框、列表控件、树控件)

简介 本文章主要介绍下MFC常用控件的使用,包括静态文本框(Static Text)、编辑框(Edit Control)、下拉框(Combo Box)、列表控件(List Control)、树控件(Tree Control)的使用。 创建项目 我们选择 文件->新建->新建项目,选择MFC程序 选择基于对话…

二叉树的三种遍历

二叉树的遍历可以有:先序遍历、中序遍历、后序遍历先序遍历:根、左子树,右子树中序遍历:左子树、根、右子树后序遍历:左子树、右子树、根下面是我画图理解三种遍历:二叉树里都是分为左子树和右子树。分治思…

Linux文件基础I/O

文件IO文件的常识基础IO为什么要学习操作系统的文件操作C语言对于函数接口的使用接口函数介绍如何理解文件文件描述符重定向更新给模拟实现的shell增加重定向功能为什么linux下一切皆文件?缓冲区为什么要有缓冲区缓冲区对应的刷新策略缓冲区的位置在哪里文件的常识 …