从代码角度理解DETR

news2025/1/14 18:26:44

  1. 一个cnn的backbone, 提图像的feature, 比如, HWC.
  2. 同时对这个feature做position_embedding.
  3. 然后二者相加 (在Transformer里面就是二者相加)
  4. 输入encoder,
  5. 输入decoder (这里有object queries.)
  6. 然后接Prediction Heads, 比如分类和回归.

下面的代码参考自: https://github.com/facebookresearch/detr
commit-id: 3af9fa8
在这里插入图片描述
可以看到, 这里传入的有backbone, transformer, 输入的类别个数(用来确定head的输出维度), num_quries, 以及是否需要aux_loss等.

先看一下forward的解释
在这里插入图片描述
输入是一堆图片和对应的mask. 这里mask先不管. 后面再来细看其具体起的作用;
输出是logits, boxes, 还有aux_outputs(只有在用aux_loss的时候才会有这个的输出)

接下来看第一步
在这里插入图片描述

backbone部分

从这里在这里插入图片描述
可以看出来, backbone是这两个的结合.
在这里插入图片描述
也就是说最终backbone输出的第一个其实是图像backbone提的feature, 第二个是每个feature所对应的 position encoding.

Transormer部分

transformer的输入如下

在这里插入图片描述
由于没有跑代码. 这里
src: 先理解成是images的feature, pos[-1], 先理解成是position-encoding

这里input_proj. 是对输入的src做了一个FC.
在这里插入图片描述
这里的query_embed是

在这里插入图片描述
可以理解成是一个query的词典表.

之后如图
在这里插入图片描述
encoder部分和原始transformer是一致的, 而decoder部分, 原始transformer输入的是trg_seq. 而这里是一个全0的矩阵. 大小与query_embed一样大.

之前transformer的decoder中是trg_seq 与 src_seq的encoder的output 做encoder-decoder-attention. 但是现在detr里的decoder中, 到现在还没有用到groundtruth.

head部分

在这里插入图片描述
这里bbox的coord取sigmoid的原因是gt是按图像的长宽给的比例.
在这里插入图片描述

aux_loss

loss部分

bbox-loss. 采用的有 l1-loss 以及giou loss, 这里用giou-loss的原因是scale-不变.

在这里插入图片描述

SetCriterion

在这里插入图片描述

这里写得很清晰, 即对gt和dt做一次Hungarian匹配, 然后, 将匹配到的pair, 去算loss, loss包括类别和bbox.
我理解这个过程相当于label-assign, 只不过特殊的地方是这个label-assign 是一个一对一的, 这是和之前一些一阶段和二阶段检测不太一样的地方. 比如使用anchor的方法中, 可能多个anchor会对应到一个gt上面, 这也是为啥那些方法的后处理中要使用NMS, 相当于是一种搜索排查式的检测方式,先检测出一堆proposals, 再选出置信度较高的.

loss_cardinality

这其实不是一个loss, 就是为了统计预测的object的数量与ground-truth数量之间的差异. 用来观测.

HungarianMatcher

这个就是一个匈牙利匹配, 用的 from scipy.optimize import linear_sum_assignment, 这里用pytorch的方式封装了一下.

mask起的作用

二维positionEncoding的细节

paper-reading

Abstract

  1. 把目标检测视为一个集合预测问题. 从设计上去掉了很多的人为操作,比如anchor设定, nms 等.
  2. 更关注object与image context 之间的本质, 直接去预测最终的结果集合. 而非"搜索式检测"
  3. 不需要开发额外的库,比如roi-align, roi-pooling, 这些操作…
  4. 很容易换一个head就可以去做分割的任务,

pipline


整个Pipline看上去很好理解, 细节主要体现在 图像的backbone的features如何转化成为 word-embeding似的输入, 进入到transformer中.

在大目标上面要比小目标上好

在这里插入图片描述
这里解释说在大目标上效果比较好是因为transormer的non-local的机制, 这一点我的理解是, transformer由于内部的self-attention操作, 使得输入的一句话中每一个词彼此之间都会去算attention的加权分数。 所以哪怕是某一个词的预测,它也是依赖于整个句子的. 所以是一个non-local的操作.
而小目标, 因为占据的图像中的位置比较少. 别的位置对于这个小目标的attention不那么重要, 因此这种non-local的操作,对于小目标不太友好.
当然作者也提了可以用其他的方式来缓解小目标不好的问题, 比如FPN.

训练时间长

在这里插入图片描述

对于set prediction问题, 两个重要的部件

set prediction loss

这个主要用于在predictin和ground-truth之间建立one-one map.
DETR是预测N个objects, N是一个超参, 比如100.

能够预测objects以及他们之间关系的模型结构. 这里就是指的是DETR

backbone

在这里插入图片描述
比较好理解,就是正常的 2d-backbone.

encoder

在这里插入图片描述
关键的部分也都在上面标出来了.

decoder

在这里插入图片描述

说实话,没太理解 N个, object-query 到最后为啥能够预测 N个 final predictions. 背后的原因是啥?

prediction Heads

这个就是正常的heads.

auxiliary losses

为了帮助模型训练, 在每个decoder层后面, 加了PredictionHeads 和Hungarian loss 来做监督, 并且这些 层是share 参数的.
在这里插入图片描述

实验对比

encoder层的影响

从下表可以看出来, 用encoder还是很有用的.
在这里插入图片描述

而且有可视化结果证明, encoder 层似乎已经把目标分离开来了.
在这里插入图片描述

decoder层数的影响

在这里插入图片描述

  1. decoder层数的增加, 效果变好,
  2. 当decoder层数只有一层的时候, NMS有用.
  3. 当decoder层数大于一层的时候, NMS几乎没有什么用.
    这说明 单个decoder层不足以表现不同输出间的关系.
    在这里插入图片描述

FFN层的重要性

去掉之后会掉点.
在这里插入图片描述

PositionEncoding 的重要性

  1. spatioal position encoding 在encoder 和decoder 中都非常重要. 没有的话会掉6个点左右.
    在这里插入图片描述

object_query

从代码里看, 是这样的流程.

在这里插入图片描述
比如 num_queries 是100, 而 hidden_dim 是64的话.

那么query_embed.weight 也是 [100, 64] 维的.

这里进去encoder的, tgt 每次其实都是0. 而query_embed.weight 充当的是query_pos.

我理解这个就是上面说的, output position encodings.

因此当normalize_before=False的时候, decoder的时候会直接走forward_post,
从下面的代码可以看出
在这里插入图片描述

在decoder的时候, 最初的一层的输入, 因为tgt都是0, 所以 q=k=query_pos.

所以这里的object_query 其实就是随机初始化的 query_embed.weight.

这里解释下nn.embedding 是什么意思.

nn.embedding可以理解为是一个词嵌入模块. 它是有weight的. 这是一个可以学习的层, 有参数,类似于conv2d.

比如上面的例子中, query_embed 就可以理解为是一个词典, 只不过这个词典有点小, 只有100个词, 每个词的embedding的大小是64.
forward的时候,可以传入indices, 来得到对应的每个单词的embeddings. 可以传入batch的indices.

推理的时候query_embed充当什么角色?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/577176.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

单片机原理及应用——持续更新

目录 一、单片机概述 1、单片机简介 2、单片机的特点 3、MSC-51系列与AT89S5x系列单片机 (1)MSC-51系列单片机 (2)AT89S5x系列单片机 二、AT89S52单片机的片内硬件结构 1、AT89S52单片机的硬件组成 2、AT89S52单片机的引…

Springboot +spring security,OAuth2 四种授权模式概念

一.简介 这篇文章来讲下Spring Security OAuth2 四种授权模式。 二.什么是OAuth2 OAuth 2.0 是一种用于授权的开放标准,允许用户授权第三方应用程序访问他们的资源,例如照片、视频或其他个人信息。OAuth 2.0 提供了一些不同的授权模式,包括…

我有一个朋友,分享给我的字节跳动测试开发真题

朋友入职已经两周了,整体工作环境还是非常满意的!所以这次特意抽空给我写出了这份面试题,而我把它分享给小伙伴们,面试&入职的经验! 大概是在3月中的时候他告诉我投递了简历,5月的时候经过了3轮面试收获…

Windows10中搭建ftp服务器以实现文件传输

开启ftp服务: 1、打开控制面板》程序和功能》 启用或关闭Windows功能 2、找到Internet Information Services,开启以下服务 勾选之后,ftp服务开启成功。 配置IIS,搭建ftp 1、WinS键搜索iis,回车打开》右击网站 》添加…

QUIC 协议:特性、应用场景及其对物联网/车联网的影响

什么是 QUIC 协议 QUIC(Quick UDP Internet Connections)是由谷歌公司开发的一种基于用户数据报协议(UDP)的传输层协议,旨在提高网络连接的速度和可靠性,以取代当前互联网基础设施中广泛使用的传输控制协议…

/dev/kmem /proc/kallsyms

文章目录 前言概述使用 /dev/kmem使用 /proc/kallsyms验证进阶 前言 上篇文章我们介绍了 /dev/mem,今天再来介绍下它的好兄弟 /dev/kmem crw-r----- 1 root kmem 1, 1 May 26 06:10 /dev/mem crw-r----- 1 root kmem 1, 2 May 26 06:10 /dev/kmem对比一下&#xf…

第十四届全国大学生数学竞赛决赛(非数类)游记+答案解析

2023/5/27 20:08:今天早上9:00~12:00考了数学竞赛国赛。广州是真的热啊!西安才17度,还下着小雨,到广州之后那个艳阳直接给我人干废了,去酒店的路上步行了20分钟真的要死了已经。 拿到卷子的我是崩溃的,用正…

计算机视觉:填充(padding)技术

本文重点 在前面的课程中,我们学习了使用3*3的过滤器去卷积一个5*5的图像,那么最终会得到一个3*3的输出。那是因为 33 过滤器在 55 矩阵中,只可能有 33 种可能的位置。 这背后的数学解释是,如果我们有一个nn的图像,用ff的过滤器做卷积,那么输出的维度就是(n−f+1)(n−f…

码出高效_第一章 | 有意思的二进制表示及运算

目录 0与1的世界1.如何理解32位机器能够同时处理处理32位电路信号?2.如何理解负数的加减法运算3.溢出在运算中如何理解4.计算机种常用的存储单位及转换5.位移运算规则6.有趣的 && 和 & 浮点数1.定点小数(为什么会出现浮点数表示?…

Linux 系统烧写

目录 MfgTool 工具简介MfgTool 工作原理简介烧写方式系统烧写原理 烧写NXP 官方系统烧写自制的系统系统烧写网络开机自启动设置 改造我们自己的烧写工具改造MfgTool烧写测试解决Linux 内核启动失败 前面我们已经移植好了uboot 和linux kernle,制作好了根文件系统。但…

Spring 事件相关知识ApplicationEvent

Spring 事件相关知识ApplicationEvent 事件工作流程相关类ApplicationListenerApplicationEvent 我们可以发布自己的事件ApplicationEventPublisher Spring框架中提供了多种事件类型,常用的几个事件类型如下: Spring 事件驱动模型是 Spring 框架中的一个…

uCOSii中的事件标志组

事件标志管理 (EVENT FLAGS MANAGEMENT) OSFlagAccept() 无等待查询”事件标志组的事件标志位”是否建立 OSFlagPend() 需要等待”事件标志组的事件标志位”建立 OSFlagCreate() 建立一个事件标志组 OSFlagDel() 删除一个事件标志组 OSFlagPost() 置位或清0事件标志组中的…

SpringBoot整合百度云人脸识别功能

SpringBoot整合百度云人脸识别功能 1.在百度官网创建应用 首先需要在百度智能云官网中创建应用,获取AppID,API Key,Secret Key 官网地址:https://console.bce.baidu.com/ 2.添加百度依赖 添加以下依赖即可。其中版本号可在mav…

【问卷分析】调节效应检验的操作①

文章目录 1.首先要明白自己的调节和自变量是什么类别的2.实操演练2.1 当调节变量是连续变量时2.1.1 将ml中心化2.1.2 使用分层回归探讨自变量和ml的交互对adh的影响2.1.3 结果解读 1.首先要明白自己的调节和自变量是什么类别的 2.实操演练 在本次演练中,我们以自变…

马斯克要用人工智能对抗人工智能

导读:马斯克对人工智能可能变得失控并“摧毁人类”的担忧促使他采取行动,发起了一个名为“TruthGPT”的项目。 本文字数:1400,阅读时长大约:9分钟 亿万富翁埃隆马斯克在谈到人工智能(AI)的危险时…

瑞合信LED字幕WiFi卡使用教程(8.0版)

请按照提示下载和安装软件,同时请允许所有权限,如下图; 也可以在各大主流应用商店(华为、小米、OPPO、vivo、App Store等)搜索“瑞合信”直接安装。 首次使用APP时会提示注册登录软件,你可以选择“使用微信…

深度学习进阶篇-预训练模型[4]:RoBERTa、SpanBERT、KBERT、ALBERT、ELECTRA算法原理模型结构应用场景区别等详解

【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化算法、卷积模型、序列模型、预训练模型、对抗神经网络等 专栏详细介绍:【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化…

Full-Scanner是一个多功能扫描工具,支持被动/主动信息收集,漏洞扫描工具联动,可导入POC和EXP

github项目地址:https://github.com/Zhao-sai-sai/Full-Scanner gitee项目地址:https://gitee.com/wZass/Full-Scanner 工具简介 做挖漏洞渗透测试有的时候要去用这个工具那个工具去找感觉麻烦我自己就写了一个简单的整合工具,有互联网大佬不…

Presto之BroadCast Join的实现

一. 前言 在Presto中,Join的类型主要分成Partitioned Join和Broadcast Join,在Presto 之Hash Join的Partition_王飞活的博客-CSDN博客 中已经介绍了Presto的Partitioned Join的实现过程,本文主要介绍Broadcast Join的实现。 二. Presto中Broa…

ChatGPT免费使用的方法有哪些?

目录 一、ChatGpt是什么? 二、ChatGPT国内免费使用的方法: 第一点:电脑端 第二点:手机端 三、结语: 一、ChatGpt是什么? ChatGPt是美国OpenAI [1] 研发的聊天机器人程序 。更是人工智能技术驱动的自然语…