MobileNetV3基于NNI剪枝操作

news2025/1/23 2:04:51

NNI剪枝入门可参考:nni模型剪枝_benben044的博客-CSDN博客_nni 模型剪枝

1、背景

本文的剪枝操作针对CenterNet算法的BackBone,即MobileNetV3算法。

该Backbone最后的输出格式如下:

假如out = model(x),则x[-1]['hm']可获得heatmap的shape。

2、直接添加nni操作

直接添加的示例代码如下:

import torch
from torch import nn
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup

class hswish(nn.Module):
    def __init__(self):
        super(hswish, self).__init__()
        self.relu6 = nn.ReLU6(inplace=True)

    def forward(self, x):
        out = x * self.relu6(x + 3) / 6
        return out

class hsigmoid(nn.Module):
    def __init__(self):
        super(hsigmoid, self).__init__()
        self.relu6 = nn.ReLU6(inplace=True)

    def forward(self, x):
        out = self.relu6(x + 3) / 6
        return out

# 注意力机制
class SE(nn.Module):
    def __init__(self, in_channels, reduce=4):
        super(SE, self).__init__()

        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduce, 1, bias=False),
            nn.BatchNorm2d(in_channels // reduce),
            nn.ReLU6(inplace=True),
            nn.Conv2d(in_channels // reduce, in_channels, 1, bias=False),
            nn.BatchNorm2d(in_channels),
            hsigmoid()
        )

    def forward(self, x):
        out = self.se(x)
        out = x * out
        return out

class Block(nn.Module):
    def __init__(self, kernel_size, in_channels, expand_size, out_channels, stride, se=False, nolinear='RE'):
        super(Block, self).__init__()

        self.se = nn.Sequential()
        if se:
            self.se = SE(expand_size)

        if nolinear == 'RE':
            self.nolinear = nn.ReLU6(inplace=True)
        elif nolinear == 'HS':
            self.nolinear = hswish()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels, expand_size, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(expand_size),
            self.nolinear,

            nn.Conv2d(expand_size, expand_size, kernel_size, stride=stride, padding=kernel_size // 2, groups=expand_size, bias=False),
            nn.BatchNorm2d(expand_size),
            self.se,
            self.nolinear,

            nn.Conv2d(expand_size, out_channels, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut = nn.Sequential()
        if stride == 1 and in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        self.stride = stride

    def forward(self, x):
        out = self.block(x)

        if self.stride == 1:
            out += self.shortcut(x)

        return out

class MobileNetV3(nn.Module):
    def __init__(self, class_num):
        super(MobileNetV3, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            hswish()
        )

        self.neck = nn.Sequential(
            Block(3, 16, 16, 16, 2, se=True),
            Block(3, 16, 72, 24, 2),
            Block(3, 24, 88, 24, 1),
            Block(5, 24, 96, 40, 2, se=True, nolinear='HS'),
            Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),
            Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),
            Block(5, 40, 120, 48, 1, se=True, nolinear='HS'),
            Block(5, 48, 144, 48, 1, se=True, nolinear='HS'),
            Block(5, 48, 288, 96, 2, se=True, nolinear='HS'),
            Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),
            Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(96, 576, 1, bias=False),
            nn.BatchNorm2d(576),
            hswish()
        )

        self.avgpool = nn.AdaptiveAvgPool2d(1)

        self.conv3 = nn.Sequential(
            nn.Conv2d(576, 1280, 2, bias=False),
            nn.BatchNorm2d(1280),
            hswish()
        )

        self.hm = nn.Conv2d(20, class_num, kernel_size=1)
        self.wh = nn.Conv2d(20, 2, kernel_size=1)
        self.reg = nn.Conv2d(20, 2, kernel_size=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.neck(x)
        x = self.conv2(x)
        x = self.conv3(x)

        y = x.view(x.shape[0], -1, 128, 128)
        z = {}
        z['hm'] = self.hm(y)
        z['wh'] = self.wh(y)
        z['reg'] = self.reg(y)
        return [z]


if __name__ == '__main__':
    model = MobileNetV3(10)
    print('-----------raw model------------')
    print(model)

    config_list = [{
        'sparsity_per_layer': 0.8,
        'op_types': ['Conv2d']
    }]

    pruner = L1NormPruner(model, config_list)
    _, masks = pruner.compress()
    for name, mask in masks.items():
        print(name, ' sparsity: ', '{:.2f}'.format(mask['weight'].sum() / mask['weight'].numel()))
    pruner._unwrap_model()
    ModelSpeedup(model, torch.rand(2, 3, 516, 516), masks).speedup_model()

    print('------------after speedup------------')
    print(model)

如果参考nni入门直接添加nni压缩的代码,则会报如下错误:
RuntimeError: Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions。

 File "D:\programs\python37\lib\site-packages\nni\common\graph_utils.py", line 78, in _trace
    self.trace = torch.jit.trace(model, dummy_input, **kw_args)
  File "D:\programs\python37\lib\site-packages\torch\jit\_trace.py", line 742, in trace
    _module_class,
  File "D:\programs\python37\lib\site-packages\torch\jit\_trace.py", line 940, in trace_module
    _force_outplace,
RuntimeError: Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions

 原因,返回的数据不符合torch.jit.trace的要求,而示例model返回的是一个dict,它不是tensors | lists | tuples of tensors | dictionary of tensors中的一种

所以需要对MobileNetv3进行改造,以满足torch.jit.trace的返回要求。

3、MobileNetV3针对NNI的改造

改造方法:

(1)将输出从dict修改为tuple形式

(2)hm、wh、reg的定义从__init__()函数移到forward中。因为hm中conv的in_channel是会变化的,未剪枝前是A,剪枝后是B,所以在__init__()中定义没法动态修改in_channel值,只能放到forward中进行处理。

改造后的示例代码如下:

import torch
from torch import nn
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup


class hswish(nn.Module):
    def __init__(self):
        super(hswish, self).__init__()
        self.relu6 = nn.ReLU6(inplace=True)

    def forward(self, x):
        out = x * self.relu6(x + 3) / 6
        return out

class hsigmoid(nn.Module):
    def __init__(self):
        super(hsigmoid, self).__init__()
        self.relu6 = nn.ReLU6(inplace=True)

    def forward(self, x):
        out = self.relu6(x + 3) / 6
        return out

# 注意力机制
class SE(nn.Module):
    def __init__(self, in_channels, reduce=4):
        super(SE, self).__init__()

        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduce, 1, bias=False),
            nn.BatchNorm2d(in_channels // reduce),
            nn.ReLU6(inplace=True),
            nn.Conv2d(in_channels // reduce, in_channels, 1, bias=False),
            nn.BatchNorm2d(in_channels),
            hsigmoid()
        )

    def forward(self, x):
        out = self.se(x)
        out = x * out
        return out

class Block(nn.Module):
    def __init__(self, kernel_size, in_channels, expand_size, out_channels, stride, se=False, nolinear='RE'):
        super(Block, self).__init__()

        self.se = nn.Sequential()
        if se:
            self.se = SE(expand_size)

        if nolinear == 'RE':
            self.nolinear = nn.ReLU6(inplace=True)
        elif nolinear == 'HS':
            self.nolinear = hswish()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels, expand_size, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(expand_size),
            self.nolinear,

            nn.Conv2d(expand_size, expand_size, kernel_size, stride=stride, padding=kernel_size // 2, groups=expand_size, bias=False),
            nn.BatchNorm2d(expand_size),
            self.se,
            self.nolinear,

            nn.Conv2d(expand_size, out_channels, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut = nn.Sequential()
        if stride == 1 and in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        self.stride = stride

    def forward(self, x):
        out = self.block(x)

        if self.stride == 1:
            out += self.shortcut(x)

        return out

class MobileNetV3(nn.Module):
    def __init__(self, class_num, sparsity_ratio):
        super(MobileNetV3, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            hswish()
        )

        self.neck = nn.Sequential(
            Block(3, 16, 16, 16, 2, se=True),
            Block(3, 16, 72, 24, 2),
            Block(3, 24, 88, 24, 1),
            Block(5, 24, 96, 40, 2, se=True, nolinear='HS'),
            Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),
            Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),
            Block(5, 40, 120, 48, 1, se=True, nolinear='HS'),
            Block(5, 48, 144, 48, 1, se=True, nolinear='HS'),
            Block(5, 48, 288, 96, 2, se=True, nolinear='HS'),
            Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),
            Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(96, 576, 1, bias=False),
            nn.BatchNorm2d(576),
            hswish()
        )

        self.avgpool = nn.AdaptiveAvgPool2d(1)

        self.conv3 = nn.Sequential(
            nn.Conv2d(576, 1280, 2, bias=False),
            nn.BatchNorm2d(1280),
            hswish()
        )

        self.class_num = class_num

    def forward(self, x):
        x = self.conv1(x)
        x = self.neck(x)
        x = self.conv2(x)
        x = self.conv3(x)

        y = x.view(x.shape[0], -1, 128, 128)

        in_channel = y.shape[1]
        hm = nn.Conv2d(in_channel, self.class_num, kernel_size=1)
        wh = nn.Conv2d(in_channel, self.class_num, kernel_size=1)
        reg = nn.Conv2d(in_channel, self.class_num, kernel_size=1)

        return (hm(y), wh(y), reg(y))

if __name__ == '__main__':
    model = MobileNetV3(10, 0.2)
    print('-----------raw model------------')
    print(model)

    config_list = [{
        'sparsity_per_layer': 0.2,
        'op_types': ['Conv2d']
    }]
    pruner = L1NormPruner(model, config_list)
    _, masks = pruner.compress()
    for name, mask in masks.items():
        print(name, ' sparsity: ', '{:.2f}'.format(mask['weight'].sum() / mask['weight'].numel()))
    pruner._unwrap_model()
    ModelSpeedup(model, torch.rand(2, 3, 516, 516), masks).speedup_model()

    print('------------after speedup------------')
    print(model)

    input = torch.randn(2, 3, 516, 516)   # batch_size =1 会报错
    out = model(input)
    print(out[0].shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/92456.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring框架04(Spring框架中AOP)

一、spring中bean的生命周期 1.singleton 容器启动的时候创建对象,容器正常关闭时销毁对象 2.prototype 获取对象的时候创建对象,spring容器不负责对象的销毁 生命周期的过程: 1.调用无参创建对象 2.调用set方法初始化属性 3.调用初始化…

知识付费系统源码,可直接打包成app、H5、小程序

知识付费,在近几年来,越来越受到大家的关注。知识付费系统源码是将知识通过互联网渠道变现的方式。以知识为载体,通过付费获得在线知识以及在线学习所带来的收益。知识付费平台主要以分享知识内容,内容分为直播、录播、图文等形式…

【从零开始学爬虫】采集收视率排行数据

l 采集网站 ​【场景描述】采集收视率排行数据。 【源网站介绍】收视率排行网提供收视率排行,收视率查询,电视剧收视率,综艺节目收视率和电视台收视率信息。 【使用工具】前嗅ForeSpider数据采集系统 【入口网址】http://www.tvtv.hk/archives/category/tv 【采集内容】 …

产线工控安全

场景描述 互联网飞速发展,工业4.0的大力推行,让工控产线更加智能化,生产网已经发展成一个组网的计算机环境。这些工控产线组网中的所有工控设备现在统称为主机。 信息化虽然提高各大企业的生产效率,但也会遭遇各类安全问题&…

Problem B: 算法10-15~10-17:基数排序

Problem Description 基数排序是一种并不基于关键字间比较和移动操作的排序算法。基数排序是一种借助多关键字排序的思想对单逻辑关键字进行排序的方法。 通过对每一个关键字分别依次进行排序,可以令整个关键字序列得到完整的排序。而采用静态链表存储记录&#xf…

FAST-LIO论文阅读

1. 摘要 本文提出一个开销较小且鲁棒的激光惯性里程计框架。使用迭代扩展卡尔曼滤波器来实现激光雷达特征点和IMU的紧耦合,可以在快速运动、有噪声或重复纹理等退化环境中鲁棒地定位。为了在测量数据量很大的情况下降低开销,提出了计算卡尔曼增益的新公…

如何做电商运营

电商是通过电子设备和网络技术进行的商业模式,通俗的来说也就是通过网络结识买家完成最终交易。电子商务凭借它便宜,丰富和方便的特性,迅速占领了中国一大半的经济市场,作为个人怎么才能做好电商呢?掌握这几个要点就不…

物联网开发笔记(63)- 使用Micropython开发ESP32开发板之控制ILI9341 3.2寸TFT-LCD触摸屏进行LVGL图形化编程:显示中文

一、目的 这一节我们学习如何使用我们的ESP32开发板来控制ILI9341 3.2寸TFT-LCD触摸屏进行LVGL图形化编程的第一步:显示中文。 二、环境 ESP32 3.2寸 ILI9341触摸屏 Thonny IDE 几根杜邦线 Win10 接线方法:请看上一篇文章。 三、流程介绍 …

Verilog刷题HDLBits——Conwaylife

Verilog刷题HDLBits——Conwaylife题目描述代码结果题目描述 Conway’s Game of Life is a two-dimensional cellular automaton. The “game” is played on a two-dimensional grid of cells, where each cell is either 1 (alive) or 0 (dead). At each time step, each c…

【图像融合】小波变换(加权平均法+局域能量+区域方差匹配)图像融合【含Matlab源码 1819期】

⛄一、小波变换彩色图像融合简介 1 前言 图像融合是将不同传感器所获得的多个图像根据某种算法进行融合处理,取长补短,使一幅图像能够更清楚、更准确地反映多幅图像的信息,多聚焦彩色图像融合是图像融合的一个分支。目前在各种图像采集与分析系统中已使用的CCD数码相机,对于聚…

分享7 个VUE项目用得上的JavaScript库

借助开源库加速VUE项目的开发进度是现代前端开发比较常见的方式,平常收集一些JavaScript库介绍,在遇到需要的时候可以信手拈来。 VUE 生态有很多不错的依赖库或者组件,是使用VUE开发前端的原因之一。 1. vueuse 这是 GitHub 上星最多的库之…

【coarse-to-fine:基于频谱和空间损失约束】

UPanGAN: Unsupervised pansharpening based on the spectral and spatial loss constrained Generative Adversarial Network (UPanGAN:基于频谱和空间损失约束的生成式对抗网络的无监督全色锐化) 研究发现,在大多数基于神经网…

扎根底层核心技术:OPPO发布旗舰蓝牙音频SoC芯片

OPPO自研芯片能力更进一步。 2022年12月14日,OPPO发布自研芯片马里亚纳MariSilicon Y,作为一颗旗舰蓝牙音频SoC,实现了三大核心技术突破,使OPPO具备了计算连接能力的蓝牙SoC平台的设计能力。 这是OPPO发布的第二款自研芯片。去年…

初学者数据分析——Python职位全链路分析

最近在做Python职位分析的项目,做这件事的背景是因为接触Python这么久,还没有对Python职位有一个全貌的了解。所以想通过本次分析了解Python相关的职位有哪些、在不同城市的需求量有何差异、薪资怎么样以及对工作经验有什么要求等等。分析的链路包括&…

用了那么久的Vue,你了解Vue的报错机制吗?

Vue的5种处理Vue异常的方法 相信大家对Vue都不陌生。在使用Vue的时候也会遇到报错,也会使用浏览器的F12 来查看报错信息。但是你知道Vue是如何进行异常抛出的吗?vue 是如何处理异常的呢?接下来和大家介绍介绍,Vue是如何处理这几种…

【数据结构】树以及二叉树的概念

作者:一个喜欢猫咪的的程序员 专栏:《数据结构》 喜欢的话:世间因为少年的挺身而出,而更加瑰丽。 ——《人民日报》 目录 树的概念: 树的相关概念: 树如何表示&#xff…

Anaconda中Python虚拟环境的创建、使用与删除

本文介绍在Anaconda环境下,创建、使用与删除Python虚拟环境的方法。 在Python的使用过程中,我们常常由于不同Python版本以及不同第三方库版本的支持情况与相互之间的冲突情况,而需要创建不同的Python虚拟环境;在Anaconda的帮助下&…

如何使用Python构建Telegram机器人来生成随机引语

使用Python构建Telegram机器人以生成随机引语 聊天机器人是用于进行在线聊天对话的软件应用程序,通过文本或文本转语音的方式实现客户服务的自动化。[聊天机器人]可以用于提醒、预约等事情,也可以在社交媒体平台上使用。 在本教程中,我们将…

会自动化就能拿20K?不,你这顶多算会点皮毛···

前段时间公司要招2个自动化测试,同事面了几十个候选人,发现了一个很奇怪的现象,面试的时候,如果问的是框架api、脚本编写这些问题,基本上个个都能对答如流,等问到实际项目的时候,类似“怎么从0开…

如何实现一个基于WebRTC的音视频通信系统

文章有点长,推荐先收藏前言 目前市场上音视频技术方案大致分为以下几类,WebRTC因其超低延时、集成音视频采集传输等优点,是在线教育、远程会议等领域首选技术。 前言 目前市场上音视频技术方案大致分为以下几类,WebRTC因其超低延…