【CNN轻量化】ParameterNet: Parameters Are All You Need 参数就是你所需要的

news2024/11/24 8:31:00

论文链接:http://arxiv.org/abs/2306.14525
代码链接:https://github.com/huawei-noah/Efficient-AI-Backbones

一、摘要

  现有的低FLOPs模型(轻量化模型)无法从大规模预训练中受益。本文旨在增加大规模视觉预训练模型中的参数数量,同时最大限度地减少FLOPs的增加。利用动态卷积将额外的参数引入网络中,而仅仅带来了FLOPs的轻微增加。ParameterNet方法使得低FLOPs网络能够充分利用大规模视觉预训练。此外,将ParameterNet概念扩展到语言领域,以提高推理结果的同时保持推理速度。在大规模ImageNet-22K上的实验证明了我们ParameterNet方案的优越性。例如,ParameterNet-600M在ImageNet上的准确率比广泛使用的Swin Transformer更高(81.6% vs. 80.9%),而且FLOPs要低得多(0.6G vs. 4.5G)。在语言领域,经ParameterNet增强的LLaMA-1B比原始LLaMA准确率高出2%。

二、关键问题与创新点

1、关键问题

  如下图所示,随着FLOPs的逐渐增加,准确性增加,无论是在ImageNet-1K还是ImageNet-22K的预训练中。对于具有高FLOPs(>10G)的模型,在ImageNet-22K上的预训练优于在ImageNet-1K上的预训练。然而,对于具有较低FLOPs(<4G)的模型,更多数据的预训练并不会提高性能。如图中在ImageNet-22K上预训练的FLOPs小于2G的EfficientNetV2模型无法比在ImageNet-1K上预训练的模型表现更好。

  通过对Transformer和CNN网络的观察,得出了一个经验性结论:低FLOPs模型无法从大规模预训练中受益,这被称为低FLOPs陷阱。

fig2
fig3

2、主要贡献

  • 低FLOPs陷阱,即高FLOPs模型的性能随着更多训练数据而增加,但低FLOPs模型的性能不增加。

  • 提出参数比FLOPs对于大规模视觉预训练更为重要,并进一步引入ParameterNet方案,通过增加更多参数而保持低FLOPs。

  • 提出的ParameterNet方案可以克服低FLOPs陷阱,实验结果表明,在视觉和语言任务中,ParameterNet在大规模预训练中取得了显著更高的性能。

二、理论

  当FLOPs高于5G FLOPs时,基于Transformer的模型在计算成本相似的情况下始终优于CNN。至于更小的模型,特别是在600M FLOPs内的移动级模型,具有局部性和平移等变性的CNN仍然占主导地位。文中选择CNN作为基础模型GhostNet用于视觉任务的高效主干网络,设计原则是在保持低FLOPs的同时增加更多参数。

  引入了参数增强函数,旨在引入更多参数:
W ′ = f ( W ) W^{′} = f(W) W=f(W),
函数 f f f应满足两个基本规则:1)不需要太多的计算成本,2)大幅增加模型的容量或可训练参数。例如动态卷积和重新参数化卷积

1、动态卷积

  文中主要考虑高效的动态卷积(图4中的一种多专家MoE层),它在几乎不增加额外FLOPs的情况下多倍增加了参数数量。
  具有 M M M个动态专家的动态卷积可以写成:
Y = X ∗ W ′ , W ′ = ∑ i = 1 M = α i W i , Y = X * W^{′},W^{′} = \sum_{i=1}^{M} = \alpha_{i} W_{i}, Y=XWW=i=1M=αiWi,
其中 W i ∈ R C o u t × C i n × H × W W_{i} \in R^{Cout×Cin×H×W} WiRCout×Cin×H×W是第i个卷积权重张量, α i \alpha_{i} αi是相应的动态系数。系数 α i \alpha_{i} αi是根据不同的输入样本动态生成的,一种典型的方式是基于输入使用MLP模块生成:
α = s o f t m a x ( M L P ( P o o l ( X ) ) ) , \alpha = softmax(MLP(Pool(X))), α=softmax(MLP(Pool(X))),
其中 α ∈ R M \alpha \in R^{M} αRM。方程式4中的系数生成与原始卷积层相比只带来了可忽略的FLOPs。通过这种方式,使用动态卷积实现的ParameterNet可以大幅增加更多参数,同时最小化FLOPs的增加。
fig4

  对于标准卷积层,参数数量为 C o u t ⋅ C i n ⋅ K ⋅ K C_{out} \cdot C_{in} \cdot K \cdot K CoutCinKK,FLOPs数量为 H ′ ⋅ W ′ ⋅ C o u t ⋅ C i n ⋅ K ⋅ K H^{′} \cdot W^{′} \cdot C_{out} \cdot C_{in} \cdot K \cdot K HWCoutCinKK。动态卷积包括系数生成模块、动态权重融合和卷积过程。具有 C i n C_{in} Cin个隐藏维度的系数生成模块需要 C i n 2 + C i n M C^{2}_{in} + C_{in} M Cin2+CinM个参数和 C i n 2 + C i n M C^{2}_{in} + C_{in} M Cin2+CinM个FLOPs。动态权重融合是无参数的,具有 M ⋅ C o u t ⋅ C i n ⋅ K ⋅ K M \cdot C_{out} \cdot C_{in} \cdot K \cdot K MCoutCinKK个FLOPs。因此,动态卷积的总参数和FLOPs数量分别为
C i n 2 + C i n M + M ⋅ C o u t ⋅ C i n ⋅ K ⋅ K C^{2}_{in} + C_{in} M + M \cdot C_{out} \cdot C_{in} \cdot K \cdot K Cin2+CinM+MCoutCinKK C i n 2 + C i n M + M ⋅ C o u t ⋅ C i n ⋅ K ⋅ K + H ′ ⋅ W ′ ⋅ C o u t ⋅ C i n ⋅ K ⋅ K C^{2}_{in} + C_{in} M + M \cdot C_{out} \cdot C_{in} \cdot K \cdot K + H^{′} \cdot W^{′} \cdot C_{out} \cdot C_{in} \cdot K \cdot K Cin2+CinM+MCoutCinKK+HWCoutCinKK

  动态卷积相对于标准卷积的参数比率为:
R p a r a m = C i n 2 + C i n M + M C o u t C i n K 2 C o u t C i n K K = C i n C o u t K 2 + M C o u t K 2 + M ≈ 1 K 2 + M . ( M ≪ C o u t K 2 , C i n ≈ C o u t R_{param} = \frac{C_{in}^{2} + C_{in}M + M C_{out} C_{in} K^{2}}{C_{out} C_{in} K K} = \frac{C_{in}}{C_{out} K^{2}} + \frac{M}{C_{out} K ^{2}} + M \approx \frac{1}{K^{2}} + M.(M ≪ C_{out} K^{2},C_{in} \approx C_{out} Rparam=CoutCinKKCin2+CinM+MCoutCinK2=CoutK2Cin+CoutK2M+MK21+M.MCoutK2CinCout
  FLOPs比率为:
R f l o p s = C i n 2 + C i n M + M C o u t C i n K 2 + H ′ W ′ C o u t C i n K 2 H ′ W ′ C o u t C i n K 2 = C i n H ′ W ′ C o u t C i n K 2 + M H ′ W ′ C o u t C i n K 2 + M H ′ W ′ + 1 ≈ 1. ( 1 < M ≪ H ′ W ′ , C i n ≈ C o u t R_{flops} = \frac{C_{in}^{2} + C_{in} M + M C_{out} C_{in} K^{2} + H^{′} W^{′} C_{out} C_{in} K^{2}}{H^{′} W^{′} C_{out} C_{in} K^{2}} = \frac{C_{in}}{H^{′} W^{′} C_{out} C_{in} K^{2}} + \frac{M}{H^{′} W^{′} C_{out} C_{in} K^{2}} + \frac{M}{H^{′} W^{′}} + 1 \approx 1.(1 < M ≪ H^{′} W^{′},C_{in} \approx C_{out} Rflops=HWCoutCinK2Cin2+CinM+MCoutCinK2+HWCoutCinK2=HWCoutCinK2Cin+HWCoutCinK2M+HWM+11.1<MHWCinCout

  因此,与标准卷积相比,动态卷积具有大约M倍的参数,而额外的FLOPs可以忽略不计。

2、将ParameterNet扩展到语言领域

  稀疏激活的专家混合(MoE)模型最初在自然语言处理领域引入,允许在保持每个标记或样本的计算负载不变的情况下大幅增加参数数量。许多后续研究深入探讨了高效的路由机制,并展示了MoE在各种大型语言模型(LLM)中的有效性,如T5[38]、NLLB[26]、LLaMA[54]和Palm[8]。在这种情况下,重点主要是低FLOPs语言模型,以验证提出的假设,即合并更多参数可以增强大规模预训练对低FLOPs模型的好处,文中按比例减少并构建一个缩减版本LLaMA-1B

  与MoE类似,获取一个标记表示 x x x,然后将其路由到从一组N个确定的专家中确定的前k个专家。路由器模块生成表示为 h ( x ) = s o f t m a x ( r o u t e r ( x ) ) h(x) = softmax(router(x)) h(x)=softmax(router(x))的逻辑值logits,通过softmax函数在该特定层上的N个可用专家之间创建一个归一化分布。然后选择前k个专家(在实验中 k = 1 k = 1 k=1)来路由标记 x x x。专家容量上的训练损失(每个专家计算的标记数量)遵循Switch Transformer[12]中的设置。

三、实验

1、数据集和设置

  采用ImageNet-22K进行大规模预训练,并使用ImageNet-1K作为正常训练数据进行比较。

  • ImageNet-22K是一个包含14,197,122张图片,属于21841个类别的大规模图像数据集。ImageNet-1K是ImageNet-22K的一个子集,包含1000个对象类别。其中包含1,281,167张训练图片和50,000张验证图片。

  • 在ImageNet-1K上训练。按照常见设置,使用AdamW优化器对模型进行300个epoch的训练,其中包括20个warm up。使用批量大小为1024。基础学习率设置为0.001,并按照余弦调度进行衰减。

  • 数据增强策略包括RandAugment和随机擦除。采用权重衰减和标签平滑进行正则化。更多细节在表1。

  在ImageNet-22K上预训练的模型有ImageNet-22K 25.6M 12.0G 80.0 , EfficientNetV2-B0 ImageNet-22K 7.1M 0.72G 77.6 , EfficientNetV2-B1 ImageNet-22K 8.1M 1.2G 79.0 , Swin-T ImageNet-22K 28M 4.5G 80.9 . GhostNet-600M模型在ImageNet-22K上进行90个epoch的预训练,其中包括5个热身epoch。批量大小为4096,基础学习率设置为0.004。其他设置基本遵循ImageNet-1K上的设置如表1。

  在ImageNet-1K上微调:在ImageNet-1K上对预训练模型进行30个epoch的微调,不包括warm up。批量大小为512,基础学习率设置为0.0005。权重衰减设置为1e-8,并关闭随机擦除以更好地适应ImageNet-1K。其他设置基本遵循ImageNet-1K上的设置,如表1所示。

table1

table2

2、ParameterNet(约300MFLOPs和约600MFLOPs)

  通过调整宽度和深度构建了基准GhostNet,其具有不同的FLOPs(约300M和约600M)。ParameterNet是通过用动态卷积替换传统卷积层构建的。默认情况下,专家数量设置为4。网络架构的详细信息可在附录中找到。结果如表2所示。仅在ImageNet-1K上训练,ParameterNet的性能优于原始GhostNet约0.4-xx的准确率。对于GhostNet,仅在ImageNet-22K上进行预训练并不能提高性能。在ImageNet-22K上预训练的ParameterNet可以比ImageNet-1K获得超过2%的改进。这表明我们的ParameterNet具有更多参数但类似FLOPs的优势,可以从大规模视觉预训练中受益。

  与SOTA的比较。文中将ParameterNet与其他在ImageNet-22K或更大数据集(如JFT-300M 和IG-1B-Targeted)上预训练的代表模型进行比较。从表3的结果可以看出,ParameterNet在拥有更少FLOPs的情况下胜过其他在大规模数据集上预训练的模型。例如,ParameterNet-600M实现了81.6%的top-1准确率,其FLOPs约为ResNet50或Swin-T的1/7。

  推理速度:我们评估了ParameterNet和其他代表模型的推理速度以进行比较。我们在Intel Xeon上使用ONNX工具包运行模型,如图5所示。

table3
fig5

3、消融实验

(1) 动态专家的数量。

  动态卷积的动态专家数量是动态卷积的一个重要超参数,直接控制参数和FLOPs。如表4所示,更多的专家将大幅增加参数数量,稍微影响FLOPs。更多专家的性能优于较少专家。我们默认使用4个专家以进行效率权衡。

请添加图片描述

(2) 动态卷积与重新参数化卷积。

  正如之前讨论的,有各种方法来构建ParameterNet,例如动态卷积和重新参数化卷积。比较这两种方法,其中动态卷积有4个专家,重新参数化卷积在原始卷积基础上增加了3个并行分支。从表6的结果来看,尽管重新参数化卷积增加了训练参数,但其参数和FLOPs在推断时保持不变,即模型容量没有增加,ImageNet-22K预训练性能也没有提高。

table6

(3) 其他网络架构的ParameterNet

  除了CNN,将ParameterNet扩展到Transformer架构(即Swin Transformer)。为构建一个较小版本,将Swin-T的token维度设置为24,得到大约300M FLOPs的Swin-300M。从表5的结果来看,原始的Swin-300M在对ImageNet-22K进行预训练时有显著的准确率下降。我们的策略可以从ImageNet-22K预训练中获得+2.2%的性能提升。

table5

4、语言领域

  训练数据集是由几个来源混合而成,包括C4 [39]、维基百科 [54] 和 ArXiv [31]。这些数据都是公开可用的,直接混合它们而没有进行任何质量过滤。总体而言,训练数据集的网络架构。通过按比例减少原始 LLaMA [54] 的维度和层数来构建基准 LLaMA-1B,如表8所示。具体来说,隐藏大小、中间大小、头数和层数分别为2048、8191、16和12。分词器与 LLaMA 相同。结果和分析。按照之前的工作[2],在几个常识推理任务上呈现相应的训练损失和零样本结果,其中模型对提出的答案进行排名。FLOPs 是在输出响应长度设置为1的情况下计算的。路由器模块采用线性层实现,输入通道为隐藏大小,输出通道等于专家数量。如表7所示,我们观察到更多的专家为基线模型带来了额外的参数,从而显著提高了下游性能。例如,在上投影层上具有8个专家的 LLaMA-1B 平均获得了2.37% 的准确率提升。此外,增加的参数有助于减少训练损失,表明通过将 ParameterNet 引入语言模型,可以增强对输入数据的理解。此外,实验结果表明,LLaMA 的 FFN 中的三个线性投影具有类似的效果。

table7
table8

四、总结

  ParameterNet是一种通用方法,有各种实现方法,如动态卷积和重新参数化卷积。在实验中使用动态卷积来构建ParameterNet模型。ParameterNet能够克服低FLOPs的缺陷,并从大规模视觉预训练中获益良多。在ImageNet-22K大规模数据集上的实验证明了所提出的ParameterNet的有效性,文中还验证了我们的方法在语言领域的泛化能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1532775.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UE4_官方动画内容示例1.3_ 运动混合空间(Locomotion BlendSpace)

如何使用运动&#xff08;Locomotion&#xff09;混合空间将Actor在不同方向上及不同速度的运动混合起来。&#xff08;例如&#xff0c;展示了一个混合了以不同速度向后、前、左和右走路/跑步动作的Actor&#xff09;。 一、相关知识点&#xff1a; 混合空间是允许根据多个输…

UniTask 异步任务

文章目录 前言一、UniTask是什么&#xff1f;二、使用步骤三、常用的UniTask API和示例1.编写异步方法2.处理异常3.延迟执行4.等待多个UniTask或者一个UniTas完成5.异步加载资源示例6.手动控制UniTask的完成状态7.UniTask.Lazy延迟任务的创建8.后台线程切换Unity主线程9.不要返…

java数据结构与算法刷题-----LeetCode406. 根据身高重建队列

java数据结构与算法刷题目录&#xff08;剑指Offer、LeetCode、ACM&#xff09;-----主目录-----持续更新(进不去说明我没写完)&#xff1a;https://blog.csdn.net/grd_java/article/details/123063846 文章目录 1. 从高到底排序 1. 从高到底排序 解题思路&#xff1a;时间复杂…

MCU技术的创新浪潮与产业变革

MCU技术的创新浪潮与产业变革 一、MCU技术的创新发展 MCU&#xff0c;即微控制器&#xff0c;作为现代电子设备的核心部件&#xff0c;一直在不断地创新与发展。随着科技的进步&#xff0c;MCU的性能得到了极大的提升&#xff0c;功能也越来越丰富。从8位到32位&#xff0c;再…

MYSQL数据库管理基本操作

一、数据库的基本操作 1、登录数据库 [rootmysql-server ~]#mysql -uroot -p123456 ###直接回车&#xff0c;则进入数据库[rootmysql-server ~]#mysql -u root -p ###直接回车 Enter password: ###输入密码 方法一&#xff1a…

OpenGL学习笔记【2】——开发环境配置(GLFW,VS,Cmake),创建第一个项目

学OpenGL的都会知道&#xff0c;OpenGL只提供了绘图功能&#xff0c;创建窗口是需要自己完成的。这就需要学习相应操作系统的创建窗口方法&#xff0c;为简化创建窗口的过程&#xff0c;可以使用专门的窗口库&#xff0c;例如GLFW。使用GLFW之前需要先进行配置&#xff0c;那怎…

SQLiteC/C++接口详细介绍sqlite3_stmt类(四)

返回&#xff1a;SQLite—系列文章目录 上一篇&#xff1a;SQLiteC/C接口详细介绍sqlite3_stmt类&#xff08;三&#xff09; 下一篇&#xff1a;SQLiteC/C接口详细介绍sqlite3_stmt类&#xff08;五&#xff09; 7. sqlite3_bind_parameter_count函数 sqlite3_bind_param…

章节10实验--Ubuntu18.04 Qt MySQL libqsqlmysql.so

前言: 内容参考《操作系统实践-基于Linux应用与内核编程》一书的示例代码和教材内容&#xff0c;所做的读书笔记。本文记录再这里按照书中示例做一遍代码编程实践加深对操作系统的理解。 引用: 《操作系统实践-基于Linux应用与内核编程》 作者&#xff1a;房胜、李旭健、黄…

软考高级:结构化需求分析概念和例题

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;大厂高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

通过jsDelivr实现Github的图床CDN加速

最近小伙伴们是否发现访问我的个人博客http://xiejava.ishareread.com/图片显示特别快了&#xff1f; 我的博客的图片是放在github上的&#xff0c;众所周知的原因&#xff0c;github访问不是很快&#xff0c;尤其是hexo博客用github做图床经常图片刷不出来。一直想换图床&…

构建一个前端智能停车可视化系统

引言 随着城市化进程的加速&#xff0c;停车难问题日益突出。智能停车可视化系统通过实时展示停车场的车位信息&#xff0c;帮助用户快速找到空闲车位&#xff0c;提高停车效率。 目录 引言 一、系统设计 二、代码实现 1. 环境准备 2. 安装依赖 3. 创建停车场组件 4. 集…

时序预测 | Matlab实现BiTCN-BiLSTM双向时间卷积神经网络结合双向长短期记忆神经网络时间序列预测

时序预测 | Matlab实现BiTCN-BiLSTM双向时间卷积神经网络结合双向长短期记忆神经网络时间序列预测 目录 时序预测 | Matlab实现BiTCN-BiLSTM双向时间卷积神经网络结合双向长短期记忆神经网络时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现BiTCN…

本地主机连接Linux虚拟机中的mongodb,并使用studio 3T连接,同时项目启动连接mongodb刷新数据库

本部分只做个人纪录 ** 1.安装mongodb ** 本部分为尚硅谷的电影推荐系统的文档&#xff0c;具体以实际存放位置为准 // 通过WGET下载Linux版本的MongoDB [bigdatalinux ~]$ wget https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-rhel62-3.4.3.tgz// 将压缩包解压到…

Python 深度学习第二版(GPT 重译)(二)

四、入门神经网络&#xff1a;分类和回归 本章涵盖 您的第一个真实世界机器学习工作流示例 处理矢量数据上的分类问题 处理矢量数据上的连续回归问题 本章旨在帮助您开始使用神经网络解决实际问题。您将巩固从第二章和第三章中获得的知识&#xff0c;并将所学应用于三个新…

Java newInstance方法学习

用newInstance与用new是有区别的&#xff0c;区别在于创建对象的方式不一样&#xff0c;前者是使用类加载机制&#xff1b; newInstance方法要求该 Class 对应类有无参构造方法&#xff1b; 执行 newInstance()方法实际上就是使用对应类的无参构造方法来创建该类的实例&#x…

Golang 异步(bsd/linux)io

Golang 异步(bsd/linux)io 在日常开发中&#xff0c;读写文件的底层调用函数是syscall.Read/Write。一切都是围绕这两个函数展开的&#xff0c;不过有时候需要或者就是单纯想异步执行。liburing是linux上一个很好的原生异步io库&#xff0c;这里需要适配bsd派系的系统&#xf…

Redis面试题以及答案

1. 什么是Redis&#xff1f;它主要用来什么的&#xff1f; Redis&#xff0c;英文全称是Remote Dictionary Server&#xff08;远程字典服务&#xff09;&#xff0c;是一个开源的使用ANSI C语言编写、支持网络、可基于内存亦可持久化的日志型、Key-Value数据库&#xff0c;并…

00. 认识 Java 语言与安装教程

认识 Java Java 在 20 多年发展过程中&#xff0c;与时俱进&#xff0c;为了适应时代的需要&#xff0c;经历过两次重大的版本升级&#xff0c;一个是 Java 5&#xff0c;它提供了泛型等重要的功能。另一个是提供了 Lambda 表达式等重要的功能的 Java 8。 一些重要的 Java 的…

GitHub配置SSH Key(详细版本)

GitHub配置SSH Key的目的是为了帮助我们在通过git提交代码是&#xff0c;不需要繁琐的验证过程&#xff0c;简化操作流程。比如新建的仓库可以下载, 但是提交需要账号密码。 步骤 一、设置git的user name和email 如果你是第一次使用&#xff0c;或者还没有配置过的话需要操作…

linux内核input子系统概述

目录 一、input子系统二、关键数据结构和api2.1 数据结构2.1.1 input_dev2.1.2 input_handler2.1.3 input_event2.1.4 input_handle 2.2 api接口2.2.1 input_device 相关接口input_device 注册流程事件上报 2.2.2 input handle 相关接口注册 handle指定 handle 2.2.3 input han…