YOLOv8改进算法之添加CA注意力机制

news2024/11/27 15:32:23

1. CA注意力机制

CA(Coordinate Attention)注意力机制是一种用于加强深度学习模型对输入数据的空间结构理解的注意力机制。CA 注意力机制的核心思想是引入坐标信息,以便模型可以更好地理解不同位置之间的关系。如下图:

1. 输入特征: CA 注意力机制的输入通常是一个特征图,它通常是卷积神经网络(CNN)中的某一层的输出,具有以下形状:[C, H, W],其中:

  • C 是通道数,表示特征图中的不同特征通道。
  • H 是高度,表示特征图的垂直维度。
  • W 是宽度,表示特征图的水平维度。

2. 全局平均池化: CA 注意力机制首先对输入特征图进行两次全局平均池化,一次在宽度方向上,一次在高度方向上。这两次操作分别得到两个特征映射:

  • 在宽度方向上的平均池化得到特征映射 [C, H, 1]
  • 在高度方向上的平均池化得到特征映射 [C, 1, W]

这两个特征映射分别捕捉了在宽度和高度方向上的全局特征。

3. 合并宽高特征: 将上述两个特征映射合并,通常通过简单的堆叠操作,得到一个新的特征层,形状为 [C, 1, H + W],其中 H + W 表示在宽度和高度两个方向上的维度合并在一起。

4. 卷积+标准化+激活函数: 对合并后的特征层进行卷积操作,通常是 1x1 卷积,以捕捉宽度和高度维度之间的关系。然后,通常会应用标准化(如批量标准化)和激活函数(如ReLU)来进一步处理特征,得到一个更加丰富的表示。

5. 再次分开: 分别从上述特征层中分离出宽度和高度方向的特征:

  • 一个分支得到特征层 [C, 1, H]
  • 另一个分支得到特征层 [C, 1, W]

6. 转置: 对分开的两个特征层进行转置操作,以恢复宽度和高度的维度,得到两个特征层分别为 [C, H, 1][C, 1, W]

7. 通道调整和 Sigmoid: 对两个分开的特征层分别应用 1x1 卷积,以调整通道数,使其适应注意力计算。然后,应用 Sigmoid 激活函数,得到在宽度和高度维度上的注意力分数。这些分数用于指示不同位置的重要性。

8. 应用注意力: 将原始输入特征图与宽度和高度方向上的注意力分数相乘,得到 CA 注意力机制的输出。

2. YOLOv8添加CA注意力机制

加入注意力机制,在ultralytics包中的nn包的modules里添加CA注意力模块,我这里选择在conv.py文件中添加CA注意力机制。

CA注意力机制代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class CoordAtt(nn.Module):
    def __init__(self, inp, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()

        self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x

        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)

        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out

CA注意力机制的注册和引用如下:

 ultralytics/nn/modules/_init_.py文件中:

  ultralytics/nn/tasks.py文件夹中:

 在tasks.py中的parse_model中添加如下代码:

        elif m in {CoordAtt}:
            args=[ch[f],*args]

新建相应的yolov8s-CA.yaml文件,代码如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 8], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 5], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 15], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[18, 21, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

在main.py文件中进行训练:

if __name__ == '__main__':

    # 使用yaml配置文件来创建模型,并导入预训练权重.
    model = YOLO('ultralytics/cfg/models/v8/yolov8s-CA.yaml')
    # model.load('yolov8n.pt')
    model.train(**{'cfg': 'ultralytics/cfg/default.yaml', 'data': 'dataset/data.yaml'})

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1052815.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【RocketMQ】【源码】Dledger日志复制源码分析

消息存储 在 【RocketMQ】消息的存储一文中提到,Broker收到消息后会调用CommitLog的asyncPutMessage方法写入消息,在DLedger模式下使用的是DLedgerCommitLog,进入asyncPutMessages方法,主要处理逻辑如下: 调用serial…

leetCode 122.买卖股票的最佳时机 II 动态规划 + 状态转移 + 状态压缩

122. 买卖股票的最佳时机 II - 力扣(LeetCode) 给你一个整数数组 prices ,其中 prices[i] 表示某支股票第 i 天的价格。 在每一天,你可以决定是否购买和/或出售股票。你在任何时候 最多 只能持有 一股 股票。你也可以先购买&…

006:连续跌三天,第四天上涨的概率--用python统计

我们已经可以获取到K线信息了,然后我们来进行一些统计,就统计连续三天下跌,第四天上涨的概率。 我们用宁波银行(002142)最近三年的数据来统计。先用上一篇的程序下载到K线数据,得到文件002142.csv。然后在…

Spring修炼之旅(4)静态/动态代理模式与AOP

一、代理模式概述 代理模式 为什么要学习代理模式,因为AOP的底层机制就是动态代理! 代理模式: 静态代理 动态代理 学习aop之前 , 我们要先了解一下代理模式! 1.1静态代理 静态代理角色分析 抽象角色 : 一般使用接口或者抽象…

【数据结构练习】二叉树相关oj题集锦二

目录 前言 1.平衡二叉树 2.对称二叉树 3.二叉树遍历 4.层序遍历 5.判断一棵树是不是完全二叉树 前言 编程想要学的好,刷题少不了,我们不仅要多刷题,还要刷好题!为此我开启了一个弯道超车必做好题锦集的系列,此为…

2023/9/30 使用消息队列完成进程间通信

发送方 ​ #include <myhead.h> //消息结构体 typedef struct {long msgtype; //消息类型char data[1024]; //消息正文 }Msg_ds;#define SIZE sizeof(Msg_ds) - sizeof(long) //正文大小 int main(int argc, const char *argv[]) {//1.创建key值key_t key ;if((key …

中断向量控制器(NVIC)

1. 什么是中断 在处理器中&#xff0c;中断是一个过程&#xff0c;即CPU在正常执行程序的过程中&#xff0c;遇到外部/内部的紧急事件需要处理&#xff0c;暂时中止当前程序的执行&#xff0c;转而去为处理紧急的事件&#xff0c;待处理完毕后再返回被打断的程序处继续往下执行…

Spring MVC 中的国际化和本地化

Spring MVC 中的国际化和本地化 国际化&#xff08;Internationalization&#xff0c;简称i18n&#xff09;和本地化&#xff08;Localization&#xff0c;简称l10n&#xff09;是构建多语言应用程序的重要概念。Spring MVC提供了丰富的支持&#xff0c;使开发人员能够轻松地处…

Python 笔记06(Mysql数据库)

一 基础 1.1 安装 MySQL下载参考&#xff1a;MySQL8.0安装配置教程【超级详细图解】-CSDN博客 测试是否安装并正确配置环境变量&#xff1a; 1.2 查看服务器是否正常运行 1.3 显示数据库 show databases; 1.4 退出 exit 1.5 python 连接 1.6 查主机IP ipconfig

2.springboot代理调用

1.概述 本文介绍在方法上开启声明式事务Transactional后(使用InfrastructureAdvisorAutoProxyCreator创建jdk动态代理)&#xff0c;springboot的调用该方法的过程&#xff1b; 2.结论(重点) 在方法开启声明式事务后&#xff0c;spring会为该对象创建动态代理。spring容器为该…

Android Jetpack组件架构:ViewModel的原理

Android Jetpack组件架构&#xff1a;ViewModel的原理 导言 本篇文章是关于介绍ViewModel的&#xff0c;由于ViewModel的使用还是挺简单的&#xff0c;这里就不再介绍其的基本应用&#xff0c;我们主要来分析ViewModel的原理。 ViewModel的生命周期 众所周知&#xff0c;一般…

聚观早报 | 2024款小鹏P5全新发布;华为发布13.2英寸MatePad Pro

【聚观365】9月26日消息 2024款小鹏P5全新发布 华为发布13.2英寸MatePad Pro 特斯拉发布人形机器人最新进展 百川智能发布Baichuan2-53B 软件行业仍将人才供不应求 2024款小鹏P5全新发布 继2024款小鹏G9问世仅一周&#xff0c;小鹏汽车再度发力新产品&#xff0c;2024款小…

【小沐学前端】Node.js实现UDP通信

文章目录 1、简介2、下载和安装3、代码示例3.1 HTTP3.2 UDP单播3.4 UDP广播 结语 1、简介 Node.js 是一个开源的、跨平台的 JavaScript 运行时环境。 Node.js 是一个开源和跨平台的 JavaScript 运行时环境。 它是几乎任何类型项目的流行工具&#xff01; Node.js 在浏览器之外…

2.4g无线收发芯片:Ci24R1(DFN8)

Ci24R1 采用GFSK/FSK数字调制与解调技术。数据传输速率与PA输出功率都可以调节&#xff0c;支持2Mbps, 1Mbps, 250Kbps三种数据速率。高的数据速率可以在更短的时间完成同样的数据收发&#xff0c;因此可以具有更低的功耗。 Ci24R1 是一颗工作在2.4GHz ISM频段&#xff0c;专为…

医疗实施-住院流程详解

住院就诊流程详解 1.病人入院登记2.病人进入病区3.医生操作病人4.医嘱录入与审核执行5. 医嘱收费前在对应业务系统的操作5.1.药物医嘱5.2.检查检验医嘱5.3.手术医嘱 6.住院医嘱费用的产生7. 医嘱收费后在对应业务系统的操作8. 病人出院 这篇文章是基于我的文章《医疗实施-住院就…

8.3Jmeter使用json提取器提取数组值并循环(循环控制器)遍历使用

Jmeter使用json提取器提取数组值并循环遍历使用 响应返回值例如&#xff1a; {"code":0,"data":{"totalCount":11,"pageSize":100,"totalPage":1,"currPage":1,"list":[{"structuredId":&q…

[React] 性能优化相关

文章目录 1.React.memo2.useMemo3.useCallback4.useTransition5.useDeferredValue 1.React.memo 当父组件被重新渲染的时候&#xff0c;也会触发子组件的重新渲染&#xff0c;这样就多出了无意义的性能开销。如果子组件的状态没有发生变化&#xff0c;则子组件是不需要被重新渲…

百度网盘的扩容

百度网盘的扩容怎么扩 百度网盘的扩容通常需要购买额外的存储空间。以下是扩容百度网盘存储空间的一般步骤&#xff1a; 登录百度网盘&#xff1a;首先&#xff0c;在您的计算机或移动设备上打开百度网盘&#xff0c;并使用您的百度账号登录。 选择扩容选项&#xff1a;一旦登…

数据结构题型12-链式队列

#include <iostream> //引入头文件 using namespace std;typedef int Elemtype;#define Maxsize 5 #define ERROR 0 #define OK 1typedef struct LinkNode {Elemtype data;struct LinkNode* next; }LinkNode;typedef struct {LinkNode* front;LinkNode* rear; }LinkQ…

java项目之小说阅读网站(ssm源码+文档)

项目简介 小说阅读网站实现了以下功能&#xff1a; 管理员&#xff1a;首页、个人中心、读者管理、作者管理、小说信息管理、小说分类管理、余额充值管理、购买小说管理、下载小说管理、系统管理。读者&#xff1a;个人中心、余额充值管理、购买小说管理、下载小说管理、我的…