【MLP-Mixer】核心方法解读

news2025/1/23 0:59:37

 abstract:

我们提出MLP-Mixer架构(或简称“Mixer”),这是一个具有竞争力但在概念和技术上都很简单的替代方案,它不使用卷积或自关注。相反,Mixer的架构完全基于多层感知器(mlp),这些感知器可以在空间位置或特征通道上重复应用。Mixer仅依赖于基本的矩阵乘法例程、对数据布局的更改(重塑和换位)以及标量非线性。

intro:

图1描述了Mixer的宏观结构。它接受一系列线性投影图像补丁(也称为令牌),形状为“patches x channels”表,作为输入,并保持该维度。Mixer使用两种类型的MLP层:通道混合MLP和令牌混合MLP。信道混合mlp允许在不同信道之间进行通信;它们独立地操作每个令牌,并将表中的各个行作为输入。令牌混合mlp允许在不同的空间位置(令牌)之间进行通信;它们独立地在每个通道上操作,并将表中的各个列作为输入。这两种类型的层相互交织,以支持两个输入维度的交互。

在极端情况下,我们的架构可以看作是一个非常特殊的CNN,它使用1×1卷积进行通道混合,使用单通道深度卷积的完整接受场和参数共享进行令牌混合。然而,相反的情况并不成立,因为典型的cnn并不是Mixer的特例。此外,卷积比mlp中的普通矩阵乘法更复杂,因为它需要对矩阵乘法进行额外的昂贵简化和/或专门的实现。

mixer architecture:

现代深度视觉架构由混合特征的层组成(i)在给定的空间位置,(ii)在不同的空间位置之间,或同时混合特征。

在cnn中,(ii)是用N × N个卷积(对于N > 1)和池化实现的。更深层的神经元有更大的接受野[1,29]。同时,1×1卷积也执行(i),更大的核同时执行(i)和(ii)。

在Vision transformer和其他基于注意力的架构中,自注意力层允许(i)和(ii),而mlp块执行(i)。Mixer架构背后的思想是清楚地分离每个位置(通道混合)操作(i)和跨位置(令牌混合)操作(ii)。

这两个操作都是通过mlp实现的。图1总结了该体系结构。

代码:

class MlpBlock(nn.Module):
  mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.Dense(self.mlp_dim)(x)
    y = nn.gelu(y)
    return nn.Dense(x.shape[-1])(y)


class MixerBlock(nn.Module):
  """Mixer block layer."""
  tokens_mlp_dim: int
  channels_mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.LayerNorm()(x)
    y = jnp.swapaxes(y, 1, 2)
    y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
    y = jnp.swapaxes(y, 1, 2)
    x = x + y
    y = nn.LayerNorm()(x)
    return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)


class MlpMixer(nn.Module):
  """Mixer architecture."""
  patches: Any
  num_classes: int
  num_blocks: int
  hidden_dim: int
  tokens_mlp_dim: int
  channels_mlp_dim: int
  model_name: Optional[str] = None

  @nn.compact
  def __call__(self, inputs, *, train):
    del train
    x = nn.Conv(self.hidden_dim, self.patches.size,
                strides=self.patches.size, name='stem')(inputs)
    x = einops.rearrange(x, 'n h w c -> n (h w) c')
    for _ in range(self.num_blocks):
      x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
    x = nn.LayerNorm(name='pre_head_layer_norm')(x)
    x = jnp.mean(x, axis=1)
    if self.num_classes:
      x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
                   name='head')(x)
    return x

mlpblock是一个简单的mlp块,包含两个全连接层和一个GELU激活函数。第一个全连接层增加维度,第二个全连接层减少维度回到原始输入的维度。

mixerblock是核心,首先对输入进行层归一化,然后应用token mixing,即对mlpblock交换维度,接着添加残差连接。之后再次对结果进行层归一化,并应用channel mixing,最后再次添加残差连接。

mlpmixer是完整的MLP-Mixer模型,首先使用一个卷积来处理输入图像,然后将图像patches展平。接着,通过多个mixerblock来处理patches。在所有MixerBlock之后,它应用层归一化并进行平均池化。如果指定了num_classes,它将添加一个全连接层作为分类头。

可以将公式总结为:

这里σ是单元非线性(GELU[16])。Ds和Dc分别是令牌混合和通道混合mlp中可调的隐藏宽度。注意,Ds的选择与输入补丁的数量无关。因此,网络的计算复杂度在输入patch的数量上是线性的,不像ViT的复杂度是二次的。由于Dc与patch大小无关,因此整体复杂度在图像的像素数上是线性的

如上所述,将相同的通道混合MLP(令牌混合MLP)应用于x的每一行(列)。将通道混合MLP的参数(在每一层内)捆绑在一起是一种自然的选择——它提供了位置不变性,这是卷积的一个突出特征。但是,跨通道绑定参数的情况要少见得多。例如,在一些cnn中使用的可分离卷积[9,40],独立于其他通道对每个通道应用卷积。然而,在可分离卷积中,不同的卷积核应用于每个通道,而不像Mixer中的令牌混合mlp,它为所有通道共享相同的核(完全接受场)。

当增加隐藏维度C或序列长度S时,参数绑定可以防止体系结构增长过快,并节省大量内存。

Mixer中的每一层(除了初始patch投影层)都有相同大小的输入。这种“各向同性”的设计最类似于变压器,或者其他领域的深度rnn,它们也使用固定的宽度。这与大多数具有金字塔结构的cnn不同:较深的层具有较低的分辨率输入,但有更多的通道。请注意,虽然这些是典型的设计,但也存在其他组合,如各向同性ResNets[38]和金字塔状vit[52]。

除了MLP层,Mixer还使用其他标准的体系结构组件:跳过连接[15]和层规范化[2]。与vit不同,Mixer不使用位置嵌入,因为令牌混合mlp对输入令牌的顺序很敏感。最后,Mixer使用具有全局平均池化层的标准分类头,然后是线性分类器。

注意,在多模态中,视频-文本,音频-文本应用了以上结构:

信息可以在不同的模态和时间序列之间流动。每个区块由两个MLP层和一个GELU激活功能(描述为𝛷). 此外,在每个块中也应用了跳连接。假设 是一个输入特征,其中 𝑡 是时间序列的长度,并且 𝑑 是模态的数量。在每一层中,MLP-communicator模块可表示如下:

其中,𝑖 从 1 到 𝑑 表示行数,𝑗 从 1 到 𝑡 表示列数。Norm() 表示 LayerNorm (层归一化),W 表示每个块中线性层的权重。输入特征 𝑋 首先通过 Time-Mixing MLP 过程,并通过跳连接生成 𝑍,这一步骤允许水平对应特征之间的通信。然后,𝑍 经过 Modality-Mixing MLP 过程生成 𝑌,特征在纵向方向上进行融合。最终得到的特征 𝑌 融合了两个方向的特征信息。该结构允许输入特征中的每个元素可以沿着两个维度与其他特征进行互动。

        解释:在Time-Mixing的过程中,假设我们有三个模态,X j,i 就是表示在第j个时间点,第i个模态的特征值。在这个公式中,我们将所有时间点下不同模态的特征进行了整合,然后Norm(X *,i)是对这个特征值进行正规化,使其分布在一个固定的区间内,通常是0到1或者-1到1。这个正规化的过程可以使得不同的特征值有可比性,且防止部分大数值特征对整个模型训练结果的过度影响。

        接着W_1和W_2就是两个权重矩阵,它们和正规化后的特征值进行乘积操作。φ是激活函数,例如本模型中的GELU激活函数,这个函数起到的作用类似一个开关,可以根据函数结果决定神经元是否被激活,即该特征是否被放入下一层中去。

        然后,我们再看第二个公式:

        这个公式和前一个公式的结构是一样的,都是先进行一次正规化,然后使用两个权重矩阵和输入的特征进行乘积操作,并通过激活函数处理。

        在这个公式中,我们将所有模态在同一个时间点的特征进行了整合,即在时间 j 的所有模态的特征都放入了这个公式中。这也就体现了这个模型中,在不同模态和时间系列进行信息流动和整合的目的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2205948.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

渗透测试 之 域渗透手法【域内用户枚举】手法 Kerbrute msf pyKerbrute 工具使用详解

说明一下: 域内用户枚举工具使用说说: Kerbrute pyKerbrute MSF模块的使用 域内用户名枚举原理分析: 域内用户枚举攻击防御: 流量检测: 日志层面: 说明一下: 域环境或者内网环境下,可以在没有域环…

深入理解Transformer的笔记记录(精简版本)---- ELMO->GPT->BERT

1、ELMO word embedding无法区分多义词的不同语义,其本质上是个静态的方式,所谓静态指的是训练好之后每个单词的表达就固定住了,以后使用的时候,不论新句子上下文单词是什么,这个单词的Word Embedding不会跟着上下文场景的变化而改变 ELMO根据当前上下文对Word Embed…

有趣的python库:用 difflib 实现文本差异的可视化

一,介绍 difflib 模块是Python标准库的一部分,提供了一系列用于比较序列的类和函数,特别适用于文本比较任务。这个模块可以帮助用户发现两个文本文件或字符串序列之间的差异,并以多种格式展示这些差异,比如这样&#…

400行程序写一个实时操作系统RTOS(开篇)

笔者之前突发奇想,准备写一个极其微小的实时操作系统内核,在经过数天的努力后,这个RTOS诞生了。令读者比较意外的是,它的程序只有400行左右。但就是这短短的400行,完成了动态内存管理、多线程、优先级、低功耗管理、调…

深度学习--------------------------------使用注意力机制的seq2seq

目录 动机加入注意力Bahdanau注意力的架构 总结Bahdanau注意力代码带有注意力机制的解码器基本接口实现带有Bahdanau注意力的循环神经网络解码器测试Bahdanau注意力解码器该部分总代码 训练从零实现总代码简洁实现代码 将几个英语句子翻译成法语该部分总代码 将注意力权重序列进…

BUG修复(不断整理想起什么就整理什么)

声明:此篇博文是记录本人从开始学习计算机过程中遇到的各种类型的报错以解决办法,希望给同道中人提供一点绵薄的帮助,也欢迎大家在评论区讨论或私信我交流问题 共同进步! 一、FPGA系列 1.Synthesis failed 错误:综合失败&#…

Python | Leetcode Python题解之第468题验证IP地址

题目: 题解: class Solution:def validIPAddress(self, queryIP: str) -> str:if queryIP.find(".") ! -1:# IPv4last -1for i in range(4):cur (len(queryIP) if i 3 else queryIP.find(".", last 1))if cur -1:return &q…

测试工作能干到退休!从会写一份成长型测试周报开始

测试周报则是反映团队工作进展和专业态度的一扇窗口。通过周报,我们不仅可以展示一周内的工作成果,更可以体现团队的工作心态——是积极进取、不断学习的成长型心态,还是仅仅满足于现状、缺乏动力的躺平型心态。本文将带您深入了解这两种不同…

Vue 项目文件大小优化

优化逻辑 任何优化需求,都有一个前提,即可衡量。 那 Vue 加载速度的优化需求,本质上是要降低加载静态资源的大小。 所以,优化前,需要有一个了解项目现状的资源加载大小情况。 主要分 3 步走: 找到方法测…

k8s jenkins 动态创建slave

k8s jenkins 动态创建slave 简述使用jenkins动态slave的优势:配置jenkins动态slave配置 Pod Template配置容器模板挂载卷 测试 简述 持续构建与发布是我们日常工作中必不可少的一个步骤,目前大多公司都采用 Jenkins 集群来搭建符合需求的 CI/CD 流程&am…

8. 多态、匿名内部类、权限修饰符、Object类

文章目录 一、多态 -- 花木兰替父从军1. 情境2. 小结 二、匿名内部类三、权限修饰符四、Object -- 所有类的父类(包括我们自己定义的类)五、内容出处 一、多态 – 花木兰替父从军 1. 情境 我们现在新建两个类HuaMuLan和HuaHu。HuMuLan是HuaHu的女儿,所以她会有她父…

利用编程思维做题之链表内指定区间反转

牛客网题目 1. 理解问题 给定一个单链表和两个整数 m 和 n,要求反转链表中从位置 m 到位置 n 的节点,最后返回反转后的链表头节点。 示例: 输入:链表 1 -> 2 -> 3 -> 4 -> 5 -> NULL,m 2,…

《市场营销学》PPT课件.ppt

网盘:https://pan.notestore.cn/s.html?id29https://pan.notestore.cn/s.html?id29

山西农业大学20241011

03-JAVASCRIPT 一.数组二.BOM1. window对象2. location对象3. history对象4. navigator对象5. screen对象6. cookie对象 三.DOM操作1. 概述2. 查找元素2.1 id方式2.2 标签名方式2.3 class名方式2.4 css选择器方式 一.数组 <script>// 1. 创建数组, 通过数组字面量// …

不卷且创新idea:KAN+特征提取!10篇高分套路拆解,快来抄作业!

今天和大家分享一种创新的深度学习技术&#xff1a;KAN特征提取。 这种技术通过引入KAN来增强模型的特征处理能力&#xff0c;借由KAN的自适应激活函数&#xff0c;动态调整数据特性&#xff0c;从而有效提取更加准确的特征&#xff0c;实现更高性能的模型表现。 这种优势让K…

离散微分几何基础:流形概念与网格数据结构

一、流形概念的引入 &#xff08;一&#xff09;微分几何核心概念——流形 在微分几何的广袤领域中&#xff0c;流形概念占据着核心地位。它如同一个神秘的基石&#xff0c;支撑着我们对各种几何形状和空间的深入理解。就像网格和抽象的单纯复数是我们探索拓扑结构&#xff08…

使用阿里云盘将服务器上的文件上传/下载到云盘/服务器

阿里云盘官方文档&#xff1a; 具体的操作步骤这里都有&#xff1a; https://github.com/tickstep/aliyunpan 具体步骤 &#xff1a; 安装&#xff1a; wget https://github.com/tickstep/aliyunpan/releases/download/v0.3.4/aliyunpan-v0.3.4-linux-amd64.zip【这里最好下…

服务器与内存市场|2025预测动态早知道

根据TrendForce的数据分析报告&#xff0c;三大DRAM供应商在2023年服务器总bit增长率经历了不同程度下滑后&#xff0c;2024年市场迎来了反弹&#xff0c;增长率分别达到了9.9%/12.3%/24.1%。这一转变表明服务器DRAM在三大供应商中的比例预计将会增加。与此同时&#xff0c;由于…

Java项目实战II养老||基于Java+Spring Boot+MySQL的社区智慧养老监护管理平台设计与实现(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 随着老龄化…

ConditionVideo: 无训练的条件引导视频生成 | AAAI 2024

作者&#xff1a;彭博&#xff0c;上海人工智能实验室与上海交大2023级联培博士。 最近的工作已经成功地将大规模文本到图像模型扩展到视频领域&#xff0c;产生了令人印象深刻的结果&#xff0c;但计算成本高&#xff0c;需要大量的视频数据。在这项工作中&#xff0c;我们介…