DINO损失函数构造解析

news2025/2/23 6:38:26

损失函数

首先看下模型的输出结果:
output_cls:torch.Size([2, 900, 3])
output_box:torch.Size([2, 900, 4])
即设置batch-size=2,900个预测框
真值信息如下:第一张图片内有4个真值框,第二张图片亦然

在这里插入图片描述
在这里插入图片描述
随后我们看看如何损失函数的计算,首先是进行匹配,即将预测框与真值框使用二分匹配的方式进行匹配。
outputs_without_aux 原本的output实际为6层的输出结果,我们只取最后一层即可。

outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
device=next(iter(outputs.values())).device
indices = self.matcher(outputs_without_aux, targets)

通过indices = self.matcher(outputs_without_aux, targets)进行二分图匹配进入使预测框与真值框进行匹配。

在这里插入图片描述
统计target数量,如果使用多显卡则进行平均分配

num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()

开始计算损失:
首先是加噪损失,即DINO中使用添加噪声的方式来辅助匈牙利算法收敛,之后便是预测框与真值框的计算损失过程:

if 'interm_outputs' in outputs:
    interm_outputs = outputs['interm_outputs']
    indices = self.matcher(interm_outputs, targets)
    if return_indices:
        indices_list.append(indices)
    for loss in self.losses:
        if loss == 'masks':
            # Intermediate masks losses are too costly to compute, we ignore them.
            continue
        kwargs = {}
        if loss == 'labels':
            # Logging is enabled only for the last layer
            kwargs = {'log': False}
        l_dict = self.get_loss(loss, interm_outputs, targets, indices, num_boxes, **kwargs)
        l_dict = {k + f'_interm': v for k, v in l_dict.items()}
        losses.update(l_dict)

其损失包含类别损失与定位损失。
首先是通过匈牙利匹配算法来进行预测框与真值框的匹配,随后进行计算。获得匹配的编号

在这里插入图片描述
然后判断要计算哪些loss ,self.loss为 [‘labels’, ‘boxes’, ‘cardinality’]
调用get_loss方法计算损失值:

l_dict = self.get_loss(loss, interm_outputs, targets, indices, num_boxes, **kwargs)

loss_map中保存loss_labels与loss_boxes等

def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
    loss_map = {
        'labels': self.loss_labels,
        'cardinality': self.loss_cardinality,
        'boxes': self.loss_boxes,
        'masks': self.loss_masks,
    }
    assert loss in loss_map, f'do you really want to compute {loss} loss?'
    return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

接下来我们看看labels损失与box损失是如何计算的

类别损失

首先是loss_labels :
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
传递的参数为:

在这里插入图片描述

src_logits存储的是类别输出结果

src_logits = outputs['pred_logits']   torch.Size([2, 900, 3])

随后获取匹配结果的batch与编号

 idx = self._get_src_permutation_idx(indices)

在这里插入图片描述
_get_src_permutation_idx 方法的定义可知,其返回结果为所属batch与预测框的编号

   def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

进行类别编号赋值,target_classes_o值为tensor([1, 1, 1, 1, 1, 1, 1, 1], device=‘cuda:0’),即这8个物体预测类别皆为1号类,随后是其余的全部划为背景类,即target_classestorch.Size([2, 900])
类目内容全为3,3为我们设定的背景类别
随后将对应编号的预测框设置为target_classes_o,即预测的类别
idx值为(tensor([0, 0, 0, 0, 1, 1, 1, 1]), tensor([696, 720, 721, 866, 0, 1, 2, 3]))
就相当于batch=0的696号置为target_classes_o设定的值,以次类推

target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o

随后进行one_hot编码
target_classes_onehottorch.Size([2, 900, 4])

在这里插入图片描述

随后计算类别损失,使用的是交叉熵损失sigmoid_focal_loss ,值为:tensor(0.8650, device=‘cuda:0’, grad_fn=)

多分类任务采取的softmax loss,而是使用了多标签分类中的sigmoid
loss(即逐个判断属于每个类别的概率,不要求所有概率的和为1,一个检测框可以属于多个类别),原因是sigmoid的形式训练过程中会更稳定。

并将值保存在 losses = {'loss_ce': loss_ce}
同时class_error为:

losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]

最终返回losses即可得到类别损失值。
接下来便是box回归损失的计算:

Box损失计算

box损失计算传入的参数与类别损失计算传入的参数相同,过程也即为类似,首先是通过匈牙利匹配获取预测框与真值框的匹配关系,紧接着进行L1损失与GIOU损失的计算。
idx值为:(tensor([0, 0, 0, 0, 1, 1, 1, 1]), tensor([696, 720, 721, 866, 0, 1, 2, 3]))
src_boxes值如下:

在这里插入图片描述

target_boxes值如下:

在这里插入图片描述

随后进行计算L1损失:

loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

计算完成的损失值还要除以 num_boxes,值为 8.0

losses['loss_bbox'] = loss_bbox.sum() / num_boxes

紧接着计算GIOU损失值

loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
            box_ops.box_cxcywh_to_xyxy(src_boxes),
            box_ops.box_cxcywh_to_xyxy(target_boxes)))
losses['loss_giou'] = loss_giou.sum() / num_boxes

完整代码如下:

  def loss_boxes(self, outputs, targets, indices, num_boxes):
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes
        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
            box_ops.box_cxcywh_to_xyxy(src_boxes),
            box_ops.box_cxcywh_to_xyxy(target_boxes)))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        # calculate the x,y and h,w loss
        with torch.no_grad():
            losses['loss_xy'] = loss_bbox[..., :2].sum() / num_boxes
            losses['loss_hw'] = loss_bbox[..., 2:].sum() / num_boxes
        return losses

在这里插入图片描述

最终得到的loss_dict为,共有19个

在这里插入图片描述

而真正计算损失的只有9个,这却决于是否给定weight值:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/540821.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面试官:前端如果 100 个请求,你怎么用 Promise 去控制并发?

摘要: 时隔两年半,我,一个卑微的前端菜鸡,又来写面经了!以为钱是程序员年轻奋斗的动力!作为一个程序员,在一个地方慢慢成长后会产生一个能力小提升的一种傲娇!希望你们一跳涨好几丈。…

【问题总结】不使用回滚,如何删除/剔除/回退 远程仓库的错误commit。

场景描述 某次使用IDEA操作GIT时,将一些【被忽略】的文件都提交到commit中,并且被push到远程仓库,甚至还被合并到了主分支中该怎么办? 解决思路 分享两种思路 删除/回退/剔除 掉远程错误的分支重新commit一次正确的分支 删除…

企业落地数字化转型,如何部署战略规划

当前环境下,各领域企业通过数字化相关的一切技术,以数据为基础、以用户为核心,创建一种新的,或对现有商业模式进行重塑就是数字化转型。这种数字化转型给企业带来的效果就像是一次重构,会对企业的业务流程、思维文化、…

降低 Spark 计算成本 50.18 %,使用 Kyligence 湖仓引擎构建云原生大数据底座,为计算提速 2x

2023 中国开源未来发展峰会于 5 月 13 日成功举办。在大会开源原生商业分论坛,Kyligence 解决方案架构高级总监张小龙发表《云原生大数据底座演进 》主题演讲,向与会嘉宾介绍了他对开源发展的见解,数据底座向云原生湖仓一体架构演进的趋势&am…

建立在Safe生态的—GameFi SocialFi双赛道项目No.1头号玩家

最近大家关注的重点在BRC-20和MEME项目,人们似乎更在意短期的投机回报。而在这之外,一个web3的游戏——No.1头号玩家却得到了大量的玩家支持。 据了解,No.1是一个GameFi & SocialFi的双赛道web3游戏,中文名称为头号玩家。它是…

系统分析师上午题,经典易错题目

CRC即循环冗余校验码(Cyclic Redundancy Check)是数据通信领域中最常用的一种差错校验码,其特征是信息字段和校验字段的长度可以任意选定。在CRC校验方法中,进行多项式除法(模2除法)运算后的余数为校验字段。第一个空的分析,系统读记录的时间为33/11=3ms,对第一种情况:…

计算机毕业论文选题推荐|软件工程|系列七

文章目录 导文题目导文 计算机毕业论文选题推荐|软件工程 (***语言)==使用其他任何编程语言 例如:基于(***语言)门窗账务管理系统的设计与实现 得到:基于JAVA门窗账务管理系统的设计与实现 基于vue门窗账务管理系统的设计与实现 等等 题目 基于(***语言)的家政服务系统…

Android 打开webView黑屏闪烁问题排查

______ NO.1 ______ 前言 最近在研发项目的时候,有个模块调用webView功能; 点击搜索框,进入webView页面,出现了黑色过渡页面,效果如下: ______ NO.2 ______ 排查问题 个人在排查此问题的时候,用…

Redis缓存双写一致性之更新策略

Redis缓存双写一致性之更新策略 一 面试题引入二 缓存双写一致性三 双写双检加锁策略四 数据库和缓存一致性的集中更新策略4.1 最终一致性4.2 可以关机的情况下4.3 不能关机的情况下,四种更新策略4.3.1 先更新数据库,再更新缓存4.3.2 先更新缓存&#xf…

【算法学习系列】03 - 由[1-5]等概率随机实现[2-10]等概率随机

文章目录 约定条件说明解决方案构造 0 1 发生器函数 f2()计算需要几个二进制位验证 2-10 等概率返回某个整数 总结 约定条件说明 假定 f() 是一个函数,保证 [1, 5] 范围内等概率返回一个整数实现 2-10 等概率随机不能使用 Math.random() 函数,只能使用函…

栈与队列的性质互换

本期内容:栈,队列的定义性质,性质转换 栈,队列的定义性质,性质转换 认识栈实现栈 队列实现 性质转换 认识栈 栈(stack)又名堆栈,它是一种运算受限的线性表。限定仅在表尾进行插入和…

【渗透测试】web日志、linux命令、常用知识

文章目录 web日志分析基础知识1. 编码2. 解码工具3. 数据提交方式4. 常见脚本语言5. 日志还原 分析日志1. 分析日志的目的2. 攻击出现的位置3. 攻击常见的语句4. 攻击常见的特点5. 攻击日志分析流程 相关linux命令常用命令系统状态检测命令工作目录切换命令文本文件编辑命令文件…

BlueZ自动连接蓝牙耳机

问题:调好蓝牙之后,出现了一个客户问题,第一次连接好之后,开关机后没法自动连了。 解决方法: 针对这个情况,实际定位一下问题原因,原来是蓝牙耳机每次连时,都要求授权服务: Author…

sqlmap

1、Sqlmap简介: Sqlmap是一个开源的渗透测试工具,可以用来自动化的检测,利用SQL注入漏洞,获取数据库服务器的权限。它具有功能强大的检测引擎,针对各种不同类型数据库的渗透测试的功能选项,包括获取数据库…

Maven安装和配置(详细版)

Maven安装和配置 Maven安装1、安装链接:2、配置环境变量: Maven配置1、修改Maven仓库下载镜像及修改仓库位置:2、在Idea上配置Maven: 测试Maven安装能否安装jar包 Maven安装 1、安装链接: Maven – Download Apache …

使用A100 GPU搭建OBBDetection的运行环境

项目场景: 最近需要复现一篇目标检测论文的代码,文章提供了代码,因此自己根据仓库的说明尝试配置环境运行代码,但遇到了非常多的困难 问题描述 比较老的代码加上比较的GPU,导致了环境在配置的时候困难重重 OBBDetect…

xorm多表连接查询

SQL的连接查询可以将多个表的数据查询出来,形成一个中间表。在sql中为JOIN关键字。最常用的是LEFT JOIN,RIGHT JOIN,INNER JOIN,OUTER JOIN。 xorm框架是基于go语言的orm框架同样支持连接查询,由于xom及支持原生的sql查询也支持基于xorm的方法查询&…

openEuler用户软件仓(EUR)| 近期项目介绍

在操作系统的世界,软件包是一等公民,软件包的丰富程度和是否易于分发,一定程度上决定了操作系统用户和开发者的使用体验.。 EUR(openEuler User Repo)是openEuler社区针对开发者推出的个人软件包托管平台,目的在于为开发者提供一个…

【LeetCode训练营】用栈来实现队列+用队列来实现栈 详解

💯 博客内容:【LeetCode训练营】用栈来实现队列用队列来实现栈 详解 😀 作  者:陈大大陈 🚀 个人简介:一个正在努力学技术的准前端,专注基础和实战分享 ,欢迎私信! …

Requests-翻页请求实现

翻页请求实现 继https://blog.csdn.net/ssslq/article/details/130747686之后,本篇详述在获取了页面第一页之后,如何获取剩余页的标题内容。 网页:https://books.toscrape.com 找规律 同样还是进行页面的检查,切到网络一栏&…