【YOLOv5/v7改进系列】引入卷积块注意力模块CBAM注意力机制

news2024/11/23 22:53:12
一、导言

CBAM(Convolutional Block Attention Module)是一种简单而有效的注意力机制模块,旨在增强卷积神经网络(CNN)的表现力。该模块通过引入两个独立的注意力机制——通道注意力和空间注意力——来适应性地精炼特征图,从而提升模型的整体性能。

通道注意力

通道注意力关注于“哪些”特征应该被强调或抑制。它基于全局信息,通过平均池化和最大池化操作从特征图中提取两种类型的全局描述符。这两种描述符分别捕捉了特征图中不同通道的重要性和分布情况。之后,这些描述符被传递给一系列卷积层,最终生成一个与输入特征图通道数相同的注意力图。此注意力图被乘以原始特征图,从而实现对不同通道特征的选择性放大或减弱。

空间注意力

空间注意力则关注于“哪里”的问题,即特征图中的哪些位置应该被强调或抑制。它通过分析特征图的局部模式来生成一个二维的注意力图。该过程首先计算特征图中每个位置的平均值和最大值,然后将这两个统计量连接起来并传递给一个卷积层,生成一个单一的二维注意力图。这个注意力图同样被乘以原始特征图,使得模型能够聚焦于关键区域。

CBAM模块的工作流程

CBAM模块依次应用通道注意力和空间注意力机制。具体来说:

  1. 通道注意力:对输入特征图进行平均池化和最大池化,生成两种全局描述符,接着通过卷积层计算出通道注意力图。
  2. 空间注意力:将经过通道注意力处理后的特征图再进行平均和最大值计算,生成两种空间描述符,然后通过卷积层计算出空间注意力图。

最后,通道注意力图和空间注意力图分别与输入特征图相乘,以实现特征的精炼。CBAM模块轻量化且通用,可以无缝集成到任何卷积神经网络架构中,并且整个模块可以与基础CNN一起端到端训练。

实验结果

通过在ImageNet-1K、MS COCO检测和VOC 2007检测数据集上的广泛实验,验证了CBAM的有效性。实验结果显示,CBAM能够显著提高各种基准模型的分类和检测性能,而且这种性能提升不是因为增加了大量的额外参数,而是因为更有效地利用了现有的特征。此外,使用轻量级骨干网络时,CBAM也显示出了良好的适用性,这意味着它对于资源受限的设备同样有益。

总之,CBAM是一种强大的注意力机制,能够帮助CNN更好地聚焦于目标对象,从而提高其识别和定位能力。

优点
  1. 通用性:CBAM可以无缝地集成到任何CNN架构中,而且不会增加显著的计算负担,这使得它能够作为一种“即插即用”的组件被广泛采用。
  2. 有效性验证:作者通过在多个数据集上的实验验证了CBAM的有效性,包括ImageNet-1K分类任务、MS COCO目标检测以及VOC 2007目标检测任务,证明了其在分类和检测任务上的一致改进。
  3. 轻量级设计:CBAM的设计考虑到了计算效率和参数数量,这意味着它可以被轻易地添加到现有的深度学习模型中而不增加过多的开销。
  4. 注意力机制的分解:CBAM将注意力机制分解为通道注意力和空间注意力两个独立的部分,这种方法简化了注意力的学习过程,并且降低了计算复杂度和参数数量。
  5. 性能提升:通过在不同模型上实现CBAM,作者表明这种机制不仅可以提高分类准确率,还可以提高目标检测任务的性能,甚至在某些情况下达到了最先进的水平。

二、准备工作

首先在YOLOv5/v7的models文件夹下新建文件moreattention.py,导入如下代码

from models.common import *


# CBAM注意力机制 https://arxiv.org/pdf/1807.06521
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)


class CBAM(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(CBAM, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.LeakyReLU(0.1) if act else nn.Identity()
        self.ca = ChannelAttention(c2)
        self.sa = SpatialAttention()

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        x = self.ca(x) * x
        x = self.sa(x) * x
        return x

    def fuseforward(self, x):
        return self.act(self.conv(x))

其次在在YOLOv5/v7项目文件下的models/yolo.py中在文件首部添加代码

from models.moreattention import CBAM

并搜索def parse_model(d, ch)

定位到如下行添加以下代码

                 CBAM,

三、YOLOv7-tiny改进工作

完成二后,在YOLOv7项目文件下的models文件夹下创建新的文件yolov7-tiny-cbam.yaml,导入如下代码。

# parameters
nc: 80  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple

# anchors
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# yolov7-tiny backbone
backbone:
  # [from, number, module, args] c2, k=1, s=1, p=None, g=1, act=True
  [[-1, 1, Conv, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]],  # 0-P1/2

   [-1, 1, Conv, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]],  # 1-P2/4

   [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 7

   [-1, 1, MP, []],  # 8-P3/8
   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 14

   [-1, 1, MP, []],  # 15-P4/16
   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 21

   [-1, 1, MP, []],  # 22-P5/32
   [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [512, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 28
  ]

# yolov7-tiny head
head:
  [[-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, SP, [5]],
   [-2, 1, SP, [9]],
   [-3, 1, SP, [13]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -7], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 37

   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [21, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P4
   [[-1, -2], 1, Concat, [1]],

   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 47

   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [14, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P3
   [[-1, -2], 1, Concat, [1]],

   [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 57

   [-1, 1, Conv, [128, 3, 2, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, 47], 1, Concat, [1]],

   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 65

   [-1, 1, Conv, [256, 3, 2, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, 37], 1, Concat, [1]],

   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 73
   
   [-1, 1, CBAM, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 74

   [57, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [65, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [74, 1, Conv, [512, 3, 1, None, 1, nn.LeakyReLU(0.1)]],

   [[75,76,77], 1, IDetect, [nc, anchors]],   # Detect(P3, P4, P5)
  ]

                 from  n    params  module                                  arguments                     
  0                -1  1       928  models.common.Conv                      [3, 32, 3, 2, None, 1, LeakyReLU(negative_slope=0.1)]
  1                -1  1     18560  models.common.Conv                      [32, 64, 3, 2, None, 1, LeakyReLU(negative_slope=0.1)]
  2                -1  1      2112  models.common.Conv                      [64, 32, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  3                -2  1      2112  models.common.Conv                      [64, 32, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  4                -1  1      9280  models.common.Conv                      [32, 32, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  5                -1  1      9280  models.common.Conv                      [32, 32, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  6  [-1, -2, -3, -4]  1         0  models.common.Concat                    [1]                           
  7                -1  1      8320  models.common.Conv                      [128, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  8                -1  1         0  models.common.MP                        []                            
  9                -1  1      4224  models.common.Conv                      [64, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 10                -2  1      4224  models.common.Conv                      [64, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 11                -1  1     36992  models.common.Conv                      [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 12                -1  1     36992  models.common.Conv                      [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 13  [-1, -2, -3, -4]  1         0  models.common.Concat                    [1]                           
 14                -1  1     33024  models.common.Conv                      [256, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 15                -1  1         0  models.common.MP                        []                            
 16                -1  1     16640  models.common.Conv                      [128, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 17                -2  1     16640  models.common.Conv                      [128, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 18                -1  1    147712  models.common.Conv                      [128, 128, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 19                -1  1    147712  models.common.Conv                      [128, 128, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 20  [-1, -2, -3, -4]  1         0  models.common.Concat                    [1]                           
 21                -1  1    131584  models.common.Conv                      [512, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 22                -1  1         0  models.common.MP                        []                            
 23                -1  1     66048  models.common.Conv                      [256, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 24                -2  1     66048  models.common.Conv                      [256, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 25                -1  1    590336  models.common.Conv                      [256, 256, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 26                -1  1    590336  models.common.Conv                      [256, 256, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 27  [-1, -2, -3, -4]  1         0  models.common.Concat                    [1]                           
 28                -1  1    525312  models.common.Conv                      [1024, 512, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 29                -1  1    131584  models.common.Conv                      [512, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 30                -2  1    131584  models.common.Conv                      [512, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 31                -1  1         0  models.common.SP                        [5]                           
 32                -2  1         0  models.common.SP                        [9]                           
 33                -3  1         0  models.common.SP                        [13]                          
 34  [-1, -2, -3, -4]  1         0  models.common.Concat                    [1]                           
 35                -1  1    262656  models.common.Conv                      [1024, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 36          [-1, -7]  1         0  models.common.Concat                    [1]                           
 37                -1  1    131584  models.common.Conv                      [512, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 38                -1  1     33024  models.common.Conv                      [256, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 39                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 40                21  1     33024  models.common.Conv                      [256, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 41          [-1, -2]  1         0  models.common.Concat                    [1]                           
 42                -1  1     16512  models.common.Conv                      [256, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 43                -2  1     16512  models.common.Conv                      [256, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 44                -1  1     36992  models.common.Conv                      [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 45                -1  1     36992  models.common.Conv                      [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 46  [-1, -2, -3, -4]  1         0  models.common.Concat                    [1]                           
 47                -1  1     33024  models.common.Conv                      [256, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 48                -1  1      8320  models.common.Conv                      [128, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 49                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 50                14  1      8320  models.common.Conv                      [128, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 51          [-1, -2]  1         0  models.common.Concat                    [1]                           
 52                -1  1      4160  models.common.Conv                      [128, 32, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 53                -2  1      4160  models.common.Conv                      [128, 32, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 54                -1  1      9280  models.common.Conv                      [32, 32, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 55                -1  1      9280  models.common.Conv                      [32, 32, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 56  [-1, -2, -3, -4]  1         0  models.common.Concat                    [1]                           
 57                -1  1      8320  models.common.Conv                      [128, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 58                -1  1     73984  models.common.Conv                      [64, 128, 3, 2, None, 1, LeakyReLU(negative_slope=0.1)]
 59          [-1, 47]  1         0  models.common.Concat                    [1]                           
 60                -1  1     16512  models.common.Conv                      [256, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 61                -2  1     16512  models.common.Conv                      [256, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 62                -1  1     36992  models.common.Conv                      [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 63                -1  1     36992  models.common.Conv                      [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 64  [-1, -2, -3, -4]  1         0  models.common.Concat                    [1]                           
 65                -1  1     33024  models.common.Conv                      [256, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 66                -1  1    295424  models.common.Conv                      [128, 256, 3, 2, None, 1, LeakyReLU(negative_slope=0.1)]
 67          [-1, 37]  1         0  models.common.Concat                    [1]                           
 68                -1  1     65792  models.common.Conv                      [512, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 69                -2  1     65792  models.common.Conv                      [512, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 70                -1  1    147712  models.common.Conv                      [128, 128, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 71                -1  1    147712  models.common.Conv                      [128, 128, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 72  [-1, -2, -3, -4]  1         0  models.common.Concat                    [1]                           
 73                -1  1    131584  models.common.Conv                      [512, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 74                -1  1     74338  models.moreattention.CBAM               [256, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 75                57  1     73984  models.common.Conv                      [64, 128, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 76                65  1    295424  models.common.Conv                      [128, 256, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 77                74  1   1180672  models.common.Conv                      [256, 512, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
 78      [75, 76, 77]  1     17132  models.yolo.IDetect                     [1, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]

Model Summary: 277 layers, 6089326 parameters, 6089326 gradients, 13.2 GFLOPS

运行后若打印出如上文本代表改进成功。

四、YOLOv5s改进工作

完成二后,在YOLOv5项目文件下的models文件夹下创建新的文件yolov5s-cbam.yaml,导入如下代码。

# Parameters
nc: 1  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
   [-1, 1, CBAM, [1024]],# 24 (P5/32-large)+attention

   [[17, 20, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

                 from  n    params  module                                  arguments                     
  0                -1  1      3520  models.common.Conv                      [3, 32, 6, 2, 2]              
  1                -1  1     18560  models.common.Conv                      [32, 64, 3, 2]                
  2                -1  1     18816  models.common.C3                        [64, 64, 1]                   
  3                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               
  4                -1  2    115712  models.common.C3                        [128, 128, 2]                 
  5                -1  1    295424  models.common.Conv                      [128, 256, 3, 2]              
  6                -1  3    625152  models.common.C3                        [256, 256, 3]                 
  7                -1  1   1180672  models.common.Conv                      [256, 512, 3, 2]              
  8                -1  1   1182720  models.common.C3                        [512, 512, 1]                 
  9                -1  1    656896  models.common.SPPF                      [512, 512, 5]                 
 10                -1  1    131584  models.common.Conv                      [512, 256, 1, 1]              
 11                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 12           [-1, 6]  1         0  models.common.Concat                    [1]                           
 13                -1  1    361984  models.common.C3                        [512, 256, 1, False]          
 14                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]              
 15                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 16           [-1, 4]  1         0  models.common.Concat                    [1]                           
 17                -1  1     90880  models.common.C3                        [256, 128, 1, False]          
 18                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]              
 19          [-1, 14]  1         0  models.common.Concat                    [1]                           
 20                -1  1    296448  models.common.C3                        [256, 256, 1, False]          
 21                -1  1    590336  models.common.Conv                      [256, 256, 3, 2]              
 22          [-1, 10]  1         0  models.common.Concat                    [1]                           
 23                -1  1   1182720  models.common.C3                        [512, 512, 1, False]          
 24                -1  1    296034  models.moreattention.CBAM               [512, 512]                    
 25      [17, 20, 24]  1     16182  models.yolo.Detect                      [1, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]

Model Summary: 284 layers, 7318360 parameters, 7318360 gradients, 16.2 GFLOPs

运行后若打印出如上文本代表改进成功。

五、YOLOv5n改进工作

完成二后,在YOLOv5项目文件下的models文件夹下创建新的文件yolov5n-cbam.yaml,导入如下代码。

# Parameters
nc: 1  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
   [-1, 1, CBAM, [1024]],# 24 (P5/32-large)+attention

   [[17, 20, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

                 from  n    params  module                                  arguments                     
  0                -1  1      1760  models.common.Conv                      [3, 16, 6, 2, 2]              
  1                -1  1      4672  models.common.Conv                      [16, 32, 3, 2]                
  2                -1  1      4800  models.common.C3                        [32, 32, 1]                   
  3                -1  1     18560  models.common.Conv                      [32, 64, 3, 2]                
  4                -1  2     29184  models.common.C3                        [64, 64, 2]                   
  5                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               
  6                -1  3    156928  models.common.C3                        [128, 128, 3]                 
  7                -1  1    295424  models.common.Conv                      [128, 256, 3, 2]              
  8                -1  1    296448  models.common.C3                        [256, 256, 1]                 
  9                -1  1    164608  models.common.SPPF                      [256, 256, 5]                 
 10                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]              
 11                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 12           [-1, 6]  1         0  models.common.Concat                    [1]                           
 13                -1  1     90880  models.common.C3                        [256, 128, 1, False]          
 14                -1  1      8320  models.common.Conv                      [128, 64, 1, 1]               
 15                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 16           [-1, 4]  1         0  models.common.Concat                    [1]                           
 17                -1  1     22912  models.common.C3                        [128, 64, 1, False]           
 18                -1  1     36992  models.common.Conv                      [64, 64, 3, 2]                
 19          [-1, 14]  1         0  models.common.Concat                    [1]                           
 20                -1  1     74496  models.common.C3                        [128, 128, 1, False]          
 21                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]              
 22          [-1, 10]  1         0  models.common.Concat                    [1]                           
 23                -1  1    296448  models.common.C3                        [256, 256, 1, False]          
 24                -1  1     74338  models.moreattention.CBAM               [256, 256]                    
 25      [17, 20, 24]  1      8118  models.yolo.Detect                      [1, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [64, 128, 256]]

Model Summary: 284 layers, 1839608 parameters, 1839608 gradients, 4.3 GFLOPs

运行后打印如上代码说明改进成功。

六、注意

还有一些可添加的位置,常见可分为骨干和颈部,可作用于局部或全局。

本文只是一个示例修改,实际上还可以将注意力机制添加在更多地方,另外需要注意的是

第二步中self.act = nn.LeakyReLU(0.1) if act else nn.Identity()适用于YOLOv7-tiny,若你采用的是YOLOv5或YOLOv7,则需要修改为SiLU(),即self.act = nn.SiLU() if act else nn.Identity()

注意力机制能加在哪?会在下一篇文章中具体阐述给出,敬请关注。

更多文章产出中,主打简洁和准确,欢迎关注我,共同探讨!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1983158.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大模型学习笔记 - LLM 参数高效微调技术

LLM 参数高效微调技术 LLM 高效微调技术 Adapter Tuning (2019 google,在Transformer 结构中嵌入adapter,tuning on BERT)Prefix Tuning (2021.01 斯坦福,通过virtual token构造连续型隐式prompt)Prompt Tuning(2021.04,google)P…

CSS学习 02 利用鼠标悬停制造按钮边框的渐变方向变化

效果 页面背景为深灰色,使用Karla字体。容器内的按钮居中显示,按钮有一个彩色渐变的边框。按钮的背景为黑色,文字为浅灰色。当鼠标悬停在按钮边框上时,边框的渐变方向变化,按钮文字变为白色,并且按钮内边距…

类和对象——日期类实现

目录 前言: 一、构造 二、运算符重载 三、与-- 四、实现与 操作 4.1 实现 操作 4..2 实现操作 4.3 效率分析 五、实现 - 与- 操作符 六、日期减日期 七、流插入和流提取 八、完整代码 Date.h Date.cpp 前言: 我们之前已经学习过了类和对象&…

交换机-端口安全

端口安全 1 概述 端口安全(Port Security)通过将接口学习到的动态MAC地址转换为安全MAC地址(包括安全动态MAC、安全静态MAC和Sticky MAC),阻止非法用户通过本接口和交换机通信,从而增强设备的安全性。 2 …

力扣热题100_二叉树_94_二叉树的中序遍历

文章目录 题目链接解题思路解题代码 题目链接 94. 二叉树的中序遍历 给定一个二叉树的根节点 root ,返回 它的 中序 遍历 。 示例 1: 输入:root [1,null,2,3] 输出:[1,3,2] 示例 2: 输入:root [] 输…

全面提升PDF编辑效率,2024年五大顶级PDF编辑器推荐!

在这个数字化飞速发展的时代,PDF文件已经成为我们日常工作和学习中不可或缺的一部分。然而,面对PDF文件的编辑和管理,许多人仍然感到困惑和无助。今天,就让我们一起探索几款高效、易用的PDF编辑器,它们将彻底改变你的工…

抖音短视频素材网站有哪些?做抖音的几个素材站分享

随着短视频平台尤其是抖音的飞速发展,更多的内容创作者正加入到这一潮流中,无论是生活分享、才艺展示还是知识传播,一个吸引人的视频都离不开优质的素材。本文将为追求流行的抖音视频创作者介绍几个寻找热门素材的优选平台,帮助您…

c语言-编译预处理

9 编译预处理 预处理命令 一、预处理命令的作用 -- 在编译之前,需要执行的命令 -- 编译预处理是在对源程序正式编译之前的处理,如inlcude先替换后计算 二、预处理命令 1、#include -- 包含包括 (1) #include&l…

考研数学|张宇18讲太多太难,强化要不要换武老?

面对张宇18讲的难度,考虑是否更换辅导资料是一个值得深思的问题。首先,要认识到每个人的学习风格和需求是不同的。 张宇18讲可能因其深度和广度而受到一些考生的青睐,但同时也可能因其难度让部分考生感到挑战重重。 在决定是否更换辅导资料…

如何优雅的使用枚举类型,可以这样做!

使用枚举有时候会给我们带来了一些困扰: 前端展示数据,需要将枚举转成用户可读的数据显示。 代码中的枚举类型要存储数据库得转成数值类型。 那么本文介绍一种自动转换方案,大大方便前后端使用枚举类型 我们以用户状态为例,用…

创建型模式(Creational Patterns)之工厂模式(Factory Pattern)之工厂方法模式(Factory Method Pattern)

1. 工厂方法模式(Factory Method Pattern) 将对象的创建延迟到子类中实现,每个具体产品类都有一个对应的工厂类负责创建,从而使得系统更加灵活。客户端可以通过调用工厂方法来创建所需的产品,而不必关心具体的实现细节。这种模式符合开放-封闭…

客户数据分析模型:RFM模型的深度解析与应用探索

RFM模型,作为客户数据分析中的经典工具,凭借其简单而强大的分析能力,被广泛应用于各行各业。本文旨在深入探讨RFM模型的核心原理、应用价值,并详细阐述其在2C(面向消费者)和2B(面向企业&#xf…

Lumerical 光束重叠率计算

Lumerical 光束重叠率计算 引言正文不同位置处 FDE Solver 对仿真结果的影响FDE Solver 不同尺寸下的光束重叠率计算引言 在 Lumerical 光纤模式仿真 一文中我们介绍了如何进行光纤模式的仿真。本文,我们将继续使用 SMF28 来进行光束重叠率计算说明。 正文 不同位置处 FDE …

大数据技术原理-spark的安装

摘要 本实验报告详细记录了在"大数据技术原理"课程中进行的Spark安装与应用实验。实验环境包括Spark、Hadoop和Java。实验内容涵盖了Spark的安装、配置、启动,以及使用Spark进行基本的数据操作,如读取本地文件、文件内容计数、模式匹配和行数…

【未授权访问漏洞复现~~~】

一: Redis未授权访问漏洞 步骤一:进入vulhub目录使用以下命令启动靶机… 进入目录:cd /vulhub-master/redis/4-unacc 启动:docker-compose up-d 检查:docker-compose ps步骤二:在Kali上安装redis程序进行服务的链接. #安装redis apt-get install redis #redis链接 redis-cli…

【Linux进程篇】并发艺术:Linux条件变量与生产消费者模型的完美融合

W...Y的主页 😊 代码仓库分享💕 前言:我们上篇博客中提到了线程互斥、引出了互斥锁解决了线程中的未将临界资源保护的问题。但是随之出来的问题——竞争锁是自由竞争的,竞争锁的能力太强的线程会导致其他线程抢不到票&#xff0…

Linux——线程互斥与同步

一、线程互斥 1.1 线程间互斥的概念 在学习管道的时候,管道是自带同步与互斥的。而在线程中,当多个线程没有加锁的情况下同时访问临界资源时会发生混乱。在举例之前,先了解几个概念。 临界资源:多个线程执行流共享的资源叫做临…

软甲测试定义和分类

软件测试定义 使用人工和自动手段来运行或测试某个系统的过程,其目的在于检验他是否满足规定的需求或弄清预期结果与实际结果之间的差别 软件测试目的 为了发现程序存在的代码或业务逻辑错误 – 第一优先级发现错误为了检验产品是否符合用户需求 – 跟用户要求实…

WPF学习(3)- WrapPanel控件(瀑布流布局)+DockPanel控件(停靠布局)

WrapPanel控件(瀑布流布局) WrapPanel控件表示将其子控件从左到右的顺序排列,如果第一行显示不了,则自动换至第二行,继续显示剩余的子控件。我们来看看它的结构定义: public class WrapPanel : Panel {pub…

【网页设计】基于HTML+CSS上海旅游网站网页作业制作

一、👨‍🎓网站题目 旅游,当地特色,历史文化,特色小吃等网站的设计与制作。 二、✍️网站描述 👨‍🎓静态网站的编写主要是用HTML DIVCSS 等来完成页面的排版设计👩‍&#x1f39…