YOLOv5改进 | 注意力机制 | 通道和空间的双重作用的CBAM注意力机制

news2024/10/5 19:17:56

在深度学习目标检测领域,YOLOv5成为了备受关注的模型之一。本文给大家带来的是通道和空间的双重作用的CBAM注意力机制。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。


专栏地址YOLOv5改进+入门——持续更新各种有效涨点方法 

目录

1.原理 

2.YOLOv5添加CBAM注意力机制

2.1 CBAM注意力机制代码

2.2新增yaml文件

2.3 注册模块

2.4 执行程序

3.总结 


1.原理 

论文地址:CBAM: Convolutional Block Attention Module点击即可跳转

实现代码:CBAM代码实现点击即可跳转

CBAM(Convolutional Block Attention Module)是一种引入了注意力机制的卷积神经网络模块,旨在增强CNN模型的表征能力和性能。它由两个关键组件组成:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。

1. 通道注意力模块(CAM):
   CAM主要用于建模特征图在通道维度上的关系。它的目标是学习每个通道的重要性,并对不同通道的特征进行加权,以提升有用特征的影响力,抑制无用特征的干扰。
   CAM首先对输入的特征图进行全局平均池化(Global Average Pooling),将每个通道的特征图压缩成一个标量,然后通过全连接层(FC)学习得到每个通道的权重。这些权重用于对每个通道的特征图进行加权,得到加权后的特征表示。

2. 空间注意力模块(SAM):
   SAM用于捕捉特征图在空间维度上的重要性。它的目标是学习不同空间位置的权重,使网络能够更好地关注图像中的重要区域。
   SAM首先对特征图进行两种池化操作:最大池化和平均池化。这两种池化操作分别用于捕捉特征图中的局部显著性和全局分布信息。然后,将两种池化结果结合,并通过全连接层学习得到每个空间位置的权重,以产生最终的空间注意力图。
   
通过结合通道注意力和空间注意力,CBAM可以使网络更好地理解输入数据中的关键信息,并提高模型在各种视觉任务上的性能。这种注意力机制的引入使得网络能够自适应地调整特征图中不同通道和空间位置的重要性,从而有效地提升了模型的表现力和泛化能力。CBAM已经被成功应用于图像分类、目标检测、语义分割等多个计算机视觉任务中,取得了显著的性能提升。

CBAM结构简图

2.YOLOv5添加CBAM注意力机制

2.1 CBAM注意力机制代码

关键步骤一:将下面代码添加到 yolov5/models/common.py中任意位置

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        # (特征图的大小-算子的size+2*padding)/步长+1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 1*h*w
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        #2*h*w
        x = self.conv(x)
        #1*h*w
        return self.sigmoid(x)


class CBAM(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, ratio=16, kernel_size=7):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(c1, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.channel_attention(x) * x
        # c*h*w
        # c*h*w * 1*h*w
        out = self.spatial_attention(out) * out
        return out

 CBAM(Convolutional Block Attention Module)注意力机制的流程可以总结如下:

1. 输入特征图:接收来自上一层或输入图像的特征图作为输入。

2. 通道注意力模块(Channel Attention Module,CAM):
   对输入特征图进行全局平均池化(Global Average Pooling),将每个通道的特征进行降维,得到每个通道的全局描述。
   通过全连接层(Fully Connected Layer)学习得到每个通道的权重向量,这些权重用于衡量每个通道的重要性。
   将学习到的权重与原始特征图相乘,以加权增强有用特征和抑制无用特征。

3. 空间注意力模块(Spatial Attention Module,SAM):
   对输入特征图进行最大池化(Max Pooling)和平均池化(Average Pooling),分别捕获局部显著性和全局分布信息。
   将两种池化结果进行组合(如相加),得到综合的空间注意力图。
   通过激活函数(如sigmoid)对空间注意力图进行归一化,得到每个空间位置的权重。

4. 结合通道和空间注意力:
   将通道注意力加权后的特征图与空间注意力加权后的特征图进行逐元素相乘,得到最终的注意力增强特征图。

5. 输出:最终的注意力增强特征图作为模块的输出,传递给下一层网络进行后续的处理,如分类、检测或分割等任务。

整个CBAM注意力机制的流程是将通道注意力和空间注意力相结合,使得网络能够自适应地调整不同通道和空间位置的重要性,从而提升模型的性能和泛化能力。

2.2新增yaml文件

关键步骤二:在 /yolov5/models/ 下新建文件 yolov5_cbam.yaml并将下面代码复制进去

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, CBAM, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 10
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 14

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 18 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 15], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 21 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 11], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 24 (P5/32-large)

   [[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

温馨提示:因为本文只是对yolov5s基础上添加CBAM模块,如果要对yolov5n/l/m/x进行添加则只需要修改对应的depth_multiple 和 width_multiple。


yolov5n/l/m/x对应的depth_multiple 和 width_multiple如下:

# YOLOv5n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple

# YOLOv5l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple

# YOLOv5m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple

# YOLOv5x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple
2.3 注册模块

关键步骤三:在yolov5/models/yolo.py中注册,大概在250行左右添加 ‘CBAM’

2.4 执行程序

在train.py中,将cfg的参数路径设置为yolov5_cbam.yaml的路径,如下图所示

建议大家写绝对路径,确保一定能找到

运行程序,如果出现下面的内容则说明添加成功🚀

我修改后的代码:链接: https://pan.baidu.com/s/1qoLGhu7t4noFMxvi7t0rzA?pwd=92im 提取码: 92im

3.总结 

CBAM(Convolutional Block Attention Module)是一种用于增强卷积神经网络(CNN)性能的注意力机制。它由两个子模块组成:通道注意力模块和空间注意力模块。通道注意力模块通过全局平均池化和全连接层学习通道间的关系,并利用学到的权重对每个通道的特征图进行加权,以增强有用的特征并抑制无用的特征。空间注意力模块则通过对特征图在空间维度上进行最大池化和平均池化操作,结合两种池化结果通过全连接层学习得到每个空间位置的权重,使得网络能够更好地关注图像中的重要区域。CBAM的引入可以帮助网络更好地理解输入数据中的关键信息,从而提高了模型在各种视觉任务上的性能,如图像分类、目标检测和语义分割等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1667365.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Kafka效率篇-提升效率三板斧

kafka在效率上做了很多的努力。最初的一个使用场景是处理网页上活跃的数据,它往往有非常大的体量,每个页面都能产生数十条写入。而且我们假设每条消息都会被至少一个消费者消费(通常是多个),因此,我们努力让…

对称加密介绍

一、什么是对称加密 对称密钥算法(Symmetric-key algorithm),又称为对称加密、私钥加密、共享密钥加密,是密码学中的一类加密算法。 对称加密的特点是,在加密和解密时使用相同的密钥,或是使用两个可以简单地相互推算的密钥。 这…

智慧公厕:数据驱动的公共厕所智慧化管理

公共厕所作为城市基础设施的重要组成部分,对于城市居民的生活质量和城市形象有着不可忽视的影响。然而,传统的公共厕所管理模式存在诸多问题,如设施老化、卫生状况不佳等,严重限制了公众对于公共厕所的使用体验。随着大数据和智能…

ViLT 浅析

ViLT 浅析 论文链接:ViLT 文章目录 ViLT 浅析创新点网络结构总结 创新点 本文先分析了4种不同类型的Vision-and-Language Pretraining(VLP) 其中每个矩形的高表示相对计算量大小,VE、TE和MI分别是visual embedding、text embedding和modality interact…

类型注解-Python

师从黑马程序员 类型注解的语法 类型注释的限制 import json import randomvar_1 : int10 var_2 : str"itheima" var_3 : boolTrueclass Student:pass stu :StudentStudent()my_list:list [1,2,3] my_tuple:tuple(1,2,3) my_dict:dict{"itheima":666}my_l…

智慧安监中的物联网主机E6000

物联网主机E6000的研发背景主要源于我国对物联网技术在安全生产、环境监测、火灾预警与防控、人员定位与紧急救援等领域的迫切需求。近年来,随着物联网技术的飞速发展,我国政府对智慧安监的重视程度不断提升,相关的政策扶持力度也在加大。在这…

乡村振兴与数字乡村建设:加强农村信息化建设,推动数字乡村发展,提升乡村治理和服务水平,构建智慧化的美丽乡村

目录 一、引言 二、数字乡村建设的必要性 1、推动农村经济转型升级 2、提升乡村治理水平 3、改善乡村民生福祉 三、数字乡村建设的现状与挑战 1、现状 2、挑战 四、数字乡村建设的未来发展路径 1、加强农村信息化基础设施建设 2、提升农民信息素养和技能水平 3、制…

解锁Spring Boot数据映射新利器:深度探索MapperStruct

解锁Spring Boot数据映射新利器:深度探索MapperStruct MapperStruct 是一个强大的 Java 映射工具,它的主要作用是简化对象之间的映射操作。在 Spring Boot 应用程序中,MapperStruct 通常用于将领域模型对象(Domain Model&#xff…

17_基于Flash和RAM的的文件系统选择

嵌入式系统常见文件系统 本文主要讲述在嵌入式系统中,常见的基于flash和内存(RAM)的文件系统类型,具体选择要结合实际需求灵活选配。 一、基于 Flash 的文件系统 基于 Flash 的文件系统主要包括 JFFS2、 YAFFS、 Cramfs 和 Romfs 等,各种文件系统具有不同的特点,本文将分…

基于微信小程序的预约挂号系统(源码)

博主介绍:✌程序员徐师兄、10年大厂程序员经历。全网粉丝12W、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅&#x1f447…

【免费】2024年全新超强版本itvboxfast如意版影视APP源码 TV+手机双端后台PHP源码

首先,让我们了解一下ITVBox如意版影视源码的特点和优势。这一源码基于先进的技术和框架开发,具有稳定、高效的性能,能够满足影视网站的各种需求。与此同时,该源码还提供了丰富的功能和模块,包括影视资源管理、会员系统…

C语言——文件相关操作补充

一、文件读取结束的判定 当我们使用例如fgetc、fgets、fscanf、fread等函数来读取文件内容时,我们可能遇到需要判断文件读取的结束,一般情况下都是通过这些函数的返回值来判断文件读取是否结束。 1、fgetc 返回读取的字符的ASCII值,如果读…

能源效率:未来可持续发展的全球当务之急

当前全球正面临着严重的能源与气候危机,能源消耗不断增长导致环境污染、气候变化等问题日益严重。在这一背景下,提高能源效率成为了当务之急。今天,我们来简要探讨一下能源效率在全球可持续发展中的重要性,重点关注建筑物能源效率…

了解当前经济,VBA一键获取不同货币实时汇率

了解当前经济数据,VBA一键获取不同货币间实时汇率 当下较火的经济新闻:黄金价格、日元贬值、美元加息等,咱们不去分析了解这些经济变动背后的动机及原因,做一点本份的事,如何用VBA获取不同货币之间的实时汇率。这肯定是需要联网的,现从“外汇查询” 网站(https://www.wa…

Django国际化与本地化指南

title: Django国际化与本地化指南 date: 2024/5/12 16:51:04 updated: 2024/5/12 16:51:04 categories: 后端开发 tags: Django-i18n本地化-L10n多语言国际化翻译工具表单验证性能优化 引言 在数字化时代,网站和应用程序必须跨越地域限制,服务于全球…

微信小程序踩坑,skyline模式下,简易双向绑定无效

工具版本 基础库版本 Skline模式 页面json设置 问题描述 skyline模式下,textarea,input标签设置简易双向绑定 model:value是无效的,关闭skyline模式就正常使用了 截图展示 这里只展示了textarea标签,input标签的简易双向绑定也是无效的 总结 我在文档里面是没找到skyline里面不…

第3周 后端微服务基础架构与前端项目联调配备

第3周 后端微服务基础架构与前端项目联调配备 1. 微服务项目层次设计与Maven聚合1.1 项目层次设计1.2 父项目pom1.2.1 打包方式 1.3 创建通用 ************************************************************************************** 1. 微服务项目层次设计与Maven聚合 1.1…

【JS红宝书学习笔记】第3章 语言基础

第3章 语言基础 1. 语法 标识符(变量、函数、属性或函数参数的名称):一般使用驼峰法命名,关键字、保留字、true、false 和 null 不能作为标识符。 标识符的第一个字符必须是一个字母、下划线(_)或美元符号…

MySQL数据库基础(数据库操作,常用数据类型,表的操作)

MySQL数据库基础(数据库操作,常用数据类型,表的操作) 前言 数据库的操作1.显示当前数据库2.创建数据库3.使用数据库4.删除数据库 常用数据类型1.数值类型2.字符串类型3.日期类型 表的操作1.查看表结构2.创建表3.删除表 总结 前言 …

【电子实验3】简单变调电子门铃

🚩 WRITE IN FRONT 🚩 🔎 介绍:"謓泽"正在路上朝着"攻城狮"方向"前进四" 🔎🏅 荣誉:2021|2022年度博客之星物联网与嵌入式开发TOP5|TOP4、2021|2222年获评…