DINO-DETR匈牙利匹配与加噪过程学习记录

news2024/9/27 21:31:21

今天再来回顾一下DINO中匈牙利匹配与损失函数部分,该部分大致与DETR相似,却又略有不同。
为了查看数据方便,博主将num_query改为20,max_select值也为20。

匈牙利匹配过程

首先是数据送入匈牙利匹配中进行标签匹配过程了。

获取预测的类别,box信息

bs, num_queries = outputs["pred_logits"].shape[:2]
#获取预测值信息
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()  
#[batch_size * num_queries, num_classes]  torch.Size([40, 4]) 4为类别数目
out_bbox = outputs["pred_boxes"].flatten(0, 1)  
# [batch_size * num_queries, 4]  torch.Size([40, 4]) 4为xywh数据

获取真实框的类别与box信息

tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])

计算Focal_loss:pos_cost_class与neg_cost_class皆为:torch.Size([40, 4]),得到cost_class为:torch.Size([40, 5]),cost_class为每个query与target的损失。

alpha = self.focal_alpha  #0.25
gamma = 2.0   
neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]

计算L1距离得到cost_bbox为:torch.Size([40, 5])

cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

计算giou,得到cost-giou为:torch.Size([40, 5])

cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))

构成cost矩阵

C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu()

中间步骤C:

在这里插入图片描述

最终形成的C:torch.Size([2, 20, 5])

在这里插入图片描述

获取每个batch中对应的标签个数

sizes = [len(v["boxes"]) for v in targets]

使用匈牙利匹配算法进行计算,得出匹配的标签与预测框。

indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]

得出的indices是list形式,内部每个元素为tuple,其内为array对应标签id与预测框id。

在这里插入图片描述

将indices转换为tensor向量形式。

return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

返回的indices如下:

在这里插入图片描述

CDN构造过程

CDN是DINO的创新点之一,其是如何构造的呢?举个例子:
设置batch_size=2,第一张图片有一个标注框,第二张图片有4个标注框
开始设置参数dn_number=100,即添加噪声的query有100个,同时要设置对照组,也是100
dn_number=200,注意这里是设置dn_query的个数
随后判断设置多少个对照组,根据每个batch中最大的tgt数目设置dn_group。

		known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
		#[tensor([1], device='cuda:0'), tensor([1, 1, 1, 1], device='cuda:0')]
        batch_size = len(known)#2
        known_num = [sum(k) for k in known]
        #[tensor(1, device='cuda:0'), tensor(4, device='cuda:0')]
        if int(max(known_num)) == 0:
            dn_number = 1
        else:
            if dn_number >= 100:
                dn_number = dn_number // (int(max(known_num) * 2))
                #确定dn_number=25

什么意思呢,就是说总共我每个batch中设置25组即可。
然后总共正样本数为(1+4)x25=125,同理负样本数也是如此,两者加起来总共有250个
随后对标签进行加噪。
分别得到编码后的标签类别与box:(input_label_embed等都需经过embed编码)

input_label_embed:torch.Size([250, 256])
input_bbox_embed:torch.Size([250, 4])

由于我们设置了dn_query数目固定为200,生成dn_query:
初始时全为0,

padding_label = torch.zeros(pad_size, hidden_dim).cuda()#torch.Size([200, 256])
padding_bbox = torch.zeros(pad_size, 4).cuda()#torch.Size([200, 4])
随后复制batch维度:
input_query_label = padding_label.repeat(batch_size, 1, 1)#torch.Size([2, 200, 256])
input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)#torch.Size([2, 200, 4])

可以看到,此时其全部为0,那么如何将我们加噪后的query放进去呢?

if len(known_num):
   map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num])  # [1,2, 1,2,3]
   map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
if len(known_bid):
   input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed
   input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed

第一个判断获取了标识index,第二个判断结合batch_id与indice来进行填充:

举个例子:

input_query_box[0,0]填充input_bbox_embed[0]
input_query_box[1,0]填充input_bbox_embed[1]
input_query_box[1,1]填充input_bbox_embed[2]
input_query_box[1,2]填充input_bbox_embed[3]
input_query_box[1,3]填充input_bbox_embed[4]
input_query_box[0,4]填充input_bbox_embed[5]
input_query_box[1,4]填充input_bbox_embed[6],以此类推

至此便构造出dn_query了,值得一提的是只有最大tgt数目的图像中的query是全部有非零值的,在本次例子中,第一个batch中2x100个query中只有2x25个非零值,而在第二个batch中,全部都是被填充的。

known_bid值如下所示:五个一组

在这里插入图片描述

indices值如下所示,五个一组,配合batch_id可以将加噪后的值填入到query中,共有250个,但其值最大到199,刚好与200对应。

在这里插入图片描述

计算Label Loss

首先看传入的label_loss的参数:

def loss_labels(self, outputs, targets, indices, num_boxes, log=True):

outputs为预测结果:labels:torch.Size([2, 200, 4]) box:torch.Size([2, 200, 4])
targets为真实值
indices为匈牙利匹配结果:

在这里插入图片描述
num_boxes为box的个数,此时为125,需要注意的是在第一次跳入loss_labels时,实际上计算的是DN的损失。

使用dn_query计算loss不易查看(有200个),我们使用匈牙利匹配的结果来查看:
target共有5个,其中第一个batch有一个,第二个batch有4个。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/656922.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

qt.qpa.plugin: Could not load the Qt platform plugin “xcb“ in

兄弟们看看是不是这个错: QObject::moveToThread: Current thread (0xe5205f0) is not the objects thread (0xa14d0f0). Cannot move to target thread (0xe5205f0)qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "/xxx/python3.…

Esp32+Blynk实现云端控制LED开灭

目录 环境配置依赖库安装blynk 基础设置 GPIO 点灯实验 环境配置 依赖库安装 参考 blynk 官方快速上手文档 如果要使用 blynk 提供的环境,我们就必须安装对应的库 选择基于 blynk 且适用于 ESP32 的库并安装到 arduino 上: blynk 基础设置 进入官网并且…

Question1:harbor登录成功,推送镜像失败

denied: requested access to the resource is denied 解决方案 查看用户的权限 Harbor 用户角色权限速查 系统级角色: Harbor 系统管理员:“Harbor 系统管理员”拥有最多的权限。除了上述权限外,“Harbor 系统管理员”还可以列出所有项目、…

一种令人拍案叫绝的 ChatGPT 攻击手段!

公众号关注 “GitHubDaily” 设为 “星标”,每天带你逛 GitHub! 最近看到一个非常巧妙的 ChatGPT 攻击手段,跟大家分享一下,也算是做个提醒。 不论你是否懂技术,我都建议你了解一下这种攻击手段,有备无患。…

Golang的trace性能分析

文章目录 一、trace概述二、trace的使用方式代码中trace采集通过pprof采集 三、trace分析细节trace的web界面trace中需要关注的关注GC的频率关注goroutine调度情况关注goroutine的数量理想情况 四、GC分析当前服务GC情况设置GOGC设置GOMEMLIMITGC阈值的讨论GC的特点 五、gorout…

【每日挠头算法题(8)】最后一个单词的长度|重新排列字符串

文章目录 一、最后一个单词的长度思路1:从后往前遍历具体代码如下: 思路2:具体代码如下: 二、重新排列字符串思路具体代码如下: 一、最后一个单词的长度 点我直达~ 思路1:从后往前遍历 从后往前遍历&…

Stable DiffusionAI绘画一键启动整合包

点击"仙网攻城狮”关注我们哦~ 不当想研发的渗透人不是好运维 让我们每天进步一点点 简介 搞了个Stable DiffusionAI绘画整合包,里面有二次元风格、3D风格、真人模型,需要的后台回复“AI绘画”即可获取下载链接,放几个用SD生成的图。 实战 1.下载好…

调用万维易源API实现图像性别转换

目录 1、作者介绍2、调用万维易源API2.1 API介绍2.2 API调用过程 3、代码实现3.1 实现步骤3.2 完整代码 4、问题与分析 1、作者介绍 梁随欣,男,西安工程大学电子信息学院,2022级研究生 研究方向:模式识别与人工智能 电子邮箱&…

【Java技术专题】「入门到精通系列教程」深入探索Java特性中并发编程体系的原理和实战开发指南( 线程基础技术专题)

深入探索Java特性中并发编程体系的原理和实战开发指南 并发编程介绍什么是并发编程并发编程的好处是什么并发编程的挑战是什么并发编程模型有哪些如何学习并发编程本系列专题文章大全 实战原理计算的问题简单的方法:更快的CPU来遍历靠谱的方法:分而治之来…

Redis客户端 - RedisSerializer

原文首更地址,阅读效果更佳! Redis客户端 - RedisSerializer | CoderMast编程桅杆https://www.codermast.com/database/redis/redistemplate-redis-serializer.html 前景回顾 在上一篇中,我们实现了一个简单的案例,操作一个 St…

【22AP20 解码处理器(Hi3536AV100)】

22AP20 解码处理器(Hi3536AV100) 一、产品简介 22AP20 是针对多路高清/超高清(1080p/4M/5M/4K)智能 NVR 产品应用开发的新一代专业高端 SoC 芯片。22AP20 集成了 ARM Cortex-A55 八核处理器和性能强大的图像分析工具处理器,支持多种智能算法…

webpack无损压缩本地静态资源图片image-minimizer-webpack-plugin

开发如果项目中引用了较多图片,那么图片体积会比较大,将来请求速度比较慢。 我们可以对图片进行压缩,减少图片体积。 一、image-minimizer-webpack-plugin介绍 Image-minimizer-webpack-plugin是一个用于优化和压缩图片的Webpack插件。它使…

Qt5.15.2安装Android开发环境。

文章目录 1.下载并安装JDK1.1.下载1.2.安装 2.修改sdk_definitions.json文件3.QtCreator的配置3.1.设置JDK、Android SDK的路径3.2.设置openssl 现在(20230617)利用QtCreator来配置android开发环境还是挺方便的。基本三步搞定(不过你要先安装…

《分布式中间件技术实战:Java版》学习笔记(一):抢红包

数据库建表 (1)red_send_record 记录用户发送了若干总金额的若干个红包。 CREATE TABLE red_send_record (id int(0) NOT NULL AUTO_INCREMENT,user_id int(0) NOT NULL COMMENT 用户id,red_packet varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL…

笔记本电脑介绍:记录生活,激发灵感

笔记本电脑是一种轻便、便携的电脑,它的出现改变了人们的工作和生活方式,它的优势在于它的小巧、轻便、便携性,可以满足用户的不同需求。本文将从笔记本电脑的结构、功能、优势和应用四个方面进行详细阐述。 一、笔记本电脑的结构 笔记本电…

mysql常见错误汇总

mysql常见错误汇总 列别名问题 可以在查询选择列表中使用别名为列提供 不同名称。可以使用 、 或子句中的别名来引用该列:GROUP BY ORDER BY HAVING SELECT SQRT(a*b) AS root FROM tbl_nameGROUP BY root HAVING root > 0;SELECT id, COUNT(*) AS cnt FROM t…

Grafana安装和实现可视化和告警

1、Grafana安装和实现可视化和告警 Prometheus UI 提供了快速验证 PromQL 以及临时可视化支持的能力,但其可视化能力却比较弱。一般情况下, 我们都用 Grafana 来实现对 Prometheus 的可视化实现。 1.1 什么是 Grafana Grafana 是一个用来展示各种各样…

【Linux】可重入函数

文章目录 前言一. 场景二. 可重入与线程安全结束语 前言 在Linux中,进程/线程可能因为时间片到达,或者其他中断,或者调用系统,需要从用户态切换到内核态,而内核空间会保存切换前,用户代码执行处的上下文&a…

环境搭建【1】VM和ubuntun 环境搭建

1.安装VMware 1.1 下载安装包 (1)官网下载:https://customerconnect.vmware.com/en/downloads/info/slug/desktop_end_user_computing/vmware_workstation_pro/16_0 (2)百度网盘:https://pan.baidu.com/s/…

5.pixi.js编写的塔防游戏(类似保卫萝卜)-子弹跟随精灵移动

游戏说明 一个用pixi.js编写的h5塔防游戏,可以用electron打包为exe,支持移动端,也可以用webview控件打包为app在移动端使用 环境说明 cnpm6.2.0 npm6.14.13 node12.22.7 npminstall3.28.0 yarn1.22.10 npm config list electron_mirr…