DHVT:在小数据集上降低VIT与卷积神经网络之间差距,解决从零开始训练的问题

news2025/1/13 3:14:49

VIT在归纳偏置方面存在空间相关性和信道表示的多样性两大缺陷。所以论文提出了动态混合视觉变压器(DHVT)来增强这两种感应偏差。

在空间方面,采用混合结构,将卷积集成到补丁嵌入和多层感知器模块中,迫使模型捕获令牌特征及其相邻特征。

在信道方面,引入了MLP中的动态特征聚合模块和多头注意力模块中全新的“head token”设计,帮助重新校准信道表示,并使不同的信道组表示相互交互。

Dynamic Hybrid Vision Transformer (DHVT)

1、顺序重叠补丁嵌入 (Sequential Overlapping Patch Embedding )

改进后的补丁嵌入称为Sequential overlap patch embedding(SOPE),它包含了3×3步长s=2的卷积、BN和GELU激活的几个连续卷积层。卷积层数与patch大小的关系为P=2^k。SOPE能够消除以前嵌入模块带来的不连续性,保留重要的底层特征。它能在一定程度上提供位置信息。

在一系列卷积层前后分别采用两次仿射变换。该操作对输入特征进行了缩放和移位,其作用类似于归一化,使训练性能在小数据集上更加稳定。

SOPE的整个流程可以表述如下。

这里的α和β为可学习参数,分别初始化为1和0。

2、编码器整体架构

然后将特征映射重塑为补丁并与cls令牌连接,并发到编码器层。每个编码器包含层归一化、多头自注意力和前馈网络。将MHSA网络改进为头部交互多头自注意网络(HI-MHSA),将前馈网络改进为动态聚合前馈网络(DAFF)。在最后的编码器层之后,输出类标记将被馈送到线性头部进行最终预测。

3、动态聚合前馈 (Dynamic Aggregation Feed Forward )

ViT 中的普通前馈网络 (FFN) 由两个全连接层和 GELU 组成。DAFF 在 FFN 中集成了来自 MobileNetV1 的深度卷积 (DWCONV)。由于深度卷积带来的归纳偏差,模型被迫捕获相邻特征,解决了空间视图上的问题。它极大地减少了在小型数据集上从头开始训练时的性能差距,并且比标准 CNN 收敛得更快。还使用了与来自 SENet 的 SE 模块类似的机制。

Xc、Xp 分别表示类标记和补丁标记。类标记在投影层之前从序列中分离为 Xc。剩余的令牌 Xp 则通过一个内部有残差连接的深度集成多层感知器。

然后将输出的补丁标记平均为权重向量 W。在squeeze-excitation操作之后,输出权重向量将与类标记通道相乘。然后重新校准的类令牌将与输出补丁令牌以恢复令牌序列。

4、相互作用多头自注意(HI-MHSA)

在最初的MHSA模块中,每个注意头都没有与其他头交互。在缺乏训练数据的情况下,每个通道组的表征都太弱而无法识别。

在HI-MHSA中,每个d维令牌,包括类令牌,将被重塑为h部分。每个部分包含d个通道,其中d =d×h。所有分离的标记在它们各自的部分中取平均值。因此总共得到h个令牌,每个令牌都是d维的。所有这样的中间令牌将再次投影到d维,总共产生h个头部令牌。最后,将它们与补丁令牌和类令牌连接起来。

5、模型变体

DHVT-T: 12层编码器,嵌入维度为192,MLP比为4,CIFAR-100和DomainNet上的注意头为4,ImageNet-1K上的注意头为3。DHVT-S: 12层编码器,嵌入维度为384,MLP比4,CIFAR-100上注意头为8,DomainNet和ImageNet-1K上注意头为6。

需要说明的是:论文和模型的重点是在小数据集上从零开始训练。

结果展示

1、DomainNet & ImageNet-1K

在DomainNet上,DHVT表现出比标准ResNet-50更好的结果。在ImageNet-1K上,DHVT-T的准确率达到76.47,DHVT-S的准确率达到82.3。论文说这是在VIT的最佳性能。

2、CIFAR-100

DHVT-T在5.8M参数下达到83.54。DHVT-S仅用2280万个参数即可达到85.68。与其他基于vit的模型和CNN(ResNeXt, SENet, SKNet, DenseNet和Res2Net)相比,所提出的模型参数更少,性能更高。

3、消融研究

DeiT-T 4头,从头开始训练300次,基线成绩67.59。当移除绝对位置嵌入时,性能急剧下降至58.72。当采用SOPE并取消绝对位置嵌入时,性能下降幅度并不大。

同时采用SOPE和DAFF时,可以对位置信息进行全面编码,SOPE也有助于解决这里的不重叠问题,在早期保留了细粒度的底层特征。

table6发现了跨不同模型结构的head令牌带来的稳定性能增益。

当采用这三种修改时,获得了+13.26的精度增益,成功地弥合了与CNN的性能差距。

4、可视化

不同的head令牌在不同的补丁上激活

https://avoid.overfit.cn/post/806ce15b180440d988de5f76e22a2aaf

作者:Sik-Ho Tsang

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/604722.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

tcp shrinking window 之进退

一个有趣的问题:Unbounded memory usage by TCP for receive buffers, and how we fixed it 引出一个 kernel patch:[PATCH] Add a sysctl to allow TCP window shrinking in order to honor memory limits 但这 patch 把一个问题变成了两个问题&#…

apple pencil一代的平替有哪些品牌?平价电容笔推荐

要知道,真正的苹果原装电容笔,价格可不低,仅仅一支就是近千块。实际上,平替电容笔对没有太多预算的用户是个不错的选择。一款苹果的电容笔的售价,相当于平替电容笔的四倍,不过平替电容笔的书写体验&#xf…

pnpm对npm及yarn降维打击详解

目录 正文npm2yarnpnpm总结 正文 大家最近是不是经常听到 pnpm,我也一样。今天研究了一下它的机制,确实厉害,对 yarn 和 npm 可以说是降维打击。 那具体好在哪里呢? 我们一起来看一下。 我们按照包管理工具的发展历史&#xf…

3.5 凸多边形最优三角部分

博主简介:一个爱打游戏的计算机专业学生博主主页: 夏驰和徐策所属专栏:算法设计与分析 1.什么是多边形的三角剖分? 多边形三角剖分是指将多边形分割成互不相交的三角形的弦的集合T。 我的理解: 多边形三角剖分是将给…

uniapp本地存储详解

uniapp本地存储详解 前言 在开发uniapp应用时,我们常常需要使用本地存储来保存一些数据,比如用户登录信息、设置项等,使得应用能够在设备上保存和读取数据,以便提供更好的用户体验和离线功能支持,本文将简单介绍unia…

python编程——编译器与解释器

作者:Insist-- 个人主页:insist--个人主页 本文专栏:python专栏 专栏介绍:本专栏为免费专栏,并且会持续更新python基础知识,欢迎各位订阅关注。 目录 前言 一、编译器与解释器的介绍 二、编译器与解释器…

JDK1.8下载安装(优雅)

bug虐我千百遍,我待bug如初恋。 这里使用的环境是win11 64位系统,应该适配win8-win11 一、下载 这里提供两种下载方式,官网下载和第三方下载,区别就是下载速度不同 1. 官网下载 (1)官网下载:…

智慧物流货运系统源码 货运平台的功能介绍

网络货运平台源码 网络货运平台的功能 网络货运是指利用互联网平台,通过物流配送的方式进行商品销售和物流运输的一种新型商业模式。这种模式将传统的货运模式与互联网技术相结合,通过网络平台进行交易、物流配送和结算等一系列商业流程,从而…

用户画像如何创新破局数据驱动增长 | 数据增长

用户画像即用户信息标签化,就是企业通过收集与分析消费者社会属性、生活习惯、消费行为等主要信息的数据之后,完美地抽象出一个用户的商业全貌,是企业应用大数据技术的基本方式。例如:通过收集用户的人口属性、行为属性、消费习惯…

面向初学者的数据科学|要学习的内容概述

面向初学者的数据科学|要学习的内容概述 数据科学家是21世纪最性感的工作。每个人都想变得性感。该领域开始变得竞争激烈,提高了就业标准。 因此,仅仅知道如何使用不同的工具是不够的,求职者需要能够抓住基本的概念和技术,然后应用…

VMware Cloud Foundation 5.0 发布 - 领先的多云平台

VMware Cloud Foundation 5.0 发布 - 领先的多云平台 高效管理虚拟机 (VM) 和容器工作负载。为本地部署的全栈超融合基础架构 (HCI) 提供云的优势。 请访问原文链接:https://sysin.org/blog/vmware-cloud-foundation-5/,查看最新版。原创作品&#xff…

DEMO:F4帮助 收藏夹功能

货铺QQ群号:834508274微信群不能扫码进了,可以加我微信SAPliumeng拉进群,申请时请提供您哪个模块顾问,否则是一律不通过的。进群统一修改群名片,例如BJ_ABAP_森林木。群内禁止发广告及其他一切无关链接,小程…

没有硬件资源?免费使用Colab搭建你自己的Stable Diffiusion在线模型!保姆级教程...

部署 Stable Diffusion 需要一定的硬件资源,具体取决于要处理的图像大小和处理速度等因素。一般来说,至少需要一台具有较高计算能力的服务器,而对 GPU 的高要求就限制了我们学习和使用SD来生成我们想要的图像。 GPU是深度学习开发的重要硬件条…

C++ 学习 ::【基础篇:16】:C++ 类的基本成员函数:拷贝构造函数(认识、特征、注意点及典型使用场景)及其基本写法与调用

本系列 C 相关文章 仅为笔者学习笔记记录,用自己的理解记录学习!C 学习系列将分为三个阶段:基础篇、STL 篇、高阶数据结构与算法篇,相关重点内容如下: 基础篇:类与对象(涉及C的三大特性等&#…

FastJSON autoType is not support问题解决

概述 产品在使用内部的后台管理系统时反馈的问题。 于是登录平台,发现如下报错详情: 排查 经过分析,不难得知,请求是从gateway网关转发到对应的统计服务 statistics,此服务有个接口/api/statistics/data/overview…

华为OD机试真题 Java 实现【支持优先级的队列】【2023 B卷 100分】

一、题目描述 实现一个支持优先级的队列,高优先级先出队列,同优先级时先进先出。 如果两个输入数据和优先级都相同,则后一个数据不入队列被丢弃。 队列存储的数据内容是一个整数。 二、输入描述 一组待存入队列的数据(包含内…

Java官方笔记4类和对象

创建类 定义类Bicycle: public class Bicycle {// the Bicycle class has// three fieldspublic int cadence;public int gear;public int speed;// the Bicycle class has// one constructorpublic Bicycle(int startCadence, int startSpeed, int startGear) {gea…

李沐动手学习深度学习 2023年Win10 下安装 CUDA 和 Pytorch 跑深度学习(最新)

目录 一、安装Anaconda 1.下载Anaconda 测试是否安装成功 二、安装pytorch 验证pytorch是否安装成功 4.测试 3.配置pycharm 一、安装Anaconda 1.下载Anaconda 可以在官网下载,但是速度较慢,这里我选择了清华镜像源的下载 https://mirrors.tuna.t…

Gradio的web界面演示与交互机器学习模型,Blocks的事件侦听《7》

在第一篇文章我们就熟悉了Blocks的用法,使用Blocks比Interface更加灵活,这节重点关注Blocks里面的相关操作。 1、Blocks标准例子 import gradio as grdef greet(name):return "你好 " name "!"with gr.Blocks() as demo:name g…

简单的汉诺塔,神奇的预言,竟然需要5849亿年???(52)

小朋友们好,大朋友们好! 我是猫妹,一名爱上Python编程的小学生。 和猫妹学Python,一起趣味学编程。 今日主题 汉诺塔 古印度大梵天传说 Python玩转汉诺塔 递归 汉诺塔 汉诺塔(Hanoi)是一个著名的益智游戏,也称…