torch_pruning库介绍
在结构修剪中,**Group被定义为深度网络中最小的可移除单元。**每个组由多个相互依赖的层组成,需要同时修剪这些层以保持最终结构的完整性。然而,深度网络通常表现出层与层之间错综复杂的依赖关系,这对结构修剪提出了重大挑战。这项研究通过引入DepGraph轻松实现参数分组,并有助于修剪各种深度网络。
如修剪图中高亮的神经元,我们需要对所有有连线的层都进行修剪。(a) W_l、W_l+1 (b) W_l、W_l+1、W_l+2 © W_l、W_l+1、W_l+2 (d) W_l
一个简单的例子
我们修剪resnet18,把其第一个卷积的输出通道维度减少3。
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True).eval()
print(model)
# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))
# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )
print(group)
# 3. prune all grouped layers that are coupled with model.conv1 (included).
if DG.check_pruning_group(group): # avoid full pruning, i.e., len(channels)=0.
group.prune()
print(model)
跟conv1相关的层都被修剪了,剪的是第2,6,9维,但这种方式不给灵活,只能修剪固定的索引,下面我实现更灵活的方式。
下面我们对整个resnet18剪枝,使其通道数减半,ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
这里需要我们做两件事,第一,重要性函数,抽象的说就是评估每个组的重要性,不重要的就可以修剪掉,可以使用内置重要性函数,也可以自定义。第二,配置剪枝器,如使用剪枝多少步,最后的稀疏性是多少,这里我们设置稀疏性是0.5。
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True)
# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!
iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
importance=imp,
iterative_steps=iterative_steps,
ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
ignored_layers=ignored_layers,
)
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"Before Pruning: MACs={base_macs / 1e9: .5f} G, #Params={base_nparams / 1e6: .5f} M")
for i in range(iterative_steps):
if isinstance(imp, tp.importance.TaylorImportance):
# Taylor expansion requires gradients for importance estimation
loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
loss.backward() # before pruner.step()
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"{i+1}/{iterative_steps} Pruning: MACs={macs / 1e9: .5f} G, #Params={nparams / 1e6: .5f} M")
# finetune your model here
# finetune(model),model.train()
# ...
在YOLOv8中剪枝
通过上面的步骤,发现剪枝好像并不困难,主要与ultralytics框架集成会麻烦点,核心剪枝代码在下面,有注释,完整代码在这里:yolov8_pruning.py
def prune(args):
# 加载模型,yaml,pt
model = YOLO(args.model)
# 加入train_v2方法在YOLO对象中,主要是重写train方法,为了不让model.train()每次新建一个模型来训练,这样会导致剪枝失败。
# 解决方法就是在训练完后重新加载训练完的权重到YOLO.model中。
model.__setattr__("train_v2", train_v2.__get__(model))
pruning_cfg = yaml_load(check_yaml(args.cfg))
batch_size = pruning_cfg['batch']
model.model.train()
# split操作不支持剪枝,使用两个卷积替换split操作
replace_c2f_with_c2f_v2(model.model)
initialize_weights(model.model) #设置 BN.eps, momentum, ReLU.inplace
# 开启梯度训练
for name, param in model.model.named_parameters():
param.requires_grad = True
example_inputs = torch.randn(1, 3, pruning_cfg["imgsz"], pruning_cfg["imgsz"]).to(model.device)
# 保存浮点数、参数量、mAP和剪枝mAP的记录
macs_list, nparams_list, map_list, pruned_map_list = [], [], [], []
# 计算浮点数和参数量
base_macs, base_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)
# 在剪枝操作之前先评估一次模型
pruning_cfg['name'] = f"baseline_val"
validation_model = deepcopy(model)
metric = validation_model.val(**pruning_cfg)
init_map = metric.box.map
# 保存浮点数、参数量、mAP和剪枝mAP的记录
macs_list.append(base_macs)
nparams_list.append(base_nparams)
map_list.append(init_map)
pruned_map_list.append(init_map)
print(f"Before Pruning: MACs={base_macs / 1e9: .5f} G, #Params={base_nparams / 1e6: .5f} M, mAP={init_map: .5f}")
# 每一步的剪枝率
pruning_ratio = 1 - math.pow((1 - args.target_prune_rate), 1 / args.iterative_steps)
print(pruning_ratio)
# 这里可以发现剪枝器可以在循环里或者循环外,虽然最终模型的稀疏性是没变化的,但每剪一次微调一次的效果会更好。
for i in range(args.iterative_steps):
model.model.train()
for name, param in model.model.named_parameters():
param.requires_grad = True
ignored_layers = []
unwrapped_parameters = []
# 忽略的层,一般都对头部网络进行忽略,如果是目标检测就换成Detect,记得先引入,这是个类。
for m in model.model.modules():
if isinstance(m, (Segment,)):
ignored_layers.append(m)
example_inputs = example_inputs.to(model.device)
pruner = tp.pruner.GroupNormPruner(
model.model,
example_inputs,
importance=tp.importance.GroupNormImportance(), # L2 norm pruning,
iterative_steps=1,
pruning_ratio=pruning_ratio,
ignored_layers=ignored_layers,
unwrapped_parameters=unwrapped_parameters
)
pruner.step()
# 剪枝完后先评估一遍模型
pruning_cfg['name'] = f"step_{i}_pre_val"
validation_model.model = deepcopy(model.model)
metric = validation_model.val(**pruning_cfg)
pruned_map = metric.box.map
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruner.model, example_inputs.to(model.device))
current_speed_up = float(macs_list[0]) / pruned_macs
print(f"After pruning iter {i + 1}: MACs={pruned_macs / 1e9} G, #Params={pruned_nparams / 1e6} M, "
f"mAP={pruned_map}, speed up={current_speed_up}")
# 微调模型,重新训练,一般10-50epochs?
for name, param in model.model.named_parameters():
param.requires_grad = True
pruning_cfg['name'] = f"step_{i}_finetune"
pruning_cfg['batch'] = batch_size # restore batch size
model.train_v2(pruning=True, **pruning_cfg)
# 微调完后再评估一遍模型
pruning_cfg['name'] = f"step_{i}_post_val"
validation_model = YOLO(model.trainer.best)
metric = validation_model.val(**pruning_cfg)
current_map = metric.box.map
print(f"After fine tuning mAP={current_map}")
macs_list.append(pruned_macs)
nparams_list.append(pruned_nparams / base_nparams * 100)
pruned_map_list.append(pruned_map)
map_list.append(current_map)
# 移除剪枝器
del pruner
save_pruning_performance_graph(nparams_list, map_list, macs_list, pruned_map_list)
if init_map - current_map > args.max_map_drop and current_speed_up>=1.2:
print("Pruning early stop")
break
model.export(format='onnx')