【目标检测】YOLOv5算法实现(八):模型验证

news2024/9/29 21:31:30

  本系列文章记录本人硕士阶段YOLO系列目标检测算法自学及其代码实现的过程。其中算法具体实现借鉴于ultralytics YOLO源码Github,删减了源码中部分内容,满足个人科研需求。
  本系列文章主要以YOLOv5为例完成算法的实现,后续修改、增加相关模块即可实现其他版本的YOLO算法。

文章地址:
YOLOv5算法实现(一):算法框架概述
YOLOv5算法实现(二):模型加载
YOLOv5算法实现(三):数据集加载
YOLOv5算法实现(四):损失计算
YOLOv5算法实现(五):预测结果后处理
YOLOv5算法实现(六):评价指标及实现
YOLOv5算法实现(七):模型训练
YOLOv5算法实现(八):模型验证
YOLOv5算法实现(九):模型预测(编辑中…)

本文目录

  • 1 引言
  • 2 模型验证(validation.py)

1 引言

  本篇文章综合之前文章中的功能,实现模型的验证。模型验证的逻辑如图1所示。
在这里插入图片描述

图1 模型验证流程

2 模型验证(validation.py)

def validation(parser_data):
    device = torch.device(parser_data.device if torch.cuda.is_available() else "cpu")
    print("Using {} device validation.".format(device.type))

    # read class_indict
    label_json_path = './data/object.json'
    assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
    with open(label_json_path, 'r') as f:
        class_dict = json.load(f)

    category_index = {v: k for k, v in class_dict.items()}

    data_dict = parse_data_cfg(parser_data.data)
    test_path = data_dict["valid"]

    # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
    batch_size = parser_data.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using %g dataloader workers' % nw)

    # load validation data set
    val_dataset = LoadImagesAndLabels(test_path, parser_data.img_size, batch_size,
                                      hyp=parser_data.hyp,
                                      rect=False)  # 将每个batch的图像调整到合适大小,可减少运算量(并不是512x512标准尺寸)

    val_dataset_loader = torch.utils.data.DataLoader(val_dataset,
                                                     batch_size=batch_size,
                                                     shuffle=True,
                                                     num_workers=nw,
                                                     pin_memory=True,
                                                     collate_fn=val_dataset.collate_fn)

    # create model
    model = Model(parser_data.cfg, ch=3, nc=parser_data.nc)
    weights_dict = torch.load(parser_data.weights, map_location='cpu')
    weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
    model.load_state_dict(weights_dict, strict=False)
    model.to(device)

    # evaluate on the test dataset

    # 计算PR曲线和AP
    stats = []
    iouv = torch.linspace(0.5, 0.95, 10, device=device)  # iou vector for mAP@0.5:0.95
    niou = iouv.numel()
    # 混淆矩阵
    confusion_matrix = ConfusionMatrix(nc=3, conf=0.6)
    model.eval()

    with torch.no_grad():
        for imgs, targets, paths, shapes, img_index in tqdm(val_dataset_loader, desc="validation..."):
            imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            nb, _, height, width = imgs.shape  # batch size, channels, height, width
            targets = targets.to(device)
            preds = model(imgs)[0]  # only get inference result
            preds = non_max_suppression(preds, conf_thres=0.3, iou_thres=0.6, multi_label=False)
            targets[:, 2:] *= torch.tensor((width, height, width, height), device=device)
            outputs = []
            for si, pred in enumerate(preds):
                '''
                labels: [clas, x, y, w, h] (训练图像上绝对坐标)
                pred: [x,y,x,y,obj,cls] (训练图像上绝对坐标)
                predn: [x,y,x,y,obj,cls] (输入图像上绝对坐标)
                labels: [x,y,x,y,class] (输入图像上绝对坐标)
                shapes[si][0]: 输入图像大小
                shapes[si][1]
                '''
                labels = targets[targets[:, 0] == si, 1:]  # 当前图片的标签信息
                nl = labels.shape[0]  # number of labels # 当前图片标签数量
                if pred is None:
                    npr = 0
                else:
                    npr = pred.shape[0]  # 预测结果数量
                correct = torch.zeros(npr, niou, dtype=torch.bool, device=device)  # 判断在不同IoU下预测是否预测正确
                path, shape = Path(paths[si]), shapes[si][0]  # 当前图片shape(原图大小)
                if npr == 0:  # 若没有预测结果
                    if nl:  # 没有预测结果但有实际目标
                        # 不同IoU阈值下预测准确率,目标类别置信度,预测类别,实际类别
                        stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0]))
                        # 混淆矩阵计算(类别信息)
                        confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
                    continue
                predn = pred.clone()
                scale_boxes(imgs[si].shape[1:], predn[:, :4], shape, shapes[si][1])  # native-space pred
                if nl:  # 有预测结果且有实际目标
                    tbox = xywh2xyxy(labels[:, 1:5])  # target boxes
                    scale_boxes(imgs[si].shape[1:], tbox, shape, shapes[si][1])  # native-space labels
                    labelsn = torch.cat((labels[:, 0:1], tbox), 1)  # native-space labels
                    correct = process_batch(predn, labelsn, iouv)
                    confusion_matrix.process_batch(predn, labelsn)
                stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0]))  # 预测结果在不同IoU是否预测正确, 预测置信度, 预测类别, 实际类别
        confusion_matrix.plot(save_dir=parser_data.save_path, names=["normal", 'defect', 'leakage'])

    # 图片:预测结果在不同IoU下预测结果,预测置信度,预测类别,实际类别
    stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]  # to numpy
    if len(stats) and stats[0].any():
        tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, names=["normal", 'defect', 'leakage'])
        ap50, ap = ap[:, 0], ap.mean(1)  # AP@0.5, AP@0.5:0.95
        mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
        print(map50)

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description=__doc__)

    # 使用设备类型
    parser.add_argument('--device', default='cuda', help='device')

    # 检测目标类别数
    parser.add_argument('--nc', type=int, default=3, help='number of classes')
    file = 'yolov5s'
    cfg = f'cfg/models/{file}.yaml'
    parser.add_argument('--cfg', type=str, default=cfg, help="*.cfg path")
    parser.add_argument('--data', type=str, default='data/my_data.data', help='*.data path')
    parser.add_argument('--hyp', type=str, default='cfg/hyps/hyp.scratch-med.yaml', help='hyperparameters path')
    parser.add_argument('--img-size', type=int, default=640, help='test size')

    # 训练好的权重文件
    weight_1 = f'./weights/{file}/{file}' + '-best_map.pt'
    weight_2 = f'./weights/{file}/{file}' + '.pt'
    weight = weight_1 if os.path.exists(weight_1) else weight_2
    parser.add_argument('--weights', default=weight, type=str, help='training weights')
    parser.add_argument('--save_path', default=f'results/{file}', type=str, help='result save path')

    # batch size
    parser.add_argument('--batch_size', default=2, type=int, metavar='N',
                        help='batch size when validation.')

    args = parser.parse_args()

    validation(args)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1378977.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【DevOps-08-3】Jenkins容器内部使用Docker

一、简要描述 构建镜像和发布镜像到harbor都需要使用到docker命令。而在Jenkins容器内部安装Docker官方推荐直接采用宿主机带的Docker即可。 设置Jenkins容器使用宿主机Docker。 二、配置和操作步骤 1、修改宿主机docker.sock权限 # 修改docker.sock 用户和用户组都为root $ …

如何配置Kafka账号密码

背景 我们需要与第三方系统进行数据同步,需要搭建公网Kafka,Kafka默认是没有用户密码校验的,所以我们需要配置用户名密码校验。 配置 新增JAAS配置文件 在conf目录下新增kafka_server_jaas.conf文件,文件内容如下:…

高压消防泵:科技与安全性的完美结合

在现代社会,随着科技的不断发展,各种高科技设备层出不穷,为我们的生活带来了极大的便利。在森林火灾扑救领域,恒峰智慧科技研发的高压消防泵作为一种高效、节能、绿色、环保的优质设备,将科技与安全性完美地结合在一起…

最强联网Chat GPT 火爆全网高速 永久免费

🔴高速联网 秒响应支持语音通话🎈 首先介绍一下她的功能吧😁 女友消息代回机👌🏻 朋友圈文案👌🏻 聊天话术👌🏻 高情商回复👌🏻 脱单助…

Windows使用(版本8.11)ElasticSearch、elasticsearch-head、kibana

下载安装引用这篇文章 目录 1、ES基本知识核心术语核心概念倒排索引ES字典树ES怎么保证读写一致 2、Window启动ES步骤elasticsearch-8.11.3elasticsearch-head-masterkibana-8.11.3 3、Kibana 调用ES API示例 1、ES基本知识 核心术语 ● 索引:index (相…

持续构建行业影响力|HarmonyOS SDK荣膺年度“技术卓越”奖项

自2023年9月华为宣布鸿蒙原生应用全面启动以来,HarmonyOS SDK通过将HarmonyOS系统级能力对外开放,支撑开发者高效打造更纯净、更智能、更精致、更易用的鸿蒙原生应用,和开发者共同成长。 通过在开发者社区和HarmonyOS开发者持续的内容共创与技…

uniapp小程序超出一行显示...并展示更多按钮

注意:全部标签需要浮动在父盒子右边哦 循环获取所有需要展示数据标签的高度 this.goods this.goods.map(item > ({...item,showBtn: false}));this.$nextTick(() > {uni.createSelectorQuery().in(this).selectAll(".cart-info").boundingClientRect((data)…

RocketMQ源码阅读-Producer发消息

RocketMQ源码阅读-Producer发消息 1. 从单元测试入手2. 启动过程3. 同步消息发送过程4. 异步消息发送过程5. 小结 Producer是消息的生产者。 Producer和Consummer对Rocket来说都是Client,Server是NameServer。 客户端在源码中是一个单独的Model,目录为ro…

LeetCode - 1371 每个元音包含偶数次的最长子字符串(Java JS Python C)

题目来源 1371. 每个元音包含偶数次的最长子字符串 - 力扣(LeetCode) 题目描述 给你一个字符串 s ,请你返回满足以下条件的最长子字符串的长度:每个元音字母,即 a,e,i,o&#xff0…

【Git】的工作流程简介

目录 Git的工作区域Git的基本流程 1.将工作区的代码添加到暂存区2.将暂存区的文件提交到本地仓库3.将暂存区的文件提交到远程仓库 Git的工作区域 Git的基本流程 图形化方式操作 命令行模式(Linux系统常用)操作 1.将工作区的代码添加到暂存区 查看文件状…

架构03 - 理解构架的视角

学习架构时,首要任务是弄清楚不同视角对于架构的理解,因为每个人对于架构的理解可能存在差异。不同职位对于架构的关注点也不同。开发人员更多关注开发架构,售前人员更多关注业务架构,运维人员更多关注运维架构,技术支…

基于SSM的电脑测评系统(有报告)。Javaee项目。ssm项目。

演示视频: 基于SSM的电脑测评系统(有报告)。Javaee项目。ssm项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构,通过Spring Spri…

书生·浦语大模型--第二节课作业

书生浦语大模型--第二节课作业 基础部分生成300字小故事hugging face 下载功能 进阶部分浦语灵笔的图文理解及创作部署Lagent 工具调用 Demo 创作部署 基础部分 生成300字小故事 hugging face 下载功能 hugging face被墙了,在本地电脑无论是不是科学上网&#xff…

用通俗易懂的方式讲解:对 embedding 模型进行微调,我的大模型召回效果提升了太多了

QA对话目前是大语言模型的一大应用场景,在QA对话中,由于大语言模型信息的滞后性以及不包含业务知识的特点,我们经常需要外挂知识库来协助大模型解决一些问题。 在外挂知识库的过程中,embedding模型的召回效果直接影响到大模型的回…

使用Sqoop将数据从Hadoop导出到关系型数据库

当将数据从Hadoop导出到关系型数据库时,Apache Sqoop是一个非常有用的工具。Sqoop可以轻松地将大数据存储中的数据导出到常见的关系型数据库,如MySQL、Oracle、SQL Server等。本文将深入介绍如何使用Sqoop进行数据导出,并提供详细的示例代码&…

Android Studio 实现网易新闻App (简单方便易懂)

🍅文章末尾有获取完整项目源码方式🍅 目录 前言 一、任务介绍 1.1 背景 1.2目的和意义 二、 实现介绍 视频演示 2.1 启动页实现 2.2 注册页面实现 2.3 登陆页面实现 2.4 首页实现 2.5 详情页面实现 三、获取源码 前言 随着移动互联网的持续发…

力扣120. 三角形最小路径和(Java 动态规划)

Problem: 120. 三角形最小路径和 文章目录 题目描述思路解题方法复杂度Code 题目描述 思路 Problem:64. 最小路径和 本题目可以看作是在上述题目的基础上改编而来,具体的思路: 1.记录一个int类型的大小的 n 乘 n n乘n n乘n的数组(其中 n n n为…

第九讲 单片机驱动彩色液晶屏 控制RA8889软件:显存操作

单片机驱动TFT彩色液晶屏系列讲座 目录 第一讲 单片机最小系统STM32F103C6T6通过RA8889驱动彩色液晶屏播放视频 第二讲 单片机最小系统STM32F103C6T6控制RA8889驱动彩色液晶屏硬件框架 第三讲 单片机驱动彩色液晶屏 控制RA8889软件:如何初始化 第四讲 单片机驱动彩色液晶屏 控…

日志审计系统Agent项目创建——读取日志文件(Linux版本)

紧接着上一篇的分享,继续做日志文件的读取,点击连接即可日志文件初始化https://blog.csdn.net/wjl990316fddwjl/article/details/135553238 1、将指针移动到文件末尾 //文件移动到结尾fseek(fp, 0, SEEK_END); 2、定义当前指针的位置 lastPosition ft…

人工智能:我的学习之旅与认知探索(第1版)

🌟🌌 欢迎来到知识与创意的殿堂 — 远见阁小民的世界!🚀 🌟🧭 在这里,我们一起探索技术的奥秘,一起在知识的海洋中遨游。 🌟🧭 在这里,每个错误都…