【6s965-fall2022】剪枝✂pruningⅠ

news2024/12/27 17:02:17

在这里插入图片描述

模型剪枝的介绍

修剪,消除不必要的知识。DNN的知识可以理解为存在于其权重中。

事实证明,许多 DNN 模型可以被分解为权重张量,而权重张量经常包含统计冗余(稀疏性)。因此,你可以压缩 DNN 的权重张量,从而缩小 DNN 的存储、数据传输和执行时间以及能耗成本。想想看,这就像把每个 DNN 张量放入一个 “压缩文件”,它需要的比特数比原始文件少。

然而,压缩后的权重张量仍然需要在模型推理时用于计算,所以我们需要找到一个适合高效计算的压缩方案。

许多压缩技术的工作原理是从权重张量中去除多余的零值,这样一来,压缩后的大部分或所有剩余的值都是非零的。通常情况下,使用这样的去零方案可以直接在压缩的数据上进行计算,因为我们在DNN中主要的计算是乘法/累加,对于这些零操作数是无效的。
在这里插入图片描述然而,在许多DNN中,稀疏性并不直接以冗余零值的形式存在,而是表现为权重张量中的其他隐藏模式。于是,形成了剪枝的做法,假设张量有一些潜在的统计稀疏性,我们使用一些技术将这种稀疏性转换为零稀疏性。我们将DNN的一些权重翻转为零,并尽可能调整其他权重的值以进行补偿由于零值权重对乘法/累加无效,我们将零值权重视为不存在(它们被 "修剪 "了),因此我们不存储、传输或计算这些权重

因此,综上所述,我们通过将统计稀疏性转换为特定形式的稀疏性(零稀疏性),然后依靠硬件或软件来利用这种稀疏性做更少的计算,从而达到节省关键指标的目的。

通过神经生物学的类比来介绍修剪技术

在神经生物学中,轴突连接在计算和传递信息方面起重要的作用。轴突连接的数量在出生后急剧增加,并在一段时间内保持较高的水平,直到最终大量的轴突连接在青春期结束和进入成年后被破坏和消失。科学家认为,大脑这样做是为了促进生命早期的积极学习,然后在生命后期只磨练最关键的知识。这个破坏轴突连接的阶段,在神经生物学中被称为神经元轴突的 “修剪”。
在这里插入图片描述
因此,在深度学习中,将权重张量稀疏性转化为多余的零权重,然后对这些零权重进行处理的过程被命名为 “修剪”,类比人类发展中轴突破坏的阶段。有时,DNN在修剪前被训练完成(预训练),在这种情况下,神经生物学的类比是,模型的初始预训练就像婴儿和儿童阶段的积极努力学习,之后需要修剪,只为 "成年 "保留最重要的知识。

稀疏性

  • 稀疏性通常指的是包含多个值的数据结构中的基本统计冗余度。一般来说,稀疏性意味着有机会通过减少冗余来压缩数据结构。
  • 在这门课中,我们关注的数据结构是张量
  • 冗余度是由计算零值元素的比例得到的
  • 利用pytorch来计算模型的稀疏性
def get_sparsity(tensor: torch.Tensor) -> float:
    """
    calculate the sparsity of the given tensor
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    return 1 - float(tensor.count_nonzero()) / tensor.numel()


def get_model_sparsity(model: nn.Module) -> float:
    """
    calculate the sparsity of the given model
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    num_nonzeros, num_elements = 0, 0
    for param in model.parameters():
        num_nonzeros += param.count_nonzero()
        num_elements += param.numel()
    return 1 - float(num_nonzeros) / num_elements

剪枝的流程

在这里插入图片描述

  1. 模型预训练,训练模型所有的权重
  2. 选择特定的权重来修剪✂
    • 权重变为零
    • 零权重可以从计算中完全跳过 ⇒ 时间($),节省能耗
    • 零权重不需要被存储 ⇒ 节省存储和数据传输
    • 使用掩码来权重剪枝,掩码需要保存,训练之后还需使用它来剪枝
  3. 训练一轮,增强剩余权重的连接性🔗
    • 梯度下降/SGD
    • 在训练过程中,零权重可能会变成非零权重
  4. 使用之前的掩码来剪枝
  5. 再训练,重复以上2和3的步骤

修剪、权重更新的影响

  • 在这里,我们假设使用的是magnitude-based pruning的变体,但是无论如何,想法都是相似的。
  • 修剪过程的每一步都会对权重分布产生独特的影响
  • 在没有其他先验知识的情况下,设想初始权重分布是高斯分布
  • 修剪过程将那些绝对值接近零的权重剪掉,留下一个高于和低于零的权重值的双峰分布。
  • 重新训练/微调/权重更新增加了这些分布的范围,这是因为权重调整后准确性恢复。

请添加图片描述

  • 修剪牺牲了准确性,剪枝比例越高,对准确性的影响程度也越高
  • 准确率/剪枝比例的斜率是一个渐进的过程
  • 再训练/微调/权重更新可以恢复准确率
  • 在非常高的剪枝比例下,准确率几乎是断崖式的下降

请添加图片描述

剪枝的颗粒度 Granularity

在这里插入图片描述

  • 细粒度(又称非结构化)修剪 Fine-grained pruning
    • 灵活
      • 基本上是 “随机访问”,以修剪任何张量中的任何权重
    • 优点
      • 在相同的预训练准确率下,有较高的恢复准确率/较高的压缩率
    • 缺点
      • 有可能在定制硬件中利用;在CPU/GPU上有效利用具有挑战性或不可能性。
  • 基于模式的剪枝 Pattern-based pruning
    • 将剪枝指数限制为固定模式,这样硬件就更容易利用稀疏性
    • 一个典型的例子是N:M块剪枝,在每个大小为M的连续权重块中,N必须为非零。
    • 优点
      • 更容易在硬件中加以利用
      • 英伟达Ampere GPU架构利用了2:4块的稀疏性;速度大约提高了2倍
    • 缺点
      • 由于灵活性较低,与细粒度相比,恢复的准确性/压缩率较低
      • 在CPU上有效利用可能具有挑战性
  • 通道修剪 Channel pruning
    • 粗粒度修剪约束,只去除整个通道,有效改变层的几何形状
    • 优点
      • 可以在CPU和GPU上利用(实际上是改变了层的几何形状)
      • 也可以在自定义硬件中利用
    • 缺点
      • 由于几乎没有任何灵活性,恢复的准确率/压缩率最低
  • 在细粒度和通道修剪之间
    • 矢量级修剪
    • 内核级修剪

修剪权重的准则 Weight-pruning criteria

基于幅度的修剪——修剪最小的权重值 Magnitude pruning - prune the smallest weight values

绝对值较大的权重比绝对值小的权重更重要

在这里插入图片描述在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

利用pytorch实现 基于幅值的细粒度剪枝 magnitude-based fine-grained pruning

s p a r s i t y : = # Z e r o s # W = 1 − # N o n z e r o s # W \mathrm{sparsity} := \frac {\#\mathrm{Zeros}} {\#W} = 1 - \frac {\#\mathrm{Nonzeros}} {\#W} sparsity:=#W#Zeros=1#W#Nonzeros

式子中, # W \#W #W W W W中元素的个数, # Z e r o s \#\mathrm{Zeros} #Zeros是非零元素的个数

给定目标稀疏度 s s s,权重张量 W W W乘于二进制掩码(a binary mask) M M M来忽略要去除的权重

I m p o r t a n c e = ∣ W ∣ Importance=|W| Importance=W

v t h r = kthvalue ( I m p o r t a n c e , # W ⋅ s ) v_{\mathrm{thr}} = \texttt{kthvalue}(Importance, \#W \cdot s) vthr=kthvalue(Importance,#Ws)

M = I m p o r t a n c e > v t h r M = Importance > v_{\mathrm{thr}} M=Importance>vthr

W = W ⋅ M W = W \cdot M W=WM

I m p o r t a n c e Importance Importance W W W中每个元素的绝对值,张量的形状和 W W W一样;
kthvalue ( X , k ) \texttt{kthvalue}(X, k) kthvalue(X,k)找到张量 X X X的第 k k k个最小值, v t h r v_{\mathrm{thr}} vthr是阈值;
M M M中的元素只有0和1。

def fine_grained_prune(tensor: torch.Tensor, sparsity : float) -> torch.Tensor:
    """
    magnitude-based pruning for single tensor
    :param tensor: torch.(cuda.)Tensor, weight of conv/fc layer
    :param sparsity: float, pruning sparsity
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    :return:
        torch.(cuda.)Tensor, mask for zeros
    """
    sparsity = min(max(0.0, sparsity), 1.0)
    if sparsity == 1.0:
        tensor.zero_()
        return torch.zeros_like(tensor)
    elif sparsity == 0.0:
        return torch.ones_like(tensor)

    num_elements = tensor.numel()

    # Step 1: calculate the #zeros (please use round())
    num_zeros = round(num_elements * sparsity)
    # Step 2: calculate the importance of weight
    importance = tensor.abs()
    # Step 3: calculate the pruning threshold
    threshold = importance.view(-1).kthvalue(num_zeros).values
    # Step 4: get binary mask (1 for nonzeros, 0 for zeros)
    mask = torch.gt(importance, threshold)
    # Step 5: apply mask to prune the tensor
    tensor.mul_(mask)

    return mask

说明:

  • 在步骤1中,我们计算修剪后的零元素的个数(num_zeros)。请注意,num_zeros应该是一个整数。使用round()将浮动数转换成整数。
  • 在第二步,我们计算权重张量的重要性。Pytorch提供了torch.abs(), torch.Tensor.abs(), torch.Tensor.abs_() API.
  • 在第三步,我们计算修剪的 “阈值”,这样所有重要性小于 "阈值 "的突触将被删除。Pytorch提供了torch.kthvalue(), torch.Tensor.kthvalue(), torch.topk() API.
  • 在第四步,我们根据 "阈值 "计算修剪 “掩码”。mask'中的**1**表示突触将被保留,mask’中的0表示突触将被删除。mask = mportance > threshold。Pytorch提供torch.gt() API。

在这里插入图片描述
在这里插入图片描述

class FineGrainedPruner:
    def __init__(self, model, sparsity_dict):
        # keep a record of the pruning masks
        # make sure the model keep sparse all the time
        self.masks = FineGrainedPruner.prune(model, sparsity_dict)

    @torch.no_grad()
    def apply(self, model):
        # pruning
        for name, param in model.named_parameters():
            if name in self.masks:
                param *= self.masks[name]

    @staticmethod
    @torch.no_grad()
    def prune(model, sparsity_dict):
        masks = dict()
        for name, param in model.named_parameters():
            if param.dim() > 1: # we only prune conv and fc weights
                masks[name] = fine_grained_prune(param, sparsity_dict[name])
        return masks

prune the pre-trained model

pruner = FineGrainedPruner(model, sparsity)
pruner.apply(model)

fine-tune the pruned model

def train(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: Optimizer,
    scheduler: StepLR,
    callbacks = None,
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
) -> None:
    model.train()

    for inputs, targets in tqdm(dataloader, desc='train', leave=False):
        # Move the data from CPU to GPU
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Reset the gradients (from the last iteration)
        optimizer.zero_grad()

        # Forward inference
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward propagation
        loss.backward()

        # Update optimizer
        optimizer.step()

		# use callbacks to prune model after every train
        if callbacks is not None:
            for callback in callbacks:
                callback()

    # Update scheduler
    scheduler.step()


num_finetune_epochs = 5
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_finetune_epochs)

best_sparse_checkpoint = dict()
best_sparse_accuracy = 0
print(f'Finetuning Fine-grained Pruned Sparse Model')
for epoch in range(num_finetune_epochs):
    # At the end of each train iteration, we have to apply the pruning mask 
    # to keep the model sparse during the training
    train(model, dataloader['train'], criterion, optimizer, scheduler,
          callbacks=[lambda: pruner.apply(model)], device=device)
    accuracy = evaluate(model, dataloader['test'], device=device)
    # save the best model
    is_best = accuracy > best_sparse_accuracy
    if is_best:
        best_sparse_checkpoint['state_dict'] = copy.deepcopy(model.state_dict())
        best_sparse_accuracy = accuracy
    print(f'Epoch {epoch+1} Sparse Accuracy {accuracy:.2f}% / Best Sparse Accuracy: {best_sparse_accuracy:.2f}%')

效果

预训练模型
在这里插入图片描述
剪枝后
在这里插入图片描述微调后
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/172893.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[从零开始]用python制作识图翻译器·五

测试 通过以上步骤我们终于实现了系统,现在到了紧张刺激的测试环节。直接运行run.py文件: python run.py ::注意需要进入conda环境稍作等等,我们的系统就运行啦(啵唧啵唧)。 在使用之前,我们还需要在设置中输入自己的…

使用vscode进行C++代码开发(linux平台)

使用vscode进行C代码开发(linux平台一、插件安装二、常用快捷键三、重要配置文件四、实际例子1. 编译并运行一个含有多个文件夹和文件的代码工程2. 编译并运行一个依赖第三方库的代码工程参考资料一、插件安装 执行 ctrl shift x打开插件窗口,然后搜索c插件&…

鸡格线(map操作)

鸡格线 (nowcoder.com) 题目描述 你有一个长为n的数组a,你需要支持以下两种操作: 1、输入l, r, k,对区间[1,r]中所有数字执行a; f(a;)操作k次(式中等号表示赋值操作),之中f(z)round(10、c),round为四舍五入函数。 2、输出当前数组…

缓存一致性问题怎么解决

关于Redis的其他的一些面试问题已经写过了,比如常见的缓存穿透、雪崩、击穿、热点的问题,但是还有一个比较麻烦的问题就是如何保证缓存一致性。对于缓存和数据库的操作,主要有以下两种方式。先删缓存,再更新数据库先删除缓存&…

Java 多线程 笔记

文章目录实现方式1. 通过Thread类2. 通过Runnable接口3. 通过Callable接口线程状态线程方法1. 线程停止2. 线程休眠sleep3. 线程礼让yield4. 线程强制执行join5. 线程状态观测getState6. 线程优先级守护线程(daemon线程同步 TODO线程死锁Lock锁(可重入锁…

Typora+Gitee+PicGo搭建图床

引言 markdown原则上不建议使用base64内嵌图片,因为太麻烦。 如果只是在本机浏览、编辑的话,那引用相对路径或者绝对路径即可,但是考虑到要发布、分享的情况,使用图床是比较好的解决方案。 本教程可以快速得到一个相对稳定的免…

单片机数据、地址、指令、控制总线结构

数据、地址、指令:之所以将这三者放在一同,是因为这三者的实质都是相同的──数字,或者说都是一串‘0’和‘1’组成的序列。换言之,地址、指令也都是数据。 指令:具体可参考文章 由单片机芯片的设计者规则的一种数字…

c程序gcc编译常见报错及解决方法整理

目录一、简介二、常见报错及解决方法1、数组定义错误2、Not enough information to produce a SYMDEFs file3、文件乱码<U0000>4、未定义或未申明报错5、代码中误加入中文三、其他相关链接一、简介 本文主要是整理c程序编译过程的常见报错的解决方法&#xff0c;方便大家…

Leetcode.312 戳气球

题目链接 Leetcode.312 戳气球 题目描述 有 n个气球&#xff0c;编号为0到 n - 1&#xff0c;每个气球上都标有一个数字&#xff0c;这些数字存在数组 nums中。 现在要求你戳破所有的气球。戳破第 i 个气球&#xff0c;你可以获得 nums[i−1]∗nums[i]∗nums[i1]nums[i - 1]…

C++ 类 对象初学者学习笔记

C 类 & 对象 访问数据成员 #include <iostream>using namespace std;class Box {public:double length; // 长度double breadth; // 宽度double height; // 高度// 成员函数声明double get(void);void set( double len, double bre, double hei ); }; // 成员…

两种在CAD中加载在线卫星影像的方法

概述 经常使用CAD的朋友应该会有这样的一个烦恼&#xff0c;就是当加载卫星图到CAD中进行绘图的时候&#xff0c;由于CAD本身的限制和电脑性能等原因&#xff0c;往往不能加载太大的地图图片到CAD内&#xff0c;这里给大家介绍两种在CAD内加载在线卫星影像的方法&#xff0c;希…

用docker部署webstack导航网站-其一

序言 认识的好朋友斥资买了一个NAS&#xff0c;并搭建了webdev和其他一些web应用&#xff0c;用于存放电子书、电影&#xff0c;并用alist搭建了一个网盘。现在他还缺少一个导航页&#xff0c;于是拖我给他做一个导航页。我欣然接受了&#xff0c;他想做一个和TBox导航&#x…

Arduino开发:网页控制ESP8266三色LED灯闪烁

根据板卡原理RGB三色LED对应引脚&#xff1a;int LEDR12、int LEDG14、int LEDB13;设置串口波特率为115200Serial.begin(115200);源代码如下所示&#xff1a;#include <ESP8266WiFi.h> // 提供 Wi-Fi 功能的库#include <ESP8266WebServer.h> // 提供网站服务器功能…

分布式共识算法随笔 —— 从 Quorum 到 Paxos

分布式共识算法随笔 —— 从 Quorum 到 Paxos 概览: 为什么需要共识算法&#xff1f; 昨夜西风凋碧树&#xff0c;独上高楼&#xff0c;望尽天涯路 复制(Replication) 是一种通过将同一份数据在复制在多个服务器上来提高系统可用性和扩展写吞吐的策略, 。常见的复制策略有主从…

Flink CDC 原理

文章目录CDC&#xff0c;Change Data Capture 变更数据捕获 目前CDC有两种实现方式&#xff0c;一种是主动查询、一种是事件接收。 主动查询&#xff1a; 相关开源产品有Sqoop、Kafka JDBC Source等。 用户通常会在数据原表中的某个字段中&#xff0c;保存上次更新的时间戳或…

一篇博客教会你写序列化工具

文章目录什么是序列化&#xff1f;序列化格式JSON序列化精简序列化数据总结源码什么是序列化&#xff1f; 总所周知&#xff0c;在Java语言中&#xff0c;所有的数据都是以对象的形式存在Java堆中。 但是Java对象如果要存储在别的地方&#xff0c;那么单纯的Java对象就无法满…

我靠steam/csgo道具搬运实现财富自由

和莘莘学子一样&#xff0c;本着毕业对未来的憧憬&#xff0c;规划着漫漫人生&#xff0c;可是被残酷的事实打败。 直到一次偶然的同学聚会&#xff0c;谈及了现如今的生活才发现一片新大陆&#xff0c;通过信息差去赚取收益。 记得之前在校期间常常和哥几个通宵干CSGO,直到这…

Elasticsearch7.8.0版本高级查询—— 完全匹配查询文档

目录一、初始化文档数据二、完全匹配查询文档2.1、概述2.2、示例一、初始化文档数据 在 Postman 中&#xff0c;向 ES 服务器发 POST 请求 &#xff1a;http://localhost:9200/user/_doc/1&#xff0c;请求体内容为&#xff1a; { "name":"zhangsan", &quo…

【Nginx】使用Docker完成Nginx负载均衡+动静分离

前提是需要配置Nginx的反向代理&#xff0c;可以我看之前的文章 上篇Nginx配置动态代理的文章&#xff0c;我们在tomcat里写了两个简单html 这次我们依然采取同样的思路来演示负载均衡 一、负载均衡 1.在两个Tomcat容器&#xff08;我这里一个端口8081&#xff0c;一个8082…

Gradle vs Maven 基本对比(一)

Gradle 与Maven 的基本对比 对比目录&#xff1a; 1、工具包目录对比 2、创建项目结构对比 3、启动进程对比 4、性能对比 5、简洁性对比 什么是gradle: Gradle 是一个开源的运行在JVM上自动化构建工具&#xff0c;专注于灵活性和性能。Gradle 使用 Groovy 或 Kotlin DSL(领…