目标检测算法改进系列之添加SCConv空间和通道重构卷积

news2025/1/18 10:00:51

SCConv-空间和通道重构卷积

SCConv(空间和通道重构卷积)的高效卷积模块,以减少卷积神经网络(CNN)中的空间和通道冗余。SCConv旨在通过优化特征提取过程,减少计算资源消耗并提高网络性能。该模块包括两个单元:
1.空间重构单元(SRU):SRU通过分离和重构方法来减少空间冗余。
2.通道重构单元(CRU):CRU采用分割-变换-融合策略来减少通道冗余。

论文地址:SCConv: Spatial and Channel Reconstruction Convolution for Feature Redundancy

SCConv结构

YOLOv8添加SCConv卷积

SCConv代码实现

import torch
import torch.nn.functional as F
import torch.nn as nn
 
 
class GroupBatchnorm2d(nn.Module):
    def __init__(self, c_num: int,
                 group_num: int = 16,
                 eps: float = 1e-10
                 ):
        super(GroupBatchnorm2d, self).__init__()
        assert c_num >= group_num
        self.group_num = group_num
        self.weight = nn.Parameter(torch.randn(c_num, 1, 1))
        self.bias = nn.Parameter(torch.zeros(c_num, 1, 1))
        self.eps = eps
 
    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.group_num, -1)
        mean = x.mean(dim=2, keepdim=True)
        std = x.std(dim=2, keepdim=True)
        x = (x - mean) / (std + self.eps)
        x = x.view(N, C, H, W)
        return x * self.weight + self.bias
 
 
class SRU(nn.Module):
    def __init__(self,
                 oup_channels: int,
                 group_num: int = 16,
                 gate_treshold: float = 0.5,
                 torch_gn: bool = True
                 ):
        super().__init__()
 
        self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d(
            c_num=oup_channels, group_num=group_num)
        self.gate_treshold = gate_treshold
        self.sigomid = nn.Sigmoid()
 
    def forward(self, x):
        gn_x = self.gn(x)
        w_gamma = self.gn.weight / sum(self.gn.weight)
        w_gamma = w_gamma.view(1, -1, 1, 1)
        reweigts = self.sigomid(gn_x * w_gamma)
        # Gate
        w1 = torch.where(reweigts > self.gate_treshold, torch.ones_like(reweigts), reweigts)  # 大于门限值的设为1,否则保留原值
        w2 = torch.where(reweigts > self.gate_treshold, torch.zeros_like(reweigts), reweigts)  # 大于门限值的设为0,否则保留原值
        x_1 = w1 * x
        x_2 = w2 * x
        y = self.reconstruct(x_1, x_2)
        return y
 
    def reconstruct(self, x_1, x_2):
        x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1)
        x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1)
        return torch.cat([x_11 + x_22, x_12 + x_21], dim=1)
 
 
class CRU(nn.Module):
    '''
    alpha: 0<alpha<1
    '''
 
    def __init__(self,
                 op_channel: int,
                 alpha: float = 1 / 2,
                 squeeze_radio: int = 2,
                 group_size: int = 2,
                 group_kernel_size: int = 3,
                 ):
        super().__init__()
        self.up_channel = up_channel = int(alpha * op_channel)
        self.low_channel = low_channel = op_channel - up_channel
        self.squeeze1 = nn.Conv2d(up_channel, up_channel // squeeze_radio, kernel_size=1, bias=False)
        self.squeeze2 = nn.Conv2d(low_channel, low_channel // squeeze_radio, kernel_size=1, bias=False)
        # up
        self.GWC = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=group_kernel_size, stride=1,
                             padding=group_kernel_size // 2, groups=group_size)
        self.PWC1 = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=1, bias=False)
        # low
        self.PWC2 = nn.Conv2d(low_channel // squeeze_radio, op_channel - low_channel // squeeze_radio, kernel_size=1,
                              bias=False)
        self.advavg = nn.AdaptiveAvgPool2d(1)
 
    def forward(self, x):
        # Split
        up, low = torch.split(x, [self.up_channel, self.low_channel], dim=1)
        up, low = self.squeeze1(up), self.squeeze2(low)
        # Transform
        Y1 = self.GWC(up) + self.PWC1(up)
        Y2 = torch.cat([self.PWC2(low), low], dim=1)
        # Fuse
        out = torch.cat([Y1, Y2], dim=1)
        out = F.softmax(self.advavg(out), dim=1) * out
        out1, out2 = torch.split(out, out.size(1) // 2, dim=1)
        return out1 + out2
 
 
class ScConv(nn.Module):
    def __init__(self,
                 op_channel: int,
                 group_num: int = 4,
                 gate_treshold: float = 0.5,
                 alpha: float = 1 / 2,
                 squeeze_radio: int = 2,
                 group_size: int = 2,
                 group_kernel_size: int = 3,
                 ):
        super().__init__()
        self.SRU = SRU(op_channel,
                       group_num=group_num,
                       gate_treshold=gate_treshold)
        self.CRU = CRU(op_channel,
                       alpha=alpha,
                       squeeze_radio=squeeze_radio,
                       group_size=group_size,
                       group_kernel_size=group_kernel_size)
 
    def forward(self, x):
        x = self.SRU(x)
        x = self.CRU(x)
        return x
 
 
if __name__ == '__main__':
    x = torch.randn(1, 32, 16, 16)
    model = ScConv(32)
    print(model(x).shape)

SCConv嵌入时额外添加调用函数

YOLOv8中直接嵌入会报错而且参数对不上,所以需要额外定义一个函数作为中转

class SCConv_yolov8(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, g=1, dilation=1):
        super().__init__()
        self.conv = Conv(in_channels, out_channels, k=1)
 
        self.RFAConv = ScConv(out_channels)
 
        self.bn = nn.BatchNorm2d(out_channels)
 
        self.gelu = nn.GELU()
 
    def forward(self, x):
        x = self.conv(x)
 
        x = self.RFAConv(x)
 
        x = self.gelu(self.bn(x))
        return x

将SCConv嵌入C2f与Bottleneck模块

class Bottleneck_SCConv(nn.Module):
    """Standard bottleneck."""
 
    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
        """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
        expansion.
        """
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, k[0], 1)
        self.cv2 = SCConv_yolov8(c_, c2, k[1], 1, g=g)
        self.add = shortcut and c1 == c2
 
    def forward(self, x):
        """'forward()' applies the YOLO FPN to input data."""
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
 
 
class C2f_SCConv(nn.Module):
    """Faster Implementation of CSP Bottleneck with 2 convolutions."""
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
        """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
        expansion.
        """
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck_SCConv(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
 
    def forward(self, x):
        """Forward pass through C2f layer."""
        x = self.cv1(x)
        x = x.chunk(2, 1)
        y = list(x)
        # y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))
 
    def forward_split(self, x):
        """Forward pass using split() instead of chunk()."""
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

参考案例

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
 
# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
 
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
 
 
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12
 
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f_SCConv, [256]]  # 15 (P3/8-small)
 
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f_SCConv, [512]]  # 18 (P4/16-medium)
 
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f_SCConv, [1024]]  # 21 (P5/32-large)
 
  - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1280679.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【小沐学Python】Python实现Web服务器(Flask+celery,生产者-消费者)

文章目录 1、简介2、安装和下载2.1 flask2.2 celery2.3 redis 3、功能开发3.1 创建异步任务的方法3.1.1 使用默认的参数3.1.2 指定相关参数3.1.3 自定义Task基类 3.2 调用异步任务的方法3.2.1 app.send_task3.2.2 Task.delay3.2.3 Task.apply_async 3.3 获取任务结果和状态 4、…

Spring cloud - gateway

什么是Spring Cloud Gateway 先去看下官网的解释&#xff1a; This project provides an API Gateway built on top of the Spring Ecosystem, including: Spring 6, Spring Boot 3 and Project Reactor. Spring Cloud Gateway aims to provide a simple, yet effective way t…

javaweb mybatis(手动jar包)

基础&#xff1a;https://blog.csdn.net/qq_67832732/article/details/134764134 条件查询 在映射文件的SQL配置中配置参数 使用parameterType来指定参数类型 使用#{参数名}来接收参数的值 parameterType"string" 表示sql语句需要一个参数&#xff0c;类型为字符…

计算UDP报文CRC校验的总结

概述 因公司项目需求&#xff0c;遇到需要发送带UDP/IP头数据包的功能&#xff0c;经过多次尝试顺利完成&#xff0c;博文记录以备忘。 环境信息 操作系统 ARM64平台的中标麒麟Kylin V10 工具 tcpdump、wireshark、vscode 原理 请查看大佬的博文 UDP伪包头定义&#x…

2023年【G1工业锅炉司炉】考试试题及G1工业锅炉司炉模拟考试题库

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2023年【G1工业锅炉司炉】考试试题及G1工业锅炉司炉模拟考试题库&#xff0c;包含G1工业锅炉司炉考试试题答案和解析及G1工业锅炉司炉模拟考试题库练习。安全生产模拟考试一点通结合国家G1工业锅炉司炉考试最新大纲及…

【开源】基于Vue和SpringBoot的校园二手交易系统

项目编号&#xff1a; S 009 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S009&#xff0c;文末获取源码。} 项目编号&#xff1a;S009&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 二手商品档案管理模…

Docker容器(一)概述

一、虚拟化概述 1.1引⼊虚拟化技术的必要性 服务器只有5%的时间是在⼯作的&#xff1b;在其它时间服务器都处于“休眠”状态. 虚拟化前 每台主机⼀个操作系统; 软硬件紧密结合; 在同⼀个主机上运⾏多个应⽤程序通常会遭遇冲突; 系统的资源利⽤率低; 硬件成本⾼昂⽽且不够灵活…

TIA西门子博途V19安装教程及注意事项

TIA西门子博途V19安装教程及注意事项 前提条件: TIA Portal V19需要.Net Framework 3.5环境,所以在安装TIA V19之前要先安装它。 如下图所示,否则可能会提示报错信息: 大家可以在控制面板中的程序和功能中检查是否已经安装,如果没有,可以参考以下步骤自行安装:

OpenWrt作为旁路由(网关)配置

目录 背景前提条件环境操作步骤物理层连接设置与主路由同一网段禁用IPv6取消LAN接口桥接防火墙配置 背景 本文简介如何配置OpenWrt&#xff0c;使其作为旁路由&#xff08;网关&#xff09;运行。 旁路由大概有以下这几种工作方式&#xff1a; 主路由开DHCP&#xff0c;网关未…

十、FreeRTOS之FreeRTOS任务调度

本节需要掌握的内容如下&#xff1a; 1&#xff0c;开启任务调度器&#xff08;熟悉&#xff09; 2&#xff0c;启动第一个任务&#xff08;熟悉&#xff09; 3&#xff0c;任务切换&#xff08;掌握&#xff09; 一&#xff0c;开启任务调度器&#xff08;熟悉&#xff09…

java高校实验室排课学生考勤系统springboot+vue

随着各高校办学规模的迅速扩大,学科专业的不断拓宽,传统的实验教学和实验室管理方法已经不能适应学校管理的要求,特别是化学实验室的管理,化学实验室仪器药品繁杂多样,管理任务繁重,目前主要使用人工记录方法管理,使用不便,效率低下,而且容易疏漏.时间一长将产生大量的文件和数…

【C++】异常处理 ⑦ ( 异常类的继承层次结构 | 抛出 / 捕获 多个类型异常对象 | 抛出子类异常对象 / 捕获并处理 父类异常对象 )

文章目录 一、抛出 / 捕获 多个类型异常对象1、抛出 / 捕获 多个类型异常对象2、操作弊端3、完整代码示例 二、异常类的继承层次结构1、抛出子类异常对象 / 捕获并处理 父类异常对象2、完整代码示例 - 抛出子类异常对象 / 捕获并处理 父类异常对象 自定义的 异常类 , 可能存在 …

Java数据结构 之 包装类简单认识泛类

生命不息&#xff0c;奋斗不止 目录 1. 什么是包装类&#xff1f; 1.1 装箱和拆箱 1.2 自动装箱和自动拆箱 2. 什么是泛型 3. 引出泛型 3.1 语法 4 泛型类的使用 4.1 语法 4.2 示例 4.3 类型推导(Type Inference) 5. 裸类型(Raw Type) &#xff08;了解&#xff09…

〖大前端 - 基础入门三大核心之JS篇㊻〗- JS + CSS实现动画

说明&#xff1a;该文属于 大前端全栈架构白宝书专栏&#xff0c;目前阶段免费&#xff0c;如需要项目实战或者是体系化资源&#xff0c;文末名片加V&#xff01;作者&#xff1a;不渴望力量的哈士奇(哈哥)&#xff0c;十余年工作经验, 从事过全栈研发、产品经理等工作&#xf…

JSP格式化标签 parseDate将指定时间格式字符串转为真正的date对象

格式化标签最后一个就是 parseDate 作用 将一个日期/时间格式字符串 转为 真正的date时间类型 有点无语 这种 东西应该都是在java中去做的 而不是在java中 这个建议也是做个了解即可 作用不是那么大 基本语法如下 这里 我们 直接编写代码如下 <% page contentType"…

智能优化算法应用:基于天牛须算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于天牛须算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于天牛须算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.天牛须算法4.实验参数设定5.算法结果6.参考文献7.…

2024年AMC8美国初中数学竞赛最后一个月复习指南(附资料)

还有一个半月的时间&#xff0c;2024年AMC8&#xff08;大家默认都直接叫这个比赛的英文名&#xff0c;而不叫中文名美国数学竞赛或美国初中数学竞赛了&#xff09;就要开始了。 有志于在2024年AMC8的比赛中拿到奖项的孩子已经在“磨拳霍霍”了。那么最后一个半月的时间该如何…

智能优化算法应用:基于热交换算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于热交换算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于热交换算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.热交换算法4.实验参数设定5.算法结果6.参考文献7.…

Java线程池的使用和最佳实践

第1章&#xff1a;引言 处理并发问题时&#xff0c;如果每次都新建线程&#xff0c;那系统的压力得有多大&#xff1f;这时候&#xff0c;线程池就像一个英雄一样出现了&#xff0c;它帮我们有效地管理线程&#xff0c;提高资源利用率&#xff0c;降低开销。那么&#xff0c;为…

还搞不懂什么是参数,超参数吗?三分钟快速了解参数与超参数的概念和区别!!!

文章目录 前言一、参数是什么&#xff1f;二、超参数是什么三&#xff0c;常使用的超参数有哪些 前言 参数是模型中可被学习和调整的参数&#xff0c;通过训练数据进行学习和优化&#xff1b; 而超参数则是手动设置的参数&#xff0c;用于控制模型的行为和性能&#xff0c;超…