政安晨:【Keras机器学习实践要点】(十六)—— 图像分类从零开始

news2025/1/22 21:35:20

目录

简介

设置

加载数据:猫与狗数据集

原始数据下载

滤除损坏的图像

生成数据集

将数据可视化

使用图像数据增强

数据标准化

预处理数据的两个选项

配置数据集以提高性能

建立模型

训练模型

对新数据进行推理


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

在 Kaggle Cats vs Dogs 数据集上从头开始训练图像分类器。

简介


本示例展示了如何从零开始进行图像分类,从磁盘上的 JPEG 图像文件开始,而无需利用预训练的权重或预制的 Keras 应用程序模型。

我们在 Kaggle Cats vs Dogs 二进制分类数据集上演示了该工作流程。

我们使用 image_dataset_from_directory 实用程序生成数据集,并使用 Keras 图像预处理层进行图像标准化和数据扩增。

设置

接下来的演绎咱们在Kaggle上进行。

import os
import numpy as np
import keras
from keras import layers
from tensorflow import data as tf_data
import matplotlib.pyplot as plt

演绎:

加载数据:猫与狗数据集

原始数据下载

首先,让我们下载 786M 的原始数据 ZIP 压缩包:

!curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
!unzip -q kagglecatsanddogs_5340.zip
!ls

演绎如下:

现在,我们有一个 PetImages 文件夹,其中包含两个子文件夹:猫和狗。每个子文件夹都包含每个类别的图片文件。

!ls PetImages

滤除损坏的图像

在处理大量真实世界的图像数据时,损坏图像是常有的事。让我们来过滤掉标题中没有 "JFIF "字符串的编码不良图像。

num_skipped = 0
for folder_name in ("Cat", "Dog"):
    folder_path = os.path.join("PetImages", folder_name)
    for fname in os.listdir(folder_path):
        fpath = os.path.join(folder_path, fname)
        try:
            fobj = open(fpath, "rb")
            is_jfif = b"JFIF" in fobj.peek(10)
        finally:
            fobj.close()

        if not is_jfif:
            num_skipped += 1
            # Delete corrupted image
            os.remove(fpath)

print(f"Deleted {num_skipped} images.")

演绎:

生成数据集

image_size = (180, 180)
batch_size = 128

train_ds, val_ds = keras.utils.image_dataset_from_directory(
    "PetImages",
    validation_split=0.2,
    subset="both",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)

演绎如下:

将数据可视化

以下是训练数据集中的前 9 幅图像。

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(np.array(images[i]).astype("uint8"))
        plt.title(int(labels[i]))
        plt.axis("off")

使用图像数据增强

在没有大型图像数据集的情况下,通过对训练图像进行随机但真实的变换,例如随机水平翻转或小幅随机旋转,人为地引入样本多样性是一种很好的做法。这有助于让模型接触到训练数据的不同方面,同时减缓过度拟合。

data_augmentation_layers = [
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
]


def data_augmentation(images):
    for layer in data_augmentation_layers:
        images = layer(images)
    return images

让我们对数据集中的前几张图片反复应用 data_augmentation,直观地看看增强后的样本是什么样的:

plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):
    for i in range(9):
        augmented_images = data_augmentation(images)
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(np.array(augmented_images[0]).astype("uint8"))
        plt.axis("off")

演绎如下:

数据标准化

我们的图像已经是标准尺寸(180x180),因为它们是由我们的数据集生成的连续 float32 批次。不过,它们的 RGB 通道值在 [0, 255] 范围内。

这对于神经网络来说并不理想;

一般来说,你应该尽量使输入值小一些。

在此,我们将在模型开始时使用重缩放层(Rescaling layer),将值标准化为 [0, 1] 范围内的值。

预处理数据的两个选项

使用 data_augmentation 预处理器有两种方式:

方案 1: 让它成为模型的一部分,就像这样:

inputs = keras.Input(shape=input_shape)
x = data_augmentation(inputs)
x = layers.Rescaling(1./255)(x)
...  # Rest of the model

使用此选项,数据增强将在设备上进行,与模型执行的其他部分同步,这意味着它将受益于 GPU 加速。

请注意,数据增强在测试时处于非活动状态,因此输入样本只会在拟合(fit)时增强,而不会在调用评估(evaluate)或预测(predict)时增强。

如果使用 GPU 进行训练,这可能是一个不错的选择。

方案 2:将其应用于数据集,从而获得一个数据集,生成成批的增强图像,就像这样:

augmented_train_ds = train_ds.map(
    lambda x, y: (data_augmentation(x, training=True), y))

使用该选项,数据增强将在 CPU 上异步进行,并在进入模型前进行缓冲。
如果在 CPU 上进行训练,这是更好的选择,因为它能使数据增强异步且无阻塞。
在我们的例子中,我们会选择第二种方案。如果你不确定该选哪一个,那么第二个选项(异步预处理)始终是一个可靠的选择。

配置数据集以提高性能

让我们对训练数据集进行数据增强,并确保使用缓冲预取,这样我们就能从磁盘获取数据,而不会出现 I/O 阻塞:

# Apply `data_augmentation` to the training images.
train_ds = train_ds.map(
    lambda img, label: (data_augmentation(img), label),
    num_parallel_calls=tf_data.AUTOTUNE,
)
# Prefetching samples in GPU memory helps maximize GPU utilization.
train_ds = train_ds.prefetch(tf_data.AUTOTUNE)
val_ds = val_ds.prefetch(tf_data.AUTOTUNE)

建立模型

我们将构建一个小版本的 Xception 网络。我们没有特别尝试优化架构;如果你想系统地搜索最佳模型配置,可以考虑使用 KerasTuner。

请注意

我们用 data_augmentation 预处理器启动模型,然后是 Rescaling 层。
我们在最终分类层之前加入了一个 Dropout 层。

def make_model(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    # Entry block
    x = layers.Rescaling(1.0 / 255)(inputs)
    x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    for size in [256, 512, 728]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(size, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    x = layers.SeparableConv2D(1024, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.GlobalAveragePooling2D()(x)
    if num_classes == 2:
        units = 1
    else:
        units = num_classes

    x = layers.Dropout(0.25)(x)
    # We specify activation=None so as to return logits
    outputs = layers.Dense(units, activation=None)(x)
    return keras.Model(inputs, outputs)


model = make_model(input_shape=image_size + (3,), num_classes=2)
keras.utils.plot_model(model, show_shapes=True)

演绎如下:

训练模型

epochs = 25

callbacks = [
    keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),
]
model.compile(
    optimizer=keras.optimizers.Adam(3e-4),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy(name="acc")],
)
model.fit(
    train_ds,
    epochs=epochs,
    callbacks=callbacks,
    validation_data=val_ds,
)

演绎如下:

epochs = 25

callbacks = [
    keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),
]
model.compile(
    optimizer=keras.optimizers.Adam(3e-4),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy(name="acc")],
)
model.fit(
    train_ds,
    epochs=epochs,
    callbacks=callbacks,
    validation_data=val_ds,
)

演绎如下: 

Epoch 1/25
...
Epoch 25/25
 147/147 ━━━━━━━━━━━━━━━━━━━━ 53s 354ms/step - acc: 0.9638 - loss: 0.0903 - val_acc: 0.9382 - val_loss: 0.1542

<keras.src.callbacks.history.History at 0x7f41003c24a0>

过程:

我们在完整数据集上训练了 25 个历元后,验证准确率达到了 90% 以上(实际上,在验证性能开始下降之前,可以训练 50 多个历元)。

对新数据进行推理

需要注意的是,在推理时,数据扩增和数据剔除都处于非活动状态。

img = keras.utils.load_img("PetImages/Cat/6779.jpg", target_size=image_size)
plt.imshow(img)

img_array = keras.utils.img_to_array(img)
img_array = keras.ops.expand_dims(img_array, 0)  # Create batch axis

predictions = model.predict(img_array)
score = float(keras.ops.sigmoid(predictions[0][0]))
print(f"This image is {100 * (1 - score):.2f}% cat and {100 * score:.2f}% dog.")

实验如下:


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1571343.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【快捷部署】011_PostgreSQL(16)

&#x1f4e3;【快捷部署系列】011期信息 编号选型版本操作系统部署形式部署模式复检时间011PostgreSQL16Ubuntu 20.04Docker单机2024-03-28 一、快捷部署 #!/bin/bash ################################################################################# # 作者&#xff1…

【二分查找】Leetcode 二分查找

题目解析 二分查找在数组有序可以使用&#xff0c;也可以在数组无序的时候使用&#xff08;只要数组中的一些规律适用于二分即可&#xff09; 704. 二分查找 算法讲解 当left > right的时候&#xff0c;我们循环结束&#xff0c;但是当left和right缩成一个点的时候&#x…

DDR3接口

mig IP核的配置 首先添加mig IP核   然后确认以下工程信息&#xff0c;主要是芯片型号以及编译环境&#xff0c;没什么问题后点击next.   如下图所示&#xff0c;这一页选择"Create Design"&#xff0c;在"Component Name"一栏设置该IP元件的名称&…

Redis数据库——群集(主从、哨兵)

目录 前言 一、主从复制 1.基本原理 2.作用 3.流程 4.搭建主动复制 4.1环境准备 4.2修改主服务器配置 4.3从服务器配置&#xff08;Slave1和Slave2&#xff09; 4.4查看主从复制信息 4.5验证主从复制 二、哨兵模式——Sentinel 1.定义 2.原理 3.作用 4.组成 5.…

59 使用 uqrcodejs 生成二维码

前言 这是一个最近的一个来自于朋友的需求, 然后做了一个 基于 uqrcodejs 来生成 二维码的一个 demo package.json 中增加以依赖 "uqrcodejs": "^4.0.7", 测试用例 <template><div class"hello"><canvas id"qrcode&qu…

代码随想录-算法训练营day02【滑动窗口、螺旋矩阵】

专栏笔记&#xff1a;https://blog.csdn.net/weixin_44949135/category_10335122.html https://docs.qq.com/doc/DUGRwWXNOVEpyaVpG?uc71ed002e4554fee8c262b2a4a4935d8977.有序数组的平方 &#xff0c;209.长度最小的子数组 &#xff0c;59.螺旋矩阵II &#xff0c;总结 建议…

[中级]软考_软件设计_计算机组成与体系结构_08_输入输出技术

输入输出技术 前言控制方式考点往年真题 前言 输入输出技术就是IO技术 控制方式 程序控制(查询)方式&#xff1a;分为无条件传送和程序查询方式两种。 方法简单&#xff0c;硬件开销小&#xff0c;但I/O能力不高&#xff0c;严重影响CPU的利用率。 程序中断方式&#xff1…

LeetCode---127双周赛

题目列表 3095. 或值至少 K 的最短子数组 I 3096. 得到更多分数的最少关卡数目 3097. 或值至少为 K 的最短子数组 II 3098. 求出所有子序列的能量和 一、或值至少k的最短子数组I&II 暴力的做法大家都会&#xff0c;这里就不说了&#xff0c;下面我们来看看如何进行优化…

Python云计算技术库之libcloud使用详解

概要 随着云计算技术的发展,越来越多的应用和服务迁移到了云端。然而,不同云服务商的API和接口千差万别,给开发者带来了不小的挑战。Python的libcloud库应运而生,它提供了一个统一的接口,让开发者可以轻松地管理不同云服务商的资源。本文将深入探讨libcloud库的特性、安装…

SecureCRT通过私钥连接跳板机,再连接到目标服务器

文章目录 1. 配置第一个session&#xff08;跳板机&#xff09;2. 设置本地端口3. 设置全局firewall4. 配置第二个session&#xff08;目标服务器&#xff09; 服务器那边给了一个私钥&#xff0c;现在需要通过私钥连接跳板机&#xff0c;再连接到目标服务器上 &#x1f349; …

vue3和vue2项目中如何根据不同的环境配置基地址?

在不同环境下取出的变量的值是不同的, 像这样的变量称为环境变量 为什么要使用环境变量呢? 开发环境生产环境下的接口地址有可能是不一样的&#xff0c;所以我们需要根据环境去配置不同的接口基地址 1、vue2环境变量配置 在根目录创建&#xff1a;.env.development和.env.p…

getc(),putc(),getchar(),putchar()之间的区别

getc&#xff08;&#xff09; putc() 与函数 getchar() putchar()类似&#xff0c;但是不同点在于&#xff1a;你要告诉getc和putc函数使用哪一个文件 1.从标准输入中获取一个字符&#xff1a; ch getchar(); //在处理器上输入字符 2.//从fp指定的文件中获取以一个字符 ch …

全面解析找不到msvcr110.dll,无法继续执行代码的解决方法

MSVCR110.dll的丢失可能导致某些应用程序无法启动。当用户试图打开依赖于该特定版本DLL文件的软件时&#xff0c;可能会遭遇“找不到指定模块”的错误提示&#xff0c;使得程序启动进程戛然而止。这种突如其来的故障不仅打断了用户的正常工作流程&#xff0c;也可能导致重要数据…

分库分表 ——12 种分片算法

目录 前言 分片策略 标准分片策略 行表达式分片策略 复合分片策略 Hint分片策略 不分片策略 分片算法 准备工作 自动分片算法 1、MOD 2、HASH_MOD 3、VOLUME_RANGE 4、BOUNDARY_RANGE 5、AUTO_INTERVAL 标准分片算法 6、INLINE 7、INTERVAL COSID 类型算法 …

004 CSS介绍2

文章目录 css最常用属性link元素进制css颜色表示浏览器的渲染流程(不含js) css最常用属性 font-size 文字大小 color:前景色(文字颜色) background-color:背景色 width:宽度 height:高度 link元素 也可以用来创建站点图标 link元素常见属性 href:指定被链接资源的URL rel:指…

【Linux篇】认识冯诺依曼体系结构

文章目录 一、冯诺依曼体系结构是什么二、冯诺依曼为什么要这么设计&#xff1f;三、内存是怎么提高效率的呢&#xff1f;解释&#xff1a;程序要运行&#xff0c;必须加载到内存四、和QQ好友聊天的时候&#xff0c;数据是怎么流向的&#xff1f; 一、冯诺依曼体系结构是什么 冯…

抖音运营技巧

1、视频时长 抖音的作品是否能够继续被推荐&#xff0c;取决于综合数据&#xff0c;包括完播率、点赞率、评论率、转发率和收藏率等。其中&#xff0c;完播率是最容易控制的因素。对于新号来说&#xff0c;在没有粉丝的初期&#xff0c;发布过长的视频可能会导致无人观看。因此…

vue 打包 插槽 inject reactive draggable 动画 foreach pinia状态管理

在Vue项目中&#xff0c;当涉及到打包、插槽&#xff08;Slots&#xff09;、inject/reactive、draggable、transition、foreach以及pinia时&#xff0c;这些都是Vue框架的不同特性和库&#xff0c;它们各自在Vue应用中有不同的用途。下面我将逐一解释这些概念&#xff0c;并说…

vue给input密码框设置眼睛睁开闭合对于密码显示与隐藏

<template><div class"login-container"><el-inputv-model"pwd":type"type"class"pwd-input"placeholder"请输入密码"><islot"suffix"class"icon-style":class"elIcon"…

springboot项目引入swagger

1.引入依赖 创建项目后&#xff0c;在 pom.xml 文件中引入 Swagger3 的相关依赖。回忆一下&#xff0c;我们集成 Swagger2 时&#xff0c;引入的依赖如下&#xff1a; <dependency><groupId>io.springfox</groupId><artifactId>springfox-swagger2&…