Torch-Pruning 库入门级使用介绍

news2025/1/12 2:57:14

在这里插入图片描述

项目地址:https://github.com/VainF/Torch-Pruning

Torch-Pruning 是一个专用于torch的模型剪枝库,其基于DepGraph 技术分析出模型layer中的依赖关系。DepGraph 与现有的修剪方法(如 Magnitude Pruning 或 Taylor Pruning)相结合可以达到良好的剪枝效果。

本博文结合项目官网案例,对信息进行结构话,抽离出剪枝技术说明、剪枝模型保存与加载、剪枝技术的基本使用,剪枝技术的具体使用案例。并结合外部信息,分析剪枝对模型性能精度的影响。

1、基本说明

1.1 项目安装

打开https://github.com/VainF/Torch-Pruning,下载项目
在这里插入图片描述
然后在终端中,进入项目目录,并执行pip install -r requirements.txt 安装项目依赖库
在这里插入图片描述
然后在执行 pip install -e . ,将项目安装在当前目录下,并设置为editing模式。
在这里插入图片描述
验证安装:执行命令python -c "import torch_pruning", 如果没有输出报错信息则表示安装成功。
在这里插入图片描述

1.2 DepGraph 技术说明

在结构修剪中,组被定义为深度网络中最小的可移除单元。每个组由多个相互依赖的层组成,需要同时修剪这些层以保持最终结构的完整性。然而,深度网络通常表现出层与层之间错综复杂的依赖关系,这对结构修剪提出了重大挑战。这项研究通过引入一种名为 DepGraph 的自动化机制来解决这一挑战,该机制可以轻松实现参数分组,并有助于修剪各种深度网络。
在这里插入图片描述

直接剪枝会会破坏layer间的依赖关系,会导致forward流程报错。具体如下面代码,移除model.conv1模块中的idxs为0与1的channel,导致后续的bn1层输入输入与参数格式对不上号,然后报错。

from torchvision.models import resnet18
import torch_pruning as tp
import torch

model = resnet18().eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # remove channel 0 and channel 1
output = model(torch.randn(1,3,224,224)) # test

在这里插入图片描述
基本在后续层添加剪枝,运行代码也会保存,因为batchnorm的下一层要求的输出channel是64。

model = resnet18(pretrained=True).eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) 
tp.prune_batchnorm_out_channels(model.bn1, idxs=[0,1])
tp.prune_batchnorm_in_channels(model.layer1[0].conv1, idxs=[0,1])
output = model(torch.randn(1,3,224,224)) 

使用DepGraph剪枝代码如下,先使用tp.DependencyGraph().build_dependenc构建出依赖图,然后基于DG.get_pruning_group函数获取目标剪枝层的依赖关系组,最后在检验关系并进行剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )

# 3. prune all grouped layers that are coupled with model.conv1 (included).
print(group)
if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
    group.prune()
    
# 4. Save & Load
model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the model object

代码执行后的输出如下所示,可以看到捕捉到group对应的依赖layer

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------

1.3 剪枝模型的保存与加载

剪枝后的模型由于网络结构改变了,如果只保存模型参数,是无法支持原始网络结构,需要将模型结构连参数一并保存。加载时连同参数一起加载。

model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the pruned model

或者基于tp库中tp.state_dict函数提取目标参数进行保存,并基于tp.load_state_dict函数将剪枝后的参数赋值到原始模型中形成剪枝模型。

# save the pruned state_dict, which includes both pruned parameters and modified attributes
state_dict = tp.state_dict(pruned_model) # the pruned model, e.g., a resnet-18-half
torch.save(state_dict, 'pruned.pth')

# create a new model, e.g. resnet18
new_model = resnet18().eval()

# load the pruned state_dict into the unpruned model.
loaded_state_dict = torch.load('pruned.pth', map_location='cpu')
tp.load_state_dict(new_model, state_dict=loaded_state_dict)
print(new_model) # This will be a pruned model.

2、剪枝基本案例

2.1 具有目标结构的剪枝

以下代码使用TaylorImportance指标进行剪枝,设置忽略输出层的剪枝。并设置MagnitudePruner中对通道剪枝50%,一共分iterative_steps步完成剪枝,每一次剪枝都进行微调。
整体来说,具备目标结构的剪枝,效果是最差的。 基于https://blog.csdn.net/a486259/article/details/140407147 分析的数据得出的结论。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

#model = resnet18(pretrained=True)
model = resnet18()

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    #pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(f"iter {i} | rate:{macs/base_macs:.4f}  {nparams/base_nparams:.4f}")
print(model)
    # finetune your model here
    # finetune(model)
    # ...

代码的输出信息如下所示,可以看到macs与nparams在逐步降低。最终输出的模型结构,所有的chanel都减半了,只有输出层例外。

iter 0 | rate:0.8092  0.8111
iter 1 | rate:0.6469  0.6445
iter 2 | rate:0.4971  0.4979
iter 3 | rate:0.3718  0.3695
iter 4 | rate:0.2674  0.2614
ResNet(
  (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)
PS D:\开源项目\Torch-Pruning-master>
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)

2.2 自动结构剪枝

这里的自动结构是有一个预设目标,即将总体channel剪枝到原模型的多少,但没有预定的目标结构。可能有的laye通道剪枝数多,有的剪枝数少。 与2.1中的代码相比,主要是增加了参数 global_pruning=True。但这个剪枝方式比具有目标结构的剪枝更加有效。就像裁员一样,要求各个部门内裁员比例相同与在公司内控制裁员比例(各个部门裁员比例按重要度排列,裁员比例不一样),必然是第二种方式更有效。第一种方式,使低效率部门的靠前但无用员工保留下来了。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

#model = resnet18(pretrained=True)
model = resnet18()

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 3 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5, # remove 50%的channel
    ignored_layers=ignored_layers,
    global_pruning=True
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(f"iter {i} | rate:{macs/base_macs:.4f}  {nparams/base_nparams:.4f}")
print(model)
    # finetune your model here
    # finetune(model)
    # ...

2.3 MagnitudePruner中的参数

指定特定层的剪枝比例 通过pruning_ratio_dict参数,指定model.layer2的剪枝比例为20%,这里适用于有先验经验的layer,控制对特定layer的剪枝比例。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    imp,
    pruning_ratio = 0.5,
    pruning_ratio_dict = {model.layer2: 0.2}
)
pruner.step()
print(model)

代码执行后的层为:ResNet{64, 128, 256, 512} => ResNet{32, 102, 128, 256}

设置最大剪枝比例 通过 max_pruning_ratio 参数设置最大剪枝比例,避免由于稀疏剪枝或者自动剪枝时某个层被严重剪枝或者移除。

剪枝次数与剪枝调度器 您打算分多轮修剪模型,iterative_steps 会很有用。默认情况下,修剪器会逐渐增加模型的稀疏度,直到达到所需的 pruning_ratio。如以下代码,分5次实现剪枝目标。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
)

# prune the model, iteratively if necessary.
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print("Round %d/%d, Params: %.2f M" % (i+1, iterative_steps, nparams/1e6))
    # finetune your model here
    # finetune(model)
    # ...
print(model)

对应输出如下
Round 1/5, Params: 9.44 M
Round 2/5, Params: 7.45 M
Round 3/5, Params: 5.71 M
Round 4/5, Params: 4.20 M
Round 5/5, Params: 2.93 M

设置忽略的层 这主要是避免对输出层进行剪枝,修改模型的输出结构。使用代码如下,通过ignored_layers参数传入忽略的layer对象。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, # remove 50% channels
    ignored_layers=[model.conv1, model.fc] # ignore the first & last layers
)
pruner.step()
print(model)

channel取整 在很多的时候都认为channel为16的倍数,gpu运行效率最高。使用代码如下,通过round_to参数,保持channel是特定数的倍数。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.3, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    round_to=10 # round to 10x. Note: 10x is not a good practice.
)

pruner.step()
print(model)

channel_groups 某些层(例如 nn.GroupNorm 和 nn.Conv2d)具有 group 参数,这会在层内引入额外的依赖项。修剪后,保持所有组的大小相同至关重要。为了满足这一要求,引入了参数 channel_groups 以启用对这些通道的手动分组。如以下代码,通过channel_groups参数,控制model.group_conv1中的参数为8个一组

pruner = tp.pruner.MagnitudePruner(
            model,
            example_inputs=example_inputs,
            importance=importance,
            iterative_steps=1,
            pruning_ratio=0.5,
            channel_groups = {model.group_conv1: 8} # For Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), groups=8)
        )

额外参数剪枝 有些时候模型具备的可训练参数并非conv、fc等传统layer中,需要基于unwrapped_parameters参数将额外的可剪枝参数传入到剪枝器中。具体如下所示:

from torchvision.models.convnext import CNBlock, ConvNeXt
unwrapped_parameters = []
for m in model.modules():
    if isinstance(m, CNBlock):
        unwrapped_parameters.append( (m.layer_scale, 0) )

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, 
    unwrapped_parameters=unwrapped_parameters 

限定剪枝范围 root_module_types 参数用于指定组的“根”或第一层。在许多情况下,我们专注于修剪线性层和卷积 (Conv) 层。要专门针对这些层启用修剪,我们可以使用以下参数:root_module_types=[nn.Conv2D, nn.Linear]。这可确保将修剪应用于所需的层。

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, 
    root_module_types=[nn.Conv2D, nn.Linear]

3、具体应用案例

3.1 timm模型剪枝

官方代码为:examples\timm_models\prune_timm_models.py
具体详情如下,这里有一个特殊用法,是通过num_heads参数实现对于transformer layer的支持

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
os.environ['TIMM_FUSED_ATTN'] = '0'
import torch
import torch.nn as nn 
import torch.nn.functional as F
from typing import Sequence
import timm
from timm.models.vision_transformer import Attention
import torch_pruning as tp
import argparse

parser = argparse.ArgumentParser(description='Prune timm models')
parser.add_argument('--model', default=None, type=str, help='model name')
parser.add_argument('--pruning_ratio', default=0.5, type=float, help='channel pruning ratio')
parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning')
parser.add_argument('--pretrained', default=False, action='store_true', help='global pruning')
parser.add_argument('--list_models', default=False, action='store_true', help='list all models in timm')
args = parser.parse_args()

def main():
    timm_models = timm.list_models()
    if args.list_models:
        print(timm_models)
    if args.model is None: 
        return
    assert args.model in timm_models, "Model %s is not in timm model list: %s"%(args.model, timm_models)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = timm.create_model(args.model, pretrained=args.pretrained, no_jit=True).eval().to(device)

    imp = tp.importance.GroupNormImportance()
    print("Pruning %s..."%args.model)
        
    input_size = model.default_cfg['input_size']
    example_inputs = torch.randn(1, *input_size).to(device)
    test_output = model(example_inputs)
    ignored_layers = []
    num_heads = {}

    for m in model.modules():
        if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
            ignored_layers.append(model.head)
            print("Ignore classifier layer: ", m.head)
       
        # Attention layers
        if hasattr(m, 'num_heads'):
            if hasattr(m, 'qkv'):
                num_heads[m.qkv] = m.num_heads
                print("Attention layer: ", m.qkv, m.num_heads)
            elif hasattr(m, 'qkv_proj'):
                num_heads[m.qkv_proj] = m.num_heads

    print("========Before pruning========")
    print(model)
    base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
    pruner = tp.pruner.MetaPruner(
                    model, 
                    example_inputs, 
                    global_pruning=args.global_pruning, # If False, a uniform pruning ratio will be assigned to different layers.
                    importance=imp, # importance criterion for parameter selection
                    iterative_steps=1, # the number of iterations to achieve target pruning ratio
                    pruning_ratio=args.pruning_ratio, # target pruning ratio
                    num_heads=num_heads,
                    ignored_layers=ignored_layers,
                )
    for g in pruner.step(interactive=True):
        g.prune()

    for m in model.modules():
        # Attention layers
        if hasattr(m, 'num_heads'):
            if hasattr(m, 'qkv'):
                m.num_heads = num_heads[m.qkv]
                m.head_dim = m.qkv.out_features // (3 * m.num_heads)
            elif hasattr(m, 'qkv_proj'):
                m.num_heads = num_heads[m.qqkv_projkv]
                m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)

    print("========After pruning========")
    print(model)
    test_output = model(example_inputs)
    pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
    print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
    print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))

if __name__=='__main__':
    main()

3.2 llm模型剪枝

在examples\LLMs\prune_llama.py中提供了一个对于llama模型的剪枝案例.
核心代码如下,可以看到也是基于num_heads记录transformer的结构信息,然后在剪枝后将num_heads数据赋值到对应模型参数上。与原始代码相比,这里删除了模型精度验证相关的代码。


# Code adapted from 
# https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
# https://github.com/locuslab/wanda

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))

import argparse
import os 
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version
import time
import torch
import torch.nn as nn
from collections import defaultdict
import fnmatch
import numpy as np
import random

print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def get_llm(model_name, cache_dir="./cache"):
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float16, 
        cache_dir=cache_dir, 
        device_map="auto"
    )

    model.seqlen = model.config.max_position_embeddings 
    return model

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help='LLaMA model')
    parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
    parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
    parser.add_argument('--pruning_ratio', type=float, default=0, help='Sparsity level')
    parser.add_argument("--cache_dir", default="./cache", type=str )
    parser.add_argument('--save', type=str, default=None, help='Path to save results.')
    parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
    parser.add_argument("--eval_zero_shot", action="store_true")
    args = parser.parse_args()

    # Setting seeds for reproducibility
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    model_name = args.model.split("/")[-1]
    print(f"loading llm model {args.model}")
    model = get_llm(args.model, args.cache_dir)       
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    device = torch.device("cuda:0")
    if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
        device = model.hf_device_map["lm_head"]
    print("use device ", device)

    ##############
    # Pruning
    ##############
    print("----------------- Before Pruning -----------------")
    print(model)
    text = "Hello world."
    inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
    import torch_pruning as tp 
    num_heads = {}
    for name, m in model.named_modules():
        if name.endswith("self_attn"):
            num_heads[m.q_proj] = model.config.num_attention_heads
            num_heads[m.k_proj] = model.config.num_key_value_heads
            num_heads[m.v_proj] = model.config.num_key_value_heads
            
    head_pruning_ratio = args.pruning_ratio
    hidden_size_pruning_ratio = args.pruning_ratio
    pruner = tp.pruner.MagnitudePruner(
        model, 
        example_inputs=inputs,
        importance=tp.importance.GroupNormImportance(),
        global_pruning=False,
        pruning_ratio=hidden_size_pruning_ratio,
        ignored_layers=[model.lm_head],
        num_heads=num_heads,
        prune_num_heads=True,
        prune_head_dims=False,
        head_pruning_ratio=head_pruning_ratio,
    )
    pruner.step()

    # Update model attributes
    num_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads )
    num_key_value_heads = int( (1-head_pruning_ratio) * model.config.num_key_value_heads )
    model.config.num_attention_heads = num_heads
    model.config.num_key_value_heads = num_key_value_heads
    for name, m in model.named_modules():
        if name.endswith("self_attn"):
            m.hidden_size = m.q_proj.out_features
            m.num_heads = num_heads
            m.num_key_value_heads = num_key_value_heads
        elif name.endswith("mlp"):
            model.config.intermediate_size = m.gate_proj.out_features
    print("----------------- After Pruning -----------------")
    print(model)

    #ppl_test = eval_ppl(args, model, tokenizer, device)
    #print(f"wikitext perplexity {ppl_test}")

    if args.save_model:
        model.save_pretrained(args.save_model)
        tokenizer.save_pretrained(args.save_model)

if __name__ == '__main__':
    main()

3.3 目标检测模型剪枝

在Torch-Pruning 库中提供了针对yolov8、yolov7、yolov5的剪枝案例。关于yolov8还提供了剪枝后的训练策略,其主要技巧在与对不可剪枝层的可剪枝话处理(C2f模块的剪枝,其含split操作,不利于剪枝索引)。后续会补充博客,说明对yolov8的剪枝使用。

4、其他信息

4.1 剪枝器中的评价指标

在torch_pruning\pruner\importance.py中有很多个剪枝评价指标

__all__ = [
    # Base Class
    "Importance",

    # Basic Group Importance
    "GroupNormImportance",
    "GroupTaylorImportance",
    "GroupHessianImportance",

    # Aliases
    "MagnitudeImportance",
    "TaylorImportance",
    "HessianImportance",

    # Other Importance
    "BNScaleImportance",
    "LAMPImportance",
    "RandomImportance",
]

整体来看是TaylorImportance最好,一直使用该值即可。
来看

4.2 剪枝对性能精度的影响

在博客https://blog.csdn.net/a486259/article/details/140407147?spm=1001.2014.3001.5501 中基本确定了剪枝50%,对模型精度是没有任何影响的。这里对Torch-Pruning 库相关的论文数据进行二次核验,以致于分析出剪枝中速度提升对精度的影响。

以DepGraph: Towards Any Structural Pruning数据为例,可以发现最高支持6x速度剪枝后保持模型性能。
在这里插入图片描述
以LLM-Pruner: On the Structural Pruning of Large Language Models 论文数据为例,可以发现使用Vector评价方法的剪枝,移除10%的参数,zero-shot下对模型精度影响不大。而图4更表明,剪枝方法正确的话,移除50%的参数对模型性能影响也不大。
在这里插入图片描述
以论文 Structural Pruning for Diffusion Models 的数据为分析,同样可以发现剪枝50%左右的通道,对结果影响不对。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1926505.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python:setattr()函数和__setattr__()魔术方法

相关阅读 Pythonhttps://blog.csdn.net/weixin_45791458/category_12403403.html?spm1001.2014.3001.5482 setattr()是一个python内置的函数,用于设置一个对象的属性值,一般情况下,可以通过点运算符(.)完成相同的功能,但是getat…

[激光原理与应用-112]:南京科耐激光-激光焊接-焊中检测-智能制程监测系统IPM介绍 - 16 - 常见的产品指标

目录 一、光学传感器技术指标:实时信号采集与信号处理 (1)适用激光器的功率范围宽: (2)感光范围:350nm~1750nm:从可见光到红外光 (3)信号类型&#xff1a…

NSSCTF_RE(二)暑期

[CISCN 2021初赛]babybc LLVM是那个控制流平坦化混淆,bc是IR指令文件 得到64位elf文件 然后就慢慢分析,感觉太妙了我靠 一个数独游戏,用二个二维数组添加约束,一个二维数组作地图,慢慢看 最后用 z3 来解数独&#xf…

k8s快速部署一个网站

1)使用Deployment控制器部署镜像: kubectl create deployment web-demo --imagelizhenliang/web-demo:v1 kubectl get deployment,pods[rootk8s-matser ~]# kubectl get pods NAME READY STATUS RESTARTS A…

25届平安产险校招测评IQ新16PF攻略:全面解析与应试策略

尊敬的读者,您好。随着平安产险校招季的到来,许多应届毕业生正积极准备着各项测评。本文旨在提供一份详尽的测评攻略,帮助您更好地理解平安产险的校招测评流程,以及如何有效应对。 25届平安产险平安IQ(新)测…

Java 设计模式系列:外观模式

简介 外观模式(Facade Pattern)是一种设计模式,又名门面模式,是一种通过为多个复杂的子系统提供一个一致的接口,而使这些子系统更加容易被访问的模式。该模式对外有一个统一接口,外部应用程序不用关心内部…

2024-07-14 Unity插件 Odin Inspector2 —— Essential Attributes

文章目录 1 说明2 重要特性2.1 AssetsOnly / SceneObjectsOnly2.2 CustomValueDrawer2.3 OnValueChanged2.4 DetailedInfoBox2.5 EnableGUI2.6 GUIColor2.7 HideLabel2.8 PropertyOrder2.9 PropertySpace2.10 ReadOnly2.11 Required2.12 RequiredIn(*)2.…

基于Python thinker GUI界面的股票评论数据及投资者情绪分析设计与实现

1.绪论 1.1背景介绍 Python 的 Tkinter 库提供了创建用户界面的工具,可以用来构建股票评论数据及投资者情绪分析的图形用户界面(GUI)。通过该界面,用户可以输入股票评论数据,然后通过情感分析等技术对评论进行情绪分析…

昇思25天学习打卡营第14天 | ShuffleNet图像分类

昇思25天学习打卡营第14天 | ShuffleNet图像分类 文章目录 昇思25天学习打卡营第14天 | ShuffleNet图像分类ShuffleNetPointwise Group ConvolutionChannel ShuffleShuffleNet模块网络构建 模型训练与评估数据集训练模型评估模型预测 总结打卡 ShuffleNet ShuffleNetV1是旷世科…

大模型系列3--pytorch dataloader的原理

pytorch dataloader运行原理 1. 背景2. 环境搭建2.1. 安装WSL & vscode2.2. 安装conda & pytorch_gpu环境 & pytorch 2.112.3 命令行验证python环境2.4. vscode启用pytorch_cpu虚拟环境 3. 调试工具3.1. vscode 断点调试3.2. py-spy代码栈探测3.3. gdb attach3.4. …

基于锚框的物体检测过程

说明:基于锚框的物体检测过程:分为单阶段和两阶段 整体步骤: 提供目标候选区域: 锚框提供了一组预定义的候选区域,这些区域可以覆盖各种尺度和长宽比的目标。通过这些锚框,可以在不同的位置和不同的尺度上…

02-Charles的安装与配置

一、Charles的安装 Charles的下载地址:https://www.charlesproxy.com/。 下载之后,傻瓜式安装即可。 二、Charles组件介绍 主导航栏介绍: 请求导航栏介绍: 请求数据栏介绍: 三、Charles代理设置 四、客户端-windows代理…

【Linux】多线程_6

文章目录 九、多线程7. 生产者消费者模型生产者消费者模型的简单代码结果演示 未完待续 九、多线程 7. 生产者消费者模型 生产者消费者模型的简单代码 Makefile: cp:Main.ccg -o $ $^ -stdc11 -lpthread .PHONY:clean clean:rm -f cpThread.hpp: #i…

React学习笔记02-----

一、React简介 想实现页面的局部刷新,而不是整个网页的刷新。AJAXDOM可以实现局部刷新 1.特点 (1)虚拟DOM 开发者通过React来操作原生DOM,从而构建页面。 React通过虚拟DOM来实现,可以解决DOM的兼容性问题&#x…

NSSCTF_RE(一)暑期

[SWPUCTF 2021 新生赛]简单的逻辑 nss上附件都不对 没看明白怎么玩的 dnspy分析有三个 AchievePoint , game.Player.Bet - 22m; for (int i 0; i < Program.memory.Length; i) { byte[] array Program.memory; int num i; array[num] ^ 34; } Environment.SetEnvironment…

【CICID】GitHub-Actions-SpringBoot项目部署

[TOC] 【CICID】GitHub-Actions-SpringBoot项目部署 0 流程图 1 创建SprinBoot项目 ​ IDEA创建本地项目&#xff0c;然后推送到 Github 1.1 项目结构 1.2 Dockerfile文件 根据自身项目&#xff0c;修改 CMD ["java","-jar","/app/target/Spri…

Scrapy框架实现数据采集的详细步骤

需求描述&#xff1a; 本项目目标是使用Scrapy框架从宁波大学经济学院网站&#xff08;nbufe.edu.cn&#xff09;爬取新闻或公告详情页的内容。具体需求如下&#xff1a; 1、通过遍历多个页面&#xff08;共55页&#xff09;构建翻页URL。 2、使用scrapy自带的xpath从每页的…

STM32智能机器人避障系统教程

目录 引言环境准备智能机器人避障系统基础代码实现&#xff1a;实现智能机器人避障系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景&#xff1a;机器人导航与避障问题解决方案与优化收尾与总结 1. 引言 智能机器人避…

Android ImageDecoder把瘦高/扁平大图相当于fitCenter模式decode成目标小尺寸Bitmap,Kotlin

Android ImageDecoder把瘦高/扁平大图相当于fitCenter模式decode成目标小尺寸Bitmap&#xff0c;Kotlin val sz Size(MainActivity.SIZE, MainActivity.SIZE)val src ImageDecoder.createSource(mContext?.contentResolver!!, uri)val bitmap ImageDecoder.decodeBitmap(sr…

iPhone数据恢复篇:在 iPhone 上恢复找回短信的 5 种方法

方法 1&#xff1a;检查最近删除的文件夹 iOS 允许您在 30 天内恢复已删除的短信。您需要先从“设置”菜单启用“过滤器”。让我们来实际检查一下。 步骤 1&#xff1a;打开“设置” > “信息”。 步骤 2&#xff1a;选择“未知和垃圾邮件”&#xff0c;然后切换到“过滤…