cond conv 代码-思想

news2024/12/23 22:48:42

参考博客:
1 解析图示最清楚
动态卷积之CondConv思想和代码实现_&永恒的星河&的博客-CSDN博客
2 知乎的解释,简洁明了
CondConv代码解析 - 知乎
知乎提供code:External-Attention-pytorch/CondConv.py at master · xmu-xiaoma666/External-Attention-pytorch · GitHub
2 中知乎的评论给出的更好的代码
CondConv-pytorch/condconv.py at master · nibuiro/CondConv-pytorch · GitHub

论文题目: CondConv: Conditionally Parameterized Convolutions for Efficient Inference

论文地址:  https://link.zhihu.com/?target=https%3A//arxiv.org/abs/1904.04971

代码地址:  https://link.zhihu.com/?target=https%3A//github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv

1 介绍

Cond conv: 是2019年发表在Google Brains上关于卷积CondConv的文章
即插即用的模块

常规卷积进行构建网络,有如下假设:所有的样本共享卷积网络中的卷积参数.
需求:
提升模型容量就需要增加网络的参数,深度,通道数,这将导致模型的计算量和参数量增加,模型部署难度大。若要模型的实时性高,这就需要模型拥有较低的参数量和计算量.
 

cond conv目的提升模型的容量,同时保持实时性。
cond conv核心思想:为了打破传统卷积的特性,CondConv将卷积核参数化为多个专家知识的线性组合
公式: (a1*W1+a2*W2+...+an*Wn)*x
a1,a2,a3,...an是通过梯度下降法学习的权重系数
x是输入样本.可以通过提升专家的数量来提升模型的容量,这比提升卷积核的尺寸更有效,同时专家知识只需要一次线性组合,就可以提升模型容量的同时保持高效的推理.

Mixture of Experts(MoE)公式:α1*(W1∗x)+. . .+αn(Wn∗x)
Mixture of Experts(MoE)结构: 采用更细粒度的集成方式,每一个卷积层都拥有多套权重,卷积层的输入分别经过不同的权重卷积之后组合输出,缺点是但这计算量依旧很大.


CondConv公式: (α1*W 1+ . . . + αn*Wn)∗x       =α1*(W1∗x)+. . .+αn(Wn∗x) 与MoE等同
Cond conv结构:可以解决MoE计算大问题,降低计算量。
既然输入相同,卷积是一种线性计算,COMBINE也是一个线性计算(比如加权求和),作者将多套权重加权组合之后,只做一次卷积就能完成相当的效果!
2者区别: MoE是每个卷积核分别与x计算再组合,cond conv是先组合卷积核,在与x计算。

细致实现流程图(感谢这位博主的绘制,参考博客1)(tensorflow version)

输入:X(N,H,W,C)
N:数据Batch的大小
H和W:输入图片的高和宽
C:输入图片的通道数
两条输出:右边输出,左边输出,最后对各自的输出进行整合.
(h,w,cin,cout)表示卷积核大小
h和w分别表示卷积核的高和宽
cin,cout分别表示卷积核的输入和输出通道数.

右边线路:由计算样本生成多个卷积核的各自权重
step1: 对输入X,进行GAP操作(GlobalAveragePooling2D)操作,具体在维度(H,W),
输出大小为(N,C)
step2: 之后经过FC层,学习不同输入样本对用num_experts个卷积的各自的权重系数,输出为(N,num_experts)
step3: 采用Sigmoid归一化到(0,1)之间,输出为(N,num_experts)
step4: step3输出权重系数和num_experts个卷积核权重通过矩阵的相乘,赋予到相应的卷积上,输出各个样本对应加权后的卷积核权重,输出大小为:(N,h*w*cin*cout)
step5: step4中的输出在N维度进行Split操作,得到各个样本对应加权后卷积核

左边线路:对输入X依次通过对应加权输出的卷积核权重,完成CondConv。
step1:将X在N维度进行split操作
step2:step1中输出结果和右边线输出对应卷积权重进行卷积操作,之后进行Concat。

上面已经说的很细致了,下面介绍实验效果


总结:
CondConv打破了静态卷积的假设:卷积核对所有输入“一视同仁”。
提升模型容量保持高效推理:提升卷积核生成函数的尺寸与复杂度。
由于卷积核参数仅需计算一次,相比卷积计算,这些额外的计算量可以忽略。
即:提升卷积核生成的计算量>(优于)添加更多卷积更多通道数

2 代码详解torch(来自知乎条件参数化卷积(CondConv))

为了方便理解 ,下面把原代码中的 initial_weights 和 bias 相关的部分删掉了。
rount_fn部分:是 attention 函数,输入为 [N, C, H, W]
需要两个参数,in_planes为输特征通道数,K为专家个数。
输出shape为[N, K],即这里是针对N个批次的K个卷积核的权重。

# 输入为 [N, C, H, W],需要两个参数,Cin为输入特征通道数,K 为专家个数
class Attention(nn.Module):
    def __init__(self,Cin,K):
        super().__init__()
        self.avgpool=nn.AdaptiveAvgPool2d(1)  # 池化操作
        self.net=nn.Conv2d(Cin, K, kernel_size=1)  # Cin通道变为K
        self.sigmoid=nn.Sigmoid()  # 归一化

    def forward(self,x):
        # 将输入特征全局池化为 [N,Cin,H,W]->[N, Cin, 1, 1]
        att=self.avgpool(x)
        # 使用1*1卷积,转化为 [N, Cin, 1, 1]->[N, K, 1, 1]
        att=self.net(att)
        # 将特征转化为二维 [N, K, 1, 1]->[N, K]
        att=att.view(x.shape[0],-1) 
        # 使用 sigmoid 函数输出归一化到 [0,1] 区间
        return self.sigmoid(att)

CondConv 是一种特殊的动态卷积,增加卷积核生成函数的大小和复杂性(增加capacity容量)
CondConv还利用样本的特点来提高模型性能(权重是基于样本生成的类似SE block)
 

class CondConv(nn.Module):
    def __init__(self,Cin,Cout,kernel_size,stride,padding=0,
                 groups=1,K=4):
        super().__init__()
        self.Cin = Cin  # 输入通道
        self.Cout = Cout  # 输出通道
        self.K = K  # K个权重
        self.groups = groups
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.attention = Attention(Cin=Cin,K=K)
        # weight [K, Cout, Cin, kernelz_size, kernel_size]
        self.weight = nn.Parameter(torch.randn(K,Cout,Cin//groups,
                                             kernel_size,kernel_size),requires_grad=True)

    def forward(self,x):
        ### part1 weight
        # 调用 attention 函数得到归一化的权重 [N,Cin,H,W]->[N, K]
        N,Cins, H, W = x.shape
        softmax_att=self.attention(x)
        
        ### part2 x
        # [N, Cin, H, W]->[1, N*Cin, H, W]
        x=x.view(1, -1, H, W)
        
        ### part3 conv
        # 生成随机weight[K, Cout, C_in/groups, 3, 3] (卷积核一般为3*3)
        # 注意添加了 requires_grad=True,这样里面的参数是可以优化的
        weight = self.weight
        # 改变 weight 形状为 [K,Cout,Cin,3,3]->[K, C_out*(C_in/groups)*3*3]
        weight = weight.view(self.K, -1) 

        # part4: 新的wconv = weight*conv
        # 矩阵相乘:[N, K]*[K, Cout*(Cin/groups)*3*3] = [N, Cout*(Cin/groups)*3*3]
        aggregate_weight = torch.mm(softmax_att,weight)
        # 改变形状为:[N, Cout*Cin/groups*3*3]->[N*Cout, Cin/groups, 3, 3],即新的卷积核权重
        aggregate_weight = aggregate_weight.view(
            N*self.Cout, self.Cin//self.groups,
            self.kernel_size, self.kernel_size)
        # 用新生成的卷积核进行卷积,[1, N*Cin, H, W] conv [N*Cout, Cin/groups, 3, 3]
        # 输出为 [1, N*Cout, H, W]
        output=F.conv2d(x,weight=aggregate_weight,
                        stride=self.stride, padding=self.padding,
                        groups=self.groups*N)
        # 形状恢复为 [N, C_out, H, W]        
        output=output.view(N, self.out_planes, H, W)
        return output

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/509087.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

详解MySQL索引失效

目录 B树结构 测试数据 索引失效的情况 没有用到索引 违反左前缀原则 范围查询断索引 like需要分情况 结果数据超过半数 B树结构 索引失效的根本原因其实就是违反了B树的结构特性,查找的时候没办法在B树上继续走下去,所以首先我们来回顾一下B树…

进程控制(中)

目录: 1.status获取子进程退出的退出码和信号 2.不进行位操作方式获取子进程的退出码和信号 3.waitpid 第三个参数options ------------------------------------------------------------------------------------------------------------------------------- 1.…

工业4.0,为什么数字化转型这么难,上了ERP还要上MES

工业4.0时代,中国制造企业已经面临着与国际先进水平的差距,更多的企业在寻找新的发展道路,数字化转型是制造业企业转型升级的必由之路。但是,许多制造型企业由于在传统生产过程中,业务数据不能得到有效监控、生产过程数…

人脸修复增强调研

Real-ESRGAN 工程地址:https://github.com/xinntao/Real-ESRGAN 效果: 人脸增强部分,调用的GFPGAN. GFPGAN 工程地址:https://github.com/TencentARC/GFPGAN 论文效果: BasicSR-ESRGAN: 项目地址&a…

巨杉数据库荣获新睿之星,赋能大湾区技术与产业升级

巨杉数据库凭借多年深耕分布式数据库的技术积累和创新能力,于广州投资年会上荣获新睿之星奖项,该奖项不仅是对巨杉数据库的肯定,也充分肯定广州培育本土高新企业的发展成果。 4月18日,2023年第九届广州国际投资年会在广州白云国际…

2023年10大最佳「内容日历」软件工具

随随便便运行一个社交媒体策略就能成功,这几乎是不可能。你需要提前规划排期,收集资源并与他人合作,来创造出能吸引受众注意力的内容。 所有这些规划、研究和创意都需要一个地方汇总聚合,这就是内容日历软件的用武之地。 有了合…

C++ 多线程:实现一个功能完整的线程池

C 多线程(四):实现一个功能完整的线程池 今天我们来聊一聊异步编程的知识。在分布式系统中,一个功能完整的线程池类是一切代码的前提。 一个『合格』的线程池该具备哪些功能? 首先,很自然地想到『线程池类…

被嫌弃可视化太丑?这种可视化大屏搭建方法,分分钟让老板满意

在数据可视化中,使用频率最高的展览方式一定是地图可视化。基本上现有的大屏都是以地图作为主视图来呈现的,没有一幅地图放到大屏中央,已经不好意思给同行说明自己企业数据分析有多牛了。在地图可视化中,最炫酷的一定是3D可视化大…

家用洗地机有什么推荐的吗?家用洗地机哪款好

洗地机是创新、高效的清洁工具,其具有高性能的清洁能力和卓越的操作体验。与传统的清洁工具相比,洗地机可以迅速而彻底地打扫地面,降低清洁时间和人力成本,让我们在工作之余不用花费大量的时间和精力去打扫卫生,下面就…

TCP 协议和 UDP 协议 的优势和劣势

TCP 最核心的价值是提供了可靠性,而 UDP 最核心的价值是灵活,你几乎可以用它来做任何事情。例如:HTTP 协议 1.1 和 2.0 都基于 TCP,而到了 HTTP 3.0 就开始用 UDP 了。UDP 在数据传输、网络控制、音视频、Web 技术中,都…

Chroma向量数据库

嵌入向量(vector embedding)是表示任何类型数据的 A.I 原生方式,使它们非常适合与各种 A.I 驱动的工具和算法一起使用。 它们可以表示文本、图像,很快还可以表示音频和视频。 有许多创建嵌入的选项,无论是在本地使用已…

Fiddler 微信小程序抓图教程(傻瓜式|汉化版|狗看了直呼内行)

前言 本篇文章主要给大家详细讲解如何用Fiddler爬取微信小程序的图片,内容图文并茂,流程非常简单,我们开始吧。 目录 获取软件并打开点击工具设置相关代理如何抓图答疑总结 一、获取软件并打开 1、通过百度网盘下载获取安装包(链接是永久的…

深度学习学习路线:从入门到精通

深度学习是机器学习的一个分支,已经成为人工智能领域最热门的技术之一。想要深入学习深度学习,需要具备一定的数学和编程基础。以下是一个深度学习学习路线的建议,可以帮助你快速入门深度学习: 数学基础: 线性代数&am…

从模式识别到图像文档分析——浅析场景文本识别研究

目录 一、场景文本识别工作简述二、基于视觉关系预测复杂场景文本识别2.1、FPN骨干网络2.2、文本片段检测模块2.3、候选关系对构建模块2.4、基于关系网络的连接关系预测2.5、损失函数 三、文档图像智能分析与处理前沿研究 文字作为人类语言的书面形式,是文本图像中最…

APP外包项目的线上维护方案

APP的使用已经非常普及,不论是2C还是2B的APP都已经渗透到了我们生活的方方面面,对于APP的开发公司来说APP项目的线上维护是一个非常重要的问题。如果APP项目比较重要而且用户规模比较大,那更需要专业的技术团队来维护。今天和大家分享这方面的…

神经注释精细化:用于肾上腺分析的新3D数据集的开发

文章目录 Neural Annotation Refinement:Development of a New 3D Dataset for Adrenal Gland Analysis摘要本文方法Neural Annotation Refinement 实验结果 Neural Annotation Refinement:Development of a New 3D Dataset for Adrenal Gland Analysis 摘要 人工注释是不完美…

如何将两张图片合成一张,多方式提高10倍效率

如何将两张图片合成一张?在日常工作和生活过程中,图片无处不在。图片能够及时记录当下生活,并且容易保存和传播。随着技术手段的创新和发展,图片质量也在不断提高。如何将多张图片进行有效合并,从而便于保存呢&#xf…

庄懂的TA笔记(十四)<特效:流动 + 扰动>

庄懂的TA笔记(十四)<特效:流动 扰动> 效果展示: 正文: 大纲: 一、增广: 1、排序问题: 造成这个问题的原因是,他在取背景前,小人的胳膊不算是透…

常用的排序算法--JavaScript

1.冒泡排序 这个函数使用了双重循环,第一个循环用于遍历数组中的每个元素,第二个循环用于比较相邻的元素,如果它们的顺序不正确,则交换它们的位置。在每次内部循环之后,最大的元素都会被移到数组的末尾,因此…

MSR015/MSR025低温漂、低功耗电压基准可pin对pin兼容REF015/REF025

MSR015/MSR025 是低温漂、低功耗、高精度 CMOS 电压基准, 具有0.05% 初始精度、低功耗特点。可pin对pin兼容REF015/REF025。该器件的低输出电压迟滞和低长期输出电压漂移特性,进一步提高稳定性和系统可靠性。 此外,器件的小尺寸和低运行电流特…