YOLOV7 添加 CBAM 注意力机制

news2024/10/5 14:17:30

用于学习记录

文章目录

  • 前言
  • 一、CBAM
    • 1.1 models/common.py
    • 1.2 models/yolo.py
    • 1.3 yolov7/cfg/training/CBAM.yaml
    • 2.4 CBAM 训练结果图


前言


一、CBAM

CBAM: Convolutional Block Attention Module

1.1 models/common.py

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)
        
class CBAM(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(CBAM, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.Hardswish() if act else nn.Identity()
        self.ca = ChannelAttention(c2)
        self.sa = SpatialAttention()

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        x = self.ca(x) * x
        x = self.sa(x) * x
        return x

    def fuseforward(self, x):
        return self.act(self.conv(x))

1.2 models/yolo.py

搜索 if m in 添加以下代码 CBAM

        if m in [nn.Conv2d, Conv, RobustConv, RobustConv2, DWConv, GhostConv, RepConv, RepConv_OREPA, DownC, 
                 SPP, SPPF, SPPCSPC, GhostSPPCSPC, MixConv2d, Focus, Stem, GhostStem, CrossConv, 
                 Bottleneck, BottleneckCSPA, BottleneckCSPB, BottleneckCSPC, 
                 RepBottleneck, RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC,  
                 Res, ResCSPA, ResCSPB, ResCSPC, 
                 RepRes, RepResCSPA, RepResCSPB, RepResCSPC, 
                 ResX, ResXCSPA, ResXCSPB, ResXCSPC, 
                 RepResX, RepResXCSPA, RepResXCSPB, RepResXCSPC, 
                 Ghost, GhostCSPA, GhostCSPB, GhostCSPC,
                 SwinTransformerBlock, STCSPA, STCSPB, STCSPC,
                 SwinTransformer2Block, ST2CSPA, ST2CSPB, ST2CSPC, C3, CBAM]:
            c1, c2 = ch[f], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)

            args = [c1, c2, *args[1:]]
            if m in [DownC, SPPCSPC, GhostSPPCSPC, 
                     BottleneckCSPA, BottleneckCSPB, BottleneckCSPC, 
                     RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC, 
                     ResCSPA, ResCSPB, ResCSPC, 
                     RepResCSPA, RepResCSPB, RepResCSPC, 
                     ResXCSPA, ResXCSPB, ResXCSPC, 
                     RepResXCSPA, RepResXCSPB, RepResXCSPC,
                     GhostCSPA, GhostCSPB, GhostCSPC,
                     STCSPA, STCSPB, STCSPC,
                     ST2CSPA, ST2CSPB, ST2CSPC, C3]:
                args.insert(2, n)  # number of repeats
                n = 1

1.3 yolov7/cfg/training/CBAM.yaml

# parameters
nc: 60  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple

# anchors
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

backbone:
  # [from, number, module, args] c2, k=1, s=1, p=None, g=1, act=True
  # [[-1, 1, Conv, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]],  # 0-P1/2 
  [[-1, 1, CBAM, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]],  # 0-P1/2  
  
  #  [-1, 1, Conv, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]],  # 1-P2/4  
   [-1, 1, CBAM, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]],  # 1-P2/4    
   
   [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 7
   
   [-1, 1, MP, []],  # 8-P3/8
   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 14
   
   [-1, 1, MP, []],  # 15-P4/16
   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 21
   
   [-1, 1, MP, []],  # 22-P5/32
   [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [512, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 28
  ]

# yolov7-tiny head
head:
  [[-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, SP, [5]],
   [-2, 1, SP, [9]],
   [-3, 1, SP, [13]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -7], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 37
  
   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [21, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P4
   [[-1, -2], 1, Concat, [1]],
   
   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 47
  
   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [14, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P3
   [[-1, -2], 1, Concat, [1]],
   
   [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 57
   
   [-1, 1, Conv, [128, 3, 2, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, 47], 1, Concat, [1]],
   
   [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 65
   
   [-1, 1, Conv, [256, 3, 2, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, 37], 1, Concat, [1]],
   
   [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [[-1, -2, -3, -4], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 73
      
   [57, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [65, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
   [73, 1, Conv, [512, 3, 1, None, 1, nn.LeakyReLU(0.1)]],

   [[74,75,76], 1, Detect, [nc, anchors]],   # Detect(P3, P4, P5)
  ]

2.4 CBAM 训练结果图

在这里插入图片描述
在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/954942.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3升级了些什么

Vue 3 升级了以下几个方面的内容: 响应式系统:Vue 3 使用了 Proxy 对象来替代 Vue 2 中的 Object.defineProperty,这使得响应式系统更加高效和灵活。Vue 3 的响应式系统可以追踪更细粒度的依赖关系,提供了更好的性能和更细致的响应…

深兰科技荣膺“2023人工智能行业领航企业奖”

8月28日,由高科技行业门户OFweek维科网主办,OFweek物联网、OFweek人工智能承办的“维科杯OFweek 2023(第八届)物联网与人工智能行业年度评选(OFweek 8th IoT & AI Awards 2023)”在深圳福田会展中心成功举行。 深兰科技凭借在自动驾驶及新能源…

Apollo安装与配置使用

介绍 Apollo(阿波罗)是一款可靠的分布式配置管理中心,诞生于携程框架研发部,能够集中化管理应用不同环境、不同集群的配置,配置修改后能够实时推送到应用端,并且具备规范的权限、流程治理等特性&#xff0…

外贸新手必看的寄样品攻略!别再盲目踩雷了!

样品,被看做是订单前的一个敲门砖,寄样品这一步如果处理的不好,同样会对最终谈判结果产生较大影响。因此外贸新手在寄样品前,也需要对具体流程和注意事项做一个了解,以避免在这个过程中,造成无法挽回的后果…

西北大学计算机考研844经验分享(初试科目844-笔记课件整理)

西北大学计算机考研844经验分享 个人介绍 ​ 本人是西北大学22级软件工程研究生,考研专业课129分,过去一年里在各大辅导机构任职,辅导考研学生专业课844,辅导总时长达288小时,帮助多名学生专业课高分上岸。 前情回顾…

23款奔驰GLS450豪华型升级原厂电动吸合门,体验绅士的关门状态

电吸门的工作原理是在门框(或门板边缘)上安装一个电磁线圈。当门打开时,电流会流过线圈,形成电磁场。这样,由于磁力的作用,当门靠近门框关闭时,门会自动关闭。 另外,电吸门也有有用的一面。如果下车&#…

【考研数学】概率论与数理统计 —— 第二章 | 一维随机变量及其分布(1,基本概念与随机变量常见类型)

文章目录 引言一、一维随机变量及其分布1.1 随机变量1.2 分布函数 二、随机变量常见类型及分布2.1 离散型随机变量2.2 连续型随机变量及概率密度函数 写在最后 引言 暑假接近尾声了,争取赶一点概率论部分的进度。 一、一维随机变量及其分布 1.1 随机变量 设随机试…

Linux查看目录下文件及其大小

ls -lh在Linux下,"ls -lh"是一个用于显示文件和目录详细信息的命令。它会列出当前目录中的文件和目录,并显示它们的文件大小和权限等详细信息。 其中,参数"-l"是用来显示详细信息的选项,"h"表示以…

gif怎么转换成mp4格式视频

gif怎么转换成mp4格式视频?GIF格式是一种广泛应用的公用图像文件格式标准,具有许多优势。它占用的内存较小,可以实现自动循环播放,并且兼容多个平台。然而,GIF格式也存在一些缺点。例如,它无法处理复杂的图…

恒运资本:三大指数震荡走低,地产股大幅回撤,光刻胶概念逆市上涨

周四(8月31日),到上午收盘,A股三大指数震动走低。其间,上证指数跌0.53%,报3120.39点;深证成指和创业板指别离跌0.55%、0.54%;沪深两市算计成交额5290.51亿元,总体来看&am…

如何在小红书进行学习直播

诸神缄默不语-个人CSDN博文目录 因为我是从B站开始的,所以一些直播常识型的东西请见我之前写的如何在B站进行学习直播这一篇。 本篇主要介绍一些小红书之与B站不同之处。 小红书在手机端是可以直接点击“”选择直播的。 文章目录 1. 电脑直播-小红书直播软件2. 电…

【LeetCode算法系列题解】第11~15题

CONTENTS LeetCode 11. 盛最多水的容器(中等)LeetCode 12. 整数转罗马数字(中等)LeetCode 13. 罗马数字转整数(简单)LeetCode 14. 最长公共前缀(简单)LeetCode 15. 三数之和&#xf…

Unity碰撞检测(3D和2D)

Unity碰撞检测3D和2D 前言准备材料3D2D 代码3D使用OnCollisionEnter()进行碰撞Collider状态代码 使用OnTriggerEnter()进行碰撞Collider状态代码 2D使用OnCollisionEnter2D()进行碰撞Collider2D状态代码 使用OnTriggerEnter2D()进行碰撞Collider2D状态代码 区别3D代码OnCollisi…

Python中*(星号)传可变长的元组

在Python中,*(星号)可以用来传递变长元组参数,通常在函数定义和函数调用中使用。这是一种用于处理不定数量的参数的方式,使得函数能够接受任意数量的位置参数。 在函数定义中使用 *: 在函数定义时&#xff…

Jupyter lab 配置

切换jupyterlab的默认工作目录 在终端中输入以下命令 PS C:\Users\Administrator> jupyter-lab --generate-config Writing default config to: C:\Users\Administrator\.jupyter\jupyter_lab_config.py它就会生成JupyterLab的配置文件(如果之前有这个文件的话…

奥威BI数据可视化工具一出马,财务数据分析不再烧脑

数据可视化工具可以使财务数据分析更加直观和易于理解。这些工具可以将大量的财务数据简化为易于阅读和理解的图表、图形和表格,帮助财务人员更快地分析和发现问题。例如,通过将财务数据转化为柱状图、折线图、饼图等图形,可以更加清晰地展示…

《向量数据库指南》——腾讯云向量数据库(Tencent Cloud VectorDB) SDK 正式开源

腾讯云向量数据库 SDK 宣布正式开源。根据介绍,腾讯云向量数据库(Tencent Cloud VectorDB)的 Python SDK 与 Java SDK 是基于数据库设计模型,遵循 HTTP 协议,将 API 封装成易于使用的 Python 与 Java 函数或类,为开发者提供了更加友好、更加便捷的数据库使用和管理方式。…

刚来的00后进来就有18K,我三年工作经验就是个笑话

00后带来的压力 公司一位工作3年的老油条工资还没有刚来的00后高,她心中不平,对这件事情有不小的怨气,她觉得自己来公司三年了,三年内迟到次数都不超过5次,每天勤勤恳恳,要加班的时候也愿意加班&#xff0…

vue自定义事件 div 拖拽方法缩小

在main.js 引用 // 引入拖动js import dragMove from "./utils/dragMove.js" 创建 drawmove.js export default (app) > {app.directive(dragMove, (el, binding) > {const DragVindow el.querySelector(binding.value.DragVindow)// 按下鼠标处理事件con…

修改yum下载文件的位置,指定安装位置

yum update 的软件包,可以放在别的地方。即可。 修改/etc/yum.conf 指定安装位置 yum -c /etc/yum.conf --installroot/usr/local --releasever/ install 你需要安装的软件