gMLP(NeurIPS 2021)原理与代码解析

news2024/11/16 21:30:42

paper:Pay Attention to MLPs

third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mlp_mixer.py

方法介绍

gMLP和MLP-Mixer以及ResMLP都是基于MLP的网络结构,非常简单,关于MLP-Mixer和ResMLP的介绍见MLP-Mixer(NeurIPS 2021, Google)论文与源码解读-CSDN博客、ResMLP(NeurIPS 2021,Meta)论文与代码解析-CSDN博客。

在MLP-Mixer中每个block包含两个MLP,每个MLP包含两个线性层(即全连接层),一个MLP用于token间的信息交互,另一个MLP用于通道间的信息交互,每个MLP都用了residual connection,标准化采用LayerNorm。而在ResMLP中,第一个包含两个线性层的token MLP换成了单个线性层,此外在线性层前后包含两个标准化层pre-normalization和post-normalization,pre-normalization采用了简单的仿射变换,post-normalization采用了CaiT中的LayerScale。

gMLP的结构和伪代码如图1所示。可以看到gMLP将token_mlp(即这里的spatial gating unit)和channel_mlp放到了一起,只包含一个skip-connection,而不是像MLP-Mixer和ResMLP中每个mlp都采用一个skip-connection。此外block内的结构和MLP-Mixer以及ResMLP中的先token_mlp后channel_mlp不同,这里采用了channel+token+channel的形式。最后作者专门为token_mlp设计了一个门控机制,将输入split开一分为二,一半经过spatial proj得到的输出再和另一半相乘得到最终输出。

以上就是gMLP和MLP-Mixer以及ResMLP不同之处,总共包括三点,整体结构也非常简单。下面就直接用代码来解释具体的实现细节。

代码解析

一个完整的block的代码如下,forward函数中可以看到只包含一个skip-connection,self.mlp_channels包含了图1中第一个Channel Proj到最后的Channel Proj。

class SpatialGatingBlock(nn.Module):
    """ Residual Block w/ Spatial Gating

    Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
    """
    def __init__(
            self,
            dim,
            seq_len,
            mlp_ratio=4,
            mlp_layer=GatedMlp,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            act_layer=nn.GELU,
            drop=0.,
            drop_path=0.,
    ):
        super().__init__()
        channel_dim = int(dim * mlp_ratio)  # 512x6=3072
        self.norm = norm_layer(dim)
        sgu = partial(SpatialGatingUnit, seq_len=seq_len)  # 196
        self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):  # (1,196,512)
        x = x + self.drop_path(self.mlp_channels(self.norm(x)))
        return x

上面的mlp_layer的代码如下,self.fc1和self.fc2对应两个Channel Proj。

class GatedMlp(nn.Module):
    """ MLP as used in gMLP
    """
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            norm_layer=None,
            gate_layer=None,
            bias=True,
            drop=0.,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        if gate_layer is not None:
            assert hidden_features % 2 == 0
            self.gate = gate_layer(hidden_features)
            hidden_features = hidden_features // 2  # FIXME base reduction on gate property?
        else:
            self.gate = nn.Identity()
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):  # (1,196,512)
        # Linear(in_features=512, out_features=3072, bias=True)
        x = self.fc1(x)  # (1,196,3072)
        x = self.act(x)
        x = self.drop1(x)
        x = self.gate(x)  # (1,196,1536)
        x = self.norm(x)
        # Linear(in_features=1536, out_features=512, bias=True)
        x = self.fc2(x)  # (1,196,512)
        x = self.drop2(x)
        return x

gate_layer的代码如下,其中x.chunk(2, dim=-1)表示将x沿最后一个维度均分为2份。

class SpatialGatingUnit(nn.Module):
    """ Spatial Gating Unit

    Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
    """
    def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm):
        super().__init__()
        gate_dim = dim // 2
        self.norm = norm_layer(gate_dim)
        self.proj = nn.Linear(seq_len, seq_len)  # 196,196

    def init_weights(self):
        # special init for the projection gate, called as override by base model init
        nn.init.normal_(self.proj.weight, std=1e-6)
        nn.init.ones_(self.proj.bias)

    def forward(self, x):  # (1,196,3072)
        u, v = x.chunk(2, dim=-1)  # (1,196,1536),(1,196,1536)
        v = self.norm(v)
        v = self.proj(v.transpose(-1, -2))  # (1,1536,196)
        return u * v.transpose(-1, -2)  # (1,196,1536) * (1,196,1536)

实验结果

作者设计了三个不同大小的gMLP,具体参数配置如下

和其它模型在ImageNet上的分类性能对比如下,可以看到和类似大小的MLP-Mixer与ResMLP相比,gMLP用更少的参数得到了更好的性能。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1864411.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

太阳初升:born 诞生

在《long long ago》中,我们分析出了首字母l的形象,就是长长的脐带的形象,ong就是脐带冗余蔓连于婴儿肚子上的形象,整个场景为婴儿呱呱坠地脐带尚未剪掉时的情景,而且on通汉字“旦”,通“one”,…

红酒品鉴秘籍:一键解锁味觉宇宙,开启你的味觉探险新纪元

红酒,这种优雅的液体,蕴藏着丰富的口感和层次,每一次的品鉴都是一次味觉的探险。今天,就让我们一起探索红酒品鉴的奥秘,解锁味觉的新世界,而在这个过程中,雷盛红酒将成为我们的向导,…

GraphQL:简介

GraphQL 图片来源: 我们将探索GraphQL 的基础知识,并学习如何使用Apollo将其与 React 和 React Native 等前端框架连接起来。这将帮助您了解如何使用 GraphQL、React、React Native 和 Apollo 构建现代、高效的应用程序。 什么是 GraphQL?…

[深度学习] 生成对抗网络GAN

生成对抗网络(Generative Adversarial Networks,GANs)是一种由 Ian Goodfellow 等人在2014年提出的深度学习模型Generative Adversarial Networks。GANs的基本思想是通过两个神经网络(生成器和判别器)的对抗过程&#…

Nodejs使用mqtt库连接阿里云服务器

建项目 命令行输入: npm init 输入项目名,自动化生成项目列表。 6.3 编写代码 新建mqtt_demo_aliyun.js,代码如下: // mqtt_demo_aliyun.jsconst mqtt require("mqtt"); const connectUrl "ws://post-cn-nw**…

展厅设计中需要人性化的地方

1、预留参观空间 展厅空间的布局设计必须尽可能的宽敞,以避免参观人数较多时可能会发生的拥堵,重点展品需要预留较大的展示空间或四面通畅的中心位置,更方便观众从不同角度与方位参观。因为是展厅,不仅代表着企业形象,…

安科瑞光伏并网电表ADL400N-CT双向计量防逆流自带互感器电表-安科瑞 蒋静

1 概述 ADL 系列导轨式多功能电能表,是主要针对于光伏并网系统、微逆系统、储能系统、交流耦合系统等新能源发电系统而设计的一款智能仪表,产品具有精度高、体积小、响应速度快、安装方便等特点。具有对电力参数进行采样计量和监测,逆变器或…

flask与vue实现通过websocket通信

在一些情况下,我们需要实现前后端之间的时刻监听,本文是一篇工具文档,用于解决前后端之间使用websocket交互。 一. Flask的相关配置 1. 下载相关依赖库 如果还没有配置flask的话,需要先安装flask,同时为解决跨域问题&#xff0…

Topaz Gigapixel AI图片无损放大软件下载安装,Topaz Gigapixel AI 高精度的图片无损放大

Topaz Gigapixel AI无疑是一款革命性的图片无损放大软件,它在图像处理领域开创了一种全新的可能性。 Topaz Gigapixel AI的核心功能在于能够将图片进行高精度的无损放大。虽然经过软件处理的图片严格意义上并不能算是完全无损,但相较于传统方法&#xf…

AI实战案例!如何运用SD完成运营设计海报?玩转Stable Diffusion必知的3大绝技

大家好我是安琪! Satble Diffusion 给视觉设计带来了前所未有的可能性和机会,它为设计师提供了更多选择和工具的同时,也改变了设计师的角色和设计流程。然而,设计师与人工智能软件的协作和创新能力仍然是不可或缺的。接下来我将从…

【语言模型】探索AI模型、AI大模型、大模型、大语言模型与大数据模型的关系与协同

一、引言 随着人工智能(AI)技术的飞速发展,各种AI模型如雨后春笋般涌现,其中AI模型、AI大模型、大模型、大语言模型以及大数据模型等概念在学术界和工业界引起了广泛关注。这些模型不仅各自具有独特的特点和应用场景,…

告别臭脚尴尬!安全鞋除臭秘籍大公开

你是否有过这样的烦恼,忙碌一天回到家,脱鞋的瞬间,那令人窒息的气味让人瞬间清醒?别担心,今天百华小编就与大家一起探讨下安全鞋除臭的秘籍,让你从此告别臭脚尴尬! 首先,我们要了解…

PHP 面向对象编程(OOP)入门指南

面向对象编程(Object-Oriented Programming,简称OOP)是一种编程范式,通过使用对象来设计和组织代码。PHP作为一种广泛使用的服务器端脚本语言,支持面向对象编程。本文将介绍PHP面向对象编程的基本概念和用法&#xff0…

SpringCloud Alibaba Seata2.0分布式事务AT模式实践总结

这里我们划分订单、库存与支付三个module来实践Seata的分布式事务。 依赖版本(jdk17)&#xff1a; <spring.boot.version>3.1.7</spring.boot.version> <spring.cloud.version>2022.0.4</spring.cloud.version> <spring.cloud.alibaba.version>…

美多多商城定义用户模型类遇见的问题

from django.db import models from django.contrib.auth.models import AbstractUser # Create your models here. class User(AbstractUser):mobile models.CharField(max_length11, uniqueTrue,verbose_name手机号)class Meta:db_tabletb_users #自定义表名verbose_name用户…

【动态内存】详解

Hi~&#xff01;这里是奋斗的小羊&#xff0c;很荣幸您能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎 ~~ &#x1f4a5;&#x1f4a5;个人主页&#xff1a;奋斗的小羊 &#x1f4a5;&#x1f4a5;所属专栏&#xff1a;C语言 &#x1f680;本系列文章为个人学习…

深入浅出 langchain 1. Prompt 与 Model

示例 从代码入手来看原理 from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI prompt ChatPromptTemplate.from_template("tell me a short joke about…

B端页面:日志管理页面,简洁实用的设计法门

B端日志管理是指在企业级后台系统中对系统操作日志进行记录、查看和管理的功能。 它的作用主要有以下几点&#xff1a; 1. 安全审计&#xff1a;通过记录用户的操作日志&#xff0c;可以对系统的安全性进行审计和监控&#xff0c;及时发现异常操作和安全漏洞。 2. 故障排查&a…

Program LLMs,不只是Prompt LLMs

前言 随着大模型的使用和应用越来越频繁&#xff0c;也越来越广泛&#xff0c;大家有没有陷入到无限制的研究、调优自己的prompt。 随之&#xff0c;市面上也出现了提示词工程师&#xff0c;更有专门的提示工程一说。 现在网上搜一搜&#xff0c;有各种各样的写提示词的技巧…

Python多线程技巧心得详解

概要 多线程是一种能够并发执行代码的方法,可以提高程序的执行效率和响应速度。本文将详细介绍 Python 中多线程的概念、使用场景、基本用法以及实际应用,可以更好地掌握多线程编程。 什么是多线程? 多线程是一种在单个进程内并发执行多个线程的技术。每个线程共享相同的内…