yolov8训练进阶:从配置文件读入配置参数

news2024/12/18 9:55:42

yolov8官方教程提供了2种训练方式,一种是通过命令行启动训练,一种是通过写代码启动。
image.png
image.png
命令行的方式启动方便,通过传入参数可以方便的调整训练参数,但这种方式不方便记录训练参数和调试训练代码。
自行写训练代码的方式更灵活,也比较方便调试,但官方的示例各种参数都是在代码中写死的方式,失去了灵活性。
其实我们可以结合这两种方法的优势,既能够通过命令行参数修改很容易变化的参数(如batch size, epoch, imgsz等),然后用配置文件保存很少需要变化的参数,或者这些变化需要保存下来方便对比(如各类增强比例)。

代码分析

首先我们需要知道我们能够设置哪些参数,尽管官方文档列出了命令行能够传入的参数列表,但每次设置大量参数还是不方便,而不设置的时候默认参数是多少我们也不知道,所以还是有必要分析一下代码。
通过模型的train接口我们会知道所有的Trainer均继承自BaseTrainer(yolo/engine/trainer.py),该类的构造函数如下:

def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """
        Initializes the BaseTrainer class.

        Args:
            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
            overrides (dict, optional): Configuration overrides. Defaults to None.
        """
        self.args = get_cfg(cfg, overrides)
        self.device = select_device(self.args.device, self.args.batch)
        self.check_resume()
        ...

其中overrides就是我们设置的参数,我们未设置的参数则来源于DEFAULT_CFG,继续跟踪我们会发现这个DEFAULT_CFG实际来源于yolo/cfg/default.yaml:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# Default training settings and hyperparameters for medium-augmentation COCO training

task: detect  # YOLO task, i.e. detect, segment, classify, pose
mode: train  # YOLO mode, i.e. train, val, predict, export, track, benchmark

# Train settings -------------------------------------------------------------------------------------------------------
model:  # path to model file, i.e. yolov8n.pt, yolov8n.yaml
data:  # path to data file, i.e. coco128.yaml
epochs: 100  # number of epochs to train for
start_epoch: 0  # start epoch
patience: 50  # epochs to wait for no observable improvement for early stopping of training
batch: 16  # number of images per batch (-1 for AutoBatch)
imgsz: 640  # size of input images as integer or w,h
save: True  # save train checkpoints and predict results
save_period: -1 # Save checkpoint every x epochs (disabled if < 1)
cache: False  # True/ram, disk or False. Use cache for data loading
device:  # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
workers: 8  # number of worker threads for data loading (per RANK if DDP)
project:  # project name
name:  # experiment name, results saved to 'project/name' directory
exist_ok: False  # whether to overwrite existing experiment
pretrained: False  # whether to use a pretrained model
optimizer: SGD  # optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
verbose: True  # whether to print verbose output
seed: 0  # random seed for reproducibility
deterministic: True  # whether to enable deterministic mode
single_cls: False  # train multi-class data as single-class
rect: False  # rectangular training if mode='train' or rectangular validation if mode='val'
cos_lr: False  # use cosine learning rate scheduler
close_mosaic: 0  # (int) disable mosaic augmentation for final epochs
resume: False  # resume training from last checkpoint
amp: True  # Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
fraction: 1.0  # dataset fraction to train on (default is 1.0, all images in train set)
profile: False  # profile ONNX and TensorRT speeds during training for loggers
# Segmentation
overlap_mask: True  # masks should overlap during training (segment train only)
mask_ratio: 4  # mask downsample ratio (segment train only)
# Classification
dropout: 0.0  # use dropout regularization (classify train only)
...

我们所有能设置的参数就在这个文件中,如果我们设置了不在其中的参数则会报错(下一篇介绍怎么增加参数)。

自定义参数配置文件

我们可以将训练会调整的参数单独保存到一个yaml文件,如hyp.scratch.yaml作为从头训练的配置,进行多次实验时,就可以建立不同的配置参数文件:

lr0: 0.01  # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
lrf: 0.001  # final learning rate (lr0 * lrf)
momentum: 0.937  # SGD momentum/Adam beta1
weight_decay: 0.0005  # optimizer weight decay 5e-4
warmup_epochs: 3.0  # warmup epochs (fractions ok)
warmup_momentum: 0.8  # warmup initial momentum
warmup_bias_lr: 0.1  # warmup initial bias lr
box: 7.5  # box loss gain
cls: 0.5  # cls loss gain (scale with pixels)
dfl: 1.5  # dfl loss gain
pose: 12.0  # pose loss gain
kobj: 1.0  # keypoint obj loss gain
label_smoothing: 0.0  # label smoothing (fraction)
nbs: 64  # nominal batch size
hsv_h: 0.015  # image HSV-Hue augmentation (fraction)
hsv_s: 0.7  # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4  # image HSV-Value augmentation (fraction)
degrees: 0.0  # image rotation (+/- deg)
translate: 0.1  # image translation (+/- fraction)
scale: 0.5  # image scale (+/- gain)
shear: 0.0  # image shear (+/- deg)
perspective: 0.0  # image perspective (+/- fraction), range 0-0.001
flipud: 0.0  # image flip up-down (probability)
fliplr: 0.5  # image flip left-right (probability)
mosaic: 0.1  # image mosaic (probability)
mixup: 0.05  # image mixup (probability)
copy_paste: 0.0  # segment copy-paste (probability)

workers: 12  # number of workers
# cache: disk

自定义训练脚本

建立了自定义参数文件,我们还要建立自己的训练脚本来载入配置文件,并且还有一些经常变化的参数需要通过命令行传入, 新建train.py:

from ultralytics import YOLO
import yaml
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='configs/data/phd.yaml', help='dataset.yaml path')
parser.add_argument('--epochs', type=int, default=300, help='number of epochs')
parser.add_argument('--hyp', type=str, default='configs/hyp.yaml', help='size of each image batch')
parser.add_argument('--model', type=str, default='weights/yolov8n.pt', help='pretrained weights or model.config path')
parser.add_argument('--batch-size', type=int, default=64, help='size of each image batch')
parser.add_argument('--img-size', type=int, default=320, help='size of each image dimension')
parser.add_argument('--device', type=str, default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--project', type=str, default='yolo', help='project name')
parser.add_argument('--name', type=str, default='pretrain', help='exp name')
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')

args = parser.parse_args()

assert args.data, 'argument --data path is required'
assert args.model, 'argument --model path is required'

if __name__ == '__main__':
    # Initialize
    model = YOLO(args.model)
    hyperparams = yaml.safe_load(open(args.hyp))
    hyperparams['epochs'] = args.epochs
    hyperparams['batch'] = args.batch_size
    hyperparams['imgsz'] = args.img_size
    hyperparams['device'] = args.device
    hyperparams['project'] = args.project
    hyperparams['name'] = args.name
    hyperparams['resume'] = args.resume

    model.train(data= args.data, **hyperparams)

该脚本通过argparse来接受命令行参数,并设置到超参数字典,和yolov5的启动脚本类似。
主要有以下几个参数(可以根据个人需要增删):

  • data: 数据集配置文件
  • hyp: 参数配置文件(上一节我们建立的)
  • model: 模型权重或者模型结构配置文件
    其他参数根据名字就显而易见了。

模型训练(单卡)

python train.py --model weights/yolov8n.pt --data
configs/data/objects365.yaml --hyp configs/hyp.yaml --batch-size 512 --img-size 416 --device
0 --project object365 --name yolov8n

模型训练(多卡DDP)

理论上,我们只需要将device设置为多张卡就可以进行多卡并行了,但我们直接运行会发生一下错误:

assert args.model, 'argument --model path is required'

也就是我们设置的参数并没有接收到,进一步分析,DDP情况下,实际运行的命令是:

DDP command: ['/root/miniconda3/bin/python', '-m', 'torch.distributed.run', '--nproc_per_node', '4', '--master_port', '39083', 'xxx/code/yolov8/train.py']
WARNING:__main__:

也就是yolov8实际是用pytorch的ddp脚本启动了我们写得train.py脚本,但是却没有把我们设置的参数传过来(应该算是个bug吧···),这个过程发生在BaseTrainer的train接口中:
image.png
我们对generate_ddp_command进行修改,将命令行参数增加到train.py后(file后增加*sys.argv[1:]):

cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file, *sys.argv[1:]]

完整的函数:

def generate_ddp_command(world_size, trainer):
    """Generates and returns command for distributed training."""
    import __main__  # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
    if not trainer.resume:
        shutil.rmtree(trainer.save_dir)  # remove the save_dir
    file = str(Path(sys.argv[0]).resolve())

    safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$')  # allowed characters and maximum of 100 characters
    if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')):  # using CLI
        file = generate_ddp_file(trainer)
    dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
    port = find_free_network_port()
    cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file, *sys.argv[1:]]
    return cmd, file

修改后,device设置多卡则能正常开启训练。

结语

本文介绍了如何使用自定义训练脚本的方式启动yolov8的训练,有效的结合命令行和配置文件的优点,即可以灵活的修改训练参数,又可以用配置文件来管理我们的训练超参数。并通过修改文件,支持了DDP训练。

f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/877434.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Vue框架】用户和请求

前言 在上一篇 【Vue框架】Vuex状态管理 针对Vuex状态管理以getters.js进行说明&#xff0c;没有对其中state引入的对象进行详细介绍&#xff0c;因为整体都比较简单&#xff0c;也就不对全部做详细介绍了&#xff1b;但其中的user.js涉及到获取用户的信息、前后端请求的token…

今天来给大家聊一聊什么是Hierarchical-CTC模型

随着人工智能领域的不断发展&#xff0c;语音识别技术在日常生活和工业应用中扮演着越来越重要的角色。为了提高识别准确性和效率&#xff0c;研究人员不断探索新的模型和算法。在这个领域中&#xff0c;Hierarchical-CTC模型引起了广泛的关注和兴趣。本文将介绍什么是Hierarch…

JavaFx基础学习【二】:Stage

一、介绍 窗口Stage为图中标绿部分&#xff1a; 实际为如下部分&#xff1a; 不同的操作系统表现的样式不同&#xff0c;以下都是以Windows操作系统为例&#xff0c;为了使大家更清楚Stage是那部分&#xff0c;直接看以下图&#xff0c;可能更清楚&#xff1a; 有点潦草&…

MachineLearningWu_15/P70-P71_AdamAndConv

x.1 算法参数更新 我们使用梯度下降算法来自动更新参数&#xff0c;但是由于学习率的不好选择性&#xff0c;我们有时候会下降地很快&#xff0c;有时候下降地很慢&#xff0c;我们期望有一种方式能够自动调整学习率的变化&#xff0c;这里引入Adaptive Moment Estimation/Ada…

City Walk带动茶饮品牌售1200万,媒介盒子带你探究奥秘

年轻人生活趋势又出现了一个新鲜词——City Walk&#xff0c;简单来说&#xff0c;City Walk就是没有目的地&#xff0c;没有目标&#xff0c;只是出行&#xff0c;填充自己的生活。 其实这个词源于gap year&#xff0c;而这个说法一直是国外的一些毕业生&#xff0c;大多会在…

解决方案 | 法大大加速医疗器械行业创新升级

科技的不断进步&#xff0c;带动医疗器械产品不断创新升级。数字化、智能化的技术也开始广泛应用在医疗器械行业中。行业的蓬勃发展&#xff0c;进一步驱动了医疗器械行业规范化管理政策的出台&#xff0c;2019年&#xff0c;《医疗器械产品注册管理办法》&#xff08;2019&…

Mongodb (四十一)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 前言 一、概述 1.1 相关概念 1.2 特性 二、应用场景 三、安装 四、目录结构 五、默认数据库 六、 数据库操作 6.1 库操作 6.2 文档操作 七、MongoDB数据库备份 7.1 备…

Mybatis-Plus(四 )--其他功能和ActiveRecord和MybatisX和AutoGenerator

一.其他功能 1.自动填充 有些时候我们可能会在插入或者更新数据时&#xff0c;希望有些字段可以自动填充数据&#xff0c;比如密码&#xff0c;version等。 【1】添加TableField注解 TableField(fillFieldFill.INSERT)//插入数据时进行填充 private String password; 除了…

【机密计算实践】OPEN Enclave SDK 安装与构建

机密计算是基于硬件支持的可信执行环境的&#xff0c;比如 Intel SGX 硬件技术上面的 enclave 以及 Arm Trustzone 上的 OT-TEE&#xff0c;不过这些异构的 TEE 之间差异还是蛮大的&#xff0c;所以亟需一种能够屏蔽 TEE 差异软件中间件或者 SDK&#xff0c;这就是本文将要提到…

我学会这些车载技术,是不是就可以进【小米】车企?

作者&#xff1a;阿刁 随着智能化和电动化的发展&#xff0c;车载开发领域的前景非常广阔。许多手机厂商也纷纷加入进来&#xff0c;华为、小米等手机巨头也相继推出新能源汽车。所以在未来&#xff0c;车载系统将成为汽车的核心部分&#xff0c;涵盖车辆的控制、信息娱乐、智能…

“先锋龙颜美学”,比亚迪宋L 完成工信部申报,单双电机正式上市

根据工信部最新发布的《道路机动车辆生产企业及产品公告》&#xff08;第 374 批&#xff09;&#xff0c;我们得知比亚迪汽车公司的新款车型宋 L 已经顺利完成申报&#xff0c;并成功获得核准。这款车型将会有两个版本&#xff0c;分别是单电机和双电机版本。 此外&#xff0c…

Redis——String类型详解

概述 Redis中的字符串直接按照二进制的数据存储&#xff0c;不会有任何的编码转换&#xff0c;因此存放什么样&#xff0c;取出来的时候就什么样。而MySQL默认的字符集是拉丁文&#xff0c;如果插入中文就会失败 Redis中的字符串类型不仅可以存放文本数据&#xff0c;还可以存…

GloVe、子词嵌入、BPE字节对编码、BERT相关知识(第十四次组会)

GloVe、子词嵌入、BPE字节对编码、BERT相关知识(第十四次组会) Glove子词嵌入上游、下游任务监督学习、无监督学习BERTGlove 子词嵌入 上游、下游任务 监督学习、无监督学习 BERT

强制Edge或Chrome使用独立显卡【WIN10】

现代浏览器通常将图形密集型任务卸载到 GPU&#xff0c;以改善你的网页浏览体验&#xff0c;从而释放 CPU 资源用于其他任务。 如果你的系统有多个 GPU&#xff0c;Windows 10 可以自动决定最适合 Microsoft Edge 自动使用的 GPU&#xff0c;但这并不一定意味着最强大的 GPU。 …

8.14 刷题【7道】

二叉树 1. 树中两个结点的最低公共祖先 原题链接 方法一&#xff1a;公共路径 分别找出根节点到两个节点的路径&#xff0c;则最后一个公共节点就是最低公共祖先了。 时间复杂度分析&#xff1a;需要在树中查找节点&#xff0c;复杂度为O(n) /*** Definition for a binary…

CUDA、cuDNN以及Pytorch介绍

文章目录 前言一、CUDA二、cuDNN三、Pytorch 前言 在讲解cuda和cuDNN之前&#xff0c;我们首先来了解一下英伟达&#xff08;NVIDA&#xff09;公司。 NVIDIA是一家全球领先的计算机技术公司&#xff0c;专注于图形处理器&#xff08;GPU&#xff09;和人工智能&#xff08;…

买机票系统---(java实现)

/* * 案例 * 卖机票 * 需求&#xff1a;机票价格按照淡季和旺季&#xff0c;头等舱和经济舱收费&#xff0c;输入机票原价&#xff0c;月份和头等舱或经济舱 * 旺季&#xff08;5-10月&#xff09;&#xff1a;头等舱9折&#xff0c;经济舱8.5折 * 淡季&#xff08;11-来年4月&…

小目标检测(5)——有线硬触发和有线软触发架构学习

文章目录 引言正文PLC介绍有线硬触发有线软触发硬件接口 总结引用 引言 之前花了很多时间也就是仅仅看懂了基本代码,最近和老师交流之后,发现还有很多东西都需要弄.最终的灯检机,并不是直接接上计算机就使用的,并不是单纯通过计算机控制的,还有一个叫做PLC(可编程逻辑控制器),…

Python自动化实战之使用Selenium进行Web自动化详解

概要 为了完成一项重复的任务&#xff0c;你需要在网站上进行大量的点击和操作&#xff0c;每次都要浪费大量的时间和精力。Python的Selenium库就可以自动化完成这些任务。 在本篇文章中&#xff0c;我们将会介绍如何使用Python的Selenium库进行Web自动化&#xff0c;以及如何…

免费敏捷工具做敏捷需求管理

传统的瀑布工作模式使用详细的需求说明书来表达需求&#xff0c;需求人员负责做需求调研&#xff0c;根据调研情况编制详细的需求说明书&#xff0c;进行需求评审&#xff0c;评审之后签字确认交给研发团队设计开发。在这样的环境下&#xff0c;需求文档是信息传递的主体&#…