PyTorch模型转ONNX量化模型

news2024/9/25 1:23:20

你是否发现模型太大,无法部署在你想要的云服务上?或者你是否发现 TensorFlow 和 PyTorch 等框架对于你的云服务来说太臃肿了?ONNX Runtime 可能是你的救星。

如果你的模型在 PyTorch 中,你可以轻松地在 Python 中将其转换为 ONNX,然后根据需要量化模型(对于 TensorFlow 模型,你可以使用 tf2onnx)。ONNX Runtime 是轻量级的,量化可以减小模型大小。

让我们尝试将 PyTorch 中预训练的 ResNet-18 模型转换为 ONNX,然后量化。我们将使用 ImageNet 数据集的子集比较准确率。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - AI模型在线查看 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

0、先决条件

首先下载 ImageNet-100 验证数据集并将其解压到一个目录,从现在开始我们将该目录称为  {VAL}。 {VAL} 应如下所示。

{VAL}/
  |--n01440764/
        |--ILSVRC2012_val_00000293.JPEG
        |--...
  |--...
  

换句话说, {VAL}/{synset}/{image_name}.JPEG

接下来下载ImageNet的同义词集(synset)。

如果你想知道“同义词集”是什么,ImageNet 网站是这样描述的:

ImageNet 是根据 WordNet 层次结构组织的图像数据集。WordNet 中的每个有意义的概念(可能由多个单词或词组描述)称为“同义词集”或“同义词集”。

现在,下载此 synset_words.txt 文件,由 J.D. Salinger 的《麦田里的守望者》的狂热粉丝提供。你也应该阅读它。😃

1、软件包

我们需要安装和导入以下软件包。你可以使用 pip 来完成此操作。如果你有受支持的 GPU,可能能够使用为 GPU 构建的软件包版本。(例如 --onnxruntime-gpu

from tqdm import tqdm
from PIL import Image
import glob
import numpy as np
import torch
import torchvision as tv
import onnx
import onnxruntime as ort
from onnxruntime import quantization

TQDM 仅用于美观的进度条。 😄

2、PyTorch环节

对于输入图像,模型输出一个向量,其中包含 1000 个元素,每个元素代表一个同义词集。因此,我们需要使用 synset_words.txt 将数据集中的同义词集与索引中的模型输出向量进行匹配。

synset_to_target = {}
f = open("synset_words.txt", "r")
index = 0
for line in f:
    parts = line.split(" ")
    synset_to_target[parts[0]] = index
    index = index + 1
f.close()

2.1 数据加载器

创建一个可在 DataLoader 中使用的 dataset 类:

preprocess = tv.transforms.Compose([
    tv.transforms.Resize(256),
    tv.transforms.CenterCrop(224),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def tar_transform(synset):
  return synset_to_target[synset]

class ImageNetValDataset(torch.utils.data.Dataset):
  def __init__(self, img_dir, transform=None, target_transform=None):
      self.img_dir = img_dir
      self.img_paths = sorted(glob.glob(img_dir + "*/*.JPEG"), key=lambda x: int(x.split("_")[-1].split(".")[0]))
      self.transform = transform
      self.target_transform = target_transform

  def __len__(self):
      return len(self.img_paths)

  def __getitem__(self, idx):
      img_path = self.img_paths[idx]
      image = Image.open(img_path)
      synset = img_path.split("/")[-2]
      label = synset
      if self.transform:
          image = self.transform(image)
      if self.target_transform:
          label = self.target_transform(label)
      return image, label

ds = ImageNetValDataset("{VAL}/", transform=preprocess, target_transform=tar_transform)

如果需要,拆分或切片数据集,并保留数据集的未触及部分进行量化。

offset = 500
calib_ds = torch.utils.data.Subset(ds, list(range(offset)))
val_ds = torch.utils.data.Subset(ds, list(range(offset, offset * 2)))

calib_ds 保留用于量化。

创建具有所需批量大小的 DataLoader。

batch_size = 64
dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False)

可以关闭 shuffle,因为按图像名称排序会按预定顺序混合图像。

2.2 PyTorch 模型

从 Torch Hub 下载 ResNet-18。

model_pt = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', weights=tv.models.ResNet18_Weights.DEFAULT)
model_pt.eval()

eval() 将模型设置为推理模式。

使用虚拟输入执行一次推理:

dummy_in = torch.randn(1, 3, 224, 224, requires_grad=True)

dummy_out = model_pt(dummy_in)

2.3 转换为 ONNX

ONNX 模型将保存到给定的路径:

# export fp32 model to onnx
model_fp32_path = 'resnet18_fp32.onnx'

torch.onnx.export(model_pt,                                         # model
                  dummy_in,                                         # model input
                  model_fp32_path,                                  # path
                  export_params=True,                               # store the trained parameter weights inside the model file
                  opset_version=14,                                 # the ONNX version to export the model to
                  do_constant_folding=True,                         # constant folding for optimization
                  input_names = ['input'],                          # input names
                  output_names = ['output'],                        # output names
                  dynamic_axes={'input' : {0 : 'batch_size'},       # variable length axes
                                'output' : {0 : 'batch_size'}})

常量折叠(constant folding)将用预先计算的常量节点替换一些具有所有常量输入的 op。

验证模型的结构并确认模型具有有效的架构。通过检查模型的版本、图形的结构以及节点及其输入和输出来验证 ONNX 图的有效性。

model_onnx = onnx.load(model_fp32_path)
onnx.checker.check_model(model_onnx)

如果测试失败,则会引发异常。

2.4 PyTorch vs. ONNX

定义一个将PyTorch张量转换为NumPy数组的函数:

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

此函数将允许我们将相同的 PyTorch DataLoader 与 ONNX 一起使用。

准备模型:

ort_provider = ['CPUExecutionProvider']
if torch.cuda.is_available():
    model_pt.to('cuda')
    ort_provider = ['CUDAExecutionProvider']

ort_sess = ort.InferenceSession(model_fp32_path, providers=ort_provider)

使用 GPU(如果可用)。

测试模型:

correct_pt = 0
correct_onnx = 0
tot_abs_error = 0

for img_batch, label_batch in tqdm(dl, ascii=True, unit="batches"):

  ort_inputs = {ort_sess.get_inputs()[0].name: to_numpy(img_batch)}
  ort_outs = ort_sess.run(None, ort_inputs)[0]

  ort_preds = np.argmax(ort_outs, axis=1)
  correct_onnx += np.sum(np.equal(ort_preds, to_numpy(label_batch)))

  if torch.cuda.is_available():
    img_batch = img_batch.to('cuda')
    label_batch = label_batch.to('cuda')

  with torch.no_grad():
    pt_outs = model_pt(img_batch)

  pt_preds = torch.argmax(pt_outs, dim=1)
  correct_pt += torch.sum(pt_preds == label_batch)

  tot_abs_error += np.sum(np.abs(to_numpy(pt_outs) - ort_outs))

print("\n")

print(f"pt top-1 acc = {100.0 * correct_pt/len(val_ds)} with {correct_pt} correct samples")
print(f"onnx top-1 acc = {100.0 * correct_onnx/len(val_ds)} with {correct_onnx} correct samples")

mae = tot_abs_error/(1000*len(val_ds))
print(f"mean abs error = {mae} with total abs error {tot_abs_error}")

你可能会得到一些这样的结果:

# CPU
# pt top-1 acc = 79.0 with 395 correct samples
# onnx top-1 acc = 79.0 with 395 correct samples
# mean abs error = 1.7788757681846619e-06 with total abs error 0.8894378840923309

# GPU
# pt top-1 acc = 79.0 with 395 correct samples
# onnx top-1 acc = 79.0 with 395 correct samples
# mean abs error = 4.85603129863739e-06 with total abs error 2.428015649318695

已知 CPU 和 GPU 产生的结果略有不同,具体取决于操作的实现方式和轻微的位错误。

4、ONNX模型量化

根据 ONNX 运行时文档,建议在量化之前执行此预处理步骤,其中包括优化。

model_prep_path = 'resnet18_prep.onnx'

quantization.shape_inference.quant_pre_process(model_fp32_path, model_prep_path, skip_symbolic_shape=False)

预处理后的模型将保存到给定的路径。

4.1 校准数据读取器

根据 ONNX 运行时文档,

通常,建议对 RNN 和基于 transformer 的模型使用动态量化,对 CNN 模型使用静态量化。

由于 ResNet-18 主要是 CNN,我们应该进行静态量化。但是,它需要一个数据集来校准量化的模型参数。(幸好我们把 alib_ds留下了! 😉

class QuntizationDataReader(quantization.CalibrationDataReader):
    def __init__(self, torch_ds, batch_size, input_name):

        self.torch_dl = torch.utils.data.DataLoader(torch_ds, batch_size=batch_size, shuffle=False)

        self.input_name = input_name
        self.datasize = len(self.torch_dl)

        self.enum_data = iter(self.torch_dl)

    def to_numpy(self, pt_tensor):
        return pt_tensor.detach().cpu().numpy() if pt_tensor.requires_grad else pt_tensor.cpu().numpy()

    def get_next(self):
        batch = next(self.enum_data, None)
        if batch is not None:
          return {self.input_name: self.to_numpy(batch[0])}
        else:
          return None

    def rewind(self):
        self.enum_data = iter(self.torch_dl)

qdr = QuntizationDataReader(calib_ds, batch_size=64, input_name=ort_sess.get_inputs()[0].name)

量化模型将保存到给定的路径:

q_static_opts = {"ActivationSymmetric":False,
                 "WeightSymmetric":True}
if torch.cuda.is_available():
  q_static_opts = {"ActivationSymmetric":True,
                  "WeightSymmetric":True}

model_int8_path = 'resnet18_int8.onnx'
quantized_model = quantization.quantize_static(model_input=model_prep_path,
                                               model_output=model_int8_path,
                                               calibration_data_reader=qdr,
                                               extra_options=q_static_opts)

根据 ONNX 运行时存储库,

如果模型以 GPU/TRT 为目标,则需要对称激活和权重。如果模型面向 CPU,建议使用非对称激活和对称权重,以平衡性能和准确性。

你可以从这个ResearchGate 页面了解有关对称/非对称量化的更多信息。

4.2 ONNX FP32 vs. INT8

加载 量化的onnx 模型:

ort_int8_sess = ort.InferenceSession(model_int8_path, providers=ort_provider)

测试模型:

correct_int8 = 0
correct_onnx = 0
tot_abs_error = 0

for img_batch, label_batch in tqdm(dl, ascii=True, unit="batches"):

  ort_inputs = {ort_sess.get_inputs()[0].name: to_numpy(img_batch)}
  ort_outs = ort_sess.run(None, ort_inputs)[0]

  ort_preds = np.argmax(ort_outs, axis=1)
  correct_onnx += np.sum(np.equal(ort_preds, to_numpy(label_batch)))


  ort_int8_outs = ort_int8_sess.run(None, ort_inputs)[0]

  ort_int8_preds = np.argmax(ort_int8_outs, axis=1)
  correct_int8 += np.sum(np.equal(ort_int8_preds, to_numpy(label_batch)))

  tot_abs_error += np.sum(np.abs(ort_int8_outs - ort_outs))


print("\n")

print(f"onnx top-1 acc = {100.0 * correct_onnx/len(val_ds)} with {correct_onnx} correct samples")
print(f"onnx int8 top-1 acc = {100.0 * correct_int8/len(val_ds)} with {correct_int8} correct samples")

mae = tot_abs_error/(1000*len(val_ds))
print(f"mean abs error = {mae} with total abs error {tot_abs_error}")

可能得到类似如下的结果:

# CPU
# onnx top-1 acc = 79.0 with 395 correct samples
# onnx int8 top-1 acc = 77.8 with 389 correct samples
# mean abs error = 0.265933556640625 with total abs error 132966.7783203125

# GPU
# onnx top-1 acc = 79.0 with 395 correct samples
# onnx int8 top-1 acc = 77.4 with 387 correct samples
# mean abs error = 0.44179485546875 with total abs error 220897.427734375

CPU 与 GPU 的结果也可能不同,因为选择了对称与非对称量化方法。


原文链接:PyTorch转ONNX量化模型 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2162014.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

关于YOLOX的一些优势

YOLOX 是旷视开源的高性能检测器。旷视的研究者将解耦头、数据增强、无锚点以及标签分类等目 标检测领域的优秀进展与 YOLO 进行了巧妙的集成组合,提出了 YOLOX,不仅实现了超越 YOLOv3、 YOLOv4 和 YOLOv5 的 AP,而且取得了极具竞争力的推理速…

FME学习笔记

读取数据 方法一:add reader 通过读模块来进行数据的读取 方法二:FeatureReader Parameters 通过转换器来进行数据的读取 可以通过空间范围进行筛选 在FME中,所有数据处理都要用到的,绝对的重点:转换器&#xff…

深圳某局联想SR850服务器黄灯 不开机维修

深圳 福田区1台Lenovo Thinksystem SR850 四路服务器黄灯问题现场处理。 1:型号:联想SR850 机架式2U服务器 2:故障:能通电,开机按钮快闪,随后叹号警告灯常亮 3:用户自行折腾无果后找到我们tech …

【推文制作】秀米简明教程 1.0

【前言】本文内容主要是针对一些常用的秀米操作进行介绍,并说明一些往年的经验要求。但是,最重要的是,请发挥你的艺术创造力,相信你一定可以做出更好看的推文。 一、秀米页面介绍 在使用秀米之前,我们会有一个通过浏览…

Maya学习笔记:项目设置和快捷键

文章目录 项目设置工程文件夹 快捷键 项目设置 工程文件夹 maya需要一个文件夹存放自己的工程内容 先指定一个文件夹 文件/项目窗口 选择一个文件夹,然后选择创建默认工作区 然后生成文件目录 在项目窗口里,选择要生成的子文件夹(保持默认…

【ASE】第二课_溶解效果

今天我们一起来学习ASE插件,希望各位点个关注,一起跟随我的步伐 今天我们来学习溶解效果,通过渐变纹理达到好像燃烧效果的溶解效果 今天我们的效果很简单,但是其中包含没有学习的节点,所以还是要拿出来学习一下 最终…

ESP32异常报错2

出现这种情况 一般是缺少";"分号. 或者缺少, 仔细查找代码.查看是哪儿缺少了这些代码

【2024W35】肖恩技术周刊(第 13 期):肉,好次!

周刊内容: 对一周内阅读的资讯或技术内容精品(个人向)进行总结,分类大致包含“业界资讯”、“技术博客”、“开源项目”和“工具分享”等。为减少阅读负担提高记忆留存率,每类下内容数一般不超过3条。 更新时间: 星期天 历史收录:…

docker快速部署zabbix

两台主机 一台作为server 一台作为agent 安装好docker 并保证服务正常运行,镜像正常pull 分析: 部署 Zabbix 容器环境,通常会涉及几个主要组件: MySQL(或 MariaDB 数据库)、Zabbix Server 和 Zabbix Web I…

c++ 继承 和 组合

目录 一. 继承 1.1 继承的概念 1.2 继承定义 1.3 继承类模板 1.4. 继承中的作用域 二. 派生类(子类)的默认成员函数 2.1 概念: 2.2 实现⼀个不能被继承的类 2.3 继承与友元 2.4继承与静态成员 三.多继承及其菱形继承问题 3.1继承方…

物联网实践教程:微信小程序结合OneNET平台MQTT实现STM32单片机远程智能控制 远程上报和接收数据——汇总

物联网实践教程:微信小程序结合OneNET平台MQTT实现STM32单片机远程智能控制 远程上报和接收数据——汇总 前言 之前在学校获得了一个新玩意:ESP-01sWIFI模块,去搜了一下这个小东西很有玩点,远程控制LED啥的,然后我就想…

CUDA编程三、C++和cuda实现矩阵乘法SGEMM

目录 一、矩阵SGEMM 二、SGEMM的各种实现 1、cpu版本的实现 2、GPU并行计算最初始的版本 GPU中数据的移动 3、矩阵分块Shared Memory优化 4、LDS.128 float4* 优化 5、__syncthreads()位置优化 6、blank conflict优化 bank概念 bank conflict bank conflict危害和处…

IO其他流

1. 缓冲流 昨天学习了基本的一些流,作为IO流的入门,今天我们要见识一些更强大的流。比如能够高效读写的缓冲流,能够转换编码的转换流,能够持久化存储对象的序列化流等等。这些功能更为强大的流,都是在基本的流对象基础…

yum库 docker的小白安装教程(附部分问题及其解决方案)

yum库 首先我们安装yum 首先在控制台执行下列语句 首先切换到root用户,假如已经是了就不用打下面的语句 su root #使用国内的镜像,不执行直接安装yum是国外的,那个有问题 curl -o /etc/yum.repos.d/CentOS-Base.repo https://mirrors.al…

大模型框架 LangChain 介绍

文章目录 langchain介绍安装依赖大模型类别千帆大模型案例常见问题 langchain介绍 是一个开源大语言模型框架,本身不提供大模型算法,只提供对接大模型算法平台的接口(模型包裹器);langchain官网v0.2,内部涉…

POI获取模板文件,替换数据横纵动态表格、折线图、饼状图、折线饼状组合图

先说几个关键的点 pom.xml依赖 <dependency><groupId>commons-io</groupId><artifactId>commons-io</artifactId><version>2.11.0</version> </dependency> <dependency><groupId>com.deepoove</groupId>&…

现代桌面UI框架科普及WPF入门1

现代桌面UI框架科普及WPF入门 文章目录 现代桌面UI框架科普及WPF入门桌面应用程序框架介绍过时的UI框架MFC (Microsoft Foundation Class)缺点 经典的UI框架**WinForms****QT****WPF** 未来的UI框架**MAUI****AvaloniaUI** WPF相对于Winform&#xff0c;QT&#xff0c;MFC的独立…

【深度学习】(5)--搭建卷积神经网络

文章目录 搭建卷积神经网络一、数据预处理1. 下载数据集2. 创建DataLoader&#xff08;数据加载器&#xff09; 二、搭建神经网络三、训练数据四、优化模型 总结 搭建卷积神经网络 一、数据预处理 1. 下载数据集 在PyTorch中&#xff0c;有许多封装了很多与图像相关的模型、…

二阶滤波算法总结(对RC滤波算法整理的部分修正和完善)

文章目录 1、一阶低通滤波2、一阶高通滤波3、二阶低通滤波器3.1 二阶RC低通滤波器的连续域数学模型3.2 二阶RC低通滤波器的算法推导3.3 matlab仿真 4、二阶高通滤波器4.1 二阶RC高通滤波器的连续域数学模型4.2 二阶RC高通滤波器的算法推导4.3 matlab仿真 5、陷波滤波6、带通滤波…

要大爆发的AI Agent是什么?(软件测试人员需要掌握)

什么是AI Agent&#xff1f; AI Agent 是一种软件程序&#xff0c;可以与环境交互&#xff0c;收集数据&#xff0c;并使用数据执行自主任务以实现预定目标。即人类设定目标&#xff0c;AI Agent 独立选择实现这些目标所需的最佳行动。 简单来说&#xff0c;AI Agent是一个能够…