基于torch实现模型剪枝

news2025/1/8 6:01:52

一、剪枝分类
所谓模型剪枝,其实是一种从神经网络中移除"不必要"权重或偏差(weigths/bias)的模型压缩技术。关于什么参数才是“不必要的”,这是一个目前依然在研究的领域。
1.1、非结构化剪枝
非结构化剪枝(Unstructured Puning)是指修剪参数的单个元素,比如全连接层中的单个权重、卷积层中的单个卷积核参数元素或者自定义层中的浮点数(scaling floats)。其重点在于,剪枝权重对象是随机的,没有特定结构,因此被称为非结构化剪枝。
1.2、结构化剪枝
与非结构化剪枝相反,结构化剪枝会剪枝整个参数结构。比如,丢弃整行或整列的权重,或者在卷积层中丢弃整个过滤器(Filter)。
1.3、本地与全局修剪
剪枝可以在每层(局部)或多层/所有层(全局)上进行。
二、pytorch的剪枝
目前 PyTorch 框架支持的权重剪枝方法有:
Random: 简单地修剪随机参数。
Magnitude: 修剪权重最小的参数(例如它们的 L2 范数)
以上两种方法实现简单、计算容易,且可以在没有任何数据的情况下应用。
2.1、torch剪枝工作原理
剪枝功能在 torch.nn.utils.prune 类中实现,代码在文件 torch/nn/utils/prune.py 中,主要剪枝类如下图所示:
在这里插入图片描述
剪枝原理是基于张量(Tensor)的掩码(Mask)实现。掩码是一个与张量形状相同的布尔类型的张量,掩码的值为 True 表示相应位置的权重需要保留,掩码的值为 False 表示相应位置的权重可以被删除。

Pytorch 将原始参数 复制到名为 _original 的参数中,并创建一个缓冲区来存储剪枝掩码 _mask。同时,其也会创建一个模块级的 forward_pre_hook 回调函数(在模型前向传播之前会被调用的回调函数),将剪枝掩码应用于原始权重。

pytorch 剪枝的 api 大致可作如下区分:
在这里插入图片描述
pytorch 中进行模型剪枝的工作流程如下:
1、选择剪枝方法(或者子类化 BasePruningMethod 实现自己的剪枝方法)。
2、指定剪枝模块和参数名称。
3、设置剪枝方法的参数,比如剪枝比例等。

2.2,局部剪枝
Pytorch 框架中的局部剪枝有非结构化和结构化剪枝两种类型,值得注意的是结构化剪枝只支持局部不支持全局。

2.2.1,局部非结构化剪枝
对应的函数原型如下:

def random_unstructured(module, name, amount)

1)函数功能:
用于对权重参数张量进行非结构化剪枝。该方法会在张量中随机选择一些权重或连接进行剪枝,剪枝率由用户指定。
2)函数参数定义:
module (nn.Module): 需要剪枝的网络层/模块,例如 nn.Conv2d() 和 nn.Linear()。
name (str): 要剪枝的参数名称,比如 “weight” 或 “bias”。
amount (int or float): 指定要剪枝的数量,如果是 0~1 之间的小数,则表示剪枝比例;如果是整数,则直接剪去参数的绝对数量。比如amount=0.2 ,表示将随机选择 20% 的元素进行剪枝。
3)使用示例:

import torch
import torch.nn.utils.prune as prune
conv = torch.nn.Conv2d(1, 1, 4)
prune.random_unstructured(conv, name="weight", amount=0.5)
conv.weight
"""
tensor([[[[-0.1703,  0.0000, -0.0000,  0.0690],
          [ 0.1411,  0.0000, -0.0000, -0.1031],
          [-0.0527,  0.0000,  0.0640,  0.1666],
          [ 0.0000, -0.0000, -0.0000,  0.2281]]]], grad_fn=<MulBackward0>)
"""

可以看书输出的 conv 层中权重值有一半比例为 0。

2.2.2、局部结构化剪枝
局部结构化剪枝有两种函数,对应的两种函数原型如下:

def random_structured(module, name, amount, dim)
def ln_structured(module, name, amount, n, dim, importance_scores=None)

1)函数功能:
与非结构化移除的是连接权重不同,结构化剪枝移除的是整个通道权重。
2)参数定义:
与局部非结构化函数非常相似,唯一的区别是必须定义 dim 参数(ln_structured 函数多了 n 参数)。

n 表示剪枝的范数,dim 表示剪枝的维度。

对于 torch.nn.Linear:
dim = 0:移除一个神经元。
dim = 1:移除与一个输入的所有连接。

对于 torch.nn.Conv2d:
dim = 0(Channels) : 通道 channels 剪枝/过滤器 filters 剪枝
dim = 1(Neurons): 二维卷积核 kernel 剪枝,即与输入通道相连接的 kernel

2.2.3,局部结构化剪枝示例代码
在写示例代码之前,我们先需要理解 Conv2d 函数参数、卷积核 shape、轴以及张量的关系。

首先,Conv2d 函数原型如下:

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

torch 中conv2d的卷积核权重 shape 都为(C_out, C_in, kernel_height, kernel_width),所以在代码中卷积层权重 shape 为 [3, 2, 3, 3],dim = 0 对应的是 shape [3, 2, 3, 3] 中的 3。这里我们 dim 设定了哪个轴,那自然剪枝之后权重张量对应的轴机会发生变换。
在这里插入图片描述
理解了前面的关键概念,下面就可以实际使用了,dim=0 的示例如下所示:

conv = torch.nn.Conv2d(2, 3, 3)
norm1 = torch.norm(conv.weight, p=1, dim=[1,2,3])
print(norm1)
"""
tensor([1.9384, 2.3780, 1.8638], grad_fn=<NormBackward1>)
"""
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=0)
print(conv.weight)
"""
tensor([[[[-0.0005,  0.1039,  0.0306],
          [ 0.1233,  0.1517,  0.0628],
          [ 0.1075, -0.0606,  0.1140]],

         [[ 0.2263, -0.0199,  0.1275],
          [-0.0455, -0.0639, -0.2153],
          [ 0.1587, -0.1928,  0.1338]]],


        [[[-0.2023,  0.0012,  0.1617],
          [-0.1089,  0.2102, -0.2222],
          [ 0.0645, -0.2333, -0.1211]],

         [[ 0.2138, -0.0325,  0.0246],
          [-0.0507,  0.1812, -0.2268],
          [-0.1902,  0.0798,  0.0531]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000]]]], grad_fn=<MulBackward0>)
"""

从运行结果可以明显看出,卷积层参数的最后一个通道参数张量被移除了(为 0 张量),其解释参见下图:
在这里插入图片描述
dim = 1 的情况:

conv = torch.nn.Conv2d(2, 3, 3)
norm1 = torch.norm(conv.weight, p=1, dim=[0, 2,3])
print(norm1)
"""
tensor([3.1487, 3.9088], grad_fn=<NormBackward1>)
"""
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=1)
print(conv.weight)
"""
tensor([[[[ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000]],

         [[-0.2140,  0.1038,  0.1660],
          [ 0.1265, -0.1650, -0.2183],
          [-0.0680,  0.2280,  0.2128]]],


        [[[-0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]],

         [[-0.2087,  0.1275,  0.0228],
          [-0.1888, -0.1345,  0.1826],
          [-0.2312, -0.1456, -0.1085]]],


        [[[-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]],

         [[-0.0891,  0.0946, -0.1724],
          [-0.2068,  0.0823,  0.0272],
          [-0.2256, -0.1260, -0.0323]]]], grad_fn=<MulBackward0>)
"""

很明显,对于 dim=1的维度,其第一个张量的 L2 范数更小,所以shape 为 [2, 3, 3] 的张量中,第一个 [3, 3] 张量参数会被移除(即张量为 0 矩阵) 。

2.3,全局非结构化剪枝
前文的 local 剪枝的对象是特定网络层,而 global 剪枝是将模型看作一个整体去移除指定比例(数量)的参数,同时 global 剪枝结果会导致模型中每层的稀疏比例是不一样的。

全局非结构化剪枝函数原型如下:

def global_unstructured(parameters, pruning_method, **kwargs)

def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):

1)函数功能:
随机选择全局所有参数(包括权重和偏置)的一部分进行剪枝,而不管它们属于哪个层。
2)参数定义:
parameters((Iterable of (module, name) tuples)): 修剪模型的参数列表,列表中的元素是 (module, name)。
pruning_method(function): 目前好像官方只支持 pruning_method=prune.L1Unstuctured,另外也可以是自己实现的非结构化剪枝方法函数。
importance_scores: 表示每个参数的重要性得分,如果为 None,则使用默认得分。
**kwargs: 表示传递给特定剪枝方法的额外参数。比如 amount 指定要剪枝的数量。
3)global_unstructured 函数的示例代码如下所示:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)
# 计算卷积层和整个模型的稀疏度
# 其实调用的是 Tensor.numel 内内函数,返回输入张量中元素的总数
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)
# 程序运行结果
"""
Sparsity in conv1.weight: 3.70%
Global sparsity: 20.00%
"""

运行结果表明,虽然模型整体(全局)的稀疏度是 20%,但每个网络层的稀疏度不一定是 20%。
三、总结

pytorch 框架还提供了一些帮助函数:
torch.nn.utils.prune.is_pruned(module): 判断模块 是否被剪枝。
torch.nn.utils.prune.remove(module, name):用于将指定模块中指定参数上的剪枝操作移除,从而恢复该参数的原始形状和数值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/549096.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是可持续能源?

随着全球经济的不断发展和人口的不断增长&#xff0c;能源问题越来越受到关注。传统能源已经不能满足人们对能源的需求&#xff0c;同时也对环境和健康带来了严重的影响。为了解决这些问题&#xff0c;出现了可持续能源的概念。那么&#xff0c;什么是可持续能源呢&#xff1f;…

逐渐从土里长出来的小花

从土里逐渐长出来的小花&#xff08;这是长出来后的样子&#xff0c;图片压缩了出现了重影~&#xff09; 代码在这里&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title&g…

MySQL-索引(2)

本文主要讲解MySQL-索引相关的知识点 联合索引前缀索引覆盖索引索引下推索引的优缺点什么时候适合创建索引,什么时候不适合?如何优化索引 ? 索引失效场景 ? 为什么SQL语句使用了索引,却还是慢查询 ? 使用索引有哪些注意事项 ? InnoDB引擎中的索引策略 目录 联合索引 联合…

LeetCode高频算法刷题记录6

文章目录 1. 编辑距离【困难】1.1 题目描述1.2 解题思路1.3 代码实现 2. 寻找两个正序数组的中位数【困难】2.1 题目描述2.2 解题思路2.3 代码实现 3. 合并区间【中等】3.1 题目描述3.2 解题思路3.3 代码实现 4. 爬楼梯【简单】4.1 题目描述4.2 解题思路4.3 代码实现 5. 排序链…

chatgpt赋能Python-python3_9安装numpy

Python 3.9 安装 NumPy 的完整指南 Python是一种功能强大的编程语言&#xff0c;已成为数据分析、人工智能和科学计算领域的主流语言之一。NumPy是一个Python库&#xff0c;用于执行高效的数值计算和科学计算操作。Python 3.9是Python最新版本&#xff0c;带来了许多新功能和改…

一款非常有趣的中国版本的Excalidraw作图工具drawon(桌案)

桌案工具集成了很多有趣的在线作图工具&#xff0c; 思维导图&#xff0c; 流程图&#xff0c;以及草图&#xff0c;在线ppt等功能。 而草图是基于国外有名的Excalidraw而改造而来&#xff0c;使得它更符合国人的使用习惯。 最近在 使用excalidraw时&#xff0c;发现了很多新功…

Excel | 基因名都被Excel篡改了怎么办呢!?~(附3种解决方案)

1写在前面 今天和大家分享一下在做表达矩阵处理时尝尝会遇到的一个问题&#xff0c;但又经常被忽视&#xff0c;就是Excel会修改你的基因名。&#x1f637; 无数大佬在这里都踩过坑&#xff0c;这些普遍的问题已经被写成了paper&#xff08;左右滑动&#xff09;&#xff1a;&a…

75.建立一个主体样式第一部分

我们的目标如下图所示 ● 首先建立文件夹&#xff0c;生成框架代码 ● 把页面上面的HTML元素写进去 <header><nav><div>LOGO</div><div>NAVIGATION</div></nav><div><h1>A healty meal delivered to your door, ever…

Java基础--->并发部分(2)【Java中的锁】

文章目录 synchronized和ReentrantLock的区别Java中锁的名词synchronized锁ReentrantLock锁 synchronized和ReentrantLock的区别 synchronized 和 ReentrantLock 都可以用来实现 Java 中的线程同步。它们的作用类似&#xff0c;但是在用法和特性上还是有一些区别的。 synchroni…

【2023/05/20】Visual Basic

Hello&#xff01;大家好&#xff0c;我是霜淮子&#xff0c;2023倒计时第15天。 Visual Basic是一种广泛应用于Windows操作系统的编程语言&#xff0c;它是Microsoft公司开发的一种面向对象的编程语言&#xff0c;以其简单、易学、易用的特点受到广泛欢迎。本文旨在介绍Visual…

2023年申请美国大学,需要SAT/ACT成绩吗?

受疫情影响&#xff0c;2021 和 2022 年申请美国大学时&#xff0c;许多大学都放宽了SAT/ACT门槛&#xff0c;不强行要求学生提交标化成绩。今年3月&#xff0c;理工大牛院校 MIT 率先打破了这个局面&#xff0c;宣布恢复 SAT/ACT 标化成绩要求&#xff0c;随后几个大学也陆续宣…

Java --- 云尚办公用户管理模块实现

目录 一、用户管理 1.1、数据库表 1.2、使用代码生成器生成相关代码 1.3、后端代码 1.4、前端代码 二、用户与角色功能实现 一、用户管理 1.1、数据库表 CREATE TABLE sys_user (id BIGINT(20) NOT NULL AUTO_INCREMENT COMMENT 会员id,username VARCHAR(20) NOT NU…

Java面试知识点(全)-spring面试知识点一

Java面试知识点(全) 导航&#xff1a; https://nanxiang.blog.csdn.net/article/details/130640392 注&#xff1a;随时更新 Spring原理 Spring ioc概念&#xff1a;控制权由对象本身转向容器&#xff1b;由容器根据配置文件去创建实例并创建各个实例之间的依赖关系。核心&am…

学术会议参会经验分享一(参会前的准备工作)

前前后后参加了两次学术会议&#xff0c;一次是今年三月份在深圳&#xff0c;另一次是在五月份在南宁&#xff0c;并且两次都进行了主题演讲。总的来说&#xff0c;我感觉参加学术会议重要的是自身能力的提升&#xff0c;比如说演讲、PPT制作等更方面的能力。下面我来分享一些我…

USRP概念基础

GBIC Gigabit Interface Converter的缩写,是将千兆位电信号转换为光信号的接口器件。GBIC设计上可以为热插拔使用。 SFP SFP (Small Form Pluggable)可以简单理解为GBIC(Gigabit Interface Converter的缩写)升级版本,是将千兆位电信号转换为光信号的接口器件,可以热插…

python字符串拼接

首先 什么是字符串拼接 我们来看一个段代码 print("你好""小猫猫")运行结果如下 这是一个最简单的演示 字符串 与 字符串的拼接 两个字符串字面量可以直接用加号 合并成一个字符串 当然 直接这里 字面量字面量 直接写上去看着会非常傻 所以 一般都是 变…

( 动态规划) 516. 最长回文子序列 ——【Leetcode每日一题】

❓516. 最长回文子序列 难度&#xff1a;中等 给你一个字符串 s &#xff0c;找出其中最长的回文子序列&#xff0c;并返回该序列的长度。 子序列定义为&#xff1a;不改变剩余字符顺序的情况下&#xff0c;删除某些字符或者不删除任何字符形成的一个序列。 示例 1&#xf…

MarkDown语法2

MarkDown语法2 一、基本语法 1. 标题 一级标题&#xff1a;# 一级标题二级标题&#xff1a;## 二级标题 2. 字体 斜体&#xff1a;*斜体*,_斜体_粗体&#xff1a;**粗体**&#xff0c;__粗体__粗斜体&#xff1a;***粗斜体***, ___粗斜体___ 3. 线 分割线&#xff1a;&a…

java常用工具之Objects类

目录 简介一、对象判空二、 对象为空时抛异常三、 判断两个对象是否相等四、 获取对象的hashcode五、 比较两个对象六、比较两个数组七、 小结 简介 Java 的 Objects 类是一个实用工具类&#xff0c;包含了一系列静态方法&#xff0c;用于处理对象。它位于 java.util 包中&…

七、Spring Cloud Alibaba-Sentinel

一、引言 1、了解服务可用性问题&#xff0c;服务挂掉原因 缓存击穿、单点故障、流量激增、线程池爆满、CPU飙升、DB超时、缺乏容错机制或保护机制、负载不均、服务雪崩、异常没处理等。 服务雪崩效应&#xff1a;因服务提供者的不可用导致服务调用者的不可用&#xff0c;并将…