【代码实现】DETR原文解读及代码实现细节

news2024/11/29 11:55:11

1 模型总览

在这里插入图片描述

宏观上来说,DETR主要包含三部分:以卷积神经网络为主的骨干网(CNN Backbone)、以TRM(Transformer)为主的特征抽取及交互器以及以FFN为主的分类和回归头,如DETR中build()函数所示。DETR最出彩的地方在于,它摒弃了非端到端的处理过程,如NMS、anchor generation等,以集合预测的方式来端到端建模目标检测过程,并且将Transformer引入到目标检测中,打开新领域的大门)。

def build(args):

    backbone = build_backbone(args)
    transformer = build_transformer(args)

    model = DETR(
        backbone,# 骨干网
        transformer,# 重点部分
        num_classes=81,
        num_queries=100,# object query数量,作用相当于spatial embedding
        aux_loss=args.aux_loss)
    matcher = build_matcher(args)# 二分图匹配
    weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
    weight_dict['loss_giou'] = args.giou_loss_coef
    losses = ['labels', 'boxes', 'cardinality']
    postprocessors = {'bbox': PostProcess()}

    return model, criterion, postprocessors

2 DETR的基本流程

  • CNN backbone 提取图像的 feature
  • Transformer Encoder 通过 self-attention 建模全局关系对 feature 进行增强
  • Transformer Decoder 的输入是 object queries(spatial embedding) 和 Transformer encoder 的输出(content embedding),主要包含 self-attention 和 cross-attention 的过程。Self-attention 主要是对每个 query 之间做交互,让每个 query 能看到其他 query 在查询什么东西,从而不重复,类似与 NMS 的作用;Cross- attention 主要是将 object query 当做查询,encoder feature 当做 key,为了查询和 query 有关的区域。
  • 对 Decoder 输出的查询好了的 query,使用 FFN 提取出目标框的位置和类别信息

顺着上边的基本流程,从代码入手一点点理解原文的思想,下面开始!

3 backbone

首先是构建backbone模块的函数

def build_backbone(args):

    position_embedding = build_position_encoding(args)
    backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
    model = Joiner(backbone, position_embedding)
    model.num_channels = backbone.num_channels
    return model

3.1 根据特征图生成位置编码

def build_position_encoding(args):
    N_steps = args.hidden_dim // 2
    if args.position_embedding in ('v2', 'sine'):
        # TODO find a better way of exposing other arguments
        position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
    elif args.position_embedding in ('v3', 'learned'):
        position_embedding = PositionEmbeddingLearned(N_steps)
    else:
        raise ValueError(f"not supported {args.position_embedding}")

    return position_embedding

对于输入特征x,假定其尺寸为BxCxHxW,位置编码需要在H W两个维度上进行位置编码,所以一般会将hidden_dim(C, 通道)切分为两部分,一部分代表H另一部分代表W,最后在通道维度上进行拼接。DETR中位置编码主要是sine和learning position embedding。我自己仿照TRM写了一个可学习位置编码的实现,可以运行试试

# 验证learnable pos embedding机制
x = torch.randn((8, 3, 32, 32))
h,w=x.shape[-2:]
row_embed,col_embed=nn.Embedding(50,256),nn.Embedding(50,256)

i,j=torch.arange(w,device=x.device),torch.arange(h,device=x.device)
x_emb,y_emb=col_embed(i),row_embed(j)
x_cat=x_emb.unsqueeze(0).repeat(h,1,1)
y_cat=y_emb.unsqueeze(1).repeat(1,w,1)

pos=torch.cat([x_cat,y_cat],dim=-1)
pos_learn=pos.permute(2,0,1).unsqueeze(0).repeat(x.shape[0],1,1,1)# shape:(8,512,32,32)

构建backbone

这部分不太难,相关注释已经写在代码块中

class BackboneBase(nn.Module):

    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        if return_interm_layers:# 是否返回中间层,在多尺度融合操作时会用到
            return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:# 一般情况下只返回最后一层的输出
            return_layers = {'layer4': "0"}
        # IntermediateLayerGetter作用类似于Sequential,将多个神经层组合并可以指定返回中间输出
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels

    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)# 将数据传入网络,实例化网络得到输出,xs即为经过resnet四部分后的输出
        out: Dict[str, NestedTensor] = {}# 定义输出格式
        for name, x in xs.items():# 如果返回中间层,out可以按照name存储,返回最后一层则只有layer4
            m = tensor_list.mask
            assert m is not None
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]# 利用插值法产生不同尺度下的mask
            out[name] = NestedTensor(x, mask)
        return out

将上述二者对应组合

class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)# [0]代表backbone的输出
        out: List[NestedTensor] = []
        pos = []
        for name, x in xs.items():
            out.append(x)
            # position encoding
            pos.append(self[1](x).to(x.tensors.dtype))# [1]代表position_embedding

        return out, pos# 返回抽取后的特征及对应的位置编码

Transformer

在这里插入图片描述
TRM部分其实跟Attention is all you need的模型结构完全相同,不同的部分只是decoder部分输入。在原始transformer中,decoder的输入是对应目标序列融合位置编码后的embedding,而在本文中,则使用初始化为全0的tgt作为目标序列,然后再融合query_embed。这里非常容易混淆的一点是:tgt全零序列才是content embedding, 代码中的query_embed是代表目标框集合位置的spatial embedding
由于encoder部分主要是特征提取,对边界定位影响不大,所以我们考虑decoder的cross-attention部分,其输入主要包括三部分:query key value

  • queries:每个 query 都是 decoder 第一层 self-attention 的输出( content query )+ object query( spatial query ),这里的 object query 就是 DETR 中提出的概念,每个 object query 都是候选框的信息,经过 FFN 后能输出位置和类别信息(本文 object query 个数 N 为 100)
  • keys:每个 key 都是 encoder 的输出特征( content key ) + 位置编码( spatial key )构成
  • values:只有来自 encoder 的输出
class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,return_intermediate=return_intermediate_dec)

        self.d_model = d_model
        self.nhead = nhead

    def forward(self, src, mask, query_embed, pos_embed):
        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape
        src = src.flatten(2).permute(2, 0, 1)# flatten(k)表示将[k:n-1]拉平为一个维度
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)# sine
        # num_queries x hidden_dim to num_queries x N x hidden_dim
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
        # NxHxW to NxHW
        mask = mask.flatten(1)

        # decoder embedding,初始化为全0
        tgt = torch.zeros_like(query_embed)
        # encoder特征抽取,得到memory,shape同tgt
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        # decoder特征交互,得到hs
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

DETR

class DETR(nn.Module):
    """ This is the DETR module that performs object detection """
    def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):

        super().__init__()
        self.num_queries = num_queries# object query nums
        self.transformer = transformer
        hidden_dim = transformer.d_model# 隐层维度
        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)# 分类头,最后的类别为:类别数+背景
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)# 检测头,使用三层全连接层进行映射,最后投影到xywh
        self.query_embed = nn.Embedding(num_queries, hidden_dim)# object query
        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)# 将得到的feature map通道归一
        self.backbone = backbone
        self.aux_loss = aux_loss

    def forward(self, samples: NestedTensor):

        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)

        # 得到的feature可能是C3-C5几层,DETR只拿最后一层输入TRM
        src, mask = features[-1].decompose()
        assert mask is not None
        # self.transformer()[0]表示取dncoder的输出,序列1表示encoder输出
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]

        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/975457.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

看了字节跳动月薪25K+测试岗面试题,让我这个工作2年的测试工程师,直冒冷汗....

朋友入职字节已经两周了,整体工作环境还是非常满意的!所以这次特意抽空给我整理了这份面试宝典,而我把它分享给伙伴们,面试&入职的经验! 大概是在8月中的时候他告诉我投递了字节跳动并且简历已通过,8月…

JavaScript - 一个好玩的打字动画

效果预览&#xff1a; <!DOCTYPE html> <html> <head><title>打字动画示例</title><style>.typewriter {color: #000;overflow: hidden; /* 隐藏溢出的文本 */white-space: nowrap; /* 不换行 */border-right: .15em solid #000; /* 添加…

IDEA中debug调试模拟时显示不全(不显示null)的解决

IDEA中debug调试模拟时显示不全&#xff08;不显示null&#xff09;的解决 1、在IDEA中找到File&#xff08;文件&#xff09;->Settings&#xff08;设置&#xff09; 2、依次找到以下内容进行设置&#xff08;原版、汉化版&#xff09;&#xff1a; 打开Build, Executio…

STL常用容器 (C++核心基础教程之STL容器详解)String的API

在C的标准模板库&#xff08;STL&#xff09;中&#xff0c;有多种容器可供使用。以下是一些常见的容器类型&#xff1a; 序列容器&#xff08;Sequential Containers&#xff09;&#xff1a; std::vector&#xff1a;动态数组&#xff0c;支持快速随机访问。 std::list&…

LLM大模型推理加速 vLLM;Qwen vLLM使用案例

参考&#xff1a; https://github.com/vllm-project/vllm https://zhuanlan.zhihu.com/p/645732302 https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html ##文档 1、vLLM 这里使用的cuda版本是11.4&#xff0c;tesla T4卡 加速原理&#xff1a; Paged…

项目招标投标公众号小程序开源版开发

项目招标投标公众号小程序开源版开发 以下是一个招标投标公众号小程序的功能列表&#xff1a; 用户注册与登录&#xff1a;用户可以注册账号并登录公众号小程序。项目发布&#xff1a;用户可以发布招标项目的详细信息&#xff0c;包括项目名称、招标单位、项目描述、招标要求…

巨人互动|游戏出海H5游戏出海规模如何?

H5游戏出海是指将H5游戏推广和运营扩展到国外市场的行为&#xff0c;它的规模受到多个因素的影响。本文小编讲一些关于H5游戏出海规模的详细介绍。 1、市场规模 H5游戏出海的规模首先取决于目标市场的规模。不同国家和地区的游戏市场规模差异很大&#xff0c;有些市场庞大而成…

error:03000086:digital envelope routines::initialization error

项目背景 前端vue项目启动突然报错error:03000086:digital envelope routines::initialization error 我用的开发工具是vscode&#xff0c;node版本是v18.17.0 前端项目版本如下↓ 具体报错如下↓ 报错原因 node版本过高 解决方法 1输入命令 $env:NODE_OPTIONS"--op…

cron介绍

cron表达式在线生成 在使用定时调度任务的时候&#xff0c;我们最常用的&#xff0c;就是cron表达式了。通过cron表达式来指定任务在某个时间点或者周期性的执行。 cron表达式的组成 cron表达式是一个字符串&#xff0c;由6到7个字段组成&#xff0c;用空格分隔。其中前6个字…

融云观察:国内外社交应用差异分析,如何更快补齐功能和切换交互?

9 月 21 日&#xff0c;融云直播课 社交泛娱乐出海最短变现路径如何快速实现一款 1V1 视频社交应用&#xff1f; 欢迎点击上方小程序报名~ 本文中&#xff0c;我们将通过对 WhatsApp、微信、Telegram、Discord 几大社交软件的主要功能和交互体验对比&#xff0c;解析社交软件…

【AIGC专题】Stable Diffusion 从入门到企业级应用0414

一、前言 本文是《Stable Diffusion 从入门到企业级应用实战》系列的第四部分能力进阶篇《Stable Diffusion ControlNet v1.1 图像精准控制》的第0414篇 利用Stable Diffusion ControlNet 法线贴图模型精准控制图像生成。本部分内容&#xff0c;位于整个Stable Diffusion生态体…

融云出海:社交泛娱乐出海,「从 0 到 1」最全攻略

9 月 21 日&#xff0c;融云直播课社交泛娱乐出海最短变现路径如何快速实现一款 1V1 视频社交应用&#xff1f; 欢迎点击上方小程序报名~ 本期我们翻到《地图》的实践篇&#xff0c;从赛道/品类选择、目标地区适配、用户增长、变现模式、本地化运营、跨国团队管理等方面完整描绘…

大健云仓:股价正在艰难反弹,被严重低估了

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 业务和基本面 大健云仓&#xff08;GCT)是一家从事大型家居、家用电器、健身器材和园艺等国际大件商品出口的B2B交易平台&#xff0c;是全球领先的国际贸易数字化服务商。依靠国际化商业架构&#xff0c;自有全球化物流仓储…

Ubuntu20.04同时安装ROS1和ROS2

Ubuntu20.04同时安装ROS1和ROS2 Excerpt 每版的Ubuntu系统版本都有与之对应ROS版本的&#xff0c;每一版ROS都有其对应版本的Ubuntu版本&#xff0c;不可随便装&#xff0c;ubuntu20.04对应ROS1 noetic和ROS2 foxy版本。_ros1和ros2共存 Ubuntu20.04同时安装ROS1和ROS2共存 文…

MES管理系统对电子企业来说有什么优点

引言&#xff1a;在电子制造企业中&#xff0c;MES管理系统已经成为提高生产效率、降低成本、提高订单履行速度和准确性的重要工具。电子企业MES管理系统是一套集成的信息系统&#xff0c;用于监控和控制电子企业的生产过程。本文将探讨MES管理系统对于电子企业来说有哪些优点。…

多线程场景下谨慎使用@Transactional注解,你不信我也没办法

最近遇到一个很诡异的bug&#xff0c;觉得很有趣也很值得分享&#xff0c;于是想写篇文章记录下来&#xff0c;希望有缘人看到以后少踩坑~ 先简单说下场景&#xff1a;有个任务平台&#xff0c;功能很多但我们只关注 提交任务和取消任务 两个功能&#xff0c;并且取消任务后会有…

全民拼购模式:美妆行业的新机遇和挑战

美妆是一个充满创意和变化的行业&#xff0c;每个人都想拥有自己独特的美丽风格。但是&#xff0c;美妆产品的价格和品质却不尽相同&#xff0c;很多消费者在购买时会遇到困惑和不满。有没有一种方法&#xff0c;可以让消费者以更低的价格买到更好的美妆产品&#xff0c;同时还…

八个针对高级职位的高级 JavaScript 面试题

JavaScript 是一种功能强大的语言&#xff0c;是网络的主要构建块之一。这种强大的语言也有一些怪癖。例如&#xff0c;您是否知道 0 -0 的计算结果为 true&#xff0c;或者 Number("") 的结果为 0&#xff1f; 问题是&#xff0c;有时这些怪癖会让你摸不着头脑&…

了解 glTF 2.0 格式

推荐&#xff1a;使用 NSDT场景编辑器快速搭建3D应用场景 介绍 glTF 代表 GL 传输格式。 glTF 是一种用于存储和加载 3D 场景的标准化文件格式&#xff0c;其基本目的是由 3D 创建工具轻松生成并被任何图形应用程序使用&#xff0c;无论使用何种 API&#xff0c;处理最少。 …

软考(1)-面向对象的概念

目录 一. 软考基本信息 1. 软考时间&#xff1a; 2. 软考科目&#xff1a; 3.专业知识介绍 -- 综合知识考点分布 4. 专业介绍 -- 软件设计考点分布 二. 面向对象概念 1. 封装 考点一&#xff1a;对象 考点二&#xff1a;封装private 2. 继承 考点三&#xff1a;类 考…