【yolov8】yolov8剪枝训练流程

news2024/11/19 7:26:29

yolov8剪枝训练流程

流程:

  • 约束
  • 剪枝
  • 微调

一、正常训练

yolo train model=./weights/yolov8s.pt data=yolo_bvn.yaml epochs=100 amp=False project=prun name=train

二、约束训练

2.1 修改YOLOv8代码:

ultralytics/yolo/engine/trainer.py
添加内容:

# Backward

self.scaler.scale(self.loss).backward()

# ========== 新增 ==========

l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
        m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))

# ========== 新增 ==========

# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html

if ni - last_opt_step >= self.accumulate:
    self.optimizer_step()
    last_opt_step = ni

2.2 训练

需要注意的就是amp=False

yolo train model=prunt/train/weights/best.pt data=yolo_bvn.yaml epochs=100 amp=False project=prun name=constraint

训练完会得到一个best.pt和last.pt,推荐用last.pt

三、剪枝

上一步得到的last.pt作为剪枝对象,运行项目中的prun.py文件:

*这里的剪枝代码仅适用yolov8原模型,如有模块/模型的更改,则需要修改剪枝代码*

运行完会得到prune.pt和prune.onnx可以在netron.app网站拖入onnx文件查看是否剪枝成功了,成功的话可以看到某些通道数字为单数或者一些不规律的数字,如下图:

在这里插入图片描述

左侧为未剪枝的模型,右侧为剪枝后的模型。

关于yolov8剪枝有以下几点值得注意:

Pipeline:

    1. 为模型的BN增加L1约束,lambda用1e-2左右
    1. 剪枝模型使用的是全局阈值
    1. finetune模型时,一定要注意,此时需要去掉L1约束,最终的final的版本一定是去掉的(ultralytics/yolo/engine/trainer.py中注释)
    1. 对于yolo.model.named_parameters()循环,需要设置p.requires_gradTrue

Future work:

    1. 不能剪枝的layer,其实可以不用约束
    1. 对于低于全局阈值的,可以删掉整个module
    1. keep channels,对于保留的channels,它应该能整除n才是最合适的,否则硬件加速比较差
  • n怎么选呢?一般fp16时,n为8;int8时,n为16

四、 回调训练(finetune)

回调训练的唯一关键点就在于不让模型从yaml文件加载结构,直接加载pt文件

两种方法(因yolov8版本不同而选择不同方法):

方法一:

3.1 首先要把第一步约束训练的代码注释掉
3.2 修改相关代码,使模型不加载yaml文件

修改位置:yolo/engine/model.py的443行左右

self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)

# ========== 新增该行代码 ==========

self.model = weights

# ========== 新增该行代码 ==========

return ckpt

方法二:

3.1 首先要把第一步约束训练的代码注释掉
3.2 修改相关代码,使模型不加载yaml文件

修改位置:yolo/engine/model.py的335行左右

if not args.get('resume'):  # manually set model only if not resuming
    # self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
    # self.model = self.trainer.model
    ######################上面两行注释掉,添加下面一行#####
    self.trainer.model = self.model.train()
    ##########################修改####################
    self.trainer.hub_session = self.session  # attach optional HUB session
3.3 修改完代码就可以进行finetun训练了

命令行输入:

yolo train model=prun/prune/weights/last_prune.pt data="yolo_bvn.yaml" amp=False epochs=100 project=prun name=finetune device=0

五、结果展示:

5.1模型大小:ONNX模型大小从42M减少到34M

在这里插入图片描述

5.2PR曲线:

正常训练约束训练100轮微调
在这里插入图片描述在这里插入图片描述在这里插入图片描述

5.3实测视频在ubuntu上检测速度:

未剪枝:平均每帧5毫秒

剪枝后:平均每帧3.7毫秒

六、问题及解决:

对剪枝完的yolov8进行finetune时遇到RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm)

self.proj 可能不在与 pred_dist 相同的设备上。这可能是因为 self.proj 被指定在 CPU 上,而 pred_dist 在 GPU 上(或反之)。
要解决这个问题,需要确保两个张量位于相同的设备上。可以使用 to() 方法将 self.proj 放到与 pred_dist 相同的设备上。

解决:在loss.py添加如下代码:

def bbox_decode(self, anchor_points, pred_dist):
    """Decode predicted object bounding box coordinates from anchor points and distribution."""
    if self.use_dfl:
        b, a, c = pred_dist.shape  # batch, anchors, channels
        ####添加
        device = pred_dist.device
        self.proj = self.proj.to(device)
        #####

        pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
        # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
        # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
    return dist2bbox(pred_dist, anchor_points, xywh=False)

七、参考:

7.1 【yolov8系列】 yolov8 目标检测的模型剪枝_yolov8 剪枝-CSDN博客
7.2 YOLOv8剪枝全过程-CSDN博客

7.3 剪枝与重参第七课:YOLOv8剪枝-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1636959.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习高频问答题总结

机器学习问答题总结 第一章 线性回归1.什么是线性回归?解释主要原理2.解释线性回归中最小二乘法的原理吗?3.如何评估线性回归模型的性能?4.线性回归中正则化的目的是什么吗?L1正则化和L2正则化有什么不同? 第二章 逻辑…

深入解析yolov5,为什么算法都是基于yolov5做改进的?(一)

YOLOv5简介 YOLOv5是一种单阶段目标检测算法,它在YOLOv4的基础上引入了多项改进,显著提升了检测的速度和精度。YOLOv5的设计哲学是简洁高效,它有四个版本:YOLOv5s、YOLOv5m、YOLOv5l、YOLOv5x,分别对应不同的模型大小…

神经网络与深度学习--网络优化与正则化

文章目录 前言一、网络优化1.1网络结构多样性1.2高维变量的非凸优化1.鞍点2.平坦最小值3.局部最小解的等价性 1.3.改善方法 二、优化算法2.1小批量梯度下降法(Min-Batch)2.2批量大小选择2.3学习率调整1.学习率衰减(学习率退火)分段…

MouseBoost PRO for Mac激活版:强大的 鼠标增强软件

在追求高效工作的今天,MouseBoost PRO for Mac成为了许多Mac用户的得力助手。这款功能强大的鼠标增强软件,以其独特的智能化功能和丰富的实用工具,让您的电脑操作更加便捷、高效。 MouseBoost PRO for Macv3.4.0中文激活版下载 MouseBoost PR…

nginxconfig.io项目nginx可视化配置--搭建-视频

项目地址 https://github.com/digitalocean/nginxconfig.io搭建视频 nginxconfig.io搭建 nginxconfig.io搭建 展示效果 找到这个项目需要的docker镜像,有项目需要的node的版本 docker pull node:20-alpine运行这个node容器,在主机中挂载一个文件夹到容器中 主机&a…

Python 与 TensorFlow2 生成式 AI(四)

原文:zh.annas-archive.org/md5/d06d282ea0d9c23c57f0ce31225acf76 译者:飞龙 协议:CC BY-NC-SA 4.0 第九章:文本生成方法的崛起 在前几章中,我们讨论了不同的方法和技术来开发和训练生成模型。特别是在第六章“使用 …

LLM应用:让大模型prompt总结生成Mermaid流程图

生成内容、总结文章让大模型Mermaid流程图展示: mermaid 美人鱼, 是一个类似 markdown,用文本语法来描述文档图形(流程图、 时序图、甘特图)的工具,您可以在文档中嵌入一段 mermaid 文本来生成 SVG 形式的图形 Prompt 示例:用横向…

基于OSAL 实现UART、LED、ADC等基础示例 4

1 UART 实验目的 串口在我们开发单片机项目是很重要的,可以观察我们的代码运行情况,本节的目的就 是实现串口双工收发。 虽然说 osal 相关的代码已经跟硬件关系不大了,但是我们还是来贴出相关的硬件原理图贴出来。 1.1 串口初始化 osal_ini…

latex使用bib引用参考文献时,正文编号顺序乱序解决办法,两分钟搞定!

一、背景 用Latex写文章时,使用bib添加参考文献是一种最为简便的方式。但有的期刊模板,如机器人顶会IROS,会出现正文参考文献序号没按顺序排列的情况,如下图所示。按理说文献[4]应该是文献[2],[2]应该是[3]&#xff0…

Go中为什么不建议用锁?

Go语言中是不建议用锁,而是用通道Channel来代替(不要通过共享内存来通信,而通过通信来共享内存),当然锁也是可以用,锁是防止同一时刻多个goroutine操作同一个资源; GO语言中,要传递某个数据给另一个gorout…

Java项目:88 springboot104学生网上请假系统设计与实现

作者主页:源码空间codegym 简介:Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 本学生网上请假系统管理员,教师,学生。 管理员功能有个人中心,学生管理,教师管理,…

22 重构系统升级-实现不停服的数据迁移和用户切量

专栏的前 21 讲,从读、写以及扣减的角度介绍了三种特点各异的微服务的构建技巧,最后从微服务的共性问题出发,介绍了这些共性问题的应对技巧。 在实际工作中,你就可以参考本专栏介绍的技巧构建新的微服务,架构一个具备…

vue3 安装-使用之第一篇

首先需要node版本高于V16.14.1 安装 执行 npm create vitelatest 具体选择按照自己实际需要的来 Project name:项目名称 Select a framework:选择用哪种框架 (我选择vue) Select a variant: 选择用JS还是TS(我选择JS)找到项目&…

STM32 HAL库F103系列之IIC实验

IIC总线协议 IIC总线协议介绍 IIC:Inter Integrated Circuit,集成电路总线,是一种同步 串行 半双工通信总线。 总线就是传输数据通道 协议就是传输数据的规则 IIC总线结构图 ① 由时钟线SCL和数据线SDA组成,并且都接上拉电阻…

(7)快速调优

文章目录 前言 1 安装脚本 2 运行 QuikTune 3 高级配置 前言 VTOL QuikTune Lua 脚本简化了为多旋翼飞行器的姿态控制参数寻找最佳调整的过程。 脚本会缓慢增加相关增益,直到检测到振荡。然后,它将增益降低 60%,并进入下一个增益。所有增…

微服务保护和分布式事务(Sentinel、Seata)笔记

一、雪崩问题的解决的服务保护技术了解 二、Sentinel 2.1Sentinel入门 1.Sentinel的安装 (1)下载Sentinel的tar安装包先 (2)将jar包放在任意非中文、不包含特殊字符的目录下,重命名为 sentinel-dashboard.jar &…

Spark Structured Streaming 分流或双写多表 / 多数据源(Multi Sinks / Writes)

博主历时三年精心创作的《大数据平台架构与原型实现:数据中台建设实战》一书现已由知名IT图书品牌电子工业出版社博文视点出版发行,点击《重磅推荐:建大数据平台太难了!给我发个工程原型吧!》了解图书详情,…

数据库管理-第179期 分库分表vs分布式(20240430

数据库管理179期 2024-04-30 数据库管理-第179期 分库分表vs分布式(20240430)1 分库分表1.1 分库1.2 分表1.3 组合1.4 问题 2 分布式3 常见分布式数据库4 期望总结 数据库管理-第179期 分库分表vs分布式(20240430) 作者&#xff1…

git 第一次安装设置用户名密码

git config --global user.name ljq git config --global user.email 15137659164qq.com创建公钥命令 输入后一直回车 ssh-keygen -t rsa下面这样代表成功 这里是公钥的 信息输入gitee 中 输入下面命令看是否和本机绑定成功 ssh -T gitgitee.com如何是这样,恭喜…

spring的高阶使用技巧1——ApplicationListener注册监听器的使用

Spring中的监听器,高阶开发工作者应该都耳熟能详。在 Spring 框架中,这个接口允许开发者注册监听器来监听应用程序中发布的事件。Spring的事件处理机制提供了一种观察者模式的实现,允许应用程序组件之间进行松耦合的通信。 更详细的介绍和使…