知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)

news2024/9/24 14:59:39

有两种知识蒸馏方法:一种利用教师模型的输出概率(基于logits的方法)[15,14,11],另一种利用教师模型的中间表示(基于提示的方法)[12,13,18,17]。基于logits的方法利用教师的输出作为辅助信号来训练一个较小的模型,即学生模型:

利用教师模型的输出概率(基于logits的方法)

该类方法损失函数为:
在这里插入图片描述

DIST

Tao Huang,Shan You,Fei Wang,Chen Qian,and Chang Xu.Knowledge distillation from a strongerteacher.In Advances in Neural Information Processing Systems,2022.

import torch.nn as nn


def cosine_similarity(a, b, eps=1e-8):
    return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)


def pearson_correlation(a, b, eps=1e-8):
    return cosine_similarity(a - a.mean(1).unsqueeze(1),
                             b - b.mean(1).unsqueeze(1), eps)


def inter_class_relation(soft_student_outputs, soft_teacher_outputs):
    return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()


def intra_class_relation(soft_student_outputs, soft_teacher_outputs):
    return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))


class DIST(nn.Module):
    def __init__(self, beta=1.0, gamma=1.0, temp=1.0):
        super(DIST, self).__init__()
        self.beta = beta
        self.gamma = gamma
        self.temp = temp

    def forward(self, student_preds, teacher_preds, **kwargs):
        soft_student_outputs = (student_preds / self.temp).softmax(dim=1)
        soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)
        inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)
        intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)
        kd_loss = self.beta * inter_loss + self.gamma * intra_loss
        return kd_loss

KLDiv (2015年的原始方法)

import torch.nn as nn
import torch.nn.functional as F

# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):
    def __init__(self, temp=1.0):
        super(KLDiv, self).__init__()
        self.temp = temp

    def forward(self, student_preds, teacher_preds, **kwargs):
        soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)
        soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)
        kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()
        kd_loss *= self.temp ** 2
        return kd_loss

dkd (Decoupled KD(CVPR 2022) )

Borui Zhao,Quan Cui,Renjie Song,Yiyu Qiu,and Jiajun Liang.Decoupled knowledge distillation.InIEEE/CVF Conference on Computer Vision and Pattern Recognition,2022.

import torch
import torch.nn as nn
import torch.nn.functional as F


def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):
    gt_mask = _get_gt_mask(logits_student, target)
    other_mask = _get_other_mask(logits_student, target)
    pred_student = F.softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    pred_student = cat_mask(pred_student, gt_mask, other_mask)
    pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
    log_pred_student = torch.log(pred_student)
    tckd_loss = (
            F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')
            * (temperature ** 2)
    )
    pred_teacher_part2 = F.softmax(
        logits_teacher / temperature - 1000.0 * gt_mask, dim=1
    )
    log_pred_student_part2 = F.log_softmax(
        logits_student / temperature - 1000.0 * gt_mask, dim=1
    )
    nckd_loss = (
            F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')
            * (temperature ** 2)
    )
    return alpha * tckd_loss + beta * nckd_loss


def _get_gt_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
    return mask


def _get_other_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
    return mask


def cat_mask(t, mask1, mask2):
    t1 = (t * mask1).sum(dim=1, keepdims=True)
    t2 = (t * mask2).sum(1, keepdims=True)
    rt = torch.cat([t1, t2], dim=1)
    return rt


class DKD(nn.Module):
    def __init__(self, alpha=1., beta=2., temperature=1.):
        super(DKD, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.temperature = temperature

    def forward(self, z_s, z_t, **kwargs):
        target = kwargs['target']
        if len(target.shape) == 2:  # mixup / smoothing
            target = target.max(1)[1]
        kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature)
        return kd_loss

利用教师模型的中间表示(基于提示的方法)

该类方法损失函数为:
[ L_{hint} = D_{hint}(T_s(F_s), T_t(F_t)) ]

ReviewKD (CVPR2021)

论文:

Pengguang Chen,Shu Liu,Hengshuang Zhao,and Jiaya Jia.Distilling knowledge via knowledge review.In IEEE/CVF Conference on Computer Vision and Pattern Recognition,2021.

代码:

https://github.com/dvlab-research/ReviewKD

Adriana Romero,Nicolas Ballas,Samira Ebrahimi Kahou,Antoine Chassang,Carlo Gatta,and YoshuaBengio.Fitnets:Hints for thin deep nets.arXiv preprint arXiv:1412.6550,2014.

Yonglong Tian,Dilip Krishnan,and Phillip Isola.Contrastive representation distillation.In IEEE/CVFInternational Conference on Learning Representations,2020.

Baoyun Peng,Xiao Jin,Jiaheng Liu,Dongsheng Li,Yichao Wu,Yu Liu,Shunfeng Zhou,and ZhaoningZhang.Correlation congruence for knowledge distillation.In International Conference on ComputerVision,2019.

关于知识蒸馏损失函数的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

待更新

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1285055.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

免费的SEO外链发布工具,提升排名的利器

互联网已经成为信息传播和商业发展的重要平台。而对于拥有网站的个人、企业来说,如何让自己的网站在搜索引擎中脱颖而出?SEO(Search Engine Optimization)作为提高网站在搜索引擎中排名的关键手段. 什么是SEO外链? S…

C#,数值计算——计算实对称矩阵所有特征值和特征向量的雅可比(Jacobi)方法与源程序

1 文本格式 using System; namespace Legalsoft.Truffer { /// <summary> /// Computes all eigenvalues and eigenvectors of /// a real symmetric matrix by Jacobis method. /// </summary> public class Jacobi { private …

09、pytest多种调用方式

官方用例 # content of myivoke.py import sys import pytestclass MyPlugin:def pytest_sessionfinish(self):print("*** test run reporting finishing")if __name__ "__main__":sys.exit(pytest.main(["-qq"],plugins[MyPlugin()]))# conte…

Linux 输入输出重定向

Linux 系统默认的输入输出有3种类型&#xff0c;分别为标准输入、标准输出、错误输出&#xff0c;并且Linux 还为这几类设备分别分配了一个所谓的文件描述符&#xff0c;如下是他们之间的对应关系。 输入输出类型文件描述符系统中设备名通常对应的物理设备标准输入设备0/dev/s…

IntelliJ IDEA的下载安装配置步骤详解

引言 IntelliJ IDEA 是一款功能强大的集成开发环境&#xff0c;它具有许多优势&#xff0c;适用于各种开发过程。本文将介绍 IDEA 的主要优势&#xff0c;并提供详细的安装配置步骤。 介绍 IntelliJ IDEA&#xff08;以下简称 IDEA&#xff09;之所以被广泛使用&#xff0c;…

SpringBoot集成i18n(多语言)

配置文件 spring: messages: basename: il8n/messages # 配置国际化资源文件路径 fallback-to-system-locale: true # 是否使用系统默认的语言环境作为备选项 国际化配置 import org.springframework.context.annotation.Bean; import org.spr…

基于Eclipse+Mysql+Tomcat开发的 教学评价管理系统

基于EclipseMysqlTomcat开发的 教学评价管理系统 项目介绍&#x1f481;&#x1f3fb; 随着教育信息化的发展&#xff0c;教学评价管理系统已经成为了学校、教育机构等场所必不可少的一部分。本项目是基于EclipseMysqlTomcat开发的一套教学评价管理系统&#xff0c;旨在帮助教育…

成为AI产品经理——回归模型评估(MSE、RMSE、MAE、R方)

分类问题的评估是看实际类别和预测类别是否一致&#xff0c;它的评估指标主要有混淆矩阵、AUC、KS。回归问题的评估是看实际值和预测值是否一致&#xff0c;它的评估指标包括MAE、MSE、RMSE、R方。 如果我们预测第二天某支股票的价格&#xff0c;给一个模型 y1.5x&#xff0c;…

计算机网络之IP篇

目录 一、IP 的基本认识 二、DNS 三、ARP 四、DHCP 五、NAT 六、ICMP 七、IGMP 七、ping 的工作原理 ping-----查询报文的使用 traceroute —— 差错报文类型的使用 八、断网了还能 ping 通 127.0.0.1 吗&#xff1f; 8.1、什么是 127.0.0.1 &#xff1f; 8.2、为…

利用 FormData 实现文件上传、监控网路速度和上传进度(前端原生,后端 koa)

利用 FormData 实现文件上传 基础功能&#xff1a;上传文件 演示如下&#xff1a; 概括流程&#xff1a; 前端&#xff1a;把文件数据获取并 append 到 FormData 对象中后端&#xff1a;通过 ctx.request.files 对象拿到二进制数据&#xff0c;获得 node 暂存的文件路径 前端…

12.2旋转,SPLAY树的各种操作(SPLAY与AVL是两种BST)

Splay树和AVL树是两种不同的自平衡二叉搜索树实现。 1. 平衡条件&#xff1a;AVL树通过维护每个节点的平衡因子&#xff08;左子树高度减去右子树高度&#xff09;来保持平衡&#xff0c;要求每个节点的平衡因子的绝对值不超过1。Splay树则通过经过每次操作后将最近访问的节点…

Mybatis 操作续集(连着上文一起看)

"查"操作(企业开发中尽量不使用*,需要哪些字段就写哪些字段,都需要就全写上) Mybatis 会自动地根据数据库的字段名和Java对象的属性名进行映射,如果名称一样就进行赋值 但是那些名称不一样的,我们想要拿到,该怎么拿呢? 一开始数据库字段名和Java对象属性名如下图…

mfc 设置excel 单元格的列宽

CString strTL, strBR;strTL.Format(L"%s%d", GetExcelColName(cd.nCol), cd.nRow);strBR strTL;CRange rangeMerge range.get_Range(_variant_t(strTL), _variant_t(strBR));rangeMerge.put_ColumnWidth(_variant_t((long)(20))); 宽度设置函数为 &#xff1a; pu…

软件测试要学习的基础知识——黑盒测试

黑盒测试概述 黑盒测试也叫功能测试&#xff0c;通过测试来检测每个功能是否都能正常使用。在测试中&#xff0c;把程序看作是一个不能打开的黑盒子&#xff0c;在完全不考虑程序内部结构和内部特性的情况下&#xff0c;对程序接口进行测试&#xff0c;只检查程序功能是否按照…

Milvus 再上新!支持 Upsert、Kafka Connector、集成 Airbyte,助力高效数据流处理

Milvus 已支持 Upsert、 Kafka Connector、Airbyte&#xff01; 在上周的文章中《登陆 Azure、发布新版本……Zilliz 昨夜今晨发生了什么&#xff1f;》&#xff0c;我们已经透露过 Milvus&#xff08;Zilliz Cloud&#xff09;为提高数据流处理效率&#xff0c; 先后支持了 Up…

为告警设备设置服务端属性,在tb中标记存在告警的设备

有位读者想要实现标记系统中存在告警的设备,于是我给他做了三个方案。各有优缺点。 第一个方案时,告警是在规则链里手动创建的,通过告警数,+1,-1来标记设备告警属性。 第二种是当设备通过设备配置创建,清空告警。这种情况只适用于一次遥测创建,清空一个告警。不支持单次…

【Vue】使用 Vue CLI 脚手架创建 Vue 项目(使用GUI创建)

前言 在开始使用Vue进行开发之前&#xff0c;我们需要先创建一个Vue项目。Vue CLI&#xff08;Command Line Interface&#xff09;是一个官方提供的脚手架工具&#xff0c;可以帮助我们快速创建Vue项目。Vue CLI也提供了一个可视化的GUI界面来创建和管理Vue项目。 步骤 打开终…

【离散差分】LeetCode2953:统计完全子字符串

作者推荐 [二分查找]LeetCode2040:两个有序数组的第 K 小乘积 本题其它解法 【滑动窗口】LeetCode2953:统计完全子字符串 涉及知识点 分块循环 离散差分 题目 给你一个字符串 word 和一个整数 k 。 如果 word 的一个子字符串 s 满足以下条件&#xff0c;我们称它是 完全…

爬虫程序为什么一次写不好?需要一直修改BUG?

从我学习编程以来&#xff0c;尤其是在学习数据抓取采集这方面工作&#xff0c;经常遇到改不完的代码&#xff0c;我毕竟从事了8年的编程工作&#xff0c;算不上大佬&#xff0c;但是也不至于那么差。那么哪些因素导致爬虫代码一直需要修改出现BUG&#xff1f;下面来谈谈我的感…

网络协议与 IP 编址

网络协议与 IP 编址 之前大概了解过了网络的一些基础概念&#xff0c;见文章&#xff1a; 网络基础概念。 之前简单了解OSI模型分层&#xff1a; TCP/IP模型OSI模型TCP/IP对等模型应用层应用层表示层应用层会话层主机到主机层传输层传输层因特网层网络层网络层网络接入层数据链…