YOLOv8目标检测算法改进之融合SCconv的特征提取方法

news2024/9/25 11:18:38

引言

YOLO目标检测算法历经发展,目前已经成为了目标检测领域的经典算法。当前,YOLO目标检测算法已经更新到YOLOv10,但从大家的反映来看,YOLOv10的效果并不理想(该算法的创新点是提升检测速度,并不提升精度,去除了NMS的后处理过程),YOLOv9则针对反向传播中距离远的模块学习效果差的问题,提出使用可编程梯度。那么博主今天为何要介绍YOLOv8呢,因为它实在是太好用了。

为什么要学习YOLOv8?

前面博主说YOLOv8好用,那么体现在哪里呢?
首先,相较于其他的YOLO目标检测算法,YOLOv8更像一个算法框架,他里面集成了从v3v9YOLO算法(没错,它有v9)以及RT-DETR,因此环境是通用的,方便我们做对比实验
其次,该框架没有拘泥于目标检测领域,他还包含了分类、分割、姿态估计、追踪等一系列计算机视觉任务。
最后,该算法框架其实都是在调包,改进起来十分简单,方便水创新点。(当然还是希望大家能够知其然知其所以然)
综上所述,无论是从发论文还是项目学习角度来看,YOLOv8真的非常适合大家去学习。

那么我们便开始YOLOv8的学习吧!

YOLOv8模型结构

下面是YOLOv8的模型结构,可以看到,YOLO目标检测模型的整体架构基本已经确定了,其包含骨干特征提取层(Backbone),特征融合模块(Neck),检测头(Head)。

当然,今天博主并不是要讲这些模块的实现和原理,而是从源码的角度告诉大家该如何去改进算法。

在这里插入图片描述

从水论文的角度来看,改进较为简单的便是模块的增改,或是损失函数的替换(这里博主说的是直接用别人的成果,如果你自己提出了新的模块或损失函数的话那是非常厉害的),这种改进就是缝缝补补,没有什么创新的,但事实上,为了应对学业要求,很大部分人不得不采用这种方式去来水一篇论文。

那么,话不多说,我们开整!

添加SCConv模块

本章创新为融合SCconv的特征提取方法,顾名思义就是将SCconv模块融合到YOLOv8的骨干特征提取网络部分(backbone),首先我们了解一些SCconv

SCConvCVPR2023收录的一个即插即用的空间和通道重建卷积模块,其结构如下:

在这里插入图片描述

SCconv论文下载地址

同时,在论文中也提供了实现代码:


'''
Description: 
Date: 2023-07-21 14:36:27
LastEditTime: 2023-07-27 18:41:47
FilePath: /chengdongzhou/ScConv.py
'''
import torch
import torch.nn.functional as F
import torch.nn as nn 


class GroupBatchnorm2d(nn.Module):
    def __init__(self, c_num:int, 
                 group_num:int = 16, 
                 eps:float = 1e-10
                 ):
        super(GroupBatchnorm2d,self).__init__()
        assert c_num    >= group_num
        self.group_num  = group_num
        self.weight     = nn.Parameter( torch.randn(c_num, 1, 1)    )
        self.bias       = nn.Parameter( torch.zeros(c_num, 1, 1)    )
        self.eps        = eps
    def forward(self, x):
        N, C, H, W  = x.size()
        x           = x.view(   N, self.group_num, -1   )
        mean        = x.mean(   dim = 2, keepdim = True )
        std         = x.std (   dim = 2, keepdim = True )
        x           = (x - mean) / (std+self.eps)
        x           = x.view(N, C, H, W)
        return x * self.weight + self.bias


class SRU(nn.Module):
    def __init__(self,
                 oup_channels:int, 
                 group_num:int = 16,
                 gate_treshold:float = 0.5,
                 torch_gn:bool = True
                 ):
        super().__init__()
        
        self.gn             = nn.GroupNorm( num_channels = oup_channels, num_groups = group_num ) if torch_gn else GroupBatchnorm2d(c_num = oup_channels, group_num = group_num)
        self.gate_treshold  = gate_treshold
        self.sigomid        = nn.Sigmoid()

    def forward(self,x):
        gn_x        = self.gn(x)
        w_gamma     = self.gn.weight/sum(self.gn.weight)
        w_gamma     = w_gamma.view(1,-1,1,1)
        reweigts    = self.sigomid( gn_x * w_gamma )
        # Gate
        w1          = torch.where(reweigts > self.gate_treshold, torch.ones_like(reweigts), reweigts) # 大于门限值的设为1,否则保留原值
        w2          = torch.where(reweigts > self.gate_treshold, torch.zeros_like(reweigts), reweigts) # 大于门限值的设为0,否则保留原值
        x_1         = w1 * x
        x_2         = w2 * x
        y           = self.reconstruct(x_1,x_2)
        return y
    
    def reconstruct(self,x_1,x_2):
        x_11,x_12 = torch.split(x_1, x_1.size(1)//2, dim=1)
        x_21,x_22 = torch.split(x_2, x_2.size(1)//2, dim=1)
        return torch.cat([ x_11+x_22, x_12+x_21 ],dim=1)


class CRU(nn.Module):
    '''
    alpha: 0<alpha<1
    '''
    def __init__(self, 
                 op_channel:int,
                 alpha:float = 1/2,
                 squeeze_radio:int = 2 ,
                 group_size:int = 2,
                 group_kernel_size:int = 3,
                 ):
        super().__init__()
        self.up_channel     = up_channel   =   int(alpha*op_channel)
        self.low_channel    = low_channel  =   op_channel-up_channel
        self.squeeze1       = nn.Conv2d(up_channel,up_channel//squeeze_radio,kernel_size=1,bias=False)
        self.squeeze2       = nn.Conv2d(low_channel,low_channel//squeeze_radio,kernel_size=1,bias=False)
        #up
        self.GWC            = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=group_kernel_size, stride=1,padding=group_kernel_size//2, groups = group_size)
        self.PWC1           = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=1, bias=False)
        #low
        self.PWC2           = nn.Conv2d(low_channel//squeeze_radio, op_channel-low_channel//squeeze_radio,kernel_size=1, bias=False)
        self.advavg         = nn.AdaptiveAvgPool2d(1)

    def forward(self,x):
        # Split
        up,low  = torch.split(x,[self.up_channel,self.low_channel],dim=1)
        up,low  = self.squeeze1(up),self.squeeze2(low)
        # Transform
        Y1      = self.GWC(up) + self.PWC1(up)
        Y2      = torch.cat( [self.PWC2(low), low], dim= 1 )
        # Fuse
        out     = torch.cat( [Y1,Y2], dim= 1 )
        out     = F.softmax( self.advavg(out), dim=1 ) * out
        out1,out2 = torch.split(out,out.size(1)//2,dim=1)
        return out1+out2


class ScConv(nn.Module):
    def __init__(self,
                op_channel:int,
                group_num:int = 4,
                gate_treshold:float = 0.5,
                alpha:float = 1/2,
                squeeze_radio:int = 2 ,
                group_size:int = 2,
                group_kernel_size:int = 3,
                 ):
        super().__init__()
        self.SRU = SRU( op_channel, 
                       group_num            = group_num,  
                       gate_treshold        = gate_treshold )
        self.CRU = CRU( op_channel, 
                       alpha                = alpha, 
                       squeeze_radio        = squeeze_radio ,
                       group_size           = group_size ,
                       group_kernel_size    = group_kernel_size )
    
    def forward(self,x):
        x = self.SRU(x)
        x = self.CRU(x)
        return x


if __name__ == '__main__':
    x       = torch.randn(1,32,16,16)
    model   = ScConv(32)
    print(model(x).shape)

当然,博主在这里并不是要对SCconv的结构抑或是背后的逻辑进行解读,博主只是想告诉大家如何将该模块应用到YOLOv8(当然也可以是任意一个YOLOv8里集成的模型),其实,要做到这一点并不难,大家按照这个步骤来即可:

确定添加位置

对于这样的模块,我们一般会将其添加到骨干特征提取模块(Backbone),当然这个位置任意,只要效果好了即可,这里博主选择Backbone中的ConvC2f 模块后面。
这里顺带对模型的配置文件进行讲解,其遵循如下定义规则:

[from, repeats, module, args]表示层的来源、重复次数、模块类型和参数。
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2  -1代表将上层的输入作为本层的输入。第0层的输入是640*640*3的图像。Conv代表卷积层,相应的参数:64代表输出通道数,3代表卷积核大小k,2代表stride步长。 
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 3, C2f, [128, True]]#C2f模块,3代表本层重复3次。128代表输出通道数,True表示Bottleneck有shortcut。
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9

我们首先开启一次训练,看下其结构,我打算将SCconv模块加入到ConvC2f之间

在这里插入图片描述

可以看到,模块的输入维度是由上一个模块的输出维度所确定的,以C2f模块为例,其输入维度即为Conv的输出维度64,那么我们要将SCConv模块加入到ConvC2f之间的话,就需要让SCConv的输入维度为64(对应Conv的输出维度),让SCConv的输出维度为64(对应C2f的输入维度)

确定模块的输入输出

在确定了ConvC2f模块的关联后,我们就需要让SCConv与其进行适配了,首先通过代码来模拟一下其结果:

if __name__ == '__main__':
    x       = torch.randn(1,256,80,80)
    model   = ScConv(256)
    print(model(x).shape)

输出结果为:torch.Size([1, 256, 80, 80]),即输入输出前后没有发生维度上的变化,符合要求。

创建SCConv类

接下来便是如何修改代码,从而让该模块与YOLOv8相融合了
我们在首先在如下文件夹下新建一个mine.py,这个名字随便起,随后将SCConv的代码粘贴到该文件中。

添加引用(可跳过)

当然如果你想要更规范的话,可以将代码粘贴到\ultralytics\nn\modules\block.py文件中,并
在ultralytics/nn/modules/init.py文件中加入引用,这看你个人要求。

from .block import (
    C1,
    C2,
    C3,
    C3TR,
    DFL,
    SPP,
    SPPELAN,
    SPPF,
    ADown,
    BNContrastiveHead,
    Bottleneck,
    BottleneckCSP,
    C2f,
    C2fAttn,
    C3Ghost,
    C3x,
    CBFuse,
    CBLinear,
    ContrastiveHead,
    GhostBottleneck,
    HGBlock,
    HGStem,
    ImagePoolingAttn,
    Proto,
    RepC3,
    RepNCSPELAN4,
    ResNetLayer,
    Silence,
    ScConv,
)
__all__ = (
    "Conv",
    "Conv2",
    "LightConv",
    "RepConv",
    "DWConv",
    "DWConvTranspose2d",
    "ConvTranspose",
    "Focus",
    "GhostConv",
    "ChannelAttention",
    "SpatialAttention",
    "CBAM",
    "Concat",
    "TransformerLayer",
    "TransformerBlock",
    "MLPBlock",
    "LayerNorm2d",
    "DFL",
    "HGBlock",
    "HGStem",
    "SPP",
    "SPPF",
    "C1",
    "C2",
    "C3",
    "C2f",
    "C2fAttn",
    "C3x",
    "C3TR",
    "C3Ghost",
    "GhostBottleneck",
    "Bottleneck",
    "BottleneckCSP",
    "Proto",
    "Detect",
    "Segment",
    "Pose",
    "Classify",
    "TransformerEncoderLayer",
    "RepC3",
    "RTDETRDecoder",
    "AIFI",
    "DeformableTransformerDecoder",
    "DeformableTransformerDecoderLayer",
    "MSDeformAttn",
    "MLP",
    "ResNetLayer",
    "OBB",
    "WorldDetect",
    "ImagePoolingAttn",
    "ContrastiveHead",
    "BNContrastiveHead",
    "RepNCSPELAN4",
    "ADown",
    "SPPELAN",
    "CBFuse",
    "CBLinear",
    "Silence",
    "ScConv",#all中加入引用
)

随后在tasks.py中引入SCConv模块,让配置文件知道该去哪里读取:

from ultralytics.nn.modules.mine import ScConv

在这里插入图片描述

修改模型结构

随后只需要修改对应的yaml配置文件即可:

    [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 1, ScConv, [64]]
  - [-1, 6, C2f, [256, True]]

至此,修改便完成了,是不是非常简单,然后我们训练一下看看是否能够跑通?
在加载模型时,发现只能加载一个,即我们修改了yaml文件,就不能只有YOLOv8的预训练模型,由于模型将重头开始训练,这是十分令人苦恼的。

在这里插入图片描述

随后开启训练,根据下面的网络层数,参数量以及GFLOPs来看,加入了SCconv模块后的计算复杂度提升了很多,这就需要更多的计算资源。

在这里插入图片描述
在这里插入图片描述

由于没有使用预训练模型,因此刚开始的训练效果很差

yi

随着训练轮次的增加,其效果开始提升:

在这里插入图片描述
Closing dataloader mosaic
最终的训练结果如下:

在这里插入图片描述

从结果来看,似乎并不太理想,当然这很大程度上是由于博主没有使用YOLOv8的预训练模型所导致的。

随后,博主让其加载了YOLOv8的预训练权重:

	from ultralytics import YOLO
    model = YOLO("/ultralytics\cfg\models/v8\yolov8.yaml")  # build a new model from YAML
    model.load("yolov8n.pt")
    results = model.train(data="/ultralytics\cfg\datasets\cocomine.yaml", epochs=100, imgsz=640,batch=8,workers=2)

由开始的训练效果可知,使用了预训练模型后,起点高了那么一点,至于最终的效果,敬请期待。

在这里插入图片描述

总结

经过上述过程,我们便将SCconv模块插入到了YOLOv8模型中,当然这个改进是十分简单的,我们可以对SCconv模块再进行改进,让其更加的适配YOLOv8检测模型,同时我们需要记住的是,写一篇论文不但要求你的创新点要新颖,如何去描述你的创新更是重中之重,正所谓做的好不如说的好(博主还是希望能够既做的好,又说得好)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1971276.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JVM: 方法调用

文章目录 一、介绍二、方法调用的原理1、静态绑定2、动态绑定&#xff08;1&#xff09;介绍&#xff08;2&#xff09;原理 一、介绍 在JVM中&#xff0c;一共有五个字节码指令可以执行方法调用&#xff1a; invokestatic: 调用静态方法。invokespecial&#xff1a;调用对象…

大模型参与城市规划中的应用

人工智能咨询培训老师叶梓 转载标明出处 传统的城市规划往往依赖于专业规划师的经验和判断&#xff0c;耗时耗力&#xff0c;且难以满足居民多样化的需求。近年来&#xff0c;大模型&#xff08;LLMs&#xff09;的崛起为城市规划领域带来了新的机遇。清华大学电子工程系的Zhil…

微信小程序多端框架实现app内自动升级

多端框架生成的app&#xff0c;如果实现app内自动升级&#xff1f; 一、Android 实现app自动升级&#xff0c;华为应用市场 1、获取 应用市场地址 下载地址 2、在微信开放平台进行配置 应用下载地址&#xff1a;应用市场点击分享&#xff0c;里面有一个复制连接功能 应用市…

XMLDecoder反序列化

XMLDecoder反序列化 基础知识 就简单讲讲吧&#xff0c;就是为了解析xml内容的 一般我们的xml都是标签属性这样的写法 比如person对象以xml的形式存储在文件中 在decode反序列化方法后&#xff0c;控制台成功打印出反序列化的对象。 就是可以根据我们的标签识别是什么成分…

QT多媒体编程(一)——音频编程知识详解及MP3音频播放器Demo

目录 引言 一、QtMultimedia模块简介 主要类和功能 二、QtMultimedia相关类及函数解析 QAudioInput QAudioOutput QAudioFormat QMediaPlayer QMediaPlaylist QCamera 三、音频项目实战Demo UI界面 核心代码 运行结果 四、结论 引言 在数字时代&#xff0c;音频…

ArcGIS for js 分屏(vue项目)

一、引入依赖 import {onMounted, ref} from "vue"; import Map from "arcgis/core/Map"; import MapView from "arcgis/core/views/MapView"; import WebTileLayer from "arcgis/core/layers/WebTileLayer"; 二、页面布局 <tem…

22. Hibernate 性能之缓存

1. 前言 本节和大家一起聊聊性能优化方案之&#xff1a;缓存。通过本节学习&#xff0c;你将了解到&#xff1a; 什么是缓存&#xff0c;缓存的作用&#xff1b;HIbernate 中的缓存级别&#xff1b;如何使用缓存。 2. 缓存 2.1 缓存是什么 现实世界里&#xff0c;缓存是一个…

纪念二2024.07 federated-解决mysql跨库联表问题

若需要创建FEDERATED引擎表&#xff0c;则目标端实例要开启FEDERATED引擎。从MySQL5.5开始FEDERATED引擎默认安装 只是没有启用&#xff0c;进入命令行输入 show engines ; FEDERATED行状态为NO。 mysql安装配置文件 一、连接工具查看是否开启federated show engines 二、m…

VMware Workstation17 安装 CentOS7 教程

今天给伙伴们分享一下VMware Workstation17 安装 CentOS7 教程&#xff0c;希望看了有所收获。 我是公众号「想吃西红柿」「云原生运维实战派」作者&#xff0c;对云原生运维感兴趣&#xff0c;也保持时刻学习&#xff0c;后续会分享工作中用到的运维技术&#xff0c;在运维的路…

JS【详解】内存泄漏(含泄漏场景、避免方案、检测方法),垃圾回收 GC (含引用计数、标记清除、标记整理、分代式垃圾回收)

内存泄漏 在执行一个长期运行的应用程序时&#xff0c;应用程序分配的内存没有被释放&#xff0c;导致可用内存逐渐减少&#xff0c;最终可能导致浏览器崩溃或者应用性能严重下降的情况&#xff0c;即 JS 内存泄漏 可能导致内存泄漏的场景 不断创建全局变量未及时清理的闭包&…

Graylog 收集网络设备日志的详细配置指南

需求:网络日志接入到日志服务中,做日志的备份和查询。 交换机或是其它网络设备日志需要接入到graylog日志服务中进行备份和查询。 软件版本 graylog5.1 架构图 一、添加inputs 接受日志信息 二、编辑inputs 配置 第1个红框 title 代表通道的名称,您可以根据需要自由定义…

【CTF-Crypto】格密码基础(例题较多,非常适合入门!)

格密码相关 文章目录 格密码相关格密码基本概念&#xff08;属于后量子密码&#xff09;基础的格运算&#xff08;行列式运算&#xff09;SVP&#xff08;shortest Vector Problem&#xff09;最短向量问题CVP&#xff08;Closet Vector Problem&#xff09;最近向量问题 做题要…

浏览器用户文件夹详解 - ShortCuts(六)

1. Shortcuts简介 1.1 什么是Shortcuts文件&#xff1f; Shortcuts文件是Chromium浏览器中用于存储用户创建的快捷方式信息的一个重要文件。每当用户在浏览器中创建快捷方式时&#xff0c;这些信息都会被记录在Shortcuts文件中。通过这些记录&#xff0c;用户可以方便地快速访…

《小迪安全》学习笔记02

域名默认存放目录和IP默认存放目录不一样。 IP地址是WWW文件里的&#xff0c;域名访问是WWW里的一个子目录里的&#xff08;比如是blog&#xff09;。 Nmap: Web源码拓展 拿到一个网站的源码&#xff0c;要分析这几个方面↑。 不同类型产生的漏洞类型也不一样 在网站中&…

MSPM0G3507_2024电赛自动行驶小车(H题)_问题与感悟

这次电赛题目选的简单了&#xff0c;还规定不能使用到摄像头&#xff0c;这让我之前学习的Opencv 4与树莓派无用武之地了&#xff0c;但我当时对于三子棋题目饶有兴趣&#xff0c;但架不住队友想稳奖&#xff0c;只能选择这个H题了...... 之后我还想抽空将这个E题三子棋题目做…

快手批量取关

目录 突然发现快手木有批量取关功能&#xff0c;没有功能就创造功能 执行代码中 逐渐变少 后面关注列表没人了&#xff0c;总数还有32&#xff0c;不知道是不是帮测出个bug还是咋的(^_^) 突然发现快手木有批量取关功能&#xff0c;没有功能就创造功能 刚开始1000多人 执行代…

中间件之异步通讯组件rocketmq入门

一、概述 1.1介绍 RocketMQ是阿里巴巴2016年MQ中间件&#xff0c;使用Java语言开发&#xff0c;RocketMQ 是一款开源的分布式消息系统&#xff0c;基于高可用分布式集群技术&#xff0c;提供低延时的、高可靠的消息发布与订阅服务。同时&#xff0c;广泛应用于多个领域&#…

暖水袋 亚马逊日本站认证 PSE认证步骤

暖水袋是用来加热取暖的生活用品&#xff0c;有内置热水来加热的类型和利用微波炉加热后使用的类型等。内置热水的暖水袋有塑料制、橡胶制、陶器制等多种类型&#xff0c;但是利用加热石头而不是利用热水来取暖的产品类型为审查对象外商品。 审查资料 每个 ASIN 的文件&#x…

成为AI产品经理,为何应选择LLMs方向?

前言 随着人工智能&#xff08;AI&#xff09;技术的快速发展&#xff0c;越来越多的人开始考虑如何在这个领域找到自己的位置。对于那些希望成为AI产品经理的人来说&#xff0c;选择LLMs&#xff08;Large Language Models&#xff0c;大型语言模型&#xff09;方向是一个非常…

mac下通过brew安装mysql的环境调试

mac安装mysql 打开终端&#xff0c;运行命令&#xff08;必须已经装过homebrew哦&#xff09;&#xff1a; 安装brewbin/zsh -c "$(curl -fsSL https://gitee.com/cunkai/HomebrewCN/raw/master/Homebrew.sh)"已安装brew直接运行&#xff1a;brew install mysql8.0报…