RT-DETR训练自己的数据集(从代码下载到实例测试)

news2024/12/26 21:56:38

文章目录

  • 前言
  • 一、RT-DETR简介
  • 二、环境搭建
  • 三、构建数据集
  • 四、修改配置文件
    • ①数据集文件配置
    • ②模型文件配置
    • ③训练文件配置
  • 五、模型训练和测试
    • 模型训练
    • 模型测试
  • 总结


前言

提示:本文是RT-DETR训练自己数据集的记录教程,需要大家在本地已配置好CUDA,cuDNN等环境,没配置的小伙伴可以查看我的往期博客:在Windows10上配置CUDA环境教程

RT-DETR(Real-Time Detection Transformer)是百度提出的一种实时目标检测模型。RT-DETR 采用了与 DETR 相同的编码器和解码器结构,但对其进行了大量的优化,在保持较高检测精度的同时,实现了实时目标检测,为目标检测领域提供了一种新的有效解决方案,具有广泛的应用前景,如自动驾驶、智能监控、机器人等领域。

在这里插入图片描述

论文地址:https://arxiv.org/abs/2304.08069
代码地址:https://github.com/lyuwenyu/RT-DETR

以上是RT-DETR官方的论文和代码,详细介绍了RT-DETR的构成和实现。在YOLOv10中也实现了RT-DETR的代码,所以本文使用YOLOv10的项目文件进行RT-DETR的训练和测试。YOLOv10的源码地址为:

代码地址:https://github.com/THU-MIG/yolov10

在这里插入图片描述


一、RT-DETR简介

核心设计

  • 采用Transformer结构RT-DETR 采用了与 DETR 相同的编码器和解码器结构,但对其进行了大量的优化。编码器用于对输入图像进行特征提取,解码器则用于预测目标的类别、位置和边界框等信息。
  • 优化计算成本:使用了更小的特征图来减少计算成本,并且使用更少的注意力头,以减少模型中的参数数量。此外,还引入了一种新的分组注意力机制,可以进一步提高性能。
  • 高效的混合编码器
    • 解耦内部尺度交互和跨尺度融合:设计了高效的混合编码器(Hybrid Encoder),通过解耦尺度内特征交互(AIFI)跨尺度特征融合(CCFM),能够高效地处理多尺度特征。其中 AIFI 基于注意力机制实现尺度内特征交互,CCFM 则基于 CNN 进行跨尺度特征融合,这样可以更好地捕捉不同尺度下的目标信息,同时降低计算复杂度。
  • IOU感知的查询选择
    • 改进目标查询初始化:提出了 IOU 感知的查询选择(IOU-aware Query Selection)方法,以改进目标查询的初始化。传统的查询选择方式可能忽略了检测器需要同时对对象的类别和位置进行建模的事实,而该方法显式地构建和优化认知不确定性,对编码器特征的联合潜在变量进行建模,从而为解码器提供高质量(高分类分数和高 IOU 分数)的初始查询,有助于提高检测的准确性。

二、环境搭建

在配置好CUDA环境,并且获取到RT-DETR源码后,建议新建一个虚拟环境专门用于RT-DETR模型的训练。将RT-DETR加载到环境后,安装剩余的包。requirements.txt 中包含了运行所需的包和版本,利用以下命令批量安装:

pip install -r requirements.txt

三、构建数据集

RT-DETR模型的训练需要原图像及对应的YOLO格式标签,还未制作标签的可以参考我这篇文章:LabelImg安装与使用教程。

我的原始数据存放在根目录的data文件夹(新建的)下,里面包含图像和标签。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
标签内的格式如下:
在这里插入图片描述
具体格式为 class_id x y w h,分别代表物体类别,标记框中心点的横纵坐标(x, y),标记框宽高的大小(w, h),且都是归一化后的值,图片左上角为坐标原点。

将原本数据集按照8:1:1的比例划分成训练集、验证集和测试集三类,划分代码如下。

# 将图片和标注数据按比例切分为 训练集和测试集
import shutil
import random
import os
 
# 原始路径
image_original_path = "data/images/"
label_original_path = "data/labels/"
 
cur_path = os.getcwd()
# 训练集路径
train_image_path = os.path.join(cur_path, "datasets/images/train/")
train_label_path = os.path.join(cur_path, "datasets/labels/train/")
 
# 验证集路径
val_image_path = os.path.join(cur_path, "datasets/images/val/")
val_label_path = os.path.join(cur_path, "datasets/labels/val/")
 
# 测试集路径
test_image_path = os.path.join(cur_path, "datasets/images/test/")
test_label_path = os.path.join(cur_path, "datasets/labels/test/")
 
# 训练集目录
list_train = os.path.join(cur_path, "datasets/train.txt")
list_val = os.path.join(cur_path, "datasets/val.txt")
list_test = os.path.join(cur_path, "datasets/test.txt")
 
train_percent = 0.8
val_percent = 0.1
test_percent = 0.1
 
 
def del_file(path):
    for i in os.listdir(path):
        file_data = path + "\\" + i
        os.remove(file_data)
 
 
def mkdir():
    if not os.path.exists(train_image_path):
        os.makedirs(train_image_path)
    else:
        del_file(train_image_path)
    if not os.path.exists(train_label_path):
        os.makedirs(train_label_path)
    else:
        del_file(train_label_path)
 
    if not os.path.exists(val_image_path):
        os.makedirs(val_image_path)
    else:
        del_file(val_image_path)
    if not os.path.exists(val_label_path):
        os.makedirs(val_label_path)
    else:
        del_file(val_label_path)
 
    if not os.path.exists(test_image_path):
        os.makedirs(test_image_path)
    else:
        del_file(test_image_path)
    if not os.path.exists(test_label_path):
        os.makedirs(test_label_path)
    else:
        del_file(test_label_path)
 
 
def clearfile():
    if os.path.exists(list_train):
        os.remove(list_train)
    if os.path.exists(list_val):
        os.remove(list_val)
    if os.path.exists(list_test):
        os.remove(list_test)
 
 
def main():
    mkdir()
    clearfile()
 
    file_train = open(list_train, 'w')
    file_val = open(list_val, 'w')
    file_test = open(list_test, 'w')
 
    total_txt = os.listdir(label_original_path)
    num_txt = len(total_txt)
    list_all_txt = range(num_txt)
 
    num_train = int(num_txt * train_percent)
    num_val = int(num_txt * val_percent)
    num_test = num_txt - num_train - num_val
 
    train = random.sample(list_all_txt, num_train)
    # train从list_all_txt取出num_train个元素
    # 所以list_all_txt列表只剩下了这些元素
    val_test = [i for i in list_all_txt if not i in train]
    # 再从val_test取出num_val个元素,val_test剩下的元素就是test
    val = random.sample(val_test, num_val)
 
    print("训练集数目:{}, 验证集数目:{}, 测试集数目:{}".format(len(train), len(val), len(val_test) - len(val)))
    for i in list_all_txt:
        name = total_txt[i][:-4]
 
        srcImage = image_original_path + name + '.jpg'
        srcLabel = label_original_path + name + ".txt"
 
        if i in train:
            dst_train_Image = train_image_path + name + '.jpg'
            dst_train_Label = train_label_path + name + '.txt'
            shutil.copyfile(srcImage, dst_train_Image)
            shutil.copyfile(srcLabel, dst_train_Label)
            file_train.write(dst_train_Image + '\n')
        elif i in val:
            dst_val_Image = val_image_path + name + '.jpg'
            dst_val_Label = val_label_path + name + '.txt'
            shutil.copyfile(srcImage, dst_val_Image)
            shutil.copyfile(srcLabel, dst_val_Label)
            file_val.write(dst_val_Image + '\n')
        else:
            dst_test_Image = test_image_path + name + '.jpg'
            dst_test_Label = test_label_path + name + '.txt'
            shutil.copyfile(srcImage, dst_test_Image)
            shutil.copyfile(srcLabel, dst_test_Label)
            file_test.write(dst_test_Image + '\n')
 
    file_train.close()
    file_val.close()
    file_test.close()

 
if __name__ == "__main__":
    main()

划分完成后将会在datasets文件夹下生成划分好的文件,其中images为划分后的图像文件,里面包含用于train、val、test的图像,已经划分完成;labels文件夹中包含划分后的标签文件,已经划分完成,里面包含用于train、val、test的标签;train.tet、val.txt、test.txt中记录了各自的图像路径。

在这里插入图片描述

在这里插入图片描述
在训练过程中,也是主要使用这三个txt文件进行数据的索引。

四、修改配置文件

①数据集文件配置

数据集划分完成后,在根目录文件夹下新建data.yaml文件,替代coco.yaml。用于指明数据集路径和类别,我这边只有一个类别,只留了一个,多类别的在name内加上类别名即可。data.yaml中的内容为:

path: ../datasets  # 数据集所在路径
train: train.txt  # 数据集路径下的train.txt
val: val.txt  # 数据集路径下的val.txt
test: test.txt  # 数据集路径下的test.txt

# Classes
names:
  0: wave

在这里插入图片描述

②模型文件配置

ultralytics/cfg/models/rt-detr文件夹下存放的是RT-DETR的各个版本的模型配置文件,检测的类别是coco数据的80类。在训练自己数据集的时候,只需要将其中的类别数修改成自己的大小。在根目录文件夹下新建rtdetr-l-test.yaml文件,此处以rtdetr-l.yaml文件中的模型为例,将其中的内容复制到rtdetr-l-test.yaml文件中 ,并将nc: 1 # number of classes 修改类别数` 修改成自己的类别数,如下:

在这里插入图片描述

在这里插入图片描述

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, HGStem, [32, 48]] # 0-P2/4
  - [-1, 6, HGBlock, [48, 128, 3]] # stage 1

  - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
  - [-1, 6, HGBlock, [96, 512, 3]] # stage 2

  - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
  - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]]
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3

  - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
  - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
  - [[-2, -1], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
  - [[-1, 17], 1, Concat, [1]] # cat Y4
  - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
  - [[-1, 12], 1, Concat, [1]] # cat Y5
  - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1

  - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)


修改完成后,模型文件就配置好啦。

③训练文件配置

项目的超参数配置在ultralytics/cfg文件夹下的default.yaml文件中

在这里插入图片描述

在模型训练中,比较重要的参数是weights、data、epochs、batch、imgsz、device以及workers。

  • weight是配置预训练权重的路径,可以指定模型的yaml文件或pt文件。

  • data是配置数据集文件的路径,用于指定自己的数据集yaml文件。

  • epochs指训练的轮次,默认是100次,只要模型能收敛即可。

  • batch是表示一次性将多少张图片放在一起训练,越大训练的越快,如果设置的太大会报OOM错误,我这边在default中设置16,表示一次训练16张图像。设置的大小为2的幂次,1为2的0次,16为2的4次。

  • imgsz表示送入训练的图像大小,会统一进行缩放。要求是32的整数倍,尽量和图像本身大小一致。

  • device指训练运行的设备。该参数指定了模型训练所使用的设备,例如使用 GPU 运行可以指定为device=0,或者使用多个 GPU 运行可以指定为 device=0,1,2,3,如果没有可用的 GPU,可以指定为 device=cpu 使用 CPU 进行训练。

  • workers是指数据装载时cpu所使用的线程数,默认为8,过高时会报错:[WinError 1455] 页面文件太小,无法完成操作,此时就只能将workers调成0了。

模型训练的相关基本参数就是这些啦,其余的参数可以等到后期训练完成进行调参时再详细了解。

五、模型训练和测试

模型训练

由于项目中未提供单独的训练程序用于训练,而只是使用命令行进行训练,此处提供两种训练方法,一是在终端使用命令行进行训练;二是新建训练程序,配置参数进行训练。

(1)、在终端使用命令行进行训练

打开终端或新建终端后,输入命令:

yolo detect train data=data.yaml model=rtdetr-l-test.yaml epochs=300 batch=16 imgsz=640 device=0 workers=8

(2)、新建训练程序,配置参数进行训练

在项目根目录下新建train.py文件,输入以下内容后运行当前文件即可开始训练。

from ultralytics.models import RTDETR
 
if __name__ == '__main__':
    model = RTDETR(model='rtdetr-l-test.yaml')
    model.train(pretrained=True, data='data.yaml', epochs=300, batch=32, device=0, imgsz=640, workers=8)

在这里插入图片描述

在这里插入图片描述

训练完成后,将会在runs/detect/train/exp/weights文件夹下存放训练后的权重文件。

模型测试

(1)、在终端使用命令行进行测试

打开终端或新建终端后,输入命令:

yolo detect val data=data.yaml model=runs/detect/train/exp/weights/best.pt, batch=32, imgsz=640, split=test, device=0, workers=8

(2)、新建训练程序,配置参数进行训练

在项目根目录下新建val.py文件,输入以下内容后运行当前文件即可开始测试。

from ultralytics.models import RTDETR
 
 
if __name__ == '__main__':
    model = RTDETR(model='runs/train/exp/weights/best.pt')
    model.val(data='data.yaml',batch=32, device='0', imgsz=640, workers=8)

总结

以上就是RT-DETR训练自己数据集的全部过程啦,欢迎大家在评论区交流~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2134592.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

302状态如何进行重定向

文章目录 一、302状态是什么意思二、遇到的使用场景三、如何处理customservice.wxmlcustomservice.js 一、302状态是什么意思 302状态码是临时重定向(Move Temporarily),表示所请求的资源临时地转移到新的位置。此外还有一个301永久重定向&a…

【spring】maven引入okhttp的日志拦截器打开增量注解进程

HttpLoggingInterceptor 是在logging-interceptor库中的:这个logging库老找不到 import okhttp3.OkHttpClient; import okhttp3.logging.HttpLoggingInterceptor;发现这仨是独立的库 pom中三个依赖 <!-- OKHTTP3 --><

在group by分组的时候,某个key过多导致数据倾斜

解决方案&#xff1a;将 key 打散&#xff0c;给 key 增加随机前缀 在进行 group by 之前&#xff0c;先给每个 user_id 增加一个随机前缀&#xff0c;使得原本相同的 user_id 被打散到不同的分组中。 按带前缀的 key 进行分组 对带有随机前缀的 user_id 进行分组和聚合。 …

重要涉密文件如何防窃取?四个方法有效防止文件泄密【文件保密管理】

随着信息化时代的发展&#xff0c;数据安全问题变得日益突出&#xff0c;特别是对于一些重要的涉密文件&#xff0c;其泄密将带来严重后果。因此&#xff0c;企业和个人在处理机密文件时&#xff0c;必须采取有效的措施来防止文件被窃取。 小编在本文将介绍四个有效的方法&…

三招教你搞定GPU服务器配置→收藏推荐配置

在AI人工智能应用日益渗透各行各业的今天&#xff0c;图形处理器&#xff08;GPU&#xff09;市场呈现出蓬勃发展的态势&#xff0c;其中GPU服务器市场更是炙手可热&#xff0c;其热度始终居高不下。随着人工智能、深度学习、大数据分析等前沿领域的不断拓展与深化&#xff0c;…

python+matplotlib 画一个漂亮的折线统计图

pythonmatplotlib 画一个漂亮的折线统计图 有详细的注释说明…… import matplotlib.pyplot as plt import numpy as np import mathdef draw_line_chart(Line_data_list,title,pic_name)::param Line_data_list: 折线数据源:param title: 图表名称:param pic_name: 保存图片名…

免费!OpenAI发布最新模型GPT-4o mini,取代GPT3.5,GPT3.5退出历史舞台?

有个小伙伴问我&#xff0c;GPT-4O mini是什么&#xff0c;当时我还一脸懵逼&#xff0c;便做了一波猜测&#xff1a; 我猜测哈&#xff0c;这个可能是ChatGPT4o的前提下&#xff0c;只支持文本功能的版本&#xff0c;速度更快 结果&#xff0c;大错特错。 让我们一起看看Open…

理解高并发

文章目录 1、如何理解高并发2、高并发的关键指标3、高并发系统设计的目标是什么&#xff1f;1_宏观目标2_微观目标1.性能指标2.可用性指标3.可扩展性指标 4、高并发的实践方案有哪些&#xff1f;1_通用的设计方法1.纵向扩展&#xff08;scale-up&#xff09;2.横向扩展&#xf…

【隐私计算】Cheetah安全多方计算协议-阿里安全双子座实验室

2PC-NN安全推理与实际应用之间仍存在较大性能差距&#xff0c;因此只适用于小数据集或简单模型。Cheetah仔细设计DNN&#xff0c;基于格的同态加密、VOLE类型的不经意传输和秘密共享&#xff0c;提出了一个2PC-NN推理系统Cheetah&#xff0c;比CCS20的CrypTFlow2开销小的多&…

数据结构—线性表和顺序表

线性表&#xff1a; 线性表是一个由n个具有相同特性的数据元素构成的有限序列。常用到的线性表都有&#xff1a;链表、队列、栈、顺序表.... 顺序表&#xff1a; 顺序表是用一段物理地址连续的存储单元依次存储数据元素的线性结构&#xff08;顺序表的元素类型是包装类&#x…

[苍穹外卖]-10WebSocket入门与实战

WebSocket WebSocket是基于TCP的一种新的网络协议, 实现了浏览器与服务器的全双工通信, 即一次握手,建立持久连接,双向数据传输 区别 HTTP是短连接, WebSocket是长连接HTTP单向通信, 基于请求响应模型WebSocket支持双向通信 相同 HTTP和WebSocket底层都是TCP连接 应用场景…

Android 通过相机和系统相册获取图片,压缩,结果回调

一、需求背景 在常规的App开发中&#xff0c;很多时候需要用户上传图片来进行一些业务上的实现&#xff0c;例如用户反馈&#xff0c;图片凭证等。 二、实现功能 1.选择弹窗&#xff08;即选择拍照或者相册&#xff09; 2.申请权限&#xff08;相机权限&#xff09; 3.相机…

油耳用什么掏耳朵比较好?可视挖耳勺推荐平价

掏耳朵是一个轻松又舒服的感觉&#xff0c;很多人就会用棉签和普通耳勺越掏越进&#xff0c;在盲掏的过程中容易弄伤耳膜。所以我们在掏耳时要选好工具。市面上的智能可视挖耳勺&#xff0c;顶端带有摄像头&#xff0c;可以通过清楚的观察到耳道中的情况。但现在市面上关于可视…

【Unity学习心得】如何使用Unity制作“饥荒”风格的俯视角2.5D游戏

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、需要导入的素材二、要实现的步骤 俯视角2D人物移动控制2.5D风格的实现使用协程实现相机绕玩家旋转效果总结 前言 由于要找工作开始重新拾起学习Unity&#…

系统资源智能管理:zTasker软件的监控与优化

在创新的引领下&#xff0c;科技不断迭代升级&#xff0c;为我们应对快节奏生活的挑战提供了强大的工具。它让我们在协调工作与家庭的同时&#xff0c;也能保持内心的宁静与平衡——而自动化工具的出现&#xff0c;正是科技力量在提升工作效率和生活质量方面的体现。zTasker&am…

System.out源码解读——err 和 out 一起用导致的顺序异常Bug

前言 笔者在写一个小 Demo 的过程中&#xff0c;发现了一个奇怪的问题。问题如下&#xff1a; // 当 flagtrue 时打印 a1 &#xff1b;当 flagfalse 时打印 a2。 public static void main(String[] args) {boolean flag false;for (int i 0; i < 10; i) {if (flag) {Sys…

AI 与大模型如何助力金融研发效能最大化?

在金融行业&#xff0c;技术创新与严格合规的需求并行存在&#xff0c;推动着研发团队不断寻求更高效的解决方案。面对日益增长的市场竞争和技术进步&#xff0c;金融机构必须迅速适应变化&#xff0c;同时确保所有创新措施都符合监管要求。这种需求催生了对高效研发流程和先进…

深入掌握:如何进入Docker容器并运行命令

感谢浪浪云支持发布 浪浪云活动链接 &#xff1a;https://langlangy.cn/?i8afa52 文章目录 查看正在运行的容器使用 docker exec 命令进入容器进入容器的交互式 shell在容器中运行命令 使用 docker attach 命令附加到容器检查容器日志退出容器从 docker exec 方式退出从 docke…

趣味SQL | 从围棋收官到秦楚大战的数据库SQL语言实现

目录 0 前言 1 秦孝公大战商鞅 2 收官类型与城池特征 3 收官顺序与攻城策略 4 秦孝公展示SQL神功 5 写在最后 欲知后事如何&#xff0c;想进一步了解SQL这门艺术语言的&#xff0c;可以订阅我的专栏数字化建设通关指南&#xff0c;且听下回分解。专栏 原价99&#xff0c…

MacBook上怎么查找历史复制记录?

你是否经常遇到这样的情况:做内容或方案时,需要用到素材就去找,找到后回来粘贴,然后再去找,再回来粘贴?这个过程是不是很繁琐? 那么找到的素材要不要保存下来呢?每个都存成文件似乎太麻烦了。但如果不单独保存,过两天想再利用又找不到了,怎么办? 在网上看到的一段好文案、…