【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(上,非结构化剪枝)

news2024/11/24 11:28:47

文章目录

      • 0. 前言
      • 1. 剪枝`prune`主要功能分类
      • 2. `torch.nn.utils.prune`中的方法介绍
      • 3. PyTorch实例
        • 3.1 `BasePruningMethod`
        • 3.2`PruningContainer`
        • 3.3 `identity`
        • 3.4`random_unstructured`
        • 3.5`l1_unstructured`
      • 4. 总结

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

PyTorch中的torch.nn.utils.prune模块是一个专门用于神经网络模型剪枝的工具集。模型剪枝是一种减少神经网络参数数量的技术,其目标是在保持模型性能的同时减少计算成本内存占用。这对于部署模型到资源受限的设备(如移动设备或嵌入式系统)特别有用。

本文将通过实例介绍torch.nn.utils.prune模块中的各个方法,由于内容较多分为上、下两篇。本篇主要介绍非结构化剪枝。

下篇非结构化剪枝链接:【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(下,结构化剪枝)

1. 剪枝prune主要功能分类

torch.nn.utils.prune模块提供了一系列的剪枝方法,包括但不限于:

  1. 无结构剪枝:这种剪枝方法可以独立地移除网络中的权重,而不考虑权重之间的结构关系。例如,L1UnstructuredRandomUnstructured 就是两种无结构剪枝方法,它们分别根据权重的绝对值大小和随机选择的方式移除权重。

  2. 结构化剪枝:与无结构剪枝相反,结构化剪枝会移除整个的结构单位(如整个神经元或通道),而不是单独的权重。RandomStructuredLnStructured 就是这样的例子,它们可以移除整个的通道。

  3. 自定义剪枝CustomFromMask 方法允许用户自定义剪枝策略,通过提供一个掩码来指定哪些权重应该被保留或移除。

  4. 剪枝管理:除了剪枝方法本身,torch.nn.utils.prune还提供了工具来管理和应用剪枝,例如,prune.global_unstructuredprune.remove 方法。前者允许跨多个层执行全局剪枝,而后者则用于移除剪枝操作,恢复原始权重或应用剪枝掩码。

2. torch.nn.utils.prune中的方法介绍

下面是本文将介绍的torch.nn.utils.prune中的方法:

  • BasePruningMethod: 抽象基类,用于创建新的剪枝类。
  • PruningContainer:允许组合多种不同的剪枝策略,并按顺序应用这些策略。
  • identity: 实现了一个不剪枝任何单元仅生成一个全为一的掩码的实用剪枝方法。
  • random_unstructured: 随机剪枝张量中的单元。
  • l1_unstructured: 根据L1范数(绝对值)剪枝张量中的单元。

3. PyTorch实例

为了介绍这些剪枝方法,我们将首先定义一个简单的模型,并使用torch.nn.utils.prune模块中的各种剪枝方法来处理这个模型的权重。我们将以一个简单的卷积层为例,然后应用上述提到的每种剪枝方法。

首先,让我们导入必要的库并定义一个包含单个卷积层的模型:

import torch
import torch.nn as nn
from torch.nn.utils import prune

torch.manual_seed(888)
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        # 创建一个简单的卷积层
        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3)

model = SimpleModel()

这里有一个值得注意的地方就是prune的导入:如果不写from torch.nn.utils import prune,而直接在代码中使用torch.nn.utils.prune.xxxx(),会报错↓
在这里插入图片描述
这个报错我不太能理解,不知道会不会在后续版本中更正。

接下来,我们将逐一介绍并应用每种剪枝方法:

3.1 BasePruningMethod

这是一个抽象类,可以理解为自定义剪枝的类。

class BasePruningMethod(ABC):
    r"""Abstract base class for creation of new pruning techniques.

    Provides a skeleton for customization requiring the overriding of methods
    such as :meth:`compute_mask` and :meth:`apply`.
    """
3.2PruningContainer

一开始我觉得这个方法和nn.Sequential差不多,但是实际并不是!

PruningContainer通常不会直接由用户实例化,而是作为torch.nn.utils.prune中其他剪枝方法的基础。当调用如l1_unstructuredrandom_unstructuredln_structured等剪枝方法时,内部会创建一个PruningContainer实例,并且将特定的剪枝方法添加到容器中。

3.3 identity

这个方法不会剪枝(改变)任何权重,它只会生成一个全为1的掩码。

print("Weight before Identity pruning:")
print(model.conv.weight)
prune.identity(model.conv, name="weight")
print("Weight after Identity pruning:")
print(model.conv.weight)
print("mask:")
print(model.conv.weight_mask)

输出为:

Weight before Identity pruning:
Parameter containing:
tensor([[[[-0.3017,  0.1290, -0.2468],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.1403]]]], requires_grad=True)
Weight after Identity pruning:
tensor([[[[-0.3017,  0.1290, -0.2468],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.1403]]]], grad_fn=<MulBackward0>)
mask:
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])
3.4random_unstructured

这个方法会随机选择权重值进行剪枝:

prune.random_unstructured(model.conv, name="weight", amount=0.5) 
#amount参数指定的是要被剪枝(即置零)的权重比例。
print("Weight after RandomUnstructured pruning (50%):")
print(model.conv.weight)

输出为:

Weight after RandomUnstructured pruning (50%):
tensor([[[[-0.0000,  0.0000, -0.0000],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.0000]]]], grad_fn=<MulBackward0>)

可以明显看出,对比3.3节的输出结果,有4个(50%)参数被剪枝(置零)了。

3.5l1_unstructured

这个方法会根据权重的L1范数选择要剪枝的权重。

prune.l1_unstructured(model.conv, name="weight", amount=0.5)
print("Weight after L1Unstructured pruning (50%):")
print(model.conv.weight)

输出为:

Weight after L1Unstructured pruning (50%):
tensor([[[[-0.3017,  0.0000, -0.2468],
          [ 0.2107,  0.0000,  0.1923],
          [ 0.1887, -0.0000,  0.0000]]]], grad_fn=<MulBackward0>)

Process finished with exit code 0

对比3.3输出的结果,可以看出L1范数(绝对值)最小的4个(50%)参数被剪枝(置零)了。

4. 总结

本文介绍了PyTorch中的prune模型剪枝模块的中的非结构化剪枝,下一篇将介绍结构化剪枝。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1879676.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

接口请求网关超时排查和引发的思考

问题描述 前端请求服务端接口&#xff0c;返回504 Gateway Timeout&#xff0c;请求接口为https://profile.noodles.com/user-mail-pub/api/user-mail/user-trash-mails?userId123456 原因分析 观察日志&#xff0c;发现网关链接超时&#xff0c;并且对应请求没有达到对应服…

四川省高等职业学校大数据技术专业建设暨专业质量监测研讨活动顺利开展

6月21日&#xff0c;省教育评估院在四川邮电职业技术学院组织开展全省高等职业学校大数据技术专业建设暨专业质量监测研讨活动。省教育评估院副院长赖长春&#xff0c;四川邮电职业技术学院党委副书记、校长冯远洪&#xff0c;四川邮电职业技术学院党委委员、副校长程德杰等出席…

【python】python知名品牌调查问卷数据分析可视化(源码+调查数据表)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

抓紧收藏!7 款令人惊艳的 AI 开源项目

&#x1f43c; 关注我, 了解更多 AI 前沿资讯和玩法&#xff0c;AI 学习之旅上&#xff0c;我与您一同成长&#xff01; &#x1f388; 进入公众号&#xff0c;回复 AI, 可免费领取超多实用的 AI 资料 和内容丰富的 AI 知识库地址。 自从去年 AIGC 兴起以来&#xff0c;AI 开源…

gin 服务端无法使用sse流式nginx配置

我在本地使用 gin 可以流式的将大模型数据传递给前端。但是当我部署到服务器中时&#xff0c;会阻塞一段时间&#xff0c;然后显示一大段文本。 起初我怀疑是gin 没有及时将数据刷到管道中&#xff0c;但是经过测试&#xff0c;还是会阻塞。 c.Writer.(http.Flusher).Flush()最…

使用LabVIEW报告生成工具包时报错97

问题详情&#xff1a; 在运行使用Excel/Word调用节点的程序时&#xff0c;收到错误97&#xff1a;LabVIEW&#xff1a;&#xff08;十六进制0x61&#xff09;输入中传递了一个空引用句柄或先前已删除的引用句柄。 当运行报告生成工具包中的一个示例程序时&#xff0c;收到错误…

【python】python入门day2——数据类型与运算

python数据类型与运算 一、Python中变量的数据类型1、数据类型分类2、数值类型3、布尔类型4、字符串类型5、其他类型(了解) 二、Python数据类型转换1、使用Python实现超市的收银系统2、Python数据类型的转换方法3、总结 三、Python运算符1、算术运算符3、赋值运算符4、复合赋值…

计算机科学基础简单介绍(1—6)

计算机影响了我们生活的方方面面&#xff0c;在我们这个时代完全渗透了我们的生活。 最早是算盘、星盘、时钟、尺卡等古老的计算工具&#xff0c;后来出现了进步计算机&#xff0c;类似与汽车里程表的一种机械工具&#xff0c;但是他也是手工制品。经过历史的演变与发展&#x…

Prompting已死?DSPy:自动优化LLM流水线

在 LLM 应用中&#xff0c;如何优化一个 pipeline 的流程一直是一个比较头疼的问题。提示词作为一个预定义字符串&#xff0c;往往也没有很好地优化方向。本文中的 DSPy 框架或许能在实际应用中对效果优化起到一定帮助。 当前&#xff0c;在 LLM 的应用中&#xff0c;大家都在探…

LSTM时间序列基础学习

时间序列 时间序列可以是一维&#xff0c;二维&#xff0c;三维甚至更高维度的数据&#xff0c;在深度学习的世界中常见的是三维时间序列&#xff0c;这三个维度分别是&#xff08;batch_size,time_step,input_dimensions&#xff09;。 其中time_step是时间步&#xff0c;它…

GPU配置pytorch环境(links for torch)

一、创建一个新的虚拟环境 二、激活虚拟环境 三、打开或新建一个pycharm项目&#xff0c;把环境选成我们刚刚新建的虚拟环境 四、从links for torch网站下载与自己cuda版本和python版本对应的torch 五、在pycharm的终端pip install 安装torch 直到显示成功安装 六、验证pytorch…

六月,允许自己做自己,别人做别人

今天结束后&#xff0c;2024 就过去一半了。 年初的规划完成一半了吗&#xff1f;如果没有也没关系&#xff0c;做你自己继续前进。 家人来北京旅游&#xff0c;我累趴了 六月初&#xff0c;我搬家了&#xff0c;这次租了一整套房&#xff0c;是一个小俩居、还带一个小阁楼。…

新手练习项目 6:图书管理系统

名人说&#xff1a;莫听穿林打叶声&#xff0c;何妨吟啸且徐行。—— 苏轼《定风波莫听穿林打叶声》 Code_流苏(CSDN)&#xff08;一个喜欢古诗词和编程的Coder&#xff09; 目录 一、项目描述二、项目结构三、项目步骤步骤1&#xff1a;定义Book类步骤2&#xff1a;实现主程序…

FHE全同态加密介绍——小白版

1. 何为FHE&#xff1f; FHE中的evluation key p k e v a l pk_{eval} pkeval​是public的&#xff0c;用于密文计算逻辑 f ( ⋅ ) f(\cdot) f(⋅)的evalute circuit中&#xff0c;但根据所处理数据加解密密钥的不同&#xff0c;可将FHE分为&#xff1a; 1&#xff09;对称F…

Web后端开发概述环境搭建项目创建servlet生命周期

Web开发概述 web开发指的就是网页向后再让发送请求,与后端程序进行交互 web后端(javaEE)程序需要运行在服务器中 这样前端才可以对其进行进行访问 什么是服务器? 解释1: 服务器就是一款软件,可以向其发送请求,服务器会做出一个响应.可以在服务器中部署文件&#xff0c;让…

【ai】trition:tritonclient.utils.shared_memory 仅支持linux

Can’t find tritonclient.utils.shared_memory on WIN10 #4149yolov4的python客户端 导入以后,windows 的pycharm 就是看不到折腾了很久:SaviorEnv 环境下安装tritonclient[all]也会失败 (base) C:\Users\zhangbin>conda create -n SaviorEnv python=3.8 Collecting pack…

计算机体系结构和指令系统

1.计算机体系结构 - 五大部件 - 冯 诺依曼 计算机的特点 1.计算机有五大部件组成 2.指令和数据以同等地位存储于存储器&#xff0c;可按照地址访问 3.指令和数据用二进制表示 4.指令由操作码和地址码组成 5。存储程序 6.以计算器为中心&#xff08;输入、输出设备与存储器…

成都市水资源公报(2000-2022年)

数据年限&#xff1a;2000-2022年&#xff0c;无2009年 数据格式&#xff1a;pdf、word、jpg 数据内容&#xff1a;降水量、地表水资源量、地下水资源量、水资源总量、蓄水状况、平原区浅层地下水动态、水资源情况分析、供水量、用水量、污水处理、洪涝干旱等

[XYCTF新生赛]-PWN:EZ1.0?(mips,mips的shellcode利用)

查看保护 查看ida 这里用的是retdec&#xff0c;没安装的可以看这个[CTF]-PWN:mips反汇编工具&#xff0c;ida插件retdec的安装-CSDN博客 这里直接看反汇编貌似看不出什么&#xff0c;所以直接从汇编找 完整exp&#xff1a; from pwn import* context(log_leveldebug,archmip…

【机器学习】在【Pycharm】中的应用:【线性回归模型】进行【房价预测】

专栏&#xff1a;机器学习笔记 pycharm专业版免费激活教程见资源&#xff0c;私信我给你发 python相关库的安装&#xff1a;pandas,numpy,matplotlib&#xff0c;statsmodels 1. 引言 线性回归&#xff08;Linear Regression&#xff09;是一种常见的统计方法和机器学习算法&a…