DETR整体模型结构解析

news2024/11/15 17:42:31

DETR流程

  1. Backbone用卷积神经网络抽特征。最后通过一层1*1卷积转化到d_model维度fm(B,d_model,HW)。

  2. position embedding建立跟fm维度相同的位置编码(B,d_model,HW)。

  3. Transformer Encoder,V为fm,K,Q为fm+position embedding。因为V代表的是图像特征。所以不添加位置编码

  4. Transformer Decoder。生成一个固定大小(query_num)的object query(B,q_num,d_model)比如100个预测框。Decoder输入tgt与object query形状相同。代码中为torch.zero()。第一层selfattention K,V为tgt+query,Q为tgt。第二层Q为上一层输出+query。V为encoder输出,K为encoder输出+position。这里V仍然代表图像特征所以不添加位置编码

  5. 用输出的100个object query框和ground truth框做一个匹配,然后在一一配对好的框中去计算目标检测的loss(分类loss与回归loss(L1+IOU))

  6. 二分图匹配与匈牙利算法

    DETR 预测了一组固定大小的 N = 100 个边界框

    将 ground-truth 也扩展成 N = 100 个检测框

    使用一个额外的特殊类标签 ϕ 来表示在未检测到任何对象,或者认为是背景类别。

    这样预测和真实都是两个100 个元素的集合了

    采用匈牙利算法进行二分图匹配,对预测集合和真实集合的元素进行一一对应,使得匹配损失最小。

  7. 推理过程不需要二分图匹配,只需要取最大得分框即可

代码详细参考:

transformer 在 CV 中的应用(二) DETR 目标检测网络 -

网络结构

参数说明:B:batchsize大小,C通道数,H,W:CNN输出特征图的高宽。d_model设定的特征维度大小如512。
Q,K,V:自注意力矩阵。l_q:Q矩阵的长度,l_kv:K,V矩阵的长度。KV矩阵的长度必须相同,Q矩阵长度可以跟KV矩阵长度不同
Q矩阵维度:(B,l_q,d_model)
K矩阵维度:(B,l_kv,d_model)
V矩阵维度:(B,l_kv,d_model)
object_query维度(B,q_num,d_model)

Backbone:

img→CNNbackbone→fm特征图(B,C,H,W) → fm特征图输入到transformer中时要再经过一层卷积将通道数转化成d_ model。C→d_model.

position embedding(B,d_model,H*W)。backbone通过CNN提取图像特征,然后通过特征图生成尺度对应的位置编码。

position embedding:

位置编码官方实现了两种,一种是固定位置编码,另一种是自学习位置编码,这里就介绍固定位置编码。

位置编码要考虑 x, y 两个方向,图像中任意一个点 (h, w) 有一个位置,这个位置编码长度为 256 ,前 128 维代表 h 的位置编码, 后 128 维代表 w 的位置编码,把这两个 128 维的向量拼接起来就得到一个 256 维的向量,它代表 (h, w) 的位置编码。位置编码的计算公式如下图所示

在这里插入图片描述

Transformer
DETRtransformer结构图
在这里插入图片描述

接受CNN提取的特征(B,d_model,HW),位置编码(B,d_model,HW),querys(B,query_num,d_model)

encoder:q,k添加位置编码。v代表图像本身特征,不添加位置编码。multi_head_attention跟FFN后都带了两个残差连接。

# post代表norm放在后面
def forward_post(self,
                 src,
                 src_mask: Optional[Tensor] = None,
                 src_key_padding_mask: Optional[Tensor] = None,
                 pos: Optional[Tensor] = None):
    q = k = self.with_pos_embed(src, pos)  #q,k增加position
    src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]
    src = src + self.dropout1(src2)   # 残差
    src = self.norm1(src)
    # ffn
    src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
    # 残差
    src = src + self.dropout2(src2)
    src = self.norm2(src)
    return src

decoder:

设定一个object queries(num_query,d_model)

有两层multihead self attention

  • 第一层obquery添加到K,Q上作为position embedding

第二层的Q来于decoder,K,V来自于encoder输出。

第二层self attention K添加编码,Q增加object queries。V代表图像特征,不添加任何信息

KV要有相同维度,Q可以跟KV在长度维度上不同,d_model维度相同

softmax(QKt/(d^0.5))V→矩阵乘法Q*Kt:(l_q,d_model)@(d_model_l_kv)→(l_q,l_kv)

再乘以V(l_q,l_kv)@(l_kv,d_model)→(l_q,d_model)

def forward_post(self, tgt, memory,
                 tgt_mask: Optional[Tensor] = None,
                 memory_mask: Optional[Tensor] = None,
                 tgt_key_padding_mask: Optional[Tensor] = None,
                 memory_key_padding_mask: Optional[Tensor] = None,
                 pos: Optional[Tensor] = None,
                 query_pos: Optional[Tensor] = None):
    '''
    :param tgt: query_pos rep 2tensor of shape (bs, c, h, w) ->tgt = torch.zeros_like(query_embed)
    :param memory:
    :param tgt_mask:
    :param memory_mask:
    :param tgt_key_padding_mask:
    :param memory_key_padding_mask:
    :param pos:
    :param query_pos:
    :return:
    '''
    q = k = self.with_pos_embed(tgt, query_pos)
    tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                          key_padding_mask=tgt_key_padding_mask)[0]
    tgt = tgt + self.dropout1(tgt2)
    tgt = self.norm1(tgt)
    tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                               key=self.with_pos_embed(memory, pos),
                               value=memory, attn_mask=memory_mask,
                               key_padding_mask=memory_key_padding_mask)[0]
    tgt = tgt + self.dropout2(tgt2)
    tgt = self.norm2(tgt)
    tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
    tgt = tgt + self.dropout3(tgt2)
    tgt = self.norm3(tgt)
    return tgt

  • DETR在计算attention的时候没有使用masked attention,因为将特征图展开成一维以后,所有像素都可能是互相关联的,因此没必要规定mask。

  • object queries的转换过程:object queries是预定义的目标查询的个数,代码中默认为100。它的意义是:根据Encoder编码的特征,Decoder将100个查询转化成100个目标,即最终预测这100个目标的类别和bbox位置。最终预测得到的shape应该为[N, 100, C],N为Batch Num,100个目标,C为预测的100个目标的类别数+1(背景类)以及bbox位置(4个值)

得到预测结果以后,将object predictions和ground truth box之间通过匈牙利算法进行二分匹配:假如有K个目标,那么100个object predictions中就会有K个能够匹配到这K个ground truth,其他的都会和“no object”匹配成功,使其在理论上每个object query都有唯一匹配的目标,不会存在重叠,所以DETR不需要nms进行后处理。

匹配

匈牙利匹配算法

匈牙利匹配算法,二分图匹配算法

scipy.optimize.linear_sum_assignment(cost_matrix, maximize=False)
#cost_matrix 二分图开销矩阵

https://blog.csdn.net/CV_Autobot/article/details/129096035

https://blog.csdn.net/lemonxiaoxiao/article/details/108672039

query与gt匹配

transformer通过query输出n_q数量的bbox与对应分类置信度

真实框[gt1,gt2,…gtn]

每个bbox与gt之间有一个距离度量。

距离度量由三部分组成:真实类别的置信度得分+边界框的L1loss+边界框的IOUloss

通过匈牙利算法找出距离最小的query_bbox为gt对应的prebbox

loss训练

整体流程

pred输出→100(num_queries)class,100(num_queries)boxes

gt(tagert)→100class,100boxes(包含背景类)

pred,gt→计算相互loss,得到二分图成本矩阵,然后计算匈牙利匹配算法→return匹配上的classes与boxes

匹配成功的框→计算真正的class损失(),box回归损失(GLOUloss)。。

预测框与真实框的差异来自于两方面:1.二分图匹配时带来的差异。2.预测框与真实框之间的差异。

  • 分类损失:交叉熵损失,针对所有predictions。没有匹配到的querybbox应该分类为背景

  • 回归损失:bbox loss采用了L1 loss和giou loss,针对匹配成功的querybbox

  • cardinality 损失,对应函数是 loss_cardinality ; cardinality 损失是计算预测有物体的个数的绝对损失,值是为了记录,不参与反向传播

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1714238.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

QT基础初学

目录 1.什么是QT 2.环境搭建 QT SDK的下载 QT的使用 QT构建项目 快捷指令 QT的简单编写 对象树 编码问题 组件 初识信号槽 窗口的释放 窗口坐标体系 1.什么是QT QT 是一个跨平台的 C 图形用户界面库,支持多个系统,用于开发具有图形界面的应…

File name ‘xxxx‘ differs from already included file name ‘xxxx‘ only in casing.

一、报错信息 VSCode报错如下: File name ‘d:/object/oral-data-management/src/components/VisitLogPopup/Info.vue’ differs from already included file name ‘d:/object/oral-data-management/src/components/VisitLogPopup/INfo.vue’ only in casing. The…

AI企业需要“联盟营销”?一文带你探索AI企业营销新玩法!

为什么联盟营销对AI业务有较大优势 联盟营销在电商领域、saas领域与其他产品领域同样有效。在AI业务中,它有效的原因与其他领域大不相同。 高好奇心和试用率 AI领域是创新的热点。它吸引了一群渴望探索和尝试每一项新技术的人群。这种蓬勃的好奇心为聪明的AI企业提…

大模型助力企业提效,九章云极DataCanvas公司联合腾讯搜狗输入法发布私有化解决方案

近日,九章云极DataCanvas公司与腾讯搜狗输入法的合作再次升级。在搜狗输入法开发者中心正式推出之际,九章云极DataCanvas公司作为搜狗输入法的首批开发合作伙伴,双方联合发布“企业知识管理助手”私有化解决方案。 “企业知识管理助手”整体私…

AI虚拟试穿革命:I2VEdit技术引领电商视频内容创新

在当今快速迭代的电子商务领域,用户体验与内容创新是企业竞争力的核心要素。随着AI技术的飞速进步,AI虚拟试穿已不再局限于静态图像,而是迈向了动态视频的新纪元。本文将深入解析一项革新性技术——I2VEdit,如何以其独到之处,为电商尤其是服装零售行业带来一场内容创作与产…

Opencv图像处理技术(图像轮廓)

1图像轮廓概念: 图像轮廓是指图像中连续的像素边界,这些边界通常代表了图像中的物体或者物体的边缘。在数字图像处理中,轮廓是由相同像素值组成的曲线,它们连接相同的颜色或灰度值,并且具有连续性。轮廓可以用来描述和…

【几何】输入0-360度任意的角度,求上面直线与椭圆相切点的坐标计算公式

输入0-360度任意的角度,求上面直线与椭圆相切点的坐标计算公式 使用积分计算 使用到的公式有椭圆公式: x 2 a 2 + y 2 b 2 = 1 \frac{x^2}{a^2}+\frac{y^2}{b^2} = 1 a2x2​+b2y2​=1 平面旋转公式 X r = cos ⁡ θ ∗ ( X s − X O ) − sin ⁡ θ ∗ ( Y s − Y O ) + X …

文心智能体平台:快来创建你的Java学习小助理,全方位辅助学习

文章目录 一、文心智能体平台1.1平台介绍1.2智能体介绍 二、智能体创建三、体验与总结 一、文心智能体平台 文心智能体平台是百度推出的基于文心大模型的智能体(Agent)平台,支持广大开发者根据自身行业领域、应用场景,选取不同类…

Three.js 中的场景与相机基础

Three.js 中的场景与相机基础 一、场景(Scene) 在 Three.js 中,场景是所有 3D 对象存在和交互的容器。艾斯视觉作为行业ui设计与前端开发服务商很高兴能在这里与你共同探讨:它就像是一个虚拟的 3D 空间,我们可以在其中…

[vue3后台管理一]vue3下载安装及环境配置教程

[vue3后台管理二]vue3下载安装element plus 一、vue3下载安装element plus cnpm install element plus二:修改main.js import { createApp } from "vue"; import App from "./App.vue"; import ElementPlus from "element-plus"; …

好用的国产大文件传输软件有哪些,快来看看吧

在这个数字化飞速发展的时代,我们每天都在与各种文件打交道,从简单的文档到庞大的视频素材,文件的体积越来越大,传统的文件传输方式逐渐显得力不从心。面对这个挑战,大文件传输软件应运而生,它们不仅解决了…

Khoj:开源个人AI助手能连接你的在线和本地文档充当你的第二大脑

项目简介 Khoj是一个开源的、个人化的AI助手,旨在充当你的第二大脑。它能够帮助你回答任何问题,不论这些问题是在线上的还是在你自己的笔记中。Khoi 支持使用在线AI模型(例如 GPT-4)或私有、本地的语言模型(例如 Llama3)。你可以选择自托管 Khoj&#x…

自定义tabbar

一: 在这里配置true,则会找根目录下的文件,作为路由渲染到每个tabbar页面下.原本是默认情况,自定义内部路由,放到每个tabbar页面下. 二: 上面做到了,每个tabbar页面的底部都是这个组件. 三: Tabbar 标签栏 - Vant Weapp (youzan.github.io) 使用组件库,配置组件内容. <va…

金融反欺诈指南:车险欺诈为何如此猖獗?

目录 车险欺诈猖獗的原因 车险欺诈的识别难点 多重合作打击车险欺诈 保险企业需要提升反欺诈能力 监管部门需要加强协同合作 青岛市人民检察院在其官方微信公众号上发布的梁某保险诈骗案显示&#xff0c;2020 年以来&#xff0c;某汽修厂负责人梁某、某汽车服务公司负责人孙某&…

pyqt Qtreeview分层控件

pyqt Qtreeview分层控件 介绍效果代码 介绍 QTreeView 是 PyQt中的一个控件&#xff0c;它用于展示分层数据&#xff0c;如目录结构、文件系统等。QTreeView 通常与模型&#xff08;如 QStandardItemModel、QFileSystemModel 或自定义模型&#xff09;一起使用&#xff0c;以管…

2024 GIAC 全球互联网架构大会:拓数派向量数据库 PieCloudVector 架构设计与案例实践

5月24-25日&#xff0c;msup 和高可用架构联合举办了第11届 GIAC 全球互联网架构大会。会议聚焦“共话AI技术的最新进展、架构实践和未来趋势”主题&#xff0c;邀请了 100 余位行业内的领军人物和革新者&#xff0c;分享”Agent/RAG 技术、云原生、基座大模型“等多个热门技术…

07 FreeRTOS 事件组(event group)

1、事件组概念 1.1 基本概念 使用事件组可以等待某个事件、若干事件中的任意一个事件、若干事件中的所有事件&#xff0c;但是不能指定若干事件中的某些事件。 事件组可以简单地认为就是一个整数&#xff1a;这个整数的每一位表示一个事件&#xff1b;每一位事件的含义由程序员…

真实测评:9款电脑加密软件最新排名

2024年已经过半&#xff0c;电脑加密软件市场发生了很多变化&#xff0c;根据资料汇总&#xff0c;一些电脑加密软件排名也发生了变化&#xff0c;下面是最近的排名。 1、安企神&#xff1a; 可以试试7天的免费试用&#xff0c;用过之后就回不去了 试用版https://work.weix…

【408真题】2009-26

“接”是针对题目进行必要的分析&#xff0c;比较简略&#xff1b; “化”是对题目中所涉及到的知识点进行详细解释&#xff1b; “发”是对此题型的解题套路总结&#xff0c;并结合历年真题或者典型例题进行运用。 涉及到的知识全部来源于王道各科教材&#xff08;2025版&…

29-ESP32-S3-WIFI_Driver-00 STA模式扫描全部 AP

ESP32-S3 WIFI_Driver 引言 ESP32-S3是一款集成了Wi-Fi和蓝牙功能的芯片。关于WIFI的部分&#xff0c;其实内容比我想象的要多得多。所以通常来说&#xff0c;如果你想要编写自己的Wi-Fi应用程序&#xff0c;最快捷的方法就是先找一个类似的示例应用&#xff0c;然后将它的相…