PyTorch2ONNX-分类模型:速度比较(固定维度、动态维度)、精度比较

news2025/1/11 2:24:11
图像分类模型部署: PyTorch -> ONNX

1. 模型部署介绍

1.1 人工智能开发部署全流程

step1
数据
数据采集
定义类别
标注
数据集
step2
模型
训练模型
测试集评估
调参优化
可解释分析
step3
部署
手机/平板
服务器
PC/浏览器
嵌入式开发板

1.2 模型部署平台和芯片介绍

  • 设备:PC、浏览器、APP、小程序、服务器、嵌入式开发板、无人车、无人机、Jetson Nano、树莓派、机械臂、物联网设备
  • 厂商
    • 英特尔(Intel):主要生产 CPU(中央处理器)和一些 FPGA(现场可编程门阵列)芯片。代表作品包括 Intel Core 系列 CPU 和 Xeon 系列服务器 CPU,以及 FPGA 产品如 Intel Stratix 系列。
    • 英伟达(NVIDIA):以 GPU(图形处理器)为主打产品,广泛应用于图形渲染、深度学习等领域。代表作品包括 NVIDIA GeForce 系列用于游戏图形处理,NVIDIA Tesla 和 NVIDIA A100 用于深度学习加速。
    • AMD:主要生产 CPU 和 GPU。代表作品包括 AMD Ryzen 系列 CPU 和 AMD EPYC 系列服务器 CPU,以及 AMD Radeon 系列 GPU 用于游戏和专业图形处理。
    • 苹果(Apple):生产自家设计的芯片,主要包括苹果 M 系列芯片。代表作品有 M1 芯片,广泛应用于苹果的 Mac 电脑、iPad 和一些其他设备。
    • 高通(Qualcomm):主要生产移动平台芯片,包括移动处理器和调制解调器。代表作品包括 Snapdragon 系列芯片,用于智能手机和移动设备。
    • 昇腾(Ascend):由华为生产,主要生产 NPU(神经网络处理器),用于深度学习任务。代表作品包括昇腾 910 和昇腾 310。
    • 麒麟(Kirin):同样由华为生产,主要生产手机芯片,包括 CPU 和 GPU。代表作品包括麒麟 9000 系列,用于华为旗舰手机。
    • 瑞芯微(Rockchip):主要生产 VPU(视觉处理器)和一些移动平台芯片。代表作品包括 RK3288 和 RK3399,广泛应用于智能显示、机器人等领域。
芯片名英文名中文名厂商主要任务是否训练是否推理算力速度
CPUCentral Processing Unit(CPU)中央处理器各大厂商通用计算中等
GPUGraphics Processing Unit(GPU)图形处理器NVIDIA、AMD等图形渲染、深度学习加速
TPUTensor Processing Unit(TPU)张量处理器谷歌机器学习中的张量运算
NPUNeural Processing Unit(NPU)神经网络处理器华为、联发科等深度学习模型的性能提升中等
VPUVision Processing Unit(VPU)视觉处理器英特尔、博通等图像和视频处理中等中等
DSPDigital Signal Processor(DSP)数字信号处理器德州仪器、高通等数字信号处理、音频信号处理中等中等
FPGAField-Programmable Gate Array(FPGA)现场可编程门阵列英特尔、赛灵思等可编程硬件加速器中等

1.3 模型部署的通用流程

转换
运行
PyTorch
TensorFlow
Caffe
PaddlePaddle
训练框架
ONNX/中间表示
推理框架/引擎/后端
TensorRT
ONNXRuntime
OpenVINO
NCNN/TNN
PPL

2. 使用 ONNX 的意义

从这两张图可以很明显的看到,当有了中间表示 ONNX 后,从原来的 M × N M \times N M×N 变为了 M + N M + N M+N,让模型部署的流程变得简单。

3. ONNX 的介绍

开源机器学习通用中间格式,由微软、Facebook(Meta)、亚马逊、IBM 共同发起的。它可以兼容各种深度学习框架,也可以兼容各种推理引擎和终端硬件、操作系统

4. ONNX 环境安装

pip install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple

5. PyTorch → ONNX

5.1 将一个分类模型转换为 ONNX

import torch
from torchvision import models


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"正在使用的设备: {device}")

# 创建一个训练好的模型
model = models.resnet18(pretrained=True)  # ImageNet 预训练权重
model = model.eval().to(device)

# 构建一个输入
dummy_input = torch.randn(size=[1, 3, 256, 256]).to(device)  # [N, B, H, W]

# 让模型推理
output = model(dummy_input)
print(f"output.shape: {output.shape}")

# 使用 PyTorch 自带的函数将模型转换为 ONNX 格式
onnx_save_path = 'ONNX/saves/resnet18_imagenet.onnx'  # 导出的ONNX模型路径 
with torch.no_grad():
    torch.onnx.export(
        model=model,                            # 要转换的模型
        args=dummy_input,                       # 模型的输入
        f=onnx_save_path,                       # 导出的ONNX模型路径 
        input_names=['input'],                  # ONNX模型输入的名字(自定义)
        output_names=['output'],                # ONNX模型输出的名字(自定义)
        opset_version=11,                       # Opset算子集合的版本(默认为17)
    )
    
print(f"ONNX 模型导出成功,路径为:{onnx_save_path}")
正在使用的设备: cpu
/home/leovin/anaconda3/envs/wsl/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/leovin/anaconda3/envs/wsl/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/leovin/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 44.7M/44.7M [00:03<00:00, 13.9MB/s]
output.shape: torch.Size([1, 1000])
ONNX 模型导出成功,路径为:ONNX/saves/resnet18_imagenet.onnx

💡 Tips:

  1. opset 算子集不同版本区别: Operators.md
  2. 虽然说 PyTorch 在提醒 pretrained=True 将会被弃用,可以使用 weights=weights=ResNet18_Weights.DEFAULTweights=ResNet18_Weights.IMAGENET1K_V1 来代替。但很明显前者比较方便,后者还需要查看对应的版本号,比较麻烦 😂

接下来我们使用 Netron 查看一下这个模型:

  1. 原图链接为:resnet18_imagenet.png

  2. ImageNet 数据集有 1000 个类别

5.2 检查一个模型导出是否正确

import onnx


# 读取导出的模型
onnx_path = 'ONNX/saves/resnet18_imagenet.onnx'  # 导出的ONNX模型路径
onnx_model = onnx.load(onnx_path)

# 检查模型是否正常
onnx.checker.check_model(onnx_model)

print(f"模型导出正常!")
模型导出正常!

我们在《onnx基础》中已经讲过 check_model() 这个函数,它可以检查 ONNX 模型,如果该函数发现模型错误,则会抛出异常,

5.3 修改动态维度

前面我们导出的 ONNX 模型中,输入的维度是固定的:[1, 3, 256, 256],那么此时这个 ONNX 的输入就被限制了:

  • 如果我们想要多 Batch 的输入 → 不行
  • 如果我们输入的图片是灰度图 → 不行
  • 如果我们输入的图片尺寸不是 256×256 → 不行

torch.onnx.export() 这个函数也帮我解决了这个问题,它有一个名为 dynamic_axis 的参数,我们看一下官网对该参数的描述:

dynamic_axes (dict[string, dict[int, string]] or dict[string, list(int)], default empty dict) –

By default the exported model will have the shapes of all input and output tensors set to exactly match those given in args. To specify axes of tensors as dynamic (i.e. known only at run-time), set dynamic_axes to a dict with schema:

  • KEY (str): an input or output name. Each name must also be provided in input_names or output_names.
  • VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a list, each element is an axis index.

dynamic_axes(dict[string, dict[int, string]]或dict[string, list(int)],默认为空字典)–

默认情况下,导出的模型将使所有输入和输出张量的形状完全匹配args中给定的形状。要将张量的轴指定为动态(即仅在运行时知道),请将dynamic_axes设置为一个具有以下结构的字典:

  • KEY(str):输入或输出的名称。每个名称还必须在 input_namesoutput_names 中提供。
  • VALUE(dict或list):如果是字典,则键是轴索引,值是轴名称。如果是列表,则每个元素是轴索引。

下面我们用一下这个参数:

import torch
from torchvision import models
import onnx


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"正在使用的设备: {device}")

# 创建一个训练好的模型
model = models.resnet18(pretrained=True)  # ImageNet 预训练权重
model = model.eval().to(device)

# 构建一个输入
dummy_input = torch.randn(size=[1, 3, 256, 256]).to(device)  # [N, B, H, W]

# 让模型推理
output = model(dummy_input)
print(f"output.shape: {output.shape}\n")

# ------ 使用 PyTorch 自带的函数将模型转换为 ONNX 格式
onnx_save_path = 'ONNX/saves/resnet18_imagenet-with_dynamic_axis.onnx'  # 导出的ONNX模型路径 
with torch.no_grad():
    torch.onnx.export(
        model=model,                            # 要转换的模型
        args=dummy_input,                       # 模型的输入
        f=onnx_save_path,                       # 导出的ONNX模型路径 
        input_names=['input'],                  # ONNX模型输入的名字(自定义)
        output_names=['output'],                # ONNX模型输出的名字(自定义)
        opset_version=11,                       # Opset算子集合的版本(默认为17)
        dynamic_axes={                          # 修改某一个维度为动态
            'input': {0: 'B', 2: 'H', 3: 'W'}   # 将原本的 [1, 3, 256, 256] 修改为 [B, 3, H, W]
        }
    )
    
print(f"ONNX 模型导出成功,路径为:{onnx_save_path}\n")

# ------ 验证导出的模型是否正确
# 读取导出的模型
onnx_model = onnx.load(onnx_save_path)

# 检查模型是否正常
onnx.checker.check_model(onnx_model)

print(f"模型导出正常!")
正在使用的设备: cpu
output.shape: torch.Size([1, 1000])

ONNX 模型导出成功,路径为:ONNX/saves/resnet18_imagenet-with_dynamic_axis.onnx

模型导出正常!

此时我们再用 Netron 看一下这个模型:

可以看到,输入的 Batch、Height、Width 均变为了动态维度,即只有当模型运行的时候才知道输入的这三个维度具体的值

6. ONNX Runtime 部署:推理单张图片

import os
import random
import numpy as np
from PIL import Image
import onnxruntime
from torchvision import transforms
import torch
import torch.nn.functional as F
import pandas as pd


# ==================================== 加载 ONNX 模型,创建推理会话 ==================================== 
ort_session = onnxruntime.InferenceSession(path_or_bytes='ONNX/saves/resnet18_imagenet-fix_axis.onnx')  # ort -> onnxruntime

# ==================================== 模型冷启动 ==================================== 
dummy_input = np.random.randn(1, 3, 256, 256).astype(np.float32)
ort_inputs = {'input': dummy_input}
ort_output = ort_session.run(output_names=['output'], input_feed=ort_inputs)[0]  # 输出被[]包围了,所以需要取出来
print(f"模型冷启动完毕! 其推理结果的shape为: {ort_output.shape}")

# ==================================== 加载真正的图像 ==================================== 
images_folder = 'Datasets/Web/images'
images_list = [os.path.join(images_folder, img) for img in os.listdir(images_folder) if img.lower().endswith(('.jpg', '.png', '.webp'))]

img_path = images_list[random.randint(0, len(images_list)-1)]
img = Image.open(fp=img_path)

# ==================================== 图像预处理 ==================================== 
# 定义预处理函数
img_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # imagenet专用
        std=[0.229, 0.224, 0.225]),  # imagenet专用
])

# 对图片进行预处理
input_img = img_transform(img)
print(f"input_img.type: {type(input_img)}")
print(f"input_img.shape: {input_img.shape}")

# 为图片添加batch维度
input_img = torch.unsqueeze(input_img, dim=0)

# ==================================== ONNX模型推理 ==================================== 
# 因为ONNXRuntime需要的是numpy而非torch的tensor, 所以将其转换为numpy
input_img = input_img.numpy()
print(f"input_img.type: {type(input_img)}")
print(f"input_img.shape: {input_img.shape}")

# 模型推理图片
ort_inputs = {'input': input_img, }
ort_results = ort_session.run(output_names=['output'], input_feed=ort_inputs)[0]  # 得到 1000 个类别的分数
print(f"模型推理完毕! 此时结果的shape为:{ort_results.shape}")

# ==================================== 后处理 ==================================== 
# 使用 softmax 函数将分数转换为概率
ort_results_softmax = F.softmax(input=torch.from_numpy(ort_results), dim=1)
print(f"经过softmax后的输出的shape为:{ort_results_softmax.shape}")

# 取概率最大的前 n 个结果
n = 3
top_n = torch.topk(input=ort_results_softmax, k=n)

probs = top_n.values.numpy()[0]
indices = top_n.indices.numpy()[0]

print(f"置信度最高的前{n}个结果为:\t{probs}\n"
      f"对应的类别索引为:\t\t{indices}")

# ==================================== 显示类别 ==================================== 
df = pd.read_csv('Datasets/imagenet_classes_indices.csv')

idx2labels = {}
for idx, row in df.iterrows():
    # idx2labels[row['ID']] = row['class']  # 英文标签
    idx2labels[row['ID']] = row['Chinese']  # 中文标签

print(f"=============== 推理结果 ===============\n"
      f"图片路径: {img_path}")
for i, (class_prob, idx) in enumerate(zip(probs, indices)):
    class_name = idx2labels[idx]
    text = f"\tNo.{i}: {class_name:<30} --> {class_prob:>.4f}"
    print(text)
模型冷启动完毕! 其推理结果的shape为: (1, 1000)
input_img.type: <class 'torch.Tensor'>
input_img.shape: torch.Size([3, 256, 256])
input_img.type: <class 'numpy.ndarray'>
input_img.shape: (1, 3, 256, 256)
模型推理完毕! 此时结果的shape为:(1, 1000)
经过softmax后的输出的shape为:torch.Size([1, 1000])
置信度最高的前3个结果为:       [9.9472505e-01 7.4335985e-04 5.2123831e-04]
对应的类别索引为:              [673 662 487]
=============== 推理结果 ===============
图片路径: Datasets/Web/images/mouse.jpg
        No.0: 鼠标,电脑鼠标                        --> 0.9947
        No.1: 调制解调器                          --> 0.0007
        No.2: 移动电话,手机                        --> 0.0005

💡 图片链接:Web/images

💡 ImageNet 类别文件链接:imagenet_classes_indices.csv

7. ONNX Runtime 和 PyTorch 速度对比

  1. 不同尺度下单张图片推理 --> 对比代码链接
  2. 不同尺度下多张图片推理 --> 对比代码链接

实验环境

  • CPU:Intel i5-10400F @ 2.90 GHz
  • Memory: 8 x 2 = 16GB
  • Disk: SSD
  • GPU: RTX 3070 O8G
  • OS: Windows 10 (WSL)
  • Device: CPU
  • 模型推理次数: 50

7.1 ResNet-18

实验结果

Input ShapeONNX(fix)ONNX(fix+sim)ONNX(dyn)ONNX(dyn+sim)PyTorch(CPU)PyTorch(GPU)
[1, 3, 32, 32]0.0577s0.0597s0.0592s0.0585s0.0688s0.0787s
[1, 3, 64, 64]0.0605s0.0593s0.0588s0.0621s0.0700s0.0723s
[1, 3, 128, 128]0.0705s0.0686s0.0699s0.0694s0.0762s0.0760s
[1, 3, 256, 256]0.0784s0.0811s0.0797s0.0789s0.0949s0.0813s
[1, 3, 512, 512]0.1249s0.1241s0.1251s0.1256s0.1686s0.0996s
[1, 3, 640, 640]0.1569s0.1525s0.1572s0.1579s0.2242s0.0863s
[1, 3, 768, 768]0.1986s0.1946s0.1985s0.2038s0.2933s0.0956s
[1, 3, 1024, 1024]0.2954s0.2957s0.3094s0.3045s0.4871s0.1047s
[16, 3, 32, 32]0.2540s0.2545s0.2558s0.2498s0.2570s0.2473s
[16, 3, 64, 64]0.2811s0.2745s0.2696s0.2655s0.2824s0.2553s
[16, 3, 128, 128]0.3595s0.3181s0.3143s0.3544s0.3817s0.3518s
[16, 3, 256, 256]0.7315s0.7112s0.6767s0.6122s0.7169s0.3469s
[16, 3, 512, 512]1.3042s1.2586s1.1813s1.1949s1.6609s0.4270s
[16, 3, 640, 640]1.6340s1.6429s1.6659s1.6693s2.3923s0.5292s
[16, 3, 768, 768]2.2843s2.2830s2.3325s2.3303s3.9278s1.7851s
[16, 3, 1024, 1024]3.9132s3.9742s3.9668s3.9104s6.7532s3.6507s

画图结果

⚠️ 在 [18, 3, 768, 768]、 时,PyTorch(CPU) 因为内存不足导致只能推理 1 次而非 50 次

⚠️ 在 [18, 3, 1024, 1024]、 时,PyTorch(CPU) 和 PyTorch(GPU) 因为内存不足导致只能推理 1 次而非 50 次

结论

  • 单 Batch
    • 静态维度和动态维度相差不大
    • 当图片尺寸在 [32, 32] ~ [256, 256] 之间时,ONNX 速度比 PyTorch-GPU 速度要快;当图片尺寸大于 [256, 256] 时,PyTorch-GPU 拥有绝对的优势
    • 当图片尺寸小于 [64, 64] 时,PyTorch-CPU 速度快于 PyTorch-GPU;当图片尺寸大于 [64, 64] 时,PyTorch-GPU 速度快于 PyTorch-CPU
    • 无论在什么时候,ONNX 速度均快于 PyTorch-CPU
  • 多 Batch
    • 静态维度和动态维度相差不大
    • 当图片尺寸小于 [128, 128] 时,ONNX、PyTorch-CPU、PyTorch-GPU 三者很难有区别(实际上 PyTorch-GPU 速度要慢一些,因为要将模型和输入放到 GPU 中,这部分会划分几秒钟的时间)
    • 当图片尺寸大于 [128, 128] 时,GPU 逐渐扩大优势(由于 OOM 的原因,[18, 3, 1024, 1024] 下 PyTorch-GPU 只推理了一次,因此速度被拉平了很多。在显存足够充裕的情况下,PyTorch-GPU 的速度是碾压其他方法的)
    • 当图片尺寸大于 [256, 256] 时,PyTorch-CPU 的速度远远慢于 ONNX
  • Sim 前后
    • 可以发现,在使用 python -m onnxsim 前后差距不大
  • 总结
    • 在使用 CPU 进行推理时,建议使用 ONNX 进行,因为不光速度有优势,而且对内存的占用也比 PyTorch-CPU 要小的多
    • 在进行多 Batch 推理时,如果有 GPU 还是使用 PyTorch-GPU,这样会缩减大量的时间(⚠️ GPU 在加载模型和输入时可能会比较耗时)
    • ⚠️ 在使用 python -m onnxsim 前后差距不大

7.2 MobileNetV3-Small

接下来我们在 MobileNetV3-Small 上也进行相同的实验。

⚠️ 因为 opset=11 不支持 hardsigmoid 算子,在官网上查询后,我们使用 opset=17

⚠️ 在使用 opset=17 时可能会报错,报错原因一般是当前 PyTorch 版本低导致的,可以创建一个新的环境,使用最新的 PyTorch(也可以不实验,直接看我得结论就行 😂)

Input ShapeONNX(fix)ONNX(dyn)PyTorch(CPU)PyTorch(GPU)
[1, 3, 32, 32]0.0575s0.0619s0.0636s0.0731s
[1, 3, 64, 64]0.0585s0.0591s0.0643s0.0701s
[1, 3, 128, 128]0.0611s0.0597s0.0629s0.0700s
[1, 3, 256, 256]0.0627s0.0622s0.0690s0.0731s
[1, 3, 512, 512]0.0714s0.0703s0.0841s0.0765s
[1, 3, 640, 640]0.0776s0.0785s0.0975s0.0823s
[1, 3, 768, 768]0.0867s0.0861s0.1138s0.0851s
[1, 3, 1024, 1024]0.1103s0.1126s0.1630s0.0958s
[16, 3, 32, 32]0.2410s0.2295s0.2538s0.2446s
[16, 3, 64, 64]0.2443s0.2421s0.2576s0.2481s
[16, 3, 128, 128]0.2618s0.2576s0.2804s0.2727s
[16, 3, 256, 256]0.3097s0.3131s0.3502s0.3043s
[16, 3, 512, 512]0.5556s0.5873s0.7655s0.3970s
[16, 3, 640, 640]0.7191s0.7130s0.8988s0.4877s
[16, 3, 768, 768]0.9293s0.9285s1.5091s0.5754s
[16, 3, 1024, 1024]1.4768s1.4945s3.3530s1.1316s

画图结果

⚠️ 在 [18, 3, 1024, 1024]、 时,PyTorch(CPU) 因为内存不足导致只能推理 1 次而非 50 次

其实可以发现,与 ResNet18 的结论是一致的。

7.3 为什么 python -m onnxsim 没有效果

我们看一下这个过程:

-------------- ResNet-18 --------------

python -m onnxsim ONNX/saves/resnet18-dynamic_dims.onnx ONNX/saves/resnet18-dynamic_dims-sim.onnx
Simplifying...
Finish! Here is the difference:
┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃                   ┃ Original Model ┃ Simplified Model ┃
┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│ Add               │ 8              │ 8                │
│ Constant          │ 42             │ 42               │
│ Conv              │ 20             │ 20               │
│ Flatten           │ 1              │ 1                │
│ Gemm              │ 1              │ 1                │
│ GlobalAveragePool │ 1              │ 1                │
│ MaxPool           │ 1              │ 1                │
│ Relu              │ 17             │ 17               │
│ Model Size        │ 44.6MiB        │ 44.6MiB          │
└───────────────────┴────────────────┴──────────────────┘

-------------- MobileNetV3-Small --------------

python -m onnxsim ONNX/saves/mobilenetv3small-dynamic_dims.onnx ONNX/saves/mobilenetv3small-dynamic_dims-sim.onnx
Simplifying...
Finish! Here is the difference:
┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃                   ┃ Original Model ┃ Simplified Model ┃
┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│ Add               │ 6              │ 6                │
│ Constant          │ 108            │ 108              │
│ Conv              │ 52             │ 52               │
│ Flatten           │ 1              │ 1                │
│ Gemm              │ 2              │ 2                │
│ GlobalAveragePool │ 10             │ 10               │
│ HardSigmoid       │ 9              │ 9                │
│ HardSwish         │ 19             │ 19               │
│ Mul               │ 9              │ 9                │
│ Relu              │ 14             │ 14               │
│ Model Size        │ 9.7MiB         │ 9.7MiB           │
└───────────────────┴────────────────┴──────────────────┘

可以看到,其实根本没有变化,所以速度也没有提升。

⚠️ ONNX 文件变大了可能是因为 onnxsim 放了一些东西在模型中,但对模型性能没有影响。

8. ONNX 与 PyTorch 精度对比

我们现在有如下的模型:

  • weights.pth: PyTorch 权重
  • weights.onnx: ONNX 权重
  • weights-sim.onnx: ONNX 精简后的权重

模型的关系如下:

torch.onnx.export
python -m onnxsim
weights.pth
weights-sim.onnx
weights.onnx

现在我们想要搞清楚,这样转换后的模型精度是怎么样的?

import os
import argparse
import numpy as np
import pandas as pd
from PIL import Image
import onnxruntime
import torch
import torch.nn.functional as F
from torchvision import transforms, models
from rich.progress import track


# ==================================== 参数 ==================================== 
parser = argparse.ArgumentParser()
parser.add_argument('--image_folder_path', type=str, default='Datasets/Web/images', help='图片路径')
parser.add_argument('--input-shape', type=int, nargs=2, default=[256, 256])
parser.add_argument('--verbose', action='store_true', help='')
args = parser.parse_args()  # 解析命令行参数

onnx_weights = 'ONNX/saves/model-dynamic_dims.onnx'
onnx_weights_sim = 'ONNX/saves/model-dynamic_dims-sim.onnx'
# ==============================================================================

# 定义模型
onnx_model = onnxruntime.InferenceSession(path_or_bytes=onnx_weights)
onnx_model_sim = onnxruntime.InferenceSession(path_or_bytes=onnx_weights_sim)
pytorch_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).eval()  # ⚠️ 一定要 .eval
# pytorch_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# 定义预处理函数
img_transform = transforms.Compose([
    transforms.Resize(args.input_shape[-1]),
    transforms.CenterCrop(args.input_shape[-1]),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # imagenet专用
        std=[0.229, 0.224, 0.225]),  # imagenet专用
])

image_list = [os.path.join(args.image_folder_path, img) for img in os.listdir(args.image_folder_path) \
               if img.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]

for img_idx, image_path in track(enumerate(image_list), description='Precision Comparison'):
    # 读取图片
    img = Image.open(fp=image_path)  # 读取图片
    input_img = img_transform(img)
    input_img = input_img.unsqueeze(0)
    print(f"inputs.type: {type(input_img)}") if args.verbose else ...
    print(f"inputs.shape: {input_img.shape}") if args.verbose else ...

    model_ls = ['pt', 'onnx', 'onnx-sim']
    for model_name in model_ls:
        if model_name != 'pt':
            if not isinstance(input_img, np.ndarray):
                input_img = input_img.numpy()
            model_input = {'input': input_img, }
            model_result = onnx_model.run(output_names=['output'], input_feed=model_input)[0]
        else:
            model_result = pytorch_model(input_img)
        
        if not isinstance(model_result, torch.Tensor):
            model_result = torch.from_numpy(model_result)
        
        model_result_softmax = F.softmax(input=model_result, dim=1)  # [1, 1000]

        # 取概率最大的前 n 个结果
        n = 3
        top_n = torch.topk(input=model_result_softmax, k=n, dim=1)

        probs = top_n.values.detach().numpy()[0]  # torch.Size([18, 3])
        indices = top_n.indices.detach().numpy()[0]  # torch.Size([18, 3])
        print(f"probs: {probs}") if args.verbose else ...
        print(f"indices: {indices}") if args.verbose else ...

        df = pd.read_csv('Datasets/imagenet_classes_indices.csv')

        idx2labels = {}
        for _, row in df.iterrows():
            idx2labels[row['ID']] = row['Chinese']  # 中文标签

        print(f"============================== 推理结果-{model_name} ==============================")  if args.verbose else ...
        
        _results = []
        for i, (prob, idx) in enumerate(zip(probs, indices)):
            class_name = idx2labels[idx]
            text = f"No.{i}: {class_name:<30} --> {prob:>.5f}"  if args.verbose else ...
            _results.append(prob)
            print(text)
        print(f"=====================================================================")  if args.verbose else ...

        with open("ONNX/saves/Precision-comparison.txt", 'a') as f:
            if model_name == 'pt':
                f.write(f"|[{img_idx+1}] {os.path.basename(image_path)}"
                        f"|{_results[0]:>.5f}</br>{_results[1]:>.5f}</br>{_results[2]:>.5f}")
            elif model_name == 'onnx':
                f.write(f"|{_results[0]:>.5f}</br>{_results[1]:>.5f}</br>{_results[2]:>.5f}")
            else:
                f.write(f"|{_results[0]:>.5f}</br>{_results[1]:>.5f}</br>{_results[2]:>.5f}|\n")

实验结果

图片名称PyTorchONNXONNX-sim
[1] book.jpg0.739730.050490.023580.739730.050490.023580.739730.050490.02358
[2] butterfly.jpg0.897040.047720.015420.897040.047720.015420.897040.047720.01542
[3] camera.jpg0.276580.177090.109250.276580.177090.109250.276580.177090.10925
[4] cat.jpg0.277730.183930.172540.277730.183930.172540.277730.183930.17254
[5] dog.jpg0.517870.253840.059290.517870.253840.059290.517870.253840.05929
[6] dogs_orange.jpg0.352890.301140.077910.352890.301140.077910.352890.301140.07791
[7] female.jpg0.156000.080310.048080.156000.080310.048080.156000.080310.04808
[8] free-images.jpg0.455950.176260.084140.455950.176260.084140.455950.176260.08414
[9] gull.jpg0.647110.233240.044300.647110.233240.044300.647110.233240.04430
[10] laptop-phone.jpg0.493790.354050.060630.493790.354050.060630.493790.354050.06063
[11] monitor.jpg0.516780.441930.022320.516780.441930.022320.516780.441930.02232
[12] motorcycle.jpg0.317120.224350.156310.317120.224350.156310.317120.224350.15631
[13] mouse.jpg0.994730.000740.000520.994730.000740.000520.994730.000740.00052
[14] panda.jpg0.945590.031990.005610.945590.031990.005610.945590.031990.00561
[15] share_flower_fullsize.jpg0.788060.056910.024830.788060.056910.024830.788060.056910.02483
[16] tiger.jpeg0.617490.380010.000520.617490.380010.000520.617490.380010.00052

可以看到,转换前后模型并没有精度的丢失。

9. 〔拓展知识〕为什么 .pt 模型在推理时一定要 .eval()

在PyTorch中,.eval() 是一个用于将模型切换到评估(inference)模式的方法。在评估模式下,模型的行为会有所变化,主要体现在两个方面:DropoutBatch Normalization

  1. Dropout:

    • 在训练阶段,为了防止过拟合,通常会使用 dropout 策略,即在每个训练步骤中,以一定的概率随机丢弃某些神经元的输出。
    • 在推理阶段,我们希望获得模型的确定性输出,而不是在每次推理时都丢弃不同的神经元。因此,在推理时应该关闭 dropout。通过调用 .eval(),PyTorch 会将所有 dropout 层设置为评估模式,即不进行随机丢弃。
  2. Batch Normalization:

    • Batch Normalization(批标准化)在训练时通过对每个 mini-batch 进行标准化来加速训练,但在推理时,我们通常不是基于 mini-batch 进行预测,因此需要使用整个数据集的统计信息进行标准化。
    • .eval() 模式下,Batch Normalization 会使用训练时计算的移动平均和方差,而不是使用当前 mini-batch 的统计信息。

因此,为了确保在推理时得到一致和可靠的结果,需要在推理之前调用 .eval() 方法,以确保模型处于评估模式,关闭了 dropout,并使用适当的 Batch Normalization 统计信息。


举个例子,对于一张猫咪图片而言,如果我们的 .pt 模型没有开启 .eval() 就进行推理,那么得到的结果如下:

============================== 推理结果-pt ==========================
No.0: 桶                                --> 0.00780
No.1: 手压皮碗泵                        --> 0.00680
No.2: 钩爪                              --> 0.00601
====================================================================
probs: [0.27773306 0.18392678 0.17254312]
indices: [281 285 287]
============================== 推理结果-onnx ========================
No.0: 虎斑猫                            --> 0.27773
No.1: 埃及猫                            --> 0.18393
No.2: 猞猁,山猫                         --> 0.17254
====================================================================
probs: [0.27773306 0.18392678 0.17254312]
indices: [281 285 287]
============================== 推理结果-onnx-sim ====================
No.0: 虎斑猫                            --> 0.27773
No.1: 埃及猫                            --> 0.18393
No.2: 猞猁,山猫                         --> 0.17254
====================================================================

可以看到,对于 ONNX 模型而言,推理相对来说是比较正确的。但对于 PyTorch 模型,推理与猫无关了,所以 ⚠️ 在推理时开启 .eval() 是非常重要的!

参考

  1. 图像分类模型部署-Pytorch转ONNX
  2. Pytorch图像分类模型部署-ONNX Runtime本地终端推理

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1419914.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

5G安卓核心板开发板_MT6833天玑700规格参数

核心板采用沉金生产工艺&#xff0c;耐腐蚀抗干扰&#xff0c;支持-20℃-70℃环境下7x24小时稳定运行&#xff0c;尺寸仅为45mmx48mm x2.65mm&#xff0c;可嵌入到各种智能产品中&#xff0c;助力智能产品便携化及功能差异化。 联发科MT6833处理器采用台积电 7nm 制程的5G SoC…

基于YOLOv8深度学习的水稻叶片病害智能诊断系统【python源码+Pyqt5界面+数据集+训练代码】深度学习实战

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源&#xff0c;可关注公-仲-hao:【阿旭算法与机器学习】&#xff0c;共同学习交流~ &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 《------往期经典推…

USB-C显示器:未来显示技术的革新者

随着科技的不断发展&#xff0c;显示技术也在不断进步&#xff0c;而USB-C显示器作为最新的显示技术&#xff0c;正在引领着显示行业的发展潮流。USB-C显示器具有许多优点&#xff0c;如高速传输、便捷连接、节能环保等&#xff0c;使其成为未来显示技术的革新者。 一、USB-C显…

[Grafana]ES数据源Alert告警发送

简单的记录一下使用es作为数据源&#xff0c;如何在发送告警是带上相关字段 目录 前言 一、邮件配置 二、配置 1.Query 2.Alerts 总结 前言 ES作为数据源&#xff0c;算是Grafana中比较常见的&#xff0c;Alerts告警是我近期刚接触&#xff0c;有一个需求是当表空间大于…

Apache SeaTunnel (不含web) Window11 本机搭建(非源码)

启动环境 需要提前准备的(只提供作者试过且可行的方案) window11ubuntu20(wsl2) window11内置ubuntu的方式自行百度&#xff0c;此处不做陈述jdk8mysql8navicatvscode 环境准备不做过多陈述&#xff0c;以下是正式的安装启动步骤 SeaTunnel 2.3.3 资源准备 第一步: 创建文件…

基于Javaweb开发的二手图书零售系统详细设计【附源码】

基于Javaweb开发的二手图书零售系统详细设计【附源码】 &#x1f345; 作者主页 央顺技术团队 &#x1f345; 欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; &#x1f345; 文末获取源码联系方式 &#x1f4dd; &#x1f345; 查看下方微信号获取联系方式 承接各种定制系统…

【webrtc】‘ninja.exe‘ 不是内部或外部命令,也不是可运行的程序及vs2019 重新构建m98

werbtc 就是用ninja.exe 来构建找到了自己以前构建的webrtc 原版 m98 【m98 】webrtc ninja 构建 、example、tests 及OWT- P2P 项目P2PMFC-E2E-m98G:\CDN\rtcCli\webrtc-checkout\src找到了自己的deptools的路径 deptools里确实没有ninja.exe D:\SOFT\depot_tools\third_party…

Nginx进阶篇【五】

Nginx进阶篇【五】 八、Nginx实现服务器端集群搭建8.1.Nginx与Tomcat部署8.1.1.环境准备(Tomcat)8.1.1.1.浏览器访问:8.1.1.2.获取动态资源的链接地址:8.1.1.3.在Centos上准备一个Tomcat作为后台web服务器8.1.1.4.准备一个web项目&#xff0c;将其打包为war8.1.1.5.启动tomcat进…

数据结构奇妙旅程之七大排序

꒰˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好&#xff0c;我是xiaoxie.希望你看完之后,有不足之处请多多谅解&#xff0c;让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN …

欧拉计划第816题:求大量点的最短距离

本次来解决欧拉计划的第816题: 解: 第一步:最原始的算法 先从简单的情况开始,即原题里的14个点的情况 import mathdef gen_points(n):s = [0] * (2*n)s[0] = 290797for i in range(1, 2*n):s[i] = (s[i - 1] * s[i - 1]) % 50515093p = [(s[2 * i], s[2 * i + 1]) for…

Android悬浮窗的实现

最近想做一个悬浮窗秒表的功能&#xff0c;所以看下悬浮窗具体的实现步骤 1、初识WindowManager 实现悬浮窗主要用到的是WindowManager SystemService(Context.WINDOW_SERVICE) public interface WindowManager extends ViewManager {... }WindowManager是接口类&#xff0c…

【HTML教程】跟着菜鸟学语言—HTML5个人笔记经验(五)完结

HTML学习第五天 PS&#xff1a;牛牛只是每天花了1.5-2小时左右来学习HTML。这也是最后一天&#xff0c;其实HTML只需要1-2天就可以学完&#xff01; 书接上回 HTML 脚本 JavaScript 使 HTML 页面具有更强的动态和交互性。 尝试一下&#x1f3f7; 插入一段脚本 <!DOCT…

C语言菜鸟入门·判断语句(if语句、if...else语句、嵌套if语句)详细介绍

目录 1. if语句 2. if...else语句 3. if...else if...else 语句 4. 嵌套if语句 C 语言把任何非零和非空的值假定为 true&#xff0c;把零或 null 假定为 false。 语句描述if语句一个 if 语句 由一个布尔表达式后跟一个或多个语句组成。if...else语句一个 if 语句 后可跟…

Flutter 应用服务:主题、暗黑、国际化、本地化-app_service库

Flutter应用服务 主题、暗黑、国际化、本地化-app_service库 作者&#xff1a;李俊才 &#xff08;jcLee95&#xff09;&#xff1a;https://blog.csdn.net/qq_28550263 邮箱 &#xff1a;291148484163.com 本文地址&#xff1a;https://blog.csdn.net/qq_28550263/article/det…

FullStack之Django(1)开发环境配置

FullStack之Django(1)开发环境配置 author: Once Day date&#xff1a;2022年2月11日/2024年1月27日 漫漫长路&#xff0c;才刚刚开始… 全系列文档请查看专栏: FullStack开发_Once_day的博客-CSDN博客Django开发_Once_day的博客-CSDN博客 具体参考文档: The web framewor…

从比亚迪的整车智能战略,看王传福的前瞻市场布局

众所周知&#xff0c;作为中国新能源汽车的代表企业&#xff0c;比亚迪在中国乃至全球的新能源汽车市场一直都扮演着引领者的角色。2024年新年伊始&#xff0c;比亚迪又为新能源汽车带来了一项重磅发布。 整车智能才是真智能 近日&#xff0c;在“2024比亚迪梦想日”上&#xf…

微服务-微服务Alibaba-Nacos 源码分析(上)

Nacos&Ribbon&Feign核心微服务架构图 架构原理 1、微服务系统在启动时将自己注册到服务注册中心&#xff0c;同时外发布 Http 接口供其它系统调用(一般都是基于Spring MVC) 2、服务消费者基于 Feign 调用服务提供者对外发布的接口&#xff0c;先对调用的本地接口加上…

Java强训day11(选择题编程题)

选择题 编程题 题目1 import java.util.Scanner;public class Main {public static String TentoTwo(int n) {StringBuffer sum new StringBuffer();while (n ! 0) {sum.append(n % 2);n / 2;}return sum.reverse().toString();}public static void main(String[] args) {S…

大模型日报-20240130

500行代码构建对话搜索引擎&#xff0c;贾扬清被内涵的Lepton Search真开源了 来了&#xff0c;贾扬清承诺的 Lepton Search 开源代码来了。前天&#xff0c;贾扬清在 Twitter 上公布了 Lepton Search 的开源项目链接&#xff0c;并表示任何人、任何公司都可以自由使用开源代码…

【STM32F103单片机】利用ST-LINK V2烧录程序 面包板的使用

1、ST‐LINK V2安装 参考&#xff1a; http://t.csdnimg.cn/Ulhhq 成功&#xff1a; 2、烧录器接线 背后有标识的引脚对应&#xff1a; 3、烧录成功 烧录成功后&#xff0c;按下核心板的RESET键复位&#xff01;&#xff01;&#xff01;即可成功&#xff01; 4、面包板的…