Pytorch量化之Post Train Static Quantization(训练后静态量化)

news2024/9/21 16:34:12

使用Pytorch训练出的模型权重为fp32,部署时,为了加快速度,一般会将模型量化至int8。与fp32相比,int8模型的大小为原来的1/4, 速度为2~4倍。
Pytorch支持三种量化方式:

  • 动态量化(Dynamic Quantization): 只量化权重,激活在推理过程中进行量化
  • 静态量化(Static Quantization): 量化权重和激活
  • 量化感知训练(Quantization Aware Training,QAT): 插入量化算子后进行训练,主要在静态量化精度不满足需求时进行。
    大多数情况下,我们只需要进行静态量化,少数情况下在量化感知训练不满足时使用QAT进行微调。所以本篇只重点讲静态量化,并且理论部分先略过(后面再专门总结),只关注实操。
    注:下面的代码是在pytorch1.10下,后面Pytorch对量化的接口有调整
    官方文档:Quantization — PyTorch 1.10 documentation

动态模式(Eager Mode)与静态模式(fx graph)

Pytorch支持用2种方式量化,一种是动态图模式,也是我们日常使用Pytorch训练所使用的方式,使用这种方式量化需要自己手动修改网络结构,在支持量化的算子前、后插入量化节点,优点是方便调试。静态模式则是由pytorch自动在计算图中插入量化节点,不需要手动修改网络。
网络上大部分的教程都是基于静态模式,这种方式比较大的问题就是需要手动修改网络结构,官方教程里的网络是属于demo型, 其中的QuantStub和DeQuantStub就分别是量化和反量化的节点:

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

Pytorch对于很多网络层是不支持量化的(比如很常用的Prelu),如果我们用这种方式,我们就必须在这些不支持的层前面插入DeQuantStub,然后在支持的层前面插入QuantStub。笔者体验下来,体验很差,个人觉得不太实用,会破坏原来的网络结构。
而静态图模式,我们只需要调用Pytorch提供的接口将原模型转换一下即可,不需要修改原来的网络结构文件,个人认为实用性更强。
image.png

静态模式量化

1. 载入fp32模型,并转成fx graph

其中量化参数有‘fbgemm’和‘qnnpack’两种,前者在x86运行,后者在arm运行。

model_fp32 = torch.load(xxx)
model_fp32_quantize = copy.deepcopy(model_fp32)
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
model_fp32_quantize.eval()
# prepare

model_prepared = quantize_fx.prepare_fx(model_fp32_quantize, qconfig_dict)
model_prepared.eval()

2.读取量化数据,标定(Calibration)量化参数

标定的过程就是使用模型推理量化图片,然后统计权重和激活分布,从而得到量化参数。量化图片一般来源于训练集(几百张左右,根据测试情况调整)。量化图片可以通过Pytorch的Dataloader读取,也可以直接自行实现读图片然后送入网络。

### 使用dataloader读取
for i, (data, label) in enumerate(train_loader):
    data = data.to(torch.device("cpu:0"))
    outputs = model_prepared(data)
    print("calibrating {}".format(i))
    if i > 1000:
        break

3. 转换为量化模型并保存

quantized_model = quantize_fx.convert_fx(model_prepared)
torch.jit.save(torch.jit.script(quantized_model), "quantized_model.pt")

速度测试

量化后的模型使用方法与fp32模型一样:

import torch
import cv2
import numpy as np
torch.set_num_threads(1)

fused_model = torch.jit.load("jit_model.pt")
fused_model.eval()
fused_model.to(torch.device("cpu:0"))

img = cv2.imread("./1.png")
img_fp32 = img.astype(np.float32)
img_fp32 = (img_fp32-127.5) / 127.5
input = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()

def speed_test(model, input):
    # warm up
    for i in range(10):
        model(input)

    import time
    start = time.time()
    for i in range(100):
        model(input)
    end = time.time()
    print("model time: ", (end-start)/100)
    time.sleep(10)

# quantized model
quantized_model= torch.jit.load("quantized_model.pt")
quantized_model.eval()
quantized_model.to(torch.device("cpu:0"))

speed_test(fused_model, input)
speed_test(quantized_model, input)

实测fp32模型单核运行120ms, 量化后47ms

结语

本文介绍了fx graph模式下的Pytorch的PTSQ方法,并实测了一个模型,效果还比较不错。
1_995567224_161_79_3_732056265_62005da0d7c1b531a6cf91ea587d312e.jpg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/855900.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

最大异或对

如果你觉得这篇题解对你有用,可以点个赞或关注再走呗,谢谢你的关注~ 分析 最大异或对 (1)最大异或对是运用trie树存储十进制数对应的二进制数的每一位。 (2)再根据trie树的每一位进行搜索查找,严格满足不同的数异或为1,相同的异…

【业余小练习】交互式网格自定义增删改(进行中)

学习SQL和PLISQL数据类型的区别和应用场景 Oracle plsql 基础篇1 数据类型以及流程控制_bb_tarek的博客-CSDN博客https://blog.csdn.net/bb_tarek/article/details/17555713?ops_request_misc&request_id&biz_id102&utm_termplsql%E5%9F%BA%E6%9C%AC%E6%95%B0%E6…

Unlikely argument type for equals(): String seems to be unrelated to T

Unlikely argument type for equals(): String seems to be unrelated to Integer Unlikely argument type for equals(): String seems to be unrelated to Date 多余代码

java代码审计9之XXE

文章目录 1、简介2、 java XXE审计函数3、漏洞3.1、正常的业务3.2、有回显的情况3.3、无回显的情况3.4、修复 之前的文章, php代码审计9之XXE 1、简介 XXE(XML外部实体注⼊,XML External Entity) ,在应⽤程序解析XML输⼊时&…

【雕爷学编程】Arduino动手做(200)---WS2812B幻彩LED灯带4

37款传感器与模块的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&#x…

linux系统虚拟主机开启支持SourceGuardian(sg11)加密组件

注意:sg11我司只支持linux系统虚拟主机自主安装。支持php5.3及以上版本。 1、登陆主机控制面板,找到【远程文件下载】这个功能。 2、远程下载文件填写http://download.myhostadmin.net/vps/sg11_for_linux.zip 下载保存的路径填写/others/ 3、点击控制…

golang 自定义exporter - 端口连接数 portConnCount_exporter

需求: 1、计算当前6379 、3306 服务的连接数 2、可prometheus 语法查询 下面代码可直接使用: 注: 1、windows 与linux的区分 第38行代码 localAddr : fields[1] //windows为fields[1] , linux为fields[3] 2、如需求 增加/修改/删除…

PHP实现在线进制转换器,10进制,2、4、8、16、32进制转换

1.接口文档 2.laravel实现代码 /*** 进制转换计算器* return \Illuminate\Http\JsonResponse*/public function binaryConvertCal(){$ten $this->request(ten);$two $this->request(two);$four $this->request(four);$eight $this->request(eight);$sixteen …

JavaScript基础 第二天

1. 运算符 2. 语句 一.运算符 1.赋值运算符 2.一元运算符 3.比较运算符 4.逻辑运算符 5.运算符优先级 1.1 赋值运算符 概念:对变量进行赋值的运算符 赋值运算符: - * / % 1.2 一元运算符 可以根据表达式的个数,分为一…

数据结构【第4章】——栈与队列

队列是只允许在一端进行插入操作、而在另-端进行删除操作的线性表。 栈 栈与队列:栈是限定仅在表尾进行插入和删除操作的线性表。 我们把允许插入和删除的一端称为栈顶(top),另一端称为栈底(bottom)&…

提升客户满意度的创意项目管理软件推荐!

发现功能强大的工作管理软件,让创意大放异彩。将您团队的愿景变成引人注目的项目。 一、交付总是令人印象深刻的工作 Zoho Projects的创意项目管理软件可帮助您和您的团队在一个地方监督多个项目。使用我们的内置管理工具和模板,花更少的时间在管理上&a…

postman如何添加token

参考博客:https://blog.csdn.net/Mrbignose/article/details/107237581 1.添加token: 2.设置token: 3.发送时携带token:

【JavaEE】懒人的福音-MyBatis框架—介绍、搭建环境以及初步感受

【JavaEE】MyBatis框架要点总结(1) 文章目录 【JavaEE】MyBatis框架要点总结(1)1. MyBatis是什么?2. 搭建MyBatis的开发环境2.0 MySQL建库建表2.1 新项目添加MyBatis框架2.2 设置MyBatis的配置2.2.1 设置数据库的连接信…

图像的平移变换之c++实现(qt + 不调包)

1.基本原理 设dx为水平偏移量&#xff0c;dy为垂直偏移量&#xff0c;则平移变换的坐标映射关系为下公式&#xff0c;图像平移一般有两种方式。 1.不改变图像大小的平移&#xff08;一旦平移&#xff0c;相应内容被截掉&#xff09; 1&#xff09;当dx > width、dx < -wi…

【云原生】Kubernetes节点亲和性分配 Pod

目录 1 给节点添加标签 2 根据选择节点标签指派 pod 到指定节点[nodeSelector] 3 根据节点名称指派 pod 到指定节点[nodeName] 4 根据 亲和性和反亲和性 指派 pod 到指定节点 5 节点亲和性权重 6 pod 间亲和性和反亲和性及权重 7 污点和容忍度 8 Pod 拓扑分布约束 官方…

flinksql sink to sr often fail because of nullpoint

flinksql or DS sink to starrocks often fail because of nullpoint flink sql 和 flink ds sink starrocks 经常报NullpointException重新编译代码 并上传到flink 集群 验证&#xff0c;有效 flink sql 和 flink ds sink starrocks 经常报NullpointException 使用flink-sta…

【EI复现】售电市场环境下电力用户选择售电公司行为研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

AWS——04篇(AWS之Amazon S3(云中可扩展存储)-02——EC2访问S3存储桶)

AWS——04篇&#xff08;AWS之Amazon S3&#xff08;云中可扩展存储&#xff09;-02——EC2访问S3存储桶&#xff09; 1. 前言2. 创建EC2实例 S3存储桶3. 创建IAM角色4. 修改EC2的IAM 角色5. 连接EC2查看效果5.1 连接EC25.2 简单测试5.2.1 查看桶内存储情况5.2.2 复制本地文件…

docker中的jenkins之流水线构建

docker中的jenkins之流水线构建项目 1、用node这种方式&#xff08;因为我用pipeline方式一直不执行&#xff0c;不知道为什么&#xff09; 2、创建项目 创建两个参数&#xff0c;一个是宿主端口号&#xff0c;一个是docker中的端口号 3、使用git项目中的Jenkinsfile 4、编写…

Android安卓实战项目(11)—每个步骤带有动画演示功能的线上运动APP,可计算每日运动卡路里(源码在文末)

Android安卓实战项目&#xff08;11&#xff09;—每个步骤带有动画演示功能的线上运动APP&#xff0c;可计算每日运动卡路里&#xff08;源码在文末&#x1f415;&#x1f415;&#x1f415;&#xff09; 【bilibili演示】 https://www.bilibili.com/video/BV1bk4y1g7Wo/?sh…