1、Pytorch-Quantization简介
PyTorch Quantization是一个工具包,用于训练和评估具有模拟量化的PyTorch模型。PyTorch Quantization API支持将 PyTorch 模块自动转换为其量化版本。转换也可以使用 API 手动完成,这允许在不想量化所有模块的情况下进行部分量化。例如,一些层可能对量化比较敏感,对其不进行量化可提高任务精度。
PyTorch Quantization的量化模型可以直接导出到ONNX,并由TensorRT 8.0或者更高版本导入进行转换Engine。
1.1 量化函数
tensor_quant和fake_tensor_ quant是量化张量的2个基本函数:
- fake_tensor_quant 返回伪量化张量(浮点值)。
- tensor_quant 返回量化后的张量(整数值)以及其对应的缩放值Scale。
1.2 描述符和量化器
QuantDescriptor是用来定义张量应如何量化;PyTorch Quantization提供了一些预定义的QuantDescriptor,例如:
1)QUANT_DESC_8BIT_PER_TENSOR
2)QUANT_DESC_8BIT_CONV2D_WEIGHT_PER_CHANNEL
TensorQuantizer 可以量化、伪量化或收集张量的统计信息。它与 QuantDescriptor 一起使用,后者描述了如何量化张量。
如下图所示,