模型量化(二)—— 训练后量化PTQ(全代码)

news2025/2/23 7:09:38

训练后量化(Post-training Quantization,PTQ)是一种常见的模型量化技术,它在模型训练完成之后应用,旨在减少模型的大小和提高推理速度,同时尽量保持模型的性能。训练后量化对于部署到资源受限的设备上,如移动设备和嵌入式设备,特别有用。

在我们量化时,量化操作可以应用于模型的输入、权重 和 激活(即神经元输出值)上。

但我们发现,对于激活值,我们执行反量化时,并不知道这些激活值对应的浮点数矩阵的最大值和最小值,即我们执行非对称或对称量化里面的 𝛼, β 参数,所以我们拿到一个模型时,最多只能对它的权重W和输入X做量化,对于激活值Y的反量化,我们需要一组小的calibration set数据来初步计算对于Y的S和Z参数。

不熟悉非对称或对称量化的朋友可以康康这篇:《模型量化(一)—— 非对称量化、对称量化(全代码)》

在这里插入图片描述
 

目录

  • PTQ流程:
  • 全代码
    • 预训练模型
    • 加入Observer
    • 校准模型
      • 量化模型

 

PTQ流程:

在这里插入图片描述
Observer,顾名思义就是模型在正常inference的时候会被记录下正常的浮点激活值,用来算激活值对应的S和Z参数。

Calibrate后模型的W和Y都有对应的S和Z了,模型名义上量化完成。浮点的输入X也能off-line地实时算它对应的S和Z。

所以量化后的模型运行时,先对浮点输入进行量化,然后与整型的W矩阵相乘,得到整型的激活值,这时再反量化为浮点激活值,对应于下一个神经元的浮点输入,依次循环。
大家可能会想吗,这么麻烦,又是量化又是反量化,怎么还会压缩模型和加速模型呢?

压缩模型:原本所有的W都是浮点数存储,比如float32,现在转换为int8存储,模型尺寸减了大概4倍;再额外存一些神经元或网络层的S和Z参数(取决于量化的粗粒度),相对于W来说占内存很小(如果是很细粒度的量化可能这部分也得好好考虑,量化的粒度分为权重级量化、层级量化、通道级量化等)。

加速模型:主要的收益是使得模型中占大头的 W * X 操作变成了整型相乘,功耗和时延最低(浮点数相乘时功耗和时延最大)。3 * 100 * 100 * 10的全连接网络中,有213个神经元,但是有 3 * 100 * 100 * 10 = 300M个参数!这还是忽略了bias。量化相当于就是让这 300M 次乘法更轻量。而相对的 overhead 就是对开头的3个输入进行一下量化 和 对210和神经元的输出进行一下反量化,这部分开销随着网络层数与参数的增加几乎可以忽略不计。
一些专门的深度学习加速器和现代CPU/GPU提供了对低位宽整数(如int8)的优化支持,用这些硬件后可以更加体现模型量化的优势。

量化会带来一定的量化误差,即模型精度会受影响,这肯定的,但按经验来说几乎没什么影响,不要压到int4或int2这么极限就行。

 

全代码

预训练模型

import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

# Make torch deterministic
_ = torch.manual_seed(0)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cpu"


# Define the model
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = VerySimpleNet().to(device)


# Train the model
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, net, epochs=1)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)


# Define the testing loop
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0
    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')


# Print weights and size of the model before quantization

# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

print('Size of the model before quantization')
print_size_of_model(net)

print(f'Accuracy of the model before quantization: ')
test(net)

加入Observer

# Insert min-max observers in the model

class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(QuantizedVerySimpleNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

net_quantized = QuantizedVerySimpleNet().to(device)
# Copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

校准模型

#用测试集再跑一次装了observer的模型
test(net_quantized)

print(f'Check statistics of the various layers')
net_quantized

在这里插入图片描述
这时看到激活层的 𝛼, β 都有了,good!

量化模型

# Quantize the model using the statistics collected

net_quantized = torch.ao.quantization.convert(net_quantized)

print(f'Check statistics of the various layers')
net_quantized

在这里插入图片描述

# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))


# Compare the dequantized weights and the original weights
print('Original weights: ')
print(net.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(net_quantized.linear1.weight()))
print('')

# Print size and accuracy of the quantized model
print('Size of the model after quantization')
print_size_of_model(net_quantized)
print('Testing the model after quantization')
test(net_quantized)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1513848.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【阿里云系列】-利用yaml文件部署NacosXxl-job到ACK

背景介绍 随着容器化的技术成熟落地,拥抱各种成熟的容器化集群平台是加速我们落地的必然之路,目前国内以阿里云、华为云、腾讯云为平台的供应商为主,国外则以AWS,Azure为主,让我们借助平台已有的优势进行快速落地提高…

指针【理论知识速成】(3)

一.指针的使用和传值调用&#xff1a; 在了解指针的传址调用前&#xff0c;先来额外了解一下 “传值调用” 1.传值调用&#xff1a; 对于来看这个帖子的你相信代码展示胜过千言万语 #include <stdio.h> #include<assert.h> int convert(int a, int b) {int c 0…

log4j2.xml介绍和使用

log4j2.xml是什么 log4j2.xml 是用于配置 Apache Log4j 2 的 XML 格式配置文件。Log4j 2 是一个用于 Java 应用的流行日志框架&#xff0c;提供灵活的日志管理和配置。在 log4j2.xml 文件中&#xff0c;可以配置日志记录的格式、级别、目的地等。 下面是一些主要节点和属性的…

内容管理平台原来对企业这么重要,看完收藏!

“内容为王”&#xff0c;这是当今数字化时代的一个重要真理。不论是创业新贵、还是行业巨头&#xff0c;纷纷开始深入理解和应用内容管理平台&#xff08;Content Management System&#xff0c;简称CMS&#xff09;&#xff0c;以便更好的管理其大量的内容和信息。 那么&…

网络安全从业人员何去何从

从2024年1月1日开始到今天&#xff0c;基本没有真正放下自己休息过一天。可能很多人会说是卷&#xff0c;其实真正的原因是压力。不仅仅是生活压力还有行业压力。 今年这个行业让很多人开始感到了迷茫&#xff0c;不仅是股市的低迷&#xff0c;更多的来自于各大公司不断的因为…

什么是架构?架构设计原则是哪些?什么是设计模式?设计模式有哪些?

什么是架构?架构设计原则是哪些?什么是设计模式?设计模式有哪些? 架构的本质 架构本身是一种抽象的、来自建筑学的体系结构,其在企业及IT系统中被广泛应用。 架构的本质是对事物复杂性的管理,是对一个企业、一个公司、一个系统复杂的内部关系进行结构化、体系化的抽象,…

Stable-Diffusion的WebUI部署实战

1、环境准备及安装 1.1、linux环境 # 首先&#xff0c;已经预先安装好了anaconda&#xff0c;在这里新建一个环境 conda create -n sdwebui python3.10 # 安装完毕后&#xff0c;激活该环境 conda activate sdwebui# 安装 # 下载stable-diffusion-webui代码 apt install wget…

String 底层是如何实现的?

1、典型回答 String 底层是基于数组实现的&#xff0c;并且数组使用了 final 修饰&#xff0c;不同版本中的数组类型也是不同的&#xff1a; JDK9 之前&#xff08;不含JDK9&#xff09; String 类是使用 char[ ]&#xff08;字符数组&#xff09;实现的但 JDK9 之后&#xf…

C#版开源免费的Bouncy Castle密码库

前言 今天大姚给大家分享一款C#版开源、免费的Bouncy Castle密码库&#xff1a;BouncyCastle。 项目介绍 BouncyCastle是一款C#版开源、免费的Bouncy Castle密码库&#xff0c;开发人员可以通过该项目在他们的 C# 应用程序中使用 Bouncy Castle 提供的各种密码学功能&#x…

如何使用 Langchain、Ollama 和 Streamlit 构建 RAG

一、先决条件&#xff1a;您需要了解什么 在深入讨论技术细节之前&#xff0c;我们先概述一下先决条件。Python 的基础知识至关重要&#xff0c;因为它是我们将使用的主要语言。熟悉机器学习和自然语言处理的基本概念将帮助您更轻松地掌握这些概念。此外&#xff0c;对 Langch…

瑞熙贝通实验室物联网管理平台新升级|支持远程开门视频监控与电源控制以及环境监测

瑞熙贝通实验室智能物联网管控平台&#xff1a;利用“互联网与物联网技术”有机融合&#xff0c;对实验室的用电安全监测、实验室环境异常监测&#xff08;颗粒物监测、明火监测、可燃气体、烟雾监测、温湿度传感器、红外人体感应&#xff09;、实验室人员安全准入、万物互联等…

16、技巧之九: 修改参数,如何让表格翻页滚动到底部?【Selenium+Python3网页自动化总结】

1、问题提出 在网页配置参数时&#xff0c;输入参数名称搜索&#xff0c;搜出来的同名参数结果有多个&#xff0c;分布在一个表格的不同行&#xff0c;表格是动态加载的&#xff0c;需要滚动鼠标才能把所出参数找出来。用selenium怎么实现这种参数修改&#xff1f; 2、网页元素…

数字工厂管理系统和ERP管理系统有什么区别

在制造业的数字化转型浪潮中&#xff0c;数字工厂管理系统和ERP管理系统作为两大核心系统&#xff0c;扮演者不可或缺的角色。虽然它们都是为了提高企业的运营效率和降低成本&#xff0c;但在功能与实施效果方面&#xff0c;二者却有着显著的区别。本文将从这两个方面对数字工厂…

Pytorch实战01——CIAR10数据集

目录 1、model.py文件 &#xff08;预训练的模型&#xff09; 2、train.py文件&#xff08;会产生训练好的.th文件&#xff09; 3、predict.py文件&#xff08;预测文件&#xff09; 4、结果展示&#xff1a; 1、model.py文件 &#xff08;预训练的模型&#xff09; impor…

day57 动态规划part17● 647. 回文子串 ● 516.最长回文子序列● 动态规划总结篇

如果大家做了很多这种子序列相关的题目&#xff0c;在定义dp数组的时候 很自然就会想题目求什么&#xff0c;我们就如何定义dp数组。 布尔类型的dp[i][j]&#xff1a;表示区间范围[i,j] &#xff08;注意是左闭右闭&#xff09;的子串是否是回文子串&#xff0c;如果是dp[i][j…

C++学习路线

C学习路线思维导图&#xff0c;肝了一个星期终于搞定&#xff0c;这么硬核求个赞不过分吧&#xff1f; 思维导图的内容&#xff0c;也是本文的内容框架&#xff0c;坐稳扶好&#xff0c; C 高速快车要发车了&#xff01; 内容我会持续更新&#xff0c;点赞收藏&#xff0c;…

Window系统下Vscode配置C/Cpp运行+调试环境

Window系统下Vscode配置C/Cpp运行调试环境 文章目录 Window系统下Vscode配置C/Cpp运行调试环境1.安装Vscode2.安装C/Cpp插件3.配置gcc编译器4.配置Cpp运行环境5.配置Cpp调试环境 1.安装Vscode 安装VScode很简单&#xff0c;直接到官网进行下载&#xff0c;然后傻瓜安装即可。 …

读CDO代码

前置任务 module PIL.Image has no attribute ANTIALIASImage.ANTIALIAS 替换为 Image.LANCZOS&#xff0c;参考https://blog.csdn.net/fovever_/article/details/134690657 OSError: science is not a valid package style, path of style file, URL of这是对应可视化里面生…

GraphView实现测量工具

效果演示&#xff1a; 主模块代码&#xff1a; MeasureGraphView::MeasureGraphView(QWidget *parent): QWidget(parent) {ui.setupUi(this);m_measureType NoType;m_bll BllData::getInstance();m_scene new GraphicsScene;connect(m_bll, &BllData::pressMeasurePos…

力扣106 从中序与后续遍历序列构造二叉树

文章目录 题目描述解题思路代码 题目描述 给定两个整数数组 inorder 和 postorder &#xff0c;其中 inorder 是二叉树的中序遍历&#xff0c; postorder 是同一棵树的后序遍历&#xff0c;请你构造并返回这颗 二叉树 。 示例 1: 输入&#xff1a;inorder [9,3,15,20,7], …