YOLOv9改进策略【卷积层】| HWD,引入`Haar小波变换`到下采样模块中,减少信息丢失

news2024/9/19 13:37:40

一、本文介绍

本文记录的是利用Haar小波下采样对YOLOv9网络进行改进的方法研究。传统的卷积神经网络中常用的最大池化平均池化步长为2的卷积等操作进行下采样可能会导致信息丢失,为了解决信息丢失问题,HWD作者受无损信息变换方法的启发,引入Haar小波变换到下采样模块中,旨在尽可能地保留图像信息,以便后续层能够提取更具判别性的特征,从而提高分割性能。

文章目录

  • 一、本文介绍
  • 二、Haar小波下采样原理
    • 2.1、原理
    • 2.2、优势
  • 三、HWD的实现代码
  • 四、添加步骤
    • 4.1 修改common.py
    • 4.2 修改yolo.py
    • 4.3 修改train_dual.py
  • 五、yaml模型文件
    • 5.1 模型改进
  • 六、成功运行结果


二、Haar小波下采样原理

Haar小波下采样:一个简单但有效的语义分割下采样模块。

2.1、原理

HWD模块由两个主要块组成:无损特征编码块特征表示学习块

  • 无损特征编码块:利用Haar小波变换层有效地降低特征图的空间分辨率,同时保留所有信息。Haar小波变换是一种广泛认可的、紧凑的、二进的和正交的变换,在图像编码、边缘提取和二进制逻辑设计中有着广泛的应用。当对二维信号(如灰度图像)应用Haar小波变换时,会产生四个分量,每个分量的空间分辨率是原始信号的一半,而特征图的通道数则变为原来的四倍。这意味着Haar小波变换可以将部分空间维度的信息编码到通道维度中,而不会丢失任何信息。
  • 特征表示学习块:由标准的1×1卷积层批量归一化层ReLU激活函数组成。该块用于调整特征图的通道数,使其与后续层对齐,并尽可能地过滤冗余信息,使后续层能够更有效地学习代表性特征。

在这里插入图片描述

2.2、优势

  • 提高分割性能:通过在三个不同模态的图像数据集上进行的广泛实验表明,HWD模块能够有效提高分割性能。在Camvid数据集上,与七种最先进的分割架构相结合,使用HWD模块的模型在平均交并比(mIoU)上相比基线有1 - 2%的提升,特别是对于小尺度对象(如行人、自行车、围栏和标志符号等)的性能有显著改善。
  • 减少信息不确定性:利用结构相似性(SSIM)、峰值信噪比(PSNR)和提出的特征熵指数(FEI)评估下采样对特征图的有效性,结果表明HWD模块能够提高SSIM(7.78%)和PSNR(2.14 dB),并大幅降低信息不确定性。在所有21个模型中,HWD模块相比原始下采样方法,使特征不确定性降低了58.2%(FEI)和46.8%(FEI_B)。
  • 通用性和易用性HWD模块可以直接替换现有分割架构中的现有下采样方法(如最大池化、平均池化或步幅卷积),而不会引入额外的复杂性,并且能够显著提高分割性能。
  • 在参数和计算量上的平衡:与传统的下采样方法(如平均池化和步幅卷积)相比,HWD模块在参数和浮点运算(FLOPs)上提供了一种平衡。虽然平均池化在参数和FLOPs方面表现更好,但HWD模块所需的参数少于步幅卷积的两倍,并且当通道数C大于一时,步幅卷积的计算开销超过HWD模块
  • 对浅层CNN的有效性:在MOST数据集上的实验表明,当使用ResNet - 18和ResNet - 34作为特征提取的骨干网络时,HWD模块显著提高了分割性能,这表明浅层CNN对信息的需求更高,而HWD模块能够满足这种需求。

HWD模块与其他下采样模块对比

保留信息能力:传统的下采样方法(如最大池化、平均池化和步幅卷积等)会导致信息丢失,而HWD模块通过引入Haar小波变换,能够在降低特征图空间分辨率的同时尽可能保留信息。

论文:https://doi.org/10.1016/j.patcog.2023.109819
源码:https://github.com/apple1986/HWD

三、HWD的实现代码

HWD模块的实现代码如下:

class HWD(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(HWD, self).__init__()
        from pytorch_wavelets import DWTForward
        self.wt = DWTForward(J=1, mode='zero', wave='haar')
        self.conv = Conv(in_ch * 4, out_ch, 1, 1)
 
    def forward(self, x):
        yL, yH = self.wt(x)
        y_HL = yH[0][:, :, 0, ::]
        y_LH = yH[0][:, :, 1, ::]
        y_HH = yH[0][:, :, 2, ::]
        x = torch.cat([yL, y_HL, y_LH, y_HH], dim=1)
        x = self.conv(x)
 
        return x


四、添加步骤

4.1 修改common.py

此处需要修改的文件是models/common.py

common.py中定义了网络结构的通用模块,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。

HWD的实现过程中使用的pytorch_wavelets包需要自行安装:

pip install pytorch_wavelets

HWD模块添加后如下:

在这里插入图片描述

注意❗:在4.2小节中的yolo.py文件中需要声明的模块名称为:HWD

4.2 修改yolo.py

此处需要修改的文件是models/yolo.py

yolo.py用于函数调用,我们只需要将common.py中定义的新的模块名添加到parse_model函数下即可。

HWD模块添加后如下:

在这里插入图片描述

还需在此函数下添加如下代码:

在这里插入图片描述

elif m in (HWD,):
    args = [ch[f], ch[f]]

4.3 修改train_dual.py

train_dual.py文件的第314行关闭amp,将其设置为False

with torch.cuda.amp.autocast(False):
     pred = model(imgs)  # forward
     loss, loss_items = compute_loss(pred, targets.to(device))  # loss scaled by batch_size
     if RANK != -1:
         loss *= WORLD_SIZE  # gradient averaged between devices in DDP mode
     if opt.quad:
         loss *= 4.

在这里插入图片描述


五、yaml模型文件

5.1 模型改进

在代码配置完成后,配置模型的YAML文件。

此处以models/detect/yolov9-c.yaml为例,在同目录下创建一个用于自己数据集训练的模型文件yolov9-c-hwd.yaml

yolov9-c.yaml中的内容复制到yolov9-c-hwd.yaml文件下,修改nc数量等于自己数据中目标的数量。

📌 修改方法是将HWD模块替换YOLOv9网络中的ADown模块HWD受无损信息变换方法的启发,引入Haar小波变换到下采样模块中,旨在尽可能地保留图像信息,使改进后的模型在下采样过程中能够提取更具判别性的特征,从而提高模型性能。

# YOLOv9

# parameters
nc: 1  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()

# anchors
anchors: 3

# YOLOv9 backbone
backbone:
  [
   [-1, 1, Silence, []],  
   
   # conv down
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 2-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 3

   # avg-conv down
   [-1, 1, HWD, [256]],  # 4-P3/8

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 5

   # avg-conv down
   [-1, 1, HWD, [512]],  # 6-P4/16

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 7

   # avg-conv down
   [-1, 1, HWD, [512]],  # 8-P5/32

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 9
  ]

# YOLOv9 head
head:
  [
   # elan-spp block
   [-1, 1, SPPELAN, [512, 256]],  # 10

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 7], 1, Concat, [1]],  # cat backbone P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 13

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P3

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 16 (P3/8-small)

   # avg-conv-down merge
   [-1, 1, ADown, [256]],
   [[-1, 13], 1, Concat, [1]],  # cat head P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 19 (P4/16-medium)

   # avg-conv-down merge
   [-1, 1, ADown, [512]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 22 (P5/32-large)
   
   
   # multi-level reversible auxiliary branch
   
   # routing
   [5, 1, CBLinear, [[256]]], # 23
   [7, 1, CBLinear, [[256, 512]]], # 24
   [9, 1, CBLinear, [[256, 512, 512]]], # 25
   
   # conv down
   [0, 1, Conv, [64, 3, 2]],  # 26-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 27-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 28

   # avg-conv down fuse
   [-1, 1, ADown, [256]],  # 29-P3/8
   [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30  

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 31

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 32-P4/16
   [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33 

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 34

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 35-P5/32
   [[25, -1], 1, CBFuse, [[2]]], # 36

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 37
   
   
   
   # detection head

   # detect
   [[31, 34, 37, 16, 19, 22], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)
  ]


六、成功运行结果

分别打印网络模型可以看到HWD模块已经加入到模型中,并可以进行训练了。

yolov9-c-hwd

                 from  n    params  module                                  arguments                     
  0                -1  1         0  models.common.Silence                   []                            
  1                -1  1      1856  models.common.Conv                      [3, 64, 3, 2]                 
  2                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               
  3                -1  1    212864  models.common.RepNCSPELAN4              [128, 256, 128, 64, 1]        
  4                -1  1    262656  models.common.HWD                       [256, 256]                    
  5                -1  1    847616  models.common.RepNCSPELAN4              [256, 512, 256, 128, 1]       
  6                -1  1   1049600  models.common.HWD                       [512, 512]                    
  7                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       
  8                -1  1   1049600  models.common.HWD                       [512, 512]                    
  9                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       
 10                -1  1    656896  models.common.SPPELAN                   [512, 512, 256]               
 11                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 12           [-1, 7]  1         0  models.common.Concat                    [1]                           
 13                -1  1   3119616  models.common.RepNCSPELAN4              [1024, 512, 512, 256, 1]      
 14                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 15           [-1, 5]  1         0  models.common.Concat                    [1]                           
 16                -1  1    912640  models.common.RepNCSPELAN4              [1024, 256, 256, 128, 1]      
 17                -1  1    164352  models.common.ADown                     [256, 256]                    
 18          [-1, 13]  1         0  models.common.Concat                    [1]                           
 19                -1  1   2988544  models.common.RepNCSPELAN4              [768, 512, 512, 256, 1]       
 20                -1  1    656384  models.common.ADown                     [512, 512]                    
 21          [-1, 10]  1         0  models.common.Concat                    [1]                           
 22                -1  1   3119616  models.common.RepNCSPELAN4              [1024, 512, 512, 256, 1]      
 23                 5  1    131328  models.common.CBLinear                  [512, [256]]                  
 24                 7  1    393984  models.common.CBLinear                  [512, [256, 512]]             
 25                 9  1    656640  models.common.CBLinear                  [512, [256, 512, 512]]        
 26                 0  1      1856  models.common.Conv                      [3, 64, 3, 2]                 
 27                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               
 28                -1  1    212864  models.common.RepNCSPELAN4              [128, 256, 128, 64, 1]        
 29                -1  1    164352  models.common.ADown                     [256, 256]                    
 30  [23, 24, 25, -1]  1         0  models.common.CBFuse                    [[0, 0, 0]]                   
 31                -1  1    847616  models.common.RepNCSPELAN4              [256, 512, 256, 128, 1]       
 32                -1  1    656384  models.common.ADown                     [512, 512]                    
 33      [24, 25, -1]  1         0  models.common.CBFuse                    [[1, 1]]                      
 34                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       
 35                -1  1    656384  models.common.ADown                     [512, 512]                    
 36          [25, -1]  1         0  models.common.CBFuse                    [[2]]                         
 37                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       
 38[31, 34, 37, 16, 19, 22]  1  21542822  DualDDetect                             [1, [512, 512, 512, 256, 512, 512]]
yolov9-c-hwd summary: 601 layers, 51583014 parameters, 49258246 gradients, 239.5 GFLOPs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2146359.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python本地进程通讯----共享内存变量

背景 最近在开发实践中,接触到了需要多进程开发的场景。众所周知,进程和线程最大的区别就在于:进程是资源分配的最小单位,线程是cpu调度的最小单位。对于多进程开发来说,每一个进程都占据一块独立的虚拟内存空间&#…

大数据新视界 --大数据大厂之探索ES:大数据时代的高效搜索引擎实战攻略

💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

致计算机新生们

欢迎你们踏入计算机科学的世界,这是一个充满挑战与机遇的领域。在你们即将开始的大学旅程中,了解计算机专业的就业方向和行业现状是非常重要的。以下是一些关于计算机专业就业方向和行业现状的介绍,希望能够帮助你们更好地规划自己的未来。 …

土豆王国小乐队携手阿派朗创造力乐园,打造2024年okgo儿童音乐节

艺术与科技的完美融合,为首都少年儿童带来音乐盛宴 北京,2024年9月19日 —— 备受期待的2024年okgo儿童音乐节即将于9月21日至22日在北京阿派朗创造力乐园盛大开幕。这场由土豆王国小乐队与阿派朗创造力乐园联合举办的音乐节,旨在为首都及全国…

波分技术基础 -- WDM/OTN介绍

什么是WDM WDM(Wavelength Division Multiplexing):波分复用技术,将不同波长的光信号复用到一根光纤中进行传送的方式(每个波长承载一个业务信号),主要功能是传送和复用。在波分技术出现之前&am…

Gephi 0.9.2中文版百度云下载(附教程)

如大家所了解的,Gephi常用于各种图形和网络的可视化和探索,是最受欢迎的网络可视化软件之一。在生物科学领域,常用于基因共表达网络、蛋白互作网络、微生物相互关系网络等等类似的网络图形绘制。 目前用的比较多的版本为Gephi 0.9.2&#xf…

使用rust自制操作系统内核

一、系统简介 本操作系统是一个使用rust语言实现,基于32位的x86CPU的分时操作系统。 项目地址(求star):GitHub - CaoGaorong/os-in-rust: 使用rust实现一个操作系统内核 详细文档:自制操作系统 语雀 1. 项目特性 …

深度学习自编码器 - 使用自编码器学习流形篇

序言 在数据科学的浩瀚宇宙中,深度学习如同一颗璀璨的星辰,引领着我们对复杂数据内在规律的探索。其中,自编码器作为深度学习家族中的一位独特成员,以其非凡的能力——通过无监督学习捕捉数据的有效表示,而备受瞩目。…

Tomcat_WebApp

Tomcat的目录的介绍 /bin: 这个目录包含启动和关闭 Tomcat 的脚本。 startup.bat / startup.sh:用于启动 Tomcat(.bat 文件是 Windows 系统用的,.sh 文件是 Linux/Unix 系统用的)。shutdown.bat / shutdown.sh&#xf…

Java 实现桌面烟花秀

前言 今天,我们将展示如何使用 Java Swing 创建一个烟花效果,覆盖整个桌面。我们将重点讲解如何在桌面上展示烟花、如何实现发射和爆炸效果,以及如何将这些效果整合到一个完整的程序中。 效果展示 如上图所示,我们在桌面实现了&…

【开源大模型生态9】百度的文心大模型

这张图展示了百度千帆大模型平台的功能架构及其与BML-AI开发平台和百度百舸AI异构计算平台的关系。以下是各个模块的解释: 模型广场: 通用大模型:提供基础的自然语言处理能力。行业大模型:针对不同行业的定制化模型。大模型工具链…

新的 MathWorks 硬件支持包支持从 MATLAB 和 Simulink 模型到高通 Hexagon 神经处理单元架构的自动化代码生成

MathWorks 今天宣布,推出针对 Qualcomm Hexagon™ 神经处理单元(NPU)的硬件支持包。该处理单元嵌入在 Snapdragon 系列处理器中。MathWorks 硬件支持包,则专门针对 Qualcomm Technologies 的 Hexagon NPU 架构进行优化&#xff0c…

基于SSM的“校园外卖管理系统”的设计与实现(源码+数据库+文档+开题报告)

基于SSM的“校园外卖管理系统”的设计与实现(源码数据库文档开题报告) 开发语言:Java 数据库:MySQL 技术:SSM 工具:IDEA/Ecilpse、Navicat、Maven 系统展示 消费者系统结构图 商户系统结构图 管理员系统结构图 校…

数据脱敏 (Jackson + Hutool 工具包)

一、简介 系统使用 Jackson 序列化策略,对标注了 Sensitive 注解的属性进行脱敏处理 基于Hutool 脱敏案列: Retention(RetentionPolicy.RUNTIME) Target(ElementType.FIELD) JacksonAnnotationsInside// 表示只对有此注解的字段进行序列化 JsonSeriali…

MySQL高阶1831-每天的最大交易

题目 编写一个解决方案,报告每天交易金额 amount 最大 的交易 ID 。如果一天中有多个这样的交易,返回这些交易的 ID 。 返回结果根据 transaction_id 升序排列。 准备数据 Create table If Not Exists Transactions (transaction_id int, day date, …

吹爆上海交大的大模型实战教程!!—《动手学大模型》附实战教程及ppt

今天分享一个上海交大的免费的大模型课程,有相关教程文档和Slides,目前是2.2K星标,还是挺火的! 《动手学大模型》系列编程实践教程, 由上海交通大学2024年春季《人工智能安全技术》课程(NIS3353&#xff09…

深入剖析Docker容器安全:挑战与应对策略

随着容器技术的广泛应用,Docker已成为现代应用开发和部署的核心工具。它通过轻量级虚拟化技术实现应用的隔离与封装,提高了资源利用率。然而,随着Docker的流行,其安全问题也成为关注焦点。容器化技术虽然提供了良好的资源隔离&…

SHAP 模型可视化 + 参数搜索策略在轴承故障诊断中的应用

往期精彩内容: Python-凯斯西储大学(CWRU)轴承数据解读与分类处理 Python轴承故障诊断入门教学-CSDN博客 Python轴承故障诊断 (13)基于故障信号特征提取的超强机器学习识别模型-CSDN博客 Python轴承故障诊断 (14)高创新故障识别模型-CSDN…

Linux用户组管理

目录 一、增删改用户组 1.1. 创建一个新的用户组 1.2. 创建用户组并指定ID 1.3. 修改用户组的名 1.4. 修改用户组的ID 1.5. 删除一个用户组 二、用户组中的用户操作 2.1. 添加用户到一个已存在的用户组 2.2. 从用户组中移除用户 注:本章内容全部基于Centos…

论文阅读--Planning-oriented Autonomous Driving(二)

自动驾驶框架的各种设计比较。 ( a )大多数工业解决方案针对不同的任务部署不同的模型。 ( b )多任务学习方案共享一个具有分割任务头的主干。 ( c )端到端范式将感知和预测模块统一起来。以往的尝试要么采用( c.1 )中对规划的直接优化,要么采用( c.2 )中的部分元…