深度学习-第T10周——数据增强

news2024/11/27 11:43:48

深度学习-第T10周——数据增强

  • 深度学习-第T10周——数据增强
    • 一、前言
    • 二、我的环境
    • 三、前期工作
      • 1、导入数据集
      • 2、查看图片数目
    • 四、数据预处理
      • 1、 加载数据
        • 1.1、设置图片格式
        • 1.2、划分训练集
        • 1.3、划分验证集
        • 1.4、查看标签
        • 1.5、再次检查数据
        • 1.6、配置数据集
      • 2、数据可视化
    • 五、数据增强
      • 1.数据增强
      • 2.增强方式
        • 法一:将其嵌入model中
        • 法二:在Dataset数据集中进行数据增强
    • 六、模型训练
    • 七、自定义增强函数
    • 八、总结

深度学习-第T10周——数据增强

一、前言

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

二、我的环境

  • 电脑系统:Windows 10
  • 语言环境:Python 3.8.5
  • 编译器:colab在线编译
  • 深度学习环境:Tensorflow

三、前期工作

数据增强:数据增强可以用少量数据达到非常棒的识别准确率
数据增强的两种方式:
1、将数据增强模块嵌入model中
2、在Dataset数据集中进行数据增强

1、导入数据集

导入数据集,这里使用k同学的数据集,共2个分类。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os, PIL, pathlib

#1、载入数据
data_dir = ("D:/DL_Camp/CNN/T8/365-7-data")
data_dir = pathlib.Path(data_dir)

这段代码将字符串类型的 data_dir 转换为了 pathlib.Path 类型的对象。pathlib 是 Python3.4 中新增的模块,用于处理文件路径。
通过 Path 对象,可以方便地操作文件和目录,如创建、删除、移动、复制等。
在这里,我们使用 pathlib.Path() 函数将 data_dir 转换为路径对象,这样可以更加方便地进行文件路径的操作和读写等操作。

2、查看图片数目

image_mount = len(list(data_dir.glob("*/*.jpg"))) 
print(image_mount)

获取指定目录下所有子文件夹中 jpg 格式的文件数量,并将其存储在变量 image_count 中。

data_dir 是一个路径变量,表示需要计算的目标文件夹的路径。
glob() 方法可以返回匹配指定模式(通配符)的文件列表,该方法的参数 “/.jpg” 表示匹配所有子文件夹下以 .jpg 结尾的文件。

list() 方法将 glob() 方法返回的生成器转换为列表,方便进行数量统计。最后,len() 方法计算列表中元素的数量,就得到了指定目录下 jpg 格式文件的总数。

所以,这行代码的作用就是计算指定目录下 jpg 格式文件的数量。
在这里插入图片描述

四、数据预处理

1、 加载数据

1.1、设置图片格式

batch_size = 32
img_height = 224
img_width = 224

1.2、划分训练集

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    
    data_dir,
    validation_split = 0.3,
    subset = 'training',
    seed = 12,
    image_size = (img_height, img_width),
    batch_size = batch_size
    
    )

这行代码使用 TensorFlow 读取指定路径下的图片文件,并生成一个 tf.data.Dataset 对象,用于模型的训练和评估。

具体来说,tf.keras.preprocessing.image_dataset_from_directory() 函数从指定目录中读取图像数据,并自动对其进行标准化和预处理。该函数有以下参数:

data_dir: 字符串,指定要读取的图片文件夹路径。
validation_split: 浮点数,指定验证集所占的比例。默认值为 0.2。
subset: 字符串,表示要读取哪个子集的数据。默认为 “training”,即读取训练集数据。
seed: 整型,用于设置随机种子以生成可重复的随机数,默认为 None。
image_size: 元组,表示所有图像的期望尺寸。例如 (150, 150) 表示将所有图像调整为 150x150 大小。
batch_size: 整型,表示每个批次的样本数。

通过这些参数,函数将指定目录中的图像按照指定大小预处理后,随机划分为训练集和验证集。最终,生成的 tf.data.Dataset 对象包含了划分好的数据集,可以用于后续的模型训练和验证。

需要注意的是,这里的 img_height 和 img_width 变量应该提前定义,并且应该与实际图像的尺寸相对应。同时,batch_size 也应该根据硬件设备的性能合理调整,以充分利用 GPU/CPU 的计算资源。
在这里插入图片描述

1.3、划分验证集

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    
    data_dir,
    validation_split = 0.3,
    subset = "validation",
    seed = 12,
    image_size = (img_height, img_width),
    batch_size = batch_size
    
    )

这段代码和上一段代码类似,使用 TensorFlow 的 keras.preprocessing.image_dataset_from_directory() 函数从指定的目录中读取图像数据集,并将其划分为训练集和验证集。

其中,data_dir 指定数据集目录的路径,validation_split 表示从数据集中划分出多少比例的数据作为验证集,subset 参数指定为 “validation” 则表示从数据集的 20% 中选择作为验证集,其余 80% 作为训练集。seed 是一个随机种子,用于生成可重复的随机数。image_size 参数指定输出图像的大小,batch_size 表示每批次加载的图像数量。

该函数返回一个 tf.data.Dataset 对象,代表了整个数据集(包含训练集和验证集)。可以使用 train_ds 和 val_ds 两个对象分别表示训练集和验证集。

不过两段代码的 subset 参数值不同,一个是 “training”,一个是 “validation”。

因此,在含有交叉验证或者验证集的深度学习训练过程中,需要定义两个数据集对象 train_ds 和 val_ds。我们已经定义了包含训练集和验证集的数据集对象 train_ds,可以省略这段代码,无需重复定义 val_ds 对象。只要确保最终的训练过程中,两个数据集对象都能够被正确地使用即可。

如果你没有定义 val_ds 对象,可以使用这段代码来创建一个验证数据集对象,用于模型训练和评估,从而提高模型性能。

在这里插入图片描述

1.4、查看标签

由于原始数据集不包含测试集,因此需要创建一个

#由于原始数据集不包含测试集,因此需要创建一个
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds = val_ds.take(val_batches // 5)
val_ds = val_ds.skip(val_batches // 5)

print('Number Of Val Batches: %d' % val_batches)
print('Number Of Validation Batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number Of test Batches: %d' % tf.data.experimental.cardinality(test_ds))

class_names = train_ds.class_names
class_names

在这里插入图片描述

1.5、再次检查数据

for images_batch, labels_batch in train_ds.take(1):
    print(images_batch.shape)
    print(labels_batch.shape)
    break

在这里插入图片描述

image_batch 是张量的形状(64, 224, 224,3)。这是一批形状2242243的8张图片,最后一维指的是彩色通道RGB
label_batch是形状为(64,)的张量,这些标签对应64张图片

1.6、配置数据集

AUTOTUNE = tf.data.AUTOTUNE
"""
定义 AUTOTUNE 常量
这个常量的作用是指定 TensorFlow 数据管道读取数据时使用的线程个数,使得数据读取可以尽可能地并行化,提升数据读取效率。
具体来说,AUTOTUNE 的取值会根据系统资源和硬件配置等因素自动调节。"""
def preprocess_image(image, label):
    return (image / 255.0, label)
"""这个函数的作用是对输入的图像数据进行预处理操作,其中 image 表示输入的原始图像数据,label 表示对应的标签信息。
函数体内的操作是把原始图像数据除以 255,使其数值归一化到 0 和 1 之间。
函数返回一个元组 (image / 255.0, label),其中第一个元素是经过处理后的图像数据,第二个元素是对应的标签信息。
"""

#归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls = AUTOTUNE)
val_ds = val_ds.map(preprocess_image, num_parallel_calls = AUTOTUNE)
test_ds = test_ds.map(preprocess_image, num_parallel_calls = AUTOTUNE)

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size = AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size = AUTOTUNE)
"""
通过 map 方法对数据集中的每个元素应用 preprocess_image 函数进行预处理。
num_parallel_calls 参数指定了并行处理的个数,这里设为 AUTOTUNE,表示自动选择最优的并行个数。

接着,对经过预处理后的数据集,通过 cache 方法将其缓存到内存中,以提高读取效率。
然后,再利用 shuffle 方法和 prefetch 方法对训练数据集进行混洗和数据预取,增强训练稳定性和效率。
而验证数据集只需要进行缓存和数据预取操作即可。
"""

在 TensorFlow 中,map 是一种对数据集中的每个元素应用一个函数的方法,常用于数据预处理和数据增强等任务。其使用方式为:

dataset = dataset.map(map_func, num_parallel_calls=None)
其中,dataset 表示待处理的数据集对象,map_func 表示要应用的函数,num_parallel_calls 表示并行执行 map_func 的线程数。

具体来说,map_func 函数会被应用到数据集中的每个元素上,函数接受一个或多个张量作为输入,输出也可以是一个或多个张量。map_func 的定义方式应当符合 TensorFlow 的计算图模型,即是一组 TensorFlow 的计算操作(ops)。

使用 map 方法可以方便地对数据集进行预处理,例如图像数据的归一化、尺寸调整、数据增强等。同时,由于 map 方法本身支持并行处理,因此可以大大加速数据处理的速度。

在使用 map 方法时,应尽可能指定 num_parallel_calls 参数以充分利用计算资源,提高处理效率。

2、数据可视化

plt.figure(figsize = (15, 13))

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(5, 8, i + 1)
        plt.imshow(images[i])
        plt.title(class_names[labels[i]])
        plt.axis("off")

train_ds.take(1) 是一个方法调用,它返回一个数据集对象 train_ds 中的子集,其中包含了 take() 方法参数指定的数量的样本。
在这个例子中,take(1) 意味着我们从 train_ds 数据集中获取一批包含一个样本的数据块。

因此,for images, labels in train_ds.take(1): 的作用是遍历这个包含一个样本的数据块,并将其中的图像张量和标签张量依次赋值给变量 images 和 labels。具体来说,
它的执行过程如下:

从 train_ds 数据集中获取一批大小为 1 的数据块。
遍历这个数据块,每次获取一个图像张量和一个标签张量。
将当前图像张量赋值给变量 images,将当前标签张量赋值给变量 labels。
执行 for 循环中的代码块,即对当前图像张量和标签张量进行处理。

plt.imshow() 函数是 Matplotlib 库中用于显示图像的函数,它接受一个数组或张量作为输入,并在窗口中显示对应的图像。
在这个代码中,images[i] 表示从训练集中获取的第 i 个图像张量。由于 images 是一个包含多个图像的张量列表,因此使用 images[i] 可以获取其中的一个图像。

plt.axis(“off”) 是 Matplotlib 库中的一个函数调用,它用于控制图像显示时的坐标轴是否可见。
具体来说,当参数为 “off” 时,图像的坐标轴会被关闭,不会显示在图像周围。这个函数通常在 plt.imshow() 函数之后调用,以便在显示图像时去掉多余的细节信息,仅仅显示图像本身。

在这里插入图片描述

五、数据增强

1.数据增强

data_augmentation = tf.keras.Sequential([
    
    tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.2)
    
    ])
image = tf.expand_dims(images[i], 0)

plt.figure(figsize = (8,8))
for i in range(9):
    augmented_image = data_augmentation(image)
    ax = plt.subplot(3, 3, i+1)
    plt.imshow(augmented_image[0])
    plt.axis('off')

上述代码定义了一个数据增强模型data_augmentation,用于对图像数据进行随机翻转和旋转操作。具体来说,该模型使用tf.keras.Sequential容器将两个图像处理层拼接在一起:

tf.keras.layers.experimental.preprocessing.RandomFlip(“horizontal_and_vertical”):该层用于对图像进行随机水平和垂直翻转的操作。其中,"horizontal_and_vertical"参数表示对图像进行同时水平和垂直方向上的翻转。如果不需要进行某个方向上的翻转可以分别使用"horizontal"和"vertical"参数。
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2):该层用于对图像进行随机旋转操作。其中,0.2参数表示旋转角度的幅度范围,即在[-20%, 20%]的范围内随机旋转图像。

在这里插入图片描述

2.增强方式

法一:将其嵌入model中

model = tf.keras.Sequential([
    
    data_augmentation,
    layers.Conv2D(16, 3, padding = "same", activation = 'relu'),
    layers.MaxPooling2D()
    
    ])

法二:在Dataset数据集中进行数据增强

batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE()

def prepare(ds):
    ds = ds.map(lambda x, y : (data_augmentation(x, training = True), y), num_parallel_calls = AUTOTUNE)
    return ds

train_ds = prepare(train_ds)

六、模型训练

#四、训练模型
model = tf.keras.Sequential([
    data_augmentation,
    layers.Conv2D(16, 3, padding = 'same', activation = 'relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, padding = 'same', activation = 'relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(64, 3, padding = 'same', activation = 'relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128, activation = 'relu'),
    layers.Dense(len(class_names))
    ])

model.compile(
    
    optimizer = 'adam',
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=( True )),
    metrics = ['accuracy']
    )

epochs = 20
history = model.fit(
    
    train_ds,
    validation_data = val_ds,
    epochs = epochs
    
    )

loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)

在这里插入图片描述

七、自定义增强函数

import random

def aug_img(image):
  seed = (random.randint(0, 9), 0)
  #随机改变图像对比度
  stateless_random_brightness = tf.image.stateless_random_contrast(image, lower = 0.1, upper = 1.0, seed = seed)
  return stateless_random_brightness


image = tf.expand_dims(images[3] * 255, 0)
print("Min And Max Pixel Values:", image.numpy().min(), image.numpy().max())

plt.figure(figsize = (8, 8))
for i in range(9):

  augmented_image = aug_img(image)
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(augmented_image[0].numpy().astype("uint8"))

  plt.axis("off")

在这里插入图片描述

八、总结

数据增强可以有效地提高模型的识别精度和泛化能力。在训练过程中对输入数据进行随机变换,可以使得模型更加鲁棒,避免过度拟合。对输入图像进行随机翻转和旋转,可以增加训练数据的多样性,从而提高模型对不同角度和方向的图像进行分类的能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/629069.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

硬件设计电源系列文章-DCDC转换器基础知识

文章目录 概要整体架构流程技术名词解释技术细节小结 概要 提示:这里可以添加技术概要 本文主要接着上篇,上篇文章主要讲述了LDO的相关基础知识,本节开始分享DCDC基础知识 整体架构流程 提示:这里可以添加技术整体架构 以下是…

ROS学习——通信机制(话题通信②—订阅方实现)

2.1 话题通信 Autolabor-ROS机器人入门课程《ROS理论与实践》零基础教程 042话题通信(C)3_订阅方实现_Chapter2-ROS通信机制_哔哩哔哩_bilibili 1.新建demo02_sub.cpp文件,搭建框架 2.包含头文件 3.初始化ROS节点 cuiHua——节点名称,具有唯一性 4.创…

一小时让你Get到面试套路:记一次Java初中级程序员面试流程梳理

视频教程传送门: 一小时让你Get到面试套路:记一次Java初中级程序员面试流程梳理_哔哩哔哩_bilibili听了N多个师兄师姐的面试录音,采访了N多个师兄时间的面试经历,才总结出来的java面试流程,非常适合正在准备面试的你。…

【7 微信小程序学习 - 小程序的系统API调用,网络请求封装】

1 网络请求 – 原生 请求数据,保存数据 1 原生请求 Page({data: {allCities: {},houselist: [],currentPage: 1},async onLoad() {// 1.网络请求基本使用wx.request({url: "http://codercba.com:1888/api/city/all",success: (res) > {//保存数据const data res…

企业级在线办公系统搭建开发环境

目录 介绍 搭建开发环境 安装MySQL数据库 安装Redis程序 安装MongoDB数据库 安装RabbitMQ 安装JDK 安装Maven环境 安装Node.js程序 安装HbuilderX工具 MacOS环境的程序安装 安装Docker环境 安装MySQL数据库 安装MongoDB数据库 安装Redis程序 安装RabbitMQ 学习…

【Python 匿名函数】零基础也能轻松掌握的学习路线与参考资料

Python 匿名函数是一种特殊的函数,也称为 lambda 函数。它允许我们创建一种简单的函数,一行代码就可以搞定。可能你会发现,在使用Python时,经常可以看到lambda关键字,以lambda开头的就是匿名函数。本文将详细介绍Pytho…

在一台电脑上配置多个Git账号,工作、生活两不误

文章目录 先 Unset global 配置生成 SSH Key 并配置到 GitHub多账号用 config 管理 先 Unset global 配置 任意文件夹下 Git Bash Here 然后输入如下命令来 unset git config --global --unset user.name git config --global --unset user.email git config --global --unset…

LabVIEW编程开发汽油中各种掺假物浓度⽔平的检测

LabVIEW编程开发汽油中各种掺假物浓度⽔平的检测 主要目的是使用LabVIEW中的密度法找到汽油中掺杂物的浓度。已经对许多技术进行了研究,以发现汽油中的掺杂物。例如蒸馏试验、化学制造商试验、蒸发试验、气相色谱法可以专门测量掺杂物。甚至有数字密度计来测量汽油…

深圳旧改投资_一秒读懂什么是确权?物业权利人核实的基础要点。

权利人核实常见问题 为什么要做历史违建物业权利人核实? 2021年3月1日正式实施的《深圳经济特区城市更新条例》,区城市更新部门应当在物业权利人更新意愿核实阶段组织区规划土地监察机构、辖区街道办事处和原农村集体经济组织继受单位对历史违建物业权…

【数据可视化】数据可视化Canvas

1、了解Canvas ◼什么是Canvas ---- Canvas 最初由Apple于2004 年引入,用于Mac OS X WebKit组件,为仪表板小部件和Safari浏览器等应用程序提供支持。后来,它被Gecko内核的浏览器(尤其是Mozilla Firefox)&#xff0c…

上位机Qt应用程序与MCU板子之间的串口数据传输算法,举例1字节、2字节、4字节正负数。再加qDebug的重定向显示打印数据。

串口之间的数据传输算法 前言【1】Qt界面设计图【2】串口char型举例串口收发正数举例串口收发负数举例 【3】串口short 型举例大端序和小端序 串口收发正数举例串口收发负数举例 【4】串口int型举例串口收发正数举例串口收发负数举例串口收发正负数(简洁版推荐&…

大数据分析的Python实战指南:数据处理、可视化与机器学习【上进小菜猪大数据】

上进小菜猪,沈工大软件工程专业,爱好敲代码,持续输出干货。 引言: 大数据分析是当今互联网时代的核心技术之一。通过有效地处理和分析大量的数据,企业可以从中获得有价值的洞察,以做出更明智的决策。本文将…

Mujoco210 Ubuntu 22.04配置安装(一)

目录 .1 下载 1.1 解压 1.2 许可问题 1.3 环境配置 1.4 测试mujoco .2 安装mujoco-py 2.1 conda激活虚拟环境\或新创建一个环境 2.2 下载mujoco-py ​编辑 2.3 配置环境变量 2.4 测试mujoco-py 2.5 测试时的一些报错处理 2.5.0 command /usr/bin/gcc failed with…

Linux操作系统——第三章 基础IO

目录 接口介绍 open 文件描述符fd 0 & 1 & 2 文件描述符的分配规则 重定向 FILE 理解文件系统 inode ​编辑 理解硬链接 软链接 动态库和静态库 静态库与动态库 生成静态库 库搜索路径 生成动态库 使用动态库 运行动态库 使用外部库 接口介绍 o…

(顶刊复现)配电网两阶段鲁棒故障恢复(matlab实现)

参考文献: X. Chen, W. Wu and B. Zhang, "Robust Restoration Method for Active Distribution Networks," in IEEE Transactions on Power Systems, vol. 31, no. 5, pp. 4005-4015, Sept. 2016, doi: 10.1109/TPWRS.2015.2503426. 1.研究背景 1.1摘…

2023 Navicat for Redis 与 Navicat Premium 16.2 现已正式发布 | 释放 Redis 全部潜能

🌷🍁 博主 libin9iOak带您 Go to New World.✨🍁 🦄 个人主页——libin9iOak的博客🎐 🐳 《面试题大全》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~&#x1f33…

【C++】 Lambda表达式详解

▒ 目录 ▒ 🛫 问题描述环境 1️⃣ 什么是Lambda表达式Lambda 表达式的各个部分 2️⃣ 优缺点优点缺点 3️⃣ 使用场景在线C工具STL算法库STL容器中需要传递比较函数(示例失败了)多线程示例 4️⃣ Lambda表达式与函数指针的比较5️⃣ 捕获列表…

KISS复盘法

KISS复盘法 KISS复盘法是一种科学的项目复盘方法,能够把过往经验转化为实践能力,以促进下一次活动更好地展开,从而不断提升个人和团队的能力! 模型介绍 【复盘】原是围棋术语,本意是对弈者在下完一盘棋之后&#xff0…

距离保护原理

距离保护是反映故障点至保护安装处的距离,并根据距离的远近确定动作时间的一种保护。故障点距保护安装处越近,保护的动作时间就越短,反之就越长,从而保证动作的选择性。测量故障点至保护安装处的距离,实际上就是用阻抗…

Spring Boot banner详解

Spring Boot 3.x系列文章 Spring Boot 2.7.8 中文参考指南(一)Spring Boot 2.7.8 中文参考指南(二)-WebSpring Boot 源码阅读初始化环境搭建Spring Boot 框架整体启动流程详解Spring Boot 系统初始化器详解Spring Boot 监听器详解Spring Boot banner详解 自定义banner Spring …