模型压缩——量化方法解读

news2025/1/11 4:11:06

1.引言

前面我们已经介绍了剪枝、蒸馏等通过减少模型参数量来进行压缩的方法。除这些方法以外,量化 (quantization) 是另一种能够压缩模型参数的方法。与前面方法不同的是,量化并不减少模型参数量,而是通过修改网络中每个参数占用的比特数,从而减少模型参数占用的空间。

使用量化压缩模型的基本理论在于:计算机中不同数据类型的加法和乘法操作耗时是不同的

在这里插入图片描述

上图可以看到,8位整型的加法运算比32位浮点型的加法运算速度要快30倍,8位整型的乘法运算比32位浮点型的乘法运算要快18倍。

模型量化有以下几个好处:

  • 减小模型大小:如 int8 量化可减少 75% 的模型大小,int8 量化模型大小一般为 32 位浮点模型大小的 1/4。
  • 加快推理速度:浮点型可以访问四次 int8 整型,整型运算比浮点型运算更快;CPU 用 int8 计算的速度更快。
  • 适配边缘计算设备:一些IoT和智能手机上的微处理器通常是低功耗的,硬件加速器NPU只支持int8,通常需要8位量化才能加速。

2.浮点运算和定点运算

整个模型的量化是围绕数据类型转换展开的,因此,了解数据类型是学习模型量化的基础。

与模型量化相关的所有数据类型可以分为三类:整数定点数浮点数,整数比较简单,我们这里重点介绍定点数和浮点数。

2.1 定点数

定点数的小数点位置固定,所有数字的小数位数相同,因此可以直接进行按位左移操作,变成整数直接进行加减运算。

定点数运算可以视为在整数运算的基础上加了小数点位置的管理,它的运算规则与整数运算规则完全相同,因此定点数运算也具备与整数一样高效的运算速度。
在这里插入图片描述

定点数优点:运算速度较快,适合于资源受限的系统,例如:嵌入式系统;缺点:只能表示一定范围内的数字,超出范围则表示准确表示。

2.2 浮点数

浮点数通常采用科学计数法表示,由符号位指数位尾数三部分组成,浮点数在计算时需要同时计算指数和尾数,并且需要对结果进行归一化。

举例:12300用浮点数可以表示为1.23 x 10^4, 在这个浮点数中,符号为0表示正数,尾数为1.23,指数为4表示尾数1.23需要向右移动4位。

常见的浮点数汇总如下图所示:

在这里插入图片描述

  • fp32:指数8位,尾数23位;
  • fp16:指数5位,尾数11位,是16位浮点数的标准表示,所有硬件设备都支持;
  • bf16:指数8位,尾数7位,是16位浮点数的优化表示,相比fp16扩大了数值可表示范围,降低了数值下溢和上溢风险,但并不是所有硬件设备都支持;
  • fp8(E4M3):指数4位,尾数3位,精度更高;
  • fp8(E5M2):指数5位,尾数2位,数值范围更大;

与定点数相比,浮点数可以表示的数据范围大得多,但是浮点运算比定点运算涉及更多步骤,导致计算速度较慢。

3.量化方法

量化基本方法有两种:基于k-means的量化线性量化

3.1 K-means量化

基于 k-means 的量化(K-means-based Quantization):存储方式为整型权重 + 浮点型的转换表,计算方式为浮点计算。

每个权重的位置只需要存储聚类的索引值。将权重聚类成4类(0,1,2,3),就可以实现2-bit的压缩。 存储占用从 32bit x 16 = 512 bit = 64 B => 2bit x 16 + 32 bit x 4 = 32 bit + 128 bit = 160 bit = 20 B,如下图所示:

在这里插入图片描述

注:推理时,我们读取转换表,根据索引值获取对应的值。训练时,我们将gradient按照weights的聚类方式进行聚类相加,反向传播到转换表,更新转换表的值。

我们以下面这个weights权重的数据为例,来解读k-means的运算过程。

import random
from fast_pytorch_kmeans import KMeans
from collections import namedtuple
import torch

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

bitwidth = 2   # 量化后的位数
weights = torch.Tensor([
    [2.09, 0.98, 1.48, 0.09],
    [0.05, -0.14, -1.08, 2.12], 
    [-0.91, 1.92, 0, -1.03], 
    [1.87, 0, 1.53, 1.49]])
第一步,确定聚类中心的数量,并以聚类数量准备kmeans算法。
> 注:聚类数量能告诉我们需要将数据分成多少个不同的类别,1个二进制位可以表示两个类别,2个二进制位可以表示2*2=4个类别,以此类推。
n_clusters = 2**bitwidth
kmeans = KMeans(n_clusters=n_clusters, mode='euclidean', verbose=0)
n_clusters
4

注:mode='euclidean’表示以欧几里得距离来进行聚类,对于一维空间,欧几里得距离就是两个点数值差的绝对值,即:$ d = |x_2 - x_1| $ 。

第二步,调整张量形状,让每个元素单独占一行,(4,4)的张量转换成(16,1), 以适应Kmeans算法的输入要求。

vectors = weights.view(-1, 1)
vectors
tensor([[ 2.0900],
        [ 0.9800],
        [ 1.4800],
        [ 0.0900],
        [ 0.0500],
        [-0.1400],
        [-1.0800],
        [ 2.1200],
        [-0.9100],
        [ 1.9200],
        [ 0.0000],
        [-1.0300],
        [ 1.8700],
        [ 0.0000],
        [ 1.5300],
        [ 1.4900]])

第三步,使用fit_predict方法将所有元素分配到聚类中,并返回每个元素的标签,表示它属于哪个聚类。

labels = kmeans.fit_predict(vectors).to(torch.long)
labels
tensor([0, 3, 3, 1, 1, 1, 2, 0, 2, 0, 1, 2, 0, 1, 3, 3])

在上面执行聚类的过程中,会自动计算出聚类中心(可以视为每个聚类的代表),我们可以访问kmeans.centroids直接获取并打印它的结果。

centroids = kmeans.centroids.to(torch.float).view(-1)
centroids = torch.from_numpy(np.around(centroids.numpy(), decimals=4))
centroids
tensor([ 2.0000,  0.0000, -1.0067,  1.3700])

聚类的过程是一个循环迭代不断寻找最优聚类中心的过程,初始的聚类中心是随机选择的,算法会计算数据点与每个聚类中心的距离,目的是将数据点分配到距离其最近的聚类中心。
当所有数据点都被分配后,算法会重新计算每个聚类中心点(可以理解为该聚类内所有点的均值),并基于新的聚类中心对所有数据点再重新分配,直到聚类中心不再变化。

使用前面得到的labelscentroids来量化张量。

由于每个标签(label)都指向它对应的聚类中心(centroids),因此,量化操作实际上是将所有元素替换为它对应标签的聚类中心,例如:2.09 替换为2.00。

quantized = centroids[labels]
quantized
tensor([ 2.0000,  1.3700,  1.3700,  0.0000,  0.0000,  0.0000, -1.0067,  2.0000,
        -1.0067,  2.0000,  0.0000, -1.0067,  2.0000,  0.0000,  1.3700,  1.3700])

重塑最终结果,将量化后的张量调整回与原始weights相同的形状。

quantized_weights = quantized.view_as(weights)
quantized_weights
tensor([[ 2.0000,  1.3700,  1.3700,  0.0000],
        [ 0.0000,  0.0000, -1.0067,  2.0000],
        [-1.0067,  2.0000,  0.0000, -1.0067],
        [ 2.0000,  0.0000,  1.3700,  1.3700]])

上面的过程封装为一个函数,方便调用。

def k_means_quantize(tensor: torch.Tensor, bitwidth=4):
    n_clusters = 2**bitwidth
    kmeans = KMeans(n_clusters=n_clusters, mode='euclidean', max_iter=100, tol=1e-4, verbose=0)
    
    vectors = tensor.view(-1, 1)
    labels = kmeans.fit_predict(vectors).to(torch.long)
    centroids = kmeans.centroids.to(torch.float).view(-1)
    centroids = torch.from_numpy(np.around(centroids.numpy(), decimals=4))
    
    quantized = centroids[labels].view_as(tensor)
    return quantized

quantized = k_means_quantize(weights, 2)
quantized
tensor([[ 2.0000,  1.3700,  1.3700,  0.0000],
        [ 0.0000,  0.0000, -1.0067,  2.0000],
        [-1.0067,  2.0000,  0.0000, -1.0067],
        [ 2.0000,  0.0000,  1.3700,  1.3700]])

这就是基于K-means进行量化的基本过程,下面我们再来看下线性量化。

3.2 线性量化

线性量化(也称为标准的8-bit量化)是将原始浮点数据和量化后的定点数据之间建立一个简单的线性变换关系,这样做可以减少存储和计算的复杂度,同时尽量保持原始数值的信息。

它的基本思想如下图示意:

在这里插入图片描述

实现步骤如下:

  • 确定量化区间:找到数据中的最小值rmin和最大值rmax,确定量化范围。
  • 计算缩放因子(scale)和偏移(z):用来将浮点数值映射到定点数值。
  • 映射值:将浮点值通过缩放因子和偏移映射到量化值。

我们依然以一组数据为例来介绍,假设有一组浮点数值[3.2, 4.5, 5.1, 6.0],我们想将它们量化到 8-bit 定点数(范围是 0 到 255)。

第一步:确定量化区间

import numpy as np

data = np.array([3.2, 4.5, 5.1, 6.0])
min_val = data.min()
max_val = data.max()
print(f"Min value: {min_val}, Max value: {max_val}")
Min value: 3.2, Max value: 6.0

第二步:计算缩放因子
将浮点数据映射到定点数值范围(例如 8-bit 范围 0 到 255)时,需要一个缩放因子(可以理解为地图缩放时的比例尺)来进行等比例缩放,以确保数据能在期望的范围内均匀分布。

缩放因子的计算方法是将浮点数值范围除以目标数值范围,如下代码所示。

qmin, qmax = 0, 255
scale = (max_val - min_val) / (qmax - qmin)
print("scale:", scale)
scale: 0.010980392156862744

第三步:计算偏移
我们希望将浮点数值映射到定点数值时,能够尽量保持数据的分布和关系,偏移(也称为零点)的作用是确保浮点数值的某个基准点(通常是最小值或零)能够准确地映射到定点数值中的某个点。

偏移是由于缩放因子和浮点数值中的最小值来确定的。

zero_point = qmin - min_val/scale
print("zero_point: ", zero_point)
zero_point:  -291.4285714285715

注:min_val/scale计算得到浮点数最小值缩放后的值min_val_scaled,量化最小值qmin与min_val_scaled的差值就是每个浮点数缩放后需要加的一个偏移量。

第四步:映射值
将浮点数映射到定点数,计算方法是将每个浮点数按照比例因子进行缩放,再加上偏移。

quantized_data = np.round(data/scale + zero_point)
quantized_data
array([  0., 118., 173., 255.])

上面就是线性量化的一个简单实现思路。

线性量化和k-means量化是目前两种主要的量化方法,具体到实际应用中,可能线性量化应用的更为普遍。其原因在于以几下点:

  • 简单性:线性量化通常将浮点数映射到固定的整数范围(如8-bit),通过简单的线性变换。这种简单性使得硬件实现非常容易。
  • 标准化:8-bit线性量化已经成为工业标准,许多硬件加速器和处理器(如ARM Cortex-M系列、NVIDIA TensorRT等)都专门优化了对8-bit整数操作的支持。
  • 性能:线性量化的计算效率高,对于乘加操作来说,整数运算比浮点运算快且能耗低。

而k-means量化由于涉及对权重进行聚类,这需要查找和映射操作,硬件实现起来更复杂,效率低于简单的线性变换。

小结:本文介绍了模型量化的基本原理,通过将高精度浮点数转换为低精度定点数,来降低模型中每个参数占用的比特数,同时借助定点数的快速运算,来提高模型整体的推理速度。此外,还介绍了线性量化和k-means两种量化方法,其中线性量化应用更普遍,其原因是实现简单、硬件友好且能够在大多数场景下提供足够的精度和性能提升。

参考阅读

  • 如何进行模型剪枝?
  • 如何进行知识蒸馏?
  • 如何进行神经网络架构搜索?
  • 什么是量化?
  • k-means算法介绍

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2250289.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Core 授权 认证 案例

利用 cookie 模式 》》 框架默认的 利用 cookie 模式 》》 策略授权

计算机网络常见面试题总结(上)

计算机网络基础 网络分层模型 OSI 七层模型是什么?每一层的作用是什么? OSI 七层模型 是国际标准化组织提出的一个网络分层模型,其大体结构以及每一层提供的功能如下图所示: 每一层都专注做一件事情,并且每一层都需…

Macos远程连接Linux桌面教程;Ubuntu配置远程桌面;Mac端远程登陆Linux桌面;可能出现的问题

文章目录 1. Ubuntu配置远程桌面2. Mac端远程登陆Linux桌面3. 可能出现的问题1.您用来登录计算机的密码与登录密钥环里的密码不再匹配2. 找不到org->gnome->desktop->remote-access 1. Ubuntu配置远程桌面 打开设置->共享->屏幕共享。勾选允许连接控制屏幕&…

【C语言】结构体、联合体、枚举类型的字节大小详解

在C语言中,结构体(struct)和联合体(union) 是常用的复合数据类型,它们的内存布局和字节大小直接影响程序的性能和内存使用。下面为大家详细解释它们的字节大小计算方法,包括对齐规则、内存分配方…

免交互运用

免交互的概念 文本免交互 免交互的格式 变量配置 expect expect的格式 在脚本外传参 嵌套 练习 免交互ssh远程连接

物联网客户端在线服务中心(客服功能/私聊/群聊/下发指令等功能)

一、界面 私聊功能(下发通知类,一对多)群聊(点对点)发送指令(配合使用客户端,基于cefsharp做的物联网浏览器客户端)修改远程参数配置(直接保存到本地)&#…

使用C#开发VTK笔记(一)-开发环境搭建

一.使用C#开发VTK的背景 因为C#开发的友好性,一直都比较习惯于从C#开发程序。而长期以来,都希望有一个稳定可靠的三位工程数模的开发演示平台,经过多次对比之后,感觉VTK和OpenCasCade这两个开源项目是比较好的,但它们都是用C++编写的,我用C#形式开发,只能找到发布的C#组…

力扣96:不同的二叉搜索树

给你一个整数 n ,求恰由 n 个节点组成且节点值从 1 到 n 互不相同的 二叉搜索树 有多少种?返回满足题意的二叉搜索树的种数。 示例 1: 输入:n 3 输出:5示例 2: 输入:n 1 输出:1 卡…

k8s Init:ImagePullBackOff 的解决方法

kubectl describe po (pod名字) -n kube-system 可查看pod所在的节点信息 例如&#xff1a; kubectl describe po calico-node-2lcxx -n kube-system 执行拉取前先把用到的节点的源换了 sudo mkdir -p /etc/docker sudo tee /etc/docker/daemon.json <<-EOF {"re…

人工智能如何改变你的生活?

在我们所处的这个快节奏的世界里&#xff0c;科技融入日常生活已然成为司空见惯的事&#xff0c;并且切实成为了我们生活的一部分。在这场科技变革中&#xff0c;最具变革性的角色之一便是人工智能&#xff08;AI&#xff09;。从我们清晨醒来直至夜晚入睡&#xff0c;人工智能…

道路机器人识别交通灯,马路,左右转,黄线,人行道,机器人等路面导航标志识别-使用YOLO标记

数据集分割 train组66% 268图片 validation集22% 91图片 test集12&#xff05; 48图片 预处理 没有采用任何预处理步骤。 增强 未应用任何增强。 数据集图片&#xff1a; 交通灯 马路 右转 向右掉头 机器人识别 人行横道 黄线 直行或右转 数据集下载&#xff1a; 道路…

【四轴】利用PWM捕获解析接收机信号

在学习这部分之间&#xff0c;建议大家先看之前这篇博客&#xff0c;里面包含对PWM一些重要概念的基本介绍。 【四轴】利用PWM输出驱动无刷电机-CSDN博客 1. 基本原理 1.1 PWM是什么 这一部分可以看我之前的博客&#xff0c;已经对PWM有了基本的介绍。 1.2 什么叫捕获PWM波&…

洛谷 P1162 填涂颜色 C语言 bfs

题目&#xff1a; https://www.luogu.com.cn/problem/P1162 由数字 0 组成的方阵中&#xff0c;有一任意形状的由数字 1 构成的闭合圈。现要求把闭合圈内的所有空间都填写成 22。例如&#xff1a;66的方阵&#xff08;n6&#xff09;&#xff0c;涂色前和涂色后的方阵如下&am…

38 基于单片机的宠物喂食(ESP8266、红外、电机)

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 基于STC89C52单片机&#xff0c;采用L298N驱动连接P2.3和P2.4口进行电机驱动&#xff0c; 然后串口连接P3.0和P3.1模拟ESP8266&#xff0c; 红外传感器连接ADC0832数模转换器连接单片机的P1.0~P1.…

霍夫变换:原理剖析与 OpenCV 应用实例

简介&#xff1a;本文围绕霍夫变换相关内容展开&#xff0c;先是讲解霍夫变换基本原理&#xff0c;包含从 xy 坐标系到 kb 坐标系及极坐标系的映射等。接着介绍了 cv2.HoughLines、cv2.HoughLinesP 概率霍夫变换、cv2.HoughCircles 霍夫圆变换的函数用法、参数含义、与常规霍夫…

【Debug】hexo-github令牌认证 Support for password authentication was removed

title: 【Debug】hexo-github令牌认证 date: 2024-07-19 14:40:54 categories: bug解决日记 description: “Support for password authentication was removed on August 13, 2021.” cover: https://pic.imgdb.cn/item/669b38ebd9c307b7e9f3e5e0.jpg 第一章 第一篇博客记录一…

JVM 性能调优 -- JVM常用调优工具【jps、jstack、jmap、jstats 命令】

前言&#xff1a; 前面我们分析怎么去预估系统资源&#xff0c;怎么去设置 JVM 参数以及怎么去看 GC 日志&#xff0c;本篇我们分享一些常用的 JVM 调优工具&#xff0c;我们在进行 JVM 调优的时候&#xff0c;通常需要借助一些工具来对系统的进行相关分析&#xff0c;从而确定…

net9 abp vnext 多语言通过数据库动态管理

通过数据库加载实现动态管理&#xff0c;用户可以自己修改界面显示的文本&#xff0c;满足国际化需求 如图所示,前端使用tdesign vnext 新建表TSYS_Localization与TSYS_LocalizationDetail 国旗图标下载网址flag-icons: Free Country Flags in SVG 在Shared下创建下图3个文件 …

Vue:使用 KeepAlive 缓存切换掉的 component

一、内置特殊元素 不是组件 <component>、<slot> 和 <template> 具有类似组件的特性&#xff0c;也是模板语法的一部分。但它们并非真正的组件&#xff0c;同时在模板编译期间会被编译掉。因此&#xff0c;它们通常在模板中用小写字母书写。 1.1 <compone…

Spring中每次访问数据库都要创建SqlSession吗?

一、SqlSession是什么二、源码分析1&#xff09;mybatis获取Mapper流程2&#xff09;Spring创建Mapper接口的代理对象流程3&#xff09;MapperFactoryBean#getObject调用时机4&#xff09;SqlSessionTemplate创建流程5&#xff09;SqlSessionInterceptor拦截逻辑6&#xff09;开…