tensorflow QAT

news2025/1/19 3:35:05

tensorflow qat

https://www.wpgdadatong.com/tw/blog/detail/70672

在边缘运算的重点技术之中,除了简化复杂的模块构架,来简化参数量以提高运算速度的这项模块轻量化网络构架技术之外。另一项技术就是各家神经网络框架(TensorFlow、Pytorch etc…)的模块优化能力,主要探讨TensorFlow Lite的训练后之量化方式(Post-training quantization)与感知量化训练(Quantization-aware Training),依序分为上与下两篇幅,本篇将介绍后者信息为主。所谓的量化就是将以最小精度的方式,来进行模块推理,使模块应用至各种Edge Device之中,并达到足够成本效益,如下图所示。顺带一提,恩智浦NXP i.MX8M Plus的NPU(Neural Processing Unit)神经处理单元,属于纯整数的AI加速器,就仅适用于8位的整数运算才能获得最佳效益!!此系列的后续章节,也会利用NPU来实现算法加速之目的。

TensorFlow模型应用概念之示意图
利用TensorFlow Lite量化方式所构成的模块,就是将训练完成的轻量化模块,透过量化与优化的方式来提升推理速度!!如下模型运作概念图所示,储存模型完成后,即可依序执行冻结模型、优化模型到最后轻量化模型(.tflite),让模型运行在移动式装置时可达到最佳化的效果。
※MobileNet模块是一种轻量化模块的构架,而此篇重点是如何透过模块量化转换为轻量化模块(tflite)
TensorFlow 模型運作概念之示意圖
何谓量化?在此文章是泛指数值程度上的量化,亦指有限范围的数值表示方式。其作用是为了降低数值数据量与模块大小,来提升传输与执行(推理)速度!!而所谓的训练后之量化(Post-training quantization)就是利用训练完成的模块,再次进行量化的一种优化方式。主要特色就是仅须要储存后的模块(SaveModel / .h5 /ckpt),且不需要训练时的数据库即可量化。
举例来说,如下图所示,是须将原本数值分布为-3e38到+3e38的浮点数型态float,量化为数值分布-231到231的整数型态int,并以原本数据的最大值与最小值来找出有效的数值范围,将有一定概率大幅度减少数据量。
在这里插入图片描述
量化优势?劣势?对于TensorFlow Lite轻量化的应用而言
优势:
-减少模块尺寸:最多能缩减75%的大小
-加快推理速度:使用整数计算大幅度提升速度
-支持硬件较佳:能使处理八位元的处理器进行推理
-传输速度提升:因模块尺寸缩小,能更获得更好的传输质量
缺点:
-精度损失:因为数值的表示范围缩减,故模块的准确度将会大幅度的降低

三.感知量化训练(Quantization-aware Training)
感知量化训练(Quantization-aware Training,QAT)亦是一种量化手段,其原理与上一小节所介绍的量化方式雷同,目的也是以降低精度的方式,缩小模块所需计算的数据量,来提升模块运算速度,且保持一定准确度的一种优化手段。相较于上一小节所介绍的训练后之量化方式(Post-training quantization)最大的不同,就是需要利用原生模块与训练集(DataSets)来作重新训练,而感知量化训练会于训练时,去模拟低精度的运算,来保持最佳的模块准确度。
故理论上,『感知量化训练量化』的准确度会来得比『训练后之量化方式』来的准确,如下图所示;感知量化训练的模块(QAT Model)准确度能够逼近于原生模块(Baseline Model),反之,训练后之量化方式(Post-Training full quantized Model)则降低了约莫0.05的准确度。但其实该量化方式存在比较多不稳定因素,像是各机器学习框架的转换、版本不匹配、缺乏正确训练数据集或是不能微调等等因素,故官方比较推荐使用『感知量化训练量化』来获得更稳定的量化体验。

使用方式 :

大致上 Quantization Aware Training 的應用核心 可以分為四個步驟,分別為建立原生模組、建立感知量化模組、進行感知量化訓練、進行轉換等步驟。

必要套件 :

$ pip install -q tensorflow
$ pip install -q tensorflow-model-optimization

Step 1:建立原生模块

im

port tempfile
import os
import tensorflow as tf
from tensorflow import keras

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
model.fit( train_images, train_labels,epochs=1, validation_split=0.1, batch_size=32 )

感知量化模型是需要搭配原生模型与原生模型所训练的数据集。这里利用MNIST手写识别的示例来进行演示,如下代码所示;包含模块构架建立,以及利用MNIST DataSets进行训练。
其原生模块构架,如下图所示。
在这里插入图片描述
Step 2:建立感知量化模块(普通用法)
最简单的感知量化训练方式,就是直接利用Tensorflow Model Optimization的量化套件进行应用。
如同下代码与结果所示,将原生模块代入至quantize_model量化套件中,即构成感知量化的模块构架;若仔细观察的话,则会发现构架层的名称皆冠上quant的字眼,也就是程序会去模拟低精度的运算,亦可称Fake Quantization。

import tensorflow_model_optimization as tfmot
quantize_model = tfmot.quantization.keras.quantize_model
# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)
q_aware_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

q_aware_model.summary()

在这里插入图片描述
Step 2:建立感知量化模块(进阶用法)
进阶的感知量化训练方式,就是直接利用Tensorflow Model Optimization的量化套件进行微调,选择适当的构架层进行量化。
如下代码所示,将利用clone_model复制原生模块(baseline model)并透过apply_quantization_to_dense来选择须量化的构架层,而此示例仅量化Dense构架层。

def apply_quantization_to_dense(layer):
if isinstance(layer,tf.keras.layers.Dense):
return tfmot.quantization.keras.quantize_annotate_layer(layer)
return layer
annotated_model = tf.keras.models.clone_model(
base_model,
clone_function=apply_quantization_to_dense,
)
qat_model = tfmot.quantization.keras.quantize_apply(annotated_model)
qat_model.compile(optimizer='adam',loss=tf.keras.losses.SpareCategoricalCrossentropy(from_logits=True),metrics=['accuracy']

其感知量化模块构架,如下图所示,请读者仔细观察与普通用法的不同,就能够发现前面几构架层,以无quant的字眼,表示这是原生的构架层。
在这里插入图片描述
若欲理解更多进阶用法,可以参考官方示例以及查看可量化的构架层(请搜寻Default8BitQuantizeRegistry的字眼,查阅所描述的构架层),同理,若欲尝试量化没有支持的构架层,可以透过QuantizeConfig方式进行量化,可以参考Medium网志的解说!!

Step 3:进行感知量化训练
接着,需要对感知量化模块(q_ware_model)再次进行训练,得以模拟低精度运算,如下代码所示。

train_images_subset = train_images[0:1000]
train_labels_subset = train_labels[0:1000]
q_aware_model.fit(train_images_subset,train_labels_subset,batch_size=500,epochs=1,validation_split=0.1)
q_aware_annotate_model.fit(train_images_subset,train_labels_subset,batch_size=500,epochs=1,validation_split=0.1

训练完成后,即可验证模块的准确度的表现状况,如下代码与结果所示;Baseline test accuary为原始模块所呈现的准确度,反之Quant test accuracy为感知量化的准确度!!而感知量化的方式略高于0.02的准确度,故更好的准确度表现!!

_,baseline_model_accuracy = model.evaluate(
test_images,test_labels,verbose=0)
_,q_aware_model_accuracy = q_aware_model.evaluate(
test_images,test_labels,verbose=0)
_,q_aware_annotate_model_accuracy = q_aware_annotate_model.evaluate(
test_images,test_labels,verbose=0print'Baseline test accuracy:',baseline_model_accuracy)
print'Quant test accuracy:',q_aware_model_accuracy)
print'Quant test accuracy(annotate):',q_aware_annotate_model_accuracy)

Step 4:进行转换
基本上,前面步骤已经完成大致感知量化训练的操作,但因为最终目标是须应用于移动装置之中。故需要于感知量化训练且生成新的模块之后,利用上一小节的训练后之量化方式(Post-training quantization)的方式进行量化转换,如下代码所示!!然而,细心的读者应该可以发现一些小细节,就是这个代码又做一次优化,这里先卖个关子,原因将置于结论再向读者探讨!!

# quantized_aware_tranning_model(Dynamic)
run_model = tf.function(lambda x: q_aware_model(x))
concrete_func = run_model.get_concrete_function(tf.TensorSpec([1,28,28],model.inputs[0].dtype))
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_aware_tranning_model_dynamic = converter.convert()
with open(“quantized_aware_tranning_model_dynamic.tflite”,'wb'as f:
f.write(quantized_aware_tranning_model_dynamic)
print(“quantized_aware_tranning_model_dynamic done!!”)

完成后,会生成对应的.tflite档案,即可直接应用!!

量化使用分析:
这里以MNIST手写数字识别模块为基准,将测试原生模块与感知量化训练(普通用法与进阶用法)所生成的模块,来搭配训练后之量化等转换方式来验证准确度为何!!同时,也测试经过训练后之量化的方式,是否对于模块准确度或是应用有何影响?其测试代码就如同上一章所介绍的方式,或可以直接查看以及执行运行Colab代码。
实验测试数据结果:
其测试准确度数据结果如下:
在这里插入图片描述
四.结语
依目前实验结果而论,感知量化训练(Quantization Aware Training)能够尽可能去逼近原始模块的准确度,甚至还可能有些许的小幅度提升。而就推理速度来看,感知量化训练的普通用法其实相当于训练后之量化的全整数量化之结果,也表示模块内的参数已转换为低精度的表示方式!!换句话说,若实现过感知量化训练的优化则不必多此一举再做一次训练后之量化的优化。但碍于硬件加速器或处理器的不同,可能会出现各式各样的问题,就如同上述表格中的X,这就代表感知量化训练所生成的模块,仍有一定机率不能顺利运行于NPU上。故重新结合训练后之量化方式来达到更加的应用体验!!此外,读者不仿思考一下,为何感知量化训练的普通用法与进阶用法,所呈现的推理速度表现会有落差?其实此结果是完全符合,上述所向各位描述的一致,也就是进阶用法仅量化了完全连接层(Dense)一个构架层而已,自然表现就会比普通用法来得慢!!最后,探讨一下是否推荐使用『感知量化训练』来进行优化?以作者角度而言,若是开发者刚好使用Keras框架来开发模块的话,是个很棒的选择!!而事实上,每个神经网络框架都有拥护者,不可能所有人都用此框架开发。故对于活用度而言,感知量化训练略显于不足,故仍推荐先活用训练后之量化的方式进行优化!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/979974.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Day7:浅谈useEffect

「目标」: 持续输出!每日分享关于web前端常见知识、面试题、性能优化、新技术等方面的内容。 Day7-今日话题 useEffect 是 React 中一个非常重要的 Hook,用于处理副作用和订阅外部数据源的变化。它可以在函数式组件中执行各种操作,例如数据获…

小程序如何上传微信聊天记录的文件

wx.chooseMessageFile({count: 10,type: image,success (res) {// tempFilePath可以作为img标签的src属性显示图片const tempFilePaths res.tempFiles} })参数说明 回调函数说明

数据库实现学生管理系统

1.QT将数据库分为三个层次: 1> 数据库驱动层:QSqlDriver、QSqlDriverCreator、QSqlDriverCreatorBase、QSqlDriverPlugin 2> sql接口层:QSqlDatabase、QSqlQuery、QSqlRecord、QSqlError 3> 用户接口层:提供一些模型QSql…

开眼“观天”,从墨迹天气服贸会之旅看气象服务新未来

今年夏天,天气焦人。先是高温早早上线,然后台风来势汹汹,北京高温、河北暴雨,杜苏芮、苏拉、海葵轮番“奔袭”,极端气象事件频繁登上热搜,其险象环生的过程,让大众对气候问题的关注度节节走高。…

架构师 软件测试

架构师 软件测试 目录概述需求: 设计思路实现思路分析1.软件测试方法 软件测试工具 参考资料和推荐阅读 Survive by day and develop by night. talk for import biz , show your perfect code,full busy,skip hardness,make a better result,wait for c…

cudnn-windows-x86_64-8.6.0.163_cuda11-archive 下载

网址不太好访问的话,请从下面我提供的分享下载 Download cuDNN v8.6.0 (October 3rd, 2022), for CUDA 11.x 此资源适配 cuda11.x 将bin和include文件夹里的文件,分别复制到C盘安装CUDA目录的对应文件夹里 安装cuda时自动设置了 CUDA_PATH_V11_8 及path C:\Progra…

1123. 最深叶节点的最近公共祖先

文章目录 Tag题目来源题目解读解题思路方法一:递归 写在最后 Tag 【递归】【最近公共祖先】【二叉树】 题目来源 1123. 最深叶节点的最近公共祖先,865. 具有所有最深节点的最小子树 此二题系重复的题目。 题目解读 题目意思很明确,找出二叉…

String类的常用方法(Java)

目录 一、字符串构造二、String对象的比较1、比较是否引用同一个对象2、boolean equals(Object anObject) 方法:按照字典序比较3、int compareTo(String s) 方法: 按照字典序进行比较4、int compareToIgnoreCase(String str) 方法:与compareTo方式相同&a…

SpringMVC相对路径和绝对路径

1.相对地址与绝对地址定义 在jsp,html中使用的地址,都是在前端页面中的地址,都是相对地址 地址分类:(1),绝对地址,带有协议名称的是绝对地址,http://www.baidu.com&…

c语言练习41:深入理解字符串函数strlen strcpy strcat

深入理解字符串函数strlen strcpy strcat 模拟实现&#xff1a;”strlen strcpy strcat strlen strcat: #define _CRT_SECURE_NO_WARNINGS #include<stdio.h> #include<assert.h> strlen 1.通过指针移动模拟 //int my_strlen(char* str) { // size_t c…

Unity之创建第一个游戏

一 Unity环境配置 1.1 Untity资源官网下载&#xff1a;https://unity.cn/releases 1.2 Unity Hub集成环境&#xff0c;包含工具和项目的管理 1.3 Unity Editor编辑器 1.4 Visual Studio 2022脚本编辑器 1.5 AndroidSKD&#xff0c;JDK&#xff0c;NDK工具&#xff0c;用于and…

Python之写文件操作(二十九)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…

长风破浪会有时,直挂云帆济沧海!(工作室年会总结)

前言 我也是有段时间没写过总结性的博客了。最近是很忙的&#xff0c;尤其是年会那两天&#xff0c;我甚至可以说这是我这辈子目前最忙的两天。但这段经历还是很值得我记录下来的&#xff0c;也是给后面有需要的人提供的一些建议。我个人也是第一次筹办这种大型些的活动&#x…

SpringBoot_第六章(知识点总结)

目录 1&#xff1a;拦截器(Interceptor) 1.1&#xff1a;拦截器代码实现 1.2&#xff1a;拦截器源码分析和流程总结 2&#xff1a;过滤器(Filter)、自定义(Servlet)、监听器(Listener) 3&#xff1a;文件上传 3.1&#xff1a;文件上传代码实现 3.2&#xff1a;文件上传源…

部署Redis集群

文章目录 部署Redis集群1. 准备集群主机2. 启用集群功能3. 配置管理主机并创建集群3.1 配置管理主机 192.168.88.573.2 创建集群创建集群命令创建集群失败解决办法 3.3 查看集群信息查看集群统计信息查看集群详细信息 4. **测试集群及集群工作原理**4.1. 访问集群存取数据4.2 *…

Jmeter进阶使用指南-使用参数化

Apache JMeter是一个广泛使用的开源负载和性能测试工具。在进行性能测试时&#xff0c;我们经常需要模拟不同的用户行为和数据&#xff0c;这时候&#xff0c;参数化就显得尤为重要。此文主要介绍如何在JMeter中使用参数化。 什么是参数化&#xff1f; 参数化是一种将静态值替…

OpenCV之ellipse函数

ellipse函数用来在图片中绘制椭圆、扇形&#xff0c;有两个重载函数。 函数原型1&#xff1a; void cv::ellipse( InputOutputArray img,Point center,Size axes,double angle,double startAngle,double …

ORB-SLAM2算法14之局部建图线程Local Mapping

文章目录 0 引言1 概述2 处理队列中的关键帧3 剔除坏的地图点4 创建新地图点5 融合当前关键帧和其共视帧的地图点6 局部BA优化7 剔除冗余关键帧 0 引言 ORB-SLAM2算法7详细了解了System主类和多线程、ORB-SLAM2学习笔记8详细了解了图像特征点提取和描述子的生成、ORB-SLAM2算法…

如何使用SpringCloud Eureka 创建单机Eureka Server-注册中心

&#x1f600;前言 本篇博文是关于使用SpringCloud Eureka 创建单机Eureka Server-注册中心&#xff0c;希望你能够喜欢 &#x1f3e0;个人主页&#xff1a;晨犀主页 &#x1f9d1;个人简介&#xff1a;大家好&#xff0c;我是晨犀&#xff0c;希望我的文章可以帮助到大家&…

Jenkins实现基础CD操作

操作截图 在Jenkins里面设置通过标签进行构建 在Jenkins中进入项目&#xff0c;配置以下 将execute shell换到invoke top-level maven targets之前 在gitlab中配置标签 代码迭代新的版本 项目代码迭代 修改docker-compose.yml 提交新版本的代码 在Jenkins中追加新…