Google的MLP-MIXer的复现(pytorch实现)

news2024/11/18 19:40:29

Google的MLP-MIXer的复现(pytorch实现)

该模型原论文实现用的jax框架实现,先贴出原论文的代码实现:

# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

import einops
import flax.linen as nn
import jax
import jax.numpy as jnp


class MlpBlock(nn.Module):
  mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.Dense(self.mlp_dim)(x)
    y = nn.gelu(y)
    return nn.Dense(x.shape[-1])(y)


class MixerBlock(nn.Module):
  """Mixer block layer."""
  tokens_mlp_dim: int
  channels_mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.LayerNorm()(x)
    y = jnp.swapaxes(y, 1, 2)
    y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y) #  (32, 512, 196)
    y = jnp.swapaxes(y, 1, 2)
    x = x + y
    y = nn.LayerNorm()(x)
    return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)


class MlpMixer(nn.Module):
  """Mixer architecture."""
  patches: Any
  num_classes: int
  num_blocks: int
  hidden_dim: int
  tokens_mlp_dim: int
  channels_mlp_dim: int
  model_name: Optional[str] = None

  @nn.compact
  def __call__(self, inputs, *, train):
    del train
    x = nn.Conv(self.hidden_dim, self.patches.size,
                strides=self.patches.size, name='stem')(inputs)
    x = einops.rearrange(x, 'n h w c -> n (h w) c')  # 从(32,512,14,14)变成了(32,196,512)
    for _ in range(self.num_blocks):
      x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
    x = nn.LayerNorm(name='pre_head_layer_norm')(x)
    x = jnp.mean(x, axis=1)
    if self.num_classes:
      x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
                   name='head')(x)
    return x


model_params = {
    'patches': {'size': (16, 16), 'stride': (16, 16)}, # 这里需要一个描述patch大小和步长的对象,例如Flax的stem模块初始化参数
    'num_classes': 10,  # 分类任务的类别数
    'num_blocks': 8,  # Mixer Block的重复次数
    'hidden_dim': 512,  # 隐藏层维度
    'tokens_mlp_dim': 256,  # token mixing的MLP维度
    'channels_mlp_dim': 2048,  # channel mixing的MLP维度
}

# 准备输入数据,例如一批32张图片,每张图片尺寸为512x14x14(假设已经按要求预处理)

# 初始化模型
seed=0
key = jax.random.PRNGKey(seed)
model = MlpMixer.apply(key, **model_params)

input_data = jnp.ones((4096, 224, 224, 3))  # 示例输入数据
# 调用模型进行前向传播
output = model(input_data)

print("Output shape:", output)  # 打印输出形状,预期是(32, 10)如果num_classes=10

该模型的总体框架图如下所示:

在这里插入图片描述

对该框架的讲解,网上已经很多了,就不在此赘述。

实现的pytorch代码如下所示:

class MlpBlock(nn.Module):
    def __init__(self, in_mlp_dim=196, out_mlp_dim=256):
        super(MlpBlock, self).__init__()
        self.mlp_dim = out_mlp_dim
        self.dense1 = nn.Linear(in_mlp_dim, out_mlp_dim)  # 若输入的向量为[32,196, 512]则输入的也应该是512,输出可以自己定
        self.gelu = nn.GELU()
        self.dense2 = nn.Linear(out_mlp_dim, in_mlp_dim)

    def forward(self, x):
        y = self.dense1(x)
        y = self.gelu(y)
        y = self.dense2(y)
        return y


class MixerBlock(nn.Module):
    def __init__(self, tokens_mlp_dim=256, channels_mlp_dim=2048, batch_size=32):
        super(MixerBlock, self).__init__()
        self.batch_size = batch_size
        self.norm1 = nn.LayerNorm(512)  # 对512维的做归一化,默认给最后一个维度做归一化
        self.token_Mixing = MlpBlock(out_mlp_dim=tokens_mlp_dim)
        self.norm2 = nn.LayerNorm(512)      # 对512维的做归一化
        self.channel_mixing = MlpBlock(in_mlp_dim=512, out_mlp_dim=channels_mlp_dim)

    def forward(self, x):
        y = self.norm1(x)
        y = y.permute(0, 2, 1)
        y = self.token_Mixing(y)
        y = y.permute(0, 2, 1)
        x = x + y
        y = self.norm2(x)
        return x + self.channel_mixing(y)


class MlpMixer(nn.Module):
    def __init__(self, patches, num_classes, num_blocks, hidden_dim, tokens_mlp_dim, channels_mlp_dim):
        super(MlpMixer, self).__init__()
        self.stem = nn.Conv2d(3, hidden_dim, kernel_size=patches, stride=patches)
        self.mixer_block_1 = MixerBlock()
        self.mixer_blocks = nn.ModuleList([MixerBlock(tokens_mlp_dim, channels_mlp_dim) for _ in range(num_blocks)])
        self.pre_head_norm = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward(self, x):
        x = self.stem(x)
        b, c, h, w = x.shape
        x = x.view(b, c, -1).permute(0, 2, 1)
        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)
        x = self.pre_head_norm(x)
        x = x.mean(dim=1)
        x = self.head(x)
        return x


# model = MlpMixer(16, 10, 6, 512, 256, 2048)
# input_tensor = torch.randn(32, 3, 224, 224)  # (batch size, num_patches, input_dim)
# output = model(input_tensor)
# print(output)

在将flax框架的代码改为pytorch实现的时候,还是踩了不少的坑,在此讲一下,希望后面做的人,可以避免。

1.在flax框架的nn.linear层中没有输入维度,只有一个输出维度。

2.在处理两个差异的时候,如输入维度[32,196,512],其中代表的意思分别为batch_size为32,196为图片在经过patch之后的224*224输入之后经过patch=16,变为14 * 14即196,512会在二维卷积处理之后输出的channel类似。

1.在flax框架的nn.linear层中没有输入维度,只有一个输出维度。

2.在处理两个差异的时候,如输入维度[32,196,512],其中代表的意思分别为batch_size为32,196为图片在经过patch之后的224*224输入之后经过patch=16,变为14 * 14即196,512会在二维卷积处理之后输出的channel类似。

在nn.linear那儿的in_channel与第三个维度保持一致,就可以不必将其三维的转换为二维的。同时在对layernorm那儿转换的时候,默认也是对最后一个维度进行正则化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1699767.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

springboot + Vue前后端项目(第十一记)

项目实战第十一记 1.写在前面2. 文件上传和下载后端2.1 数据库编写2.2 工具类CodeGenerator生成代码2.2.1 FileController2.2.2 application.yml2.2.3 拦截器InterceptorConfig 放行 3 文件上传和下载前端3.1 File.vue页面编写3.2 路由配置3.3 Aside.vue 最终效果图总结写在最后…

【NumPy】关于numpy.clip()函数,看这一篇文章就够了

🧑 博主简介:阿里巴巴嵌入式技术专家,深耕嵌入式人工智能领域,具备多年的嵌入式硬件产品研发管理经验。 📒 博客介绍:分享嵌入式开发领域的相关知识、经验、思考和感悟,欢迎关注。提供嵌入式方向…

H.机房【蓝桥杯】/数组链式前向星建图+堆优化版dijkstra

机房 数组链式前向星建图堆优化版dijkstra #include<iostream> #include<queue> #include<cstring> #include<vector> using namespace std; typedef pair<int,int> pii; //无向图开两倍 int e[200005],ne[200005],v[200005],h[200005],du[1000…

前端 JS 经典:Web 性能指标

什么是性能指标&#xff1a;Web Performance Metrics 翻译成 Web 性能指标&#xff0c;一般和时间有关系&#xff0c;在短时间内做更多有意义的事情。 一个站点表现得好与不好&#xff0c;标准在于用户体验&#xff0c;而用户体验好不好&#xff0c;有一套 RAIL 模型来衡量。这…

2024年上半年系统架构设计师真题-复原程度90%

前言 此次考试监考特别严格&#xff0c;草稿纸不允许带出考场&#xff0c;并且准考证上不允许任何写画&#xff0c;甚至连笔都允许带一支&#xff0c;所以下面的相关题目都是参考一些群友的提供&#xff0c;加上自己的记忆回顾&#xff0c;得到的结果。 其中综合知识部分的题…

修复谷歌 AdSense 的 Ads.Txt 无效的有收益损失风险提示

明月的 AdSense 账号后台一直都有“有收益损失风险 - 您需要纠正 ads.txt 文件存在的一些问题&#xff0c;以免严重影响您的收入。”的提示长达一年多了&#xff0c;这次重新开始投放谷歌 AdSense 广告后感觉需要解决掉这个问题了&#xff0c;因为已经全站使用了 CloudFlare&am…

《Ai学习笔记》-模型集成部署

后续大多数模型提升速度和精度&#xff1a; 提升速度&#xff1a; -知识蒸馏&#xff0c;以distillBert和tinyBert为代表 -神经网络优化技巧。prune来剪裁多余的网络节点&#xff0c;混合精度&#xff08;fp32和fp26混合来降低计算精度从从而实现速度的提升&#xff09; 提…

驾驭数字前沿--欧盟商会网络安全大会活动

本次安策参加由欧盟商会组织举办的--超越 2024 年网络安全大会&#xff1a;驾驭数字前沿大会(上海)&#xff0c;安策在大会上做了《2024数据威胁报告主题报告》并希望携手各行业伙伴&#xff0c;共同驾驭数字前沿的波涛&#xff0c;共创安全、合规、高效的数字未来。 【安策活动…

操作系统入门系列-MIT6.828(操作系统工程)学习笔记(二)----课程实验环境搭建(wsl2+ubuntu+quem+xv6)

MIT6.S081&#xff08;操作系统&#xff09;学习笔记 操作系统入门系列-MIT6.828&#xff08;操作系统&#xff09;学习笔记&#xff08;一&#xff09;---- 操作系统介绍与接口示例 操作系统入门系列-MIT6.828&#xff08;操作系统工程&#xff09;学习笔记&#xff08;二&am…

大模型的实践应用24-LLaMA-Factory微调通义千问qwen1.5-1.8B模型的实例

大家好,我是微学AI,今天给大家介绍一下大模型的实践应用24-LLaMA-Factory微调通义千问qwen1.5-1.8B模型的实例, LLaMA-Factory是一个专门用于大语言模型微调的框架,它支持多种微调方法,如LoRA、QLoRA等,并提供了丰富的数据集和预训练模型,便于用户进行模型微调。通义千问…

谷歌Google广告投放优势和注意事项!

谷歌Google作为全球最大的搜索引擎&#xff0c;谷歌不仅拥有庞大的用户基础&#xff0c;还提供了高度精准的广告投放平台&#xff0c;让广告主能够高效触达目标受众&#xff0c;实现品牌曝光、流量增长乃至销售转化的多重目标&#xff0c;云衔科技以专业服务助力您谷歌Google广…

C++笔记:红黑树与哈希表

1.容器rb_tree 按正常规则it遍历&#xff0c;便能得到排序状态不能使用rb_tree的iterators改变元素值两种插入操作&#xff1a;insert_unique()和insert_equal() template <class Key, class Value, class KeyOfValue, class Compare, class Allocalloc> class rb_tree…

基于Zookeeper的分布式锁

分布式锁的介绍 在Java的多线程部分&#xff0c;我们知道如果在单个jvm进程中&#xff0c;多个线程之间同时访问一个资源&#xff0c;此时会有多线程的安全问题。为了解决这个线程安全的问题&#xff0c;我们可以使⽤“锁”来实现。但是&#xff0c;多个jvm进程之间如果同时访问…

计算机毕业设计 | SpringBoot社区物业管理系统 小区管理(附源码)

1&#xff0c; 概述 1.1 课题背景 近几年来&#xff0c;随着物业相关的各种信息越来越多&#xff0c;比如报修维修、缴费、车位、访客等信息&#xff0c;对物业管理方面的需求越来越高&#xff0c;我们在工作中越来越多方面需要利用网页端管理系统来进行管理&#xff0c;我们…

【源码】java + uniapp交易所源代码/带搭建教程java交易所/完整源代码

java uniapp交易所源代码/带搭建教程java交易所/完整源代码 带简洁教程&#xff0c;未测 java uniapp交易所源代码/带搭建教程java交易所/完整源代码 - 吾爱资源网

软件需求开发管理规程-Word原件(配套软件全资料文档)

1. 目的 2. 适用范围 3. 参考文件 4. 术语和缩写 5. 需求获取的方式 5.1. 与用户交谈向用户提问题 5.1.1. 访谈重点注意事项 5.1.2. 访谈指南 5.2. 参观用户的工作流程 5.3. 向用户群体发调查问卷 5.4. 已有软件系统调研 5.5. 资料收集 5.6. 原型系统调研 5.6.1. …

Android11热点启动和关闭

Android官方关于Wi-Fi Hotspot (Soft AP) 的文章&#xff1a;https://source.android.com/docs/core/connect/wifi-softap?hlzh-cn 在 Android 11 的WifiManager类中有一套系统 API 可以控制热点的开和关&#xff0c;代码如下&#xff1a; 开启热点&#xff1a; // SoftApC…

Java进阶学习笔记27——StringBuilder、StringBuffer

StringBuilder&#xff1a; StringBuilder代表可变字符串对象&#xff0c;相当于一个容器&#xff0c;它里面装的字符串是可以改变的&#xff0c;就是用来操作字符串的。 好处&#xff1a; StringBuilder比String更适合做字符串的修改操作&#xff0c;效率会更高&#xff0c;…

基于Ruoyi-Cloud-Plus重构黑马项目-学成在线

文章目录 一、系统介绍二、系统架构图三、参考教程四、演示图例机构端运营端用户端开发端 一、系统介绍 毕设&#xff1a;基于主流微服务技术栈的在线教育系统的设计与实现 前端仓库&#xff1a;https://github.com/Xiamu-ssr/Dragon-Edu-Vue3 后端仓库&#xff1a;https://g…

Nodejs(文件操作,构建服务器,express,npm)

文章目录 文件操作1.读取文件1&#xff09;步骤2&#xff09;范例 2.写文件1&#xff09;步骤2&#xff09;范例 3.删除文件4.重命名文件夹5删除文件夹 Url1.url.parse()2.url.fomat() Query1.query.parse()2.query.stringfy()3.编码和解码 第三方模块1.nodemailer2.body-parse…