模型部署:量化中的Post-Training-Quantization(PTQ)和Quantization-Aware-Training(QAT)

news2024/11/17 3:57:47

模型部署:量化中的Post-Training-Quantization(PTQ)和Quantization-Aware-Training(QAT)

  • 前言
  • 量化
    • Post-Training-Quantization(PTQ)
    • Quantization-Aware-Training(QAT)
  • 参考文献

前言

随着人工智能的不断发展,深度学习网络被广泛应用于图像处理、自然语言处理等实际场景,将其部署至多种不同设备的需求也日益增加。然而,常见的深度学习网络模型通常包含大量参数和数百万的浮点数运算(例如ResNet50具有95MB的参数以及38亿浮点数运算),实时地运行这些模型需要消耗大量内存和算力,这使得它们难以部署到资源受限且需要满足实时性、低功耗等要求的边缘设备。为了进一步推动深度学习网络模型在移动端或边缘设备中的快速部署,深度学习领域提出了一系列的模型压缩与加速方法:

  • 知识蒸馏(Knowledge distillation):使用教师-学生网络结构,让小型的学生网络模仿大型教师网络的行为,以使得准确率尽可能高的同时,能够获得一个轻量化的网络。
  • 剪枝(Parameter pruning):删除不必要的网络参数,以减少模型的规模和计算复杂度。
  • 低秩分解(Low-rank factorization):将模型的参数矩阵分解为较低秩的小矩阵,以减少模型的复杂度和计算成本。
  • 参数共享(Parameter sharing):将多个层共用一组参数,以减少模型的参数数量。
  • 量化(Quantization):将模型的参数和运算转化为更小的数据类型,以减少内存占用和计算时间。

量化

模型量化(Quantization)是一种将浮点计算转化为定点计算的技术,例如从FP32降低至INT8,主要用于减少模型的计算强度、参数大小以及内存消耗,以提高模型在设备上的推理计算效率,但是也有可能会带来一定的精度损失。

模型量化精度损失的主要原因为量化-反量化(Quantization-Dequantization)过程中取整引起的误差。这里简单介绍一下量化的计算方法,以FP32到INT8的量化为例,量化的核心思想就是将浮点数区间的参数映射到INT8的离散区间中。
量化公式:
q = r s + Z q = \frac{r}{s} + Z q=sr+Z反量化公式:
r = S ( q − Z ) r = S(q-Z) r=S(qZ)其中, r r r 为FP32的浮点数(real value), q q q 为INT8的量化值(quantization value),
S S S Z Z Z 分别为缩放因子(Scale-factor)和零点(Zero-Point)。

量化最重要的便是确定 S S S Z Z Z 的值, S S S Z Z Z 的计算公式如下:
S = r m a x − r m i n q m a x − q m i n S = \frac{r_{max}-r_{min}}{q_{max}-q_{min}} S=qmaxqminrmaxrmin Z = − r m i n S + q m i n Z = -\frac{r_{min}}{S} + q_{min} Z=Srmin+qmin其中, r m a x r_{max} rmax r m i n r_{min} rmin 分别为FP32网络参数最大、最小值, q m a x q_{max} qmax q m i n q_{min} qmin 分别为INT8网络参数最大、最小值。

为了减少量化所带来的精度损失,学者提出了Quantization-Aware-Training(QAT)方法,再介绍此之前,由于Post-Training-Quantization(PTQ)方法也经常在文献中出现,此篇博客将着重介绍这两个方法的含义与区别。
在这里插入图片描述

Post-Training-Quantization(PTQ)

Post-Training-Quantization(PTQ)是目前常用的模型量化方法之一。以INT8量化为例,PTQ方法的处理流程为:

  1. 首先在数据集上以FP32精度进行模型训练,得到训练好的模型;
  2. 使用小部分数据对FP32模型进行采样(Calibration),主要是为了得到网络各层参数的数据分布特性(比如统计最大最小值);
  3. 根据步骤2中的数据分布特性,计算出网络各层 S 和 Z 量化参数;
  4. 使用步骤3中的量化参数对FP32模型进行量化得到INT8模型,并将其部署至推理框架进行推理。

PTQ方法会使用小部分数据集来估计网络各层参数的数据分布,找到合适的S和Z的取值,从而一定程度上降低模型精度损失。然而,论文中指出PTQ方式虽然在大模型上效果较好(例如ResNet101),但是在小模型上经常会有较大的精度损失(例如MobileNet) 不同通道的输出范围相差可能会非常大(大于100x), 对异常值较为敏感。

Quantization-Aware-Training(QAT)

由上文可知PTQ方法中模型的训练和量化是分开的,而Quantization-Aware-Training(QAT)方法则是在模型训练时加入了伪量化节点,用于模拟模型量化时引起的误差,并通过微调使得模型在量化后尽可能减少精度损失。以INT8量化为例,QAT方法的处理流程为:

  1. 首先在数据集上以FP32精度进行模型训练,得到训练好的FP32模型;
  2. 在FP32模型中插入伪量化节点,得到QAT模型,并且在数据集上对QAT模型进行微调(Fine-tuning);
  3. 同PTQ方法中的采样(Calibration),并计算量化参数 S 和 Z ;
  4. 使用步骤3中得到的量化参数对QAT模型进行量化得到INT8模型,并部署至推理框架中进行推理。

在PyTorch中,可以使用 torch.quantization.quantize_dynamic() 方法来执行 QAT。这是一个基本的 QAT 代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import quantize_dynamic, QuantStub, DeQuantStub

# 定义简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.quant(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.dequant(x)
        return x

# 数据加载
# 这里使用 MNIST 数据集作为示例
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform),
                                           batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=False, download=True, transform=transform),
                                          batch_size=64, shuffle=False)

# 定义损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 定义 QAT 训练函数
def train(model, train_loader, criterion, optimizer, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data.view(data.shape[0], -1))
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

# 训练模型
train(model, train_loader, criterion, optimizer, num_epochs=5)

# 在训练完成后执行动态量化
quantized_model = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)

# 评估量化模型
def test(model, test_loader, criterion):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data.view(data.shape[0], -1))
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = correct / total
    print(f'Accuracy of the network on the test images: {accuracy * 100:.2f}%')

# 测试量化模型
test(quantized_model, test_loader, criterion)

上述代码示例中,我使用了一个简单的全连接神经网络,并在训练完成后使用torch.quantization.quantize_dynamic()对模型进行动态量化。在量化之前,我们通过QuantStub()DeQuantStub()添加了量化和反量化的辅助模块。这个示例使用了MNIST数据集,你可以根据你的实际需求替换成其他数据集和模型。

参考文献

量化感知训练(Quantization-aware-training)探索-从原理到实践

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1200751.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Proteus仿真】【51单片机】多路温度控制系统

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器,使用按键、LED、蜂鸣器、LCD1602、DS18B20温度传感器、HC05蓝牙模块等。 主要功能: 系统运行后,默认LCD1602显示前4路采集的温…

JavaScript_动态表格_删除功能

1、动态表格_删除功能 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>动态表格_添加和删除功能</title><style>table{border: 1px solid;margin: auto;width: 100%;}td,th{text-align: …

总结:利用JDK原生命令,制作可执行jar包与依赖jar包

总结&#xff1a;利用JDK原生命令&#xff0c;制作可执行jar包与依赖jar包 一什么是jar包&#xff1f;二制作jar包的工具&#xff1a;JDK原生自带的jar命令&#xff08;1&#xff09;jar命令注意事项&#xff1a;&#xff08;2&#xff09;jar包清单文件创建示例&#xff1a;&a…

20231112_DNS详解

DNS是实现域名与IP地址的映射。 1.映射图2.DNS查找顺序图3.DNS分类和地址4.如何清除缓存 1.映射图 图片来源于http://egonlin.com/。林海峰老师课件 2.DNS查找顺序图 3.DNS分类和地址 4.如何清除缓存

winform+access数据库增删查改报表导出demo源码

C#winformaccess数据库增删查改报表导出demo源码设备管理的一个简单程序使用access数据库增删查改导出报表功能 OleDbConnection conn new OleDbConnection("Data Source" System.Windows.Forms.Application.StartupPath "\\config\\cinfor.mdb;ProviderMicr…

Java图像编程之:Graphics

一、概念介绍 1、Java图像编程的核心类 Java图像编程的核心类包括&#xff1a; BufferedImage&#xff1a;用于表示图像的类&#xff0c;可以进行像素级的操作。Image&#xff1a;表示图像的抽象类&#xff0c;是所有图像类的基类。ImageIcon&#xff1a;用于显示图像的类&a…

.net在使用存储过程中IN参数的拼接方案,使用Join()方法

有时候拼接SQL语句时&#xff0c;可能会需要将list中的元素都加上单引号&#xff0c;并以逗号分开&#xff0c;但是Join只能简单的分开&#xff0c;没有有单引号&#xff01; 1.第一种拼接方案 List<string> arrIds new List<string>(); arrIds.Add("aa&qu…

微软近日限制员工访问ChatGPT!

作者 | 撒鸿宇 据CNBC报道&#xff0c;在这周四的短时间内&#xff0c;微软的员工被禁止使用ChatGPT。 微软在其内部网站的更新中表示&#xff1a;“由于安全和数据问题&#xff0c;一些AI工具不再对员工开放。”据CNBC查证&#xff0c;他们看到了一张截图&#xff0c;该截图显…

[Go语言]SSTI从0到1

[Go语言]SSTI从0到1 1.Go-web基础及示例2.参数处理3.模版引擎3.1 text/template3.2 SSTI 4.[LineCTF2022]gotm1.题目源码2.WP 1.Go-web基础及示例 package main import ("fmt""net/http" ) func sayHello(w http.ResponseWriter, r *http.Request) { // 定…

发布订阅者模式(观察者模式)

目录 应用场景 1.结构 2.效果 3.代码 3.1.Main方法的类【ObserverPatternExample】 3.2.主题&#xff08;接口&#xff09;【Subject】 3.3.观察者&#xff08;接口&#xff09;【Observer】 3.4.主题&#xff08;实现类&#xff09;【ConcreteSubject】 3.5.观察者&a…

[工业自动化-16]:西门子S7-15xxx编程 - 软件编程 - 西门子仿真软件PLCSIM

目录 前言&#xff1a; 一、PLCSIM仿真软件 1.1 PLCSIM仿真软件基础版&#xff08;内嵌&#xff09; 1.2 PLCSIM仿真软件与PLCSIM仿真软件高级版的区别&#xff1f; 1.3 PLCSIM使用 前言&#xff1a; PLC集成开发环境是运行在Host主机上&#xff0c;Host主机与PLC可以通过…

外星人笔记本键盘USB协议逆向

前言 我朋友一台 dell g16 购买时直接安装了linux系统&#xff0c;但是linux上没有官方的键盘控制中心&#xff0c;所以无法控制键盘灯光&#xff0c;于是我就想着能不能逆向一下键盘的协议&#xff0c;然后自己写一个控制键盘灯光的程序。我自己的外星人笔记本是m16&#xff…

基恩士软件的基本指令(二)

目录 基础指令 输入输出常开常闭指令 “A软元件名称--装入快捷键” “O软元件名称--输出快捷键” “ALT回车--连线快捷键” “B软元件--常闭接点” “软元件“/”--切换常开/常闭接点状态” 上升沿下降沿指令 “P-软元件回车--上升沿输入方法” “F-软元件回车--下降沿输入…

logback异步日志打印阻塞工作线程

前言 最新做项目&#xff0c;发现一些历史遗留问题&#xff0c;典型的是日志打印的配置问题&#xff0c;其实都是些简单问题&#xff0c;但是往往简单问题引起严重的事故&#xff0c;比如日志打印阻塞工作线程&#xff0c;以logback和log4j2为例。logback实际上是springboot的…

通过SD卡给某摄像头植入可控程序

0x01. 摄像头卡刷初体验 最近研究了手上一台摄像头的sd卡刷机功能&#xff0c;该摄像头只支持fat32格式的sd卡&#xff0c;所以需要先把sd卡格式化为fat32&#xff0c;另外微软把fat32限制了最大容量32G&#xff0c;所以也只能用不大于32G的sd卡来刷机。 这里使用32G的sd卡来…

flutter逆向 ACTF native app

前言 算了一下好长时间没打过CTF了,前两天看到ACTF逆向有道flutter逆向题就过来玩玩啦,花了一个下午做完了.说来也巧,我给DASCTF十月赛出的逆向题其中一道也是flutter,不过那题我难度降的相当之低啦,不知道有多少人做出来了呢~ 还原函数名 flutter逆向的一大难点就是不知道l…

RGMII回环:IDDR+ODDR+差分接口

目录 一、实验内容二、原理解释三、程序1、顶层文件&#xff1a;2、子模块2.1 oddr模块2.2、iddr顶层模块2.3、iddr子模块 3、仿真4、注意5、下载工程及仿真 一、实验内容 1、通过IDDR和ODDR的方式完成RGMII协议&#xff1b; 2、外部接口使用OBUFDS、IBUFDS转换成差分接口&…

C++语言的广泛应用领域

目录 1. 系统级编程 2. 游戏开发 3. 嵌入式系统 4. 大数据处理 5. 金融和量化分析 6. 人工智能和机器学习 7. 网络和通信 结语 C是一种多范式编程语言&#xff0c;具有高性能、中级抽象能力和面向对象的特性。由Bjarne Stroustrup于1979年首次设计并实现&#xff0c;C在…

如何确定线程栈的基址?

起 很早之前&#xff0c;我遇到过几个与栈相关的问题&#xff0c;当时总结过几篇关于线程栈的文章&#xff0c;分别是 《栈大小可以怎么改&#xff1f;》、《栈局部变量优化探究&#xff0c;意外发现了 vs 的一个 bug &#xff1f;》、《栈又溢出了》、《有趣的异常》。在这几…

【fast2021论文导读】 Learning Cache Replacement with Cacheus

文章:Learning Cache Replacement with Cacheus 导读摘要: 机器学习的最新进展为解决计算系统中的经典问题开辟了新的、有吸引力的方法。对于存储系统,缓存替换是一个这样的问题,因为它对性能有巨大的影响。 本文第一个贡献,确定了与缓存相关的特征,特别是,四种工作负载…