目录
1. 简介
2. 解析
2.1 获取标签
2.1.1 载入数据集
2.1.2 标签-Index
2.1.3 保存和读取类别标签
2.2 读取单个图片
2.3 载入模型并推理
2.3.1 tiny-vgg 模型结构
2.3.2 运行推理
2.4 置信度柱状图
2.5 预测标签
3. 完整代码
4. 总结
1. 简介
本博文在《Vitis AI 基本认知(Tiny-VGG 项目代码详解)-CSDN博客》基础上,详细介绍如何使用TensorFlow框架进行单个图片的推理,从获取和处理数据集的标签开始,到模型的加载与推理,再到结果的可视化展示。关键信息如下:
- 获取数据集的标签
- 保存和读取类别标签
- 加载模型并推理
- 绘制图像
- 使用中文标签
- 置信度柱状图
2. 解析
2.1 获取标签
2.1.1 载入数据集
通过 image_dataset_from_directory 方法
vali_dataset = tf.keras.preprocessing.image_dataset_from_directory(
'./dataset/class_10_val/val_images/',
image_size=(64, 64),
batch_size=32)
取出一个图片,并查看其标签:
for images, labels in vali_dataset.take(1):
# 取出第一个图片和标签
image = images[0].numpy().astype("uint8")
label = labels[0].numpy()
# 显示图片
plt.figure(figsize=(2, 2))
plt.imshow(image)
plt.title(f"Label: {label}")
plt.axis('off')
plt.show()
2.1.2 标签-Index
查看类别标签及其 Index:
class_names = vali_dataset.class_names
for i, class_name in enumerate(class_names):
print(f"Class name: {class_name:<4}, Index: {i}")
---
Class name: 咖啡 , Index: 0
Class name: 小熊猫 , Index: 1
Class name: 披萨 , Index: 2
Class name: 救生艇 , Index: 3
Class name: 校车 , Index: 4
Class name: 橙子 , Index: 5
Class name: 灯笼椒 , Index: 6
Class name: 瓢虫 , Index: 7
Class name: 考拉 , Index: 8
Class name: 跑车 , Index: 9
类别标签对应的 one-hot 标签:
for index, class_name in enumerate(class_names):
one_hot = tf.one_hot(index, len(class_names)).numpy()
print(f"Class: {class_name}, One-hot: {one_hot}")
---
Class: 咖啡 , One-hot: [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Class: 小熊猫, One-hot: [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
Class: 披萨 , One-hot: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
Class: 救生艇, One-hot: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
Class: 校车 , One-hot: [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
Class: 橙子 , One-hot: [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
Class: 灯笼椒, One-hot: [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
Class: 瓢虫 , One-hot: [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
Class: 考拉 , One-hot: [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
Class: 跑车 , One-hot: [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
2.1.3 保存和读取类别标签
将类别标签写入文本文档:
with open('tiny_VGG_class_names.txt', 'w') as file:
for class_name in class_names:
file.write(f"{class_name}\n")
从文本文档中读取类别标签:
with open('tiny_VGG_class_names.txt', 'r') as file:
class_names = [line.strip() for line in file]
print(class_names)
---
['咖啡', '小熊猫', '披萨', '救生艇', '校车', '橙子', '灯笼椒', '瓢虫', '考拉', '跑车']
2.2 读取单个图片
读取图片,并显示在 Jupyter Lab 中:
img = cv2.imread('./dataset/class_10_val/val_images/橙子/val_1067.JPEG')
plt.figure(figsize=(2, 2))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()
对图片归一化操作:
normalization_layer = tf.keras.layers.Rescaling(1./255)
img_norm = normalization_layer(img)
img_norm = np.expand_dims(img_norm, axis=0)
np.shape(img_norm)
---
(1, 64, 64, 3)
训练过程中,对数据集做了归一化处理,推理时也要做同样的处理。
2.3 载入模型并推理
2.3.1 tiny-vgg 模型结构
# Create an instance of the model
filters = 10
tiny_vgg = Sequential([
Conv2D(filters, (3, 3), input_shape=(64, 64, 3), name='conv_1_1'),
Activation('relu', name='relu_1_1'),
Conv2D(filters, (3, 3), name='conv_1_2'),
Activation('relu', name='relu_1_2'),
MaxPool2D((2, 2), name='max_pool_1'),
Conv2D(filters, (3, 3), name='conv_2_1'),
Activation('relu', name='relu_2_1'),
Conv2D(filters, (3, 3), name='conv_2_2'),
Activation('relu', name='relu_2_2'),
MaxPool2D((2, 2), name='max_pool_2'),
Flatten(name='flatten'),
Dense(NUM_CLASS, activation='softmax', name='output')
])
2.3.2 运行推理
tiny_vgg = tf.keras.models.load_model('trained_vgg_best.h5')
prediction = tiny_vgg.predict(img_norm)
prediction
---
array([[6.2276758e-02, 3.6967881e-03, 9.2534656e-06, 4.8701441e-01,
3.6426269e-02, 2.9939638e-02, 7.1093095e-03, 2.9743392e-02,
2.1278052e-02, 3.2250613e-01]], dtype=float32)
注意:模型的最后一层已经经过 softmax 计算,无需单独调用 softmax 计算概率:
sum = np.sum(prediction)
print(sum)
---
1.0
2.4 置信度柱状图
fig = plt.figure(figsize=(18,6))
# 绘制左图-预测图,调整比例
ax1 = plt.subplot(1,6,1)
ax1.imshow(img)
ax1.axis('off')
# 绘制右图-柱状图,调整比例
ax2 = plt.subplot(1,6,(2,6))
y = prediction[0]
ax2.bar(class_names, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
ax2.set_xticks(x)
ax2.set_xticklabels(class_names, fontproperties=font)
plt.ylim([0, 1.0]) # y轴取值范围
# 显示置信度数值
for i in range(len(y)):
plt.text(i, y[i] + 0.01, f'{y[i]:.2f}', ha='center', fontsize=15)
plt.xlabel('类别', fontsize=20, fontproperties=font)
plt.ylabel('置信度', fontsize=20, fontproperties=font)
ax2.tick_params(labelsize=16)
plt.tight_layout()
2.5 预测标签
predict_label = class_names[np.argmax(prediction)]
print("类别: {}".format(predict_label))
# 显示图片
plt.figure(figsize=(2, 2))
plt.imshow(img)
plt.axis('off')
plt.show()
3. 完整代码
import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import cv2
font = matplotlib.font_manager.FontProperties(fname="./SimHei.ttf")
vali_dataset = tf.keras.preprocessing.image_dataset_from_directory(
'./dataset/class_10_val/val_images/',
image_size=(64, 64),
batch_size=32)
class_names = vali_dataset.class_names
img = cv2.imread('./dataset/class_10_train/橙子/n07747607_0.JPEG')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
tiny_vgg = tf.keras.models.load_model('trained_vgg_best.h5')
prediction = tiny_vgg.predict(img_norm)
fig = plt.figure(figsize=(18,6))
# 绘制左图-预测图,调整比例
ax1 = plt.subplot(1,6,1)
ax1.imshow(img)
ax1.axis('off')
# 绘制右图-柱状图,调整比例
ax2 = plt.subplot(1,6,(2,6))
y = prediction[0]
ax2.bar(class_names, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
ax2.set_xticks(x)
ax2.set_xticklabels(class_names, fontproperties=font)
plt.ylim([0, 1.0]) # y轴取值范围
# 显示置信度数值
for i in range(len(y)):
plt.text(i, y[i] + 0.01, f'{y[i]:.2f}', ha='center', fontsize=15)
plt.xlabel('类别', fontsize=20, fontproperties=font)
plt.ylabel('置信度', fontsize=20, fontproperties=font)
ax2.tick_params(labelsize=16)
plt.tight_layout()
4. 总结
本博文详继续介绍 Tiny-VGG 项目,对模型进行单张图片的推理,关键要点包括:
1). 数据处理与标签管理:通过 image_dataset_from_directory 方法加载数据,并提取类别名称作为标签,同时展示了如何保存和读取类别标签到/从文本文件。
2). 图片预处理:读取单个图片,并对其进行归一化处理,以匹配训练时的数据处理方式,确保模型能正确解读输入数据。
3). 模型加载与推理:加载预训练的Tiny-VGG模型,并对单张图片进行推理,获取预测结果。
4). 结果可视化:通过绘制图片和置信度柱状图来可视化模型的预测结果,使用中文标签和显示每个类别的置信度值。
5). 实用代码示例:提供了完整的代码示例,包括数据加载、模型推理和结果展示,方便读者理解和实际操作。