torch模型量化方法总结

news2025/1/16 13:53:49

0.概述

模型训练完成后的参数为float或double类型,而装机(比如车载)后推理预测时,通常都会预先定点(量化)为int类型参数,相应的推理的精度会有少量下降,但不构成明显性能下降,带来的结果是板端部署的可能性,推理的latency明显降低,本文对torch常用的量化方法进行总结作为记录。

1.模型量化的作用

量化是指将信号的连续取值近似为有限多个离散值的过程。可理解成一种信息压缩的方法。在计算机系统上考虑这个概念,一般用“低比特”来表示。也有人称量化为“定点化”,但是严格来讲所表示的范围是缩小的。定点化特指scale为2的幂次的线性量化,是一种更加实用的量化方法。

卷积神经网络具有很好的精度,甚至在一些任务上比如人脸识别、图像分类,已经超越了人类精度。但其缺点也比较明显,具有较大的参数量,计算量,以及内存占用。而模型量化可以缓解现有卷积神经网络参数量大、计算量大、内存占用多等问题,具有为神经网络压缩参数、提升速度、降低内存占用等“潜在”优势。为什么“潜在”是加引号的呢?因为想同时达到这三个特性并不容易,在实际应用过程中存在诸多限制和前提条件。

另外,由于模型量化是一种近似算法方法,精度损失是一个严峻的问题,大部分的研究都在关注这一问题。作为一个在公司支撑很多业务线的团队,我们会在关注精度的同时,注重部署最终的速度和资源占用情况。

1.1 压缩参数

1.2 提升速度

什么样的量化方法可以带来潜在、可落地的速度提升呢?我们总结需要满足两个条件:

1、量化数值的计算在部署硬件上的峰值性能更高 。

2、量化算法引入的额外计算(overhead)少 。

要准确理解上述条件,需要有一定的高性能计算基础知识,限于篇幅就不展开讨论了。现直接给出如下结论:已知提速概率较大的量化方法主要有如下三类,

1、二值化,其可以用简单的位运算来同时计算大量的数。对比从nvdia gpu到x86平台,1bit计算分别有5到128倍的理论性能提升。且其只会引入一个额外的量化操作,该操作可以享受到SIMD(单指令多数据流)的加速收益。

2、线性量化,又可细分为非对称,对称和ristretto几种。在nvdia gpu,x86和arm平台上,均支持8bit的计算,效率提升从1倍到16倍不等,其中tensor core甚至支持4bit计算,这也是非常有潜力的方向。由于线性量化引入的额外量化/反量化计算都是标准的向量操作,也可以使用SIMD进行加速,带来的额外计算耗时不大。

3、对数量化,一个比较特殊的量化方法。可以想象一下,两个同底的幂指数进行相乘,那么等价于其指数相加,降低了计算强度。同时加法也被转变为索引计算。但没有看到有在三大平台上实现对数量化的加速库,可能其实现的加速效果不明显。只有一些专用芯片上使用了对数量化。

1.3 降低内存

2. 量化的实现方法

pytorch有3种量化模式,包括Eager quantization mode、FX quantization mode以及PyTorch 2 Export Quantization(pytrch2.1新增),每种模式都支持多种量化方式,包括动态量化、静态量化以及量化感知训练。

2.1动态量化(Dynamic Quantization)

  • 概述: 动态量化是在推理过程中动态地量化激活(activations)。这种方法对权重进行静态量化,并在每次输入时对激活动态量化。这种方法主要应用于不易量化的模型,如包含 LSTM 的 RNN 模型。

  • 优点: 适用于无法在推理前获得输入数据分布的场景,且对模型的精度影响较小。

  • 缺点: 在每次推理时需要进行量化,可能会有一些计算开销。

  • 使用场景: NLP 模型(如 Transformer、BERT)中的全连接层。

import torch
import torch.nn as nn
import torch.quantization

# 定义模型
class MLPModel(nn.Module):
    def __init__(self):
        super(MLPModel, self).__init__()
        self.fc1 = nn.Linear(1, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

model = MLPModel()

# 动态量化
model_quantized = torch.quantization.quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8
)

2.2 静态量化(Static Quantization)

  • 概述: 静态量化在推理前对模型的权重和激活进行量化。为了有效地进行量化,模型需要在校准数据集上运行,以估计激活的分布。

  • 优点: 可以带来显著的性能提升,适用于大部分 CNN 模型和传统的全连接网络。

  • 缺点: 需要校准数据,且量化后的精度可能会下降,特别是在小型数据集上。

  • 使用场景: 计算机视觉中的 CNN 模型,如 ResNet、MobileNet。

import torch
import torch.nn as nn
import torch.quantization

# 定义模型
class MLPModel(nn.Module):
    def __init__(self):
        super(MLPModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc1 = nn.Linear(1, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.dequant(x)
        return x

model = MLPModel()

# 配置和准备模型
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)

# 模拟校准数据运行
model.eval()
with torch.no_grad():
    for _ in range(10):
        model(torch.randn(1, 1))

# 转换为量化模型
torch.quantization.convert(model, inplace=True)

2.3 量化感知训练(Quantization-Aware Training, QAT)

概述: QAT 在训练过程中模拟量化过程,使模型能够适应量化引入的噪声。这种方法在训练期间插入了量化和反量化操作。

优点: 精度损失最小,适用于对精度要求高的任务。

缺点: 训练时间增加,且需要重新训练模型。

使用场景: 高精度要求的任务,如语音识别、图像分类等。

import torch
import torch.nn as nn
import torch.quantization

# 定义模型
class MLPModel(nn.Module):
    def __init__(self):
        super(MLPModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc1 = nn.Linear(1, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.dequant(x)
        return x

model = MLPModel()

# 配置 QAT 模型
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)

# 开始 QAT 训练
model.train()
for epoch in range(10):
    # 模拟训练过程
    inputs = torch.randn(1, 1)
    outputs = model(inputs)
    loss = torch.mean((outputs - inputs) ** 2)
    loss.backward()

# 转换为量化模型
torch.quantization.convert(model, inplace=True)

3. 量化的实现示例

3.1 实现代码

3.1.1 构建数据集与数据加载

import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np

class SineDataset(Dataset):
    def __init__(self, n_samples=1000):
        self.x = np.linspace(0, 2 * np.pi, n_samples)
        self.y = np.sin(self.x)
        self.x = self.x.reshape(-1, 1).astype(np.float32)
        self.y = self.y.reshape(-1, 1).astype(np.float32)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
        
 def dataset_loader():
    # 创建数据集和数据加载器
    dataset = SineDataset(n_samples=1000)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=100)
    
    return train_loader, val_loader

3.1.2 基础模型与量化模型

为了保持模型基础结构一致, QuantizationModel 继承于MLPModel

import torch.nn as nn
import pytorch_lightning as pl

class MLPModel(pl.LightningModule):
    def __init__(self):
        super(MLPModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
class QuantizationModel(MLPModel):
    def __init__(self):
        super(QuantizationModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x
        
#  模型保存
def save_normal_model(model_path: str):
    train_loader, val_loader = dataset_loader()
    model = MLPModel()
    trainer = pl.Trainer(max_epochs=100)
    trainer.fit(model, train_loader, val_loader)
    print("Normal model:\n",model)
    # model.to_torchscript().save()
    torch.save(model, model_path)
    return model

3.1.3 模型动态量化与保存


def save_dynamic_quantization(model, model_path: str):
    # model = MLPModel()
    # trainer = pl.Trainer(max_epochs=100)
    # train_loader, val_loader = dataset_loader()    
    # trainer.fit(model, train_loader, val_loader)    
    # 动态量化
    model_dynamic_quantized = torch.quantization.quantize_dynamic(
        model, {nn.Linear}, dtype=torch.qint8
    )
    print('model_dynamic_quantized:\n',model_dynamic_quantized)
    # 保存整个动态量化后的模型
    torch.save(model_dynamic_quantized, model_path)    

3.1.4 模型静态量化与保存

def save_static_quantization(model_path: str):
    # 静态量化
    model = QuantizationModel()
    train_loader, val_loader = dataset_loader()
    trainer = pl.Trainer(max_epochs=100)
    trainer.fit(model, train_loader, val_loader)    
    # 定义量化配置
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    # 准备模型进行量化
    model_prepared = torch.quantization.prepare(model)

    # 创建校准数据 
    # train_loader, val_loader = dataset_loader()
    # 对模型进行校准
    model_prepared.eval()
    with torch.no_grad():
        for data, _ in val_loader:
            model_prepared(data)   
    model_static_quantized = torch.quantization.convert(model_prepared)
             
    print("model_static_quantized:\n", model_static_quantized)
    # 保存整个静态量化后的模型
    torch.save(model_static_quantized, model_path)
    return model_static_quantized

 3.1.5模型qat量化与保存

def save_qat_quantization(model_path: str) -> None:
    qat_model = QuantizationModel()
    train_loader, val_loader = dataset_loader()
    trainer = pl.Trainer(max_epochs=100)
    trainer.fit(qat_model, train_loader, val_loader)        
    qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    torch.quantization.prepare_qat(qat_model, inplace=True)
    
    trainer = pl.Trainer(max_epochs=100)
    train_loader, val_loader = dataset_loader()
    trainer.fit(qat_model, train_loader, val_loader)

    torch.quantization.convert(qat_model, inplace=True)
    print("qat_model:\n",qat_model)
    torch.save(qat_model, model_path)    

 3.1.6 模型加载与测试验证


import time
from functools import wraps
def timeit(message=""):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            start_time = time.time()  # 记录开始时间
            result = func(*args, **kwargs)  # 执行函数
            end_time = time.time()  # 记录结束时间
            elapsed_time = end_time - start_time  # 计算执行时间
            print(f"{message}: Function '{func.__name__}' executed in: {elapsed_time:.4f} seconds")
            return result
        return wrapper
    return decorator
def test_model(model_path, message="test_model"):
    @timeit(message)
    def inner_test(model_path):
        model = torch.load(model_path)
        # print(model)
        check_quantization_type(model)
        model.eval()    
        dataset = SineDataset(n_samples=1000000)
        x, y = dataset[:]    
        for i in range(10):
            # x = torch.tensor([0.1]).unsqueeze(1)
            with torch.no_grad():
                y_hat = model(torch.tensor(x))
    inner_test(model_path)

3.1.7  测试主流程

def save_model():
    model = save_normal_model("model/mlp_model.pt")
    save_dynamic_quantization(model, "model/mlp_model_dynamic_quantized.pt")
    save_static_quantization("model/mlp_model_static_quantized.pt")
    save_qat_quantization("model/mlp_model_qat_quantized.pt")

def test_models():
    # 测试原始模型
    test_model("model/mlp_model.pt", "Normal Model")
    # 测试动态量化模型
    test_model("model/mlp_model_dynamic_quantized.pt", "Dynamic Quantized Model")
    # 测试静态量化模型
    test_model("model/mlp_model_static_quantized.pt", "Static Quantized Model")
    # 测试QAT模型
    test_model("model/mlp_model_qat_quantized.pt", "Qat Quantized Model")    
    
def main():
    save_model()
    test_models()

3.2 效果对比

3.2.1 性能对比

性能/MSE

时间消耗

原始模型

0.05726085230708122

0.9883 seconds

动态量化

0.05744938179850578

1.0602 seconds

静态量化

0.05726085230708122

0.3772 seconds

QAT量化

0.01776209846138954

0.3706 seconds

3.2.2 推理结果对比

如图所示,精度有所损失

4.参考资料

一文搞懂模型量化算法

Quantization — PyTorch 2.4 documentation

https://arxiv.org/pdf/2205.07877

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2154555.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CO-锁存器(Latch)

1.描述 锁存器(Latch),是数字电路中的一种具有记忆功能的逻辑元件,是一种对脉冲电平敏感的存储单元电路,可以在特定输入脉冲电平作用下改变状态,利用电平控制数据的输入,包括不带使能控制的锁存器和带使能控制的锁存器…

sql执行流程经典案例分析

现在有联合索引(a,b),select* form tb where b xx group by a执行流程是什么样子的? CREATE TABLE IF NOT EXISTS test(id INT(10) NOT NULL AUTO_INCREMENT COMMENT主键,a INT(10) NULL,b INT(10) NULL,PRIMARY KEY(id),INDEX idx_a_b(a,b))ENGINE INNODB;INSERT INTO test…

【Unity-UGUI组件拓展】| Image 组件拓展,支持FIlled和Slice功能并存

🎬【Unity-UGUI组件拓展】| Image 组件拓展,支持FIlled和Slice功能并存一、组件介绍二、组件拓展方法三、完整代码💯总结🎬 博客主页:https://xiaoy.blog.csdn.net 🎥 本文由 呆呆敲代码的小Y 原创,首发于 CSDN🙉 🎄 学习专栏推荐:Unity系统学习专栏 🌲 游戏…

Linux:login shell和non-login shell以及其配置文件

相关阅读 Linuxhttps://blog.csdn.net/weixin_45791458/category_12234591.html?spm1001.2014.3001.5482 shell是Linux与外界交互的程序,登录shell有两种方式,login shell与non-login shell,它们的区别是读取的配置文件不同,本…

TypeScript入门 (三)数据类型

引言 大家好,我是GISer Liu😁,一名热爱AI技术的GIS开发者。本系列文章是我跟随DataWhale 2024年9月学习赛的TypeScript学习总结文档。本文旨在全面介绍 TypeScript 中的各种数据类型,帮助读者深入理解每种数据类型的用法、内置属性…

LabVIEW提高开发效率技巧----自动化测试和持续集成

在大型项目中,自动化测试和持续集成是提高开发效率和代码质量的关键手段。通过这些技术,开发者能够在开发的早期阶段快速发现问题,减少后期调试的工作量,并且能够确保代码的稳定性和可维护性。以下是这两个概念如何在LabVIEW开发中…

Docker Networking Tutorial (Bridge - None - Host - IPvlan - Macvlan )

In this article, We will talk about the network of docker. Therere have five types of docker network. 一、Bridge The default network of docker network type. You can use : docker network ls docker network create --driver bridge my_bridge_network ##The CID…

什么是 GPT?通过图形化的方式来理解 Transformer 架构

Predict, sample, repeat 预测、取样、重复 GPT 是 Generative Pre-trained Transformer 的缩写。首个单词较为直接,它们是用来生成新文本的机器人。“Pre-trained” 指的是模型经历了从大量数据中学习的过程,这个词暗示了该模型还有进一步在特定任务中…

移动技术开发:ListView水果列表

1 实验名称 ListView水果列表 2 实验目的 掌握自定义ListView控件的实现方法 3 实验源代码 布局文件代码&#xff1a; activity_main.xml: <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.androi…

Java 中Lock接口锁的使用

一. Lock接口下的实现类 在Java中&#xff0c;Lock 接口是 java.util.concurrent.locks 包中的一部分&#xff0c;它提供了比 synchronized 更丰富的锁操作。Lock 接口的实现类包括 ReentrantLock&#xff08;可重入锁&#xff09;、ReadWriteLock&#xff08;读写锁&#xff…

从零开始学习TinyWebServer

写在前面 项目参考&#xff1a;https://github.com/qinguoyi/TinyWebServer 写作框架/图参考&#xff1a;https://blog.csdn.net/qq_52313711/article/details/136356042?spm1001.2014.3001.5502 原本计划是&#xff0c;先将项目代码大概看一遍&#xff0c;然后再着手实现一下…

【hot100-java】【组合总和】

R8-回溯篇 印象题&#xff0c;很基本的回溯 class Solution {void backtrack(List<Integer> state,int target,int[] choices,int start,List<List<Integer>> ret){//子集和等于target&#xff0c;记录解if (target0){ret.add(new ArrayList<>(state)…

LeetCode讲解篇之1343. 大小为 K 且平均值大于等于阈值的子数组数目

文章目录 题目描述题解思路题解代码 题目描述 题解思路 题目让我们求长度为k的子数组并且该子数组的平均值大于threshold&#xff0c;对于这题&#xff0c;我们可以考虑维护一个长度为k的窗口&#xff0c;窗口不断向右滑动&#xff0c;遍历所有长度为k的子数组&#xff0c;我们…

低版本SqlSugar的where条件中使用可空类型报语法错误

SQLServer数据表中有两列可空列&#xff0c;均为数值类型&#xff0c;同时在数据库中录入测试数据&#xff0c;Age和Height列均部分有值。   使用SqlSugar的DbFirst功能生成数据库表类&#xff0c;其中Age、Height属性均为可空类型。   开始使用的SqlSugar版本较低&…

win11 wsl2安装ubuntu22最快捷方法

操作系统是win11&#xff0c;wsl版本是wsl2&#xff0c;wsl应该不用多介绍了&#xff0c;就是windows上的虚拟机&#xff0c;在wsl上可以很方便的运行Linux系统&#xff0c;性能棒棒的&#xff0c;而且wsl运行的系统和win11主机之间的文件移动是无缝的&#xff0c;就是两个系统…

力扣115-不同的子序列(Java详细题解)

题目链接&#xff1a;不同的子序列 前情提要&#xff1a; 因为本人最近都来刷dp类的题目所以该题就默认用dp方法来做。 dp五部曲。 1.确定dp数组和i下标的含义。 2.确定递推公式。 3.dp初始化。 4.确定dp的遍历顺序。 5.如果没有ac打印dp数组 利于debug。 每一个dp题目…

Spring IDEA 2024 安装Lombok插件

1.简介 Lombook插件的Data标签可以自动生成类的get和set以及toString方法。 2.安装步骤 在idead设置的插件中搜索lombok插件&#xff0c;安装。 在Spring项目的pom.xml中添加依赖项 <dependency><groupId>org.projectlombok</groupId><artifactId…

数据结构与算法——Java实现 7.习题——反转链表

当你穿过了暴风雨&#xff0c;你已不是原来那个人 —— 24.9.21 206. 反转链表 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1]示例 2&#xff1a; 输…

echarts标注legend的配置

代码&#xff1a; legend: [{top: bottom, //上下位置 top center bottom 还可以用百分比50%等orient: horizontal, // 竖立 vertical horizontal 水平的// right: 0, //靠右 还可以用百分比 50%等// left: 0,// 靠左 还可以用百分比 50%等// 左右位置 否则居中itemWidth: …

前端框架Vue、React、Angular、Svelte对比

在对比 React、Vue.js、Angular 和 Svelte 时&#xff0c;除了在高层次的特性上有显著差异&#xff0c;它们在核心设计理念和底层实现机制上也有明显的不同。为了清晰地理解这些框架&#xff0c;我们可以从以下几个方面来分析它们的核心不同点和底层不同点。 1. 框架类型和设计…