YOLOv8改进实战 | 注意力篇 | 引入基于跨空间学习的高效多尺度注意力EMA,小目标涨点明显

news2024/11/13 18:40:35

在这里插入图片描述


在这里插入图片描述
YOLOv8专栏导航:点击此处跳转


前言

YOLOv8 是由 YOLOv5 的发布者 Ultralytics 发布的最新版本的 YOLO。它可用于对象检测、分割、分类任务以及大型数据集的学习,并且可以在包括 CPU 和 GPU 在内的各种硬件上执行。

YOLOv8 是一种尖端的、最先进的 (SOTA) 模型,它建立在以前成功的 YOLO 版本的基础上,并引入了新的功能和改进,以进一步提高性能和灵活性。YOLOv8 旨在快速、准确且易于使用,这也使其成为对象检测、图像分割和图像分类任务的绝佳选择。具体创新包括一个新的骨干网络、一个新的 Ancher-Free 检测头和一个新的损失函数,还支持YOLO以往版本,方便不同版本切换和性能对比。


目录

  • 一、EMA介绍
  • 二、代码实现
    • 代码目录
    • 注册模块
    • 配置yaml文件
  • 三、模型测试
  • 四、模型训练
  • 五、总结

一、EMA介绍

在这里插入图片描述

论文链接:Efficient Multi-Scale Attention Module with Cross-Spatial Learning

在这里插入图片描述

论文提出了一种新颖的高效多尺度注意力(EMA)模块。EMA模块旨在保留每个通道的信息,同时减少计算开销。它通过重塑部分通道到批次维度,并将通道雏度分组为多个子特征,使得空间语义特征在每个特征组内均匀分布。此外,EMA模块通过编码全局信息来重新校准每个并行分支中的通道权重,并通过跨维度交互来捕获像素级别的成对关系。

在这里插入图片描述

创新点主要包括:

  1. 高效多尺度注意力(EMA):新型的注意力机制,同时减少计算开销和保留每个通道的关键信息

  2. 通道和批次维度的重组:通过重新组织通道维度和批次维度,提高了模型处理特征的能力。

  3. 跨维度交互:模块利用跨维度的交互来捕捉像素级别的关系

  4. 全局信息编码和通道权重校准:在并行分支中编码全局信息,用于通道权重的重新校准,增强了特征表示的能力。

二、代码实现

代码目录

  • 按下面文件夹结构创建文件(相比于在原有ultralytics/nn/modules文件夹下的相关文件中直接添加便于管理
    - ultralytics
    	- nn
    		- extra_modules
    			- __init__.py
    			- attention.py
    		- modules
    

ultralytics/nn/extra_modules/__init__.py中添加:

from .attention import *

ultralytics/nn/extra_modules/attention.py中添加:

import torch
from torch import nn

__all__ = ['EMA']


class EMA(nn.Module):
    def __init__(self, channels, factor=8):
        super(EMA, self).__init__()
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)

注册模块

ultralytics/nn/tasks.py文件开头添加:

from ultralytics.nn.extra_modules import *

ultralytics/nn/tasks.py文件中parse_model函数添加:

elif m in {EMA}:
    args = [ch[f], *args]

配置yaml文件

yolov8-ema.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
  - [-1, 1, EMA, []]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
  - [-1, 1, EMA, []]

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 23 (P5/32-large)
  - [-1, 1, EMA, []]

  - [[16, 20, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)


三、模型测试

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

model = YOLO("yolov8n-ema.yaml")  # build a new model from scratch
                   from  n    params  module                                       arguments
  0                  -1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  1                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
  2                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
  3                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
  4                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
  5                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
  6                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
  7                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
  8                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
  9                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 10                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 11             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 12                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 13                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 14             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 15                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 16                  -1  1       672  ultralytics.nn.extra_modules.attention.EMA   [64]
 17                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 18            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 19                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 20                  -1  1      2624  ultralytics.nn.extra_modules.attention.EMA   [128]
 21                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 22             [-1, 9]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 23                  -1  1    493056  ultralytics.nn.modules.block.C2f             [384, 256, 1]
 24                  -1  1     10368  ultralytics.nn.extra_modules.attention.EMA   [256]
 25        [16, 20, 24]  1    897664  ultralytics.nn.modules.head.Detect           [80, [64, 128, 256]]
YOLOv8n-ema summary: 249 layers, 3170864 parameters, 3170848 gradients, 9.1 GFLOPs

四、模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

# Load a model
model = YOLO("yolov8n-ema.yaml")  # build a new model from scratch

# Use the model
model.train(
    data="./mydata/data.yaml",
    epochs=300,
    batch=32,
    imgsz=640,
    workers=8,
    device=0,
    project="runs/train",
    name='exp')  # train the model

五、总结

  • 模型的训练具有很大的随机性,您可能需要点运气和更多的训练次数才能达到最高的 mAP。
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2106183.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Centos安装配置Gitea(Ubuntu等系统也可参考)

准备工作 安装好宝塔面板,再进入宝塔面板安装好MySQL,添加配置一个MySQL数据库gitea,用户名和密码也为gitea (也可用命令行做相关操作,自行搜索教程) 通过终端下载安装git,添加普通用户git&a…

数据库管理-第238期 23ai:全球分布式数据库-架构与组件(20240904)

数据库管理238期 2024-09-04 数据库管理-第238期 23ai:全球分布式数据库-架构与组件(20240904)1 架构图2 分片数据库与分片3 Shard Catalog4 Shard Director5 Global Service6 管理界面总结 数据库管理-第238期 23ai:全球分布式数…

效率升级,创意无限:2024年必备录屏软件

随着科技的飞速发展与用户需求的多元化趋势,录屏软件市场迎来了前所未有的繁荣景象,各种功能强大、特色鲜明的软件如雨后春笋般涌现。今天,我们将聚焦于那些如同obs录屏般,能够提供快捷操控体验的专业录屏工具。 1.福昕录屏大师 …

第L5周:机器学习:决策树(分类模型)

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标: 1. 决策树算法是一种在机器学习和数据挖掘领域广泛应用的强大工具,它模拟人类决策过程,通过对数据集进行逐步的分析和…

Chrome 浏览器插件获取网页 window 对象(方案二)

前言 最近有个需求,是在浏览器插件中获取 window 对象下的某个数据,当时觉得很简单,和 document 一样,直接通过嵌入 content_scripts 直接获取,然后使用 sendMessage 发送数据到插件就行了,结果发现不是这…

Python(TensorFlow)和MATLAB及Java光学像差导图

🎯要点 几何光线和波前像差计算入瞳和出瞳及近轴光学计算波前像差特征矩阵方法计算光谱反射率、透射率和吸光度透镜像差和绘制三阶光线像差图和横向剪切干涉图分析瞳孔平面焦平面和大气湍流建模神经网络光学像差计算透镜光线传播几何偏差计算像差和像散色差纠正对齐…

【unity实战】利用Root Motion+Blend Tree+Input System+Cinemachine制作一个简单的角色控制器

文章目录 前言动画设置Blend Tree配置角色添加刚体和碰撞体代码控制人物移动那么我们接下来调整一下相机的视角效果参考完结 前言 Input System知识参考: 【推荐100个unity插件之18】Unity 新版输入系统Input System的使用,看这篇就够了 Cinemachine虚…

嵌入式全栈开发学习笔记---C++(函数/类模板)

目录 函数模板 模板机制 函数模板语法 函数模板和普通函数的区别 函数模板和普通函数调用规则 函数模板机制 排序模板函数 类模板 类模板语法 模板继承 类模板中的static关键字 模板声明 .hpp文件 类模板小结 上节学习了运算符重载,本节开始学习函数模…

使用 GZCTF 结合 GitHub 仓库搭建独立容器与动态 Flag 的 CTF 靶场+基于 Docker 的 Web 出题与部署+容器权限控制

写在前面 关于 CTF 靶场的搭建(使用 CTFd 或者 H1ve)以及 AWD 攻防平台的搭建,勇师傅在前面博客已经详细写过,可以参考我的《网站搭建》专栏,前段时间玩那个 BaseCTF,发现它的界面看着挺不错的&#xff0c…

LVGL 控件之复选框(lv_checkbox)和下拉列表(lv_dropdown)

目录 一、复选框1、组成2、设置复选框文本3、复选框部件的状态4、复选框事件5、API 函数 二、下拉列表1、组成2、选项2.1 添加选项2.2 获取当前选中的选项 3、设置3.1 设置列表展开方向3.2 设置下拉列表图标3.3 设置列表常显文本 4、事件5、API 函数 一、复选框 1、组成 复选…

Android studio 导出 release 版本的 .aar 文件

不同的android studio 版本可能会有不同的方案,我针对的是: 首先打开settings: Setting —> Experimental 界面 将选项:【configure all gradle tasks】勾上: 接着点击 File —> Sync Project with Gradle Files 然后&…

【js逆向专题】8.webpack打包

本教程仅供学习交流使用,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关,请各学员自觉遵守相关法律法规。小节目标: 熟悉 webpack打包原理熟悉 webpack打包方式了解 webpack多模块打包 一. webpack打包 概念: webpack 是…

【颤抖不再怕,帕金森患者的活力锻炼秘籍!】

Hey小伙伴们~👋 今天我们来聊聊一个温暖而重要的话题——如何帮助我们的亲人或自己,在帕金森病的挑战下,依然保持生活的活力与光彩!🌈 帕金森病,这个名字听起来或许让人心生畏惧,但它绝不是生活…

地产行业如何利用Java实现精准营销

在当今竞争激烈的地产市场中,如何有效触达潜在客户并促进销售转化,成为众多房企关注的焦点。106短信平台作为一种精准的营销工具,在地产行业中发挥着越来越重要的作用。 支持免费对接试用:乐讯通PaaS平台 找好用的短信平台,选择乐…

AUTO TECH 2025 华南展 第十二届广州国际汽车零部件加工技术及汽车模具展览会——探索未来出行的创新动力

AUTO TECH 2025 华南展 第十二届广州国际汽车零部件加工技术及汽车模具展览会——探索未来出行的创新动力 随着全球汽车工业的不断进步和新能源汽车技术的迅猛发展,2025年11月20-22日在广州保利世贸博览馆将迎来一场行业瞩目的盛会——2025 第十二届广州国际汽车零部…

外接串口板,通过串口打开adb模式

一、依赖库 import subprocess import serial from serial.tools import list_ports import logging import time 二、代码 import subprocessimport serial from serial.tools import list_ports import logging import timedef openAdb(com):# com []# for i in list_por…

无人机之地面站篇

无人机的地面站,又称无人机控制站,是整个无人机系统的重要组成部分,扮演着作战指挥中心的角色。以下是对无人机地面站的详细阐述: 一、定义与功能 无人机地面站是指具有对无人机飞行平台和任务载荷进行监控和操纵能力的一组设备&…

[数据集][目标检测]翻越栏杆行为检测数据集VOC+YOLO格式512张1类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):512 标注数量(xml文件个数):512 标注数量(txt文件个数):512 标注类别…

通过卷积神经网络(CNN)识别和预测手写数字

一:卷积神经网络(CNN)和手写数字识别MNIST数据集的介绍 卷积神经网络(Convolutional Neural Networks,简称CNN)是一种深度学习模型,它在图像和视频识别、分类和分割任务中表现出色。CNN通过模仿…

快排的深入学习

目录 交换类排序 一、冒泡排序 1. 算法介绍 2.算法流程 3. 算法性能分析 (1)时间复杂度分析 (2) 空间复杂度分析 冒泡排序的特性总结: 二、快速排序 1.算法介绍 2. 执行流程 1). hoare版本 2). 挖坑法 3)…