gMLP:Pay Attention to MLPs--模型代码讲解

news2025/1/10 23:48:09

gMLP模型代码讲解

  • Introduction
  • gMLP网络结构
    • Spatial Gating Unit (SGU)
  • code
    • gMLPBlock
    • Spatial Gating Unit

基于MLP-Mixer 的改进…

Introduction

总的来说,gMLP 在视觉和NLP领域的惊人有效性表明,自我注意并不是扩大机器学习模型的必要因素,尽管它根据任务的不同可以是一个有用的补充随着数据和计算量的增加。具有gMLP等更简单的空间交互机制的模型可以像变压器一样强大,分配给自我注意的能力可以被删除或大幅减少。

gMLP网络结构

gMLP 的输入仍为若干图像块(即将一张图像切割成若干图像块),输出为若干个向量(token)堆叠组成的矩阵,例如token的维度为L,个数为N,则输出为N ∗ L 的矩阵,通过池化等操作转换为最终的特征向量。
由若干个基本构成单元堆叠而成
在这里插入图片描述

设输入矩阵(即图中的Input Embeddedings)为 n ∗ d n∗d nd 的矩阵X , n为序列长度, d为特征维度,则gMLP的unit结构可以简化为 Z = δ ( X U ) Z ~ = s ( Z ) Y = δ ( Z ~ V ) + X Z=\delta (XU)\\ \tilde{Z} = s(Z)\\ Y=\delta(\tilde{Z}V)+X Z=δ(XU)Z~=s(Z)Y=δ(Z~V)+X
U , V U,V U,V为可学习的矩阵, δ \delta δ 为激活函数, s ( z ) s(z) s(z) 为图中的Spatial Gating Unit.

Spatial Gating Unit (SGU)

为了能有跨token的交互, s ( ⋅ ) s(\cdot) s() 操作须在空间维度。可以简单的使用线性映射表示: f W , b ( Z ) = W Z + b s ( Z ) = Z ⊙ f W , b ( Z ) f_{W,b}(Z)=WZ+b\\ s(Z)=Z⊙f_{W,b}(Z) fW,b(Z)=WZ+bs(Z)=ZfW,b(Z) Z Z Z n ∗ d n∗d nd 的矩阵,则 W W W n ∗ n n∗n nn 的矩阵,表示空间交互的映射参数,b 为n 维向量(WZ+b表示WZ的第一行元素与b的第一维元素相加),为了保证训练的稳定性,W 初始化值接近于0(貌似用[-1,1]的均匀分布初始化),b 的初始值为1,此时 f W , b ( Z ) ≈ 1 , s ( Z ) ≈ Z f_{W,b}(Z)\approx1,s(Z)\approx Z fW,b(Z)1,s(Z)Z,这种初始化确保了每个gMLP块在训练的早期阶段像一个常规的FFN,其中每个token 都被独立处理,并且只在学习过程中逐步跨token注入空间信息。

更进一步的作者发现将Z 沿着channel维度切割成 Z 1 , Z 2 Z_1,Z_2 Z1,Z2 ( Z 1 , Z 2 Z_1,Z_2 Z1,Z2的维度分别为 n ∗ d 1 , n ∗ d 2 , d 1 + d 2 = n n*d_1,n*d_2,d_1+d_2=n nd1,nd2,d1+d2=n)两个部分更为有效,此时s(Z)操作变为
s ( Z ) = Z 1 ⊙ f W , b ( Z 2 ) s(Z)=Z_1\odot f_{W,b}(Z_2) s(Z)=Z1fW,b(Z2)

code

先看整体结构,在整个gMLP结构中,gmlp代替self-attention设计了框架结构。每一个层级使用gMLPBlock作为一个block阶段。整个残差形式为gmlp(norm(x))+x.

class gMLP(nn.Module):
    def __init__(
            self,
            *,
            ...
    ):
        super().__init__()
        dim_ff = dim * ff_mult
        self.seq_len = seq_len
        self.prob_survival = prob_survival

        self.to_embed = nn.Embedding(num_tokens, dim) if exists(num_tokens) else nn.Identity()

        self.layers = nn.ModuleList([Residual(PreNorm(dim, gMLPBlock(dim = dim, dim_ff = dim_ff, seq_len = seq_len, attn_dim = attn_dim, causal = causal, act = act))) for i in range(depth)])
        #  gmlp(norm(x))+x

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        ) if exists(num_tokens) else nn.Identity()

    def forward(self, x):
        x = self.to_embed(x)
        layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival)
        out = nn.Sequential(*layers)(x)
        return self.to_logits(out)
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

gMLPBlock

class gMLPBlock(nn.Module):
    def __init__(
            self,
            *,
            dim,
            dim_ff,
            seq_len,
            attn_dim = None,
            causal = False,
            act = nn.Identity()
    ):
        super().__init__()
        self.proj_in = nn.Sequential(
            nn.Linear(dim, dim_ff),
            nn.GELU()
        )
		# dim_ff = dim * ff_mult(4)
		# dim -> dim*4
        self.attn = Attention(dim, dim_ff // 2, attn_dim, causal) if exists(attn_dim) else None

        self.sgu = SpatialGatingUnit(dim_ff, seq_len, causal, act)
        self.proj_out = nn.Linear(dim_ff // 2, dim)

    def forward(self, x):
        gate_res = self.attn(x) if exists(self.attn) else None
		# 默认的attn是None,即不进行该操作
        x = self.proj_in(x)
        x = self.sgu(x, gate_res = gate_res)
        x = self.proj_out(x)
        return x

Spatial Gating Unit

class SpatialGatingUnit(nn.Module):
    def __init__(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3):
        super().__init__()
        dim_out = dim // 2
        self.causal = causal

        self.norm = nn.LayerNorm(dim_out)
        self.proj = nn.Conv1d(dim_seq, dim_seq, 1)

        self.act = act

        init_eps /= dim_seq
        nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
        nn.init.constant_(self.proj.bias, 1.)

    def forward(self, x, gate_res = None):
        device, n = x.device, x.shape[1]

        res, gate = x.chunk(2, dim = -1)
        # self-atten 用的dim
        # sgu用的dim_ff = dim * ff_mult(4),即4倍
        # chunk之后,每个为2倍,用两倍的值进行attention
        gate = self.norm(gate)

        weight, bias = self.proj.weight, self.proj.bias
        if self.causal:
            ...

        gate = F.conv1d(gate, weight, bias)
		# 1d卷积混合w*h维度的信息,patch通道的混合
        if exists(gate_res):
            gate = gate + gate_res

        return self.act(gate) * res

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2170905.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于STM32的DHT11功能实现(操作时序)

1.引脚定义 Pin名称注释1VDD供电 3-5.5V2GND接地,电源负极3DATA串行数据,单总线4NC空脚,请悬空 2.数据格式 DHT11 采用单总线协议与单片机通信,单片机发送一次复位信号后,DHT11 从低功耗模式转换到高速模式&#xff…

新品:新一代全双工音频对讲模块SA618F22-C1

SA618F22-C1是我司一款升级版的无线数字和音频二合一全双工传输模块,支持8路并发高音质通话。用户不仅可以通过串口实现数据的无线传输,还可以通过I2S数字音频或模拟音频接口来传输语音信号。该模块内置高速微控制器、回声消除电路、ESD静电防护、高性能…

全国自闭症学校全寄宿制:为特殊儿童提供专业教育护理

在全国范围内,随着对自闭症儿童教育需求的日益增长,全寄宿制自闭症学校逐渐成为了一种重要的教育模式。这些学校以其专业的教育团队、全面的生活护理以及个性化的教学方案,为自闭症儿童提供了一个安全、稳定、充满爱的成长环境。在广州&#…

走进灯塔工厂,腾讯云携手业界专家共筑AI智造未来

现在,我国工业正处于从数字化向智能化转型的关键阶段,而人工智能、云计算和大数据等前沿技术正成为推动这进程的核心力量。以大模型为例,大模型通过高效处理和分析海量数据,帮助企业挖掘出有价值的规律和趋势,有效拓展…

使用双向链表和哈希表实现LRU缓存

在日常开发中,缓存 是一个非常常见且重要的技术手段,能够显著提升系统性能。为了保证缓存的有效性,需要实现一种机制,在缓存空间不足时,能够自动淘汰最久未被使用的数据。这种机制就是**LRU(Least Recently…

CSS文本格式化

通过 CSS 中的文本属性您可以像操作 Word 文档那样定义网页中文本的字符间距、对齐方式、缩进等等,CSS 中常用的文本属性如下所示: text-align:设置文本的水平对齐方式;text-decoration:设置文本的装饰;te…

面试题-部分

目录 1. 从输入url到渲染页面,中间经历了什么? 2. vue中的v-if和v-show有什么区别 3. 什么是Css中的回流(重排)与重绘 4. 介绍一下let、const、var的区别 5. 箭头函数和普通函数有什么区别 6. Css中常用的水平垂直居中解决方…

传输大咖49 | 科学研究院跨网文件交换高效、安全解决方案

在科学研究领域,数据的价值堪比黄金。科学研究所的日常运作依赖于大量的数据交换,高效的文件交换系统离不开内部合作和与外部合作伙伴的交流。然而,随着数据量的激增和网络环境的复杂性,传统的文件交换方法很难满足需求。本文将讨…

RK3568 android11 适配鼎桥MT5710-CN 5G模块

一,概述 鼎桥MT571X设备和Android系统主要通过USB接口进行数据通信,Android系统上的Linux内核需要根据鼎桥模块设备上报的USB设备接口加载USB驱动,USB驱动正确加载后,鼎桥模块才能正常工作。 Android系统中支持鼎桥模块设备相关的Linux内核驱动架构,如下图所示: 在Lin…

js删除emoji表情问题

emoji标签占位两个 &#xff0c;直接删除后一位会出现乱码符&#xff1b; 判断是否是emoji function isEmoji(char) {let code char.charCodeAt(0);return code>55296&&code<57343 } // 使用方法&#xff0c;传入单字符 console.log(isEmoji(1)); // false con…

Kubernetes 配置管理

一、什么是 ConfigMap&#xff1f; 在传统架构中&#xff0c;配置文件往往被保存在宿主机上&#xff0c;程序启动是可以指定某个配置文件&#xff0c;但是使用容器部署时&#xff0c;容器所在的节点并不固定&#xff0c;所以不能使用这种方式&#xff0c;此处在构建镜像时&…

【Redis】主从复制(下)--主从复制原理和流程

文章目录 主从复制原理主从节点建立复制流程图数据同步 psyncpsync的语法格式 psync运行流程全量复制全量复制的流程全量复制的缺陷有磁盘复制 vs 无磁盘复制 部分复制部分复制的流程复制积压缓冲区 实时复制 主从复制原理 主从节点建立复制流程图 保存主节点的信息从节点(sla…

感悟:糟糠之妻不下堂和现在女性觉醒的关系

古人说“糟糠之妻不下堂”真是害惨了中国女性&#xff0c;古代之所以有这一说法&#xff0c;大概是因为男子可以三妻四妾&#xff0c;妻子永远是正妻&#xff0c;也不需要讲究什么从一而终&#xff0c;更不会讲什么男德&#xff0c;只会要求女性学习女德、女训之类&#xff0c;…

性能测试:性能测试报告

性能测试报告是性能测试的产出物之一&#xff0c;它是对系统性能测试结果和数据的总结和分析&#xff0c;记录了系统在不同负载和场景下的性能表现和性能问题。性能测试报告提供了有关系统性能的详细信息&#xff0c;供项目团队、开发人员和其他相关利益相关者参考。 性能测试…

原生app云打包,更换图标,和名称。PDA的安装正式包

原生app云打包 复制下载即可&#xff0c;是正式版

Android下MVP和MVVM模式的实践

转载注明出处&#xff1a;https://blog.csdn.net/skysukai 1、前言 MVP和MVVM诞生已经好些年头了&#xff0c;记得刚毕业才参加工作的时候&#xff0c;第一次见到了有上万行的Activity&#xff0c;这种巨无霸的Activity维护起来简直就是噩梦。这时候&#xff0c;就需要进行代…

商标价值如何评估与增值?

商标是企业的标志&#xff0c;代表着企业的产品或服务质量、信誉和形象。一个具有高知名度和美誉度的商标&#xff0c;能够为企业带来巨大的商业价值。它不仅可以帮助企业在市场中脱颖而出&#xff0c;吸引消费者的关注和购买&#xff0c;还可以作为企业的重要资产进行融资、并…

无人便利店无人超市云值守收银系统源码

随着人力成本越来越高&#xff0c;很多门店越来想做无人值守模式&#xff0c;尤其是晚上休息时间等想让云值守客服来看店。自然要求收银系统需要可以在【有收银员值守】和【无收银员值守】两种模式灵活切换。 1. 有收银员值守模式 白天有收银员在店&#xff0c;收银员可以协助…

Tomcat服务与运用

案例准备 1.规划节点 IP 主机名 节点 192.168.20.20 tomcat Tomcat 2.基础准备 使用VMWare Workstation软件安装CentOS 7.2操作系统&#xff0c;镜像使用提供的CentOS-7-x86_64-DVD-1804.iso&#xff0c;最小化安装CentOS 7.2系统 案例实施 1.基础环境配置 1.1修改…

微信小程序的 button 标签的边框如何去除?

目录 问题描述&#xff1a; 问题原因&#xff1a; 解决办法&#xff1a; 方案一 方案二 问题描述&#xff1a; 实际开发中会发现这个 button 自带有样式&#xff0c;当背景颜色设置为白色的时候还有一个黑色的边框&#xff0c;刚开始那个边框怎么都去不掉 无法去除的边框…