DETR:End-to-End Object Detection with Transformers笔记

news2025/1/11 10:13:30

文章目录

  • End-to-End Object Detection with Transformers
    • 摘要
    • 本文方法
      • 损失函数
    • 代码实现

End-to-End Object Detection with Transformers

摘要

提出了一种将目标检测视为直接集预测问题的新方法。我们的方法简化了检测管道,有效地消除了许多手工设计的组件,如非最大抑制过程或锚生成,这些组件显式地编码了我们对任务的先验知识。新框架的主要组成部分,称为检测变压器或DETR,是一个基于集合的全局损失,通过二部匹配强制进行唯一预测,以及一个变压器编码器-解码器架构。给定一组固定的小学习对象查询,DETR对对象和全局图像上下文的关系进行推理,以直接并行输出最终的预测集。新模型在概念上很简单,不像许多其他现代探测器那样需要专门的库。
代码地址

本文方法

在这里插入图片描述
DETR通过将普通CNN与transformer架构相结合,直接预测(并行)最终检测集。

详细结构:
在这里插入图片描述
DETR使用传统的CNN主干来学习输入图像的二维表示。在将其传递到变压器编码器之前,模型将其扁平化并使用位置编码进行补充。然后,转换器解码器将少量固定数量的学习到的位置嵌入(我们称之为对象查询)作为输入,并额外关注编码器输出。我们将解码器的每个输出嵌入传递给一个共享前馈网络(FFN),该网络预测检测(类和边界框)或“无对象”类。

损失函数

在通过解码器的单次传递中,DETR推断出固定大小的N个预测集,其中N被设置为明显大于图像中典型对象的数量。训练的主要困难之一是根据真实情况对预测对象(类别、位置、大小)进行评分。我们的损失在预测对象和真实对象之间产生最优的二部匹配,然后优化对象特定的(边界盒)损失

让我们用y表示对象的基本真实集。假设N大于图像中对象的数量,我们也将y视为大小为N的集合,其中填充了?(无对象)。为了找到这两个集合之间的二部匹配,我们搜索代价最小的N个元素σ 2 SN的排列:
在这里插入图片描述
在这里插入图片描述

代码实现

第一步:提取CNN特征加上位置编码

features, pos = self.backbone(samples)

提取CNN特征的主要函数

backbone = getattr(torchvision.models, name)(
            replace_stride_with_dilation=[False, False, dilation],
            pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
        num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
        super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

将得到的特征送入transformer

hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]

transformer里面包含编码和解码,通过维度变换已经转为词嵌入格式了

		bs, c, h, w = src.shape
        src = src.flatten(2).permute(2, 0, 1)
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
        mask = mask.flatten(1)

        tgt = torch.zeros_like(query_embed)
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

编码器的代码:

 		src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

解码器的代码,返回的仍然是词嵌入格式

		tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt2 = self.norm2(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)
        return tgt

然后全连接到类别数以及box的数量得到
outputs_class = {Tensor: (6, 2, 100, 11)}#类别的
outputs_coord = {Tensor: (6, 2, 100, 4)} #框的

最后增加一个辅助损失:

out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/694359.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Win10实时保护老是自动开启怎么办?

Win10实时保护老是自动开启怎么办?使用Win10电脑的用户遇到了实时保护老是自动开启的问题,想知道怎么操作才能解决此问题,这时候用户需要打开电脑的组策略编辑器,然后找到管理模板中的Windows Defender选项,点击关闭Wi…

Intellij IDEA HTTP Request 请求设置Cookie

使用Intellij IDEA 的 HTTP Request 请求中带有Cookie时,需要将 Cookie单词要写成全小写的“cookie”,否则设置的Cookie不会生效。 POST http://localhost:9091/rest/miracle/findList Content-Type: application/json cookie: JSESSIONIDce22a4ed-b185…

Redis基本数据类型

string(字符串):最常见的用户是缓存用户信息,将用户信息结构体使用JSON序列化成字符串,然后将序列化后的字符串塞进Redis来缓存,然后取用户信息的过程会经历一次反序列化的过程。 Redis的字符串是动态字符…

互联网+洗鞋店预约小程序新模式;

互联网洗鞋店预约小程序 1、线上线下业务的结合。 传统的线下业务消费者到店可以向其推介线上的预约到家服务,让线下的消费者成为小程序内的会员,留存客户之后线上可直接触达,减少与消费者的距离,从等待客户到可以主动出击&…

Mac 配置Flutter开发环境遇到的坑

1. flutter doctor 命令无反应; 加sudo 执行; sudo flutter doctor执行完后, 在执行flutter doctor就好使了, 就很神奇;(还不行就重启再试试) 2. 创建项目提示没权限, 照提示命令赋予权限即可;(应该是前面sudo命令引起的, 但没找到其他好办法) 最后一个提示的没权限, 照提示运…

MAC中clion的默认clang的bug

问题描述 用clion运行上述代码info报错 >>报错,这个应该是代码风格的问题,还有部分大括号也会报错 解决方法 怀疑是编译器的问题,把clang改成gcc就好了

2023年7月1日(星期六):骑行小石林

2023年7月1日(星期六):骑行小石林(大青山),早8:30到9:00, 大观公园门囗集合,9:30点准时出发 【因迟到者,骑行速度快者,可自行追赶偶遇。】 偶遇地点: 大观公园门囗集合,…

《Pytorch深度学习和图神经网络(卷 1)》学习笔记——第六章

实例5:识别黑白图中的服装图案 Fashion-MNIST是手写数字数据集MNIST的一个替代品,常常被用来测试网络模型,如果在该数据集上效果都不好,其他数据集上的效果可想而知。其单个样本为28X28,6万张训练集,1万张…

【多线程】实现一个线程池

1. 线程池的概念 1.1 什么是线程池? 线程池也是一种线程的使用方式,前面刚开始学习多线程的时候,我们了解到线程太多,会带来 CPU 的调度开销。 以前我们都是一个线程执行一个任务(一个run方法),就好比搬砖,…

【Neo4j】图数据库安装和演示

部署图库 环境Win10Docker Desktop Neo4j 寻找容器,拉取容器,查询容器 docker search neo4j docker pull neo4j docker images参考说明 docker run -d --name neo4j \ //-d表示容器后台运行 --name指定容器名字-p 17474:7474 -p 17687:7687 \ //映射…

Tex表格代码--stat期刊

Tex表格代码1: \begin{center} \begin{table*}[t]% \caption{AAAAAA.\label{Table:BBB}} \centering \begin{tabular*}{500pt}{{\extracolsep\fill}lccD{.}{.}{3}c{\extracolsep\fill}} \toprule &\multicolumn{2}{{}c{}}{\textbf{Spanned heading\tnote{1}}} …

Python(六)函数

函数是一个工具,在输入和输出之间构造一个关系;使用函数方便了代码的复用,避免重新造轮子; 目录 函数的分类 内置函数 自定义函数 函数几种格式对比 无参数,无返回值 有参数,无返回值 无参数&#…

ElasticSearch——地理坐标查询

Elasticsearch 语雀(完整笔记) 所谓的地理坐标查询,其实就是根据经纬度查询,官方文档:Geo queries | Elasticsearch Guide [8.8] | Elastic 常见的使用场景包括: 携程:搜索我附近的酒店滴滴…

Linux服务器Jenkins部署打包Flutter

程序猿日常 记Jenkins部署打包Flutter参考Linux服务器Jenkins部署打包Flutter 安装Flutter环境 Flutter SDK 下载地址 配置服务器Flutter环境变量 创建任务 #!/bin/bash -ilex source /etc/profileflutter clean flutter pub get flutter build apk

8.OpenCV-识别身份证号码(Python)

需求描述: 通过OpenCV识别身份证照片上的身份证号码(仅识别身份证号码) 实现思路: 1.将身份证号中的0,1,2,3,4,5,6,7,8,9作为模板,与身份证照片中的身份证号码区域进行模板匹配。 2.先要制作一个身份证号码模板&am…

坚鹏:中国邮储银行金融科技前沿技术发展与应用场景第1期培训

中国邮政储蓄银行金融科技前沿技术发展与应用场景第1期培训圆满结束 中国邮政储蓄银行拥有优良的资产质量和显著的成长潜力,是中国领先的大型零售银行。2016年9月在香港联交所挂牌上市,2019年12月在上交所挂牌上市。中国邮政储蓄银行拥有近4万个营业网点…

基于java+swing+mysql图书管理系统V6.0

基于javaswingmysql图书管理系统V6.0 一、系统介绍二、功能展示1.项目骨架2.数据库表3.项目内容4.登陆界面5.管理员-读者注册6、管理员-书籍入库7、管理员-书籍更新8、管理员-书库管理9、管理员-读者更新10、用户-还书11、用户-借书 四、其它1.其他系统实现五.获取源码 一、系统…

【3Ds Max】常用的基本初始化设置

目录 一、单位设置 二、首选项设置 2.1 撤销次数设置 2.2 设置保存时压缩 2.3 设置自动保存时间间隔 2.4 选中模型时高亮显示 一、单位设置 我们以设置毫米单位为例 在 “自定义-》单位设置” 中进行设置 点击“系统单位设置”按钮 如下设置就表示:1个单位长度…

Jmeter_响应数据为空以及中文乱码

目录 一、响应数据为空 解决方法 二、响应中文乱码 产生原因 解决方法 一、响应数据为空 最近做测试接口,使用同样的请求方式、地址、参数和header,在postman中能正常响应,接收数据的也正常,但是在Jmeter中,虽然…

FPGA-DFPGL22学习4-仿真平台学习

文章目录 前言一、仿真的步骤二、使用步骤1.PDS编译仿真库2.编写仿真tb文件3.选择行为仿真4.查看观察窗口5.修改代码后重新编译 总结 前言 和原子哥一起学习FPGA 开发环境:正点原子 ATK-DFPGL22G 开发板 参考书籍: 《ATK-DFPGL22G之FPGA开发指南_V1.1…