Focal and Global Knowledge Distillation for Detectors(CVPR 2022)原理与代码解析

news2024/11/22 16:12:34

paper:Focal and Global Knowledge Distillation for Detectors

official implementation:https://github.com/yzd-v/FGD

存在的问题 

如图1所示,前景区域教师和学生注意力之间的差异非常大,背景区域则相对较小。此外通道注意力的差异也非常明显。

作者还设计了实验解耦了蒸馏过程中的前景和背景,结果如表1所示,令人惊讶的是,前景背景一起进行蒸馏的效果是最差的,比单独蒸馏前景或背景还差。

 

上述结果表明,特征图中的不均匀差异会对蒸馏产生负面效果。这种不均匀差异不仅存在于前背景之间,也存在于不同像素位置和通道之间。

本文的创新点

针对前背景、空间位置、通道之间的差异,本文提出了focal distillation,在分离前背景的同时,还计算了教师特征不同空间位置和通道的注意力,使得学生专注于学习教师的关键像素和通道。

但是只关注关键信息还不够,在检测任务中全局语义信息也很重要。为了弥补focal蒸馏中缺失的全局信息,作者还提出了global distillation,其中利用GcBlock来提取不同像素之间的关系,然后传递给学生。

方法介绍

Focal Distillation

首先用一个binary mask \(M\) 来分离前背景

 

其中 \(r\) 是ground truth box,\(i,j\) 表示像素位置的坐标。

为了消除不同大小的gt box的尺度的影响和不同图片中前背景比例的差异,作者又设置了一个scale mask \(S\)

其中 \(H_{r},W_{r}\) 表示gt box \(r\) 的高和宽,如果一个像素属于不同的target,选择最小的box来计算 \(S\)。

接着作者借鉴SENet和CBAM的方法提取通道注意力和空间注意力

\(G^{S},G^{C}\) 分别表示空间和通道attention map,然后attention mask按下式计算

其中 \(T\) 是温度系数。

利用binary mask \(M\)、scale mask \(S\)、attention mask \(A^{S},A^{C}\),特征损失 \(L_{fea}\) 如下

其中 \(A^{S},A^{C}\)  表示教师的空间和通道attention mask,\(F^{T},F^{S}\) 分别表示教师和学生的feature map,\(\alpha, \beta\) 是balance超参。

此外作者还提出了注意力损失 \(L_{at}\) 让学生模仿教师的attention mask

\(l\) 表示L1损失。

完整的focal损失就是特征损失和注意力损失的和

Global Distillation

如图4所示,作者用GcBlock来提取全局关系信息,关于GcBlock的详细介绍可以参考GCNet: Global Context Network(ICCV 2019)原理与代码解析

全局损失 \(L_{global}\) 如下

 

\(W_{k},W_{v1},W_{v2}\) 是卷积层,\(LN\) 表示layer normalization,\(N_{p}\) 是特征中所有像素个数,\(\lambda\) 是balance超参。

Overall loss

完整的损失函数如下,包括原本的训练损失和蒸馏损失,蒸馏损失又包括focal损失和global损失

实验结果

  

其中inheriting strategry是《Instance-conditional knowledge distillation for object detection》这篇文章中提出的用教师的neck和head参数初始化学生网络,可以得到更好的效果。

代码解析

主要实现在mmdet/distillation/losses/fgd.py中,函数forward中,首先教师和学生的attention mask,即文中的式(5)~(8)

S_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)  # (N,H,W),(N,C)
S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)
def get_attention(self, preds, temp):
    """ preds: Bs*C*W*H """
    N, C, H, W = preds.shape

    value = torch.abs(preds)
    # Bs*W*H
    fea_map = value.mean(axis=1, keepdim=True)
    S_attention = (H * W * F.softmax((fea_map / temp).view(N, -1), dim=1)).view(N, H, W)

    # Bs*C
    channel_map = value.mean(axis=2, keepdim=False).mean(axis=2, keepdim=False)
    C_attention = C * F.softmax(channel_map / temp, dim=1)

    return S_attention, C_attention

接下来为了减小不同target尺度和前背景比例的影响,计算scale mask,即文中的式(2)~式(4)。其中内层的for循环是当一个像素属于不同的target时,选择最小的box来计算。

Mask_fg = torch.zeros_like(S_attention_t)
Mask_bg = torch.ones_like(S_attention_t)
wmin, wmax, hmin, hmax = [], [], [], []
for i in range(N):
    new_boxxes = torch.ones_like(gt_bboxes[i])
    new_boxxes[:, 0] = gt_bboxes[i][:, 0] / img_metas[i]['img_shape'][1] * W
    new_boxxes[:, 2] = gt_bboxes[i][:, 2] / img_metas[i]['img_shape'][1] * W
    new_boxxes[:, 1] = gt_bboxes[i][:, 1] / img_metas[i]['img_shape'][0] * H
    new_boxxes[:, 3] = gt_bboxes[i][:, 3] / img_metas[i]['img_shape'][0] * H

    wmin.append(torch.floor(new_boxxes[:, 0]).int())
    wmax.append(torch.ceil(new_boxxes[:, 2]).int())
    hmin.append(torch.floor(new_boxxes[:, 1]).int())
    hmax.append(torch.ceil(new_boxxes[:, 3]).int())

    area = 1.0 / (hmax[i].view(1, -1) + 1 - hmin[i].view(1, -1)) / (wmax[i].view(1, -1) + 1 - wmin[i].view(1, -1))

    for j in range(len(gt_bboxes[i])):
        Mask_fg[i][hmin[i][j]:hmax[i][j] + 1, wmin[i][j]:wmax[i][j] + 1] = \
            torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j] + 1, wmin[i][j]:wmax[i][j] + 1], area[0][j])

    Mask_bg[i] = torch.where(Mask_fg[i] > 0, 0, 1)
    if torch.sum(Mask_bg[i]):
        Mask_bg[i] /= torch.sum(Mask_bg[i])

接着就是完整的feature损失,即文中的式(9)

fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg,
                                     C_attention_s, C_attention_t, S_attention_s, S_attention_t)
def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):
    loss_mse = nn.MSELoss(reduction='sum')

    Mask_fg = Mask_fg.unsqueeze(dim=1)
    Mask_bg = Mask_bg.unsqueeze(dim=1)

    C_t = C_t.unsqueeze(dim=-1)
    C_t = C_t.unsqueeze(dim=-1)

    S_t = S_t.unsqueeze(dim=1)

    fea_t = torch.mul(preds_T, torch.sqrt(S_t))
    fea_t = torch.mul(fea_t, torch.sqrt(C_t))
    fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
    bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))

    fea_s = torch.mul(preds_S, torch.sqrt(S_t))
    fea_s = torch.mul(fea_s, torch.sqrt(C_t))
    fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))
    bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))

    fg_loss = loss_mse(fg_fea_s, fg_fea_t) / len(Mask_fg)
    bg_loss = loss_mse(bg_fea_s, bg_fea_t) / len(Mask_bg)

    return fg_loss, bg_loss

文中作者还提出了用L1 loss的attention损失,即式(10)

mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)
def get_mask_loss(self, C_s, C_t, S_s, S_t):
    mask_loss = torch.sum(torch.abs((C_s - C_t))) / len(C_s) + torch.sum(torch.abs((S_s - S_t))) / len(S_s)

    return mask_loss

feature loss和attention loss一起组成的focal loss,为了弥补全局语义信息的缺失,作者又引入了全局蒸馏损失,其中用到了GcBlock,即式(12)

rela_loss = self.get_rela_loss(preds_S, preds_T)
def get_rela_loss(self, preds_S, preds_T):
    loss_mse = nn.MSELoss(reduction='sum')

    context_s = self.spatial_pool(preds_S, 0)
    context_t = self.spatial_pool(preds_T, 1)

    out_s = preds_S
    out_t = preds_T

    channel_add_s = self.channel_add_conv_s(context_s)
    out_s = out_s + channel_add_s

    channel_add_t = self.channel_add_conv_t(context_t)
    out_t = out_t + channel_add_t

    rela_loss = loss_mse(out_s, out_t) / len(out_s)

    return rela_loss
def spatial_pool(self, x, in_type):
    batch, channel, width, height = x.size()
    input_x = x
    # [N, C, H * W]
    input_x = input_x.view(batch, channel, height * width)
    # [N, 1, C, H * W]
    input_x = input_x.unsqueeze(1)
    # [N, 1, H, W]
    if in_type == 0:
        context_mask = self.conv_mask_s(x)
    else:
        context_mask = self.conv_mask_t(x)
    # [N, 1, H * W]
    context_mask = context_mask.view(batch, 1, height * width)
    # [N, 1, H * W]
    context_mask = F.softmax(context_mask, dim=2)
    # [N, 1, H * W, 1]
    context_mask = context_mask.unsqueeze(-1)
    # [N, 1, C, 1]
    context = torch.matmul(input_x, context_mask)
    # [N, C, 1, 1]
    context = context.view(batch, channel, 1, 1)

    return context


self.channel_add_conv_s = nn.Sequential(
    nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
    nn.LayerNorm([teacher_channels//2, 1, 1]),
    nn.ReLU(inplace=True),  # yapf: disable
    nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
self.channel_add_conv_t = nn.Sequential(
    nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
    nn.LayerNorm([teacher_channels//2, 1, 1]),
    nn.ReLU(inplace=True),  # yapf: disable
    nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/846215.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【错误记录】Uncaught SyntaxError: Not available in legacy mode

错误记录:Uncaught SyntaxError: Not available in legacy mode 错误描述:在vite脚手架项目当中,使用vue-i18n插件进行国际化多语言时,报错 解决方案: 在引入vue-i18n 处,添加 legacy: false 如果对项目…

玩转Vue3:计算属性和监视属性深度解析

计算属性computed Vue中的计算属性是一种特殊的属性,它可以根据依赖的数据动态计算并返回结果。计算属性的值是通过getter函数计算得到的,当依赖的数据发生变化时,计算属性会自动重新计算并更新视图。计算属性具有缓存机制,只有当…

SSM的知识点考试系统java在线问答试卷管理jsp源代码mysql

本项目为前几天收费帮学妹做的一个项目,Java EE JSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。 一、项目描述 SSM的知识点考试系统 系统1权限:管理员 …

把大模型装进手机,分几步?

点击关注 文 | 姚 悦 编 | 王一粟 大模型“跑”进手机,AI的战火已经从“云端”烧至“移动终端”。 “进入AI时代,华为盘古大模型将会来助力鸿蒙生态。”8月4日,华为常务董事、终端BG CEO、智能汽车解决方案BU CEO 余承东介绍&#xff0c…

【计算机网络】UDP服务器实现网络聊天室

前言 上一篇文章我们简单了解了一下什么是套接字编程,这篇文章我们利用UDP套接字来实现一个简单的网络聊天室。 编写UDP套接字服务器 成员变量 // 1. socket的id,相当于文件id int _sock; // 2. port uint16_t _port;// 3 一个线程负责收放消息&…

JUC并发、JVM相关

文章目录 JUC并发synchronized锁对象底层原理 synchronized锁升级reentrantlock公平锁和非公平锁可重入锁 / 递归锁 死锁死锁产生条件如何排查死锁?如果解决死锁? LockSupport与中断机制中断机制中断相关的三大API如何中断运行中的线程? LockSupportLoc…

【C++】C++11--- 线程库及详解lock_guard与unique_lock

目录 一、thread类的介绍二、线程函数参数三、 原子性操作库四、lock_guard与unique_lock4.1、mutex的种类4.2 lock_guard4.3 unique_lock 一、thread类的介绍 在C11之前,涉及到多线程问题,都是和平台相关的,比如**windows和linux下各有自己…

【css】属性选择器

有些场景中需要在相同元素中获取具有特定属性的元素&#xff0c;比如同为input&#xff0c;type属性有text、button&#xff0c;可以通过属性选择器设置text和button的不同样式。 代码&#xff1a; <style> input[typetext] {width: 150px;display: block;margin-bottom…

自动配置要点解读

目录 要点1&#xff1a;什么是自动配置&#xff1f; 要点2&#xff1a;配置文件与默认配置 要点3&#xff1a;自动配置设置思想来源 要点4&#xff1a;spring.factories文件作用 要点5&#xff1a;自动配置的核心 本文只对自动配置的思想进行基本的解读&#xff0c;不涉…

21、p6spy输出执行SQL日志

文章目录 1、背景2、简介3、接入3.1、 引入依赖3.2、修改database参数&#xff1a;3.3、 创建P6SpyLogger类&#xff0c;自定义日志格式3.4、添加spy.properties3.5、 输出样例 4、补充4.1、参数说明 1、背景 在开发的过程中&#xff0c;总希望方法执行完了可以看到完整是sql语…

通用人工智能操作系统

随着科技的飞速发展&#xff0c;人工智能已经成为了当今世界最热门的技术领域之一。从智能手机、自动驾驶汽车到智能家居系统&#xff0c;人工智能技术已经渗透到了我们生活的方方面面。然而&#xff0c;尽管人工智能在很多领域取得了显著的成果&#xff0c;但它仍然存在一些局…

matplotlib+tkinter实现一个简单的绘图系统

文章目录 封装成类布局实现绘图功能 绘图系统系列&#xff1a;将matplotlib嵌入到tkinter 封装成类 在理解matplotlib嵌入到tkinter中的原理之后&#xff0c;就已经具备了打造绘图系统的技术基础&#xff0c;接下来要做的&#xff0c;就是做一个较有可读性的绘图类&#xff0…

Java异常体系总结(下篇)

目录 1. 异常处理的三种方法 1.1 JVM 默认处理异常 1.2 通过 try...catch...自己处理异常 1.3 使用 throws和throw 抛出异常 1.3.1 使用 throws 抛出异常 1.3.2 使用 throw 抛出异常 2. try...catch.. 捕获到异常之后代码的执行顺序&#xff1f; 3. try...catch... 相关…

Mysql进阶(中) -- 索引

索引上部分 -> Mysql进阶(上) -- 存储引擎&#xff0c;索引_千帐灯无此声的博客-CSDN博客 &#x1f442; 爸爸妈妈 - 王蓉 - 单曲 - 网易云音乐 &#x1f448;目录看左栏 目录 &#x1f33c;索引 &#x1f43b;性能分析 - show profiles &#x1f43b;性能分析 - exp…

Cocos 适配 HarmonyOS NEXT,亮相 HDC2023,携手华为共筑鸿蒙生态!

HDC 2023 8月4-6日&#xff0c;作为华为合作伙伴&#xff0c;Cocos 引擎应邀参加了华为开发者大会 2023 - HDC 2023 暨 HarmonyOS 4 发布会&#xff0c;并获得了【鸿蒙生态能力共创奖】。 8月5日&#xff0c;在华为开发者大会&#xff08;HDC.Together&#xff09;游戏服务论坛…

SpringBoot系列---【使用jasypt把配置文件密码加密】

使用jasypt把配置文件密码加密 1.引入pom坐标 <dependency><groupId>com.github.ulisesbocchio</groupId><artifactId>jasypt-spring-boot-starter</artifactId><version>3.0.5</version> </dependency> 2.新增jasypt配置 2.1…

HCIP-linux知识

linux安装教程参考&#xff0c;https://blog.51cto.com/cloudcs/5245337 yum源配置 本地yum源配置&#xff1a; 8版本配置&#xff1a;将光盘iso挂载到某个目录&#xff0c;/dev/cdrom是/dev/sr0软链接&#xff0c;# mount /dev/cdrom /mnt&#xff0c;# ls /mnt AppStream B…

Elastic:linux设置elasticsearch、kibana开机自启

0. 引言 每次启动服务器都要手动启动es服务&#xff0c;相当之不方便&#xff0c;为此&#xff0c;书写一个脚本&#xff0c;实现es、kibana的开机自启 1. 原理 首先任何服务要实现开机自启&#xff0c;都可分为如下三步&#xff1a; 1、在/etc/init.d目录下创建启动、关闭服…

跳表与Redis

跳表原理 跳表是Redis有序集合ZSet底层的数据结构 首先有一个头结点 这个头结点里面的数据是null 就是他就是这个链表的最小值 就算是Math.Min也比它大 然后我们新建一个节点的时候是怎么操作的呢 先根据参数(假如说是5)创建一个节点 然后把它放在对应位置 就是找到小于他的最…

(JS逆向专栏十一)某融平台网站登入RSA

声明: 本文章中所有内容仅供学习交流&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff0c;若有侵权&#xff0c;请联系我立即删除&#xff01; 名称:点融 目标:登入参数 加密类型:RSA 目标网址:https://www.dianrong.com/accoun…