模型压缩之剪枝

news2024/9/20 0:54:47

(1)通道选择

这里要先解释一下:

(1)通道剪枝

那我们实际做法不是上面直接对所有层都添加L1正则项,而是仅仅对BN层权重添加L1正则项。通道剪枝具体步骤如下:
 

1.BN层权重添加L1正则项,进行稀疏训练

2.BN层权重的scale factor进行排序,对scale factor低于阈值的通道进行裁剪,得到剪枝模型

3.对剪枝模型进行finetune
 

注:进行finetune的目的是因为剪枝完整个网络结构发生了变化,之前的训练的模型无法再加载进入,必须要finetune(或者这里用重新训练更合适),否则会发现推理结果都是0.

在深度学习中,Batch Normalization(BN)层通常用于加速训练过程并提高模型的泛化能力。BN层的权重参数包括scale factor(缩放因子)和shift factor(偏移因子)。通过对BN层的scale factor添加L1正则化,我们可以实现通道剪枝。

下面是一个示例代码,展示了如何对BN层的scale factor添加L1正则化,并进行通道剪枝和微调(finetune)。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# L1正则化参数
lambda_l1 = 0.001

# 稀疏训练
for epoch in range(10):
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        
        # 计算BN层scale factor的L1正则化项
        l1_regularization = 0
        for module in model.modules():
            if isinstance(module, nn.BatchNorm2d):
                l1_regularization += torch.norm(module.weight, p=1)
        
        loss += lambda_l1 * l1_regularization
        
        loss.backward()
        optimizer.step()
    
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

print("稀疏训练完成")

# 通道剪枝
def prune_channels(model, sparsity_threshold):
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            weights = module.weight.data
            mask = torch.abs(weights) > sparsity_threshold
            module.weight.data = weights[mask]
            module.bias.data = module.bias.data[mask]
            module.num_features = int(torch.sum(mask))
            # 更新卷积层的输入通道数
            if hasattr(module, 'conv'):
                conv_module = getattr(module, 'conv')
                conv_module.out_channels = int(torch.sum(mask))
                conv_module.weight.data = conv_module.weight.data[mask]
                if conv_module.bias is not None:
                    conv_module.bias.data = conv_module.bias.data[mask]

# 设置稀疏性阈值
sparsity_threshold = 0.01

# 剪枝
prune_channels(model, sparsity_threshold)

print("通道剪枝完成")

# 微调
for epoch in range(10):
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    
    print(f'Finetune Epoch {epoch + 1}, Loss: {loss.item()}')

print("微调完成")

(2)卷积核剪枝

1.conv层权重添加L1正则项,进行稀疏训练

2.conv层权重进行排序,对权重低于阈值的卷积核进行裁剪,得到剪枝模型

3.对剪枝模型进行finetune

下面我写了一个简单的示例代码,展示了如何在训练过程中计算权重的稀疏性,并根据稀疏性剪掉通道。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# L1正则化参数
lambda_l1 = 0.001

# 训练模型
for epoch in range(10):
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        
        # 计算L1正则化项
        l1_regularization = 0
        for param in model.parameters():
            l1_regularization += torch.norm(param, p=1)
        
        loss += lambda_l1 * l1_regularization
        
        loss.backward()
        optimizer.step()
    
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

print("训练完成")

# 根据稀疏性剪掉通道
def prune_channels(model, sparsity_threshold):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            weights = module.weight.data
            abs_weights = torch.abs(weights)
            channel_sums = torch.sum(abs_weights, dim=(1, 2, 3))
            mask = channel_sums > sparsity_threshold
            module.weight.data = weights[mask]
            module.out_channels = int(torch.sum(mask))
            if module.bias is not None:
                module.bias.data = module.bias.data[mask]

# 设置稀疏性阈值
sparsity_threshold = 0.01

# 剪枝
prune_channels(model, sparsity_threshold)

print("剪枝完成")

在上面例子中,我们在训练完成后,通过 prune_channels 函数根据稀疏性剪掉通道。

具体步骤如下:

  1. 计算权重的稀疏性:对于每个卷积层的权重,我们计算每个通道的权重绝对值之和。

  2. 剪枝:根据设定的稀疏性阈值,我们创建一个掩码(mask),只保留那些权重绝对值之和大于阈值的通道,并更新卷积层的权重和偏置。

通过这种方式,我们可以根据权重的稀疏性剪掉不重要的通道,从而减少模型的复杂度和计算量。

通道剪枝和卷积核剪枝小结:

卷积核剪枝(Kernel Pruning)和通道剪枝(Channel Pruning)是两种不同的模型剪枝技术,它们在剪枝的对象和目标上有所区别。

卷积核剪枝(Kernel Pruning)

卷积核剪枝 是指从卷积层中移除整个卷积核(kernel)。一个卷积核通常由一组权重组成,这些权重在卷积操作中与输入特征图的局部区域进行卷积运算。卷积核剪枝的目标是移除那些对模型性能贡献较小的卷积核,从而减少模型的计算量和参数数量。

  • 剪枝对象:卷积核(kernel)。

  • 剪枝目标:移除整个卷积核。

  • 影响:减少卷积层的输出通道数。

通道剪枝(Channel Pruning)

通道剪枝 是指从卷积层或全连接层中移除整个通道(channel)。一个通道通常由一组权重组成,这些权重在卷积操作中与输入特征图的所有位置进行卷积运算。通道剪枝的目标是移除那些对模型性能贡献较小的通道,从而减少模型的计算量和参数数量。

  • 剪枝对象:通道(channel)。

  • 剪枝目标:移除整个通道。

  • 影响:减少卷积层的输入或输出通道数。

主要区别

  1. 剪枝对象

    • 卷积核剪枝针对的是卷积核,即卷积层中的单个权重组。

    • 通道剪枝针对的是通道,即卷积层或全连接层中的整个权重集合。

  2. 剪枝目标

    • 卷积核剪枝的目标是移除整个卷积核。

    • 通道剪枝的目标是移除整个通道。

  3. 影响

    • 卷积核剪枝主要影响卷积层的输出通道数。

    • 通道剪枝既可以影响卷积层的输入通道数,也可以影响输出通道数。

卷积核剪枝代码:
 

def prune_kernels(model, sparsity_threshold):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            weights = module.weight.data
            abs_weights = torch.abs(weights)
            kernel_sums = torch.sum(abs_weights, dim=(1, 2, 3))
            mask = kernel_sums > sparsity_threshold
            module.weight.data = weights[mask]
            module.out_channels = int(torch.sum(mask))
            if module.bias is not None:
                module.bias.data = module.bias.data[mask]

通道剪枝代码:

def prune_channels(model, sparsity_threshold):
    #遍历模型中的所有模块
    for module in model.modules():
        #检查模块是否为BN层
        if isinstance(module, nn.BatchNorm2d):
            #获取BN层的权重
            weights = module.weight.data
            #根据稀疏性阈值创建掩码
            mask = torch.abs(weights) > sparsity_threshold
            #应用掩码到BN层的权重和偏置
            module.weight.data = weights[mask]
            module.bias.data = module.bias.data[mask]
            module.num_features = int(torch.sum(mask))
            #检查BN层是否有与之关联的卷积层
            if hasattr(module, 'conv'):
                conv_module = getattr(module, 'conv')
                #应用掩码到卷积层的权重和偏置
                conv_module.out_channels = int(torch.sum(mask))
                conv_module.weight.data = conv_module.weight.data[mask]
                if conv_module.bias is not None:
                    conv_module.bias.data = conv_module.bias.data[mask]

在通道剪枝中,我们不仅需要剪枝Batch Normalization(BN)层的权重,还需要相应地剪枝与之关联的卷积层的权重。具体来说,BN层的权重(scale factor)决定了哪些通道是重要的,因此我们需要根据BN层的权重来剪枝卷积层的通道。

通过这种方式,我们确保了BN层的剪枝与卷积层的剪枝是一致的,即剪枝后的BN层和卷积层具有相同的通道数。这样可以保证模型在剪枝后的结构是有效的,并且能够正常工作。

总结来说,通道剪枝不仅涉及BN层的权重剪枝,还涉及与之关联的卷积层的权重剪枝,以确保剪枝后的模型结构的一致性和有效性。

(3)特征图重构
 

特征图重构是一种在通道剪枝中常用的方法,旨在最小化剪枝后特征图与原始特征图之间的差异。通过这种方式,我们可以更直接地控制剪枝的力度,并确保剪枝后的模型在性能上与原始模型尽可能接近。

下面是一个示例代码,展示了如何使用最小二乘法(linear least squares)来实现特征图重构,从而控制通道剪枝的力度。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
for epoch in range(10):
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

print("训练完成")

# 特征图重构
def feature_map_reconstruction(model, train_loader, alpha=0.01):
    model.eval()
    original_features = []
    pruned_features = []
    
    # 收集原始特征图
    with torch.no_grad():
        for data, _ in train_loader:
            output = model(data)
            original_features.append(output)
    
    # 剪枝
    def prune_channels(model, sparsity_threshold):
        for module in model.modules():
            if isinstance(module, nn.BatchNorm2d):
                weights = module.weight.data
                mask = torch.abs(weights) > sparsity_threshold
                module.weight.data = weights[mask]
                module.bias.data = module.bias.data[mask]
                module.num_features = int(torch.sum(mask))
                # 更新卷积层的输入通道数
                if hasattr(module, 'conv'):
                    conv_module = getattr(module, 'conv')
                    conv_module.out_channels = int(torch.sum(mask))
                    conv_module.weight.data = conv_module.weight.data[mask]
                    if conv_module.bias is not None:
                        conv_module.bias.data = conv_module.bias.data[mask]

    # 设置稀疏性阈值
    sparsity_threshold = 0.01
    prune_channels(model, sparsity_threshold)

    # 收集剪枝后的特征图
    with torch.no_grad():
        for data, _ in train_loader:
            output = model(data)
            pruned_features.append(output)
    
    # 计算特征图差异
    original_features = torch.cat(original_features, dim=0)
    pruned_features = torch.cat(pruned_features, dim=0)
    diff = original_features - pruned_features
    loss = alpha * torch.norm(diff, p=2)
    
    # 反向传播和优化
    loss.backward()
    optimizer.step()
    
    print(f'Feature Map Reconstruction Loss: {loss.item()}')

# 特征图重构
feature_map_reconstruction(model, train_loader)

print("特征图重构完成")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2103998.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ElementUI实现el-table组件的合并行功能

前言 有时遇到一些需求,需要实现ElementUI中,el-tabled组件合并单元格的功能,稍微了解一下它的数据格式,不难可以写出比合并方法。但是在鼠标经过单元行时,会出现高亮的行与鼠标经过的行不一致的BUG。因此还需要实现c…

超级右键 - 为 Mac 的右键菜单升级一下

是不是有很多小伙伴,希望 Mac 也能像 Windows 一样,拥有丰富的右键菜单,快速完成新建、剪切、发送文件等操作。 一个叫作超级右键的工具就能做到,它能为 Mac 右键菜单增添多个功能选项,如 Win 系统般一键新建 / 剪切文…

vue通过html2canvas+jspdf生成PDF问题全解(水印,分页,截断,多页,黑屏,空白,附源码)

前端导出PDF的方法不多,常见的就是利用canvas画布渲染,再结合jspdf导出PDF文件,代码也不复杂,网上的代码基本都可以拿来即用。 如果不是特别追求完美的情况下,或者导出PDF内容单页的话,那么基本上也就满足业…

我的大模型岗位面试总结!太卷了!!!—我面试了24家大模型岗位 只拿了9个offer!

这段时间面试了很多家(共24家,9个offer,简历拒了4家,剩下是面试后拒的),也学到了超级多东西。 大模型这方向真的卷,面试时好多新模型,新paper疯狂出,东西出的比我读的快…

传统CV算法——基于opencv的答题卡识别判卷系统

基于OpenCV的答题卡识别系统,其主要功能是自动读取并评分答题卡上的选择题答案。系统通过图像处理和计算机视觉技术,自动化地完成了从读取图像到输出成绩的整个流程。下面是该系统的主要步骤和实现细节的概述: 1. 导入必要的库 系统首先导入…

误删的PPT怎么恢复回来?

在日常工作和学习中,PPT已成为我们不可或缺的工具。然而,有时不小心误删重要的PPT文件,可能会让人倍感焦虑。别担心,本文将为你提供几种实用的方法,帮助你轻松恢复误删的PPT文件。 一、从回收站恢复 当你误删文件时&…

【Grafana】Prometheus结合Grafana打造智能监控可视化平台

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

香港一带一路研究院国际事务研究中心副主任陈景才阐述香港在一带一路建设及区块链金融领域的关键作用

2024年8月28日,香港金管局举行Ensemble项目沙盒(以下简称沙盒)启动仪式,并宣布首阶段试验将涵盖四大代币化资产用例主题,标志着金融业在代币化技术的实际应用进程中迈出重要一步。香港一带一路研究院国际事务研究中心副…

解剖学上合理的分割:通过先验变形显式保持拓扑结构|文献速递--基于深度学习的医学影像病灶分割

Title 题目 Anatomically plausible segmentations: Explicitly preserving topology through prior deformations 解剖学上合理的分割:通过先验变形显式保持拓扑结构 01 文献速递介绍 进行环向应变或壁厚度的计算,这些测量通常用于诊断肥厚性心肌病…

IDEA 安装lombok插件不兼容的问题及解决方法

解决:IDEA 安装lombok插件不兼容问题,plugin xxxx is incompatible 一、去官网下载最新的2024版本 地址传送通道: lombok插件官网地址https://plugins.jetbrains.com/plugin/6317-lombok/versions/stable 二、修改参数的配置 在压缩包路径…

理解C++的【内部链接】和【外部链接】

一、前言 最近在看《大规模C程序设计》一书,看第一章关于内部链接和外部链接这部分时,有点不太明白。通过书本理解和网上查阅文献,在此记录一下自己对这部分知识点的理解。 首先,提几个问题: 什么是内部链接&#x…

全域运营公司哪家做得好?全域运营系统综合评测结果揭晓!

作为当前火爆的风口项目,一直以来,全域运营都以其广阔的业务范围和巨大的收益潜力吸引着一批又一批的创业者入局分羹,使得全域运营公司哪家做得好等问题一度成为了相关创业者交流群内的讨论重点。 从目前的市场情况来看,由于进入…

定期加强医疗器械维修技能学习重要性

医学影像技术是现代医疗的重要支撑,是辅助临床诊断和治疗不可或缺的技术手段。影像医疗设备成像质量的优劣程度在一定程度上决定了疾病诊断结果的准确性,而术中使用的影像设备的优劣甚至可能影响手术的成功率。因此保证设备正常使用是重中之重,设备售后维修保养也就…

Langchain-Chatchat+Qwen实现本地知识库

1.基础介绍 Langchain-Chatchat一种利用 langchain 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。大致过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化…

《OpenCV计算机视觉》—— 对图片的各种操作

文章目录 1、安装OpenCV库2、读取、显示、查看图片3、对图片进行切割4、改变图像的大小5、图片打码6、图片组合7、图像运算8、图像加权运算 1、安装OpenCV库 使用pip是最简单、最快捷的安装方式 pip install opencv-python3.4.2还需要安装一个包含了其他一些图像处理算法函数的…

vector中的push_back()和emplace_back()的区别、以及使用场景

目录 前言 1. 基本区别 2. 性能差异 3. 构造参数传递 4. 使用场景总结 前言 push_back() 更适合在已经有对象实例的情况下使用。emplace_back() 则更适合需要在容器内部直接构造对象的场景,特别是在性能敏感的情况下。 1. 基本区别 push_back(): 作用&#xff…

酒店智能触摸开关在酒店管理中的作用

在众多智能化设备中,酒店智能触摸开关以其便捷性、高效性和节能环保的特性,正逐步成为提升住客体验、优化酒店运营管理的关键元素。本文将深入探讨酒店智能触摸开关在酒店管理中的多重作用。 一、提升住客体验,增强服务品质 便捷操作&#xf…

护眼灯真的可以保护眼睛吗?曝光劣质护眼台灯常见的三个特征

护眼灯真的可以保护眼睛吗?随着时代的发展,我们注意到越来越多的孩子开始佩戴眼镜。这一趋势引起了许多细心家长的关注,他们认识到这不仅是个别情况,而是现代生活方式和环境对孩子视力健康的挑战。自然而然地,“儿童是…

【淘宝采集项目经验分享】商品评论采集 |商品详情采集 |关键词搜索商品信息采集

商品评论采集 1、输入商品ID 2、筛选要抓取评论类型 3、填写要抓取的页数 4、立刻提交-启动测试 5、等爬虫结束后就可以到“爬取结果”里面下载数据 商品详情采集 1、输入商品ID 2、立刻提交-启动爬虫 3、等爬虫结束后就可以到“爬取结果”里面下载数据 taobao.item_…

报名开启!IDEA研究院编程语言MoonBit全球编程创新挑战赛启动

"懂语言者得天下"。探索编程之革新,参与AI时代编程语言之构建。2024年MoonBit全球编程创新挑战赛,为你开启! 我们向每一位怀揣才华与创意的编程爱好者发出邀请,一起在这场创新与挑战的盛会中,将理想照进现实…