YOLOv8改进实战 | 引入多维协作注意模块MCA,实现暴力涨点

news2024/11/15 17:35:57

在这里插入图片描述


在这里插入图片描述
YOLOv8专栏导航:点击此处跳转


前言
YOLOv8 是由 YOLOv5 的发布者 Ultralytics 发布的最新版本的 YOLO。它可用于对象检测、分割、分类任务以及大型数据集的学习,并且可以在包括 CPU 和 GPU 在内的各种硬件上执行。

YOLOv8是一种尖端的、最先进的 (SOTA) 模型,它建立在以前成功的 YOLO 版本的基础上,并引入了新的功能和改进,以进一步提高性能和灵活性。YOLOv8 旨在快速、准确且易于使用,这也使其成为对象检测、图像分割和图像分类任务的绝佳选择。具体创新包括一个新的骨干网络、一个新的 Ancher-Free 检测头和一个新的损失函数,还支持YOLO以往版本,方便不同版本切换和性能对比。


目录

  • 一、MCANet介绍
  • 二、代码实现
    • 代码目录
    • 注册模块
    • 配置yaml文件
  • 三、模型测试
  • 四、模型训练
  • 五、总结

一、MCANet介绍

在这里插入图片描述

论文链接:MCA: Multidimensional collaborative attention in deep convolutional neural networks for image recognition
Pytorch code:MCANet

在这里插入图片描述

作者提出了一种基于高效轴向注意力的多尺度交叉轴注意(MCA)方法。MCA通过计算两个并行轴向注意力之间的双向交叉注意力,以更好地捕获全局信息。此外,为了处理病变区域或器官在个体大小和形状上的显著变化,还在每个轴向注意力路径中使用不同大小的条形卷积核进行多次卷积,以提高编码空间信息的效率。

在这里插入图片描述
顶部分支用于捕捉空间维度 W 中特征之间的交互作用。同样,中间分支用于捕捉空间维度 H 中特征之间的交互作用。底部分支负责捕捉通道之间的交互作用。在前两个分支中,我们使用置换操作来捕捉通道维度和两个空间维度之间的长程依赖关系。最后,在集成阶段,来自所有三个分支的输出通过简单平均进行聚合。

在这里插入图片描述

二、代码实现

代码目录

  • 按下面文件夹结构创建文件(相比于在原有ultralytics/nn/modules文件夹下的相关文件中直接添加便于管理
    - ultralytics
    	- nn
    		- extra_modules
    			- __init__.py
    			- attention.py
    		- modules
    

ultralytics/nn/extra_modules/__init__.py中添加:

from .attention import *

ultralytics/nn/extra_modules/attention.py中添加:

import torch
from torch import nn
import math

__all__ = ['MCALayer', 'MCAGate']


class StdPool(nn.Module):
    def __init__(self):
        super(StdPool, self).__init__()

    def forward(self, x):
        b, c, _, _ = x.size()

        std = x.view(b, c, -1).std(dim=2, keepdim=True)
        std = std.reshape(b, c, 1, 1)

        return std


class MCAGate(nn.Module):
    def __init__(self, k_size, pool_types=['avg', 'std']):
        super(MCAGate, self).__init__()

        self.pools = nn.ModuleList([])
        for pool_type in pool_types:
            if pool_type == 'avg':
                self.pools.append(nn.AdaptiveAvgPool2d(1))
            elif pool_type == 'max':
                self.pools.append(nn.AdaptiveMaxPool2d(1))
            elif pool_type == 'std':
                self.pools.append(StdPool())
            else:
                raise NotImplementedError

        self.conv = nn.Conv2d(1, 1, kernel_size=(1, k_size), stride=1, padding=(0, (k_size - 1) // 2), bias=False)
        self.sigmoid = nn.Sigmoid()

        self.weight = nn.Parameter(torch.rand(2))

    def forward(self, x):
        feats = [pool(x) for pool in self.pools]

        if len(feats) == 1:
            out = feats[0]
        elif len(feats) == 2:
            weight = torch.sigmoid(self.weight)
            out = 1 / 2 * (feats[0] + feats[1]) + weight[0] * feats[0] + weight[1] * feats[1]
        else:
            assert False, "Feature Extraction Exception!"

        out = out.permute(0, 3, 2, 1).contiguous()
        out = self.conv(out)
        out = out.permute(0, 3, 2, 1).contiguous()

        out = self.sigmoid(out)
        out = out.expand_as(x)

        return x * out


class MCALayer(nn.Module):
    def __init__(self, inp, no_spatial=False):
        super(MCALayer, self).__init__()

        lambd = 1.5
        gamma = 1
        temp = round(abs((math.log2(inp) - gamma) / lambd))
        kernel = temp if temp % 2 else temp - 1

        self.h_cw = MCAGate(3)
        self.w_hc = MCAGate(3)
        self.no_spatial = no_spatial
        if not no_spatial:
            self.c_hw = MCAGate(kernel)

    def forward(self, x):
        x_h = x.permute(0, 2, 1, 3).contiguous()
        x_h = self.h_cw(x_h)
        x_h = x_h.permute(0, 2, 1, 3).contiguous()

        x_w = x.permute(0, 3, 2, 1).contiguous()
        x_w = self.w_hc(x_w)
        x_w = x_w.permute(0, 3, 2, 1).contiguous()

        if not self.no_spatial:
            x_c = self.c_hw(x)
            x_out = 1 / 3 * (x_c + x_h + x_w)
        else:
            x_out = 1 / 2 * (x_h + x_w)

        return x_out

注册模块

ultralytics/nn/tasks.py文件开头添加:

from ultralytics.nn.extra_modules import *

ultralytics/nn/tasks.py文件中parse_model函数添加:

elif m in {MCALayer}:
    args = [ch[f], *args]

配置yaml文件

yolov8-mca.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
  - [-1, 1, MCALayer, []]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
  - [-1, 1, MCALayer, []]

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 23 (P5/32-large)
  - [-1, 1, MCALayer, []]

  - [[16, 20, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

三、模型测试

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

model = YOLO("yolov8n-mca.yaml")  # build a new model from scratch
                   from  n    params  module                                       arguments
  0                  -1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  1                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
  2                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
  3                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
  4                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
  5                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
  6                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
  7                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
  8                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
  9                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 10                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 11             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 12                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 13                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 14             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 15                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 16                  -1  1        15  ultralytics.nn.extra_modules.attention.MCALayer[64]
 17                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 18            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 19                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 20                  -1  1        15  ultralytics.nn.extra_modules.attention.MCALayer[128]
 21                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 22             [-1, 9]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 23                  -1  1    493056  ultralytics.nn.modules.block.C2f             [384, 256, 1]
 24                  -1  1        17  ultralytics.nn.extra_modules.attention.MCALayer[256]
 25        [16, 20, 24]  1    751507  ultralytics.nn.modules.head.Detect           [1, [64, 128, 256]]
YOLOv8n-mca summary: 282 layers, 3011090 parameters, 3011074 gradients, 8.2 GFLOPs

四、模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

# Load a model
model = YOLO("yolov8n-mca.yaml")  # build a new model from scratch

# Use the model
model.train(
    data="./mydata/data.yaml",
    epochs=300,
    batch=32,
    imgsz=640,
    workers=8,
    device=0,
    project="runs/train",
    name='exp')  # train the model

五、总结

  • 模型的训练具有很大的随机性,您可能需要点运气和更多的训练次数才能达到最高的 mAP。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2102492.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows安装anaconda注意事项及jupyter notebook更换目录

anaconda的介绍就不罗嗦了,既然准备安装了,说明你已经有所了解了。直入主题,Anaconda官网下载,实在太慢,可到https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/下载,注意,这是清华镜像站…

Mysql中的锁机制详解

一、概述 锁是计算机协调多个进程或线程并发访问某一资源的机制。 在数据库中,除了传统的计算资源(如CPU、RAM、I/O等)的争用以外,数据也是一种供需要用户共享的资源。如何保证数据并发访问的一致性、有效性是所有数据库必须解决…

新电脑Win11系统想要降级为Win10怎么操作?

前言 现在的电脑大部分都是Windows 11系统,组装机还好一些,如果想要使用Windows 10,只需要在安装系统的时候选择Windows 10镜像即可。 但是对于新笔记本、厂商的成品机、一体机来说,只要是全新的电脑,基本上都是Wind…

黑马JavaWeb开发笔记14——Tomcat(介绍、安装与卸载、启动与关闭)、入门程序解析(起步依赖、SpringBoot父工程、内嵌Tomcat)

文章目录 前言一、Web服务器-Tomcat1. 简介1.1服务器概述1.2 Web服务器1.3 Tomcat 2. 基本使用2.1 下载2.2 安装与卸载2.3 启动与关闭2.4 常见问题 二、入门程序解析1. 起步依赖2. SpringBoot父工程3. 内嵌Tomcat 总结 前言 本篇文章是2023年最新黑马JavaWeb开发笔记14&#x…

[Java]SpringBoot登录认证流程详解

登录认证 登录接口 1.查看原型 2.查看接口 3.思路分析 登录核心就是根据用户名和密码查询用户信息,存在则登录成功, 不存在则登录失败 4.Controller Slf4j RestController public class LoginController {Autowiredprivate EmpService empService;/*** 登录的方法** param …

【Python】一文详细向您介绍 bisect_left 函数

【Python】一文详细向您介绍 bisect_left 函数 下滑即可查看博客内容 🌈 欢迎莅临我的个人主页 👈这里是我静心耕耘深度学习领域、真诚分享知识与智慧的小天地!🎇 🎓 博主简介:985高校的普通本硕&#x…

shell脚本1----编程规范与变量

shell脚本 shell的功能 Shell(壳程序)是一个特殊的应用程序,它介于操作系统内核与用户之间,充当了一个“命令解释器”的角色,负责接收用户输入的操作指令(命令)并进行解释,将需要执…

(前端)面试300问之(3)this的指向判断

一、this的相关理解与解读 1、各角度看this。 1)ECMAScript规范: this 关键字执行为当前执行环境的 ThisBinding。 2)MDN: In most cases, the value of this is determined by how a function is called. 在绝大多数情况下&…

图片损坏,如何修复?

在数字化时代,图片已成为我们日常生活和工作中不可或缺的一部分。然而,有时我们可能会遇到图片损坏的情况,无论是珍贵的家庭照片、工作文档中的关键图像,还是社交媒体上的分享内容,图片损坏都可能带来不小的困扰。那么…

网络传输加密及openssl使用样例(客户端服务器)

文章目录 背景常用加密方式SSLOpenSSL主要功能 库结构 交互流程证书生成生成 RSA 私钥私钥的主要组成部分私钥的格式 创建自签名证书: 签发证书服务器端代码客户端代码常见错误版本问题证书问题证书格式 背景 网络传输中为保证数据安全,通常需要加密 常用加密方式…

Open3D 基于曲率大小的特征点提取

目录 一、概述 1.1原理 1.2实现步骤 1.3应用场景 二、代码实现 三、实现效果 3.1原始点云 3.2提取特征点 Open3D点云算法汇总及实战案例汇总的目录地址: Open3D点云算法与点云深度学习案例汇总(长期更新)-CSDN博客 一、概述 基于曲率…

STM32 外部中断(EXTI)

STM32 外部中断(EXTI) 实验:配置一个引脚的下降沿作为外部中断。 参考:江协科技 相关缩写 RCC(Reset and Clock Control) 复位和时钟控制 GPIO(General Purpose Input/Output) 通用输入/输出 AFIO(Alternate Function Input Output) 复用功能输入输…

6.Lab five —— Lazy Page Allocation

首先先切换到lazy分支 git checkout lazy make clean Xv6应用程序使用sbrk()系统调用向内核请求堆内存。sbrk()分配物理内存并将其映射到进程的虚拟地址空间。内核为一个大请求分配和映射内存可能需要很长时间。为了提高效率,故采用懒分配的策略 Eliminate alloc…

Scratch在线玩:3D地铁跑酷

小虎鲸Scratch资源站-免费Scratch作品源码,素材,教程分享平台! 作品介绍: 欢迎体验在 Scratch 上重新制作的 3D 地铁跑酷游戏!这款游戏完全采用 3D 技术打造,带来流畅的视觉效果和出色的游戏体验。游戏的目标是避免列车和障碍物,同…

力扣 1419. 数青蛙

力扣 1419. 数青蛙 1. 题目 2. 思路 本题就是一道 字符串模拟题; 题目说到了, 会混杂着青蛙的叫声, 如果字符串 croakOfFrogs 不是由若干有效的 “croak” 字符混合而成,请返回 -1, 那就是说如果有多余的 c, r, o等等, 比如 &quo…

【机器学习】机器学习引领未来:赋能精准高效的图像识别技术革新

📝个人主页🌹:Eternity._ 🌹🌹期待您的关注 🌹🌹 ❀目录 🔍1. 引言📒2. 机器学习基础与图像识别原理🍁机器学习概述:监督学习、无监督学习与强化学…

「深入理解」HTML Meta标签:网页元信息的配置

「深入理解」HTML Meta标签&#xff1a;网页元信息的配置 HTML的<meta>元素用于提供关于HTML文档的元数据&#xff08;metadata&#xff09;&#xff0c;这些信息对于浏览器和其他处理HTML文档的应用程序来说是非常有用的&#xff0c;如&#xff1a;<base>、<li…

虚幻引擎VR游戏开发02 | 性能优化设置

常识&#xff1a;VR需要保持至少90 FPS的刷新率&#xff0c;以避免用户体验到延迟或晕眩感。以下是优化性能的一系列设置&#xff08;make sure the frame rate does not drop below a certain threshold&#xff09; In project setting-> &#xff08;以下十个设置都在pr…

用于全栈自动化测试的最佳Python工具

我知道大多数测试人员会说Java是他们创建自动化测试的首选语言。 但是我最喜欢的是Python。为什么?为什么是Python ? Al Sweigart&#xff0c;《自动化那些无聊的东西》的作者&#xff0c;Python一直是他的首选语言&#xff0c;因为:它有一个温和的学习曲线。它适用于Windows…

42.哀家要长脑子了!

1.965. 单值二叉树 - 力扣&#xff08;LeetCode&#xff09; 深度优先搜索&#xff0c;看边两端的结点是不是一样的值 class Solution { public:bool isUnivalTree(TreeNode* root) {if(!root) return true;if(root->right) {if(root->val ! root->right->val || …