YOLOv8改进 - 注意力篇 - 引入CBAM注意力机制

news2024/9/20 6:40:18

一、本文介绍

作为入门性第一篇,这里介绍了CBAM注意力在YOLOv8中的使用。包含CBAM原理分析,CBAM的代码、CBAM的使用方法、以及添加以后的yaml文件及运行记录。

二、CBAM原理分析

CBAM官方论文地址:CBAM论文

CBAM的pytorch版代码:CBAM的pytorch版代码

CBAM:卷积块注意力模块,由通道注意力和空间注意力组成。其中通道注意力机制主要检测目标的内容信息,空间注意力主要检测目标位置信息。模块先应用通道注意力,再利用空间注意力;其原理结构如下图所示。

相关代码:

在YOLOv8中,作者已经集成了cbam注意力的代码,仅未应用。

class ChannelAttention(nn.Module):
    """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""

    def __init__(self, channels: int) -> None:
        """Initializes the class and sets the basic configurations and instance variables required."""
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""
        return x * self.act(self.fc(self.pool(x)))


class SpatialAttention(nn.Module):
    """Spatial-attention module."""

    def __init__(self, kernel_size=7):
        """Initialize Spatial-attention module with kernel size argument."""
        super().__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1
        self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.act = nn.Sigmoid()

    def forward(self, x):
        """Apply channel and spatial attention on input for feature recalibration."""
        return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))


class CBAM(nn.Module):
    """Convolutional Block Attention Module."""

    def __init__(self, c1, kernel_size=7):
        """Initialize CBAM with given input channel (c1) and kernel size."""
        super().__init__()
        self.channel_attention = ChannelAttention(c1)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        """Applies the forward pass through C1 module."""
        return self.spatial_attention(self.channel_attention(x))

四、YOLOv8中CBAM使用方法

YOLOv8中CBAM模块,作者存于ultralytics/nn/modules/conv.py中。

我们使用CBAM模块,仅需在ultralytics/nn/tasks.py进行CBAM注意力机制的注册,以及在YOLOv8的yaml配置文件中添加CBAM即可。

首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面添加以下注册代码:

        elif m in {CBAM,MHSA,SEAttention,ECA,ShuffleAttention,ECA_SA,SE_SA,SA_ECA,SA_ShuffleAttention,CBAM_base,PECA_SA,CPAM, CPAM_SA, MSCA_SA, SKAttention, DoubleAttention,PCPAM_SA,SA_CPAM,CoordAtt}:#自己加的注意力模块
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, *args[1:]]

然后,就是新建一个名为YOLOv8_CBAM.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_CBAM.yaml)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call CPAM-yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, CBAM, [1024,7]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

其中参数中nc,由自己的数据集决定。本文测试,采用的coco8数据集,有80个类别。

在根目录新建一个train.py文件,内容如下

from ultralytics import YOLO


# 加载一个模型
model = YOLO('ultralytics/cfg/models/v8/YOLOv8_CBAM.yaml')  # 从YAML建立一个新模型
# 训练模型
results = model.train(data='ultralytics/cfg/datasets/coco8.yaml', epochs=1,imgsz=640,optimizer="SGD")

训练输出:

五、总结

以上就是CBAM的原理及使用方式,但具体CBAM注意力机制的具体位置放哪里,效果更好。需要根据不同的数据集做相应的实验验证。希望本文能够帮助你入门YOLO中注意力机制的使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2143252.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Gateway网关的实现

API网关 网关路由必须支持负载均衡,服务列表是从注册中心拉取的客户端发出请求的URL指向的是网关,URL还必须要包含目标信息网关收到URL,通过一定的规则,要能识别出交给哪个实例去处理网关有能力对请求响应进行修改 引入依赖包 …

图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)

DFS 基础 BFS 基础 Leetcode 815. 公交路线 思路&#xff1a; class Solution { public:int numBusesToDestination(vector<vector<int>>& routes, int source, int target) {// 记录经过车站x的公交车编号 hashunordered_map<int, vector<int>> …

frp内网穿透功能使用教程

frp 是一款高性能的反向代理应用&#xff0c;专注于内网穿透。它支持多种协议&#xff0c;包括 TCP、UDP、HTTP、HTTPS 等&#xff0c;并且具备 P2P 通信功能。使用 frp&#xff0c;您可以安全、便捷地将内网服务暴露到公网&#xff0c;通过拥有公网 IP 的节点进行中转。 文档地…

Python3网络爬虫开发实战(15)Scrapy 框架的使用(第一版)

文章目录 一、Scrapy 框架介绍1.1 数据流1.2 项目结构1.3 Scrapy 入门 二、Selector 解析器2.1 XPath 和 CSS 选择器2.2 信息提取2.3 正则提取 三、Spider 的使用3.1 Spider 运行流程3.2 Spider 类分析3.3 Request3.4 Response 四、Download Middleware 的使用4.1 process_requ…

55.【C语言】字符函数和字符串函数(strstr函数)

11.strstr函数 *简单使用 strstr: string string cplusplus的介绍 点我跳转 翻译: 函数 strstr const char * strstr ( const char * str1, const char * str2 ); 或另一个版本char * strstr ( char * str1, const char * str2 ); 寻找子字符串 返回指向第一次出现在字…

PostMan使用变量

环境变量 使用场景 当测试过程中&#xff0c;我们需要对开发环境、测试环境、生产环境进行测试 不同的环境对应着不同的服务器&#xff0c;那么这个时候我们就可以使用环境变量来区分它们 避免切换测试环境后&#xff0c;需要大量的更改接口的url地址 全局变量 使用场景 当…

chapter2-站点首页功能实现

1. 导航功能实现 导航数据的存储&#xff0c;必须要找出导航数据的结构&#xff0c;也就是有哪些属性&#xff1f; 导航位置、导航名称、导航链接、导航序号、是否显示、是否外链、添加时间、更新时间、是否删除 E-R图&#xff0c;如下&#xff1a; 1.1 创建模型 home/model…

Ardupilot开源飞控之VTOL之旅:且败且战

Ardupilot开源飞控之VTOL之旅&#xff1a;且败且战 1. 源由2. 希望3. 打印件3.1 Back_cover_AKK_Race_Ranger_VTX_SMA3.2 CRSF-Ant_mount_TPU3.3 H743_mount 4. 其他问题5. 总结6. 参考资料 1. 源由 折腾了这么久&#xff0c;可是VTOL就是一直没有装起来&#xff0c;主要问题还…

vs2022快捷键异常解决办法

安装了新版本的vs2022&#xff0c;安装成功后&#xff0c;发现快捷键发生异常&#xff0c;之前常用的快捷键要么发生改变&#xff0c;要么无法使用&#xff0c;比如原来注释代码的快捷键是ctrlec&#xff0c;最新安装版本变成了ctrlkc&#xff0c;以前编译代码的快捷键是F6或者…

【delphi】正则判断windows完整合法文件名,包括路径

在 Delphi 中&#xff0c;可以使用正则表达式来检查 Windows 文件名称或路径是否合法。合法的文件名和路径要求符合以下几点&#xff1a; 禁止的字符&#xff1a;文件名和路径不能包含以下字符&#xff1a;<, >, :, ", /, \, |, ?, *。文件名不能以空格或点结束。…

一文讲懂Mac中的环境变量

你是否曾经因为环境变量配置不当而浪费了宝贵的开发时间?你是否好奇为什么有时候在终端输入命令会提示"command not found",而有时候又能正常运行?如果你是一名Mac用户,并且希望真正掌握环境变量的奥秘,那么这篇文章将为你揭开Mac中环境变量的神秘面纱,帮助你成为一…

BFS 解决边权为1的最短路问题

边权为1的最短路问题 最短路问题&#xff1a; 比如说从D->K&#xff0c;找出最短的那条&#xff0c;其中每条路都是有权值&#xff0c;此篇主要讲解的边权为1的最短路问题。 即边权都是一样的。 解法就是从起点开始&#xff0c;做一次BFS&#xff1a; 需要一个队列、一个…

深入理解IP地址分类及子网划分详解

在互联网时代&#xff0c;IP地址是网络通信的基础。无论是访问网站、发送电子邮件&#xff0c;还是进行数据传输&#xff0c;IP地址都扮演着至关重要的角色。本文将详细解析IP地址的分类及子网划分的原理&#xff0c;帮助你更好地理解网络架构及其应用。 一、什么是IP地址 IP…

通信工程学习:什么是TDMA时分多址

TDMA时分多址 TDMA&#xff08;Time Division Multiple Access&#xff0c;时分多址&#xff09;是一种在无线通信中广泛使用的多址接入技术。它通过将时间划分为不重叠的时间帧&#xff0c;并将每个时间帧进一步划分为多个时隙&#xff0c;每个时隙分配给不同的用户或通信系统…

8.JMeter+Ant(基于工具的实现接口自动化,命令行方式)

一、JMeterAnt&#xff08;基于工具的实现接口自动化&#xff09; 如果想要实现自动化&#xff0c;就必须使用命令行。 1.jmeter命令 -n 使用非界面的方式去执行脚本 -t 指定jmeter的脚本位置 -l 生成jtl报告&#xff0c;可以通过查看结果树来解析 -e 生产html格式的报告 -o …

p14 使用阿里云服务器的docker部署NGINX

拉取NGINX的镜像 这里因为之前已经配置过从阿里云的镜像仓库里面拿镜像所以这里直接就执行docker pull nginx拉取NGINX镜像就OK了 运行NGINX镜像 这里执行docker run -d --name nginx01 -p 3344:80 nginx这里3344是服务器访问的端口80是容器内部的端口&#xff0c;可以看到…

Flask-JWT-Extended登录验证, 不用自定义

"""安装:pip install Flask-JWT-Extended创建对象 初始化与app绑定jwt JWTManager(app) # 初始化JWTManager设置 Cookie 的选项:除了设置 cookie 的名称和值之外&#xff0c;你还可以指定其他的选项&#xff0c;例如&#xff1a;过期时间 (max_age)&#xff1…

【贪心】【数据结构-小根堆,差分】力扣2406. 将区间分为最少组数

给你一个二维整数数组 intervals &#xff0c;其中 intervals[i] [lefti, righti] 表示 闭 区间 [lefti, righti] 。 你需要将 intervals 划分为一个或者多个区间 组 &#xff0c;每个区间 只 属于一个组&#xff0c;且同一个组中任意两个区间 不相交 。 请你返回 最少 需要…

vue3 ref的用法及click事件的说明

1、ref可以定义一个简单的属性&#xff0c;也可以是一个复杂的列表、数组等等。 2、为什么要使用 ref&#xff1f;简单的let个变量不行吗&#xff1f;const个变量不行吗&#xff1f; 其实这个跟vue的响应式的系统有关&#xff0c;官方的说明如下&#xff1a; 3、为 ref() 标注…

VMWare中的Centos8:Errors during downloading metadata for repository ‘appstream‘

在VMWare的环境中&#xff0c;安装和部署好Centos8&#xff0c;待设置好网络环境后&#xff0c;安装部署C开发和编译环境&#xff0c;遇到报错&#xff1a; dnf gcc gcc-c -y 解决问题的办法如下, 1. 进入仓库源文件夹&#xff1a;cd /etc/yum.repos.d/ 2. 修改镜像配置{这…