NLP(六十七)BERT模型训练后动态量化(PTDQ)

news2025/1/23 17:45:40

  本文将会介绍BERT模型训练后动态量化(Post Training Dynamic Quantization,PTDQ)。

量化

  在深度学习中,量化(Quantization)指的是使用更少的bit来存储原本以浮点数存储的tensor,以及使用更少的bit来完成原本以浮点数完成的计算。这么做的好处主要有如下几点:

  • 更少的模型体积,接近4倍的减少
  • 可以更快地计算,由于更少的内存访问和更快的int8计算,可以快2~4倍

  PyTorch中的模型参数默认以FP32精度储存。对于量化后的模型,其部分或者全部的tensor操作会使用int类型来计算,而不是使用量化之前的float类型。当然,量化还需要底层硬件支持,x86 CPU(支持AVX2)、ARM CPU、Google TPU、Nvidia Volta/Turing/Ampere、Qualcomm DSP这些主流硬件都对量化提供了支持。

模型量化示例图片

PTDQ

  PyTorch对量化的支持目前有如下三种方式:

  • Post Training Dynamic Quantization:模型训练完毕后的动态量化
  • Post Training Static Quantization:模型训练完毕后的静态量化
  • QAT (Quantization Aware Training):模型训练中开启量化

  本文仅介绍Post Training Dynamic Quantization(PTDQ)
  对训练后的模型权重执行动态量化,将浮点模型转换为动态量化模型,仅对模型权重进行量化,偏置不会量化。默认情况下,仅对Linear和RNN变体量化 (因为这些layer的参数量很大,收益更高)。

torch.quantization.quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False)

参数解释:

  • model:模型(默认为FP32)
  • qconfig_spec:
  1. 集合:比如: qconfig_spec={nn.LSTM, nn.Linear} 。列出要量化的神经网络模块。
  2. 字典: qconfig_spec = {nn.Linear: default_dynamic_qconfig, nn.LSTM: default_dynamic_qconfig}
  • dtype: float16 或 qint8
  • mapping:就地执行模型转换,原始模块发生变异
  • inplace:将子模块的类型映射到需要替换子模块的相应动态量化版本的类型

例子:

# -*- coding: utf-8 -*-
# 动态量化模型,只量化权重
import torch
from torch import nn


class DemoModel(torch.nn.Module):
    def __init__(self):
        super(DemoModel, self).__init__()
        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1)
        self.relu = nn.ReLU()
        self.fc = torch.nn.Linear(2, 2)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.fc(x)
        return x


if __name__ == "__main__":
    model_fp32 = DemoModel()
    # 创建一个量化的模型实例
    model_int8 = torch.quantization.quantize_dynamic(model=model_fp32,  # 原始模型
                                                     qconfig_spec={torch.nn.Linear},  # 要动态量化的算子
                                                     dtype=torch.qint8)  # 将权重量化为:qint8

    print(model_fp32)
    print(model_int8)

    # 运行模型
    input_fp32 = torch.randn(1, 1, 2, 2)
    output_fp32 = model_fp32(input_fp32)
    print(output_fp32)

    output_int8 = model_int8(input_fp32)
    print(output_int8)

输出结果如下:

DemoModel(
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  (relu): ReLU()
  (fc): Linear(in_features=2, out_features=2, bias=True)
)
DemoModel(
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  (relu): ReLU()
  (fc): DynamicQuantizedLinear(in_features=2, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)
tensor([[[[0.3120, 0.3042],
          [0.3120, 0.3042]]]], grad_fn=<AddBackward0>)
tensor([[[[0.3120, 0.3042],
          [0.3120, 0.3042]]]])

模型量化策略

  当前,由于量化算子的覆盖有限,因此,对于不同的深度学习模型,其量化策略不同,见下表:

模型量化策略原因
LSTM/RNNDynamic Quantization模型吞吐量由权重的计算/内存带宽决定
BERT/TransformerDynamic Quantization模型吞吐量由权重的计算/内存带宽决定
CNNStatic Quantization模型吞吐量由激活函数的内存带宽决定
CNNQuantization Aware Training模型准确率不能由Static Quantization获取的情况

   下面对BERT模型进行训练后动态量化,分析模型在量化前后,推理效果和推理性能的变化。

实验

   我们使用的训练后的模型为中文文本分类模型,其训练过程可以参考文章:NLP(六十六)使用HuggingFace中的Trainer进行BERT模型微调 。
   训练后的BERT模型动态量化实验的设置如下:

  1. base model: bert-base-chinese
  2. CPU info: x86-64, Intel® Core™ i5-10210U CPU @ 1.60GHz
  3. batch size: 1
  4. thread: 1

   具体的实验过程如下

  • 加载模型及tokenizer
import torch
from transformers import AutoModelForSequenceClassification

MAX_LENGTH = 128
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = f"./sougou_test_trainer_{MAX_LENGTH}/checkpoint-96"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint).to(device)
from transformers import AutoTokenizer, DataCollatorWithPadding

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
  • 测试数据集
import pandas as pd

test_df = pd.read_csv("./data/sougou/test.csv")

test_df.head()
textlabel
0届数比赛时间比赛地点参加国家和地区冠军亚军决赛成绩第一届1956-1957英国11美国丹麦6...0
1商品属性材质软橡胶带加浮雕工艺+合金彩色队徽吊牌规格162mm数量这一系列产品不限量发行图案...0
2今天下午,沈阳金德和长春亚泰队将在五里河相遇。在这两支球队中沈阳籍球员居多,因此这场比赛实际...0
3本报讯中国足协准备好了与特鲁西埃谈判的合同文本,也在北京给他预订好了房间,但特鲁西埃爽约了!...0
4网友点击发表评论祝贺中国队夺得五连冠搜狐体育讯北京时间5月6日,2006年尤伯杯羽毛球赛在日...0
  • 量化前模型的推理时间及评估指标
import numpy as np
import time

s_time = time.time()
true_labels, pred_labels = [], [] 
for i, row in test_df.iterrows():
    row_s_time = time.time()
    true_labels.append(row["label"])
    encoded_text = tokenizer(row['text'], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors='pt').to(device)
    # print(encoded_text)
    logits = model(**encoded_text)
    label_id = np.argmax(logits[0].detach().cpu().numpy(), axis=1)[0]
    pred_labels.append(label_id)
    print(i, (time.time() - row_s_time)*1000, label_id)

print("avg time: ", (time.time() - s_time) * 1000 / test_df.shape[0])
0 229.3872833251953 0
100 362.0314598083496 1
200 311.16747856140137 2
300 324.13792610168457 3
400 406.9099426269531 4
avg time:  352.44047810332944
from sklearn.metrics import classification_report

print(classification_report(true_labels, pred_labels, digits=4))
              precision    recall  f1-score   support

           0     0.9900    1.0000    0.9950        99
           1     0.9691    0.9495    0.9592        99
           2     0.9900    1.0000    0.9950        99
           3     0.9320    0.9697    0.9505        99
           4     0.9895    0.9495    0.9691        99

    accuracy                         0.9737       495
   macro avg     0.9741    0.9737    0.9737       495
weighted avg     0.9741    0.9737    0.9737       495
  • 设置量化后端
# 模型量化
cpu_device = torch.device("cpu")
torch.backends.quantized.supported_engines
['none', 'onednn', 'x86', 'fbgemm']
torch.backends.quantized.engine = 'x86'
  • 量化后模型的推理时间及评估指标
# 8-bit 量化
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
).to(cpu_device)
q_s_time = time.time()
q_true_labels, q_pred_labels = [], [] 

for i, row in test_df.iterrows():
    row_s_time = time.time()
    q_true_labels.append(row["label"])
    encoded_text = tokenizer(row['text'], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors='pt').to(cpu_device)
    logits = quantized_model(**encoded_text)
    label_id = np.argmax(logits[0].detach().numpy(), axis=1)[0]
    q_pred_labels.append(label_id)
    print(i, (time.time() - row_s_time) * 1000, label_id)
    
print("avg time: ", (time.time() - q_s_time) * 1000 / test_df.shape[0])
0 195.47462463378906 0
100 247.33805656433105 1
200 219.41304206848145 2
300 206.44831657409668 3
400 187.4992847442627 4
avg time:  217.63229466447928
from sklearn.metrics import classification_report

print(classification_report(q_true_labels, q_pred_labels, digits=4))
              precision    recall  f1-score   support

           0     0.9900    1.0000    0.9950        99
           1     0.9688    0.9394    0.9538        99
           2     0.9900    1.0000    0.9950        99
           3     0.9320    0.9697    0.9505        99
           4     0.9896    0.9596    0.9744        99

    accuracy                         0.9737       495
   macro avg     0.9741    0.9737    0.9737       495
weighted avg     0.9741    0.9737    0.9737       495
  • 量化前后模型大小对比
import os

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print("Size (MB): ", os.path.getsize("temp.p")/1e6)
    os.remove("temp.p")

print_size_of_model(model)
print_size_of_model(quantized_model)
Size (MB):  409.155273
Size (MB):  152.627621

  量化后端(Quantization backend)取决于CPU架构,不同计算机的CPU架构不同,因此,默认的动态量化不一定在所有的CPU上都能生效,需根据自己计算机的CPU架构设置好对应的量化后端。另外,不同的量化后端也有些许差异。Linux服务器使用uname -a可查看CPU信息。
  重复上述实验过程,以模型的最大输入长度为变量,取值为128,256,384,每种情况各做3次实验,结果如下:

实验最大长度量化前平均推理时间(ms)量化前weighted F1值量化前平均推理时间(ms)量化前weighted F1值
实验138410660.97976860.9838
实验23841047.60.9899738.10.9879
实验33841020.90.9817714.00.9838
实验1256668.70.9717431.40.9718
实验2256675.10.9717449.90.9718
实验3256656.00.9717446.50.9718
实验1128335.80.9737200.50.9737
实验2128336.50.9737227.20.9737
实验3128352.40.9737217.60.9737

  综上所述,对于训练后的BERT模型(文本分类模型)进行动态量化,其结论如下:

  • 模型推理效果:量化前后基本相同,量化后略有下降
  • 模型推理时间:量化后平均提速约1.52倍

总结

  本文介绍了量化基本概念,PyTorch模型量化方式,以及对BERT模型训练后进行动态量化后在推理效果和推理性能上的实验。
  本文项目已开源至Github项目:https://github.com/percent4/dynamic_quantization_on_bert 。
  本人已开通个人博客网站,网址为:https://percent4.github.io/ ,欢迎大家访问~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/965827.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C#知识点、常见面试题

相关源码 https://github.com/JackYan666/CSharpCode/blob/main/CSharpCode.cs 0.简要概括 1.删除集合元素 1.For循环删除集合元素:从后面往前删除 从前往后删,有可能不能完全删除 #region 01.For循环删除集合元素void Test01_ForDelListElement(){//错误代码 虽然可以跑…

说说你了解的 Nginx

分析&回答 nginx性能数据 高并发连接: 官方称单节点支持5万并发连接数&#xff0c;实际生产环境能够承受2-3万并发。内存消耗少: 在3万并发连接下&#xff0c;开启10个nginx进程仅消耗150M内存 (15M10150M) 1. 正向、反向代理 所谓“代理”&#xff0c;是指在内网边缘 …

Dump文件的生成以及使用WinDbg静态分析

前言 本文章主要介绍了如何生成Dump文件&#xff0c;包括两种方式&#xff0c;通过代码生成和通过注册表生成。并且介绍了WinDbg工具的下载和使用&#xff0c;以及如何使用WinDbg工具去静态分析Dump文件&#xff0c;从而找到程序的崩溃位置。 生成Dump文件 通过调用WinAPI生成…

WGCNA分析教程 | 代码四

写在前面 WGCNA的教程&#xff0c;我们在前期的推文中已经退出好久了。今天在结合前期的教程的进行优化一下。只是在现有的教程基础上&#xff0c;进行修改。其他的其他并无改变。 前期WGCNA教程 WGCNA分析 | 全流程分析代码 | 代码一 WGCNA分析 | 全流程分析代码 | 代码二 …

论文阅读_扩散模型_DDPM

英文名称: Denoising Diffusion Probabilistic Models 中文名称: 去噪扩散概率模型 论文地址: http://arxiv.org/abs/2006.11239 代码地址1: https://github.com/hojonathanho/diffusion &#xff08;论文对应代码 tensorflow&#xff09; 代码地址2: https://github.com/AUTOM…

Linux图形栈入门概念

Mesa在图形栈中的位置 游戏引擎&#xff1a; 游戏引擎指的是一种软件框架&#xff0c;通过编程和各种工具&#xff0c;帮助开发者设计、构建和运行视频游戏。它相当于一个虚拟的世界创造工具&#xff0c;提供了各种功能模块和资源&#xff0c;如渲染引擎、物理引擎(碰撞检测、重…

跨模态可信感知

文章目录 跨模态可信感知综述摘要引言跨协议通信模式PCP网络架构 跨模态可信感知跨模态可信感知的概念跨模态可信感知的热点研究场景目前存在的挑战可能改进的方案 参考文献 跨模态可信感知综述 摘要 随着人工智能相关理论和技术的崛起&#xff0c;通信和感知领域的研究引入了…

【网络编程上】

目录 一.什么是互联网 1.计算机网络的定义与分类&#xff08;了解&#xff09; &#xff08;1&#xff09;计算机网络的定义 &#xff08;2&#xff09;计算机网络的分类 ① 按照网络的作用范围进行分类 ②按照网络的使用者进行分类 2.网络的网络 &#xff08;理解&#xf…

OpenCV模块介绍

其中core、highgui、imgproc是最基础的模块&#xff0c;该课程主要是围绕这几个模块展开的&#xff0c;分别介绍如下: core模块实现了最核心的数据结构及其基本运算&#xff0c;如绘图函数、数组操作相关函数。 highgui模块实现了视频与图像的读取、显示、存储等接口。 imgp…

Redis未授权访问漏洞复现

Redis 简单使用 Redis 未设置密码&#xff0c;客户端工具可以直接链接。 Redis 是非关系型数据库系统&#xff0c;没有库表列的逻辑结构&#xff0c;仅仅以键值对的方式存储数据。 先启动容器 Redis 未设置密码&#xff0c;客户端工具可以直接链接 https://github.com/xk11z/…

windows无法连接到无线网络怎么办 windows无线网络连接不上的解决方法

windows无法连接到无线网络怎么办&#xff1f;一般出现这种问题的都是笔记本电脑&#xff0c;笔记本找不到无线网络也就相当于不能上网&#xff0c;今天小编要为大家带来的就是windows无线网络连接不上的解决方法&#xff0c;一共有五种解决教程&#xff0c;有需要的可以来看看…

7.6 函数的递归调用

直接调用&#xff1a; ### 1. 直接递归调用 直接递归调用是指一个函数直接调用自己。例如&#xff0c;计算阶乘的函数&#xff0c;可以使用递归方法&#xff1a; int factorial(int n) {if (n < 1) {return 1;}return n * factorial(n - 1); } 在这个例子中&#xff0c;f…

2021年12月 C/C++(六级)真题解析#中国电子学会#全国青少年软件编程等级考试

C/C++编程(1~8级)全部真题・点这里 第1题:电话号码 给你一些电话号码,请判断它们是否是一致的,即是否有某个电话是另一个电话的前缀。比如: Emergency 911 Alice 97 625 999 Bob 91 12 54 26 在这个例子中,我们不可能拨通Bob的电话,因为Emergency的电话是它的前缀,当拨…

Java object类

一、JDK类库的根类:obiect 1、这个类中的方法都是所有子类通用的。任何一个类默认继承object。就算没有直接继承&#xff0c;最终也会间接继承。 2、obiect类当中有哪些常用的方法?我们去哪里找这些方法呢? 第一种方法:去源代码当中。(但是这种方式比较麻烦&#xff0c;源代…

重写 UGUI

重写Button using UnityEngine; using UnityEngine.UI; public class MyButton : Button {[SerializeField] private int _newNumber; }using UnityEditor;//编辑器类在UnityEditor命名空间下。所以当使用C#脚本时&#xff0c;你需要在脚本前面加上 "using UnityEditor&q…

Hamilton力学的辛算法简介

Hamilton力学的辛算法简介冯康我的熟人和我 都是曾经要死要活的人我的朋友和我 都是正在要死要活的人 外微分形式与辛几何 外微分形式 1-形式2-形式闭2-形式&#xff08;辛构造&#xff09; Euclid Space 符合如下内积定义的线性空间V称为Euclid空间 对称性 (a, b) (b, a)线…

大数据面试题:MapReduce压缩方式

面试题来源&#xff1a; 《大数据面试题 V4.0》 大数据面试题V3.0&#xff0c;523道题&#xff0c;679页&#xff0c;46w字 可回答&#xff1a;1&#xff09;Hadoop常见的压缩算法有哪些&#xff1f; 问过的一些公司&#xff1a;网易云音乐(2022.11)&#xff0c;阿里(2020.…

【GAMES202】Real-Time Global Illumination(screen space)1—实时全局光照(屏幕空间)1

一、Real-Time Global Illumination(in 3D cont.) 上篇只介绍了RSM&#xff0c;这里我们还会简要介绍另外两种在3D空间中做全局光照的方法&#xff0c;分别是LPV和VXGI。 1.Light Propagation Volumes (LPV) 首先我们知道Radiance在传播过程中是不会被改变的&#xff0c;这点…

9.3-day3-Don‘t let desire break through your will

你这个年龄 是站在阳光下都会发光的年纪 “岂能被欲望所控制”

Shell-AI:基于LLM实现自然语言理解的CLI工具

一、前言 随着AI技术的普及&#xff0c;部分技术领域的门槛逐步降低&#xff0c;比如非科班出身&#xff0c;非技术专业&#xff0c;甚至从未涉足技术领域&#xff0c;完全不懂服务器部署和运维&#xff0c;如今可以依托AI大模型非常轻松的掌握和使用相关技术&#xff0c;来解…