DAB-Deformable-DETR源码学习记录之模型构建(二)

news2025/2/27 5:27:11

书接上回,上篇博客中我们学习到了Encoder模块,接下来我们来学习Decoder模块其代码是如何实现的。
其实Deformable-DETR最大的创新在于其提出了可变形注意力模型以及多尺度融合模块:
其主要表现在Backbone模块以及self-attention核cross-attention的计算上。这些方法都在DINO-DETR中得到继承,此外DAB-DETR中的Anchor Query设计与bounding box强化机制也有涉及。

Encoder模块

首先经过Encoder后的输出结果为 memory:torch.Size([2, 9620, 256]),其分别代表不同level的特征信息:tensor([ 0, 7220, 9044, 9500], device=‘cuda:0’)

memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)

Two-Stage

核心思想:
Encoder会生成特征memory,再自己生成初步proposals(其实就是特征图上的点坐标 xywh)。
然后分别使用非共享检测头的分类分支对memory进行分类预测,得到对每个类别的分类结果;
再用回归分支进行回归预测,得到proposals的偏移量(xywh)。再用初步proposals偏移量 得到第一个阶段的预测proposals。
然后选取top-k个分数最高的那批预测proposals作为Decoder的参考点。
并且,Decoder的object query和 query pos都是由参考点通过位置嵌入(position embedding)再接上一个全连接层 + LN层处理生成的。

Two-Stage主要是应用在初始化参考点坐标上。
one-stage的参考点是get_reference_points函数生成的,而two-stage参考点是通过gen_encoder_output_proposals函数生成的。

one-stage初始化方法

def get_reference_points(spatial_shapes, valid_ratios, device):
    """
    生成参考点   reference points  为什么参考点是中心点?  为什么要归一化?
    spatial_shapes: 4个特征图的shape [4, 2]
    valid_ratios: 4个特征图中非padding部分的边长占其边长的比例  [bs, 4, 2]  如全是1
    device: cuda:0
    """
    reference_points_list = []
    # 遍历4个特征图的shape  比如 H_=100  W_=150
    for lvl, (H_, W_) in enumerate(spatial_shapes):
        # 0.5 -> 99.5 取100个点  0.5 1.5 2.5 ... 99.5
        # 0.5 -> 149.5 取150个点 0.5 1.5 2.5 ... 149.5
        # ref_y: [100, 150]  第一行:150个0.5  第二行:150个1.5 ... 第100行:150个99.5
        # ref_x: [100, 150]  第一行:0.5 1.5...149.5   100行全部相同
        ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
                                      torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
        # [100, 150] -> [bs, 15000]  150个0.5 + 150个1.5 + ... + 150个99.5 -> 除以100 归一化
        ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
        # [100, 150] -> [bs, 15000]  100个: 0.5 1.5 ... 149.5  -> 除以150 归一化
        ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
        # [bs, 15000, 2] 每一项都是xy
        ref = torch.stack((ref_x, ref_y), -1)
        reference_points_list.append(ref)
    # list4: [bs, H/8*W/8, 2] + [bs, H/16*W/16, 2] + [bs, H/32*W/32, 2] + [bs, H/64*W/64, 2] ->
    # [bs, H/8*W/8+H/16*W/16+H/32*W/32+H/64*W/64, 2]
    reference_points = torch.cat(reference_points_list, 1)
    # reference_points: [bs, H/8*W/8+H/16*W/16+H/32*W/32+H/64*W/64, 2] -> [bs, H/8*W/8+H/16*W/16+H/32*W/32+H/64*W/64, 1, 2]
    # valid_ratios: [1, 4, 2] -> [1, 1, 4, 2]
    # 复制4份 每个特征点都有4个归一化参考点 -> [bs, H/8*W/8+H/16*W/16+H/32*W/32+H/64*W/64, 4, 2]
    reference_points = reference_points[:, :, None] * valid_ratios[:, None]
    # 4个flatten后特征图的归一化参考点坐标
    return reference_points

Two-Stage参考点初始化方法

def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
    """得到第一阶段预测的所有proposal box output_proposals和处理后的Encoder输出output_memory
    memory: Encoder输出特征  [bs, H/8 * W/8 + ... + H/64 * W/64, 256]
    memory_padding_mask: Encoder输出特征对应的mask [bs, H/8 * W/8 + H/16 * W/16 + H/32 * W/32 + H/64 * W/64]
    spatial_shapes: [4, 2] backbone输出的4个特征图的shape
    """
    N_, S_, C_ = memory.shape  # bs  H/8 * W/8 + ... + H/64 * W/64  256
    base_scale = 4.0
    proposals = []
    _cur = 0   # 帮助找到mask中每个特征图的初始index
    for lvl, (H_, W_) in enumerate(spatial_shapes):  # 如H_=76  W_=112
        # 1、生成所有proposal box的中心点坐标xy
        # 展平后的mask [bs, 76, 112, 1]
        mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
        valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
        valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
        # grid_y = [76, 112]   76行112列  第一行全是0  第二行全是1 ... 第76行全是75
        # grid_x = [76, 112]   76行112列  76行全是 0 1 2 ... 111
        grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
                                        torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
        # grid = [76, 112, 2(xy)]   这个特征图上的所有坐标点x,y
        grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
        scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)  # [bs, 1, 1, 2(xy)]
        # [76, 112, 2(xy)] -> [1, 76, 112, 2] + 0.5 得到所有网格中心点坐标  这里和one-stage的get_reference_points函数原理是一样的
        grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale

        # 2、生成所有proposal box的宽高wh  第i层特征默认wh = 0.05 * (2**i)
        wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
        # 3、concat xy+wh -> proposal xywh [bs, 76x112, 4(xywh)]
        proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
        proposals.append(proposal)
        _cur += (H_ * W_)
    # concat 4 feature map proposals [bs, H/8 x W/8 + ... + H/64 x W/64] = [bs, 11312, 4]
    output_proposals = torch.cat(proposals, 1)
    # 筛选一下 xywh 都要处于(0.01,0.99)之间
    output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
    #用log(x/1-x)
    output_proposals = torch.log(output_proposals / (1 - output_proposals))
    # mask的地方是无效的 直接用inf代替
    output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
    # 再按条件筛选一下 不符合的用用inf代替
    output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))

    output_memory = memory
    output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
    output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
    # 对encoder输出进行处理:全连接层 + LayerNorm
    output_memory = self.enc_output_norm(self.enc_output(output_memory))
    return output_memory, output_proposals

for循环里是对不同level的所有格点创建不同尺寸的anchor框,scale其实是对有效区域的处理,后续对output_proposals的处理是筛选掉边界附近的候选,输出是对应位置的特征和编码后的proposal, 对应位置的特征用于映射proposal的类别score以及校正偏差。值得注意的是proposal并没有直接使用原始坐标,而是进行了log的编码, 在forward中的two_stage情况提取reference_points是使用sigmoid函数进行了解码,我们假设偏置量为0,可以发现:

在这里插入图片描述

所谓的双阶段其实就是在Encoder后不是将数据直接送入Decoder,而是送入MLP与全连接层进行分类与回归后再送入Decoder。

enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) #torch.Size([2, 9620, 91])
enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals#torch.Size([2, 9620, 4])

在这里插入图片描述

随后选择topk

topk = self.two_stage_num_proposals
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]#torch.Size([2, 300])
#torch.Size([2, 300])
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
#torch.Size([2, 300, 4])
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
#torch.Size([2, 300, 4])
topk_coords_unact = topk_coords_unact.detach()

将其进行sigmoid,由于gen_encoder_output_proposals进行了log,此时sigmoid刚好可以变回初始值

reference_points = topk_coords_unact.sigmoid() #torch.Size([2, 300, 4])

在这里插入图片描述

随后得到初始化参考点坐标信息:
层归一化定义:

self.pos_trans_norm = nn.LayerNorm(d_model * 2)
#torch.Size([2, 300, 4])       
init_reference_out = reference_points
#pos_trans_norm是层归一化,得到结果torch.Size([2, 300, 512])
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
query_embed, tgt = torch.split(pos_trans_out, c, dim=2)

最终得到:query_embed torch.Size([2, 300, 256]),tgt torch.Size([2, 300, 256])

在这里插入图片描述

在这里插入图片描述

Decoder模块

终于,进入了Decoder模块,我们首先来看其传入的参数:
tgt:torch.Size([2, 300, 256])
reference_points:torch.Size([2, 300, 4])
memory:torch.Size([2, 9620, 256])
spatial_shapes:

tensor([[76, 95],
        [38, 48],
        [19, 24],
        [10, 12]], device='cuda:0')

level_start_index:tensor([ 0, 7220, 9044, 9500], device=‘cuda:0’)
query_embed:torch.Size([2, 300, 256])
mask_flatten:torch.Size([2, 9620])

hs, inter_references = self.decoder(tgt, reference_points, memory,
                                            spatial_shapes, level_start_index, valid_ratios, 
                                            query_pos=query_embed if not self.use_dab else None, 
                                            src_padding_mask=mask_flatten)

进入Decoder层:
其后就与DAB-DETR一致了,只是将cross_attention替换为可变形注意力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/468649.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

平台+AI:全面拥抱大模型的商业创新,打造企业数字化「柔性供应链」 | D3演讲实录

马斯克曾说&#xff1a;“高生产率解决诸多问题。” 在社会化内卷的大环境下&#xff0c;借助数智化“降本增效”已是不争事实。AI技术日新月异、大量信息繁杂涌现&#xff0c;无数原来烟囱式的模式亟需变革&#xff0c;平台与AI之间怎样融合&#xff0c;才能发挥更大的功效&a…

深度学习量化总结(PTQ、QAT)

背景 目前神经网络在许多前沿领域的应用取得了较大进展&#xff0c;但经常会带来很高的计算成本&#xff0c;对内存带宽和算力要求高。另外降低神经网络的功率和时延在现代网络集成到边缘设备时也极其关键&#xff0c;在这些场景中模型推理具有严格的功率和计算要求。神经网络…

如何减少项目在Corona和V-Ray中的3ds Max渲染时间?

相信在大多 3D 项目里&#xff0c;渲染是最耗费时间的部分&#xff0c;它不仅是建模和纹理化 3D 场景的过程&#xff0c;而是需要利用硬件来完成任务。我们在配备独立GPU和带有2到4个强大内核的CPU的中档计算机上&#xff0c;可以将3ds Max中创建和处理的项目轻松渲染完成&…

MATLAB实现车牌识别

车牌识别主要包括三个主要步骤&#xff1a;车牌区域定位、车牌字符分割、车牌字符识别。 本项目通过对拍摄的车牌图像进行灰度变换、边缘检测、腐蚀及平滑等过程来进行车牌图像预处理&#xff0c;并由此得到一种基于车牌颜色纹理特征的车牌定位方法&#xff0c;最终实现了车牌…

在Docker上部署SpringBoot项目

在Docker上部署SpringBoot项目 在学习中发现了部署的时候总是有各种问题,此文章只有操作步骤没有原理解释,只是用来提醒自己部署步骤 第一步:将SpringBoot项目打包成jar包 使用idea打包,点一下就行 第二部:编写Dockerfile文件 新建一个名为Dockerfile的文件,注意没有后缀…

improper Integral反常积分

笔记 笔记二 例题 hyperlink

安陆EGS20 SDRAM仿真

目录 一. 搭建仿真平台 二. 实现SDRAM连续写入1024个数据&#xff0c;然后再连续读出&#xff0c;并比较 1. 调试过程中问题&#xff1a; 2. 顶层代码 3. 功能代码 三. SDRAMFIFO实现上述功能调试 1. 代码设计要点 2. 仿真过程问题 3. 上板运行调试 安陆反馈&#xf…

80%的人都关注的电子合同签署疑问,君子签官方解答来了!

电子合同签错了在平台可以撤回吗&#xff1f;如果合同上名字签错了&#xff0c;有法律效力吗&#xff1f;签的电子合同&#xff0c;内容会不会被别人看见&#xff1f;… 最近&#xff0c;小编将80%的人都关注的电子合同签署问题进行了整理&#xff0c;官方专业解答帮助大家更好…

mac真机调试h5攻略

原因&#xff1a; h5项目想在mac本通过chrome://inspect/#devices调试 &#xff08;win上调试h5很简单&#xff0c;请参考&#xff1a;chrome真机调试Android_chrome 调试安卓_芒果终结者的博客-CSDN博客&#xff09; 调试步骤&#xff1a; 1. 需要下载安装安卓开发工具and…

信息化发展

信息系统是&#xff1a;管理模型、信息处理模型和系统实现条件结合的 信息系统生命周期&#xff1a; 可行性分析与项目开发计划 需求分析 概要设计 详细设计 编码 测试 可以简化为&#xff1a; 系统规划&#xff1a;现行情况的分析&#xff0c;可行性研究报告 -> 设计任务…

Java笔记_13(集合进阶2)

Java笔记_13 一、双列集合1.1、Map的常见API1.2、Map遍历方式一&#xff08;键找值&#xff09;1.3、Map集合遍历方法二&#xff08;键值对&#xff09;1.4、Map集合遍历方法三&#xff08;lambda表达式&#xff09;1.5、HashMap1.6、HashMap练习1.7、HashMap底层源码解析1.7、…

12秒内AI在手机上完成作画!谷歌提出扩散模型推理加速新方法

本文源自&#xff1a;量子位 只需12秒&#xff0c;只凭手机自己的算力&#xff0c;就能拿Stable Diffusion生成一张图像。 而且是完成了20次迭代的那种。 要知道&#xff0c;现在的扩散模型基本都超过了10亿参数&#xff0c;想要快速生成一张图片&#xff0c;要么基于云计算&…

Python 实现txt、excel、csv文件读写【附源码】

目录 前言 一、txt文件读写 二、excel文件读写 总结 前言 本文介绍使用Python进行文件读写操作&#xff0c;包括txt文件、excel文件(xlsx、xls、csv) 编译器使用的是PyCharm 一、txt文件读写 read() # 一次性读取全部内容readline() # 读取第一…

K8s入门教程:10分钟带你速览全程

K8s&#xff0c;英文全称为Kubernetes&#xff0c;就是基于容器的集群管理平台&#xff0c;是用于自动部署、扩缩和管理容器化应用程序的开源系统。 K8s是用来干啥的&#xff1f; 简单来说&#xff0c;可以用一句话来解释&#xff1a;K8s的特点就是所有主机上都装上docker&…

Win10老是蓝屏收集错误信息重启无效怎么办?

Win10老是蓝屏收集错误信息重启无效怎么办&#xff1f;有用户遇到了电脑开机蓝屏的情况&#xff0c;收集错误信息重启电脑之后&#xff0c;依然无法解决问题。那么这个问题要怎么去进行解决呢&#xff1f;接下来我们来看看以下具体的处理方法教学吧。 准备工作&#xff1a; 1、…

JAVA:基于Redis 实现计数器限流

1、简述 在现实世界中可能会出现服务器被虚假请求轰炸的情况&#xff0c;因此您可能希望控制这种虚假的请求。 一些实际使用情形可能如下所示&#xff1a; API配额管理-作为提供者&#xff0c;您可能希望根据用户的付款情况限制向服务器发出API请求的速率。这可以在客户端或服…

Bing 性能是如何跟随 .NET 一起迭代的?

大约两年前&#xff0c;我发表了一篇文章&#xff0c;详细的介绍了 Bing 的中央工作流引擎(XAP)从 .NET Framework 升级到 .NET 5 的过程。你可以通过这篇文章来了解 XAP 的工作原理&#xff0c;以及它在 Bing 全局中的位置。从那时起&#xff0c;XAP 一直是微软许多搜索和工作…

mysql语句高级用法使用记录和sql_mode=only_full_group_by错误解决

最近工作时用到的几种用法记录一下 sql_modeonly_full_group_by 报错 sql出错示例如下 column ‘qnaq.ta.issue_org_code’ which is not functionally dependent on columns in GROUP BY clause; this is incompatible with sql_modeonly_full_group_by 原因分析&#xff1a;…

云服务器使用jenkins+docker自动化部署SpringBoot项目

docker 安装jenkins&#xff0c;就这一步都恶心死了 //拉取镜像&#xff0c;踩了很多坑&#xff0c;用其它版本的镜像插件一直安装失败&#xff0c;最后用的是lts版本&#xff08;基础版&#xff09; 用其它版本要么是连不上插件的下载地址&#xff0c;要么是插件下载不成功 d…

Window10搭建GPU环境(CUDA、cuDNN)

一、查看CUDA版本 方法一&#xff0c;cmd命令 nvidia-smi下图的 CUDA 版本是11.7 方法二&#xff0c;点击 NVIDIA的图标 1.右键点击会出现nvidia 控制面板 或者 2.点击系统信息 3.点击组件 二.下载CUDA 到官网下载根据不同的版本 https://developer.nvidia.com/cud…