YOLOv8改进 | 卷积模块 | SAConv可切换空洞卷积

news2024/10/6 14:29:10

秋招面试专栏推荐 :深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转 


💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡


专栏目录 :《YOLOv8改进有效涨点》专栏介绍 & 专栏目录 | 目前已有40+篇内容,内含各种Head检测头、损失函数Loss、Backbone、Neck、NMS等创新点改进——点击即可跳转


许多现代目标检测器通过使用“观察和思考两次”的机制展现了卓越的性能。我们在目标检测的主干设计中探讨了这种机制。在宏观层面上,提出了递归特征金字塔(Recursive Feature Pyramid),它将额外的反馈连接从特征金字塔网络整合到底向上的主干层中。在微观层面上,提出了可切换空洞卷积(Switchable Atrous Convolution),它用不同的空洞率对特征进行卷积,并使用切换函数收集结果。将它们结合起来就得到了DetectoRS,它显著提高了目标检测的性能。 文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

目录

1. 原理

2. 将PConv添加到YOLOv8代码

2.1 PConv代码实现

2.2 更改init.py文件 

2.3 新增yaml文件

2.4 注册模块

2.5 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6. 总结


1. 原理

论文地址:DetectoRS: Detecting Objects with Recursive Feature Pyramid and Switchable Atrous Convolutio——点击即可跳转
官方代码:官方代码仓库——点击即可跳转

可切换空洞卷积 (SAC) 是一种复杂的卷积技术,用于增强深度学习模型的性能,特别是在对象检测和分割等任务中。以下是根据提供的文档对其主要原理的细分:

空洞卷积

空洞卷积,也称为扩张卷积,用于增加过滤器的视野,而不会增加参数数量或计算量。这是通过在连续的过滤器值之间引入零来实现的。插入这些零的速率称为空洞速率 ( r )。对于大小为 ( k \times k ) 的过滤器,有效内核大小变为 ( k + (k-1)(r-1) )。这允许网络捕获多尺度信息,这对于检测不同大小的物体特别有用。

可切换空洞卷积 (SAC)

SAC 以空洞卷积为基础,引入了一种在卷积操作期间动态切换不同空洞率的机制。其核心思想是允许卷积层根据输入特征自适应地选择适当的空洞率。这种适应性有助于模型更好地处理同一图像中不同尺度的对象。

SAC 的组件

SAC 由三个主要组件组成:主 SAC 组件和两个附加在 SAC 组件之前和之后的全局上下文模块。

  • 主 SAC 组件:这是实际发生可切换空洞卷积的核心部分。它涉及两个具有不同空洞率的卷积操作的加权组合。

SAC 的数学公式如下: y = S(x) \cdot \text{Conv}(x, w, 1) + (1 - S(x)) \cdot \text{Conv}(x, w + \Delta w, r) 其中:

  • ( x ) 是输入。

  • ( w )( w + \Delta w ) 是卷积运算的权重。

  • ( r ) 是空洞率。

  • ( S(x) ) 是一个开关函数,实现为一个平均池化层,具有 5x5 内核,后跟一个 1x1 卷积层,确定两个卷积运算之间的平衡。

  • 全局上下文模块:这些模块在 SAC 操作前后为特征添加了图像级上下文,增强了模型理解图像全局结构的能力。

SAC的优势

  1. 多尺度检测:通过动态切换不同的空洞率,SAC 使模型能够更好地检测各种尺寸的物体。

  2. 参数效率:SAC 在不增加额外参数的情况下增加了感受野,同时保持了计算效率。

  3. 适应性:切换功能允许卷积层根据输入进行适应,使模型更加灵活,能够处理不同的图像和物体尺度。

实现示例

SAConv结构

骨干网络(例如 ResNet)中的每个 3x3 卷积层都转换为 SAC,然后执行具有不同空洞率的卷积的加权组合。

综上所述,可切换空洞卷积结合了空洞卷积和动态适应性的优点,增强了深度学习模型,提高了其处理不同物体尺度图像的能力,同时保持参数和计算方面的效率。

2. 将PConv添加到YOLOv8代码

2.1 PConv代码实现

关键步骤一:将下面代码粘贴到在/ultralytics/ultralytics/nn/modules/conv.py中,并在该文件的__all__中添加“SAConv”

class ConvAWS2d(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias)
        self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1))
        self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1))

    def _get_weight(self, weight):
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                                            keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)
        weight = weight / std
        weight = self.weight_gamma * weight + self.weight_beta
        return weight

    def forward(self, x):
        weight = self._get_weight(self.weight)
        return super()._conv_forward(x, weight, None)

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        self.weight_gamma.data.fill_(-1)
        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
                                      missing_keys, unexpected_keys, error_msgs)
        if self.weight_gamma.data.mean() > 0:
            return
        weight = self.weight.data
        weight_mean = weight.data.mean(dim=1, keepdim=True).mean(dim=2,
                                                                 keepdim=True).mean(dim=3, keepdim=True)
        self.weight_beta.data.copy_(weight_mean)
        std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)
        self.weight_gamma.data.copy_(std)


class SAConv2d(ConvAWS2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 s=1,
                 p=None,
                 g=1,
                 d=1,
                 act=True,
                 bias=True):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=s,
            padding=autopad(kernel_size, p),
            dilation=d,
            groups=g,
            bias=bias)
        self.switch = torch.nn.Conv2d(
            self.in_channels,
            1,
            kernel_size=1,
            stride=s,
            bias=True)
        self.switch.weight.data.fill_(0)
        self.switch.bias.data.fill_(1)
        self.weight_diff = torch.nn.Parameter(torch.Tensor(self.weight.size()))
        self.weight_diff.data.zero_()
        self.pre_context = torch.nn.Conv2d(
            self.in_channels,
            self.in_channels,
            kernel_size=1,
            bias=True)
        self.pre_context.weight.data.fill_(0)
        self.pre_context.bias.data.fill_(0)
        self.post_context = torch.nn.Conv2d(
            self.out_channels,
            self.out_channels,
            kernel_size=1,
            bias=True)
        self.post_context.weight.data.fill_(0)
        self.post_context.bias.data.fill_(0)

        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        # pre-context
        avg_x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1)
        avg_x = self.pre_context(avg_x)
        avg_x = avg_x.expand_as(x)
        x = x + avg_x
        # switch
        avg_x = torch.nn.functional.pad(x, pad=(2, 2, 2, 2), mode="reflect")
        avg_x = torch.nn.functional.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
        switch = self.switch(avg_x)
        # sac
        weight = self._get_weight(self.weight)
        out_s = super()._conv_forward(x, weight, None)
        ori_p = self.padding
        ori_d = self.dilation
        self.padding = tuple(3 * p for p in self.padding)
        self.dilation = tuple(3 * d for d in self.dilation)
        weight = weight + self.weight_diff
        out_l = super()._conv_forward(x, weight, None)
        out = switch * out_s + (1 - switch) * out_l
        self.padding = ori_p
        self.dilation = ori_d
        # post-context
        avg_x = torch.nn.functional.adaptive_avg_pool2d(out, output_size=1)
        avg_x = self.post_context(avg_x)
        avg_x = avg_x.expand_as(out)
        out = out + avg_x
        return self.act(self.bn(out))

可切换空洞卷积 (SAconv) 通过动态调整卷积层中的空洞率来增强图像处理,从而实现更有效的多尺度特征提取。 SAconv 处理图像的主要工作流程概述:

SAconv 的主要工作流程

  1. 输入图像: 该过程从输入到神经网络的输入图像开始。

  2. 初始卷积层: 图像首先经过几个标准卷积层,提取边缘、纹理和基本形状等低级特征。这些层通常具有固定的内核大小和步长值。

  3. 可切换空洞卷积 (SAC) 模块: 核心组件 SAC 模块取代了网络主干中的传统卷积层(例如 ResNet)。以下是 SAC 操作的详细分解:

  • 特征提取: 输入特征 ( x ) 首先通过两个不同的卷积操作进行处理,权重分别为 ( w ) 和 ( w + \Delta w ),其中 ( \Delta w )表示对权重的调整。这些卷积具有不同的空洞率,使它们能够捕获多个尺度的特征。

  • 切换函数: 切换函数 S(x)  确定两个卷积之间的平衡。它通常实现为具有 5x5 内核的平均池化层,后跟 1x1 卷积层。切换函数的输出是一组权重,用于平衡两个卷积的贡献。

  • 加权组合: SAC 的最终输出是两个卷积的加权组合: y = S(x) \cdot \text{Conv}(x, w, 1) + (1 - S(x)) \cdot \text{Conv}(x, w + \Delta w, r) 这里,( \text{Conv}(x, w, 1) ) 表示标准卷积,( \text{Conv}(x, w + \Delta w, r) )表示速率为 ( r ) 的空洞卷积。

  1. 全局上下文模块: 在 SAC 模块之前和之后,使用全局上下文模块来合并图像级上下文。这些模块通常涉及全局平均池化和全连接层等操作,以捕获图像的整体结构。这有助于细化特征并为后续层提供更好的上下文。

  2. 中间层: SAC 模块和全局上下文模块的输出被馈送到网络的进一步层,其中可能包括额外的卷积、规范化和激活函数。这些层继续细化和处理 SAC 模块提取的特征。

  3. 输出层: 最后,处理后的特征到达输出层,生成所需的结果。这可能是分类标签、分割掩码或用于对象检测的边界框,具体取决于网络设计用于的任务。

工作流程说明

在提供的文档中,图 4 显示了集成到 ResNet 架构中的 SAC 模块的示例。在此图中,ResNet 主干中的每个 3x3 卷积层都被 SAC 模块替换。此修改允许网络根据输入特征动态调整其感受野,从而提高需要多尺度特征检测的任务的性能。

2.2 更改init.py文件 

关键步骤二:修改modules文件夹下的__init__.py文件,先导入函数

然后在下面的__all__中声明函数

2.3 新增yaml文件

关键步骤三:在 \ultralytics\ultralytics\cfg\models\v8下新建文件 yolov8_SAConv.yaml并将下面代码复制进去

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [ 0.33, 0.25, 1024 ]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [ -1, 1, SAConv2d, [ 64, 3, 2 ] ]  # 0-P1/2
  - [ -1, 1, SAConv2d, [ 128, 3, 2 ] ]  # 1-P2/4
  - [ -1, 3, C2f, [ 128, True ] ]
  - [ -1, 1, SAConv2d, [ 256, 3, 2 ] ]  # 3-P3/8
  - [ -1, 6, C2f, [ 256, True ] ]
  - [ -1, 1, SAConv2d, [ 512, 3, 2 ] ]  # 5-P4/16
  - [ -1, 6, C2f, [ 512, True ] ]
  - [ -1, 1, SAConv2d, [ 1024, 3, 2 ] ]  # 7-P5/32
  - [ -1, 3, C2f, [ 1024, True ] ]
  - [ -1, 1, SPPF, [ 1024, 5 ] ]  # 9

# YOLOv8.0n head
head:
  - [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ]
  - [ [ -1, 6 ], 1, Concat, [ 1 ] ]  # cat backbone P4
  - [ -1, 3, C2f, [ 512 ] ]  # 12

  - [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ]
  - [ [ -1, 4 ], 1, Concat, [ 1 ] ]  # cat backbone P3
  - [ -1, 3, C2f, [ 256 ] ]  # 15 (P3/8-small)

  - [ -1, 1, SAConv2d, [ 256, 3, 2 ] ]
  - [ [ -1, 12 ], 1, Concat, [ 1 ] ]  # cat head P4
  - [ -1, 3, C2f, [ 512 ] ]  # 18 (P4/16-medium)

  - [ -1, 1, SAConv2d, [ 512, 3, 2 ] ]
  - [ [ -1, 9 ], 1, Concat, [ 1 ] ]  # cat head P5
  - [ -1, 3, C2f, [ 1024 ] ]  # 21 (P5/32-large)

  - [ [ 15, 18, 21 ], 1, Detect, [ nc ] ]  # Detect(P3, P4, P5)

温馨提示:因为本文只是对yolov8基础上添加模块,如果要对yolov8n/l/m/x进行添加则只需要指定对应的depth_multiple 和 width_multiple。 


# YOLOv8n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
max_channels: 1024 # max_channels
 
# YOLOv8s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
max_channels: 1024 # max_channels
 
# YOLOv8l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
max_channels: 512 # max_channels
 
# YOLOv8m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple
max_channels: 768 # max_channels
 
# YOLOv8x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple
max_channels: 512 # max_channels

2.4 注册模块

关键步骤四:在parse_model函数中进行注册,添加SAConv,

2.5 执行程序

在train.py中,将model的参数路径设置为yolov8_SAConv.yaml的路径

建议大家写绝对路径,确保一定能找到

from ultralytics import YOLO
 
# Load a model
# model = YOLO('yolov8n.yaml')  # build a new model from YAML
# model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)
 
model = YOLO(r'/projects/ultralytics/ultralytics/cfg/models/v8/yolov8_PConv.yaml')  # build from YAML and transfer weights
 
# Train the model
model.train(batch=16)

🚀运行程序,如果出现下面的内容则说明添加成功🚀 

3. 完整代码分享

https://pan.baidu.com/s/1z2KWa9guFsyHTETM9gXdPA?pwd=kbu5

 提取码: kbu5 

4. GFLOPs

关于GFLOPs的计算方式可以查看百面算法工程师 | 卷积基础知识——Convolution

未改进的YOLOv8nGFLOPs

改进后的GFLOPs

5. 进阶

可以与其他的注意力机制或者损失函数等结合,进一步提升检测效果

6. 总结

可切换空洞卷积 (SAConv) 可动态调整卷积层中的空洞率,通过基于输入特征平衡其贡献的切换函数组合不同的空洞卷积,使网络能够更有效地捕获多尺度特征。这种适应性增强了网络检测不同大小物体的能力,并提高了物体检测和分割等任务的性能,而无需添加额外的参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1889824.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[数据集][目标检测]螺丝螺母检测数据集VOC+YOLO格式2400张2类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):2400 标注数量(xml文件个数):2400 标注数量(txt文件个数):2400 标注…

JavaSE (Java基础):面向对象(上)

8 面向对象 面向对象编程的本质就是:以类的方法组织代码,以对象的组织(封装)数据。 8.1 方法的回顾 package com.oop.demo01;// Demo01 类 public class Demo01 {// main方法public static void main(String[] args) {int c 10…

为什么有些人思考得多,决策反而不好?避免过度拟合的终极指南:决策高手的秘密:灰度认知,黑白决策

在决策过程中,过度关注细节可能导致决策效果不佳,这被称为“过度拟合”。为了避免这种情况,我们需要进行“灰度认知,黑白决策”,即接受不确定性,关注整体趋势,设定明确目标,简化选择…

微信小程序 调色板

注意:是在uniapp中直接使用的一个color-picker插件,改一下格式即可在微信小程序的原生代码中使用 https://github.com/KirisakiAria/we-color-picker 这是插件的地址,使用的话先把这个插件下载下来,找到src,在项目创…

STM32CubeMX实现矩阵按键(HAL库实现)

功能描述: 实现矩阵按键验证,将矩阵按键的按键值,通过串口显示,便于后面使用。 实物图 原理图: 编程原理: 原理很简单,就是通过循环设置引脚为低电平,另外引脚扫描读取电平值&…

[Leetcode 128][Medium] 最长连续序列

目录 题目描述 整体思路 具体代码 题目描述 原题链接 整体思路 首先看到找连续升序排序的最长序列长度,想到对数组进行排序预处理。但是排序算法时间复杂度需要O(nlogn),题目要求时间复杂度为O(n)。因此不能进行排序与处理 接着想到数据结构哈希表&a…

苹果电脑内存满了怎么清理空间垃圾 苹果电脑内存不足怎么办 MacBook优化储存空间

在日常使用苹果电脑过程中,某些用户可能经常会遇到存储空间不足的问题,尤其是当硬盘存储了大量的文件。这不仅影响电脑的运行速度,还可能导致应用程序运行不稳定。 一、节省 MacBook Pro 的空间 苹果电脑的操作系统(macOS&#x…

PDF 生成(6)— 服务化、配置化

当学习成为了习惯,知识也就变成了常识。 感谢各位的 关注、点赞、收藏和评论。 新视频和文章会第一时间在微信公众号发送,欢迎关注:李永宁lyn 文章已收录到 github 仓库 liyongning/blog,欢迎 Watch 和 Star。 回顾 前面我们分…

Android 15 应用适配默认全屏的行为变更(Android V的新特性)

简介 Android V 上默认会使用全面屏兼容方式,影响应用显示,导致应用内跟导航标题重合,无法点击上移的内容。 默认情况下,如果应用以 Android 15(API 级别 35)为目标平台,在搭载 Android 15 的设…

鸿蒙应用笔记

安装就跳过了,一直点点就可以了 配置跳过,就自动下了点东西。 鸿蒙那个下载要12g个内存,大的有点吓人。 里面跟idea没区别 模拟器或者真机运行 真机要鸿蒙4.0,就可以实机调试 直接在手机里面跑,这个牛逼&#xf…

现代智能宠物喂食器方案定制

现代智能宠物喂食器不仅具备定时喂食功能,帮助宠物主人管理宠物的饮食时间和食量,还加入了录音功能和摄像头,使得宠物主人即使不在家也能与宠物保持互动,并实时监控宠物的状况。此外,一些产品还具备紧急预警功能&#…

数据分析:基于聚类的LASSO预测模型包----clustlasso

介绍 clustlasso是结合lasso和cluster-lasso策略的R包,并发表在Interpreting k-mer based signatures for antibiotic resistance prediction。 标准交叉验证lasso分类或回归流程如下: 选择交叉验证数据集(数据分割)&#xff1…

【计算机网络】网络层(作业)

【一】 1、某主机的 IP 地址为 166.199.99.96/19。若该主机向其所在网络发送广播 IP 数据报, 则目的地址可以是(D)。 A. 166.199.99.255B. 166.199.96.255C. 166.199.96.0D. 166.199.127.255 解析: 166.199.99.96/19166.199.0…

YOLOv5初学者问题——用自己的模型预测图片不画框

如题,我在用自己的数据集训练权重模型的时候,在训练完成输出的yolov5-v5.0\runs\train\exp2目录下可以看到,在训练测试的时候是有输出描框的。 但是当我引用训练好的best.fangpt去进行预测的时候, 程序输出的图片并没有描框。根据…

nginx转发的问题

我在项目配置的时候遇到一个问题: 配置了域名转发,且配置了https nginx配置如下: server {listen 443 ssl;server_name yourdomain.com;ssl_certificate /path/to/your/certificate.crt;ssl_certificate_key /path/to/your/private.key;loca…

【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【17】认证服务01

持续学习&持续更新中… 守破离 【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【17】认证服务01 环境搭建验证码倒计时短信服务邮件服务验证码短信形式:邮件形式: 异常机制MD5参考 环境搭建 C:\Windows\System32\drivers\etc\hosts 192.168.…

2024年软件测试面试题,精选100道,内附文档。。。

测试技术面试题 1、我现在有个程序,发现在 Windows 上运行得很慢,怎么判别是程序存在问题还是软硬件系统存在问题? 2、什么是兼容性测试?兼容性测试侧重哪些方面? 3、测试的策略有哪些? 4、正交表测试用…

lua中判断2个表是否相等

当我们获取 table 长度的时候无论是使用 # 还是 table.getn 其都会在索引中断的地方停止计数,而导致无法正确取得 table 的长度,而且还会出现奇怪的现象。例如:t里面有3个元素,但是因为最后一个下表是5和4,却表现出不一…

mac M2芯片系统版本macOS Sonoma14.4.1 Navicat Premium意外退出问题,报错:Translated Report (Full Report Below)

前言 Mac电脑正在使用的navicat客户端突然闪退了!!!! 之前用的好好的。做了可能影响navicat客户端闪退的事情就是把电脑系统升级到了macOS Sonoma14.4.1。后悔莫及~ 现象:navicat能正常创建连接&#xff0…

大数据处理引擎选型之 Hadoop vs Spark vs Flink

随着大数据时代的到来,处理海量数据成为了各个领域的关键挑战之一。为了应对这一挑战,多个大数据处理框架被开发出来,其中最知名的包括Hadoop、Spark和Flink。本文将对这三个大数据处理框架进行比较,以及在不同场景下的选择考虑。…