yolov5源码解析(9)--输出

news2025/2/23 20:02:27

本文章基于yolov5-6.2版本。主要讲解的是yolov5是怎么在最终的特征图上得出物体边框、置信度、物体分类的。

一。总体框架

首先贴出总体框架,直接就拿官方文档的图了,本文就是接着右侧的那三层输出开始讨论。

  • BackboneNew CSP-Darknet53
  • NeckSPPFNew CSP-PAN
  • HeadYOLOv3 Head

这三个输出层分别就是浅、中、深层啦,浅层特征图分辨率是80乘80,中层是40乘40,深层是20乘20,一般来说浅层用于预测小物体,深层用于预测大物体。另外说明一下,浅、中、深三层的特征图输出通道数不一定是256、512、1024,要看你用的是哪一种规格的模型。比如yolov5s的话,那这三层的通道数分别是128、256、512,可以导出onnx格式用Netron看一下模型结构来确定。

 

简要说一下原因,这个是由对应的模型配置文件,即models目录里的yolov5s.yaml,yolov5m.yaml等等来决定的,看你用哪一个,第二个红框里的就是每一层的输出通道数了,但是它是要乘上第一个红框里的值的,即width_multiple这个配置,你会发现几个模型配置文件的内容都差不多,区别就区别在这里的depth_multiple和width_multiple。

 二。输出物体边框、置信度、物体分类

接下来进入正题,每层特征图最终都会经过1乘1卷积,变成(5+分类数)乘3个通道:

0)首先为什么乘以3,因为每一层都有3个anchor,后面再细讲

下面讲的是每一anchor对应的(5+分类数)个通道,假设分类数为2,那一共就是7个通道了,这7个通道分别是xywh(4个通道),置信度(1个通道),分类(此处2分类,就是2个通道)

1)物体边框的4个值,x,y,w,h啦,不过这个x,y并不直接是物体框中心点的坐标,而是它相对于自身所处的格子左上角的偏移,比如下图红色的这个格子(假设现在特征图就是4乘4),这个格子预测出7个值,前4个就是xywh,然后x是0.2,y是0.2,那么中心点就差不多在蓝点所处的位置了(其实这其中还有玄机,一步步来)。然后再把这个中心点的相对值作用到原图的尺度得到最终的坐标。

但是呢如果像上面这样直接预测一个相对格子左上角的偏移的这样一个值呢,会比较不稳定,它可能预测的值很大,比如x给你预测一个10出来,那就是往右数10个格子了,偏差这么大不利用网络收敛,也没有意义,因为这个格子里的特征跟右边第10个格子的特征相差可能很大了。

所以要加一个限制,首先给它sigmoid一下,这样其值范围就变成0-1了(小数),此时它的波动就在自己的这个格子内,然后乘以2再减0.5,如下图(直接拿官方文档的图了~~)

 

 这样它的波动范围就是下图的黄框的范围。

 限制为0-1好理解,自己这个格子的预测范围就在自己格子内麻,为啥又变成了-0.5-1.5呢,因为这样更容易得到0-1范围内的值。如果的范围限制为0-1,而且是用sigmoid来限制的话,那接近0和1这两个位置的导数就会很小,梯度更新的时候就会慢。

然后就是宽高,宽高也不是直接预测出物体边框的宽高啦,而是基于anchor的,预测出来的值会乘上anchor的宽高得出最终的宽高,并且,这里仍然是先用sigmoid将输出值限制为0-1,然后再乘以2,再来个平方,这样最终的值的范围就是0-4了。

之前说了每一层有3个anchor,这些anchor还是配置在模型的配置文件里的,比如models/yolov5s.yaml,P3就是浅层的(80乘80的格子),P4是中层的(40乘40),P5是深层的(20乘20),然后这里的anchor的大小呢就是绝对值(按照640乘640的图来算的,如果你的输入图不是640乘640,那输入图是会resize一下再进行推理的)

比如现在是深层的输出,2分类,那么深层的特征图经过最后的1乘1卷积后,会得到3乘(5+2)=21个通道,每7个通道就对应一个anchor了,现在看第2个7个通道(即7-13,从0开始算),那么它对应的anchor就应该是156,198这个,那么预测出来的宽高值经过sigmoid,再乘2,再平方之后,还分别要乘上156和198,得出最终的物体宽高(基于640乘640的图的),然后再按比例得到原图的物体宽高。

2)置信度

代表预测出的物体边框和分类的可信度,最终的范围肯定是0-1了(小数),跟前面的一样,会用sigmoid来把它的范围限制为0-1。

这边可能有一个问题,那个xy不是sigmoid()乘2减0.5吗,这里咋不这么干,那是因为xy的值真的是可以达到-0.5或1.5的,那样的话就变成预测的物体中心点跑到相邻格子里去了,这也不是不行的啦。但置信度只能是0-1!

3)分类

有几个分类,就会再加几个通道,分别代表对应分类的概率,都是用sigmoid把他们的概率限制为0-1,在计算损失的时候,标签对应分类所在通道的直值为1,其它都为0了,然后分别计算BCE损失。

三。源码

最终输出层的相关源码主要就是models/yolo.py的Detect类的源码了,添加了相应的注释。

class Detect(nn.Module):
    stride = None  # strides computed during build
    onnx_dynamic = False  # ONNX export parameter
    export = False  # export mode

    def __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layer
        super().__init__()
        self.nc = nc  # number of classes
        self.no = nc + 5  # number of outputs per anchor
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors,除以2是因为[10,13, 16,30, 33,23]这个长度是6,对应3个anchor
        self.grid = [torch.zeros(1)] * self.nl  # init grid,下面会计算grid,grid就是每个格子的x,y坐标(整数,比如0-19)
        self.anchor_grid = [torch.zeros(1)] * self.nl  # init anchor grid
        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2),注意后面就可以通过self.anchors来访问它了
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv,3个输出层最后的1乘1卷积
        self.inplace = inplace  # use inplace ops (e.g. slice assignment)

    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):  # 三个输出层分别处理
            x[i] = self.m[i](x[i])  # conv,经过这个1乘1卷积就变成(5+分类数)个通道了
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)--这里的85对应coco数据集,5+80个分类
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training:  # inference
                if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                y = x[i].sigmoid()
                if self.inplace:
                    # 这里的grid[i]即对应输出层的3个anchor层的每个格子的坐标,方便进行批量计算,乘上对应的stride[i](下采样率),就得到基于640乘640的图的坐标了
                    y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i]  # xy
                    # anchor_grid[i]也是一样,不过它的形状是(1, self.na, 1, 1, 2),跟y[..., 2:4]计算时是会自动广播的,最终得到的宽高也是基于640乘640的图的宽高
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
                    # 这段是非inplace操作,计算方法是一样的
                    xy, wh, conf = y.split((2, 2, self.nc + 1), 4)  # y.tensor_split((2, 4, 5), 4)  # torch 1.8.0
                    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, -1, self.no))

        return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)

    def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
        d = self.anchors[i].device
        t = self.anchors[i].dtype
        shape = 1, self.na, ny, nx, 2  # grid shape
        # grid其实就是特征图网络的坐标,比如20乘20的,其坐标分别是0,0 0,1...0,19 1,0 1,1...19,19,第2个维度na就是anchor数啦
        y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
        if torch_1_10:  # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
            yv, xv = torch.meshgrid(y, x, indexing='ij')
        else:
            yv, xv = torch.meshgrid(y, x)
        # 注意这边先给它把0.5给减了
        grid = torch.stack((xv, yv), 2).expand(shape) - 0.5  # add grid offset, i.e. y = 2.0 * x - 0.5
        # anchor_grid即每个格子对应的anchor宽高,stride是下采样率,三层分别是8,16,32,这里为啥要乘呢,因为在外面已经把anchors给除了对应的下采样率,这里再乘回来
        anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
        return grid, anchor_grid

此处单独说一下torch.meshgrid,它其实就是用于得到网格坐标的,简化代码如下,假设现在是2乘2的网络

y, x = torch.arange(2), torch.arange(2)
yv, xv = torch.meshgrid(y, x, indexing='ij')
print(f'yv={yv}')
print(f'xv={xv}')

grid = torch.stack((xv, yv), 2)
print(f'grid={grid}')

 输出如下

 grid对应的就是如下图,得到这个网络坐标就可以直接跟输出层的x,y做批量运算了。

 四。NMS

Detect类foward之后确实是整个网络最终的输出,不过这个输出还得再经过NMS,提取出最终的答案,即这张图上到底有几个物体,边框、置信度、分类分别是什么。NMS后面再讨论~~

下一篇:

yolov5源码解析(10)--损失计算与anchor_扫地僧1234的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/405903.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaWeb酒店管理系统

酒店管理系统 一、项目介绍 1、项目用到的技术栈 开发工具:idea语言:java、js、htmlajax数据库:MySQL服务器:Tomcat框架:mybatis、jQuery 2、项目实现功能 管理员和用户登录和退出功能以及用户注册功能&#xf…

【第二趴】uni-app开发工具(手把手带你安装HBuilderX、搭建第一个多端项目初体验)

文章目录写在前面HBuilderXHBuilderX 优势HBuilderX 安装uni-app 初体验写在最后写在前面 聚沙成塔——每天进步一点点,大家好我是几何心凉,不难发现越来越多的前端招聘JD中都加入了uni-app 这一项,它也已经成为前端开发者不可或缺的一项技能…

Eolink 治愈了后端开发者的痛

一、前后端的爱恨情仇 最近公司的一个前端同事和一个后端同事吵了一架,事情大概是这样的。后端说要联调接口,前端说你的数据尽量按我的要求来,后端不干,说你这个没用。前端就讲道理呀,传统的前后端分离返回的格式要尽…

【node进阶】深入浅出websocket即时通讯(二)-实现简易的群聊私聊

✅ 作者简介:一名普通本科大三的学生,致力于提高前端开发能力 ✨ 个人主页:前端小白在前进的主页 🔥 系列专栏 : node.js学习专栏 ⭐️ 个人社区 : 个人交流社区 🍀 学习格言: ☀️ 打不倒你的会使你更强&a…

保姆级教程:Ant Design Vue中 a-table 嵌套子表格

前端为Ant Design Vue 版本为1.6.2,使用的是vue2 Ant Design Vue中 a-table 嵌套子表格,说的可能稍微墨迹了点,不过重点内容都说的比较详细,利于新人理解,高手可以自取完整代码 内容概述:完成样式及完整代…

在收到消息后秒级使网站变灰,不改代码不上线,如何实现?

注意:文本不是讲如何将网站置灰的那个技术点,那个技术点之前汶川地震的时候说过。 本文不讲如何实现技术,而是讲如何在第一时间知道消息后,更快速的实现这个置灰需求的上线。 实现需求不是乐趣,指挥别人去实现需求才…

[Vue warn]: Error in render: “TypeError: Cannot read properties of undefined(reading“category1Name“

明明页面正常显示,但是控制台却一直报 如下 错误 [Vue warn]:渲染错误:"TypeError:无法读取未定义的属性(读取category1Name)" 中发现的 Detail 的 vuex 仓库 import { reqDetail } from "/api" export default{actions:{async getDetail({co…

【前端修炼场】 — 这些标签你学会了么?快速拿下 “hr”

此文为【前端修炼场】第四篇&#xff0c;上一篇文章链接&#xff1a;上一篇 文章目录前言一、 常用标识符1.1 特殊标识符1.1.1 "<" 和 ">"&#xff08;<&#xff1b;&#xff09;1.1.2 空格&#xff08;&emsp&#xff1b;&#xff09;1.1.3 商…

uniapp微信小程序无法使用本地静态资源图片,背景图在真机不显示方法

前言 首先要说明&#xff0c;使用HBuilder或者vs Code工具开发的时候&#xff0c;在微信开发者工具调试的时候&#xff0c;我们使用本地图片是OK的&#xff0c;但是一旦放到真机上调试的时候&#xff0c;图片就显示不出来。 先看uniapp官网对背景图片的说明 错误用法 <tem…

uniapp 微信小程序和H5的弹窗滚动穿透解决

滚动穿透&#xff1a; 页面里的弹窗也是可以滚动的&#xff0c;然后页面本身内容多所以也是滚动的&#xff0c;就造成&#xff0c;在弹窗滚动的时候&#xff0c;页面内容也跟着滚动了。如图所示 ps: 电脑端分鼠标滚轮滚动和长按鼠标拖拽滚动&#xff0c;手机端只有触屏滑屏滚…

视频实时行为检测——基于yolov5+deepsort+slowfast算法

文章目录前言一、核心功能设计二、核心实现步骤1.yolov5实现目标检测2.deepsort实现目标跟踪3.slowfast动作识别三、核心代码解析1.参数2.主函数3.将结果保存成视频总结前言 前段时间打算做一个目标行为检测的项目&#xff0c;翻阅了大量资料&#xff0c;也借鉴了不少项目&…

【Java】运算符

我不去想是否能够成功 既然选择了远方 便只顾风雨兼程 —— 汪国真 目录 1. 认识运算符 1.1 认识运算符 1.2 运算符的分类 2. 算术运算符 2.1 四则运算符 2.2 复合赋值运算符 2.3 自增 / 自减 运算符 3.关系运算符 4.逻辑运算符 4.1 逻辑与 && 4.2 逻…

什么是异步

文章目录 前言一、异步是什么&#xff1f;二、举个例子来理解异步 1.异步最典型的例子就是“回调函数”总结前言 在vue的过程中&#xff0c;我们一定会遇到诸如&#xff1a; function&#xff08;参数&#xff09;.then(res>{}) 形式的代码。到底怎么编译执行的呢 &#xf…

【Jetpack】ViewModel 架构组件 ( 视图 View 和 数据模型 Model | ViewModel 作用 | ViewModel 生命周期 | 代码示例 | 使用注意事项 )

文章目录一、Activity 遇到的问题二、视图 View 和 数据模型 Model三、ViewModel 架构组件作用四、ViewModel 代码示例1、ViewModel 视图模型2、Activity 组件3、UI 布局文件4、运行效果五、ViewModel 生命周期六、ViewModel 使用注意事项一、Activity 遇到的问题 Activity 遇到…

宝塔部署nodejs项目

前言 部署操作很简单&#xff0c;网上也有很多教程&#xff0c;不过我还是踩坑了&#xff0c;这里记录一下&#xff0c;给其他人也避避坑吧。 步骤 首先你已经有了服务器&#xff0c;并且打开了宝塔面板&#xff0c;其次准备好你的nodejs项目。 在宝塔安装pm2管理器&#xf…

Nginx 调整文件上传大小限制

使用3A服务器做了网页&#xff0c;感觉挺不错的&#xff0c;使用LNMP环境 用Nginx部署了前端&#xff0c;发现上传附件大一点就会报错&#xff0c;查看配置文件&#xff0c;发现spring的附件配置已经配置了。那么就看下Nginx的body设置。nginx文件上传默认是1MB。 在 server 模…

VUE3TS: Vue3+TS的项目搭建

简介 通过 Vue-cli4 创建的 Vue3TS 的项目&#xff0c;并进行一些基础使用的举例。 此例是以 VSCode编辑器 进行的编码。 一、项目搭建 1. 进入命令提示符窗口 在要搭建项目的文件夹中&#xff0c;点击路径&#xff0c;输入CMD并按回车 2. 查看node版本、Vue-cli版本 2…

Android 架构之长连接技术

上文中我们提到了HttpDNS&#xff0c;虽然它比系统DNS更优&#xff0c;但终归还是要做DNS操作。而长连接都是IP直接连接&#xff0c;因此没有DNS相关的开销和耗时。 3. 如果有大量网络请求&#xff0c;可以明显减少网络延时&#xff0c;节省带宽 对于大型App而言&#xff0c;…

npm——安装、卸载与更新

npm 官方文档&#xff1a;https://docs.npmjs.com/ 什么是npm npm&#xff08;“Node 包管理器”&#xff09;是 JavaScript 运行时 Node.js 的默认程序包管理器。 它也被称为“Ninja Pumpkin Mutants”&#xff0c;“Nonprofit Pizza Makers”&#xff0c;以及许多其他随机…

Vue通知提醒框(Notification)

项目相关依赖版本信息 可自定义设置以下属性&#xff1a; 自动关闭的延时时长&#xff08;duration&#xff09;&#xff0c;单位ms&#xff0c;默认4500ms消息从顶部弹出时&#xff0c;距离顶部的位置&#xff08;top&#xff09;&#xff0c;单位像素px&#xff0c;默认24p…