Conditional DETR解读---带anchor的DETR

news2025/1/18 20:15:31

DETR存在的问题

1.收敛速度慢

2.对小目标物体检测效果不好,因为transformer计算量大,受限于计算规模,CNN提取特征时只采取了最后一层特征,没有用FPN等结构。所以对于小目标检测效果不好。

论文主要观点

  • 通过对DETRdecoder中的attentionmap进行可视化,发现query查询到的区域都是物体的extremity末端区域。所以论文认为attention尝试找到物体的边界区域。

  • 论文中认为DETRtransofmer结构中的信息主要可以分为两部分,一部分是与图像的特征(颜色纹理等)相关的信息,称为content,比如encoder或decoder的输出信息。另一部分是代表空间上的信息,称为spatial,比如position embedding等。

  • detr中的CNN与encoder只涉及图像特征向量提取;decoder中的self-attn只涉及query之间的交互去重;所以收敛慢的最可能原因发生在cross attn

  • Cross attention中的K包含encoder输出信息(content key Ck)与position embedding(spatial Key Pk),Q包含self attention的输出(content query Cq)和object query(spatial query Pq)信息。论文中发现去掉cross attention中的object基本不掉点,所以收敛慢很可能是content query难学习导致的。

  • 提出了reference point的概念,为每个query设定一个检测范围,使得匹配更加稳定,加快了收敛

  • 原始detr混合两者学习,使得content query难学习。所以将content与spatial进行解耦

在这里插入图片描述

变为

在这里插入图片描述

网络结构

在这里插入图片描述

对于object query生成了一个2D坐标embedding(上图中的s),用于限定当前query的预测范围。最终decoder的输出的是相对与s的偏移量

bbox回归输出

在这里插入图片描述

其中f是decoer的输出,S表示x,y的坐标。最终b是[x,y,w,h]的向量。

classifier分类输出

在这里插入图片描述

f是decoder的输出,输出每个候选框的类别

decoder Pq生成:

提出了reference point的概念,即图中的s,是一个2d的坐标(q_num,B,2),由object queries经过一个线性层生成,代表了每个query的预测范围。

s经过sigmoid和position embedding后(图中的Ps),跟FFN(decoder embedding)(即图中的T)做内积。得到空间特征Pq

在这里插入图片描述

在这里插入图片描述

代码spatial query这一部分的实现:

# query_pos [num_query,batch,d_model]
# reference_points_before_sigmoid [num_query,batch,2]  从query预测一个坐标,代表了这个query预测的大概范围
reference_points_before_sigmoid = self.ref_point_head(query_pos)    # [num_queries, batch_size, 2]
reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1)
for layer_id, layer in enumerate(self.layers):
    # 图里的s,代表了query的预测大概范围
    obj_center = reference_points[..., :2].transpose(0, 1)      # [num_queries, batch_size, 2]

    # For the first decoder layer, we do not apply transformation over p_s
    ## pos_transformation代表图里的T,表示decoder embedding的特征经过ffn后其实得到的是相对于s的偏移量
    if layer_id == 0:
        pos_transformation = 1
    else:
        pos_transformation = self.query_scale(output)

    # get sine embedding for the query vector
    query_sine_embed = gen_sineembed_for_position(obj_center)     
    # apply transformation
    # 最终的Pq,代表空间特征信息
    query_sine_embed = query_sine_embed * pos_transformation
    output = layer(output, memory, tgt_mask=tgt_mask,
                   memory_mask=memory_mask,
                   tgt_key_padding_mask=tgt_key_padding_mask,
                   memory_key_padding_mask=memory_key_padding_mask,
                   pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
                   is_first=(layer_id == 0))

decoder中cross attention的实现


# ========== Begin of Cross-Attention =============
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.ca_qcontent_proj(tgt)
k_content = self.ca_kcontent_proj(memory)
v = self.ca_v_proj(memory)

num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape

# k的位置编码
k_pos = self.ca_kpos_proj(pos)

# For the first decoder layer, we concatenate the positional embedding predicted from 
# the object query (the positional embedding) into the original query (key) in DETR.
if is_first:
    q_pos = self.ca_qpos_proj(query_pos)
    q = q_content + q_pos
    k = k_content + k_pos
else:
    q = q_content
    k = k_content

q = q.view(num_queries, bs, self.nhead, n_model//self.nhead)
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead)
# decoder embedding cat spatial query
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
k = k.view(hw, bs, self.nhead, n_model//self.nhead)
# encoder embdeding cat position embedding
k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead)
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)

tgt2 = self.cross_attn(query=q,
                           key=k,
                           value=v, attn_mask=memory_mask,
                           key_padding_mask=memory_key_padding_mask)[0]               
# ========== End of Cross-Attention =============

head的实现

# hs代表decoder embedding,reference代表s(reference point)
hs, reference = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])
reference_before_sigmoid = inverse_sigmoid(reference)
outputs_coords = []
for lvl in range(hs.shape[0]):
    # 回归head hs输出相对于 reference的偏移量,得到检测框
    tmp = self.bbox_embed(hs[lvl])
    tmp[..., :2] += reference_before_sigmoid
    outputs_coord = tmp.sigmoid()
    outputs_coords.append(outputs_coord)
outputs_coord = torch.stack(outputs_coords)
#分类head,hs输出分类结果
outputs_class = self.class_embed(hs)

总结思考

实际上conditional DETR有点像transfoermer版本的faster-RCNN。将特征信息与空间信息进行了解耦。reference point像anchor的概念,让网络自己为每个query设定一个anchor范围,从而使得二分匹配更加问题,所以加快了网络的收敛

作者论文解读:https://zhuanlan.zhihu.com/p/401916664
公式解释得更加详细

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1719883.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java—— StringBuilder 和 StringBuffer

1.介绍 由于String的不可更改特性,为了方便字符串的修改,Java中又提供了StringBuilder和Stringbuffer类,这两个类大部分功能是相同的,以下为常用方法: public static void main(String[] args) {StringBuilder sb1 n…

乡村振兴与乡村旅游创新:创新乡村旅游产品,提升旅游服务水平,打造特色乡村旅游品牌,助力美丽乡村建设

目录 一、引言 二、乡村旅游产品的创新 (一)挖掘乡村特色资源 (二)注重产品体验性 (三)创新旅游产品形态 三、旅游服务水平的提升 (一)加强基础设施建设 (二&…

微信小程序-页面导航-导航传参

1.声明式导航传参 navigator组件的url属性用来指定将要跳转到的页面的路径,同时,路径的后面还可以携带参数: (1)参数与路径之间使用 ? 分割 (2)参数键与参数值用 相连 (3&…

《SpringBoot3+Vue3实战》系列文章目录

前后端分离(Frontend-Backend Separation)是一种软件架构设计模式,它将传统的Web应用中的前端(用户界面)和后端(服务器逻辑和数据存储)从应用层面进行解耦,使得两者可以独立地开发、…

conda与pip的镜像源与代理设置

conda与pip的镜像源与代理设置 一、前言二、conda镜像源设置2.1conda默认镜像源介绍2.2通过终端设置镜像源2.3通过配置文件设置镜像源 三、pip镜像源设置3.1pip默认镜像源介绍3.2通过终端临时设置镜像源3.3通过配置文件设置一个或多个镜像源 四、conda代理设置4.1通过终端设置代…

铁塔基站用能监控能效解决方案

截至2023年10月,我国5G基站总数达321.5万个,占全国通信基站总数的28.1%。然而,随着5G基站数量的快速增长,基站的能耗问题也逐渐日益凸显,基站的用电给运营商带来了巨大的电费开支压力,降低5G基站的能耗成为…

【论文速读】Self-Rag框架,《Self-Rag: Self-reflective Retrieval augmented Generation》

关于前面的文章阅读《When to Retrieve: Teaching LLMs to Utilize Information Retrieval Effectively》,有网友问与Self-Rag有什么区别。 所以,大概看了一下Self-Rag这篇论文。 两篇文章的方法确实非常像,Self-Rag相对更加复杂一些。 When …

大模型部署_书生浦语大模型 _作业2基本demo

本节课可以让同学们实践 4 个主要内容,分别是: 1、部署 InternLM2-Chat-1.8B 模型进行智能对话 1.1安装依赖库: pip install huggingface-hub0.17.3 pip install transformers4.34 pip install psutil5.9.8 pip install accelerate0.24.1…

系统架构设计师【第5章】: 软件工程基础知识 (核心总结)

文章目录 5.1 软件工程5.1.1 软件工程定义5.1.2 软件过程模型5.1.3 敏捷模型5.1.4 统一过程模型(RUP)5.1.5 软件能力成熟度模型 5.2 需求工程5.2.1 需求获取5.2.2 需求变更5.2.3 需求追踪 5.3 系统分析与设计5.3.1 结构化方法5.3.2 面向对象…

Kafka自定义分区器编写教程

1.创建java类MyPartitioner并实现Partitioner接口 点击灯泡选择实现方法,导入需要实现的抽象方法 2.实现方法 3.自定义分区器的使用 在自定义生产者消息发送时,属性配置上加入自定义分区器 properties.put(ProducerConfig.PARTITIONER_CLASS_CONFIG,&q…

RabbitMQ(Direct 订阅模型-路由模式)的使用

文章目录 RabbitMQ(Direct 订阅模型-路由模式)一,Direct 订阅模型-路由模式介绍(Routing)二,使用1.添加依赖2.修改配置文件3.创建配置类4.注入RabbitMQ模版引擎5.消息的发送6.消息的接收(监听)7.设置回调函…

2024.5.30学习记录

1 面经复习 LRU 手写 等 2 代码随想录二刷 3 rosebush完成 upload组件 初步完成 form组件

如何设置手机的DNS

DNS 服务器 IP 地址 苹果 华为 小米 OPPO VIVO DNS 服务器 IP 地址 中国大陆部分地区会被运营商屏蔽网络导致无法访问,可修改手机DNS解决。 推荐 阿里的DNS (223.5.5.5)或 114 (114.114.114.114和114.114.115.115) 更多公开DNS参考: 苹果…

鸿蒙开发接口媒体:【@ohos.multimedia.media (媒体服务)】

媒体服务 说明: 本模块首批接口从API version 6开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。 开发前请熟悉鸿蒙开发指导文档: gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或者复制转到。 媒体子系…

dubbo复习:(11)使用grpc客户端访问tripple协议的dubbo 服务器

一、服务器端依赖&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.…

尝试用智谱机器人+知识库,制作pytorch测试用例生成器

尝试用智谱机器人知识库,制作pytorch测试用例生成器 1 保存pytorch算子文档到txt2 创建知识库3 创建聊天机器人4 测试效果5 分享 背景:是否能将API的接口文档和sample放到RAG知识库,让LLM编写API相关的程序呢 小结:当前的实验效果并不理想,可以生成代码,但几乎都存在BUG 1 保存…

Window系统安装Docker

因为docker只适合在liunx系统上运行&#xff0c;如果在window上安装的话&#xff0c;就需要开启window的虚拟化&#xff0c;打开控制面板&#xff0c;点击程序&#xff0c;在程序和功能中可以看到启动和关闭window功能&#xff0c;点开后&#xff0c;找到Hyper-V&#xff0c;Wi…

DevExpress开发WPF应用实现对话框总结

说明&#xff1a; 完整代码Github​&#xff08;https://github.com/VinciYan/DXMessageBoxDemos.git&#xff09;DevExpree v23.2.4&#xff08;链接&#xff1a;https://pan.baidu.com/s/1eGWwCKAr8lJ_PBWZ_R6SkQ?pwd9jwc 提取码&#xff1a;9jwc&#xff09;使用Visual St…

FFmpeg 中 Filters 使用文档介绍

描述 这份文档描述了由libavfilter库提供的过滤器Filters、源sources和接收器sinks。 滤镜介绍 FFmpeg通过libavfilter库启用过滤功能。在libavfilter中,一个过滤器可以有多个输入和多个输出。为了说明可能的类型,我们考虑以下过滤器图: 这个过滤器图将输入流分成两个流,然…

微信小程序-wx.showToast超长文字展示不全

wx.showToast超长文字展示不全 问题解决方法1 问题 根据官方文档&#xff0c;iconnone&#xff0c;最多显示两行文字。所以如果提示信息较多&#xff0c;超过两行&#xff0c;就需要用其他方式解决。 解决方法1 使用vant组件里面的tost 根据官方例子使用&#xff1a; 1、在…