T10:数据增强

news2024/11/27 16:47:05

T10周:数据增强

      • **一、前期工作**
        • 1.设置GPU,导入库
        • 2.加载数据
      • **二、数据增强**
      • **三、增强方式**
        • 方法一:将其嵌入model中
        • 方法二:在Dataset数据集中进行数据增强
      • **四、训练模型**
      • **五、自定义增强函数**
      • **六、总结**

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

🍺 要求:

  1. 学会在代码中使用数据增强手段来提高acc
  2. 请探索更多的数据增强手段并记录

在本教程中,你将学会如何进行数据增强,并通过数据增强用少量数据达到非常非常棒的识别准确率。我将展示两种数据增强方式,以及如何自定义数据增强方式并将其放到我们代码当中,两种数据增强方式如下:

  • 将数据增强模块嵌入model中
  • 在Dataset数据集中进行数据增强

⌛ 我的环境:

  • 语言环境:Python3.6.5
  • 编译器:Google Colab
  • 深度学习环境:TensorFlow2.17.0

一、前期工作

1.设置GPU,导入库
import matplotlib.pyplot as plt
import numpy as np
#隐藏警告
import warnings
warnings.filterwarnings('ignore')

from tensorflow.keras import layers
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

# 打印显卡信息,确认GPU可用
print(gpus)
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
2.加载数据
from google.colab import drive
drive.mount("/content/drive/")
%cd "/content/drive/Othercomputers/My laptop/jupyter notebook/data/"
Mounted at /content/drive/
/content/drive/Othercomputers/My laptop/jupyter notebook/data
data_dir   = "./10/"
img_height = 224
img_width  = 224
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 600 files belonging to 2 classes.
Using 420 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 600 files belonging to 2 classes.
Using 180 files for validation.

由于原始数据集不包含测试集,因此需要创建一个。使用 tf.data.experimental.cardinality 确定验证集中有多少批次的数据,然后将其中的 20% 移至测试集。

val_batches = tf.data.experimental.cardinality(val_ds)
test_ds     = val_ds.take(val_batches // 5)
val_ds      = val_ds.skip(val_batches // 5)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))
Number of validation batches: 5
Number of test batches: 1
class_names = train_ds.class_names
print(class_names)
['cat', 'dog']
AUTOTUNE = tf.data.AUTOTUNE

def preprocess_image(image,label):
    return (image/255.0,label)

# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds  = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
plt.figure(figsize=(15, 10))  # 图形的宽为15高为10

for images, labels in train_ds.take(1):
    for i in range(8):
        ax = plt.subplot(2, 4, i + 1)
        plt.imshow(images[i])
        plt.title(class_names[labels[i]])

        plt.axis("off")

在这里插入图片描述

二、数据增强

我们可以使用 tf.keras.layers.experimental.preprocessing.RandomFliptf.keras.layers.experimental.preprocessing.RandomRotation 进行数据增强

  • tf.keras.layers.experimental.preprocessing.RandomFlip:水平和垂直随机翻转每个图像。
  • tf.keras.layers.experimental.preprocessing.RandomRotation:随机旋转每个图像
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip("horizontal_and_vertical"), #进行随机水平和垂直翻转
  tf.keras.layers.RandomRotation(0.2), #按照0.2的弧度值进行随机旋转
  tf.keras.layers.RandomBrightness(factor=0.5, value_range=(0.0, 1.0), seed=123), #按照0.1比例调整亮度
  tf.keras.layers.RandomContrast(factor=0.1)
])

# 第一个层表示进行随机的水平和垂直翻转,而第二个层表示按照 0.2 的弧度值进行随机旋转。
#其余增强方式:https://www.tensorflow.org/api_docs/python/tf/keras/layers/RandomRotation
# Add the image to a batch.
image = tf.expand_dims(images[i], 0)

plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = data_augmentation(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0])
    plt.axis("off")

在这里插入图片描述

三、增强方式

方法一:将其嵌入model中
model_base = tf.keras.Sequential([
  data_augmentation,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
])

这样做的好处是:

  • 数据增强这块的工作可以得到GPU的加速(如果你使用了GPU训练的话)

    注意:只有在模型训练时(Model.fit)才会进行增强,在模型评估(Model.evaluate)以及预测(Model.predict)时并不会进行增强操作。
方法二:在Dataset数据集中进行数据增强
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE
#定义处理函数,只有training数据集会处理
def prepare(ds):
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    return ds
train_ds = prepare(train_ds)

四、训练模型

model = tf.keras.Sequential([
  model_base,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(len(class_names))
])

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 评价函数(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
model.compile(optimizer='adam',
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=['accuracy'])
epochs=20
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)
Epoch 1/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 116ms/step - accuracy: 0.5252 - loss: 0.6915 - val_accuracy: 0.4257 - val_loss: 0.7044
Epoch 2/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.5461 - loss: 0.6897 - val_accuracy: 0.5068 - val_loss: 0.6871
Epoch 3/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 35ms/step - accuracy: 0.6517 - loss: 0.6707 - val_accuracy: 0.7500 - val_loss: 0.6056
Epoch 4/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 45ms/step - accuracy: 0.6466 - loss: 0.6242 - val_accuracy: 0.7568 - val_loss: 0.5237
Epoch 5/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 41ms/step - accuracy: 0.6958 - loss: 0.5998 - val_accuracy: 0.7365 - val_loss: 0.6128
Epoch 6/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 37ms/step - accuracy: 0.7187 - loss: 0.6327 - val_accuracy: 0.7230 - val_loss: 0.5495
Epoch 7/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 37ms/step - accuracy: 0.7182 - loss: 0.5746 - val_accuracy: 0.7770 - val_loss: 0.5527
Epoch 8/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.7288 - loss: 0.5921 - val_accuracy: 0.8176 - val_loss: 0.4373
Epoch 9/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.7544 - loss: 0.4827 - val_accuracy: 0.8243 - val_loss: 0.3807
Epoch 10/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 29ms/step - accuracy: 0.7535 - loss: 0.5227 - val_accuracy: 0.8446 - val_loss: 0.4185
Epoch 11/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.7958 - loss: 0.4514 - val_accuracy: 0.8649 - val_loss: 0.3466
Epoch 12/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8288 - loss: 0.3707 - val_accuracy: 0.8581 - val_loss: 0.3243
Epoch 13/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8466 - loss: 0.3749 - val_accuracy: 0.7973 - val_loss: 0.4575
Epoch 14/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 29ms/step - accuracy: 0.8277 - loss: 0.4484 - val_accuracy: 0.8784 - val_loss: 0.3096
Epoch 15/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8629 - loss: 0.2800 - val_accuracy: 0.8716 - val_loss: 0.2454
Epoch 16/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8496 - loss: 0.3098 - val_accuracy: 0.9054 - val_loss: 0.1984
Epoch 17/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 29ms/step - accuracy: 0.8630 - loss: 0.2889 - val_accuracy: 0.9257 - val_loss: 0.2030
Epoch 18/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8820 - loss: 0.2658 - val_accuracy: 0.9257 - val_loss: 0.2008
Epoch 19/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8940 - loss: 0.2895 - val_accuracy: 0.9324 - val_loss: 0.1662
Epoch 20/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 31ms/step - accuracy: 0.8923 - loss: 0.2681 - val_accuracy: 0.8919 - val_loss: 0.2506
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 409ms/step - accuracy: 0.8438 - loss: 0.2309
Accuracy 0.84375

五、自定义增强函数

import random
# 这是大家可以自由发挥的一个地方
def aug_img(image):
    seed = (random.randint(0,9), 0)
    # 随机改变图像对比度
    stateless_random_brightness = tf.image.stateless_random_contrast(image, lower=0.1, upper=1.0, seed=seed)
    return stateless_random_brightness
image = tf.expand_dims(images[3]*255, 0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())
Min and max pixel values: 14.000048 253.28577
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = aug_img(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0].numpy().astype("uint8"))

    plt.axis("off")

在这里插入图片描述

参考上文的 preprocess_image 函数,将 aug_img 函数嵌入到 preprocess_image 函数中,在数据预处理时完成数据增强就OK啦。

六、总结

学习了不同的数据增强方法并且了解如何用增强方法预处理数据和引入模型训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2187055.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

idea使用ant源码运行tomcat8.5

1 安装ant 下载ant 下载地址 使用apache-ant-1.10.15版本 将压缩包放到/Library/Java/ant解压 设置ant环境变量 打开finder到用户根目录 按下shiftcmd.显示隐藏文件 打开隐藏文件.zprofile 按照以下格式设置环境变量 #ant export ANT_HOME/Library/Java/ant/apache-a…

数据丢失怎么办?2024四款恢复工具帮你忙!

数据是我们日常生活和工作中不可或缺的一部分,然而,由于误操作、硬件故障或其他原因导致的数据丢失却是常有的事情。这时候,一款可靠的数据恢复工具就显得尤为重要。 福昕数据恢复 直达链接:www.pdf365.cn/foxit-restore/ 假设…

Linux命令:用于显示 Linux 发行版信息的命令行工具lsb_release详解

目录 一、概述 二、用法 1、基本用法 2、选项 3、获取帮助 三、示例 1. 显示所有信息 2. 只显示发行版名称 3. 只显示发行版版本号 4. 只显示发行版代号 5. 只显示发行版描述 6. 只显示值,不显示标签 四、使用场景 1、自动化脚本 2、诊断问题 3、环…

【KVM】虚拟化技术及实战

一,KVM简介 KVM全称为QEMU-KVM。 KVM可以模拟内存,cpu的虚拟化,不能模拟其他设备虚拟化。 QEMU可以模拟I/O设备(网卡,磁盘等) 两者结合,实现真正意义上的虚拟化。 从rhel6版本开始&#xff0c…

Elasticsearch——数据聚合、数据同步与集群搭建

目录 1.数据聚合1.1.聚合的种类1.2.DSL实现聚合1.2.1.Bucket 聚合语法1.2.2.聚合结果排序1.2.3.限定聚合范围1.2.4.Metric 聚合语法1.2.5.小结 1.3.RestAPI 实现聚合1.3.1.API 语法1.3.2.业务需求1.3.3.业务实现 2.自动补全2.1.拼音分词器2.2.自定义分词器2.3.自动补全查询2.4.…

YOLOv8 结合设计硬件感知神经网络设计的高效 Repvgg的ConvNet 网络结构 ,改进EfficientRep结构

一、理论部分 摘要—我们提出了一种硬件高效的卷积神经网络架构,它具有类似 repvgg 的架构。Flops 或参数是评估网络效率的传统指标,这些网络对硬件(包括计算能力和内存带宽)不敏感。因此,如何设计神经网络以有效利用硬件的计算能力和内存带宽是一个关键问题。本文提出了一…

Chromium 使用安全 DNS功能源码分析c++

一、选项页安全dns选项如下图: 二、那么如何自定义安全dns功能呢? 1、先看前端部分代码调用 shared.rollup.jsclass PrivacyPageBrowserProxyImpl {.................................................................getSecureDnsResolverList() {re…

25货拉拉校园招聘面试经验 面试最常见问题总结

货拉拉校园招聘面试经验 目录 【面试经历】 问题+详细答案 面试全流程 【面试经历】 发面经,攒人品。 项目问题: 1.AOP日志落库到数据库,为什么不用一些现成的方案? 2.邀请链接的id怎么用redis生成的? 3.乐观锁保证了奖励的正确发放,请你说说乐观锁的原理。 4.奖…

TCN模型实现电力数据预测

关于深度实战社区 我们是一个深度学习领域的独立工作室。团队成员有:中科大硕士、纽约大学硕士、浙江大学硕士、华东理工博士等,曾在腾讯、百度、德勤等担任算法工程师/产品经理。全网20多万粉丝,拥有2篇国家级人工智能发明专利。 社区特色&a…

AI少女/HS2甜心选择2 仿天刀人物卡全合集打包

内含AI少女/甜心选择2 仿天刀角色卡全合集打包共21张 下载地址:https://www.51888w.com/408.html 部分演示图:

python全栈学习记录(二十一)类的继承、派生、组合

类的继承、派生、组合 文章目录 类的继承、派生、组合一、类的继承二、派生三、组合 一、类的继承 继承是一种新建类的方式,新建的类称为子类,被继承的类称为父类。 继承的特性是:子类会遗传父类的属性(继承是类与类之间的关系&a…

基于STM32的数字温度传感器设计与实现

引言 STM32 是由意法半导体(STMicroelectronics)开发的基于 ARM Cortex-M 内核的微控制器系列,以其强大的处理能力、丰富的外设接口和低功耗著称,广泛应用于嵌入式系统设计中。在这篇文章中,我们将介绍如何基于 STM32…

C++《string》

在之前的C语言学习当中我们已经了解了一系列的字符以及字符串函数,虽然这些函数也能实现对字符串进行求长度、拷贝、追加等操作,但是C语言当中的这些函数是与字符串分离的,并且最主要的是在使用这些函数时原字符串的底层空间是需要我们自己来…

微知-Intel芯片中的QPI是什么?本质是什么?以及其他几个高速总线的速率问题(快速通道互联,CPU之间互联总线)

基础信息 CPU与CPU之间通过QPI总线进行通信,类似CPU与PCI-E设备通过PCIE总线进行通信。 The Intel QuickPath Interconnect (QPI):快速通道互联,快路径内部互联总线。是Inter-connect,内部互联的。是英特尔开发的一种高速点对点…

SpringBoot精华:打造高效美容院管理系统

第一章 绪论 1.1 选题背景 如今的信息时代,对信息的共享性,信息的流通性有着较高要求,尽管身边每时每刻都在产生大量信息,这些信息也都会在短时间内得到处理,并迅速传播。因为很多时候,管理层决策需要大量信…

BiLSTM模型实现电力数据预测

基础模型见:A020-LSTM模型实现电力数据预测 1. 引言 时间序列预测在电力系统管理、负荷预测和能源优化等领域具有重要意义。传统的单向长短期记忆网络(LSTM)因其在处理时间序列数据中的优势,广泛应用于此类任务。然而&#xff0…

用友NC service接口信息泄露漏洞

漏洞描述 用友NC service接口信息泄露漏洞,攻击者可通过构造恶意链接获取所有接口链接 公网上大部分服务器都没有修复此漏洞,可刷SRC 用友nc有个接口可以获取数据库账户密码,不过是老版本了 漏洞复现 app"用友-UFIDA-NC" POC …

哪家宠物空气净化器可以高效去除浮毛?希喂、IAM、有哈怎么样

在现代养宠家庭中,随着生活节奏的加快,清理浮毛也是很多家庭周末必须要做的事情。但是如何选择一款吸毛好、还不增加清理负担的宠物空气净化器,在寸土寸金的租房里为全家老小的健康生活保障?又如何通过强大的吸毛、除臭技术和除菌…

【学习笔记】手写一个简单的 Spring IOC

目录 一、什么是 Spring IOC? 二、IOC 的作用 1. IOC 怎么知道要创建哪些对象呢? 2. 创建出来的对象放在哪儿? 3. 创建出来的对象如果有属性,如何给属性赋值? 三、实现步骤 1. 创建自定义注解 2. 创建 IOC 容器…

IO模型介绍

一、理解IO 网络通信的本质就是进程间通信,进程间通信本质就是IO TCP中的IO接口:read / write / send / recv,本质都是:等 拷贝 所以IO的本质就是:等 拷贝 那么如何高效的IO? 减少“等”在单位时间的…