YOLOv8改进 | 注意力篇 | YOLOv8引入CBAM注意力机制

news2024/11/23 23:33:21

1.CBAM介绍

摘要:我们提出了卷积块注意力模块(CBAM),这是一种用于前馈卷积神经网络的简单而有效的注意力模块。 给定中间特征图,我们的模块沿着两个独立的维度(通道和空间)顺序推断注意力图,然后将注意力图乘以输入特征图以进行自适应特征细化。 由于 CBAM 是一个轻量级通用模块,因此它可以无缝集成到任何 CNN 架构中,且开销可以忽略不计,并且可以与基础 CNN 一起进行端到端训练。 我们通过在 ImageNet-1K、MS COCO 检测和 VOC 2007 检测数据集上进行大量实验来验证我们的 CBAM。 我们的实验表明各种模型的分类和检测性能得到了一致的改进,证明了 CBAM 的广泛适用性。 代码和模型将公开。

官方论文地址:CBAM论文 

官方代码地址:CBAM代码

简单介绍:CBAM的主要思想是通过关注重要的特征并抑制不必要的特征来增强网络的表示能力。模块首先应用通道注意力,关注"重要的"特征,然后应用空间注意力,关注这些特征的"重要位置"。通过这种方式,CBAM有效地帮助网络聚焦于图像中的关键信息,提高了特征的表示力度,下图为其原理结构图。

2.核心代码

import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""

    def __init__(self, channels: int) -> None:
        """Initializes the class and sets the basic configurations and instance variables required."""
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""
        return x * self.act(self.fc(self.pool(x)))


class SpatialAttention(nn.Module):
    """Spatial-attention module."""

    def __init__(self, kernel_size=7):
        """Initialize Spatial-attention module with kernel size argument."""
        super().__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1
        self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.act = nn.Sigmoid()

    def forward(self, x):
        """Apply channel and spatial attention on input for feature recalibration."""
        return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))


class CBAM(nn.Module):
    """Convolutional Block Attention Module."""

    def __init__(self, c1, kernel_size=7):
        """Initialize CBAM with given input channel (c1) and kernel size."""
        super().__init__()
        self.channel_attention = ChannelAttention(c1)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        """Applies the forward pass through C1 module."""
        return self.spatial_attention(self.channel_attention(x))

3.YOLOv8中添加CBAM方式  

3.1 在ultralytics/nn下新建Extramodule

3.2 在Extramodule里创建CBAM

在CBAM.py文件里添加给出的CBAM代码

添加完CBAM代码后,在ultralytics/nn/Extramodule/__init__.py文件中引用

3.3 在task.py里引用

在ultralytics/nn/tasks.py文件里引用Extramodule

在task.py找到parse_model(ctrl+f可以直接搜索parse_model位置

添加如下代码:

        elif m in {CBAM}:
            c2 = ch[f]
            args = [c2, *args]

4.新建一个yolov8CBAM.yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 2 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 12

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 15 (P3/8-small)
  - [-1, 1, CBAM, []]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
  - [-1, 1, CBAM, []]

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
  - [-1, 1, CBAM, []]

  - [[15, 19, 24], 1, Detect, [nc]] # Detect(P3, P4, P5)

大家根据自己的数据集实际情况,修改nc大小。

5.模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO(r'E:\csdn\ultralytics-main\datasets\yolov8CBAM.yaml')
    model.train(data=r'E:\csdn\ultralytics-main\datasets\data.yaml',
                cache=False,
                imgsz=640,
                epochs=100,
                single_cls=False,  # 是否是单类别检测
                batch=16,
                close_mosaic=10,
                workers=0,
                device='0',
                optimizer='SGD',
                amp=True,
                project='runs/train',
                name='exp',
                )

模型结构打印,成功运行 :

6.本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

YOLOv8有效涨点专栏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2087280.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python | Leetcode Python题解之第384题打乱数组

题目: 题解: class Solution:def __init__(self, nums: List[int]):self.nums numsself.original nums.copy()def reset(self) -> List[int]:self.nums self.original.copy()return self.numsdef shuffle(self) -> List[int]:for i in range(l…

C++ | Leetcode C++题解之第383题赎金信

题目&#xff1a; 题解&#xff1a; class Solution { public:bool canConstruct(string ransomNote, string magazine) {if (ransomNote.size() > magazine.size()) {return false;}vector<int> cnt(26);for (auto & c : magazine) {cnt[c - a];}for (auto &am…

群晖(Docker Compose)配置 frp 服务

为了方便远程电脑&#xff0c;访问自己电脑上的ComfyUI等服务&#xff0c;配置了 frp 服务。 配置 frp 服务后&#xff0c;发现群晖中的一些服务也可以 stcp 安全的暴露出来。 直接在群晖通过 Docker Compose 方式部署 frps 和 frpc&#xff0c;访问者通过 frpc 安全访问暴露…

计算机三级网络第3套练习记背

计算机三级网络第3套练习记背

【C++ | 设计模式】抽象工厂模式的详解与实现

1. 概念 抽象工厂模式&#xff08;Abstract Factory Pattern&#xff09;是一种创建型设计模式&#xff0c;用于创建一系列相关或相互依赖的对象&#xff0c;而无需指定它们具体的类。它允许客户端代码通过工厂接口来创建一组对象&#xff0c;而无需了解它们的具体实现细节。 …

从暴力到秩序:解锁权力奥秘

从暴力到秩序&#xff1a;解锁权力奥秘 - 孔乙己大叔权力的诞生 在人类社会的最初形态中&#xff0c;权力往往源自最原始的力量——暴力。一个人&#xff0c;起初仅拥有一把枪&#xff0c;他的权力简单而直接&#xff1a;决定对谁开枪。然而&#xff0c;随着他利用这把…

【58同城-注册安全分析报告】

前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨大&#xff0c;造成亏损无底洞 …

【Scala】Windows下安装Scala(全面)

目录 1.下载 2.安装 3.配置环境变量 1.新增系统环境变量 2.环境变量Path 4.验证 1.下载 官网下载地址&#xff1a;https://downloads.lightbend.com/scala/2.11.12/scala-2.11.12.msi 2.安装 双击下载的.msi文件&#xff1a; 勾选"I accept the terms in the Li…

Flink 1.14.* Flink窗口创建和窗口计算源码

解析Flink如何创建的窗口&#xff0c;和以聚合函数为例&#xff0c;窗口如何计算聚合函数 一、构建不同窗口的build类1、全局窗口2、创建按键分流后的窗口 二、在使用窗口处理数据流时&#xff0c;不同窗口创建的都是窗口算子WindowOperator1、聚合函数实现2、创建全局窗口(入参…

智能合约开发与测试1

智能合约开发与测试 任务一&#xff1a;智能合约设计 &#xff08;1&#xff09;编写新能源智能合约功能需求文档。 区块链新能源管理智能合约功能需求包括资产与能源绑定、用户管理、能源交易、智能结算等&#xff0c;确保安全性、隐私保护和可扩展性&#xff0c;提高能源利…

2024年第六届控制与机器人国际会议(ICCR 2024)即将召开!

2024年第六届控制与机器人国际会议&#xff08;ICCR 2024&#xff09;将于2024年12月5日至7日在日本横滨举行。智能机器人结合了多种概念、学科和技术&#xff0c;共同创造出各种有用的设备、操作器和自主实体&#xff0c;为特定人类社区服务&#xff0c;如制造设备、医疗和远程…

【练习】哈希表的使用

&#x1f3a5; 个人主页&#xff1a;Dikz12&#x1f525;个人专栏&#xff1a;算法(Java)&#x1f4d5;格言&#xff1a;吾愚多不敏&#xff0c;而愿加学欢迎大家&#x1f44d;点赞✍评论⭐收藏 目录 1.哈希表简介 2.两数之和 题目描述 题解 代码实现 2.面试题.判定是否…

代码随想录Day 28|题目:122.买卖股票的最佳时机Ⅱ、55.跳跃游戏、45.跳跃游戏Ⅱ、1005.K次取反后最大化的数组和

提示&#xff1a;DDU&#xff0c;供自己复习使用。欢迎大家前来讨论~ 文章目录 题目题目一&#xff1a;122.买卖股票的最佳时机 II贪心算法&#xff1a;动态规划 题目二&#xff1a;55.跳跃游戏解题思路&#xff1a; 题目三&#xff1a; 45.跳跃游戏 II解题思路方法一方法二 题…

在Centos中的mysql的备份与恢复

1.物理备份 冷备份&#xff1a;关闭数据库时进行热备份&#xff1a;数据库运行时进行&#xff0c;依赖于数据库日志文件温备份&#xff1a;数据库不可写入但可读的状态下进行 2.逻辑备份 对数据库的表或者对象进行备份 3.备份策略 完全备份&#xff1a;每次都备份完整的数…

每日OJ_牛客_Rational Arithmetic(英文题模拟有理数运算)

目录 牛客_Rational Arithmetic&#xff08;英文题模拟有理数运算&#xff09; 解析代码 牛客_Rational Arithmetic&#xff08;英文题模拟有理数运算&#xff09; Rational Arithmetic (20)__牛客网 解析代码 本题看上去不难&#xff0c;但是存在几个问题&#xff1a; 除…

【C++】汇编分析

传参 有的是用寄存器传参&#xff0c;有的用push传参 我在MSVC编译测出来的是PUSH传参&#xff08;debug模式&#xff09;&#xff0c;具体过程如下 long func(long a, long b, long c, long d,long e, long f, long g, long h) {long sum;sum (a b c d e f g h);ret…

《机器学习》文本数据分析之关键词提取、TF-IDF、项目实现 <上>

目录 一、如何进行关键词提取 1、关键词提取步骤 1&#xff09;数据收集 2&#xff09;数据准备 3&#xff09;模型建立 4&#xff09;模型结果统计 5&#xff09;TF-IDF分析 2、什么是语料库 3、如何进行中文分词 1&#xff09;导包 2&#xff09;导入分词库 3&#xff09…

智能优化特征选择|基于鲸鱼WOA优化算法实现的特征选择研究Matlab程序(SVM分类器)

智能优化特征选择|基于鲸鱼WOA优化算法实现的特征选择研究Matlab程序&#xff08;SVM分类器&#xff09; 文章目录 一、基本原理鲸鱼智能优化特征选择&#xff08;WOA&#xff09;结合SVM分类器的详细原理和流程原理流程 二、实验结果三、核心代码四、代码获取五、总结 智能优化…

js | XMLHttpRequest

是什么&#xff1f; 和serve交互数据的对象&#xff1b;能够达到页面部分刷新的效果&#xff0c;也就是获取数据之后&#xff0c;不会使得整个页面都刷新&#xff1b;虽然名字是XML&#xff0c;但不限于XML数据。 怎么用&#xff1f; function reqListener() {console.log(thi…

理解数据库系统的内部结构

数据库系统在我们的数字世界中扮演着关键角色。本文将介绍数据库系统的内部结构&#xff0c;帮助初学者了解其基本概念。 数据库系统的三级模式 数据库系统内部采用三级模式二级映像结构&#xff0c;包括外模式、模式和内模式。这种结构确保了数据的逻辑独立性和物理独立性。…