深度学习训练营之数据增强

news2025/1/13 15:38:06

深度学习训练营

  • 学习内容
  • 原文链接
  • 环境介绍
  • 前置工作
    • 设置GPU
    • 加载数据
    • 创建测试集
    • 数据类型查看以及数据归一化
  • 数据增强操作
    • 使用嵌入model的方法进行数据增强
  • 模型训练
  • 结果可视化
    • 自定义数据增强
  • 查看数据增强后的图片

学习内容

在深度学习当中,由于准备数据集本身是一件十分复杂的过程,很难保障每一张图片的学习能力都很高,所以对于同一种图片采用数据增强就显得十分重要了

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P10周:实现数据增强
  • 🍖 原作者:K同学啊|接辅导、项目定制

环境介绍

  • 语言环境:Python3.9.13
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2
  • 数据链接:猫和狗数据

前置工作

设置GPU

import matplotlib.pyplot as plt
import numpy as np
#隐藏警告
import warnings
warnings.filterwarnings('ignore')

from tensorflow.keras import layers
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

# 打印显卡信息,确认GPU可用
print(gpus)

加载数据

将对应的数据按照不同种类放入到不同文件夹当中,再将数据整合为animal_data

data_dir   = "animal_data"
img_height = 224
img_width  = 224
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 3400 files belonging to 2 classes.
Using 2380 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 3400 files belonging to 2 classes.
Using 2380 files for training.

创建测试集

因为数据本身没有设置测试集,这里需要进行手动创建

val_batches = tf.data.experimental.cardinality(val_ds)
test_ds     = val_ds.take(val_batches // 5)
val_ds      = val_ds.skip(val_batches // 5)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))

运行结构如下

Number of validation batches: 60
Number of test batches: 15

预测的batches和测试batches分别为60和15

数据类型查看以及数据归一化

class_names = train_ds.class_names
print(class_names)
['cat', 'dog']

进行数据归一化操作

AUTOTUNE = tf.data.AUTOTUNE

def preprocess_image(image,label):
    return (image/255.0,label)

# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds  = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

查看数据集

plt.figure(figsize=(15, 10))  # 图形的宽为15高为10

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(5, 8, i + 1) 
        plt.imshow(images[i])
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

在这里插入图片描述

数据增强操作

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])
#进行随机的水平翻转和垂直翻转
# Add the image to a batch.
image = tf.expand_dims(images[i], 0)
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = data_augmentation(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0])
    plt.axis("off")

请添加图片描述

使用嵌入model的方法进行数据增强

model = tf.keras.Sequential([
  data_augmentation,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
])
  • 这样的操作可以得到GPU的加速

模型训练

模型开始训练之前都需要进行这个模型的调整

model = tf.keras.Sequential([
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(len(class_names))
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

开始进行正式的训练

epochs=20
history = model.fit(
 train_ds,
 validation_data=val_ds,
 epochs=epochs
)
75/75 [==============================] - 35s 462ms/step - loss: 2.0268e-05 - accuracy: 1.0000 - val_loss: 1.8425e-05 - val_accuracy: 1.0000
Epoch 2/20
75/75 [==============================] - 34s 461ms/step - loss: 1.7937e-05 - accuracy: 1.0000 - val_loss: 1.6272e-05 - val_accuracy: 1.0000
Epoch 3/20
75/75 [==============================] - 35s 461ms/step - loss: 1.5871e-05 - accuracy: 1.0000 - val_loss: 1.4373e-05 - val_accuracy: 1.0000
Epoch 4/20
75/75 [==============================] - 34s 450ms/step - loss: 1.4039e-05 - accuracy: 1.0000 - val_loss: 1.2682e-05 - val_accuracy: 1.0000
Epoch 5/20
75/75 [==============================] - 34s 450ms/step - loss: 1.2429e-05 - accuracy: 1.0000 - val_loss: 1.1195e-05 - val_accuracy: 1.0000
Epoch 6/20
75/75 [==============================] - 35s 462ms/step - loss: 1.1014e-05 - accuracy: 1.0000 - val_loss: 9.8961e-06 - val_accuracy: 1.0000
Epoch 7/20
75/75 [==============================] - 34s 450ms/step - loss: 9.7220e-06 - accuracy: 1.0000 - val_loss: 8.6961e-06 - val_accuracy: 1.0000
Epoch 8/20
75/75 [==============================] - 34s 455ms/step - loss: 8.5416e-06 - accuracy: 1.0000 - val_loss: 7.6252e-06 - val_accuracy: 1.0000
Epoch 9/20
75/75 [==============================] - 34s 459ms/step - loss: 7.5130e-06 - accuracy: 1.0000 - val_loss: 6.7169e-06 - val_accuracy: 1.0000
Epoch 10/20
75/75 [==============================] - 34s 460ms/step - loss: 6.6338e-06 - accuracy: 1.0000 - val_loss: 5.9490e-06 - val_accuracy: 1.0000
Epoch 11/20
75/75 [==============================] - 34s 457ms/step - loss: 5.8835e-06 - accuracy: 1.0000 - val_loss: 5.2946e-06 - val_accuracy: 1.0000
Epoch 12/20
75/75 [==============================] - 34s 456ms/step - loss: 5.2507e-06 - accuracy: 1.0000 - val_loss: 4.7294e-06 - val_accuracy: 1.0000
Epoch 13/20
...
Epoch 19/20
75/75 [==============================] - 34s 449ms/step - loss: 2.5978e-06 - accuracy: 1.0000 - val_loss: 2.3737e-06 - val_accuracy: 1.0000
Epoch 20/20
75/75 [==============================] - 34s 449ms/step - loss: 2.3849e-06 - accuracy: 1.0000 - val_loss: 2.1841e-06 - val_accuracy: 1.0000

这里比较奇怪的是训练的结果准确性很高,loss的值都是很小很小的,和原本博主的相应的内容是不一样的,我觉得很大的可能应该是首先这个数据的内容很大,原本只有几百张图片,但是这里一共有3400张图片,再加上模型训练的增强方式比较简单,导致在结果上面训练看起来很好

loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)
15/15 [==============================] - 1s 83ms/step - loss: 1.9960e-06 - accuracy: 1.0000
Accuracy 1.0

结果可视化

自定义数据增强

这里主要是可以更改随机数种子的大小

import random
# 这是大家可以自由发挥的一个地方
def aug_img(image):
    seed = (random.randint(5,10), 0)
    #设立随机数种植,randint是指在0到9之间进行一个数据的增强
    # 随机改变图像对比度
    stateless_random_brightness = tf.image.stateless_random_contrast(image, lower=0.1, upper=1.0, seed=seed)
    return stateless_random_brightness

查看数据增强后的图片

image = tf.expand_dims(images[3]*255, 0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())
Min and max pixel values: 2.4591687 241.47968
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = aug_img(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0].numpy().astype("uint8"))

    plt.axis("off")

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/402607.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python 中 KeyError: 0 exception 错误

Python “KeyError: 0” 异常是在我们尝试访问不包含0 这个键的时候去访问该键而引起的。 要解决该错误,请在尝试访问字典之前在字典中设置键,或者如果键不存在,则使用 dict.get() 获取默认值。 下面是一个产生上述错误的示例 my_dict {1…

KDZD互感器二次负载测试仪

一、概述 电能计量综合误差过大是电能计量中普遍存在的一个关键问题。电压互感器二次回路压降引起的计量误差往往是影响电能计量综合误差的因素。所谓电压互感器二次压降引起的误差,就是指电压互感器二次端子和负载端子之间电压的幅值差相对于二次实际电压的百分数…

五分钟了解JumpServer V2.* 与 v3 的区别

一、升级注意项 1、梳理数据。JumpServer V3 去除了系统用户功能,将资产与资产直接绑定。当一个资产名下有多个同名账号,例如两个root用户时,升级后会自动合并最后一个root,不会同步其他root用户。升级前需保证每一个资产只拥有一…

即时通讯系列-N-客户端如何在推拉结合的模式下保证消息的可靠性展示

结论先行 原则: server拉取的消息一定是连续的原则: 端侧记录的消息的连续段有两个作用: 1. 记录消息的连续性, 即起始中间没有断层, 2. 消息连续, 同时意味着消息是最新的,消息不是过期的。同…

Java学习-MySQL-创建数据库表

Java学习-MySQL-创建数据库表 SHOW DATABASESUSE school CREATE TABLE IF NOT EXISTS student( id INT(10) NOT NULL AUTO_INCREMENT COMMENT 学号, name VARCHAR(30) NOT NULL DEFAULT 匿名 COMMENT 姓名, pws VARCHAR(20) NOT NULL DEFAULT 123456 COMMENT 密码, sex VARCHA…

算法题--二叉树(判断是不是平衡二叉树、二叉树的中序遍历、二叉树最大深度、对称二叉树、合并二叉树)

目录 二叉树 题目 判断是不是平衡二叉树 题链接 解析 核心思想 答案 二叉树的中序遍历 原题链接 解析 核心思想 答案 二叉树最大深度、对称二叉树、合并二叉树 二叉树 该类题目的解决一般是通过节点的遍历去实现,一般是分两种。 一是递归(…

【记录】日常|shandianchengzi的三周年创作纪念日

机缘 接触 CSDN 之前,我已经倒腾过 hexo 搭建 github 博客、本地博客、图床;   接触 CSDN 之后,我还倒腾过纸质笔记、gitee 博客、博客园、知乎、b站、Notion、腾讯文档、有道云笔记、XMind、飞书文档、简书等一系列创作平台,但…

SAPUI5开发01_01-Installing Eclipse

1.0 简要要求概述: 本节您将安装SAPUI 5,以及如何在Eclipse Juno中集成SAPUI 5工具。 1.1 安装JDK JDK 是一种用于构建在 Java 平台上发布的应用程序、Applet 和组件的开发环境,即编写 Java 程序必须使用 JDK,它提供了编译和运行 Java 程序的环境。 在安装 JDK 之前,首…

1635_fileno的简单使用

全部学习汇总: GreyZhang/g_unix: some basic learning about unix operating system. (github.com) 在看MIT的OS课程的时候发现自己动不动就因为只是的缺少而卡住,而这个学习占据了我工作之余很多的时间。现在都有一点觉得通关不了的感觉了,…

1. Qt Designer Studio界面介绍

1. 说明: Qt当中的Qt Quick框架使用QML语言来快速搭建优美的界面,但是对于单纯做界面的设计人员并不是很友好,还要让界面设计人员去消耗时间成本学习QML语法。Qt Designer Studio软件就是为了解决这个问题而设计的,工作人员不需要…

【Blender】Stability AI插件 - AI生成图像和动画

Stability AI 的官方插件允许 Blender 艺术家使用现有的项目和文本描述来创建新的图像、纹理和动画。 推荐:用 NSDT场景设计器 快速搭建3D场景。 1、安装Stability for Blender插件 首先,从这里下载最新版本的 Blender,然后转到 Addon Relea…

论文阅读笔记|大规模多标签文本分类

多标签文本分类(Extreme Multi Label Classification, MLTC)是自然语言处理领域中一个十分重要的任务,其旨在从一个给定的标签集合中选取出与文本相关的若干个标签。MLTC可以广泛应用于网页标注,话题识别和情感分析等场景。大规模…

1636_isatty函数的功能

全部学习汇总: GreyZhang/g_unix: some basic learning about unix operating system. (github.com) 前面刚刚看完了一个函数和三个文件指针,一行代码懂了半行。但是继续分析我之前看到的代码还是遇到了困难,因为之前自己对于UNIX的一些基础知…

网络协议(十四):WebSocket、WebService、RESTful、IPv6、网络爬虫、HTTP缓存

网络协议系列文章 网络协议(一):基本概念、计算机之间的连接方式 网络协议(二):MAC地址、IP地址、子网掩码、子网和超网 网络协议(三):路由器原理及数据包传输过程 网络协议(四):网络分类、ISP、上网方式、公网私网、NAT 网络…

Kubernetes(K8s)接入Prometheus示例、查看指标

Prometheus安装关联服务见:https://blog.csdn.net/lsc_2019/article/details/129445580?spm1001.2014.3001.5502 在Kubernetes中创建一个Deployment和一个Service apiVersion: apps/v1 kind: Deployment metadata:name: myapp spec:replicas: 3selector:matchLabe…

Jackson 返回前端的 Response结果字段大小问题

目录 1、问题产生的背景 2、出现的现象 3、解决方案 4、成果展现 5、总结 6、参考文章 1、问题产生的背景 因为本人最近工作相关的对接外部项目,在我们国内有很多程序员都是使用汉语拼音或者部分字母加上英文复合体定义返回实体VO,这样为了能够符合…

数据表(三) - 多语言的实现

前文介绍了关于数据表的几种形式,以及如何让数据表运用更加简单高效,这篇我们来讲讲多语言在数据表中的实现方式。游戏项目中文字显示本身就是件比较头疼的事,再加上多语言,更多的问题将待需解决。很多时候项目起初,文…

Golang-GMP模型

写在前面 Go 为了自身 goroutine 执行和调度的效率,自身在 runtime 中实现了一套 goroutine 的调度器,下面通过一段简单的代码展示一下 Go 应用程序在运行时的 goroutine,方便大家更好的理解。 The Go scheduler is part of the Go runtime,…

华为机试题:HJ92 在字符串中找出连续最长的数字串(python)

文章目录(1)题目描述(2)Python3实现(3)知识点详解1、input():获取控制台(任意形式)的输入。输出均为字符串类型。1.1、input() 与 list(input()) 的区别、及其相互转换方…

C++语法规则2(C++面向对象)

继承 面向对象程序设计中最重要的一个概念是继承。继承允许我们依据另一个类来定义一个类,这使得创建和维护一个应用程序变得更容易。这样做,也达到了重用代码功能和提高执行效率的效果。 当创建一个类时,您不需要重新编写新的数据成员和成…