CrossKD 原理与代码解析

news2024/9/20 10:42:50

paper:CrossKD: Cross-Head Knowledge Distillation for Dense Object Detection

official implementation: https://github.com/jbwang1997/CrossKD

前言 

蒸馏可以分为预测蒸馏prediction mimicking和特征蒸馏feature imitation两种,2015年Geoffrey Hinton提出的KD知识蒸馏开山之作KD:Distilling the Knowledge in a Neural Network 原理与代码解析属于预测模拟,而FitNets: Hints for Thin Deep Nets 原理与代码解析就属于典型的特征模拟。然而长期以来,大家发现预测模拟相比于特征模拟更加低效,LD for Dense Object Detection(CVPR 2022)原理与代码解析表明,预测模拟具有转移特定任务知识的能力,这有利于学生同时进行预测模拟和特征模拟。这促使作者进一步探索和改进预测模拟。

本文的创新点

在预测模拟中,学生模型的预测需要同时模拟GT和教师模型的预测,但是教师模型的预测常常会和GT有很大的差异,学生模型在蒸馏过程中经历了一个矛盾的学习过程,作者认为这是阻碍预测模型获得更高性能的主要原因。

为了缓解学习目标冲突的问题,本文提出了一种新的蒸馏方法CrossKD,将学生检测头的中间特征送入教师的检测头,得到的预测结果与教师的原始预测结果进行蒸馏,这种方法有两个好处,首先KD损失不影响学生检测头的权重更新,避免了原始检测损失和KD损失的冲突。此外由于交叉头的预测和教师的预测共享了部分教师的检测头,两者的预测相对一致,缓解了学生-教师之间的预测差异,提高了预测模拟的训练稳定性。

预测模拟、特征模拟以及本文提出的cross kd分别为图1的(a)(b)(c)所示

方法介绍 

CrossKD的整体架构如图3所示

给定一个dense detector比如RetinaNet,每个检测head通常由一系列卷积组成,表示为 \(\left \{ C_{i} \right \} \)。为了简便,我们假设每个检测头共有 \(n\) 个卷积层,(比如RetinaNet中n=5,包括4个隐含层和1个预测层)。我们用 \(f_{i},i\in\left \{ 1,2,...,n-1 \right \} \) 来表示 \(C_{i}\) 的输出特征图,\(f_{0}\) 表示 \(C_{1}\) 的输出特征图。预测 \(p\) 是由最后一个卷积层 \(C_{n}\) 的输出,教师和学生的最终预测结果可以分别表示为 \(p^{t},p^{s}\)。

CrossKD将学生检测头的中间特征 \(f_{i}^{s},i\in\left \{ 1,2,...,n-1 \right \} \) 送入 \(C^{t}_{i+1}\),即教师检测头的第 \((i+1)\) 个卷积层,得到交叉头的预测 \(\hat{p}^{s}\)。和之前的方法不同,我们不计算 \(p^{s}\) 和 \(p^{t}\) 之间的KD损失,而是计算 \(\hat{p}^{s}\) 和 \(p^{t}\) 之间的KD损失,如下

其中 \(\mathcal{S}(\cdot)\) 和 \(|\mathcal{S}|\) 分别是region selection principle和归一化因子。本文作者没有涉及复杂的 \(\mathcal{S}(\cdot)\),分类分支 \(\mathcal{S}(\cdot)\) 是常量值1,回归分支前景区域 \(\mathcal{S}(\cdot)\) 为1背景区域 \(\mathcal{S}(\cdot)\) 为0。

实验结果

首先是一些消融实验,教师网络采用ResNet-50+GFL,学生网络为ResNet-18。

Positions to apply CrossKD.

上面说过将学生检测头的第 \(i\) 个卷积层的输出送入教师网络,这里作者比较了不同 \(i\) 的值对最终结果的影响,当 \(i=0\) 时表示直接将FPN的输出特征送入教师网络的head,具体结果如下

可以看出当 \(i=3\) 时, 模型的最终精度最高,因此后续实验都采用默认配置 \(i=3\)。

CrossKD v.s. Feature Imitation.

作者对比了CrossKD和特征蒸馏的SOTA方法PKD,为了公平起见,与CrossKD相同的位置上执行PKD,包括 \(i=0\) 的neck和 \(i=3\) 的head,结果如下

可以看出,无论PKD在什么位置,效果都不如CrossKD。

CrossKD for Lightweight Detectors.

作者将CrossKD轻量的检测器上的结果如下

 

可以看出,教师网络为ResNet-101+GFL,学生网络为ResNet-50、ResNet-34、ResNet-18,CrossKD都可以显著提升精度。

Comparison with SOTA KD Methods

和其它目标检测的SOTA蒸馏方法的对别如下表,可以看出,CrossKD优于现有的所有方法。

代码解析

官方的实现是基于mmdetection,并将crosskd用到了atss、fcos、gfl、retinanet中,以atss为例,代码在mmdet/models/detectors/crosskd_atss.py中,loss部分代码如下。首先原始输入 batch_inputs分别经过教师和学生的backbone和neck,self.teacher_extract_feat就是教师网络的backbone和neck,self.extract_feat就是学生网络的backbone和neck。

 def loss(self, batch_inputs: Tensor,
             batch_data_samples: SampleList) -> Union[dict, list]:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            batch_inputs (Tensor): Input images of shape (N, C, H, W).
                These should usually be mean centered and std scaled.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.

        Returns:
            dict: A dictionary of loss components.
        """
        tea_x = self.teacher.extract_feat(batch_inputs)
        tea_cls_scores, tea_bbox_preds, tea_centernesses, tea_cls_hold, tea_reg_hold = \
            multi_apply(self.forward_hkd_single, 
                        tea_x,
                        self.teacher.bbox_head.scales, 
                        module=self.teacher)
            
        stu_x = self.extract_feat(batch_inputs)
        stu_cls_scores, stu_bbox_preds, stu_centernesses, stu_cls_hold, stu_reg_hold = \
            multi_apply(self.forward_hkd_single, 
                        stu_x,
                        self.bbox_head.scales, 
                        module=self)
            
        reused_cls_scores, reused_bbox_preds, reused_centernesses = multi_apply(
            self.reuse_teacher_head, 
            tea_cls_hold, 
            tea_reg_hold, 
            stu_cls_hold,
            stu_reg_hold, 
            self.teacher.bbox_head.scales)


        outputs = unpack_gt_instances(batch_data_samples)
        (batch_gt_instances, batch_gt_instances_ignore,
         batch_img_metas) = outputs
        losses = self.loss_by_feat(tea_cls_scores, 
                                   tea_bbox_preds,
                                   tea_centernesses,
                                   tea_x,
                                   stu_cls_scores,
                                   stu_bbox_preds,
                                   stu_centernesses,
                                   stu_x,
                                   reused_cls_scores,
                                   reused_bbox_preds,
                                   reused_centernesses,
                                   batch_gt_instances,
                                   batch_img_metas, 
                                   batch_gt_instances_ignore)
        return losses

得到的neck输出特征tea_xstu_x,然后分别进入函数self.forward_hkd_single,实现如下

    def forward_hkd_single(self, x, scale, module):
        cls_feat, reg_feat = x, x
        cls_feat_hold, reg_feat_hold = x, x
        for i, cls_conv in enumerate(module.bbox_head.cls_convs):
            cls_feat = cls_conv(cls_feat, activate=False)
            if i + 1 == self.reused_teacher_head_idx:
                cls_feat_hold = cls_feat
            cls_feat = cls_conv.activate(cls_feat)
        for i, reg_conv in enumerate(module.bbox_head.reg_convs):
            reg_feat = reg_conv(reg_feat, activate=False)
            if i + 1 == self.reused_teacher_head_idx:
                reg_feat_hold = reg_feat
            reg_feat = reg_conv.activate(reg_feat)
        cls_score = module.bbox_head.atss_cls(cls_feat)
        bbox_pred = scale(module.bbox_head.atss_reg(reg_feat)).float()
        centerness = module.bbox_head.atss_centerness(reg_feat)
        return cls_score, bbox_pred, centerness, cls_feat_hold, reg_feat_hold

其中分别经过教师和学生的head,包括cls分支和reg分支,self.reused_teacher_head_idx就是学生head中要送入教师的检测头的特征的索引,将这个位置的特征保存下来后续送入教师head,即函数reuse_teacher_head

    def reuse_teacher_head(self, tea_cls_feat, tea_reg_feat, stu_cls_feat,
                           stu_reg_feat, scale):
        reused_cls_feat = self.align_scale(stu_cls_feat, tea_cls_feat)
        reused_reg_feat = self.align_scale(stu_reg_feat, tea_reg_feat)
        if self.reused_teacher_head_idx != 0:
            reused_cls_feat = F.relu(reused_cls_feat)
            reused_reg_feat = F.relu(reused_reg_feat)

        module = self.teacher.bbox_head
        for i in range(self.reused_teacher_head_idx, module.stacked_convs):
            reused_cls_feat = module.cls_convs[i](reused_cls_feat)
            reused_reg_feat = module.reg_convs[i](reused_reg_feat)
        reused_cls_score = module.atss_cls(reused_cls_feat)
        reused_bbox_pred = scale(module.atss_reg(reused_reg_feat)).float()
        reused_centerness = module.atss_centerness(reused_reg_feat)
        return reused_cls_score, reused_bbox_pred, reused_centerness

注意这里有个align_scale的步骤,论文中没有提及,即将学生head的特征减去均值除以方差后,再乘以教师head对应位置特征的方差接着再加上教师特征的均值,如下

    def align_scale(self, stu_feat, tea_feat):
        N, C, H, W = stu_feat.size()
        # normalize student feature
        stu_feat = stu_feat.permute(1, 0, 2, 3).reshape(C, -1)
        stu_mean = stu_feat.mean(dim=-1, keepdim=True)
        stu_std = stu_feat.std(dim=-1, keepdim=True)
        stu_feat = (stu_feat - stu_mean) / (stu_std + 1e-6)
        #
        tea_feat = tea_feat.permute(1, 0, 2, 3).reshape(C, -1)
        tea_mean = tea_feat.mean(dim=-1, keepdim=True)
        tea_std = tea_feat.std(dim=-1, keepdim=True)
        stu_feat = stu_feat * tea_std + tea_mean
        return stu_feat.reshape(C, N, H, W).permute(1, 0, 2, 3)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/736755.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LeetCode】HOT 100(26)

题单介绍: 精选 100 道力扣(LeetCode)上最热门的题目,适合初识算法与数据结构的新手和想要在短时间内高效提升的人,熟练掌握这 100 道题,你就已经具备了在代码世界通行的基本能力。 目录 题单介绍&#…

什么是 AOP?对于 Spring IoC 和 AOP 的理解?Spring AOP 和 AspectJ AOP 有什么区别?

什么是 AOP? AOP(Aspect-Oriented Programming),即 面向切面编程, 它与OOP( ObjectOriented Programming, 面向对象编程) 相辅相成,提供了与OOP 不同的抽象软件结构的视角 在 OOP 中, 我们以类(class)作为我们的基本单元 而 A…

Zynq 多个UDP客户端组网启动问题(Auto negotiation error)PS:附UDP客户端初始化代码

最近正在进行一个Zynq项目,根据设计需求,需要将上位机作为UDP服务器,而FPGA则充当UDP客户端。同时,服务器需要能够接收和控制多个UDP客户端。 开发过程中,我是基于lwip UDP Perf Client 官方模版开发的。我遇到了以下几…

重磅!高通大客户「跳单」,智能座舱SoC赛道进入「混战期」

哪家车企是高通座舱芯片的最大客户?答案不是蔚来、理想、小鹏等智能化布局领先的新势力,而是比亚迪。 高工智能汽车研究院监测数据显示,2022年中国市场(不含进出口)乘用车智能座舱前装标配高通芯片交付367.45万辆&…

DALL·E2(unCLIP)、Stable Diffusion、IS、FID要点总结

DALLE 1 DALLE 1可以看成是VQ-VAE和文本经过BPE编码得到的embedding AE(Auto Encoder) encoder decoder结构,AE在生成任务时只会模仿不会创造,所有有了后面的VAE VAE(Variational AutoEncoder) 不再学习固定的bottleneck特征…

2023-07-07-liunx环境,python调用ITK(c++版本)批量生成Drr

文章目录 一、前言二、配置过程2.1.CMake与ITK的配置2.2.改写ITK的生成drr代码2.3.编译代码2.4.python调用cpp 三、总结四、参考博客 一、前言 最近在做配准,需要用ITK来生成数据,windows版本可以通过cmake与visual studio可以跑通生成。但是想要在linu…

纯干货,全文手码:如何利用低/无代码平台建立集团信息化系统

信息化系统是企业管理体系的延伸 对于一家集团企业而言,要实现信息化,首先需要考虑是否已经建立了完备的信息化管理制度。早在上世纪九十年代卡特彼勒引入了6 Sigma,使整个集团公司的运营规范化、系统化。通过多年的实践积累,卡特…

【雕爷学编程】Arduino动手做(157)---MX1508双路电机驱动模块

37款传感器与执行器的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止这37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&am…

计算机视觉:通过边缘检测探究卷积的特征提取功能

本文重点 在前面的课程中,我们学习了卷积核的运算,同时我们也学习了卷积核的含义,我们可以将卷积核理解为特征提取器,也就是说一个卷积核就是一个特征提取器,很多人对这种说法不了解,下面我们就通过一个边缘检测的例子来看一下卷积核是如何进行边缘特征的提取的。 什么…

[ICML 2023] Fast inference from transformers via speculative decoding

Contents IntroductionSpeculative DecodingStandardized SamplingSpeculative Sampling AnalysisNumber of Generated TokensCalculating α \alpha αWalltime ImprovementNumber of Arithmetic OperationsChoosing γ \gamma γ ExperimentsReferences Introduction 为了…

springboot+vue膳食营养健康网站零食美食品商城_4d8g9

随着社会的不断进步与发展,人们对生活质量要求逐步提升。如果开发一款膳食营养健康网站,可以让用户在最短的时间里享受到最好的服务;而开发本网站,又能够提高网站整体工作水平,简化工作程序,这对管理员和用…

北京“数据二十条”发布,筑牢数据安全根基,加速释放数据产能

要实现2000亿的目标,基础保障如何做好? 7月5日,中共北京市委、北京市人民政府印发《关于更好发挥数据要素作用进一步加快发展数字经济的实施意见》(《实施意见》分为9个部分,共涉及20项具体任务,也被称作“…

TypeScript 中的 enum 枚举类型的使用解读。

前言 在 TypeScript 中,新增了很多具有特性的一些数据类型处理方法,enum 【枚举】就是其中,很具有代表性的一种,所以本章节就来聊聊 在 TypeScript 中如何去运用 enum 【枚举】。 枚举得的概念: 枚举(Enum&…

Web前端工程师笔试题(合集)

Web前端开发工程师笔试题篇1 1. 在一个框架的属性面板中,不能设置下面哪一项。( C ) A.源文件 ; B.边框颜色 ; C.边框宽度 D.滚动条 2. CSS样式表根据所在网页的位置,可分为?(B ) A.行内样式表、内嵌样式表、混合样式表 B.行内样式表、内嵌样式表…

MySQL数据库——多表查询练习2

一、练习素材 创建表 --创建部门表dept create table dept ( dept1 int , dept_name varchar(11));--创建员工表emp create table emp ( sid int , name varchar(11), age int, worktime_start date, incoming int, dept2 int); 插入数据 --部门表插入数据 insert into dep…

C++ 栈和队列(stack and queue)语法使用及底层实现原理

本篇文章会对C中的容器stack和queue用法进行详解,也包含对优先队列(priority_queue)的讲解。同时会模拟实现stack、queue和priority_queue底层。希望本篇文章会对你有所帮助! 目录 一、stack 栈 1、1 什么是适配器 1、2 stack 语法…

C++ 线程池实现

思路 创建多个工作线程同时维护一个公共的任务队列, 任务队列非空时通过信号量唤醒阻塞等待的工作线程, 工作线程通过互斥锁互斥的从任务队列中取出任务, 然后执行任务 实现 信号量类 class sem {//封装信号量类 public:sem(int num 0) {if (sem_init(&m_sem, 0, num)…

Kernel-Pwn-FGKASLR保护绕过

FGKASLR FGASLR(Function Granular KASLR)是KASLR的加强版,增加了更细粒度的地址随机化。因此在开启了FGASLR的内核中,即使泄露了内核的程序基地址也不能调用任意的内核函数。 layout_randomized_image 在fgkaslr.c文件中存在着…

支持中文手写和多画布的Handraw

什么是 Handraw ? Handraw 是支持中文手写和多画布的 Excalidraw 白板工具。 官网上项目名称还是 Excalidraw-CN,所以 Handraw 应该是基于 Excalidraw 二开的,特点是支持中文手写字体和多画布 官方也提供了免费使用的站点:https://handraw.t…

ModaHub魔搭社区:向量数据库Zilliz Cloud插入 Entity教程

目录 开始前 插入单个 Entity 批量插入 Entity 准备数据 插入数据 写入操作 本文介绍如何将 Entity 插入到 Zilliz Cloud 集群中的 Collection。 Entity 是 Collection 中的基本数据单元。同一个 Collection 中的 Entity 具有相同的属性,这些属性共同定义在 Schema 中…