YOLOv8 基于BN层的通道剪枝

news2024/12/23 17:46:41

YOLOv8 基于BN层的通道剪枝

1. 稀疏约束训练

在损失项中增加对BN层的缩放系数 γ \gamma γ和偏置项 β \beta β的稀疏约束, λ \lambda λ系数越大,稀疏约束越严重
L = ∑ ( x , y ) l ( f ( x ) , y ) + λ 1 ∑ γ g ( γ ) + λ 2 ∑ β g ( β ) L = \sum_{(x,y)}l(f(x),y)+\lambda_1 \sum_{\gamma}g(\gamma)+\lambda_2 \sum_{\beta}g(\beta) L=(x,y)l(f(x),y)+λ1γg(γ)+λ2βg(β)
对于 L 1 L_1 L1 稀疏约束,有:
g ( γ ) = ∣ γ ∣ , g ( β ) = ∣ β ∣ g(\gamma)=|\gamma|,\quad g(\beta) = |\beta| g(γ)=γ,g(β)=β
如果直接修改YOLOv8的损失,不方便控制L只传导对BN的参数更新,因此,采用修改BN的梯度的方式修改。

相对于原始的梯度项,BN的缩放系数和偏置项会增加以下梯度:
Δ γ = ∂ ( λ 1 ∗ g ( γ ) ) ∂ γ = λ 1 ∗ s i g n ( γ ) Δ β = ∂ ( λ 2 ∗ g ( β ) ) ∂ β = λ 2 ∗ s i g n ( β ) \Delta\gamma = \frac{\partial (\lambda_1*g(\gamma))}{\partial \gamma} = \lambda_1*sign(\gamma) \\ \Delta\beta = \frac{\partial (\lambda_2*g(\beta))}{\partial \beta} = \lambda_2*sign(\beta) Δγ=γ(λ1g(γ))=λ1sign(γ)Δβ=β(λ2g(β))=λ2sign(β)
在训练过程中,逐渐减小 λ 1 \lambda_1 λ1参数,减小对 γ \gamma γ的约束(稳定训练、增强训练和重调的一致性
λ 1 = 0.01 ∗ ( 1 − 0.9 ∗ e n e ) \lambda_1 = 0.01*(1-0.9*\frac{e}{ne}) λ1=0.01(10.9nee)
对于YOLOv8,我们只需要找到梯度更新的地方,然后修改即可。

修改YOLOv8代码:ultralytics/engine/trainer.py-390行

# Backward
self.scaler.scale(self.loss).backward()

# ========== 新增 ==========
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
        m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
# ========== 新增 ==========

# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html

然后执行如下代码开启训练:

yolo = YOLO("yolov8n.pt")
yolo.train(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640, epochs=50)

2. 剪枝

稀疏训练之后呢,我们得到了一个best.pt和last.pt,由于需要微调,基于last.pt相对更好。

YOLOv8的结构如下:

在这里插入图片描述

该结构中每一个Conv层中均包含一个BN层,对BN进行通道剪枝的时候,一方面需要剪掉Conv的输出通道数和对应的权重,另一方面需要剪掉下一层Conv的输入通道数和权重。

由于前三层0,1,2通道数较少因此每个通道对特征提取均较为重要,因此不剪枝

由于第4,6,9层的输出涉及head层中的通道拼接,结构复杂不便于剪枝,因此不剪枝

此外,其它Conv非连续的部分,例如C2f内部Conv层与Bottleneck之间有split操作,FPN中C2f之间穿插了Upsample,Concat等操作。这些部分我们也不剪枝。

这样来看,我们可以剪枝的地方包括:

模块间

Backbone:

Conv(3) => C2f(4)
Conv(5) => C2f(6)
Conv(7) => C2f(8)
C2f(8)  => SPPF(9)

Head:

C2f(15) => [Conv(16),Conv(Detect.cv2[0][0]),Conv(Detect.cv3[0][0])]
C2f(18) => [Conv(19),Conv(Detect.cv2[1][0]),Conv(Detect.cv3[1][0])]
C2f(21) => [Conv(Detect.cv2[2]),Conv(Detect.cv3[2])]

模块内

除了上述模块之间的衔接,模块内的连续Conv主要包括两部分

Bottleneck in C2f

Conv(Bottleneck.cv1) => Conv(Bottleneck.cv2)

cv2, cv3 in Detect

Conv(Detect.cv2[0][0]) => Conv(Detect.cv2[0][1])
Conv(Detect.cv2[0][1]) => Conv2d(Detect.cv2[0][2])
Conv(Detect.cv3[0][0]) => Conv(Detect.cv3[0][1])
Conv(Detect.cv3[0][1]) => Conv2d(Detect.cv3[0][2])

Conv(Detect.cv2[1][0]) => Conv(Detect.cv2[1][1])
Conv(Detect.cv2[1][1]) => Conv2d(Detect.cv2[1][2])
Conv(Detect.cv3[1][0]) => Conv(Detect.cv3[1][1])
Conv(Detect.cv3[1][1]) => Conv2d(Detect.cv3[1][2])

Conv(Detect.cv2[2][0]) => Conv(Detect.cv2[2][1])
Conv(Detect.cv2[2][1]) => Conv2d(Detect.cv2[2][2])
Conv(Detect.cv3[2][0]) => Conv(Detect.cv3[2][1])
Conv(Detect.cv3[2][1]) => Conv2d(Detect.cv3[2][2])

剪枝代码如下:

import torch
from ultralytics import YOLO
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect


def prune_conv(conv1: Conv, conv2: Conv, threshold=0.01):
    # 剪枝top-bottom conv结构
    # 首先,剪枝conv1的bn层和conv层
    # 获取conv1的bn层权重和偏置参数作为剪枝的依据
    gamma = conv1.bn.weight.data.detach()
    beta = conv1.bn.bias.data.detach()
    # 索引列表,用于存储剪枝后保留的参数索引
    keep_idxs = []
    local_threshold = threshold
    # 保证剪枝后的通道数不少于8,便于硬件加速
    while len(keep_idxs) < 8:
        # 取绝对值大于阈值的参数对应的索引
        keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
        # 降低阈值
        local_threshold = local_threshold * 0.5
    # print(local_threshold)
    # 剪枝后的通道数
    n = len(keep_idxs)
    # 更新BN层参数
    conv1.bn.weight.data = gamma[keep_idxs]
    conv1.bn.bias.data = beta[keep_idxs]
    conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
    conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
    conv1.bn.num_features = n
    # 更新conv层权重
    conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
    # 更新conv层输出通道数
    conv1.conv.out_channels = n
    # 更新conv层偏置,如果存在的话
    if conv1.conv.bias is not None:
        conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

    # 然后,剪枝conv2的conv层
    if not isinstance(conv2, list):
        conv2 = [conv2]

    for item in conv2:
        if item is not None:
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            # 更新输入通道数
            conv.in_channels = n
            # 更新权重
            conv.weight.data = conv.weight.data[:, keep_idxs]


def prune_module(m1, m2, threshold=0.01):
    # 剪枝 模块间衔接处结构,m1需要获取模块的bottom conv,m2需要获取模块的top conv
    # 打印出m1和m2的名字
    print(m1.__class__.__name__, end="->")
    if isinstance(m2, list):
        print([item.__class__.__name__ for item in m2])
    else:
        print(m2.__class__.__name__)
    if isinstance(m1, C2f):  # C2f as a top conv
        m1 = m1.cv2
    if not isinstance(m2, list):  # m2 is just one module
        m2 = [m2]
    for i, item in enumerate(m2):
        if isinstance(item, C2f) or isinstance(item, SPPF):
            m2[i] = item.cv1
    prune_conv(m1, m2, threshold)


def prune():
    # Load a model
    yolo = YOLO("last.pt")
    model = yolo.model
    # 统计所有的BN层权重和偏置参数
    ws = []
    bs = []

    for name, m in model.named_modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            w = m.weight.abs().detach()
            b = m.bias.abs().detach()
            ws.append(w)
            bs.append(b)
            # print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())

    # 保留80%的参数
    factor = 0.8
    ws = torch.cat(ws)
    # 从大到小排序,取80%的参数对应的阈值
    threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
    print(threshold)

    # 先剪枝整个网络bottleneck模块内部的结构
    for name, m in model.named_modules():
        if isinstance(m, Bottleneck):
            prune_conv(m.cv1, m.cv2, threshold)

    # 再剪枝backbone模块间衔接结构
    seq = model.model
    for i in range(3, 9):
        if i in [6, 4, 9]: continue
        prune_module(seq[i], seq[i + 1], threshold)

    # 再剪枝Head模块间衔接结构
    # Head模块间剪枝包括两部分,一部分是相邻下层连接,一部分是跨层到Detect层的输出
    # 从last_inputs到colasts是相邻下层连接,从last_inputs到detect是跨层到最后的输出
    detect: Detect = seq[-1]
    last_inputs = [seq[15], seq[18], seq[21]]
    colasts = [seq[16], seq[19], None]
    for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
        prune_module(last_input, [colast, cv2[0], cv3[0]], threshold)
        # 剪枝Detect层内部模块间衔接结构
        prune_module(cv2[0], cv2[1], threshold, )
        prune_module(cv2[1], cv2[2], threshold)
        prune_module(cv3[0], cv3[1], threshold)
        prune_module(cv3[1], cv3[2], threshold)

    # 设置所有参数为可训练,为retrain做准备
    for name, p in yolo.model.named_parameters():
        p.requires_grad = True

    # 保存剪枝后的模型
    yolo.save("prune.pt")


if __name__ == '__main__':
    prune()

3. 重调

剪枝完成后需要进行重调,此时我们需要先取消稀疏约束,即将trainer中的约束代码重新注释掉

随后,重调的时候,需要防止代码重新根据yaml文件生成模型,而是直接读取权重模型

修改:在ultralytics/engine/model.py-808行后添加

self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
# 新增 ===================================
self.trainer.model.model = self.model.model
# 新增 ===================================
self.model = self.trainer.model

随后基于如下代码进行重调训练:

yolo = YOLO("prune.pt")
yolo.train(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640, epochs=50)

4. 对比

我们可以对比一下稀疏训练的原模型、剪枝后的模型、重调后的模型的精度、参数、计算量

def compare_prune():
    # 统计压缩前后的参数量,精度,计算量
    yolo = YOLO("last.pt")
    before_results = yolo.val(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640)

    yolo_prune = YOLO("prune.pt")
    prune_results = yolo_prune.val(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640)

    yolo_retrain = YOLO("retrain.pt")
    retrain_results = yolo_retrain.val(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640)

    # 打印压缩前后的参数量,精度,计算量
    n_l, n_p, n_g, flops = yolo.info()
    prune_n_l, prune_n_p, prune_n_g, prune_flops = yolo_prune.info()
    retrain_n_l, retrain_n_p, retrain_n_g, retrain_flops = yolo_retrain.info()
    acc = before_results.box.map
    prune_acc = prune_results.box.map
    retrain_acc = retrain_results.box.map
    print(f"{'':<10}{'Before':<10}{'Prune':<10}{'Retrain':<10}")
    print(f"{'Params':<10}{n_p:<10}{prune_n_p:<10}{retrain_n_p:<10}")
    print(f"{'FLOPs':<10}{flops:<10}{prune_flops:<10}{retrain_flops:<10}")
    print(f"{'Acc':<10}{acc:<10}{prune_acc:<10}{retrain_acc:<10}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1963247.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华杉研发九学习日记18 集合 泛型

华杉研发九学习日记18 一&#xff0c;集合框架 1.1 集合和数组的区别 集合就是在java中用来保存多个对象的容器 集合是数组的升级版&#xff0c;集合中只能放置对象[object]. 数组: 在java中用来保存多个具有相同数据类型数据的容器 数组弊端&#xff1a; 1.数组只能保存…

2024AICoding公司全景图及评分

AI Coding背景 AI coding 领域的产品和公司在 2024 年开始爆发了&#xff0c;主要涉及技术进步、市场需求和开发者生态系统的变化。 本文会从技术背景&#xff0c;市场需求&#xff0c;生态以及相关评分为大家完整梳理一下相关内容。 底层技术 大规模预训练模型 技术背景&#…

C#使用OPC组件方式和AB的PLC通信

目录 一、PLC硬件配置 1、创建PLC程序 &#xff08;1&#xff09;程序工程选择 &#xff08;2&#xff09;变量和程序 2、配置程序在模拟器中运行 &#xff08;1&#xff09;打开RSLkin Classic &#xff08;2&#xff09;仿真器配置 &#xff08;3&#xff09;PLC程序…

我终于搭建完成了我的个人网站!(仅分享,非教程)

先看看我的个人网站~ https://yaoqx.pages.devhttps://yaoqx.pages.dev 来看看我搭建的过程吧&#xff01; &#xff08;仅分享&#xff0c;非教程&#xff09; 网站技术 前端框架&#xff1a;Astro主题&#xff1a;Frosti代码托管&#xff1a;Github网页部署&#xff1a;Cl…

Vscode ssh Could not establish connection to

错误表现 上午还能正常用vs code连接服务器看代码&#xff0c;中午吃个饭关闭vscode再重新打开输入密码后就提示 Could not establish connection to 然后我用终端敲ssh的命令连接&#xff0c;结果是能正常连接。 解决方法 踩坑1 网上直接搜Could not establish connectio…

浮点数如何存储

一、浮点数存储格式 符号&#xff08;sign&#xff09; s是符号位&#xff0c;1表示负&#xff0c;0表示正阶码&#xff08;exponent&#xff09; E的作用是对浮点数加权&#xff0c;这个权重是2的E次幂尾数&#xff08;significand&#xff09; M是一个二进制小数 二、举例说…

被爬网站用fingerprintjs来对selenium进行反爬,怎么破?

闲暇逛乎的时候&#xff0c;看到了这个问题&#xff1a; Fingerprintjs实际上就是专门用来识别和追踪浏览器的&#xff0c;要应对起来&#xff0c;确实并非易事。那么&#xff0c;我们要如何应对FingerprintJS的唯一标记技术呢&#xff1f; 接下来&#xff0c;我们将一起来探讨…

【自学深度学习梳理2】深度学习基础

一、优化方法 上一篇说到,使用梯度下降进行优化模型参数,可能会卡在局部最小值,或优化方法不合适永远找不到具有最优参数的函数。 1、局部最小值 梯度下降如何工作? 梯度下降是一种优化算法,用于最小化损失函数,即寻找一组模型参数,使得损失函数的值最小(局部最小值…

【Python体验】第五天:目录搜索、数据爬虫(评论区里写作业)

文章目录 目录搜索 os、shutil库数据爬虫 request、re作业&#xff1a;爬取案例的top250电影的关键信息&#xff08;名称、类型、日期&#xff09;&#xff0c;并保存在表格中 目录搜索 os、shutil库 os 模块提供了非常丰富的方法用来处理文件和目录。 os.listdir(path)&#x…

STM32的外部中断实现按键控制led灯亮灭(HAL库)

一&#xff1a;stm32外部中断概述 1&#xff1a;stm32的外部中断线 STM32的每个IO都可以作为外部中断输入。 STM32的中断控制器支持19个外部中断/事件请求&#xff1a; 线0~15&#xff1a;对应外部IO口的输入中断。 线16&#xff1a;连接到PVD输出。 线17&#xff1a;连接到R…

后端采用SpringBoot框架开发的:ADR药物不良反应智能监测系统源码,用于监测和收集药品在使用过程中发生的不良反应的系统

ADR药物不良反应智能监测系统是一套用于监测和收集药品在使用过程中发生的不良反应&#xff08;Adverse Drug Reaction, ADR&#xff09;的系统。该系统基于医院临床数据中心&#xff0c;运用信息技术实现药品不良反应的智能监测、报告管理、知识库查询、统计分析等功能&#x…

【Python学习手册(第四版)】学习笔记11.2-表达式语句(print函数)及打印操作(重定向等)详解

个人总结难免疏漏&#xff0c;请多包涵。更多内容请查看原文。本文以及学习笔记系列仅用于个人学习、研究交流。 主要介绍表达式语句&#xff08;print函数&#xff09;及打印操作&#xff08;重定向等&#xff09;。视需要选择目录阅读。 目录 表达式语句 错误示例&#xf…

实验3-7 统计学生成绩

//实验3-7 统计学生成绩 /* 本题要求编写程序读入N个学生的百分制成绩&#xff0c;统计五分制成绩的分布。 百分制成绩到五分制成绩的转换规则&#xff1a;大于等于90分为A&#xff1b;小于90且大于等于80为B&#xff1b;小于80且大于等于70为C&#xff1b;小于70且大于等于60为…

相机标定(Camera Calibration)

什么是 相机标定&#xff08;Camera Calibration&#xff09;&#xff1f; 相机标定&#xff08;CameraCalibration&#xff09;是确定相机内部参数&#xff08;如焦距、光学中心、畸变系数等&#xff09;和外部参数&#xff08;如相机在世界坐标系中的位置和姿态&#xff09;的…

黑马头条vue2.0项目实战(三)——个人中心功能的实现

1. Tabbar 处理 通过分析页面&#xff0c;可以看到&#xff0c;首页、问答、视频、我的 都使用的是同一个底部标签栏&#xff0c;我们没必要在每个页面中都写一个&#xff0c;所以为了通用方便&#xff0c;我们可以使用 Vue Router 的嵌套路由来处理。 父路由&#xff1a;一个…

激发潜能,Vatee万腾平台驱动企业持续发展

在当今这个日新月异的商业环境中&#xff0c;企业要想保持竞争力并实现持续发展&#xff0c;就必须不断挖掘自身潜能&#xff0c;探索新的增长点。而Vatee万腾平台&#xff0c;正是这样一位能够激发企业潜能、驱动其持续发展的强大伙伴。 一、智能化赋能&#xff0c;解锁企业潜…

了解ISO 22301:业务连续性管理的关键

在当今全球化和复杂化的商业环境中&#xff0c;企业面临着各种潜在的风险和灾难&#xff0c;这些可能对其运营和声誉造成严重影响。为了有效地应对这些挑战并保障持续经营&#xff0c;国际标准化组织&#xff08;ISO&#xff09;引入了ISO 22301标准&#xff0c;这是一项专注于…

智能制造与工业物联网CC2530——定时器查询和中断

一、项目目的&#xff1a; 熟悉 ZigBee 模块相关硬件接口。使用 IAR 开发环境设计程序&#xff0c;学习 CC2530 定时器的使用&#xff0c;利用 CC2530 的定时器 T1 查询方式控制 LED 周期性闪烁。 二、项目原理&#xff1a; LED及按键原理图&#xff0c;如下图所示&#xff…

使用“阿里云人工智能平台 PAI”制作数字人

体验 阿里云人工智能平台 PAI PAI-DSW免费试用 https://free.aliyun.com/?spm5176.14066474.J_5834642020.5.7b34754cmRbYhg&productCodelearn https://help.aliyun.com/document_detail/2261126.html 体验PAI-DSW https://help.aliyun.com/document_detail/2261126.…

一文详解香港机房服务器干什么用的

香港机房服务器干什么用的&#xff1f;香港机房服务器是用于数据存储和备份、网络服务、数据处理与分析、云计算服务、游戏托管服务、其他服务等。香港机房服务器在现代互联网业务中扮演着至关重要的角色&#xff0c;其主要用途可以归纳为以下几个方面&#xff1a; 1、数据存储…