yolov5及yolov7实战之剪枝

news2024/12/24 3:16:04

之前有讲过一次yolov5的剪枝:yolov5实战之模型剪枝_yolov5模型剪枝-CSDN博客
当时基于的是比较老的yolov5版本,剪枝对整个训练代码的改动也比较多。最近发现一个比较好用的剪枝库,可以在不怎么改动原有训练代码的情况下,实现剪枝的操作,这篇文章就简单介绍一下,剪枝的概念以及为什么要剪枝可以参看上一篇,这里就不赘述了。

Torch-Pruning

VainF/Torch-Pruning: [CVPR 2023] Towards Any Structural Pruning; LLMs / Diffusion / Transformers / YOLOv8 / CNNs (github.com)
今天我们要用到的就是这个剪枝库,这个库集成了很多剪枝的方法,毕竟使用比较简单。

用法

这个剪枝库既有low level的剪枝,也就是手动控制剪枝哪些层,也有high level的剪枝,就是使用预设的剪枝算法,自动选择剪枝的部分。对于我们来说,更适合使用high level剪枝。具体的这里使用和上一篇yolov5里面的剪枝一样的算法,在这个库里叫BNScalePruner。

安装

首先我们需要安装上面提到的库,有两种方式来安装:

pip install torch-pruning

或源码安装(当碰到bug发布版本没修复,源码修复的时候):

pip install git+https://github.com/VainF/Torch-Pruning.git

稀疏化训练

为了更好的剪枝,我们在训练剪枝前的网络时,推荐开启稀疏化训练,利用这个库,我们可以很方便的实现这个操作。
首先在我们的训练代码中定义好剪枝器, 这里的opt.prune是我自己加的来控制是否开启稀疏化训练的标志:

# prune
if opt.prune:
	examle_input = torch.randn(1, 3, imgsz, imgsz).to(device)
	imp = tp.importance.BNScaleImportance()
	pruner = tp.pruner.BNScalePruner(model, examle_input, imp,
									 reg=0.0001)

稀疏化训练主要需要设置reg参数,一般设置0.001~1e-6之间。
定义好剪枝器后,在训练代码的scaler.scale(loss).backward()之后,添加如下代码:

if opt.prune:
	pruner.regularize(model)

即可实现稀疏化训练。

剪枝

稀疏化训练后(也可以不做稀疏化训练),我们就可以进行剪枝操作了。这个库可以在训练中交互式进行多次剪枝,简单起见,我们这里分离剪枝和训练的代码,只进行剪枝操作。

import torch_pruning as tp
from models.experimental import attempt_load
import torch

weights = "yolov7.pt"
model = attempt_load(weights, map_location=torch.device('cuda:0'), fuse=False)
for p in model.parameters():
    p.requires_grad = True
ignored_layers = []
from models.yolo import Detect, IDetect
from models.common import ImplicitA, ImplicitM
for m in model.modules():
    if isinstance(m, (Detect,IDetect)):
        ignored_layers.append(m.m)
unwrapped_parameters = []
for name, m in model.named_parameters():
    if isinstance(m, (ImplicitA,ImplicitM,)):
        unwrapped_parameters.append((name,1)) # pruning 1st dimension of implicit matrix

print(ignored_layers)
example_inputs = torch.rand(1, 3, 416, 416, device='cuda:0')
imp = tp.importance.BNScaleImportance()
pruner = tp.pruner.BNScalePruner(model, example_inputs, imp,
                                   ignored_layers=ignored_layers,
                                   unwrapped_parameters=unwrapped_parameters,
                                   global_pruning=True,
                                   ch_sparsity=0.3,
                                   round_to=8,
                                   )

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
pruned_model = pruner.model
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruned_model, example_inputs)
print(f"macs: {base_macs} -> {pruned_macs}")
print(f"nparams: {base_nparams} -> {pruned_nparams}")
macs_cutoff_ratio = (base_macs - pruned_macs) / base_macs
nparams_cutoff_ratio = (base_nparams - pruned_nparams) / base_nparams
print(f"macs cutoff ratio: {macs_cutoff_ratio}")
print(f"nparams cutoff ratio: {nparams_cutoff_ratio}")
save_path = weights.replace(".pt", "_pruned_bn_0.3.pt")

torch.save({"model": pruned_model.module if hasattr(pruned_model, 'module') else pruned_model}, save_path)

去掉一些计算剪枝比例的,保存代码等代码外,剪枝操作其实由pruner.step()这一步完成。这里我们主要需要设置的参数是:

  • ch_sparsity: 可以理解成剪枝的比例,越大剪得越多
  • global_pruning: True表示整个模型的权重按一个整体排序后剪枝,False表示按分组内部按比例剪枝
  • round_to: 剪枝后的通道保留为多少的倍数,一般在硬件上,保留8的倍数

微调

经过剪枝的网络,精度是下降比较明显的,需要再在数据上finetune一些epoch才能把精度拉回来。
yolov7默认是通过yaml文件创建模型结构,然后再载入权重进行训练的,而我们剪枝后的模型是没有模型结构文件的,因此需要对训练代码做一定的修改,具体而言,只是对模型的载入进行一点修改。其中opt.finetune是用来控制是否处于finetune模式的标志位。

if opt.finetune: # for model without cfg
	new = torch.load(weights, map_location=device)  # create
	model = new["model"]
	print("Finetune Mode...")
elif pretrained:
...

比较简单的改法是这样,从checkpoint中载入结构和权重,还有一种方式则是修改yolov7的Model类,这个在后面讲yolov7剪枝后蒸馏的时候再讲,暂时用上面这种方式就可以了。

评测

我在自己的任务上的效果是yolov7剪枝50%,微调后基本上能达到剪枝前的map,没记错的话这是和稀疏化训练的比,毕竟开启稀疏化训练本身也会掉点。大家可以在自己的任务上尝试一下,总体上精度还是可以的

结语

这篇文章简述了以下yolov7的剪枝,yolov5也可用,希望对大家有帮助。
f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1068827.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

c#学生管理系统

一、系统概述 学生管理系统是一个旨在帮助学校、教育机构和教育者有效管理学生信息、课程安排和成绩记录的应用程序。该系统旨在简化学生管理的各个方面,提供高效的解决方案,以满足教育机构的需求。 二、功能模块 1. 学生信息管理 添加学生:录入学生…

HashMapConcurrentHashMap

文章目录 1、HashMap基础类属性node容量负载因子hash算法 2、数组链表/树为什么引入链表为什么jdk1.8会引入红黑树为什么一开始不就使用红黑树?HashMap的底层数组取值的时候,为什么不用取模,而是&数组的长度为什么是2的次幂如果指定数组的…

数据结构--》数组和广义表:从基础到应用的全面剖析

数据结构为我们提供了组织和处理数据的基本工具。而在这个广袤的数据结构领域中,数组和广义表是两个不可或缺的重要概念。它们作为线性结构的代表,在算法与应用中扮演着重要的角色。 无论你是初学者还是进阶者,本文将为你提供简单易懂、实用可…

青少年近视问题不容小觑,蔡司用专业技术助力孩子视力健康发展

根据国家卫健委公布的数据显示,2022年全国儿童青少年近视率达到53.6%,青少年近视已成为社会普遍的眼健康问题。对家长来说,也需要提高对孩子眼视光健康重要性的认知,日常培养青少年良好的用眼习惯,并通过矫正视力的方式…

如何使用 Tensor.art 实现文生图

摘要:Tensor.art 是一个基于 AI 的文本生成图像工具。本文介绍了如何使用 Tensor.art 来实现文生图的功能。 正文: 文生图是指将文本转换为图像的技术。它具有广泛的应用,例如在广告、教育和娱乐等领域。 Tensor.art 是一个基于 AI 的文本…

外汇天眼:真实记录,投资者在盗版MT4平台SCE Group上做交易的经历!

外汇市场是全球最大的金融市场,比起其他市场有着更多天然的优势,但也因为资讯的不对等,导致很多人上当受骗。而在外汇市场上最常见的骗局之一,就是黑平台使用盗版MT4/5交易软件,因为截至目前MT4/5仍是外汇市场交易使用…

汽车电子中的安森美深力科分享一款高性能车规级芯片NCV7520MWTXG

安森美深力科NCV7520MWTXG可编程六沟道低压侧 MOSFET 预驱动器,是一个 FLEXMOS™ 汽车级产品系列,用于驱动逻辑电平 MOSFET。该产品可通过串行 SPI 和并行输入组合控制。该器件提供可兼容 3.3 V/5 V 的输入,串行输出驱动器可基于 3.3 V 或 5 …

在模拟器上安装magisk实现Charles抓https包(三)

经过前两篇的内容,链接如下: 在模拟器上安装magisk实现Charles抓https包(一)_小小爬虾的博客-CSDN博客 在模拟器上安装magisk实现Charles抓https包(二)_小小爬虾的博客-CSDN博客 电脑端的Charles就可以抓…

VS2022 17.8 功能更新:现已支持 C11 线程

早在 VS2022 17.5 版本,Microsoft Visual C 库已经初步支持了 C11 atomics。今天,我们很高兴地宣布,在最新版本 VS2022 17.8 预览版 2 中已正式支持 C11 线程。开发者可以更轻松地将跨平台 C 应用程序移植到 Windows,而无需开发线…

华为OD机试 - 最小步骤数(Java 2023 B卷 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入:4 8 7 5 2 3 6 4 8 12、输出:23、说明:4、思路分析 华为OD机试 2023B卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《…

网络安全总结

前言 本文内容主要摘抄网络规划设计师的教材和腾讯-SUMMER课堂,主要对网络安全进行简单梳理和总结 OSI安全体系 X轴表示8种安全机制,Y轴表示OSI7层模型,Z轴表示5种安全服务,图中X是水平,Y轴竖直,Z轴向外…

2023年中国喷头受益于技术创新,功能不断提升[图]

喷头行业是一个专注于生产和供应各种类型喷头的产业。喷头是一种用于将液体、气体或粉末等物质喷射或喷洒的装置,广泛应用于不同领域,包括工业、农业、家用、医疗等。 喷头行业分类 资料来源:共研产业咨询(共研网) 随…

Redis 获取、设置配置文件

以Ubuntu 为例 redis配置文件 cd /etc/redis sudo vim redis.conf 获取配置文件、修改配置文件

【轻松玩转MacOS】网络连接篇

引言 本篇让我们来聊聊网络连接。不论你是在家、在办公室,还是咖啡厅、机场,几乎所有的MacOS用户都需要连接到互联网。在这个部分,我们将向你展示如何连接到互联网和局域网。让我们开始吧! 一、连接到互联网 首先,我…

农业育种好策略:凌恩生物种质资源数字化全方位解决方案

动植物育种是通过创造遗传变异、改良遗传特性,以培育具有优良性状的动植物新品种的技术。随着高通量组学技术的发展和应用,分子育种等现代科学理论与技术得以发展和不断完善,是未来作物育种的不二选择,它的精准性、高效性都将带领…

NoSQL之 Redis命令工具及常用命令

目录 1 Redis 命令工具 1.1 redis-cli 命令行工具 1.2 redis-benchmark 测试工具 2 Redis 数据库常用命令 2.1 set:存放数据,命令格式为 set key value 2.2 get:获取数据,命令格式为 get key 2.3 keys 命令可以取符合规则的…

深入探索地理空间查询:如何优雅地在MySQL、PostgreSQL及Redis中实现精准的地理数据存储与检索技巧

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

入门级气传导耳机推荐哪款?安利几款好用的气传导耳机

​在当今的快节奏生活中,音乐成为了我们放松身心的重要方式。然而,我们在享受音乐的同时,也面临着耳机线缆的束缚和耳朵的压迫感。这时,气传导耳机应运而生,它们以一种更加先进、舒适的方式来传递音乐,为我…

【C++】-C++11中的知识点(上)--右值引用,列表初始化,声明

💖作者:小树苗渴望变成参天大树🎈 🎉作者宣言:认真写好每一篇博客💤 🎊作者gitee:gitee✨ 💞作者专栏:C语言,数据结构初阶,Linux,C 动态规划算法🎄 如 果 你 …

Linux基本指令一

Linux基本指令一 一、ls指令1、语法2、功能3、常用选项4、示例 二、pwd指令1、功能2、示例 三、cd指令1、语法2、功能3、常用操作4、示例 四、 touch指令1、语法2、功能3、示例 五、mkdir指令1、语法2、功能3、常用选项4、示例 六、rmdir指令1、语法2、适用对象3、功能4、常用选…