第T10周:数据增强

news2024/9/24 7:20:52

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
>- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

在本教程中,你将学会如何进行数据增强,并通过数据增强用少量数据达到非常非常棒的识别准确率。

我将展示两种数据增强方式,以及如何自定义数据增强方式并将其放到我们代码当中,两种数据增强方式如下:

  • 将数据增强模块嵌入model中
  • 在Dataset数据集中进行数据增强

 

一、前期准备工作

1. 设置GPU

import matplotlib.pyplot as plt
import numpy as np
#隐藏警告
import warnings
warnings.filterwarnings('ignore')

from tensorflow.keras import layers
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

# 打印显卡信息,确认GPU可用
print(gpus)

2. 加载数据 

关于 tf.keras.preprocessing.image_dataset_from_directory 的介绍,我这里就不赘述了,不明白的同学直接看这里:tf.keras.preprocessing.image_dataset_from_directory() 简介_tf.python.keras preprocessing在哪里-CSDN博客文章浏览阅读1.1w次,点赞14次,收藏69次。函数原型tf.keras.preprocessing.image_dataset_from_directory( directory, labels="inferred", label_mode="int", class_names=None, color_mode="rgb", batch_size=32, image_size=(256, 256), shuffle=True, seed=None, validation__tf.python.keras preprocessing在哪里https://blog.csdn.net/qq_38251616/article/details/117018789

data_dir   = "./34-data/"
img_height = 224
img_width  = 224
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

由于原始数据集不包含测试集,因此需要创建一个。使用 tf.data.experimental.cardinality 确定验证集中有多少批次的数据,然后将其中的 20% 移至测试集。 

val_batches = tf.data.experimental.cardinality(val_ds)
test_ds     = val_ds.take(val_batches // 5)
val_ds      = val_ds.skip(val_batches // 5)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))

 一共有猫、狗两类

class_names = train_ds.class_names
print(class_names)
AUTOTUNE = tf.data.AUTOTUNE

def preprocess_image(image,label):
    return (image/255.0,label)

# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds  = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
plt.figure(figsize=(15, 10))  # 图形的宽为15高为10

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(5, 8, i + 1) 
        plt.imshow(images[i])
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

二、数据增强

我们可以使用 tf.keras.layers.experimental.preprocessing.RandomFliptf.keras.layers.experimental.preprocessing.RandomRotation 进行数据增强

  • tf.keras.layers.experimental.preprocessing.RandomFlip:水平和垂直随机翻转每个图像。
  • tf.keras.layers.experimental.preprocessing.RandomRotation:随机旋转每个图像
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

 第一个层表示进行随机的水平和垂直翻转,而第二个层表示按照 0.2 的弧度值进行随机旋转。

 data_augmentation的定义,这是一个数据增强层的序列模型

# Add the image to a batch.
image = tf.expand_dims(images[i], 0)
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = data_augmentation(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0])
    plt.axis("off")

 

三、增强方式

方法一:将其嵌入model中

model = tf.keras.Sequential([
  data_augmentation,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
])
这样做的好处是:
● 数据增强这块的工作可以得到GPU的加速(如果你使用了GPU训练的话)
注意:只有在模型训练时(Model.fit)才会进行增强,在模型评估(Model.evaluate)以及预测(Model.predict)时并不会进行增强操作。

 方法二:在Dataset数据集中进行数据增强

batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds):
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    return ds

 

 四、训练模型

model = tf.keras.Sequential([
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(len(class_names))
])
在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:
● 损失函数(loss):用于衡量模型在训练期间的准确率。
● 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
● 评价函数(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

 开始训练

epochs=20
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

 

准确率

loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)

 

五、自定义增强函数 

import random
# 这是大家可以自由发挥的一个地方
def aug_img(image):
    seed = (random.randint(0,9), 0)
    # 随机改变图像对比度
    stateless_random_brightness = tf.image.stateless_random_contrast(image, lower=0.1, upper=1.0, seed=seed)
    return stateless_random_brightness

 

image = tf.expand_dims(images[3]*255, 0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = aug_img(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0].numpy().astype("uint8"))

    plt.axis("off")

 那么如何将自定义增强函数应用到我们数据上呢?请参考上文的 preprocess_image 函数,将 aug_img 函数嵌入到 preprocess_image 函数中(函数在加载数据部分),在数据预处理时完成数据增强就OK啦

总结

数据增强有着关键的作用,本文讲述了两种方式,三种方法,方式有嵌入到模型中进行数据增强,好处是能获得GPU加速,但是只能在训练阶段增强,第二种方式就是可以单独拿出一个数据增强模块,在数据集中进行增强,设置一个序列模型sequential,里面存有各种数据增强方法。方法有随机翻转,垂直的水平的,还有固定角度翻转,和随机改变图像对比度。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2110613.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【NLP自然语言处理】文本处理的基本方法

目录 🍔什么是分词 🍔中文分词工具jieba 2.1 jieba的基本特点 2.2 jieba的功能 2.3 jieba的安装及使用 🍔什么是命名实体识别 🍔什么是词性标注 🍔小结 学习目标 🍀 了解什么是分词, 词性标注, 命名…

Java笔试面试题AI答之JDBC(3)

文章目录 13. 编写JDBC连Oracle的程序?14. 简述JDBC的主要组件有哪些 ?15. JDBC中如何防止SQL注入攻击?1. 使用预处理语句(PreparedStatement)2. 避免在SQL查询中直接拼接用户输入的数据总结 16. JDBC的脏读是什么?哪…

Windows下Python和PyCharm的应用(一)__第一个测试程序

1、下载Python安装包 直接从官网下载,百度里搜出来的Python下载,很多是别的公司的商业广告,千万要注意,不要乱点进去,免得浪费时间。 从官网下载,链接:Download Python | Python.org 2、安装Pyt…

网络编程day03(网络体系结构、调试命令、TCP/IP对比)

目录 1》网络的体系结构 1> OSI模型 2> TCP/IP模型 3> 常见网络协议 4> DNS域名解析协议 2》 网络调试命令 1> ping:测试网络连通性(ICMP) 2> netstat 3》Dos (拒绝式服务)攻击?…

怎么在mathtype中打空格 MathType空格键不能用

MathType是一款数学公式编辑器,可以帮助用户创建复杂的数学公式和方程式。它提供了一个用户友好的界面,使得编辑和排版数学公式变得更加容易和高效。用户可以直接在其界面中输入公式,也可以将已有的公式从其他文档中复制粘贴过来进行编辑。在…

【2024数模国赛赛题思路公开】国赛B题第二套思路丨附可运行代码丨无偿自提

2024年数模国赛B题解题思路 B 题 生产过程中的决策问题 一、问题1解析 问题1的任务是为企业设计一个合理的抽样检测方案,基于少量样本推断整批零配件的次品率,帮助企业决定是否接收供应商提供的这批零配件。具体来说,企业需要依据两个不同…

秋燥拜拜,中秋润起来,酒茶香中秋有“礼”

话说这初秋啊,真是个让人又爱又恨的季节! 爱它的秋高气爽,恨它的天干物燥。就像是我们刚刚结束了一个炎热的夏天,身体还没来得及适应,就被秋天的干燥给来了个“突然袭击”。鼻子干、嘴唇干、喉咙干,感觉整个…

hcip什么时候考试?一文带您了解hcip考试报名与预约流程

其他考试一般都会有固定的时间,但hcip不一样,它的考试时间并不固定,这就让考生很是疑惑:hcip什么时候考试呢?除了知道考试时间之外,还要了解hcip的报名条件、报名流程等相关内容。关于这些问题的答案,小编…

blender图像如何分层导出?blender动画云渲染

在blender渲染时产品会被其他物体影响,这时候就可以用到blender中的阻隔;分层导出图像到PS中进行校色等后期处理。 在分层前,我们需要先打开渲染属性-胶片-透明,这样导出的图像才是透明背景的,反之会变成黑色底。 第一…

使用GPU加速及配置

配置CUDA 英伟达 https://developer.nvidia.com/cuda-downloadsPython python要求3.8.x版本以上 python下载 https://www.python.org/getit/使用pytorch 查询地址: https://pytorch.org/index.html给出建议: 可以直接 pip3 install torch torchv…

如何实现思维导图简单漂亮?其实并不难

如何实现思维导图简单漂亮?思维导图是一种非常有效的思考和组织工具,它通过图形化的方式帮助我们梳理信息、激发创意。一个简单又漂亮的思维导图不仅能提高工作效率,还能让人赏心悦目。为了帮助你在学习和工作中更加得心应手,下面…

win10任务栏颜色怎么调?分享几个简单操作,附上详细图文教程

win10任务栏颜色怎么调?相信在大家的日常生活中,电脑是不可或缺的一部分。我们平时上班都需要使用到电脑,最近有个小伙伴问:Win10系统的任务栏颜色可以调整吗?答案当然是可以的。任务栏颜色调整可以帮助我们提高桌面美…

828华为云征文|部署电影收藏管理器 Radarr

828华为云征文|部署电影收藏管理器 Radarr 一、Flexus云服务器X实例介绍1.1 云服务器介绍1.2 应用场景1.3 性能模式 二、Flexus云服务器X实例配置2.1 重置密码2.2 服务器连接2.3 安全组配置 三、部署 Radarr3.1 Radarr 介绍3.2 Docker 环境搭建3.3 Radarr 部署3.4 R…

C++11: 智能指针(unique_ptr,shared_ptr和weak_ptr的使用及简单实现)

目录 1. 为何需要智能指针? 1.1 抛异常场景 1.2 什么是内存泄漏 2. 智能指针的原理 2.1 RAII技术 2.2 补充实现 3. auto_ptr 4. unique_ptr 4.1 使用及原理 4.2 定制删除器 5. shared_ptr 5.1 shared_ptr简介及使用 5.2 shared_ptr简单实现 5.2.1 基本…

CentOS7 部署 Zabbix 监控平台———监控网络设备,Linux 主机、Windows 主机

Node 有自己的配置文件和数据库,其要做的是将配置信息和监控数据向 Master 同步。 当 Master 发生故障或损坏, Node 可以保证架构的完整性。 3)Server-Prxoy-Client 架构 Proxy 是 Server、Client 之间沟通的桥梁,Proxy 本身没…

Cortex-M3架构学习:存储器系统

存储系统功能 CM3 的存储器系统与从传统 ARM 架构的相比,进行如下改革: 它的存储器映射是预定义的,并且还规定好了哪个位置使用哪条总线。 CM3 的存储器系统支持所谓的“位带”(bit-band)操作。通过它,实…

超详细!!!electron-vite-vue开发桌面应用之创建新窗口以及主进程和子进程的通信监听(十二)

云风网 云风笔记 云风知识库 一、新建打开窗口 1、在electron/main.ts中加入主进程打开窗口逻辑代码 import { ipcMain } from "electron"; ipcMain.handle("open-win", (_, arg) > {const childWindow new BrowserWindow({webPreferences: {preloa…

代码执行漏洞-Log4j2漏洞

1.执行以下命令启动靶场环境并在浏览器访问 cd log4j/CVE-2021-44228docker-compose up -ddocker ps 2.先在自己搭建的DNSLOG平台上获取⼀个域名来监控我们注⼊的效果 3.可以发现 /solr/admin/cores?action 这⾥有个参数可以传,可以按照上⾯的原理 先构造⼀个请求…

TomCat环境配置(实验报告)

实验 Tomcat 实验环境配置 一、实验目的 1、掌握Tomcat的安装和启动 2、掌握在IntelliJ IDEA中配置Tomcat的方法 二、实验环境 1、硬件:PC电脑一台,网络正常; 2、配置:Windows10系统,内存8G及以上,硬盘…

LAN变压器的DCR

在变压器技术中,DCR代表直流电阻(DC Resistance)。它是变压器线圈在直流条件下测得的电阻值,通常用来评估变压器的质量和效率。直流电阻是线圈材料和尺寸的一个函数,它与变压器线圈的发热量和功率损耗直接相关。在变压…