【深度学习|目标检测】YOLO系列anchor-based原理详解

news2025/3/18 21:41:23

YOLO之anchor-based

  • 一、关于anchors的设置
  • 二、网络如何利用anchor来训练
    • 关于register_buffer
    • 训练阶段的anchor使用
    • 推理阶段的anchor使用
  • 三、训练时的正负样本匹配
    • 静态策略:
      • 跨分支采样
      • 跨anchor采样
      • 跨grid采样
    • 动态策略

总结起来其实就是:基于anchor-based的yolo就是基于三个检测头的分支上的grids和anchors(通过正样本匹配选择出来的)来计算预测与gt的偏移量,同时考量该grid中是否含有物体,并且是什么样的物体,然后在满足这三者条件的最小loss下不断迭代模型的权重参数。那么会有人问,在推理的时候,模型怎么知道应该基于哪个grid下的anchor进行回归呢?这就可以交给objectness置信度了,根据模型权重最后得到的特征图在计算后的objectness低于阈值的话,我们根本就不会考虑这个位置的回归框。

一、关于anchors的设置

在yolov5的模型yaml文件中已经设置了一套默认的anchors的尺寸(针对640*640的输入):

# anchors
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

其中p3,p4,p5分别代表三个尺度的特征图,中括号内的数字两两一对,表示在640*640输入基准上的anchor尺寸。8,16,32分别代表各自的stride步长,即相对输入的降采样倍数,也代表了该特征图每一个grid的边长。
那么为什么降采样越多的检测头预设的anchor大小越大呢,因为降采样越多,grid的个数越少,因此每一个grid对于大目标的感知能力越强。降采样越少的,grid个数越多,对于小目标的感知能力越强,模型参数收敛的更快。因此我们会发先降采样越多的头,预设的anchor越大。即为了更快的收敛。

虽然给我们设置好了默认的anchors,但是我们可以不使用这个默认anchor,在训练的时候,打开autoanchor开关:

parser.add_argument("--noautoanchor", action="store_true", help="disable AutoAnchor")

然后会进入check_anchors函数中进行判断是否需要重新生成预设anchor:

check_anchors(dataset, model=model, thr=hyp["anchor_t"], imgsz=imgsz)  # run AutoAnchor

其中thr表示这一波数据集的标注中,宽高比的最大阈值,这个参数在hyp的yaml文件中有设置好,是4。核查的代码如下所示:

def metric(k):  # compute metric
    """Computes ratio metric, anchors above threshold, and best possible recall for YOLOv5 anchor evaluation."""
    r = wh[:, None] / k[None]
    x = torch.min(r, 1 / r).min(2)[0]  # ratio metric
    best = x.max(1)[0]  # best_x
    aat = (x > 1 / thr).float().sum(1).mean()  # anchors above threshold
    bpr = (best > 1 / thr).float().mean()  # best possible recall
    return bpr, aat

返回的bpr参数就是用于来判断是否需要重新计算anchor的主要依据,当大于0.98时,我们则不需要为了这个数据集重新计算anchor。当小于0.98时,会自动重新计算。计算的方式就是使用kmeans聚类算法,根据我们这个数据集的标注情况来调整anchor的宽高和比例,计算完成后,还会再计算一次现在的bpr,然后和默认的bpr进行比较,如果小于默认的bpr则继续使用默认的anchor设置。


二、网络如何利用anchor来训练

在train.py中我们看到了train的函数中,建立网络整体架构的代码:

model = Model(cfg or ckpt["model"].yaml, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device)  # create

我们进入Model类中,会进入到DetectionModel类中,这个类在初始化的过程中,会读取我们传入的模型结构yaml文件,然后通过parse_model方法来返回一个pytorch搭建的网络框架,包含了backbone+head。backbone就按照正常的流程构建,当搭建到head时,会进入Detect类中,初始化如下:

    def __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layer
        super().__init__()
        self.nc = nc  # number of classes                     // 目标的类别个数
        self.no = nc + 5  # number of outputs per anchor      // 每一个grid的输出维度
        self.nl = len(anchors)  # number of detection layers  // 在几个尺度的特征层上进行设置anchor
        self.na = len(anchors[0]) // 2  # number of anchors   // 每一个grid上有的anchor的个数
        self.grid = [torch.zeros(1)] * self.nl  # init grid
        self.anchor_grid = [torch.zeros(1)] * self.nl  # init anchor grid
        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
        self.inplace = inplace  # use in-place ops (e.g. slice assignment)

其中register_buffer即将anchor注册到网络结构中的一步,在下面会详细讲解这个函数的作用。
head的最后输出会经过一层头部的卷积层(p3,p4,p5分别对应一个),这个卷积层的输入维度以yolov5s为例子的话,分别是(128,256,512),输出维度是(xywh + objectness + nc)* 3,其中3是每个grid上的锚框个数。以640* 640的输入为例,我们将三个锚框的维度移动到grid个数上,那么最后三个检测头的输出维度分别是:(bs, 80* 80 * 3, xywh + objectness + nc) ,(bs, 40* 40 * 3, xywh + objectness + nc),(bs, 20* 20 * 3, xywh + objectness + nc),将这三个头合起来之后就是网络的输出结果,即(bs,25200,xywh + objectness + nc)。


关于register_buffer

nn.Module.register_buffer 是 PyTorch 提供的方法,用于向模型中注册一个缓冲区(buffer)。这些缓冲区是与模型相关的固定数据,在模型训练和保存时非常有用,但它们不会参与梯度计算,也不会被优化器更新。
使用场景:
● 存储模型所需的固定参数(例如锚点、均值、方差等)。
● 在保存和加载模型时,确保这些数据一并存储和恢复。
● 用于在模型中共享某些不可训练的参数。

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 注册一个缓冲区
        self.register_buffer("my_buffer", torch.tensor([1.0, 2.0, 3.0]))

    def forward(self, x):
        # 使用缓冲区中的数据
        return x + self.my_buffer

# 创建模型实例
model = MyModel()
print("缓冲区内容:", model.my_buffer)

# 保存模型
torch.save(model.state_dict(), "model.pth")

# 加载模型
new_model = MyModel()
new_model.load_state_dict(torch.load("model.pth"))
print("加载后的缓冲区内容:", new_model.my_buffer)

示例说明:

  1. register_buffer 的作用:
    ○ self.register_buffer(name, tensor) 会将 tensor 注册为缓冲区,名称为 name。
    ○ 例如,代码中的 my_buffer 是模型的一部分,但它不会参与梯度计算。
  2. 访问缓冲区:
    ○ 缓冲区可以通过模型属性直接访问,例如 model.my_buffer。
  3. 保存与加载:
    ○ 缓冲区会被存储在模型的 state_dict 中,使用 torch.save 和 torch.load 保存/加载。

训练阶段的anchor使用

在训练阶段中,并没有在检测头中直接使用anchor,而是在其他两个地方使用anchor的信息,一个是正负样本匹配的过程中,一个是损失函数计算的过程中。
那么我们看一下,训练的时候,DetectionModel类到底做了什么吧:
parse_model之后便是拿到模型的最后一层,记为m,然后做个判断,是否是检测头或者分割头,是的话进入代码段中,然后以256 * 256的尺寸为例,输入_forward_once中进行一次模型的输出,得到了最后三个检测头的输出,遍历这三个输出,分别取特征图的尺寸,用输入除以特征图的尺寸得到stride张量,说实话个人感觉这一步有点脱裤子放屁的意思。然后就是检查anchor和stride的对应关系,以及将每一个检测头的anchor尺寸缩放到该特征图尺度上的对应大小。然后就是初始化偏置和初始化权重了:

# Build strides, anchors
m = self.model[-1]  # Detect()
if isinstance(m, (Detect, Segment)):

    def _forward(x):
        """Passes the input 'x' through the model and returns the processed output."""
        return self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)

        s = 256  # 2x min stride
    m.inplace = self.inplace
    m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))])  # forward
    check_anchor_order(m)
    m.anchors /= m.stride.view(-1, 1, 1)
    self.stride = m.stride
    self._initialize_biases()  # only run once

    # Init weights, biases
    initialize_weights(self)
self.info()
LOGGER.info("")

推理阶段的anchor使用

由于在我们的train.py中,我们首先是创建网络,此时还没有到model.train()的状态,因此在训练调试代码时,会先进入检测头的not self.training中,这里我们就可以看到在推理阶段是如何使用anchor来进行预测的。首先我们会使用_make_grid方法来对我们的三个尺度的特征图进行grid的和anchor的搭建,即预设好grid的左上角的框的坐标以及对应的anchor的大小。然后我们将这个预设好的grid点的左上角坐标和每个grid上的anchor给到模型最后的output,从output的最后一个维度拆分,拿到xy,wh,conf =(objectness,nc)一共三组结果,其中conf的结果是可以直接使用的,但是xy,wh还需要和我们预设好的grids和anchors,以及每个检测头的stride来得到最后的精确检测结果。最后将每个检测头上设置的anchor的个数乘到grid上,即代表了我们的结果一共有na * (80 * 80 + 40 * 40 + 20 * 20)个,然后输出的结果维度便是(1,25200,6)。以下代码是回归的计算方式:

         if not self.training:  # inference
                if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                if isinstance(self, Segment):  # (boxes + masks)
                    xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
                    xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
                else:  # Detect (boxes only)
                    xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
                    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, self.na * nx * ny, self.no))

如图所示:
在这里插入图片描述


三、训练时的正负样本匹配

首先,什么是正负样本匹配。正样本匹配即通过计算gt与被选中的grid和anchor来计算偏移量来调整网络的权重参数。正样本的匹配然后loss计算是为了让模型朝着更小的损失去迭代更新这样的权重,负样本的作用则是让模型的权重往更远离能检测出错误目标的权重方向迭代。

yolov5在正负样本匹配在v3, 和v4的基础上作出了改进,yolov5通过跨分支(检测头)和跨grid和跨anchor来匹配多个grid和多个anchor,目的就是增加正样本的数量。跨分支采样就是根据每个分支的相对位置来选择grid,确定了grid的范围之后,开始匹配anchor。
整个正负样本匹配分为静态策略和动态策略,静态策略是为了给正样本增样,动态策略则是在训练的过程中动态的引导网络关注高质量的正样本,给予一定的权重来加速收敛。

静态策略:

跨分支采样

在这里插入图片描述
即优化了yolo系列之前只在单分支上进行正样本采样的缺陷,实现在三个检测头上同时采样。

跨anchor采样

在YoloV5网络中,一共设计了9个不同大小的先验框。每个输出的特征层对应3个先验框。
对于任何一个真实框gt,YoloV5不再使用iou进行正样本的匹配,而是直接采用高宽比进行匹配,即使用真实框和9个不同大小的先验框计算宽高比。
如果真实框与某个先验框的宽高比例大于设定阈值,则说明该真实框和该先验框匹配度不够,将该先验框认为是负样本。
比如此时有一个真实框,它的宽高为[200, 200],是一个正方形。YoloV5默认设置的9个先验框为[10,13], [16,30], [33,23], [30,61], [62,45], [59,119], [116,90], [156,198], [373,326]。设定阈值门限为4。
此时我们需要计算该真实框和9个先验框的宽高比例。比较宽高时存在两个情况,一个是真实框的宽高比先验框大,一个是先验框的宽高比真实框大。因此我们需要同时计算:真实框的宽高/先验框的宽高;先验框的宽高/真实框的宽高。然后在这其中选取最大值。
下个列表就是比较结果,这是一个shape为[9, 4]的矩阵,9代表9个先验框,4代表真实框的宽高/先验框的宽高;先验框的宽高/真实框的宽高。

[[20.         15.38461538  0.05        0.065     ]
 [12.5         6.66666667  0.08        0.15      ]
 [ 6.06060606  8.69565217  0.165       0.115     ]
 [ 6.66666667  3.27868852  0.15        0.305     ]
 [ 3.22580645  4.44444444  0.31        0.225     ]
 [ 3.38983051  1.68067227  0.295       0.595     ]
 [ 1.72413793  2.22222222  0.58        0.45      ]
 [ 1.28205128  1.01010101  0.78        0.99      ]
 [ 0.53619303  0.61349693  1.865       1.63      ]]

我们自然可以看出[59,119], [116,90], [156,198], [373,326]是满足条件的anchor。也可以通过以下图例来理解这一过程:
在这里插入图片描述

跨grid采样

确定了满足条件的anchor之后,我们就该找具体是哪里的grid了。
在过去的Yolo系列中,grid的选择是看gt框的中心点所处的网格的坐上角(即当前grid)。对于yolov5而言,对于被选中的特征层,首先计算gt落在哪个网格内,此时该网格左上角特征点便是一个负责预测的特征点。同时利用四舍五入规则,找出最近的两个网格,将这三个网格都认为是负责预测该真实框的。如下图所示:
在这里插入图片描述
红色点表示该真实框的中心,除了当前所处的网格外,其2个最近的邻域网格也被选中。从这里就可以发现预测框的XY轴偏移部分的取值范围不再是0-1,而是0.5-1.5。
找到对应特征点后,对应特征点的刚才anchor匹配中被选中的anchor负责该真实框的预测。

动态策略

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2317433.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux 入门:权限的认识和学习

目录 一.shell命令以及运行原理 二.Linux权限的概念 1.Linux下两种用户 cannot open directory .: Permission denied 问题 2.Linux权限管理 1).是什么 2).为什么(权限角色目标权限属性) 3).文件访问者的分类(角色) 4).文…

搭建opensbi+kernel+rootfs及基本设备驱动开发流程

目录 一.编译qemu 运行opensbikernelrootfs 1.编译qemu-9.1.1 2.安装riscv64编译器 3. 编译opensbi 4.编译kernel 5.编译rootfs 设备驱动开发流程 1.安装 RISC-V 交叉编译工具链 2.驱动开发准备 3.编写简易中断控制器驱动(PLIC)​ 4.配置内核…

QT非UI设计器生成界面的国际化

目的 UI设计器生成界面的国际化,比较容易实现些,因为有现成的函数可以调用,基本过程如下: void MainWindow::on_actLang_CN_triggered() {//中文界面qApp->removeTranslator(trans);delete trans;transnew QTranslator;trans…

python | 输入日期,判断这一天是这一年的第几天

题目: 使用 python 编程,实现输入日期,判断这一天是这一年的第几天? 具体实现代码如下: import datetime year input(请输入年份:) month input(请输入月份:) day input(请输入天:) date…

单片机开发资源分析的实战——以STM32F103C8T6为例子的单片机资源分析

目录 第一点:为什么叫STM32F103C8T6 从资源手册拿到我们的对STM32F103C8T6的资源描述 第二件事情,关心我们的GPIO引脚输出 第三件事情:去找对应外设的说明部分 前言 本文章隶属于项目: Charliechen114514/BetterATK: This is…

Maven | 站在初学者的角度配置

目录 Maven 是什么 概述 常见错误 创建错误代码示例 正确代码示例 Maven 的下载 Maven 依赖源 Maven 环境 环境变量 CMD测试 Maven 文件配置 本地仓库 远程仓库 Maven 工程创建 IDEA配置Maven IDEA Maven插件 Maven 是什么 概述 Maven是一个项目管理和构建自…

【css酷炫效果】纯CSS实现3D翻转卡片动画

【css酷炫效果】纯CSS实现3D翻转卡片动画 缘创作背景html结构css样式完整代码效果图 想直接拿走的老板,链接放在这里:https://download.csdn.net/download/u011561335/90490472 缘 创作随缘,不定时更新。 创作背景 刚看到csdn出活动了&am…

并发编程面试题二

1、java线程常见的基本状态有哪些,这些状态分别是做什么的 (1)创建(New):new Thread(),生成线程对象。 (2)就绪(Runnable):当调用线程对象的sta…

Spring Cloud Stream - 构建高可靠消息驱动与事件溯源架构

一、引言 在分布式系统中,传统的 REST 调用模式往往导致耦合,难以满足高并发和异步解耦的需求。消息驱动架构(EDA, Event-Driven Architecture)通过异步通信、事件溯源等模式,提高了系统的扩展性与可观测性。 作为 S…

突破连接边界!O9201PM Wi-Fi 6 + 蓝牙 5.4 模块重新定义笔记本无线体验

在当今数字化时代,笔记本电脑已成为人们工作、学习和娱乐的必备工具。而无线连接技术,作为笔记本电脑与外界交互的关键桥梁,其性能的优劣直接关乎用户体验的好坏。当下,笔记本电脑无线连接领域存在诸多痛点,严重影响着…

Python----计算机视觉处理(Opencv:图像颜色替换)

一、开运算 开运算就是对图像先进行腐蚀操作, 然后进行膨胀操作。开运算可以去除二值化图中的小的噪点,并分离相连的物体。 其主要目的就是消除那些小白点 在开运算组件中,有一个叫做kernel的参数,指的是核的大小,通常…

一周学会Flask3 Python Web开发-SQLAlchemy查询所有数据操作-班级模块

锋哥原创的Flask3 Python Web开发 Flask3视频教程: 2025版 Flask3 Python web开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili 我们来新建一个的蓝图模块-班级模块,后面可以和学生模块,实现一对多的数据库操作。 blueprint下新建g…

Matlab 风力发电机磁悬浮轴承模型pid控制

1、内容简介 略 Matlab 174-风力发电机磁悬浮轴承模型pid控制 可以交流、咨询、答疑 2、内容说明 磁悬浮轴承具有无接触、无摩擦、高速度、高精度、能耗低、不需要需润滑无油污染、可靠性高、寿命长和密封等一系列显著的优点。将磁悬浮技术应用于风力发电机中可以降低风机切入…

FPGA中级项目1——IP核(ROM 与 RAM)

FPGA中级项目1——IP核(ROM 与 RAM) IP核简介 在 FPGA(现场可编程门阵列)设计中,IP 核(Intellectual Property Core,知识产权核)是预先设计好的、可重用的电路模块,用于实…

Matlab 基于专家pid控制的时滞系统

1、内容简介 Matlab 185-基于专家pid控制的时滞系统 可以交流、咨询、答疑 2、内容说明 略 在处理时滞系统(Time Delay Systems)时,使用传统的PID控制可能会面临挑战,因为时滞会导致系统的不稳定或性能下降。专家PID控制通过结…

Unity 笔记:在EditorWindow中绘制 Sorting Layer

在Unity开发过程中,可能会对旧资源进行批量修改,一个个手动修改费人费事,所以催生出了一堆批量工具。 分享一下在此过程中绘制 Sorting Layer 面板的代码脚本。 示意图: 在 EditorGUI 和 EditorGUILayer 中内置了 SortingLayerF…

2024浙江大学计算机考研上机真题

2024浙江大学计算机考研上机真题 2024浙江大学计算机考研复试上机真题 2024浙江大学计算机考研机试真题 2024浙江大学计算机考研复试机试真题 历年浙江大学计算机复试上机真题 历年浙江大学计算机复试机试真题 2024浙江大学计算机复试上机真题 2024浙江大学计算机复试机试真题 …

蓝桥杯嵌入式赛道复习笔记2(按键控制LED灯,双击按键,单击按键,长按按键)

硬件原理解释 这张图展示了一个简单的按键电路原理图,其中包含四个按键(PB0、PB1、PB2、PB3、PA0),每个按键通过一个10kΩ的上拉电阻连接到VDD(电源电压),并接地(GND)。 …

每天五分钟深度学习PyTorch:循环神经网络RNN的计算以及维度信息

本文重点 前面我们学习了RNN从何而来,以及它的一些优点,我们也知道了它的模型的大概情况,本文我们将学习它的计算,我们来看一下RNN模型的每一个时间步在计算什么? RNN的计算 ht-1是上一时刻的输出,xt是本时刻的输入,然后二者共同计算得到了ht,然后yt通过ht计算得到,…

Ubuntu docker安装milvusdb

一、安装docker 1.更新软件包 sudo apt update sudo apt upgrade sudo apt-get install docker-ce docker-ce-cli containerd.io查看是否安装成功 docker -v二、使用国内的镜像下载 milvusdb Docker中国区官方镜像: https://registry.docker-cn.com milvusdb/milvus - Doc…