DETR论文详解

news2024/9/23 23:26:20

文章目录

  • 前言
  • 一、DETR理论
  • 二、模型架构
    • 1. CNN
    • 2. Transformer
    • 3. FFN
  • 三、损失函数
  • 四、代码实现
  • 总结

前言

 DETR是Facebook团队在2020年提出的一篇论文,名字叫做《End-to-End Object Detection with Transformers》端到端的基于Transformers的目标检测,DETR是Detection Transformers的缩写。DETR摆脱了传统目标检测模型中复杂的组件,例如NMS、先验框等,是一种基于Transformer的简单的端到端的目标检测架构,并且取的了与Faster R-CNN差不多的成绩。接下来对DETR这篇论文进行介绍。

一、DETR理论

 DETR将目标检测看做一个直接的预测集合问题,集合的大小是固定的,集合中的每个元素可以当作一个检测框,可以理解为DETR根据图片的输入去预测集合中元素的属性(分类,坐标)。
 DETR的推理概如下图所示:输入是一张图片,图片首先经过CNN网络进行下采样,然后将CNN输出的特征图拉直成向量进入Transformer的编码器中,Transformer的解码器的输入是object queries,(object queries是可学习的参数,可以理解为object queries学习预测框的先验知识)然后与Transformer的编码器输出做交叉注意力并行得到最终的检测框。在DETR论文中这个集合的大小取100,也就是对于每张图片,都会一口气预测出100个框,对于预测框中的分类预测为’no object’的则不显示,然后我们可以设置一个阈值,把集合中置信度低于阈值的预测框去掉,从而得到最终的输出。
 DETR根据Transformer的全局建模能力对图像进行全局的上下文推理,因此对于大物体的检测效果很好,但是对于小物体的检测效果不如Faster R-CNN。但是由于其架构的简单性深受大家喜欢。
在这里插入图片描述

二、模型架构

DETR的模型结构主要由三部分组成,分别是CNN,Transformer和FFN。如下图所示:
在这里插入图片描述

1. CNN

 DETR使用CNN骨干网络(例如ResNet50)将输入的照片由[3,H,W]下采样成[2048,H/32,W/32],然后在使用一个卷积核用来减少通道数,由[2048,H/32,W/32]变为[d,H/32,W/32]。然后将其进行展平与位置编码表进行相加送入到Transformer编码器中。

2. Transformer

 在论文中,DETR使用了6层的Transformer。
编码器 DETR中的编码器架构与经典的Transformer编码器相同,由多头自注意力层和FFN组成。下图为encoder部分的自注意力可视化,可以看到encoder主要负责预测物体的主体部分。
在这里插入图片描述

解码器 DETR中的解码器架构也与经典的Transformer解码器相同,稍微不同的是,经典的Transformer模型是自回归方式,而DETR在每个解码器层并行解码N个对象,因此N个输入嵌入必须不同,输入嵌入是可以学习的位置编码,称之为’object query’,在经过解码器的解码后并行计算出N个输出。下图为解码器的可视化输出,可以看到解码器主要用于区分物体的边缘或者轮廓。
在这里插入图片描述

3. FFN

 当得到解码器的N个输出后,使用两个线性层和ReLU激活函数将每个输出分别映射到预测框的输出类别和坐标位置。由于集合中的边界框比实际照片中的物体数量多,所以集合中剩余的那些边界框预测的标签则为’no object’,表示预测为背景。

三、损失函数

 DETR在进行训练时,例如我们设置集合的大小为100,那么模型最终会输出100个框,Ground True可能只有2个,那么我们如何算Loss呢?

1.论文里首先是使用匈牙利算法进行最佳二分图匹配。 如何理解最佳二分图匹配呢?

例如:有100个预测框,最终只有2个预测框与真实的Ground True相匹配。最佳二分图匹配就是那么选哪两个预测框与2个Ground True进行匹配得到的最终Loss最小。

那么我们如何衡量一个预测框与Groud True之间的损失呢(匹配损失)?公式如下:
在这里插入图片描述
通过这个公式我们可以获得预测框分别于Ground True进行匹配的Loss,最终使用匈牙利算法选定哪两个预测框与Ground True相对应,而其他的98个预测框可以看作与’no object‘相匹配。

2. 接着我们得出100个预测框如何与Ground True进行匹配,使用如下公式算的最终的Loss:
在这里插入图片描述
在实际中,为了保持正负样本的均衡,当预测框的种类为’no object’,其对数概率项的权重被缩小了10倍。同时为了使Bounding box loss ( L b o x L_{box} Lbox)对于不同大小的预测框的惩罚项尽可能公平, L b o x L_{box} Lbox L o s s I o u Loss_{Iou} LossIou L 1 L_1 L1组成,公式如下:
在这里插入图片描述

3 辅助损失函数。 在训练过程中Transformer的decoder中加入了辅助损失函数,也就是将每层decoder的输出都通过参数共享的FFN映射成预测框然后计算Loss。

四、代码实现

 DETR的前向推理过程代码如下所示:

import torch
from torch import nn
from torchvision.models import resnet50

class DETR(nn.Module):
	def __init__(self, num_classes, hidden_dim, nheads,
		num_encoder_layers, num_decoder_layers):
		super().__init__()
		# We take only convolutional layers from ResNet-50 model
		self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
		self.conv = nn.Conv2d(2048, hidden_dim, 1)
		self.transformer = nn.Transformer(hidden_dim, nheads,
		num_encoder_layers, num_decoder_layers)
		self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
		self.linear_bbox = nn.Linear(hidden_dim, 4)
		self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
		self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
		self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
	def forward(self, inputs):
		x = self.backbone(inputs)
		h = self.conv(x)
		H, W = h.shape[-2:]
		pos = torch.cat([
		self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
		self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
		], dim=-1).flatten(0, 1).unsqueeze(1)
		h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
		self.query_pos.unsqueeze(1))
		return self.linear_class(h), self.linear_bbox(h).sigmoid()

总结

  DETR是基于Transformer架构的端到端的目标检测模型,把目标检测看作一个直接的集合预测问题,他简化了传统目标检测模型繁杂的前处理和后处理过程,并且随着Transformer架构的增加模型的性能也有所增加。DETR相当于使用’object queries’替代了’anchor’,使用二分图匹配去掉了之前的NMS。同时DETR还有一些不足,因为他是Transformer架构的所以不好优化,对于小目标的物体检测不足等等。但由于DETR的简单性有效性后续出来了一大批工作对于DETR做出了改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1976440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数仓入门:数据分析模型、数仓建模、离线实时数仓、Lambda、Kappa、湖仓一体

往期推荐 大数据HBase图文简介-CSDN博客 数仓分层ODS、DWD、DWM、DWS、DIM、DM、ADS-CSDN博客 数仓常见名词解析和名词之间的关系-CSDN博客 目录 0. 前言 0.1 浅谈维度建模 0.2 数据分析模型 1. 何为数据仓库 1.1 为什么不直接用业务平台的数据而要建设数仓? …

ChatGPT能代替网络作家吗?

最强AI视频生成:小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频百万播放量https://aitools.jurilu.com/ 当然可以!只要你玩写作AI玩得6,甚至可以达到某些大神的水平! 看看大神、小白、AI输出内容的区…

重塑企业知识库:AI搜索的深度应用与变革

在数字化浪潮的推动下,企业知识库已成为企业智慧的核心载体。而AI搜索技术的融入,让海量信息瞬间变得井然有序,触手可及。它不仅革新了传统的搜索方式,更开启了企业知识管理的新纪元,引领着企业向更加智能化、高效化的…

【人工智能】FPGA实现人工智能算法硬件加速学习笔记

一. FPGA的优势 FPGA拥有高度的重配置性和并行处理能力,能够同时处理多个运算单元和多个数据并行操作。FPGA与卷积神经网络(CNN)的结合,有助于提升CNN的部署效率和性能。由于FPGA功耗很低的特性进一步增强了其吸引力。此外,FPGA可以根据具体算法需求量身打造硬件加速器。针对动…

[CR]厚云填补_SEGDNet

Structure-transferring edge-enhanced grid dehazing network Abstract 在过去的二十年里,图像去雾问题在计算机视觉界受到了极大的关注。在雾霾条件下,由于空气中水汽和粉尘颗粒的散射,图像的清晰度严重降低,使得许多计算机视觉…

鸿蒙媒体开发【基于AVCodec能力的视频编解码】音频和视频

基于AVCodec能力的视频编解码 介绍 本实例基于AVCodec能力,提供基于视频编解码的视频播放和录制的功能。 视频播放的主要流程是将视频文件通过解封装->解码->送显/播放。视频录制的主要流程是相机采集->编码->封装成mp4文件。 播放支持的原子能力规…

【从0到1进阶Redis】Jedis 操作 Redis

笔记内容来自B站博主《遇见狂神说》:Redis视频链接 Jedis 是一个用于 Java 的 Redis 客户端库,它提供了一组 API 用于与 Redis 数据库进行交互。Redis 是一个高性能的键值存储数据库,广泛用于缓存、消息队列等场景。Jedis 使得 Java 开发者能…

图欧科技-IMYAI智能助手24年5月~7月更新日志大汇总

上一篇推文盘点了我们图欧科技团队近一年来的更新日志,可以说是跟随着人工智能时代的发展,我们的IMYAI也丝毫不落后于这场时代的浪潮!近三个月以来,我们的更新频率直线上升,现在我们AI网站已经成为一个集GPT、Claude、…

《学会 SpringMVC 系列 · 消息转换器 MessageConverters》

📢 大家好,我是 【战神刘玉栋】,有10多年的研发经验,致力于前后端技术栈的知识沉淀和传播。 💗 🌻 CSDN入驻不久,希望大家多多支持,后续会继续提升文章质量,绝不滥竽充数…

Inno Setup 安装界面、卸载界面+美化

Inno Setup Inno Setup用Delphi写成,其官方网站同时也提供源程序免费下载。它虽不能与Installshield这类恐龙级的安装制作软件相比,但也当之无愧算是后起之秀。Inno Setup是一个免费的安装制作软件,小巧、简便、精美是其最大特点,…

arduino程序—模拟输出(基础知识)

arduino程序—模拟输出(基础知识) 1-25 模拟输出1-analogWrite电路效果演示模拟输出analog output复合运算符示例程序Analogwrite() 1-26 模拟输出2-PWMPWM概念(极其重要) 1-27 模拟输出3-for电路效果演示程…

【Verilog-CBB】开发与验证(2)——单比特信号CDC同步器

引言 多时钟域的设计中,CDC处理的场景还是蛮多的。单比特信号在CDC时,为保证信号采样的安全性,降低亚稳态,必须要对信号做同步处理。CDC从时钟的快慢关系来说分为两种case:快到慢、慢到快。对于脉冲型的控制信号&…

『C++实战项目 负载均衡式在线OJ』一、项目介绍与效果展示(持续更新)

文章目录 一、项目介绍二、开发环境三、第三方库四、相关技术五、项目整体框架代码目录框架 代码仓库连接 点击这里✈ 一、项目介绍 本项目是实现一个仿 leetcode 的 OJ (Online-Judge)系统。更准确的说应该称之为leetcode 的裁剪版。因为本项目只实现了leetcode中…

‘#‘ is not followed by a macro parameter 关于宏定义的错误

今天在项目代码上想定义一个这样的宏,结果编译错误,这个宏定义类似这样的: #define DELETE_FILE_DPP(key) \ #ifdef PLATFORM_DPP \delete_file(&key); \ #endif 因为有平台之分需要用到编译宏,但不想每个调用的地方都写 #i…

HTML 专业词汇与语法规则

目录 1. 专业词汇 2. 语法规则 1. 专业词汇 标签&#xff08;tag&#xff09;&#xff1a;一堆尖叫号&#xff08;<>&#xff09;&#xff0c; 属性&#xff08;attribute&#xff09;&#xff1a;对标签特征设置的方式&#xff1b; 文本&#xff08;text&#xff0…

【外排序】--- 文件归并排序的实现

Welcome to 9ilks Code World (๑•́ ₃ •̀๑) 个人主页: 9ilk (๑•́ ₃ •̀๑) 文章专栏&#xff1a; 数据结构 我们之前学习的八大排序&#xff1a;冒泡&#xff0c;快排&#xff0c;插入&#xff0c;堆排等都是内排序&#xff0c;这些排序算法处理的都是…

java对接kimi详细说明,附完整项目

需求&#xff1a; 使用java封装kimi接口为http接口&#xff0c;并把调用kimi时的传参和返回数据&#xff0c;保存到mysql数据库中 自己记录一下&#xff0c;以做备忘。 具体步骤如下&#xff1a; 1.申请apiKey 访问&#xff1a;Moonshot AI - 开放平台使用手机号手机号验证…

SuccBI+低代码文档中心 — 低代码应用(SuccAP)(概论)

概述&#xff1a; 低代码是什么&#xff1f; 低代码就是通过易用的、可视化的操作、加上少量的代码或脚本的方式快速的搭建业务应用。 低代码的优势&#xff1f; 低代码可以提升开发人员的效率&#xff0c;也可以让非开发人员也能进行应用开发。 低代码的分类&#xff1a;…

基于SpringBoot的大学生信息兼职服务网站系统,源码、部署+讲解

目 录 摘 要 Abstract 目 录 绪 论 1 系统分析 1.1可行性分析 1.1.1经济可行性分析 1.1.2技术可行性分析 1.1.3操作可行性分析 1.2需求分析 1.2.1从学生的角度 1.2.2从企业的角度 1.2.3从管理员的角度 1.3用例建模 1.3.1识别参与者用例 1.3.2用…

3.5 菜单资源

菜单分类 窗口的顶层菜单弹出式菜单&#xff08;鼠标右键的那些选项&#xff0c;记事本窗口左上角点击“文件”弹出的这些&#xff09;系统菜单&#xff08;记事本左上角的图标&#xff09; HMENU类型表示菜单&#xff0c;ID表示菜单项 资源相关 资源脚本文件:*.rc文件编译器…