YOLOv7改进:CBAM注意力机制

news2024/11/17 21:18:26

 目录

1.介绍

1.1、论文的出发点

1.2、论文的主要工作

1.3、CBAM模块的具体介绍

2.YOLOv7改进

2.1yaml 配置文件如下

2.2common.py配置

2.3yolo.py配置


1.介绍

1.1、论文的出发点

cnn基于其丰富的表征能力,极大地推动了视觉任务的完成,为了提高cnn网络的性能,最近的研究主要聚焦在网络的三个重要因素:深度、宽度和基数。除了这些因素,作者还研究了网络架构的一个不同方面——注意力。注意力研究的目标是通过使用注意机制来增加表现能力:关注重要的特征,并抑制不必要的特征。在本文中,作者提出了一个新的网络模块,名为“卷积块注意模块”(CBAM),该模块用来强调这两个主要维度上的有意义的特征:通道和空间轴,该模块实现方式是通过学习强调或抑制哪些信息,有效地帮助信息在网络中流动。

1.2、论文的主要工作

1. 提出了一种简单而有效的注意力模块(CBAM),可广泛应用于增强cnn的表示能力。
2. 作者验证了该注意模块的有效性,通过广泛的消融试验。
3. 通过插入CBAM,作者验证了在多个基准测试(ImageNet-1K, MS COCO,和VOC 2007)上,各种网络的性能都得到了极大的改善。

1.3、CBAM模块的具体介绍

CBAM注意力机制是由通道注意力机制(channel)和空间注意力机制(spatial)组成。

传统基于卷积神经网络的注意力机制更多的是关注对通道域的分析,局限于考虑特征图通道之间的作用关系。CBAM从 channel 和 spatial 两个作用域出发,引入空间注意力和通道注意力两个分析维度,实现从通道到空间的顺序注意力结构。空间注意力可使神经网络更加关注图像中对分类起决定作用的像素区域而忽略无关紧要的区域,通道注意力则用于处理特征图通道的分配关系,同时对两个维度进行注意力分配增强了注意力机制对模型性能的提升效果。
 

CBAM中的通道注意力机制模块流程图如下。先将输入特征图分别进行全局最大池化和全局平均池化,对特征映射基于两个维度压缩,获得两张不同维度的特征描述。池化后的特征图共用一个多层感知器网络,先通过一个全连接层下降通道数,再通过另一个全连接恢复通道数。将两张特征图在通道维度堆叠,经过 sigmoid 激活函数将特征图的每个通道的权重归一化到0-1之间。将归一化后的权重和输入特征图相乘。

 

 该模块有两个顺序子模块:通道(Channel)和空间(Spatial)。

1. Channel attention modul

 

目的:利用特征的通道间关系生成通道注意图。

方法通道维度不变,压缩输入特征图的空间维度。

2. Spatial attention module

 

目的:利用特征间的空间关系生成空间注意图。

方法:空间维度不变,压缩通道维度。

步骤:

(1)AP和MP操作:首先特征图F'使用AP(average pooling)和MP(max pooling)操作得到两个1*H*W的特征图。

(2)拼接和卷积:将它们拼接在一起得到一个2*H*W的特征图,再通过一个7x7的卷积重新得到1*H*W的特征图。

(3)sigmoid:最后,通过一个sigmoid函数,得到包含空间注意力的特征图。原文中没有给予这个特征图命名符,为了方便将其称之为z。

2.YOLOv7改进

2.1yaml 配置文件如下

# YOLOv7 🚀, GPL-3.0 license
# parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 1.0  # layer channel multiple

# anchors
anchors:
  - [12,16, 19,36, 40,28]  # P3/8
  - [36,75, 76,55, 72,146]  # P4/16
  - [142,110, 192,243, 459,401]  # P5/32

# yolov7 backbone by yoloair
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [32, 3, 1]],  # 0
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [128, 3, 2]],  # 3-P2/4 
   [-1, 1, CNeB, [128]], 
   [-1, 1, Conv, [256, 3, 2]], 
   [-1, 1, MP, []],
   [-1, 1, Conv, [128, 1, 1]],
   [-3, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, -3], 1, Concat, [1]],  # 16-P3/8
   [-1, 1, Conv, [128, 1, 1]],
   [-2, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [512, 1, 1]],
   [-1, 1, MP, []],
   [-1, 1, Conv, [256, 1, 1]],
   [-3, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 2]],
   [[-1, -3], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1]],
   [-2, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [1024, 1, 1]],          
   [-1, 1, MP, []],
   [-1, 1, Conv, [512, 1, 1]],
   [-3, 1, Conv, [512, 1, 1]],
   [-1, 1, Conv, [512, 3, 2]],
   [[-1, -3], 1, Concat, [1]],
   [-1, 1, CNeB, [1024]],
   [-1, 1, Conv, [256, 3, 1]],
  ]

# yolov7 head by yoloair
head:
  [[-1, 1, SPPCSPC, [512]],
   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [31, 1, Conv, [256, 1, 1]],
   [[-1, -2], 1, Concat, [1]],
   [-1, 1, C3C2, [128]],
   [-1, 1, Conv, [128, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [18, 1, Conv, [128, 1, 1]],
   [[-1, -2], 1, Concat, [1]],
   [-1, 1, C3C2, [128]],
   [-1, 1, MP, []],
   [-1, 1, Conv, [128, 1, 1]],
   [-3, 1, CBAM, [128]],
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, -3, 44], 1, Concat, [1]],
   [-1, 1, C3C2, [256]], 
   [-1, 1, MP, []],
   [-1, 1, Conv, [256, 1, 1]],
   [-3, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 2]], 
   [[-1, -3, 39], 1, Concat, [1]],
   [-1, 3, C3C2, [512]],

# 检测头 -----------------------------
   [49, 1, RepConv, [256, 3, 1]],
   [55, 1, RepConv, [512, 3, 1]],
   [61, 1, RepConv, [1024, 3, 1]],

   [[62,63,64], 1, IDetect, [nc, anchors]],   # Detect(P3, P4, P5)
  ]

2.2common.py配置

./models/common.py文件增加以下模块

class ChannelAttentionModule(nn.Module):
    def __init__(self, c1, reduction=16):
        super(ChannelAttentionModule, self).__init__()
        mid_channel = c1 // reduction
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.shared_MLP = nn.Sequential(
            nn.Linear(in_features=c1, out_features=mid_channel),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(in_features=mid_channel, out_features=c1)
        )
        self.act = nn.Sigmoid()
        #self.act=nn.SiLU()
    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x).view(x.size(0),-1)).unsqueeze(2).unsqueeze(3)
        maxout = self.shared_MLP(self.max_pool(x).view(x.size(0),-1)).unsqueeze(2).unsqueeze(3)
        return self.act(avgout + maxout)
        
class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        self.act = nn.Sigmoid()
    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.act(self.conv2d(out))
        return out

class CBAM(nn.Module):
    def __init__(self, c1,c2):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttentionModule(c1)
        self.spatial_attention = SpatialAttentionModule()

    def forward(self, x):
        out = self.channel_attention(x) * x
        out = self.spatial_attention(out) * out
        return out
  

2.3yolo.py配置

在 models/yolo.py文件夹下

  • 定位到parse_model函数中
  • for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):内部
  • 对应位置 下方只需要新增以下代码
elif m is CBAM:
    c1, c2 = ch[f], args[0]
    if c2 != no:
        c2 = make_divisible(c2 * gw, 8)
    args = [c1, c2]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1047236.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MySql】3- 实践篇(一)

文章目录 1. 普通索引和唯一索引的选择1.1 查询过程1.2 更新过程1.2.1 change buffer1.2.2 change buffer 的使用场景 1.3 索引选择和实践1.4 change buffer 和 redo log2. MySQL为何有时会选错索引?2.1 优化器的逻辑2.1.1 扫描行数是怎么判断的?2.1.2 重新统计索引信息 2.2 …

一站式吃鸡利器,提升游戏战斗力,助您稳坐鸡王宝座!

各位吃鸡玩家们,听说过绝地求生作图工具吗?想知道如何提高游戏战斗力、分享顶级作战干货、查询装备皮肤库存?还在为游戏账号安全而担心吗?别急,今天就为您介绍一款一站式吃鸡利器,满足您的所有需求&#xf…

【使用工具】IDEA创建类及已有类添加注释-详细操作

1.背景 很多开发好多时候其实不太会给类添加注释,尤其是已经有的类,上网查询,好多文档错误百出,而且不全 2.正文 2.1新建类添加注释 idea给新建类创建注释有两种方式 先写一个简单的模板 /** * description: TODO * autho…

kotlin协程CoroutineScope Dispatchers.IO launch 线程Id

kotlin协程CoroutineScope Dispatchers.IO launch 线程Id import kotlinx.coroutines.*fun main(args: Array<String>) {println("main 线程id:${Thread.currentThread().threadId()}")CoroutineScope(Dispatchers.IO).launch {println("launch 线程id:$…

【JVM】第二篇 JVM内存模型深度剖析与优化

目录 一. JDK体系结构与跨平台特性介绍二. JVM内存模型深度剖析三. 从Jvisualvm来研究下对象内存流转模型四. GC Root与STW机制五. JVM参数设置通用模型一. JDK体系结构与跨平台特性介绍 二. JVM内存模型深度剖析 按照线程是否共享来划分 TLAB(Thread Local Allocation Buffer…

mybatis核心组件

title: “mybatis核心组件” createTime: 2021-12-08T12:19:5708:00 updateTime: 2021-12-08T12:19:5708:00 draft: false author: “ggball” tags: [“mybatis”] categories: [“java”] description: “mybatis核心组件” #mermaid-svg-AYu4pQutsPsK0P5T {font-family:&quo…

stm32 - 初识2

stm32 - 初识2 工程架构点灯程序寄存器方式点灯库函数的方式点灯 工程架构 启动文件 中断向量表&#xff0c;中断服务函数&#xff0c;其他中断等 中断服务函数中的&#xff0c;复位中断是整个程序的入口&#xff0c;调用systeminit&#xff0c;和main函数 点灯程序 寄存器方式…

自适应阈值分割-OTSU

OTSU 在前面固定阈值中选取了一个阈值为127进行阈值分割&#xff0c;那如何知道选的这个阈值效果好不好呢&#xff1f;答案是&#xff1a;不断尝试&#xff0c;所以这种方法在很多文献中都被称为经验阈值。 Otsu阈值法就提供了一种自动高效的二值化方法。Otsu算法也称最大类间…

C++之std::atomic类模板原子操作应用总结(二百三十九)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…

JAVA+SpringBoot+VUE工厂车间管理系统(含论文)源码

springboot169基于vue的工厂车间管理系统的设计录像(毕业设计jdz2023) 一、源码描述 JAVASpringBootVUE工厂车间管理系统,包含源码数据库论文等,含MySQL脚本&#xff0c;基于B/S和Web开发的&#xff0c;感兴趣的朋友可以下载看看 二、功能介绍 1、个人中心 2、人员管理 3、设备…

计算机图像处理:图像轮廓

图像轮廓 图像阈值分割主要是针对图片的背景和前景进行分离&#xff0c;而图像轮廓也是图像中非常重要的一个特征信息&#xff0c;通过对图像轮廓的操作&#xff0c;就能获取目标图像的大小、位置、方向等信息。画出图像轮廓的基本思路是&#xff1a;先用阈值分割划分为两类图…

Textpad 缺少Java编译和运行功能

一、问题 缺少Java编译和运行功能 二、处理方法 1、点击菜单Configure->Preferences 2、点击 Tools -> Add -> Java SDK Commands 3、点击应用和确认 三、结果

现代 GPU 容易受到新 GPU.zip 侧通道攻击

来自四所美国大学的研究人员开发了一种新的 GPU 侧通道攻击&#xff0c;该攻击利用数据压缩在访问网页时泄露现代显卡中的敏感视觉数据。 研究人员通过 Chrome 浏览器执行跨源 SVG 过滤器像素窃取攻击&#xff0c;证明了这种“ GPU.zip ”攻击的有效性。 研究人员于 2023 年 …

【JVM】第五篇 垃圾收集器G1和ZGC详解

导航 一. G1垃圾收集算法详解1. 大对象Humongous说明2. G1收集器执行一次GC运行的过程步骤3. G1垃圾收集分类4. G1垃圾收集器参数设置5. G1垃圾收集器的优化建议6. 适合使用G1垃圾收集器的场景?二. ZGC垃圾收集器详解1. NUMA与UMA2. 颜色指针3. ZGC的运作过程4. ZGC垃圾收集器…

mysql面试题2:说一说MySQL的架构设计?一条 MySQL 语句执行的步骤?

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:说一说MySQL的架构设计? MySQL的架构设计主要包括以下几个组件: 连接器(Connector):负责与客户端建立连接,并进行身份验证和授权。 查询缓存…

文件的随机读写函数:fseek

目录 函数介绍&#xff1a; fseek&#xff1a; 原型&#xff1a; 参数说明&#xff1a; int origin&#xff1a; 举例&#xff1a; 文件内容展示&#xff1a; 正常的使用fgetc函数&#xff1a; 结果&#xff1a; 使用了fseek之后&#xff1a; SEEK_SET :从开始位置进行…

VBA技术资料MF60:从二维变体数组中删除一列数据

【分享成果&#xff0c;随喜正能量】如果一个人能看破红尘&#xff0c;特别是如果能看破贪欲的本质&#xff0c;他的身心一定会非常自在&#xff0c;外在的气色会很好&#xff0c;内心也会非常安乐。如果一个人总是贪执某人&#xff0c;那他的戒律肯定不会清净&#xff0c;他的…

渗透测试之打点

请遵守中华人民共和国网络安全法 打点的目的是获取一个服务器的控制权限 1. 企业架构收集 &#xff08;1&#xff09;官网 &#xff08;2&#xff09;网站或下属的子网站&#xff0c;依次往下 天眼查 企查查 2. ICP 备案查询 ICP/IP地址/域名信息备案管理系统 使用网站…

轮廓检测及透视变换

文章目录 注&#xff1a;代码来自b站&#xff1a;阿头目G # 导入工具包 import numpy as np import argparse import cv2# # 设置参数 # ap argparse.ArgumentParser() # ap.add_argument("-i", "--image", required True, # help "Path to the i…

flutter GetMaterialAPP unknownRoute 失效

这个问题。。。找了半天&#xff0c;最后去了 官网的ISSu找到了答案&#xff0c;分享给同样困惑的小伙伴们 为了让unknownRoute能成功工作, 你的initialRoute一定不能是/, 否则当有未定义的导航时, Get只会跳到initialRoute页面去了. GetPage 里面不要把root设置为“/” 页面&a…