YOLOv8改进 - 注意力篇 - 引入EMA注意力机制

news2024/10/7 6:05:25

一、本文介绍

作为入门性篇章,这里介绍了EMA注意力在YOLOv8中的使用。包含EMA原理分析,EMA的代码、EMA的使用方法、以及添加以后的yaml文件及运行记录。

二、EMA原理分析

EMA官方论文地址:EMA文章

EMA代码:EMA代码

EMA注意力机制(高效的多尺度注意力):通过重塑部分通道到批次维度,并将通道维度分组为多个子特征,以保留每个通道的信息并减少计算开销。EMA模块通过编码全局信息来重新校准每个并行分支中的通道权重,并通过跨维度交互来捕获像素级别的关系。

相关代码:

EMA注意力的代码,如下:

class EMA_attention(nn.Module):
    def __init__(self, channels, c2=None, factor=32):
        super(EMA_attention, self).__init__()
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)

四、YOLOv8中EMA使用方法

1.YOLOv8中添加EMA模块:

首先在ultralytics/nn/modules/conv.py最后添加EMA模块的代码。

2.在conv.py的开头__all__ = 内添加EMA模块的类别名EMA_attention:

3.在同级文件夹下的__init__.py内添加EMA的相关内容:(分别是from .conv import EMA_attention ;以及在__all__内添加EMA_attention)

4.在ultralytics/nn/tasks.py进行EMA_attention注意力机制的注册,以及在YOLOv8的yaml配置文件中添加EMA_attention即可。

首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面(或者直接加载含有c2f模块的注册代码内,注册方式一样)添加以下注册代码:

        if m in (EMA_attention):
            c1, c2 = ch[f], args[0]

然后,就是新建一个名为YOLOv8_EMA.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_EMA.yaml)其中参数中nc,由自己的数据集决定。本文测试,采用的coco8数据集,有80个类别。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call CPAM-yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, EMA_attention, [1024]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

在根目录新建一个train.py文件,内容如下:

from ultralytics import YOLO
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    model = YOLO('ultralytics/cfg/models/v8/YOLOv8_EMA.yaml')  # 从YAML建立一个新模型
    results = model.train(data='ultralytics/cfg/datasets/coco8.yaml', epochs=1,imgsz=640,optimizer="SGD")

训练输出:​

五、总结

以上就是EMA的原理及使用方式,但具体EMA注意力机制的具体位置放哪里,效果更好。需要根据不同的数据集做相应的实验验证。希望本文能够帮助你入门YOLO中注意力机制的使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2193639.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Oracle中MONTHS_BETWEEN()函数详解

文章目录 前言一、MONTHS_BETWEEN()的语法二、主要用途三、测试用例总结 前言 在Oracle数据库中,MONTHS_BETWEEN()函数可以用来计算两个日期之间的月份差。它返回一个浮点数,表示两个日期之间的整月数。 一、MONTHS_BETWEEN()的语法 MONTHS_BETWEEN(dat…

毕业设计项目 基于大数据人才岗位数据分析

文章目录 1 前言1. 数据集说明2. 数据处理2.1 数据清洗2.2 数据导入 3. 数据分析可视化3.1 整体情况(招聘企业数、岗位数、招聘人数、平均工资)3.2 企业主题行业情况公司类型最缺人的公司 TOP平均薪资最高的公司 TOP工作时间工作地点福利词云 3.3 岗位主…

晶体管最佳效率区域随频率逆时针旋转原因分析

晶体管最佳效率区域随频率逆时针旋转原因分析 在功率放大器的设计时,晶体管最佳区域随频率逆时针旋转。但是,对于一般的微带电路,匹配阻抗区域是随着频率顺时针旋转的(也有称这个特性是Foster特性),因此功…

运动员场景分割系统源码&数据集分享

运动员场景分割系统源码&数据集分享 [yolov8-seg-HGNetV2&yolov8-seg-aux等50全套改进创新点发刊_一键训练教程_Web前端展示] 1.研究背景与意义 项目参考ILSVRC ImageNet Large Scale Visual Recognition Challenge 项目来源AAAI Global Al lnnovati…

基于SpringBoot+Vue+MySQL的在线学习交流平台

系统展示 用户前台界面 管理员后台界面 系统背景 随着互联网技术的飞速发展,在线学习已成为现代教育的重要组成部分。传统的面对面教学方式已无法满足广大学习者的需求,特别是在时间、地点上受限的学习者。因此,构建一个基于SpringBoot、Vue.…

CPU 多级缓存

在多线程并发场景下,普通的累加很可能错的 CPU 多级缓存 Main Memory : 主存Cache : 高速缓存,数据的读取存储都经过此高速缓存CPU Core : CPU 核心Bus : 系统总线 CPU Core 和 Cache 通过快速通道连接,Main menory 和 Cache 都挂载到 Bus 上…

gm/ID设计方法学习笔记(二)

一、任务 设计一个二级运放,第一级为有源负载差动对(五管OTA),第二级为电流源负载的共源极。 二、参数指标 GBW≥50MHz|Av|≥80dBPM60~70SR≥50V/us10pF 本文使用smic13mmrf_1233工艺库进行设计。 三、电路设计 (…

【进阶OpenCV】 (6)--指纹识别

文章目录 指纹识别1. 计算指纹间匹配点的个数2. 获取指纹编号3. 获取对应姓名4. 代码实现 总结 指纹识别 假设,现在我们有一个小的指纹库,此时,有一个指纹图片需要我们识别是不是指纹库中某一个人的。如果是,是谁的呢&#xff1f…

力扣110:判断二叉树是否为平衡二叉树

利用二叉树遍历的思想编写一个判断二叉树,是否为平衡二叉树 示例 : 输入:root [3,9,20,null,null,15,7] 输出:true思想: 代码: int getDepth(struct TreeNode* node) {//如果结点不存在,返回…

有趣幽默彩虹屁文案生成工具微信小程序源码

有趣幽默彩虹屁文案生成工具小程序源码 此文案小程序主要功能为分享各种有趣幽默的文案 免服务器免域名,源码只提供彩虹屁,朋友圈,毒鸡汤API接口,其他需自行查找替代 小程序拥有复制收藏功能,可自行体验,设…

FineReport 11 在线学习

文章目录 学习路线图FineReport 11 在线学习资源链接分享帆软report 特点 学习路线图 学习生态 自测题 FineReport 11 在线学习资源链接分享 帮助中心https://help.fanruan.com/finereport/ FineReport 入门学习路径https://edu.fanruan.com/guide/finereport 普通报表…

【c++】string类 (一)

简介 由于c的历史包袱,c要兼容c语言,c的字符串要兼容c语言,在 C 中,字符串通常使用两种主要的方式来表示: C风格字符串(C-style strings): 依然是以 \0 结尾的字符数组。这种表示方…

【Java 并发编程】初识多线程

前言 到目前为止,我们学到的都是有关 “顺序” 编程的知识,即程序中所有事物在任意时刻都只能执行一个步骤。例如:在我们的 main 方法中,都是多个操作以 “从上至下” 的顺序调用方法以至结束的。 虽然 “顺序” 编程能够解决相当…

YOLO11改进|卷积篇|引入SPDConv

目录 一、【SPD】卷积1.1【SPD】卷积介绍1.2【SPD】核心代码 二、添加【SPD】卷积2.1STEP12.2STEP22.3STEP32.4STEP4 三、yaml文件与运行3.1yaml文件3.2运行成功截图 一、【SPD】卷积 1.1【SPD】卷积介绍 SPD-Conv卷积的结构图如下,下面我们简单分析一下其处理过程…

贪心算法.

序幕 贪心算法(Greedy Algorithm)是一种在求解问题时采取逐步构建解决方案的策略,每一步都选择当前状态下局部最优的解,期望通过局部最优解能够得到全局最优解。 以上为了严谨性,引用了官方用语。 而用大白话总结就是&…

如何移除 iPhone 上的网络锁?本文筛选了一些适合您的工具

您是否对 iPhone 运营商的网络感到困惑?不用担心,我们将向您介绍 8 大免费 iPhone 解锁服务。这些工具可以帮助您移除 iPhone 上的网络锁,并使您能够永久在网络上使用您的设备。如果您想免费解锁 iPhone,请阅读本文并找到最适合您…

APC论文总结

论文详情 论文标题:APC: Adaptive Patch Contrast for Weakly Supervised Semantic Segmentation 论文作者:Wangyu Wu,Tianhong Dai,Zhenhong Chen,Xiaowei Huang,Fei Ma,Jimin Xiao 发表时间…

股票期货高频数据获取方法

导语:在量化交易领域,获取高质量的股票期货高频数据是进行有效分析和策略开发的基础。本文从专业量化角度出发,介绍了几种获取股票期货高频数据的方法。一、使用专业数据供应商1. 【推荐】**银河数据库(yinhedata.com)…

【Blender Python】6.修改物体模式

概述 Blender对象共有6种编辑模式,物体模式、编辑模式、雕刻模式、顶点绘制、权重绘制和纹理绘制。 在Blender Python中通过bpy.ops.object的mode_set()方法可以修改物体的编辑模式。只需要传入相应的mode参数就行了。 让物体进入编辑模式 >>> bpy.ops.…

leetcode 力扣算法题 快慢指针 双指针 19.删除链表的倒数第n个结点

删除链表的倒数第N个结点 题目要求题目示例解题思路从题目中的已知出发思考寻找目标结点条件转换核心思路 需要注意的点改进建议 完整代码提交结果 题目要求 给你一个链表,删除链表的倒数第n个结点,并且返回链表的头结点。 题目示例 示例 1&#xff1…