Vitis AI 基本认知(Tiny-VGG 标签获取+预测后处理)

news2024/9/20 9:25:43

目录

1. 简介

2. 解析

2.1 获取标签

2.1.1 载入数据集

2.1.2 标签-Index

2.1.3 保存和读取类别标签

2.2 读取单个图片

2.3 载入模型并推理

2.3.1 tiny-vgg 模型结构

2.3.2 运行推理

 2.4 置信度柱状图

2.5 预测标签

3. 完整代码

4. 总结


1. 简介

本博文在《Vitis AI 基本认知(Tiny-VGG 项目代码详解)-CSDN博客》基础上,详细介绍如何使用TensorFlow框架进行单个图片的推理,从获取和处理数据集的标签开始,到模型的加载与推理,再到结果的可视化展示。关键信息如下:

  • 获取数据集的标签
  • 保存和读取类别标签
  • 加载模型并推理
  • 绘制图像
  • 使用中文标签
  • 置信度柱状图

2. 解析

2.1 获取标签

2.1.1 载入数据集

通过 image_dataset_from_directory 方法

vali_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    './dataset/class_10_val/val_images/',
    image_size=(64, 64),
    batch_size=32)

取出一个图片,并查看其标签:

for images, labels in vali_dataset.take(1):
    # 取出第一个图片和标签
    image = images[0].numpy().astype("uint8")
    label = labels[0].numpy()

    # 显示图片
    plt.figure(figsize=(2, 2))
    plt.imshow(image)
    plt.title(f"Label: {label}")
    plt.axis('off')
    plt.show()

2.1.2 标签-Index

查看类别标签及其 Index:

class_names = vali_dataset.class_names

for i, class_name in enumerate(class_names):
    print(f"Class name: {class_name:<4}, Index: {i}")
---
Class name: 咖啡   , Index: 0
Class name: 小熊猫 , Index: 1
Class name: 披萨   , Index: 2
Class name: 救生艇 , Index: 3
Class name: 校车   , Index: 4
Class name: 橙子   , Index: 5
Class name: 灯笼椒 , Index: 6
Class name: 瓢虫   , Index: 7
Class name: 考拉   , Index: 8
Class name: 跑车   , Index: 9

类别标签对应的 one-hot 标签:

for index, class_name in enumerate(class_names):
    one_hot = tf.one_hot(index, len(class_names)).numpy()
    print(f"Class: {class_name}, One-hot: {one_hot}")
---
Class: 咖啡  , One-hot: [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Class: 小熊猫, One-hot: [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
Class: 披萨  , One-hot: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
Class: 救生艇, One-hot: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
Class: 校车  , One-hot: [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
Class: 橙子  , One-hot: [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
Class: 灯笼椒, One-hot: [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
Class: 瓢虫  , One-hot: [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
Class: 考拉  , One-hot: [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
Class: 跑车  , One-hot: [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]

2.1.3 保存和读取类别标签

将类别标签写入文本文档:

with open('tiny_VGG_class_names.txt', 'w') as file:
    for class_name in class_names:
        file.write(f"{class_name}\n")

从文本文档中读取类别标签: 

with open('tiny_VGG_class_names.txt', 'r') as file:
    class_names = [line.strip() for line in file]

print(class_names)
---
['咖啡', '小熊猫', '披萨', '救生艇', '校车', '橙子', '灯笼椒', '瓢虫', '考拉', '跑车']

2.2 读取单个图片

读取图片,并显示在 Jupyter Lab 中:

img = cv2.imread('./dataset/class_10_val/val_images/橙子/val_1067.JPEG')

plt.figure(figsize=(2, 2))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()

 对图片归一化操作:

normalization_layer = tf.keras.layers.Rescaling(1./255)
img_norm = normalization_layer(img)
img_norm = np.expand_dims(img_norm, axis=0)
np.shape(img_norm)
---
(1, 64, 64, 3)

训练过程中,对数据集做了归一化处理,推理时也要做同样的处理。

2.3 载入模型并推理

2.3.1 tiny-vgg 模型结构

# Create an instance of the model
filters = 10
tiny_vgg = Sequential([
 
    Conv2D(filters, (3, 3), input_shape=(64, 64, 3), name='conv_1_1'),
    Activation('relu', name='relu_1_1'),
 
    Conv2D(filters, (3, 3), name='conv_1_2'),
    Activation('relu', name='relu_1_2'),
    MaxPool2D((2, 2), name='max_pool_1'),
 
    Conv2D(filters, (3, 3), name='conv_2_1'),
    Activation('relu', name='relu_2_1'),
 
    Conv2D(filters, (3, 3), name='conv_2_2'),
    Activation('relu', name='relu_2_2'),
    MaxPool2D((2, 2), name='max_pool_2'),
 
    Flatten(name='flatten'),
    Dense(NUM_CLASS, activation='softmax', name='output')
])

2.3.2 运行推理

tiny_vgg = tf.keras.models.load_model('trained_vgg_best.h5')
prediction = tiny_vgg.predict(img_norm)
prediction
---
array([[6.2276758e-02, 3.6967881e-03, 9.2534656e-06, 4.8701441e-01,
        3.6426269e-02, 2.9939638e-02, 7.1093095e-03, 2.9743392e-02,
        2.1278052e-02, 3.2250613e-01]], dtype=float32)

注意:模型的最后一层已经经过 softmax 计算,无需单独调用 softmax 计算概率:

sum = np.sum(prediction)
print(sum)
---
1.0

 2.4 置信度柱状图

fig = plt.figure(figsize=(18,6))

# 绘制左图-预测图,调整比例
ax1 = plt.subplot(1,6,1)
ax1.imshow(img)
ax1.axis('off')

# 绘制右图-柱状图,调整比例
ax2 = plt.subplot(1,6,(2,6))
y = prediction[0]
ax2.bar(class_names, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
ax2.set_xticks(x)
ax2.set_xticklabels(class_names, fontproperties=font)
plt.ylim([0, 1.0]) # y轴取值范围

# 显示置信度数值
for i in range(len(y)):
    plt.text(i, y[i] + 0.01, f'{y[i]:.2f}', ha='center', fontsize=15)

plt.xlabel('类别', fontsize=20, fontproperties=font)
plt.ylabel('置信度', fontsize=20, fontproperties=font)
ax2.tick_params(labelsize=16)

plt.tight_layout()

2.5 预测标签

predict_label = class_names[np.argmax(prediction)]
print("类别: {}".format(predict_label))

# 显示图片
plt.figure(figsize=(2, 2))
plt.imshow(img)
plt.axis('off')
plt.show()

3. 完整代码

import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import cv2

font = matplotlib.font_manager.FontProperties(fname="./SimHei.ttf")

vali_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    './dataset/class_10_val/val_images/',
    image_size=(64, 64),
    batch_size=32)

class_names = vali_dataset.class_names

img = cv2.imread('./dataset/class_10_train/橙子/n07747607_0.JPEG')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

tiny_vgg = tf.keras.models.load_model('trained_vgg_best.h5')

prediction = tiny_vgg.predict(img_norm)

fig = plt.figure(figsize=(18,6))

# 绘制左图-预测图,调整比例
ax1 = plt.subplot(1,6,1)
ax1.imshow(img)
ax1.axis('off')

# 绘制右图-柱状图,调整比例
ax2 = plt.subplot(1,6,(2,6))
y = prediction[0]
ax2.bar(class_names, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
ax2.set_xticks(x)
ax2.set_xticklabels(class_names, fontproperties=font)
plt.ylim([0, 1.0]) # y轴取值范围

# 显示置信度数值
for i in range(len(y)):
    plt.text(i, y[i] + 0.01, f'{y[i]:.2f}', ha='center', fontsize=15)

plt.xlabel('类别', fontsize=20, fontproperties=font)
plt.ylabel('置信度', fontsize=20, fontproperties=font)
ax2.tick_params(labelsize=16)

plt.tight_layout()

4. 总结

本博文详继续介绍 Tiny-VGG 项目,对模型进行单张图片的推理,关键要点包括:

1). 数据处理与标签管理:通过 image_dataset_from_directory 方法加载数据,并提取类别名称作为标签,同时展示了如何保存和读取类别标签到/从文本文件。

2). 图片预处理:读取单个图片,并对其进行归一化处理,以匹配训练时的数据处理方式,确保模型能正确解读输入数据。

3). 模型加载与推理:加载预训练的Tiny-VGG模型,并对单张图片进行推理,获取预测结果。

4). 结果可视化:通过绘制图片和置信度柱状图来可视化模型的预测结果,使用中文标签和显示每个类别的置信度值。

5). 实用代码示例:提供了完整的代码示例,包括数据加载、模型推理和结果展示,方便读者理解和实际操作。
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2083017.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python酷库之旅-第三方库Pandas(105)

目录 一、用法精讲 456、pandas.DataFrame.rdiv方法 456-1、语法 456-2、参数 456-3、功能 456-4、返回值 456-5、说明 456-6、用法 456-6-1、数据准备 456-6-2、代码示例 456-6-3、结果输出 457、pandas.DataFrame.rtruediv方法 457-1、语法 457-2、参数 457-3…

云计算实训38——docker网络、跨主机容器之间的通讯

一、docker⽹络 1.桥接--bridge 所有容器连接到桥就可以使⽤外⽹&#xff0c;使⽤nat让容器可以访问外⽹ 使⽤ ip a s指令查看桥&#xff0c;所有容器连接到此桥&#xff0c;ip地址都是 172.17.0.0/16 ⽹段&#xff0c;桥是启动docker服务后出现&#xff0c;在centos使⽤ br…

centos安装mysql8.0版本,并且实现远程连接

一、 卸载mysql 查看mysql安装情况 rpm -qa | grep -i mysql 删除上图中所有信息 rpm -ev mysql-community-release-el7-5.noarch --nodeps 再次查询&#xff0c;没有数据&#xff0c;则为删除干净 find / -name mysql rm -rf /var/lib/mysql 将机器上的所有mysql相关文…

一篇文章带你真正了解接口测试(附视频教程+面试真题)

一、什么是接口测试&#xff1f; 所谓接口&#xff0c;是指同一个系统中模块与模块间的数据传递接口、前后端交互、跨系统跨平台跨数据库的对接。而接口测试&#xff0c;则是通过接口的不同情况下的输入&#xff0c;去对比输出&#xff0c;看看是否满足接口规范所规定的功能、…

79.位域

目录 一.位域的概念 二.语法格式 三.无名位域 四.视频教程 一.位域的概念 有些数据在存储的时候并不需要一个完整的字节。比如使用一个变量表示开关的状态&#xff0c;开关只有开和关俩个状态&#xff0c;所以只需要使用0和1表示&#xff0c;也就是一个二进制位。所以这时候…

前端提升之——chrome浏览器插件开发指南——chrome插件介绍及入门

前言 有一天突发奇想&#xff0c;想要自己写一个浏览器插件玩一玩&#xff0c;并不做用于商业或者其他方面&#xff0c;仅仅用于自我技术的练习和提升。 这里的浏览器我选择Chrome&#xff0c;当然chrome插件同样适用于微软自带的 Microsoft Edge 在当今发达的互联网环境下&…

云微客短视频矩阵如何打造多元化的视频内容呢?

随着抖音、快手等平台的兴起&#xff0c;短视频已经成为了人们日常生活的一部分&#xff0c;也有不少企业通过短视频赛道实现了品牌曝光和获客引流&#xff0c;但是单一的视频内容终究很难长久的吸引用户&#xff0c;所以如何打造多元化的视频内容呢&#xff1f; 在这个快节奏的…

【二叉树】OJ题目

&#x1f31f;个人主页&#xff1a;落叶 目录 单值⼆叉树 【单值二叉树】代码 相同的树 【相同二叉树】代码 对称⼆叉树 【对称二叉树】代码 另一颗树的子树 【另一颗树的子树】代码 二叉树的前序遍历 【二叉树前序遍历】代码 二叉树的中序遍历 【二叉树中序遍历】…

【数据结构】栈和队列相互实现

目录 栈实现队列 思路 入队列 出队列 获取队头元素 队列实现栈 思路 入栈 出栈 获取栈顶元素 完整代码 栈实现队列 队列实现栈 栈实现队列 思路 栈的特点是 先进后出&#xff0c; 队列的特点是 先进新出&#xff0c;这就意味着我们无法通过一个栈来实现队列&…

YOLOv5改进 | 融合改进 | C3融合Efficient Multi-Scale Conv Plus【完整代码】

秋招面试专栏推荐 &#xff1a;深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转 &#x1f4a1;&#x1f4a1;&#x1f4a1;本专栏所有程序均经过测试&#xff0c;可成功执行&#x1f4a1;&#x1f4a1;&#x1f4a1; 专栏目录&#xff1a; 《YOLOv5入门 改…

生成式AI扩散模型-Diffusion Model【李宏毅2023】概念讲解、原理剖析笔记

目录 一、Diffusion的基本概念和运作方法 1.Diffusion Model是如何运作的&#xff1f; 2.Denoise模块内部正在做的事情 如何训练Noise predictor&#xff1f; 1&#xff09;Forward Process (Diffusion Process) 2&#xff09;noise predictor 3.Text-to-Image 4.两个A…

MySQL必会知识精华3(使用MySQL)

我们的目标是&#xff1a;按照这一套资料学习下来&#xff0c;大家可以完成数据库增删改查的实际操作。轻松应对面试或者笔试题中MySQL相关题目 上篇文章我们先做一下数据库的基础知识以及MySQL的简单介绍。本篇文章主要连接使用MySQL的相关知识。相对简单&#xff0c;争取做到…

Datawhle X 李宏毅苹果书AI夏令营深度学习笔记之——局部最小值与鞍点

深度学习中优化神经网络是一个重要的问题&#xff0c;我们经常沮丧地发现到了一个节点&#xff0c;不管参数怎么更新&#xff0c;训练的损失都不会下降&#xff0c;神经网络似乎训练不起来了。这可能和损失函数收敛在局部最小值与鞍点有关。 一、 局部最小值&#xff08;local…

‌蜘蛛的工作原理及蜘蛛池的搭建与优化

蜘蛛的工作原理主要包括跟踪网页链接、‌采用一定的爬行策略遍历互联网&#xff0c;‌以及将新内容添加到引擎的索引中。‌具体来说&#xff1a;‌ 跟踪网页链接‌&#xff1a;‌蜘蛛会从一个或多个初始URL开始&#xff0c;‌通过这些URL发现新的链接&#xff0c;‌并将这些链接…

数据的基本类型

数据的基本类型 字符串 切片 切片语法&#xff1a; strs "hello" strs[0:]整数型 浮点型 布尔类型

vscode c++和cuda开发环境配置

文章目录 1. vscode 插件安装2. 开发环境配置2.1 bear 安装2.2 代码的编译2.2.1 编写Makefile文件2.2.2 bear make和make命令2.3 debug环境配置2.1 函数跳转设置2.1.1 ` c_cpp_properties.json` 设置2.1.2 settings.json设置2.2 调试环境配置2.2.1 tasks.json2.2.2 launch.json…

【C语言进阶】C语言指针进阶实战:优化与难题解析

&#x1f4dd;个人主页&#x1f339;&#xff1a;Eternity._ ⏩收录专栏⏪&#xff1a;C语言 “ 登神长阶 ” &#x1f921;往期回顾&#x1f921;&#xff1a;C语言指针进阶 (上) &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; ❀C语言指针进阶 &#x…

Java常用API(BigInteger)

在Java中&#xff0c;整数有四种类型&#xff1a;byte&#xff0c;short&#xff0c;int&#xff0c;long 在底层占用字节个数&#xff1a;byte 1个字节&#xff0c;short2个字节&#xff0c;int 4个字节&#xff0c;long 8个字节 对象一旦创建&#xff0c;里面的值是不能改变…

Go wv(WebView2) GUI框架介绍和使用

说明 wv(webview2) 是Go语言基于LCL和WebView2基础上封装的框架&#xff0c;用于开发Windows GUI软件。 介绍 LCL(Lazarus Component Library) &#xff1a;跨平台原生UI组件库. wv(WebView2): Microsoft Edge WebView2 控件允许在本机应用中嵌入 web 技术(HTML、CSS 以及 …

俄罗斯应用本地化中需要考虑的不同格式的特点

在为俄罗斯市场本地化应用程序时&#xff0c;调整各种格式以符合当地惯例至关重要。这些格式&#xff0c;包括日期和时间、数字、货币、地址等&#xff0c;在确保应用程序对俄罗斯用户来说自然和用户友好方面发挥着重要作用。以下是本地化过程中应考虑的一些关键格式特征。 日…