yolov5模型Detection输出内容与源码详细解读

news2024/11/20 8:45:24

文章目录

  • 前言
  • 一、Detiction类源码说明
  • 二、Detection类初始化参数解读
  • 三、Detection的训练输出源码解读
  • 四、Detection的预测输出源码解读
    • 1、self.grid内容解读
    • 2、xy/wh内容解读
    • 3、推理输出解读
  • 总结


前言

最近,需要修改yolov5推理结果,通过推理特征添加一些其它操作(如蒸馏)。显然,你需要对yolov5推理输出内容有详细了解,方可被你使用。为此,本文将记录个人对yolov5输出内容源码解读,这样对于你修改源码或蒸馏操作可提供理论参考。


一、Detiction类源码说明

yolov5的detection类输出包含2个部分,一个是训练的输出,一个是预测的输出。而我将在这里解释类参数与训练、预测输出内容。为什么我使用一篇文章来说明?显然是训练与预测输出内容与原始图像尺寸、模型尺寸、特征尺寸以及回归box含义与使用细节。

整体源码如下:

class Detect(nn.Module):
    stride = None  # strides computed during build
    onnx_dynamic = False  # ONNX export parameter

    def __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layer
        super().__init__()
        self.nc = nc  # number of classes
        self.no = nc + 5  # number of outputs per anchor
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors
        self.grid = [torch.zeros(1)] * self.nl  # init grid
        self.anchor_grid = [torch.zeros(1)] * self.nl  # init anchor grid
        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
        self.inplace = inplace  # use in-place ops (e.g. slice assignment)

    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training:  # inference
                if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                y = x[i].sigmoid()
                if self.inplace:
                    y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
                    xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
                    wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, y[..., 4:]), -1)
                z.append(y.view(bs, -1, self.no))

        return x if self.training else (torch.cat(z, 1), x)

二、Detection类初始化参数解读

我已voc数据为例说明,voc数据类别为20。我将其具体解释注释其代码中,如下:

def __init__(self, nc=20, anchors=(), ch=(), inplace=True):  # detection layer
        super().__init__()
	 self.nc = nc  #   =20 类别数量
	 self.no = nc + 5  # =25 每个类别需添加位置与置信度
	 self.nl = len(anchors)  # =3 yolov5检测层为3,每层特征有一个anchor,可使用anchors替代 
	 self.na = len(anchors[0]) // 2  # =3 获得每个grid的anchor数量
	 self.grid = [torch.zeros(1)] * self.nl  # =[tensor([0.]), tensor([0.]), tensor([0.])],初始化grid,后期会每个特征变成[1,3,feature_w,feature_h,2],共三个
	 self.anchor_grid = [torch.zeros(1)] * self.nl  # 初始化anchor grid,与上面self.grid类似,直白说就是占位初始化
	 self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # 将anchors值变成shape(3,3,2)与对应格式
	 self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # 3个模型数据输出转换,每个特征图输出转换,
	 self.inplace = inplace  # use in-place ops (e.g. slice assignment)

anchors值如下:

  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

self.m模块:

ModuleList(
  (0): Conv2d(320, 75, kernel_size=(1, 1), stride=(1, 1))
  (1): Conv2d(640, 75, kernel_size=(1, 1), stride=(1, 1))
  (2): Conv2d(1280, 75, kernel_size=(1, 1), stride=(1, 1))
)

其中ch是输入,与模型图像大小关联。

三、Detection的训练输出源码解读

为便于快速理解模型训练输出内容,我直接去掉推理输出代码,这样一目了然知道模型输出内容。

    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # 预测输出
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()    

        return x 

输出x值为三个特征层的值输出,都表示[batch,3,h,w,cls+5],含义为batch size图像的h*w个像素有3个预测,每个预测都是类别概率、置信度与box位置。其输出如下:

在这里插入图片描述
注:特别说明,x输出都没有经过sigmoid函数。

四、Detection的预测输出源码解读

为便于快速理解模型推理输出内容,我直去掉不必要判断,给出推理输出代码,这样一目了然知道模型输出内容。

    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

          	 # inference
             if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                 self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

             y = x[i].sigmoid()
             if self.inplace:
                 y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
                 y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
             else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
                 xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
                 wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                 y = torch.cat((xy, wh, y[..., 4:]), -1)
             z.append(y.view(bs, -1, self.no))

        return   (torch.cat(z, 1), x)

1、self.grid内容解读

grid为网格,可以理解是对应特征图的网格,以第一个特征宽高为80说明,grid[0]存的是像素位置坐标如下tensort示意。相当于,每个batch的高宽对应像素重复3次分别从0到79的x坐标和y坐标得到grid。而grid有三个特征,其它二个特征图也是同样原理,如下图。

在这里插入图片描述

tensor([[[[[ 0.,  0.],
           [ 1.,  0.],
           [ 2.,  0.],
           ...,
           [77.,  0.],
           [78.,  0.],
           [79.,  0.]],

          [[ 0.,  1.],
           [ 1.,  1.],
           [ 2.,  1.],
           ...,
           [77.,  1.],
           [78.,  1.],
           [79.,  1.]],

          [[ 0.,  2.],
           [ 1.,  2.],
           [ 2.,  2.],
           ...,
           [77.,  2.],
           [78.,  2.],
           [79.,  2.]],

          ...,

2、xy/wh内容解读

yolov5在经历一次sigmoid为y = x[i].sigmoid(),获得dx、dy、dw、dh值(我暂时这么理解),在使用以下代码实现在该特征图对应xy值与wh值。特别注意,假设该特征图是8080尺寸,仅是获得在该特征8080层恢复到模型输入特征图像尺寸,还未到原始图像输入尺寸。源码如下:

y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh

这里也说明一下,dw/dh转到对应尺寸是直接乘以self.anchor_grid,而self.anchor_grid保留是anchors值如下:

  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

最终扩展到对应特征图维度,如下显示:
在这里插入图片描述

注:这里知道dx/dy/dw/dh恢复到模型输入尺寸(如640)即可。dw/dh并没有*strid,而是直接*anchors值。

3、推理输出解读

最后将输出内容通过view转换为[batch,-1,25]形式。

z.append(y.view(bs, -1, self.no))

即可获得(torch.cat(z, 1), x)这样的推理输出。我想说x实际还是训练输出内容,没有变化。

推理输出为2个列表,第一个列表就是上面说到的z值,其经历了sigmoid,包含类、置信度、box都使用了sigmoid,且box转为了模型输出对应尺寸位置。第二个列表为三个特征层的值输出,都表示[batch,3,h,w,cls+5],含义为batch size图像的h*w个像素有3个预测,每个预测都是类别概率、置信度与box位置,是没有经历过sigmoid,实际就是训练输出内容。其结果示意如下:

在这里插入图片描述

注:x没有经过sigmoid,而z经过sigmoid后转成模型输出尺寸box。


总结

掌握yolov5的Detection类训练与预测输出内容,有利于对源码更改提供理论依据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1384292.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BDD(Behavior-Driven Development)行为驱动开发介绍

为什么需要BDD? “开发软件系统最困难的部分就是准确说明开发什么” (“The hardest single part of building a software system is deciding precisely what to build” — No Silver Bullet, Fred Brooks) 。 看一下下面的开发场景: 场景一&#xf…

python 通过定时任务执行pytest case

这段Python代码使用了schedule库来安排一个任务,在每天的22:50时运行。这个任务执行一个命令来运行pytest,并生成一个报告。 代码开始时将job_done变量设为False,然后运行预定的任务。一旦任务完成,将job_done设置为True并跳出循…

【昕宝爸爸小模块】线程的几种状态,状态之间怎样流转

➡️博客首页 https://blog.csdn.net/Java_Yangxiaoyuan 欢迎优秀的你👍点赞、🗂️收藏、加❤️关注哦。 本文章CSDN首发,欢迎转载,要注明出处哦! 先感谢优秀的你能认真的看完本文&…

react项目运行卡在编译:您当前运行的TypeScript版本不受@TypeScript eslint/TypeScript estree的官方支持

1.问题 错误信息具体如下: 搜索了一下,是typescript版本的问题,提示我版本需要在3.3.0和4.5.0中间,我查看了package.json,显示版本为4.1.3,然后一直给我提示我的版本是4.9.5,全局搜索一下&…

读写分离的手段——主从复制,解决读流量大大高于写流量的问题

应用场景 假设说有这么一种业务场景,读流量显著高于写流量,你要怎么优化呢。因为写是要加锁的,可能就会阻塞你读请求。而且其实读多写少的场景还很多见,比如电商平台,用户浏览n多个商品才会买一个。 大部分人的思路可…

智慧园区数字孪生智能可视运营平台解决方案:PPT全文82页,附下载

关键词:智慧园区解决方案,数字孪生解决方案,数字孪生应用场景及典型案例,数字孪生可视化平台,数字孪生技术,数字孪生概念,智慧园区一体化管理平台 一、基于数字孪生的智慧园区建设目标 1、实现…

Linux-命名管道

文章目录 前言一、命名管道接口函数介绍二、使用步骤 前言 上章内容,我们介绍与使用了管道。上章内容所讲的,是通过pipe接口函数让操作系统给我们申请匿名管道进行进程间通信。 并且这种进程间通信一般只适用于父子进程之间,那么对于两个没有…

什么是二分查找

一、是什么 在计算机科学中,二分查找算法,也称折半搜索算法,是一种在有序数组中查找某一特定元素的搜索算法 想要应用二分查找法,则这一堆数应有如下特性: 存储在数组中有序排序 搜索过程从数组的中间元素开始&…

Python-AST语法树

一、抽象语法树 1、什么是抽象语法树 在计算机科学中,抽象语法树(abstract syntax tree ,AST),是源代码的抽象语法结构的树状表现形式,这里特指编程语言的源代码。AST是编译器或解释器在处理源代码时所使…

<软考高项备考>《论文专题 - 65 质量管理(4) 》

4 过程3-管理质量 4.1 问题 4W1H过程做什么为了评估绩效,确保项目输出完整、正确且满足客户期望,而监督和记录质量管理活动执行结果的过程作用:①核实项目可交付成果和工作已经达到主要干系人的质量要求,可供最终验收;②确定项目…

class_4:car类

#include <iostream> using namespace std; class Car{ public://成员数据string color; //颜色string brand; //品牌string type; //车型int year; //年限//其实也是成员数据&#xff0c;指针变量&#xff0c;指向函数的变量&#xff0c;并非真正的成员函数void (*…

剪映国际版,免费无限制使用

随着抖音的爆火短视频的崛起&#xff0c;相信每一个人都感受到了短视频快节奏下的生活洪流。 现如今每个人都能成为自己生活的记录者&#xff0c;每一个人都有掌握着剪辑的基本技能。而剪映就是很多人都会使用的剪辑软件。 相对于PR、AE等剪辑软件来说&#xff0c;作为一款国…

读《Open-Vocabulary Video Anomaly Detection》

2023 西北工业大学和新大 引言 视频异常检测(VAD)旨在检测不符合预期模式的异常事件&#xff0c;由于其在智能视频监控和视频内容审查等应用前景广阔&#xff0c;已成为学术界和工业界日益关注的问题。通过几年蓬勃发展&#xff0c;VAD 在许多不断涌现的工作中取得了重大进展。…

spring-mvc(1):Hello World

虽然目前大多数都是使用springboot来开发java程序&#xff0c;或者使用其来为其他端提供接口&#xff0c;而为其他端提供接口&#xff0c;这些功能都是依靠springmvc实现的&#xff0c;所以有必要学习一下spring-mvc&#xff0c;这样才能更好的学习springboot。 一&#xff0c…

好物周刊#36:程序员简历

村雨遥的好物周刊&#xff0c;记录每周看到的有价值的信息&#xff0c;主要针对计算机领域&#xff0c;每周五发布。 一、项目 1. SmartDNS 一个运行在本地的 DNS 服务器&#xff0c;它接受来自本地客户端的 DNS 查询请求&#xff0c;然后从多个上游 DNS 服务器获取 DNS 查询…

陶瓷碗口缺口检测-图像分割

图像分割 由于对碗口进行缺口检测&#xff0c;因此只需要碗口的边界信息。得到陶瓷碗区域填充后的图像&#xff0c;对图像进行边缘检测。这是属于图像分割中的内容&#xff0c;在图像的边缘中&#xff0c;可以利用导数算子对数字图像求差分&#xff0c;将边缘提取出来。 本案…

案例:新闻数据加载

文章目录 介绍相关概念相关权限约束与限制完整示例 代码结构解读构建主界面数据请求下拉刷新总结 介绍 本篇Codelab是基于ArkTS的声明式开发范式实现的样例&#xff0c;主要介绍了数据请求和touch事件的使用。包含以下功能&#xff1a; 数据请求。列表下拉刷新。列表上拉加载…

基于面向对象,C++实现双链表

双链表同单链表类似&#xff0c;由一个值和两个指针组成 Node.h节点头文件 #pragma once class Node { public:int value;Node* prev;Node* next;Node(int value);~Node(); };Node.cpp节点源文件 #include "Node.h"Node::Node(int value) {this->value value…

Python-高阶函数

在Python中&#xff0c;高阶函数是指能够接收函数作为参数&#xff0c;或者能够返回函数的函数。这种特性使得函数在Python中可以被灵活地传递和使用。以下是一些关于Python高阶函数的详细解释&#xff1a; 函数作为参数&#xff1a; 高阶函数可以接收其他函数作为参数。这样的…

C语言 - 最简单,最易懂的指针、引用讲解

一、变量、地址、变量值 二、直接上代码&#xff0c;一边看上图&#xff0c;一边讲解 #include <stdio.h>struct Hello {int a;int b; };int main() {struct Hello h;h.a 10;h.b 20;struct Hello *hp;hp &h;printf("1: h的地址是%d&#xff0c;hp地址是%d \…