【TensorFlow2 之013】TensorFlow-Lite

news2024/11/15 19:54:46

一、说明

        在这篇文章中,我们将展示如何构建计算机视觉模型并准备将其部署在移动和嵌入式设备上。有了这些知识,您就可以真正将脚本部署到日常使用或移动应用程序中。

        教程概述:

  1. 介绍
  2. 在 TensorFlow 中构建模型
  3. 将模型转换为 TensorFlow Lite
  4. 训练后量化

二、前提知识

      上次,我们展示了如何使用迁移学习来提高模型性能。但是,当我们可以在智能手机或其他嵌入式设备上使用我们的模型时,为什么我们只使用我们的模型来预测计算机上猫或狗的图像呢?

        TensorFlow Lite 是 TensorFlow 针对移动和嵌入式设备的轻量级解决方案。它使我们能够在移动设备上以低延迟快速运行机器学习模型,而无需访问服务器。

        它可以通过 C++ API 在 Android 和 iOS 设备上使用,也可以与 Android 开发人员的 Java 包装器一起使用。

三、 在 TensorFlow 中构建模型

        在开始使用 TensorFlow Lite 之前,我们需要训练一个模型。我们将使用台式计算机或云平台等更强大的机器从一组数据训练模型。然后,可以导出该模型并在移动设备上使用。

        让我们首先准备训练数据集。我们将使用 wget .download 命令下载数据集。之后,我们需要解压缩它并合并训练和测试部分的路径。

import os
import wget
import zipfile 

wget.download("https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip")
100% [........................................................................] 68606236 / 68606236

Out[2]:

'cats_and_dogs_filtered.zip'
with zipfile.ZipFile("cats_and_dogs_filtered.zip","r") as zip_ref:
    zip_ref.extractall()

base_dir = 'cats_and_dogs_filtered'

train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

        

        接下来,让我们导入所有必需的库并构建我们的模型。我们将使用一个名为MobileNetV2的预训练网络 ,该网络是在 ImageNet 数据集上进行训练的。之后,我们需要冻结预训练层并添加一个名为GlobalAveragePooling2D 的新层,然后是具有 sigmoid 激活函数的密集层。我们将使用图像数据生成器来迭代图像。

base_model = MobileNetV2(input_shape=(224, 224, 3),
                                      include_top=False,
                                      weights='imagenet')

base_model.trainable = False
model = Sequential([base_model,
                    GlobalAveragePooling2D(),
                    Dense(1, activation='sigmoid')])
model.compile(loss='binary_crossentropy',
              optimizer=RMSprop(lr=0.001),
              metrics=['accuracy'])

train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=(224, 224),
        batch_size=32,
        class_mode='binary')

validation_generator = val_datagen.flow_from_directory(
        validation_dir,
        target_size=(224, 224),
        batch_size=32,
        class_mode='binary')
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
history = model.fit(
      train_generator,
      epochs=6,
      validation_data=validation_generator,
      verbose=2)
Train for 63 steps, validate for 32 steps
Epoch 1/6
63/63 - 470s - loss: 0.3610 - accuracy: 0.8515 - val_loss: 0.1664 - val_accuracy: 0.9460
Epoch 2/6
63/63 - 139s - loss: 0.2055 - accuracy: 0.9210 - val_loss: 0.2065 - val_accuracy: 0.9150
Epoch 3/6
63/63 - 89s - loss: 0.1507 - accuracy: 0.9470 - val_loss: 0.1294 - val_accuracy: 0.9560
Epoch 4/6
63/63 - 81s - loss: 0.1342 - accuracy: 0.9510 - val_loss: 0.1733 - val_accuracy: 0.9350
Epoch 5/6
63/63 - 80s - loss: 0.1205 - accuracy: 0.9555 - val_loss: 0.1684 - val_accuracy: 0.9390
Epoch 6/6
63/63 - 82s - loss: 0.1079 - accuracy: 0.9620 - val_loss: 0.1418 - val_accuracy: 0.9450

四、 将模型转换为 TensorFlow Lite

        在 TensorFlow 中训练深度神经网络后,我们可以将其转换为 TensorFlow Lite。

        这里的想法是使用提供的 TensorFlow Lite 转换器来转换模型并生成 TensorFlow Lite FlatBuffer文件。此类文件的扩展名是.tflite

        可以直接从经过训练的 tf.keras 模型、保存的模型或具体的 TensorFlow 函数进行转换,这也是可能的。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

open("converted_model.tflite", "wb").write(tflite_model)

        转换后,FlatBuffer 文件可以部署到客户端设备。在客户端设备上,它可以使用 TensorFlow Lite 解释器在本地运行。

https://www.tensorflow.org/lite/convert

        让我们看看如何使用TensorFlow Lite Interpreter直接从 Python 运行该模型。

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

        我们可以通过运行get_input_detailsget_output_details函数来获取模型的输入和输出详细信息。我们还可以调整输入和输出的大小,这样我们就可以对整批图像进行预测。

input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()

print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])
== Input details ==
shape: [  1 224 224   3]
type: <class 'numpy.float32'>

== Output details ==
shape: [1 1]
type: <class 'numpy.float32'>
interpreter.resize_tensor_input(input_details[0]['index'], (32, 224, 224, 3))
interpreter.resize_tensor_input(output_details[0]['index'], (32, 5))
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()

print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])
== Input details ==
shape: [ 32 224 224   3]
type: <class 'numpy.float32'>

== Output details ==
shape: [32  1]
type: <class 'numpy.float32'>

        接下来,我们可以生成一批图像,在我们的例子中为 32 个,因为这是之前图像数据生成器中使用的数量。

        使用 TensorFlow Lite 进行预测的第一步是设置模型的输入,在我们的示例中是一批 32 张图像。设置输入张量后,我们可以调用invoke 命令来调用解释器。输出预测将通过调用get_tensor命令获得。

        我们来对比一下 TensorFlow Lite 模型和 TensorFlow 模型的结果。它们应该是相同的。我们将使用np.testing.assert_almost_equal函数来完成此操作。如果结果在一定的小数位数内不相同,则会引发错误。

mage_batch, label_batch = next(iter(validation_generator))
interpreter.set_tensor(input_details[0]['index'], image_batch)
interpreter.invoke()
tflite_results = interpreter.get_tensor(output_details[0]['index'])

tf_results = model.predict(image_batch)

for tf_result, tflite_result in zip(tf_results, tflite_results):
    np.testing.assert_almost_equal(tf_result, tflite_result, decimal=5)

        到目前为止没有错误,所以它向我们表明此转换已成功完成。但这是有阶级概率的。让我们编写一个函数将其转换为类。该函数将执行与 TensorFlow 中内置函数相同的操作。此外,我们还将展示这些预测。红色将显示与主 TensorFlow 模型不同的预测,蓝色图像将显示两个模型做出相同的预测。作为标签,我们将显示 TensorFlow Lite 模型的预测。

def predict_classes(tf_model, tf_lite_interpreter, x):
    tf_lite_interpreter.set_tensor(input_details[0]['index'], x)
    tf_lite_interpreter.invoke()
    tflite_results = tf_lite_interpreter.get_tensor(output_details[0]['index'])

    tf_results = model.predict_classes(x)

    if tflite_results.shape[-1] > 1:
        tflite_results = tflite_results.argmax(axis=-1)
    else:
        tflite_results = (tflite_results > 0.5).astype('int32')

    plt.figure(figsize=(12,10))
    for n in range(32):
        plt.subplot(6,6,n+1)
        plt.subplots_adjust(hspace = 0.3)
        plt.imshow(image_batch[n])
        color = "blue" if tf_results[n] == tflite_results[n] else "red"
        plt.title(tflite_results[n], color=color)
        plt.axis('off')
    _ = plt.suptitle("Model predictions (blue: same, red: different)")

predict_classes(model, interpreter, image_batch)

模型预测 - 猫与狗

五、训练后量化

        这非常简单,但某些模型并未针对在低功耗设备上运行进行优化。因此,为了解决这个问题,我们需要进行量化。

        训练后量化是一种转换技术,可以减小模型大小,同时改善 CPU 和硬件加速器延迟,而模型精度几乎没有下降。

        下表显示了三种技术及其优点,以及可以运行该技术的硬件。

技术好处硬件
权重量化缩小 4 倍,加速 2-3 倍,准确度提高中央处理器
全整数量化缩小 4 倍,加速 3 倍以上CPU、边缘TPU等
Float16 量化体积缩小 2 倍,具有潜在的 GPU 加速能力中央处理器/图形处理器

        让我们看看它是如何进行权重量化的,这是默认的。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()

open("converted_model.tflite", "wb").write(tflite_quant_model)

interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()

interpreter.resize_tensor_input(input_details[0]['index'], (32, 224, 224, 3))
interpreter.resize_tensor_input(output_details[0]['index'], (32, 5))
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()
print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])
== Input details ==
shape: [ 32 224 224   3]
type: <class 'numpy.float32'>

== Output details ==
shape: [32  1]
type: <class 'numpy.float32'>

以下决策树可以帮助确定哪种训练后量化方法最适合您的用例。

决策树

        在我们的例子中,我们使用名为 MobileNetV2 的模型,该模型已经针对低功耗设备进行了优化。所以本例中的权重量化会降低模型精度。

        和之前一样,让我们​​看看这次模型的表现如何。红色将显示与主 TensorFlow 模型不同的预测,蓝色图像将显示两个模型做出相同的预测。作为标签,我们将显示 TensorFlow Lite 模型的预测。

predict_classes(model, interpreter, image_batch)

模型预测 2 - 猫与狗

六、总结

        总而言之,在这篇文章中,我们讨论了从 TensorFlow 到 TensorFlow Lite 的转换以及低功耗设备的优化。在下一篇文章中,我们将展示如何在 TensorFlow 2.0 中实现 LeNet-5。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1080035.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第九章-线程

初始时&#xff0c;CPU的执行流为进程&#xff1b;当产生了线程概念后&#xff0c;CPU执行流变为了线程&#xff0c;大大增大了一个周期以内进程的执行速度。 线程产生的作用就是为了提速&#xff0c;利用线程提速&#xff0c;原理就是实现多个执行流的伪并行&#xff0c;让处…

vue3前端开发系列 - electron开发桌面程序(2023-10月最新版)

文章目录 1. 说明2. 创建项目3. 创建文件夹electron3.1 编写脚本electron.js3.2 编写脚本proload.js 4. 修改package.json4.1 删除type4.2 修改scripts4.3 完整的配置如下 5. 修改App.vue6. 修改vite.config.ts7. 启动8. 打包安装9. 项目公开地址 1. 说明 本次安装使用的环境版…

提取log文件中的数据,画图

要提取的log格式如下&#xff1a; 代码如下&#xff1a; import reimport matplotlib.pyplot as plt import numpy as npimport argparse from os import path from re import searchclass DataExtractor(object): DataExtrator class def __init__(self, infile, keyword, out…

电脑上播放4K视频需要具备哪些条件?

在电视上播放 4K&#xff08; 4096 2160 像素&#xff09;视频是很简单的&#xff0c;但在电脑设备上播放 4K 视频并不容易。相反&#xff0c;它们有自己必须满足的硬件要求。 如果不满足要求&#xff0c;在电脑上打开 4K 分辨率文件或大型视频文件会导致卡顿、音频滞后以及更…

ROS中的命名空间

ROS中的节点、参数、话题和服务统称为计算图源&#xff0c;其命名方式采用灵活的分层结构&#xff0c;便于在复杂的系统中集成和复用。以下是一些命名的示例&#xff1a; /foo /stanford/robot/name /wg/node1计算图源命名是ROS封装的一种重要机制。每个资源都定义在一个命名空…

微信小程序wxml使用过滤器

微信小程序wxml使用过滤器 1. 新建wxs2. 引用和使用 如何在微信小程序wxml使用过滤器&#xff1f; 犹如Angular使用pipe管道这样子方便&#xff0c;用的最多就是时间格式化。 下面是实现时间格式化的方法和步骤&#xff1a; 1. 新建wxs 插入代码&#xff1a; /*** 管道过滤工…

泡泡玛特,难成“迪士尼”

作者 | 艺馨 豆乳拿铁 排版 | Cathy 监制 | Yoda 出品 | 不二研究 新增长难寻&#xff0c;新故事难讲。泡泡玛特(06682.HK)业绩增长承压的困局&#xff0c;都写在最新的半年报里。 曾经潮玩领域的王者、“潮玩第一股”泡泡玛特&#xff0c;主题城市乐园于9月26日在北京朝阳…

centos下安装配置redis7

1、找个目录下载安装包 sudo wget https://download.redis.io/release/redis-7.0.0.tar.gz 2、将tar.gz包解压至指定目录下 sudo mkdir /home/redis sudo tar -zxvf redis-7.0.0.tar.gz -C /home/redis 3、安装gcc-c yum install gcc-c 4、切换到redis-7.0.0目录下 5、修改…

2023年中国医学影像信息系统市场规模、竞争格局及行业趋势分析[图]

医学影像信息系统简称PACS&#xff0c;与临床信息系统、放射学信息系统、医院信息系统、实验室信息系统同属医院信息系统。医学影像信息系统是处理各种医学影像信息的采集、存储、报告、输出、管理、查询的计算机应用程序。主要包括&#xff1a;预约管理、数据接收、影像处理、…

[读博随笔] 系统安全和论文写作的那些事——不忘初心,江湖再见

很难想象读博这四年的时光意味着什么&#xff0c;是对妻子和儿子深切的思念。我在珞珈山下挑灯夜读&#xff0c;你在贵阳家中独自照顾幼子。怕的不是孑然一身&#xff0c;而是明明已经习惯两个人&#xff0c;又必须各自前行&#xff0c;像单打独斗的勇士。想到千里之外还有一个…

【软考设计师】【计算机系统】E01 计算机硬件组成与CPU

【计算机系统】E01 计算机硬件组成与CPU 硬件组成概述中央处理单元 CPUCPU 组成运算器控制器寄存器组 多核 CPU 硬件组成概述 运算器&#xff1a; 数据加工处理部件&#xff0c;用于完成计算机的各种算术和逻辑运算。控制器&#xff1a; 顾名思义&#xff0c;控制整个CPU的工作…

2023年中国牙线市场规模、竞争现状及行业需求前景分析[图]

牙线是由合成纤维或其他材料制成&#xff0c;或添加香料、色素、活性成分等&#xff0c;用来清洁牙齿邻面附着物的线。能够有效包裹牙齿&#xff0c;对于清洁平面/凸起牙面和牙齿邻接面的牙菌斑效果很好&#xff0c;还可以实现对于牙缝间食物/异物的剔除&#xff0c;有效清洁口…

水库大坝除险加固安全监测系统解决方案

一、系统背景 为贯彻落实《办公厅关于切实加强水库除险加固和运行管护工作的通知》&#xff08;〔2021〕8号&#xff09;要求&#xff0c;完成“十四五”小型病险水库除险加固、雨水情测报和大坝安全监测设施建设任务&#xff0c;规范项目管理&#xff0c;消除安全隐患&#xf…

nodejs+vue 教学辅助管理系统

通过科技手段提高自身的优势&#xff1b;对于驾校预约管理系统当然也不能排除在外&#xff0c;随着网络技术的不断成熟&#xff0c;带动了驾校预约管理系统&#xff0c;它彻底改变了过去传统的管理方式&#xff0c;不仅使服务管理难度变低了&#xff0c;还提升了管理的灵活性。…

云里雾里?云方案没有统一标准,业务结合实际情况才是应该着重考虑的

公共云正在迅速成为IT领导者及其业务线同行的首选平台。根据研究机构IDC的数据&#xff0c;2018年全球公共云服务和基础设施支出预计将达到1600亿美元&#xff0c;比2017年的投资水平增长23%。 然而&#xff0c;在向按需IT转变的步伐不断加快的同时&#xff0c;首席信息官们面…

文旅如何以数字人三维动画宣传片,实现文化资源数字化转化?

近日&#xff0c;山西文旅推出虚拟星推官——晋依依&#xff0c;通过数字人三维动画宣传片的形式&#xff0c;以数字人晋依依的第一视角&#xff0c;对山西的历史文化进行了回顾、展望与想象&#xff0c;利用数字人将山西原有的人文、环境、地貌等进一步宣传&#xff0c;带动当…

乐优商城(一)介绍和项目搭建

1. 乐优商城介绍 1.1 项目介绍 乐优商城是一个全品类的电商购物网站&#xff08;B2C&#xff09;用户可以在线购买商品、加入购物车、下单可以评论已购买商品管理员可以在后台管理商品的上下架、促销活动管理员可以监控商品销售状况客服可以在后台处理退款操作希望未来 3 到 …

【机器学习】集成学习(以随机森林为例)

文章目录 集成学习随机森林随机森林回归填补缺失值实例&#xff1a;随机森林在乳腺癌数据上的调参附录参数 集成学习 集成学习&#xff08;ensemble learning&#xff09;是时下非常流行的机器学习算法&#xff0c;它本身不是一个单独的机器学习算法&#xff0c;而是通过在数据…

Rust入门基础

文章目录 Rust相关介绍为什么要用Rust&#xff1f;Rust的用户和案例 开发环境准备安装Rust更新与卸载Rust开发工具 Hello World程序编写Rust程序编译与运行Rust程序 Cargo工具Cargo创建项目Cargo构建项目Cargo构建并运行项目Cargo检查项目Cargo为发布构建项目 Rust相关介绍 为…

关于串口服务器及转接线的一些基础知识笔记

1.普通个人计算机9针串口为232接口&#xff0c;部分特殊工业计算机为485接口。接线方式差异较大&#xff0c;容易区分。 2.串口服务器的作用&#xff1a;带串口的设备&#xff08;支持常见232、485/422接口方式&#xff09;&#xff0c;将其串口数据信号通过串口服务器转为网络…