YOLOV7学习记录之训练过程

news2025/1/12 3:02:36

在前面学习YOLOV7的过程中,我们已经学习了其网络结构,然而实际上YOLOV7项目的难点并不在于其网络模型而是在于其损失函数的设计,即如何才能训练出来合适的bbox。
神经网络模型都有训练和测试(推理)过程,在YOLOV7的训练过程中,包含模型构造,标签分配和损失函数计算。其中模型在前面以及讲过了。在测试过程中包含加载模型,损失函数计算,输出值解码,非极大值抑制,MAP计算等,今天我们先讲一下YOLOV7的训练过程。
在训练过程中主要用到这几个文件:
在这里插入图片描述

YOLOV7训练过程中的一个重要思想便是正样本匹配策略,其更像是YOLOV5与YOLOX的一个结合体,那么接下来我们结合代码看一下其匹配策略:

YOLOV5,V7匹配策略

yolov5,v7与yolov3、yolov4最大的不同就是v3与v4一个gt只会与一个正样本匹配,而v5,v7一个 gt 可以被分配给多个anchor,还有可能被分配到三个不同的特征图中的两个甚至三个。

匹配策略:这里指的是不带辅助头,论文中,将负责最终输出的Head为lead Head,将用于辅助训练的Head称为auxiliary Head。本博客不重点讨论,原因是论文中后面的结构实验实现提升比较有限(0.3个点)

主要是参考了YOLOV5 和YOLOV6使用的当下比较火的simOTA.

S1.训练前,会基于训练集中gt框,通过k-means聚类算法,先验获得9个从小到大排列的anchor框。(可选)

S2.将每个gt与9个anchor匹配:Yolov5为分别计算它与9种anchor的宽与宽的比值(较大的宽除以较小的宽,比值大于1,下面的高同样操作)、高与高的比值,在宽比值、高比值这2个比值中,取最大的一个比值,若这个比值小于设定的比值阈值,这个anchor的预测框就被称为正样本。一个gt可能与几个anchor均能匹配上(此时最大9个)。所以一个gt可能在不同的网络层上做预测训练,大大增加了正样本的数量,当然也会出现gt与所有anchor都匹配不上的情况,这样gt就会被当成背景,不参与训练,说明anchor框尺寸设计的不好。

S3.扩充正样本。根据gt框的中心位置,将最近的2个邻域网格也作为预测网格,也即一个groundtruth框可以由3个网格来预测;可以发现粗略估计正样本数相比前yolo系列,增加了三倍(此时最大27个匹配)。图下图浅黄色区域,其中实线是YOLO的真实网格,虚线是将一个网格四等分,如这个例子中,GT的中心在右下虚线网格,则扩充右和下真实网格也作为正样本。

S4.获取与当前gt有top10最大iou的prediction结果。将这top10 (5-15之间均可,并不敏感)iou进行sum,就为当前gt的k。k最小取1。

S5.根据损失函数计算每个GT和候选anchor损失(前期会加大分类损失权重,后面减低分类损失权重,如1:5->1:3),并保留损失最小的前K个。

S6.去掉同一个anchor被分配到多个GT的情况。

正负样本分配

正负样本分配的函数build_targets(yolo_training.py),将其分为以下结构:

├── 数据准备
└── 遍历每个特征图
        ├── ①anchors和gt匹配,看哪些gt是当前特征图的正样本(find_3_positive)初筛
        └── ②将当前特征图的正样本分配给对应的grid(完成复筛:iou,类别)

步骤1:anchors和gt匹配,看哪些gt是当前特征图的正样本**(find_3_positive)**
这里要做的是从gt的上下左右分别偏移0.5来获取周边的单元格来进行预测,通过计算anchor的长宽比例是否合适(比例位于1/4与4之间)则认为符合,那么当前gt就能与当前特征图匹配。
如图所示:这是lead head的正样本匹配策略
在这里插入图片描述
在这里插入图片描述
YOLOV7中引入了辅助头,其正样本为:
在这里插入图片描述
如图:训练时,lead head和aux head中正样本分配图示(蓝色点代表着gt所处的位置,实线组成的网格代表着特征图grid,虚线代表着一个grid分成了4个象限以进行正负样本分配。如果一个gt位于蓝点位置,那么在lead head中,黄色grid将成为正样本。在aux head中,黄色+橙色grid将成为正样本)

初筛(find_3_positive)

设置偏移方向与偏移大小

 g = 0.5  # offsets 漂移的距离,为获取更多正样本
        off = torch.tensor([#漂移方向
            [0, 0],
            [1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m
            # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
        ], device=targets.device).float() * g

在yolov5,v7中,会将一个特征点分为四个象限,针对步骤1中匹配的gt,会计算该gt(上图中蓝色点)处于四个象限中的哪一个,并将邻近的两个特征点也作为正样本。以上图举例,若gt偏向于右下角的象限,就会将gt所在grid的右边、下边特征点也作为正样本。

# 分别对应中心点、左、上、右、下
off = torch.tensor([[0, 0],
                    [1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m
                    ], device=targets.device).float() * g

# gain = [1, 1, 特征图w, 特征图_h, 特征图w, 特征图_h]
gxy = t[:, 2:4]  # 以特征图左上角为原点,gt的xy坐标
gxi = gain[[2, 3]] - gxy  # 以特征图左上角为原点,gt的xy坐标
# jklm就分别代表左、上、右、下是否能作为正样本。g=0.5
# j和l, k和m是互斥的,(x,y)%1会得到两个值所以其最终可以组成四个方位
j, k = ((gxy % 1 < g) & (gxy > 1)).T
l, m = ((gxi % 1 < g) & (gxi > 1)).T
j = torch.stack((torch.ones_like(j), j, k, l, m))#组成五维
# 原本一个gt只会存储一份,现在复制成3份 拼接函数
t = t.repeat((5, 1, 1))[j]
# 偏移量
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]

关于上面的代码语法上理解如下
在这里插入图片描述
这里可以看出V5和V7是很类似的,相比较yolov3和v4一个gt只会匹配一个正样本的方式,该种方法能够分配更多的正样本,有助于训练加速,正负样本平衡。
完成增加正样本后(找到了t个先验框(正样本),我们需要判断这些框是属于哪张图片,其属于哪个类别,负责该样本预测的单元格的左上坐标为多少,以及该坐标的w,h的缩放比例

 # -------------------------------------------#
            #   b   代表属于第几个图片,即每个t属于的图片
            #   gxy 代表该真实框所处的x、y中心坐标
            #   gwh 代表该真实框的wh坐标
            #   gij 代表真实框所属的特征点坐标
            # -------------------------------------------#
            b, c = t[:, :2].long().T  # image, class
            gxy = t[:, 2:4]  # grid xy
            gwh = t[:, 4:6]  # grid wh
            gij = (gxy - offsets).long()#.long是取值,不要小数部分,如gxy(2.3,2.2)左移-0.5则为(1.8,2.2)取值(1,2)即由(1,2)的anchor来进行匹配,获得偏移后负责预测的单元格
            gi, gj = gij.T  # grid xy indices

            # -------------------------------------------#
            #   gj、gi不能超出特征层范围
            #   a代表属于该特征点的第几个先验框
            # -------------------------------------------#
            a = t[:, 6].long()  # anchor indices
            indices.append(
                (b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))  # image, anchor, grid indices
            anchors.append(anchors_i[a])  # anchors比例

而且由于每一个特征图中,都会将所有的gt与当前特征图的anchor计算能否分配正样本,也就说明一个gt可能会在多个特征图中都分配到正样本。
find_3_positive的返回结果为:
indices 的shape为:【3,4,正样本个数】
在这里插入图片描述
anchor的shape为:
在这里插入图片描述

复筛get_target()

至此,我们便完成了正样本的匹配即初筛工作,紧接着我们要对初筛得到的先验框进行复筛,此时便是要根据predictions的预测先验框与真实框计算IOU与类别进行复筛。

 #   取出这个真实框对应的预测结果
                # -------------------------------------------#
                fg_pred = prediction[b, a, gj, gi]
                #判断是否是物体与类别符合
                p_obj.append(fg_pred[:, 4:5])
                p_cls.append(fg_pred[:, 5:])

                # -------------------------------------------#
                #   获得网格后,进行解码,这里需要按照步长进行恢复,并得到我们的预测恢复结果
                # -------------------------------------------#
                grid = torch.stack([gi, gj], dim=1).type_as(fg_pred)
                pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i]
                pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i]
                pxywh = torch.cat([pxy, pwh], dim=-1)
                pxyxy = self.xywh2xyxy(pxywh)#该函数将xywh转化为左上,右下坐标形式
                pxyxys.append(pxyxy)

计算当前图片中,真实框与预测框的重合程度
ou的范围为0-1,取-log后为0~inf
重合程度越大,取-log后越小
因此,真实框与预测框重合度越大,pair_wise_iou_loss越小,所得为(真实框个数*候选框个数)

 pair_wise_iou = self.box_iou(txyxy, pxyxys)
 pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8)

通过计算IOU后,选择前20,若没有20,则有多少选多少

 top_k, _ = torch.topk(pair_wise_iou, min(20, pair_wise_iou.shape[1]), dim=1)
            dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1) 
           #   gt_cls_per_image    种类的真实信息,转换为one -hot格式,复制操作           
            gt_cls_per_image = F.one_hot(this_target[:, 1].to(torch.int64), self.num_classes).float().unsqueeze(
                1).repeat(1, pxyxys.shape[0], 1)

预测类别并计算交叉熵

 cls_preds_ = p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * p_obj.unsqueeze(0).repeat(num_gt,
                                                                                                    1,    1).sigmoid_()
 y = cls_preds_.sqrt_()
 pair_wise_cls_loss = F.binary_cross_entropy_with_logits(torch.log(y / (1 - y)), gt_cls_per_image,
                                                            reduction="none").sum(-1)

求cost损失总和,topk函数:找前k个最大值

 cost = ( pair_wise_cls_loss+ 3.0 * pair_wise_iou_loss)
 matching_matrix = torch.zeros_like(cost)
 for gt_idx in range(num_gt):#从真实框中去找这里面损失最小的k个
                _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item())
                matching_matrix[gt_idx][pos_idx] = 1.0

在这里插入图片描述

为防止一个anchor预测多个gt,还需要将其转换一下取出最小的iou作为y预测

 anchor_matching_gt = matching_matrix.sum(0)#sum(0)求数组每一列的和
            if (anchor_matching_gt > 1).sum() > 0:#找出哪些sum>0,说明一个anchor正样本匹配到了多个gt
                _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)#找到最小的
                matching_matrix[:, anchor_matching_gt > 1] *= 0.0#其余赋值0
                matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0#最小的赋值1
            fg_mask_inboxes = matching_matrix.sum(0) > 0.0
            fg_mask_inboxes = fg_mask_inboxes.to(torch.device(device))#哪些是正样本
            matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)#正样本对应的真实框索引

最终我们将匹配的批次,我们得到的值为:
matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
下面来讲解其值含义:
matching_bs:匹配上的批次
在这里插入图片描述
matching_as:匹配上的anchor id [0,1,2]
在这里插入图片描述
matching_gjs, matching_gis:匹配上的单元格xy坐标(负责正样本预测)
在这里插入图片描述
在这里插入图片描述
matching_targets:匹配上的标签,通过标签中批次,xywh(真实框)可以与前面匹配上的anchor进行计算。在find-3-positive中其加上了anchorid,但在这里没有用到被删除了。
在这里插入图片描述
在这里插入图片描述
matching_anchs:匹配上的anchor缩放比例
在这里插入图片描述

计算损失

完成build-targe函数后获得上面提到的匹配的正样本信息:call函数中

 bs, as_, gjs, gis, targets, anchors = self.build_targets(predictions, targets, imgs)

开始计算损失:

 for i, prediction in enumerate(predictions):
            # -------------------------------------------#
            #   image, anchor, gridy, gridx
            # -------------------------------------------#
            b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i]
            tobj = torch.zeros_like(prediction[..., 0], device=device)  # target obj

            # -------------------------------------------#
            #   获得目标数量,如果目标大于0
            #   则开始计算种类损失和回归损失
            # -------------------------------------------#
            n = b.shape[0]
            if n:
                prediction_pos = prediction[b, a, gj, gi]  # prediction subset corresponding to targets

                # -------------------------------------------#
                #   计算匹配上的正样本的回归损失
                # -------------------------------------------#
                # -------------------------------------------#
                #   grid 获得正样本的x、y轴坐标
                # -------------------------------------------#
                grid = torch.stack([gi, gj], dim=1)
                # -------------------------------------------#
                #   进行解码,获得预测结果,这里可以看到与build_target中是遥相呼应的是相同的计算方式
                # -------------------------------------------#
                xy = prediction_pos[:, :2].sigmoid() * 2. - 0.5
                wh = (prediction_pos[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
                box = torch.cat((xy, wh), 1)
                # -------------------------------------------#
                #   对真实框进行处理,映射到特征层上
                # -------------------------------------------#
                selected_tbox = targets[i][:, 2:6] * feature_map_sizes[i]
                selected_tbox[:, :2] -= grid.type_as(prediction)
                # -------------------------------------------#
                #   计算预测框和真实框的回归损失
                # -------------------------------------------#
                iou = self.bbox_iou(box.T, selected_tbox, x1y1x2y2=False, CIoU=True)
                box_loss += (1.0 - iou).mean()
                # -------------------------------------------#
                #   根据预测结果的iou获得置信度损失的gt,使用iou来代替置信度
                # -------------------------------------------#
                tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype)  # iou ratio

                # -------------------------------------------#
                #   计算匹配上的正样本的分类损失
                # -------------------------------------------#
                selected_tcls = targets[i][:, 1].long()
                t = torch.full_like(prediction_pos[:, 5:], self.cn, device=device)  # targets
                t[range(n), selected_tcls] = self.cp
                cls_loss += self.BCEcls(prediction_pos[:, 5:], t)  # BCE

            # -------------------------------------------#
            #   计算目标是否存在的置信度损失
            #   并且乘上每个特征层的比例
            # -------------------------------------------#
            obj_loss += self.BCEobj(prediction[..., 4], tobj) * self.balance[i]  # obj loss

        # -------------------------------------------#
        #   将各个部分的损失乘上比例
        #   全加起来后,乘上batch_size
        # -------------------------------------------#
        box_loss *= self.box_ratio
        obj_loss *= self.obj_ratio
        cls_loss *= self.cls_ratio
        bs = tobj.shape[0]

        loss = box_loss + obj_loss + cls_loss
        return loss

根据模型预测结果计算回归值,回归值计算公式
在这里插入图片描述
至此,YOLOV7的正样本匹配与损失函数计算过程便完成了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/110615.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

QT JS交互、调用JS、传值

本文详细的介绍了QT JS交互、调用JS、传值的各种操作&#xff0c;包括QT向JS传递String字符串、包括QT向JS传递Int数字、包括QT向JS传递List数组&#xff0c;同时也接收JS向QT返回的List数组、JS向QT返回的Json、JS向QT返回的数字、JS向QT返回的字符串。 本文作者原创&#xff…

Vue基础8之Vue组件化编程、非单文件组件与单文件组件

Vue基础8Vue组件化编程对组件的理解一些概念的理解非单文件组件基本使用几个注意点组件的嵌套VueComponent一个重要的内置关系先导篇&#xff1a;原型对象正文&#xff08;可以理解为类的继承&#xff09;单文件组件Vue组件化编程 对组件的理解 传统方式&#xff1a; 使用组…

计算机网络-交换方式

目录电路交换&#xff08;Circuit Switching&#xff09;分组交换&#xff08;Packet Switching&#xff09;报文交换&#xff08;Message Switching&#xff09;电路交换、报文交换、分组交换的对比电路交换&#xff08;Circuit Switching&#xff09; 在电话问世后不久&#…

扫雷游戏的设计(百分百还原电脑操作)

目录 &#x1f332;了解扫雷游戏的作用原理并梳理思路 &#x1f332;扫雷游戏前期部分完善 &#x1f337;文件的创建 &#x1f337;创建菜单&#xff0c;完善主函数 &#x1f333;代码呈现&#xff1a; &#x1f332;扫雷游戏主题内容 &#x1f334;第一步初始化棋盘 &#x1…

Gradle中如何修改Springboot引入的依赖版本

扫描漏洞升级 不知道各位是否遇到过以下问题&#xff1a; 当下层项目将spring引入的某个依赖版本升级之后&#xff0c;上层项目只要指定了Springboot版本&#xff0c;那么还是会将这个版本改回去&#xff1f; 比如&#xff1a;现在有两个Springboot项目A、B&#xff0c;B项目…

Git安装和配置

GitGitee 官网安装配置教程&#xff1a;https://gitee.com/help/articles/4104本文是以官网教程为基础而展开的实践笔记。初学者可以以本文为引入&#xff0c;但建议最终以官方文档为最终深入学习的参考。一、 下载和安装Git 1、官网下载&#xff1a;https://git-scm.com 如果对…

HTML5基础

HTML5 文章目录HTML5概述开发工具浏览器开发软件DemoHTML5语法HTML5标签HTML5标签属性HTML5文档注释HTML5文档结构头部内容主体内容DemoHTML5常见标签常见块级标签标题标签水平线标签段落标签换行标签引用标签预格式标签无序列表标签有序列表标签定义列表标签分区标签常见行级标…

【Java寒假打卡】Java基础-继承

【Java寒假打卡】Java基础-继承一、继承的好处和弊端二、继承的成员变量访问特点三、重写方法四、方法重写的注意事项五、权限修饰符六、构造方法一、继承的好处和弊端 继承的好处 提高了代码的复用性 提高了代码的维护性 让类和类之间产生了关系 是多态的前提 继承的弊端 …

Flink-使用filter和SideOutPut进行分流操作

文章目录1.什么是分流&#xff1f;2. 过滤器(filter)3. 使用侧输出流&#xff08;SideOutput&#xff09;&#x1f48e;&#x1f48e;&#x1f48e;&#x1f48e;&#x1f48e; 更多资源链接&#xff0c;欢迎访问作者gitee仓库&#xff1a;https://gitee.com/fanggaolei/learni…

四、网络层(七)网络层设备

目录 7.1 路由器的组成和功能 7.2 路由表与路由转发 7.1 路由器的组成和功能 路由器是一种具有多个输入/输出端口的专用计算机&#xff0c;其任务是连接不同的网络&#xff08;可以是异构的&#xff09;并完成路由转发。在多个逻辑网络&#xff08;即多个广播域&#xff…

Vulnhub靶机:HACKADEMIC_ RTB1

目录介绍信息收集主机发现主机信息探测网站探测Sql注入挂马提权介绍 系列&#xff1a;Hackademic&#xff08;此系列共2台&#xff09; 发布日期&#xff1a;2011年9月6日 难度&#xff1a;初级 运行环境&#xff1a;VMware Workstation 目标&#xff1a;取得 root 权限 flag…

5W2H分析法

什么是5W2H 5W2H分析法又叫七何分析法&#xff0c;是二战中美国陆军兵器修理部首创。简单、方便&#xff0c;易于理解、使用&#xff0c;富有启发意义&#xff0c;广泛用于企业管理和技术活动&#xff0c;对于决策和执行性的活动措施也非常有帮助&#xff0c;也有助于弥补考虑…

【UE4 第一人称射击游戏】07-添加“AK47”武器

素材资料地址&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1epyD62jpOZg-o4NjWEjiyg 密码&#xff1a;jlhr 效果&#xff1a; 步骤&#xff1a; 1.打开“WalkRun_BS”&#xff0c;将内插时间改为1 2.创建一个文件夹&#xff0c;命名为“Weapons” 进入“Weapons”…

可视化数据图表-FineReportJS实现清空控件内容

1. 概述 1.1 问题描述 在使用查询控件时&#xff0c;有时我们希望能够快捷重置控件的内容&#xff0c;或者重置所有控件的内容。效果如下图所示&#xff1a; 重置某个控件的内容&#xff1a;1.2 实现思路 在使用查询控件时&#xff0c;有时我们希望能够快捷重置控件的内容&a…

H3C 二层链路聚合

简介&#xff1a; 它通过将多条以太网物理链路捆绑在一起成为一条逻辑链路&#xff0c;从而实现增加链路带宽的目的。 成员端口&#xff1a; 选中&#xff08;Selected&#xff09;状态&#xff1a;此状态下的成员端口可以参与用户数据的转发&#xff0c;处于此状态的成员端口…

绝!OpenAI 年底上新,单卡 1 分钟生成 3D 点云,text-to 3D 告别高算力消耗时代

内容一览&#xff1a;继 DALL-E、ChatGPT 之后&#xff0c;OpenAI 再发力&#xff0c;于近日发布 PointE&#xff0c;可以依据文本提示直接生成 3D 点云。 关键词&#xff1a;OpenAI 3D 点云 PointE OpenAI 年底冲业绩&#xff0c;半个多月前发布的 ChatGPT 广大网友还没…

政务行业势能厂商 |美创科技入选《嘶吼2022中国网络安全产业势能榜》

近日&#xff0c;网络安全垂直媒体嘶吼网络安全产业研究院正式发布《嘶吼2022中国网络安全产业势能榜》评选结果。凭借在政务数据安全领域的服务深耕以及广泛的市场认可&#xff0c;美创科技入选势能榜“政务篇”&#xff0c;获评政务行业“专精型”安全厂商。 嘶吼安全产业研究…

Apache 之执行 CGI 脚本(Python 实现)

目录前言1 查看并挑选 Python 版本2 用 Python 实现一个简单的 CGI 脚本3 查看 CGI 环境变量总结前言 本文记录了一个搭建 CGI 环境的示例。前文推荐&#xff1a;《Apache 2.4.54 x64 安装及配置》。 【系统环境】 Win10-64bit Apache 2.4.54 x64 Python 3.11.1 1 查看并挑选…

PyInstaller的常用打包命令

学习了pyqt后&#xff0c;设计了界面&#xff0c;并且需要打包为exe程序。 每次打包时&#xff0c;都要查好久资料&#xff0c;故此记录一下常用的命令。 PyInstaller 是一个 Python 应用程序打包工具&#xff0c;它可以将 Python 程序打包为单个独立可执行文件。 要使用 P…

2022星空创造营应用创新大赛圆满落幕,获奖名单出炉!

​12月22日&#xff0c;2022星空创造营应用创新大赛在2022手机创新周暨第十届手机设计大赛颁奖典礼上作为特别专场正式公布获奖名单。2022星空创造营应用创新大赛由联通在线、手机设计大赛天鹅奖组委会联合主办&#xff0c;联通在线音乐公司及工信部赛迪研究院共同承办&#xf…