pytorch量化库使用(1)

news2024/11/18 9:21:43

量化简介

量化是指以低于浮点精度的位宽执行计算和存储张量的技术。量化模型以降低的精度而不是全精度(浮点)值对张量执行部分或全部运算。这允许更紧凑的模型表示以及在许多硬件平台上使用高性能矢量化操作。与典型的 FP32 模型相比,PyTorch 支持 INT8 量化,从而使模型大小减少 4 倍,内存带宽要求减少 4 倍。与 FP32 计算相比,对 INT8 计算的硬件支持通常快 2 到 4 倍。量化主要是一种加速推理的技术,量化运算符仅支持前向传递。

PyTorch 支持多种量化深度学习模型的方法。大多数情况下,模型在 FP32 中训练,然后模型转换为 INT8。此外,PyTorch 还支持量化感知训练,它使用假量化模块对前向和后向传递中的量化误差进行建模。请注意,整个计算都是以浮点形式进行的。在量化感知训练结束时,PyTorch 提供转换函数,将训练后的模型转换为较低精度。

在较低级别,PyTorch 提供了一种表示量化张量并使用它们执行操作的方法。它们可用于直接构建以较低精度执行全部或部分计算的模型。提供了更高级别的 API,其中包含将 FP32 模型转换为较低精度的典型工作流程,并且精度损失最小。

量化 API 总结

PyTorch 提供两种不同的量化模式:Eager 模式量化和 FX Graph 模式量化。

Eager 模式量化是测试版功能。用户需要进行融合并手动指定量化和反量化发生的位置,而且它仅支持模块而不支持函数。

FX 图形模式量化是 PyTorch 中的一个新的自动量化框架,目前它是一个原型功能。它通过添加对泛函的支持和自动化量化过程来改进 Eager 模式量化,尽管人们可能需要重构模型以使模型与 FX 图形模式量化兼容(通过 进行符号追踪)torch.fx。请注意,FX 图形模式量化预计不适用于任意模型,因为该模型可能无法符号追踪,我们会将其集成到 torchvision 等域库中,并且用户将能够使用 FX 量化与支持的域库中的模型类似的模型图模式量化。对于任意模型,我们将提供一般指南,但要真正使其发挥作用,用户可能需要熟悉torch.fx,特别是如何使模型具有符号可追溯性。

我们鼓励量化的新用户首先尝试 FX 图形模式量化,如果不起作用,用户可以尝试遵循使用 FX 图形模式量化的指南或回退到 eager 模式量化。

下表比较了 Eager 模式量化和 FX Graph 模式量化之间的差异:

 

支持三种类型的量化:

  1. 动态量化(通过以浮点形式读取/存储的激活进行量化的权重并进行量化以进行计算)

  2. 静态量化(权重量化、激活量化、训练后需要校准)

  3. 静态量化感知训练(权重量化、激活量化、训练期间建模的量化数值)

请参阅我们的PyTorch 量化简介博客文章,以更全面地概述这些量化类型之间的权衡。

动态和静态量化之间的运算符覆盖范围有所不同,如下表所示。请注意,对于 FX 量化,还支持相应的泛函。

 Eager Mode 量化

有关量化流程的一般介绍,包括不同类型的量化,请查看一般量化流程。

训练后动态量化

这是最简单的量化应用形式,其中权重提前量化,但激活在推理过程中动态量化。这用于模型执行时间主要由从内存加载权重而不是计算矩阵乘法的情况。对于小批量的 LSTM 和 Transformer 类型模型来说确实如此。

# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                 /
linear_weight_fp32

# dynamically quantized model
# linear and LSTM weights are in int8
previous_layer_fp32 -- linear_int8_w_fp32_inp -- activation_fp32 -- next_layer_fp32
                     /
   linear_weight_int8

动态量化 PTDQ API 示例:

import torch

# define a floating point model
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = self.fc(x)
        return x

# create a model instance
model_fp32 = M()
# create a quantized model instance
model_int8 = torch.ao.quantization.quantize_dynamic(
    model_fp32,  # the original model
    {torch.nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)

训练后静态量化

训练后静态量化(PTQ static)量化模型的权重和激活。它尽可能将激活融合到前面的层中。它需要使用代表性数据集进行校准,以确定激活的最佳量化参数。当内存带宽和计算节省都很重要且 CNN 是典型用例时,通常会使用训练后静态量化。

# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                    /
    linear_weight_fp32

# statically quantized model
# weights and activations are in int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
                    /
  linear_weight_int8

静态量化 PTSQ API 示例:

import torch

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')

# Fuse the activations to preceding layers, where applicable.
# This needs to be done manually depending on the model architecture.
# Common fusions include `conv + relu` and `conv + batchnorm + relu`
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)

# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

静态量化的量化感知训练

量化感知训练 (QAT) 对训练过程中的量化效果进行建模,与其他量化方法相比,具有更高的准确性。我们可以对静态、动态或仅权值量化进行 QAT。在训练期间,所有计算均以浮点形式完成,fake_quant 模块通过钳位和舍入对量化效果进行建模,以模拟 INT8 的效果。模型转换后,权重和激活被量化,并且激活尽可能融合到前一层中。它通常与 CNN 一起使用,并且与静态量化相比具有更高的精度。

# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                      /
    linear_weight_fp32

# model with fake_quants for modeling quantization numerics during training
previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32
                           /
   linear_weight_fp32 -- fq

# quantized model
# weights and activations are in int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
                     /
   linear_weight_int8

QAT API 示例:

import torch

# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.bn = torch.nn.BatchNorm2d(1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval for fusion to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')

# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,
    [['conv', 'bn', 'relu']])

# Prepare the model for QAT. This inserts observers and fake_quants in
# the model needs to be set to train for QAT logic to work
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train())

# run the training loop (not shown)
training_loop(model_fp32_prepared)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

静态量化的模型准备

目前有必要在 Eager 模式量化之前对模型定义进行一些修改。这是因为当前量化是逐个模块进行的。具体来说,对于所有量化技术,用户需要:

  1. 将任何需要输出重新量化(因此具有附加参数)的操作从泛函转换为模块形式(例如,使用torch.nn.ReLU代替torch.nn.functional.relu)。

  2. .qconfig通过在子模块上分配属性或指定 来指定模型的哪些部分需要量化 qconfig_mapping。例如,设置意味着该 图层不会被量化,设置 意味着将使用的量化设置而不是全局 qconfig。model.conv1.qconfig = Nonemodel.convmodel.linear1.qconfig = custom_qconfigmodel.linear1custom_qconfig

对于量化激活的静态量化技术,用户还需要执行以下操作:

  1. 指定激活的量化和反量化位置。这是使用 QuantStub和 DeQuantStub模块完成的。

  2. 用于FloatFunctional将需要特殊处理量化的张量运算包装到模块中。例如,诸如addcat之类的操作需要特殊处理来确定输出量化参数。

  3. 熔断模块:将操作/模块组合成单个模块以获得更高的精度和性能。这是使用 fuse_modules()API 完成的,它接收要融合的模块列表。我们目前支持以下融合:[Conv, Relu]、[Conv, BatchNorm]、[Conv, BatchNorm, Relu]、[Linear, Relu]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/700079.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OnlyOffice7.3.2.8 和 7.3.3.40 Docker魔改版搭建教程

写在前面: 搭建项目所需环境只需查看此篇文档即可,中间有不清楚的地方评论即可,看到会回复! ⼀ 、安装docker: 具体步骤请参照docker官网: https://docs.docker.com/get-docker/ 注: 不同操作系统安装方…

今天分享:配音软件哪个好

在数字化时代,视频内容的需求愈发增长。越来越多的人们创作和分享各种类型的视频,而其中一个重要的元素是声音。然而,有时候我们可能面临着一种情况:我们拍摄了一段令人惊艳的视频,但缺乏适合的配音或原声录音。这时&a…

人工智能时代,前端如何抓住机会

自从 2022 年底 OpenAI 推出了 ChatGPT3.5 后,GPT 的活跃用户数快速突破一亿,打破了互联网应用发展的历史记录。ChatGPT是一种基于人工智能技术的聊天机器人,它可以理解人类的自然语言,模拟人类的语言和思维方式,与人类…

介绍几款在线编程工具(Python)

阅读原文 有时候个人电脑不在身边,又需要处理一些工作,这时候可能需要在朋友的电脑或者公用电脑上操作数据。又或者要将自己写的代码以 notebook 的形式分享给 co-worker,这时就需要用到以下总结的几个直接在浏览器里进行 Python 编程的工具。…

怎么学习数据库连接与操作? - 易智编译EaseEditing

学习数据库连接与操作可以按照以下步骤进行: 理解数据库基础知识: 在学习数据库连接与操作之前,首先要了解数据库的基本概念、组成部分和工作原理。 学习关系型数据库和非关系型数据库的区别,了解常见的数据库管理系统&#xff…

为何及如何使用数据结构提升算法效率和问题解决能力?

数据结构是计算机科学中的一个重要概念,它是一种组织和存储数据的方式。数据结构提供了一种在计算机程序中有效地组织和操作数据的方法。 数据结构的主要目的是解决问题和优化算法。它们帮助我们在计算机内存中存储和组织数据,以便能够高效地访问和操作这…

Flink 自定义源算子之 读取MySQL

1、功能说明: 在Flink 自定义源算子中封装jdbc来读取MySQL中的数据 2、代码示例 Flink版本说明:flink_1.13.0、scala_2.12 自定义Source算子,这里我们继承RichParallelSourceFunction,因为要使用open方法来初始化数据库连接对…

Egg.js阿里JS后端框架,可以放心用。

目录 一、快速开始 二、尝试创建一个controll,修改路由,然后检查测试单元。 一、快速开始 npm install -g yarn yarn create egg --typesimple cd egg yarn install yarn devhttp://127.0.0.1:7001 二、尝试创建一个controll,修改路由,然后检查测试单…

PDF怎么转图片?PDF转图片的方法分享!​

PDF怎么转图片呢?相信很多人都会觉得PDF的非常的好用,小编也是被身边很多朋友推荐过后用了这个软件。但很多人在使用的时候有疑问,比如说PDF如何转图片?这难倒不少人,那么今天这篇文章就带你解析PDF怎么转图片&#xf…

LiangGaRy-学习笔记-Day25

1、Apache web相关 1.1、curl命令 作用:用来与服务器之间传输数据的工具 官网:https://curl.se支持很多种协议 语法:curl选项网址 选项: -A:设置代理给服务器-I(大写i):输出返…

Spring Boot 中的 WebMvc 是什么,原理,如何使用

Spring Boot 中的 WebMvc 是什么,原理,如何使用 介绍 在 Spring Boot 中,WebMvc 是非常重要的一个模块。它提供了一系列用于处理 Web 请求的组件和工具。在本文中,我们将介绍 Spring Boot 中的 WebMvc 是什么,其原理…

Python+ddt+Excel实现接口自动化测试生成完美测试报告

接口自动化测试是指通过编写代码或使用工具,模拟用户发送请求,验证接口是否符合设计规范和功能需求的过程。” 如何用 python ddtexcel 实现接口自动化测试 接口自动化测试可以提高测试效率和质量,节省测试成本和时间,保证测试覆…

一步一步学OAK之八:通过OAK相机实现视频帧拼接

帧拼接在有些场景下非常有用,比如将一个较大的帧输入到尺寸较小的神经网络中时。可以将较大的帧拆分成多个较小的帧,并将这些较小的帧输入到神经网络中。 这里我们使用 2 个 ImageManip 将原始预览帧拆分为两个帧。 这里写目录标题 涉及到的节点内容Co…

STM32实战项目—停车计费系统

文章目录 一、任务要求1.1 概述1.2 串口收发1.2.1 串口输出内容1.2.2 串口接收内容 1.3 说明 二、实现思路2.1 指令判别2.1 车辆进入2.2 车辆驶出2.3 费率调整 三、程序设计3.1 串口接收消息处理3.2 车辆驶入处理函数3.3 车辆驶出处理函数3.4 费率调整处理函数 题目原型是第十二…

4-Python如何创建等比数列?【视频版】

目录 问题视频解答 问题 视频解答 点击观看: 4-如何创建等比数列?

windows无法启动RemoteDesktopServices服务(位于本地计算机上)。错误126:找不到指定的模块。

win10的搜索栏输入 注册表编辑器。打开,找到如下路径 计算机\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\TermService\Parameters 将指定数值项ServiceDll的数值数据改成默认值: %SystemRoot%\System32\termsrv.dll 再重新尝试就好了。 …

Java(七):项目部署

项目部署 运行容器解决Centos8中yum命令遇到的问题打包项目拷贝.jar到容器中安装jdk后台运行.jar后台运行.jar并输入日志实时查看日志查看/杀死运行程序目录结构日志配置 运行容器 $ docker run -d -p 8001:8001 -p 8081:8081 -p 8082:8082 --namelocal_centos --privilegedtr…

DigiCert SSL证书有什么优势?

DigiCert是全球领先的SSL证书颁发机构,2017年收购赛门铁克数字证书业务后,成为全球市场占有率领先的SSL证书提供商,提供高保证TLS/SSL证书和自动化解决方案,也是沃通CA在全球信任数字证书业务方面的重要合作伙伴。 与全球其他品牌…

AI聊天对话工具,让沟通更简单轻松

人工智能技术的发展不断为我们带来新的惊喜和变革,其中之一就是ai聊天对话应用。这种应用利用自然语言处理、机器学习和对话管理等技术,在智能手机、电脑等设备上实现了人机对话,让人们更轻松地与计算机之间进行交流和互动。随着移动互联网的…

基于SpringBoot的电脑商城项目

基于SpringBoot的电脑商城项目 — 参考B站袁庭新老师项目 Maven多聚合项目 技术栈: Spring boot、MyBtis、Bootstrap、Layui,Redis,内网穿透等 功能: 后台:管理员登录、商品信息管理、订单管理、用户信息管理、统计图…