DETR的位置编码

news2024/12/22 6:02:44

记录一下,以防忘记。

首先,致谢知乎vincent DETR论文详解

DETR中有这样一个类和一个包装函数

class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        # type: (Device) -> NestedTensor # noqa
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            assert mask is not None
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general
    if tensor_list[0].ndim == 3:
        if torchvision._is_tracing():
            # nested_tensor_from_tensor_list() does not export well to ONNX
            # call _onnx_nested_tensor_from_tensor_list() instead
            return _onnx_nested_tensor_from_tensor_list(tensor_list)

        # TODO make it support different-sized images
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False
    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)

假如batch_size=2,有两张图片分别为

im0 = torch.rand(3,200,200)
im1 = torch.rand(3,200,250)

我们使用 nested_tensor_from_tensor_list 函数将它们打包在一块,这里调用了NestedTensor类,它的作用就是构成 {tensor, mask} 这么一个数据结构,在这里,tensor就是图片的值,那mask是啥?

当一个batch中的图片大小不一样的时候,我们要把它们处理的整齐,简单说就是把图片都padding成最大的尺寸,padding的方式就是补零,那么batch中的每一张图都有一个mask矩阵,所以mask大小为(2, 200, 250), tensor大小为(2, 3, 200, 250)。

DETR - Backbone 

从DETR的角度来看,当我们用resnet50提取特征得到特征维度为 (2, 1024, 24, 32),这里输出的的mask的维度为 (2, 24, 32),mask使用F.interpolate得到。

DETR -  Position Encoding

首先,DETR官方源码中包括了正弦位置编码和可学习位置编码,我们这次讲下正弦位置编码。首先,Transformer带有位置信息的特征是通过 Feature Embedding + Position Embedding 相加得到的,至于为什么相加,请看这篇博文 为什么Transformer / ViT 中的Position Encoding能和Feature Embedding直接相加?

 在DETR中,位置编码构造方法与Transformer原文中的位置编码一致。

def forward(self, tensor_list: NestedTensor):   
        x = tensor_list.tensors #(2,1024, 24,32)
        mask = tensor_list.mask #(2, 24,32)
        assert mask is not None
        not_mask = ~mask #就是有像素值得位置
        y_embed = not_mask.cumsum(1, dtype=torch.float32) #沿y方向累加,(1,1,1)--(1,2,3)
        # (1,1,1,...) #y_embed
        # (2,2,2,...)
        # (3,3,3,...)
        # (...)
        x_embed = not_mask.cumsum(2, dtype=torch.float32) #沿x方向累加,(1,1,1).T--(1,2,3).T
        # (1,2,3,...) #x_embed
        # (1,2,3,...)
        # (1,2,3,...)
        # (...)
        if self.normalize: #进行归一化
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale #(2,24,32)
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale #(2,24,32)

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        # self.num_pos_feats=128, 
        # dim_t = [1,2,3,4,...,128]
        # 以下按上述公式计算
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
        pos_x = x_embed[:, :, :, None] / dim_t #(2,24,32,128)
        pos_y = y_embed[:, :, :, None] / dim_t #(2,24,32,128)
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos #(2,256, 24,32)

位置编码(positional encoding),最直观方式是将第一个pixel赋予1,第二个pixel赋予2,以此类推,但当pixel序列足够大时,会造成位置嵌入(positional embedding)的值过大,所以采用正弦曲线把值控制在-1到1之间,但由于正弦曲线的周期性,可能会造成不同位置值相同的情况。

因此,作者将positional embedding 扩充为一个d维的向量,这个向量用来为每个pixel提供位置信息,再和该位置的pixel embedding相加,增强模型输入,至于引用d维正弦函数的作用大致是控制不同通道位置编码的波长,波长随d由小到大。以下是positional encoding过程的举例。

 

下图为不同维度对应的正弦曲线,可以看到随d越大,正弦曲线的波长越大,这样做的原因是,如果每个Position只对应一个正弦曲线,那么由于正弦曲线的周期性,P2,P9(不同的位置点)可能计算出相同值。而采用每个位置对应多个维度,即多个不同波长的正弦曲线,任何两位置由d维向量表示,就不会发生位置不同,值相同的情况。

另外positional embedding的周期有 2� 到10000 ∗2� 变化,而每个位置在embedding demension上都会得到不同周期的sin和cos函数的取值组合,从而产生独一的纹理位置信息,最终使模型学到位置之间的依赖关系和自然语言的时序特征(Pixel的时序特征)。

DETR论文中引入attention机制使用的模块是transformer,第一步先要将feature map投射变换成Q,K,V,Q可以理解为语义空间向量投射变维的输出,K可以理解为字典,V为字典对应的输出,通过Q与K点乘(典型的attention操作)得到V的加权系数,然后对V加权求和,最后经过一个前向网络输出类别和坐标预测,transformer丢失位置信息,所以又添加了一个位置编码(Position Enconding),所谓PE是根据目标的位置坐标,将位置坐标转换为固定维度的编码,转换方式是将位置坐标代入不同波长的三角函数里,三角函数天生具有描述相对位置的作用,而之所以要使用不同波长的三角函数进行计算而不是单波长,是为了描述相对位置同时保存绝对位置,因为波长过小,位置较远的像素将超出同一个周期导致绝对位置丢失,所以大家可以粗略的理解为,小波长精确描述距离较近的相对位置,大波长描述距离较远绝对位置。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/773189.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C知道,CSDN 出来的AI尝试

已经上图,算力不知道怎么样。 C知道 (csdn.net)

JDK、JRE与JVM三者之间的关系及区别

文章目录 0、关系1、JDK2、JRE3、JVM 0、关系 JDK JRE Java 开发工具包 [Java,Javac,Javadoc,Javap等]JRE JVM Java 的核心类库 1、JDK 什么是JDK,JDK是用于Java程序开发的最小环境,包含:Java程序设计语言,Java虚拟机&#…

git : 从入门到实战进阶

目录 0. 前言 1. git stash: 暂时保存本地修改 2. git push时发生冲突怎么办? 3. 访问过去的提交版本:git checkout 3.1 detached HEAD 3.2 “detached HEAD”状态下所作的修改会怎样呢? 3.3 “detached HEAD”状态下所作的修改如何汇…

leetcode100.相同的树

⭐️ 题目描述 🌟 leetcode链接:相同的树 1️⃣ 代码: bool isSameTree(struct TreeNode* p, struct TreeNode* q){// 判断两棵树当前结点是否为空if (p NULL && q NULL) {// 说明是相同的return true;}// 来到这里有几种情况// …

causal-learn ModuleNotFoundError: No module named ‘pygam‘

调用 causallearn 库包,测试CAM-UV算法时报错: No module named pygam 解决方法: pip install pygam 参考链接: 【Python Causal Learning Toolbox】causallearn 库包的使用、报错修改_板砖板砖我是兔子的博客-CSDN博客

java ArratList深拷贝

引用深拷贝 便捷方法 class Test {public static void main(String[] args) {ArrayList<User> list new ArrayList<>();for (int i 0; i < 3; i) {User user new User(i, "name" i);list.add(user);}ArrayList<User> list1 new ArrayList…

超细致的性能测试流程,你get了吗?

性能测试&#xff1a;利用工具模拟大量用户操作&#xff0c;验证系统承受的负载情况。 性能测试的目的&#xff1a;找到潜在的性能问题或瓶颈&#xff0c;分析并解决&#xff1b;找出性能变化趋势&#xff0c;为后续扩展系统提供参考。测试监控&#xff1a;基准测试、配置测试…

【JavaEE】DI与DL的介绍-Spring项目的创建-Bean对象的存储与获取

Spring的开发要点总结 文章目录 【JavaEE】Spring的开发要点总结&#xff08;1&#xff09;1. DI 和 DL1.1 DI 依赖注入1.2 DL 依赖查询1.3 DI 与 DL的区别1.4 IoC 与 DI/DL 的区别 2. Spring项目的创建2.1 创建Maven项目2.2 设置国内源2.2.1 勾选2.2.2 删除本地jar包2.2.3 re…

(栈队列堆) 剑指 Offer 31. 栈的压入、弹出序列 ——【Leetcode每日一题】

❓ 剑指 Offer 31. 栈的压入、弹出序列 难度&#xff1a;中等 输入两个整数序列&#xff0c;第一个序列表示栈的压入顺序&#xff0c;请判断第二个序列是否为该栈的弹出顺序。假设压入栈的所有数字均不相等。例如&#xff0c;序列 {1,2,3,4,5} 是某栈的压栈序列&#xff0c;序…

好的CRM需要有哪些特点?

CRM客户管理系统在企业中占有举足轻重的地位&#xff0c;既是战略工具又可以强化部门间的团队协作、优化销售流程、缩短销售周期。市面上crm做得比较好的公司有哪些&#xff1f; 1.上榜Gartner魔力象限 好的CRM市场的引领、产品研发的持续投入、技术创新以及不断增长的市场份…

面试题 02.07. 链表相交

给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表没有交点&#xff0c;返回 null 。 图示两个链表在节点 c1 开始相交&#xff1a; 题目数据 保证 整个链式结构中不存在环。 注意&#xff0c;函数返回结果后&#x…

划片机的作用将晶圆分割成独立的芯片

划片机是将晶圆分割成独立芯片的关键设备之一。在半导体制造过程中&#xff0c;晶圆划片机用于将整个晶圆切割成单个的芯片&#xff0c;这个过程被称为“晶圆分割”或“晶圆切割”。 晶圆划片机通常采用精密的机械传动系统、高精度的切割刀具和先进的控制系统&#xff0c;以确保…

Kafka - AR 、ISR、OSR,以及HW和LEO之间的关系

文章目录 引子举例说明 引子 AR&#xff08;Assigned Replication&#xff09;&#xff1a; 分区中的所有副本统称为AR&#xff08;Assigned Replicas&#xff09; ISR&#xff08;In-Sync Replicas&#xff09;&#xff1a;同步副本集合 ISR是指当前与主副本保持同步的副本集合…

JavaSwing+MySQL的酒店管理系统

点击以下链接获取源码&#xff1a; https://download.csdn.net/download/qq_64505944/88063706?spm1001.2014.3001.5503 JDK1.8、MySQL5.7 功能&#xff1a;散客开单&#xff1a;完成散客的开单&#xff0c;可一次最多开5间相同类型的房间。 2、团体开单&#xff1a;完成团体…

找不到类NoClassDefFoundError: ionetty.util.intemnal.Platformlependent0

解决方案&#xff0c;jdk版本的问题&#xff0c;在project structure 中把项目jdk改为1.8

[GXYCTF2019]simple CPP

前言 三个加密区域&#xff0c;第一次是基本运算&#xff0c;八位叠加&#xff0c;z3方程 分析 第一轮加密&#xff0c;和Dst中模27异或 &#xff08;出题人对动调有很大意见呢&#xff09; 将输入的字符串按八位存入寄存器中&#xff0c;然后将寄存器内容转存到内存 第一次…

Tauri自带命令生成各平台图标

npm命令&#xff1a; npm run tauri icon yarn命令&#xff1a; yarn tauri icon 1.在项目根目录中放置一个app-icon.png (图片)文件: 图片最好长宽比是1:1&#xff1a;(其他好像会报错) 2.执行命令: npm run tauri icon 可以到文件里面查看 如果本地测试&#xff0c;图标…

记一次真实MySQL百万数据优化

证实下确实是150万+数据哈 原SQL 原SQL执行计划 原SQL执行时间 5秒左右 原SQL分析 思路来源 整体看下SQL好像没啥可优化的。那咱们就大错特错了。 可能有人会说B表为啥在A表后面不正常呀,因为这是内连接查询不是左右连接查询。A,B表的顺序是可以交换的(实测无影响) 首先我们…

leetcode97. 交错字符串(算法:动态规划)

题目&#xff1a; 给定三个字符串 s1、s2、s3&#xff0c;请你帮忙验证 s3 是否是由 s1 和 s2 交错 组成的。 两个字符串 s 和 t 交错 的定义与过程如下&#xff0c;其中每个字符串都会被分割成若干 非空 子字符串&#xff1a; s s1 s2 ... sn t t1 t2 ... tm |n - …

Spring Batch之读数据库—StoredProcedureItemReader(四十)

一、StoredProcedureItemReader Spring Batch框架对存储过程提供了支持&#xff0c;StoredProcedureItemReader提供了对存储过程的支持&#xff0c;其运行和JdbcCursorItemReader类似&#xff0c;均是获取游标对象&#xff0c;然后转换为JavaBean对象。 StoredProcedureItemRe…