伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP](2)——代码分析
ASP整个模块的结果如下:
.
├── COPYRIGHT
├── README.md
├── __init__.py
├── asp.py
├── permutation_lib.py
├── permutation_search_kernels
│ ├── CUDA_kernels
│ │ └── permutation_search_kernels.cu
│ ├── __init__.py
│ ├── call_permutation_search_kernels.py
│ ├── channel_swap.py
│ ├── exhaustive_search.py
│ └── permutation_utilities.py
├── permutation_tests
│ ├── README.md
│ ├── ablation_studies.sh
│ ├── permutation_test.py
│ ├── runtime_table.sh
│ └── unstructured_study.sh
├── sparse_masklib.py
└── test
├── checkpointing_test_part1.py
├── checkpointing_test_part2.py
├── checkpointing_test_reference.py
├── test_permutation_application.py
└── toy_problem.py
共包含三个主要文件:
- asp.py
- permutation_lib. py
- sparse_masklib.py
以及三个主要目录
- permutation_search_kernels
- permutation_tests
- test
其中目录test用于展示一些具体的实例,目录permutation_tests是一个单独的模块,用于复现论文中的实验,这两个目录不用关注。如果不需要使用通道置换算法的话,目录permutation_search_kernels和文件permutation_lib.py也不需要关注。
因此,ASP源代码中最主要的还是asp.py文件和sparse_masklib.py文件,如果需要使用通道置换算法的话,可以在此基础上探询一下permutation_search相关的算法和代码实现。
asp.py文件
ASP类
asp.py中主定义了ASP类,其成员函数定义了init_model_for_pruning
、init_optimizer_for_pruning
、compute_sparse_masks
、already_init_asp_model
、restore_pruned_weights
、is_sparsity_enabled
、prune_trained_model
、set_permutation_saving_params
八个静态方法,分别用于对模型、优化器进行稀疏初始化、计算稀疏mask、检查模型是否已经进行稀疏初始化,检查模型是否进行了稀疏化,恢复模型的权重以及为通道设置算法设置参数。其中最主要的是prune_trained_model
及其调用的init_model_for_pruning
、init_optimizer_for_pruning
、compute_sparse_masks
三个方法。
成员变量
__model = None # 待处理的模型
__verbosity = 0 # 输出信息的详细程度
__optimizer = None # 待处理的优化器
__sparse_parameters = [] # 用于保存稀疏参数信息
__calculate_mask = None # 一个函数指针,能够通过传入的tensor的shape为tensor生成相应的mask
__allow_permutation = True # 是否需要开启通道置换算法
__all_parameters = [] # 用于保存模型中所有参数的信息
__save_permutation_graph = False # 是否保存通道置换的graph
__permutation_output_dir = '' # 通道置换信息的输出目录
成员函数
- prune_trained_model
prune_trained_model是用法介绍中需要在模型训练文件中需要添加的两行代码之一,也是ASP模块的使用入口:
@classmethod
def prune_trained_model(cls, model, optimizer):
# add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention], allow_recompute_mask=False)
cls.init_optimizer_for_pruning(optimizer)
cls.compute_sparse_masks()
prune_trained_model方法接受两个参数,分别是需要训练后的模型和优化器。
该方法中又分别调用了三个方法:首先使用init_model_for_pruning
,init_optimizer_for_pruning
方法分别对模型和优化器中的权重进行分析和初始化准备工作(为模型添加mask buffer),并通过compute_sparse_masks
方法为每个权重计算生成对应的稀疏mask。
- init_model_for_pruning
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MultiheadAttention],
allowed_layer_names=None,
disallowed_layer_names=[],
allow_recompute_mask=False,
custom_layer_dict={},
allow_permutation=True):
assert (cls.__model is None), "ASP has been initialized already."
cls.__model = model
cls.__verbosity = verbosity
cls.__allow_permutation = allow_permutation
if isinstance(mask_calculator, str):
def create_mask_from_pattern(param):
return create_mask(param, mask_calculator).bool()
cls.__calculate_mask = create_mask_from_pattern
else:
cls.__calculate_mask = mask_calculator #user defined function
# function to extract variables that will be sparsified.
# idea is that you will add one of these functions for each module type that can be sparsified.
if torchvision_imported:
print("[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.")
torchvision_version = str(torchvision.__version__)
torchvision_version_major = int(torchvision_version.split('.')[0])
torchvision_version_minor = int(torchvision_version.split('.')[1])
if torchvision_version_major == 0 and torchvision_version_minor < 12:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'], torchvision.ops.misc.Conv2d: ['weight']}
else: # Torchvision remove APIs that were deprecated before 0.8 (#5386) in 0.12.0, torchvision.ops.misc.Conv2d is removed
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']}
else:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']}
if custom_layer_dict: # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune
sparse_parameter_list.update(custom_layer_dict)
whitelist += list(custom_layer_dict.keys())
for module_type in whitelist:
assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()
先看看官方给出的注释:
Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA, it does not enable use of sparse MMA.
注释指出init_model_for_pruning
方法仅仅为模型添加了额外的mask buffer,并没有实际上启用sparse MMA.
参数列表:
model The model
mask_calculator Either callable that computes mask given a tensor OR pattern string for sparse mask lib.
verbosity Integer controling verbosity level.
0 -> Only errors.
1 -> Errors and warnings.
2 -> Errors, warnings and info.
3 -> Errors, warnings, info and debug.
whitelist Module types approved for sparsity.
allowed_layer_names If not None, only layer names that appear in this list are considered for sparsity.
disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity.
allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}
allow_permutation If True, allow the input channel permutation to ease the influence of weight pruning.
init_model_for_pruning
方法主要做了这样几件事情:
- 使用传入的参数对静态类ASP进行初始化,以便后续的处理
cls.__model = model
cls.__verbosity = verbosity
cls.__allow_permutation = allow_permutation
- 设置了一个函数指针,用来为特定的tensor生成sparse mask。
if isinstance(mask_calculator, str):
def create_mask_from_pattern(param):
return create_mask(param, mask_calculator).bool()
cls.__calculate_mask = create_mask_from_pattern
else:
cls.__calculate_mask = mask_calculator #user defined function
""" returns a sparse mask """
def create_mask(tensor, pattern="m4n2_1d", density=0.5):
# Reshape tensor and mask.
shape = tensor.shape
ttype = tensor.type()
t = tensor.float().contiguous()
# 1d-tensor
if len(shape) == 1:
t = t.view(1, shape[0])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
# 2d-tensor (K, C)
elif len(shape) == 2:
# linear
t = t.view(shape[0], shape[1])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
# 3d-tensor (K, C, R)
elif len(shape) == 3:
# 1d convs
t = t.permute(0,2,1).contiguous().view(shape[0]*shape[2], shape[1])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
mask = mask.view(shape[0], shape[2], shape[1]).permute(0,2,1).contiguous()
return mask.view(shape).type(ttype)
# 4d-tensor (K, C, R, S)
elif len(shape) == 4:
"""
# transformers (bmm)
t = t.view(shape[0]*shape[1]*shape[2], shape[3])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
"""
# 2d convs
t = t.permute(2,3,0,1).contiguous().view(shape[2]*shape[3]*shape[0], shape[1])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2,3,0,1).contiguous()
return mask.view(shape).type(ttype)
def m4n2_1d(mat, density):
return mn_1d_best(mat, 4, 2)
def mn_1d_best(matrix, m, n):
# Find all possible patterns.
patterns = compute_valid_1d_patterns(m,n).cuda()
# Find the best m:n pattern (sum of non-masked weights).
mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)
mat,shape = reshape_1d(matrix,m)
pmax = torch.argmax(torch.matmul(mat.abs(),patterns.t()), dim=1)
mask[:] = patterns[pmax[:]]
mask = mask.view(matrix.shape)
return mask
- 遍历模型中每一层的权重,为特定层的特定权重申请buffer并将权重加入
__sparse_parameters
中,用于后续mask的计算。
那么,如何确定到底为哪些层的哪些权重来申请buffer、生成mask呢?
init_model_for_pruning
方法首先会根据是否导入了torchvision、以及torchvision的版本来确定一个sparse_parameter_list,其实际是以一个字典的形式记录着目前所支持的被稀疏的模块以及对应的参数:
torchvision_imported=True
try:
import torchvision
except ImportError:
print("[ASP][Warning] torchvision cannot be imported.")
torchvision_imported=False
if torchvision_imported:
print("[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.")
torchvision_version = str(torchvision.__version__)
torchvision_version_major = int(torchvision_version.split('.')[0])
torchvision_version_minor = int(torchvision_version.split('.')[1])
if torchvision_version_major == 0 and torchvision_version_minor < 12:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'], torchvision.ops.misc.Conv2d: ['weight']}
else: # Torchvision remove APIs that were deprecated before 0.8 (#5386) in 0.12.0, torchvision.ops.misc.Conv2d is removed
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']}
else:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']}
除此之外,init_model_for_pruning
方法还会根据传入的custom_layer_dict
, whitelist
, allowed_layer_names
, disallowed_layer_names
等参数来最终确定到底需要为当前模型中具体哪个模块的哪个参数进行稀疏化。除此之外,还会检查这些参数的shape是否符合要求,如果不符合要求会跳过该参数,不做稀疏。
接下来,init_model_for_pruning
方法会为符合要求的参数创建一个buffer,命名为xxx_mma_mask
,如果allow_recompute_mask=True
,那么还会为参数创建一个额外的buffer,名为xxx_mma_pruned_p
。
最后,init_model_for_pruning
方法会将所有符合条件的参数的相关信息加入__sparse_parameters
中
关于permutation search的部分暂且不提。
# 找到需要稀疏化且支持进行稀疏化的模块
def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):
eligible_modules_list = []
for name, mod in model.named_modules():
if isinstance(mod, whitelist_layer_types) and name not in disallowed_layer_names:
if allowed_layer_names is not None and name not in allowed_layer_names:
continue
eligible_modules_list.append((name, mod))
return eligible_modules_list
# 对需要且支持进行稀疏化的模块进行处理
for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):
add_sparse_attributes(name, sparse_module)
# 对每个模块中的支持的参数类型进行处理
def add_sparse_attributes(module_name, module):
sparse_parameters = sparse_parameter_list[type(module)]
for p_name, p in module.named_parameters():
if p_name in sparse_parameters and p.requires_grad:
# check for NVIDIA's TC compatibility: we check along the horizontal direction
if p.dtype == torch.float32 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0):
#User defines FP32 and APEX internally uses FP16 math
continue
if p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0):
#For Conv2d dim= K x CRS; we prune along C
continue
p = p.t().contiguous()
print("---------------{}", p.shape)
model.state_dict[p_name] = p
mask = torch.ones_like(p).bool()
buffname = p_name.split(".")[-1] # buffer names cannot contain "."
module.register_buffer('__%s_mma_mask' % buffname, mask)
# 如果需要多次计算mask,那么需要将模型中被剪枝的参数保存下来,方便重新计算mask的时候使用
# 因此需要额外申请一个用于存储原始数据的Buffer,以xxx_mma_pruned_p来命名
if allow_recompute_mask:
pruned = torch.zeros_like(p).cpu()
module.register_buffer('__%s_mma_pruned_p' % buffname, pruned)
else:
pruned = None
cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))
else:
continue
if allow_permutation:
......
- init_optimizer_for_pruning
Call this method to monkey patch optimizer step function so that masks can be applied to gradients and weights during training.
You must call init_model_for_pruning(…) before calling init_optimizer_for_pruning(…)
官方给出的注释中,说明了两点:
首先,init_optimizer_for_pruning
方法的作用是在训练时让mask参与梯度和权重的计算。
其次,强调调用init_optimizer_for_pruning
前必须调用init_optimizer_for_pruning
方法。
接下来是源代码:
@classmethod
def init_optimizer_for_pruning(cls, optimizer):
assert (cls.__optimizer is None), "ASP has initialized optimizer already."
assert (cls.__calculate_mask is not None), "Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."
# store pointer to original optimizer step method
cls.__optimizer = optimizer
cls.__optimizer.__step = optimizer.step
def __step(opt_self, *args, **kwargs):
# prune gradients before step method
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if p.grad is not None: #thx pjudd
p.grad.mul_(mask)
# call original optimizer step method
rval = opt_self.__step(*args, **kwargs)
# prune parameters after step method
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
p.mul_(mask)
return rval
cls.__optimizer.step = types.MethodType(__step, cls.__optimizer)
init_optimizer_for_pruning
方法主要通过对原来的optimizer的step方法进行重写,从而实现在optimizer每次执行step方法前后对梯度和权重进行剪枝。
首先ASP先将__optimier指向原始的optimizer,由于Python对复杂对象的赋值操作其实相当于是为optimizer建立了一个新的引用 ,二者指向同一个对象。
同时又为__optimizer创建了一个名为__step的引用,指向optimizer的step方法。
紧接着,init_optimizer_for_pruning
方法定义了一个内部方法__step,该方法调用了原来optimizer的step方法,并在调用前后分别对__sparse_parameters
中的梯度和参数进行剪枝。
最后,将新定义的__step方法绑定给__optimizer,并让optimizer的step方法指向它,实现optimizer的step方法的重写
为了方便理解,内存模型画了一个示意图:
- compute_sparse_masks
做完了准备工作,下面才是真正enable sparsity特性的时候。
为了方便阅读,删掉了打印提示信息的部分代码
@classmethod
def compute_sparse_masks(cls):
"""Call this method to enable sparsity.
If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.
"""
with torch.no_grad():
if cls.__allow_permutation:
......
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
# mask在init_model_pruning中初始化为ones_like(p)
# 如果mask.sum() < mask.numel(),则代表mask和p是稀疏的,之前已经enable 过sparsity特性了,现在是再次调用compute_mask方法
if mask.sum() < mask.numel(): # when recalculating masks
# restore dense parameter if allow_recompute_mask is enabled
# allow_recompute_mask=True : pruned = zeros_like(p)
# allow_recompute_mask=False: pruned = None
assert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False"
p.add_(pruned.cuda())
mask.set_(cls.__calculate_mask(p))
if pruned is not None: # stow away pruned weights to cpu
pruned.set_((p * (~mask)).cpu())
p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights
跳过permutation search的部分,compute_sparse_masks
方法先通过ask.sum() < mask.numel()?判断之前是否计算过mask的值,从而判断之前是否已经对模型进行过剪枝。如果之前已经进行过剪枝,则需要先从pruned中将之前保存的完整参数进行恢复。随后调用init_model_for_pruning
方法中设置好的函数指针
.__calculate_mask
为每个参数计算sparse mask,并将其乘上对应的参数,从而实现对参数的剪枝。