yolov8源码解读Detect层

news2024/11/24 12:43:42

yolov8源码解读Detect层

  • Detect层解读
  • 网络各层解读及detect层后的处理

关于网络的backbone,head,以及detect层后处理,可以参考文章结尾博主的文章。

Detect层解读

先贴一下全部代码,下面一一解读。

class Detect(nn.Module):
    """YOLOv8 Detect head for detection models."""
    dynamic = False  # force grid reconstruction
    export = False  # export mode
    shape = None
    anchors = torch.empty(0)  # init
    strides = torch.empty(0)  # init

    def __init__(self, nc=80, ch=()):
        """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
        super().__init__()
        self.nc = nc  # number of classes
        self.nl = len(ch)  # number of detection layers
        self.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
        self.no = nc + self.reg_max * 4  # number of outputs per anchor
        self.stride = torch.zeros(self.nl)  # strides computed during build
        c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channels
        self.cv2 = nn.ModuleList(
            nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
        self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

    def forward(self, x):
        """Concatenates and returns predicted bounding boxes and class probabilities."""
        shape = x[0].shape  # BCHW
        # print(">>>>", x[0].shape)
        # print(">>>>", x[1].shape)
        # print(">>>>", x[2].shape)
        for i in range(self.nl):
            x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)

        if self.training:
            return x
        elif self.dynamic or self.shape != shape:
            self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
            self.shape = shape

        x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
        if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'):  # avoid TF FlexSplitV ops
            box = x_cat[:, :self.reg_max * 4]
            cls = x_cat[:, self.reg_max * 4:]
        else:
            box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)

        dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides

        if self.export and self.format in ('tflite', 'edgetpu'):
            # Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
            # https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309
            # See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
            img_h = shape[2] * self.stride[0]
            img_w = shape[3] * self.stride[0]
            img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)
            dbox /= img_size
        # print(cls.shape)
        y = torch.cat((dbox, cls.sigmoid()), 1)
        # print(y.shape)
        return y if self.export else (y, x)
	dynamic = False #这个属性指示网格(通常是特征图上的锚框网格)是否需要动态地重建
    export = False  #这个属性用于指示模型是否处于导出模式。
    shape = None # 用于存储输入图像或特征图的尺寸。
    anchors = torch.empty(0)  # 创建了一个空的PyTorch张量
    strides = torch.empty(0)

步长(strides)是卷积神经网络中特征图相对于输入图像的缩小比例。
例如,如果步长是32,那么一个32x32像素的区域在特征图上就对应一个单元。
和anchors一样,这里的torch.empty(0)表示步长尚未初始化。

    def __init__(self, nc=80, ch=()):
        """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
        super().__init__()
        self.nc = nc  # number of classes
        self.nl = len(ch)  # number of detection layers
        self.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
        self.no = nc + self.reg_max * 4  # number of outputs per anchor
        self.stride = torch.zeros(self.nl)  # strides computed during build
        c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channels
        self.cv2 = nn.ModuleList(
            nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
        self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

nc:类别数
nl:检测层的数量,目标检测中为3。
ch:传入的图片通道尺寸,在yolov8n,图片大小为640*640时。这里的ch为(256,128,64)
no:两个卷积再拼接后输出通道数,为4×reg_max+nc
c2,c3:计算卷积层的通道数。
cv2,cv3:定义的卷积操作,以输出有关类别和选框的特征图。
dfl:通过将分布式的概率分布转化为单一的预测值

class DFL(nn.Module):

    def __init__(self, c1=16):
        """Initialize a convolutional layer with a given number of input channels."""
        super().__init__()
        self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
        x = torch.arange(c1, dtype=torch.float)
        self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
        self.c1 = c1

    def forward(self, x):
        """Applies a transformer layer on input tensor 'x' and returns a tensor."""
        b, c, a = x.shape  # batch, channels, anchors
        return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)

self.conv:创建了一个输入通道为16,输出为1,没有偏置项,不需要进行梯度更新的卷积层。
这样的权重设置实际上模拟了一个积分过程,将卷积操作变成了加权求和的形式。
x:1到15的整数。
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)):
这里使用nn.Parameter将重塑后的张量设置为模型的参数,并且参数不会被更新。
假设前向传播中,x的形状为(1, 64, 8400),下面解释下forword中的变化。
1,x.view(b, 4, self.c1, a): 这个操作是对x的形状进行重塑。self.c1是16(因为输入通道数是64,即4*self.c1),那么a是8400(代表了所有锚点的数量)。b是批次大小,这里为1。所以x.view(b, 4, self.c1, a)将x从(1, 64, 8400)重塑为(1, 4, 16, 8400)。在这个形状中,我们得到了每个锚点的每个坐标轴(x, y, 宽度, 高度)上的16个预测值(可能代表某种概率分布)。
2,transpose(2, 1): 这个操作交换第二维和第三维。在应用transpose之后,张量的形状变为(1, 16, 4, 8400)。这样做的目的是让每组概率分布的16个预测值连续地排列在一起,为后面的softmax运算做准备。
3,softmax(1): softmax函数应用于第一维(现在是16个预测值的这一维)。softmax确保了这16个值之和为1,转换为一个有效的概率分布,表示每个预测值的可能性。
4,self.conv(…): 这个操作将配置好的卷积层应用在进行了softmax操作的张量上。由于卷积层的权重已被设置为从0到15的整数,并且不更新权重(不进行梯度下降优化),这个步骤实际上是在计算期望值。卷积层将每个离散的概率值乘以其相应的索引(也就是权重),然后对结果进行求和,得到该坐标的预测值。
5,view(b, 4, a): 最后一步是将张量的形状从卷积操作后的(1, 1, 4, 8400)转换回(1, 4, 8400)。这样确保了最终的输出张量与每个坐标轴的预测值(x, y, 宽度, 高度)和所有锚点的数量对齐。
总的来说,dfl层就是对预测的坐标求加权期望值。将(1,64,8400)先变为(1,16,4,8400),然后对这16个通道求加权期望,变为(1,4,8400)即这8400个锚点中的每一个锚点,x,y,width,hight的加权平均值。
接下来是前向传播的过程。打印传入的x形状,发现通道数是64,128,256。

在这里插入图片描述
在这里插入图片描述
原因:Detect层接受15,18,21层的输入。原本通道数是1024,512,256。但是yolov8n还需要乘0.25。
在这里插入图片描述
在这里插入图片描述

经过cv2,通道数变为64,经过cv3通道数变为nc,我这里nc为2(二分类)。在经过cat拼接,在通道维度上拼接,所以x[i]的通道数变为66。

如果处于训练模式,就直接返回x。
否则执行下面的代码,将特征图列表x(1×66×40×40,1×66×80×80,1×66×20×20)传递给make_anchors()函数。make_anchors函数用于生成锚点(anchors),它通常用在目标检测网络中。每个锚点代表了特征图上的一个点,可以用来预测相对于该点的边界框。strides是这些特征图相对于原始图像的下采样步长。简单来说,生成了8400个锚点(40×40+80×80+20×20),变量为anchors,形状为1×2×8400)。同时生成了8400个步长,变量为strides,形状为1×8400。参数0.5表示每个锚点处于每个像素块的中央。

        if self.training:
            return x
        elif self.dynamic or self.shape != shape:
            self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
            self.shape = shape

将xi按照2维度进行拼接,xi分别为1×66×40×40,1×66×80×80,1×66×20×20。拼接后的x_cat为1×66×8400

x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)

这段代码就是把x_cat进行拆分。box形状为1×64×8400,包含每个边界框的回归参数。cls形状为1×2×8400,会包含类别预测,2是因为我这里类别为2。

        if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'):  # avoid TF FlexSplitV ops
            box = x_cat[:, :self.reg_max * 4]
            cls = x_cat[:, self.reg_max * 4:]
        else:
            box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)

dfl层就是对预测的坐标求加权期望值。将(1,64,8400)先变为(1,16,4,8400),然后对这16个通道求加权期望,变为(1,4,8400)即这8400个锚点中的每一个锚点,x1,y1,x2,y2的加权平均值。dist2bbox()函数的作用是将锚点x1,y1,x2,y2转换为x,y,width,hight的形式。最后在乘以步长,还原到原图的大小比例。

dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides

在这里插入图片描述

此代码片段的作用是在模型导出为 Tensorflow Lite (tflite) 或 Edge TPU 兼容格式时,对预测框 (dbox) 进行归一化处理。

 if self.export and self.format in ('tflite', 'edgetpu'):
            # Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
            # https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309
            # See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
            img_h = shape[2] * self.stride[0]
            img_w = shape[3] * self.stride[0]
            img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)
            dbox /= img_size

此时y的形状为1×66×8400。代表有8400个锚点,每个锚点包含坐标框的x,y,width,hight,以及类别得分信息。

y = torch.cat((dbox, cls.sigmoid()), 1)

返回值,至此detect层结束。完整的预测,后续还需要进行一些处理。如进行非极大抑制,对这8400个锚点进行筛选

return y if self.export else (y, x)

另外,最后的bias_init()函数用于初始化一个目标检测模型中的Detect层的偏置。确保在训练开始时偏置值是基于合理假设的。这种方法的目标是为模型提供一个好的起点,并有助于加速训练过程中的收敛。

    def bias_init(self):
        """Initialize Detect() biases, WARNING: requires stride availability."""
        m = self  # self.model[-1]  # Detect() module
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
        # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequency
        for a, b, s in zip(m.cv2, m.cv3, m.stride):  # from
            a[-1].bias.data[:] = 1.0  # box
            b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)

网络各层解读及detect层后的处理

关于backbone,head层,以及detect层参考下面博主的文章,讲的非常好。

链接: Yolov 8源码超详细逐行解读+ 网络结构细讲(自我用的小白笔记)
链接: 最细致讲解yolov8模型推理完整代码–(前处理,后处理)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1451918.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【开源】JAVA+Vue.js实现大学计算机课程管理平台

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 实验课程档案模块2.2 实验资源模块2.3 学生实验模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 实验课程档案表3.2.2 实验资源表3.2.3 学生实验表 四、系统展示五、核心代码5.1 一键生成实验5.2 提交实验5.3 批阅实…

如何选择一个适合自己的赛道

(点击即可收听) 最开始一定要先做好定位,也就是你做短视频的目的是什么?当然对大多数人来说,终极目的肯定是赚钱,但赚钱的速度是由定位决定的 如果你资金比较充裕,不急于赚钱,就可以…

配置oracle连接管理器(cman)

Oracle Connection Manager是一个软件组件,可以在oracle客户端上指定安装这个组件,Oracle连接管理器代理发送给数据库服务器的请求,在连接管理器中,我们可以通过配置各种规则来控制会话访问。 简而言之,不同于专用连接…

基于BP算法的SAR成像matlab仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 4.1 BP算法的基本原理 4.2 BP算法的优点与局限性 5.完整工程文件 1.课题概述 基于BP算法的SAR成像。合成孔径雷达(SAR)是一种高分辨率的雷达系统,能够在各种天气和光…

DS:八大排序之直接插入排序、希尔排序和选择排序

创作不易,感谢三连支持!! 一、排序的概念及运用 1.1 排序的概念 排序:所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起 来的操作。稳定性&…

Java图形化界面编程——五子棋游戏 笔记

2.8.5 五子棋 接下来,我们使用之前学习的绘图技术,做一个五子棋的游戏。 注意,这个代码只实现了五子棋的落子、删除棋子和动画等逻辑实现,并没有把五子棋的游戏逻辑编写完整,比较简单易上手。 图片素材 package…

.NET Core MongoDB数据仓储和工作单元模式实操

前言 上一章节我们主要讲解了MongoDB数据仓储和工作单元模式的封装,这一章节主要讲的是MongoDB用户管理相关操作实操。如:获取所有用户信息、获取用户分页数据、通过用户ID获取对应用户信息、添加用户信息、事务添加用户信息、用户信息修改、用户信息删除…

每日OJ题_算法_递归③力扣206. 反转链表

目录 力扣206. 反转链表 解析代码 力扣206. 反转链表 206. 反转链表 LCR 024. 反转链表 难度 简单 给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 示例 1: 输入:head [1,2,3,4,5] 输出:[5,4,3,…

Leetcode3010. 将数组分成最小总代价的子数组 I

Every day a Leetcode 题目来源:3010. 将数组分成最小总代价的子数组 I 题目描述: 给你一个长度为 n 的整数数组 nums 。 一个数组的代价是它的第一个元素。比方说,[1,2,3] 的代价是 1 ,[3,4,1] 的代价是 3 。 你需要将 num…

HTML5+CSS3小实例:彩色拨动开关

实例:彩色拨动开关 技术栈:HTML+CSS 效果: 源码: 【HTML】 <!DOCTYPE html> <html lang="zh-CN"><head><meta charset="UTF-8" /><meta http-equiv="X-UA-Compatible" content="IE=edge" /><…

双输入宽带混合 Doherty-Outphasing功率放大器设计(2021.02 MTT)-从理论到ADS版图

基于双输入的宽带混合 Doherty-Outphasing功率放大器设计(2021.02 MTT)-从理论到ADS版图 原文: Wideband Two-Way Hybrid Doherty Outphasing Power Amplifier 发表于FEBRUARY 2021&#xff0c;在微波顶刊IEEE T MTT上面&#xff0c;使用的GAN CGH40010F 全部工程下载&#…

Covalent Network(CQT)与卡尔加里大学建立合作,共同推动区块链技术创新

Covalent Network&#xff08;CQT&#xff09;作为领先的 Web3 数据索引器和提供者&#xff0c;宣布已经与卡尔加里大学达成了具备开创性意义的合作&#xff0c;此次合作标志着推动区块链数据研究和可访问性的重要里程碑。卡尔加里大学是首个以验证者的身份加入 Covalent Netwo…

linux-firewalld防火墙端口转发

目的:通过统一地址实现对外同一地址暴露 1.系统配置文件开启 ipv4 端口转发 echo "net.ipv4.ip_forward 1" >> /etc/sysctl.confsysctl -p 2.查看防火墙配置端口转发之前的状态 firewall-cmd --statefirewall-cmd --list-all 3.开启 IP 伪装 firewall-cm…

TIM(Timer)定时中断 P1

难点&#xff1a;定时器级联、主从模式 一、简介&#xff1a; 1.TIM&#xff08;Timer&#xff09;定时器 定时器可以对输入的时钟进行计数&#xff0c;并在计数值达到设定值时触发中断 补充&#xff1a; { 定时器本质上是一个计数器&#xff0c;可以工作在定时或计数模式&…

视频生成模型作为世界模拟器

我们探索了在视频数据上大规模训练生成模型。具体来说&#xff0c;我们联合训练文本条件扩散模型&#xff0c;处理不同持续时间、分辨率和宽高比的视频和图像。我们利用一种在时空补丁上操作视频和图像潜码的transformer架构。我们最大的模型&#xff0c;Sora&#xff0c;能够生…

SQL中的各种连接的区别总结

前言 今天主要的内容是要讲解SQL中关于Join、Inner Join、Left Join、Right Join、Full Join、On、 Where区别和用法&#xff0c;不用我说其实前面的这些基本SQL语法各位攻城狮基本上都用过。但是往往我们可能用的比较多的也就是左右连接和内连接了&#xff0c;而且对于许多初学…

C++中的volatile:穿越编译器的屏障

C中的volatile&#xff1a;穿越编译器的屏障 在C编程中&#xff0c;我们经常会遇到需要与硬件交互或多线程环境下访问共享数据的情况。为了确保程序的正确性和可预测性&#xff0c;C提供了关键字volatile来修饰变量。本文将深入解析C中的volatile关键字&#xff0c;介绍其作用、…

luigi,一个好用的 Python 数据管道库!

🏷️个人主页:鼠鼠我捏,要死了捏的主页 🏷️付费专栏:Python专栏 🏷️个人学习笔记,若有缺误,欢迎评论区指正 前言 大家好,今天为大家分享一个超级厉害的 Python 库 - luigi。 Github地址:https://github.com/spotify/luigi 在大数据时代,处理海量数据已经成…

Java on VS Code 2024年1月更新|JDK 21支持!测试覆盖率功能最新体验!

作者&#xff1a;Nick Zhu - Senior Program Manager, Developer Division At Microsoft 排版&#xff1a;Alan Wang 大家好&#xff0c;欢迎来到 Visual Studio Code for Java 2024年的第一期更新&#xff01;提前祝愿大家春节快乐&#xff01;在本博客中&#xff0c;我们将有…

ChatGPT高效提问—prompt实践(智能辅导-心理咨询-职业规划)

ChatGPT高效提问—prompt实践&#xff08;智能辅导-心理咨询-职业规划&#xff09; ​ 智能辅导是指利用人工智能技术&#xff0c;为学习者提供个性化、高效的学习辅助服务。它基于大数据分析和机器学习算法&#xff0c;可以针对学习者的学习行为、状态和能力进行评估和预测&a…