用自己的数据集训练YOLO-NAS目标检测器

news2024/12/23 22:31:30

YOLO-NAS 是 Deci 开发的一种新的最先进的目标检测模型。 在本指南中,我们将讨论什么是 YOLO-NAS 以及如何在自定义数据集上训练 YOLO-NAS 模型。

在这里插入图片描述

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D场景编辑器

为了训练我们的自定义模型,我们将:

  • 加载预训练的YOLO-NAS模型;
  • 从 Roboflow 加载自定义数据集,或者使用UnrealSynth制作合成数据集
  • 设置超参数值;
  • 使用超级梯度 Python 包根据我们的数据训练模型;
  • 评估模型以了解结果。

话不多说,让我们开始吧!

1、什么是 YOLO-NAS?

You Only Look Once  神经架构搜索(YOLO-NAS)是最新最先进的(SOTA)实时目标检测模型。 在 COCO 数据集上进行评估并与其前身 YOLOv6 和 YOLOv8  相比,YOLO-NAS 以更低的延迟实现了更高的 mAP 值。

YOLO-NAS 作为 Deci 维护的 super-gradient包的一部分提供。

下图展示了Deci在YOLO-NAS上的基准测试结果:
在这里插入图片描述

YOLO-NAS 与其他顶级实时检测器在 COCO 数据集上的性能对比图

YOLO-NAS 在 Roboflow 100 数据集基准测试中也是最好的,这表明它可以轻松地在自定义数据集上进行微调。

在这里插入图片描述

YOLO-NAS 和其他顶级实时检测器在 RF100 数据集上的性能对比图

2、Python环境设置

在开始训练之前,我们需要准备好Python环境。 让我们从安装三个 pip 包开始。 YOLO-NAS 模型本身是使用 super-gradient 包进行分发的。 请记住,该模型仍在积极开发中。 为了保持环境的稳定性,最好固定特定版本的包。 此外,我们将安装 roboflow 和监督,这将使我们能够从 Roboflow Universe 下载数据集并分别可视化我们的训练结果。

pip install super-gradients==3.1.1
pip install roboflow
pip install supervision

如果你在 Jupyter Notebook 中运行 YOLO-NAS,请不要忘记在安装完成后重新启动环境。

3、使用预训练模型进行推理

在开始培训之前,最好确保安装按计划进行。 最简单的方法是使用预先训练的模型之一进行测试推理。 同时,这也能让我们熟悉YOLO-NAS API。

3.1 加载YOLO-NAS模型

为了使用预训练的 COCO 模型进行推理,我们首先需要选择模型的大小。 YOLO-NAS提供三种不同的模型大小:yolo_nas_s、yolo_nas_m和yolo_nas_l。

yolo_nas_s 模型是最小且最快的,但它可能不会像较大的模型那么准确。 相反,yolo_nas_l 模型最大、最准确、最慢。 yolo_nas_m 模型提供了两者之间的中间立场。

import torch
from super_gradients.training import models

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_ARCH = 'yolo_nas_l'
#            'yolo_nas_m'
#            'yolo_nas_s'

model = models.get(MODEL_ARCH, pretrained_weights="coco").to(DEVICE)

3.2 YOLO-NAS模型推理

推理过程包括设置置信度阈值和调用预测方法。 预测方法将返回预测列表,其中每个预测对应于图像中检测到的对象。

CONFIDENCE_TRESHOLD = 0.35

result = list(model.predict(image, conf=CONFIDENCE_TRESHOLD))[0]

在这里插入图片描述

YOLO-NAS推理结果图示

3.3 YOLO-NAS 推理输出格式

YOLO-NAS 推理的输出是一个 ImageDetectionPrediction 对象,它封装了图像中检测到的对象的详细信息。 该对象包含三个字段:

  • image - 表示用于推理的图像的 NumPy 数组。
  • class_names - 模型训练期间使用的类别名称的 Python 列表。
  • Prediction -DetectionPrediction 类的实例,其中包含有关模型检测的详细信息。

DetectionPrediction对象具有三个字段:

  • bboxes_xyxy - 形状 (N, 4) 的 NumPy 数组,以 xyxy 格式表示检测到的对象的边界框。
  • confidence - 形状 (N,) 的 NumPy 数组,表示检测的置信度值。 每个值都在 0 和 1 之间。
  • labels - 形状 (N,) 的 NumPy 数组,表示检测到的对象的类 ID。 每个类 ID 对应于 class_names 列表中的一个索引。

4、使用开源数据集微调 YOLO-NAS

为了微调模型,我们需要数据。 我们将使用足球运动员检测图像数据集。

如果你已经有 YOLO 格式的数据集,请随意使用它。 如果没有,请看看 Roboflow Universe,那里拥有超过 200,000 个开源项目,并且所有项目都可以以任何格式导出。

另外一种获取数据集的方法是使用UnrealSynth,一个基于虚幻引擎开发的YOLO合成数据生成器,可以自动生成包括标注的训练数据集,非常方便:
在这里插入图片描述

https://tools.nsdt.cloud/UnrealSynth

import roboflow
from roboflow import Roboflow

roboflow.login()

rf = Roboflow()
project = rf.workspace(WORKSPACE_ID).project(PROJECT_ID)
dataset = project.version(PROJECT_VERSION).download("yolov5")

要训练 YOLO-NAS 模型,你需要设置几个关键参数。

首先,你需要选择模型尺寸。 有三个选项可供选择:小型、中型和大型。 请记住,较大的模型可能需要更长的时间来训练并需要更多的内存,因此如果使用的资源有限,你可能需要考虑使用较小的模型。

接下来,你需要设置批量大小。 该参数指示在训练过程的每次迭代期间将有多少图像通过神经网络。 较大的批量大小将加快训练过程,但也需要更多的内存。

MODEL_ARCH = 'yolo_nas_l'
BATCH_SIZE = 8
MAX_EPOCHS = 25
CHECKPOINT_DIR = f'{HOME}/checkpoints'
EXPERIMENT_NAME = project.name.lower().replace(" ", "_")
LOCATION = dataset.location
CLASSES = sorted(project.classes.keys())

dataset_params = {
    'data_dir': LOCATION,
    'train_images_dir':'train/images',
    'train_labels_dir':'train/labels',
    'val_images_dir':'valid/images',
    'val_labels_dir':'valid/labels',
    'test_images_dir':'test/images',
    'test_labels_dir':'test/labels',
    'classes': CLASSES
}

from super_gradients.training.dataloaders.dataloaders import (
    coco_detection_yolo_format_train, coco_detection_yolo_format_val)

train_data = coco_detection_yolo_format_train(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['train_images_dir'],
        'labels_dir': dataset_params['train_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size': BATCH_SIZE,
        'num_workers': 2
    }
)

val_data = coco_detection_yolo_format_val(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['val_images_dir'],
        'labels_dir': dataset_params['val_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size': BATCH_SIZE,
        'num_workers': 2
    }
)

最后,你需要设置训练过程的纪元数。 这本质上是整个数据集通过神经网络的次数。

5、训练自定义 YOLO-NAS 模型

你可能已经注意到,训练模型的过程比 YOLOv8 更加冗长。 Ultralytics 模型中的许多功能需要在 CLI 中传递参数,而对于 YOLO-NAS,则需要编写自定义逻辑。

最后,我们准备开始训练。 在调用 train 方法之前,值得运行 TensorBoard。 这将使我们能够实时跟踪培训的关键指标。 值得一提的是,YOLO-NAS还支持W&B等最流行的实验记录仪。
在这里插入图片描述

YOLO-NAS 训练期间获得的指标图

trainer.train(
    model=model, 
    training_params=train_params, 
    train_loader=train_data, 
    valid_loader=val_data
)

6、评估自定义 YOLO-NAS 模型

训练结束后,你可以使用Trainer提供的测试方法评估模型的性能。 你需要传入测试集数据加载器,训练器将返回一个指标列表,包括通常用于评估对象检测模型的平均精度(mAP)。

trainer.test(
    model=best_model,
    test_loader=test_data,
    test_metrics_list=DetectionMetrics_050(
        score_thres=0.1, 
        top_k_predictions=300, 
        num_cls=len(dataset_params['classes']), 
        normalize_targets=True, 
        post_prediction_callback=PPYoloEPostPredictionCallback(
            score_threshold=0.01, 
            nms_top_k=1000, 
            max_predictions=300,                                                                              
            nms_threshold=0.7
        )
    )
)

在这里插入图片描述

模型评估期间获得的预测与手动标注的比较

此外,你可以对测试集图像进行推理并可视化结果,以更好地了解模型在各个示例上的表现。 你还可以计算混淆矩阵,以更详细地了解每个类别的模型性能:

在这里插入图片描述

模型评估过程中创建的混淆矩阵

7、结束语

一夜之间,YOLO-NAS 成为实时物体检测器的新选择。 在为你的项目微调模型时,请记住要考虑所有方面——从模型准确性到推理速度,再到易于训练和许可限制。


原文链接:训练自己的YOLO-NAS — BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1162067.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

7+单细胞分析+预后模型构建+验证实验思路,干湿结合也能拿高分

今天给同学们分享一篇单细胞分析肿瘤预后模型构建验证实验思路的生信文章“Identification of a novel immune-related gene signature for prognosis and the tumor microenvironment in patients with uveal melanoma combining single-cell and bulk sequencing data”&…

SpringMVC简单介绍与使用

目录 一、SpringMVC介绍 二、SpringMVC作用 三、SpringMVC核心组件 四、SpringMVC快速体验 一、SpringMVC介绍 Spring Web MVC是基于Servlet API构建的原始Web框架,从一开始就包含在Spring Framework中。正式名称“Spring Web MVC”来自其源模块的名称&#xff…

UE5数字孪生制作(一) - QGIS 学习笔记

1.下载 QGIS是免费的GIS工具,下载地址: https://www.qgis.org/en/site/ 2.安装 - 转中文 按照步骤安装,完成后,在菜单 设置settings里,选择options,修改语言 确定后,需要重启下软件 3.学习视…

聊聊展会接待接待客户会用到的一些英语话术

第三期广交会依然在进行中,周六也就结束了,不知道大家这次参展的效果如何?昨晚略看了一下毅冰老师的直播课,他讲的也是和展会有关的内容,稍微摘抄了一些客户来展位时的交流英语,大家可以一起看看。 作为参展…

Numpy数值计算Numpy初体验在线闯关_头歌实践教学平台

Numpy数值计算初体验 第1关 Numpy创建数组第2关 Numpy数组的基本运算第3关 Numpy数组的切片与索引第4关 Numpy数组的堆叠第5关 Numpy的拆分 第1关 Numpy创建数组 任务描述 本关的小目标是,使用 Numpy 创建一个多维数组。 测试说明 本关的测试过程如下: 平台运行ste…

C# Winform串口助手

界面设置 修改控件name属性 了解SerialPort类 实现串口的初始化,开关 创建虚拟串口 namespace 串口助手 {public partial class Form1 : Form{public Form1(){InitializeComponent();}private void Form1_Load(object sender, EventArgs e){//在设计页面已经预先…

手写一个uniapp的步骤条组件

在template实现 <template><view class"process_more"><!-- 步骤条 --><view class"set-2" :key"index" v-for"(item,index) in options"><!-- 图片 --><view class"img-border"><…

造物者:专注游戏音乐创造——奏响游戏世界乐章

游戏的世界宛如一幅壮丽的画卷&#xff0c;由华丽的图像和引人入胜的故事构成&#xff0c;然而&#xff0c;其完美之作还有一部分不可或缺的元素&#xff0c;那就是音乐。在这个数字时代&#xff0c;北京造物者科技有限公司&#xff08;以下简称造物者&#xff09;正崭露头角&a…

【RP-RV1126】配置一套简单的板级配置

文章目录 官方配置新建一套新配置新建板级pro-liefyuan-rv1126.mk配置文件新建一个Buildroot的defconfigs文件 吐槽&#xff1a;RP-RV1126 的SDK奇怪的地方make ARCHarm xxx_defconfig 生成的.config文件位置不一样savedefconfig命令直接替换原配置文件坑爹的地方 Buildroot上增…

【本周骑行香杆箐活动简介】- 探索秋天的美景与健康同行

校长骑行的骑友们&#xff0c;大家好&#xff01;在这个秋高气爽的季节里&#xff0c;是不是已经跃跃欲试&#xff0c;想要投入大自然的怀抱&#xff0c;感受那无比清新的空气和金黄色的落叶呢&#xff1f;别再犹豫了&#xff0c;让我们一起骑行在香杆箐&#xff0c;体验一次不…

91 前K个高频元素

前K个高频元素 题解1 大根堆(STL) 给你一个整数数组 nums 和一个整数 k &#xff0c;请你返回其中出现频率前 k 高的元素。你可以按 任意顺序 返回答案。 示例 1: 输入: nums [1,1,1,2,2,3], k 2 输出: [1,2] 示例 2: 输入: nums [1], k 1 输出: [1] 提示&#xff1a;…

KADP应用加密组件实现数据动态脱敏 安当加密

动态脱敏是一种针对敏感数据进行数据抽取、数据漂白和动态掩码的专业数据脱敏技术。它通过在不动数据库中原始数据的前提下&#xff0c;依据用户的角色、职责和其他IT定义身份特征&#xff0c;动态的对生产数据库返回的数据进行专门的屏蔽、加密、隐藏和审计。可确保不同级别的…

双十一数码推荐什么?双十一选购攻略大全!实用数码产品推荐!

​在双十一这个购物狂欢节里&#xff0c;各大品牌和商家都会推出各种优惠活动&#xff0c;为消费者提供丰富的购物选择。在这个特殊的日子里&#xff0c;你是否也准备为自己或亲朋好友选购一些数码好物呢?本次推荐将为你精选一些值得购买的数码产品&#xff0c;让你在双十一这…

MATLAB和西门子SMART PLC OPC通信

西门子S7-200SMART PLC OPC软件的下载和使用,请查看下面文章 Smart 200PLC PC Access SMART OPC通信_基于pc access smart的opc通信_RXXW_Dor的博客-CSDN博客文章浏览阅读2.7k次,点赞2次,收藏5次。OPC是一种利用微软COM/DCOM技术达成自动控制的协议,采用典型的C/S模式,针…

(01)Mycat说明与介绍

1、Mycat是什么 Mycat是一个数据库中间件&#xff0c;前身是阿里的cobar。 2、Mycat可以用来做什么 1.读写分离 2.数据分片 &#xff08;1&#xff09;垂直拆分 &#xff08;2&#xff09;水平拆分 &#xff08;3&#xff09;垂直水平拆分 3.多数据源整合 3、Mycat实现的…

前端出大事儿了

大家好&#xff0c;我是风筝 文章首发于 前端出大事儿了 最近这两天&#xff0c;在前端圈最火的图片莫过于下面这张了。 这是一段 React 代码&#xff0c;就算你完全没用过 React 也没关系&#xff0c;一眼看过去就能看到其中最敏感的一句代码&#xff0c;就是那句 SQL 。 咱…

Linux安装sysv-rc-conf报错:出现NO_PUBKEY...问题,急需安装证书的情况

Linux下安装MySQL时&#xff0c;出现一个使用chkconfig命令&#xff0c;但无该命令的情况&#xff01; chkconfig --add mysql # 出现chkconfig command not found于是就展开了一次替换的行动&#xff0c;将chkconfig替换为sysv-rc-conf 第一步&#xff1a; 尝试直接安装&am…

最新阿里云服务器优惠价格表,企鹅看了瑟瑟发抖!

今年2023年阿里云双十一优惠活动云服务器价格太低了&#xff0c;比腾讯云都便宜&#xff0c;轻量2核2G服务器3M带宽优惠价87元一年、2核4G4M带宽优惠价165元一年&#xff0c;云服务器ECS经济型e实例2核2G3M固定带宽优惠价格99元一年&#xff0c;还有2核4G、2核8G、4核8G、4核16…

Leetcode刷题---轮转数组

轮转数组 题目描述&#xff1a; Java中List是有序、可重复的单列集合&#xff0c;集合中的每个元素都有对应的顺序索引&#xff0c;我们可以通过该索引来访问指定位置上的集合元素。 思路&#xff1a; 首先选用list来存储中间结果。首先用k对n(数组长度)求余获取要移动的位数…

2023年【P气瓶充装】最新解析及P气瓶充装考试技巧

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 P气瓶充装最新解析参考答案及P气瓶充装考试试题解析是安全生产模拟考试一点通题库老师及P气瓶充装操作证已考过的学员汇总&#xff0c;相对有效帮助P气瓶充装考试技巧学员顺利通过考试。 1、【多选题】LNG加气站有哪些…