RT-DETR详解之 Decoder 层

news2024/11/24 10:32:05

在上一篇博客中,博主已经讲解了如何利用Uncertainty-minimal Query Selection选择出好的特征,接下来便要将这些特征输入到Decoder中进行解码,需要注意的是,在RT-DETR的Encoder中,使用的是标准的自注意力计算方法,而在其Decoder中,则使用的是可变形自注意力(deformable attention),可变形自注意力能够大幅的降低计算量,同时该部分还使用到了CUDA算子,能够加快运行速度,当然,这个可变形自注意力计算并非是RT-DETR的创新点,但其作用却是极大,在DINO,DN-Deformable-DETR中都有使用。

在这里插入图片描述

关于Deformable-DETR,博主曾经介绍过,大家如果有兴趣可以参考博主这篇博文:

Deformable DETR模型学习记录

Decoder参数

输入Decoder的参数如下:

out_bboxes, out_logits = self.decoder(
            target,
            init_ref_points_unact,
            memory,
            spatial_shapes,
            level_start_index,
            self.dec_bbox_head,
            self.dec_score_head,
            self.query_pos_head,
            attn_mask=attn_mask)

target 是查询向量添加噪声以及查询向量筛选后的特征向量,即498=198+30

在这里插入图片描述

init_ref_point_unct 是参考点的xywhAnchor

在这里插入图片描述

memoryEncoder输出的特征向量

在这里插入图片描述

spatial_shapesEncoder输出的三个特征图的维度

在这里插入图片描述

记录每个特征图开始的索引(已将特征图展平)

在这里插入图片描述

attn_mask 特征图掩膜

在这里插入图片描述

query_pos_head的结构如下:

MLP(
  (layers): ModuleList(
    (0): Linear(in_features=4, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=256, bias=True)
  )
  (act): ReLU(inplace=True)
)

TransformerDecoderLayer的结构如下:

ModuleList(
  (0-2): 3 x TransformerDecoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
    (dropout1): Dropout(p=0.0, inplace=False)
    (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (cross_attn): MSDeformableAttention(
      (sampling_offsets): Linear(in_features=256, out_features=192, bias=True)
      (attention_weights): Linear(in_features=256, out_features=96, bias=True)
      (value_proj): Linear(in_features=256, out_features=256, bias=True)
      (output_proj): Linear(in_features=256, out_features=256, bias=True)
    )
    (dropout2): Dropout(p=0.0, inplace=False)
    (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (linear1): Linear(in_features=256, out_features=1024, bias=True)
    (dropout3): Dropout(p=0.0, inplace=False)
    (linear2): Linear(in_features=1024, out_features=256, bias=True)
    (dropout4): Dropout(p=0.0, inplace=False)
    (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
)

多层Decoder Layer(TransformerDecoder)

多层DecoderLayer的操作如下:

def forward(self,
                tgt,
                ref_points_unact,
                memory,
                memory_spatial_shapes,
                memory_level_start_index,
                bbox_head,
                score_head,
                query_pos_head,
                attn_mask=None,
                memory_mask=None):
        output = tgt
        dec_out_bboxes = []
        dec_out_logits = []
        ref_points_detach = F.sigmoid(ref_points_unact)

        for i, layer in enumerate(self.layers):
            ref_points_input = ref_points_detach.unsqueeze(2)
            query_pos_embed = query_pos_head(ref_points_detach)

            output = layer(output, ref_points_input, memory,
                           memory_spatial_shapes, memory_level_start_index,
                           attn_mask, memory_mask, query_pos_embed)

            inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))

            if self.training:
                dec_out_logits.append(score_head[i](output))
                if i == 0:
                    dec_out_bboxes.append(inter_ref_bbox)
                else:
                    dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))

            elif i == self.eval_idx:
                dec_out_logits.append(score_head[i](output))
                dec_out_bboxes.append(inter_ref_bbox)
                break

            ref_points = inter_ref_bbox
            ref_points_detach = inter_ref_bbox.detach(
            ) if self.training else inter_ref_bbox

        return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)

其中,核心代码如下:
在这里插入图片描述

单层DecoderLayer(TransformerDecoderLayer)

在该部分中,数据输入单层DecoderLayer后执行的操作如下:

在这里插入图片描述

可变形注意力计算模型(MSDeformableAttention)

可变形注意力模块构造如下:

MSDeformableAttention(
  (sampling_offsets): Linear(in_features=256, out_features=192, bias=True)
  (attention_weights): Linear(in_features=256, out_features=96, bias=True)
  (value_proj): Linear(in_features=256, out_features=256, bias=True)
  (output_proj): Linear(in_features=256, out_features=256, bias=True)
)

我们对照着Deformable-DETR的结构图来观察一下输入参数,首先是Query Feature,其对应的参数是self.with_pos_embed(tgt, query_pos_embed)Reference Point 的维度为torch.Size([4, 498, 1, 4]),在计算时,我们只选用中心点坐标即可,Input Feature Maps对应的是memory,即Encoder输出的特征图。
在这里插入图片描述
关于这个过程的代码,我就不在此一一赘述了,我们只需知道最终得到的结果即可。
最终得到可变形交叉注意力的计算结果如下:

在这里插入图片描述

单层DecoderLayer结果

将可变形自注意力计算结果拿到后,便是一系列normal等操作,最后返回单层DecoderLayer的结果:

在这里插入图片描述

这个结果会进行如下操作:

在这里插入图片描述

多层DecoderLayer结果

上述过程是在循环里,代码中有3层,经过多层DecodeLayer计算后,最终得到输出的分类结果与回归结果,将其返回,该部分完整代码如下:

out_bboxes, out_logits = self.decoder(
            target,
            init_ref_points_unact,
            memory,
            spatial_shapes,
            level_start_index,
            self.dec_bbox_head,
            self.dec_score_head,
            self.query_pos_head,
            attn_mask=attn_mask)

在这里插入图片描述

最终 Decoder 模块的输出结果如下:

pred_logits:300个特征向量产生的分类结果
pred_boxes:300个特征向量产生的Anchor
aux_outputs,每个Decoder层的结果,因为Decoder中有3层,因此其采用 list 形式存储,每个list中的结果如下:

在这里插入图片描述

dn_aux_outputs为每层Decoder加噪查询向量输出结果

在这里插入图片描述

此外,还有加噪向量

在这里插入图片描述

最终完成了Decoder的计算,接下来便是通过匈牙利匹配方法来匹配预测结果与目标了,同时进行损失计算。

在这里插入图片描述

可变形注意力分值计算方法如下:

def deformable_attention_core_func(value, value_spatial_shapes, sampling_locations, attention_weights):
    """
    Args:
        value (Tensor): [bs, value_length, n_head, c]
        value_spatial_shapes (Tensor|List): [n_levels, 2]
        value_level_start_index (Tensor|List): [n_levels]
        sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
        attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]

    Returns:
        output (Tensor): [bs, Length_{query}, C]
    """
    bs, _, n_head, c = value.shape
    _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape

    split_shape = [h * w for h, w in value_spatial_shapes]
    value_list = value.split(split_shape, dim=1)
    sampling_grids = 2 * sampling_locations - 1
    sampling_value_list = []
    for level, (h, w) in enumerate(value_spatial_shapes):
        # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
        value_l_ = value_list[level].flatten(2).permute(
            0, 2, 1).reshape(bs * n_head, c, h, w)
        # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
        sampling_grid_l_ = sampling_grids[:, :, :, level].permute(
            0, 2, 1, 3, 4).flatten(0, 1)
        # N_*M_, D_, Lq_, P_
        sampling_value_l_ = F.grid_sample(
            value_l_,
            sampling_grid_l_,
            mode='bilinear',
            padding_mode='zeros',
            align_corners=False)
        sampling_value_list.append(sampling_value_l_)
    # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_)
    attention_weights = attention_weights.permute(0, 2, 1, 3, 4).reshape(
        bs * n_head, 1, Len_q, n_levels * n_points)
    output = (torch.stack(
        sampling_value_list, dim=-2).flatten(-2) *
              attention_weights).sum(-1).reshape(bs, n_head * c, Len_q)

    return output.permute(0, 2, 1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1813501.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【多重背包 动态规划】2585. 获得分数的方法数

本文涉及知识点 动态规划汇总 背包问题汇总 C算法:前缀和、前缀乘积、前缀异或的原理、源码及测试用例 包括课程视频 LeetCode2585. 获得分数的方法数 考试中有 n 种类型的题目。给你一个整数 target 和一个下标从 0 开始的二维整数数组 types ,其中 …

TCP三次握手和四次挥手过程简介

接上篇 传输层部分 链路层、网络层、传输层和应用层协议详解分析-CSDN博客文章浏览阅读689次,点赞10次,收藏15次。wireshark抓包分析-CSDN博客wireshark是网络包分析工具网络包分析工具的主要作用是尝试捕获网络包,并尝试显示包的尽可能详细…

bugku---misc---贝斯手

1、下载附件,解压之后得到下面文件 2、zip需要密码,但是介绍里面给出了提示 3、再结合图片,是古力娜扎,搜索了以下她的生日是1992。应该就是密码 4、破解flag.zip得到一段文本 5、结合题目描述说的贝斯手,猜测应该是b…

win11 默认程序中找不到typora 和设置typora为md的默认打开程序

1.找到任意一个.md文件 2.在任意一个.md文件的上面右键,点击--》打开方式--》选择其他应用--》在电脑上选择应用--》(如果列表中没有,拉到最下面)在电脑上选择应用--》弹出文件浏览框 3.找到安装typora的时候的exe文件&#xff0c…

Flutter打包网络问题解决办法

问题情况":app:compileReleaseJavaWithJavac" 报错的最主要问题其实在下一句 Failed to find Build Tools revision 30.0.3,请查看自己的Android sdk版本,比如我的就是’34.0.0’版本. 解决办法: 在app/build.gradle中的android下添加,即可 buildToolsVersion 3…

【C++课程学习】:类和对象(拷贝构造和运算符重载)

🎁个人主页:我们的五年 🔍系列专栏:C课程学习 🎉欢迎大家点赞👍评论📝收藏⭐文章 目录 ✍拷贝构造: 🍉特点一: 🍉特点二: &…

实体类status属性使用枚举类型的步骤

1. 问题引出 当实体类的状态属性为Integer类型时,容易写错 2. 初步修改 把状态属性强制为某个类型,并且自定义一些可供选择的常量。 public class LessonStatus {public static final LessonStatus NOT_LEARNED new LessonStatus(0,"未学习"…

高边坡监测规范:确保边坡安全的科学准则

随着土木工程建设的不断发展,高边坡作为常见的土方边坡形式,其安全问题日益受到人们的关注。高边坡监测规范作为保障边坡安全的重要手段,对于预防边坡滑坡、坍塌等地质灾害具有重要意义。本文将对高边坡监测规范进行深入探讨,以期…

微信小程序组件传值

虽然微信小程序是比较轻量的,但是还是拥有组件的 这是文件的基本目录 我们的代码基本都在pages和components文件夹中 在component中创建组件 在component中 ,创建一个目录 我创建了一个 head目录 用于配置头部信息 我在这里创建了 一个头部组件&…

程序员,真有不变的技术和稳定的工作吗?

在程序员这个充满变化和创新的领域,很多人追求“稳定”的工作,认为找到一个合适的公司和岗位就能安心一辈子。然而,技术的快速更新迭代和市场需求的不断变化,使得真正的稳定变得越来越难以捉摸。作为程序员,我们需要反…

Java虚拟机 - JVM(类加载器、类加载的过程、双亲委派模型、GC垃圾回收)

一、JVM中数据内存区域划分 本地方法栈:里面保存的是native 关键字的方法,不是用Java写的,而是jvm内部用c实现的。 **程序计数器 和 虚拟机栈 每个线程都存在一份。 如果一个 JVM 进程 中有 10个 线程,那么就会存在 10份 程序计数…

通过文本指令生成3D模型纹理贴图

在3D建模的广阔领域中,我们总是追求更高效、更直观的方法来创建和编辑模型。今天,我要向大家介绍一种革新性的技术,它能够通过文本指令来精确地控制3D模型的细节,包括纹理贴图的生成。 1. 技术定位 这项技术主要定位于交互式3D建模领域,它为用户提供了一种全新的方式来创…

哈喽GPT-4o——对GPT-4o Prompt的思考与看法

目录 一、提示词二、提示词的优势1、提升理解能力2、增强专注力3、提高效率 三、什么样的算无效提示词?1、过于宽泛2、含糊不清3、太过复杂4、没有具体上下文5、缺乏明确目标6、过于开放7、使用专业术语但未定义8、缺乏相关性: 四、提示词正确的编写步骤…

Android 应用加固与重签名—使用AndroidStudio自带工具 apksigner

由 AndroidStudio 生成的release版本的app有自己的签名,但当应用加固后会删除原签名,需要重新签名。 一、加固方式: 使用基础版的腾讯云(乐固)进行免费加固,上传软件后等待在线加固完成后下载 apk 即可。…

设置SSHkeys多服务器免登录配置(ssh config)

一、背景: 多邮箱或者多git账号进行同一台电脑开发的情况。 有时候,开发时可能会面临一个情况,就是通过自己的电脑,可能同时需要开发多个不同地方的项目,或者说,自己建立的项目已经配置好SSH验证免密登录&a…

Qt C++ TCP服务端响应多客户端通讯

本示例使用的设备&#xff1a;WIFI无线4G网络RFID云读卡器远程网络开关物流网阅读器TTS语音-淘宝网 (taobao.com) #include "mainwindow.h" #include "ui_mainwindow.h" #include "QMessageBox" #include <QDebug> #include <exceptio…

【春秋云镜】Faculty Evaluation System未授权任意文件上传漏洞(CVE-2023-33440)

因为该靶场没有Write up,索性自己搞一下&#xff0c;方便别人&#xff0c;快乐自己&#xff01; 漏洞概述&#xff1a; Sourcecodester Faculty Evaluation System v1.0 is vulnerable to arbitrary code execution via /eval/ajax.php?actionsave_user. 漏洞复现&#xff…

Halcon 多相机统一坐标系

小杨说事-基于Halcon的多相机坐标系统一原理个人理解_多相机标定统一坐标系-CSDN博客 一、概述 最近在搞多相机标定等的相关问题&#xff0c;对于很大的场景&#xff0c;单个相机的视野是不够的&#xff0c;就必须要统一到一个坐标系下&#xff0c;因此我也用了4个相机&#…

SpringBoot Vue Bootstrap 旅游管理系统

SpringBoot Vue 旅游管理系统源码&#xff0c;附带环境安装&#xff0c;运行说明 源码地址 开发环境 jdk1.8,mysql8,nodejs16,navicat,idea 使用技术springboot mybatis vue bootstrap 部分功能截图预览

数据合规怎么做?哪些机构可以做数据合规

企业将数据资源入表的工作是一项复杂而全面的任务 财务部门负责统计数据资源的成本、销售数据等信息,并确保数据资源的会计处理符合会计要求&#xff1b; 数据部门则负责统计数据成本来源、价值实现路径等信息&#xff1b; 法务部门需要确认数据的收集和使用遵循相关的合规要求…