爆改YOLOv8|利用SCConv改进yolov8-即轻量又涨点

news2024/9/21 14:46:21

1,本文介绍

SCConv(空间和通道重构卷积)是一种高效的卷积模块,旨在优化卷积神经网络(CNN)的性能,通过减少空间和通道的冗余来降低计算资源的消耗。该模块由两个核心组件构成:

  1. 空间重构单元(SRU):通过分离和重构的方式,SRU 有效减少空间冗余。

  2. 通道重构单元(CRU):利用分割-变换-融合策略,CRU 旨在降低通道冗余

关于SCConv的详细介绍可以看论文:SCConv: Spatial and Channel Reconstruction Convolution for Feature Redundancy (thecvf.com)

本文将讲解如何将SCConv融合进yolov8

话不多说,上代码!

2, 将SCConv融合进yolov8

2.1 步骤一

找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个SCConv.py文件,文件名字可以根据你自己的习惯起,然后将SCConv的核心代码复制进去.

import torch
import torch.nn.functional as F
import torch.nn as nn


__all__ = ['C2f_SCConv']
 
class GroupBatchnorm2d(nn.Module):
    def __init__(self, c_num: int,
                 group_num: int = 16,
                 eps: float = 1e-10
                 ):
        super(GroupBatchnorm2d, self).__init__()
        assert c_num >= group_num
        self.group_num = group_num
        self.weight = nn.Parameter(torch.randn(c_num, 1, 1))
        self.bias = nn.Parameter(torch.zeros(c_num, 1, 1))
        self.eps = eps
 
    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.group_num, -1)
        mean = x.mean(dim=2, keepdim=True)
        std = x.std(dim=2, keepdim=True)
        x = (x - mean) / (std + self.eps)
        x = x.view(N, C, H, W)
        return x * self.weight + self.bias
 
 
class SRU(nn.Module):
    def __init__(self,
                 oup_channels: int,
                 group_num: int = 16,
                 gate_treshold: float = 0.5,
                 torch_gn: bool = True
                 ):
        super().__init__()
 
        self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d(
            c_num=oup_channels, group_num=group_num)
        self.gate_treshold = gate_treshold
        self.sigomid = nn.Sigmoid()
 
    def forward(self, x):
        gn_x = self.gn(x)
        w_gamma = self.gn.weight / sum(self.gn.weight)
        w_gamma = w_gamma.view(1, -1, 1, 1)
        reweigts = self.sigomid(gn_x * w_gamma)
        # Gate
        w1 = torch.where(reweigts > self.gate_treshold, torch.ones_like(reweigts), reweigts)  # 大于门限值的设为1,否则保留原值
        w2 = torch.where(reweigts > self.gate_treshold, torch.zeros_like(reweigts), reweigts)  # 大于门限值的设为0,否则保留原值
        x_1 = w1 * x
        x_2 = w2 * x
        y = self.reconstruct(x_1, x_2)
        return y
 
    def reconstruct(self, x_1, x_2):
        x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1)
        x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1)
        return torch.cat([x_11 + x_22, x_12 + x_21], dim=1)
 
 
class CRU(nn.Module):
    '''
    alpha: 0<alpha<1
    '''
 
    def __init__(self,
                 op_channel: int,
                 alpha: float = 1 / 2,
                 squeeze_radio: int = 2,
                 group_size: int = 2,
                 group_kernel_size: int = 3,
                 ):
        super().__init__()
        self.up_channel = up_channel = int(alpha * op_channel)
        self.low_channel = low_channel = op_channel - up_channel
        self.squeeze1 = nn.Conv2d(up_channel, up_channel // squeeze_radio, kernel_size=1, bias=False)
        self.squeeze2 = nn.Conv2d(low_channel, low_channel // squeeze_radio, kernel_size=1, bias=False)
        # up
        self.GWC = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=group_kernel_size, stride=1,
                             padding=group_kernel_size // 2, groups=group_size)
        self.PWC1 = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=1, bias=False)
        # low
        self.PWC2 = nn.Conv2d(low_channel // squeeze_radio, op_channel - low_channel // squeeze_radio, kernel_size=1,
                              bias=False)
        self.advavg = nn.AdaptiveAvgPool2d(1)
 
    def forward(self, x):
        # Split
        up, low = torch.split(x, [self.up_channel, self.low_channel], dim=1)
        up, low = self.squeeze1(up), self.squeeze2(low)
        # Transform
        Y1 = self.GWC(up) + self.PWC1(up)
        Y2 = torch.cat([self.PWC2(low), low], dim=1)
        # Fuse
        out = torch.cat([Y1, Y2], dim=1)
        out = F.softmax(self.advavg(out), dim=1) * out
        out1, out2 = torch.split(out, out.size(1) // 2, dim=1)
        return out1 + out2
 
 
class ScConv(nn.Module):
    def __init__(self,
                 op_channel: int,
                 group_num: int = 4,
                 gate_treshold: float = 0.5,
                 alpha: float = 1 / 2,
                 squeeze_radio: int = 2,
                 group_size: int = 2,
                 group_kernel_size: int = 3,
                 ):
        super().__init__()
        self.SRU = SRU(op_channel,
                       group_num=group_num,
                       gate_treshold=gate_treshold)
        self.CRU = CRU(op_channel,
                       alpha=alpha,
                       squeeze_radio=squeeze_radio,
                       group_size=group_size,
                       group_kernel_size=group_kernel_size)
 
    def forward(self, x):
        x = self.SRU(x)
        x = self.CRU(x)
        return x
class SCConv_yolov8(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, g=1, dilation=1):
        super().__init__()
        self.conv = Conv(in_channels, out_channels, k=1)
 
        self.RFAConv = ScConv(out_channels)
 
        self.bn = nn.BatchNorm2d(out_channels)
 
        self.gelu = nn.GELU()
 
    def forward(self, x):
        x = self.conv(x)
 
        x = self.RFAConv(x)
 
        x = self.gelu(self.bn(x))
        return x

class Bottleneck_SCConv(nn.Module):
    """Standard bottleneck."""
 
    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
        """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
        expansion.
        """
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, k[0], 1)
        self.cv2 = SCConv_yolov8(c_, c2, k[1], 1, g=g)
        self.add = shortcut and c1 == c2
 
    def forward(self, x):
        """'forward()' applies the YOLO FPN to input data."""
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
 
 
class C2f_SCConv(nn.Module):
    """Faster Implementation of CSP Bottleneck with 2 convolutions."""
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
        """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
        expansion.
        """
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck_SCConv(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
 
    def forward(self, x):
        """Forward pass through C2f layer."""
        x = self.cv1(x)
        x = x.chunk(2, 1)
        y = list(x)
        # y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))
 
    def forward_split(self, x):
        """Forward pass using split() instead of chunk()."""
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

2.2 步骤二

在task.py导入我们的模块

from .modules.SCConv import C2f_SCConv

2.3 步骤三

在task.py的parse_model方法里面注册我们的模块

这里需要注意在两个位置进行添加,不要漏了

到此注册成功,复制后面的yaml文件直接运行即可

yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
 
# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
 
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
 
 
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12
 
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f_SCConv, [256]]  # 15 (P3/8-small)
 
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f_SCConv, [512]]  # 18 (P4/16-medium)
 
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f_SCConv, [1024]]  # 21 (P5/32-large)
 
  - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

# 关于SCConv的使用,可以直接做卷积使用,也可以放在c2f或者bottleneck中做融合

不知不觉已经看完了哦,动动小手留个点赞吧--_--

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2109204.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

斯坦福UE4 C++课学习补充25:寻路EQS

文章目录 一、创建EQS二、修改行为树三、查询上下文 一、创建EQS 场景查询系统EQS&#xff1a;可用于收集场景相关的数据。然后该系统可以使用生成器&#xff0c;通过各种用户定义的测试就这些数据提问&#xff0c;返回符合所提问题类型的最佳项目Item。 EQS的一些使用范例包…

Unity【Colliders碰撞器】和【Rigibody刚体】的应用——小球反弹效果

目录 Collider 2D 定义&#xff1a; 类型&#xff1a; Rigidbody 2D 定义&#xff1a; 属性和行为&#xff1a; 运动控制&#xff1a; 碰撞检测&#xff1a; 结合使用 实用检测 延伸拓展 1、在Unity中优化Collider 2D和Rigidbody 2D的性能 2、Unity中Collider 2D…

香橙派列出附近所有的WiFi

使用 nmcli nmcli 是 NetworkManager 的命令行工具&#xff0c;它可以用来检索和管理网络连接。 nmcli device wifi list这个命令会列出所有周围的WiFi网络。

社区电商系统源码之卷轴模式:商业模式分析

随着互联网技术的发展&#xff0c;电商平台的竞争日益激烈&#xff0c;如何留住用户并提升用户粘性成为了各大电商平台关注的重点。卷轴模式作为一种新兴的用户参与和激励机制&#xff0c;在社区电商系统中得到了广泛的应用。本文将从技术角度探讨卷轴模式在社区电商系统中的实…

rust 命令行工具rsup管理前端npm依赖

学习了一年的 rust 了&#xff0c;但是不知道用来做些什么&#xff0c;也没能赋能到工作中&#xff0c;现在前端基建都已经开始全面进入 rust 领域了&#xff0c;rust 的前端生态是越来越好。但是自己奈何水平不够&#xff0c;想贡献点什么&#xff0c;无从下手。 遂想自己捣鼓…

Leetcode3256. 放三个车的价值之和最大 I

Every day a Leetcode 题目来源&#xff1a;3256. 放三个车的价值之和最大 I 解法1&#xff1a;贪心 从大到下排序矩阵所有值, 记为数组v。 转化此题&#xff1a;从r*c个数中选取3个数分别给到车1&#xff0c;车2&#xff0c;和车3&#xff0c;使得符合条件的三数之和最大。…

rancher upgrade 【rancher 升级】

文章目录 1. 背景2. 下载3. 安装4. 检查5. 测试5.1 创建项目5.2 创建应用5.3 删除集群5.4 注册集群 1. 背景 rancher v2.8.2 升级 v2.9.1 2. 下载 下载charts helm repo add rancher-latest https://releases.rancher.com/server-charts/latest helm repo update helm fetc…

NIO、Reactor模式与直接内存

1.NIO NIO有三大核心组件&#xff1a;Selector选择器、Channel管道、buffer缓冲区。、 1.1Selector Selector的英文含义是“选择器”&#xff0c;也可以称为为“轮询代理器”、“事件订阅器”、“channel容器管理机”都行。 Java NIO的选择器允许一个单独的线程来监视多个输…

鸿蒙MPChart图表自定义(四)短刻度线

对于图表中的x轴效果&#xff0c;我们有时想要实现如图所示的特定刻度线。若需绘制x轴的短刻度线&#xff0c;我们可以利用现有资源&#xff0c;将原本的网格线稍作修改&#xff0c;只需绘制一条简洁的短线即可达到目的。 具体的方法就是写一个类MyXAxisRender继承自XAxisRend…

iOS——runLoop

什么是runloop RunLoop实际上就是一个对象&#xff0c;这个对象管理了其需要处理的事件和消息&#xff0c;并提供了一个入口函数来执行相应的处理逻辑。线程执行了这个函数后&#xff0c;就会处于这个函数内部的循环中&#xff0c;直到循环结束&#xff0c;函数返回。 RunLoo…

【转载】golang内存分配

Go 的分配采用了类似 tcmalloc 的结构.特点: 使用一小块一小块的连续内存页, 进行分配某个范围大小的内存需求. 比如某个连续 8KB 专门用于分配 17-24 字节,以此减少内存碎片. 线程拥有一定的 cache, 可用于无锁分配. 同时 Go 对于 GC 后回收的内存页, 并不是马上归还给操作系…

Android13 Hotseat客制化--Hotseat修改布局、支持滑动、去掉开机弹动效果、禁止创建文件夹

需求如题&#xff0c;实现效果如下 &#xff1a; 固定Hotseat的padding位置、固定高度 step1 在FeatureFlags.java中添加flag,以兼容原生态代码 public static final boolean STATIC_HOTSEAT_PADDING true;//hotseat area fixed step2:在dimens.xml中添加padding值和高度值…

信息系统安全保障

关注这个证书的其他相关笔记&#xff1a;NISP 一级 —— 考证笔记合集-CSDN博客 0x01&#xff1a;信息系统 信息系统是具有集成性的系统&#xff0c;每一个组织中信息流动的综合构成一个信息系统。信息系统是根据一定的需要进行输入、系统控制、数据处理、数据存储与输出等活动…

职场关系课:职场上的基本原则(安全原则、进步原则、收益原则、逃生舱原则)

文章目录 引言安全原则进步原则收益原则逃生舱原则引言 职场上的王者,身体里都应该有三个灵魂: 一个文臣,谨小慎微,考虑风险; 一个武将,积极努力,谋求胜利; 一个商人,精打细算,心中有数。 安全原则 工作安全:保住自己的工作和位置信用安全:保住个人的信用,如果领…

《征服数据结构》差分数组

摘要&#xff1a; 1&#xff0c;差分数组的介绍 2&#xff0c;二维差分数组的介绍 1&#xff0c;差分数组的介绍 差分数组主要是操作区间的&#xff0c;关于区间操作的数据结构比较多&#xff0c;除了前面讲的《稀疏表》&#xff0c;还有树状数组&#xff0c;线段树&#xff0c…

高德地图SDK Android版开发 10 InfoWindow

高德地图SDK Android版开发 10 InfoWindow 前言相关类和方法默认样式Marker类AMap类AMap.OnInfoWindowClickListener 接口 自定义样式(视图)AMap 类AMap.ImageInfoWindowAdapter 接口 自定义样式(Image)AMap.ImageInfoWindowAdapter 接口 示例界面布局MapInfoWindow类常量成员变…

215篇【大模型医疗】论文合集(附PDF)

ChatGPT的横空出世引发了新一轮生成式大模型热潮&#xff0c;作为最新技术的"试验场"&#xff0c;医疗也成为众多大模型的热门首选。 我整理了215篇医疗和大模型的论文&#xff0c;供大家学习和参考。 领215篇医疗和大模型论文

欧拉下搭建第三方软件仓库—docker

1.创建新的文件内容 切换目录到etc底下的yum.repos.d目录&#xff0c;创建docker-ce.repo文件 [rootlocalhost yum.repos.d]# cd /etc/yum.repos.d/ [rootlocalhost yum.repos.d]# vim docker-ce.repo 编辑文件,使用阿里源镜像源&#xff0c;镜像源在编辑中需要单独复制 h…

vue3的学习(2)

属性绑定 1.将一个容器中的class和id使用vue用法赋上具体的值&#xff0c;这样就可以动态的给容器添加上自己想要给其添加的class或者id或者title。 2.关键语法&#xff0c;在容器中的class或者id或者title前面加上 "v-bind:"&#xff0c;当加上"v-bind关键语…

计算机网络基础 - 应用层(2)

计算机网络基础 应用层FTP 与 EMail文件传输协议 FTP电子邮件 EMail主要组成部分SMTP概述SMTP 与 HTTP1.1 邮件报文格式报文格式多媒体扩展 MIME 邮件访问协议概述POP3IMAP DNS概述域名结构工作机理集中式设计分布式、层次数据库根 DNS 服务器顶级域 DNS 服务器权威 DNS 服务器…