模型优化之剪枝

news2025/1/11 2:44:21

文章目录

  • 什么是神经网络剪枝
  • 剪枝的好处
  • 不同粒度的剪枝
  • 剪枝的分类
    • 非结构化剪枝
    • 结构化剪枝
  • 哪些层的参数更容易被剪掉
  • 剪枝效果

什么是神经网络剪枝

神经网络剪枝

  • 在训练期间删除连接
  • 密集张量将变得稀疏(用零填充)
  • 可以通过结构化块( n m nm nm)或( 11 11 11)删除连接

在这里插入图片描述

剪枝的好处

  • 减少过拟合
  • 稀疏性优势
  • 文件中有大量的0,如果有适当的稀疏张量表示方法,模型二进制文件尺寸减小。
  • 模型更小,可以减少内存带宽消耗量。
  • 对于特定模式的稀疏模型,可以开发优化算子,实现加速推理。

不同粒度的剪枝

在这里插入图片描述
什么时候做剪枝?

one-shot pruning : 一次性修剪,包括三个步骤训练模型、剪枝、再训练
剪枝:通常根据某种标准(如权重的大小、梯度的大小等)一次性去除大量权重。
再训练:剪枝后,模型通常需要进行一定数量的额外训练(称为fine-tuning或再训练)来恢复剪枝过程中可能损失的性能。

iterative pruning: 迭代式训练,特点如下:
初始训练:首先,对未剪枝的完整模型进行训练,直到达到满意的性能水平。
剪枝:然后,根据某种剪枝策略(例如基于权重的大小或敏感度)剪除模型的部分组件(如权重、神经元或通道)。
再训练:剪枝后,重新训练模型以恢复因剪枝而丢失的性能。
迭代:重复剪枝和再训练的过程,直到达到所需的剪枝率或性能标准。

automated gradual pruning: 自动化渐进剪枝,特点如下:
剪枝策略:采用一种预定义的剪枝策略,例如基于权重阈值、敏感度分析等,该策略在整个剪枝过程中保持一致。
渐进剪枝:在整个训练过程中逐渐增加剪枝率,通常从较低的剪枝率开始,逐步增加到目标剪枝率。
无需再训练:在整个剪枝过程中,模型持续被训练,而不是在剪枝后重新训练。
自动化:整个过程高度自动化,可以减少人为干预的需求

在这里插入图片描述

剪枝的分类

结构化剪枝(Structured Pruning)和非结构化剪枝(Unstructured Pruning)是两种常见的神经网络剪枝方法,它们的主要区别在于剪枝后网络结构的变化以及剪枝操作的粒度。

非结构化剪枝

不改变网络结构或者参数数量,把连接上的参数置0即为剪枝。
基于某种度量(如权重的绝对值大小)对所有权重进行排序,然后根据预先设定的剪枝比例(例如去除50%的最小权重)来决定哪些权重被设置为零。这种剪枝方法不会考虑权重在模型中的位置或结构,只关注权重本身的价值。示例代码:

# 导入剪枝函数
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# 计算两轮之后完成剪枝时对应的迭代次数end_step
batch_size = 128
epochs = 2
validation_split = 0.1  # 10% of training set will be used for validation set.

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# 定义剪枝模型参数,开始模型从50%稀疏度(权重为0的参数数量百分比),到80%稀疏度
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                             final_sparsity=0.80,
                                                             begin_step=0,
                                                             end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# 当使用函数`prune_low_magnitude`包装了一下模型后,需要重新编译一下
model_for_pruning.compile(optimizer='adam',
                          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                          metrics=['accuracy'])

model_for_pruning.summary()

logdir = "./logs/mnist_pruning"

callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images, train_labels,
                      batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                      callbacks=callbacks)
# --------------------------------------------------
# 评估模型,对比剪枝前后模型的准确率变化
# 经过剪枝,这里有一个小的准确率下降,和没有进行剪枝相比的话
# --------------------------------------------------

_, model_for_pruning_accuracy = model_for_pruning.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', model_for_pruning_accuracy)

结构化剪枝

结构化剪枝改变了网络结构,即网络层输出元素个数,比如卷积核的减少会影响特征图数量。
在下面的例子中是基于选择的模型层做剪枝,所以需要指出哪些层去做结构化剪枝。比如剪枝第二个卷积层和第一个全连接层,剪枝策略为pruning_params_2_by_4,表示该层剪枝比例为2 / 4,即该层保留一半(2/4)的权重,而将另一半设为零。
注意:第一个卷积层不能被结构化剪枝。要是结构化剪枝的话,应该至少大于一个input channels(本例所用图片为单通道灰度图),所以我们对第一个卷积层使用随机剪枝。

model = keras.Sequential([
    prune_low_magnitude(
        keras.layers.Conv2D(
            32, 5, padding='same', activation='relu',
            input_shape=(28, 28, 1),
            name="pruning_sparsity_0_5"),
        **pruning_params_sparsity_0_5),
    keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),
    prune_low_magnitude(
        keras.layers.Conv2D(
            64, 5, padding='same',
            name="structural_pruning"),
        **pruning_params_2_by_4),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),
    keras.layers.Flatten(),
    prune_low_magnitude(
        keras.layers.Dense(
            1024, activation='relu',
            name="structural_pruning_dense"),
        **pruning_params_2_by_4),
    keras.layers.Dropout(0.4),
    keras.layers.Dense(10)
])

哪些层的参数更容易被剪掉

因为卷积层(conv)中的参数相比全连接层(fc)来说参数量少,所以卷积层参数的压缩比没有全连接层参数的压缩比大。换句话说,就是卷积层参数更加敏感,剪掉对准确率影响相对更大。越靠后的卷积层或卷积层之后的那些全连接层往往参数越容易被剪掉。

剪枝效果

  • 一般50%-70%左右的稀疏性,准确率降低幅度并不大
  • 剪枝是独立于量化技巧,通常与量化配合效果不错
  • 可以通过微调尝试不同的参数组合

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2053563.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【无标题】playbook的基本使用

1、使用ansible安装并启动ftp服务 [root1 ~]# vim /etc/ansible/hosts s0 ansible_ssh_host10.0.0.12 ansible_ssh_port22 ansible_ssh_userroot ansible_ssh_pass1 s1 ansible_ssh_host10.0.0.13 ansible_ssh_port22 ansible_ssh_userroot ansible_ssh_pass1 s2 ansible_s…

Android 12系统源码_屏幕设备(二)DisplayAdapter和DisplayDevice的创建

前言 在Android 12系统源码_屏幕设备(一)DisplayManagerService的启动这篇文章中我们具体分析了DisplayManagerService 的启动流程,本篇文章我们将在这个的基础上具体来分析下设备屏幕适配器的创建过程。 一、注册屏幕适配器 系统是在Disp…

43.x86游戏实战-DXX寻找吸怪坐标

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 本次游戏没法给 内容参考于:微尘网络安全 工具下载: 链接:https://pan.baidu.com/s/1rEEJnt85npn7N38Ai0_F2Q?pwd6tw3 提…

Xshell中弹出“ssh服务器拒绝了密码请再试一次”时,如何解决

在使用Xshell连接Ubuntu系统时,可能会弹出这个错误 可能原因如下​ 密码输入错误Ubantu系统默认禁止root用户登录ssh。 解决方法: 1. 先用root登录 (由于我买的是云服务器,所以拿这个来举例) 注:要在本地shell中登录…

基于jqury和canvas画板技术五子棋游戏设计与实现(论文+源码)_kaic

摘 要 网络五子棋游戏如今面临着一些新的挑战和机遇。一方面,网络游戏需要考虑到网络延迟和带宽等因素,保证游戏的实时性和稳定性。另一方面,网络游戏需要考虑到游戏的可玩性和趣味性,以吸引更多的玩家参与。本文基于HTML5和Canv…

银河麒麟V10忘记Root密码怎么办?

银河麒麟V10忘记Root密码怎么办? 一:进入GRUB模式二:输入GRUB账号密码三:修改启动参数四:修改root密码五:重启系统六:验证root密码 💖The Begin💖点点关注,收…

就想刷题过?新手必看的华为认证题库最强背题经验技巧

华为认证作为网络和IT领域的重要资格认证,其难度不容小觑。许多考生为了顺利通过考试,选择背题库作为备考策略。 (重点说一下啊,不提倡刷题,能学知识,把技术学透,肯定是最佳的。) …

Java基于数据库、乐观锁、悲观锁、Redis、Zookeeper分布式锁的简单案例实现(保姆级教程)

1. 分布式锁的定义 分布式锁是一种在分布式系统中用来协调多个进程或线程对共享资源进行访问的机制。它确保在分布式环境下,多个节点(如不同的服务器或进程)不会同时访问同一个共享资源,从而避免数据不一致、资源竞争等问题。 2…

简单记录:两台服务器如何超快速互传文件/文件夹

在服务器间传输文件和文件夹是一个常见的任务,尤其是在需要同步数据或进行备份时。以下是使用 scp 命令在两台服务器之间进行文件传输的基本步骤。 服务器A 至 服务器B:文件传输指南 前提条件 确保服务器A和服务器B之间网络互通。确认您有权限访问目标…

如何让孩子喜欢上读书?

1.选择合适的书籍:根据孩子的兴趣和年龄选择合适的书籍,让孩子参与选书的过程,这样可以增加他们对阅读的主动性和兴趣。同时,避免过分强调阅读的功利性,让孩子自由选择他们感兴趣的书籍。   2.定期的阅读时间和活动&…

谷粒商城实战笔记-211~212-商城业务-认证服务-环境搭建

这一部分的主要内容是开发商城的认证服务。 文章目录 一,211-商城业务-认证服务-环境搭建1,创建模块2,引入相关依赖3,各种配置3.1 注册中心配置3.2 启用注册中心3.3 本节域名配置 4,页面模板4.1 html模板4.2 静态资源上…

python---数据可视化篇

目录 1.matplotlib简介 2.安装并且导入对应的模块 3.设置中文字体 4.创建画布 5.绘制折线图 6.对于折线图的美化 7.散点图的绘制 8.双y轴叠加图 9.簇形柱状图 10.百分比堆积柱状图 11.绘制多个子图(一个画布上面) 1.matplotlib简介 matplotl…

C盘扩容遇到恢复分区怎么办?

文章目录 1.0 问题描述2.0 了解恢复分区是啥3.0 恢复分区可以删除吗?(需确认好!)4.0 删除恢复分区(需要谨慎操作)4.0.1 管理员打开CMD4.0.2 查看磁盘 给C盘扩容 1.0 问题描述 想要给C盘扩容,但…

Hyper-v ubuntu22 上外网方法

1. 前置步骤 步骤一,首先新建一个虚拟网络交换机,我这里名称为vEthernet (hyper-v-ubuntu),选【内部网络】 步骤二, 在网络设置中,找到可以上网的网卡,这里我用的是无线网卡WLAN,设置共享连接…

【SpringBoot】SpringBoot的运行原理

SpringBoot项目中都有一个如下的启动类。 SpringBootApplication public class MyApplication {public static void main(String[] args) {SpringApplication.run(MyApplication.class,args);} }其中SpringBootApplication是这个启动类的核心注解,在它下面又有三个子…

Spring Cloud Gateway动态路由及路由插件实现方案

前言 sim-framework之前使用Zuul作为网关,结合Eureka实现了动态路由及灰度路由,但是存在以下几个问题: 性能问题:Zuul基于线程隔离,一个请求需要一个线程处理,而Gateway基于事件驱动,少量线程…

Go项目布局

Go项目布局,自举语言,源码是靠Go自己实现的 所以Go源码可以参考作为项目布局 源码放在src目录下 cmd放main internal目录下放不希望外部访问的代码(业务) common目录下可以放直接 import外部访问的 etc放配置文件yaml

第二届海南大数据创新应用大赛 - 算法赛道冠军比赛攻略_海南新境界队

关联比赛: 第二届海南大数据创新应用大赛 - 智能算法赛 第二届海南大数据创新应用大赛 - 算法赛道冠军比赛攻略 首先很幸运能拿到这次初赛冠军,本着积极学习和提升自我的态度,团队成员通力合作是获胜关键,再次感谢。 赛题背景分析和理解 …

gpio的使用----->4412的裸机的使用(第三节)

这一节主要是 4412 的裸机的使用 0 4412 的硬件原理图 数据手册 然后是数据手册的解析: 每一组都有这几个 寄存器。 需要注意: 1、 4412 的中断是 与输入,输出在同一个级别的,与stm32不同。 2、 我是在uboot 上进行编程的&#x…

重头开始嵌入式第二十二天(Linux系统编程 进程)

进程 目录 进程 1.进程的概念 2.PCB(process control block) 3.进程和程序有什么区别? 4.进程的内存分布 5.进程的分类 守护进程 6.进程的作用 7.进程的状态 8.进程的调度 9.查询进程的相关指令 1.ps aux 2.top 3.kill和killa…