模型优化【2】-剪枝[局部剪枝]

news2024/11/19 10:32:11

模型剪枝是一种常见的模型压缩技术,它可以通过去除模型中不必要的参数和结构来减小模型的大小和计算量,从而提高模型的效率和速度。在 PyTorch 中,我们可以使用一些库和工具来实现模型剪枝。

pytorch实现剪枝的思路是生成一个掩码,然后同时保存原参数、mask、新参数,如下图:
在这里插入图片描述

Pytorch实现模型剪枝的基本步骤

  1. 加载模型:我们首先需要加载一个已经训练好的模型,可以使用 PyTorch 提供的模型库或者自己训练的模型。

  2. 定义剪枝方法:我们需要定义一种剪枝方法,来决定哪些参数和结构需要被剪枝。

  3. 执行剪枝操作:我们需要执行剪枝操作,将不必要的参数和结构从模型中去除。

  4. 保存剪枝后的模型:我们需要将剪枝后的模型保存下来,以便后续使用。

pytorch 剪枝分为 局部剪枝、全局剪枝、自定义剪枝;

局部剪枝

局部剪枝是指在什么网络的单个层或局部范围内进行剪枝。
Pytorch中与剪枝有关的接口封装在torch.nn.utils.prune中。下面开始演示三种剪枝在LeNet网络中的应用效果,首先给出LeNet网络结构。

加载模型

import torch
from torch import nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 是经历卷积操作后的图片尺寸
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
model = LeNet().to(device=device)

局部剪枝实验,假定对模型的第一个卷积层中的权重进行剪枝

# 打印输出剪枝前的参数
module = model.conv1
print(list(module.named_parameters()))
print(list(module.buffers()))
print(module.weight)

运行结果

[('weight', Parameter containing:
tensor([[[[ 0.1158, -0.0091, -0.2742],
          [-0.1132,  0.1059, -0.0381],
          [ 0.0430, -0.1634, -0.1345]]],
		
		...

        [[[-0.0226,  0.2091, -0.1479],
          [ 0.2302, -0.0988,  0.2117],
          [-0.2000, -0.2531,  0.2770]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.2658,  0.2096, -0.2639, -0.3063, -0.1453,  0.1201], device='cuda:0',
       requires_grad=True))]
[]
Parameter containing:
tensor([[[[ 0.1158, -0.0091, -0.2742],
          [-0.1132,  0.1059, -0.0381],
          [ 0.0430, -0.1634, -0.1345]]],


        ...


        [[[-0.0226,  0.2091, -0.1479],
          [ 0.2302, -0.0988,  0.2117],
          [-0.2000, -0.2531,  0.2770]]]], device='cuda:0', requires_grad=True)

定义剪枝+执行剪枝

# 修剪是从 模块 中 删除 参数(如 weight),并用 weight_orig 保存该参数
# random_unstructured 是一种裁剪技术,随机非结构化裁剪
# 第一个参数:modeul,代表要进行剪枝的特定模型,之前我们已经制定了module=module.conv1,说明这里要对第一个卷积层执行剪枝
# 第二个参数:name,指定要对选中模块中的那些参数执行剪枝,这里设定为name='weight',意味着对连接网络的weight剪枝,而不死bias剪枝
# 第三个参数:amount,指定要对模型中的多大比例的参数执行剪枝,amount是一个介于0.0~1.0的float数值,或者一个正整数指定裁剪多少条连接边。
prune.random_unstructured(module, name="weight", amount=0.3)      # weight    bias
print(list(module.named_parameters()))
# 通过修剪技术会创建一个mask命名为 weight_mask 的模块缓冲区
print(list(module.named_buffers()))

# 经过裁剪操作后的模型,原始的参数存放在了weight-orig中,
# 对应的剪枝矩阵存放在weight-mask中,而将weight-mask视作掩码张量
# 再和weight-orig相乘的结果就存放在了weight中
print(module.weight)
print(module.bias)

运行结果

[('bias', Parameter containing:
tensor([ 0.1303,  0.1208, -0.0989, -0.0611, -0.1103, -0.2433], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-1.1443e-01,  3.2276e-01, -2.4664e-02],
          [ 4.6659e-02,  1.8311e-01,  6.6681e-02],
          [-2.5493e-01, -1.1471e-01,  2.8336e-01]]],


       ...

        [[[ 1.4041e-01,  2.0963e-02,  2.2884e-01],
          [ 3.5870e-02,  7.5861e-02,  8.4728e-02],
          [ 4.1965e-02, -1.2838e-01,  8.8462e-02]]]], device='cuda:0',
       requires_grad=True))]
[('weight_mask', tensor([[[[1., 0., 0.],
          [1., 0., 0.],
          [0., 0., 0.]]],


        ...


        [[[1., 0., 0.],
          [1., 0., 1.],
          [0., 1., 0.]]]], device='cuda:0'))]
tensor([[[[-1.1443e-01,  0.0000e+00, -0.0000e+00],
          [ 4.6659e-02,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00,  0.0000e+00]]],


        ...

        [[[ 1.4041e-01,  0.0000e+00,  0.0000e+00],
          [ 3.5870e-02,  0.0000e+00,  8.4728e-02],
          [ 0.0000e+00, -1.2838e-01,  0.0000e+00]]]], device='cuda:0',
       grad_fn=<MulBackward0>)
Parameter containing:
tensor([ 0.1303,  0.1208, -0.0989, -0.0611, -0.1103, -0.2433], device='cuda:0',
       requires_grad=True)

保存剪枝后的模型

# 保存剪枝后的模型
torch.save(model.state_dict(), 'pruned_model.pth')

模型经历剪枝以后,原始的权重矩阵weight参数不见了,变成了weight_orig。并且剪枝前打印为空的列表module.name_buffers(),此时拥有了一个weight_mask参数。经过剪枝操作后的模型,原始的参数存放在了weight_orig中,对应的剪枝矩阵存在weight_mask中,而将weight_mask视作掩码张量,再和weight_orig相乘的结果就存在了weight中。

Q1:打印经过剪枝处理的 weight 参数。这个 weight 实际上是原始的 weight_orig 和 weight_mask 的元素乘积,其中被剪枝的权重会被设置为0。这个weight不是剪枝了嘛?为什么还能打印出来?

答:在Pytorch的剪枝过程中,当我们说剪枝一个权重的参数时,并不是真的从网络中移除这些参数,而是通过一个掩码来“禁用”它们。这是通过将某些权重的值设为0来实现的,从而在网络的前向传播中这些权重不会有任何作用,这种方法允许我们在保留原始权重信息的同时,实现剪枝的效果。

在 PyTorch 的剪枝过程中,当我们说“剪枝”一个权重参数时,并不是真的从网络中移除这些权重,而是通过应用一个掩码来“禁用”它们,为什么禁用就可以达到模型压缩的目的?为什么剪枝完,执行print(list(module.named_parameters())),没有显示weight属性,但是执行print(module.weight)时,weight依然存在?

  • 为什么“禁用”权重可以达到模型压缩的目的?
    虽然剪枝后的权重仍然占据内存空间,但在实际计算中,值为0的权重不会对前向传播产生任何影响。这意味着在计算层面可以忽略这些权重,从而减少计算量。
  • 为什么 print(list(module.named_parameters())) 没有显示 weight 属性,但执行 print(module.weight) 时 weight 依然存在?
    • 修改参数列表:当执行剪枝操作时,PyTorch 会修改模块的参数列表。原始的 weight 参数被重命名为 weight_orig,并且创建了一个新的名为 weight_mask 的缓冲区。原始的 weight 参数(现在是 weight_orig)和 weight_mask 通过一个钩子(hook)相结合,生成了新的 weight 属性。

    • 动态权重生成:在调用 module.weight 时,由于剪枝过程中添加的前向钩子,weight 参数是动态生成的,它是 weight_orig 和 weight_mask 的元素乘积。因此,尽管 weight 在 named_parameters 列表中看起来已经不存在,但它实际上是在运行时动态生成的。

    • 参数与属性的区别:在 PyTorch 中,模块的参数(可通过 named_parameters 访问)和模块的属性(如直接通过 module.weight 访问)是不同的。module.weight 被视为一个可访问的属性,但由于剪枝过程的内部处理,它可能不再直接列在模块的参数列表中。

既然原始的 weight 参数被重命名为 weight_orig,那参数是不是并没有发生变化,又怎么能达到剪枝的效果呢?
原始的 weight 参数在剪枝过程中被重命名为 weight_orig,并且保持不变。剪枝的效果是通过以下几个关键步骤实现的:

  1. 掩码(Mask)创建:
  • 在剪枝过程中,PyTorch 创建了一个掩码(weight_mask),它是一个与 weight 形状相同的二进制张量(由0和1组成)。
  • 在这个掩码中,1表示相应的权重保持不变,而0表示相应的权重被“剪枝”(实际上是被禁用)。
  1. 动态权重更新:
  • 尽管 weight_orig 保持不变,但是模块的 weight 属性被动态更新为 weight_orig 和 weight_mask 的元素乘积。
  • 这意味着,在模型的前向传播过程中,实际使用的 weight 是被掩码修改过的。在这个新的 weight 中,被剪枝的权重(在 weight_mask 中对应0的位置)的值为0,而其他位置的权重保持原始值。
  1. 前向传播的影响:
  • 当模型进行前向传播时,使用的是被掩码修改过的 weight。因此,尽管原始的 weight 参数(现在是 weight_orig)没有变化,模型实际上使用的权重已经被剪枝修改了。
  • 剪枝过程实际上通过使某些权重值为0,从而在模型的计算过程中禁用了这些权重。
  1. 模型复杂度的降低:
  • 通过这种方式,模型的复杂度在实际运行时降低了,因为一部分权重不再对输出产生影响。
  • 这可以提高计算效率,并且在某些情况下,可以通过专门的硬件和软件优化来利用权重的这种稀疏性。
    综上所述,虽然原始的 weight 参数作为 weight_orig 保留下来,但是实际上模型使用的是被掩码修改过的权重,这就是剪枝效果的实现方式。这种方法的一个优势是可以在不永久性地移除权重的情况下,测试和评估剪枝的影响,甚至可以在必要时撤销剪枝操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1252504.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jQuery_05 事件的绑定

jQuery可以给dom对象添加事件 在程序执行期间动态的处理事件 jQuery如何绑定事件呢&#xff1f; 1. $("选择器").事件名称(事件处理函数) $("选择器") &#xff1a; 选择0或者多个dom对象 给他们添加事件 事件名称&#xff1a;就是js中事件名称去掉on的部…

C#调用ffmpeg从视频提取图片

微信公众号“CSharp编程大全”的文章《C#从视频提取图片&#xff1f;》介绍了基于Microsoft.DirectX.AudioVideoPlayback.Video类实现从视频提取图片的方式&#xff0c;本来是想学习并测试该类的用法&#xff0c;但实际测试过程中却没有测通。百度从视频提取图片&#xff0c;网…

c语言-字符函数和字符串函数详解

文章目录 1. 字符分类函数2. 字符转换函数3. strlen的使用和模拟实现4. strcpy的使用和模拟实现5. strncpy函数的使用6. strcat的使用和模拟实现7. strncat函数的使用8. strcmp的使用和模拟实现9. strncmp函数的使用10. strstr的使用和模拟实现11. strtok函数的使用12. strerro…

【云原生】什么是 Kubernetes ?

什么是 Kubernetes &#xff1f; Kubernetes 是一个开源容器编排平台&#xff0c;管理着一系列的 主机 或者 服务器&#xff0c;它们被称作是 节点&#xff08;Node&#xff09;。 每一个节点运行了若干个相互独立的 Pod。 Pod 是 Kubernetes 中可以部署的 最小执行单元&#x…

JAVA之异常详解

1. 异常的概念与体系结构 1.1 异常的概念 在Java中&#xff0c;将程序执行过程中发生的不正常行为称为异常 1. 算术异常 public class Test {public static void main(String[] args) {System.out.println(10/0);} } 因为 0 不能当被除数&#xff0c;所以报出了异常&#…

java协同过滤算法 springboot+vue游戏推荐系统

随着人们生活质量的不断提高以及个人电脑和网络的普及&#xff0c;人们的业余生活质量要求也在不断提高&#xff0c;选择一款好玩&#xff0c;精美&#xff0c;画面和音质&#xff0c;品质优良的休闲游戏已经成为一种流行的休闲方式。可以说在人们的日常生活中&#xff0c;除了…

高级驾驶辅助系统 (ADAS)介绍

随着汽车技术持续快速发展,推动更安全、更智能、更高效的驾驶体验一直是汽车创新的前沿。高级驾驶辅助系统( ADAS ) 是这场技术革命的关键参与者,是 指集成到现代车辆中的一组技术和功能,用于增强驾驶员安全、改善驾驶体验并协助完成各种驾驶任务。它使用传感器、摄像头、雷…

vue开发中遇到的问题记录

文章目录 前言1、css 即时使用了scoped子组件依然会生效2、路由配置如果出现重复name&#xff0c;只会生效最后一个&#xff0c;且前端的路由无效3、组件之间事件传递回调需要先定义emits: []&#xff0c;不然会警告提示总结如有启发&#xff0c;可点赞收藏哟~ 前言 1、css 即…

DBeaver连接Oracle时报错:Undefined Error

连接信息检查了很多遍&#xff0c;应该是没问题的&#xff0c;而且驱动也正常下载了&#xff0c;但是就是连不上。 找了好久&#xff0c;终于找到一个可用的方式了&#xff0c;记录一下。 在安装目录修改dbeave.ini文件&#xff0c;最后一行添加 -Duser.nameTest。重启就可以…

鸿蒙开发板——环境搭建(南派开发)

概述 为了帮大家理清楚鸿蒙开发的套路&#xff0c;我们从头再梳理一遍相关的脉络。并为大家总结一些重点性的内容。在介绍OpenHarmony特性前&#xff0c;需要大家先明确以下两个基本概念&#xff1a; 子系统 OpenHarmony整体遵从分层设计&#xff0c;从下向上依次为&#xf…

FPGA模块——SPI协议(读写FLASH)

FPGA模块——SPI协议&#xff08;读写FLASH&#xff09; &#xff08;1&#xff09;FLASH芯片 W25Q16BV&#xff08;2&#xff09;SPI协议&#xff08;3&#xff09;芯片部分命令1.Write Enable&#xff08;06h&#xff09;2.Chip Erase (C7h / 60h)3.写指令&#xff08;02h&am…

TI 毫米波雷达开发系列之mmWave Studio 和 Visuiallizer 的异同点雷达影响因素分析

TI 毫米波雷达开发之mmWave Studio 和 Visuiallizer 的异同点 引入整个雷达系统研究的目标分析影响这个目标的因素硬件影响因素 —— 雷达系统的硬件结构&#xff08;主要是雷达收发机&#xff09;AWR1642芯片硬件系统组成MSS 和 DSS 概述MSS 和 DSS 分工BSS的分工AWR1642 组成…

crontab计划任务

银河麒麟v10服务器版和桌面版执行周期计划任务分为两类&#xff1a;系统任务调度和用户任务调度。系统任务是由 cron (crond) 这个系统服务来控制的&#xff0c;这个系统服务是默认启动的&#xff0c;通过vim /etc/crontab执行。用户自己设置的计划任务则使用crontab 命令 配置…

Everything进行内网穿透搜索

文章目录 1\. 部署内网穿透1.1. 注册账号1.2. 登录1.3. 创建隧道 2\. 从外网访问Everything 借助cpolar可以让我们在公网上访问到本地的电脑 1. 部署内网穿透 1.1. 注册账号 在使用之前需要先进行注册cpolar cpolar secure introspectable tunnels to localhost 1.2. 登录 C…

【Web】[GKCTF 2021]easycms

直接点击登录按钮没有反应 扫目录扫出来/admin.php 访问 弱口令admin 12345直接登录成功 点开设计--主题--自定义 编辑页头&#xff0c;类型选择php源代码 点保存显示权限不够 设计--组件--素材库 先随便上传一个文件&#xff0c;之后改文件名称为../../../../../system/tmp…

计算机应用基础_错题集_PPT演示文稿_操作题_计算机多媒体技术操作题_文字处理操作题---网络教育统考工作笔记007

PPT演示文稿操作题 提示:PPT部分操作题 将第2~第4张幻灯片背景效果设为渐变预置的“雨后初晴”效果(2)设置幻灯片放映方式

卷积神经网络(Inception-ResNet-v2)交通标志识别

文章目录 一、前言二、前期工作1. 设置GPU&#xff08;如果使用的是CPU可以忽略这步&#xff09;2. 导入数据3. 查看数据 二、构建一个tf.data.Dataset1.加载数据2. 配置数据集 三、构建Inception-ResNet-v2网络1.自己搭建2.官方模型 五、设置动态学习率六、训练模型七、模型评…

【小沐学写作】免费在线AI辅助写作汇总

文章目录 1、简介2、文涌Effidit&#xff08;腾讯&#xff09;2.1 工具简介2.2 工具功能2.3 工具体验 3、PPT小助手&#xff08;officeplus&#xff09;3.1 工具简介3.2 使用费用3.3 工具体验 4、DeepL Write&#xff08;仅英文&#xff09;4.1 工具简介4.2 工具体验 5、天工AI…

数组题目:645. 错误的集合、 697. 数组的度、 448. 找到所有数组中消失的数字、442. 数组中重复的数据 、41. 缺失的第一个正数

645. 错误的集合 思路&#xff1a; 我们定义一个数组cnt&#xff0c;记录每个数出现的次数。然后我们遍历数组&#xff0c;从1开始&#xff0c;如果cnt[i] 0 那就说明这个是错误的数&#xff0c;如果 cnt[i] 2&#xff0c;那就说明是重复的数。 代码&#xff1a; class So…

嵌入式虚拟机原理

欢迎关注博主 Mindtechnist 或加入【智能科技社区】一起学习和分享Linux、C、C、Python、Matlab&#xff0c;机器人运动控制、多机器人协作&#xff0c;智能优化算法&#xff0c;滤波估计、多传感器信息融合&#xff0c;机器学习&#xff0c;人工智能等相关领域的知识和技术。关…