模型剪枝入门

news2025/1/10 21:00:39

一、定义

1.定义
2. 案例1
3. 全局剪枝案例
4. 全局剪枝案例
5. 自定义剪枝
6. 特定网络剪枝
7. 多参数模块剪枝
8. torch.nn.utils.prune 解读

二、实现

  1. 定义
    在这里插入图片描述
  2. 接口:
import torch.nn.utils.prune as prune
  1. 案例1
import torch.nn as nn
import torch.nn.utils.prune as prune

import torch

def prune_first_layer(module, inputs):
    # 对权重矩阵进行L1剪枝,保留80%的元素
    prune.l1_unstructured(module,"weight", amount=0.8)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        x = nn.functional.relu(x)
        x = self.fc3(x)
        return x

net = Net()

# 将钩子函数与第一个全连接层关联起来
handle = net.fc1.register_forward_pre_hook(prune_first_layer)   #fc1 执行之前进行剪枝

# 进行前向传递
output = net(torch.randn(1, 784))
# 移除钩子函数
handle.remove()
  1. 全局剪枝案例
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 是经历卷积操作后的图片尺寸
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = LeNet().to(device=device)

# 首先打印初始化模型的状态字典
print(model.state_dict().keys())
print('*'*50)

# 构建参数集合, 决定哪些层, 哪些参数集合参与剪枝
parameters_to_prune = (
            (model.conv1, 'weight'),
            (model.conv2, 'weight'),
            (model.fc1, 'weight'),
            (model.fc2, 'weight'),
            (model.fc3, 'weight'))

# 调用prune中的全局剪枝函数global_unstructured执行剪枝操作, 此处针对整体模型中的20%参数量进行剪枝
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)

# 最后打印剪枝后的模型的状态字典
print(model.state_dict().keys())


print(
    "Sparsity in conv1.weight: {:.2f}%".format(
    100. * float(torch.sum(model.conv1.weight == 0))
    / float(model.conv1.weight.nelement())
    ))

print(
    "Sparsity in conv2.weight: {:.2f}%".format(
    100. * float(torch.sum(model.conv2.weight == 0))
    / float(model.conv2.weight.nelement())
    ))

print(
    "Sparsity in fc1.weight: {:.2f}%".format(
    100. * float(torch.sum(model.fc1.weight == 0))
    / float(model.fc1.weight.nelement())
    ))

print(
    "Sparsity in fc2.weight: {:.2f}%".format(
    100. * float(torch.sum(model.fc2.weight == 0))
    / float(model.fc2.weight.nelement())
    ))

print(
    "Sparsity in fc3.weight: {:.2f}%".format(
    100. * float(torch.sum(model.fc3.weight == 0))
    / float(model.fc3.weight.nelement())
    ))

print(
    "Global sparsity: {:.2f}%".format(
    100. * float(torch.sum(model.conv1.weight == 0)
               + torch.sum(model.conv2.weight == 0)
               + torch.sum(model.fc1.weight == 0)
               + torch.sum(model.fc2.weight == 0)
               + torch.sum(model.fc3.weight == 0))
         / float(model.conv1.weight.nelement()
               + model.conv2.weight.nelement()
               + model.fc1.weight.nelement()
               + model.fc2.weight.nelement()
               + model.fc3.weight.nelement())
    ))

# 当采用全局剪枝策略的时候(假定20%比例参数参与剪枝),
# 仅保证模型总体参数量的20%被剪枝掉,
# 具体到每一层的情况则由模型的具体参数分布情况来定.
import torchfrom torch.nn.utils import prune# 
定义模型model = torch.nn.Sequential(   
 torch.nn.Linear(100, 50), 
 torch.nn.ReLU(),  
 torch.nn.Linear(50, 10))
 # 剪枝网络结构
prune.random_unstructured(model, amount=0.2)
# 剪枝权重
prune.l1_unstructured(model, amount=0.2)
  1. 自定义剪枝
#用户自定义剪枝(Custom pruning).
# 剪枝模型通过继承class BasePruningMethod()来执行剪枝,
# 内部有若干方法: call, apply_mask, apply, prune, remove等等.
# 一般来说, 用户只需要实现__init__, 和compute_mask两个函数即可完成自定义的剪枝规则设定.
import time
# 自定义剪枝方法的类, 一定要继承prune.BasePruningMethod
class myself_pruning_method(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    # 内部实现compute_mask函数, 完成程序员自己定义的剪枝规则, 本质上就是如何去mask掉权重参数
    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        # 此处定义的规则是每隔一个参数就遮掩掉一个, 最终参与剪枝的参数量的50%被mask掉
        mask.view(-1)[::2] = 0
        return mask

# 自定义剪枝方法的函数, 内部直接调用剪枝类的方法apply
def myself_unstructured_pruning(module, name):
    myself_pruning_method.apply(module, name)
    return module


# 实例化模型类
model = LeNet().to(device=device)

start = time.time()
# 调用自定义剪枝方法的函数, 对model中的第三个全连接层fc3中的偏置bias执行自定义剪枝
myself_unstructured_pruning(model.fc3, name="bias")

# 剪枝成功的最大标志, 就是拥有了bias_mask参数
print(model.fc3.bias_mask)

# 打印一下自定义剪枝的耗时
duration = time.time() - start
print(duration * 1000, 'ms')

# 打印出来的bias_mask张量, 完全是按照预定义的方式每隔一位遮掩掉一位,
#  0和1交替出现, 后续执行remove操作的时候,
# 原始的bias_orig中的权重就会同样的被每隔一位剪枝掉一位.
  1. 特定网络剪枝
# 第一种: 对特定网络模块的剪枝(Pruning Model).

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 是经历卷积操作后的图片尺寸
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = LeNet().to(device=device)

# 序列化一个剪枝模型(Serializing a pruned model):
# 对于一个模型来说, 不管是它原始的参数, 拥有的属性值, 还是剪枝的mask buffers参数
# 全部都存储在模型的状态字典中, 即state_dict()中.
# 将模型初始的状态字典打印出来
print(model.state_dict().keys())
print('*'*50)

# 对模型进行剪枝操作, 分别在weight和bias上剪枝
module = model.conv1
prune.random_unstructured(module, name="weight", amount=0.3)
prune.l1_unstructured(module, name="bias", amount=3)

# 再将剪枝后的模型的状态字典打印出来
print(model.state_dict().keys())

# 对模型执行剪枝remove操作.
# 通过module中的参数weight_orig和weight_mask进行剪枝, 本质上属于置零遮掩, 让权重连接失效.
# 具体怎么计算取决于_forward_pre_hooks函数.
# 这个remove是无法undo的, 也就是说一旦执行就是对模型参数的永久改变.

# 打印剪枝后的模型参数
print(list(module.named_parameters()))
print('*'*50)

# 打印剪枝后的模型mask buffers参数
print(list(module.named_buffers()))
print('*'*50)

# 打印剪枝后的模型weight属性值
print(module.weight)
print('*'*50)

# 打印模型的_forward_pre_hooks
print(module._forward_pre_hooks)
print('*'*50)

# 执行剪枝永久化操作remove
prune.remove(module, 'weight')
print('*'*50)

# remove后再次打印模型参数
print(list(module.named_parameters()))
print('*'*50)

# remove后再次打印模型mask buffers参数
print(list(module.named_buffers()))
print('*'*50)

# remove后再次打印模型的_forward_pre_hooks
print(module._forward_pre_hooks)

# 对模型的weight执行remove操作后, 模型参数集合中只剩下bias_orig了,
# weight_orig消失, 变成了weight, 说明针对weight的剪枝已经永久化生效.
# 对于named_buffers张量打印可以看出, 只剩下bias_mask了,
# 因为针对weight做掩码的weight_mask已经生效完毕, 不再需要保留了.
# 同理, 在_forward_pre_hooks中也只剩下针对bias做剪枝的函数了.
  1. 多参数模块剪枝
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 是经历卷积操作后的图片尺寸
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 第二种: 多参数模块的剪枝(Pruning multiple parameters).
model = LeNet().to(device=device)

# 打印初始模型的所有状态字典
print(model.state_dict().keys())
print('*'*50)

# 打印初始模型的mask buffers张量字典名称
print(dict(model.named_buffers()).keys())
print('*'*50)

# 对于模型进行分模块参数的剪枝
for name, module in model.named_modules():
    # 对模型中所有的卷积层执行l1_unstructured剪枝操作, 选取20%的参数剪枝
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name="weight", amount=0.2)
    # 对模型中所有全连接层执行ln_structured剪枝操作, 选取40%的参数剪枝
    elif isinstance(module, torch.nn.Linear):
        prune.ln_structured(module, name="weight", amount=0.4, n=2, dim=0)

# 打印多参数模块剪枝后的mask buffers张量字典名称
print(dict(model.named_buffers()).keys())
print('*'*50)

# 打印多参数模块剪枝后模型的所有状态字典名称
print(model.state_dict().keys())
  1. torch.nn.utils.prune 解读:用于修剪模块参数的实用类和函数。
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1938835.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深入理解Linux网络(二):UDP接收内核探究

深入理解Linux网络(二):UDP接收内核探究 一、UDP 协议处理二、recvfrom 系统调⽤实现 一、UDP 协议处理 udp 协议的处理函数是 udp_rcv。 //file: net/ipv4/udp.c int udp_rcv(struct sk_buff *skb) {return __udp4_lib_rcv(skb, &udp_…

IntelliJ IDEA 直接在软件中更新为最新版

当我们的 IDEA 工具许久没有更新,已经拖了好几个版本,想跨大版本更新,比如从2020.2.1 -> 2023.x.x 此时,我们菜单栏点击 Help -> Check for Updates… ,右下角会有提示更新,如下图: 点…

【深大计算机系统(2)】实验一 实验环境配置与使用 附常用指令

目录 一、 实验目标: 二、实验环境与工件: 三、实验内容与步骤 1. 学习并熟悉Linux基本操作,按照要求创建用户。(30分) 2.新建用户主目录下创建子目录:gdbdebug,并进入gdbdebug子目录。将过程和…

亲测--linux下安装ffmpeg最新版本---详细教程

下载地址 Download FFmpeg 下载最新的https://ffmpeg.org/releases/ffmpeg-7.0.1.tar.xz 上传到服务器 解压 tar xvf ffmpeg-7.0.1.tar.xz 编译 cd ffmpeg-7.0.1 ./configure --prefix=/usr/local/ffmpeg make && make install 报错: 解决:在后面加 跳过检测…

粉尘传感器助力面粉厂安全生产

在面粉加工行业中,粉尘问题一直是一个不容忽视的难题。从原料的破碎、研磨到成品的包装,整个生产流程中都会伴随着大量的粉尘产生。这些粉尘不仅影响生产环境,更对工作人员的健康、设备的安全运行以及环境保护构成严重威胁。因此,…

【Unity实战100例】Unity声音可视化多种显示效果

目录 一、技术背景 二、界面搭建 三、 实现 UIAudioVisualizer 基类 四、实现 AudioSampler 类 五、实现 IAudioSample 接口 六、实现MusicAudioVisualizer 七、实现 MicrophoneAudioManager 类 八、实现 MicrophoneAudioVisualizer 类 九、源码下载 Unity声音可视化四…

云计算数据中心(三)

目录 四、自动化管理(一)自动化管理的特征(二)自动化管理实现阶段(三)Facebook自动化管理 五、容灾备份(一)容灾系统的等级标准(二)容灾备份的关键技术&#…

基于图片中的表格检测与识别

1、项目介绍 本文将会使用Microsoft开源的表格检测模型table-transformer-detection来实现表格检测与入门。 以下将分三部分进行介绍: 表格检测:检测图片或PDF文件中的表格所在的区域表格结构识别:对于检测后的表格区域,再详细识…

Langchain[3]:Langchain架构演进与功能扩展:流式事件处理、事件过滤机制、回调传播策略及装饰器应用

Langchain[3]:Langchain架构演进与功能扩展:流式事件处理、事件过滤机制、回调传播策略及装饰器应用 1. Langchain的演变 v0.1: 初始版本,包含基本功能。 从0.1~0.2完成的特性: 通过事件流 API 提供更好的流式支持。标准化工具调用支持Tool…

全国区块链职业技能大赛国赛考题前端功能开发

任务3-1:区块链应用前端功能开发 1.请基于前端系统的开发模板,在登录组件login.js、组件管理文件components.js中添加对应的逻辑代码,实现对前端的角色选择功能,并测试功能完整性,示例页面如下: 具体要求如下: (1)有明确的提示,提示用户选择角色; (2)用户可看…

Java中静态代理和动态代理介绍和使用

前言 在Java中,代理模式是一种常用的设计模式,用于为其他对象提供一种代理以控制对这个对象的访问。代理模式主要有两种实现方式:静态代理和动态代理。 一、静态代理 静态代理是由程序员手动创建或指定代理类,代理类在程序运行…

【C语言】详解结构体(下)(位段)

文章目录 前言1. 位段的含义2. 位段的声明3. 位段的内存分配(重点)3.1 存储方向的问题3.2 剩余空间利用的问题 4. 位段的跨平台问题5. 位段的应用6. 总结 前言 相信大部分的读者在学校或者在自学时结构体的知识时,可能很少会听到甚至就根本没…

sip代理服务器、SIP用户代理服务器、sip服务器的区别和联系

一.SIP代理服务器(SIP Proxy Server)和SIP用户代理服务器(SIP User Agent Server,简称SIP UAS)的区别和联系。 1. 区别 1)功能定位 SIP代理服务器:主要负责将SIP请求消息从发起方…

VBA技术资料MF175:利用文本框和列表框实现多列数据录入

我给VBA的定义:VBA是个人小型自动化处理的有效工具。利用好了,可以大大提高自己的工作效率,而且可以提高数据的准确度。“VBA语言専攻”提供的教程一共九套,分为初级、中级、高级三大部分,教程是对VBA的系统讲解&#…

学习周报:文献阅读+水动力学方程推导

目录 摘要 Abstract 文献阅读:物理信息神经网络学习自由表面流 文献摘要 讨论|结论 预备知识 浅水方程SWE(Shallow Water Equations) 质量守恒方程: 动量守恒方程: Godunov通量法: 基本原理&…

分布式会话拦截器

1.分布式会话拦截器-构建拦截器 背景:对于不同的用户进行权限拦截(基于token的判断) 实现过程:在api下构建包以及相关的文件,创建UserTokenInterceptor,实现implements handlerInterceptor.重写三种主要方法。 preHandle postHandle afterCo…

MongoDB文档整理

过往mongodb文档: https://blog.csdn.net/qq_46921028/article/details/123361633https://blog.csdn.net/qq_46921028/article/details/131136935https://blog.csdn.net/qq_46921028/article/details/139247847 1. MongoDB前瞻 1、MongoDB概述: MongoDB是…

【Rust日报】在 Linux 文件系统中使用 Rust 的讨论

SIMD 加速的迭代器 单指令流多数据流(Single Instruction Multiple Data,缩写:SIMD)是一种采用一个控制器来控制多个处理器,同时对一组数据(又称"数据向量")中的每一个分别执行相同的…

PDF压缩软件电脑版 电脑pdf压缩怎么压缩文件

在数字化时代,pdf文件因其良好的兼容性和稳定性,已成为工作与生活中不可或缺的文件格式。然而,随着内容的增多,pdf文件的体积也随之增大,给文件的传输和存储带来了一定的困扰。本文将为你详细介绍如何在电脑上压缩pdf文…

【手撕数据结构】拿捏单链表

目录 单链表介绍链表的初始化打印链表增加节点尾插头插再给定位置之后插入在给定位置之前插入 删除节点尾删头删删除给定位置的节点删除给定位置之后的节点 查找节点 单链表介绍 单链表也叫做无头单向非循环链表,链表也是一种线性结构。他在逻辑结构上一定连续&…