【深入理解DETR】DETR的原理与算法实现

news2024/11/29 20:51:35

1 DETR算法概述

在这里插入图片描述
在这里插入图片描述
①端到端
②Transformer-model

之前的方法都需要进行NMS操作去掉冗余的bounding box或者手工设计anchor, 这就需要了解先验知识,增加从超参数anchor的数量,

1.1 训练测试框架

一次从图像中预测n个object的类别

在这里插入图片描述

训练阶段我们将一张图像喂入DETR模型,会得到100个bounding box,并且得到这些预测框的类别信息和坐标信息
100个是超参数,因为大部分的图像中的object的数量不会超过100个
通过label我们知道图像中有2个object
然后使用匈牙利算法从预测出的100个候选框中筛选出2个预测框,与两个标注框一起计算损失,然后反向传播,优化模型参数

在这里插入图片描述

测试阶段:通过网络预测出100个预测框,把这100个预测框的置信度去和阈值进行比较,大于阈值的预测框保留。

这样在DETR里面是没有用到anchor也没有NMS操作的

算法的两个重点:一是基于集合的全局损失,通过二分类匹配得到与标注框匹配的独一无二的损失;二是引入encoder-decoder框架,

在这里插入图片描述

object queries是可学习的参数,通过他的尺寸指定输出的预测框的个数,在transforme中输出的token个数是等于输出的token个数,

没有固定的框架:只要框架能够支持这些,就能支持DETR
CNN+位置编码+encoder-decoder+MLP

2 DETR模型结构讲解

inference

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
①图像预处理,输入(batch_size,3,800,1066)
②经过CNN的backbone,得到feature map是(batch_size,2048,25,34),下采样了32倍,channel数是2048
③特征图再经过一个1x1的卷积,输入的通道数是2048,输出的通道数是256,这个卷积层的目的就是减少channel数,输出(batch_size,256,25,34)
④维度flatten,得到(batch_size,256,850)
⑤再把维度调换一下,得到(850,batch_size,256),850就是后面transformer的token的个数,256就是每个token的特征向量的长度
⑥特征图(850,batch_size,256)和位置编码都要传入encoder中,并且位置编码需要在每个多头自注意力层里都要加到key和query上,这就和标准的transformer不一样了。对比标准的transformer结构,位置编码是直接加到输入上的,但是DETR的encoder的位置编码,在每个堆叠的encoder-decoder中都要使用位置编码
⑦query的初始值是0,(100,256),object query也是(100,256),encoder的输出包含了图像提取的全局信息,通过两个检测头得到预测框的坐标和类别

在这里插入图片描述
⑧decoder的下面部分可以理解为在学习anchor特征 ,decoder的上面部分可以理解为在得到encoder输出的全局信息后,以及anchor的特征基础上,学习和预测bounding box的坐标和目标的类别

在这里插入图片描述

代码

在这里插入图片描述

输入包括了两个参数:①src:从backbone里面得到的 image features ②pos 就是位置编码

两种位置编码方法:
在这里插入图片描述
在这里插入图片描述
可以二选一

在这里插入图片描述
src做dropout和跨层连接,模拟resnet,
src2 是FFN层 再经过relu

在这里插入图片描述
在这里插入图片描述

decoder:
参数:①tgt:queries (100,256) ②memory:就是encoder的输出 (850,batch_size,256) ③pos:位置编码 (850,batch_size,256)

④query_pos:就是Object queries (100,256)

①首先用with_pos_embed将queries和Object queries相加得到k,q,v就是queries
②然后对q,k,v进行Self-attention操作
③dropout和残差
④linear_norm1,覆盖tgt
⑤下一个query等于tgt加上Object queries,下一个k等于encoder输出的memory加上位置编码,下一个v就等于encoder输出的memory,再进行Multi-head Self-attention,得到tgt2
⑥dropout和残差
⑦linear_norm2,覆盖tgt
⑧FFN层包括一个全连接层,一个relu激活层,一个dropout,一个全连接层,输出tgt2
⑨dropout和残差
⑩linear_norm3

最终输出 (batch_size,100,256)
训练阶段是(6,batch_size,100,256)
因为堆叠了6个encoder-decoder,一次得到了6个

在这里插入图片描述
两个检测头,分别预测类别和bounding box的坐标
检测类别的FFN只是一个全连接层,92是因为coco数据集有91个类别,再加一个背景类别
检测bounding box的坐标的FFN是一个MLP,包括3个全连接层,前两个全连接层的输入和输出尺寸都是256,第3个的输入是256,输出是4,4是bounding box的(x,y,w,h),因为是需要相对坐标,所以做一个sigmoid归一化(0,1)

在测试阶段,设置一个类别置信度阈值,对于100个bounding box取置信度最大的那个类,作为bounding box的类别,

3 DETR损失函数

在这里插入图片描述
训练阶段能从网络中得到输出:是一个字典,包括了3个部分,

pred_logits和pred_boxes是decoder输出的类别预测和坐标预测结果(batch_size,100,92)和(batch_size,100,4)
batch_size这里被设置为2,aux_outputs是decoder的5个中间层的输出结果,中间层的输出和最终的decoder的检测头是一样的

在这里插入图片描述
在这里插入图片描述
要往矩阵中填的是预测框与真实的损失,其中包括两个部分,前半部分是类别损失,后半部分是坐标损失, c i c_i ci不为空,表示不计算背景的损失

在这里插入图片描述
outputs是预测值,targets是标注值,先把outputs中的预测类别提取出来,即out_prob(2,100,92) 2是batch_size,100是100个预测框,92是类别,flatten为(200,92)
第62行把标注里面的类别取出来,可以看到第一张图中有两个类别,分别是第82和第79个类别;第二张图中有4个类别,分别是第1、1、34、1个类,

第68行:要从预测的200个bounding box中提取出对应的损失,绿色和紫色分别表示第1和2张图中的类别损失,取负号就是公式的前半部分

在这里插入图片描述

匈牙利算法损失的第二部分是用来给bounding box打分的,传统的L1损失会存在问题:对于不同尺度的box计算的损失是相似的,为了缓解这一问题,采用L1损失和GIoU损失的线性结合,

在这里插入图片描述
在这里插入图片描述
第59行:从预测结果中提出坐标部分,(2,100,4),flatten成(200,4)
第63行:从targets中提出两张图像的标注坐标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1450025.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构——顺序表专题

目录 1. 数据结构的相关概念什么是数据结构为什么需要数据结构? 2. 顺序表顺序表的概念及结构顺序表分类静态顺序表动态顺序表 3. 动态顺序表的实现准备工作顺序表的初始化顺序表的扩容尾插头插尾删头删指定位置插入数据指定位置删除数据 4. 全部完整代码**test.c**…

构建智慧交通平台:架构设计与实现

随着城市交通的不断发展和智能化技术的迅速进步,智慧交通平台作为提升城市交通管理效率和水平的重要手段备受关注。本文将探讨如何设计和实现智慧交通平台的系统架构,以应对日益增长的城市交通需求,并提高交通管理的智能化水平。 ### 1. 智慧…

中小学信息学奥赛CSP-J认证 CCF非专业级别软件能力认证-入门组初赛模拟题第二套(阅读程序题)

CSP-J入门组初赛模拟题二 二、阅读程序题 (程序输入不超过数组或字符串定义的范围&#xff0c;判断题正确填√错误填X;除特殊说明外&#xff0c;判断题 1.5分&#xff0c;选择题3分&#xff0c;共计40分) 第一题 1 #include<bits/stdc.h> 2 using namespace std; 3 i…

算法——数论——快速幂

目录 快速幂 费马小定理 一、试题 算法训练 A的B的C次方次方 快速幂 快速幂是一种用于快速计算幂运算的算法。计算复杂度 O(log n)基本思想是利用指数 n 的二进制展开形式&#xff0c;将 转化为多个 a 的幂的乘积&#xff0c;然后通过迭代快速计算。 快速幂的示例代码&…

鸿蒙开发系列教程(二十一)--轮播处理

轮播处理 Swiper本身是一个容器组件&#xff0c;当设置了多个子组件后&#xff0c;可以对这些子组件进行轮播显示 在自身尺寸属性未被设置时&#xff0c;会自动根据子组件的大小设置自身的尺寸 参数&#xff1a; 通过loop属性控制是否循环播放&#xff0c;该属性默认值为tr…

[GXYCTF2019]禁止套娃

进来发现只有这句话&#xff0c;习惯性访问一下flag.php&#xff0c;发现不是404&#xff0c;那就证明flag就在这了&#xff0c;接下来要想办法拿到flag.php的源码。 这道题是.git文件泄露网页源码&#xff0c;githack拿到index.php源码 这里观察到多次判断&#xff0c;首先要…

【C/C++语法基础】2.输入与输出(✨新手推荐阅读)

前言 在C中&#xff0c;输入与输出是程序与用户进行交互的基本方式。C提供了多种方式进行数据的输入与输出&#xff0c;其中最常用的是printf、scanf、cin和cout。此外&#xff0c;我们还会讨论如何取消cin和cout的同步流&#xff0c;以及了解各种转义字符的用法。 1.printf函…

算法学习——LeetCode力扣回溯篇3

算法学习——LeetCode力扣回溯篇3 491. 非递减子序列 491. 非递减子序列 - 力扣&#xff08;LeetCode&#xff09; 描述 给你一个整数数组 nums &#xff0c;找出并返回所有该数组中不同的递增子序列&#xff0c;递增子序列中 至少有两个元素 。你可以按 任意顺序 返回答案。…

第23讲 微信用户管理实现

package com.java1234.entity;import com.baomidou.mybatisplus.annotation.TableName; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import lombok.Data;import java.io.Serializable; import java.util.Date;/*** 微信用户信息实体* author java1234_小…

docker (三)-开箱即用常用命令

一 docker架构 拉取镜像仓库中的镜像到本地&#xff0c;镜像运行产生一个容器 registry 镜像仓库 registry可以理解为镜像仓库&#xff0c;用于保存docker image。 Docker Hub 是docker官方的镜像仓库&#xff0c;docker命令默认从docker hub中拉取镜像。我们也可以搭建自己…

【学网攻】 第(28)节 -- OSPF虚链路

系列文章目录 目录 系列文章目录 文章目录 前言 一、什么是OSPF虚链路&#xff1f; 二、实验 1.引入 实验目标 实验背景 技术原理 实验步骤 实验设备 实验拓扑图 实验配置 扩展 实验拓扑图 实验配置 实验验证 文章目录 【学网攻】 第(1)节 -- 认识网络【学网攻…

mac电脑上使用android studio创建flutter项目

mac电脑环境配置可以看这篇文章&#xff1a;https://xiaoshen.blog.csdn.net/article/details/136068650 配置玩环境之后&#xff0c;开始创建第一个flutter项目&#xff1a;点击new flutter project或者new project都可以 然后选择flutter&#xff1a; 并将sdk配置为解压后的…

Linux——网络通信TCP通信常用的接口和tco服务demo

文章目录 TCP通信所需要的套接字socket()bind()listen()acceptconnect() 封装TCP socket TCP通信所需要的套接字 socket() socket()函数主要作用是返回一个描述符&#xff0c;他的作用就是打开一个网络通讯端口&#xff0c;返回的这个描述符其实就可以理解为一个文件描述符&a…

【数位dp】【动态规划】【状态压缩】【推荐】1012. 至少有 1 位重复的数字

作者推荐 视频算法专题 本文涉及知识点 动态规划汇总 LeetCode:1012. 至少有 1 位重复的数字 给定正整数 n&#xff0c;返回在 [1, n] 范围内具有 至少 1 位 重复数字的正整数的个数。 示例 1&#xff1a; 输入&#xff1a;n 20 输出&#xff1a;1 解释&#xff1a;具有至…

Vue.js2+Cesium1.103.0 十五、计算方位角

Vue.js2Cesium1.103.0 十五、计算方位角 Demo <template><divid"cesium-container"style"width: 100%; height: 100%;"/> </template><script> /* eslint-disable no-undef */ /* eslint-disable new-cap */ /* eslint-disable n…

Java学习第十四节之冒泡排序

冒泡排序 package array;import java.util.Arrays;//冒泡排序 //1.比较数组中&#xff0c;两个相邻的元素&#xff0c;如果第一个数比第二个数大&#xff0c;我们就交换他们的位置 //2.每一次比较&#xff0c;都会产生出一个最大&#xff0c;或者最小的数字 //3.下一轮则可以少…

《Git 简易速速上手小册》第10章:未来趋势与扩展阅读(2024 最新版)

文章目录 10.1 Git 与开源社区10.1.1 基础知识讲解10.1.2 重点案例&#xff1a;Python 社区使用 Git10.1.3 拓展案例 1&#xff1a;Git 在大型开源项目中的角色10.1.4 拓展案例 2&#xff1a;支持开源项目的 Git 托管平台 10.2 新兴技术与 Git 的整合10.2.1 基础知识讲解10.2.2…

【机器学习笔记】5 机器学习实践

数据集划分 子集划分 训练集&#xff08;Training Set&#xff09;&#xff1a;帮助我们训练模型&#xff0c;简单的说就是通过训练集的数据让我们确定拟合曲线的参数。 验证集&#xff08;Validation Set&#xff09;&#xff1a;也叫做开发集&#xff08; Dev Set &#xf…

【蓝桥杯单片机入门记录】LED灯(附多个例程)

目录 一、LED灯概述 1.1 LED发光原理 1.2电路原理图 1.3电路实物图 1.4 开发板LED灯原理图 1.4.1共阳极LED灯操控原理&#xff08;本开发板&#xff09; &#xff08;非实际原理图&#xff0c;便于理解版本&#xff09;由图可以看出&#xff0c;每个LED灯的左边&#xf…

深度理解实分析:超越公式与算法的学习方法

在数学的学习旅程中&#xff0c;微积分和线性代数为许多学生提供了直观且具体的入门体验。它们通常依赖于明确的公式、算法以及解题步骤&#xff0c;而这些元素往往可以通过记忆和机械练习来掌握。然而&#xff0c;当我们迈入实分析的领域时&#xff0c;我们面临着一种全新的挑…