pytorch代码实现之SAConv卷积

news2024/11/17 13:53:40

SAConv卷积

SAConv卷积模块是一种精度更高、速度更快的“即插即用”卷积,目前很多方法被提出用于降低模型冗余、加速模型推理速度,然而这些方法往往关注于消除不重要的滤波器或构建高效计算单元,反而忽略了特征内部的模式冗余。
原文地址:Split to Be Slim: An Overlooked Redundancy in Vanilla Convolution

由于同一层内的许多特征具有相似却不平等的表现模式。然而,这类具有相似模式的特征却难以判断是否存在冗余或包含重要的细节信息。因此,不同于直接移除不确定的冗余特征方案,提出了一种基于Split的卷积计算单元(称之为SPConv),它运训存在相似模型冗余且仅需非常少的计算量。

SPConv结构图

首先,将输入特征拆分为representative部分与uncertain部分;然后,对于representative部分特征采用相对多的计算复杂度操作提取重要信息,对于uncertain部分采用轻量型操作提取隐含信息;最后,为重新校准与融合两组特征,作者采用了无参特征融合模块。该文所提SPConv是一种“即插即用”型模块,可用于替换现有网络中的常规卷积。

​无需任何技巧,在GPU端的精度与推理速度方面,基于SPConv的网络均可取得SOTA性能。该文主要贡献包含下面几个方面:
(1)重新对常规卷积中的特征冗余问题进行了再思考,提出了将输入分成两部分:representative与uncertain,分别针对两部分进行不同的信息提取;
(2)设计了一种“即插即用”型SPConv模块,它可以无缝替换现有网络中的常规卷积,且在精度与GPU推理速度上均可能优于SOTA性能,同时具有更少的FLOPs和参数量。

代码实现

class ConvAWS2d(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias)
        self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1))
        self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1))

    def _get_weight(self, weight):
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)
        weight = weight / std
        weight = self.weight_gamma * weight + self.weight_beta
        return weight

    def forward(self, x):
        weight = self._get_weight(self.weight)
        return super()._conv_forward(x, weight, None)

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        self.weight_gamma.data.fill_(-1)
        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
                                      missing_keys, unexpected_keys, error_msgs)
        if self.weight_gamma.data.mean() > 0:
            return
        weight = self.weight.data
        weight_mean = weight.data.mean(dim=1, keepdim=True).mean(dim=2,
                                       keepdim=True).mean(dim=3, keepdim=True)
        self.weight_beta.data.copy_(weight_mean)
        std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)
        self.weight_gamma.data.copy_(std)
    
class SAConv2d(ConvAWS2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 s=1,
                 p=None,
                 g=1,
                 d=1,
                 act=True,
                 bias=True):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=s,
            padding=autopad(kernel_size, p),
            dilation=d,
            groups=g,
            bias=bias)
        self.switch = torch.nn.Conv2d(
            self.in_channels,
            1,
            kernel_size=1,
            stride=s,
            bias=True)
        self.switch.weight.data.fill_(0)
        self.switch.bias.data.fill_(1)
        self.weight_diff = torch.nn.Parameter(torch.Tensor(self.weight.size()))
        self.weight_diff.data.zero_()
        self.pre_context = torch.nn.Conv2d(
            self.in_channels,
            self.in_channels,
            kernel_size=1,
            bias=True)
        self.pre_context.weight.data.fill_(0)
        self.pre_context.bias.data.fill_(0)
        self.post_context = torch.nn.Conv2d(
            self.out_channels,
            self.out_channels,
            kernel_size=1,
            bias=True)
        self.post_context.weight.data.fill_(0)
        self.post_context.bias.data.fill_(0)
        
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        # pre-context
        avg_x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1)
        avg_x = self.pre_context(avg_x)
        avg_x = avg_x.expand_as(x)
        x = x + avg_x
        # switch
        avg_x = torch.nn.functional.pad(x, pad=(2, 2, 2, 2), mode="reflect")
        avg_x = torch.nn.functional.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
        switch = self.switch(avg_x)
        # sac
        weight = self._get_weight(self.weight)
        out_s = super()._conv_forward(x, weight, None)
        ori_p = self.padding
        ori_d = self.dilation
        self.padding = tuple(3 * p for p in self.padding)
        self.dilation = tuple(3 * d for d in self.dilation)
        weight = weight + self.weight_diff
        out_l = super()._conv_forward(x, weight, None)
        out = switch * out_s + (1 - switch) * out_l
        self.padding = ori_p
        self.dilation = ori_d
        # post-context
        avg_x = torch.nn.functional.adaptive_avg_pool2d(out, output_size=1)
        avg_x = self.post_context(avg_x)
        avg_x = avg_x.expand_as(out)
        out = out + avg_x
        return self.act(self.bn(out))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1001996.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BUUCTF Reverse/[羊城杯 2020]login(python程序)

查看信息,python文件 动调了一下,该程序创建了一个线程来读入数据,而这个线程的代码应该是放在内存中直接执行的,本地看不到代码,很蛋疼 查了下可以用PyInstaller Extractor工具来解包,可以参考这个Python解包及反编译…

华为云云服务器云耀L实例评测 | 在华为云耀L实例上搭建电商店铺管理系统:一次场景体验

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

sqli第一关

1.在下使用火狐访问sqlilabs靶场并使用burpsuite代理火狐。左为sqlilabs第一关,右为burpsuite。 2.输入?id1 and 11 与?id1 and 12试试 可以看出没有变化哈,明显我们输入的语句被过滤了。在?id1后面尝试各种字符,发现单引号 包…

Linux内核分析与应用4-内存管理

本系列是对 陈莉君 老师 Linux 内核分析与应用[1] 的学习与记录。讲的非常之好,推荐观看 留此记录,蜻蜓点水,可作抛砖引玉 4.1 Linux内存管理机制 lscpu[2] 命令, 类似是优化后的 cat /proc/cpuinfo 实现虚拟内存的几种机制: 当 程序一旦跑起来,那就变成…

IDEA在创建包时如何把包分开实现自动分层

IDEA在创建包时如何把包分开实现自动分层 文章目录 IDEA在创建包时如何把包分开实现自动分层一、为什么要把包分开二、建包时如何把包自动分开三、如何编写配置文件路径? 一、为什么要把包分开 一开始的时候,我也一直以为包连在一起和分开没什么区别&am…

二叉搜索树/二叉排序树/二叉查找树

文章目录 1.概念2.操作3.实现3.1框架3.2BSTree.h3.3test.cpp 1.概念 二叉搜索树又称二叉排序树,它或者是一棵空树,或者是具有以下性质的二叉树: 若它的左子树不为空,则左子树上所有节点的值都小于根节点的值 若它的右子树不为空,…

python 学习笔记(5)——SMTP 使用QQ邮箱发送邮件

目录 发送邮件 1、准备工作: 2、发送纯文本信息内容: 3、发送 HTML 格式的内容: 4、发送带附件的邮件: 5、群发(一个邮件,发给多个人): 发送邮件 以下都 以 QQ邮箱 为发送方举…

敏捷开发方法管理项目,适应变化,引领未来

​敏捷开发方法是一种灵活且高效的项目管理方法,旨在应对不断变化的需求和快速发展的项目环境。使用敏捷开发方法可以帮助团队更好地应对不确定性,提高项目的质量和效率。以下是使用敏捷开发方法管理项目的具体步骤: 明确项目目标和范围 在…

算法通过村第六关-树白银笔记|层次遍历

文章目录 前言1. 层次遍历介绍2. 基本的层次遍历与变换2.1 二叉树的层次遍历2.2 层次遍历-自底向上2.3 二叉树的锯齿形层次遍历2.4 N叉树的层次遍历 3. 几个处理每层元素的题目3.1 在每棵树行中找出最大值3.2 在每棵树行中找出平均值3.3 二叉树的右视图3.4 最底层最左边 总结 前…

C高级day4(shell脚本)

一、Xmind整理: 二、上课笔记整理: 1.创建一个文件,给组用户可读权限,所属用户可写权限,其他用户可执行权限,使用if判断文件有哪些权限 #!/bin/bash touch 1 chmod 241 1 if [ -r 1 ] thenecho "文件…

为 DevOps 战士准备的 Linux 命令

点击链接了解详情 这篇文章将帮助理解DevOps工程师所需的大部分重要且经常使用的Linux命令。 要执行这些命令,你可以使用任何Linux机器、虚拟机或在线Linux终端来迅速开始使用这些命令。 系统信息命令: hostname - 显示系统主机的名称。 hostid - 显示…

openGauss学习笔记-66 openGauss 数据库管理-创建和管理schema

文章目录 openGauss学习笔记-66 openGauss 数据库管理-创建和管理schema66.1 背景信息66.2 注意事项66.3 操作步骤66.3.1 创建管理用户及权限schema66.3.2 使用schema66.3.3 schema的搜索路径66.3.4 schema的权限控制66.3.5 删除schema openGauss学习笔记-66 openGauss 数据库管…

Codeforces Round 827 (Div. 4) D 1e5+双重for循环技巧

Codeforces Round 827 (Div. 4) D 做题链接:Codeforces Round 827 (Div. 4) 给定一个由 n个正整数 a1,a2,…,an(1≤ai≤1000)组成的数组。求ij的最大值,使得ai和aj共质,否则−1,如果不存在这样的i&#…

github 创建自己的分支 并下载代码

github创建自己的分支 并下载代码 目录概述需求: 设计思路实现思路分析1.进入到master分支,git checkout master;2.master-slave的个人远程仓库3.爬虫调度器4.建立本地分支与个人远程分支之间的联系5.master 拓展实现 参考资料和推荐阅读 Survive by day…

Python基于Flask的招聘信息爬取、招聘信息可视化系统

招聘信息可视化系统 一、介绍 此系统是一个实时分析招聘信息的系统,应用Python爬虫、Flask框架、Echarts、VUE等技术实现。 二、系统运行图 1、数据概览 将爬取到的数据进行展示,点击公司信息和职位信息可以跳转到相应的网址,支持多条件…

这一次,我顿悟了

大家好,我是苍何。昨晚和编程导航 星球嘉宾也是我的引路人闫(yn) 小林大佬,畅聊了 4 个 小时,至今内心还是久久不能平静。 小林和我一样是跨界转行,他是医学院毕业,大二开始自学编程&#xff0…

【分布式】分布式事务:2PC

分布式事务的问题可以分为两部分: 并发控制 concurrency control原子提交 atomic commit 分布式事务问题的产生场景:一份数据被分片存在多台服务器上,那么每次事务处理都涉及到了多台机器。 可序列化(并发控制)&…

软件设计师学习笔记10-死锁资源数计算+进程资源图+段页式存储

目录 1.死锁资源数计算 1.1死锁 1.2进程管理与死锁资源的计算 2.进程资源图 3.段页式存储 3.1页式存储 3.1.1页式存储组织 3.1.2完整页表及页面淘汰原则 3.1.3页面置换算法(了解一下) 3.2段式存储 1.死锁资源数计算 1.1死锁 (1)死锁的概念:所谓死锁&…

C++-day4

仿照string类&#xff0c;完成myString 类 #include <iostream> #include <cstring> using namespace std; class myString { private:char *str; //记录c风格的字符串int size; //记录字符串的实际长度 public://无参构造myString():size(10…

mac 13.x 打开第三方应用,提示已损坏无法打开

前排提示&#xff0c;不一定有效 1、先在终端执行下面这个&#xff0c;因为要提权&#xff0c;输入自己的密码 sudo xattr -r -d com.apple.quarantine 具体应用 # 具体应用是一个路径&#xff0c;拖入 访达——应用程序——第三方应用 到终端就行 # sudo xattr -r -d com.app…