政安晨:【Keras机器学习示例演绎】(十七)—— 用于图像分类的 RandAugment 可提高鲁棒性

news2024/12/26 20:59:50

目录

导入与设置

加载 CIFAR10 数据集

定义超参数

初始化 RandAugment 对象

创建 TensorFlow 数据集对象

可视化使用 RandAugment 增强的数据集

可视化使用 simple_aug 增强的数据集

定义模型构建实用功能

使用 RandAugment 训练模型

用 simple_aug 训练模型

加载 CIFAR-10-C 数据集并评估性能


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:用于训练图像分类模型的 RandAugment,具有更强的鲁棒性。

数据增强是一种非常有用的技术,有助于提高卷积神经网络(CNN)的平移不变性。RandAugment 是一种用于视觉数据的随机数据增强程序,在 RandAugment 中提出:RandAugment: Practical automated data augmentation with a reduced search space》一书中提出的。它由色彩抖动、高斯模糊、饱和度等强增强变换和随机作物等更传统的增强变换组成。

这些参数可根据给定的数据集和网络结构进行调整。

最近,它已成为 "噪声学生训练 "和 "一致性训练的无监督数据增强 "等工作的关键组成部分。它也是 EfficientNets 取得成功的关键。

pip install keras-cv

导入与设置

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import keras_cv
from keras import ops
from keras import layers
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

tfds.disable_progress_bar()
keras.utils.set_random_seed(42)

加载 CIFAR10 数据集


在本例中,我们将使用 CIFAR10 数据集。

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
print(f"Total training examples: {len(x_train)}")
print(f"Total test examples: {len(x_test)}")

演绎展示:

Total training examples: 50000
Total test examples: 10000

定义超参数

AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 128
EPOCHS = 1
IMAGE_SIZE = 72

初始化 RandAugment 对象

现在,我们将使用 RandAugment 作者建议的参数,从 imgaug.augmenters 模块中初始化一个 RandAugment 对象。

rand_augment = keras_cv.layers.RandAugment(
    value_range=(0, 255), augmentations_per_image=3, magnitude=0.8
)

创建 TensorFlow 数据集对象

train_ds_rand = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
    .map(
        lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
        num_parallel_calls=AUTO,
    )
    .map(
        lambda x, y: (rand_augment(tf.cast(x, tf.uint8)), y),
        num_parallel_calls=AUTO,
    )
    .prefetch(AUTO)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(BATCH_SIZE)
    .map(
        lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
        num_parallel_calls=AUTO,
    )
    .prefetch(AUTO)
)

为了便于比较,我们还可以定义一个简单的增强管道,其中包括随机翻转、随机旋转和随机缩放。

simple_aug = keras.Sequential(
    [
        layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ]
)

# Now, map the augmentation pipeline to our training dataset
train_ds_simple = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
    .map(lambda x, y: (simple_aug(x), y), num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

可视化使用 RandAugment 增强的数据集

sample_images, _ = next(iter(train_ds_rand))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image.numpy().astype("int"))
    plt.axis("off")

建议您多运行几次上述代码块,以了解不同的变化。

可视化使用 simple_aug 增强的数据集

sample_images, _ = next(iter(train_ds_simple))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image.numpy().astype("int"))
    plt.axis("off")

定义模型构建实用功能


现在,我们定义一个基于 ResNet50V2 架构的 CNN 模型。此外,请注意该网络内部已经有一个重缩放层。这样,我们就无需对数据集进行任何单独的预处理,特别是在部署时非常有用。

def get_training_model():
    resnet50_v2 = keras.applications.ResNet50V2(
        weights=None,
        include_top=True,
        input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
        classes=10,
    )
    model = keras.Sequential(
        [
            layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
            layers.Rescaling(scale=1.0 / 127.5, offset=-1),
            resnet50_v2,
        ]
    )
    return model


get_training_model().summary()

演绎展示:

Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape              ┃    Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ rescaling (Rescaling)           │ (None, 72, 72, 3)         │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ resnet50v2 (Functional)         │ (None, 10)                │ 23,585,290 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 Total params: 23,585,290 (89.97 MB)
 Trainable params: 23,539,850 (89.80 MB)
 Non-trainable params: 45,440 (177.50 KB)

我们将在两个不同版本的数据集上对该网络进行训练:

一个使用 RandAugment 增强。
另一个使用 simple_aug 增强。

众所周知,RandAugment 可以增强模型对常见扰动和损坏的鲁棒性,因此我们还将在 CIFAR-10-C 数据集上评估我们的模型,该数据集是 Hendrycks 等人在《神经网络对常见损坏和扰动的鲁棒性基准测试》(Benchmarking Neural Network Robustness to Common Corruptions and Perturbations)一文中提出的。

在本示例中,我们将使用以下配置:cifar10_corrupted/saturate_5。该配置下的图像如下。

为了提高可重复性,我们将浅层网络的初始随机权重序列化。

initial_model = get_training_model()
initial_model.save_weights("initial.weights.h5")

使用 RandAugment 训练模型

rand_aug_model = get_training_model()
rand_aug_model.load_weights("initial.weights.h5")
rand_aug_model.compile(
    loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)
rand_aug_model.fit(train_ds_rand, validation_data=test_ds, epochs=EPOCHS)
_, test_acc = rand_aug_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))

演绎展示:

 391/391 ━━━━━━━━━━━━━━━━━━━━ 1146s 3s/step - accuracy: 0.1677 - loss: 2.3232 - val_accuracy: 0.2818 - val_loss: 1.9966
 79/79 ━━━━━━━━━━━━━━━━━━━━ 39s 489ms/step - accuracy: 0.2803 - loss: 2.0073
Test accuracy: 28.18%

用 simple_aug 训练模型

simple_aug_model = get_training_model()
simple_aug_model.load_weights("initial.weights.h5")
simple_aug_model.compile(
    loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)
simple_aug_model.fit(train_ds_simple, validation_data=test_ds, epochs=EPOCHS)
_, test_acc = simple_aug_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))

演绎展示:
 

 391/391 ━━━━━━━━━━━━━━━━━━━━ 1132s 3s/step - accuracy: 0.3673 - loss: 1.7929 - val_accuracy: 0.4789 - val_loss: 1.4296
 79/79 ━━━━━━━━━━━━━━━━━━━━ 39s 494ms/step - accuracy: 0.4762 - loss: 1.4368
Test accuracy: 47.89%

加载 CIFAR-10-C 数据集并评估性能

# Load and prepare the CIFAR-10-C dataset
# (If it's not already downloaded, it takes ~10 minutes of time to download)
cifar_10_c = tfds.load("cifar10_corrupted/saturate_5", split="test", as_supervised=True)
cifar_10_c = cifar_10_c.batch(BATCH_SIZE).map(
    lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
    num_parallel_calls=AUTO,
)

# Evaluate `rand_aug_model`
_, test_acc = rand_aug_model.evaluate(cifar_10_c, verbose=0)
print(
    "Accuracy with RandAugment on CIFAR-10-C (saturate_5): {:.2f}%".format(
        test_acc * 100
    )
)

# Evaluate `simple_aug_model`
_, test_acc = simple_aug_model.evaluate(cifar_10_c, verbose=0)
print(
    "Accuracy with simple_aug on CIFAR-10-C (saturate_5): {:.2f}%".format(
        test_acc * 100
    )
)

演绎展示:
 

 Downloading and preparing dataset 2.72 GiB (download: 2.72 GiB, generated: Unknown size, total: 2.72 GiB) to /home/sachinprasad/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0...
 Dataset cifar10_corrupted downloaded and prepared to /home/sachinprasad/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0. Subsequent calls will reuse this data.
Accuracy with RandAugment on CIFAR-10-C (saturate_5): 30.36%
Accuracy with simple_aug on CIFAR-10-C (saturate_5): 37.18%

在本例中,我们只对模型进行了单次训练。

在 CIFAR-10-C 数据集上,与使用 simple_aug 训练的模型(例如,64.80%)相比,使用 RandAugment 的模型表现更好,准确率更高(例如,在一次实验中为 76.64%)。RandAugment 还有助于稳定训练。

您可能会注意到,虽然使用 RandAugment 增加了训练时间,但我们在 CIFAR-10-C 数据集上的表现却要好得多。您可以在运行相同的 CIFAR-10-C 数据集时,尝试使用其他损坏和扰动设置,看看 RandAugment 是否有帮助。

您还可以尝试使用 RandAugment 对象中不同的 n 和 m 值。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1624802.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Rime 如何通过 iCloud 实现词库多端同步,Windows、iOS、macOS

Rime 如何通过 iCloud 实现词库多端同步,Windows、iOS、macOS 一、设备环境 最理想的输入环境就是在多端都使用同一个词库,这样能保持多端的输入习惯是一致的。 以我为例,手头每天都要用到的操作平台和对应的输入法: 操作系统设…

Spring Boot | Spring Boot “自定义“ Redis缓存 “序列化机制“

目录: Spring Boot "自定义" Redis缓存 "序列化机制" :一、基于 "注解" 的 "Redis缓存管理" 的 "默认序列化机制" 和 "自定义序列化机制"1.1 基于 "注解" 的 "Redis缓存管理" 的 …

抽象代理模式2.0版本

前言: 1.0版本的核心 代理的定义 A proxy, in its most general form, is a class functioning as an interface to something else. The proxy could interface to anything: a network connection, a large object in memory, a file, or some other resource t…

金融级国产化替代中间件有哪些?

过去,国内中间件市场一直由IBM、Oracle等国际大型企业所主导,这在一定程度上限制了对国内企业多样化和个性化需求的满足,尤其是在实现底层硬件与上层应用软件之间高效、精准匹配方面。面对日益复杂的国际局势,金融安全已成为国家整…

akSmart大带宽服务器基础配置科普

在数字化时代,服务器的性能和网络带宽成为业务发展的关键因素。RakSmart作为知名的服务器提供商,其大带宽服务器备受用户青睐。那么,RakSmart大带宽服务器的基础配置究竟有哪些呢?本文将为您揭开这一神秘面纱。 首先,我们来看看R…

【基于YOLOv8的教室人脸识别 附源码 数据集】

基于YOLOv8的教室人脸识别 附源码 数据集 在当今数字化迅速发展的教育领域中,人脸识别技术已成为提高校园安全和教学效率的关键工具。本文将详细介绍基于最新YOLOv8算法的教室人脸识别系统,这一系统不仅能够实时准确地识别学生和教职工的面部特征&#…

【QT】ROS2 Humble联合使用QT教程

【QT】ROS2 Humble联合使用QT教程 文章目录 【QT】ROS2 Humble联合使用QT教程1. 安装ROSProjectManager插件2. 创建ROS项目3.一个快速体验的demoReference 环境的具体信息如下: ubunt 22.04ros2 humbleQt Creator 13.0.0ROS ProjectManager 13.0.0 本文建立在已经…

【A-034】基于SSH的电影订票系统(含论文)

【A-034】基于SSH的电影订票系统(含论文) 开发环境: Jdk7(8)Tomcat7(8)MySQLIntelliJ IDEA(Eclipse) 数据库: MySQL 技术: SpringStruts2HiberanteJSPJquery 适用于: 课程设计,毕业设计&…

MacOS通过命令行开启关闭向日葵远程控制的后台服务

categories: [Tips] tags: MacOS Tips 写在前面 经常有小伙伴问我电脑相关的问题, 而解决问题的一个重要途径就是远程了. 关于免费的远程工具我试过向日葵和 todesk, 并且主要使用向日葵, 虽然 MacOS 下要设置很多权限, 但是也不影响其丝滑的控制. 虽然用着舒服, 但是向日葵…

【Elasticsearch<一>✈️✈️】简单安装使用以及各种踩坑

目录 🍸前言 🍻一、软件安装(Windows版) 1.1、Elasticsearch 下载 2.1 安装浏览器插件 3.1、安装可视化工具 Kibana 4.1、集成 IK 分词器 🍺二、安装问题 🍹三、测试 IK 分词器 ​🍷 四、章…

高端制造企业生产设备文件管理,怎样保证好用不丢失文件?

高端制造业在市场经济中占据重要角色,在高端制造业企业内部,生产设备又是最关键的一环环,它们不仅负责完成生产任务,同时也会产生大量的文件。这些数据反映了设备的运行状态、生产效率、能源消耗以及产品质量等多个方面&#xff0…

Delta模拟器:iOS上的复古游戏天堂

Delta模拟器:iOS上的复古游戏天堂 在数字时代,我们有时会怀念起那些早期的电子游戏,它们简单、纯粹,带给我们无尽的乐趣。虽然现在的游戏在画质和玩法上都有了巨大的提升,但那种复古的感觉却始终无法替代。幸运的是&a…

科技云报道:走入商业化拐点,大模型“开箱即用”或突破行业困局

科技云报道原创。 大模型加速狂飙,AI商业化却陷入重重困境。 一方面,传统企业不知道怎么将AI融入原始业务,另一方面,AI企业难以找到合适的商业化路径。 纵观海外AI玩家,已经有许多企业趟出了自己的商业化道路。 微…

C#从入门到精通:一场深入浅出的编程之旅【文末送书】

文章目录 C#从入门到精通入门篇进阶篇精通篇模式探索C#从入门到精通(第7版)(软件开发视频大讲堂)【文末送书】 C#从入门到精通 在当今数字化的时代,编程已经成为一项至关重要的技能。而在众多编程语言中,C…

人工智能|深度学习——多模态条件机制 Cross Attention 原理及实现

一、引入 虽然之前写过 Attention 的文章,但现在回头看之前写的一些文章,感觉都好啰嗦,正好下一篇要写的 Stable Diffusion 中有 cross-attention,索性就再单拎出来简单说一下 Attention 吧,那么这篇文章的作用有两个&…

微软在汉诺威工业博览会上推出新制造业Copilot人工智能功能,强化Dynamics 365工具集

在近日于德国汉诺威举行的盛大工业博览会上,微软向全球展示了其最新推出的制造业人工智能功能,这些功能以Dynamics 365工具集为核心,旨在通过先进的AI技术为制造业带来前所未有的变革。 此次推出的新功能中,最为亮眼的是支持AI的…

python 中使用 ESP8266 实现语音识别(或热词检测)

介绍 我的大部分家庭自动化都是通过对网络中的设备执行 HTTP 请求来控制的。 (例如:开灯、打开收音机、控制加热系统...... 这可以使用ESP8266轻松完成。我有一个控制器和一个触摸传感器,当我在床上时用它来控制灯光和音乐。 像 Amazon Echo 或 Google Homepod 一样添加语…

【Qt QML】TabBar的用法

Qt Quick中的TabBar提供了一个基于选项卡的导航模型。TabBar由TabButton控件填充,并且可以与任何提供currentIndex属性的布局或容器控件一起使用,例如StackLayout或SwipeView。 import QtQuick import QtQuick.Controls import QtQuick.LayoutsWindow …

【论文阅读】Self-DC:何时检索,何时生成?

对于RAG来说,什么时候利用外部检索,什么时候使用大模型产生已知的知识,以回答当前的问题?这是一个非常有趣的话题。 《Self-DC: When to retrieve and When to generate? Self Divide-and-Conquer for Compositional Unknown Questions》这…

Transformer step by step--Positional Embedding 和 Word Embedding

Transformer step by step往期文章: Transformer step by step--层归一化和批量归一化 要把Transformer中的Embedding说清楚,那就要说清楚Positional Embedding和Word Embedding。至于为什么有这两个Embedding,我们不妨看一眼Transformer的…