YOLOv9改进策略【模型轻量化】| ShufflenetV2,通过通道划分构建高效网络

news2024/9/24 11:29:53

一、本文介绍

本文记录的是基于ShufflenetV2的YOLOv9目标检测轻量化改进方法研究FLOPs是评价模型复杂独的重要指标,但其无法考虑到模型的内存访问成本和并行度,因此本文在YOLOv9的基础上引入ShufflenetV2使其在在保持准确性的同时提高模型的运行效率

模型参数量计算量推理速度(bs=32)
YOLOv9-c50.69M236.6GFLOPs32.1ms
Improved42.88M194.5GFLOPs23.2ms

文章目录

  • 一、本文介绍
  • 二、ShuffleNet V2设计原理
  • 三、ShuffleNet V2基础模块的实现代码
  • 四、添加步骤
    • 4.1 修改common.py
    • 4.2 修改yolo.py
  • 五、yaml模型文件
    • 5.1 模型改进⭐
  • 六、成功运行结果


二、ShuffleNet V2设计原理

ShuffleNet V2是一种高效的卷积神经网络架构,其模型结构及优势如下:

  1. 模型结构
    • 回顾ShuffleNet v1ShuffleNet是一种广泛应用于低端设备的先进网络架构,为增加在给定计算预算下的特征通道数量,采用了点组卷积和瓶颈结构,但这增加了内存访问成本(MAC),且过多的组卷积和元素级“Add”操作也存在问题。
    • 引入Channel Split和ShuffleNet V2:为解决上述问题,引入了名为Channel Split的简单操作。在每个单元开始时,将 c c c个特征通道的输入分为两个分支,分别具有 c − c ′ c - c' cc c ′ c' c个通道。一个分支保持不变,另一个分支由三个具有相同输入和输出通道的卷积组成,以满足G1(平衡卷积,即相等的通道宽度可最小化MAC)。两个 1 × 1 1 \times 1 1×1卷积不再是组式的,这部分是为了遵循G2(避免过多的组卷积增加MAC),部分是因为拆分操作已经产生了两个组。卷积后,两个分支连接,通道数量保持不变,并使用与ShuffleNet v1相同的“通道洗牌”操作来实现信息通信。对于空间下采样,单元进行了略微修改,删除了通道拆分操作,使输出通道数量加倍。
    • 整体网络结构:通过反复堆叠构建块来构建整个网络,设置 c ′ = c / 2 c' = c/2 c=c/2,整体网络结构与ShuffleNet v1相似,并在全局平均池化之前添加了一个额外的 1 × 1 1 \times 1 1×1卷积层来混合特征。
  2. 优势
    • 高效且准确:遵循了高效网络设计的所有准则,每个构建块的高效率使其能够使用更多的特征通道和更大的网络容量,并且在每个块中,一半的特征通道直接通过块并加入下一个块,实现了一种特征重用模式,类似于DenseNet,但更高效。
    • 速度优势明显:在与其他网络架构的比较中,ShuffleNet v2在速度方面表现出色,特别是在GPU上明显快于其他网络(如MobileNet v2、ShuffleNet v1和Xception)。在ARM上,ShuffleNet v1、Xception和ShuffleNet v2的速度相当,但MobileNet v2较慢,这是因为MobileNet v2的MAC较高。
    • 兼容性好:可以与其他技术(如Squeeze - and - excitation模块)结合进一步提高性能。

论文:https://arxiv.org/pdf/1807.11164.pdf
源码:https://gitcode.com/gh_mirrors/sh/ShuffleNet-Series/blob/master/ShuffleNetV2/blocks.py?utm_source=csdn_github_accelerator&isLogin=1

三、ShuffleNet V2基础模块的实现代码

ShuffleNet V2基础模块的实现代码如下:

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups

    # reshape
    x = x.view(batchsize, groups,
               channels_per_group, height, width)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x


class conv_bn_relu_maxpool(nn.Module):
    def __init__(self, c1, c2):  # ch_in, ch_out
        super(conv_bn_relu_maxpool, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

    def forward(self, x):
        return self.maxpool(self.conv(x))


class Shuffle_Block(nn.Module):
    def __init__(self, inp, oup, stride):
        super(Shuffle_Block, self).__init__()

        if not (1 <= stride <= 3):
            raise ValueError('illegal stride value')
        self.stride = stride

        branch_features = oup // 2
        assert (self.stride != 1) or (inp == branch_features << 1)

        if self.stride > 1:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(inp),
                nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True),
            )

        self.branch2 = nn.Sequential(
            nn.Conv2d(inp if (self.stride > 1) else branch_features,
                      branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )

    @staticmethod
    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)  # 按照维度1进行split
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out



四、添加步骤

4.1 修改common.py

此处需要修改的文件是models/common.py

common.py中定义了网络结构的通用模块,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。

此时需要将上方实现的代码添加到common.py中。

在这里插入图片描述

注意❗:在4.2小节中的yolo.py文件中需要声明的模块名称为:conv_bn_relu_maxpoolShuffle_Block

4.2 修改yolo.py

此处需要修改的文件是models/yolo.py

yolo.py用于函数调用,我们只需要将common.py中定义的新的模块名添加到parse_model函数下即可。

conv_bn_relu_maxpool模块以及Shuffle_Block模块添加后如下:

在这里插入图片描述


五、yaml模型文件

5.1 模型改进⭐

在代码配置完成后,配置模型的YAML文件。

此处以models/detect/yolov9-c.yaml为例,在同目录下创建一个用于自己数据集训练的模型文件yolov9-c-shufflenetv2.yaml

yolov9-c.yaml中的内容复制到yolov9-c-shufflenetv2.yaml文件下,修改nc数量等于自己数据中目标的数量。

📌 模型的修改方法是将YOLOv9的骨干网络替换成Shufflenet V2ShuffleNet V2 在设计上注重减少内存访问成本并提高并行度,这有助于在保持准确性的同时提高模型的运行效率。相比YOLOv9原骨干网络,ShuffleNet V2 具有更低的计算复杂度,能够在相同或更少的计算资源下完成推理,对于实时性要求较高的任务具有重要意义。

结构如下:

# YOLOv9

# parameters
nc: 1  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()

# anchors
anchors: 3

# YOLOv9 backbone
backbone:
  [
   [-1, 1, Silence, []],  
   
   # conv down
   [-1, 1, conv_bn_relu_maxpool, [64, 3, 2]],  # 1-P1/2

   # conv down
   [-1, 1, Shuffle_Block, [ 128, 2 ]],  # 2-P2/4
   [-1, 3, Shuffle_Block, [ 128, 1 ]],  # 3

   [-1, 1, Shuffle_Block, [ 256, 2 ]],  # 4-P4/16 
   [-1, 7, Shuffle_Block, [ 256, 1 ]],  # 5

   [-1, 1, Shuffle_Block, [ 512, 2 ]],  # 6-P4/16
   [-1, 3, Shuffle_Block, [ 512, 1 ]],  # 7
   
  ]

# YOLOv9 head
head:
  [
   # elan-spp block
   [-1, 1, SPPELAN, [512, 256]],  # 10

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 13

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 3], 1, Concat, [1]],  # cat backbone P3

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 16 (P3/8-small)

   # avg-conv-down merge
   [-1, 1, ADown, [256]],
   [[-1, 11], 1, Concat, [1]],  # cat head P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 19 (P4/16-medium)

   # avg-conv-down merge
   [-1, 1, ADown, [512]],
   [[-1, 8], 1, Concat, [1]],  # cat head P5

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 22 (P5/32-large)
   
   
   # multi-level reversible auxiliary branch
   
   # routing
   [3, 1, CBLinear, [[256]]], # 23
   [5, 1, CBLinear, [[256, 512]]], # 24
   [7, 1, CBLinear, [[256, 512, 512]]], # 25
   
   # conv down
   [0, 1, Conv, [64, 3, 2]],  # 26-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 27-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 28

   # avg-conv down fuse
   [-1, 1, ADown, [256]],  # 29-P3/8
   [[21, 22, 23, -1], 1, CBFuse, [[0, 0, 0]]], # 30  

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 31

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 32-P4/16
   [[22, 23, -1], 1, CBFuse, [[1, 1]]], # 33 

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 34

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 35-P5/32
   [[23, -1], 1, CBFuse, [[2]]], # 36

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 37
   
   
   
   # detection head

   # detect
   [[29, 32, 35, 14, 17, 20], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)
  ]


六、成功运行结果

分别打印网络模型可以看到Shuffle_Block已经加入到模型中,并可以进行训练了。

yolov9-c-shufflenetv2

                 from  n    params  module                                  arguments                     
  0                -1  1         0  models.common.Silence                   []                            
  1                -1  1      1856  models.common.conv_bn_relu_maxpool      [3, 64]                       
  2                -1  1     14080  models.common.Shuffle_Block             [64, 128, 2]                  
  3                -1  3     27456  models.common.Shuffle_Block             [128, 128, 1]                 
  4                -1  1     52736  models.common.Shuffle_Block             [128, 256, 2]                 
  5                -1  7    242816  models.common.Shuffle_Block             [256, 256, 1]                 
  6                -1  1    203776  models.common.Shuffle_Block             [256, 512, 2]                 
  7                -1  3    404736  models.common.Shuffle_Block             [512, 512, 1]                 
  8                -1  1    656896  models.common.SPPELAN                   [512, 512, 256]               
  9                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 10           [-1, 5]  1         0  models.common.Concat                    [1]                           
 11                -1  1   2988544  models.common.RepNCSPELAN4              [768, 512, 512, 256, 1]       
 12                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 13           [-1, 3]  1         0  models.common.Concat                    [1]                           
 14                -1  1    814336  models.common.RepNCSPELAN4              [640, 256, 256, 128, 1]       
 15                -1  1    164352  models.common.ADown                     [256, 256]                    
 16          [-1, 11]  1         0  models.common.Concat                    [1]                           
 17                -1  1   2988544  models.common.RepNCSPELAN4              [768, 512, 512, 256, 1]       
 18                -1  1    656384  models.common.ADown                     [512, 512]                    
 19           [-1, 8]  1         0  models.common.Concat                    [1]                           
 20                -1  1   3119616  models.common.RepNCSPELAN4              [1024, 512, 512, 256, 1]      
 21                 3  1     33024  models.common.CBLinear                  [128, [256]]                  
 22                 5  1    197376  models.common.CBLinear                  [256, [256, 512]]             
 23                 7  1    656640  models.common.CBLinear                  [512, [256, 512, 512]]        
 24                 0  1      1856  models.common.Conv                      [3, 64, 3, 2]                 
 25                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               
 26                -1  1    212864  models.common.RepNCSPELAN4              [128, 256, 128, 64, 1]        
 27                -1  1    164352  models.common.ADown                     [256, 256]                    
 28  [21, 22, 23, -1]  1         0  models.common.CBFuse                    [[0, 0, 0]]                   
 29                -1  1    847616  models.common.RepNCSPELAN4              [256, 512, 256, 128, 1]       
 30                -1  1    656384  models.common.ADown                     [512, 512]                    
 31      [22, 23, -1]  1         0  models.common.CBFuse                    [[1, 1]]                      
 32                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       
 33                -1  1    656384  models.common.ADown                     [512, 512]                    
 34          [23, -1]  1         0  models.common.CBFuse                    [[2]]                         
 35                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       
 36[29, 32, 35, 14, 17, 20]  1  21542822  models.yolo.DualDDetect                 [1, [512, 512, 512, 256, 512, 512]]
yolov9-c-shufflenetv2 summary: 870 layers, 43094374 parameters, 43094342 gradients, 195.9 GFLOPs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2093616.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

十一. 常用类

文章目录 一、包装类1.1 包装类的继承关系1.2 包装类和基本数据类型的转换1.3 包装类与String之间的转换1.4 包装类的常用方法 二、String类2.1 String类的理解和创建对象2.2 String的创建方式2.3 字符串的特性2.4 String的常用方法 三、StringBuffer和StringBuilder类3.1 Stri…

重塑PDF编辑体验:最新在线工具深度评测

现在用PDF的场景多了&#xff0c;随之而来的加速了PDF编辑、转换工具的飞速发展&#xff0c;很多时候因为便捷大家更喜欢使用在线的工具。今天我就分享几款pdf在线编辑工具提高你文档处理效率。 1.福昕PDF编辑器 链接一下>>https://editor.foxitsoftware.cn 在沉浸阅…

Datawhale X 李宏毅苹果书 AI夏令营|机器学习基础之线性模型

1. 线性模型 线性模型是机器学习中最基础和常见的模型之一。在线性模型中&#xff0c;预测变量&#xff08;输入特征&#xff09;和目标变量&#xff08;输出&#xff09;之间的关系被建模为一个线性组合。数学形式可以表示为&#xff1a; 其中&#xff1a;x 是输入特征向量&a…

加速 PyTorch 模型:使用 ROCm 在 AMD GPU 上应用 torch.compile

Accelerate PyTorch Models using torch.compile on AMD GPUs with ROCm — ROCm Blogs 介绍 PyTorch 2.0 引入了一个名为*torch.compile()*的工具&#xff0c;可以极大地加速 PyTorch 代码和模型。通过将 PyTorch 代码转换为高度优化的内核&#xff0c;torch.compile 在现有代…

【深入理解SpringCloud微服务】深入理解微服务配置中心原理,并手写一个微服务配置中心

【深入理解SpringCloud微服务】深入理解微服务配置中心原理&#xff0c;并手写一个微服务配置中心 为什么要使用配置中心配置中心原理如何手写一个配置中心使用PropertySourceLocator监听配置变更&#xff0c;刷新配置 实现一个微服务配置中心服务端库表ConfigCenterController…

全国中学基础信息 API 数据接口

全国中学基础信息 API 数据接口 基础数据&#xff0c;高校高考&#xff0c;提供全国初级高级中学基础数据&#xff0c;定时更新&#xff0c;多维度筛选。 1. 产品功能 2024 年数据已更新&#xff1b;提供最新全国中学学校基本信息&#xff1b;包含全国初级中学与高等中学&#…

JavaWeb:实验一JSP运行环境安装及配置

一、实验目的 1&#xff0e;掌握JSP程序运行环境配置的基本要求。 2&#xff0e;熟悉HTML的常用标签使用以及静态网页的制作。 二、实验性质 验证性实验 三、实验内容 制作一个静态网站的基本页面index.html&#xff0c;要求如下&#xff1a; 1&#xff…

SQL 语言简明入门:从历史到实践

SQL&#xff08;Structured Query Language&#xff09;是数据库领域的核心语言。自20世纪70年代中期由IBM公司开发以来&#xff0c;SQL已经成为全球最广泛使用的数据库管理语言。 本文将以简洁明了的方式为您介绍SQL的历史、基本结构、核心语言组成以及其独特的特点和书写规则…

【软件部署】JumpServer堡垒机搭建及使用

目录 一、linux服务器docker搭建 二、linux服务器单机部署 1.环境安装 2.安装数据库和Redis 3.下载linux安装包并部署 4.安装启动后命令 5.通过浏览器访问部署服务器的地址 三、JumpServer产品使用 1、添加系统用户 2、创建资产 3、将主机/资源进行授权给用户 4、登录…

Day 7:条件编译

GCC编译器 预处理阶段&#xff1a; 1.gcc - E 文件名 &#xff1a;预处理文件 2.gcc - o 文件名 &#xff1a;重命名 gcc -E gcc.c gcc-o gcc.i&#xff1a;生成预处理文件。 vi gcc.i&#xff1a; 作用&#xff1a;展开头文件&#xff0c;宏进行替换。 编译阶段: gcc -…

希尔排序的图解展示与实现

什么是希尔排序 对整个数组进行预排序&#xff0c;即分组排序&#xff1a;按间距为gap分为一组&#xff0c;分组进行插入排序。 预排序的作用与特点 大的数更快地到后面&#xff0c;小的数更快地到前面&#xff1b; gap越大&#xff0c;跳得越快&#xff0c;排完接近有序慢&…

数据结构与算法---排序算法

文章目录 排序选择排序冒泡排序插入排序 希尔排序归并排序快速排序桶排序计数排序基数排序堆排序 排序 排序是指将一组数据按照特定的规则或顺序进行排列&#xff0c;比如一个数组[1, 5, 2, 4, 3]按照从小到大的顺序排列后就是[1,2,3,4,5]。 排序算法&#xff08;Sorting alg…

lay数据表格(table)的多选框限制单选

TOC lay数据表格(table)的多选框限制单选 使用layui弹窗显示表格数据提供选择&#xff0c;最初使用单选框&#xff0c;选中后无法取消勾选&#xff0c;后该成多选框限制成单选&#xff0c;可点击已勾选复选框实现取消功能。 PS&#xff1a;easyui数据表格提供简单实现 多选框…

有哪些数据分析类的软件可供参考?

对于数据分析师来说&#xff0c;掌握以下数据分析工具很有必要&#xff0c;一个好的数据分析工具&#xff0c;可以使得数据分析工作事半功倍&#xff0c;相对于整个数据分析学习流程来说&#xff0c;掌握数据分析工具是学习数据分析的关键。 日常数据分析中&#xff0c;有80%的…

【电力系统】使用电力系统稳定器 (PSS) 和静态 VAR 补偿器 (SVC) 提高瞬态稳定性

摘要 电力系统在面对故障和扰动时&#xff0c;其瞬态稳定性是确保系统安全运行的关键因素。本文探讨了通过使用电力系统稳定器&#xff08;PSS&#xff09;和静态VAR补偿器&#xff08;SVC&#xff09;来提高电力系统瞬态稳定性的策略。通过仿真分析&#xff0c;证明了PSS和SV…

折腾 Quickwit,Rust 编写的分布式搜索引擎 - 可观测性之分布式追踪

概述 分布式追踪是一种跟踪应用程序请求流经不同服务(如前端、后端、数据库等)的过程。它是一个强大的工具&#xff0c;可以帮助您了解应用程序的工作原理并调试性能问题。 Quickwit 是一个用于索引和搜索非结构化数据的云原生引擎&#xff0c;这使其非常适合用作追踪数据的后端…

提升农业信息化水平,C# ASP.NET Vue果树生长信息管理系统带来全新管理体验

&#x1f34a;作者&#xff1a;计算机毕设匠心工作室 &#x1f34a;简介&#xff1a;毕业后就一直专业从事计算机软件程序开发&#xff0c;至今也有8年工作经验。擅长Java、Python、微信小程序、安卓、大数据、PHP、.NET|C#、Golang等。 擅长&#xff1a;按照需求定制化开发项目…

【算法每日一练及解题思路】找出模式匹配字符串的异位词在原始字符串中出现的索引下标

【算法每日一练及解题思路】找出模式匹配字符串的异位词在原始字符串中出现的索引下标 一、题目&#xff1a;找出模式匹配字符串的异位词在原始字符串中出现的索引下标 二、举例&#xff1a; 两个字符串原始字符串initStr123sf3rtfb,模式匹配字符串regxf3s&#xff0c;找到模…

【读书笔记-《30天自制操作系统》-12】Day13

本篇的内容仍然是定时器的相关讲解。上一篇内容中对于中断程序做了许多优化&#xff0c;但是这些优化到底起了多少作用呢&#xff1f;本篇用一种测试方法来进行测试。然后本篇继续引入链表与哨兵的概念&#xff0c;进一步加快超时的中断处理。 1. 主程序简化 开发过程到了这…

nacos获取服务实例流程

一、客户端获取服务实例流程(以dubbo为例) 1.dubbo元数据服务初始化需要订阅的服务列表 1.1.获取与当前服务相同分组和集群的NACOS的注册服务列表。 1.2 首先是从spring-cloud-common的通用注册中心中&#xff0c;使用组合注册客户端类获取服务&#xff0c;此组合会逐个调用注…