使用torch_pruning对YOLOv8进行剪枝(新版、detect、segment)

news2024/9/24 3:21:51

torch_pruning库介绍

group原理
在结构修剪中,**Group被定义为深度网络中最小的可移除单元。**每个组由多个相互依赖的层组成,需要同时修剪这些层以保持最终结构的完整性。然而,深度网络通常表现出层与层之间错综复杂的依赖关系,这对结构修剪提出了重大挑战。这项研究通过引入DepGraph轻松实现参数分组,并有助于修剪各种深度网络。

如修剪图中高亮的神经元,我们需要对所有有连线的层都进行修剪。(a) W_l、W_l+1 (b) W_l、W_l+1、W_l+2 © W_l、W_l+1、W_l+2 (d) W_l

一个简单的例子

我们修剪resnet18,把其第一个卷积的输出通道维度减少3。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()
print(model)
# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )
print(group)
# 3. prune all grouped layers that are coupled with model.conv1 (included).
if DG.check_pruning_group(group): # avoid full pruning, i.e., len(channels)=0.
    group.prune()
    
print(model)

剪枝结果1
跟conv1相关的层都被修剪了,剪的是第2,6,9维,但这种方式不给灵活,只能修剪固定的索引,下面我实现更灵活的方式。
下面我们对整个resnet18剪枝,使其通道数减半,ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
这里需要我们做两件事,第一,重要性函数,抽象的说就是评估每个组的重要性,不重要的就可以修剪掉,可以使用内置重要性函数,也可以自定义。第二,配置剪枝器,如使用剪枝多少步,最后的稀疏性是多少,这里我们设置稀疏性是0.5。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"Before Pruning: MACs={base_macs / 1e9: .5f} G, #Params={base_nparams / 1e6: .5f} M")
for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(f"{i+1}/{iterative_steps} Pruning: MACs={macs / 1e9: .5f} G, #Params={nparams / 1e6: .5f} M")
    # finetune your model here
    # finetune(model),model.train()
    # ...

剪枝结果2

在YOLOv8中剪枝

通过上面的步骤,发现剪枝好像并不困难,主要与ultralytics框架集成会麻烦点,核心剪枝代码在下面,有注释,完整代码在这里:yolov8_pruning.py

def prune(args):
    # 加载模型,yaml,pt
    model = YOLO(args.model)
    # 加入train_v2方法在YOLO对象中,主要是重写train方法,为了不让model.train()每次新建一个模型来训练,这样会导致剪枝失败。
    # 解决方法就是在训练完后重新加载训练完的权重到YOLO.model中。
    model.__setattr__("train_v2", train_v2.__get__(model))
    pruning_cfg = yaml_load(check_yaml(args.cfg))
    batch_size = pruning_cfg['batch']

    model.model.train()
    # split操作不支持剪枝,使用两个卷积替换split操作
    replace_c2f_with_c2f_v2(model.model)
    initialize_weights(model.model)  #设置 BN.eps, momentum, ReLU.inplace
    # 开启梯度训练
    for name, param in model.model.named_parameters():
        param.requires_grad = True

    example_inputs = torch.randn(1, 3, pruning_cfg["imgsz"], pruning_cfg["imgsz"]).to(model.device)
    # 保存浮点数、参数量、mAP和剪枝mAP的记录
    macs_list, nparams_list, map_list, pruned_map_list = [], [], [], []
    # 计算浮点数和参数量
    base_macs, base_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)

    # 在剪枝操作之前先评估一次模型
    pruning_cfg['name'] = f"baseline_val"
    validation_model = deepcopy(model)
    metric = validation_model.val(**pruning_cfg)
    init_map = metric.box.map
    # 保存浮点数、参数量、mAP和剪枝mAP的记录
    macs_list.append(base_macs)
    nparams_list.append(base_nparams)
    map_list.append(init_map)
    pruned_map_list.append(init_map)
    print(f"Before Pruning: MACs={base_macs / 1e9: .5f} G, #Params={base_nparams / 1e6: .5f} M, mAP={init_map: .5f}")

    # 每一步的剪枝率
    pruning_ratio = 1 - math.pow((1 - args.target_prune_rate), 1 / args.iterative_steps)
    print(pruning_ratio)
    # 这里可以发现剪枝器可以在循环里或者循环外,虽然最终模型的稀疏性是没变化的,但每剪一次微调一次的效果会更好。
    for i in range(args.iterative_steps):

        model.model.train()
        for name, param in model.model.named_parameters():
            param.requires_grad = True

        ignored_layers = []
        unwrapped_parameters = []
        # 忽略的层,一般都对头部网络进行忽略,如果是目标检测就换成Detect,记得先引入,这是个类。
        for m in model.model.modules():
            if isinstance(m, (Segment,)):
                ignored_layers.append(m)

        example_inputs = example_inputs.to(model.device)
        pruner = tp.pruner.GroupNormPruner(
            model.model,
            example_inputs,
            importance=tp.importance.GroupNormImportance(),  # L2 norm pruning,
            iterative_steps=1,
            pruning_ratio=pruning_ratio,
            ignored_layers=ignored_layers,
            unwrapped_parameters=unwrapped_parameters
        )

        pruner.step()
        # 剪枝完后先评估一遍模型
        pruning_cfg['name'] = f"step_{i}_pre_val"
        validation_model.model = deepcopy(model.model)
        metric = validation_model.val(**pruning_cfg)
        pruned_map = metric.box.map
        pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruner.model, example_inputs.to(model.device))
        current_speed_up = float(macs_list[0]) / pruned_macs
        print(f"After pruning iter {i + 1}: MACs={pruned_macs / 1e9} G, #Params={pruned_nparams / 1e6} M, "
              f"mAP={pruned_map}, speed up={current_speed_up}")

        # 微调模型,重新训练,一般10-50epochs?
        for name, param in model.model.named_parameters():
            param.requires_grad = True
        pruning_cfg['name'] = f"step_{i}_finetune"
        pruning_cfg['batch'] = batch_size  # restore batch size
        model.train_v2(pruning=True, **pruning_cfg)

        # 微调完后再评估一遍模型
        pruning_cfg['name'] = f"step_{i}_post_val"
        validation_model = YOLO(model.trainer.best)
        metric = validation_model.val(**pruning_cfg)
        current_map = metric.box.map
        print(f"After fine tuning mAP={current_map}")

        macs_list.append(pruned_macs)
        nparams_list.append(pruned_nparams / base_nparams * 100)
        pruned_map_list.append(pruned_map)
        map_list.append(current_map)

        # 移除剪枝器
        del pruner

        save_pruning_performance_graph(nparams_list, map_list, macs_list, pruned_map_list)

        if init_map - current_map > args.max_map_drop and current_speed_up>=1.2:
            print("Pruning early stop")
            break

    model.export(format='onnx')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1946307.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[大牛直播SDK]Windows平台RTMP直播推送模块功能设计

技术优势 全自研框架,易于扩展,自适应算法让延迟更低、采集编码传输效率更高;所有功能以SDK接口形式提供,所有状态,均有event回调,完美支持断网自动重连;SDK模块化,可和大牛直播播放…

DBeaver Ultimate 22.1.0 连接数据库(MySQL+Mongo+Clickhouse)

前言 继续书接上文 Docker Compose V2 安装常用数据库MySQL+Mongo,部署安装好之后我本来是找了一个web端的在线连接数据库的工具,但是使用过程中并不丝滑,最终还是选择了使用 DBeaver ,然后发现 mongo 还需要许可,又折腾整理了半下午,终于大功告成。 DBeaver 版本及说明…

SpringBoot集成Sharding-JDBC实现分库分表

本文已收录于专栏 《中间件合集》 目录 版本介绍背景介绍拆分方式集成并测试1.引入依赖2.创建库和表3.pom文件配置3.编写测试类Entity层Mapper接口MapperXML文件测试类 4.运行结果 自定义分片规则定义分片类编写pom文件 总结提升 版本介绍 SpringBoot的版本是: 2.3.…

SpringBoot上传超大文件导致Cannot read more than 2,147,483,647 into a byte array,问题解决办法

问题描述 报错: java.lang.IllegalArgumentException: Cannot read more than 2,147,483,647 into a byte array at org.apache.commons.io.IOUtils.lambda$toByteArray$0(IOUtils.java:2403) ~[commons-io-2.11.0.jar:2.11.0] at org.apache.commons.io.output.Thre…

python每日学习12:pandas库的用法(1)

python每日学习12:pandas库的用法(1) 安装 pip install pandas设定系统环境 import pandas as pd #设定自由列表输出最多为 10 行 pd.options.display.max_rows 10 # 显示当前 Pandas 版本号 pd.__version__进入jupyter notebook 页面 在终端…

氧气传感器在汽车制氧检测中的应用

在当今汽车工业中,技术的快速发展不仅带来了驾驶安全性和舒适性的显著提升,还为车辆增加了各种智能功能,以应对不同的驾驶环境和需求。氧气传感器作为一种关键的技术装置,在汽车制氧检测系统中的应用,尤其是针对疲劳驾…

困于数字化泥潭的软件公司|专题报告集

数字化专题报告集链接:https://tecdat.cn/?p36964 在探讨企业数字化转型的进程中,软件公司无疑扮演着举足轻重的角色。它们不仅是技术创新的驱动力,更是连接管理与技术的桥梁。然而,正如许多观察家所指出的那样,软件…

每天五分钟计算机视觉:目标检测模型从RCNN到Fast R-CNN的进化

本文重点 前面的课程中,我们学习了RCNN算法,但是RCNN算法有些慢,然后又有了基于RCNN的Fast-RCNN,Fast R-CNN是一种深度学习模型,主要用于目标检测任务,尤其在图像中物体的识别和定位方面表现出色。它是R-CNN系列算法的一个重要改进版本,旨在解决R-CNN中计算量大、速度慢…

jackson序列化(jackson codec)

Jackson 是一个用于 Java 平台的开源 JSON 库,它提供了灵活且高效的方式来处理 JSON 数据的序列化(Java对象 → JSON字符串)和反序列化(JSON 字符串→ Java对象)。 以下是 Jackson 的一些主要特点和功能: 高性能:Jackson 通过使用基于流的处理…

Ubuntu安装QQ教程

Ubuntu安装QQ教程 腾讯更新Linux版的QQ,这里安装一下; 首先,进入官网找到合适对应的安装包; QQLinux版本官网:https://im.qq.com/linuxqq/index.shtml 我们是ubuntu系统选择X86下的deb版本,如果是arm开…

TikTok养号的网络环境及相关代理IP知识

TikTok作为一个流行的短视频分享平台,其用户量非常庞大,很多商家和个人都会使用TikTok来进行引流和推广。由于TikTok的规则和政策限制了每个用户每天发布视频的数量,因此许多用户会使用多个账号来发布更多的视频以提高曝光率。 然而&#xff…

Android Studio Build窗口出现中文乱码问题

刚安装成功的android studio软件打开工程,编译时下方build窗口中中文是乱码。 解决: 可点击studio状态栏的Help—>Edit Custom VM Options ,在打开的studio64.exe.vmoptions文件后面添加:(要注意不能有空格,否则st…

FL Studio Producer Edition 21.2.3.4004中文直版及FL Studio 204如何激活详细教程

在数字化音乐制作的浪潮中,FL Studio 24.1.1.4234的发布无疑又掀起了一股新的热潮。这款由Image-Line公司开发的数字音频工作站(DAW)软件,以其强大的功能和易用的界面,赢得了全球无数音乐制作人的青睐。本文将深入探讨…

python 图片转文字、语音转文字、文字转语音保存音频并朗读

一、python图片转文字 1、引言 pytesseract是基于Python的OCR工具, 底层使用的是Google的Tesseract-OCR 引擎,支持识别图片中的文字,支持jpeg, png, gif, bmp, tiff等图片格式 2、环境配置 python3.6PIL库安装Google Tesseract OCR 3、安…

使用axios请求后端的上传图片接口

安装axios npm install axios 创建input文件上传标签 <input type"file" name"" id"" change"handleChange" /> 使用axios请求后端的图片上传接口 function handleChange(val) {// new FormData() js内置构造函数&#xff0c…

C++入门基础:C++中的常用操作符练习

开头介绍下C语言先&#xff0c;C是一种广泛使用的计算机程序设计语言&#xff0c;起源于20世纪80年代&#xff0c;由比雅尼斯特劳斯特鲁普在贝尔实验室开发。它是C语言的扩展&#xff0c;增加了面向对象编程的特性。C的应用场景广泛&#xff0c;包括系统软件、游戏开发、嵌入式…

【Nginx】Mac电脑安装nginx

使用brew安装nginx brew install nginx查看nginx信息 brew info nginx启动nginx brew services start nginx验证是否启动成功 浏览器输入地址&#xff1a;127.0.0.1:8080 停止服务 brew services stop nginx进入nginx文件目录 cd /opt/homebrew/etc/nginx重启服务 bre…

轻量化YOLOv7系列:结合G-GhostNet | 适配GPU,华为诺亚提出G-Ghost方案升级GhostNet

轻量化YOLOv7系列&#xff1a;结合G-GhostNet | 适配GPU&#xff0c;华为诺亚提出G-Ghost方案升级GhostNet 需要修改的代码models/GGhostRegNet.py代码 创建yaml文件测试是否创建成功 本文提供了改进 YOLOv7注意力系列包含不同的注意力机制以及多种加入方式&#xff0c;在本文…

分布式光伏并网AM5SE-IS防孤岛保护装置介绍——安科瑞 叶西平

产品简介 功能&#xff1a; AM5SE-IS防孤岛保护装置主要适用于35kV、10kV及低压380V光伏发电、燃气发电等新能源并网供电系统。当发生孤岛现象时&#xff0c;可以快速切除并网点&#xff0c;使本站与电网侧快速脱离&#xff0c;保证整个电站和相关维护人员的生命安全。 应用…

【简历】吉林某一本大学:JAVA秋招简历指导,简历通过率比较低

注&#xff1a;为保证用户信息安全&#xff0c;姓名和学校等信息已经进行同层次变更&#xff0c;内容部分细节也进行了部分隐藏 简历说明 这是一份吉林某一本大学25届计算机专业同学的Java简历。因为学校是一本&#xff0c;所以求职目标以中厂为主。因为学校背景在中厂是正常…