YOLO-V5 算法和代码解析系列(二)—— 【train.py】核心内容

news2025/4/8 12:59:17

文章目录

    • 调试设置
    • 整体结构
    • 代码解析
      • Model
      • Trainloader
      • 分布式训练
      • Freeze
      • Optimizer
      • Scheduler
      • EMA

调试设置

  1. 调试平台:Ubuntu,VSCode

  2. 调试设置,打开【/home/slam/kxh-1/2DDection/yolov5/.vscode/launch.json】,操作如下图所示,
    在这里插入图片描述

    内容配置如下代码片段所示,根据需求修改【debug】参数,

    {
        // 使用 IntelliSense 了解相关属性。 
        // 悬停以查看现有属性的描述。
        // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
        "version": "0.2.0",
        "configurations": [
            // train.py
            {
                "name": "Python: Current File",
                "type": "python",
                "request": "launch",
                "program": "train.py",
                "console": "integratedTerminal",
                "justMyCode": true,
                "args":["--data", "coco128.yaml", 
                        "--cfg", "", 
                        "--weights", "yolov5s.pt", 
                        "--batch-size", "1", 
                        "--device", "1", 
                        "--epoch", "10",
                        "--workers", "1"]
           }
        ]
    }                
    

整体结构

  该篇博文主要目的不是帮助读者深入理解每一个知识点,每一行代码。【主要目的】:从整体上弄清楚【yolov5】整体的运行逻辑和结构,一个完整的深度学习项目,应该包含哪些模块,哪些设置。下面的思维导图,大概展示了【train.py】相关的所有知识点,以及整体的运行结构。
请添加图片描述

代码解析

  作者为了代码的健壮性,稳定性,【train.py】内容非常多,可能会感觉有一些繁杂。初次阅读代码可能会抓不住重点,进而在一些不重要的功能模块浪费时间。下面的内容,不会解析所有代码,而是解析【核心模块】和【核心代码】。

  1. 函数入口如下,parse_opt():配置参数;main():核心功能函数

    if __name__ == "__main__":
    	# 参数配置
        opt = parse_opt()
        # 核心功能函数:基本设置,train(),动态调参
        main(opt)
    
  2. 参数配置

    基本功能:解析命令行参数,包括训练参数,学习率,模型文件,数据路径,测试参数等;
    基本使用:库 【argparse】的使用,

    (1)导入库【argparse】;
    (2)创建解析对象【parser】;
    (3)添加命令行参数,以及默认的参数和选项【add_argument()】;
    (4)解析添加的参数【parser.parse_args()】;

    流程非常简单,示例如下,

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default=ROOT /'yolov5s.pt', help='initial weights path')
    opt = parser.parse_known_args()[0] if known else parser.parse_args()
    
  3. 解析【main()】 函数
    如【整体结构】中的思维导图所示,【main()】主要包括三大部分:基本设置,训练核心函数(train),动态调参。下面逐步解析每一个部分的核心内容:


    基本设置:Resume

    为了调试 Resume,需要添加【debug】参数:【“–resume”, “runs/train/exp/weights/best.pt”】,表示从该路径下的【best.pt】恢复训练。具体解释,见代码注释,

    # Resume
    # if的判断条件:
    #     opt.resume:true;
    #     check_wandb_resume(opt):false,通常不会用wandb工具;
    #     opt.evolve:动态调参的参数,通常为false;
    if opt.resume and not check_wandb_resume(opt) and not opt.evolve:  # resume an interrupted run
        # 判断【opt.resume】是否为字符串,true
        # ckpt: 'runs/train/exp/weights/best.pt'
        ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run()  # specified or most recent path
        # 检查【ckpt】文件是否存在
        assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
        # 重新加载中断训练前保存的配置文件信息,保存名字为【opt.yaml】
        with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
            opt = argparse.Namespace(**yaml.safe_load(f))  # replace
        # 恢复训练,重新更新模型参数
        opt.cfg, opt.weights, opt.resume = '', ckpt, True  # reinstate
        # 打印日志:Resuming training from runs/train/exp/weights/best.pt
        LOGGER.info(f'Resuming training from {ckpt}')
    else:
    	# 返回参数的具体路径
    	# opt.data: 'coco128.yaml' ---> '/home/slam/kxh-1/2DDection/yolov5/data/coco128.yaml'
    	# opt.hyp: PosixPath('data/hyps/hyp.scratch-low.yaml') ---> 'data/hyps/hyp.scratch-low.yaml'
    	# opt.weights: 'yolov5s.pt' ---> 'yolov5s.pt'
    	# opt.project: PosixPath('runs/train') ---> 'runs/train'
        opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
                check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project)  # checks
        # cfg:模型的结构配置文件,weights:pt文件。两者必须提供一个,才能构建出网络结构
        assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
        if opt.evolve:
            if opt.project == str(ROOT / 'runs/train'):  # if default project name, rename to runs/evolve
                opt.project = str(ROOT / 'runs/evolve')
            opt.exist_ok, opt.resume = opt.resume, False  # pass resume to exist_ok and disable resume
        if opt.name == 'cfg':
                opt.name = Path(opt.cfg).stem  # use model.yaml as name
        # 建立模型保存路径
        opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
    

    基本设置:DDP Mode

    DDP 是官方推荐的并行训练模式,下面的代码片段仅仅是设置【多进程组】。该篇文章并不解释DDP的训练模式,后续博文文章会专门解释DDP,此处不作深入解析。

    # DDP mode
    # 设置设备ID,以及检查设备的显卡,cuda信息,具体输出如下:
    # YOLOv5 🚀 v6.1-263-g0537e8dd Python-3.7.13 torch-1.10.0 CUDA:1 (NVIDIA GeForce RTX 2080 Ti, 11019MiB)
    device = select_device(opt.device, batch_size=opt.batch_size)
    if LOCAL_RANK != -1:
        msg = 'is not compatible with YOLOv5 Multi-GPU DDP training'
        assert not opt.image_weights, f'--image-weights {msg}'
        assert not opt.evolve, f'--evolve {msg}'
        assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
        assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multip of WORLD_SIZE'
        assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
        torch.cuda.set_device(LOCAL_RANK)
        device = torch.device('cuda', LOCAL_RANK)
        # DDP Mode:首先初始化多进程组
        dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
    
  4. 解析【train】函数
    训练的核心模块,包含很多内容(参考上述的整体结构图)。下面主要介绍核心的模块,具体如下

    Model

      这一部分包含模型的构建和初始化,后续博文会详细讲解,这里不作陈述。

    Trainloader

      这一部分主要是设置数据读取迭代器,数据读取类,数据增强等,后续博文会详细讲解,这里不作陈述。

    分布式训练

      这一部分主要是设置DP,DDP分布式训练,后续博文会详细讲解,这里不作陈述。

    Freeze

      下面的代码片段,简单展示了冻结层的基本设置。简单说,就是将该层的权重更新关闭,也就冻结了当前层的权重更新。

    冻结层操作通常用于迁移学习,具体可以参考 https://github.com/ultralytics/yolov5/issues/1314.

    # Freeze
    # 设置要冻结的层的ID
    freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))]  # layers to freeze
    # 遍历 model 的参数,k:层的名字,v: 权重值
    # k: 'model.0.conv.weight'
    # v.shape: torch.Size([32, 3, 6, 6])
    for k, v in model.named_parameters():
    	# 设置是否需要计算梯度的标记,True:优化该层,False:冻结该层
        v.requires_grad = True  # train all layers
        if any(x in k for x in freeze):
            LOGGER.info(f'freezing {k}')
            v.requires_grad = False
    

    Optimizer

    下面的代码主要是设置损失累积,根据实际需要调整,

    # Optimizer
    # 预先设定的【batch-size】
    nbs = 64  # nominal batch size
    # 累计【accumulate】次损失后,在进行反向传播优化,变向的是在增大batch-size?
    accumulate = max(round(nbs / batch_size), 1)  # accumulate loss before optimizing
    # 实际 batch_size <= nbs,权重衰减值不变,需要累积损失;
    # 实际 batch_size > nbs, 权重值乘以大于1的系数,不需要累积损失;
    hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decay
    LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
    

    优化器参数组:借助参数组,满足不同的优化需求,代码如下,

    # 初始化参数组的保存列表
    g = [], [], []  # optimizer parameter groups
    # 列举出【nn】模块中的名字和值(比如,('Conv2d', <class 'torch.nn.modules.conv.Conv2d'>),
    # 然后,把带有【Norm】字段的名字【k】对应的值【v】的添加到【bn】
    bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d()
    # 将当前模型中的【weights】,【bias】,【bn_weight】分开,并分别存入空列表【g】中
    # 存储结果:len(g)=3,len(g[0])=60,len(g[1])=57,len(g[2])=60
    for v in model.modules():
        if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):  # bias
            g[2].append(v.bias)
        if isinstance(v, bn):  # weight (no decay)
            g[1].append(v.weight)
        elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):  # weight (with decay)
            g[0].append(v.weight)
    

    选择优化器:选择是【SGD】,默认传入待优化参数为【g[2],bias】,然后再添加其它待优化的参数,

    # 提供三种优化器, 传入不同的参数
    if opt.optimizer == 'Adam':
        optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999))  # adjust beta1 to momentum
    elif opt.optimizer == 'AdamW':
        optimizer = AdamW(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999))  # adjust beta1 to momentum
    else:
        optimizer = SGD(g[2], lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
    
    # 更新优化器的待优化参数     
    # 显然还要添加核心的优化参数,【weights,并且指定了单独的权重衰减值】,【bn】和【bias】无法使用权重衰减系数
    optimizer.add_param_group({'params': g[0], 'weight_decay': hyp['weight_decay']})  # add g0 with weight_decay
    optimizer.add_param_group({'params': g[1]})  # add g1 (BatchNorm2d weights)
    LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
                f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
    del g
    

    关于建立优化器组(optimizer group)的简明解释:

    模型的参数可以被分成不同的组,然后赋予单独的优化参数,具体例子如下

    optim.SGD([
                  {'params': model.base.parameters()},
                  {'params': model.classifier.parameters(), 'lr': 1e-3}
              ], lr=1e-2, momentum=0.9) 
    

    上述代码片段表示:【model.base】的参数使用默认的学习率【1e-2】,【model.classifier】的参数使用学习率【1e-3】,所有的参数动使用的动量为【0.9】

    Scheduler

    学习率调整策略,

    # Scheduler
    if opt.cos_lr:
    	# 余弦调整策略,自定义调整规则,然后传入Pytorch API接口
        lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
    else:
    	# 线性调整,传入参数x
        lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf']  # linear
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)  # plot_lr_scheduler(optimizer, scheduler, epochs)
    

    绘制两种学习率的曲线图,官方已提供该函数:

    from utils.plots import plot_lr_scheduler
    plot_lr_scheduler(optimizer, scheduler, epochs)
    

    绘制结果如下,左图是余弦调整方式,右图是线性调整方式,

    EMA

    Tensorflow框架 —— tf.ExponentalMovingAverage(),解析了EMA的理论部分,可以大概了解一下具体原理。 下面是【yolo-v5】实现的指数移动平均(EMA)的代码,具体解析见注释如下,

    EMA的操作对象:网络训练中,模型的优化变量,比如weights,bias.
    EMA的操作方式:对一定步骤内的历史变量进行指数移动平均,也就是对历史变量计算平均值。
    EMA的作用:保证模型测试的稳定性,不至于由于异常值,导致推理结果波动太大。

    class ModelEMA:
        """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
        Keeps a moving average of everything in the model state_dict (parameters and buffers)
        For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
        """
        def __init__(self, model, decay=0.9999, tau=2000, updates=0):
            # Create EMA
            # 构建EMA,.eval()固定BN的参数,不参与平均
            self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
            # if next(model.parameters()).device.type != 'cpu':
            #     self.ema.half()  # FP16 EMA
            self.updates = updates  # number of EMA updates
            # 每一次更新的decay值
            self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
            # 将参数全部置为False,不进行梯度计算
            for p in self.ema.parameters():
                p.requires_grad_(False)
                
        def update(self, model):
            # Update EMA parameters
            with torch.no_grad(): # 暂时停止自动求导模块运算
                self.updates += 1
                d = self.decay(self.updates) # 更新新的decay值
                
                # 获取模型的所有参数,包括训练参数和非训练参数
                msd = de_parallel(model).state_dict()  # model state_dict
                # 计算新的参数,并更新对应的影子参数(存储在ema,默认为False)
                for k, v in self.ema.state_dict().items():
                    if v.dtype.is_floating_point:
                        v *= d
                        v += (1 - d) * msd[k].detach()
                        
        def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
            # Update EMA attributes
            copy_attr(self.ema, model, include, exclude)
    

    YOLO-V5中使用流程,

    初始化【EMA】对象

    # EMA 	
    ema = ModelEMA(model) if RANK in {-1, 0} else None
    

    更新参数

    ema.update(model) 
    
  5. 动态调参
    实际未用,暂不解析,后续若用,会更新博文。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/117983.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GNN基础知识

1. 泰勒公式 背景background 有一个很复杂的方程&#xff0c;我们直接计算方程本身的值可能非常麻烦。 所以我们希望能够找到一个近似的方法来获得一个足够近似的值 本质&#xff1a; 近似&#xff0c;求一个函数的近似值 one point is 近似的方法another point is 近似的…

【Java 数据结构】-优先级队列以及Java对象的比较

作者&#xff1a;学Java的冬瓜 博客主页&#xff1a;☀冬瓜的主页&#x1f319; 专栏&#xff1a;【Java 数据结构】 分享&#xff1a;美妙人生的关键在于你能迷上什么东西。——《球状闪电》 主要内容&#xff1a;优先级队列底层的堆&#xff0c;大堆的创建&#xff0c;插入&a…

Openssl 生成自签名证书

最近在调试Ingress需要使用多份证书&#xff0c;对证书的生成和使用做了简单的整理。 不用翻垃圾桶一条过 #!/bin/sh output_dir"/opt/suops/k8s/ingress-files/certs/fanht-create-ssl/" read -p "Enter your domain [www.example.com]: " DOMAIN echo…

C++11特性-线程

并发 一个程序执行多个独立任务&#xff0c;提高性能 单核cpu是通过(任务切换)&#xff0c;即上下文切换&#xff0c;有时间开销 多核cpu(当核数>任务数)&#xff0c;硬件并发 进程 运行起来的一个可执行程序&#xff08;一段程序的运行过程&#xff09; 资源分配的最小单…

百数应用中心上新了——餐饮门店管理系统

随着智能化时代的来临&#xff0c;传统的餐饮门店管理方式逐渐暴露出缺陷。不少餐饮业的掌门人都纷纷对管理方式进行了转型&#xff0c;由传统模式转变为数字化系统的管理。然而数字化管理方式也没那么容易进行&#xff0c;想要百分百满足需求的系统耗时耗力耗钱&#xff0c;成…

不懂PO 设计模式?这篇实战文带你搞定 PO

1080442 73.1 KB 为UI页面写测试用例时&#xff08;比如web页面&#xff0c;移动端页面&#xff09;&#xff0c;测试用例会存在大量元素和操作细节。当UI变化时&#xff0c;测试用例也要跟着变化&#xff0c; PageObject 很好的解决了这个问题&#xff01; 使用UI自动化测试工…

钉钉 ANR 治理最佳实践 | 定位 ANR 不再雾里看花

作者&#xff1a;姜凡(步定) 本文为《钉钉 ANR 治理最佳实践》系列文章首篇《定位 ANR 不再雾里看花》&#xff0c;主要介绍了钉钉自研的 ANRCanary 通过监控主线程的执行情况&#xff0c;为定位 ANR 问题提供更加丰富的信息。 后续将在第二篇文章中讲述钉钉基于分析算法得出 …

【TuyaOS开发之旅】BK7231N GPIO的简单使用

接口讲解 GPIO初始化 /*** brief gpio 初始化* * param[in] pin_id: 需要初始化的GPIO编号&#xff0c; 对应TUYA_GPIO_NUM_E枚举* param[in] cfg: gpio 配置** return OPRT_OK on success. Others on error, please refer to tuya_error_code.h*/ OPERATE_RET tkl_gpio_ini…

基于SpringBoot工程开发Docker化微服务

目录 1. 微服务容器化治理的优缺点 1.1 微服务容器化的优点 1.2 微服务容器化的缺点 2. 微服务的两种模式 2.1 Microservice SDK 2.2 ServiceMesh 3. 微服务容器化治理的推荐模式 4.Windows下开发容器化微服务&#xff08;非K8S&#xff09; 4.1 开发环境 4.2 代码框架…

全网最新、最详细的使用burpsuite验证码识别绕过爆破教程(2023最新)

1、前沿 最近一直在研究绕过验证码进行爆破的方法&#xff0c;在这里对自己这段时间以来的收获进行一下分享。在这里要分享的绕过验证码爆破的方法一共有2个&#xff0c;分为免费版本&#xff08;如果验证码比较奇怪可能会有识别错误的情况&#xff09;和付费版本&#xff08;…

【Qt】QtCreator远程部署、调试程序

1、添加远程设备 1)QtCreator 工具–> 选项 --> 设备 --> 添加 2)设备设置向导选择–> Generic Linux Device --> 开启向导 3)填写“标识配置的名称”(随便写)、设备IP、用户名 --> 下一步 4)选择配对秘密文件,第一次配对,可以不填写,点击“下一…

嵌入式:ARM嵌入式系统开发流程概述

文章目录嵌入式开发的具体过程开发流程图嵌入式软件开发环境交叉开发环境远程调试结构图嵌入式应用软件开发的基本流程软件模拟环境目标板与评估板嵌入式软件开发的可移植性和可重用性嵌入式开发的具体过程 系统定义与需求分析阶段方案设计阶段详细设计阶段软硬件集成测试阶段…

Tomcat架构分析—— Engine

文章目录一、Tomcat的核心模块&#xff08;核心组件&#xff09;二、Engine 组件1.核心类与依赖图2.核心类源码分析构造函数&#xff1a;初始化方法 init&#xff1a;启动方法 start&#xff1a;3.Engine的启动过程总结一、Tomcat的核心模块&#xff08;核心组件&#xff09; …

机器学习之支持向量机(手推公式版)

文章目录前言1. 间隔与支持向量2. 函数方程描述3. 参数求解3.1 拉格朗日乘数3.2 拉格朗日对偶函数前言 支持向量机(Support(Support(Support VectorVectorVector Machine,SVM)Machine,SVM)Machine,SVM)源于统计学习理论&#xff0c;是一种二分类模型&#xff0c;是机器学习中获…

mysql查询当天,近一周,近一个月,近一年的数据

1.mysql查询当天的数据 select * from table where to_days(时间字段) to_days(now()); 2.mysql查询昨天的数据 select * from table where to_days(now( ) ) - to_days( 时间字段名) < 1 3.mysql查询近一周的数据 SELECT * FROM table WHERE date(时间字段) > D…

MySQL表的创建修改删除

目录 1、表的创建 2、查看表结构 3、表的修改 4、表的删除 1、表的创建 CREATE TABLE table_name ( field1 datatype, field2 datatype, field3 datatype ) character set 字符集 collate 校验规则 engine 存储引擎&#xff1b;说明&#xff1a; field 表示列名 datatype 表…

计算机系统基础实验 - 定点数加减法的机器级表示

实验序号&#xff1a;2 实验名称&#xff1a;定点数加减法的机器级表示 适用专业&#xff1a;软件工程 学 时 数&#xff1a;2学时 一、实验目的 1、掌握定点数加法的机器级表示。 2、掌握定点数减法的机器级表示。 3、掌握EFLAGS中4个牵涉到计算的标志位的计算方法。 4、掌握…

python实现动态柱状图

目录 一.基础柱状图 反转x轴&#xff0c;y轴&#xff0c;设置数值标签在右侧 小结 二.基础时间线柱状图 三.GDP动态柱状图绘制 1.了解列表的sort方法并配合lambda匿名函数完成列表排序 2.完成图表所需数据 3.完成GDP动态图表绘制 添加主题类型 设置动态标题 四.完整代码…

5.6 try语句块和异常处理

文章目录throw表达式(异常检测)try语句块&#xff08;异常处理&#xff09;编写处理代码函数在寻找处理代码的过程中退出标准异常异常是指存在于运行时的反常行为&#xff0c;这些行为超出了函数正常功能的范围。典型的异常包括失去数据库连接以及遇到意外输入等。当程序的某部…

Android Studio实现一个旅游课题手机app

文章目录&#xff1a; 目录 一、课题介绍 二、软件的运行环境 三、软件运行截图 四、软件项目总结 一、课题介绍 本次课题是实现了一个外出旅游的app&#xff0c;通过app可以显示景点的信息&#xff0c;以及根据地区查询&#xff0c;具体功能如下&#xff1a; 客户端 1.用…