【Transformer】detr梳理

news2024/12/27 12:38:38

every blog every motto: You can do more than you think.
https://blog.csdn.net/weixin_39190382?type=blog

0. 前言

detr

detr

1. 引言

论文: https://arxiv.org/pdf/2005.12872v3.pdf
时间: 2020.5.26
作者: Nicolas Carion?, Francisco Massa?, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, and Sergey Zagoruyko Facebook AI

最大的特点是不需要预定义anchor,也不需要NMS,就可以实现端到端的目标检测。

在大目标上检测性能较好,而小目标较差。


2. 整体结构

整体结构如下
20240416163908

目标检测的图片一般都比较大,直接用Transformer计算量比较大,在SWin出现之前一般都是用CNN进行特征提取和尺寸缩放。

主要分三部分:

  • backbone
  • transformer
  • head

3. Backbone

主要是使用resnet50作为backbone,特征图进行了缩小,即(h/32,w/32)

其中还有一个位置编码,计算x和y方向的位置编码。位置编码长度为256,前128代表y方向,后128代表x方向。

20240416174048

4. Transformer

网络结构图如下:
20240416175554

4.1 关于输入shape

其中的输入shape变化: (b,c,h,w) -> (b,c,hw) -> (hw,b,c)
其中c是通道,表示token的维度,b为batch,表示token的个数,
按照参考文献所说,将像素作为一个token,但这种shape应该不是单个像素作为token。

# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)

而vit中的shape变化:(b,c,h,w) -> (b,c,hw) -> (b,hw,c),
这样的shape应该是单个像素作为一个token。

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
x = self.proj(x).flatten(2).transpose(1, 2)

4.2 关于Encoder位置编码

在Encoder部分,输入都加入了位置编码,并且只对q和k使用。
20240417093907

更清晰的图:
20240418094919

4.3 关于Decoder

4.3.1 理解1

20240417110538

Decoder包含多个DecoderLayer串联,而每个DecoderLayer中又包含两个串联的MultiheadAttention。

从Transformer局部代码中我们得到:

  • tgt:全0向量
  • memory:encoder的输出

DecoderLayer的输入:

  • tgt:第一个Decoder中该变量为为全0向量,后续该变量为上个Decoder的输出
  • memory:一直为encoder最后的输出。
  1. 对于Decoder中第一个MultiheadAttention

    • q: tgt,加上query_pos
    • k: tgt,加上query_pos
    • v: tgt
  2. 对于Decoder中第二个MultiheadAttention,

    • q: 第一个MultiheadAttention的输出(经过dropout和norm),加上query_pos
    • k: memory,加了pos位置信息,来源于Encoder输出
    • v: memory,来源于Encoder的输出

4.3.1 理解2-简化

更清晰的DecoderLayer图如下:

Decoder的最开始的输入一个一个全0的向量,所以:

q = query pos
k = query pos
v = 0

后续的DecodeLayer的输入是上一个DecodeLayer的输出。

20240417172915

4.4 小结

Encoder的输出shape:(hw,batch,256)

Decoder的输入和输出shape:(batch,100,256)

整个transformer的输出shape:(batch,100,256)

object queries 维度为 [100, batch, 256] ,类型为 nn.Embedding 说明这个张量是学习得到的, Object queries 充当的其实是位置编码的作用,只不过它是可以学习的位置编码。


Transformer局部:

    def forward(self, src, mask, query_embed, pos_embed):
        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape
        src = src.flatten(2).permute(2, 0, 1)
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
        mask = mask.flatten(1)

        tgt = torch.zeros_like(query_embed)
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

DecoderLayer 局部:

    def forward_post(self, tgt, memory,
                     tgt_mask: Optional[Tensor] = None,
                     memory_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        
        # 第一个MultiheadAttention
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # 第二个MultiheadAttention
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

DecoderLayer 整体:

class TransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt, memory,
                     tgt_mask: Optional[Tensor] = None,
                     memory_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        
        # 第一个MultiheadAttention
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # 第二个MultiheadAttention
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward_pre(self, tgt, memory,
                    tgt_mask: Optional[Tensor] = None,
                    memory_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):
        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt2 = self.norm2(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)
        return tgt

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
                                    tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
        return self.forward_post(tgt, memory, tgt_mask, memory_mask,
                                 tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)

Transformer整体代码如下:

class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
                                          return_intermediate=return_intermediate_dec)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, mask, query_embed, pos_embed):
        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape
        src = src.flatten(2).permute(2, 0, 1)
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
        mask = mask.flatten(1)

        tgt = torch.zeros_like(query_embed)
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

5. Head

Head部分是使用一个简单的MLP进行分类和回归。

20240417114941

参考

  1. https://blog.csdn.net/baidu_36913330/article/details/120495817
  2. https://blog.csdn.net/m0_48086806/article/details/132155312
  3. https://zhuanlan.zhihu.com/p/348060767
  4. https://feizaipp.github.io/2021/04/27/transformer-%E5%9C%A8-CV-%E4%B8%AD%E7%9A%84%E5%BA%94%E7%94%A8(%E4%BA%8C)-DETR-%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B%E7%BD%91%E7%BB%9C/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1614954.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

陈奂仁联手 The Sandbox 推出“Hamsterz Doodles”人物化身系列

全新人物化身系列结合艺术与实用性 开创元宇宙新篇章 著名亚洲唱作歌手兼香港电影金像奖得主陈奂仁携手 The Sandbox,兴奋地宣布推出新的元宇宙人物化身系列 —— Hamsterz Doodles 仓鼠涂鸦。 陈奂仁在 The Sandbox 推出 Hamsterz Doodles 系列,将艺术与…

智能家居—ESP32开发环境搭建

相关文章 毕业设计——基于ESP32的智能家居系统(语音识别、APP控制) 智能家居—ESP32开发环境搭建 一、下载安装二、验证三、资料获取 一、下载安装 下载安装 vscode 安装插件 创建工程 二、验证 写一个简单的函数来验证一下功能 void setup() {// put your setup c…

类和对象(2)——封装(封装的概念、包、staic)

前言 面向对象程序三大特性:封装、继承、多态。而类和对象阶段,主要研究的就是封装特性。何为封装呢?简单来说就是套壳屏蔽细节。 一、什么是封装 1.1 概念 将数据和操作数据的方法进行有机结合,隐藏对象的属性和实现细节&…

【码农圈子】想加免费的程序员微信群的看过来

群名:码农圈子 很多人后台反应,最近有没有免费的微信技术交流社群 。今天特意写一篇文章来创建一些只有程序猿的微信群。(广告党慎入!) 这些微信技术群都是完全免费,后续也不会收取任何费用 。 群规则 …

Paragon NTFS如何手动更新? Paragon NTFS格式化硬盘会损失数据吗?

Paragon NTFS for Mac常被用于实现在Mac上读写NTFS格式硬盘,然而,有时用户可能会遇到软件无法自动更新的情况,需要进行手动更新操作。下面我们来看看Paragon NTFS如何手动更新,Paragon NTFS格式化硬盘会损失数据吗的相关内容。 一…

Python 使用 pip 安装 matplotlib 模块(精华版)

pip 安装 matplotlib 模块 1.使用pip安装matplotlib(五步实现):2.使用下载的matplotlib画图: 1.使用pip安装matplotlib(五步实现): 长话短说:本人下载 matplotlib 花了大概三个半小时屡屡碰壁,险些暴走。为了不让新来的小伙伴走我的弯路,特意…

【matlab 代码的python复现】 Matlab实现的滤波器设计实现与Python 的库函数相同实现Scipy

实现一个IIR滤波器的设计 背景 Matlab 设计的滤波器通常封装过于完整,虽然在DSP中能够实现更多功能的滤波器设计但是很难实现Python端口的实现。 我们以一段原始的生物电信号EEG信号进行处理。 EEG信号 1.信号获取 EEG信号通常通过头皮电极,经过多通道采样芯片采样,将获…

mysql面试题八(SQL语句)

目录 1.SQL 基本组成部分 常用操作示例 创建表 插入数据 查询数据 更新数据 删除数据 创建索引 授予用户权限 2.常见的聚合查询 1. 计数(COUNT) 2. 求和(SUM) 3. 平均值(AVG) 4. 最大值&…

使用FPGA实现超前进位加法器

介绍 前面已经向大家介绍过8位逐位进位加法器了,今天向大家介绍4位超前进位加法器。 对于逐位进位加法器来说,计算任意一位的加法运算时,必须等到低位的加法运算结束送来进位才能运行。这种加法器结构简单,但是运算慢。 对于超…

WSL安装-问题解决

WslRegisterDistribution failed with error: 0x8004032d WslRegisterDistribution failed with error: 0x80080005 Error: 0x80080005 ??????? 解决: 1、 winr输入:optionalfeatures.exe 2、打开这两项

钉钉报警的优势在哪里?如何配置钉钉机器人进行报警信息推送?

一、常见的报警方式 1、短信或者电话报警 这样的报警方式更适合高级别的报警提醒,用于处理紧急情况。出现级别不高而又频繁地发送短信会让人产生排斥感,而且电话或者短信的报警方式也存在一定的成本。 2、邮件报警 邮件报警更适用于工作时的提醒&…

支付方式模块代码示例

支付方式模块代码示例 效果展示 <view class"card"><uni-title type"h3" title"支付方式"></uni-title><radio-group change"radioChange"><label class"radio"><view class"zf-t…

ThingsBoard通过规则链使用邮件发送报警信息

1、描述 2、通过规则链路配置发送邮件只需 两步 3、案例 1、基础链路 2、选择变换节点里面的To Email 3、 编辑节点to email 4、 将创建告警与to email链接 5、选择外部节点中的send email 6、配置邮箱相关信息&#xff0c;如过不知道密钥如何获取的&#xff0c;请查看下…

yolo-驾驶行为监测:驾驶分心检测-抽烟打电话检测

在现代交通环境中&#xff0c;随着汽车技术的不断进步和智能驾驶辅助系统的普及&#xff0c;驾驶安全成为了公众关注的焦点之一 。 分心驾驶&#xff0c;尤其是抽烟、打电话等行为&#xff0c;是导致交通事故频发的重要因素。为了解决这一问题&#xff0c;研究人员和工程师们…

JRT质控数据录入

之前有时间做了质控物维护界面&#xff0c;有了维护之后就应该提供可以录入业务数据的功能了&#xff0c;当时给质控物预留了一个“项目批次业务数据”的功能说是业务数据会给每天拷贝维护数据。这次一起补上&#xff0c;展示JRT怎么写质控数据录入的界面。 界面如下&#xff…

【Linux基础】Linux基础概念

目录 前言 浅谈什么是文件&#xff1f; Linux下目录结构的认识及路径 目录结构 路径 家目录 什么是递归式的删除 重定向 输出重定向&#xff1a; 追加重定向&#xff1a; 输入重定向&#xff1a; 命令行管道 shell外壳 为什么需要shell外壳&#xff1f; shell外壳…

智能算法 | Matlab基于CBES融合自适应惯性权重和柯西变异的秃鹰搜索算法

智能算法 | Matlab基于CBES融合自适应惯性权重和柯西变异的秃鹰搜索算法 目录 智能算法 | Matlab基于CBES融合自适应惯性权重和柯西变异的秃鹰搜索算法效果一览基本介绍程序设计参考资料效果一览 基本介绍 Matlab基于CBES融合自适应惯性权重和柯西变异的秃鹰搜索算法 融合自适应…

Linux下SPI设备驱动实验:使用内核提供的读写SPI设备中的数据的函数

一. 简介 前面文章的学习&#xff0c;已经实现了 读写SPI设备中数据的功能。文章如下&#xff1a; Linux下SPI设备驱动实验&#xff1a;验证读写SPI设备中数据的函数功能-CSDN博客 本文来使用内核提供的读写SPI设备中的数据的API函数&#xff0c;来实现读写SPI设备中数据。 …

【竞技宝】中超:国安本轮4比1大胜,张稀哲表现不俗

国安在本轮中超主场跟青岛西海岸相遇,这场比赛球队进攻多点开花,最终以4比1将对手斩落马下,拿到了久违的大胜。其中,张稀哲、李可、王子铭都在比赛中有不俗表现。首先,张稀哲身为国安中场核心,他在比赛中传出了多脚有威胁的球,并且成功帮助队友得分。张稀哲在国安神兵天降的表现…

C# 6.0+JavaScript云LIS系统源码 云LIS实验室信息管理新型解决方案

C# 6.0JavaScript云LIS系统源码  云LIS实验室信息管理新型解决方案 什么是医院云LIS系统&#xff1f; 云LIS是为区域医疗提供临床实验室信息服务的计算机应用程序&#xff0c;可协助区域内所有临床实验室相互协调并完成日常检验工作&#xff0c;对区域内的检验数据进行集中管…