NLP(六十八)使用Optimum进行模型量化

news2025/3/3 18:26:48

  本文将会介绍如何使用HuggingFace的Optimum,来对微调后的BERT模型进行量化(Quantization)。
  在文章NLP(六十七)BERT模型训练后动态量化(PTDQ)中,我们使用PyTorch自带的PTDQ(Post Training Dynamic Quantization)量化策略对微调后的BERT模型进行量化,取得了模型推理性能的提升(大约1.5倍)。本文将尝试使用Optimum量化工具。

Optimum介绍

  OptimumTransformers 的扩展,它提供了一组性能优化工具,可以在目标硬件上以最高效率训练和运行模型。
  Optimum针对不同的硬件,提供了不同的优化方案,如下表:

硬件安装命令
ONNX runtimepython -m pip install optimum[onnxruntime]
Intel Neural Compressor (INC)python -m pip install optimum[neural-compressor]
Intel OpenVINOpython -m pip install optimum[openvino,nncf]
Graphcore IPUpython -m pip install optimum[graphcore]
Habana Gaudi Processor (HPU)python -m pip install optimum[habana]
GPUpython -m pip install optimum[onnxruntime-gpu]

  本文将会介绍基于ONNX的模型量化技术。ONNX(英语:Open Neural Network Exchange)是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的人工智能框架(如Pytorch、MXNet)可以采用相同格式存储模型数据并交互。

模型量化

  我们使用的微调后的BERT模型采用文章NLP(六十六)使用HuggingFace中的Trainer进行BERT模型微调中给出的文本分类模型。
  首先,我们先加载PyTorch中的设备(CPU)。

# load device
import torch

device = torch.device("cpu")

  接着,我们使用optimum.onnxruntime模块加载模型和tokenizer,并将模型保存为onnx格式。

from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
import torch

model_id = "./sougou_test_trainer_256/checkpoint-96"
onnx_path = "./sougou_test_trainer_256/onnx_256"

# load vanilla transformers and convert to onnx
model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# save onnx checkpoint and tokenizer
model.save_pretrained(onnx_path)
tokenizer.save_pretrained(onnx_path)

此时,会多出onnx_256文件夹,保存模型为model.onnx。
保存为onnx模型
输出结果为:

('./sougou_test_trainer_256/onnx_256\\tokenizer_config.json',
 './sougou_test_trainer_256/onnx_256\\special_tokens_map.json',
 './sougou_test_trainer_256/onnx_256\\vocab.txt',
 './sougou_test_trainer_256/onnx_256\\added_tokens.json',
 './sougou_test_trainer_256/onnx_256\\tokenizer.json')

  使用transfomers中的pipeline对模型进行快速推理。

from transformers import pipeline

vanilla_clf = pipeline("text-classification", model=model, tokenizer=tokenizer)
vanilla_clf("这期节目继续关注中国篮球的话题。众所周知,我们已经结束了男篮世界杯的所有赛程,一胜四负的一个成绩,甚至比上一届的世界杯成绩还要差。因为这一次我们连奥运会落选赛也都没有资格参加,所以,连续两次错过了巴黎奥运会的话,对于中国篮协,还有对于姚明来说,确实成为了他任职的一个最大的败笔。对于球迷非常关注的一个话题,乔尔杰维奇是否下课,可能对于这个悬念来说也都是暂时有答案了。")

输出结果如下:

[{'label': 'LABEL_0', 'score': 0.9963239431381226}]

  对ONNX模型进行优化。

from optimum.onnxruntime import ORTOptimizer
from optimum.onnxruntime.configuration import OptimizationConfig

# create ORTOptimizer and define optimization configuration
optimizer = ORTOptimizer.from_pretrained(model)
optimization_config = OptimizationConfig(optimization_level=99) # enable all optimizations

# apply the optimization configuration to the model
optimizer.optimize(
    save_dir=onnx_path,
    optimization_config=optimization_config,
)

此时,优化后的模型为model_optimized.onnx。

  对优化后的模型进行推理。

from transformers import pipeline

# load optimized model
optimized_model = ORTModelForSequenceClassification.from_pretrained(onnx_path, file_name="model_optimized.onnx")

# create optimized pipeline
optimized_clf = pipeline("text-classification", model=optimized_model, tokenizer=tokenizer)
optimized_clf("今年7月,教育部等四部门联合印发了《关于在深化非学科类校外培训治理中加强艺考培训规范管理的通知》(以下简称《通知》)。《通知》针对近年来校外艺术培训的状况而发布,并从源头就校外艺术培训机构的“培训主体、从业人员、招生行为、安全底线”等方面进行严格规范。校外艺术培训之所以火热,主要在于高中阶段艺术教育发展迟滞于学生需求。分析教育部数据,2021年艺术学科在校生占比为9.84%,高于2020年的9.73%;2020至2021年艺术学科在校生的年增长率为5.04%,远高于4.28%的总在校生年增长率。增长的数据,是近年来艺考招生连年火热的缩影,在未来一段时间内,艺考或将在全国范围内继续保持高热度。")

输出结果为:

[{'label': 'LABEL_3', 'score': 0.9926980137825012}]

  对优化后的ONNX模型再进行量化,代码为:

from optimum.onnxruntime import ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig

# create ORTQuantizer and define quantization configuration
dynamic_quantizer = ORTQuantizer.from_pretrained(optimized_model)
dqconfig = AutoQuantizationConfig.avx2(is_static=False, per_channel=False)

# apply the quantization configuration to the model
model_quantized_path = dynamic_quantizer.quantize(
    save_dir=onnx_path,
    quantization_config=dqconfig,
)

此时量化后的模型为model_optimized_quantized.onnx。比较量化前后的模型大小,代码为:

import os

# get model file size
size = os.path.getsize(os.path.join(onnx_path, "model_optimized.onnx"))/(1024*1024)
quantized_model = os.path.getsize(os.path.join(onnx_path, "model_optimized_quantized.onnx"))/(1024*1024)

print(f"Model file size: {size:.2f} MB")
print(f"Quantized Model file size: {quantized_model:.2f} MB")

输出结果为:

Model file size: 390.17 MB
Quantized Model file size: 97.98 MB

  最后,加载量化后的模型,代码为:

# load quantization model
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import pipeline, AutoTokenizer

quantized_model = ORTModelForSequenceClassification.from_pretrained(onnx_path, file_name="model_optimized_quantized.onnx").to(device)
tokenizer = AutoTokenizer.from_pretrained(onnx_path)

推理实验

  在进行模型推理实验前,先加载测试数据集。

import pandas as pd

test_df = pd.read_csv("./data/sougou/test.csv")

  使用量化前的模型进行推理,记录推理时间,代码如下:

# original model evaluate
import numpy as np
import time

cost_time_list = []
s_time = time.time()
true_labels, pred_labels = [], [] 
for i, row in test_df.iterrows():
    row_s_time = time.time()
    true_labels.append(row["label"])
    encoded_text = tokenizer(row['text'], max_length=256, truncation=True, padding=True, return_tensors='pt')
    # print(encoded_text)
    logits = model(**encoded_text)
    label_id = np.argmax(logits[0].detach().numpy(), axis=1)[0]
    pred_labels.append(label_id)
    cost_time_list.append((time.time() - row_s_time) * 1000)
    if i % 100:
    	print(i, (time.time() - row_s_time) * 1000, label_id)

print("avg time:", (time.time() - s_time) * 1000 / test_df.shape[0])
print("P50 time:", np.percentile(np.array(cost_time_list), 50))
print("P95 time:", np.percentile(np.array(cost_time_list), 95))

输出结果为:

0 710.2577686309814 0
100 477.72765159606934 1
200 616.3530349731445 2
300 509.63783264160156 3
400 531.57639503479 4

avg time: 501.0757282526806
P50 time: 504.6522617340088
P95 time: 623.9353895187337

对输出结果进行指标评级,代码为:

from sklearn.metrics import classification_report

print(classification_report(true_labels, pred_labels, digits=4))

  重复上述代码,将模型替换为量化前ONNX模型(model.onnx),优化后ONNX模型(model_oprimized.onnx),量化后ONNX模型(model_optimized_quantized.onnx),进行推理时间(单位:ms)统计和推理指标评估,结果见下表:

模型平均推理时间P95推理时间weighted F1
量化前ONNX模型501.1623.90.9717
优化后ONNX模型484.6629.60.9717
量化后ONNX模型361.5426.90.9738

  对比文章NLP(六十七)BERT模型训练后动态量化(PTDQ)中的推理结果,原始模型的平均推理时间为666.6ms,weighted F1值为0.9717,我们有如下结论:

  • ONNX模型不影响推理效果,但在平均推理时间上提速约1.33倍
  • 优化ONNX模型不影响推理效果,但在平均推理时间上提速约1.38倍
  • 量化后的ONNX模型影响推理效果,一般会略有下降,本次实验结果为提升,但在平均推理时间上提速约1.84倍,由于PyTorch的PTDQ(模型训练后动态量化)

总结

  本文介绍了如何使用HuggingFace的Optimum,来对微调后的BERT模型进行量化(Quantization),在optimum.onnxruntime模块中,平均推理时间提速约1.8倍。
  本文已开源至Github,网址为:https://github.com/percent4/dynamic_quantization_on_bert 。
  本文已开通个人博客,欢迎大家访问:https://percent4.github.io/ 。
  欢迎你关注我的微信公众号,每周我都会在这里更新文章。

参考文献

  1. NLP(六十六)使用HuggingFace中的Trainer进行BERT模型微调:https://blog.csdn.net/jclian91/article/details/132644042
  2. NLP(六十七)BERT模型训练后动态量化(PTDQ):https://blog.csdn.net/jclian91/article/details/132644042
  3. Optimum: https://huggingface.co/docs/optimum/index
  4. Optimizing Transformers with Hugging Face Optimum: https://www.philschmid.de/optimizing-transformers-with-optimum

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/981066.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

李宏毅-机器学习hw4-self-attention结构-辨别600个speaker的身份

一、慢慢分析学习pytorch中的各个模块的参数含义、使用方法、功能: 1.encoder编码器中的nhead参数: self.encoder_layer nn.TransformerEncoderLayer( d_modeld_model, dim_feedforward256, nhead2) 所以说,这个nhead的意思,就…

使用Maven创建父子工程

📚目录 创建父工程创建子模块创建子模块示例创建认证模块(auth) 结束 创建父工程 选择空项目: 设置:项目名称,组件名称,版本号等 创建完成后的工程 因为我们需要设置这个工程为父工程所以不需要src下的所有文件 在pom…

WPF Flyout风格动画消息弹出消息提示框

WPF Flyout风格动画消息弹出消息提示框 效果如图&#xff1a; XAML: <Window x:Class"你的名称控件.FlyoutNotication"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xam…

java八股文面试[数据库]——索引覆盖

覆盖索引是一种避免回表查询的优化策略: 只需要在一棵索引树上就能获取SQL所需的所有列数据&#xff0c;无需回表&#xff0c;速度更快。 具体的实现方式: 将被查询的字段建立普通索引或者联合索引&#xff0c;这样的话就可以直接返回索引中的的数据&#xff0c;不需要再通过聚…

肖sir__设计测试用例方法之因果图07_(黑盒测试)

设计测试用例方法之因果图 一、定义&#xff1a;因果图提供了一个把规格转化为判定表的系统化方法&#xff0c;从该图中可以产生测试数据。其 中&#xff0c;原因是表示输入条件&#xff0c;结果是对输入执 行的一系列计算后得到的输出。 二、因果图方法最终生成的就是判定表。…

rhcsa4 进程和SSH

tree命令。用于以树状结构显示目录和文件。通过运行 “tree” 命令可视化地查看文件系统中的目录结构。 tree / systemd是第一个系统进程&#xff08;pid1&#xff09;不启动&#xff0c;其他进程也没法启动&#xff0c; 用pstree查看进程树 我们可以看到所有进程都是syste…

蓝桥杯打卡Day3

文章目录 吃糖果递推数列 一、吃糖果IO链接 本题思路:本题题意就是斐波那契数列&#xff01; #include <bits/stdc.h>typedef uint64_t i64;i64 f(i64 n) {if(n1) return 1;if(n2) return 2;return f(n-1)f(n-2); }signed main() {std::ios::sync_with_stdio(false);s…

GRU门控循环单元

GRU 视频链接 https://www.bilibili.com/video/BV1Pk4y177Xg?p23&spm_id_frompageDriver&vd_source3b42b36e44d271f58e90f86679d77db7Zt—更新门 Rt—重置门 控制保存之前一层信息多&#xff0c;还是保留当前神经元得到的隐藏层的信息多。 Bi-GRU GRU比LSTM参数少 …

服务器数据恢复-阵列崩溃导致LVM结构破坏的数据恢复案例

服务器数据恢复环境&#xff1a; 一台服务器中有两组分别由4块SAS硬盘组建的raid5阵列&#xff0c;两组阵列上层划分LUN组建LVM结构&#xff0c;并被格式化为EXT3文件系统。 服务器故障&检测&#xff1a; RIAD5阵列中有一块硬盘故障离线&#xff0c;热备盘激活上线顶替离线…

西门子PLC的优势在哪呢?

今日话题&#xff0c;西门子PLC有何优势以至于能够在竞争中超越三菱和欧姆龙&#xff1f;西门子PLC作为德国品牌&#xff0c;具有独特的优势。视频后方有学习资料免费发放&#xff0c;有兴趣的移步自取。首先&#xff0c;尽管其指令相对抽象&#xff0c;学习难度较高&#xff0…

2.k8s账号密码登录设置

文章目录 前言一、启动脚本二、配置账号密码登录2.1.在hadoop1&#xff0c;也就是集群主节点2.2.在master的apiserver启动文件添加一行配置2.3 绑定admin2.4 修改recommended.yaml2.5 重启dashboard2.6 登录dashboard 总结 前言 前面已经搭建好了k8s集群&#xff0c;现在设置下…

【Mycat1.6】缓存不生效问题处理

背景 系统做读写分离&#xff0c;有大量读需求&#xff0c;基本没有实时获取数据业务需要&#xff0c;所以可以启用缓存来减缓数据库压力&#xff0c;传统使用mybatis的缓存需要大量侵入式声明&#xff0c;所以结合需求使用Mycat中间件来满足 数据库结构 mysql-master&#…

直播系统源码部署,高效文件管理与传输的FTP协议

引言&#xff1a; 在直播系统源码部署的过程中&#xff0c;开发协议是支持直播系统源码功能技术搭建成功并发挥作用的关键之一&#xff0c;在直播系统源码的众多协议中&#xff0c;有一个协议可以帮助直播系统源码部署完成后用户进行媒体文件的上传、下载、管理等操作&#xff…

CMake生成Visual Studio工程

CMake – 生成Visual Studio工程 C/C项目经常使用CMake构建工具。CMake 项目文件&#xff08;例如 CMakeLists.txt&#xff09;可以直接由 Visual Studio 使用。本文要说明的是如何将CMake项目转换到Visual Studio解决方案(.sln)或项目(.vcxproj) 开发环境 为了生成Visual S…

mysql数据库通过拷贝目录实现迁移

在windows环境中&#xff0c;如果mysql已有数据目录&#xff0c;进行数据迁移&#xff0c;可以通过直接拷贝数据文件的方式实现。下面是详细步骤 1 下载安装一个同版本的mysql数据库 到mysql官网下载MySQL安装文件&#xff0c;以下是mysql官网地址: https://downloads.mysql.c…

基于3D扫描和3D打印的产品逆向工程实战【数字仪表】

逆向工程是一种从物理零件创建数字设计的强大方法&#xff0c;并且可以与 3D 扫描和 3D 打印等技术一起成为原型设计工具包中的宝贵工具。 推荐&#xff1a;用 NSDT编辑器 快速搭建可编程3D场景 3D 扫描仪可以非常快速地测量复杂的物体&#xff0c;并且在涉及现实生活参考时可以…

自动化监控系统PrometheusGrafana

Prometheus 算是一个全能型选手&#xff0c;原生支持容器监控&#xff0c;当然监控传统应用也不是吃干饭的&#xff0c;所以就是容器和非容器他都支持&#xff0c;所有的监控系统都具备这个流程&#xff0c;数据采集→数据处理→数据存储→数据展示→告警 Prometheus 特点展开…

DAY-01--分布式微服务基础概念

一、项目简介 了解整体项目包含后端、前端、周边维护。整个项目的框架知识。 二、分布式基础概念 1、微服务 将应用程序 基于业务 拆分为 多个小服务&#xff0c;各小服务单独部署运行&#xff0c;采用http通信。 2、集群&分布式&节点 集群是个物理形态&#xff0c;…

02 CSS技巧

02 CSS技巧 clip-path 自定义形状&#xff0c;或者使用自带的属性画圆等circle HTML结构 <body><div class"container"></div> </body>CSS结构 使用*polygon*自定义形状 .container {width: 300px;height: 300px;background-color: re…

基于jeecg-boot的flowable流程历史记录显示修改

更多nbcio-boot功能请看演示系统 gitee源代码地址 后端代码&#xff1a; https://gitee.com/nbacheng/nbcio-boot 前端代码&#xff1a;https://gitee.com/nbacheng/nbcio-vue.git 在线演示&#xff08;包括H5&#xff09; &#xff1a; http://122.227.135.243:9888 历…