yolov8训练进阶:自定义训练脚本,从配置文件载入训练超参数

news2024/10/6 16:24:35

yolov8官方教程提供了2种训练方式,一种是通过命令行启动训练,一种是通过写代码启动。
image.png
image.png
命令行的方式启动方便,通过传入参数可以方便的调整训练参数,但这种方式不方便记录训练参数和调试训练代码。
自行写训练代码的方式更灵活,也比较方便调试,但官方的示例各种参数都是在代码中写死的方式,失去了灵活性。
其实我们可以结合这两种方法的优势,既能够通过命令行参数修改很容易变化的参数(如batch size, epoch, imgsz等),然后用配置文件保存很少需要变化的参数,或者这些变化需要保存下来方便对比(如各类增强比例)。

代码分析

首先我们需要知道我们能够设置哪些参数,尽管官方文档列出了命令行能够传入的参数列表,但每次设置大量参数还是不方便,而不设置的时候默认参数是多少我们也不知道,所以还是有必要分析一下代码。
通过模型的train接口我们会知道所有的Trainer均继承自BaseTrainer(yolo/engine/trainer.py),该类的构造函数如下:

def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """
        Initializes the BaseTrainer class.

        Args:
            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
            overrides (dict, optional): Configuration overrides. Defaults to None.
        """
        self.args = get_cfg(cfg, overrides)
        self.device = select_device(self.args.device, self.args.batch)
        self.check_resume()
        ...

其中overrides就是我们设置的参数,我们未设置的参数则来源于DEFAULT_CFG,继续跟踪我们会发现这个DEFAULT_CFG实际来源于yolo/cfg/default.yaml:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# Default training settings and hyperparameters for medium-augmentation COCO training

task: detect  # YOLO task, i.e. detect, segment, classify, pose
mode: train  # YOLO mode, i.e. train, val, predict, export, track, benchmark

# Train settings -------------------------------------------------------------------------------------------------------
model:  # path to model file, i.e. yolov8n.pt, yolov8n.yaml
data:  # path to data file, i.e. coco128.yaml
epochs: 100  # number of epochs to train for
start_epoch: 0  # start epoch
patience: 50  # epochs to wait for no observable improvement for early stopping of training
batch: 16  # number of images per batch (-1 for AutoBatch)
imgsz: 640  # size of input images as integer or w,h
save: True  # save train checkpoints and predict results
save_period: -1 # Save checkpoint every x epochs (disabled if < 1)
cache: False  # True/ram, disk or False. Use cache for data loading
device:  # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
workers: 8  # number of worker threads for data loading (per RANK if DDP)
project:  # project name
name:  # experiment name, results saved to 'project/name' directory
exist_ok: False  # whether to overwrite existing experiment
pretrained: False  # whether to use a pretrained model
optimizer: SGD  # optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
verbose: True  # whether to print verbose output
seed: 0  # random seed for reproducibility
deterministic: True  # whether to enable deterministic mode
single_cls: False  # train multi-class data as single-class
rect: False  # rectangular training if mode='train' or rectangular validation if mode='val'
cos_lr: False  # use cosine learning rate scheduler
close_mosaic: 0  # (int) disable mosaic augmentation for final epochs
resume: False  # resume training from last checkpoint
amp: True  # Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
fraction: 1.0  # dataset fraction to train on (default is 1.0, all images in train set)
profile: False  # profile ONNX and TensorRT speeds during training for loggers
# Segmentation
overlap_mask: True  # masks should overlap during training (segment train only)
mask_ratio: 4  # mask downsample ratio (segment train only)
# Classification
dropout: 0.0  # use dropout regularization (classify train only)
...

我们所有能设置的参数就在这个文件中,如果我们设置了不在其中的参数则会报错(下一篇介绍怎么增加参数)。

自定义参数配置文件

我们可以将训练会调整的参数单独保存到一个yaml文件,如hyp.scratch.yaml作为从头训练的配置,进行多次实验时,就可以建立不同的配置参数文件:

lr0: 0.01  # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
lrf: 0.001  # final learning rate (lr0 * lrf)
momentum: 0.937  # SGD momentum/Adam beta1
weight_decay: 0.0005  # optimizer weight decay 5e-4
warmup_epochs: 3.0  # warmup epochs (fractions ok)
warmup_momentum: 0.8  # warmup initial momentum
warmup_bias_lr: 0.1  # warmup initial bias lr
box: 7.5  # box loss gain
cls: 0.5  # cls loss gain (scale with pixels)
dfl: 1.5  # dfl loss gain
pose: 12.0  # pose loss gain
kobj: 1.0  # keypoint obj loss gain
label_smoothing: 0.0  # label smoothing (fraction)
nbs: 64  # nominal batch size
hsv_h: 0.015  # image HSV-Hue augmentation (fraction)
hsv_s: 0.7  # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4  # image HSV-Value augmentation (fraction)
degrees: 0.0  # image rotation (+/- deg)
translate: 0.1  # image translation (+/- fraction)
scale: 0.5  # image scale (+/- gain)
shear: 0.0  # image shear (+/- deg)
perspective: 0.0  # image perspective (+/- fraction), range 0-0.001
flipud: 0.0  # image flip up-down (probability)
fliplr: 0.5  # image flip left-right (probability)
mosaic: 0.1  # image mosaic (probability)
mixup: 0.05  # image mixup (probability)
copy_paste: 0.0  # segment copy-paste (probability)

workers: 12  # number of workers
# cache: disk

自定义训练脚本

建立了自定义参数文件,我们还要建立自己的训练脚本来载入配置文件,并且还有一些经常变化的参数需要通过命令行传入, 新建train.py:

from ultralytics import YOLO
import yaml
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='configs/data/phd.yaml', help='dataset.yaml path')
parser.add_argument('--epochs', type=int, default=300, help='number of epochs')
parser.add_argument('--hyp', type=str, default='configs/hyp.yaml', help='size of each image batch')
parser.add_argument('--model', type=str, default='weights/yolov8n.pt', help='pretrained weights or model.config path')
parser.add_argument('--batch-size', type=int, default=64, help='size of each image batch')
parser.add_argument('--img-size', type=int, default=320, help='size of each image dimension')
parser.add_argument('--device', type=str, default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--project', type=str, default='yolo', help='project name')
parser.add_argument('--name', type=str, default='pretrain', help='exp name')
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')

args = parser.parse_args()

assert args.data, 'argument --data path is required'
assert args.model, 'argument --model path is required'

if __name__ == '__main__':
    # Initialize
    model = YOLO(args.model)
    hyperparams = yaml.safe_load(open(args.hyp))
    hyperparams['epochs'] = args.epochs
    hyperparams['batch'] = args.batch_size
    hyperparams['imgsz'] = args.img_size
    hyperparams['device'] = args.device
    hyperparams['project'] = args.project
    hyperparams['name'] = args.name
    hyperparams['resume'] = args.resume

    model.train(data= args.data, **hyperparams)

该脚本通过argparse来接受命令行参数,并设置到超参数字典,和yolov5的启动脚本类似。
主要有以下几个参数(可以根据个人需要增删):

  • data: 数据集配置文件
  • hyp: 参数配置文件(上一节我们建立的)
  • model: 模型权重或者模型结构配置文件
    其他参数根据名字就显而易见了。

模型训练(单卡)

python train.py --model weights/yolov8n.pt --data
configs/data/objects365.yaml --hyp configs/hyp.yaml --batch-size 512 --img-size 416 --device
0 --project object365 --name yolov8n

模型训练(多卡DDP)

理论上,我们只需要将device设置为多张卡就可以进行多卡并行了,但我们直接运行会发生一下错误:

assert args.model, 'argument --model path is required'

也就是我们设置的参数并没有接收到,进一步分析,DDP情况下,实际运行的命令是:

DDP command: ['/root/miniconda3/bin/python', '-m', 'torch.distributed.run', '--nproc_per_node', '4', '--master_port', '39083', 'xxx/code/yolov8/train.py']
WARNING:__main__:

也就是yolov8实际是用pytorch的ddp脚本启动了我们写得train.py脚本,但是却没有把我们设置的参数传过来(应该算是个bug吧···),这个过程发生在BaseTrainer的train接口中:
image.png
我们对generate_ddp_command进行修改,将命令行参数增加到train.py后(file后增加*sys.argv[1:]):

cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file, *sys.argv[1:]]

完整的函数:

def generate_ddp_command(world_size, trainer):
    """Generates and returns command for distributed training."""
    import __main__  # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
    if not trainer.resume:
        shutil.rmtree(trainer.save_dir)  # remove the save_dir
    file = str(Path(sys.argv[0]).resolve())

    safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$')  # allowed characters and maximum of 100 characters
    if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')):  # using CLI
        file = generate_ddp_file(trainer)
    dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
    port = find_free_network_port()
    cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file, *sys.argv[1:]]
    return cmd, file

修改后,device设置多卡则能正常开启训练。

结语

本文介绍了如何使用自定义训练脚本的方式启动yolov8的训练,有效的结合命令行和配置文件的优点,即可以灵活的修改训练参数,又可以用配置文件来管理我们的训练超参数。并通过修改文件,支持了DDP训练。

f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/880322.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

0基础学习VR全景平台篇 第84篇:智慧眼-权限如何设置?

一、功能说明 指智慧眼项目在运行的整个过程中&#xff0c;相关成员所需要用到的一系列操作权利&#xff1a;如提交&#xff08;问题&#xff09;、审核、整治、驳回、撤销、成员管理、查看数据、导出数据等等&#xff0c;这些都可以称为权限。权限需要通过后台先创建出来&…

item_get_sales-获取TB商品销量详情

一、接口参数说明&#xff1a; item_get_sales-获取商品销量详情&#xff0c;点击更多API调试&#xff0c;请移步注册API账号点击获取测试key和secret 公共参数 请求地址: https://api-gw.onebound.cn/taobao/item_get_sales 名称类型必须描述keyString是调用key&#xff08…

消防态势标绘工具,为消防基层工作助力

背景介绍 无人机测绘技术在消防领域的应用越来越普及&#xff0c;高清的二维正射影像和倾斜摄影实景三维模型能为消防态势标绘提供高质量的素材&#xff0c;消防队急需一个简便易用的、能够基于这些二三维的高清地图成果进行态势标绘的工具软件&#xff0c;使得消防“六熟悉”…

读《芯片浪潮》,学习台积电张忠谋的管理之道

大家知道&#xff0c;台积电一个公司就占据了全球晶圆代工市场一半的份额。 5纳米及以下最先进工艺的芯片&#xff0c;台积电可占到惊人的90%以上的市场。全球最新最强的智能手机、笔记本电脑的核心计算芯片都必须仰仗台积电一个企业的供应。 换一个说法&#xff0c;如果没有…

NeMo 中文ASR话者分离(说话人日志)实战

话者分离或者叫说话人日志,主要是解决说话人什么时候说了什么的问题。典型的应用场景:多人会议、坐席销售/客服场景。 典型的实现过程是基于管道。 首先基于VAD(声音活动检测)的MarbleNet,分割声音片段,然后基于TitaNet-L提取话者特征,然后通过聚类区分话者,最后通过神…

探寻Pytest的重难点:挑战与突破

Pytest作为Python社区中广受欢迎的测试框架&#xff0c;以其简洁优雅的语法和强大的功能&#xff0c;成为了许多开发者的首选。然而&#xff0c;在使用Pytest的过程中&#xff0c;我们不可避免地会遇到一些重难点&#xff0c;这些挑战也正是我们不断学习和成长的机会。本文将带…

成功将虚拟机映射到局域网,小伙伴都可以访问

一、添加入站规则 这张是添加所要映射端口的入站规则的图片&#xff0c;在此之前已将所有防火墙已关闭 如果没关就看下边的防火墙属性 二、添加虚拟机映射 添加之后&#xff0c;这里有个应用点一下&#xff0c;让NAT重启。 三、测试 等NAT 重启完成后比如你iP是182.1.1.1 …

当速度很重要时:使用 Hazelcast 和 Redpanda 进行实时流处理

在本教程中&#xff0c;了解如何构建安全、可扩展、高性能的应用程序&#xff0c;以释放实时数据的全部潜力。 在本教程中&#xff0c;我们将探索 Hazelcast 和 Redpanda 的强大组合&#xff0c;以构建对实时数据做出反应的高性能、可扩展和容错的应用程序。 Redpanda 是一个流…

SOLIDWORKS PDM—文件版本的管控

SOLIDWORKS产品数据管理 (PDM) 解决方案可帮助您控制设计数据&#xff0c;并且从本质上改进您的团队就产品开发进行管理和协作的方式。使用 SOLIDWORKS PDM Professional&#xff0c;您的团队能够&#xff1a;1. 安全地存储和索引设计数据以实现快速检索&#xff1b;2. 打消关于…

4WRZ25E3-220-5X/6A24NZ4/D3M不带位移反馈比例阀放大器

该先导阀是一个由比例电磁铁控制的三通减压阀&#xff0c;它的作用是将一个输入信号转化为一个与其成比例的压力输出信号&#xff0c;可用于所有的4WRZ...和5WRZ...型比例阀的控制。比例电磁铁是可调试&#xff0c;湿式直流电磁铁结构&#xff0c;带中心螺纹&#xff0c;线圈可…

淘宝搜索店铺列表API:关键字搜索店铺信息 获取店铺主页 店铺所在地 服务评级

接口名称&#xff1a;item_search_seller 基本功能介绍 该API可以通过传入关键字&#xff0c;获取到淘宝商城的店铺列表&#xff0c;支持翻页显示。指定参数page获取到指定页的数据。返回的店铺信息包括&#xff1a;店铺名、店铺ID、店铺主页、宝贝图片、掌柜名字、店铺所在地…

照明灯具哪个品牌好?护眼台灯该怎么选

现在儿童近视率越来越高了&#xff0c;用眼过度疲劳是导致近视的主要因素&#xff0c;学习环境的光线是否合适&#xff0c;都会直接影响用眼的疲劳程度。所以给孩子营造一个良好的学习环境非常重要&#xff01;一款护眼台灯可以很好的预防近视&#xff0c;为大家推荐五款护眼台…

探索数字孪生的数据之美:实时、多源、多维的未来

在数字孪生的世界里&#xff0c;数据不再是孤立的数字&#xff0c;而是构成了一个真实、动态的虚拟映像&#xff0c;其独特的特点为现代社会带来了前所未有的机遇。 首先&#xff0c;数字孪生的数据特点之一是实时性。在制造业中&#xff0c;数字孪生可以通过实时传感器数据&am…

4WRAP6W7-08-30=G24K4/M=00比例先导阀控制放大器

先导控制阀是直动式比例阀。控制边的尺寸经过优化&#xff0c;可用作比例方向阀型号 4WRKE 的先导控制阀。 比例电磁铁为带可拆卸线圈的耐压密闭型湿式插脚交流线圈。 它们可将电流按比例转换为机械力。电流强度的增加会导致磁力相应增加。设定的磁力会在整个控制行程中保持不…

华为AI战略的CANN

基于TVM的华为昇腾体系中—— 异构计算架构&#xff08;CANN&#xff09;是对标英伟达的CUDA CuDNN的核心软件层&#xff0c;向上支持多种AI框架&#xff0c;向下服务AI处理器&#xff0c;发挥承上启下的关键作用&#xff0c;是提升昇腾AI处理器计算效率的关键平台 主要包括有…

Java SpringBoot Vue ERP系统

系统介绍 该ERP系统基于SpringBoot框架和SaaS模式&#xff0c;支持多租户&#xff0c;专注进销存财务生产功能。主要模块有零售管理、采购管理、销售管理、仓库管理、财务管理、报表查询、系统管理等。支持预付款、收入支出、仓库调拨、组装拆卸、订单等特色功能。拥有商品库存…

【网络基础】应用层协议

【网络基础】应用层协议 文章目录 【网络基础】应用层协议1、协议作用1.1 应用层需求1.2 协议分类 2、HTTP & HTTPS2.1 HTTP/HTTPS 简介2.2 HTTP工作原理2.3 HTTPS工作原理2.4 区别 3、URL3.1 编码解码3.2 URI & URL 4、HTTP 消息结构4.1 HTTP请求方法4.2 HTTP请求头信…

虹科干货 | 化身向量数据库的Redis Enterprise——快速、准确、高效的非结构化数据解决方案!

用户期望在他们遇到的每一个应用程序和网站都有搜索功能。然而&#xff0c;超过80%的商业数据是非结构化的&#xff0c;以文本、图像、音频、视频或其他格式存储。Redis Enterprise如何实现矢量相似性搜索呢&#xff1f;答案是&#xff0c;将AI驱动的搜索功能集成到Redis Enter…

聊聊计算机技术

目录 1.计算机的概念 2.计算机的发展过程 3.计算机的作用 4.计算机给人类带来的福利 1.计算机的概念 计算机是一种用于处理和存储数据的电子设备。它能够执行各种操作&#xff0c;比如计算、逻辑操作、数据存储和检索等。计算机由硬件和软件两部分组成。 计算机的硬件包括中…

SAP ABAP 直接把内表转换成PDF格式(smartform的打印函数输出OTF格式数据)

直接上代码&#xff1a; REPORT zcycle055.DATA: lt_tab TYPE TABLE OF zpps001. DATA: ls_tab TYPE zpps001.ls_tab-werks 1001. ls_tab-gamng 150.00. ls_tab-gstrp 20201202. ls_tab-aufnr 000010000246. ls_tab-auart 标准生产. ls_tab-gltrp 20201205. ls_tab-matn…