YOLOv8 : TAL与Loss计算

news2025/2/19 6:59:32

YOLOv8 : TAL与Loss计算

1. YOLOv8 Loss计算

        YOLOv8从Anchor-Based换成了Anchor-Free,检测头也换成了Decoupled Head,论文和网络资源中有大量的介绍,本文不做过多的概述。

        Decoupled Head具有提高收敛速度的好处,但另一方面讲,也会遇到分类与回归不对齐的问题。具体来讲,在一些网络中,会通过将feature map中的cell与ground truth进行IOU计算以分配预测所用cell,但用来分类和回归的最佳cell通常不一致。为了解决这一问题,引入了TAL技术。想详细了解这一部分,可以参考“TOOD: Task-aligned One-stage Object Detection(https://arxiv.org/abs/2108.07755v3)”这篇论文。

        YOLOv8采用了TAL(Task Alignment Learning)任务对齐分配技术(正负样本分配),并引入了DFL(Distribution Focal Loss)结合CIoU Loss做回归分支的损失函数,使用BCE做分类损失,使得分类和回归任务之间具有较高的对齐一致性。

2. TAL

        TAL一般用在decoupled head网络中,用于将不同的任务进行对齐。典型的,用来解决分类与回归cell一致性问题,更具体的,TAL用于为计算LOSS所构建的GT feature map的cell分配标签。TAL,一句话,就是给feature map中的每一个cell(当然,也有人称做anchor)分配ground truth框。当然,有的cell能够分配到gt(ground truth)框,有的cell分配不到gt框。根据fm(feature map)与gt的分配情况,构建用于Loss计算的target_labels、target_bboxes和target_scores。

下面结合官方代码(class TaskAlignedAssigner)进行理论与工程化相结合的讲解。

        第一步,计算位置掩码mask_gt,对齐度量矩阵align_metric和IOU矩阵overlaps,三者均为shape(bs, n_max_boxes, na),其中mask_gt标识每一个gt框的topk个匹配cells。align_metric计算方式如下:

此处需要注意,cell_scores是经过mask_gt过滤过的得分矩阵,α默认取值为1.0,默认取值为6.0。

        第二步,为每一个cell选择IOU最大的gt框,并标记。返回每一个cell匹配的gt索引target_gt_idx(shape(bs, na)),每一个cell匹配的gt数量fg_mask(shape(bs, na)),以及更新后的全局gt和anchor的匹配情况mask_pos(shape(bs, n_max_boxes, na))。

     第三步,根据target_gt_idx构建用于loss计算的target_labels(shape(bs, na)), target_bboxes(shape(bs, na, 4))和target_scores(shape(bs, na, num_class))。

接下来做一些代码方面的解释。

        在YOLOv8中,虽然使用了Anchor Free技术,但实际上也是存在Anchor的,那就是Feature Map本身的cell。接下来参照YOLOv8代码中的TaskAlignedAssigner做些了解。

(1) get_pos_mask

        这一部分主要是获得gt候选cell的标记(mask_pos),对齐度量矩阵(align_metric)和gt与cell的IOU矩阵(overlaps)。

mask_pos: shape(bs, n_max_boxes, na),经过筛选的gt候选cell位置标记;

align_metric: shape(bs, n_max_boxes, na), gt候选cell的度量值;

overlaps: shape(bs, n_max_boxes, na),gt与其候选cell的IOU值;

下面就几个关键的节点函数做一些讲解。

mask_gt: shape为(bs, n_max_labels, 1), 实际上,在处理的时候是构建一个GT tensor, shape为(bs, n_max_labels)。我们知道,batch中每一幅图片所拥有的gt box数量并不相同,因此我们需要使用一个mask来标记哪一些是有效的,哪一些是无效的。

select_candidates_in_gts

将每一个GT Box与所有的cells进行ltrb的计算,本质上是确定哪些cell的中心点落在了GT范围内。如图一所示,蓝色半透明框为GT,那么橙色狂所标识的cell都被选为候选cell。

图一 Candidates cells

最终返回一个shape(ngt, n_max_labels, na)的tensor。

get_box_metrics

一个关键的导入参数是mask_gt, 用来标记对应每一个gt,中心点位于该gt内部的cell索引,shape为(bs, n_max_boxes)。我们在此称gt候选cell

bbox_scores, shape(bs, n_max_boxes, na), 标识gt候选cell的得分,首先针对每一个gt,根据其lebel,获取对应所有cell的得分,然后通过mask_gt进行索引,得到每一个gt候选cell的得分。

overlaps,shape(bs, n_max_boxes, na), 标识gt候选cell的IOU信息。

align_metric, shape(bs, n_max_boxes, na), 对齐度量矩阵。

返回两个tensor, 其中第一个tensor是一种度量,shape为(bs, n_max_labels, total_cells)。第二各参数是gt与pred box的iou,shape为(bs, n_max_labels, total_cells)。

select_topk_candidates

首先通过torch.topk函数对metrics(align_metric)进行排序筛选,每个gt候选cell选取前topk个。得到topk_metricstopk_idxs, shape均为(bs, n_max_boxes, topk)。

counter_tensor, shape(bs, n_max_boxes, na), 取值非0即1,取值1代表当前cell的度量值位于前topk。

总结为如下4个步骤:

  • 构建gt候选cell;
  • 构建gt候选cell的得分矩阵,IOU矩阵和对齐度量矩阵;
  • 对对齐度量矩阵执行topk操作,标记符合topk的位置;
  • 使用topk、候选cell和mask_gt执行过滤。

(2) select_highest_overlaps

参数mask_pos实际上是gt候选cell标记矩阵。

fg_mask = mask_pos.sum(-2)

计算每一个cell对应的gt数量。

当某一个cell服务于多个gt时,我们将gt与cell的IOU进行排序,并取iou最大的gt作为cell所最终服务的gt。

(3) get_targets

构建用于计算loss的信息,包括target_labels, target_bboxes, target_scores。

target_labels: shape(bs, na, 1)

target_bboxes: shape(bs, na, 4)

target_scores: shape(bs, na, num_classes)

3. DFL

        DFL(Distribution Focal Loss),本质上是Focal Loss,是一种带权重的交叉熵。一般情况下,我们认为交叉熵常用作分类损失,根本上讲,是用在计算一种符合多项分布的预测Loss。

        在论文“Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection”中,作者认为预测的目标框坐标是固定的,不能够灵活的表示(如图二所示)。针对一些便捷比较模糊的目标,很难确定边界的具体位置。DFL将边界表示成一种分布,解决边界不明确的问题。关于DFL具体理论,我们将做一个专题讲解

图二 边界分布

        在官方代码中,网络输出pred_distri为一个shape(bs, 64, na)的Tensor,进一步permute为shape(bs, na, 64)的Tensor,再经过reshape为shape(bs, na, 4, 16)的Tensor,最后经过加权计算,获得shape(bs, na, 4)的LTRB输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/863748.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华为OD机试真题 Java 实现【城市聚集度】【2023 B卷 200分】,附详细解题思路

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入2、输出3、说明 华为OD机试 2023B卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷&#…

Linux网络服务之DNS域名解析

重要的DNS域名解析 一、DNS概述1.1 DNS简介1.2 本地hosts文件1.3 DNS架构1.4 查询方式 二、DNS域名解析原理2.1 解析类型2.2 原理详解2.3 举例 三、bind服务端程序3.1 什么是bind?3.2 配置文件详解3.2.1 主配置文件概述及内容主要格式3.2.2 域名文件概述及内容主要格…

leetcode118. 119.杨辉三角

118 题目: 给定一个非负整数 numRows,生成「杨辉三角」的前 numRows 行。 在「杨辉三角」中,每个数是它左上方和右上方的数的和。 思路: 可以发现从第三行开始,从第二个元素到倒数第二个元素,每个元素都…

电视盒子什么品牌好?实测20天后分享电视盒子推荐

电视盒子可以让老旧电视机重生,解决卡顿、资源少等问题,只需要联网就能观看海量视频资源。不过对于电视盒子如何选购很多人并不了解,我通过对比十几款主流电视盒子后整理了这份电视盒子推荐清单,跟着我一起看看电视盒子什么品牌好…

记一件异常访问记录

一、问题描述 某安全护网期间,web日志中大量出现异常账户请求,虽然报404错误,但是不同异常账号的连续尝试在特殊时期,还是令人担忧. 进程如下:/usr/bin/python2 -Es /usr/sbin/tuned -l -P 二、处理及说明 1&#x…

涛思数据联合长虹佳华、阿里云 Marketplace 正式发布 TDengine Cloud

近日,涛思数据联合长虹佳华,正式在阿里云 Marketplace 发布全托管的时序数据云平台 TDengine Cloud,为用户提供更加丰富的订购渠道。目前用户可通过阿里云 Marketplace 轻松实现 TDengine Cloud 的订阅与部署,以最低的成本搭建最高…

跨境多语言商城源码搭建--定制代码+源码开源

搭建一个跨境多语言商城需要以下步骤: 1. 确定需求:首先,需要明确商城的功能和需求,比如支持哪些语言、支持哪些支付方式、支持哪些货币等。根据需求来决定使用的开发语言和技术栈。 2. 寻找源码:可以在互联网上搜索…

赛码网-上台阶(dp) 100%AC代码(C)

———————————————————————————————————— ⏩ 大家好哇!我是小光,嵌入式爱好者,一个想要成为系统架构师的大三学生。 ⏩最近在准备秋招,一直在练习编程。 ⏩本篇文章对赛码网的上台阶 题目做一个…

【Pytorch:nn.Embedding】简介以及使用方法:用于生成固定数量的具有指定维度的嵌入向量embedding vector

文章目录 1、nn.Embedding2、使用场景 1、nn.Embedding 首先我们讲解一下关于嵌入向量embedding vector的概念 1)在自然语言处理NLP领域,是将单词、短语或其他文本单位映射到一个固定长度的实数向量空间中。嵌入向量具有较低的维度,通常在几…

医院国际医疗中心智能化系统规划设计方案[81页PPT]

导读:原文《医院国际医疗中心智能化系统规划设计方案[81页PPT]》(获取来源见文尾),本文精选其中精华及架构部分,逻辑清晰、内容完整,为快速形成售前方案提供参考。 完整版领取方式 完整版领取方式&#xff…

如何对分布式光伏发电站进行智能化监测?安科瑞 顾语欢

—、概述 随着“双碳”目标提出及逐步落实,本就呈现出较好发展势头的分布式光伏发展有望大幅提速。“双碳”目标意味 着国家产业结构的调整,未来10年,新能源装机将保持在110GW以上的年增速,这里面包含集中式光伏电站和分布式光伏…

02 - git 文件重命名

第一种方式: mv kongfu_person.txt kongfu.txt git add .第二种方式: git mv kongfu_person.txt kongfu.txt

Baklib: 逆袭语雀的在线帮助中心,知识库管理工具

1. 介绍 在现代的技术发展中,知识管理变得越来越重要。特别是对于企业来说,拥有一个高效的知识库管理工具可以极大地提高工作效率和团队合作。Baklib就是这样一款在线帮助中心和知识库管理工具,它可以帮助企业集中管理和共享知识&#xff0c…

红帽8.2版本CSA题库:第七题配置 NTP

红帽8.2版本CSA题库:第七题配置 NTP systemctl status chronyd #查看状态 yum -y install chrony #如果没有安装,就安装一下 vim /etc/chrony.conf server materials.example.com iburst :wq syste…

MySQL缓存策略

文章目录 一、MySQL缓存方案的作用二、提高MySQL访问性能的方式2.1 读写分离2.1.1 是什么?2.1.2 解决了什么?2.1.3 原理是什么? 2.2 连接池2.1.1 是什么?2.1.2 解决了什么?2.1.3 原理是什么? 2.3 异步连接2…

【论文阅读】NoDoze:使用自动来源分类对抗威胁警报疲劳(NDSS-2019)

NODOZE: Combatting Threat Alert Fatigue with Automated Provenance Triage 伊利诺伊大学芝加哥分校 Hassan W U, Guo S, Li D, et al. Nodoze: Combatting threat alert fatigue with automated provenance triage[C]//network and distributed systems security symposium.…

Arcgis中直接通过sde更新sqlserver空间数据库失败

问题 背景 不知道有没有人经历过这样一个情况,我们直接在Arcgis中通过sde更新serserver数据库会失败,就是虽然在sde更新sqlserver数据库,但是在Navicat中通过sql语句来查询,发现数据并没有更新,如:上图中,更新数据库后,第一张图是sde打开的sqlserver数据库,它的数据库…

自动测试框架airtest应用一:将XX读书书籍保存为PDF

一、Airtest的简介 Airtest是网易出品的一款基于图像识别和poco控件识别的一款UI自动化测试工具。Airtest的框架是网易团队自己开发的一个图像识别框架,这个框架的祖宗就是一种新颖的图形脚本语言Sikuli。Sikuli这个框架的原理是这样的,计算机用户不需要…

24届近5年南京工业大学自动化考研院校分析

今天给大家带来的是南京工业大学控制考研分析 满满干货~还不快快点赞收藏 一、南京工业大学 学校简介 南京工业大学(Nanjing Tech University),简称“南工”,位于江苏省南京市,由国家国防科技工业局、住…