政安晨:【Keras机器学习实践要点】(二十七)—— 使用感知器进行图像分类

news2024/11/28 0:42:56

目录

简介

设置

准备数据

配置超参数

使用数据增强

实施前馈网络(FFN)

将创建修补程序作为一个层

实施补丁编码层

建立感知器模型

变换器模块

感知器模型

编译、培训和评估模式


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:实施用于图像分类的感知器模型。

简介


本文实现了安德鲁-杰格(Andrew Jaegle)等人的 Perceiver

图像分类模型,在 CIFAR-100 数据集上进行演示。

Perceiver 模型利用非对称注意力机制,将输入信息迭代提炼成一个紧密的潜在瓶颈,使其能够扩展以处理非常大的输入信息。

换句话说:假设你的输入数据数组(如图像)有 M 个元素(即补丁),其中 M 很庞大。在标准变换器模型中,会对 M 个元素执行自注意操作。这一操作的复杂度为 O(M^2)。然而,感知器模型会创建一个大小为 N 个元素(其中 N << M)的潜在数组,并迭代执行两个操作:

1. 潜数组和数据数组之间的交叉注意变换器 - 此操作的复杂度为 O(M.N)。
2. 潜数组上的自注意变换器 - 此操作的复杂度为 O(N^2)。


本示例需要 Keras 3.0 或更高版本。

设置

import keras
from keras import layers, activations, ops

准备数据

num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

演绎展示:

配置超参数

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 64
num_epochs = 2  # You should actually use 50 epochs!
dropout_rate = 0.2
image_size = 64  # We'll resize input images to this size.
patch_size = 2  # Size of the patches to be extract from the input images.
num_patches = (image_size // patch_size) ** 2  # Size of the data array.
latent_dim = 256  # Size of the latent array.
projection_dim = 256  # Embedding size of each element in the data and latent arrays.
num_heads = 8  # Number of Transformer heads.
ffn_units = [
    projection_dim,
    projection_dim,
]  # Size of the Transformer Feedforward network.
num_transformer_blocks = 4
num_iterations = 2  # Repetitions of the cross-attention and Transformer modules.
classifier_units = [
    projection_dim,
    num_classes,
]  # Size of the Feedforward network of the final classifier.

print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")
print(f"Latent array shape: {latent_dim} X {projection_dim}")
print(f"Data array shape: {num_patches} X {projection_dim}")

演绎展示:

请注意,为了将每个像素作为数据数组中的单独输入,请将 patch_size 设置为 1。

使用数据增强

data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

实施前馈网络(FFN)

def create_ffn(hidden_units, dropout_rate):
    ffn_layers = []
    for units in hidden_units[:-1]:
        ffn_layers.append(layers.Dense(units, activation=activations.gelu))

    ffn_layers.append(layers.Dense(units=hidden_units[-1]))
    ffn_layers.append(layers.Dropout(dropout_rate))

    ffn = keras.Sequential(ffn_layers)
    return ffn

将创建修补程序作为一个层

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = ops.shape(images)[0]
        patches = ops.image.extract_patches(
            image=images,
            size=(self.patch_size, self.patch_size),
            strides=(self.patch_size, self.patch_size),
            dilation_rate=1,
            padding="valid",
        )
        patch_dims = patches.shape[-1]
        patches = ops.reshape(patches, [batch_size, -1, patch_dims])
        return patches

实施补丁编码层

PatchEncoder 层会通过投射到大小为 latent_dim 的矢量中对补丁进行线性变换。此外,它还会为投影向量添加可学习的位置嵌入。

请注意,最初的 Perceiver 论文使用的是傅立叶特征位置编码。

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patches):
        positions = ops.arange(start=0, stop=self.num_patches, step=1)
        encoded = self.projection(patches) + self.position_embedding(positions)
        return encoded

建立感知器模型

感知器由两个模块组成:一个交叉注意模块和一个自注意标准变换器。

交叉注意模块


交叉注意模块将(latent_dim, projection_dim)潜在数组和(data_dim, projection_dim)数据数组作为输入,以产生(latent_dim, projection_dim)潜在数组作为输出。为了应用交叉关注,查询向量由潜在数组生成,而键向量和值向量则由编码图像生成。

请注意,本例中的数据数组是图像,其中 data_dim 设置为 num_patches。

def create_cross_attention_module(
    latent_dim, data_dim, projection_dim, ffn_units, dropout_rate
):
    inputs = {
        # Recieve the latent array as an input of shape [1, latent_dim, projection_dim].
        "latent_array": layers.Input(
            shape=(latent_dim, projection_dim), name="latent_array"
        ),
        # Recieve the data_array (encoded image) as an input of shape [batch_size, data_dim, projection_dim].
        "data_array": layers.Input(shape=(data_dim, projection_dim), name="data_array"),
    }

    # Apply layer norm to the inputs
    latent_array = layers.LayerNormalization(epsilon=1e-6)(inputs["latent_array"])
    data_array = layers.LayerNormalization(epsilon=1e-6)(inputs["data_array"])

    # Create query tensor: [1, latent_dim, projection_dim].
    query = layers.Dense(units=projection_dim)(latent_array)
    # Create key tensor: [batch_size, data_dim, projection_dim].
    key = layers.Dense(units=projection_dim)(data_array)
    # Create value tensor: [batch_size, data_dim, projection_dim].
    value = layers.Dense(units=projection_dim)(data_array)

    # Generate cross-attention outputs: [batch_size, latent_dim, projection_dim].
    attention_output = layers.Attention(use_scale=True, dropout=0.1)(
        [query, key, value], return_attention_scores=False
    )
    # Skip connection 1.
    attention_output = layers.Add()([attention_output, latent_array])

    # Apply layer norm.
    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)
    # Apply Feedforward network.
    ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)
    outputs = ffn(attention_output)
    # Skip connection 2.
    outputs = layers.Add()([outputs, attention_output])

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

变换器模块

转换器将交叉注意模块输出的潜向量作为输入,对其 latent_dim 元素应用多头自注意,然后通过前馈网络,生成另一个(latent_dim,projection_dim)潜数组。

def create_transformer_module(
    latent_dim,
    projection_dim,
    num_heads,
    num_transformer_blocks,
    ffn_units,
    dropout_rate,
):
    # input_shape: [1, latent_dim, projection_dim]
    inputs = layers.Input(shape=(latent_dim, projection_dim))

    x0 = inputs
    # Create multiple layers of the Transformer block.
    for _ in range(num_transformer_blocks):
        # Apply layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(x0)
        # Create a multi-head self-attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, x0])
        # Apply layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # Apply Feedforward network.
        ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)
        x3 = ffn(x3)
        # Skip connection 2.
        x0 = layers.Add()([x3, x2])

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=x0)
    return model

感知器模型


感知器模型重复交叉注意模块和变换器模块的迭代次数--通过共享权重和跳转连接--使潜在阵列能够根据需要从输入图像中迭代提取信息。

class Perceiver(keras.Model):
    def __init__(
        self,
        patch_size,
        data_dim,
        latent_dim,
        projection_dim,
        num_heads,
        num_transformer_blocks,
        ffn_units,
        dropout_rate,
        num_iterations,
        classifier_units,
    ):
        super().__init__()

        self.latent_dim = latent_dim
        self.data_dim = data_dim
        self.patch_size = patch_size
        self.projection_dim = projection_dim
        self.num_heads = num_heads
        self.num_transformer_blocks = num_transformer_blocks
        self.ffn_units = ffn_units
        self.dropout_rate = dropout_rate
        self.num_iterations = num_iterations
        self.classifier_units = classifier_units

    def build(self, input_shape):
        # Create latent array.
        self.latent_array = self.add_weight(
            shape=(self.latent_dim, self.projection_dim),
            initializer="random_normal",
            trainable=True,
        )

        # Create patching module.war
        self.patch_encoder = PatchEncoder(self.data_dim, self.projection_dim)

        # Create cross-attenion module.
        self.cross_attention = create_cross_attention_module(
            self.latent_dim,
            self.data_dim,
            self.projection_dim,
            self.ffn_units,
            self.dropout_rate,
        )

        # Create Transformer module.
        self.transformer = create_transformer_module(
            self.latent_dim,
            self.projection_dim,
            self.num_heads,
            self.num_transformer_blocks,
            self.ffn_units,
            self.dropout_rate,
        )

        # Create global average pooling layer.
        self.global_average_pooling = layers.GlobalAveragePooling1D()

        # Create a classification head.
        self.classification_head = create_ffn(
            hidden_units=self.classifier_units, dropout_rate=self.dropout_rate
        )

        super().build(input_shape)

    def call(self, inputs):
        # Augment data.
        augmented = data_augmentation(inputs)
        # Create patches.
        patches = self.patcher(augmented)
        # Encode patches.
        encoded_patches = self.patch_encoder(patches)
        # Prepare cross-attention inputs.
        cross_attention_inputs = {
            "latent_array": ops.expand_dims(self.latent_array, 0),
            "data_array": encoded_patches,
        }
        # Apply the cross-attention and the Transformer modules iteratively.
        for _ in range(self.num_iterations):
            # Apply cross-attention from the latent array to the data array.
            latent_array = self.cross_attention(cross_attention_inputs)
            # Apply self-attention Transformer to the latent array.
            latent_array = self.transformer(latent_array)
            # Set the latent array of the next iteration.
            cross_attention_inputs["latent_array"] = latent_array

        # Apply global average pooling to generate a [batch_size, projection_dim] repesentation tensor.
        representation = self.global_average_pooling(latent_array)
        # Generate logits.
        logits = self.classification_head(representation)
        return logits

编译、培训和评估模式

def run_experiment(model):
    # Create ADAM instead of LAMB optimizer with weight decay. (LAMB isn't supported yet)
    optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

    # Compile the model.
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="acc"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
        ],
    )

    # Create a learning rate scheduler callback.
    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.2, patience=3
    )

    # Create an early stopping callback.
    early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=15, restore_best_weights=True
    )

    # Fit the model.
    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[early_stopping, reduce_lr],
    )

    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    # Return history to plot learning curves.
    return history

请注意,在 V100 GPU 上以当前设置训练感知器模型大约需要 200 秒。

perceiver_classifier = Perceiver(
    patch_size,
    num_patches,
    latent_dim,
    projection_dim,
    num_heads,
    num_transformer_blocks,
    ffn_units,
    dropout_rate,
    num_iterations,
    classifier_units,
)


history = run_experiment(perceiver_classifier)

演绎展示: 

Test accuracy: 0.91%
Test top 5 accuracy: 5.2%

经过 40 次历时后,Perceiver 模型在测试数据上达到了约 53% 的准确率和 81% 的前五名准确率。

正如 Perceiver 论文的消融部分所述,通过增加潜在阵列大小、增加潜在阵列和数据阵列元素的(投影)维度、增加变换器模块中的块数以及增加交叉注意和潜在变换器模块的迭代次数,可以获得更好的结果。您还可以尝试增加输入图像的大小,并使用不同的补丁尺寸。

Perceiver 可以从增大模型尺寸中获益。

不过,更大的模型需要更大的加速器来适应和有效训练。这就是 Perceiver 论文中使用 32 个 TPU 内核进行实验的原因。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1590395.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RREA论文阅读

Relational Reflection Entity Alignment 关系反射实体对齐 ABSTRACT 实体对齐旨在识别来自不同知识图谱(KG)的等效实体对&#xff0c;这对于集成多源知识图谱至关重要。最近&#xff0c;随着 GNN 在实体对齐中的引入&#xff0c;近期模型的架构变得越来越复杂。我们甚至在这…

STM32—DMA直接存储器访问详解

DMA——直接存储器访问 DMA&#xff1a;Data Memory Access, 直接存储器访问。 DMA和我们之前学过的串口、GPIO都是类似的&#xff0c;都是STM32中的一个外设。串口是用来发送通信数据的&#xff0c;而DMA则是用来把数据从一个地方搬到另一个地方&#xff0c;而且不占用CPU。…

2024年经济发展、社会科学与贸易国际会议(ICEDSST2024)

2024年经济发展、社会科学与贸易国际会议(ICEDSST2024) 会议简介 2024年国际经济发展、社会科学与贸易会议&#xff08;ICEDSST2024&#xff09;将在中国深圳举行&#xff0c;主题为“经济发展、社科与贸易”。ICEDSST2024汇集了来自世界各地经济发展、社科与贸易领域的学者、…

Ubuntu无网络标识的解决方法

1.出现的情况的特点 2.解决办法 2.1 进入root并输入密码 sudo su 2.2 更新NetworkManager的配置 得先有gedit或者vim&#xff0c;两个随意一个&#xff0c;这里用的gedit&#xff0c;没有就先弄gedit&#xff0c;有的话直接下一步 apt-get install gedit 或者vim apt-get ins…

Vim:强大的文本编辑器

文章目录 Vim&#xff1a;强大的文本编辑器Vim的模式命令模式常用操作光标移动文本编辑查找和替换 底行命令模式常用操作Vim的多窗口操作批量注释与去注释Vim插件推荐&#xff1a;vimforcpp结论 Vim&#xff1a;强大的文本编辑器 Vim&#xff0c;代表 Vi IMproved&#xff0c;…

基于小程序实现的医院预约挂号系统

作者主页&#xff1a;Java码库 主营内容&#xff1a;SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app等设计与开发。 收藏点赞不迷路 关注作者有好处 文末获取源码 技术选型 【后端】&#xff1a;Java 【框架】&#xff1a;spring…

Android网络抓包--Charles

一、Android抓包方式 对Https降级进行抓包&#xff0c;降级成Http使用抓包工具对Https进行抓包 二、常用的抓包工具 wireshark&#xff1a;侧重于TCP、UDP传输层&#xff0c;HTTP/HTTPS也能抓包&#xff0c;但不能解密HTTPS报文。比较复杂fiddler&#xff1a;支持HTTP/HTTPS…

Swift Zulian Tiger

Swift Zulian Tiger 迅捷祖利安猛虎 16万金&#xff08;游戏币&#xff09; 1万金大概就能兑换460元~600元之间&#xff0c;6400元-9600元&#xff0c;汗颜 故事的一天刚打完BWL&#xff0c;才125金&#xff08;游戏币&#xff09; 本来想下线的结果他们说你太黑了&…

工控 modbusTCP 报文

Tx 发送报文:00 C9 00 00 00 06 01 03 00 00 00 02 Rx 接收报文:00 C9 00 00 00 07 01 03 04 01 4D 00 01 Tx 发送报文:00 C9 00 00 00 06 01 03 00 00 00 02 00 C9 事务处理标识符 2字节 00 00 协议标识符 2字节 固定 00 00 00 06 长度 2字节 表示之后的字节总数 &#xff08;…

贪心算法|968.监控二叉树

力扣题目链接 class Solution { private:int result;int traversal(TreeNode* cur) {// 空节点&#xff0c;该节点有覆盖if (cur NULL) return 2;int left traversal(cur->left); // 左int right traversal(cur->right); // 右// 情况1// 左右节点都有覆盖if (le…

MariaDB介绍和安装

MariaDB介绍和安装 文章目录 MariaDB介绍和安装1.MariaDB介绍2.MariaDB安装2.1 主机初始化2.1.1 设置网卡名和ip地址2.1.2 配置镜像源2.1.3 关闭防火墙2.1.4 禁用SELinux2.1.5 设置时区 2.2 包安装2.2.1 Rocky和CentOS 安装 MariaDB2.2.2 Ubuntu 安装 MariaDB 2.3 源码安装2.3.…

紫光展锐携手中国联通智慧矿山军团(山西)完成RedCap现网环境测试

近日&#xff0c;紫光展锐与中国联通智慧矿山军团&#xff08;山西&#xff09;在现网环境下成功完成了RedCap技术测试。此次测试对搭载紫光展锐RedCap芯片平台V517的模组注网速度和信号情况、Iperf打流测试上下行情况、ping包延时情况以及模组拨号入网压测等项目进行了全面验证…

【性能测试】接口测试各知识第3篇:Jmeter 基本使用流程,学习目标【附代码文档】

接口测试完整教程&#xff08;附代码资料&#xff09;主要内容讲述&#xff1a;接口测试&#xff0c;学习目标学习目标,2. 接口测试课程大纲,3. 接口学完样品,4. 学完课程,学到什么,5. 参考:,1. 理解接口的概念。学习目标&#xff0c;RESTFUL1. 理解接口的概念,2.什么是接口测试…

# Contrastive Learning(对比学习)--CLIP笔记(一)

Contrastive Learning&#xff08;对比学习&#xff09;–CLIP笔记&#xff08;一&#xff09; 参考&#xff1a;CLIP 论文逐段精读【论文精读】_哔哩哔哩_bilibili CLIP简介 CLIP是一种多模态预训练模型&#xff0c;由OpenAI在2021年提出&#xff0c;论文标题&#xff1a;L…

STM32 DCMI 的带宽与性能介绍

1. 引言 随着市场对更高图像质量的需求不断增加&#xff0c;成像技术持续发展&#xff0c;各种新兴技术&#xff08;例如3D、计算、运动和红外线&#xff09;的不断涌现。如今的成像应用对高质量、易用性、能耗效率、高集成度、快速上市和成本效益提出了全面要求。为了满足这些…

【自然语言】使用词袋模型,TF-IDF模型和Word2Vec模型进行文本向量化

一、任务目标 python代码写将 HarryPorter 电子书作为语料库&#xff0c;分别使用词袋模型&#xff0c;TF-IDF模型和Word2Vec模型进行文本向量化。 1. 首先将数据预处理&#xff0c;Word2Vec 训练时要求考虑每个单词前后的五个词汇&#xff0c;地址为 作为其上下文 &#xf…

数据结构的魅力

数据结构这块越学越敬佩 博大精深 统计大文件中相同年龄的人的个数 public static void main(String[] args) throws Exception {String str "";String fileName "";InputStreamReader isr new InputStreamReader(new FileInputStream(fileName), Stan…

OSCP靶场--Banzai

OSCP靶场–Banzai 考点(ftp爆破 webshell上传web1访问403web2可以访问webshell反弹mysql udf提权) 1.nmap扫描 ## nmap扫描一定要使用 -p- 否则容易扫不全端口 ┌──(root㉿kali)-[~/Desktop] └─# nmap -sV -sC 192.168.158.56 -Pn -p- --min-rate 2500Starting Nmap 7.…

ArcGIS Pro 3D建模简明教程

在本文中&#xff0c;我讲述了我最近一直在探索的在 ArcGIS Pro 中设计 3D 模型的过程。 我的目标是尽可能避免与其他软件交互&#xff08;即使是专门用于 3D 建模的软件&#xff09;&#xff0c;并利用 Pro 可以提供的可能性。 这个短暂的旅程分为三个不同的阶段&#xff1a;…

AI绘本生成解决方案,快速生成高质量的AI绘本视频

美摄科技凭借其深厚的技术积累和前瞻性的市场洞察力&#xff0c;近日推出了一款面向企业的AI绘本生成解决方案&#xff0c;旨在通过智能化、自动化的方式&#xff0c;帮助企业快速将文字内容转化为生动有趣的绘本视频&#xff0c;从而提升内容传播效率&#xff0c;增强品牌影响…