pytorch量化库使用(2)

news2024/11/17 23:28:27

FX Graph Mode量化模式

训练后量化有多种量化类型(仅权重、动态和静态),配置通过qconfig_mapping ( prepare_fx函数的参数)完成。

FXPTQ API 示例:

import torch
from torch.ao.quantization import (
  get_default_qconfig_mapping,
  get_default_qat_qconfig_mapping,
  QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copy

model_fp = UserModel()

#
# post training dynamic/weight_only quantization
#

# we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
# a tuple of one or more example inputs are needed to trace the model
example_inputs = (input_fp32)
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# no calibration needed when we only have dynamic/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# post training static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# quantization aware training for static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)

量化堆栈

量化是将浮点模型转换为量化模型的过程。因此,在高层次上,量化堆栈可以分为两部分:1)。量化模型的构建块或抽象 2)。将浮点模型转换为量化模型的量化流程的构建块或抽象

量化模型

量化张量

为了在 PyTorch 中进行量化,我们需要能够用张量表示量化数据。量化张量允许存储量化数据(表示为 int8/uint8/int32)以及量化参数(如比例和 Zero_point)。除了允许以量化格式序列化数据之外,量化张量还允许许多有用的操作,使量化算术变得容易。

PyTorch 支持每张量和每通道的对称和非对称量化。每个张量意味着张量内的所有值都使用相同的量化参数以相同的方式量化。每个通道意味着对于每个维度(通常是张量的通道维度),张量中的值使用不同的量化参数进行量化。这可以减少将张量转换为量化值时的错误,因为异常值只会影响其所在的通道,而不是整个张量。

映射是通过使用转换浮点张量来执行的

 

 

 

请注意,我们确保浮点中的零在量化后表示没有错误,从而确保诸如填充之类的操作不会导致额外的量化误差。

以下是量化张量的几个关键属性:

  • QScheme (torch.qscheme):一个枚举,指定我们量化张量的方式

    • torch.per_tensor_affine

    • torch.per_tensor_对称

    • torch.per_channel_affine

    • torch.per_channel_symmetry

  • dtype (torch.dtype):量化张量的数据类型

    • 火炬.quint8

    • 火炬.qint8

    • 火炬.qint32

    • 火炬.float16

  • 量化参数(根据 QScheme 的不同而变化):所选量化方式的参数

    • torch.per_tensor_affine 的量化参数为

      • 刻度(浮动)

      • 零点(整数)

    • torch.per_channel_affine 的量化参数为

      • per_channel_scales(浮点数列表)

      • per_channel_zero_points(整数列表)

      • 轴(整数)

量化和反量化

模型的输入和输出都是浮点张量,但量化模型中的激活是量化的,因此我们需要运算符在浮点和量化张量之间进行转换。

  • 量化(浮点 -> 量化)

    • torch.quantize_per_tensor(x, 尺度, 零点, dtype)

    • torch.quantize_per_channel(x, 尺度, Zero_points, 轴, dtype)

    • torch.quantize_per_tensor_dynamic(x,dtype,reduce_range)

    • 到(火炬.float16)

  • 反量化(量化 -> 浮点)

    • quantized_tensor.dequantize() - 在 torch.float16 张量上调用 dequantize 会将张量转换回 torch.float

    • 火炬.反量化(x)

量化运算符/模块

  • 量化算子是以量化Tensor为输入,输出量化Tensor的算子。

  • 量化模块是执行量化操作的 PyTorch 模块。它们通常是为线性和卷积等加权运算定义的。

量化引擎

当执行量化模型时,qengine (torch.backends.quantized.engine) 指定使用哪个后端来执行。重要的是要确保qengine在量化激活和权重的取值范围方面与量化模型兼容。

量化流程

观察者和 FakeQuantize

  • 观察者是 PyTorch 模块,用于:

    • 收集张量统计信息,例如通过观察者的张量的最小值和最大值

    • 并根据收集的张量统计数据计算量化参数

  • FakeQuantize 是 PyTorch 模块,用于:

    • 模拟网络中张量的量化(执行量化/反量化)

    • 它可以根据观察者收集的统计数据计算量化参数,也可以学习量化参数

查询配置

  • QConfig 是 Observer 或 FakeQuantize Module 类的命名元组,可以使用 qscheme、dtype 等进行配置。它用于配置应如何观察操作员

    • 算子/模块的量化配置

      • 不同类型的 Observer/FakeQuantize

      • 数据类型

      • q方案

      • quant_min/quant_max:可用于模拟较低精度的张量

    • 目前支持激活和权重的配置

    • 我们根据为给定运算符或模块配置的 qconfig 插入输入/权重/输出观察器

一般量化流程

一般来说,流程如下

  • 准备

    • 根据用户指定的 qconfig 插入 Observer/FakeQuantize 模块

  • 校准/训练(取决于训练后量化或量化感知训练)

    • 允许观察者收集统计数据或 FakeQuantize 模块来学习量化参数

  • 转变

    • 将校准/训练模型转换为量化模型

量化有不同的模式,它们可以分为两种方式:

就我们应用量化流程的位置而言,我们有:

  1. Post Training Quantization(训练后应用量化,量化参数根据样本校准数据计算)

  2. 量化感知训练(在训练过程中模拟量化,以便使用训练数据与模型一起学习量化参数)

就我们如何量化运算符而言,我们可以:

  • 仅权重量化(仅权重静态量化)

  • 动态量化(权重静态量化,激活动态量化)

  • 静态量化(权重和激活都是静态量化的)

我们可以在同一量化流程中混合不同的量化运算符方式。例如,我们可以进行具有静态和动态量化运算符的训练后量化。

量化支持矩阵

 

量化定制

虽然提供了观察者根据观察到的张量数据选择比例因子和偏差的默认实现,但开发人员可以提供自己的量化函数。量化可以选择性地应用于模型的不同部分,或者针对模型的不同部分进行不同的配置。

我们还为conv1d()conv2d()、 conv3d()Linear()的每通道量化提供支持。

量化工作流程通过在模型的模块层次结构中添加(例如,将观察者添加为 .observer子模块)或替换(例如,转换nn.Conv2d为 nn.quantized.Conv2d)子模块来工作。这意味着该模型nn.Module在整个过程中保持基于常规的实例,因此可以与 PyTorch API 的其余部分一起使用。

量化自定义模块 API

Eager 模式和 FX 图形模式量化 API 都为用户提供了一个钩子,以指定以自定义方式量化的模块,并使用用户定义的逻辑进行观察和量化。用户需要指定:

  1. 源 fp32 模块的 Python 类型(模型中存在)

  2. 被观察模块的Python类型(由用户提供)。该模块需要定义一个from_float函数,该函数定义如何从原始 fp32 模块创建观察到的模块。

  3. 量化模块的Python类型(由用户提供)。该模块需要定义一个from_observed函数,该函数定义如何从观察到的模块创建量化模块。

  4. 描述上述 (1)、(2)、(3) 的配置,传递给量化 API。

然后框架将执行以下操作:

  1. 在准备模块交换期间,它将使用 (2) 中类的from_float函数将 (1) 中指定类型的每个模块转换为 (2) 中指定的类型。

  2. 在转换模块交换期间,它将使用 (3) 中类的from_observed函数将 (2) 中指定类型的每个模块转换为(3) 中指定的类型。

目前,要求ObservedCustomModule将具有单个 Tensor 输出,并且框架(而不是用户)将在该输出上添加观察者。观察者将作为自定义模块实例的属性存储在activation_post_process键下。未来可能会放宽这些限制。

自定义 API 示例:

import torch
import torch.ao.nn.quantized as nnq
from torch.ao.quantization import QConfigMapping
import torch.ao.quantization.quantize_fx

# original fp32 module to replace
class CustomModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        return self.linear(x)

# custom observed module, provided by user
class ObservedCustomModule(torch.nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear(x)

    @classmethod
    def from_float(cls, float_module):
        assert hasattr(float_module, 'qconfig')
        observed = cls(float_module.linear)
        observed.qconfig = float_module.qconfig
        return observed

# custom quantized module, provided by user
class StaticQuantCustomModule(torch.nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear(x)

    @classmethod
    def from_observed(cls, observed_module):
        assert hasattr(observed_module, 'qconfig')
        assert hasattr(observed_module, 'activation_post_process')
        observed_module.linear.activation_post_process = \
            observed_module.activation_post_process
        quantized = cls(nnq.Linear.from_float(observed_module.linear))
        return quantized

#
# example API call (Eager mode quantization)
#

m = torch.nn.Sequential(CustomModule()).eval()
prepare_custom_config_dict = {
    "float_to_observed_custom_module_class": {
        CustomModule: ObservedCustomModule
    }
}
convert_custom_config_dict = {
    "observed_to_quantized_custom_module_class": {
        ObservedCustomModule: StaticQuantCustomModule
    }
}
m.qconfig = torch.ao.quantization.default_qconfig
mp = torch.ao.quantization.prepare(
    m, prepare_custom_config_dict=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.ao.quantization.convert(
    mp, convert_custom_config_dict=convert_custom_config_dict)
#
# example API call (FX graph mode quantization)
#
m = torch.nn.Sequential(CustomModule()).eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_qconfig)
prepare_custom_config_dict = {
    "float_to_observed_custom_module_class": {
        "static": {
            CustomModule: ObservedCustomModule,
        }
    }
}
convert_custom_config_dict = {
    "observed_to_quantized_custom_module_class": {
        "static": {
            ObservedCustomModule: StaticQuantCustomModule,
        }
    }
}
mp = torch.ao.quantization.quantize_fx.prepare_fx(
    m, qconfig_mapping, torch.randn(3,3), prepare_custom_config=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.ao.quantization.quantize_fx.convert_fx(
    mp, convert_custom_config=convert_custom_config_dict)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/700224.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ACL2023 | 赔了?引入GPT-3大模型到智能客服,竟要赔钱?

作者 | 小戏、Python 关于大模型的商业落地,一个非常容易想到的场景就是智能客服,作为不止是大模型也是 NLP 领域的一个最主要的应用场景,由于人工客服的高昂成本,AI 客服伴随着模型技术的发展也逐步走进我们的生活,在…

vue iview table Tooltip内容过多闪屏解决

vue的项目,框架是iview 客户反应,指令描述的内容有几百个字,鼠标悬浮,浏览器开始闪烁 解决思路是加宽度限制,滚动, 后面发现像是transfer属性的bug,触碰浏览器底部,距离不够造成 …

重新理解 RocketMQ Commit Log 存储协议

最近突然感觉:很多软件、硬件在设计上是有 root reason 的,不是 by desgin 如此,而是解决了那时、那个场景的那个需求。一旦了解后,就会感觉在和设计者对话,了解他们的思路,学习他们的方法,思维…

C#可视化 国产热剧信息查询(具体做法及全部代码)

目录 题目: 做法: 代码部分: DBHelper类 From1主窗体代码 题目: 1. 首次打开页面,展示所有汽车信息列表,如图 1 所示。 2.双击第二行右边内容全部发生改变 数据库设计及内容 做法: 首先设置d…

React hooks文档笔记(五)useEffect——解决异步操作竞争问题

1.开发环境下组件加载两次? 非bug,重新安装组件仅在开发过程中发生,帮助找到需要清理的效果。在生产环境中只会加载一次。 React 将在 Effect 下次运行之前以及卸载期间调用您的清理函数。return () > {}; 2. 🌰订阅事件情况…

Python连接MySQL数据库(简单便捷)

🐒,本文中,使用到的工具有:Pycharm,Anaconda,MySQL 5.5,spyder(Anaconda) 什么是 PyMySQL? PyMySQL 是在 Python3.x 版本中用于连接 MySQL 服务器的一个库,Python2 中则…

Java 语言基础练习题

Java 语言基础练习题 Key Point ●包的基本语法 ●Java 语言中的标识符,命名规范 ●八种基本类型 ●基本操作符 ●if 语句和switch 语句 练习 1.(标识符命名)下面几个变量中,那些是对的?那些是错的?错的请…

C++学习 程序控制结构

程序控制结构 以某种顺序执行的一系列动作,用于解决某个问题。包括 顺序结构、选择结构、循环结构。 顺序结构 按照顺序正常执行。前几篇文章的代码都是顺序结构的体现。 选择结构 执行满足条件的语句。 if 结构:if (表达式){} 表达式为真则执行&…

Linux历史及环境搭建(VMware搭建CentOS7环境)

Linux历史及环境搭建 1.Linux历史1.1 UNIX发展的历史1.2 Linux发展历史1.2.1 开源1.2.2 官网1.2.3 发行版本 2.VMware配置CentOS7环境2.1 CentOS下载2.2 配置环境2.3 切换国内阿里源2.4 无图形化界面开机 结语 1.Linux历史 在这里简要介绍Linux的发展史。要说 Linux&#xff0…

机器学习李宏毅学习笔记36

文章目录 前言Meta learning应用总结 前言 Meta learning(二)应用方向 Meta learning应用 回顾gradient descen Θ0(initial的参数)是可以训练的,一个好的初始化参数和普通的是有很大差距的。可以通过一些训练的任务…

Python通过私信消息提取博主的赠书活动地址

文章目录 前言背景设计开发1.引入模块2.获取私信内容3.根据文本提取url的方法4.获取包含‘书’的url5.程序入口 效果总结最后 前言 博主空空star主页空空star的主页 大家好,我是空空star,本篇给大家分享一下《通过私信消息提取博主的赠书活动地址》。 背…

通用策略04丨ORB魔改框架+自适应动量过滤模板

量化策略开发,高质量社群,交易思路分享等相关内容 大家好,今天我们分享2023年度第4期通用策略——ORB魔改框架自适应动量过滤模板。 本期策略是2023年通用系列第4篇。本期主要内容有对ORB原版的逻辑魔改,其次我们将跨日周期均线过…

现在有一个未分库分表的系统,未来要分库分表,如何设计才可以让系统从未分库分表动态切换到分库分表上?

停机迁移方案 最 low 的方案,就是很简单,大家伙儿凌晨 12 点开始运维,网站或者 app 挂 个公告,说 0 点到早上 6 点进行运维,无法访问。 接着到 0 点停机,系统停掉,没有流量写入了,…

设计一个高流量高并发的系统需要关注哪些点

1、设计原则 1.1、系统设计原则 在设计一个系统之前,我们先要有一个统一且清晰的认知:不要想着一下就能设计出完美的系统,好的系统是迭代出来的。不要复杂化,要先解决核心问题。但是要有先行的规划,对现有的问题有方…

字符与代表数据的转化

目的 在与设备交互当中,大都以十六进制的数进行交互。 而显示给用户时,是以字符的形式显示。 这中间就需要字符与其所代表的数值的转化,比如: ‘0F’---->0x0F 这怎么实现呢,一个是字符,另一个是数字&a…

Apache seatunnel集群部署

跳转到安装目录 cd /opt/soft/seatunnel 1.设置环境变量 export SEATUNNEL_HOME/opt/soft/seatunnel export PATH$PATH:$SEATUNNEL_HOME/bin 启动服务端 ./bin/seatunnel-cluster.sh -d 启动客户端 ./bin/seatunnel.sh --config ./config/kafka2gbase_udf.conf 这样就启…

Vue3 数字滚动插件 vue-countup-v3

文章目录 介绍效果安装属性事件配置项完整样例 介绍 vue-countup-v3 插件是一个基于 Vue3 的数字动画插件,用于在网站或应用程序中创建带有数字动画效果的计数器。通过该插件,我们可以轻松地实现数字的递增或递减动画,并自定义其样式和动画效…

软件测试职业发展的7个阶段,哪个都吃香!

首先谈谈我在软件测试行业的亲身经历:我的一位同事曾经很认真地问过我一个问题,他说他现在从事软件测试工作已经4年了,但是他不知道现在的工作和自己在工作3年时有什么不同,他想旁观者清,也许我能回答他的问题。此外他…

手写vue-diff算法(一)

Vue初始化流程 1.Vue流程图 Vue流程图: Vue的初始化流程,默认会创建一个Vue实例,执行初始化、挂载、模板编译操作,模板被编译成为render函数;在render函数初始化时会执行取值操作,从而进入getter方法对当…

【科研入门】会议、期刊、出版社、文献数据库、引文数据库、SCI分区、影响因子等基础科研必备知识

大家好,我是洲洲,欢迎关注,一个爱听周杰伦的程序员。关注公众号【程序员洲洲】即可获得10G学习资料、面试笔记、大厂独家学习体系路线等…还可以加入技术交流群欢迎大家在CSDN后台私信我! 本文目录 一、会议与期刊二、如何辨别是否…