YOLOv9有效提点|加入SE、CBAM、ECA、SimAM等几十种注意力机制(一)

news2025/1/22 21:37:21


专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!!


一、本文介绍

        本文将以SE注意力机制为例,演示如何在YOLOv9种添加注意力机制!


 《Squeeze-and-Excitation Networks》

        SENet提出了一种基于“挤压和激励”(SE)的注意力模块,用于改进卷积神经网络(CNN)的性能。SE块可以适应地重新校准通道特征响应,通过建模通道之间的相互依赖关系来增强CNN的表示能力。这些块可以堆叠在一起形成SENet架构,使其在多个数据集上具有非常有效的泛化能力。


《CBAM:Convolutional Block Attention Module》

        CBAM模块能够同时关注CNN的通道和空间两个维度,对输入特征图进行自适应细化。这个模块轻量级且通用,可以无缝集成到任何CNN架构中,并可以进行端到端训练。实验表明,使用CBAM可以显著提高各种模型的分类和检测性能。


《ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks》

        通道注意力模块ECA,可以提升深度卷积神经网络的性能,同时不增加模型复杂性。通过改进现有的通道注意力模块,作者提出了一种无需降维的局部交互策略,并自适应选择卷积核大小。ECA模块在保持性能的同时更高效,实验表明其在多个任务上具有优势。


《SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks》

        SimAM一种概念简单且非常有效的注意力模块。不同于现有的通道/空域注意力模块,该模块无需额外参数为特征图推导出3D注意力权值。具体来说,SimAM的作者基于著名的神经科学理论提出优化能量函数以挖掘神经元的重要性。该模块的另一个优势在于:大部分操作均基于所定义的能量函数选择,避免了过多的结构调整

适用检测目标:   YOLOv9模块通用改进


二、改进步骤

        以下以SE注意力机制为例在YOLOv9中加入注意力代码,其他注意力机制同理!

 2.1 复制代码

        将SE的代码辅助到models包下common.py文件中。

 2.2 修改yolo.py文件

        在yolo.py脚本的第700行(可能因YOLOv9版本变化而变化)增加下方代码。

        elif m in (SE,):
            args.insert(0, ch[f])

2.3 创建配置文件

        创建模型配置文件(yaml文件),将我们所作改进加入到配置文件中(这一步的配置文件可以复制models  - > detect 下的yaml修改。)。对YOLO系列yaml文件不熟悉的同学可以看我往期的yaml详解教学!

YOLO系列 “.yaml“文件解读-CSDN博客

# YOLOv9
# Powered bu https://blog.csdn.net/StopAndGoyyy
# parameters
nc: 80  # number of classes
depth_multiple: 1  # model depth multiple
width_multiple: 1  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()

# anchors
anchors: 3

# YOLOv9 backbone
backbone:
  [
   [-1, 1, Silence, []],  
   
   # conv down
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 2-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 3

   # avg-conv down
   [-1, 1, ADown, [256]],  # 4-P3/8

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 5

   # avg-conv down
   [-1, 1, ADown, [512]],  # 6-P4/16

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 7

   # avg-conv down
   [-1, 1, ADown, [512]],  # 8-P5/32

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 9
  ]

# YOLOv9 head
head:
  [
   # elan-spp block
   [-1, 1, SPPELAN, [512, 256]],  # 10

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 7], 1, Concat, [1]],  # cat backbone P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 13

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P3

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 16 (P3/8-small)

   # avg-conv-down merge
   [-1, 1, ADown, [256]],
   [[-1, 13], 1, Concat, [1]],  # cat head P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 19 (P4/16-medium)

   # avg-conv-down merge
   [-1, 1, ADown, [512]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 22 (P5/32-large)
   
   
   # multi-level reversible auxiliary branch
   
   # routing
   [5, 1, CBLinear, [[256]]], # 23
   [7, 1, CBLinear, [[256, 512]]], # 24
   [9, 1, CBLinear, [[256, 512, 512]]], # 25
   
   # conv down
   [0, 1, Conv, [64, 3, 2]],  # 26-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 27-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 28

   # avg-conv down fuse
   [-1, 1, ADown, [256]],  # 29-P3/8
   [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30  

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 31

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 32-P4/16
   [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33 

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 34

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 35-P5/32
   [[25, -1], 1, CBFuse, [[2]]], # 36

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 37
   [-1, 1, SE, [16]],  # 38

   
   
   # detection head

   # detect
   [[31, 34, 38, 16, 19, 22], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)
  ]

3.1 训练过程

        最后,复制我们创建的模型配置,填入训练脚本(train_dual)中(不会训练的同学可以参考我之前的文章。),运行即可。

YOLOv9 最简训练教学!-CSDN博客

​​

​​


SE代码

class SE(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SE, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

CBAM代码

class CBAMBlock(nn.Module):

    def __init__(self, channel=512, reduction=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention(channel=channel, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=kernel_size)

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out

ECA代码

class ECAAttention(nn.Module):

    def __init__(self, kernel_size=3):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.sigmoid = nn.Sigmoid()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        y = self.gap(x)  # bs,c,1,1
        y = y.squeeze(-1).permute(0, 2, 1)  # bs,1,c
        y = self.conv(y)  # bs,1,c
        y = self.sigmoid(y)  # bs,1,c
        y = y.permute(0, 2, 1).unsqueeze(-1)  # bs,c,1,1
        return x * y.expand_as(x)

SimAM代码

class SimAM(torch.nn.Module):
    def __init__(self, e_lambda=1e-4):
        super(SimAM, self).__init__()

        self.activaton = nn.Sigmoid()
        self.e_lambda = e_lambda

    def __repr__(self):
        s = self.__class__.__name__ + '('
        s += ('lambda=%f)' % self.e_lambda)
        return s

    @staticmethod
    def get_module_name():
        return "simam"

    def forward(self, x):
        b, c, h, w = x.size()

        n = w * h - 1

        x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
        y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5

        return x * self.activaton(y)

如果觉得本文章有用的话给博主点个关注吧!


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1482281.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【六袆 - React】Next.js:React 开发框架;Next.js开发框架的特点

Next.js:React 开发框架 Next.js的特点 1.直观的、基于页面的路由系统(并支持动态路由) Next.js 提供了基于文件系统的路由,意味着你可以通过创建页面文件来定义路由。 伪代码示例: // pages/index.js export defa…

css【详解】—— 圣杯布局 vs 双飞翼布局 (含手写清除浮动 clearfix)

两者功能效果相同&#xff0c;实现方式不同 效果预览 两侧宽度固定&#xff0c;中间宽度自适应&#xff08;三栏布局&#xff09;中间部分优先渲染允许三列中的任意一列成为最高列 圣杯布局 通过左右栏填充容器的左右 padding 实现&#xff0c;更多细节详见注释。 <!DOCTYP…

day03-Vue-Element

一、Ajax 1 Ajax 介绍 1.1 Ajax 概述 概念&#xff1a;Asynchronous JavaScript And XML&#xff0c;异步 的 JavaScript 和 XML。 作用&#xff1a; 数据交换&#xff1a;通过 Ajax 可以给服务器发送请求&#xff0c;并获取服务器响应的数据。异步交互&#xff1a;可以在 不…

排序(2)——希尔排序

希尔排序&#xff08;缩小增量排序&#xff09; 基本思想 希尔排序法又称缩小增量法。希尔排序法的基本思想是&#xff1a;先选定一个整数&#xff0c;把待排序文件中所有记录分成个组&#xff0c;所有距离为的记录分在同一组内&#xff0c;并对每一组内的记录进行排序。然后&…

lotus 从矿工可用余额扣除扇区质押

修改 miner配置文件 # Whether to use available miner balance for sector collateral instead of sending it with each message## type: bool# env var: LOTUS_SEALING_COLLATERALFROMMINERBALANCE#CollateralFromMinerBalance falseCollateralFromMinerBalance true质押金…

手写数字识别(慕课MOOC人工智能之模式识别)

问题&#xff1a;手写数字识别 数据集 数据集链接请点击我 代码 %mat2vector.m function [data_] mat2vector(data,num)[row,col,~] size(data);data_zeros(num,row*col);for page 1:numfor rows 1:rowfor cols1:coldata_(page,((rows-1)*colcols)) im2double(data(rows,cols…

应用稳定性优化1:ANR问题全面解析

闪退、崩溃、无响应、重启等是应用稳定性常见的问题现象&#xff0c;稳定性故障大体可归类为ANR/冻屏、Crash/Tombstone、资源泄露三大类。本文通过对三类故障的产生原因、故障现象、触发机制及如何定位等&#xff0c;展开深度解读。 本文将详解ANR类故障&#xff0c;并通过一…

java数据结构与算法刷题-----LeetCode437. 路径总和 III(前缀和必须掌握)

java数据结构与算法刷题目录&#xff08;剑指Offer、LeetCode、ACM&#xff09;-----主目录-----持续更新(进不去说明我没写完)&#xff1a;https://blog.csdn.net/grd_java/article/details/123063846 文章目录 1. 深度优先2. 前缀和 1. 深度优先 解题思路&#xff1a;时间复…

图书推荐||Word文稿之美

让你的文档从平凡到出众&#xff01; 本书内容 《Word文稿之美》是一本全面介绍Word排版技巧和应用的实用指南。从初步认识数字排版到高效利用模板、图文配置和表格与图表的排版技巧&#xff0c;再到快速修正错误和保护文件&#xff0c;全面系统地讲解数字排版的技术和能力&…

多行业万能预约门店小程序源码系统 支持多门店预约小程序 带完整的安装代码包以及搭建教程

随着消费者对于服务体验要求的不断提升&#xff0c;门店预约系统成为了许多行业提升服务质量、提高运营效率的重要工具。然而&#xff0c;市面上的预约系统往往功能单一&#xff0c;无法满足多行业、多场景的个性化需求。下面&#xff0c;小编集合了多年的行业经验和技术积累&a…

Linux 安装k8s

官网 常见的三种安装k8s方式 1.kubeadm 2.kops&#xff1a;自动化集群制备工具 3.kubespray&#xff1a; 提供了 Ansible Playbook 下面以kubeadm安装k8s kubeadm的安装是通过使用动态链接的二进制文件完成的&#xff0c;目标系统需要提供 glibc ##使用 ss 或者 netstat 检测端…

基于Springboot的高校实习信息发布网站的设计与实现(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的高校实习信息发布网站的设计与实现&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xf…

jmeter如何请求访问https接口

添加线程组http请求 新建线程组&#xff0c;添加http请求 填入协议&#xff0c;ip&#xff0c;端口&#xff0c;请求类型&#xff0c;路径&#xff0c;以及请求参数&#xff0c;查看结果树等。 然后最关键的一步来了。 导入证书 步骤&#xff1a;获取证书&#xff0c;重新生…

Windows11家庭版安装Docker

文章目录 安装Docker安装hyper-v继续解决报错完成效果图进一步测试是否完成安装 安装Docker windows如何安装docker 装好之后&#xff0c;我打开报错。 安装hyper-v 按这个视频操作&#xff1a;Windows 11 家庭版安装 Hyper-V bat文件里的代码是&#xff1a; pushd "…

机器学习 | 模型性能评估

目录 一. 回归模型的性能评估1. 平均平方误差(MSE)2. 平均绝对误差(MAE)3. R 2 R^{2} R2 值3.1 R 2 R^{2} R2优点 二. 分类模型的性能评估1. 准确率&#xff08;Accuracy&#xff09;2. 召回率&#xff08;Recall&#xff09;3. 精确率&#xff08;Precision&#xff09;4. …

指针篇章-(1)

指针&#xff08;1&#xff09;学习流程 —————————————————————————————————————————————————————————————————————————————————————————————————————————————…

帝国cms7.5仿非小号区块链门户资讯网站源码 带手机版

帝国cms7.5仿非小号区块链门户资讯网站源码 带手机版 带自动采集 开发环境&#xff1a;帝国cms 7.5 安装环境&#xff1a;phpmysql 包含火车头采集规则和模块&#xff0c;采集目标站非小号官网。 专业的数字货币大数据平台模板&#xff0c;采用帝国cms7.5内核仿制&#xff0…

爱普生的SG2016系列高频,低相位抖动spxo样品

精工爱普生公司(TSE: 6724&#xff0c;“爱普生”)已经开始发货样品的新系列简单封装晶体振荡器(SPXO)与差分输出1。该系列包括SG2016EGN、SG2016EHN、SG2016VGN和SG2016VHN。它们在基本模式下都具有低相位抖动&#xff0c;并且采用尺寸为2.0 x 1.6 mm的小封装&#xff0c;高度…

【go从入门到精通】什么是go?为什么要选择go?

go的出生&#xff1a; go语言&#xff08;或Golang&#xff09;是Google开发的开源编程语言&#xff0c;诞生于2006年1月2日下午15点4分5秒&#xff0c;于2009年11月开源&#xff0c;2012年发布go稳定版。Go语言在多核并发上拥有原生的设计优势&#xff0c;Go语言从底层原生支持…

深度学习 精选笔记(10)简单案例:房价预测

学习参考&#xff1a; 动手学深度学习2.0Deep-Learning-with-TensorFlow-bookpytorchlightning ①如有冒犯、请联系侵删。 ②已写完的笔记文章会不定时一直修订修改(删、改、增)&#xff0c;以达到集多方教程的精华于一文的目的。 ③非常推荐上面&#xff08;学习参考&#x…