Vitis AI 基本认知(Tiny-VGG 项目代码详解)

news2025/1/23 12:18:55

目录

1. 简介

1.1 Tiny-VGG

1.2 data 目录结构

2. 代码分析

2.1 Import packages

2.2 Dataset

2.3 Train step

2.4 Vali & Test step

2.5 Ceate model

2.6 Compile model

2.6.1 计算 loss

2.6.2 计算平均值

3.6.3 计算准确度

2.7 训练循环

2.7.1 自定义训练循环

2.7.2 使用 model.fit()

2.8 早停法

2.9 Save model

2.10 Test info

3. 通用功能

3.1 @tf.function

3.2 数据集分配

3.3 String formatting

3.4 数据集加载器

3.4.1 函数定义

3.4.2 加载单个目录

3.4.3 拆分训练/验证集

3.4.4 输入归一化

3.5 修改数据集

3.5.1 修正目录

3.5.2 移动图片文件

3.5.3 批量重命名

3.6 访问数据集

4. 总结


1. 简介

1.1 Tiny-VGG

本文分享 Tiny-VGG 项目的代码解析,对于冗余部分,进行了删减。主要内容包括:

  • 图像数据预处理
  • 手动计算:训练损失、训练准确度、验证损失、验证准确度
  • 自定义训练循环
  • 数据集加载器
  • 修正目录
  • 计算图

1.2 data 目录结构

├── class_10_train
│   ├── n07920052 # 咖啡
│   ├── n02509815 # 小熊猫
│   ├── n07873807 # 披萨
│   ├── n03662601 # 救生艇
│   ├── n04146614 # 校车
│   ├── n07747607 # 橙子
│   ├── n07720875 # 灯笼椒
│   ├── n02165456 # 瓢虫
│   ├── n01882714 # 考拉
│   └── n04285008 # 跑车
├── class_10_val
│   ├── test_images
│   └── val_images
├── class_dict_10.json
└── val_class_dict_10.json

class_10_train 目录下,有10个分类,每个分类包含500个图像,共计5000个图像。

test_images 目录下,测试集和验证集各有250个图像。

修改目录结构,以方便使用 tf.keras.preprocessing.image_dataset_from_directory 原生函数:

├── class_10_train
│   ├── 橙子
│   ├── 灯笼椒
│   ├── 救生艇
│   ├── 咖啡
│   ├── 考拉
│   ├── 跑车
│   ├── 披萨
│   ├── 瓢虫
│   ├── 小熊猫
│   └── 校车
└── class_10_val
    ├── test_images
    │   ├── 橙子
    │   ├── 灯笼椒
    │   ├── 救生艇
    │   ├── 咖啡
    │   ├── 考拉
    │   ├── 跑车
    │   ├── 披萨
    │   ├── 瓢虫
    │   ├── 小熊猫
    │   └── 校车
    └── val_images
        ├── 橙子
        ├── 灯笼椒
        ├── 救生艇
        ├── 咖啡
        ├── 考拉
        ├── 跑车
        ├── 披萨
        ├── 瓢虫
        ├── 小熊猫
        └── 校车

2. 代码分析

2.1 Import packages

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Activation
from tensorflow.keras import Model, Sequential
from time import time

 由于使用了简化的图像预处理,大量库不需要了,原库的简单介绍:

  • numpy (np): 提供高性能的多维数组处理及数学函数操作。
  • pandas (pd): 用于数据分析和操作的库,特别擅长处理表格数据。
  • re: 提供正则表达式的功能,用于字符串搜索和复杂的文本分析。
  • shutil: 文件操作工具,能进行文件复制、移动、删除等。
  • glob: 用于文件路径的模式匹配,可以找到符合特定规则的文件路径名。
  • json (load, dump): load用于读取JSON文件,dump用于将Python对象写入JSON文件。
  • os.path: 提供了一系列方法来处理和操作文件路径。
  • time: 提供时间相关功能,如获取当前时间和测量运行时间等。

2.2 Dataset

1). 加载数据:image_dataset_from_directory。

2). 定义预处理函数:prepare_ds(),完成归一化,和 one-hot 标签编码。

3). 应用预处理函数:使用 Dataset 的 .map() 方法来调用 prepare_ds()。

train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    './dataset/class_10_train/',
    image_size=(64, 64),
    batch_size=32)

vali_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    './dataset/class_10_val/val_images/',
    image_size=(64, 64),
    batch_size=32)

test_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    './dataset/class_10_val/test_images/',
    image_size=(64, 64),
    batch_size=32)

normalization_layer = tf.keras.layers.Rescaling(1./255)

def prepare_ds(image, label):
    image = normalization_layer(image)
    label = tf.one_hot(label, 10)
    return image, label

train_dataset = train_dataset.map(prepare_ds)
vali_dataset = vali_dataset.map(prepare_ds)
test_dataset = test_dataset.map(prepare_ds)

2.3 Train step

@tf.function
def train_step(image_batch, label_batch):
    # 使用tf.GradientTape来记录梯度信息
    with tf.GradientTape() as tape:
        predictions = tiny_vgg(image_batch)

        loss = loss_object(label_batch, predictions)
        # 计算关于模型可训练参数的损失梯度
        gradients = tape.gradient(loss, tiny_vgg.trainable_variables)
        # 应用梯度更新到模型的可训练参数
        optimizer.apply_gradients(zip(gradients, tiny_vgg.trainable_variables))

        train_mean_loss(loss)
        train_accuracy(label_batch, predictions)

2.4 Vali & Test step

@tf.function
def vali_step(image_batch, label_batch):
    predictions = tiny_vgg(image_batch)
    # 计算损失值
    vali_loss = loss_object(label_batch, predictions)

    vali_mean_loss(vali_loss)
    vali_accuracy(label_batch, predictions)
@tf.function
def test_step(image_batch, label_batch):
    predictions = tiny_vgg(image_batch)
    # 计算损失值
    test_loss = loss_object(label_batch, predictions)

    test_mean_loss(test_loss)
    test_accuracy(label_batch, predictions)

2.5 Ceate model

# Create an instance of the model
filters = 10
tiny_vgg = Sequential([

    Conv2D(filters, (3, 3), input_shape=(64, 64, 3), name='conv_1_1'),
    Activation('relu', name='relu_1_1'),

    Conv2D(filters, (3, 3), name='conv_1_2'),
    Activation('relu', name='relu_1_2'),
    MaxPool2D((2, 2), name='max_pool_1'),

    Conv2D(filters, (3, 3), name='conv_2_1'),
    Activation('relu', name='relu_2_1'),

    Conv2D(filters, (3, 3), name='conv_2_2'),
    Activation('relu', name='relu_2_2'),
    MaxPool2D((2, 2), name='max_pool_2'),

    Flatten(name='flatten'),
    Dense(NUM_CLASS, activation='softmax', name='output')
])

模型信息: 

tiny_vgg.summary()
---
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv_1_1 (Conv2D)           (None, 62, 62, 10)        280       
 relu_1_1 (Activation)       (None, 62, 62, 10)        0         
 conv_1_2 (Conv2D)           (None, 60, 60, 10)        910       
 relu_1_2 (Activation)       (None, 60, 60, 10)        0         
 max_pool_1 (MaxPooling2D)   (None, 30, 30, 10)        0         
 conv_2_1 (Conv2D)           (None, 28, 28, 10)        910       
 relu_2_1 (Activation)       (None, 28, 28, 10)        0         
 conv_2_2 (Conv2D)           (None, 26, 26, 10)        910       
 relu_2_2 (Activation)       (None, 26, 26, 10)        0         
 max_pool_2 (MaxPooling2D)   (None, 13, 13, 10)        0         
 flatten (Flatten)           (None, 1690)              0         
 output (Dense)              (None, 10)                16910     
=================================================================
Total params: 19,920
Trainable params: 19,920
Non-trainable params: 0
_________________________________________________________________

2.6 Compile model

loss_object = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)

tiny_vgg.compile(optimizer=optimizer, loss=loss_object)

train_mean_loss = tf.keras.metrics.Mean(name='train_mean_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

vali_mean_loss = tf.keras.metrics.Mean(name='vali_mean_loss')
vali_accuracy = tf.keras.metrics.CategoricalAccuracy(name='vali_accuracy')

手动计算:训练损失、训练准确度、验证损失、验证准确度。

2.6.1 计算 loss

有3个分类,进行两次推理预测(批量大小为2),预测的概率分布如下:

  • 真实标签 y_true: [0, 1, 0], [0, 0, 1]
  • 预测 y_pred: [0.05, 0.95, 0], [0.1, 0.8, 0.1]
y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]

cce = tf.keras.losses.CategoricalCrossentropy()
print(cce(y_true[0], y_pred[0]))
print(cce(y_true[1], y_pred[1]))
print(cce(y_true, y_pred))
---
tf.Tensor(0.051293306, shape=(), dtype=float32)
tf.Tensor(2.3025851, shape=(), dtype=float32)
tf.Tensor(1.1769392, shape=(), dtype=float32) # 本批量的平均值

2.6.2 计算平均值

记录每个 batch 的 loss,在整个 epoch 结束后,可得 loss 平均值:

my_mean = tf.keras.metrics.Mean()
for i in range(4):
    print(i)
    my_mean(i)
print(my_mean.result())

train_mean_loss.reset_states()
print(my_mean.result())
---
0
1
2
3
tf.Tensor(1.5, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)

通过调用 train_mean_loss.reset_states(),在每个 epoch 后重置 my_mean 状态,以便在下一个 epoch 中重新计算平均值。

3.6.3 计算准确度

计算准确度和计算 loss 的输入参数完全一样:

y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]

accuracy_object = tf.keras.metrics.CategoricalAccuracy()
accuracy = accuracy_object(y_true, y_pred)

print("Categorical Acurracy:", accuracy.numpy())
---
Categorical Acurracy: 0.5

2.7 训练循环

本例使用自定义训练循环,即开发者自己编写训练过程的代码,而不是使用框架提供的高级接口如 Keras 的 model.fit()。

# Initialize early stopping parameters
no_improvement_epochs = 0
best_vali_loss = np.inf
best_epoch = 0
start_time = time()
print('Start training.\n')

for epoch in range(EPOCHS):
    # Train
    for image_batch, label_batch in train_dataset:
        train_step(image_batch, label_batch)

    # Predict on the vali dataset
    for image_batch, label_batch in vali_dataset:
        vali_step(image_batch, label_batch)

    template = 'epoch: {:>3}, train loss: {:.4f}, train accuracy: {:>8.4f}, '
    template += 'vali loss: {:.4f}, vali accuracy: {:>8.4f}'
    print(template.format(epoch + 1,
                          train_mean_loss.result(),
                          train_accuracy.result() * 100,
                          vali_mean_loss.result(),
                          vali_accuracy.result() * 100))

    # Early stopping
    if vali_mean_loss.result() < best_vali_loss:
        no_improvement_epochs = 0
        best_vali_loss = vali_mean_loss.result()
        # Save the best model
        best_epoch = epoch + 1
        tiny_vgg.save('trained_vgg_best.h5')
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= PATIENCE:
        print('Early stopping at epoch = {}'.format(epoch))
        print('Best epoch = {}'.format(best_epoch))
        break

    # Reset evaluation metrics
    train_mean_loss.reset_states()
    train_accuracy.reset_states()
    vali_mean_loss.reset_states()
    vali_accuracy.reset_states()

print('\nFinished training, used {:.4f} mins.'.format((time() - start_time) / 60))

自定义训练循环和使用 model.fit() 方法关键区别

2.7.1 自定义训练循环

1). 优点

  • 灵活性:可以完全控制训练过程,包括前向传播、反向传播、梯度更新等。
  • 自定义逻辑:可以轻松添加自定义的训练步骤、验证步骤、早停机制、学习率调度等。
  • 调试和监控:可以更细粒度地监控和调试训练过程中的各个部分。

2). 缺点

  • 复杂性:需要编写更多的代码,理解和实现训练过程中的每个细节。
  • 易错性:由于需要手动实现很多步骤,容易引入错误。

2.7.2 使用 model.fit()

1). 优点

  • 简洁性:只需几行代码即可完成训练,适合快速原型开发和实验。
  • 内置功能:包含了许多常用的功能,如早停、学习率调度、数据增强等。
  • 稳定性:由框架提供的标准方法,经过广泛测试和优化,减少了出错的可能性。

2). 缺点

  • 灵活性有限:对于一些复杂的自定义需求,可能无法满足。
  • 黑箱操作:内部实现细节对用户不可见,调试和优化可能受到限制。

示例:

early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=50)
model.fit(train_dataset, epochs=EPOCHS, validation_data=vali_dataset, callbacks=[early_stopping])

2.8 早停法

早停法(Early Stopping),监控验证集的性能,当验证集的性能在若干个迭代中不再提升时,训练停止。

1). 训练循环

for epoch in range(EPOCHS):
    # Train
    for image_batch, label_batch in train_dataset:
        train_step(image_batch, label_batch)
    # Predict on the test dataset
    for image_batch, label_batch in vali_dataset:
        vali_step(image_batch, label_batch)
    # 早停机制
    ...
    train_mean_loss.reset_states()
    train_accuracy.reset_states()
    vali_mean_loss.reset_states()
    vali_accuracy.reset_states()
  • 外层循环遍历每个训练轮次(epoch)。
  • 内层循环遍历训练数据集的每个批次(batch),并调用 train_step 函数进行训练。 
  • 遍历验证数据集的每个批次,并调用 vali_step 函数进行验证。

2). 早停机制

初始化早停参数(在训练循环之前进行):

  • no_improvement_epochs:记录验证集损失没有改善的连续轮数。
  • best_vali_loss:记录验证集的最佳损失值,初始值为正无穷大。
  • start_time:记录训练开始的时间。
if vali_mean_loss.result() < best_vali_loss:
    no_improvement_epochs = 0
    best_vali_loss = vali_mean_loss.result()
    # Save the best model
    tiny_vgg.save('trained_vgg_best.h5')
else:
    no_improvement_epochs += 1

if no_improvement_epochs >= PATIENCE:
    print('Early stopping at epoch = {}'.format(epoch))
    break
  • 如果当前验证损更优(小于记录的最佳验证损失),则用当前验证损作为最佳验证损失,并将 no_improvement_epochs 置为 0,同时保存当前模型。
  • 否则,no_improvement_epochs 增加 1。
  • 如果 no_improvement_epochs 达到预设的耐心值(PATIENCE),则触发早停,停止训练。

3). 打印信息

训练循环中:打印当前轮次的训练损失、训练准确率、验证损失和验证准确率。

template = 'epoch: {:>3}, train loss: {:.4f}, train accuracy: {:>8.4f}, '
template += 'vali loss: {:.4f}, vali accuracy: {:>8.4f}'
print(template.format(epoch + 1,
                      train_mean_loss.result(),
                      train_accuracy.result() * 100,
                      vali_mean_loss.result(),
                      vali_accuracy.result() * 100))

训练结束时:打印用时信息时间。

print('\nFinished training, used {:.4f} mins.'.format((time() - start_time) / 60))

2.9 Save model

# Save trained model
tiny_vgg.save('trained_tiny_vgg.h5')
tiny_vgg = tf.keras.models.load_model('trained_vgg_best.h5')

1). trained_tiny_vgg.h5

  • 这是在训练结束时保存的模型文件。
  • 它包含了整个训练过程中的最终模型参数。
  • 不论模型在训练过程中表现如何,这个文件都会在训练结束时保存。

2). trained_vgg_best.h5

  • 这是在训练过程中验证集损失达到最小时保存的模型文件。
  • 它代表了训练过程中性能最好的模型参数。
  • 只有当验证集损失比之前的最小值更小时,才会更新并保存这个文件。

2.10 Test info

# Test on hold-out test images
test_mean_loss = tf.keras.metrics.Mean(name='test_mean_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')

for image_batch, label_batch in test_dataset:
    test_step(image_batch, label_batch)

template = '\ntest loss: {:.4f}, test accuracy: {:.4f}'
print(template.format(test_mean_loss.result(), test_accuracy.result() * 100))

3. 通用功能

3.1 @tf.function

计算图的组成

1). 节点(Nodes):每个节点代表一个操作(例如加法、乘法等)或者一个函数(例如一个完整的神经网络层)。节点可以接收输入,并产生输出。

2). 边(Edges):图中的边表示数据流,即一个节点的输出可以成为另一个节点的输入。

计算图的优势

1). 优化:在执行前,TensorFlow 可以优化计算图,比如删除不必要的操作、合并一些操作等,这可以提高代码的运行效率。

2). 并行处理:计算图使 TensorFlow 能够自动并行地处理那些独立的操作。例如,如果图中有两个节点,它们不依赖于彼此的数据,那么这两个操作可以在不同的处理器上同时执行。

3). 跨平台部署:由于计算图是一种语言和平台无关的表达方式,它可以使得在不同的设备和平台上运行相同的模型成为可能,无论是在服务器上的高性能 GPU 还是在手机上的低功耗处理器。

3.2 数据集分配

训练集、验证集和测试集的数据量比例没有一个固定的标准,它可以根据具体的项目需求、数据的总量、以及模型的复杂性来调整。

有一些常见的比例可以作为起点,根据需要进行调整:

1). 常见的分割比例:

  • 60% / 20% / 20%:这是一个常见的比例,其中60%的数据用作训练集,20%作为验证集,另外20%作为测试集。
  • 70% / 15% / 15%:在这种情况下,70%的数据用于训练,而验证集和测试集各占15%。
  • 80% / 10% / 10%:对于数据量较大的情况,可能会将更大的比例分配给训练集,而将较小的比例用于验证和测试。

2). 调整因素:

  • 数据总量:如果数据集非常大,可能不需要将太多的数据用于验证和测试。例如,对于包含数百万样本的数据集,10%或更少的数据用于测试可能已经足够。
  • 模型复杂度:更复杂的模型可能需要更多的数据来训练,以避免过拟合。
  • 任务的特性:某些任务可能需要更频繁的模型验证,这可能意味着需要一个较大的验证集来调整模型参数。

3). 特殊情况:

  • k-折交叉验证:在数据量较少或者需要更可靠的模型评估时,可以使用交叉验证。这种方法不需要单独的验证集和测试集,而是将数据集分为k个较小的子集。模型训练k次,每次使用不同的子集作为测试集,其余作为训练集,然后平均结果以评估模型性能。

3.3 String formatting

字符串格式化是一种将变量的值插入到字符串中的方法。

template = 'epoch: {}, train loss: {:.4f}, train accuracy: {:.4f}, '
template += 'vali loss: {:.4f}, vali accuracy: {:.4f}'
print(template.format(epoch + 1,
                      train_mean_loss.result(),
                      train_accuracy.result() * 100,
                      vali_mean_loss.result(),
                      vali_accuracy.result() * 100))

效果如下:

epoch:   1, train loss: 2.3263, train accuracy:  11.0400, vali loss: 2.2943, vali accuracy:  10.4000
epoch:   2, train loss: 2.2955, train accuracy:  11.7600, vali loss: 2.2861, vali accuracy:  10.0000
epoch:   3, train loss: 2.2830, train accuracy:  13.0400, vali loss: 2.2731, vali accuracy:  10.4000
epoch:   4, train loss: 2.2706, train accuracy:  13.5200, vali loss: 2.2660, vali accuracy:  13.2000
epoch:   5, train loss: 2.2573, train accuracy:  14.7200, vali loss: 2.2530, vali accuracy:  14.0000
...

常用的字符串格式表示方式:

1. 整数:'{:d}'.format(42)  # 输出: '42'

2. 浮点数:'{:.2f}'.format(3.14159)  # 输出: '3.14'

3. 百分比:'{:.2%}'.format(0.75)  # 输出: '75.00%'

4. 科学计数法:'{:.2e}'.format(123456789)  # 输出: '1.23e+08'

5. 填充与对齐

  • '{:>10}'.format('test')  # 右对齐,输出: '      test'
  • '{:<10}'.format('test')  # 左对齐,输出: 'test      '
  • '{:^10}'.format('test')  # 居中对齐,输出: '   test   '

6. 千位分隔符:'{:,}'.format(1234567890)  # 输出: '1,234,567,890'

7. 进制表示

  • '{:b}'.format(42)  # 二进制,输出: '101010'
  • '{:o}'.format(42)  # 八进制,输出: '52'
  • '{:x}'.format(42)  # 十六进制,输出: '2a'

注意:{} 并不是默认表示整数,而是一个通用的占位符。它可以用于任何类型的数据,具体的格式取决于传入的值。例如:

'{}'.format(42)  # 输出: '42' (整数)
'{}'.format(3.14)  # 输出: '3.14' (浮点数)
'{}'.format('hello')  # 输出: 'hello' (字符串)

3.4 数据集加载器

3.4.1 函数定义

train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    directory,   # 目录路径
    labels='inferred',   # 标签推断方式
    label_mode='int',    # 标签的形式,'int' 表示整数标签
    class_names=None,    # 类别的名字列表,通常自动推断
    color_mode='rgb',    # 图像颜色模式,'rgb' 或 'grayscale'
    batch_size=32,       # 批量大小
    image_size=(256, 256),  # 图像尺寸
    shuffle=True,         # 是否打乱数据
    seed=None,            # 随机种子
    validation_split=None, # 数据集划分比例(训练集/验证集)
    subset=None,          # 指定是训练集 'training' 还是验证集 'validation'
    interpolation='bilinear', # 插值方法
    follow_links=False,     # 是否跟踪符号链接
    crop_to_aspect_ratio=False # 是否将图像裁剪到原始宽高比
)

参数解释:

1). directory: 字符串,目标目录路径。

2). labels: 标签生成方式。默认为'inferred',从目录结构推断标签(每个子目录的名字是标签)。如果设置为 None,则不返回任何标签(用于无监督学习)。

3). label_mode:

  • 'int':返回整数标签(默认)。
  • 'categorical':返回独热编码标签;
  • 'binary':返回二进制标签。

4). class_names: 类别名称的显式列表,仅在 labels 为 'inferred' 时有效。

5). color_mode: 图像的颜色模式,可以是 'grayscale'、'rgb' 或 'rgba'。

6). batch_size: 每个批次的样本数。

7). image_size: 图像尺寸,以(height, width)格式表示。

8). shuffle: 是否打乱数据,默认值为 True。

9). seed: 随机种子,用于打乱数据和进行分割。

10). validation_split: 取值(0,1)的小数,用于创建一个验证集。指定这个参数后,必须通过 subset 参数指明是请求训练集('training')还是验证集('validation')。

11). subset: 指定 'training' 或 'validation' 来返回相应的数据集分割。

12). interpolation: 图像缩放时使用的插值方法,如 'bilinear' 或 'nearest'。

13). follow_links: 是否跟踪目录中的符号链接。

14). crop_to_aspect_ratio: 是否保持图像的原始宽高比。

3.4.2 加载单个目录

# 假设有一个目录结构如下:

/path/to/data/
    dogs/
        dog001.jpg
        dog002.jpg
        ...
    cats/
        cat001.jpg
        cat002.jpg
        ...

可以这样加载数据:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    '/path/to/data/',
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(64, 64),
    batch_size=32)

3.4.3 拆分训练/验证集

image_dataset_from_directory() 方法自身并不直接支持同时划分训练、验证和测试三个数据集,但可以通过设置 validation_split 和 subset 参数来从同一目录中创建训练集和验证集。(通过指定 'training' 或 'validation' 来返回相应的数据集分割)

创建训练集和验证集:

import tensorflow as tf

train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    'path_to_images',
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(256, 256),
    batch_size=32)

validation_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    'path_to_images',
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(256, 256),
    batch_size=32)

测试集的创建:

test_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    'path_to_test_images',
    seed=123,
    image_size=(256, 256),
    batch_size=32)

正确使用 validation_split, subset, 和 seed 参数,image_dataset_from_directory 函数会确保同一图片不会同时出现在训练集和验证集中。

3.4.4 输入归一化

使用 tf.keras.preprocessing.image_dataset_from_directory 方法加载的图像数据,是范围在[0, 255] 的 float32 格式的数据:

取一个图片,查看第一个像素,查看其数据类型:

for images, labels in vali_dataset.take(1):
    image = images[0]
    label = labels[0]

print(image[0][0])
print(label)
---
tf.Tensor([95. 88. 78.], shape=(3,), dtype=float32)
tf.Tensor(7, shape=(), dtype=int32)

 可以通过使用 tf.keras.layers.Rescaling 将值标准化为在 [0, 1] 范围内:

normalization_layer = tf.keras.layers.Rescaling(1./255)
dataset = dataset.map(lambda x, y: (normalization_layer(x), y))

归一化后,查看其数据类型:

tf.Tensor([0.454902   0.4431373  0.42352945], shape=(3,), dtype=float32)
tf.Tensor(9, shape=(), dtype=int32)

3.5 修改数据集

3.5.1 修正目录

在 data 目录下创建 ipython 文件,在单元格执行魔术指令:

%%bash

# 删除各个分类(n01882714等)中的 .txt 文件
find . -name "*.txt" -type f -delete

# 查找并删除所有 .ipynb_checkpoints 文件夹
find . -name ".ipynb_checkpoints" -type d -delete

# 将各个分类(n01882714/images/*等)图片上移一层,同时删除 images 目录
cd class_10_train/
for dir in */; do
    mv "${dir}images/"* "$dir"
    rmdir "${dir}images"
done

# 在 test_images 和 val_images 目录中创建分类目录
cd ../class_10_val/test_images/
mkdir n07920052 n02509815 n07873807 n03662601 n04146614 n07747607 n07720875 n02165456 n01882714 n04285008
cd ../val_images/
mkdir n07920052 n02509815 n07873807 n03662601 n04146614 n07747607 n07720875 n02165456 n01882714 n04285008

3.5.2 移动图片文件

在 data 目录下,另建立 ipython 单元格:

import json
import os
import shutil

# 读取JSON文件
json_file_path = './val_class_dict_10.json'
with open(json_file_path, 'r') as file:
    data = json.load(file)

# 遍历 class_10_val/test_images 并移动文件
for image_filename, attributes in data.items():
    source_path = os.path.join('./class_10_val/test_images', image_filename)
    target_dir = os.path.join('./class_10_val/test_images', attributes['class'])
    target_path = os.path.join(target_dir, image_filename)
    if os.path.exists(source_path):
        shutil.move(source_path, target_path)

print(f'Moved class_10_val/test_images')

# 遍历 class_10_val/val_images 并移动文件
for image_filename, attributes in data.items():
    source_path = os.path.join('./class_10_val/val_images', image_filename)
    target_dir = os.path.join('./class_10_val/val_images', attributes['class'])
    target_path = os.path.join(target_dir, image_filename)
    if os.path.exists(source_path):
        shutil.move(source_path, target_path)

print(f'Moved class_10_val/val_images')

3.5.3 批量重命名

在 data 目录下,另建立 ipython 单元格,在单元格执行魔术指令:

%%bash

# 定义文件夹名称映射
declare -A folder_mapping=(
    ["n07920052"]="咖啡"
    ["n02509815"]="小熊猫"
    ["n07873807"]="披萨"
    ["n03662601"]="救生艇"
    ["n04146614"]="校车"
    ["n07747607"]="橙子"
    ["n07720875"]="灯笼椒"
    ["n02165456"]="瓢虫"
    ["n01882714"]="考拉"
    ["n04285008"]="跑车"
)

# 遍历所有子目录并替换文件夹名称
for old_name in "${!folder_mapping[@]}"; do
    new_name=${folder_mapping[$old_name]}
    find . -type d -name "$old_name" -execdir mv {} "$new_name" \;
    echo "Renamed $old_name to $new_name"
done

3.6 访问数据集

在使用 tf.keras.preprocessing.image_dataset_from_directory 加载数据集后,数据集会被封装成 tf.data.Dataset 对象,其中每个元素都是一个批次,包含图像和对应的标签。如果想查看数据集中的一幅图片,可以按照以下步骤操作:

1). 迭代数据集:由于数据集是按批次组织的,你可以通过迭代数据集来访问第一个批次。

2). 选择一幅图片:从批次中选择一幅图片和其标签。

3). 使用 matplotlib 显示图片:利用 matplotlib.pyplot 显示图像。

import tensorflow as tf
import matplotlib.pyplot as plt

# 加载数据集
dataset = tf.keras.preprocessing.image_dataset_from_directory(
    'path/to/directory',
    image_size=(64, 64),
    batch_size=32)

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

查看每批次数据的形状:

for image_batch, label_batch in train_dataset:
    print(image_batch.shape, label_batch.shape)
---
(32, 64, 64, 3) (32, 10)
(32, 64, 64, 3) (32, 10)
(32, 64, 64, 3) (32, 10)
...
...
(32, 64, 64, 3) (32, 10)
(8, 64, 64, 3) (8, 10)

4. 总结

 CNN Explainer 是一个交互式可视化系统,旨在帮助非专业人士学习卷积神经网络(CNN)。这个项目由佐治亚理工学院和俄勒冈州立大学的研究人员合作开发,提供了一个直观的界面,让用户可以通过可视化的方式理解CNN的工作原理。

主要特点包括:

  • 交互式可视化:用户可以通过图形界面直观地观察和理解CNN的各个层次和操作。
  • 教育资源:该工具设计为教育用途,适合教学和自学。
  • 开源项目:代码和资源在GitHub上公开,用户可以自由下载和修改。

论文地址:CNN EXPLAINERicon-default.png?t=N7T8https://arxiv.org/pdf/2004.15004

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2079610.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BLE蓝牙协议详解

BLE蓝牙协议详解 1、BLE协议栈 1、协议栈结构 蓝牙LE协议栈按功能分为三个层&#xff1a;Controller、Host和Application Profiles and Services。 HCI event是按BLE Spec标准设计的&#xff0c;是BLE Controller和Host用来交互的事件&#xff1b;GAP event是BLE host定义的…

环境配置 --- miniconda安装torch报错OSError: [WinError 126] 找不到指定的模块

环境配置 — miniconda安装torch报错OSError: [WinError 126] 找不到指定的模块。 CSDN 原因&#xff1a;fbegmm.dll文件出现问题 解决方案&#xff1a; 使用依赖分析工具https://github.com/lucasg/Dependencies/releases/tag/v1.11.1 检测报错提示的那个dll文件发现哪个文…

Nuclei:开源漏洞扫描器

Nuclei 拥有灵活的模板系统&#xff0c;可以适应各种安全检查。 它可以使用可自定义的模板向多个目标发送请求&#xff0c;确保零误报并实现跨多台主机的快速扫描。 它支持多种协议&#xff0c;包括 TCP、DNS、HTTP、SSL、文件、Whois、Websocket 等。 特征 模板库&#xf…

Java中的定时器(Timer)

目录 一、什么是定时器? 二、标准库中的定时器 三、实现定时器 一、什么是定时器? 定时器就像一个"闹钟"&#xff0c;当它到达设定的时间后&#xff0c;就会执行预定的代码。 例如&#xff0c;我们在TCP的超时重传机制中讲过&#xff0c;如果服务器在规定的时间…

XDMA - AXI4 Memory Mapped

目录 1. What is SG DMA2. Descriptor3. Transfer for H2CStep 1. The host prepares stored data and creates descriptors in main memoryStep 2. The host enables DMA interruptsStep 2. The driver initializes DMA with descriptor start addressStep 3. The driver writ…

数据结构(邓俊辉)学习笔记】串 06——KMP算法:构造next[]表

文章目录 1. 递推2. 算法3. 实现 1. 递推 接下来的这节&#xff0c;我们就来讨论 next 查询表的构造算法。我们将会看到非常有意思是&#xff0c; next 表的构造过程与 KMP 主算法的流程在本质上是完全一样的。 在这里&#xff0c;我们不妨采用递推策略。我们只需回答这样一个…

带你深入浅出新面经:十六、十大排序之快速排序

此为面经第十六谈&#xff01;关注我&#xff0c;每日带你深入浅出一个新面经。 我们要了解面经要如何“说”&#xff01; 很重要&#xff01;很重要&#xff01;很重要&#xff01; 我们通常采取总-分-总方式来阐述&#xff01;&#xff08;有些知识点&#xff0c;你可以去…

python脚本请求数量达到上限,http请求重试问题例子解析

在使用Python的requests库进行HTTP请求时&#xff0c;可能会遇到请求数量达到上限&#xff0c;导致Max retries exceeded with URL的错误。这通常发生在网络连接不稳定、服务器限制请求次数、或请求参数设置错误的情况下。以下是一些解决该问题的策略&#xff1a; 增加重试次数…

【负载均衡式在线OJ】项目设计

文章目录 程序源码用到的技术项目宏观结构代码编写思路 程序源码 https://gitee.com/not-a-stupid-child/online-judge 用到的技术 C STL 标准库。Boost 准标准库(字符串切割)。cpp-httplib 第三方开源网络库。ctemplate 第三方开源前端网页渲染库。jsoncpp 第三方开源序列化…

栈和队列有何区别?

栈和队列是两种常见的数据结构&#xff0c;它们分别用于解决不同类型的问题。在程序设计中&#xff0c;栈和队列都是非常重要的数据结构&#xff0c;因为它们可以帮助我们解决很多实际的问题。 栈&#xff1a; 首先&#xff0c;让我们来讨论栈, 栈是一种后进先出&#xff08;…

学NLP不看这本书等于白学!一书弄懂NLP自然语言处理(附文档)

随着人工智能技术的飞速发展&#xff0c;自然语言处理成为了计算机科学与人工智能领域中不可或缺的关键技术之一。作为一名长期致力于人工智能和自然语言处理研究的学者&#xff0c;今天给大家推荐的这本《自然语言处理&#xff1a;大模型理论与实践》正是学NLP自然语言非常牛逼…

黑神话悟空用什么编程语言

《黑神话&#xff1a;悟空》作为一款备受瞩目的国产单机动作游戏&#xff0c;其背后的开发涉及了多种编程语言和技术。根据公开信息和游戏开发行业的普遍做法&#xff0c;可以推测该游戏主要使用了以下几种编程语言&#xff1a; C&#xff1a; 核心编程语言&#xff1a;作为《黑…

【C++ Primer Plus习题】5.7

问题: 解答: #include <iostream> #include <string> using namespace std;typedef struct _Car {string brand;int year; }Car;int main() {int count 0;cout << "请问你家有多少辆车呢?" << endl;cin >> count;cin.get();Car* ca…

Java 入门指南:Java IO流 —— 序列化与反序列化

序列化 序列化是指将对象转换为字节流的过程&#xff0c;以便能够将其存储到文件、内存、网络传输等介质中&#xff0c;或者在不同的进程、网络或机器之间进行数据交换。 序列化的逆过程称为反序列化&#xff0c;即将字节流转换为对象。过反序列化&#xff0c;可以从存储介质…

【mysql】mysql之索引学习

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》&#xff1a;python零基础入门学习 《python运维脚本》&#xff1a; python运维脚本实践 《shell》&#xff1a;shell学习 《terraform》持续更新中&#xff1a;terraform_Aws学习零基础入门到最佳实战 《k8…

面试搜狐大型模型算法工程师,感受非凡体验!

搜狐大模型算法工程师面试题 应聘岗位&#xff1a;搜狐大模型算法工程师 面试轮数&#xff1a; 整体面试感觉&#xff1a;偏简单 面试过程回顾 1. 自我介绍 在自我介绍环节&#xff0c;我清晰地阐述了个人基本信息、教育背景、工作经历和技能特长&#xff0c;展示了自信和沟通…

【Office】激活文件无法打开-DragonKMS--解决办法

【解决办法】右键 文件属性>>最下面勾选解除锁定即可打开。 【原因】&#xff1a;网络上下载的文件&#xff08;包括exe、zip等&#xff09;。

vue.js3+element-plus+typescript add,edit,del,search

vite.config.ts server: {cors: true, // 默认启用并允许任何源host: 0.0.0.0, // 这个用于启动port: 5110, // 指定启动端口open: true, //启动后是否自动打开浏览器 proxy: {/api: {target: http://localhost:8081/, //实际请求地址&#xff0c;数据库的rest APIschangeOr…

esp32 控制 st7735s 显示屏(spi)

Lcd初始化后全屏为花屏&#xff0c;必须再把整个屏幕转成全底白色消除花屏后再显示图片&#xff0c;字符。 我理解为什么是花屏&#xff0c;因为只是初始化各个参数&#xff0c;显示内存现在还是为空&#xff0c;还没有执行0x2c命令。 图片 #include "driver/spi_master…

统一 transformer 与 diffusion !Meta 融合新方法剑指下一代多模态王者

本文引入了 Transfusion&#xff0c;这是一种可以在离散和连续数据上训练多模态模型的方法。 来源丨机器之心 一般来说&#xff0c;多模态生成模型需要能够感知、处理和生成离散元素&#xff08;如文本或代码&#xff09;和连续元素&#xff08;如图像、音频和视频数据&#xf…