2024强化学习的结构化剪枝模型RL-Pruner原理及实践

news2025/1/12 1:11:14

[2024] RL-Pruner: Structured Pruning Using Reinforcement Learning for CNN Compression and Acceleration

目录

  • [2024] RL-Pruner: Structured Pruning Using Reinforcement Learning for CNN Compression and Acceleration
    • 一、论文说明
    • 二、原理
    • 三、实验与分析
      • 1、环境配置
        • 在Windows配置git bash链接conda环境
      • 2、项目代码运行
        • 1、训练预训练权重
        • 2、模型压缩
        • 3、模型验证
    • 四、总结

一、论文说明

论文标题:使用强化学习进行结构化剪枝用于卷积神经网路压缩和加速

机构:伊利诺伊大学厄巴纳-香槟分校

论文链接:https://arxiv.org/pdf/2411.06463

代码链接:https://github.com/Beryex/RLPruner-CNN

论文简介: 卷积神经网络(ConvolutionalNeural Networks, CNNs)近年来表现出卓越的性能。压缩这些模型不仅减少了存储需求,使其在边缘设备上的部署变得可行,还加速了推理,从而降低了延迟和计算成本。结构化剪枝,它在层级上去除过滤器,直接修改了模型架构。这种方法实现了更紧凑的架构,同时保持目标准确性,确保压缩模型具有较好的兼容性和硬件效率。所提方法基于一个关键观察:

  • 1、神经网络中不同层的过滤器对模型性能的重要性各不相同。

  • 2、当修剪的过滤器数量固定时,不同层之间的最佳修剪分配是不均匀的,以最小化性能损失

  • 3、对修剪敏感的层应该占据更小的修剪分配比例。

为了利用这一洞察,文中提出了RL-Pruner,它使用强化学习来学习最佳修剪分配。RL-Pruner可以自动提取输入模型中过滤器之间的依赖关系并执行修剪,无需特定于模型的修剪实现。在GoogleNet、ResNet和MobileNet 等模型上进行了实验,将所提方法与其他结构化剪枝方法进行了比较,以验证其有效性。

在这里插入图片描述

二、原理

RL-Pruner 首先在模型中的层之间构建依赖图,然后分几个步骤进行剪枝。在每个步骤中:1) 基于基础分布生成一个新的剪枝稀疏分布作为动作 ,这作为策略;2)根据相应的稀疏度,使用泰勒准则(Taylorcriterion)对每一层进行剪枝;3) 评估压缩后的模型以获得奖励,并将动作和奖励存储在回放(replay)经验池中。每个步骤后,基础分布根据经验池更新,如果计算资源足够,则对压缩模型应用后训练阶段,使用知识蒸馏(knowledge distillation),其中原始模型作为教师。具体框图如图2所示。

三、实验与分析

1、环境配置

实验平台及软件

  • Windows 10
  • git bash
  • conda环境

这里主要介绍如何在windows系统上让git bash链接conda环境。

在Windows配置git bash链接conda环境

由于工程代码中需要使用bash命令运行代码,因此需要保证git bash能调用conda环境运行对应的脚本文件。

C:\Users\username\.bashrc文件内设置conda.sh位置(文中示例为:D:\\Anaconda3\\etc\\profile.d\\conda.sh),并激活配置。在git bash界面输入具体命令如下:

echo "D:\\Anaconda3\\etc\\profile.d\\conda.sh" >> ~/.bashrc  
source ~/.bashrc

然后关闭git bash界面,再重新打开一个git bash界面,最后输入命令激活conda环境conda activate 虚拟环境名字。如果命令提示中出现如下图所示的字样,即为配置成功,否则根据提示的要求进行配置,比如输入conda init,重新打开一个新的git bash界面。

在这里插入图片描述

2、项目代码运行

克隆项目文件,具体命令如下:

git clone https://github.com/Beryex/RLPruner-CNN.git --depth 1
cd RLPruner-CNN

安装python第三方包,具体命令如下(如果之前有conda环境,可以不用进行下面这一步,等报错了再根据提示安装对应的包即可):

conda create -n RLPruner python=3.10 -y
conda activate RLPruner
pip install -r requirements.txt

官方代码提供了一步到位的运行脚本,从预训练模型、模型压缩到模型验证,仅需在命令行中输入如下代码:

./scripts/flexible.sh googlenet cifar100 0.20 taylor 0.00 0.00

为了更好地了解每一步的设置,下面内容将分为预训练模型、模型压缩、模型验证三个步骤进行介绍。

1、训练预训练权重

训练模型得到对应的预训练权重,这里以resnet32googlenet为例,在git bash输入具体命令(默认使用cuda)如下:

./scripts/train.sh googlenet cifar100
./scripts/train.sh resnet32 cifar100

或者使用参考指定配置命令:

python -m train --model ${MODEL} --dataset ${DATASET} --device cuda \
                --output_dir ${PRETRAINED_MODEL_DIR} \
                --log_dir ${LOG}

其中,
${MODEL}代表backbone的类型([“vgg11”, “vgg13”, “vgg16”, “vgg19”, “resnet18”, “resnet34”, “resnet50”, “resnet101”, “resnet152”, “resnet8”, “resnet14”, “resnet20”, “resnet32”, “resnet44”, “resnet56”, “resnet110”, “densenet121”, “densenet161”, “densenet169”, “densenet201”, “mobilenetv3_small”, “mobilenetv3_large”, “googlenet”]);
${DATASET}代表数据集名称,如cifar10或者cifar100。
${PRETRAINED_MODEL_DIR}代表输出权重文件路径,默认在pretrained_model文件夹下;
${LOG}代表输出日志路径,默认在log文件夹下。

在CIFAR100数据集上训练resnet32的结果(最佳准确率:0.706)如下图所示。
在这里插入图片描述
在CIFAR100数据集上训练googlenet的结果(最佳准确率:0.774)如下图所示。
在这里插入图片描述

2、模型压缩

模型结构化剪枝这里以0.2的稀疏度,taylor剪枝策略和Q_FLOP_coef=0,Q_Para_coef=0的参数进行测试。在git bash输入具体命令(默认使用cuda)如下:

./scripts/flexible.sh googlenet cifar100 0.20 taylor 0.00 0.00

同理,也可以使用参考指定配置命令:

python -m compress --model ${MODEL} --dataset ${DATASET} --device cuda \
                   --sparsity ${SPARSITY} --prune_strategy ${prune_strategy} --ppo \
                   --Q_FLOP_coef ${Q_FLOP_coef} --Q_Para_coef ${Q_Para_coef} \
                   --pretrained_pth ${PRETRAINED_MODEL_PTH} \
                   --compressed_dir ${COMPRESSED_MODEL_DIR} \
                   --checkpoint_dir ${CKPT_DIR} \
                   --log_dir ${LOG} --save_model

测试结果如下图所示:
在这里插入图片描述

3、模型验证

在数据集上验证模型的识别性能,在git bash输入具体命令(默认使用cuda)如下:

./scripts/evaluate.sh googlenet cifar100

同理,也可以使用参考指定配置命令:

python -m evaluate --model ${MODEL} --dataset ${DATASET} --device cuda \
                   --pretrained_pth ${PRETRAINED_MODEL_PTH} \
                   --compressed_pth ${COMPRESSED_MODEL_PTH} \
                   --log_dir ${LOG}

测试结果如下图所示:
在这里插入图片描述

四、总结

本文提出了RL-Pruner,一种结构化剪枝方法,能够学习各层之间的最优稀疏性分布,并支持无模型特定修改的一般剪枝。希望所提方法能够认识到每一层对模型(model)性能的重要性不同,这将影响未来在神经网络压缩领域的工作,包括无结构剪枝和量化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2244856.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电脑超频是什么意思?超频的好处和坏处

嗨,亲爱的小伙伴!你是否曾经听说过电脑超频?在电脑爱好者的圈子里,这个词似乎非常熟悉,但对很多普通用户来说,它可能还是一个神秘而陌生的存在。 今天,我将带你揭开超频的神秘面纱,…

uniapp: vite配置rollup-plugin-visualizer进行小程序依赖可视化分析减少vender.js大小

一、前言 在之前文章《uniapp: 微信小程序包体积超过2M的优化方法(主包从2.7M优化到1.5M以内)》中,提到了6种优化小程序包体积的方法,但并没有涉及如何分析common/vender.js这个文件的优化,而这个文件的大小通常情况下…

SQL Server Management Studio 的JDBC驱动程序和IDEA 连接

一、数据库准备 (一)启用 TCP/IP 协议 操作入口 首先,我们要找到 SQL Server 配置管理器,操作路径为:通过 “此电脑” 右键选择 “管理”,在弹出的 “计算机管理” 窗口中,找到 “服务和应用程…

STM32F103系统时钟配置

时钟是单片机运行的基础,时钟信号推动单片机内各个部分执行相应的指令。时钟系统就是CPU的脉搏,决定CPU速率,像人的心跳一样 只有有了心跳,人才能做其他的事情,而单片机有了时钟,才能够运行执行指令&#x…

鸿蒙进阶篇-Math、Date

“在科技的浪潮中,鸿蒙操作系统宛如一颗璀璨的新星,引领着创新的方向。作为鸿蒙开天组,今天我们将一同踏上鸿蒙基础的探索之旅,为您揭开这一神奇系统的神秘面纱。” 各位小伙伴们我们又见面了,我就是鸿蒙开天组,下面让我们进入今…

RAID存储技术 详解

RAID(Redundant Array of Independent Disks,独立磁盘冗余阵列)是一种将多个物理硬盘组合为一个逻辑存储单元的技术。它通过分布数据、冗余校验和容错能力,提高存储系统的性能、可靠性和容量利用率。 以下从底层原理和源代码层面…

MTK主板定制_联发科主板_MTK8766/MTK8768/MTK8788安卓主板方案

主流市场上的MTK主板通常采用联发科的多种芯片平台,如MT8766、MT6765、MT6762、MT8768和MT8788等。这些芯片基于64位Cortex-A73/A53架构,提供四核或八核配置,主频可达2.1GHz,赋予设备卓越的计算与处理能力。芯片采用12纳米制程工艺…

免费微调自己的大模型(llama-factory微调llama3.1-8b)

目录 1. 名词/工具解释2. 微调过程3. 总结 本文主要介绍通过llama-factory框架,使用Lora微调方法,微调meta开源的llama3.1-8b模型,平台使用的是趋动云GPU算力资源。 微调已经经过预训练的大模型目的是,通过调整模型参数和不断优化…

MySQL 中 InnoDB 支持的四种事务隔离级别名称,以及逐级之间的区别?

MySQL中的InnoDB存储引擎支持四种事务隔离级别,这些级别定义了事务在并发环境中的行为和相互之间的可见性。以下是这四种隔离级别的名称以及它们之间的区别: 读未提交(Read Uncommitted) 特点:这是最低的隔离级别&…

【YOLOv10改进[注意力]】引入并行分块注意力PPA(2024.3.16) + 适于微小目标

本文将进行在YOLOv10中引入并行分块注意力PPA魔改v10 的实践,文中含全部代码、详细修改方式。助您轻松理解改进的方法。 一 HCF 论文题目:Hierarchica

共建智能软件开发联合实验室,怿星科技助力东风柳汽加速智能化技术创新

11月14日,以“奋进70载,智创新纪元”为主题的2024东风柳汽第二届科技周在柳州盛大开幕,吸引了来自全国的汽车行业嘉宾、技术专家齐聚一堂,共襄盛举,一同探寻如何凭借 “新技术、新实力” 这一关键契机,为新…

在ubuntu下,使用Python画图,无法显示中文怎么解决

1.首先需要下载中文字体,推荐simsun,即宋体,地址如下 https://www.freefonts.io/download/simsun/ 2.下载完要把字体文件放进字体目录,具体方法如下; a.创建字体目录:sudo mkdir -p /usr/share/fonts/truet…

鸿蒙实战:使用显式Want启动Ability

文章目录 1. 实战概述2. 实现步骤2.1 创建鸿蒙应用项目2.2 修改Index.ets代码2.3 创建SecondAbility2.4 创建Second.ets 3. 测试效果4. 实战总结5. 拓展练习 - 启动文件管理器5.1 创建鸿蒙应用项目5.2 修改Index.ets代码5.3 测试应用运行效果 1. 实战概述 本实战详细阐述了在 …

《Python浪漫的烟花表白特效》

一、背景介绍 烟花象征着浪漫与激情,将它与表白结合在一起,会创造出别具一格的惊喜效果。使用Python的turtle模块,我们可以轻松绘制出动态的烟花特效,再配合文字表白,打造一段专属的浪漫体验。 接下来,让…

springboot中设计基于Redisson的分布式锁注解

如何使用AOP设计一个分布式锁注解&#xff1f; 1、在pom.xml中配置依赖 <dependency><groupId>org.springframework</groupId><artifactId>spring-aspects</artifactId><version>5.3.26</version></dependency><dependenc…

绕过CDN寻找真实IP

在新型涉网案件中&#xff0c;我们在搜集到目标主站之后常常需要获取对方网站的真实IP去进一步的信息搜集&#xff0c;但是现在网站大多都部署了CDN&#xff0c;将资源部署分发到边缘服务器&#xff0c;实现均衡负载&#xff0c;降低网络堵塞&#xff0c;让用户能够更快地访问自…

【Redis】redis缓存击穿,缓存雪崩,缓存穿透

一、什么是缓存&#xff1f; 缓存就是与数据交互中的缓冲区&#xff0c;它一般存储在内存中且读写效率高&#xff0c;提高响应时间提高并发性能&#xff0c;如果访问数据的话可以先访问缓存&#xff0c;避免数据查询直接操作数据库&#xff0c;造成后端压力过大。 但是可能会面…

linux复习5:C prog

编辑 缩排 为了使C源代码更加整洁易读&#xff0c;可以使用一些工具来自动格式化代码&#xff0c;例如cb&#xff08;C程序美化器&#xff09;、bcpp&#xff08;C美化器&#xff09;和indent等。 编译 编译并链接C文件 gcc hello.c -o hello 将 hello.c 编译并链接成可执行文…

uni-app快速入门(十)--常用内置组件(下)

本文介绍uni-app的textarea多行文本框组件、web-view组件、image图片组件、switch开关组件、audio音频组件、video视频组件。 一、textarea多行文本框组件 textarea组件在HTML 中相信大家非常熟悉&#xff0c;组件的官方介绍见&#xff1a; textarea | uni-app官网uni-app,un…

世界坐标系、相机坐标系、图像物理坐标系、像素平面坐标系

坐标系及其转换在计算机视觉领域占据核心地位。理解如何从一个坐标系转换到另一个坐标系&#xff0c;不仅是理论上的需要&#xff0c;也是实际应用中不可或缺的技能。 一、世界坐标系的定义 世界坐标系是一个全局的坐标系统&#xff0c;用于定义场景中物体的位置。在这个坐标…