【YOLO系列】YOLOX(含代码解析)

news2024/9/29 10:51:27

文章目录

    • 环境配置
      • demo测试
      • 转换成onnx
    • YOLOX
      • 数据增广
      • decoupled head
      • Anchor-free
      • 标签分配
        • get_geometry_constraint
        • SimOTA
    • 总结
    • 参考

【YOLO系列】YOLO v3(网络结构图+代码)
【YOLO 系列】YOLO v4-v5先验知识
【YOLO系列】YOLO v4(网络结构图+代码)
【YOLO系列】YOLO v5(网络结构图+代码)

环境配置

购买的腾讯云服务器,配置为GPU计算型GN7/8核/32GB/5Mbps。conda创建虚拟环境,python版本选择的是3.7。

conda create -n yolox python=3.7
conda activate yolox
git clone git@github.com:Megvii-BaseDetection/YOLOX.git
cd YOLOX
pip install -v -e .

然而在pip 安装时,出现下述错误。issues/1368中存在同样的问题,需要在安装yolox之前,手动安装torch。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ht41Lcsc-1687770183365)(./1687142812458.png)]

在pytorch官方网站上找到GPU版本torch的安装命令。

conda install pytorch1.13.0 torchvision0.14.0 torchaudio==0.13.0 pytorch-cuda=11.6 -c pytorch -c nvidia

然后再运行pip install -v -e . 命令即可安装成功。我的环境配置如下所示,这个配置没有出现pycocotools的安装错误。

conda 23.3.1
python 3.7
torch 1.13.0
cuda 11.6
cudnn 8302

demo测试

下载yolox_tiny的预训练模型,然后执行如下命令。

python tools/demo.py image -f exps/default/yolox_tiny.py -c yolox_tiny.pth --path assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result --device gpu

在这里插入图片描述

转换成onnx

tools目录下的export_onnx文件可以将pth文件转换成onnx文件,这样就可以清晰地查看网络结构。转换onnx模型时,需要检测是否下载了onnxruntime库文件。onnx的算子版本默认是11。

python tools/export_onnx.py --output-name yolox_tiny.onnx -n yolox-tiny -c yolox_tiny.pth

YOLOX

YOLOX的基础模型是YOLOv3-SPP(backbone为DarkNet53+SPP),相比于YOLOv3-SPP,YOLOX添加了EMA权重更新,cosine lr schedule,IoU loss和IoU-aware分支。

YOLOX的backbone和Neck与YOLOv3-v4相比,无太大改动,此篇文章不再赘述YOLOX的backbone和Neck部分。

数据增广

YOLOX的数据增广方式有RandomHorizontalFlip、ColorJitter、multi-scale 、Mosaic和MixUP。舍弃了RandomResizedCrop增广方式,因为官方发现RandomResizedCrop与Mosaic的功能有些重叠。

Mosaic和Mixup需要再训练结束前的15个epoch关掉。官方给出的解释是,Mosaic+Mixup生成的训练图片,脱离自然图片的真实分布,而且Mosaic大量的crop操作会带来很多不准确的标注框。提前关闭Mosaic和Mixup,能够使检测器避开不准确标注框的影响,在自然图片的数据分布下完成最终的收敛。

decoupled head

在YOLO v3-v5版本中,head头同时预测输出目标位置和类别,每个特征图的预测张量为 N × N × [ 3 ∗ ( 4 + 1 + c l a s s e s _ n u m ) ] N \times N \times [ 3*\left( 4 + 1 + classes\_num \right) ] N×N×[3(4+1+classes_num)],其中 N N N是特征图的大小。这种被称为耦合头(Coupled head),它的设计思路简单,仅需要几个全连接层或者卷积层即可。

YOLO中使用的解耦合头(Decoupled head),使用不同的分支分别预测目标位置和类别信息。首先先使用一个 1 × 1 1 \times 1 1×1的卷积减少通道维数,然后并行分类和回归两个分支,每个分支堆叠两个 3 × 3 3 \times 3 3×3的卷积。分类分支用来预测类别信息;回归分支用来输出bbox的位置信息。在回归分支中,添加了IoU 分支。

在这里插入图片描述

在训练过程中,解耦合头相比于耦合头,能够更快地收敛,并且能够学习地更精确。
在这里插入图片描述

Anchor-free

YOLOv1-v5版本皆基于anchor的模型。为了能够取得更好的性能,在训练之前,通常都会通过聚类分析训练数据集的方式确定anchor box的大小,确定的anchor box一般适用于该数据集,通用性不强。而且Anchor的存在增加了检测头的复杂度以及生成结果的数量。YOLOX是Anchor Free的模型,Anchor-Free机制显著减少了需要启发式微调的参数设计的数量和许多其他技巧,比如anchor聚类和grid sensitive。Anchor-Free的机制使得检测在训练和解码阶段,变得相对简单。

将YOLO模型转换成anchor-free的形式是很简单的。将每个位置的预测bbox由3个减少到1个,并且直接预测bbox的四个值,即网格左上角的两个偏移量和预测bbox的宽高。

下图是YOLOX-tiny模型的head部分,图像的输入尺寸是 416 × 416 × 3 416 \times 416 \times 3 416×416×3。YOLOX的基础模型是YOLOv3-SPP,Neck是FPN,有三个预测head头,每个head的输入特征尺寸分别是 52 × 52 52 \times 52 52×52 26 × 26 26 \times 26 26×26 13 × 13 13 \times 13 13×13,相比于原输入图像尺寸,下采样分别是 8 8 8 16 16 16 32 32 32倍。三个head的输出合并转置之后,特征图大小为 1 × 3549 × 85 1 \times 3549 \times 85 1×3549×85,其中 85 85 85是COCO数据集中80个类别预测概率+4个坐标信息+1个IoU值;YOLOX每个网格仅仅预测一个bbox,那么就有 3549 3549 3549个预测bbox。

如果是anchor-base的模型,同样的图像输入,同样的模型结构,则有 3 × ( 13 × 13 + 26 × 26 + 52 × 52 ) × 85 = 904995 3\times \left( 13 \times 13+26 \times 26+52 \times 52\right) \times 85=904995 3×(13×13+26×26+52×52)×85=904995个预测结果。Anchor-Free的预测结果个数为 3549 × 85 = 301665 3549 \times 85 = 301665 3549×85=301665,约等于anchor-base模型的 1 / 3 1/3 1/3

在这 3549 3549 3549个预测框中,其中有2704个预测框所对应的锚框的大小为 8 × 8 8\times 8 8×8; 有676个预测框所对应的锚框的大小为 16 × 16 16\times 16 16×16;有169个预测框所对应的锚框的大小为 32 × 32 32 \times 32 32×32。这 3549 3549 3549个预测框中,只有少部分是正样本,绝大多数是负样本,那么哪些是正样本?如何将正样本预测框挑选出来呢?思路就是将这 3549 3549 3549个预测框和图片上所有的gt框进行关联,从而挑选出正样本,这种关联方式,被称为标签分配。

在这里插入图片描述

标签分配

在YOLOX的官方代码中models/yolo_head.py中get_assignments函数定义了如何进行标签分配,调用了两个主要的函数get_geometry_constraintsimota_matching。get_geometry_constraint函数的主要目的就是:提取落在gt bboxes 一定范围内的所有候选框。simota_matching函数是SimOTA求解。标签分配问题可以转换为标准的OTA问题,但是Sinkhorn-Knopp算法需要多次迭代才能求得最优解,官方发现该算法会导致25%的额外训练时间,因此采用简化版的OTA方法,SimOTA,求解近似最优解。

get_geometry_constraint

get_geometry_constraint的代码如下所示。首先根据偏移量(x_shifts, y_shifts)恢复预测锚框在图片的中心位置;接下来以gt锚框的中心点gt_bboxes_per_image[:, 0:2]为中心,设置边长为 2 × c e n t e r _ d i s t 2 \times center\_dist 2×center_dist的正方形,计算出正方形的左上角(gt_l, gt_t)和右下角(gt_r, gt_b);然后计算预测锚框的中心点与正方形左上角(gt_l, gt_t)和右下角(gt_r, gt_b)的距离c_l, c_t, c_r, c_b,判断c_l, c_t, c_r, c_b是否都大于0,这样就可以将落在正方形范围之内的候选框提取出来了。

def get_geometry_constraint(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts):
       """
       Calculate whether the center of an object is located in a fixed range of an anchor. This is used to avert inappropriate matching. It can also reduce the number of candidate anchors so that the GPU memory is saved.
       """
       expanded_strides_per_image = expanded_strides[0]
       # 锚框在图片上的的中心点(x,y)
       x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)
       y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)

       # in fixed center
       center_radius = 1.5
       center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius
       # 左上角(gt_l, gt_t),右上角(gt_r, gt_b)
       gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist
       gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist
       gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist
       gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist

       # 计算锚框中心点(x,y)与左上角(gt_l, gt_t)、右下角(gt_r, gt_b)两个角点的相应距离
       c_l = x_centers_per_image - gt_bboxes_per_image_l
       c_r = gt_bboxes_per_image_r - x_centers_per_image
       c_t = y_centers_per_image - gt_bboxes_per_image_t
       c_b = gt_bboxes_per_image_b - y_centers_per_image
       # 堆叠
       center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
       # 是否都大于0
       is_in_centers = center_deltas.min(dim=-1).values > 0.0
       anchor_filter = is_in_centers.sum(dim=0) > 0
       # 将落在gt矩形范围内的所有anchors,都提取出来了
       geometry_relation = is_in_centers[:, anchor_filter]

       return anchor_filter, geometry_relation

SimOTA

get_geometry_constraint函数提取落在gt bboxes 一定范围内的所有候选框之后,根据返回的候选框mask分别提取候选预测框bboxes,类别分数cls_preds和前景背景目标分数obj_preds,然后再分别进行回归和分类交叉熵Loss函数计算、计算cost值。相关代码和注释如下所示。

fg_mask, geometry_relation = self.get_geometry_constraint(gt_bboxes_per_image,expanded_strides,x_shifts,y_shifts)
# 依据候选框mask提取候选预测框
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
# 依据候选框mask提取类别分数
cls_preds_ = cls_preds[batch_idx][fg_mask]
# 依据候选框mask提取前景背景目标分数
obj_preds_ = obj_preds[batch_idx][fg_mask]
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]

# 计算gt bboxes和预测bboxes之间的IoU
pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
gt_cls_per_image = (F.one_hot(gt_classes.to(torch.int64),
self.num_classes).float())
# 回归loss
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)

with torch.cuda.amp.autocast(enabled=False):
    cls_preds_ = (cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()).sqrt()
    # 分类交叉熵loss
    pair_wise_cls_loss = F.binary_cross_entropy(
        cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1),
        gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1),
        reduction="none"
    ).sum(-1)
del cls_preds_
# pair-wise matching degree,cost计算
cost = (pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + float(1e6) * (~geometry_relation))

SimOTA的代码和注释如下所示。经过get_geometry_constraint函数从原 3549 3549 3549个预测框初筛一部分预测框之后,SimOTA算是精细化筛选。首先设置候选框的最小数量,从pair_wise_ious中挑选出前k个IoU最大的候选框;接下来再通过topk_ious的信息,动态选择候选框;然后得到matching_matrix;最后过滤掉公用的候选框。

def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
    # 第一步:设置候选框的数量
    # 根据cost值的大小,新建一个全0的matching_matrix
    matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
    # 设置候选框的最小数量
    n_candidate_k = min(10, pair_wise_ious.size(1))
    # 从pair_wise_ious中挑选n_candidate_k个IoU最大的候选框
    topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
    
    # 第二步:通过topk_ious动态挑选候选框
    dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
    # 第三步:得到matching_matrix
    for gt_idx in range(num_gt):
        _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx], largest=False)
        matching_matrix[gt_idx][pos_idx] = 1
    del topk_ious, dynamic_ks, pos_idx

    # 第四步:过滤公用的候选框
    anchor_matching_gt = matching_matrix.sum(0)
    # deal with the case that one anchor matches multiple ground-truths
    if anchor_matching_gt.max() > 1:
        multiple_match_mask = anchor_matching_gt > 1
        _, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0)
        matching_matrix[:, multiple_match_mask] *= 0
        matching_matrix[cost_argmin, multiple_match_mask] = 1
    fg_mask_inboxes = anchor_matching_gt > 0
    num_fg = fg_mask_inboxes.sum().item()

    fg_mask[fg_mask.clone()] = fg_mask_inboxes

    matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
    gt_matched_classes = gt_classes[matched_gt_inds]

    pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
        fg_mask_inboxes]
    return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds

总结

上述描述的anchor-free方法,每个物体仅仅只选择了一个正样本,这样会忽略掉一些高质量的预测框。参考FCOS算法,简单地赋予中心 3 × 3 3 \times 3 3×3区域为正样本,这种方法被成为multi positives。

YOLOX在YOLO v3 的基础上,增添了一些技巧。decoupled head提升了1.1%,Mosaic和MixUP更强的增强方式又提升了2.4%,anchor-free的方式提升了0.9%,multi positives也提升了2.1%,SimOTA的提出不仅减少了训练时间,AP还提升了2.3%,比ultralytics的YOLO v3的AP还增加3%。

在这里插入图片描述

参考

  1. Megvii-BaseDetection/YOLOX
  2. YOLOX: Exceeding YOLO Series in 2021
  3. 如何评价旷视开源的YOLOX,效果超过YOLOv5?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/687610.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023年MathorCup 高校数学建模挑战赛-A 题 量子计算机在信用评分卡组合优化中的应用-思路详解(模型代码答案)

一、题目简析 运筹优化类题目,不同于目标规划,该题限制了必须使用量子退火算法QUBO来进行建模与求解。本身题目并不难,但是该模型较生僻,给出的参考文献需要耗费大量时间去钻研。建议擅长运筹类题目且建模能力强的队伍选择。 二…

用C语言进行学生成绩排序(插入排序算法)

一.排序算法 1.排序 从今天开始我们就要开始学习排序算法啦! 排序,就是重新排列表中的元素,使表中的元素满足按关键字有序的过程。为了查找方便,通常希望计算机中的表是按关键字有序的。 2.稳定性 除了我们之前了解的时间复杂度和空间复…

基于springboot+Redis的前后端分离项目之分布式锁(四)-【黑马点评】

🎁🎁资源文件分享 链接:https://pan.baidu.com/s/1189u6u4icQYHg_9_7ovWmA?pwdeh11 提取码:eh11 分布式锁 分布式锁1 、基本原理和实现方式对比2 、Redis分布式锁的实现核心思路3 、实现分布式锁版本一4 、Redis分布式锁误删情况…

S3版本控制,复制和生命周期配置

Hello大家好。 在本课时我们将讨论S3的三个功能特性,这三个特性有一些相关性,即版本控制,复制和生命周期配置。 S3版本控制 首先版本控制,是将对象的多个版本保存在同一存储桶的方法。换句话说,您上传一个对…

数据结构--顺序表的查找

数据结构–顺序表的查找 顺序表的按位查找 目标: GetElem(L,i):按位查找操作。获取表L中第i个位置的元素的值。 代码实现 #define MaxSize 10 typedef struct {ElemType data[MaxSize];int len; }Sqlist;ElemType GetElem(Sqlist L, int i) {return L.data[i-1]…

海外问卷调查项目可靠吗?是违法的吗?

可靠。 最近,一个备受瞩目的创业项目在社会上引起了广泛关注,这个项目集创业、全职和兼职于一体,被称为"海外问卷调查项目",成为了无数人追逐的新选择。 然而,自中美贸易摩擦以来,中国人对&quo…

使用CloudOS快速实现K8S容器化部署

关于容器技术 容器技术(以docker和Kubernetes为代表)呱呱坠地到如今,在国内经历了如下3个阶段: 婴儿期:2014-2016年的技术探索期; 少儿期:2017-2018年的行业试水期; 少年期&…

1.设计模式之七大原则和介绍

0.为什么我要学习设计模式呢? 我发现mysql的jdbc有factory有工厂模式(编程思想,不指定语言都可以用) mq有一个QueueBuilder().setArg().xxx().build建造者模式,单例模式貌似也遇到过,前端也遇到了好几个设计模式的问题,比如prototype深拷贝和浅拷贝 所以我决定系统的学习一下设…

TC8:SOMEIP_ETS_004-005

SOMEIP_ETS_004: Burst_Test 目的 检查DUT是否可以在短时间内处理突发请求并返回所有请求的响应 测试步骤 Tester:新建有效SOME/IP消息Tester:使用method echoUINT8发送突发SOME/IP Request消息DUT:返回每个请求消息的响应消息期望结果 3、DUT:返回每个请求消息的响应消息…

学redis这一篇就够了

目录 1.下载安装启动 1.1 临时启动服务 2.2 默认服务安装 2.常用五大基本数据类型 2.1 key操作 2.2 字符串(String) 2.3 列表(List) 2.4 Set(集合) 2.5 Hash(哈希) 2.6 Zs…

分离表示学习:通用图像融合框架

IFSepR: A General Framework for Image Fusion Based on Separate Representation Learning (IFSepR:一种基于分离表示学习的通用图像融合框架) 提出了一种基于分离表示学习的图像融合框架IFSepR。我们认为,基于先验知识的共模…

Fast Segment Anything Model(FastSAM)

Fast Segment Anything Model(FastSAM) Fast Segment Anything Model(FastSAM)是一个仅使用SAM作者发布的SA-1B数据集的2%进行训练的CNN Segment Anything模型。FastSAM在50倍的运行速度下实现了与SAM方法相当的性能。 SAM代码&a…

pubg 依赖安装

一、安装python 1、进入官网 https://www.python.org/ 2、勾选Add python.exe to PTHA 3、自定义下载 测试和文档不需要勾选,然后next 4、自定义安装路径 点击install安装 安装成功,点击close。 5、测试 windr键,输入cmd 输入python回…

基于SSM的餐厅点餐系统设计与实现(Java+MySQL)

目 录 第一章 绪论 1 1.1系统研究背景和意义 1 1.2研究现状 1 1.3论文结构 2 第二章 相关技术说明 3 2.1 JSP(Java Server Page)简介 3 2.2 Spring框架简介 4 2.3 Spring MVC框架简介 5 2.4 MyBatis 框架简介 5 2.4 MySql数据库简介 5 2.6 Tomcat简介 6 2.7 jQuery简介 7 2.8系…

计算机毕业论文内容参考|基于大数据的信息物理融合系统的分析与设计方法

文章目录 导文摘要前言绪论课题背景国内外现状与趋势:课题内容:相关技术与方法介绍:系统架构设计:数据采集与处理:数据存储与管理:数据分析与挖掘:系统优化与调试:应用场景:挑战与机遇:研究方向:系统分析:系统设计:系统实现:系统测试:总结与展望:

SpringBoot原理(1)--@SpringBootApplication注解使用和原理/SpringBoot的自动配置原理详解

文章目录 前言主启动类的配置SpringBootConfiguration注解验证启动类是否被注入到spring容器中 ComponentScan 注解ComponentScan 注解解析与路径扫描 EnableAutoConfiguration注解 问题解答1.AutoConfigurationPackage和ComponentScan的作用是否冲突起因回答 2.为什么能实现自…

WIN10上必不可少的5款优质软件

噔噔噔噔,作为一个黑科技软件爱好者,电脑里肯定是不会缺少这方面的东西,今天的5款优质软件闪亮登场了。 颜色拾取器——ColorPix ​ ColorPix是一个颜色拾取器工具,可以让你快速地获取屏幕上任意位置的颜色值,如RGB、…

ivshmem-plain设备原理分析

文章目录 前言基本原理共享内存协议规范 具体实现设备模型数据结构设备初始化 测试验证方案流程Libvirt配置Qemu配置测试步骤 前言 ivshmem-plain设备是Qemu提供的一种特殊设备,通过这个设备,可以实现虚机内存和主机上其它进程共存共享,应用…

618美妆个护28个榜单:欧莱雅稳住冠军?珀莱雅大爆发第二?

存量时代的购物造节大竞争,作为消费复苏后的首场大促,今年的618堪称史上最“卷”,也承载着消费振兴、经济复苏等希望。 不过,今年所有平台都未公布具体GMV,某种程度说明大促造节的时代俨然已成过去式了。 5月18日&am…

怎么去除视频里的背景音乐?其实非常简单!

如何去除视频背景音乐?在视频处理中,有时我们需要从视频中提取声音并进行处理,而不仅仅是简单地去除整个背景音乐。我们可能需要有选择性地去除人声或背景音乐。这个处理过程对于选用合适的工具至关重要。在本文中,我将分享两种可…