深度学习Week16——数据增强

news2025/1/19 11:15:56

文章目录
深度学习Week16——数据增强
一、前言
二、我的环境
三、前期工作
1、配置环境
2、导入数据
2.1 加载数据
2.2 配置数据集
2.3 数据可视化
四、数据增强
五、增强方式
1、将其嵌入model中
2、在Dataset数据集中进行数据增强
六、训练模型
七、自定义增强函数

一、前言

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

本篇内容分为两个部分,前面部分是学习K同学给的算法知识点以及复现,后半部分是自己的拓展与未解决的问题

本期学习了数据增强函数并自己实现一个增强函数,使用的数据集仍然是猫狗数据集。

二、我的环境

  • 电脑系统:Windows 10
  • 语言环境:Python 3.8.0
  • 编译器:Pycharm2023.2.3
    深度学习环境:TensorFlow
    显卡及显存:RTX 3060 8G

三、前期工作

1、配置环境

import matplotlib.pyplot as plt
import numpy as np
#隐藏警告
import warnings
warnings.filterwarnings('ignore')

from tensorflow.keras import layers
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

# 打印显卡信息,确认GPU可用
print(gpus)

输出:

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

这一步与pytorch第一步类似,我们在写神经网络程序前无论是选择pytorch还是tensorflow都应该配置好gpu环境(如果有gpu的话)

2、 导入数据

导入所有猫狗图片数据,依次分别为训练集图片(train_images)、训练集标签(train_labels)、测试集图片(test_images)、测试集标签(test_labels),数据集来源于K同学啊

2.1 加载数据
data_dir   = "/home/mw/input/dogcat3675/365-7-data"
img_height = 224
img_width  = 224
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset
tf.keras.preprocessing.image_dataset_from_directory()会将文件夹中的数据加载到tf.data.Dataset中,且加载的同时会打乱数据。

  • class_names
  • validation_split: 0和1之间的可选浮点数,可保留一部分数据用于验证。
  • subset: training或validation之一。仅在设置validation_split时使用。
  • seed: 用于shuffle和转换的可选随机种子。
  • batch_size: 数据批次的大小。默认值:32
  • image_size: 从磁盘读取数据后将其重新调整大小。默认:(256,256)。由于管道处理的图像批次必须具有相同的大小,因此该参数必须提供。

输出:

Found 3400 files belonging to 2 classes.
Using 2380 files for training.

由于原始的数据集里不包含测试集,所以我们需要自己创建一个

val_batches = tf.data.experimental.cardinality(val_ds)
test_ds     = val_ds.take(val_batches // 5)
val_ds      = val_ds.skip(val_batches // 5)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))
Number of validation batches: 60
Number of test batches: 15

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)

[‘cat’, ‘dog’]

2.2 配置数据集
AUTOTUNE = tf.data.AUTOTUNE

def preprocess_image(image,label):
    return (image/255.0,label)

# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
2.3 数据可视化
plt.figure(figsize=(15, 10))  # 图形的宽为15高为10

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(5, 8, i + 1) 
        plt.imshow(images[i])
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

在这里插入图片描述

四 、数据增强

使用下面两个函数来进行数据增强:

  • tf.keras.layers.experimental.preprocessing.RandomFlip:水平和垂直随机翻转每个图像。
  • tf.keras.layers.experimental.preprocessing.RandomRotation:随机旋转每个图像
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.3),
])

第一个层表示进行随机的水平和垂直翻转,而第二个层表示按照0.3的弧度值进行随机旋转。

# Add the image to a batch.
image = tf.expand_dims(images[i], 0)

plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = data_augmentation(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0])
    plt.axis("off")

五、增强方式

1. 将其嵌入model中

model = tf.keras.Sequential([
  data_augmentation,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(len(class_names))
])

Epoch 1/20
43/43 [==============================] - 18s 103ms/step - loss: 1.2824 - accuracy: 0.5495 - val_loss: 0.4272 - val_accuracy: 0.8941
Epoch 2/20
43/43 [==============================] - 3s 55ms/step - loss: 0.3326 - accuracy: 0.8815 - val_loss: 0.1882 - val_accuracy: 0.9309
Epoch 3/20
43/43 [==============================] - 3s 54ms/step - loss: 0.1614 - accuracy: 0.9488 - val_loss: 0.1493 - val_accuracy: 0.9412
Epoch 4/20
43/43 [==============================] - 2s 54ms/step - loss: 0.1215 - accuracy: 0.9557 - val_loss: 0.0950 - val_accuracy: 0.9721
Epoch 5/20
43/43 [==============================] - 3s 54ms/step - loss: 0.0906 - accuracy: 0.9666 - val_loss: 0.0791 - val_accuracy: 0.9691
Epoch 6/20
43/43 [==============================] - 3s 56ms/step - loss: 0.0614 - accuracy: 0.9768 - val_loss: 0.1131 - val_accuracy: 0.9559
Epoch 7/20
43/43 [==============================] - 3s 55ms/step - loss: 0.0603 - accuracy: 0.9807 - val_loss: 0.0692 - val_accuracy: 0.9794
Epoch 8/20
43/43 [==============================] - 3s 55ms/step - loss: 0.0577 - accuracy: 0.9793 - val_loss: 0.0609 - val_accuracy: 0.9779
Epoch 9/20
43/43 [==============================] - 3s 55ms/step - loss: 0.0511 - accuracy: 0.9825 - val_loss: 0.0546 - val_accuracy: 0.9779
Epoch 10/20
43/43 [==============================] - 3s 55ms/step - loss: 0.0462 - accuracy: 0.9871 - val_loss: 0.0628 - val_accuracy: 0.9765
Epoch 11/20
43/43 [==============================] - 3s 55ms/step - loss: 0.0327 - accuracy: 0.9895 - val_loss: 0.0790 - val_accuracy: 0.9721
Epoch 12/20
43/43 [==============================] - 3s 55ms/step - loss: 0.0242 - accuracy: 0.9938 - val_loss: 0.0580 - val_accuracy: 0.9794
Epoch 13/20
43/43 [==============================] - 3s 55ms/step - loss: 0.0354 - accuracy: 0.9907 - val_loss: 0.0797 - val_accuracy: 0.9735
Epoch 14/20
43/43 [==============================] - 3s 55ms/step - loss: 0.0276 - accuracy: 0.9900 - val_loss: 0.0810 - val_accuracy: 0.9691
Epoch 15/20
43/43 [==============================] - 3s 56ms/step - loss: 0.0243 - accuracy: 0.9931 - val_loss: 0.1063 - val_accuracy: 0.9676
Epoch 16/20
43/43 [==============================] - 3s 56ms/step - loss: 0.0253 - accuracy: 0.9914 - val_loss: 0.1142 - val_accuracy: 0.9721
Epoch 17/20
43/43 [==============================] - 3s 56ms/step - loss: 0.0205 - accuracy: 0.9937 - val_loss: 0.0726 - val_accuracy: 0.9706
Epoch 18/20
43/43 [==============================] - 3s 56ms/step - loss: 0.0154 - accuracy: 0.9948 - val_loss: 0.0741 - val_accuracy: 0.9765
Epoch 19/20
43/43 [==============================] - 3s 56ms/step - loss: 0.0155 - accuracy: 0.9966 - val_loss: 0.0870 - val_accuracy: 0.9721
Epoch 20/20
43/43 [==============================] - 3s 55ms/step - loss: 0.0259 - accuracy: 0.9907 - val_loss: 0.1194 - val_accuracy: 0.9721

这样做的好处是:
数据增强这块的工作可以得到GPU的加速(如果你使用了GPU训练的话)
注意:只有在模型训练时(Model.fit)才会进行增强,在模型评估(Model.evaluate)以及预测(Model.predict)时并不会进行增强操作。

2. 在Dataset数据集中进行数据增强

batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds):
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    return ds
model = tf.keras.Sequential([
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(len(class_names))
])
Epoch 1/20
75/75 [==============================] - 11s 133ms/step - loss: 0.8828 - accuracy: 0.7113 - val_loss: 0.1488 - val_accuracy: 0.9447
Epoch 2/20
75/75 [==============================] - 2s 33ms/step - loss: 0.1796 - accuracy: 0.9317 - val_loss: 0.0969 - val_accuracy: 0.9658
Epoch 3/20
75/75 [==============================] - 2s 33ms/step - loss: 0.0999 - accuracy: 0.9655 - val_loss: 0.0362 - val_accuracy: 0.9879
Epoch 4/20
75/75 [==============================] - 2s 33ms/step - loss: 0.0566 - accuracy: 0.9810 - val_loss: 0.0448 - val_accuracy: 0.9853
Epoch 5/20
75/75 [==============================] - 2s 33ms/step - loss: 0.0426 - accuracy: 0.9807 - val_loss: 0.0142 - val_accuracy: 0.9937
Epoch 6/20
75/75 [==============================] - 2s 33ms/step - loss: 0.0149 - accuracy: 0.9944 - val_loss: 0.0052 - val_accuracy: 0.9989
Epoch 7/20
75/75 [==============================] - 2s 33ms/step - loss: 0.0068 - accuracy: 0.9974 - val_loss: 7.9693e-04 - val_accuracy: 1.0000
Epoch 8/20
75/75 [==============================] - 2s 33ms/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 4.8532e-04 - val_accuracy: 1.0000
Epoch 9/20
75/75 [==============================] - 2s 33ms/step - loss: 4.5804e-04 - accuracy: 1.0000 - val_loss: 1.9160e-04 - val_accuracy: 1.0000
Epoch 10/20
75/75 [==============================] - 2s 33ms/step - loss: 1.7624e-04 - accuracy: 1.0000 - val_loss: 1.1390e-04 - val_accuracy: 1.0000
Epoch 11/20
75/75 [==============================] - 2s 33ms/step - loss: 1.1646e-04 - accuracy: 1.0000 - val_loss: 8.7005e-05 - val_accuracy: 1.0000
Epoch 12/20
75/75 [==============================] - 2s 33ms/step - loss: 9.0645e-05 - accuracy: 1.0000 - val_loss: 7.1111e-05 - val_accuracy: 1.0000
Epoch 13/20
75/75 [==============================] - 2s 33ms/step - loss: 7.4695e-05 - accuracy: 1.0000 - val_loss: 5.9888e-05 - val_accuracy: 1.0000
Epoch 14/20
75/75 [==============================] - 2s 33ms/step - loss: 6.3227e-05 - accuracy: 1.0000 - val_loss: 5.1448e-05 - val_accuracy: 1.0000
Epoch 15/20
75/75 [==============================] - 2s 33ms/step - loss: 5.4484e-05 - accuracy: 1.0000 - val_loss: 4.4721e-05 - val_accuracy: 1.0000
Epoch 16/20
75/75 [==============================] - 2s 33ms/step - loss: 4.7525e-05 - accuracy: 1.0000 - val_loss: 3.9201e-05 - val_accuracy: 1.0000
Epoch 17/20
75/75 [==============================] - 2s 33ms/step - loss: 4.1816e-05 - accuracy: 1.0000 - val_loss: 3.4528e-05 - val_accuracy: 1.0000
Epoch 18/20
75/75 [==============================] - 2s 33ms/step - loss: 3.7006e-05 - accuracy: 1.0000 - val_loss: 3.0541e-05 - val_accuracy: 1.0000
Epoch 19/20
75/75 [==============================] - 2s 33ms/step - loss: 3.2878e-05 - accuracy: 1.0000 - val_loss: 2.7116e-05 - val_accuracy: 1.0000
Epoch 20/20
75/75 [==============================] - 2s 33ms/step - loss: 2.9274e-05 - accuracy: 1.0000 - val_loss: 2.4160e-05 - val_accuracy: 1.0000

六、训练模型

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

epochs=20
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)

使用方法一:

15/15 [==============================] - 1s 58ms/step - loss: 0.0984 - accuracy: 0.9646
Accuracy 0.9645833373069763

使用方法二:


15/15 [==============================] - 1s 58ms/step - loss: 2.7453e-05 - accuracy: 1.0000
Accuracy 1.0

七、自定义增强函数

import random
def aug_img(image):
    seed = random.randint(0, 10000)  # 随机种子

    # 随机亮度
    image = tf.image.stateless_random_brightness(image, max_delta=0.2, seed=[seed, 0])

    # 随机对比度
    image = tf.image.stateless_random_contrast(image, lower=0.8, upper=1.2, seed=[seed, 1])

    # 随机饱和度
    image = tf.image.stateless_random_saturation(image, lower=0.8, upper=1.2, seed=[seed, 2])

    # 随机色调
    image = tf.image.stateless_random_hue(image, max_delta=0.2, seed=[seed, 3])

    # 随机翻转水平和垂直
    image = tf.image.stateless_random_flip_left_right(image, seed=[seed, 4])
    image = tf.image.stateless_random_flip_up_down(image, seed=[seed, 5])

    # 随机旋转
    image = tf.image.rot90(image, k=random.randint(0, 3))  # 旋转0, 90, 180, 270度

    return image
image = tf.expand_dims(images[3]*255, 0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())
Min and max pixel values: 2.4591687 241.47968
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = aug_img(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0].numpy().astype("uint8"))

    plt.axis("off")

在这里插入图片描述
然后我们使用了第二种增强方法,以下为他的结果:

15/15 [==============================] - 1s 57ms/step - loss: 0.1294 - accuracy: 0.9604
Accuracy 0.9604166746139526

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1801885.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么时候用C而不用C++?

做接口只用C,千万别要C。C是编译器敏感的,一旦导出的接口里有 std::string这些东西,以及类,注定了要为各个编译器的各个版本准备独立的库。 刚好我有一些资料,是我根据网友给的问题精心整理了一份「C的资料从专业入门…

[AIGC] SpringBoot的自动配置解析

下面是一篇关于SpringBoot自动配置的文章,里面包含了一个简单的示例来解释自动配置的原理。 SpringBoot的自动配置解析 Spring Boot是Spring的一个子项目,用于快速开发应用程序。它主要是简化新Spring应用的初始建立以及开发过程。其中,自动…

传统工科硕士想转嵌入式,时间够吗?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「嵌入式的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!! 零基础开始学&#xff0…

帕友饮食改善的小建议!

一、增加膳食纤维的摄入 帕金森病患者应增加膳食纤维的摄入量,以帮助调节肠道功能,预防便秘。膳食纤维丰富的食物包括蔬菜、水果、全谷类食物等。患者应确保每天摄入足够的膳食纤维,以保持肠道通畅,缓解帕金森病可能带来的消化不…

Huawei 大型 WLAN 组网 AC 间漫游

AC1配置命令 <AC6005>display current-configuration # vlan batch 100 # interface Vlanif100description to_S3_CAPWAPip address 10.0.100.254 255.255.255.0 # interface GigabitEthernet0/0/1port link-type trunkport trunk allow-pass vlan 100# ip route-stati…

Python 机器学习 基础 之 【实战案例】轮船人员获救预测实战

Python 机器学习 基础 之 【实战案例】轮船人员获救预测实战 目录 Python 机器学习 基础 之 【实战案例】轮船人员获救预测实战 一、简单介绍 二、轮船人员获救预测实战 三、数据处理 1、导入数据 2、对缺失数据的列进行填充 3、属性转换&#xff0c;把某些列的字符串值…

基于统一二维电子气密度表达式的通用MIS-HEMT紧凑模型

来源&#xff1a;A Compact Model for Generic MIS-HEMTs Based on the Unified 2DEG Density Expression&#xff08;TED 14年&#xff09; 摘要 本文提出了一种针对二维电子气&#xff08;ns&#xff09;密度和费米能级&#xff08;E_f&#xff09;的解析表达式&#xff0c…

c++使用_beginthreadex创建线程

记录使用_beginthreadex()&#xff0c;来创建线程。方便后期的使用。 创建一个线程 相关函数介绍 unsigned long _beginthreadex( void *security, // 安全属性&#xff0c; 为NULL时表示默认安全性 unsigned stack_size, // 线程的堆栈大小&#xff0c; 一般默认为0 u…

大型语言模型智能体(LLM Agent)在实际使用的五大问题

在这篇文章中&#xff0c;我将讨论人们在将代理系统投入生产过程中经常遇到的五个主要问题。我将尽量保持框架中立&#xff0c;尽管某些问题在特定框架中更加常见。 1. 可靠性问题 可靠性是所有代理系统面临的最大问题。很多公司对代理系统的复杂任务持谨慎态度&#xff0c;因…

SMS-GSM

SMS-GSM 短信模块&#xff0c;不想通过第三方的接口&#xff0c;自己搭建短信模块&#xff0c;提高信息安全。 /**/ package sms;import com.diagcn.smslib.CMessage; import com.diagcn.smslib.COutgoingMessage; import com.diagcn.smslib.SZHTOCService;/*** 短信模块** au…

用于认知负荷评估的集成时空深度聚类(ISTDC)

Integrated Spatio-Temporal Deep Clustering (ISTDC) for cognitive workload assessment 摘要&#xff1a; 本文提出了一种新型的集成时空深度聚类&#xff08;ISTDC&#xff09;模型&#xff0c;用于评估认知负荷。该模型首先利用深度表示学习&#xff08;DRL&#xff09;…

css3 都有哪些新属性

1. css3 都有哪些新属性 1.1. 圆角边框 (border-radius)1.2. 盒子阴影 (box-shadow)1.3. 文本阴影 (text-shadow)1.4. 响应式设计相关属性1.5. 渐变背景 (gradient backgrounds)1.6. 透明度 (opacity 和 rgba/hsla)1.7. 多列布局 (column-count, column-gap, etc.)1.8. 变换 (t…

设置电脑定时关机

1.使用快捷键winR 打开运行界面 2.输入cmd &#xff0c;点击确认&#xff0c;打开命令行窗口&#xff0c;输入 shutdown -s -t 100&#xff0c;回车执行命令&#xff0c;自动关机设置成功 shutdown: 这是主命令&#xff0c;用于执行关闭或重启操作。-s: 这个参数用于指定执行关…

超详解——识别None——小白篇

目录 1. 内建类型的布尔值 2. 对象身份的比较 3. 对象类型比较 4. 类型工厂函数 5. Python不支持的类型 总结&#xff1a; 1. 内建类型的布尔值 在Python中&#xff0c;布尔值的计算遵循如下规则&#xff1a; None、False、空序列&#xff08;如空列表 []&#xff0c;空…

【启明智显分享】基于工业级芯片Model3A的7寸彩色触摸屏应用于智慧电子桌牌方案

一场大型会议的布置&#xff0c;往往少不了制作安放参会人物的桌牌。制作、打印、裁剪&#xff0c;若有临时参与人员变更&#xff0c;会务方免不了手忙脚乱更新桌牌。由此&#xff0c;智能电子桌牌应运而生&#xff0c;工作人员通过系统操作更新桌牌信息&#xff0c;解决了传统…

第一个小爬虫_爬取 股票数据

前言 爬取 雪球网的股票数据 [环境使用]&#xff1a;python 3.12 解释器pycharm 编辑器 【模块使用】&#xff1a;import requests -->数据请求模块 要安装 命令 pip install requestsimport csv -->将数据保存到CSV表格中import pandas -->也可以将数据保…

react的自定义组件

// 自定义组件(首字母必须大写) function Button() {return <button>click me</button>; } const Button1()>{return <button>click me1</button>; }// 使用组件 function App() {return (<div className"App">{/* // 自闭和引用自…

【全部更新完毕】2024全国大学生数据统计与分析竞赛B题思路代码文章教学数学建模-电信银行卡诈骗的数据分析

电信银行卡诈骗的数据分析 摘要 电信银行卡诈骗是当前社会中严重的犯罪问题&#xff0c;分析电信银行卡交易数据&#xff0c;找出高风险交易特征&#xff0c;建立预测模型&#xff0c;将有助于公安部门和金融机构更好地防范诈骗行为&#xff0c;保障用户的财产安全。 针对问…

Golang | Leetcode Golang题解之第131题分割回文串

题目&#xff1a; 题解&#xff1a; func partition(s string) (ans [][]string) {n : len(s)f : make([][]int8, n)for i : range f {f[i] make([]int8, n)}// 0 表示尚未搜索&#xff0c;1 表示是回文串&#xff0c;-1 表示不是回文串var isPalindrome func(i, j int) int8…

【Python】常见的第三方库及实例

各位大佬好 &#xff0c;这里是阿川的博客 &#xff0c; 祝您变得更强 个人主页&#xff1a;在线OJ的阿川 大佬的支持和鼓励&#xff0c;将是我成长路上最大的动力 阿川水平有限&#xff0c;如有错误&#xff0c;欢迎大佬指正 库介绍 Python是通过模块来体现库&#xff0…