YOLOV5中添加CBAM模块详解——原理+代码

news2024/9/19 19:59:35

目录

  • 一、前言
  • 二、CAM
        • 1. CAM计算过程
        • 2. 代码实现
        • 3. 流程图
  • 三、SAM
        • 1. SAM计算过程
        • 2. 代码实现
        • 3. 流程图
  • 四、YOLOv5中添加CBAM模块
  • 参考文章

一、前言

  由于卷积操作通过融合通道和空间信息来提取特征(通过 N × N N×N N×N的卷积核与原特征图相乘,融合空间信息;通过不同通道的特征图加权求和,融合通道信息),论文提出的Convolutional Block Attention Module(CBAM)沿两个独立的维度(通道和空间)依次学习特征,然后与学习后的特征图与输入特征图相乘,进行自适应特征细化。

在这里插入图片描述

图1-1 CBAM结构图

  上图可以看到,CBAM包含CAM(Channel Attention Module)和SAM(Spartial Attention Module)两个子模块,分别进行通道和空间上的Attention。这样不只能够节约参数和计算力,并且保证了其能够做为即插即用的模块集成到现有的网络架构中去。

二、CAM

1. CAM计算过程

在这里插入图片描述

图2-1 CAM结构图

  输入特征图 F F F首先经过两个并行的MaxPool层和AvgPool层,将特征图的维度从 C × H × W C×H×W C×H×W变为 C × 1 × 1 C×1×1 C×1×1,然后经过Shared MLP模块。在该模块中,它先将通道数压缩为原来的 1 / r 1/r 1/r倍,再经过ReLU激活函数,然后扩张到原通道数。将这两个输出结果进行逐元素相加,再通过一个sigmoid激活函数得到Channel Attention的输出结果,然后将这个输出结果与原图相乘,变回 C × H × W C×H×W C×H×W的大小。

  上述过程的计算公式如下:

M c ( F ) = σ ( M L P ( A v g P o o l ( F ) ) + M L P ( M a x P o o l ( F ) ) ) M_{c}(F)=\sigma (MLP(AvgPool(F))+MLP(MaxPool(F))) Mc(F)=σ(MLP(AvgPool(F))+MLP(MaxPool(F)))
= σ ( W 1 ( W 0 ( F a v g c ) ) + W 1 ( W 0 ( F m a x c ) ) ) =\sigma (W_{1}(W_{0}(F^{c}_{avg}))+W_{1}(W_{0}(F^{c}_{max}))) =σ(W1(W0(Favgc))+W1(W0(Fmaxc)))

  其中, σ \sigma σ代表sigmoid激活函数, W 0 ∈ R C / r × C W_{0}\in R^{C/r\times C} W0RC/r×C W 1 ∈ R C × C / r W_{1}\in R^{C\times C/r} W1RC×C/r,且MLP的权重 W 0 W_{0} W0 W 1 W_{1} W1对于输入来说是共享的,ReLU激活函数位于 W 0 W_{0} W0之后, W 1 W_{1} W1之前。

2. 代码实现

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) # 上面公式中的W0
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) # 上面公式中的W1

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return torch.mul(x, out)

3. 流程图

  CAM过程的详细流程如下图所示:

在这里插入图片描述

图2-2 CAM流程图

三、SAM

1. SAM计算过程

在这里插入图片描述

图3-1 SAM结构图

  将Channel Attention的输出结果通过最大池化和平均池化得到两个 1 × H × W 1×H×W 1×H×W的特征图,然后经过Concat操作对两个特征图进行拼接,再通过 7 × 7 7×7 7×7卷积将特征图的通道数变为 1 1 1(实验证明 7 × 7 7×7 7×7效果比 3 × 3 3×3 3×3好),再经过一个sigmoid得到Spatial Attention的特征图,最后将输出结果与原输入特征图相乘,变回CHW大小。

  上述过程的计算公式如下:

M s ( F ) = σ ( f 7 × 7 ( [ A v g P o o l ( F ) ; M a x P o o l ( F ) ] ) ) M_{s}(F)=\sigma (f^{7\times 7}([AvgPool(F);MaxPool(F)])) Ms(F)=σ(f7×7([AvgPool(F);MaxPool(F)]))

= σ ( f 7 × 7 ( [ F a v g s ; F m a x s ] ) ) =\sigma (f^{7\times 7}([F^{s}_{avg};F^{s}_{max}])) =σ(f7×7([Favgs;Fmaxs]))

  其中, σ \sigma σ代表sigmoid激活函数, f 7 × 7 f^{7\times 7} f7×7代表卷积核大小为 7 × 7 7×7 7×7的卷积过程。

2. 代码实现

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.sigmoid(self.conv(out))
        return torch.mul(x, out)

3. 流程图

  SAM过程的详细流程如下图所示:

在这里插入图片描述

图3-2 SAM流程图

四、YOLOv5中添加CBAM模块

  • 修改common.py
    在common.py中添加下列代码:
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return torch.mul(x, out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.sigmoid(self.conv(out))
        return torch.mul(x, out)


class CBAMC3(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(CBAMC3, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)
        self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
        self.channel_attention = ChannelAttention(c2, 16)
        self.spatial_attention = SpatialAttention(7)

    def forward(self, x):
   		# 将最后的标准卷积模块改为了注意力机制提取特征
        return self.spatial_attention(
            self.channel_attention(self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))))
  • 修改yolo.py
    在yolo.py的if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3TR,......]中添加CBAMC3,即修改后的代码为:
        if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
                 C3, C3TR, ASPP, CBAMC3]:
            c1, c2 = ch[f], args[0]  
            if c2 != no:  
                c2 = make_divisible(c2 * gw, 8)  
            args = [c1, c2, *args[1:]] 
  • 修改yolov5s.yaml
    修改后的yolov5s.yaml如下:
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, CBAMC3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, CBAMC3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, CBAMC3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, CBAMC3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

参考文章

CBAM——即插即用的注意力模块(附代码)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/393710.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

模板学堂丨妙用Tab组件制作多屏仪表板并实现自动轮播

DataEase开源数据可视化分析平台于2022年6月正式发布模板市场(https://dataease.io/templates/)。模板市场旨在为DataEase用户提供专业、美观、拿来即用的仪表板模板,方便用户根据自身的业务需求和使用场景选择对应的仪表板模板,并…

2021年MathorCup数学建模D题钢材制造业中的钢材切割下料问题全过程文档及程序

2021年第十一届MathorCup高校数学建模 D题 钢材制造业中的钢材切割下料问题 原题再现 某钢材生产制造商的钢材切割流程如图 1 所示。其中开卷上料环节将原材料钢卷放在开卷机上,展开放平送至右侧操作区域(见图 2)。剪切过程在剪切台上完成&…

如何使用 NFTScan NFT API 检索单个 NFT 资产

一、什么是 NFT API API 是允许两个应用组件使用一组定义和协议相互通信的机制。一般来说,这是一套明确定义的各种软件组件之间的通信方法。API 发送请求到存储数据的服务器,接着把调用的数据信息返回。开发者可以通过调用 API 函数对应用程序进行开发&…

Qt 单例模式第一次尝试

文章目录摘要单例模式如何使用Qt 的属性系统总结关键字: Qt、 单例、 的、 Q_GLOBAL_STATIC、 女神节摘要 世界上第一位电脑程序设计师是名女性:Ada Lovelace (1815-1852)是一位英国数学家兼作家,她是第一位主张计算机不只可以用来算数的人…

【安装mxnet】

安装mxnet 通过创建python3.6版本的虚拟环境安装mxnet 1、安装anaconda 2、打开Anaconda prompt 3、查看环境 conda env list conda info -e 4、创建虚拟环境 conda create -n your_env_name python3.6 5、激活或者切换虚拟环境 activate your_env_name 6、安装mxnet,下面两…

规划数据指标体系方法(中)——UJM 模型

上文我跟大家分享了关于规划数据指标体系的 OSM 模型,从目标-战略-度量的角度解读了数据指标的规划方法,今天来跟大家讲一讲另一种规划数据指标体系的方法——UJM 模型。 了解 UJM UJM 模型,全称为 User-Journey-Map 模型,即用户…

日常文档标题级别规范

这里写自定义目录标题欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants创建一个自定义列表如何创建一个注…

使用 try-catch 捕获异常会影响性能吗?大部分人都会答错!

使用 try-catch 捕获异常会影响性能吗?大部分人都会答错!前言一、JVM 异常处理逻辑二、关于JVM的编译优化三、关于测试的约束四、测试代码五、解释模式下执行测试六、编译模式测试七、结论前言 不知道从何时起,传出了这么一句话:…

web实现环形旋转、圆形、弧形、querySelectorAll、querySelector、clientWidth、sin、cos、PI

文章目录1、HTML部分2、css部分3、JavaScript部分4、微信小程序演示1、HTML部分 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge">&l…

Java定时器Timer的使用

一、Timer常用方法 Timer应用场景&#xff1a; 1、每隔一段时间执行指定的代码逻辑&#xff08;即按周期执行任务&#xff09; 2、指定时间执行指定的代码逻辑 为方便测试并查看运行效果&#xff0c;首先先建一个类并继承TimerTask&#xff0c;代码如下: package timerTest…

[2.1.4]进程管理——进程通信

文章目录第二章 进程管理进程通信&#xff08;IPC&#xff09;为什么进程通信需要操作系统支持&#xff1f;&#xff08;一&#xff09;共享存储&#xff08;1&#xff09;基于存储区的共享&#xff08;2&#xff09;基于数据结构的共享&#xff08;二&#xff09;消息传递什么…

程序员的逆向思维

前要&#xff1a; 为什么你读不懂面试官提问的真实意图&#xff0c;导致很难把问题回答到面试官心坎上? 为什么在面试结束时&#xff0c;你只知道问薪资待遇&#xff0c;不知道如何高质量反问? 作为一名程序员&#xff0c;思维和技能是我们职场生涯中最重要的两个方面。有时候…

RWEQ模型的土壤风蚀模数估算、其变化归因分析

土壤风蚀是一个全球性的环境问题。中国是世界上受土壤风蚀危害最严重的国家之一&#xff0c;土壤风蚀是中国干旱、半干旱及部分湿润地区土地荒漠化的首要过程。中国风蚀荒漠化面积达160.74104km2&#xff0c;占国土总面积的16.7%&#xff0c;严重影响这些地区的资源开发和社会经…

使用pprof分析golang内存泄露问题

问题现象 生产环境有个golang应用上线一个月来&#xff0c;占用内存不断增多&#xff0c;约30个G&#xff0c;这个应用的DAU估计最多几十&#xff0c;初步怀疑有内存泄露问题。下面是排查步骤&#xff1a; 分析 内存泄露可能点&#xff1a; goroutine没有释放time.NewTicke…

【前端学习】D2-2:CSS基础

文章目录前言系列文章目录1 Emmet语法1.1 快速生成HTML语法结构1.2 快速生成CSS样式语法1.3 快速格式化代码2 CSS复合选择器2.1 什么是复合选择器2.2 后代选择器&#xff08;*&#xff09;2.3 子选择器2.4 并集选择器&#xff08;*&#xff09;2.5 伪类选择器2.6 链接伪类选择器…

企业文件数据泄露防护(DLP)

什么是数据丢失防护 数据丢失防护 &#xff08;DLP&#xff09; 是保护数据不落入坏人之手的做法。如今&#xff0c;数据传输的主要问题是使大量数据容易受到未经授权的传输。通过设置足够的安全边界&#xff0c;您可以控制数据在网络中的移动。由于您的数据非常有价值&#x…

Java方法的使用

目录 一、方法的概念及使用 1、什么是方法(method) 2、方法定义 3、方法调用的执行过程 4、实参和形参的关系 二、方法重载 1、为什么需要方法重载 2、方法重载概念 3、方法签名 三、递归 1、递归的概念 2、递归执行过程分析 一、方法的概念及使用 1、什么是方法(met…

MySQL 字符串函数

点击上方蓝字关注我平生文字为吾累&#xff0c;此去声名不厌低。 寒上纵归他日马&#xff0c;城中不斗少年鸡。MySQL提供了许多常用的字符串函数&#xff0c;以下是其中一些常用的字符串函数和用法&#xff1a;CONCATCONCAT函数用于连接两个或多个字符串。以下是一个示例&#…

MGAT: Multimodal Graph Attention Network for Recommendation

模型总览如下&#xff1a; 图1&#xff1a;多模态图注意力网络背景&#xff1a;本论文是对MMGCN&#xff08;Wei et al., 2019&#xff09;的改进。MMGCN简单地在并行交互图上使用GNN&#xff0c;平等地对待从所有邻居传播的信息&#xff0c;无法自适应地捕获用户偏好。 MMGCN…

Qt学习5-Qt Creator文件操作(哔站视频学习记录)

实现文件编辑器代码 目录 一、代码要点 二、重点函数 1、conncet 2、getOpenFileName 3、getSaveFileName 4、读取文件到textEdit 5、textEdit保存到文件 三、全部代码 mainwindow.h mainwindow.cpp 一、代码要点 MainWindow的菜单栏实现&#xff1b;connect函数连接…