TensorRT量化实战课YOLOv7量化:pytorch_quantization介绍

news2025/1/22 14:50:33

目录

    • 前言
    • 1. 课程介绍
    • 2. pytorch_quantization
      • 2.1 initialize函数
      • 2.2 tensor_quant模块
      • 2.3 TensorQuantizer类
      • 2.4 QuantDescriptor类
      • 2.5 calib模块
    • 总结

前言

手写 AI 推出的全新 TensorRT 模型量化实战课程,链接。记录下个人学习笔记,仅供自己参考。

该实战课程主要基于手写 AI 的 Latte 老师所出的 TensorRT下的模型量化,在其课程的基础上,所整理出的一些实战应用。

本次课程为 YOLOv7 量化实战第一课,主要介绍 TensorRT 量化工具箱 pytorch_quantization。

课程大纲可看下面的思维导图

在这里插入图片描述

1. 课程介绍

什么是模型量化呢?那我们都知道模型训练的时候是使用的 float32 或 float16 的浮点数进行运算,这样模型能保持一个比较好的效果,但浮点数在提升计算精度的同时也导致了更多的计算量以及存储空间的占用。

由于在模型推理阶段我们并不需要进行梯度反向传播,因此我们不需要那么高的计算精度,这时可以将高精度的模型参数映射到低精度上,可以降低运算量提高推理速度。

将模型从高精度运算转换到低精度运算的过程就叫做模型量化

量化的过程与数据的分布有关,当数据分布比较均匀的时候,高精度 float 向低精度 int 进行映射时就会将空间利用得比较充分,如果数据分布不均匀就会浪费很大的表示空间。

量化又分为饱和量化和非饱和量化,如果直接将量化阈值设置为 ∣ x max ∣ |x_{\text{max}}| xmax,此时 INT8 的表示空间没有被充分的利用,这是非饱和量化

如果选择了一个比较合适的阈值,舍弃那些超出范围的数值,再进行量化,那这种量化因为充分利用 INT8 的表示空间因此也被称为饱和量化。

模型量化及其意义可以总结为:

  • 模型量化是指将神经网络的浮点转换为定点
  • 模型量化主要意义就是加快模型端侧的推理速度,并降低设备功耗和减少存储空间,工业界一般只使用 INT8 量化模型。

本系列实战课程需要大家具备一些基本的量化知识,如果对模型量化知识模糊的看官的可以先观看 TensorRT下的模型量化 课程。

2. pytorch_quantization

我们先对 TensorRT 的量化工具箱 pytorch_quantization 做一个简单的介绍

它的安装指令如下:

pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com

要求:torch >= 1.9.1,Python >= 3.7, GCC >= 5.4

在博主之前学习的过程中,发现 pytorch 的版本和 pytorch_quantization 的版本如果不适配可能会导致一些问题。

目前博主的软件版本是:pytorch==2.0.1pytorch_quantization==2.1.3

我们下面介绍下 pytorch_quantization 工具库中的一些函数、类和模块

2.1 initialize函数

首先是 quant_modules 模块中的 initialize() 函数,它的使用如下:

import torchvision
from pytorch_quantization import quant_modules

quant_modules.initialize()  # quant_modules 初始化,自动为模型插入量化节点
model = torchvision.models.resnet50()   # 加载 resnet50 模型
# model 是带有量化节点的模型

它的作用是初始化量化相关的设置和一些参数,因此我们需要在量化之前调用它。因为不同类型的神经网络层如 Conv、Linear、Pool 等等,它们所需要的量化方法是不同的,例如某个网络层当中的校准方法可能用的是 Max,也有可能用的是直方图,那这都可以在我们量化之前通过 initialize 来进行一个设置。

initialize 还有一个作用,那就是将模型中的 torch 网络层替换为相应的 quant 量化层,如下所示:

torch.nn.Conv2d     ->  quant_modules.quant_nn.Conv2d
torch.nn.Linear     ->  quant_modules.quant_nn.Linear
torch.nn.MaxPool2d  ->  quant_modules.quant_nn.MaxPool2d

也就是会把 torch 中对应的算子转换为相应的量化版本。

总的来说,initialize 用于在量化模型之前,对量化过程进行必要的配置和准备工作以确保量化操作时按照我们所需要的方式进行,这样的话有助于提高量化模型的性能。

在我们调用 initialize 之后,我们的模型结构会插入 FQ 节点,也就是 fake 算子,如下图所示:

在这里插入图片描述

那在之后的代码讲解部分我们会清晰的观察到在调用 initialize 前后模型结构的一些变化。

2.2 tensor_quant模块

然后是 tensor_quant 模块,它的使用如下:

from pytorch_quantization import tensor_quant

tensor_quant.fake_tensor_quant()
tensor_quant.tensor_quant()

tensor_quant 模块负责进行张量数据的量化操作。那在模型量化过程中我们有两种量化方式:

  • 模型 weights 的量化:对于权重的量化我们是对权重的每个通道进行量化,比如一个 Conv 层的通道数是 32,这意味着 32 个通道数的每个通道都有一个对应的 scale 值去进行量化。
  • 模型 inputs/activate 的量化:而对于输入或者激活函数数值而言,它们的量化是对每个张量进行量化,也就是说整个 Tensor 数据都是用同一个 scale 值进行量化

具体见下图:

在这里插入图片描述

在上面的图中我们可以清楚的看到右边是我们的输入量化,inputs 的量化 scale 只有一个,而左边是我们的权重量化,weights 的量化 scale 有 32 个,这是因为 Conv 的通道数是 32,它有 32 个 channel,每个 channel 对应一个 scale。

下面的代码使用了 tensor_quant 模块中的函数对张量进行量化:

fake_quant_x   = tensor_quant.fake_tensor_quant(x, x.abs().max) # Q 和 DQ 节点组成了 Fake 算子
quant_x, scale = tensor_quant.tensor_quant(x, x.abs().max())    # Q 节点的输出和 scale 值

我们先来看看 tensor_quant 中的两个函数

  • tensor_quant.fake_tensor_quant
    • 这个函数通常用于模拟量化的过程,而不是实际上执行量化,也就是我们通常说的伪量化
    • 伪量化(Fake Quantization)是一种在训练过程中模拟量化效果的技术,但在内部仍然保持使用浮点数
    • 这样做的目的是使模型适应量化带来的精度损失,从而在实际进行量化时能够保持较好的性能。
  • tensor_quant.tensor_quant
    • 这个函数用于实际对张量进行量化,它将输入的浮点数张量转换为定点数的表示(比如从 floa32 转换为 int8)
    • 这个过程涉及确定量化的比例因子 scale 和零点 zero-point,然后应用这些参数将浮点数映射到量化的整数范围内。

在上面的代码中,x 是我们的输入数据,x.abs().Max 代表我们使用基于 Max 的对称量化方法进行量化,函数的输出 fake_quant_x 是经过伪量化处理的张量,它看起来像是被量化了,但实际上仍然是浮点数。

tensor_quant 函数的输出 quant_x 是我们经过实际 int 量化处理后得到的 int 类型的张量,scale 则是我们用于量化过程中的比例因子。

2.3 TensorQuantizer类

下面我们来看看将量化后的模型导出要做哪些操作,实际上我们需要使用到 nn 模块中的 TensorQuantizer,它的使用如下:

from pytorch_quantization import nn as quant_nn

quant_nn.TensorQuantizer.use_fb_fake_quant = True   # 模型导出时将一个 QDQ 算子导出两个 op

其中 pytorch_quantizaiton 的 nn 模块提供了量化相关的神经网络层和工具,大家可以类比于 pytorch 中的 nn 模块。而 TensorQuantizer 是一个用于张量量化的工具类,use_fb_fake_quant 是它的一个类属性,用于控制量化过程中伪量化的行为。

我们将 use_fb_fake_quant 设置为 True 表明我们在导出量化模型时,希望将量化和反量化操作过程作为两个单独的 op 算子来导出,如下图所示:

在这里插入图片描述

可以看到上图中的红色框部分,导出的量化模型中包含 QuantizeLinear 和 DequantizeLinear 两个模块,对应我们的量化和反量化两个 op。

在我们将 use_fb_fake_quant 设置为 True 的时候,它会调用的是 pytorch 模块中的两个函数,如下:

torch.fake_quantize_per_tensor_affine
torch.fake_quantize_per_channel_affine

这两个函数会导出我们之前量化的操作,值得注意的是,在模型导出和模型前向阶段的量化操作并不是使用 tensor_quant 模块中的函数来实现的,而是使用 torch 中上述两个函数来实现,这样做是因为更容易转化成相应 的 tensorRT 的一个操作符,以便我们后续的部署。在模型训练阶段,我们则是调用 tensor_quant 函数插入 fake 算子来进行量化的,大家需要了解到在模型训练和前向阶段调用的函数的不同。

在 Torch-TesorRT 内部,fake_quantize_per_*_affine 会被转换为 QuantizeLayer 和 DequantizerLayer,也就是我们上面导出 ONNX 模型的两个 op 算子。

在这里插入图片描述
在这里插入图片描述

从上图中我们能清晰的看出在模型训练的时候和模型导出的时候 Q/DQ 节点所发生的一个变化,在模型训练的时候,我们是通过 tensor_quant 来插入 fake 算子来实现量化的,而在模型训练完成后导出 ONNX 时,我们是需要将 use_fb_fake_quant 置为 True,它会调用 torch 中的函数将 fake 算子的节点导出成 Q 和 DQ 两个模块。

2.4 QuantDescriptor类

接下来我们再来看下 QuantDescriptor 类,它的使用如下:

import torch
import pytorch_quantization.nn as quant_nn
from pytorch_quantization.tensor_quant import QuantDescriptor

# 自定义层的量化
class QuantMultiAdd(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._input_quantizer  = quant_nn.TensorQuantizer(QuantDescriptor(
                                 num_bits=8, calib_method="histogram"))
        self._weight_quantizer = quant_nn.TensorQuantizer(QuantDescriptor(
                                 num_bits=8, axis=(1), calib_method="histogram"))
    def forward(self, w, x, y):
        return self._weight_quantizer(w) * self._input_quantizer(x) + self._input_quantizer(y)

QuantDescriptor 类主要是用于配置量化的描述符,包括量化的位数,量化的方法等等。在上面的代码中,我们创建了一个自定义的量化层,该层对权重和输入进行量化,并执行加权乘法和加法操作

  • 我们先创建了两个 TensorQuantizer 实例,一个是 _input_quantizer 用于输入量化,另一个是 _weight_quantizer 用于权重量化
  • 我们使用 QuantDescriptor 来描述量化的参数,对于这两个量化器,都使用了 8bit 量化,量化的校准方法都设置为直方图校准

也就是说,我们使用 QuantDescriptor 可以实现自定义层的量化操作,在后续代码介绍的时候会使用到这个类。

2.5 calib模块

我们再来看下 pytorch_quantization 中的校准模块 calib,它的使用如下:

from pytorch_quantization import calib

if isinstance(module._calibrator, calib.MaxCalibrator):
    module.load_calib_amax()

calib 校准模块包含 MaxCalibrator 和 HistogramCalibrator 两个校准类,其中 MaxCalibrator 用于执行最大值校准,在我们的量化训练中,我们通常会确定每个张量的一个动态范围,也就是它们的最大值和最小值,Max 方法通过跟踪张量的最大值来执行标定工作,以便在量化推理时能将其映射到 int 整数范围之内。

而对于 Histogram 直方图校准方法则是通过收集和分析张量值的直方图来确定我们的动态范围,这种方法可以更准确地估计张量值的一个分布,并且更好地适应不同数据分布的情况。

这两种校准方法在模型量化中都有它们各自的优势,具体选择哪种校准方法主要取决于我们具体的应用场景和数据分布的情况,我们通常是根据数据分布和量化的需求来选择合适的校准方法,以确保量化后的模型在推理时能保持一个比较好的准确性。

以上就是关于 pytorch_quantization 中的函数、类和模块的简单介绍。

总结

本次课程介绍了 pytorch_quantization 量化工具以及其中的一些函数、类和模块。在我们量化之前需要调用 initialize 函数来初始化量化相关的一些设置和参数。接着我们会使用 tensor_quant 模块来对张量数据进行实际的量化,而在量化完成后导出时我们需要将 TensorQuantizer 类中的属性 usb_fb_fake_quant 设置为 True,使得导出的量化模型包含 Q、DQ 两个模块。这是因为在模型训练阶段和前向、导出阶段的量化操作调用的函数是不同的,训练阶段是通过 tensor_quant 函数插入 fake 算子来量化的,而导出阶段是 torch 中的两个函数来实现的。

在量化过程中我们还会使用 QuantDescriptor 来配置量化的一些参数,包括量化位数、量化方法等等,最后我们简单介绍了 Calib 校准模块,它包含 Max 和 Histogram 两种校准方法。

下节我们正式进入 YOLOv7-PTQ 量化的学习😄

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1148325.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分类预测 | Matlab实现KOA-CNN-BiLSTM-selfAttention多特征分类预测(自注意力机制)

分类预测 | Matlab实现KOA-CNN-BiLSTM-selfAttention多特征分类预测(自注意力机制) 目录 分类预测 | Matlab实现KOA-CNN-BiLSTM-selfAttention多特征分类预测(自注意力机制)分类效果基本描述程序设计参考资料 分类效果 基本描述 1…

【MySQL--->内外连接】

文章目录 [TOC](文章目录) 一、内连接二、左外连接三、右外连接 一、内连接 内连接就是将两个表连接进行笛卡尔积查询 显示SMITH的名字和部门名称 二、左外连接 左外连接就是以左面的表为主,即便是右边的表没有而左边表项中有的,依然显示 查询所有学…

HTML基础总结——速通知识点

一、基础知识点 Web标准构成&#xff1a; HTML页面的固定结构 <html><head><title>网页的标题</title> </head> <body>网页的主体内容 </body> </html>二、语法 2.1注释 在vscode中&#xff1a;将光标置于需要注释的行&a…

引入个性化标签的协同过滤推荐算法研究_邢瑜航

第3章 引入个性化标签的I-CF推荐算法 3.2.2 相似性度量方法 3.2.3 改进后的算法步骤与流程

IntelliJ IDEA 把package包展开和压缩

想要展开就把对勾取消&#xff0c;想要压缩就勾上

【多线程面试题十二】、阻塞线程的方式有哪些?

文章底部有个人公众号&#xff1a;热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享&#xff1f; 踩过的坑没必要让别人在再踩&#xff0c;自己复盘也能加深记忆。利己利人、所谓双赢。 面试官&#xff1a;阻塞线程的方式有哪些&a…

【LeetCode力扣】42. 接雨水

目录 1、题目介绍 2、解题思路 2.1、暴力破解法 2.2、双指针法 1、题目介绍 原题链接&#xff1a; 42. 接雨水 - 力扣&#xff08;LeetCode&#xff09; 示例 1&#xff1a; 输入&#xff1a;height [0,1,0,2,1,0,1,3,2,1,2,1]输出&#xff1a;6解释&#xff1a;上面是由…

Python打包成.exe文件直接运行

文章目录 前言pyinstaller.exe文件具体步骤第一步&#xff1a;安装pyinstaller第二步&#xff1a;进入要打包文件的目录第三步&#xff1a;执行文件第四步&#xff1a;发给好友 拓展尾声 前言 很多小伙伴在阅读了博主的文章后都积极与博主交流&#xff0c;在这里博主很感谢大家…

scrapy-redis分布式爬虫(分布式爬虫简述+分布式爬虫实战)

一、分布式爬虫简述 &#xff08;一&#xff09;分布式爬虫优势 1.充分利用多台机器的带宽速度 2.充分利用多台机器的ip地址 &#xff08;二&#xff09;Redis数据库 1.Redis是一个高性能的nosql数据库 2.Redis的所有操作都是原子性的 3.Redis的数据类型都是基于基本数据…

LED数码管的静态显示与动态显示(Keil+Proteus)

前言 就是今天看了一下书上的单片机实验&#xff0c;发现很多的器件在Proteus中都不知道怎么去查找&#xff0c;然后想做一下这个实验&#xff0c;尝试能不能实现&#xff0c;LED数码管的两个还可以实现&#xff0c;但是用LED点阵显示器的时候他那个网络标号不知道是什么情况&…

GZ035 5G组网与运维赛题第7套

2023年全国职业院校技能大赛 GZ035 5G组网与运维赛项&#xff08;高职组&#xff09; 赛题第7套 一、竞赛须知 1.竞赛内容分布 竞赛模块1--5G公共网络规划部署与开通&#xff08;35分&#xff09; 子任务1&#xff1a;5G公共网络部署与调试&#xff08;15分&#xff09; 子…

python自动化测试(六):唯品会商品搜索-练习

目录 一、配置代码 二、操作 2.1 输入框“运动鞋” 2.2 点击搜索按钮 2.3 选择品牌 2.4 选择主款 2.5 适用性别 2.6 选择尺码 2.7 选择商品&#xff1a;&#xff08;通过css的属性去匹配&#xff09; 2.8 点击配送地址选项框 一、配置代码 # codingutf-8 from selen…

基于萤火虫算法的无人机航迹规划-附代码

基于萤火虫算法的无人机航迹规划 文章目录 基于萤火虫算法的无人机航迹规划1.萤火虫搜索算法2.无人机飞行环境建模3.无人机航迹规划建模4.实验结果4.1地图创建4.2 航迹规划 5.参考文献6.Matlab代码 摘要&#xff1a;本文主要介绍利用萤火虫算法来优化无人机航迹规划。 1.萤火虫…

Mysql数据库基本概念和Sql语言

一、数据库基本概念 1.1 数据库概述 数&#xff1a;数字信息 据&#xff1a;属性 数据&#xff1a;对一系列对象的具体属性的描述的集合 数据库&#xff1a;数据库就是用来组织(各个数据之间是有关联的&#xff0c;按照规则组织起来的)、存储和管理(对数据的增、删、改、查)的…

JavaEE-博客系统1(数据库和后端的交互)

本部分内容包括网站设计总述&#xff0c;数据库和后端的交互&#xff1b; 数据库操作代码如下&#xff1a; -- 编写SQL完成建库建表操作 create database if not exists java_blog_system charset utf8; use java_blog_system; -- 建立两张表&#xff0c;一个存储博客信息&am…

数据结构—线性实习题目(二)5迷宫问题(栈)

迷宫问题&#xff08;栈&#xff09; #include <iostream>​ #include <assert.h> using namespace std;int qi1, qi2; int n; int m1, p1; int** Maze NULL; int** mark NULL;struct items {int x, y, dir; };struct offsets {int a, b;char* dir; };const int…

Java SE 学习笔记(十八)—— 注解、动态代理

目录 1 注解1.1 注解概述1.2 自定义注解1.3 元注解1.4 注解解析1.5 注解应用于 junit 框架 2 动态代理2.1 问题引入2.2 动态代理实现 1 注解 1.1 注解概述 Java 注解&#xff08;Annotation&#xff09;又称Java标注&#xff0c;是JDK 5.0引入的一种注释机制&#xff0c;Java语…

Java 基于微信小程序的汉堡点餐系统的研究与实现

文章目录 1 简介2 相关技术介绍3 系统需求分析4 系统功能分析5 系统的详细设计与实现5.1 系统登录页面5.2 点餐系统后台首页页面5.3 商品信息管理页面5.4 会员管理页面5.5 购买信息管理页面5.6 小程序首页信息页面5.7 商品信息页面5.8 在线下单页面 6 推荐阅读 1 简介 基于微信…

笔记本电脑搜索不到wifi6 无线路由器信号

路由器更换成wifi6 无线路由器后&#xff0c;手机能搜索到这个无线信号&#xff0c;但是笔记本搜索不到这个无线信号&#xff0c;后网上搜索后发现是无线网卡驱动问题&#xff0c;很多无线网卡使用的是Intel芯片&#xff0c;Intel就此发布了公告&#xff0c;升级驱动就可以彻底…

【C】C语言文件操作

1.为什么使用文件 我们前面学习结构体时&#xff0c;写通讯录的程序&#xff0c;当通讯录运行起来的时候&#xff0c;可以给通讯录中增加、删除数据&#xff0c;此时数据是存放在内存中&#xff0c;当程序退出的时候&#xff0c;通讯录中的数据自然就不存在了&#xff0c;等下…