Yolov8模型用torch_pruning剪枝

news2025/2/28 0:18:10

目录

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

原理

 遍历所有分组

高级剪枝器


🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxv

原理

传统剪枝方法的缺陷

在复杂的网络结构中, 参数之间可能存在依赖关系, 这种依赖要求算法对这类参数进行同步移除以保证结构正确性,这就涉及到耦合参数的分组问题. 我们的工作通过提供一种自动化机制来对参数进行分组. 具体而言, Torch-Pruning使用伪输入来运 行模型, 跟踪网络计算图, 并记录层之间的依赖关系. 当剪枝某一层时, Torch-Pruning会识别所有耦合层, 并返回包含这些耦合信息的tp.Group.

一种通用的结构化剪枝框架DepGraph(Dependency Graph),可以应用于任意类型的神经网络架构(包括CNN、RNN、GNN和Transformer等)进行结构化剪枝。主要原理如下:

1. 神经网络内部存在着层与层之间的依赖关系,需要同时剪枝依赖的层组,否则会破坏网络结构。

2. 结构化剪枝的优势

结构化剪枝的做法是,找到网络中相互依赖的层组,把整个层组同时全部保留或全部删除,从而保证网络结构的完整性。这种做法虽然灵活性较低,但可以有效避免了网络结构被破坏的问题。

3. DepGraph通过建模层与层之间的依赖关系,明确每一层所属的层组。具体分为两种依赖关系:

   a) 层间依赖(Inter-layer Dependency): 相邻连接的层之间存在依赖   层间不依赖:resnet

   b) 层内依赖(Intra-layer Dependency): 同一层的输入和输出具有相同的剪枝方式时存在依赖   层内不依赖:没有共享权重的

4. 通过图遍历算法在DepGraph上找到最大连接分量作为层组,实现自动化的层组划分。总的来说,DepGraph解决了之前结构化剪枝算法依赖人工设计层组划分规则、缺乏通用性的问题,提出了一种自动建模层组依赖关系和组级剪枝重要性评估的通用框架。

5. DepGraph的工作原理

以ResNet的基本模块为例,如果要删除某个卷积层的滤波器核,由于残差连接的存在,我们必须同时删除该模块中所有层(BN层、ReLU层等)对应的通道。DepGraph通过建模层与层之间的依赖关系,自动将这些相互依赖的层划分到同一个层组中。在剪枝时,整个层组被统一评分,决定是完全保留还是完全删除,从而实现安全的结构化剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. 构建依赖图
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. 指定剪枝的通道维度
pruning_idxs = [2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

print(pruning_group.details())  # or print(pruning_group)

# 3. 检查剩余通道数是否<=0, 并执行剪枝
if DG.check_pruning_group(pruning_group):
    pruning_group.prune()

这个例子演示了使用 DepGraph剪枝的基本流程, resnet.conv1实际上会与多个层耦合在一起.通过打印返回的组, 可以看到组内各个层之间的剪枝是如何互相“触发”的.在以下输出中, “A => B”表示剪枝操作“A”触发剪枝操作“B”.group[0]是用户在DG.get_pruning_group中给出的剪枝操作. 

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), #idxs=3
[1] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[2] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), #idxs=3
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), #idxs=3
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), #idxs=3
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), #idxs=3
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), #idxs=3
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), #idxs=3
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), #idxs=3
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), #idxs=3
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
--------------------------------
 遍历所有分组

可以利用DG.get_all_groups(ignored_layers, root_module_types)来按顺序扫描所有的分组. 每个分组都会以一个"root_module_types"中所指定的层作为起点. 默认情况下, 这些组包含了完整的剪枝索引idxs=[0,1,2,3,...,K], 这个索引列表包含了所有的可修剪参数的索引. 如果我们希望对一个group进行剪枝, 我们需要使用group.prune(idxs=idxs)来指定具体的修剪通道/维度.

for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
    # handle groups in sequential order
    idxs = [2,4,6] # your pruning indices
    group.prune(idxs=idxs)
    print(group)
高级剪枝器
import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)

# 重要性指标
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2) # p=2表示使用L2正则,对每个group中的每个层的权值,独立的计算重要性   重要性如何计算??什么是重要的?值大还是小?是损失吗

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # 迭代式剪枝, 该示例会分五步完成50%通道剪枝 (10%->20%->...->50%)
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5, # 整体移除50%通道, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1504506.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TYPE C模拟耳机POP音产生缘由

关于耳机插拔的POP音问题&#xff0c;小白在之前的文章中讲述过关于3.5mm耳机的POP音产生原因。其实这类插拔问题的POP音不仅仅存在于3.5mm耳机&#xff0c;就连现在主流的Type C模拟耳机的插拔也存在此问题&#xff0c;今天小白就来讲一讲这类耳机产生POP音的缘由。 耳机左右…

计算机视觉——P2PNet基于点估计的人群计数原理与C++模型推理

简介 人群计数是计算机视觉领域的一个核心任务&#xff0c;旨在估算静止图像或视频帧中的行人数量。在过去几十年中&#xff0c;研究人员在这个领域投入了大量的精力&#xff0c;并在提高现有主流基准数据集性能方面取得了显著进展。然而&#xff0c;训练卷积神经网络需要大规…

书与我

和书深深结缘&#xff0c;始于需求&#xff0c;得益于通勤时间长。 读什么书 一直没有停止过编码&#xff0c;工作性质也要求我必须了解很多的新技术&#xff0c;从踏上工作岗位后&#xff0c;就需要不停的看书。从《JAVA编程思想》、《java与模式》、《TCP/IP详解》、《深入…

131.分割回文串

// 定义一个名为Solution的类 class Solution {// 声明一个成员变量&#xff0c;用于存储所有满足条件的字符串子序列划分结果List<List<String>> lists new ArrayList<>(); // 声明一个成员变量&#xff0c;使用LinkedList实现的双端队列&#xff0c;用于临…

Windows下安装pip

一、下载pip 官网地址&#xff1a;https://pypi.org/project/pip/#files 1.1、pip工具查找方法 单击官网首页“PyPi”选项 在弹出来的搜索框中输入“pip” 选择最新的pip版本&#xff0c;点进去 下载pip安装包包 二、安装pip 解压“pip-24.0.tar.gz”&#xff0c;进…

【深度学习笔记】6_5 RNN的pytorch实现

注&#xff1a;本文为《动手学深度学习》开源内容&#xff0c;部分标注了个人理解&#xff0c;仅为个人学习记录&#xff0c;无抄袭搬运意图 6.5 循环神经网络的简洁实现 本节将使用PyTorch来更简洁地实现基于循环神经网络的语言模型。首先&#xff0c;我们读取周杰伦专辑歌词…

b站小土堆pytorch学习记录—— P23-P24 损失函数、反向传播和优化器

文章目录 一、损失函数1.简要介绍2.代码 二、优化器1.简要介绍2.代码 一、损失函数 1.简要介绍 可参考博客&#xff1a; 常见的损失函数总结 损失函数的全面介绍 pytorch学习之十九种损失函数 损失函数&#xff08;Loss Function&#xff09;是用来衡量模型预测输出与实际…

开发指南002-前后端信息交互规范-概述

前后端之间采用restful接口&#xff0c;服务和服务之间使用feign。信息交互遵循如下平台规范&#xff1a; 前端&#xff1a; 建立api目录&#xff0c;按照业务区分建立不同的.js文件&#xff0c;封装对后台的调用操作。其中qlm*.js为平台预制的接口文件&#xff0c;以qlm_user.…

离线数仓(五)【数据仓库建模】

前言 今天开始正式数据仓库的内容了, 前面我们把生产数据 , 数据上传到 HDFS , Kafka 的通道都已经搭建完毕了, 数据也就正式进入数据仓库了, 解下来的数仓建模是重中之重 , 是将来吃饭的家伙 ! 以及 Hive SQL 必须熟练到像喝水一样 ! 第1章 数据仓库概述 1.1 数据仓库概念 数…

【stm32 外部中断】

中断&#xff1a;在主程序运行过程中&#xff0c;出现了特定的中断触发条件&#xff08;中断源&#xff09;&#xff0c;使得CPU暂停当前正在运行的程序&#xff0c;转而去处理中断程序&#xff0c;处理完成后又返回原来被暂停的位置继续运行 中断优先级&#xff1a;当有多个中…

mybatis-plus整合spring boot极速入门

使用mybatis-plus整合spring boot&#xff0c;接下来我来操作一番。 一&#xff0c;创建spring boot工程 勾选下面的选项 紧接着&#xff0c;还有springboot和依赖我们需要选。 这样我们就创建好了我们的spring boot&#xff0c;项目。 简化目录结构&#xff1a; 我们发现&a…

未来城市:探索数字孪生在智慧城市中的实际应用与价值

目录 一、引言 二、数字孪生与智慧城市的融合 三、数字孪生在智慧城市中的实际应用 1、智慧交通管理 2、智慧能源管理 3、智慧建筑管理 4、智慧城市管理 四、数字孪生在智慧城市中的价值 五、挑战与展望 六、结论 一、引言 随着科技的飞速发展&#xff0c;智慧城市已…

R统计学2 - 数据分析入门问题21-40

往期R统计学文章&#xff1a; R统计学1 - 基础操作入门问题1-20 21. 如何对矩阵按行 (列) 作计算&#xff1f; 使用函数 apply() vec 1:20 # 转换为矩阵 mat matrix (vec , ncol4) # [,1] [,2] [,3] [,4] # [1,] 1 6 11 16 # [2,] 2 7 12 17 # [3,] …

前端框架的发展历史介绍

前端框架的发展历史是Web技术进步的一个重要方面。从最初的简单HTML页面到现在的复杂单页应用程序&#xff08;SPA&#xff09;&#xff0c;前端框架和库的发展极大地推动了Web应用程序的构建方式。以下是一些关键的前端框架和库&#xff0c;以及它们的发布年份、创建者和主要特…

UnicodeDecodeError: ‘gbk‘和Error: Command ‘pip install ‘pycocotools>=2.0

今天重新弄YOLOv5的时候发现不能用了&#xff0c;刚开始给我报这个错误 subprocess.CalledProcessError: Command ‘pip install ‘pycocotools&#xff1e;2.0‘‘ returned non-zero exit statu 说这个包安装不了 根据他的指令pip install ‘pycocotools&#xff1e;2.0这个根…

从零开始:神经网络(2)——MP模型

声明&#xff1a;本文章是根据网上资料&#xff0c;加上自己整理和理解而成&#xff0c;仅为记录自己学习的点点滴滴。可能有错误&#xff0c;欢迎大家指正。 神经元相关知识&#xff0c;详见从零开始&#xff1a;神经网络——神经元和梯度下降-CSDN博客 1、什么是M-P 模型 人…

CorelDRAW Graphics Suite2024专业图形设计软件Windows/Mac最新25.0.0.230版

CorelDRAW Graphics Suite 2024是一款专业的图形设计软件&#xff0c;它集成了CorelDRAW Standard 2024和其他高级图形处理工具&#xff0c;为用户提供了全面的图形设计和编辑解决方案。 该软件拥有强大的矢量编辑功能&#xff0c;用户可以轻松创建和编辑矢量图形&#xff0c;…

数字化转型导师坚鹏:科技金融政策、案例及数字化营销

科技金融政策、案例及数字化营销 课程背景&#xff1a; 很多银行存在以下问题&#xff1a; 不清楚科技金融有哪些利好政策&#xff1f; 不知道科技金融有哪些成功案例&#xff1f; 不知道科技金融如何数字化营销&#xff1f; 课程特色&#xff1a; 以案例的方式解读原…

聚类简单讲解

聚类任务 聚类任务是指将一组数据分成多个不同的组&#xff08;或簇&#xff09;&#xff0c;使得同一组内的数据点彼此相似&#xff0c;而不同组之间的数据点尽可能不相似的过程。聚类任务的目标是发现数据中的固有结构&#xff0c;而不需要事先知道数据的类别信息。聚类算法…

IntelliJ IDEA Dev 容器

​一、dev 容器 开发容器&#xff08;dev 容器&#xff09;是一个 Docker 容器&#xff0c;配置为用作功能齐全的开发环境。 IntelliJ IDEA 允许您使用此类容器来编辑、构建和运行您的项目。 IntelliJ IDEA 还支持多个容器连接&#xff0c;这些连接可以使用 Docker Compose …