剪枝与重参第七课:YOLOv8剪枝

news2025/2/28 17:46:11

目录

  • YOLOv8剪枝
    • 前言
    • 1.Overview
    • 2.Pretrain(option)
    • 3.Constrained Training
    • 4.Prune
      • 4.1 检查BN层的bias
      • 4.2 设置阈值和剪枝率
      • 4.3 最小剪枝Conv单元的TopConv
      • 4.4 最小剪枝Conv单元的BottomConv
      • 4.5 Seq剪枝
      • 4.6 Detect-FPN剪枝
      • 4.7 完整示例代码
    • 5.YOLOv8剪枝总结
    • 总结

YOLOv8剪枝

前言

手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。

本次课程主要讲解YOLOv8剪枝。

课程大纲可看下面的思维导图

在这里插入图片描述

1.Overview

YOLOV8剪枝的流程如下:

在这里插入图片描述

结论:在VOC2007上使用yolov8s模型进行的实验显示,预训练和约束训练在迭代50个epoch后达到了相同的mAP(:0.5)值,约为0.77。剪枝后,微调阶段需要65个epoch才能达到相同的mAP50。修建后的ONNX模型大小从43M减少到36M。

注意:我们需要将网络结构和网络权重区分开来,YOLOv8的网络结构来自yaml文件,如果我们进行剪枝后保存的权重文件的结构其实是和原始的yaml文件不符合的,需要对yaml文件进行修改满足我们的要求。

2.Pretrain(option)

步骤如下:

  • git clone https://github.com/ultralytics/ultralytics.git
  • use VOC2007, and modify the VOC.yaml(去除VOC2012的相关内容)
  • disable amp(禁用amp混合精度)
# FILE: ultralytics/yolo/engine/trainer.py
...
def check_amp(model):
    # Avoid using mixed precision to affect finetune
    return False # <============== modified(修改部分)
    device = next(model.parameters()).device  # get model device
    if device.type in ('cpu', 'mps'):
        return False  # AMP only used on CUDA devices

    def amp_allclose(m, im):
        # All close FP32 vs AMP results
    ...

3.Constrained Training

约束训练是为了筛选哪些channel比较重要,哪些channel没有那么重要,也就是我们上节课所说的稀疏训练

  • prune the BN layer by adding L1 regularizer.
# FILE: ultralytics/yolo/engine/trainer.py
...
# Backward
self.scaler.scale(self.loss).backward()

# <============ added(新增)
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
        m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))

# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:
    self.optimizer_step()
    last_opt_step = ni
...

注意1:在剪枝时,我们选择加载last.pt而非best.pt,因为由于迁移学习,模型的泛化性比较好,在第一个epoch时mAP值最大,但这并不是真实的,我们需要稳定下来的一个模型进行prune

注意2:我们在对Conv层进行剪枝时,我们只考虑1v1(如BottleNeck,C2f and SPPF)、1vm(如Backbone,Detect)的情形,并不考虑mv1的情形。

思考:Constrained Training的必要性?

约束训练可以使得模型更易于剪枝。在约束训练中,模型会学习到一些通道或者权重系数比较不重要的信息,而这些信息在剪枝过程中得到应用,从而达到模型压缩的效果。而如果直接进行剪枝操作,可能会出现一些问题,比如剪枝后的模型精度大幅下降、剪枝不均匀等。因此,在进行剪枝操作之前,通过稀疏训练的方式,可以更好地准确地确定哪些通道或者权重系数可以被剪掉,从而避免上述问题的发生。

4.Prune

4.1 检查BN层的bias

  • 剪枝后,确保BN层的大部分bias足够小(接近于0),否则重新进行稀疏训练
for name, m in model.named_modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        w = m.weight.abs().detach()
        b = m.bias.abs().detach()
        ws.append(w)
        bs.append(b)
        print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())

4.2 设置阈值和剪枝率

  • threshold:全局或局部
  • factor:保持率,裁剪太多不推荐
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)

4.3 最小剪枝Conv单元的TopConv

Top-Bottom Conv如下图所示:

在这里插入图片描述

TopConv剪枝的示例代码如下:

def prune_conv(conv1: Conv, conv2: Conv):
    gamma = conv1.bn.weight.data.detach()
    beta  = conv1.bn.bias.data.detach()

    keep_idxs = []    
    local_threshold = threshold

    while len(keep_idxs) < 8:
        keep_idxs = torch.where(gamma.abs() >= local_threshold)[0] 
        local_threshold = local_threshold * 0.5

    n = len(keep_idxs)
    print(n / len(gamma) * 100)  # 打印我们保留了多少的channel
    
    # prune
    conv1.bn.weight.data = gamma[keep_idxs]
    conv1.bn.bias.data   = beta[keep_idxs]
    conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
    conv1.bn.running_var.data  = conv1.bn.running_var.data[keep_idxs]
    conv1.bn.num_features   = n
    conv1.conv.weight.data  = conv1.conv.weight.data[keep_idxs]
    conv1.conv.out_channels = n

    if conv1.conv.bias is not None:
        conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

# pattern to prune
# 1. prune all 1 vs 1 TB pattern e.g. bottleneck
for name, m in model.named_modules():
    if isinstance(m, Bottleneck):
        prune_conv(m.cv1, m.cv2)

注意:由于NVIDIA的硬件加速的原因,我们保留的channels应该大于等于8,我们可以通过设置local_threshold,尽量小点,让更多的channel保留下来。

4.4 最小剪枝Conv单元的BottomConv

BottomConv剪枝的示例代码如下:

def prune_conv(conv1: Conv, conv2: Conv):
    ...
    if not isinstance(conv2, list):
        conv2 = [conv2]
    
    for item in conv2:
        if item is not None:
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]

注意BottomConv存在两种情形,一种是单个Conv,还有一种是Conv列表。TopConv需要考虑conv2d+bn+act,而BottomConv只需要考虑conv2d

4.5 Seq剪枝

Seq剪枝的示例代码如下:

def prune(m1, m2):
    if isinstance(m1, C2f):
        m1 = m1.cv2
    
    if not isinstance(m2, list):
        m2 = [m2]
    
    for i, item in enumerate(m2):
        if isinstance(item, C2f) or isinstance(item, SPPF):
            m2[i] = item.cv1
    
    prune_conv(m1, m2)

# 2. prune sequential
seq = model.model
for i in range(3, 9):
    if i in [6, 4, 9]: continue
    prune(seq[i], seq[i+1])

注意:我们不考虑1vm的情形,因此在4,6,9的module我们是不进行剪枝的,后续head进行Concat时是对4,6,9的module进行拼接的。考虑到前几个conv的特征提取的重要性,我们也不剪枝它们。(那感觉没剪几个module呀😂)

4.6 Detect-FPN剪枝

Detect-FPN剪枝的示例代码如下:

# 3. prune FPN related block
detect: Detect = seq[-1]

last_inputs = [seq[15], seq[18], seq[21]]
colasts     = [seq[16], seq[19], None]

for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
    prune(last_input, [colast, cv2[0], cv3[0]])
    prune(cv2[0], cv2[1])
    prune(cv2[1], cv2[2])
    prune(cv3[0], cv3[1])
    prune(cv3[1], cv3[2])

for name, p in yolo.model.named_parameters():
    p.requires_grad = True

注意:一定要设置所有参数为需要训练。因为加载后的model会给弄成False,导致报错

4.7 完整示例代码

完整的示例代码如下:

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect

# Load a model
yolo = YOLO("epoch-50-constrained_weights/last.pt")  # build a new model from scratch
model = yolo.model

ws = []
bs = []

for name, m in model.named_modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        w = m.weight.abs().detach()
        b = m.bias.abs().detach()
        ws.append(w)
        bs.append(b)
        print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
# keep
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)


def prune_conv(conv1: Conv, conv2: Conv):
    gamma = conv1.bn.weight.data.detach()
    beta  = conv1.bn.bias.data.detach()
    # if gamma.abs().min() > f or beta.abs().min() > 0.1:
    #     return
    
    # idxs = torch.argsort(gamma.abs() * coeff + beta.abs(), descending=True)
    keep_idxs = []
    local_threshold = threshold
    while len(keep_idxs) < 8:
        keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
        local_threshold = local_threshold * 0.5
    n = len(keep_idxs)
    # n = max(int(len(idxs) * 0.8), p)
    print(n / len(gamma) * 100)
    # scale = len(idxs) / n
    conv1.bn.weight.data = gamma[keep_idxs]
    conv1.bn.bias.data   = beta[keep_idxs]
    conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
    conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
    conv1.bn.num_features = n
    conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
    conv1.conv.out_channels = n
    
    if conv1.conv.bias is not None:
        conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

    if not isinstance(conv2, list):
        conv2 = [conv2]
        
    for item in conv2:
        if item is not None:
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]
    
def prune(m1, m2):
    if isinstance(m1, C2f):      # C2f as a top conv
        m1 = m1.cv2
    
    if not isinstance(m2, list): # m2 is just one module
        m2 = [m2]
        
    for i, item in enumerate(m2):
        if isinstance(item, C2f) or isinstance(item, SPPF):
            m2[i] = item.cv1
    
    prune_conv(m1, m2)

for name, m in model.named_modules():
    if isinstance(m, Bottleneck):
        prune_conv(m.cv1, m.cv2)
        
seq = model.model
for i in range(3, 9):
    if i in [6, 4, 9]: continue
    prune(seq[i], seq[i+1])
    
detect:Detect = seq[-1]
last_inputs   = [seq[15], seq[18], seq[21]]
colasts       = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
    prune(last_input, [colast, cv2[0], cv3[0]])
    prune(cv2[0], cv2[1])
    prune(cv2[1], cv2[2])
    prune(cv3[0], cv3[1])
    prune(cv3[1], cv3[2])

# ***step4,一定要设置所有参数为需要训练。因为加载后的model他会给弄成false。导致报错
# pipeline:
# 1. 为模型的BN增加L1约束,lambda用1e-2左右
# 2. 剪枝模型,比如用全局阈值
# 3. finetune,一定要注意,此时需要去掉L1约束。最终final的版本一定是去掉的
for name, p in yolo.model.named_parameters():
    p.requires_grad = True
    
# 1. 不能剪枝的layer,其实可以不用约束
# 2. 对于低于全局阈值的,可以删掉整个module
# 3. keep channels,对于保留的channels,他应该能整除n才是最合适的,否则硬件加速比较差
#    n怎么选,一般fp16时,n为8
#                int8时,n为16
#    cp.async.cg.shared
#

yolo.val()
# yolo.export(format="onnx")
# yolo.train(data="VOC.yaml", epochs=100)
print("done")

5.YOLOv8剪枝总结

关于yolov8剪枝有以下几点值得注意:

Pipeline:

    1. 为模型的BN增加L1约束,lambda用1e-2左右
    1. 剪枝模型使用的是全局阈值
    1. finetune模型时,一定要注意,此时需要去掉L1约束,最终的final的版本一定是去掉的(ultralytics/yolo/engine/trainer.py中注释)
    1. 对于yolo.model.named_parameters()循环,需要设置p.requires_gradTrue

Future work:

    1. 不能剪枝的layer,其实可以不用约束
    1. 对于低于全局阈值的,可以删掉整个module
    1. keep channels,对于保留的channels,它应该能整除n才是最合适的,否则硬件加速比较差
  • n怎么选呢?一般fp16时,n为8;int8时,n为16

总结

本次课程学习了YOLOv8的剪枝,主要是对前面剪枝课程的一个总结和实现吧,大体流程就是稀疏训练后进行剪枝最后微调,看着虽然简单,但实际细节把控还是非常多的,比如说哪些部分好剪,哪些部分不好剪,剪枝的过程中如何通过model获取想要prune的module等等,需要对YOLOv8整体网络结构和对ONNX模型的操作非常熟练,这还只是基础理论,实操部分的坑还没踩呢,在之后好好练习练习吧😄

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/426602.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【三十天精通Vue 3】第八天 Vue 3 生命周期钩子详解

✅创作者&#xff1a;陈书予 &#x1f389;个人主页&#xff1a;陈书予的个人主页 &#x1f341;陈书予的个人社区&#xff0c;欢迎你的加入: 陈书予的社区 &#x1f31f;专栏地址: 三十天精通 Vue 3 文章目录引言一、Vue 3 生命周期钩子概述1.1 生命周期钩子的简介1.2 生命周…

mulesoft MCIA 破釜沉舟备考 2023.04.14.11

mulesoft MCIA 破釜沉舟备考 2023.04.14.111. To implement predictive maintenance on its machinery equipment, ACME Tractors has installed thousands of IoT sensors that will send data for each machinery asset as sequences of JMS messages, in near real-time, to…

惠普Probook455电脑开机突然卡住无法进入桌面

惠普Probook455电脑开机突然卡住无法进入桌面解决方法分享。最近有用户使用的惠普Probook455电脑在开机的时候&#xff0c;电脑一直卡在开机的界面上&#xff0c;无法进入到系统中。无论是重启还是安全模式都无法解决问题。那么遇到这个情况怎么去进行问题的解决&#xff0c;来…

C++---状态压缩dp---炮兵阵地(每日一道算法2023.4.16)

注意事项&#xff1a; 本题为"状态压缩dp—蒙德里安的梦想"和"状态压缩dp—小国王"和"状态压缩dp—玉米田"的近似题&#xff0c;建议先阅读这三篇文章并理解。 题目&#xff1a; 司令部的将军们打算在 NM 的网格地图上部署他们的炮兵部队。 一个…

GoogleTest+VS code编译和编写简单测试用例

目录前言一、安装gtest二、 编译gtest与运行单元测试第一种编译方式第二种编译方式前言 在B站看了非常多Gtest的教学视频&#xff0c;CSDN上gtest博客也特别多&#xff0c;但是都非常陈旧或者根本不是用vscode。本篇目的在于&#xff0c;说明如何在vscode上编写简单单元测试。…

day01_Java概述丶环境搭建

Java背景知识 Java概述 概述&#xff1a;计算机语言就是人与计算机之间进行信息交流沟通的一种特殊语言。所谓计算机编程语言&#xff0c;就是人们可以使用编程语言对计算机下达命令&#xff0c;让计算机完成人们需要的功能。 Java语言&#xff1a;是美国Sun公司&#xff08…

Siamese network

文章目录一、相似性度量1. 欧氏距离2. 马氏距离二、Siamese network1. Siamese network 基础架构2. 损失函数3. 不同的Siamese network3.1. 行人重识别3.2 其他应用场景一、相似性度量 相似性度量是机器学习中一个非常基础的概念&#xff0c;是评定两个事物之间相似程度的一种度…

E: 仓库 “http://mirrors.aliyun.com/ubuntu eoan Release” 没有 Release 文件 —— 解决方案

Ubuntu 20.04 更新的时候&#xff0c;遇到如下问题&#xff1a; 可以通过修改源&#xff0c;来进行修复&#xff1a; 1、登录如下网址&#xff1a;LUGs repo file generator 2、选择对应的 Ubuntu 版本&#xff0c;这里我是 Ubuntu 20.04 点击 Download&#xff0c;会下载一个 …

DUBBO注册中心

注册中心上保存四种类型的数据: providers: 服务提供者目录,记录着服务提供者的ip、端口等信息。 consumers: 服务消费者目录,记录服务消费者的元数据信息,服务提供者并不会用到服务消费者的信息,这里要记录消费者的信息,是给服务治理中心(dubbo-admin)使用的。 route…

五分钟排查Linux的健康状态

五分钟排查Linux的健康状态1. CPU1.1 top命令1.2 什么是负载1.3 vmstat2. 内存2.1 观测命令2.2 CPU缓存2.3 HugePage2.4 预先加载3. I/O3.1 观测命令3.2 零拷贝4. 网络参考&#xff1a;《Linux运维实战》、xjjdog 操作系统作为所有程序的载体&#xff0c;对应用的性能影响是非常…

论文阅读 | Interpolated Convolutional Networks for 3D Point Cloud Understanding

前言&#xff1a;ICCV2019点云特征提取点卷积InterpoConv Interpolated Convolutional Networks for 3D Point Cloud Understanding 引言 点云是不规则、无序、且稀疏的 处理这样的点云数据有两大类方法 第一&#xff1a;voxel化 directly rasterize irregular point clouds…

《Netty》从零开始学netty源码(三十八)之PoolSubPage

PoolSubPage 上一节中我们提到了PooledByteBufAllocator类&#xff0c;先看下netty中有关内存的类关系&#xff1a; 从图中可以看到PoolSubPage为最小单位&#xff0c;所以我们先从最小的开始分析&#xff0c;先看下它的属性值&#xff1a; 为了更好的理解这些属性的意义&…

表情包MD编辑器简单使用

&#x1f337;1 表情包 ⭐️&#xff08;1&#xff09;常规表情图标 &#x1f600; &#x1f601; &#x1f602; &#x1f603; &#x1f604; &#x1f605; &#x1f606; &#x1f609; &#x1f60a; &#x1f60b; &#x1f60e; &#x1f60d; &#x1f618; &#x1…

(Deep Learning)交叉验证(Cross Validation)

交叉验证&#xff08;Cross Validation&#xff09; 交叉验证&#xff08;Cross Validation&#xff09;是一种评估模型泛化性能的统计学方法&#xff0c;它比单次划分训练集和测试集的方法更加稳定、全面。 交叉验证不但可以解决数据集中数据量不够大的问题&#xff0c;也可以…

CSS中flex属性的的使用以及应用场景有哪些

文章目录一. flex属性?(虚假的) --- 这里主要是回顾1.1 flex-grow1.2 flex-shrink1.3 flex-basis二. flex属性 ! (真正的!!!)三. flex一些常见的值, 以及使用场景3.1 flex:initial 使用场景3.2 flex:0 和 flex:node 适用场景3.3 flex:1 和 flex:auto3.4 总结一. flex属性?(虚…

54 openEuler搭建Mariadb数据库服务器-Mariadb介绍

文章目录54 openEuler搭建Mariadb数据库服务器-Mariadb介绍54.1 MariaDB的架构54.2 MariaDB的存储引擎54 openEuler搭建Mariadb数据库服务器-Mariadb介绍 MariaDB数据库管理系统是MySQL的一个分支&#xff0c;主要由开源社区在维护&#xff0c;采用GPL授权许可。MariaDB的目的…

零售数据分析操作篇14:利用内存计算做销售筛选分析

各位数据的朋友&#xff0c;大家好&#xff0c;我是老周道数据&#xff0c;和你一起&#xff0c;用常人思维数据分析&#xff0c;通过数据讲故事。 上一讲讲了图表间联动的应用场景&#xff0c;即当我们点击某个图表时&#xff0c;会影响其他图表一起变化&#xff0c;而变化背…

Vue3 Element-plus el-menu无限级菜单组件封装

对于element中提供给我们的el-menu组件最多可以实现三层嵌套&#xff0c;如果多一层数据只能自己通过变量去加一层&#xff0c;如果加了两层、三层这种往往是行不通的&#xff0c;所以只能进行封装 效果图一、定义数据 MenuData.ts export default [{id: "1",name:…

spring boot 访问HTML

HTML整合spring boot简介默认文件路径访问自定义文件路径访问或通过Controller控制器层跳转访问简介 SpringBoot默认的页面映射路径&#xff08;即模板文件存放的位置&#xff09;为“classpath:/templates/*.html”。静态文件路径为“classpath:/static/”&#xff0c;其中可…

三菱FX3U PLC计米轮功能块(完整ST代码)

计米轮功能块(wheel_FB)详细计米、测速原理请参看下面的博客: PLC高速脉冲输入计米轮模块(编码器测速/计米详细讲解)_RXXW_Dor的博客-CSDN博客线缆行业单绞机PLC控制算法详细解读可以参看下面的文章链接:线缆行业单绞机控制算法(详细图解+代码)_RXXW_Dor的博客-CSDN博客在…