【6s965-fall2022】剪枝✂pruningⅡ

news2024/11/25 4:46:56

剪枝比例

问题:我们应该如何找到每层的剪枝比率?

  • 较浅的层,低层次的特征
  • 较深的层,抽象的特征

问题:哪些层的冗余度最高?

  • 非统一剪枝(每一层的稀疏度不一样)比统一剪枝(每一层的稀疏度都设成一样)的效果更好
  • 较深的层有更多的冗余(更多的通道,更多的重复使用的特征),所以它们可以被更积极地修剪

分析每个层的敏感性

  • 敏感性:当该层被修剪时,准确率下降了多少
  • 敏感度较高的层应减少修剪,敏感度较低的层表明有冗余。

敏感度分析

  • 在模型中挑选一个层 L i L_i Li
    • 设定一组剪枝比例,即 r ∈ { 0.1 , 0.2 , … , 0.9 } r\in \{0.1,0.2,\dots,0.9\} r{0.1,0.2,,0.9}
    • 观察每一个比例 r r r对准确率的影响,记录每个比例时的 Δ Acc i r \Delta{\text{Acc}_i^r} ΔAccir
  • 对所有层进行重复
  • 设定恢复的准确率阈值 T T T,找到每个层对应的剪枝比例阈值

请添加图片描述
优点

  • 很容易看到哪些层对修剪最不敏感
  • 实现简单

缺点

  • 忽略了各层之间的相互作用,如果两个层同时被修剪,准确性会如何下降?
  • 忽略了层的参数大小,对大层进行少量修剪比对小层进行大量修剪要好。

实现

在上一篇的基础上,实现敏感度扫描记录

@torch.inference_mode()
def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    verbose = True,
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
) -> float:
    model.eval()

    num_samples = 0
    num_correct = 0

    for inputs, targets in tqdm(dataloader, desc="eval", leave=False, disable=not verbose):
        # Move the data from CPU to GPU
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Inference
        outputs = model(inputs)

        # Convert logits to class indices
        outputs = outputs.argmax(dim=1)

        # Update metrics
        num_samples += targets.size(0)
        num_correct += (outputs == targets).sum()

    return (num_correct / num_samples * 100).cpu().item()


@torch.no_grad()
def sensitivity_scan(model, dataloader, scan_step=0.1, scan_start=0.4, scan_end=1.0, verbose=True):
    sparsities = np.arange(start=scan_start, stop=scan_end, step=scan_step)
    accuracies = []
    named_conv_weights = [(name, param) for (name, param) in model.named_parameters() if param.dim() > 1]
    # choose one layer
    for i_layer, (name, param) in enumerate(named_conv_weights):
    	# keep the param to recover
        param_clone = param.detach().clone()
        accuracy = []
        # choose one sparsity
        for sparsity in tqdm(sparsities, desc=f'scanning {i_layer}/{len(named_conv_weights)} weight - {name}'):
            # prune the layer
            fine_grained_prune(param.detach(), sparsity=sparsity)
            acc = evaluate(model, dataloader, verbose=False)
            if verbose:
                print(f'\r    sparsity={sparsity:.2f}: accuracy={acc:.2f}%', end='')
            # restore
            param.copy_(param_clone)
            accuracy.append(acc)
        if verbose:
            print(f'\r    sparsity=[{",".join(["{:.2f}".format(x) for x in sparsities])}]: accuracy=[{", ".join(["{:.2f}%".format(x) for x in accuracy])}]', end='')
        accuracies.append(accuracy)
    return sparsities, accuracies
sparsities, accuracies = sensitivity_scan(model, dataloader['test'], scan_step=0.1, scan_start=0.4, scan_end=1.0)

在这里插入图片描述
可视化

def plot_sensitivity_scan(sparsities, accuracies, dense_model_accuracy):
    lower_bound_accuracy = 100 - (100 - dense_model_accuracy) * 1.5
    fig, axes = plt.subplots(3, int(math.ceil(len(accuracies) / 3)),figsize=(15,8))
    axes = axes.ravel()
    plot_index = 0
    for name, param in model.named_parameters():
        if param.dim() > 1:
            ax = axes[plot_index]
            curve = ax.plot(sparsities, accuracies[plot_index])
            line = ax.plot(sparsities, [lower_bound_accuracy] * len(sparsities))
            ax.set_xticks(np.arange(start=0.4, stop=1.0, step=0.1))
            ax.set_ylim(80, 95)
            ax.set_title(name)
            ax.set_xlabel('sparsity')
            ax.set_ylabel('top-1 accuracy')
            ax.legend([
                'accuracy after pruning',
                f'{lower_bound_accuracy / dense_model_accuracy * 100:.0f}% of dense model accuracy'
            ])
            ax.grid(axis='x')
            plot_index += 1
    fig.suptitle('Sensitivity Curves: Validation Accuracy vs. Pruning Sparsity')
    fig.tight_layout()
    fig.subplots_adjust(top=0.925)
    plt.show()
plot_sensitivity_scan(sparsities, accuracies, dense_model_accuracy)

在这里插入图片描述
根据上图进行敏感度分析,设定每一层的稀疏度

sparsity_dict = {
    'backbone.conv0.weight': 0.55,
    'backbone.conv1.weight': 0.85,
    'backbone.conv2.weight': 0.8,
    'backbone.conv3.weight': 0.75,
    'backbone.conv4.weight': 0.7,
    'backbone.conv5.weight': 0.8,
    'backbone.conv6.weight': 0.8,
    'backbone.conv7.weight': 0.9,
    'classifier.weight': 0.9
}
pruner = FineGrainedPruner(model, sparsity_dict)

微调

num_finetune_epochs = 5
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()

best_sparse_model_checkpoint = dict()
best_accuracy = 0
print(f'Finetuning Fine-grained Pruned Sparse Model')
for epoch in range(num_finetune_epochs):
    # At the end of each train iteration, we have to apply the pruning mask 
    #    to keep the model sparse during the training
    train(model, dataloader['train'], criterion, optimizer, scheduler,
          callbacks=[lambda: pruner.apply(model)], device=device)
    accuracy = evaluate(model, dataloader['test'], device=device)
    # save the best model
    is_best = accuracy > best_accuracy
    if is_best:
        best_sparse_model_checkpoint['state_dict'] = copy.deepcopy(model.state_dict())
        best_accuracy = accuracy
    print(f'    Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%')

在这里插入图片描述

基于通道的剪枝 Channel Pruning

  • # o u t _ c h a n n e l s n e w = # o u t _ c h a n n e l s o r i g i n ⋅ ( 1 − s p a r s i t y ) \#\mathrm{out\_channels}_{\mathrm{new}} = \#\mathrm{out\_channels}_{\mathrm{origin}} \cdot (1 - \mathrm{sparsity}) #out_channelsnew=#out_channelsorigin(1sparsity)

  • 通道修剪后,权重张量 W W W仍然是密集的

  • 对所有层使用一样的剪枝比例

  • 找到不那么重要的通道权重来删除

  • i m p o r t a n c e i = ∥ W i ∥ 2 ,      i = 0 , 1 , 2 , ⋯   , # i n _ c h a n n e l s − 1 importance_{i} = \|W_{i}\|_2, \;\; i = 0, 1, 2,\cdots, \#\mathrm{in\_channels}-1 importancei=Wi2,i=0,1,2,,#in_channels1

  • 对于每个输入的通道,使用Frobenius规范来计算 i m p o r t a n c e importance importance

  • 将通道权重从重要到不重要进行排序,然后保留每层中最重要的 k k k个通道

def get_num_channels_to_keep(channels: int, prune_ratio: float) -> int:
    """A function to calculate the number of layers to PRESERVE after pruning
    Note that preserve_rate = 1. - prune_ratio
    """
    return int(round(channels * (1. - prune_ratio)))


@torch.no_grad()
def channel_prune(model: nn.Module, 
                  prune_ratio: Union[List, float]) -> nn.Module:
    """Apply channel pruning to each of the conv layer in the backbone
    Note that for prune_ratio, we can either provide a floating-point number,
    indicating that we use a uniform pruning rate for all layers, or a list of
    numbers to indicate per-layer pruning rate.
    """
    # sanity check of provided prune_ratio
    assert isinstance(prune_ratio, (float, list))
    n_conv = len([m for m in model.backbone if isinstance(m, nn.Conv2d)])
    # note that for the ratios, it affects the previous conv output and next
    # conv input, i.e., conv0 - ratio0 - conv1 - ratio1-...
    if isinstance(prune_ratio, list):
        assert len(prune_ratio) == n_conv - 1
    else:  # convert float to list
        prune_ratio = [prune_ratio] * (n_conv - 1)

    # we prune the convs in the backbone with a uniform ratio
    model = copy.deepcopy(model)  # prevent overwrite
    # we only apply pruning to the backbone features
    all_convs = [m for m in model.backbone if isinstance(m, nn.Conv2d)]
    all_bns = [m for m in model.backbone if isinstance(m, nn.BatchNorm2d)]
    # apply pruning. we naively keep the first k channels
    assert len(all_convs) == len(all_bns)
    for i_ratio, p_ratio in enumerate(prune_ratio):
        prev_conv = all_convs[i_ratio]
        prev_bn = all_bns[i_ratio]
        next_conv = all_convs[i_ratio + 1]
        original_channels = prev_conv.out_channels  # same as next_conv.in_channels
        n_keep = get_num_channels_to_keep(original_channels, p_ratio)

        # prune the output of the previous conv and bn
        prev_conv.weight.set_(prev_conv.weight.detach()[:n_keep])
        prev_bn.weight.set_(prev_bn.weight.detach()[:n_keep])
        prev_bn.bias.set_(prev_bn.bias.detach()[:n_keep])
        prev_bn.running_mean.set_(prev_bn.running_mean.detach()[:n_keep])
        prev_bn.running_var.set_(prev_bn.running_var.detach()[:n_keep])

        # prune the input of the next conv
        next_conv.weight.set_(next_conv.weight.detach()[:, :n_keep])

    return model
# function to sort the channels from important to non-important
def get_input_channel_importance(weight):
    in_channels = weight.shape[1]
    importances = []
    # compute the importance for each input channel
    for i_c in range(weight.shape[1]):
        channel_weight = weight.detach()[:, i_c]
        importance = torch.norm(channel_weight)
        importances.append(importance.view(1))
    return torch.cat(importances)


@torch.no_grad()
def apply_channel_sorting(model):
    model = copy.deepcopy(model)  # do not modify the original model
    # fetch all the conv and bn layers from the backbone
    all_convs = [m for m in model.backbone if isinstance(m, nn.Conv2d)]
    all_bns = [m for m in model.backbone if isinstance(m, nn.BatchNorm2d)]
    # iterate through conv layers
    for i_conv in range(len(all_convs) - 1):
        # each channel sorting index, we need to apply it to:
        # - the output dimension of the previous conv
        # - the previous BN layer
        # - the input dimension of the next conv (we compute importance here)
        prev_conv = all_convs[i_conv]
        prev_bn = all_bns[i_conv]
        next_conv = all_convs[i_conv + 1]
        # note that we always compute the importance according to input channels
        importance = get_input_channel_importance(next_conv.weight)
        # sorting from large to small
        sort_idx = torch.argsort(importance, descending=True) 

        # apply to previous conv and its following bn
        prev_conv.weight.copy_(torch.index_select(
            prev_conv.weight.detach(), 0, sort_idx))
        for tensor_name in ['weight', 'bias', 'running_mean', 'running_var']:
            tensor_to_apply = getattr(prev_bn, tensor_name)
            tensor_to_apply.copy_(
                torch.index_select(tensor_to_apply.detach(), 0, sort_idx)
            )
        
        # apply to the next conv input (hint: one line of code)
        next_conv.weight.copy_(
            torch.index_select(next_conv.weight.detach(), 1, sort_idx)
        )


    return model

剪枝后

channel_pruning_ratio = 0.3  # pruned-out ratio

print(" * Without sorting...")
pruned_model = channel_prune(model, channel_pruning_ratio)
pruned_model_accuracy = evaluate(pruned_model, dataloader['test'], device=device)
print(f"pruned model has accuracy={pruned_model_accuracy:.2f}%")


print(" * With sorting...")
sorted_model = apply_channel_sorting(model)
pruned_model = channel_prune(sorted_model, channel_pruning_ratio)
pruned_model_accuracy = evaluate(pruned_model, dataloader['test'], device=device)
print(f"pruned model has accuracy={pruned_model_accuracy:.2f}%")

在这里插入图片描述
微调后

num_finetune_epochs = 5
optimizer = torch.optim.SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()

best_accuracy = 0
for epoch in range(num_finetune_epochs):
    train(pruned_model, dataloader['train'], criterion, optimizer, scheduler, device=device)
    accuracy = evaluate(pruned_model, dataloader['test'], device=device)
    is_best = accuracy > best_accuracy
    if is_best:
        best_accuracy = accuracy
    print(f'Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/174083.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python学习 --- 集合基础

目录 一、什么是集合? 二、集合的创建方式 1、直接使用{} 2、使用内置函数set() 三、集合的相关操作 1、集合元素的判断 2、集合元素的新增 3、集合元素的删除 四、集合间的关系 五、集合的数学操作 1、交集操作 2、并集操作 3、差集操作 4、对称差集…

基于微信小程序的校园商铺系统小程序

文末联系获取源码 开发语言:Java 框架:ssm JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7/8.0 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Maven包:Maven3.3.9 浏览器…

2022大数据产业年度“国产化优秀代表厂商”榜单发布,亚信科技AntDB数据库位列其中

国产化/信创亚信科技 ‍数据智能产业创新服务媒体 ——聚焦数智 改变商业 亚信科技也做数据库?实际上亚信科技AntDB是我国最早的国产数据库产品之一,是在21世纪初国外品牌数据库无法满足我国暴涨的通信需求的情况下,为了帮助通信运用商更好…

普中科技手把手教你学ESP32--基于MicroPython-02

第四讲:LED灯实验 MicroPython函数使用 本来需要加入machine.的,但是我引入了form machine import Pin就可以直接省略了 machine是一个模块,Pin是machine中的一个功能 Micropython官网学习 点击DOCS文档 选择相应的芯片 查看对应的模块 …

LeetCode 1801. 积压订单中的订单总数(C++)

思路: 该题主要是对比销售、采购的价格来进行数组\队列的pop和push操作来实现;采用优先队列来实现排序,其中销售和采购对应小队列和大队列 对于 销售 操作;如果采购的积压订单中有出价格比自己的销售价格高,就出 对于 …

C++设计模式(7)——外观模式

外观模式 亦称: Facade 意图 外观模式是一种结构型设计模式, 能为程序库、 框架或其他复杂类提供一个简单的接口。 问题 假设你必须在代码中使用某个复杂的库或框架中的众多对象。 正常情况下, 你需要负责所有对象的初始化工作、 管理其…

29.Isaac教程--调整导航

调整导航 ISAAC教程合集地址: https://blog.csdn.net/kunhe0512/category_12163211.html 文章目录调整导航定位器全局规划器局部规划器控制器定位器 定位器是导航堆栈的关键部分,因为了解机器人的位置对于正确导航到目的地至关重要。 因此,快速准确的定…

2、threejs官网本地化部署启动和Parcel热加载:Web应用打包工具介绍及使用

一、Three.js 官网 背景: threejs 是国外的网站,访问有时候比较卡,所以建议本地化部署启动一下,方便随时访问学习。 部署方案: 1、访问Threejs官网 2、点击github 选择 dev版本下载 3、下载完之后,解压…

Java中的this关键字

介绍 this关键字用于引用当前实例,在Java语言中,当创建一个对象后,Java虚拟机就会为其分配一个指向对象本身的指针,这个指针就是“this”。 Java关键字this只能用于方法方法体内,在类中的非静态方法中使用&#xff0…

14 Java集合(集合框架+泛型+ArrayList类+LinkedList类+Vector类+HashSet类等)

本篇主要是集合框架基础和List集合,Map集合等等后续更 集合14.1 集合框架14.1.1 概念14.1.2 集合架构14.2 Collection接口14.2.1 常用方法14.3 迭代器14.3.1 迭代器原理14.3.2 迭代器使用常见问题14.4 泛型基本使用14.5 ArrayList类14.5.1 常用方法14.5.2 实现原理1…

【手写 Vue2.x 源码】第三十三篇 - diff算法-收尾+阶段性总结

一,前言 上篇,diff算法-乱序比对,主要涉及以下几个点: 介绍了乱序比对的方案介绍了乱序比对的过程分析实现了乱序比对的代码逻辑 本篇,diff 算法的阶段性梳理 二,初渲染与视图更新流程 Vue 初渲染时&…

注册商标需要哪些材料和条件?

申请注册商标条件是什么1、申请人必须是申请认定商标的所有人,是在当省区域内的自然人、法人和其他组织;2、该商标自核准注册之起连续使用满三年并继续有效,且无权属争议;3、该商标为相关公众所熟知,在相关市场内具有较高的知名度;4、该商标核…

亚信科技AntDB数据库荣获2022年度技术卓越奖

近日,业界知名IT垂直媒体IT168发布了“2022技术卓越奖”主题奖项,亚信科技AntDB数据库荣获技术卓越奖。 2022 “技术卓越奖”由行业CIO/CTO大咖、技术专家及IT媒体三方联合评选,评判标准代表了用户和媒体声音。经过多方评审,亚信科…

jvm参数简介

Xmx3550m:设置JVM最大堆内存为3550M。 -Xms3550m:设置JVM初始堆内存为3550M。此值可以设置与-Xmx相同,以避免每次垃圾回收完成后JVM重新分配内存。 -Xss128k:设置每个线程的栈大小。JDK5.0以后每个线程栈大小为1M,之…

【SCL】1200应用案例:交通灯模拟自动装料控制

使用博图SCL语言来编写 交通灯模拟控制 和 自动装料应用案例 文章目录 目录 前言 一、应用:交通灯模拟控制 1.控制要求 2.I\o分配和接线 3.程序编写和效果 4.小结 二、自动装料模拟控制 1.控制要求 2.I/O分配 3.程序编写 4.小结 总结 前言 本篇文章我们继续学习西…

宏任务和微任务

宏任务和微任务1. 什么是宏任务和微任务2. 宏任务和微任务的执行顺序3. 去银行办业务的场景4. 分析以下代码输出的顺序5. 经典面试题1. 什么是宏任务和微任务 JavaScript 把异步任务又做了进一步的划分,异步任务又分为两类,分别是: ① 宏任…

寄存器、RAM、ROM、Flash

单片机寄存器简述 寄存器详细请点这里 1、单片机寄存器就是单片机片内存储器(片内RAM)一部分,每一个都有地址。只不过这几个寄存器有特殊的作用,比如指令:MUL AB,这条指令用到两个寄存器A,B进行乘法,结果存到BA里面&a…

kaggle竞赛 | Quora Insincere Question | 文本情感分析

目录赛题背景赛题评价指标数据集分析pytorch建模之前发布了一遍实战类的情感分析的文章,包括微博爬虫,数据分析,相关模型。 可以参考: https://blog.csdn.net/lijiamingccc/article/details/126963413 比赛链接: http…

Spring Boot学习篇(十二)

Spring Boot学习篇(十二) shiro安全框架使用篇(四) 2 在主页显示用户登录状态、用户信息和完成默认注销(不改shiro原来的配置)操作 2.1 变更SysUserController类 2.1.1 在SysUserController类中注入sysUserMapper Autowired SysUserMapper sysUserMapper;2.1.2 在SysUserC…

1598_AURIX_TC275_GPIO功能以及部分寄存器梳理1

全部学习汇总: GreyZhang/g_TC275: happy hacking for TC275! (github.com) 接下来,看一下GPIO的寄存器以及部分相关的功能。这部分将会是接下来这个章节剩余的全部,可能内容偏雷同,因此都是跳跃式看。但是中间需要临时关注一下的…