政安晨:【Keras机器学习实践要点】(二十)—— 使用现代 MLP 模型进行图像分类

news2025/2/24 20:25:01

目录

简介

设置

准备数据

配置超参数

建立分类模型

定义实验

使用数据增强

将补丁提取作为一个图层来实施

将位置嵌入作为一个图层来实施

MLP 混频器模型

FNet 模式

gMLP 模式

实施 gMLP 模块


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:为 CIFAR-100 图像分类实施 MLP-Mixer、FNet 和 gMLP 模型。

简介


本示例实现了三种基于多层感知器(MLP)的现代无注意力图像分类模型,并在 CIFAR-100 数据集上进行了演示:

1. Ilya Tolstikhin 等人基于两种类型 MLP 的 MLP-Mixer 模型。
2. James Lee-Thorp 等人基于非参数化傅立叶变换的 FNet 模型。
3. gMLP 模型,由 Hanxiao Liu 等人提出,基于带门控的 MLP。

本示例的目的并不是要比较这些模型,因为它们在不同数据集上的表现可能不同,而且超参数都经过了很好的调整。相反,它是为了展示这些模型主要构建模块的简单实现。

设置

import numpy as np
import keras
from keras import layers

准备数据

num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

演绎如下:

配置超参数

weight_decay = 0.0001
batch_size = 128
num_epochs = 1  # Recommended num_epochs = 50
dropout_rate = 0.2
image_size = 64  # We'll resize input images to this size.
patch_size = 8  # Size of the patches to be extracted from the input images.
num_patches = (image_size // patch_size) ** 2  # Size of the data array.
embedding_dim = 256  # Number of hidden units.
num_blocks = 4  # Number of blocks.

print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")

演绎如下:

建立分类模型

我们采用一种方法,根据处理模块构建分类器。

def build_classifier(blocks, positional_encoding=False):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.
    x = layers.Dense(units=embedding_dim)(patches)
    if positional_encoding:
        x = x + PositionEmbedding(sequence_length=num_patches)(x)
    # Process x using the module blocks.
    x = blocks(x)
    # Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.
    representation = layers.GlobalAveragePooling1D()(x)
    # Apply dropout.
    representation = layers.Dropout(rate=dropout_rate)(representation)
    # Compute logits outputs.
    logits = layers.Dense(num_classes)(representation)
    # Create the Keras model.
    return keras.Model(inputs=inputs, outputs=logits)

定义实验

我们实现了一个实用功能,用于编译、训练和评估给定模型。

def run_experiment(model):
    # Create Adam optimizer with weight decay.
    optimizer = keras.optimizers.AdamW(
        learning_rate=learning_rate,
        weight_decay=weight_decay,
    )
    # Compile the model.
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="acc"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
        ],
    )
    # Create a learning rate scheduler callback.
    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=5
    )
    # Create an early stopping callback.
    early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=10, restore_best_weights=True
    )
    # Fit the model.
    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[early_stopping, reduce_lr],
        verbose=0,
    )

    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    # Return history to plot learning curves.
    return history

使用数据增强

data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

将补丁提取作为一个图层来实施

class Patches(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size

    def call(self, x):
        patches = keras.ops.image.extract_patches(x, self.patch_size)
        batch_size = keras.ops.shape(patches)[0]
        num_patches = keras.ops.shape(patches)[1] * keras.ops.shape(patches)[2]
        patch_dim = keras.ops.shape(patches)[3]
        out = keras.ops.reshape(patches, (batch_size, num_patches, patch_dim))
        return out

将位置嵌入作为一个图层来实施

class PositionEmbedding(keras.layers.Layer):
    def __init__(
        self,
        sequence_length,
        initializer="glorot_uniform",
        **kwargs,
    ):
        super().__init__(**kwargs)
        if sequence_length is None:
            raise ValueError("`sequence_length` must be an Integer, received `None`.")
        self.sequence_length = int(sequence_length)
        self.initializer = keras.initializers.get(initializer)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "initializer": keras.initializers.serialize(self.initializer),
            }
        )
        return config

    def build(self, input_shape):
        feature_size = input_shape[-1]
        self.position_embeddings = self.add_weight(
            name="embeddings",
            shape=[self.sequence_length, feature_size],
            initializer=self.initializer,
            trainable=True,
        )

        super().build(input_shape)

    def call(self, inputs, start_index=0):
        shape = keras.ops.shape(inputs)
        feature_length = shape[-1]
        sequence_length = shape[-2]
        # trim to match the length of the input sequence, which might be less
        # than the sequence_length of the layer.
        position_embeddings = keras.ops.convert_to_tensor(self.position_embeddings)
        position_embeddings = keras.ops.slice(
            position_embeddings,
            (start_index, 0),
            (sequence_length, feature_length),
        )
        return keras.ops.broadcast_to(position_embeddings, shape)

    def compute_output_shape(self, input_shape):
        return input_shape

MLP 混频器模型

MLP 混频器是一种完全基于多层感知器(MLP)的架构,包含两种类型的 MLP 层

1. 一种是独立应用于图像斑块,混合每个位置的特征。
2. 另一种应用于跨斑块(沿通道),混合空间信息。

这类似于基于深度可分离卷积的模型,如 Xception 模型,但有两个链式密集变换,没有最大池化,以及层归一化而不是批归一化。

实施 MLP 混频器模块

class MLPMixerLayer(layers.Layer):
    def __init__(self, num_patches, hidden_units, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.mlp1 = keras.Sequential(
            [
                layers.Dense(units=num_patches, activation="gelu"),
                layers.Dense(units=num_patches),
                layers.Dropout(rate=dropout_rate),
            ]
        )
        self.mlp2 = keras.Sequential(
            [
                layers.Dense(units=num_patches, activation="gelu"),
                layers.Dense(units=hidden_units),
                layers.Dropout(rate=dropout_rate),
            ]
        )
        self.normalize = layers.LayerNormalization(epsilon=1e-6)

    def build(self, input_shape):
        return super().build(input_shape)

    def call(self, inputs):
        # Apply layer normalization.
        x = self.normalize(inputs)
        # Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].
        x_channels = keras.ops.transpose(x, axes=(0, 2, 1))
        # Apply mlp1 on each channel independently.
        mlp1_outputs = self.mlp1(x_channels)
        # Transpose mlp1_outputs from [num_batches, hidden_dim, num_patches] to [num_batches, num_patches, hidden_units].
        mlp1_outputs = keras.ops.transpose(mlp1_outputs, axes=(0, 2, 1))
        # Add skip connection.
        x = mlp1_outputs + inputs
        # Apply layer normalization.
        x_patches = self.normalize(x)
        # Apply mlp2 on each patch independtenly.
        mlp2_outputs = self.mlp2(x_patches)
        # Add skip connection.
        x = x + mlp2_outputs
        return x

构建、训练和评估 MLP-Mixer 模型

请注意,在 V100 GPU 上以当前设置训练模型,每个轮次大约需要 8 秒钟。

mlpmixer_blocks = keras.Sequential(
    [MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.005
mlpmixer_classifier = build_classifier(mlpmixer_blocks)
history = run_experiment(mlpmixer_classifier)

演绎结果如下:

与卷积模型和基于变换器的模型相比,MLP-Mixer 模型的参数数量要少得多,这就减少了训练和计算成本。

正如 MLP-Mixer 论文中提到的,当在大型数据集上进行预训练或使用现代正则化方案时,MLP-Mixer 可获得与最先进模型相当的分数。您可以通过增加嵌入维度、增加混合块数量和延长模型训练时间来获得更好的结果。您还可以尝试增加输入图像的大小,并使用不同的补丁尺寸。

FNet 模式

FNet 使用与 Transformer 模块类似的模块。不过,FNet 用一个无参数的二维傅立叶变换层取代了 Transformer 模块中的自注意层:

1. 一个一维傅里叶变换沿斑块应用。
2. 沿通道进行一次一维傅里叶变换。

实施 FNet 模块

class FNetLayer(layers.Layer):
    def __init__(self, embedding_dim, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.ffn = keras.Sequential(
            [
                layers.Dense(units=embedding_dim, activation="gelu"),
                layers.Dropout(rate=dropout_rate),
                layers.Dense(units=embedding_dim),
            ]
        )

        self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
        self.normalize2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        # Apply fourier transformations.
        real_part = inputs
        im_part = keras.ops.zeros_like(inputs)
        x = keras.ops.fft2((real_part, im_part))[0]
        # Add skip connection.
        x = x + inputs
        # Apply layer normalization.
        x = self.normalize1(x)
        # Apply Feedfowrad network.
        x_ffn = self.ffn(x)
        # Add skip connection.
        x = x + x_ffn
        # Apply layer normalization.
        return self.normalize2(x)

构建、训练和评估 FNet 模型

请注意,在 V100 GPU 上以当前设置训练模型,每个轮次大约需要 8 秒钟。

演绎如下;

如 FNet 论文所示,通过增加嵌入维度、增加 FNet 块数和延长模型训练时间,可以获得更好的结果。您还可以尝试增加输入图像的大小,并使用不同的补丁尺寸。FNet 可以非常高效地扩展到较长的输入,运行速度比基于注意力的 Transformer 模型快得多,并能产生具有竞争力的准确性结果。

gMLP 模式

gMLP 是一种以空间门控单元(SGU)为特色的 MLP 架构。空间门控单元(SGU)可通过以下方式实现跨空间(通道)维度的跨门控互动:

1. 通过跨补丁(沿通道)线性投影,对输入进行空间转换。
2. 对输入及其空间变换进行元素乘法运算。

实施 gMLP 模块

class gMLPLayer(layers.Layer):
    def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.channel_projection1 = keras.Sequential(
            [
                layers.Dense(units=embedding_dim * 2, activation="gelu"),
                layers.Dropout(rate=dropout_rate),
            ]
        )

        self.channel_projection2 = layers.Dense(units=embedding_dim)

        self.spatial_projection = layers.Dense(
            units=num_patches, bias_initializer="Ones"
        )

        self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
        self.normalize2 = layers.LayerNormalization(epsilon=1e-6)

    def spatial_gating_unit(self, x):
        # Split x along the channel dimensions.
        # Tensors u and v will in the shape of [batch_size, num_patchs, embedding_dim].
        u, v = keras.ops.split(x, indices_or_sections=2, axis=2)
        # Apply layer normalization.
        v = self.normalize2(v)
        # Apply spatial projection.
        v_channels = keras.ops.transpose(v, axes=(0, 2, 1))
        v_projected = self.spatial_projection(v_channels)
        v_projected = keras.ops.transpose(v_projected, axes=(0, 2, 1))
        # Apply element-wise multiplication.
        return u * v_projected

    def call(self, inputs):
        # Apply layer normalization.
        x = self.normalize1(inputs)
        # Apply the first channel projection. x_projected shape: [batch_size, num_patches, embedding_dim * 2].
        x_projected = self.channel_projection1(x)
        # Apply the spatial gating unit. x_spatial shape: [batch_size, num_patches, embedding_dim].
        x_spatial = self.spatial_gating_unit(x_projected)
        # Apply the second channel projection. x_projected shape: [batch_size, num_patches, embedding_dim].
        x_projected = self.channel_projection2(x_spatial)
        # Add skip connection.
        return x + x_projected

建立、训练和评估 gMLP 模型

请注意,在 V100 GPU 上以当前设置训练模型,每个轮次大约需要 9 秒钟。

gmlp_blocks = keras.Sequential(
    [gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.003
gmlp_classifier = build_classifier(gmlp_blocks)
history = run_experiment(gmlp_classifier)

演绎如下:

如 gMLP 论文所示,通过增加嵌入维度、增加 gMLP 块数和延长模型训练时间,可以获得更好的效果。您还可以尝试增加输入图像的大小,并使用不同的补丁尺寸。请注意,该论文使用了高级正则化策略,如 MixUp 和 CutMix,以及 AutoAugment。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1576125.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Leetcode每日一题】 动态规划 - LCR 166. 珠宝的最高价值(难度⭐⭐)(52)

1. 题目解析 题目链接:LCR 166. 珠宝的最高价值 这个问题的理解其实相当简单,只需看一下示例,基本就能明白其含义了 2.算法原理 想象一下,你正在玩一个寻宝游戏,游戏地图是一个二维网格,每个格子都藏有一…

神经网络中的超参数调整

背景 在深度神经网络学习和优化中,超参数调整一项必备技能,通过观察在训练过程中的监测指标如损失loss和准确率来判断当前模型处于什么样的训练状态,及时调整超参数以更科学地训练模型能够提高资源利用率。在本研究中使用了以下超参数&#x…

Golang 开发实战day08 - Multiple Return values

Golang 教程08 - Multiple Return values 1. Multiple return values 1.1 如何理解多个返回值? Go语言中的多返回值,就像你听了一首歌曲yellow,可以从歌曲里反馈出忧郁和害羞!Goland的多个返回值就类似于如此,设定一…

vue实现验证码验证登录

先看效果&#xff1a; 代码如下&#xff1a; <template><div class"container"><div style"width: 400px; padding: 30px; background-color: white; border-radius: 5px;"><div style"text-align: center; font-size: 20px; m…

如何用Python编写简单的网络爬虫(页面代码简单分析过程)

一、什么是网络爬虫 在当今信息爆炸的时代&#xff0c;网络上蕴藏着大量宝贵的信息&#xff0c;如何高效地从中获取所需信息成为了一个重要课题。网络爬虫&#xff08;Web crawler&#xff09;作为一种自动化工具&#xff0c;可以帮助我们实现这一目标&#xff0c;用于数据分析…

Vscode连接WSL2当中的jupyter

主要解决办法参考自这篇博客 1. 在WSL当中安装jupyter 这个随便找一篇博客即可&#xff0c;比如这篇&#xff0c;也可以根据现有的环境参考其它博客内容 2. 使用jupyter创建一个虚拟环境 首先激活想要添加的虚拟环境后&#xff0c;输入命令安装库: pip install ipykernel …

免费全开源,功能强大的多连接数据库管理工具:DbGate

DbGate&#xff1a;您的全能数据库指挥中心&#xff0c;一站式免费开源解决方案&#xff0c;无缝连接并管理多款主流数据库&#xff0c;让复杂的数据世界变得轻松易控! - 精选真开源&#xff0c;释放新价值。 概览 DbGate 是跨平台的数据库管理器。支持 MySQL、PostgreSQL、SQ…

gin框架底层

gin框架底层 gin的背景和使用 这里蓝色的是gin增强的内容&#xff0c;红色的是为了支持增强的内容添加的东西&#xff0c;黄色的是原来的net/http库Gin框架是基于Go语言的net/http标准库构建的&#xff0c;它提供了一个gin.Engine对象&#xff0c;这个对象实现了http.Handler接…

零代码编程:用kimichat打造一个最简单的window程序

用kimichat可以非常方便的自动生成程序代码&#xff0c;有些小程序可能会频繁使用&#xff0c;如果每次都在vscode中执行就会很麻烦。常用的Python代码&#xff0c;可以直接做成一个window程序&#xff0c;点击就可以打开使用&#xff0c;方便很多。 首先&#xff0c;把kimich…

VGA显示器字符显示

1.原理 64*64256 2.1 Vga_pic.v module Vga_pic(input wire Vga_clk ,input wire sys_rst_n ,input wire [9:0] pix_x ,input wire [9:0] pix_y ,output reg [15:0] pix_data );parameter CHAR_B_H10d192,CHAR_B_V10d208;parameter CHAR_W10d256,CHAR_H10d64;paramet…

Linux从入门到精通 --- 4(上).快捷键、软件安装、systemctl、软链接、日期和时区、IP地址

文章目录 第四章(上)&#xff1a;4.1 快捷键4.1.1 ctrl c 强制停止4.1.2 ctrl d 退出4.1.3 history4.1.4 历史命令搜索4.1.5 光速移动快捷键4.1.6 清屏 4.2 软件安装4.2.1 yum4.2.2 apt 4.3 systemctl4.4 软链接4.4.1 ln 4.5 日期和时区4.5.1 date命令4.5.2 date进行日期加减…

阿里云服务器可以干嘛?阿里云服务器八大用途介绍

阿里云服务器可以干嘛&#xff1f;能干啥你还不知道么&#xff01;简单来讲可用来搭建网站、个人博客、企业官网、论坛、电子商务、AI、LLM大语言模型、测试环境等&#xff0c;阿里云百科aliyunbaike.com整理阿里云服务器的用途&#xff1a; 阿里云服务器活动 aliyunbaike.com…

怎么把学浪的视频保存到手机

越来越多的人在学浪app里面购买了课程并且想要下载下来&#xff0c;但是苦于没有方法或者工具&#xff0c;所以本文将教大家如何把学浪的视频保存到手机随时随地的观看&#xff0c;再也不用担心课程过期的问题。 本文将介绍工具来下载&#xff0c;因为下载方法太复杂&#xff…

Django检测到会话cookie中缺少HttpOnly属性手工复现

一、漏洞复现 会话cookie中缺少HttpOnly属性会导致攻击者可以通过程序(JS脚本等)获取到用户的cookie信息&#xff0c;造成用户cookie信息泄露&#xff0c;增加攻击者的跨站脚本攻击威胁。 第一步&#xff1a;复制URL&#xff1a;http://192.168.43.219在浏览器打开&#xff0c;…

Switch摇杆模块超好手感超小体积-适用于Arduino创客

Mini摇杆模块 1.模块照片 2.接线 摇杆模块的 G、V、X、Y 、SW分别连接 UNO 的G、V、A0、A1、D2引脚。 3.程序 /*rocker test- 摇杆测试This example code is in the public domain.Author : YFROBOT ZLWebsite : www.yfrobot.com.cnCreate Time: 2024 */#define XP…

自定义gitlog格式

git log命令非常强大而好用&#xff0c;在复杂系统的版本管理中扮演着重要的角色&#xff0c;但默认的git log命令显示出的东西实在太丑&#xff0c;不好好打扮一下根本没法见人&#xff0c;打扮好了用alias命令拍个照片&#xff0c;就正式出道了&#xff01; 在使用git查看lo…

CentOS7.9.2009安装elasticsearch7.11.1(单节点)

本文章使用CentOS7.9.2009服务器安装elasticsearch7.11.1软件 1.服务器信息 [root@elasticsearch ~]# cat /etc/redhat-release CentOS Linux release 7.9.2009 (Core) [root@elasticsearch ~]# [root@elasticsearch ~]# cat /etc/hosts | grep elasticsearch 192.168.10.24…

消息队列MQ(面试题:为什么使用MQ)

一、什么是mq? MQ全称 Message Queue&#xff08;消息队列&#xff09;&#xff0c;是在消息的传输过程中保存消息的容器。多用于分布式系统之间进行通信&#xff0c;解耦。 二、常见的mq产品 RabbitMQ、RocketMQ、ActiveMQ、Kafka、ZeroMQ、MetaMq RabbitMQ: One broker …

34470A是德科技34470A数字万用表

181/2461/8938产品概述&#xff1a; Truevolt数字万用表&#xff08;34460A、34461A、34465A、34470A&#xff09;利用是德科技的新专利技术&#xff0c;使您能够快速获得见解、测量低功耗设备并保持校准的测量结果。Truevolt提供全方位的测量能力&#xff0c;具有更高的精度、…

Centos7源码方式安装Elasticsearch 7.10.2单机版

下载 任选一种方式下载 官网7.10.2版本下载地址&#xff1a; https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.10.2-linux-x86_64.tar.gz 网盘下载链接 链接&#xff1a;https://pan.baidu.com/s/1EJvUPGVOkosRO2PUaKibaA?pwdbnqi 提取码&#x…