即插即用模块详解SCConv:用于特征冗余的空间和通道重构卷积

news2024/11/29 12:51:54

目录

一、摘要

二、创新点说明

2.1 Methodology

 2.2SRU for Spatial Redundancy​编辑

2.3CRU for Channel Redundancy

三、实验

3.1基于CIFAR的图像分类

3.2基于ImageNet的图像分类

3.3对象检测

四、代码详解

五、总结


论文:https://openaccess.thecvf.com/content/CVPR2023/papers/Li_SCConv_Spatial_and_Channel_Reconstruction_Convolution_for_Feature_Redundancy_CVPR_2023_paper.pdf

代码:GitHub - cheng-haha/ScConv: SCConv: Spatial and Channel Reconstruction Convolution for Feature Redundancy

一、摘要

卷积神经网络(cnn)在各种计算机视觉任务中取得了显著的性能,但这是以巨大的计算资源为代价的,部分原因是卷积层提取冗余特征。最近的作品要么压缩训练有素的大型模型,要么探索设计良好的轻量级模型。在本文中,我们尝试利用特征之间的空间和通道冗余来进行CNN压缩,并提出了一种高效的卷积模块,称为SCConv (spatial and channel reconstruction convolution),以减少冗余计算并促进代表性特征的学习。提出的SCConv由空间重构单元(SRU)和信道重构单元(CRU)两个单元组成。SRU采用分离重构的方法来抑制空间冗余,CRU采用分离变换融合的策略来减少信道冗余。此外,SCConv是一种即插即用的架构单元,可直接用于替代各种卷积神经网络中的标准卷积。实验结果表明,SCConv嵌入模型能够通过减少冗余特征来获得更好的性能,并且显著降低了复杂度和计算成本。

论文贡献总结:

       1. 提出了一种空间重构单元SRU,该单元根据权重分离冗余特征并进行重构,以抑制空间维度上的冗余,增强特征的表征能力。
       2. 我们提出了一种信道重构单元,称为CRU,它利用分裂变换和融合策略来减少信道维度的冗余以及计算成本和存储。
        3.我们设计了一种名为SCConv的即插即用操作,将SRU和CRU以顺序的方式组合在一起,以取代标准卷积,用于在各种骨干cnn上操作。结果表明,SCConv可以大大节省计算负荷,同时提高模型在挑战性任务上的性能。

二、创新点说明

2.1 Methodology

SCConv,它由两个单元组成,空间重建单元(SRU)和通道重建单元(CRU),以顺序的方式放置。具体而言,对于瓶颈残差块中的中间输入特征X,我们首先通过SRU运算获得空间细化特征Xw,然后利用CRU运算获得信道细化特征Y。我们在SCConv模块中利用了特征之间的空间冗余和通道冗余,可以无缝集成到任何CNN架构中,以减少中间特征映射之间的冗余并增强CNN的特征表示。

 2.2SRU for Spatial Redundancy

为了利用特征的空间冗余,我们引入了空间重构单元(SRU),如图2所示,它利用了分离和重构操作。分离操作的目的是将信息丰富的特征图与空间内容对应的信息较少的特征图分离开来。

2.3CRU for Channel Redundancy

为了利用特征的信道冗余,我们引入了信道重构单元(CRU),如图3所示,它利用了分裂-转换-融合策略。

三、实验

3.1基于CIFAR的图像分类

3.2基于ImageNet的图像分类

3.3对象检测

四、代码详解

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 的函数库
import torch.nn as nn  # 导入 PyTorch 的神经网络模块

# 自定义 GroupBatchnorm2d 类,实现分组批量归一化
class GroupBatchnorm2d(nn.Module):
    def __init__(self, c_num:int, group_num:int = 16, eps:float = 1e-10):
        super(GroupBatchnorm2d,self).__init__()  # 调用父类构造函数
        assert c_num >= group_num  # 断言 c_num 大于等于 group_num
        self.group_num  = group_num  # 设置分组数量
        self.gamma      = nn.Parameter(torch.randn(c_num, 1, 1))  # 创建可训练参数 gamma
        self.beta       = nn.Parameter(torch.zeros(c_num, 1, 1))  # 创建可训练参数 beta
        self.eps        = eps  # 设置小的常数 eps 用于稳定计算

    def forward(self, x):
        N, C, H, W  = x.size()  # 获取输入张量的尺寸
        x           = x.view(N, self.group_num, -1)  # 将输入张量重新排列为指定的形状
        mean        = x.mean(dim=2, keepdim=True)  # 计算每个组的均值
        std         = x.std(dim=2, keepdim=True)  # 计算每个组的标准差
        x           = (x - mean) / (std + self.eps)  # 应用批量归一化
        x           = x.view(N, C, H, W)  # 恢复原始形状
        return x * self.gamma + self.beta  # 返回归一化后的张量

# 自定义 SRU(Spatial and Reconstruct Unit)类
class SRU(nn.Module):
    def __init__(self,
                 oup_channels:int,  # 输出通道数
                 group_num:int = 16,  # 分组数,默认为16
                 gate_treshold:float = 0.5,  # 门控阈值,默认为0.5
                 torch_gn:bool = False  # 是否使用PyTorch内置的GroupNorm,默认为False
                 ):
        super().__init__()  # 调用父类构造函数

         # 初始化 GroupNorm 层或自定义 GroupBatchnorm2d 层
        self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d(c_num=oup_channels, group_num=group_num)
        self.gate_treshold  = gate_treshold  # 设置门控阈值
        self.sigomid        = nn.Sigmoid()  # 创建 sigmoid 激活函数

    def forward(self, x):
        gn_x        = self.gn(x)  # 应用分组批量归一化
        w_gamma     = self.gn.gamma / sum(self.gn.gamma)  # 计算 gamma 权重
        reweights   = self.sigomid(gn_x * w_gamma)  # 计算重要性权重

        # 门控机制
        info_mask    = reweights >= self.gate_treshold  # 计算信息门控掩码
        noninfo_mask = reweights < self.gate_treshold  # 计算非信息门控掩码
        x_1          = info_mask * x  # 使用信息门控掩码
        x_2          = noninfo_mask * x  # 使用非信息门控掩码
        x            = self.reconstruct(x_1, x_2)  # 重构特征
        return x

    def reconstruct(self, x_1, x_2):
        x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1)  # 拆分特征为两部分
        x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1)  # 拆分特征为两部分
        return torch.cat([x_11 + x_22, x_12 + x_21], dim=1)  # 重构特征并连接

# 自定义 CRU(Channel Reduction Unit)类
class CRU(nn.Module):
    def __init__(self, op_channel:int, alpha:float = 1/2, squeeze_radio:int = 2, group_size:int = 2, group_kernel_size:int = 3):
        super().__init__()  # 调用父类构造函数

        self.up_channel     = up_channel = int(alpha * op_channel)  # 计算上层通道数
        self.low_channel    = low_channel = op_channel - up_channel  # 计算下层通道数
        self.squeeze1       = nn.Conv2d(up_channel, up_channel // squeeze_radio, kernel_size=1, bias=False)  # 创建卷积层
        self.squeeze2       = nn.Conv2d(low_channel, low_channel // squeeze_radio, kernel_size=1, bias=False)  # 创建卷积层

        # 上层特征转换
        self.GWC            = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=group_kernel_size, stride=1, padding=group_kernel_size // 2, groups=group_size)  # 创建卷积层
        self.PWC1           = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=1, bias=False)  # 创建卷积层

        # 下层特征转换
        self.PWC2           = nn.Conv2d(low_channel // squeeze_radio, op_channel - low_channel // squeeze_radio, kernel_size=1, bias=False)  # 创建卷积层
        self.advavg         = nn.AdaptiveAvgPool2d(1)  # 创建自适应平均池化层

    def forward(self, x):
        # 分割输入特征
        up, low = torch.split(x, [self.up_channel, self.low_channel], dim=1)
        up, low = self.squeeze1(up), self.squeeze2(low)

        # 上层特征转换
        Y1 = self.GWC(up) + self.PWC1(up)

        # 下层特征转换
        Y2 = torch.cat([self.PWC2(low), low], dim=1)

        # 特征融合
        out = torch.cat([Y1, Y2], dim=1)
        out = F.softmax(self.advavg(out), dim=1) * out
        out1, out2 = torch.split(out, out.size(1) // 2, dim=1)
        return out1 + out2

# 自定义 ScConv(Squeeze and Channel Reduction Convolution)模型
class ScConv(nn.Module):
    def __init__(self, op_channel:int, group_num:int = 16, gate_treshold:float = 0.5, alpha:float = 1/2, squeeze_radio:int = 2, group_size:int = 2, group_kernel_size:int = 3):
        super().__init__()  # 调用父类构造函数

        self.SRU = SRU(op_channel, group_num=group_num, gate_treshold=gate_treshold)  # 创建 SRU 层
        self.CRU = CRU(op_channel, alpha=alpha, squeeze_radio=squeeze_radio, group_size=group_size, group_kernel_size=group_kernel_size)  # 创建 CRU 层

    def forward(self, x):
        x = self.SRU(x)  # 应用 SRU 层
        x = self.CRU(x)  # 应用 CRU 层
        return x

if __name__ == '__main__':
    x       = torch.randn(1, 32, 16, 16)  # 创建随机输入张量
    model   = ScConv(32)  # 创建 ScConv 模型
    print(model(x).shape)  # 打印模型输出的形状

五、总结

在本文中,我们提出了一种新的空间和信道重构模块(SCConv),这是一种有效的架构单元,可以降低计算成本和模型存储,同时通过减少标准卷积中广泛存在的空间和信道冗余来提高CNN模型的性能。我们使用两个不同的模块SRU和CRU来减少特征映射中的冗余,在减少大量计算负载的同时实现了相当大的性能改进。此外,SCConv是一个即插即用的模块,可以替代标准的卷积,不需要任何模型架构的调整。此外,各种SOTA方法在图像分类和目标检测方面的大量实验表明,scconvn嵌入模型在性能和模型效率之间取得了更好的平衡。最后,我们希望所提出的方法可以启发研究更有效的建筑设计。

参考:大佬

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1598970.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于 Operator 部署 Prometheus 监控 k8s 集群

目录 一、环境准备 1.1 选择版本 1.2 过滤镜像 1.3 修改 yaml 镜像 1.4 移动 *networkPolicy*.yaml 1.5 修改 service 文件 1.6 提前下载镜像并推送到私有镜像仓库 1.7 修改镜像&#xff08;可选&#xff09; 二、执行创建 三、查看 pod 状态 四、访问 prometheus、…

【Spring】依赖注入(DI)时常用的注解@Autowired和@Value

目录 1、Autowired 自动装配 1.1、要实现自动装配不是一定要使用Autowired 1.2、Autowired的特性 &#xff08;1&#xff09;首先会根据类型去spring容器中找(bytype),如果有多个类型&#xff0c;会根据名字再去spring容器中找(byname) &#xff08;2&#xff09;如果根据名…

冯喜运:4.16中东对抗风暴下,黄金原油市场何去何从?

黄金行情走势分析&#xff1a;4小时图布林道开始收口&#xff0c;昨日下探至下轨附近&#xff0c;也是此前的起涨低点2320启稳上升&#xff0c;十字K线配合单阳拉起&#xff0c;重新去摸高上轨。目前4小时图处于摸高当中。周线和日线留意多空转换&#xff0c;摸高之后是强势延续…

使用Scrapy选择器提取豆瓣电影信息,并用正则表达式从介绍详情中获取指定信息

本文同步更新于博主个人博客&#xff1a;blog.buzzchat.top 一、Scrapy框架 1. 介绍 在当今数字化的时代&#xff0c;数据是一种宝贵的资源&#xff0c;而网络爬虫&#xff08;Web Scraping&#xff09;则是获取网络数据的重要工具之一。而在 Python 生态系统中&#xff0c;S…

03-echarts如何画立体柱状图

echarts如何画立体柱状图 一、创建盒子1、创建盒子2、初始化盒子&#xff08;先绘制一个基本的二维柱状图的样式&#xff09;1、创建一个初始化图表的方法2、在mounted中调用这个方法3、在方法中写options和绘制图形 二、画图前知识1、坐标2、柱状图图解分析 三、构建方法1、创…

第七周学习笔记DAY.1-封装

学完本次课程后&#xff0c;你能够&#xff1a; 理解封装的作用 会使用封装 会使用Java中的包组织类 掌握访问修饰符&#xff0c;理解访问权限 没有封装的话属性访问随意&#xff0c;赋值也可能不合理&#xff0c;为了解决这些代码设计缺陷&#xff0c;可以使用封装。 面向…

RabbitMQ入门实战

文章目录 RabbitMQ入门实战基本概念安装快速入门单向发送多消费者 RabbitMQ入门实战 官方&#xff1a;https://www.rabbitmq.com 基本概念 AMQP协议&#xff1a;https://www.rabbitmq.com/tutorials/amqp-concepts.html 定义&#xff1a;高级信息队列协议&#xff08;Advanc…

LangChain入门:19.探索结构化工具对话

引言 在人工智能的浪潮中&#xff0c;对话代理技术正逐渐成为企业和开发者关注的焦点。LangChain&#xff0c;作为对话代理领域的一颗新星&#xff0c;自2021年9月诞生以来&#xff0c;以其强大的功能和灵活的应用场景迅速赢得了市场的认可。本文将带你深入了解LangChain中的S…

【从浅学到熟知Linux】进程控制上篇=>进程创建、进程终止与进程等待(含_exit与exit的区别、fork函数详解、wait与waitpid详解)

&#x1f3e0;关于专栏&#xff1a;Linux的浅学到熟知专栏用于记录Linux系统编程、网络编程等内容。 &#x1f3af;每天努力一点点&#xff0c;技术变化看得见 文章目录 进程创建fork函数写时拷贝 进程退出进程退出操作系统做了什么&#xff1f;进程退出场景进程退出的常见方法…

openstack修改实例名称但是gnocchi监控数据中实例名称没有变更的问题处理

文章目录 一、问题描述二、调研过程1、变更实例名称2、查看grafana中的监控数据3、libvirt服务中的xml文件4、现有的监控数据流转架构 总结 一、问题描述 openstack修改实例名称但是gnocchi监控数据中实例名称没有变更的问题处理。 通过修改实例名称的功能修改了实例名称&…

大模型赋能:爬虫技术的全新革命

大模型加持下的爬虫技术革新&#xff1a;从BS4到提示工程的飞跃 在爬虫技术的演进历程中&#xff0c;内容解析一直是一个核心环节。传统的爬虫技术&#xff0c;如使用BeautifulSoup&#xff08;BS4&#xff09;等工具&#xff0c;需要逐个解析网页内容&#xff0c;通过XPath或C…

XILINX 7系列时钟资源

文章目录 前言一、时钟概要1.1、CC1.2、BUFR、BUFIO、BUFMR1.3、CMT1.4、BUFH1.5、BUFG 二、时钟路由资源三、CMT 前言 本文主要参考xilinx手册ug472 一、时钟概要 7系列FPGA时钟资源主要有CC、BUFR、BUFIO、BUFMR、CMT、BUFG、BUFH和GTE_COMMON 1.1、CC “CC”&#xff0…

OpenHarmony开发案例:【自定义通知】

介绍 本示例使用[ohos.notificationManager] 等接口&#xff0c;展示了如何初始化不同类型通知的通知内容以及通知的发布、取消及桌面角标的设置&#xff0c;通知类型包括基本类型、长文本类型、多行文本类型、图片类型、带按钮的通知、点击可跳转到应用的通知。 效果预览&am…

TensorFlow实战Google深度学习框架 PDF书籍分享

今天又来给大家推荐一本TensorFlow方面的书籍<TensorFlow实战Google深度学习框架>。本书适用于想要使用深度学习或TensorFlow的数据科学家、工程师&#xff0c;希望了解大数据平台工程师&#xff0c;对人工智能、深度学习感兴趣的计算机相关从业人员及在校学生等。 下载当…

【数据结构与算法】用两个栈实现一个队列

题目 用两个栈&#xff0c;实现一个队列功能 add delete length 队列 用数组可以实现队列&#xff0c;数组和队列的区别是&#xff1a;队列是逻辑结构是一个抽象模型&#xff0c;简单地可以用数组、链表实现&#xff0c;所以数组和链表是一个物理结构&#xff0c;队列是一个逻…

Cannot access ‘androidx.activity.FullyDrawnReporterOwner‘

Android Studio新建项目就报错&#xff1a; Cannot access ‘androidx.activity.FullyDrawnReporterOwner’ which is a supertype of ‘cn.dazhou.osddemo.MainActivity’. Check your module classpath for missing or conflicting dependencies 整个类都报错了。本来原来一直…

阿里面试:DDD中的实体、值对象有什么区别?

在领域驱动设计&#xff08;DDD&#xff09;中&#xff0c;有两个基础概念&#xff1a;实体&#xff08;Entity&#xff09;和值对象&#xff08;Value Object&#xff09;。 使用这些概念&#xff0c;我们可以把复杂的业务需求映射成简单、明确的数据模型。正确使用实体和值对…

【环境】原则

系列文章目录 【引论一】项目管理的意义 【引论二】项目管理的逻辑 【环境】概述 【环境】原则 一、培养项目系统性思维 1.1 系统性思维 1.2 系统性思维的价值 1.3 建模和推演&数字孪生 二、项目的复杂性和如何驾驭复杂性 2.1 复杂性的三个维度 2.2 如何驾驭复杂性 三、…

【御控物联】Java JSON结构转换(4):对象To对象——规则属性重组

文章目录 一、JSON结构转换是什么&#xff1f;二、术语解释三、案例之《JSON对象 To JSON对象》四、代码实现五、在线转换工具六、技术资料 一、JSON结构转换是什么&#xff1f; JSON结构转换指的是将一个JSON对象或JSON数组按照一定规则进行重组、筛选、映射或转换&#xff0…

谷歌pixel6/7pro等手机WiFi不能上网,显示网络连接受限

近期在项目中遇到一个机型出现的问题,先对项目代码进行排查,发现别的设备都能正常运行,就开始来排查机型的问题,特意写出来方便后续查看,也方便其它开发者来自查。 设备机型:Pixel 6a 设备安卓版本:13 该方法无需root,只需要电脑设备安装adb(即Android Debug Bridge…