Yolov10(yolov8代码里兼容版本)推理代码解析,抛去nms,大道至简

news2024/11/17 1:41:45

一、模型的输出头

下载官方的yolov8代码库https://github.com/ultralytics/ultralytics
打开ultralytics/nn/modules/head.py,主要需要看一下模型的输出头是如何做训练和预测推理。
在这里插入图片描述
v10检测头继承与常规的检测头Detect,初始化里重构了一下分类的输出头self.cv3,多加了一些卷积层。并将end2end这个参数置为True
在这里插入图片描述
再来看Detect检测头里如何兼容v10检测的

由于end2end是True.
在这里插入图片描述
所以走forward_end2end()

    def forward_end2end(self, x):
        """
        Performs forward pass of the v10Detect module.

        Args:
            x (tensor): Input tensor.

        Returns:
            (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
                           If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
        """
        x_detach = [xi.detach() for xi in x]
        one2one = [
            torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
        ]
        for i in range(self.nl):
            x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
        if self.training:  # Training path
            return {"one2many": x, "one2one": one2one}

        y = self._inference(one2one)
        y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
        return y if self.export else (y, {"one2many": x, "one2one": one2one})

将网络端到端的3个输出头拼接得到one2one
在这里插入图片描述
one2one为1对1训练输出头
x为1对多训练输出头
如果self.training为True,
即你在训练的时候返回一个字典 {“one2many”: x, “one2one”: one2one},用于e2e训练。
在这里插入图片描述
如果是评估或者预测图片,先走推理self._inference再做后处理self.postprocess
在这里插入图片描述
推理只需要获取1对1输出头的结果即可,对box进行编码self.decode_bboxes。

    def _inference(self, x):
        """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
        # Inference path
        shape = x[0].shape  # BCHW
        x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
        if self.dynamic or self.shape != shape:
            self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
            self.shape = shape

        if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:  # avoid TF FlexSplitV ops
            box = x_cat[:, : self.reg_max * 4]
            cls = x_cat[:, self.reg_max * 4 :]
        else:
            box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)

        if self.export and self.format in {"tflite", "edgetpu"}:
            # Precompute normalization factor to increase numerical stability
            # See https://github.com/ultralytics/ultralytics/issues/7371
            grid_h = shape[2]
            grid_w = shape[3]
            grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
            norm = self.strides / (self.stride[0] * grid_size)
            dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
        else:
            dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides

        return torch.cat((dbox, cls.sigmoid()), 1)

return的输出是dbox(4)+cls(14),batch=1,为(1,18,8400)
在这里插入图片描述

二、后处理

在这里插入图片描述

    @staticmethod
    def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
        """
        Post-processes YOLO model predictions.

        Args:
            preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
                format [x, y, w, h, class_probs].
            max_det (int): Maximum detections per image.
            nc (int, optional): Number of classes. Default: 80.

        Returns:
            (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
                dimension format [x, y, w, h, max_class_prob, class_index].
        """
        batch_size, anchors, predictions = preds.shape  # i.e. shape(16,8400,84)
        boxes, scores = preds.split([4, nc], dim=-1)
        index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
        boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
        scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
        scores, index = scores.flatten(1).topk(max_det)
        i = torch.arange(batch_size)[..., None]  # batch indices
        return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)

1、首先获取到批次,预测锚框,和预测box+cls

batch_size, anchors, predictions = preds.shape

在这里插入图片描述
2、单独获取预测box和每个box的所有分类得分

 boxes, scores = preds.split([4, nc], dim=-1)

在这里插入图片描述
3、获取分类得分值最大的前300个框

index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)

topk()会获取到scores的value和对应的索引indices,这里我们只需要获取到索引即可.topk(min(max_det, anchors))[1]
在这里插入图片描述
在这里插入图片描述
4、根据索引挑出这300个框

boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))

在这里插入图片描述
5、根据索引挑出300个分类得分,即300个boxes对应的14类别的分类得分.

scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))

在这里插入图片描述
6、挑出所有分类的得分中的前300个框。

scores, index = scores.flatten(1).topk(max_det)

在这里插入图片描述
7、获取batch的索引

i = torch.arange(batch_size)[..., None]  # batch indices

在这里插入图片描述
8、返回300个得分值最高的框

 torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)

index // nc:得分值前300的索引所属的box,某一个框可能有2个分类得分都在前300,则这个框会出现2次。
scores[…, None]:框的分数
(index % nc)[…, None].float():框的类别

在这里插入图片描述

三、预测输出

ultralytic/models/yolo/detect/predict.py

    def postprocess(self, preds, img, orig_imgs):
        """Post-processes predictions and returns a list of Results objects."""
        preds = ops.non_max_suppression(
            preds,
            self.args.conf,
            self.args.iou,
            agnostic=self.args.agnostic_nms,
            max_det=self.args.max_det,
            classes=self.args.classes,
        )

        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
        return results

v8源码里v10的predict的后处理也走了nms函数但是并没有做nms处理
在这里插入图片描述
nms函数中,如果判断预测结果是属于v10的end2end的模型预测结果,则直接从300个候选框中输出大于置信度conf_thres的框作为最终的输出结果。

四、总结

1、端到端的模型抛弃了复杂的后处理过程,不再需要转模型的时候对齐精度,直拿直用,必定是未来研究的重点趋势。
2、跟同事讨论,她在她的大数据集上测试v10的结果甚至优于v8的训练结果。这可能得出一个结论,当你的数据集足够大且足够干净的情况下,v10的结果反而会更好,当然这需要各位再多多测试了。
3、后续博主也准备把分割关键点旋转框等都修改成v10的端到端模式,敬请期待。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2088265.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pacs图像打不开怎么办 --日常工作总结

先强调一下,我不是专门做图像入库和图像归档,我负责的是临床这边的影像,下面是占在我的业务日常分析总结的哈,(不太专业,勿喷) 我们经常会遇到在打开某个检查的时候,出现黑框,日志定位wado服务取不到图 这种情况一般分为 (1) 工作站,工作组,路由,存储卷配置缺失 ---对应的wad…

【C语言从不挂科到高绩点】05-流程控制语句-switch语句

Hello!彦祖们,俺又回来了!!!,继续给大家分享 《C语言从不挂科到高绩点》课程 本套课程将会从0基础讲解C语言核心技术,适合人群: 大学中开设了C语言课程的同学想要专升本或者考研的…

Linux上启动redis

1.默认启动方式:在系统的任意位置执行 redis-server即可启动 ps:这是前端界面启动,无法直接连接redis,想要连接的话只能另外启动一个窗口,因此下面我们介绍后台启动redis 2.指定配置启动: redis的配置文件位置&#xff1a…

二叉树的基本知识

(写给未来遗忘的自己) 1.二叉树的种类 1. 满二叉树:所有分支都有数(都填满) 2. 完全二叉树:除了最底层没填满外其他的都满了,而且最底层从左到右是存在数的位置是连续的 3.二叉搜索树&#xf…

大学开学必备好物清单有哪些?开学必备清单大全,超详细版!

即将踏入大学校园的新生们,是否已经准备好迎接全新的挑战与机遇呢?在开学之前,将必备物品筹备妥当是极为重要的事情,因为这能够助力大家更为良好地适应大学生活。接下来,为大家提供一份实用的大学生开学必备物品清单&a…

性能测试面试题总结

最近这一年,对性能测试有了更多的认知。 压力、强度测试:在一定软硬件环境下,通过高负载的手段来使服务器资源(强调服务器资源,硬件资源)处于极限状态,测试系统在极限状态下长时间运行是否稳定…

大模型微调---qwen实战

一、Qwen大模型的介绍 Qwen是阿里云开发的大语言模型,整个qwen系列的模型,由base模型、rm模型、chat模型、code模型、math模型等等。 qwen采用chatml样式的格式来进行模型训练,chatml格式可以时模型有效区分各类信息,可以增强模…

blender插件库

插件安装教程:blender4.2中安装插件的方式-CSDN博客 blender官网插件库地址:Add-ons — Blender Extensions 1,ExtraObjects:提供更多网格形状, 链接:https://caiyun.139.com/m/i?2gov6Lw5RAib8 提取码:0ayj 复制内…

有向图的转置:算法分析与实现

有向图的转置:算法分析与实现 前言1. 邻接链表表示法2. 邻接矩阵表示法结论前言 在计算机科学中,图是一种非常重要的数据结构,用于表示对象之间的复杂关系。有向图(Directed Graph)是一种图,其边具有方向性。有向图的转置(Transpose)是一种基本操作,它将图中所有边的…

LLM面经(持续更新中)

Tokenizer Norm Batch Norm 好处 使得模型训练收敛的速度更快 每层的数据分布都不一样的话(解决Internal Covariance Shift),将会导致网络非常难收敛和训练,而如果把每层的数据都在转换在均值为零,方差为1的状态下,这样每层数据…

第一个golang项目

第一个golang项目 开发环境安装golangVisual Studio Code安装golang语言插件初始化项目创建目录初始化golang配置 开始开发安装所需依赖创建main.go创建配置文件创建命令版本命令查看指定目录指定后缀文件并将指定内容替换为新内容 打包并运行 前因后果:因为工作需要…

不可错过的10款电脑监控软件推荐,电脑监控软件哪个好?宝藏安利

电脑监控软件已成为企业管理和家庭安全的重要工具。 无论是为了提升工作效率、保障信息安全,还是为了监督孩子的学习情况,一款优秀的电脑监控软件都能发挥巨大作用。 本文将为您推荐10款不可错过的电脑监控软件,并详细分析它们的优势与特点&…

Elastic Stack(三):Logstash介绍及安装

目录 1 Logstash介绍1.1 组件介绍1.2 Logstash 工作原理 2 Logstash安装2.1 logstash-源码包安装8.1.01、logstash安装2、创建配置文件3、启动4、配置快速启动文件 1 Logstash介绍 1.1 组件介绍 Logstash是一个开源数据收集引擎,具有实时管道功能。Logstash可以动…

财富趋势金融大模型已通过备案

财富趋势金融大模型已通过备案 8月28日晚,国内领先的证券软件与信息服务提供商——财富趋势,公布了其2024年上半年财务报告: 今年上半年,财富趋势营收1.48亿元,同比增长0.14%;实现归母净利润为1亿元&#x…

适用于 Windows 的文件恢复软件

我很遗憾我在 Windows中从 PC 中删除了数据并再次移动了它们。当我检查时,什么都没有。是否有任何 Windows 数据恢复软件,或者是否可以想象?我会看到任何援助的价值。 文档、图像、音频等数据文件可能会因意外删除、感染攻击、系统崩溃等不良…

mac os系统

各种各样的系统优缺点-CSDN博客 目录 一:mac os是什么系统?图形用户界面的革命性操作系统 二:mac os是什么系统:高性能和无缝衔接,功能丰富、安全可靠 三:mac os是什么系统:全新界面设计和卓…

Tomcat 环境配置及部署Web项目

一.环境 Java Tomcat 二.Java环境 1.下载安装JDK 2.修改及新建环境变量 3.查看Java 版本 三.Tomcat 环境 1.下载及解压Tomcat 2.配置环境变量 3.验证安装,运行startup.bat 访问:http://localhost:8080/ 三.Web项目 1.修改Tomcat配置文件 2.拷贝W…

Python将Latex公式插入到Word中

推荐一个库,可以使用python将Latex公式插入到Word中显示 使用pip进行安装: pip install latex2word 示例将如下公式插入到word 公式1: f(x) \int_{-\infty}^\infty \hat f(x)\xi\,e^{2 \pi i \xi x} \,\mathrm{d}\xi 公式2: \int x^{\mu}…

重生奇迹MU 小清新职业智弓MM

游戏中有一种令人迷醉的职业——智弓MM,她们以高超的射箭技能闻名于世。本文将为您介绍这个悠闲的小清新职业,在游戏中的特点以及如何成为一名出色的智弓MM。跟随我们一起探索这个奇妙而神秘的职业吧! 悠闲的游戏节奏是游戏的初衷之一&#…

Dataease1.8.23 local本地安装

1、安装视频 手把手带你安装DataEase(一)Local模式部署 DataEase 免费开源BI工具 开源数据可视化分析工具 2、图文 安装模式 - DataEase 文档 注意点: 1、数据库:mysql 1)my.cnf 新增配置: #忽略大小…