yolo 训练

news2025/1/21 12:17:15

这里写目录标题

  • 分配训练集&Validation数量
  • 数据集读取
    • 读取全部文件夹
    • 替换路径
  • loss weight
  • NMS
  • BBox_IOU
    • EIou
  • Optimizer

分配训练集&Validation数量

validation_size = training_size * validation_ratio / (1 - validation_ratio)

training_size = 219
validation_ratio = 0.2
validation_size = 219*0.2/(1-0.2)

如果你有 346 张验证图像,使用 k=5 的交叉验证方法,你可以将这些图像分成 5 个不同的折叠(fold),每个折叠包含 69 或 70 张图像。

平均分配图像的方法:
num_images_per_fold = len(val_images) // num_folds

from sklearn.model_selection import KFold

# 假设你有 291 个训练图像和 55 个验证图像
train_images = range(291)
valid_images = range(291, 291+55)

# 将训练图像分成五个部分
kf = KFold(n_splits=5, shuffle=True)
for fold, (train_idx, valid_idx) in enumerate(kf.split(train_images)):
    # 选择一个部分作为验证集,其余部分作为训练集
    train_images_fold = [train_images[i] for i in train_idx]
    valid_images_fold = [train_images[i] for i in valid_idx]

    # 在每次交叉验证中,使用训练集进行训练,并使用验证集进行验证
    # TODO: 训练和验证模型

    # 记录模型的性能指标
    # TODO: 记录模型性能指标

# 将五个结果的平均值作为模型的性能指标
# TODO: 计算模

数据集读取

读取全部文件夹

p = Path(p) # p = WindowsPath('E:/data/helmet_head/train')
glob.glob(str(p / '**' / '*.*'), recursive=True) 

这行 Python 代码使用了 pathlib 模块中的 WindowsPath 类来创建一个 Windows 路径对象 p,表示了一个名为 train 的目录,该目录位于 E:/data/helmet_head/ 目录下。接下来,使用 glob 函数来获取该目录及其所有子目录中的所有文件(包括子目录中的文件)。

p / '**' / '*.*' 表示将 p 对象的路径添加上 '**'(表示所有子目录),然后再添加上 '*.*'(表示所有类型的文件)路径,得到一个包含通配符的字符串路径。这个字符串路径会被转换为一个 WindowsPath 对象并传递给 glob 函数。
glob(str(p / '**' / '*.*'), recursive=True) 表示使用 glob 函数获取符合给定路径模式的文件列表。recursive=True 表示要递归地查找子目录中的文件。

替换路径

x = 'E:\\data\\helmet_head\\train\\collect20120420\\JPEGImages\\000000.jpg'
sa = '\\JPEGImages\\'
sb = '\\Annotations\\'
sb.join(x.rsplit(sa, 1)) 

将指定路径 x 中的子目录名称 ‘JPEGImages’ 替换为 ‘Annotations’

x.rsplit(sa, 1) 使用 rsplit 函数将路径 x 按照指定的子目录名称 ‘JPEGImages’ 进行分割,并将分割结果作为一个列表返回。其中,sa 是分割字符串,1 表示只分割一次,即只分割最后一次出现的位置。
x.rsplit(sa, 1) 的结果为:
[‘E:\data\helmet_head\train\collect20120420’, ‘000000.jpg’]。

'\\Annotations\\'.join(x.rsplit(sa, 1)) 使用 ‘\Annotations\’ 字符串将分割后的列表中的元素连接起来,得到一个新的路径。这个路径将原路径中的子目录名称 ‘JPEGImages’ 替换为 ‘Annotations’。
‘\Annotations\’.join(x.rsplit(sa, 1)) 的结果为:

'E:\data\helmet_head\train\collect20120420\\Annotations\\000000.jpg’

loss weight

Pytorch: BCEWithLogitsLoss

YOLO v5 中的loss weight 在data/hyp/scratch.yaml 中的cls_pw,同时 obj_pw 也可以使用同样的vector
通过train.py 中:

model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc  # attach class weights
weights = np.bincount(classes, minlength=nc)
weights[weights == 0] = 1  # replace empty bins with 1
weights = 1 / weights  # number of targets per class
weights /= weights.sum()
  • 得到3个种类在数量上应该受到的提高为: tensor([0.12096, 2.43440, 0.44464])
  • 其中数字越大的代表数据量越少
target = torch.ones([4,3,8,8], dtype=torch.float32)
output = torch.full([4,3,8,8], 1.5)
pos_weight = torch.ones([3,8,8])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target)

这里的pos_weight = torch.ones([3,8,8]) 一定确保倒数的维度和output/target 相同


初始的cls_pw 是一个scalar, 如下:

target = torch.ones([4,3,8,8], dtype=torch.float32)
output = torch.full([4,3,8,8], 1.5)
pos_weight = torch.ones([1])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target)

NMS

以下是torch 版的 NMS, 代替了torchvision.ops.nms(boxes, scores, iou_thres)

def nms(bboxes, scores, iou_thresh=0.5):
    _, order = scores.sort(0, descending=True)
    keep = []

    while order.numel() > 0:
        if order.numel() == 1:  # 保留框只剩一个
            i = order.item()
            keep.append(i)
            break
        else:
            i = order[0].item()  # 保留scores最大的那个框box[i]
            keep.append(i)

        iou = bbox_iou_new(bboxes[i], bboxes[order[1:]]).squeeze()

        idx = (iou <= iou_thresh).nonzero().squeeze()  # 注意此时idx为[N-1,] 而order为[N,]
        if idx.numel() == 0:
            break
        order = order[idx + 1]  # 修补索引之间的差值

    return keep

以下是计算bbox的,这个function 被用于替换 NMS中计算IOU。这样做可以帮助做些关于IOU相关的ablation 分析

def bbox_iou_new(box1, box2, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, Focal=True, eps=1e-7):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
    b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
    w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
    w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    iou = inter / union
    if CIoU or DIoU or GIoU or EIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU or EIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + eps))
                return iou - (rho2 / c2 + v * alpha)  # CIoU
            elif EIoU:
                rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
                rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
                cw2 = cw ** 2 + eps
                ch2 = ch ** 2 + eps
                if Focal:
                    gamma = 0.5
                    return (iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2)) * torch.pow(inter / (union + eps), gamma).mean()  # Focal_EIou
                return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2)
            return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    elif SIoU:
        # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
        s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
        s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
        sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
        sin_alpha_1 = torch.abs(s_cw) / sigma
        sin_alpha_2 = torch.abs(s_ch) / sigma
        threshold = pow(2, 0.5) / 2
        sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
        angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
        rho_x = (s_cw / cw) ** 2
        rho_y = (s_ch / ch) ** 2
        gamma = angle_cost - 2
        distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
        omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
        omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
        shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
        return iou - 0.5 * (distance_cost + shape_cost)
    return iou  # IoU

BBox_IOU

CIOU 弥补了 GIOU 只考虑重合不考虑临近但也可用的框。但是CIOU中的那几个criteria 依然还有缺点:
在这里插入图片描述

  1. v 不太靠谱

    • 如果 w, h 的量级为 ground-truth 的整倍数,那么v 为 0 (无效,一样大才是想要的,但不是倍数大) 比如 K倍:
    • w = k w g t , h = k w g t w= kw^{gt}, h = kw^{gt} w=kwgt,h=kwgt v = 4 π 2 ( a r c t a n w g t h g t − a r c t a n k w g t k h g t ) 2 = 0 v = \frac{4}{\pi ^2}(arctan \frac{w^{gt}}{h^{gt}}-arctan\frac{kw^{gt}}{kh^{gt}})^2 = 0 v=π24(arctanhgtwgtarctankhgtkwgt)2=0
  2. 还是 v 的这一步在gradient 这里引起的不靠谱

    • w, h 会有相反的符号, 当 w 和 h 都比 ground truth 大/小 时,两个量按理也应该同时扩大/缩小。
    • 符号导致对w,h 处理不公。
    • 细节原因:
      做partial gradient 后,w, h 会因为 V中 原本用来算宽高比之差的地方,导致各自在gradient 时,遭遇不公
      这里的v 如果不放进训练过程,倒也还是make sense, 可以看作是宽高比的 norm-2 的计算。但放进训练,就要搞gradient,也就是搞partial gradient,就从🆗到不太行了。

EIou

Optimizer

Adam 与 AdamW 都是用于Yolo v5 中的optimiser. 他们在小数据集上可以很快降低loss, 但随着训练增加,他们不如SGD 会 平稳,反而会 oscillation.

Adam 可以看作是extend 自 L2 regularisation 的 optimiser. (Pytorch) (Paper)

请添加图片描述

AdamW 可以看作是extend自 weight decay 的 optimiser. (Pytorch) (Paper)

请添加图片描述
论文给 SGD和 Adam 都试了 weight decay 和 L2 regularization.红色的是传统使用 L2 regularization 做法, 绿色是使用weight decay的做法。

L2 regulrisation 都是针对Gradient做。
而 Weight decay 是在对 parameters 做 update时做。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/527365.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Java+SpringMVC+vue+element实现前后端分离校园失物招领系统详细设计

基于JavaSpringMVCvueelement实现前后端分离校园失物招领系统详细设计 博主介绍&#xff1a;5年java开发经验&#xff0c;专注Java开发、定制、远程、指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java项目精品实战案例《500套》 欢迎点赞 收藏 ⭐留言 文末获…

单轴丝杠平台实现搬运功能

1. 功能说明 本文示例将实现R279样机单轴丝杠平台搬运的功能。 该机构是由一个丝杠模组和一个 舵机关节模组 组合而成&#xff0c;关节模组上安装了一个电磁铁。 注意限位开关【①触碰传感器、②近红外传感器】的安装位置&#xff1a; 2. 丝杠传动机构原理 丝杠传动机构是一个将…

基于海思Hi3531 ARM+K7 FPGA高性能综合视频图像处理平台

板卡概述 XM703是自主研制的一款基于PCIE总线架构的高性能综合视频图像处理平台&#xff0c;该平台采用Xilinx的高性能Kintex UltraScale系列FPGA加上华为海思的高性能视频处理器来实现。 华为海思的HI3531DV200是一款集成了ARM A53四核处理器性能强大的神经网络引擎&#xff…

最新域名查询-中文域名注册到期查询软件

最新域名查询 最新域名查询指的是查询最新注册的域名或者快速确认某个域名是否被注册等相关信息的工具。以下是一些常用的最新域名查询工具&#xff1a; 域名Whois查询工具&#xff1a;Whois查询是一种查询域名注册信息的方式&#xff0c;可以查询已经注册的域名的所有信息&am…

http/https

http 基本概念 超文本传输协议&#xff0c;是互联网应用最广泛的协议之一&#xff0c;用于从 WWW 服务器传输超文本到本地浏览器的传输协议&#xff0c;它可以使浏览器更加高效&#xff0c;使网络传输减少。 https 基本概念 HTTPS是HTTP over SSL的简称&#xff0c;即工作…

LTI连续线性时不变系统能控性证明(格拉姆判据、秩判据)

一、能控性和能达性 1.1、能控性和能达性的定义 能控性&#xff1a;如果在一个有限的时间间隔内&#xff0c;可以用幅值没有限制的输入作用&#xff0c;使偏离系统平衡状态的某个初始状态回复到平衡状态&#xff0c;就称这个初始状态是能控的。 能达性&#xff1a;系统在外控…

【网红营销】海外网红营销怎么做?及注意事项?

随着互联网的发展和全球化的进程&#xff0c;海外网红营销逐渐成为企业推广产品和服务的重要途径。海外网红可以借助其社交媒体平台上的影响力&#xff0c;帮助企业扩大品牌知名度、提升销售业绩。然而&#xff0c;海外网红营销存在着一定的挑战和风险&#xff0c;企业需要制定…

探索将大语言模型用作推荐系统

编者按&#xff1a;目前大语言模型主要问答、对话等场景&#xff0c;进行被动回答。是否可以将大模型应用于推荐系统&#xff0c;进行主动推送呢&#xff1f; 这篇文章回顾了可以将大模型作为推荐系统的理论基础&#xff0c;并重点描述了基于英文和阿拉伯语的购物数据集微调T5-…

菜鸟健身-新手使用哑铃锻炼手臂的动作与注意事项

目录 一、前言 二、哑铃锻炼手臂的好处 三、哑铃锻炼手臂的注意事项 四、哑铃锻炼手臂的基本动作 1. 哑铃弯举 2. 哑铃推举 3. 哑铃飞鸟 五、哑铃锻炼手臂的进阶动作 1. 哑铃侧平举 2. 哑铃俯身划船 六、哑铃锻炼手臂的训练计划 七、总结 一、前言 哑铃是一种非常…

2023年5月天津/南京/成都/深圳CDGA/CDGP数据治理认证报名

6月18日DAMA-CDGA/CDGP数据治理认证考试开放报名中&#xff01; 考试开放地区&#xff1a;北京、上海、广州、深圳、长沙、呼和浩特、杭州、南京、济南、成都、西安。其他地区凑人数中… DAMA-CDGA/CDGP数据治理认证班进行中&#xff0c;报名从速&#xff01; DAMA认证为数据管…

MySQL 日志管理与恢复

MySQL日志管理 MySQL的默认日志保存位置为/usr/local/mysql/data 日志开启方式有两种&#xff1a;通过配置文件或者是通过命令 通过命令修改开启的日志是临时的&#xff0c;关闭或重启服务后就会关闭 MySQL日志管理 日志的分类 1.错误日志 用来记录当MySQL启动、停止或运行时…

畅购商城4.0

畅购商城4.0 1.走进电商 1.1电商行业分析 近年来&#xff0c;世界经济正向数字化转型&#xff0c;大力发展数字经济成为全球共识。党的十九大报告明确提出要建设“数字中国”“网络强国”&#xff0c;我国数字经济发展进入新阶段&#xff0c;市场规模位居全球第二&#xff0c;数…

DC-DC直流隔离升压电源模块高压稳压可调输出12v24v48v转60V80V110V150V200V220V250V300V400V500V

特点 效率高达 80%以上1*2英寸标准封装单电压输出价格低稳压输出工作温度: -40℃~85℃阻燃封装&#xff0c;满足UL94-V0 要求温度特性好可直接焊在PCB 上 应用 HRB W2~40W 系列模块电源是一种DC-DC升压变换器。该模块电源的输入电压分为&#xff1a;4.5~9V、9~18V、及18~36V、…

我们拆了一款将ChatGPT“落地”的AI语音交互机器人,八核A7全志R58主控

视频版本拆机&#xff1a;【60块钱&#xff0c;垃圾佬的第一台机器人&#xff0c;国产8核CPU全志R58】 https://www.bilibili.com/video/BV1Qk4y177ja/?share_sourcecopy_web&vd_source6ec797f0de1d275e996fb7de54dea06b 公子小白是一对由狗尾草智能科技推出的人工智能机…

Pytorch代码——持续更新

1 连续两个argsort 返回张量中每个元素对应的排名 torch.argsort(torch.argsort(pred, dim1, descendingTrue),dim1,descendingFalse) 例子 使用一个argsort后得到的是张量中按列降序排序后的索引&#xff0c; 再使用一个argsort后是张量中每一个元素的排名。 例如第2行中…

港联证券|股票分批技巧是什么?分批买进的手续费如何计算?

股票分批是股市中常用操作&#xff0c;根基股票的波动不同&#xff0c;将资金分批投资在不同股价还在时间上。那么股票分批技巧是什么&#xff1f;分批买进的手续费如何计算&#xff1f;下面就由港联证券为大家分析&#xff1a; 股票分批技巧是什么&#xff1f; 1、补仓股票选…

国药集团蒸汽表内网图像识别案例

一、项目需求 项目背景&#xff1a;国药集团MES系统硬件仪表数据采集项目 为了实现现场蒸汽表计数据的采集和存储&#xff0c;我们提供了本地内网图像离线识别方案&#xff0c;它可以在不接线的情况下实现对现场蒸汽表计数据的采集&#xff0c;并通过485接口将数据传输到客户内…

Facebook商店和亚马逊店铺:双管齐下,实现多渠道销售

在当今数字化时代&#xff0c;电子商务已成为商业领域中不可或缺的一部分。随着消费者购物行为的转变&#xff0c;企业需要利用多种渠道来吸引潜在客户并增加销售额。 在这个过程中&#xff0c;Facebook商店和亚马逊店铺成为了两个备受关注的选择。本文将深入探讨如何通过同时…

基于Web智慧工业园3D可视化安全生产管控系统

建设背景 随着经济飞速发展和产业创新升级&#xff0c;作为新经济形式的重要载体&#xff0c;工业园区污染严重、安全生产难以监管等问题日益突出。工业园区作为工业高质量发展的重要载体和平台&#xff0c;工厂聚集&#xff0c;安全生产风险集中&#xff0c;在这个背景下&…

数据结构(堆)

文章目录 一、概念二、堆的使用三、PriorityQueue 介绍3.1 PriorityQueue 的特性3.2 PriorityQueue 的方法3.3 集合框架中PriorityQueue的比较方式 四、堆的应用 一、概念 1.什么是优先级队列 队列是一种先进先出(FIFO)的数据结构&#xff0c;但有些情况下&#xff0c;操作的数…