tensorflow案例7--数据增强与测试集, 训练集, 验证集的构建

news2024/11/23 11:53:27
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

前言

  • 这次主要是学习数据增强, 训练集 验证集 测试集的构建等等的基本方法, 数据集还是用的上一篇的猫狗识别;
  • 基础篇还剩下几个, 后面的难度会逐步提升;
  • 欢迎收藏 + 关注, 本人会持续更新.

文章目录

  • 1. 简介
    • 数据增强
    • 训练集划分
  • 2. 案例测试
    • 1. 数据处理
      • 1. 导入库
      • 2. 导入数据(训练集 测试集 验证集)
      • 3. 数据部分展示
      • 4. 数据归一化与内存加速
      • 5. 数据增强
      • 6. 将增强数据融合到原始数据中
    • 2. 模型创建
    • 3. 模型训练
      • 1. 超参数设置
      • 2. 模型训练
      • 3. 模型测试
    • 4. 其他方法增强数据

1. 简介

数据增强

💙 有时候数据很好, **就可以通过在原有的基础上做一些操作, ** 从而增加数据的数量, 使训练模型更加有效.

📶 对于基础的的增强, 一般就是旋转, 在pytorch中一般是用transforms.Compose进行处理, 在tensorflow中,一般用的是tf.keras.layers.experimental.preprocessing.RandomFliptf.keras.layers.experimental.preprocessing.RandomRotation 进行数据增强, 👁 具体做法请看案例

当然还有其他的方法进行增强, 比如说添加噪音, 👓 详情请看第四节, 4. 其他方法数据增强

数据增强加入模型中

一般有两个方法:

  1. 加入数据集(本文用的方法)
  2. 加入到模型中, 让模型训练的时候, 开始进行数据增强, 这个本文不介绍

注意: tensorflow和numpy版本问题不同, 可能会出现比较多数据方面的错误, 本人这个案例最后也是在云平台上跑通的.

训练集划分

简单说一下训练集, 测试集, 验证集的区别:

  • 训练集: 用来训练模型的, 确定神经网络的各种参数, 相当于我们学习一样
  • 验证集: 在训练集中, 通过验证模型效果, 来调整模型参数, 这个就相当于我们月考一样
  • 测试集: 这个就是验证模型是都具有效果, 适用于其他数据, 这个就相当于我们大考

👀 在tensorflow中, 我们可以通过tf.keras.preprocessing.image_dataset_from_directory创建训练集和验证集, 但是不能创建测试集, 创建测试集的方法, 需要我们后面对数据进行分类, 如下:

val_batches = tf.data.experimental.cardinality(val_ds)
# 创建测试集,  方法: 将验证集合拆成 5 分, 测试集占一份, 验证集占 4 份
test_ds = val_ds.take(val_batches // 2)    # 取前 * 批次
val_ds = val_ds.skip(val_batches // 2)     # 除了前 * 批次

解释:

  • tf.data.experimental.cardinality获取数据批次大小
  • .take : 取前n批数据
  • .skip : 取除了前n批次数据

2. 案例测试

本次案例是对猫狗图像进行分类, 和上一期很像, 但是这个模型使用比较简单.

注意: 不同池化层, 效果有时候天差地别, 比如说: 这个案例用的是最大池化, 但是用平均池化的话, 效果极差

1. 数据处理

1. 导入库

import tensorflow as tf 
from tensorflow.keras import layers, models, datasets 
import numpy as np 

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]
    tf.config.experimental.set_memory_growth(gpu0, True)   # 输出存储在GPU
    tf.config.set_visible_devices([gpu0], "GPU")          # 选择第一块GPU
    
gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

2. 导入数据(训练集 测试集 验证集)

# 查看数据目录
import os, pathlib

data_dir = "./data/"
data_dir = pathlib.Path(data_dir)

classnames = [str(path) for path in os.listdir(data_dir)]
classnames
['cat', 'dog']
# 创建训练集和验证集

batch_size = 32
image_width, image_height = 224, 224

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    './data/',
    subset='training',
    validation_split=0.3,
    batch_size=batch_size,
    image_size=(image_width, image_height),
    shuffle=True,
    seed=42
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    './data',
    subset='validation',
    validation_split=0.3,
    batch_size=batch_size,
    image_size=(image_width, image_height),
    shuffle=True,
    seed=42
)
Found 600 files belonging to 2 classes.
Using 420 files for training.
Found 600 files belonging to 2 classes.
Using 180 files for validation.

在tensorflow没有提供直接分割测试集的函数,但是可以通过分割验证集的方法进行创建测试集

val_batches = tf.data.experimental.cardinality(val_ds)
# 创建测试集,  方法: 将验证集合拆成 5 分, 测试集占一份, 验证集占 4 份
test_ds = val_ds.take(val_batches // 2)    # 取前 * 批次
val_ds = val_ds.skip(val_batches // 2)     # 取除了前 * 批次

print("test batches: %d"%tf.data.experimental.cardinality(test_ds))
print("val batches: %d"%tf.data.experimental.cardinality(val_ds))
test batches: 3
val batches: 3

训练集: 验证集: 测试集 = 0.7 : 0.15 : 0.15

3. 数据部分展示

# 数据规格展示
for images, labels in train_ds.take(1):
    print("image: [N, W, H, C] ", images.shape)
    print("labels: ", labels)
    break
image: [N, W, H, C]  (32, 224, 224, 3)
labels:  tf.Tensor([0 1 1 0 0 0 1 0 0 0 0 1 0 1 0 0 0 1 1 1 1 0 1 1 0 1 1 0 1 0 1 0], shape=(32,), dtype=int32)
# 部分图片数据展示
import matplotlib.pyplot as plt

train_one_batch = next(iter(train_ds))

plt.figure(figsize=(20, 10))

images, labels = train_one_batch

for i in range(20):
    plt.subplot(5, 10, i + 1)
    
    plt.title(classnames[labels[i]])
    
    plt.imshow(images[i].numpy().astype('uint8'))
    
    plt.axis('off')
        
plt.show()


在这里插入图片描述

4. 数据归一化与内存加速

from tensorflow.data.experimental import AUTOTUNE 

# 像素归一化, ---> [0, 1]
normalization_layer = layers.experimental.preprocessing.Rescaling(1.0 / 255)

# 训练集、测试集像素归一化
train_ds = train_ds.map(lambda x, y : (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y : (normalization_layer(x), y))
test_ds = test_ds.map(lambda x, y : (normalization_layer(x), y))

# 设置内存加速
AUTOTUNE = tf.data.experimental.AUTOTUNE 

# 打乱顺序加速, 测试集就不必了哈
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

5. 数据增强

我们可以使用 tf.keras.layers.experimental.preprocessing.RandomFliptf.keras.layers.experimental.preprocessing.RandomRotation 进行数据增强.

  • tf.keras.layers.experimental.preprocessing.RandomFlip:水平和垂直随机翻转每个图像.
  • tf.keras.layers.experimental.preprocessing.RandomRotation:随机旋转每个图像.
# 封装整合
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),   # 垂直和水平反转
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.2)                      # 随机翻转
])

test_datas = next(iter(train_ds))

test_images, test_labels = test_datas

# 随机选取一个
test_image = tf.expand_dims(test_images[i], 0)

plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = data_augmentation(test_image)   # 旋转
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0])                 
    plt.axis("off")


在这里插入图片描述

6. 将增强数据融合到原始数据中

batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds):
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    return ds

# 增强
train_ds = prepare(train_ds)

2. 模型创建

model = models.Sequential([
    # 第一层要输入维度
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(image_width, image_height, 3)),
    layers.MaxPooling2D((2,2)),
    
    layers.Conv2D(32, (3, 3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Dropout(0.3),
    
    layers.Conv2D(32, (3, 3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Dropout(0.3),
    
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(len(classnames))
])

model.summary()
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 222, 222, 16)      448       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 111, 111, 16)     0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 109, 109, 32)      4640      
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 54, 54, 32)       0         
 2D)                                                             
                                                                 
 dropout (Dropout)           (None, 54, 54, 32)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 52, 52, 32)        9248      
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 26, 26, 32)       0         
 2D)                                                             
                                                                 
 dropout_1 (Dropout)         (None, 26, 26, 32)        0         
                                                                 
 flatten (Flatten)           (None, 21632)             0         
                                                                 
 dense (Dense)               (None, 128)               2769024   
                                                                 
 dense_1 (Dense)             (None, 2)                 258       
                                                                 
=================================================================
Total params: 2,783,618
Trainable params: 2,783,618
Non-trainable params: 0
_________________________________________________________________

3. 模型训练

1. 超参数设置

opt = tf.keras.optimizers.Adam(learning_rate=0.001)  # 学习率:0.001

model.compile(
    optimizer = opt,
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['accuracy']
)

2. 模型训练

epochs=20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    verbose=1
)
Epoch 1/20
2024-11-22 18:03:21.866630: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8101
2024-11-22 18:03:23.553540: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
14/14 [==============================] - 4s 36ms/step - loss: 0.7059 - accuracy: 0.5643 - val_loss: 0.6646 - val_accuracy: 0.6667
Epoch 2/20
14/14 [==============================] - 0s 27ms/step - loss: 0.6125 - accuracy: 0.6381 - val_loss: 0.6096 - val_accuracy: 0.7143
Epoch 3/20
14/14 [==============================] - 0s 15ms/step - loss: 0.5027 - accuracy: 0.7714 - val_loss: 0.5646 - val_accuracy: 0.7500
Epoch 4/20
14/14 [==============================] - 0s 14ms/step - loss: 0.4723 - accuracy: 0.7952 - val_loss: 0.5496 - val_accuracy: 0.7500
Epoch 5/20
14/14 [==============================] - 0s 14ms/step - loss: 0.4395 - accuracy: 0.7857 - val_loss: 0.6267 - val_accuracy: 0.7024
Epoch 6/20
14/14 [==============================] - 0s 13ms/step - loss: 0.3721 - accuracy: 0.8262 - val_loss: 0.5001 - val_accuracy: 0.7619
Epoch 7/20
14/14 [==============================] - 0s 14ms/step - loss: 0.4041 - accuracy: 0.8238 - val_loss: 0.4595 - val_accuracy: 0.7857
Epoch 8/20
14/14 [==============================] - 0s 13ms/step - loss: 0.3195 - accuracy: 0.8643 - val_loss: 0.4247 - val_accuracy: 0.8095
Epoch 9/20
14/14 [==============================] - 0s 13ms/step - loss: 0.3010 - accuracy: 0.8738 - val_loss: 0.3674 - val_accuracy: 0.8452
Epoch 10/20
14/14 [==============================] - 0s 14ms/step - loss: 0.3190 - accuracy: 0.8762 - val_loss: 0.3660 - val_accuracy: 0.8452
Epoch 11/20
14/14 [==============================] - 0s 15ms/step - loss: 0.2864 - accuracy: 0.8690 - val_loss: 0.3529 - val_accuracy: 0.8333
Epoch 12/20
14/14 [==============================] - 0s 13ms/step - loss: 0.2532 - accuracy: 0.8762 - val_loss: 0.2737 - val_accuracy: 0.8929
Epoch 13/20
14/14 [==============================] - 0s 13ms/step - loss: 0.2374 - accuracy: 0.9000 - val_loss: 0.2939 - val_accuracy: 0.8810
Epoch 14/20
14/14 [==============================] - 0s 15ms/step - loss: 0.2216 - accuracy: 0.8976 - val_loss: 0.2952 - val_accuracy: 0.8810
Epoch 15/20
14/14 [==============================] - 0s 13ms/step - loss: 0.2365 - accuracy: 0.9095 - val_loss: 0.2559 - val_accuracy: 0.9167
Epoch 16/20
14/14 [==============================] - 0s 13ms/step - loss: 0.2114 - accuracy: 0.9071 - val_loss: 0.2702 - val_accuracy: 0.8929
Epoch 17/20
14/14 [==============================] - 0s 15ms/step - loss: 0.2075 - accuracy: 0.9024 - val_loss: 0.2353 - val_accuracy: 0.9286
Epoch 18/20
14/14 [==============================] - 0s 13ms/step - loss: 0.1850 - accuracy: 0.9262 - val_loss: 0.1927 - val_accuracy: 0.9524
Epoch 19/20
14/14 [==============================] - 0s 13ms/step - loss: 0.1318 - accuracy: 0.9524 - val_loss: 0.1837 - val_accuracy: 0.9286
Epoch 20/20
14/14 [==============================] - 0s 15ms/step - loss: 0.1561 - accuracy: 0.9476 - val_loss: 0.1951 - val_accuracy: 0.9643

3. 模型测试

loss, acc = model.evaluate(test_ds)
print("Loss: ", loss)
print("Accuracy: ", acc)
3/3 [==============================] - 0s 8ms/step - loss: 0.2495 - accuracy: 0.9062
Loss:  0.24952644109725952
Accuracy:  0.90625

测试集准确率高, 模型效果良好

4. 其他方法增强数据

这里是使数据变得模糊

import random 

def aug_img(image):
    seed = (random.randint(0, 9), 0)
    stateless_random_brightness = tf.image.stateless_random_contrast(image, lower=0.1, upper=1.0, seed=seed)
    return stateless_random_brightness
# 随机选取一张照片
image = tf.expand_dims(test_images[i] * 255, 0)   # 注意: 不乘255, 会出现黑色, 因为 像素在0 - 1中

plt.figure(figsize=(8,8))
for i in range(9):
    image_show = aug_img(image)
    plt.subplot(3, 3, i + 1)
    plt.imshow(image_show[0].numpy().astype("uint8"))


在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2245978.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ssm面向品牌会员的在线商城小程序

摘要 随着Internet的发展,人们的日常生活已经离不开网络。未来人们的生活与工作将变得越来越数字化,网络化和电子化。它将是直接管理面向品牌会员的在线商城小程序的最新形式。本小程序是以面向品牌会员的在线商城管理为目标,使用 java技术制…

国土安全部发布关键基础设施安全人工智能框架

美国国土安全部 (DHS) 发布建议,概述如何在关键基础设施中安全开发和部署人工智能 (AI)。 https://www.dhs.gov/news/2024/11/14/groundbreaking-framework-safe-and-secure-deployment-ai-critical-infrastructure 关键基础设施中人工智能的角色和职责框架 https:/…

五天SpringCloud计划——DAY2之单体架构和微服务架构的选择和转换原则

一、引言 选择合适的架构模式是一个至关重要的决策,尤其是在单体架构和微服务架构之间的选择,本文将带大家认识什么是单体架构,什么是微服务架构,以及两者如何去选择,如何去转换。 二、什么是单体架构 单体架构&a…

【网络协议】【TCP】精讲TCP数据包传递的地址解析(含三次握手四次挥手图文并茂精华版)

目录 前言 1.TCP定义 1.1 什么是面向连接? 1.2 什么是可靠的通信协议? 1.3 什么是面向字节流的? 2. 数据包传递的地址解析 3. 三次握手过程详解 3.1 第一次握手 3.2 第二次握手 3.3 第三次握手 4. 四次挥手 4.1 第一次挥手 4.2 第二次挥手 4.3 第三次挥手 4.…

Win11 24H2新BUG或影响30%CPU性能,修复方法在这里

原文转载修改自(更多互联网新闻/搞机小知识): 一招提升Win11 24H2 CPU 30%性能,小BUG大影响 就在刚刚,小江在网上冲浪的时候突然发现了这么一则帖子,标题如下:基准测试(特别是 Time…

人工智能的核心思想-神经网络

神经网络原理 引言 在理解ChatGPT之前,我们需要从神经网络开始,了解最简单的“鹦鹉学舌”是如何实现的。神经网络是人工智能领域的基础,它模仿了人脑神经元的结构和功能,通过学习和训练来解决复杂的任务。本文将详细介绍神经网络…

socket连接封装

效果: class websocketMessage {constructor(params) {this.params params; // 传入的参数this.socket null;this.lockReconnect false; // 重连的锁this.socketTimer null; // 心跳this.lockTimer null; // 重连this.timeout 3000; // 发送消息this.callbac…

蓝桥杯每日真题 - 第17天

题目:(最大数字) 题目描述(13届 C&C B组D题) 题目分析: 操作规则: 1号操作:将数字加1(如果该数字为9,变为0)。 2号操作:将数字…

探索免费的Figma中文版:开启高效设计之旅

在当今数字化设计的浪潮中,Figma以其强大的云端协作功能和出色的设计能力,成为了众多设计师的心头好。而对于国内的设计师来说,能够免费使用Figma中文版更是一大福音,下面就来一起探索一下吧。 一、Figma中文版的获取途径 虽然F…

leetcode:112. 路径总和

给你二叉树的根节点 root 和一个表示目标和的整数 targetSum 。判断该树中是否存在 根节点到叶子节点 的路径,这条路径上所有节点值相加等于目标和 targetSum 。如果存在,返回 true ;否则,返回 false 。 叶子节点 是指没有子节点…

新160个crackme - 100-E-crackme

运行分析 需根据机器码,填写正确注册码 PE分析 C程序,32位,无壳 静态分析&动态调试 ida无法搜到字符串,使用暂停法找关键函数 首先启动ida动态调试,点击注册来到错误弹窗 点击Debugger -> Pause process 发现断…

VSCode 间距太小

setting->font family 使用:Consolas, Courier New, monospace 字体

七、电机三环控制

电机三环控制指的是,直流有刷电机三环(电流环速度环位置环)PID 控制。 1、三环PID控制原理 三环 PID 控制就是将三个 PID 控制系统(例如:电流环、速度环以及位置环)串联起来,然后对前一个系统…

【快讯】亚马逊(AMZN.US)关联方拟出售7.08万股股份,价值约1,407.69万美元

根据美国证券交易委员会(SEC)美东时间11月21日披露的文件,亚马逊(AMZN.US)关联方BEZOS EARTH FUND FOUNDATION拟于11月21日出售7.08万股普通股股份,总市值约1,407.69万美元。此外,BEZOS EARTH FUND FOUNDATION自2024年…

影响电阻可靠性的因素

一、影响电阻可靠性的因素: 影响电阻可靠性的因素有温度系数、额定功率,最大工作电压、固有噪声和电压系数 (一)温度系数 电阻的温度系数表示当温度改变1摄氏度时,电阻阻值的相对变化,单位为ppm/C.电阻温度…

51c大模型~合集76

我自己的原文哦~ https://blog.51cto.com/whaosoft/12617524 #诺奖得主哈萨比斯新作登Nature,AlphaQubit解码出更可靠量子计算机 谷歌「Alpha」家族又壮大了,这次瞄准了量子计算领域。 今天凌晨,新晋诺贝尔化学奖得主、DeepMind 创始人哈萨…

深入了解 Linux htop 命令:功能、用法与示例

文章目录 深入了解 Linux htop 命令:功能、用法与示例什么是 htop?htop 的安装htop的基本功能A区:系统资源使用情况B区:系统概览信息C区:进程列表D区:功能键快捷方式 与 top 的对比常见用法与示例实际场景应…

XML文件(超详细):XML文件概念、作用、写法、如何用程序解析XML、写入XML、dom4j框架、DTD文档、schema文档

目录 1、什么是XML文件?和properties属性文件有什么区别?和txt文本文件有什么区别? 2、XML文件的用途 3、XML的格式 4、如何解析XML文件 5、如何写入XML文件 6、约束XML的书写格式 6.1 DTD文档-约束书写格式,但是不能约束具…

通过端口测试验证网络安全策略

基于网络安全需求,项目中的主机间可能会有不同的网络安全策略,这当然是好的,但很多时候,在解决网络安全问题的时候,同时引入了新的问题,如k8s集群必须在主机间开放udp端口,否则集群不能正常的运…

国产光耦合器的竞争优势与市场发展前景

国产光耦合器近年来在技术研发和市场表现上取得了显著进步,逐渐在国际市场中占据了一席之地。作为实现电气隔离和信号传输的核心器件,光耦合器在工业控制、通信设备、消费电子等领域中有着广泛的应用。国产光耦合器凭借其独特的成本、技术和市场优势&…