YOLOv8 训练自己的数据集

news2024/12/24 0:02:47

本范例我们使用 ultralytics中的YOLOv8目标检测模型训练自己的数据集,从而能够检测气球。

#安装
!pip install -U ultralytics -i https://pypi.tuna.tsinghua.edu.cn/simple
import ultralytics 
ultralytics.checks()

一,准备数据

公众号算法美食屋后台回复关键词:yolov8,获取本文notebook源代码和数据集。

训练yolo模型需要将数据集整理成yolo数据集格式。然后写一个yaml的数据集配置文件。

yolo_dataset
├── images
│   ├── train
│   │   ├── train0.jpg
│   │   └── train1.jpg
│   ├── val
│   │   ├── val0.jpg
│   │   └── val1.jpg
│   └── test
│       ├── test0.jpg
│       └── test1.jpg
└── labels
    ├── train
    │   ├── train0.txt
    │   └── train1.txt
    ├── val
    │   ├── val0.txt
    │   └── val1.txt
    └── test
        ├── test0.txt
        └── test1.txt

其中标签文件(如train0.txt)格式如下:

class_id center_x center_y bbox_width bbox_height
0 0.300926 0.617063 0.601852 0.765873
1 0.575 0.319531 0.4 0.551562

注意class_id从0开始,中心点坐标和高宽都是相对坐标。

使用 Labelme或者 makesense标注样本可以直接导出该种类型样本。

%%writefile balloon.yaml
# Ultralytics YOLO 🚀, GPL-3.0 license

path: /tf/liangyun2/torchkeras/notebooks/datasets/balloon   # dataset root dir
train: images/train  # train images (relative to 'path') 128 images
val: images/val  # val images (relative to 'path') 128 images
test:  # test images (optional)

# Classes
names:
  0: ballon
Overwriting balloon.yaml
import torch
from torch.utils.data import DataLoader
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.utils import DEFAULT_CFG,yaml_load 
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.data import build_yolo_dataset,build_dataloader

overrides = {'task':'detect',
             'data':'balloon.yaml',
             'imgsz':640,
             'workers':4
            }
cfg = get_cfg(cfg = DEFAULT_CFG,overrides=overrides)
data_info = check_det_dataset(cfg.data)
ds_train = build_yolo_dataset(cfg,img_path=data_info['train'],batch=cfg.batch,
                              data_info = data_info,mode='train',rect=False,stride=32)

ds_val = build_yolo_dataset(cfg,img_path=data_info['val'],batch=cfg.batch,data_info = data_info,
    mode='val',rect=False,stride=32)
#dl_train = build_dataloader(ds_train,batch=cfg.batch,workers=0)
#dl_val = build_dataloader(ds_val,batch=cfg.batch,workers =0,shuffle=False)
dl_train = DataLoader(ds_train,batch_size = cfg.batch, num_workers = cfg.workers,
                      collate_fn = ds_train.collate_fn)

dl_val = DataLoader(ds_val,batch_size = cfg.batch, num_workers = cfg.workers,
                      collate_fn = ds_val.collate_fn)
for batch in dl_val:
    break
batch.keys()
dict_keys(['im_file', 'ori_shape', 'resized_shape', 'ratio_pad', 'img', 'cls', 'bboxes', 'batch_idx'])

二,定义模型

from ultralytics.nn.tasks import DetectionModel

model = DetectionModel(cfg = 'yolov8n.yaml', ch=3, nc=1)
#weights = torch.hub.load_state_dict_from_url('https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt')
weights = torch.load('yolov8n.pt')
model.load(weights['model'])
model.args = cfg
model.nc = data_info['nc']  # attach number of classes to model
model.names = data_info['names']

三,训练模型

1,使用ultralytics原生接口

使用ultralytics的原生接口,只需要以下几行代码即可。

from ultralytics import YOLO 
yolo_model = YOLO('yolov8n.pt')

yolo_model.train(data='balloon.yaml',epochs=10)

0796ae19d15665d4d116e5ece0842c5f.png


2,使用torchkeras梦中情炉

尽管使用ultralytics原生接口非常简单,再使用torchkeras实现自定义训练逻辑似乎有些多此一举。

但ultralytics的源码结构相对复杂,不便于用户做个性化的控制和修改。

并且,torchkeras在可视化上会比ultralytics的原生训练代码优雅许多。

此外,掌握自定义训练逻辑对大家熟悉ultralytics这个库的代码结构也会有所帮助。

for batch in dl_train:
    break
from ultralytics.yolo.v8.detect.train import Loss 

model.cuda()
loss_fn = Loss(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) 


x = batch['img'].float()/255 

preds = model.forward(x.cuda())
loss = loss_fn(preds,batch)[0]
print(loss)
tensor(74.5465, device='cuda:0', grad_fn=<MulBackward0>)
from torchkeras import KerasModel 

#我们需要修改StepRunner以适应Yolov8的数据集格式

class StepRunner:
    def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
    
    def __call__(self, batch):
        
        features = batch['img'].float() / 255
        
        #loss
        preds = self.net(features)
        loss = self.loss_fn(preds,batch)[0]

        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
        all_preds = self.accelerator.gather(preds)
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses
        step_losses = {self.stage+"_loss":all_loss.item()}
        
        #metrics
        step_metrics = {}
        
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
KerasModel.StepRunner = StepRunner
keras_model = KerasModel(net = model, 
                         loss_fn = loss_fn, 
                         optimizer = optimizer)
keras_model.fit(train_data=dl_train,
                val_data=dl_val,
                epochs = 200,
                ckpt_path='checkpoint.pt',
                patience=20,
                monitor='val_loss',
                mode='min',
                mixed_precision='no',
                plot= True,
                wandb = False,
                quiet = True
               )

9e02b43baca40414b19510ff7a3cb212.png

d242920e834f1e8615503e8581c88c0f.png

四,评估模型

为了便于评估 map等指标,我们将权重再次保存后,用ultralytics的原生YOLO接口进行加载后评估。

keras_model.evaluate(dl_val)
100%|██████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.32it/s, val_loss=28.7]



{'val_loss': 28.715129852294922}
from ultralytics import YOLO 
keras_model.load_ckpt('checkpoint.pt')
save_dic = dict(model = keras_model.net, train_args =dict(cfg))
torch.save(save_dic, 'best_yolo.pt')
from ultralytics import YOLO 
best_model = YOLO(model = 'best_yolo.pt')
metrics = best_model.val(data = cfg.data )
metrics.results_dict
{'metrics/precision(B)': 0.9188790992746612,
 'metrics/recall(B)': 0.74,
 'metrics/mAP50(B)': 0.8516599658911874,
 'metrics/mAP50-95(B)': 0.7321355695315829,
 'fitness': 0.7440880091675434}
import pandas as pd 
df = pd.DataFrame()
df['metric'] = metrics.keys
for i,c in best_model.names.items():
    df[c] = metrics.class_result(i)

df

f40c837b5009663440bae9772217fca8.png

五,使用模型

from pathlib import Path 
root_path = './datasets/balloon/'
data_root = Path(root_path)

best_model = YOLO(model = 'best_yolo.pt')
val_imgs = [str(x) for x in (data_root/'images'/'train').rglob("*.jpg") if 'checkpoint' not in str(x)]
img_path = val_imgs[5]
import os 
from PIL import Image 
result = best_model.predict(source = img_path,save=True)
best_model.predictor.save_dir/os.path.basename(img_path)
Image.open(best_model.predictor.save_dir/os.path.basename(img_path))

639b298b28c05879b3bef9a5c94303da.png

六,导出模型

best_model.export(format='onnx')
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
predictor = DetectionPredictor(
    overrides=dict(model='best_yolo.onnx'))
results = list(predictor.stream_inference(source=img_path))

公众号算法美食屋后台回复关键词:yolov8,获取本文notebook源代码和数据集。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/610485.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaScript之DOM(九)

JavaScript之DOM 1、节点类型2、常用的属性与方法2.1、访问节点的常用方法2.2、增删改节点的常用方法2.3、class的常用方法2.4、css相关操作 DOM – Document Object Model (文档对象模型)&#xff0c;是 JS 操作 HTML 文档的接口&#xff0c;它最大的特点就是将文档表示为节点…

CloudQuery一体化数据库SQL操作安全管控平台

&#x1f497;wei_shuo的个人主页 &#x1f4ab;wei_shuo的学习社区 &#x1f310;Hello World &#xff01; CloudQuery一体化数据库SQL操作安全管控平台 导读 CloudQuery作为业界领先的面向企业的数据库安全解决方案&#xff0c;CloudQuery致力于打造一站式安全可靠的数据操…

【C++】右值引用和移动语义

1.左值和右值 在C中&#xff0c;每个表达式或者是左值&#xff0c;或者是右值。 左值(lvalue)&#xff1a;可以出现在赋值表达式左侧的值&#xff0c;例如变量名a、数据成员a.m、下标表达式a[n]、解引用表达式*p等。左值可以被赋值和取地址。右值(rvalue)&#xff1a;只能出现…

jdk动态代理源码分析

jdk动态代理源码分析 前言动态代理----demo 案例jdk动态代理源码创建代理对象获取类把二进制流生成文件 jdk 动态代理的原理 前言 上一篇中我们知道动态代理的使用, Javase 专题之 静态代理和动态代理 我们只知道其中的使用,但是原理是什么? 不明白原理只知皮毛不是我们的目的…

chatgpt赋能python:Python在原图上继续画的SEO

Python在原图上继续画的SEO Python是一种高级的多范式编程语言&#xff0c;它使用简单、易于阅读的语法以及丰富和强大的数据结构使其成为工程师的首选。Python已经成为了一种非常流行的编程语言&#xff0c;它用于多种应用领域&#xff0c;包括Web开发、数据科学、机器学习、…

区间预测 | MATLAB实现基于QRCNN-LSTM卷积长短期记忆神经网络多变量时间序列区间预测

区间预测 | MATLAB实现基于QRCNN-LSTM卷积长短期记忆神经网络多变量时间序列区间预测 目录 区间预测 | MATLAB实现基于QRCNN-LSTM卷积长短期记忆神经网络多变量时间序列区间预测效果一览基本介绍模型描述程序设计参考资料 效果一览 基本介绍 1.Matlab实现基于QRCNN-LSTM卷积神经…

注解、原生Spring、SchemaBased三种方式实现AOP【附详细案例】

目录 一、注解配置AOP 1. 开启注解支持 2. 在类和方法加入注解 3. 测试 4. 为一个类下的所有方法统一配置切点 二、原生Spring实现AOP 1. 引入依赖 2. 编写SpringAOP通知类 3. 编写配置类bean2.xml 4 测试 三、SchemaBased实现AOP 1. 配置切面 2. 测试 往期专栏…

音视频技术开发周刊 | 296

每周一期&#xff0c;纵览音视频技术领域的干货。 新闻投稿&#xff1a;contributelivevideostack.com。 22字声明、近400名专家签署、AI教父Hinton与OpenAI CEO领头预警&#xff1a;AI可能灭绝人类&#xff01; 这份声明一经发布&#xff0c;便迅速得到了多伦多大学计算机科学…

基于zookeeper的kafka中间件

一、Zookeeper 概述 1、Zookeeper 定义 Zookeeper是一个开源的分布式的&#xff0c;为分布式框架提供协调服务的Apache项目。 2、Zookeeper 工作机制 Zookeeper从设计模式角度来理解&#xff1a;是一个基于观察者模式设计的分布式服务管理框架&#xff0c;它负责存储和管理…

昨天,小灰做了人生的第一次直播!

熟悉小灰的朋友们都知道&#xff0c;小灰是一个非常腼腆的人。虽然我比较擅长写东西&#xff0c;但完全不擅长口头表达&#xff0c;在公开场合讲话很容易紧张。 因此&#xff0c;对于网上直播&#xff0c;小灰在以前完全不敢想象。 但是&#xff0c;人终究需要成长的。就在昨天…

Disco Diffusion 快速入门

Disco Diffusion 快速入门 简介快速开始进阶使用修改prompt给定指导图像修改基础参数运行参数设置运行建议模型设置参数详情 简介 Disco Diffusion&#xff08;DD&#xff09;是一个CLIP指导的AI图像生成技术&#xff0c;简单来说&#xff0c;Diffusion是一个对图像不断去噪的…

路径规划 | 图解RRT-Connect算法(附ROS C++/Python/Matlab仿真)

目录 0 专栏介绍1 RRT-Connect基本原理2 RRT-Connect vs. RRT3 ROS C算法实现4 Python算法实现5 Matlab算法实现 0 专栏介绍 &#x1f525;附C/Python/Matlab全套代码&#x1f525;课程设计、毕业设计、创新竞赛必备&#xff01;详细介绍全局规划(图搜索、采样法、智能算法等)…

chatgpt赋能python:Python实现奇数位偶数位互换的方法

Python实现奇数位偶数位互换的方法 Python是一种高级的、面向对象的编程语言&#xff0c;在当今的编程领域中具有广泛的应用。它被用于数据分析、机器学习、Web开发等众多领域&#xff0c;其简洁的语法和强大的库被开发者们广泛使用。本文将介绍Python中奇数位偶数位互换的方法…

驱动开发:内核实现SSDT挂钩与摘钩

在前面的文章《驱动开发&#xff1a;内核解析PE结构导出表》中我们封装了两个函数KernelMapFile()函数可用来读取内核文件&#xff0c;GetAddressFromFunction()函数可用来在导出表中寻找指定函数的导出地址&#xff0c;本章将以此为基础实现对特定SSDT函数的Hook挂钩操作&…

【Django 网页Web开发】07. 快捷的表单生成 Form与MoudleForm(保姆级图文)

目录 注意 正规写法是 ModelForm&#xff0c;下面文章我多实现效果url.py新建3个html文件数据库连接model.py 数据表1. 原始方法view.pytestOrgion.html 2. Form方法view.pytestForm.html 3. MoudleForm方法给字段设置样式面向对象的思路&#xff0c;批量添加样式错误信息的显示…

ASIC-WORLD Verilog(10)编写测试脚本Testbench的艺术

写在前面 在自己准备写一些简单的verilog教程之前&#xff0c;参考了许多资料----Asic-World网站的这套verilog教程即是其一。这套教程写得极好&#xff0c;奈何没有中文&#xff0c;在下只好斗胆翻译过来&#xff08;加了自己的理解&#xff09;分享给大家。 这是网站原文&…

干货!来自北大、KAUST、斯坦福、达摩院的大模型前沿动态:表格推理、代码生成、MiniGPT-4、生成式推理...

点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入&#xff01; ChatGPT的发布使得国内外众多的研究机构掀起了一股AI热潮&#xff0c;而这也进一步推动了人们对大语言模型的深入研究。2023年4月26日&#xff0c;AI TIME举办的大模型专场四活动邀请了阿里巴巴达摩院NLP研究员…

在 IDEA 中配置 JavaFX 11

因为从 Java8/openjdk 之后&#xff0c;javafx 从 jdk 中移除&#xff0c;如果进行 JavaFX 开发需要在 module 中添加 lib&#xff0c;并对 IDE 进行配置&#xff0c;确保 jdk 可以与 javafx 正常调用。 javafx 下载路径&#xff0c;主页网址&#xff1a;https://openjfx.io/ …

开发实践|程序员是如何刷抖音、玩快手、看头条进行赚米的?

欢迎关注「全栈工程师修炼指南」公众号 点击 &#x1f447; 下方卡片 即可关注我哟! 设为「星标⭐」每天带你 基础入门 到 进阶实践 再到 放弃学习&#xff01; “ 花开堪折直须折&#xff0c;莫待无花空折枝。 ” 作者主页&#xff1a;[ https://www.weiyigeek.top ] 博客&…

【计算机组成原理与体系结构】数据的表示与运算

目录 一、进位计数制 二、信息编码 三、定点数数据表示 四、校验码 五、定点数补码加减运算 六、标志位的生成 七、定点数的移位运算 八、定点数的乘除运算 九、浮点数的表示 十、浮点数的运算 一、进位计数制 整数部分&#xff1a; 二进制、八进制、十六进制 --…