pytorch代码实现之空间通道重组卷积SCConv

news2025/3/1 21:48:09

空间通道重组卷积SCConv

空间通道重组卷积SCConv,全称Spatial and Channel Reconstruction Convolution,CPR2023年提出,可以即插即用,能够在减少参数的同时提升性能的模块。其核心思想是希望能够实现减少特征冗余从而提高算法的效率。一般压缩模型的方法分为三种,分别是network pruning, weight quantization, low-rank factorization以及knowledge distillation,虽然这些方法能够达到减少参数的效果,但是往往都会导致模型性能的衰减。另一种方法就是在构建模型时利用特殊的模块或操作减少模型参数,获得轻量级的网络模型,这种方法能够在保证性能的同时达到参数减少的效果。

原文地址:SCConv: Spatial and Channel Reconstruction Convolution for Feature Redundancy

作者提出的SCConv包含两部分,分别是Spatial Reconstruction Unit (SRU)和Channel Reconstruction Unit (CRU),下面是SSConv的总体结构。
SCConv结构原理图
可以看出,SCConv模块设计,对于输入的特征图先利用1x1的卷积改变为适合的通道数,之后便分别是SRU和CRU两个模块对于特征图进行处理,最后在通过1x1的卷积将特征通道数恢复并进行残差操作。
SRU模块结构
CRU模块结构

代码实现如下:

import torch
import torch.nn.functional as F
import torch.nn as nn 


class GroupBatchnorm2d(nn.Module):
    def __init__(self, c_num:int, 
                 group_num:int = 16, 
                 eps:float = 1e-10
                 ):
        super(GroupBatchnorm2d,self).__init__()
        assert c_num    >= group_num
        self.group_num  = group_num
        self.gamma      = nn.Parameter( torch.randn(c_num, 1, 1)    )
        self.beta       = nn.Parameter( torch.zeros(c_num, 1, 1)    )
        self.eps        = eps

    def forward(self, x):
        N, C, H, W  = x.size()
        x           = x.view(   N, self.group_num, -1   )
        mean        = x.mean(   dim = 2, keepdim = True )
        std         = x.std (   dim = 2, keepdim = True )
        x           = (x - mean) / (std+self.eps)
        x           = x.view(N, C, H, W)
        return x * self.gamma + self.beta


class SRU(nn.Module):
    def __init__(self,
                 oup_channels:int, 
                 group_num:int = 16,
                 gate_treshold:float = 0.5 
                 ):
        super().__init__()
        
        self.gn             = GroupBatchnorm2d( oup_channels, group_num = group_num )
        self.gate_treshold  = gate_treshold
        self.sigomid        = nn.Sigmoid()

    def forward(self,x):
        gn_x        = self.gn(x)
        w_gamma     = self.gn.gamma/sum(self.gn.gamma)
        reweigts    = self.sigomid( gn_x * w_gamma )
        # Gate
        info_mask   = reweigts>=self.gate_treshold
        noninfo_mask= reweigts<self.gate_treshold
        x_1         = info_mask * x
        x_2         = noninfo_mask * x
        x           = self.reconstruct(x_1,x_2)
        return x
    
    def reconstruct(self,x_1,x_2):
        x_11,x_12 = torch.split(x_1, x_1.size(1)//2, dim=1)
        x_21,x_22 = torch.split(x_2, x_2.size(1)//2, dim=1)
        return torch.cat([ x_11+x_22, x_12+x_21 ],dim=1)


class CRU(nn.Module):
    '''
    alpha: 0<alpha<1
    '''
    def __init__(self, 
                 op_channel:int,
                 alpha:float = 1/2,
                 squeeze_radio:int = 2 ,
                 group_size:int = 2,
                 group_kernel_size:int = 3,
                 ):
        super().__init__()
        self.up_channel     = up_channel   =   int(alpha*op_channel)
        self.low_channel    = low_channel  =   op_channel-up_channel
        self.squeeze1       = nn.Conv2d(up_channel,up_channel//squeeze_radio,kernel_size=1,bias=False)
        self.squeeze2       = nn.Conv2d(low_channel,low_channel//squeeze_radio,kernel_size=1,bias=False)
        #up
        self.GWC            = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=group_kernel_size, stride=1,padding=group_kernel_size//2, groups = group_size)
        self.PWC1           = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=1, bias=False)
        #low
        self.PWC2           = nn.Conv2d(low_channel//squeeze_radio, op_channel-low_channel//squeeze_radio,kernel_size=1, bias=False)
        self.advavg         = nn.AdaptiveAvgPool2d(1)

    def forward(self,x):
        # Split
        up,low  = torch.split(x,[self.up_channel,self.low_channel],dim=1)
        up,low  = self.squeeze1(up),self.squeeze2(low)
        # Transform
        Y1      = self.GWC(up) + self.PWC1(up)
        Y2      = torch.cat( [self.PWC2(low), low], dim= 1 )
        # Fuse
        out     = torch.cat( [Y1,Y2], dim= 1 )
        out     = F.softmax( self.advavg(out), dim=1 ) * out
        out1,out2 = torch.split(out,out.size(1)//2,dim=1)
        return out1+out2


class ScConv(nn.Module):
    def __init__(self,
                op_channel:int,
                group_num:int = 16,
                gate_treshold:float = 0.5,
                alpha:float = 1/2,
                squeeze_radio:int = 2 ,
                group_size:int = 2,
                group_kernel_size:int = 3,
                 ):
        super().__init__()
        self.SRU = SRU( op_channel, 
                       group_num            = group_num,  
                       gate_treshold        = gate_treshold )
        self.CRU = CRU( op_channel, 
                       alpha                = alpha, 
                       squeeze_radio        = squeeze_radio ,
                       group_size           = group_size ,
                       group_kernel_size    = group_kernel_size )
    
    def forward(self,x):
        x = self.SRU(x)
        x = self.CRU(x)
        return x


if __name__ == '__main__':
    x       = torch.randn(1,32,16,16)
    model   = ScConv(32)
    print(model(x).shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/981175.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++信息学奥赛1190:上台阶

#include <iostream> using namespace std;long long arr[80]; // 用于存储斐波那契数列的数组int main() {int n;arr[1]1; // 初始化斐波那契数列的前三个元素arr[2]2;arr[3]4;for(int i4;i<71;i) { // 计算斐波那契数列的第4到第71个元素arr[i]arr[i-1]arr[i-2]…

为XDR扩展威胁检测响应提供响应解决方案

安全层面最本质的问题是检测与响应&#xff0c;而当前的检测与响应&#xff0c;还存在着一些痛点和难点亟需解决&#xff0c;响应运营层面仍存在着一些挑战。 各类安全防护设备每天会产生大量的安全告警&#xff0c;使得安全分析人员绝大部分时间和精力都“消耗”在告警信息中…

win11 使用 QEMU 配置龙芯 3A5000 虚拟环境

01 下载资源 本实验使用资源: 开源模拟器qemu 下载地址, qemu-w64-setup-20230822.exe loongarch 固件下载: QEMU_EFI_8.0.fd loongarch 基本镜像下载: archlinux-loong64.iso qemu安装在D:\install\qemu: D:\install\qemu>dir | findstr "qemu-system-loongarch"…

Qt 5.15集成Crypto++ 8.7.0(MSVC 2019)笔记

一、背景 笔者已介绍过在Qt 5.15.x中使用MinGW&#xff08;8.10版本&#xff09;编译并集成Crypto 8.7.0。 但是该编译出来的库&#xff08;.a和.dll&#xff09;不适用MSVC&#xff08;2019版本&#xff09;构建环境&#xff0c;需要重新编译&#xff08;.lib或和.dll&#xf…

LRTimelapse 6 for Mac(延时摄影视频制作软件)

LRTimelapse 是一款适用于macOS 系统的延时摄影视频制作软件&#xff0c;可以帮助用户创建高质量的延时摄影视频。该软件提供了直观的界面和丰富的功能&#xff0c;支持多种时间轴摄影工具和文件格式&#xff0c;并具有高度的可定制性和扩展性。 LRTimelapse 的主要特点如下&am…

Qt下SVG格式图片应用

SVG格式图片介绍 svg格式图片又称矢量图&#xff0c;该种格式的图片不同于png等格式的图片&#xff0c;采用的并不是位图的形式来组织图片&#xff0c;而是采用线条等组织图片&#xff0c;svg格式是图片的文件格式是xml&#xff0c;可以通过文件编译器打开查看svg格式内容。 …

halcon双目标定双相机标定

halcon双目标定 *取消更新 dev_update_off () *获取窗体句柄 dev_get_window (WindowHandle) *设置窗体字体样式 set_display_font (WindowHandle, 16, mono, true, false) *设置线条粗细 dev_set_line_width (3) *创建空对象 gen_empty_obj (ImageL) *读取指定文件内子集 li…

SpringMvc进阶

SpringMvc进阶 SpringMVC引言一、常用注解二、参数传递三、返回值 SpringMVC引言 在Web应用程序开发中&#xff0c;Spring MVC是一种常用的框架&#xff0c;它基于MVC&#xff08;Model-View-Controller&#xff09;模式&#xff0c;提供了一种结构化的方式来构建可维护和可扩…

python报错ModuleNotFoundError: No module named ‘XXX‘

记录一下改神经网络过程中遇到的小bug 在对网络结构进行更改时&#xff0c;不可避免要把别人的文件copy到自己的项目里。这时可能会遇到包导入的错误。正常情况下&#xff0c;导入的包应该大致包括三种方式&#xff1a; 1、导入外部包&#xff0c;如果这里错了就自己去pip ins…

Hadoop的第二个核心组件:MapReduce框架第四节

Hadoop的第二个核心组件&#xff1a;MapReduce框架 十、MapReduce的特殊应用场景1、使用MapReduce进行join操作2、使用MapReduce的计数器3、MapReduce做数据清洗 十一、MapReduce的工作流程&#xff1a;详细的工作流程第一步&#xff1a;提交MR作业资源第二步&#xff1a;运行M…

(二十三)大数据实战——Flume数据采集之采集数据聚合案例实战

前言 本节内容我们主要介绍一下Flume数据采集过程中&#xff0c;如何把多个数据采集点的数据聚合到一个地方供分析使用。我们使用hadoop101服务器采集nc数据&#xff0c;hadoop102采集文件数据&#xff0c;将hadoop101和hadoop102服务器采集的数据聚合到hadoop103服务器输出到…

【算法】选择排序

选择排序 选择排序代码实现代码优化 排序&#xff1a; 排序&#xff0c;就是使一串记录&#xff0c;按照其中的某个或某些关键字的大小&#xff0c;递增或递减的排列起来的操作。 稳定性&#xff1a; 假定在待排序的记录序列中&#xff0c;存在多个具有相同的关键字的记录&…

苍穹外卖 day12 Echats 营业台数据可视化整合

苍穹外卖-day12 课程内容 工作台Apache POI导出运营数据Excel报表 功能实现&#xff1a;工作台、数据导出 工作台效果图&#xff1a; 数据导出效果图&#xff1a; 在数据统计页面点击数据导出&#xff1a;生成Excel报表 1. 工作台 1.1 需求分析和设计 1.1.1 产品原型 工作台是系…

GuLi商城-前端基础Vue-整合ElementUI快速开发

npm安装 启动项目&#xff1a;npm run dev http://localhost:8082/#/hello

强大易用的开源 建站工具Halo

特点 可插拔架构 Halo 采用可插拔架构&#xff0c;功能模块之间耦合度低、灵活性提高。支持用户按需安装、卸载插件&#xff0c;操作便捷。同时提供插件开发接口以确保较高扩展性和可维护性。 ☑ 支持在运行时安装和卸载插件 ☑ 更加方便地集成三方平台 ☑ 统一的可配置设置表…

Python UI自动化 —— 关键字+excel表格数据驱动

步骤&#xff1a; 1. 对selenium进行二次封装&#xff0c;创建关键字的库 2. 准备一个表格文件来写入所有测试用例步骤 3. 对表格内容进行读取&#xff0c;使用映射关系来对用例进行调用执行 4. 执行用例 1. 对selenium进行二次封装&#xff0c;创建关键字的库 from time imp…

RabbitMQ原理和界面操作

参考 ## 原理 https://zhuanlan.zhihu.com/p/344298279### https://blog.csdn.net/qq_53263107/article/details/127844208 界面操作 界面术语 Channels 通道的属性&#xff1a; channel&#xff1a;名称。 Virtual host&#xff1a;所属的虚拟主机。 User name&#xff1a…

机器人任务挖掘与智能超级自动化技术解析

本文为上海财经大学教授、安徽财经大学学术副校长何贤杰出席“会计科技Acctech应对不确定性挑战”高峰论坛时的演讲内容整理。何贤杰详细介绍了机器人任务挖掘与智能超级自动化技术的发展背景、关键技术和应用场景。 从本质来说&#xff0c;会计是非常适合智能化、自动化的。会…

【算法】堆排序 详解

堆排序 详解 堆排序代码实现 排序&#xff1a; 排序&#xff0c;就是使一串记录&#xff0c;按照其中的某个或某些关键字的大小&#xff0c;递增或递减的排列起来的操作。 稳定性&#xff1a; 假定在待排序的记录序列中&#xff0c;存在多个具有相同的关键字的记录&#xff0c…

HarmonyOS实现静态与动态数据可视化图表

一. 样例介绍 本篇Codelab基于switch组件和chart组件&#xff0c;实现线形图、占比图、柱状图&#xff0c;并通过switch切换chart组件数据的动静态显示。要求实现以下功能&#xff1a; 实现静态数据可视化图表。打开开关&#xff0c;实现静态图切换为动态可视化图表 相关概念 s…