论文阅读笔记:Semi-DETR: Semi-Supervised Object Detection with Detection Transformers

news2024/12/23 20:04:32

论文阅读笔记:Semi-DETR: Semi-Supervised Object Detection with Detection Transformers

  • 1 背景
    • 1.1 动机
    • 1.2 问题
  • 2 创新点
  • 3 方法
  • 4 模块
    • 4.1 分阶段混合匹配
    • 4.2 跨视图查询一致性
    • 4.3 基于代价的伪标签挖掘
    • 4.4 总损失
  • 效果
    • 5.1 和SOTA方法对比
    • 5.2 消融实验

论文:https://arxiv.org/pdf/2307.08095v1.pdf

代码:https://github.com/JCZ404/Semi-DETR

1 背景

1.1 动机

虽然DETR-based方法在全监督目标检测中实现了SOTA性能,但一个可行的DETR-based半监督目标检测(SSOD)框架仍然有待探索。

1.2 问题

问题1:1对1的分配策略具有NMS-free端到端检测的优点,在半监督场景的效率较低。

如果直接用检测器对未标记图像进行伪标注,当伪包围框不准确时,一对一分配策略会将单个不准确的提议匹配为正样本,而降其他潜在正确的提议匹配为负样本,从而噪声学习效率低下。

问题2:1对多的分配策略获得了质量更好的候选建议集吗,使得检测器优化效率更高,但会引入重复预测。

问题3:SSOD中常用的一致性正则化方法在DTER-based SSOD方法中不可行。

因为DETR-based检测器通过注意力机制不断更新query特征,随着query特征的变化,预测结果也会发生变化,即输入对象查询与其输出预测结果之间不存在确定的对应关系,这使得一致性正则化无法应用于DETR-based检测器中。

2 创新点

在这里插入图片描述

作者在TeacherStudent架构的基础上提出了一个新的基于DETR的SSOD框架Semi-DETR。如图1(b)所示。主要是

(1)提出了一个分阶段混合匹配模块,分别使用1对多分配1对1分配两个阶段训练。第一个阶段旨在通过1对多分配策略提高训练效率,从而为第二个阶段的1对1训练提供高质量的伪标签。

(2)引入了一个跨视图查询一致性模块,该模块构建了跨视图对象查询,以消除对象查询确定性对应的要求,并帮助检测器在两个增强试图之间学习对象查询的语义不变特征。

(3)基于高斯混合模型设计了一个基于代价的伪标签挖掘模块,该模块根据匹配代价分布动态的挖掘用于一致性学习的可靠伪框。

提出的方法效果如图2。
在这里插入图片描述

3 方法

在这里插入图片描述

提出的Semi-DETR的整体框架如图3所示。根据SSOD流行的教师学生模型,作者提出的Semi-DETR采用了一对具有完全相同网络结构的教师和学生模型(论文里采用的是DINO)。在每次训练迭代中,弱增强和强增强的未标记图像分别反馈给教师和学生网络。然后将教师生成的置信度大于 τ s \tau_s τs 的伪标签作为训练学生网络的监督。学生的参数参数通过反向传播更新,教师模型参数是学生模型的EMA。

4 模块

4.1 分阶段混合匹配

在学生的预测和教师生成的伪标注之间执行匈牙利匹配,可以得到一个最优的1对1分配 σ o 2 o \sigma_{o2o} σo2o
在这里插入图片描述

其中 ξ N \xi_N ξN 是 N个元素的置换构成的集合, C m a t c h ( y ^ i t , y ^ σ ( i ) s ) C_{match}(\hat{y}^t_i,\hat{y}^s_{\sigma(i)}) Cmatch(y^it,y^σ(i)s) 伪标签 y ^ i t \hat{y}_i^t y^it 和学生模型的第 σ ( i ) \sigma(i) σ(i) 个预测之间的匹配代价。

由于在SSOD训练的早期阶段,教师生成的伪标注通常是不准确和不可靠的,这使得在1对1分配策略下生成稀疏和低质量建议的风险很高。为了利用多个正查询来实现高效的半监督,作者提出使用1对多的分配代替1对1的分配:
在这里插入图片描述

其中 C N M C_N^M CNM 是 M 和 N 的组合,即 M 个提议的子集被分配给每个伪框 y ^ i t \hat{y}_i^t y^it 中。使用分类得分 s s s 和 IoU值 u u u 的高阶组合作为匹配代价度量:
在这里插入图片描述

其中 α \alpha α β \beta β 是分类得分和IoU的影响因子,论文中设 α = 1 , β = 6 \alpha=1,\beta=6 α=1,β=6。通过1对多分配,选择 m m m 值最大的 M 个提案作为正样本,其余为负样本。

分类损失和回归损失也做了相应修改:
在这里插入图片描述

其中 γ \gamma γ 设置为2。通过为每个伪标签分配多个正建议,潜在的高质量正建议也获得了被优化的机会,大大提高了收敛速度,进而获得更好的伪标签。然而每个伪标签的多个正建议会导致重复的预测,为了缓解这一问题,在第二阶段切换回1对1的分配训练。通过这样做,在第一阶段训练后享受高质量的伪标签,并逐步减少重复预测,以在第二阶段通过1对1分配训练出NMS-free的检测器。该阶段的损失为:
在这里插入图片描述

教师网络的结果会采用NMS去重。

4.2 跨视图查询一致性

在传统的非DETR-based的SSOD框架中,给定相同的输入 x x x 并采用不同的随机增广,一致性正则化通过最小化教师 f θ f_\theta fθ 和学生 f θ ′ f'_\theta fθ 的输出之差来监督模型:
在这里插入图片描述

然而对于 DETR-based 框架,由于输入对象查询与输出预测结果之间没有明确的对应关系,因此进行一致性正则化变得不可行 。
在这里插入图片描述

图4展示了提出的跨视图查询一致性模块。具体来说,对于每一幅未标图像,给定一组伪边框 b b b,用若干个 MLP 处理 RoI Align 提取的 ROI 特征:
在这里插入图片描述

其中, F t F_t Ft F S F_S FS 分别是教师和学生的骨干特征。随后, c t c_t ct c s c_s cs 被视为跨视图查询嵌入,和另一个视图中的原始对象查询合并,作为解码器的输入:
在这里插入图片描述

其中 q . q_. q. E . E_. E. 表示原始对象查询和编码特征, o ^ . \hat{o}_. o^. o . o_. o. 分别表示跨视图查询和原始对象查询的解码特征。下标 t t t s s s 分别表示教师和学生,为了避免信息泄露,还使用了注意力掩膜 A A A
在跨视图查询嵌入的语义引导下,解码特征的对应关系可以自然的得到保证,一致性损失如下:
在这里插入图片描述

4.3 基于代价的伪标签挖掘

为了在跨视图查询一致性学习中挖掘出更多具有有意义语义内容的伪框,作者提出了一种基于代价的伪标签挖掘伪框模块,动态地在伪标注数据中挖掘出可靠的伪框。具体来说,在初始过滤的伪框和预测建议之间进行额外的二分匹配,并利用匹配代价来描述伪框的可靠性:
在这里插入图片描述

其中 p i p_i pi b i b_i bi 表示第 i i i 个建议预测的分类和回归, p ^ j \hat{p}_j p^j b ^ j \hat{b}_j b^j 表示第 j j j 个伪标签的类别和框坐标。

最后,在每个训练批次中,通过拟合高斯混合模型的匹配代价分布,将初始伪框类分为两种状态,如图5所示,匹配代价和伪框的质量非常吻合。作者进一步将可靠聚类中心的代价值设置为阈值,并收集所有代价低于阈值的伪框用于跨视图查询一致性计算。
在这里插入图片描述

先通过教师模型预测的每幅图像的所有建议框置信度的均值假方差获得图像级的置信度阈值,使用阈值过滤得到的初始伪标签,如图(b)所示。

代码如https://github.com/JCZ404/Semi-DETR/blob/main/detr_ssod/models/dino_detr_ssod.py#L921:

avg_score = torch.mean(proposal_box[:, -1])
std_score = torch.std(proposal_box[:, -1])

pseudo_thr = avg_score + std_score

# filter the pseudo bbox
valid_inds = torch.nonzero(proposal_box[:, -1] >= pseudo_thr, as_tuple=False).squeeze().unique()

然后对学生模型预测的结果和伪标签将进行匈牙利匹配,计算每一批次内每个边界框的匹配代价,用GMM模型拟合,如图(a)所示。作者认为成本较低的伪框更可能是可靠的伪框,因此从GMM模型中取较低的阈值来再次过滤伪标签,得到(d)中呈现的可靠伪框。最终会用人为设定的阈值过滤出的伪框计算无监督损失,并将GMM模型过滤的伪框和人为阈值过滤的伪框合并,用于计算一致性损失。

代码如https://github.com/JCZ404/Semi-DETR/blob/main/detr_ssod/models/dino_detr_ssod.py#L332:

valid_inds = torch.nonzero(match_gt_cost <= thr_, as_tuple=False).squeeze().unique()
valid_gt_inds_1 = match_gt_inds[valid_inds]


valid_gt_inds_2 = torch.nonzero(gt_scores >= base_thr, as_tuple=False).squeeze().unique()

            
valid_gt_inds = torch.cat((valid_gt_inds_1.to(imgs.device), valid_gt_inds_2.to(imgs.device))).unique()
  
gt_bboxes_list.append(gt_bboxes[valid_gt_inds_2, :4])
gt_labels_list.append(gt_labels[valid_gt_inds_2])
gt_scores_list.append(gt_scores[valid_gt_inds_2])
 
# ==== High recall pseudo labels for consistency ====
unsup_bboxes_gmm_list.append(gt_bboxes[valid_gt_inds, :4])
unsup_labels_gmm_list.append(gt_labels[valid_gt_inds])
unsup_scores_gmm_list.append(gt_scores[valid_gt_inds]) 

4.4 总损失

总损失函数如下:
在这里插入图片描述

其中 w u = 4 , w c = 1 w_u=4,w_c=1 wu=4,wc=1 T 1 T_1 T1 是SHM模块的第一个阶段,后面实验中测试最佳轮次为60K, t t t 是当前训练轮次。

效果

5.1 和SOTA方法对比

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

5.2 消融实验

本文提出的各模块的消融实验。
在这里插入图片描述

比较不同方法生成的方法来为CQC生成伪标签,其中本文提出的基于Cost的GMM阈值过滤效果最好。
在这里插入图片描述

第一阶段1对多分配策略的消融实验。
在这里插入图片描述
第一阶段的训练轮数的消融实验。
在这里插入图片描述

伪标签阈值 τ s \tau_s τs 的消融实验。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2036542.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flink开发过程中遇到的问题

1. 任务启动报错Trying to access closed classloader. Exception in thread "Thread-5" java.lang.IllegalStateException: Trying to access closed classloader. Please check if you store classloaders directly or indirectly in static fields. If the st…

基于PSO-BP+BP多特征分类预测对比(多输入单输出) Matlab代码

基于PSO-BPBP多特征分类预测对比(多输入单输出) Matlab代码 1、和市面上的不同&#xff0c;运行一个main一键出对比图&#xff0c;非常方便 2、可以根据需要定制其他算法优化模型对比 程序已经调试好&#xff0c;无需更改代码替换数据集即可运行&#xff01;&#xff01;&…

Python | Leetcode Python题解之第334题递增的三元子序列

题目&#xff1a; 题解&#xff1a; class Solution:def increasingTriplet(self, nums: List[int]) -> bool:n len(nums)if n < 3:return Falsefirst, second nums[0], float(inf)for i in range(1, n):num nums[i]if num > second:return Trueif num > first…

C++字体库开发之EM长度单位转换九

freetype 设置EM // if (m_face) // FT_Set_Pixel_Sizes(*m_face, 0, pixelSize); // 动态宽&#xff0c;固定高 px // error FT_Set_Char_Size(face, /* face 对象的句柄 */ // 0, /* 以 …

Unity Audio

这章练习将介绍在unity中创建 audio&#xff08;音频&#xff09;的工具&#xff0c;培养的技能将帮助创建引人入胜的音频音景。完成本次学习后&#xff0c;能够使用 Unity 中的所有主要音频组件&#xff0c;为各种不同体验创建音频效果。 音频处理工具&#xff1a; Audacity…

Mintegral出海系列:解锁全球应用商店新增长路径

在全球化竞争的浪潮中&#xff0c;面对打法各异的应用和游戏品类&#xff0c;以及全球数百个环境不同的国家和地区&#xff0c;开发者们正面临着前所未有的挑战。Mintegral「出海ing」系列专题内容&#xff0c;助力出海开发者选准赛道探索新的增长路径。 据近期数据显示&#x…

LLM微调(精讲)-以高考选择题生成模型为例(DataWhale AI夏令营)

前言 你好&#xff0c;我是GISer Liu&#x1f601;&#xff0c;一名热爱AI技术的GIS开发者&#xff0c;上一篇文章中&#xff0c;作者介绍了基于讯飞开放平台进行大模型微调的完整流程&#xff1b;而在本文中&#xff0c;作者将对大模型微调的数据准备部分进行深入&#xff1b;…

凤凰端子音频矩阵应用领域

凤凰端子音频矩阵&#xff0c;作为一种集成了凤凰端子接口的音频矩阵设备&#xff0c;具有广泛的应用领域。以下是其主要应用领域&#xff1a; 一、专业音响系统 会议系统&#xff1a;在会议室中&#xff0c;凤凰端子音频矩阵能够处理多个话筒和音频源的信号&#xff0c;实现…

Luminar Neo for Mac/Win:创新AI图像编辑软件的强大功能

Luminar Neo&#xff0c;这款由Skylum公司倾力打造的图像编辑软件&#xff0c;为Mac和Windows用户带来了前所未有的创作体验与编辑便利。作为一款融合了先进AI技术的图像处理工具&#xff0c;Luminar Neo以其独特的功能和高效的操作流程&#xff0c;成为了摄影师、设计师及摄影…

使用Sanic和SSE实现实时股票行情推送

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐&#xff1a;「storm…

【Next】全局样式和局部样式

不同于 nuxt &#xff0c;next 的样式绝大部分都需要手动导入。 全局样式 使用 sass 先安装 npm i sass -D 。 我们可以定义一个 styles 文件&#xff0c;存放全局样式。 variables.scss $fs30: 30px;mixin border() {border: 1px solid red; }main.scss use ./variables …

业界首个OpenTelemetry结合eBPF的向导式可观测性平台APO正式开源

AutoPilot Observability (简称APO&#xff09;是什么&#xff1f; 开箱即用的可观测性平台&#xff1a;APO 致力于提供一键安装、开箱即用的可观测性平台。APO 的 OneAgent 支持一键免配置安装 Tracing 探针&#xff0c;支持采集应用的故障现场日志、基础设施指标、应用和下游…

主机防火墙IPV6 域名 测试环境搭建及测试方法

由于国内当前网站支持ipv6的很少,部分支持ipv6 的网站由于路由器的限制,也无法直接访通过ipv6进行访问,因此进行主机防火墙ipv6域名测试时,需要自己搭建环境进行测试,以下为搭建环境的步骤。 1 . 搭建DNS服务器 环境:安装有python,系统为Windows Server 2016 DNS服务…

【Vue3】vue模板中如何使用enum枚举类型

简言 有的时候&#xff0c;我们想在vue模板中直接使用枚举类型的值&#xff0c;来做一些判断。 ts枚举 枚举允许开发人员定义一组命名常量。使用枚举可以更容易地记录意图&#xff0c;或创建一组不同的情况。TypeScript 提供了基于数字和字符串的枚举。 枚举的定义这里不说了…

haproxy最强攻略

1、负载均衡 负载均衡&#xff08;Load Balance&#xff0c;简称 LB&#xff09;是高并发、高可用系统必不可少的关键组件&#xff0c;目标是 尽力将网络流量平均分发到多个服务器上&#xff0c;以提高系统整体的响应速度和可用性。 负载均衡的主要作用如下&#xff1a; 高并发…

接入谷歌支付配置

1.谷歌云创建项目 网址&#xff1a;https://console.cloud.google.com/ 按照步骤创建即可 创建好后选择项目&#xff0c;转到项目设置 选择服务账户&#xff0c;选择创建新的服务账户 名称输入好后访问权限吗账号权限都可以不用填写&#xff0c;默认就好了 然后点击电子邮…

爵士编曲:Bass编写,Walking Bass,SwingBass 爵士鼓 Swing Jazz律动 Moonkits

Walking Bass Line是乐曲构造中的基垫&#xff0c;“Walking”是在BassLine中的一种重要的感觉构成&#xff0c;等同于我们对于“行走”的理解&#xff0c;意义就是“一步接着一步”&#xff0c;先从每一步&#xff08;每一小节&#xff09;建立&#xff0c;并持续构建成一个完…

Android 10.0 SystemUI下拉状态栏QSTileView去掉着色效果显示彩色图标功能实现

1.前言 在10.0的系统rom定制化开发中,在关于SystemUI的下拉状态栏中QSTileView的背景颜色设置过程中,在由于 系统原生有着色效果,导致现在某些彩色背景显示不是很清楚效果不好,所以需要去掉QSTileView的默认着色 背景显示原生的彩色背景,接下来就来实现相关功能 如图: 2.…

【微信小程序】实现中英文切换

1、组织语言资源 创建两个文件夹&#xff0c;分别用于存放中文和英文的语言资源。例如&#xff0c;可以在 utils 文件夹下创建 lang 文件夹&#xff0c;然后在其中创建 zh.js 和 en.js 文件&#xff0c;分别存放中文和英文的文本内容。 zh.js: const zh {home: {title: 这里…

【人工智能】全景解析:【机器学习】【深度学习】从基础理论到应用前景的【深度探索】

目录 1. 人工智能的基本概念 1.1 人工智能的定义与发展 1.1.1 人工智能的定义 1.1.2 人工智能的发展历史 1.2 人工智能的分类 1.2.1 弱人工智能 1.2.2 强人工智能 1.2.3 超人工智能 1.3 人工智能的关键组成部分 1.3.1 数据 1.3.2 算法 1.3.3 计算能力 2. 机器学习…