七、图像分类模型的部署(Datawhale组队学习)

news2024/11/17 16:00:07

文章目录

  • 前言
    • ONNX简介
    • 应用场景
  • 部署ImageNet预训练图像分类模型
    • 导出ONNX模型
    • 推理引擎ONNX Runtime部署-预测单张图像
      • 前期准备
      • ONNX Runtime预测
    • 推理引擎ONNX Runtime部署-ImageNet预训练图像分类模型预测摄像头实时画面
      • 前期准备
      • 预测摄像头的一帧画面
      • 预测摄像头实时画面
  • 部署自己训练的图像分类模型
    • 导出ONNX模型
    • 推理引擎ONNX Runtime部署-预测单张图像
      • ONNX Runtime预测
      • 解析预测结果
  • 总结

本文内容为 同济子豪兄图像分类系列视频的学习笔记, 项目参考代码。本文使用ONNX-ONNX Runtime部署我们的模型。
请添加图片描述

前言

ONNX简介

ONNX是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的人工智能框架(如Pytorch, TensorFlow等)可以采用相同格式存储模型数据并交互。

应用场景

后面的代码在需要部署的硬件上运行,只需把onnx模型文件发到部署硬件上,并安装 ONNX Runtime 环境,用几行代码就可以运行模型了。

pip install onnx onnxruntime

部署ImageNet预训练图像分类模型

导出ONNX模型

import torch
from torchvision import models

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

载入ImageNet预训练图像分类模型

model = models.resnet18(pretrained=True)
model = model.eval().to(device)

Pytorch模型转ONNX模型

x = torch.randn(1, 3, 256, 256).to(device)

with torch.no_grad():
    torch.onnx.export(
        model,                  # 要转换的模型
        x,                      # 模型的任意一组输入
        'resnet18.onnx',        # 导出的 ONNX 文件名
        opset_version=11,       # ONNX 算子集版本
        input_names=['input'],  # 输入 Tensor 的名称(自己起名字)
        output_names=['output'] # 输出 Tensor 的名称(自己起名字)
    ) 

验证onnx模型导出成功

import onnx

# 读取 ONNX 模型
onnx_model = onnx.load('resnet18.onnx')

# 检查模型格式是否正确
onnx.checker.check_model(onnx_model)
print('无报错,onnx模型载入成功')

#以可读形式打印计算图
#print(onnx.helper.printable_graph(onnx_model.graph))

无报错,onnx模型载入成功

使用Netron对onnx模型可视化
在这里插入图片描述

在这里插入图片描述
关于如何理解这个网络我们可以参考子豪兄的【精读AI论文】ResNet深度残差网络,一下为ResNet网络的简要介绍。
在这里插入图片描述
Resnet将一个模块的输入分为两条路。右边这条路称为短路连接,这个连接将输入原封不动的传入到输出。左边的这条路是两层的神经网络,这两层神经网络不用拟合复杂的底层映射,只用你和原来输入的基础上进行的偏移和修改(即残差)就可以了。最后将残差和恒等映射相加再用Relu激活函数处理。

推理引擎ONNX Runtime部署-预测单张图像

使用推理引擎 ONNX Runtime,读取 onnx 格式的模型文件,对单张图像文件进行预测。

import onnxruntime
import numpy as np
import torch

前期准备

载入 onnx 模型,获取 ONNX Runtime 推理器

ort_session = onnxruntime.InferenceSession('resnet18.onnx')

预处理

from torchvision import transforms

# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

载入测试图像并进行预处理

img_path = 'banana1.jpg'
# 用 pillow 载入
from PIL import Image
img_pil = Image.open(img_path)
img_pil

在这里插入图片描述

input_img = test_transform(img_pil)
input_tensor = input_img.unsqueeze(0).numpy()
input_tensor.shape

(1, 3, 256, 256)

ONNX Runtime预测

注意,输入输出张量的名称需要和 torch.onnx.export 中设置的输入输出名对应

# ONNX Runtime 输入
ort_inputs = {'input': input_tensor}
# ONNX Runtime 输出
pred_logits = ort_session.run(['output'], ort_inputs)[0]
pred_logits = torch.tensor(pred_logits)
import torch.nn.functional as F
pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
pred_softmax.shape

torch.Size([1, 1000])

对预测结果进行柱状图可视化

import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize=(8,4))

x = range(1000)
y = pred_softmax.cpu().detach().numpy()[0]

ax = plt.bar(x, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
plt.ylim([0, 1.0]) # y轴取值范围
# plt.bar_label(ax, fmt='%.2f', fontsize=15) # 置信度数值

plt.xlabel('Class', fontsize=20)
plt.ylabel('Confidence', fontsize=20)
plt.tick_params(labelsize=16) # 坐标文字大小
plt.title(img_path, fontsize=25)

plt.show()

在这里插入图片描述

推理引擎ONNX Runtime部署-ImageNet预训练图像分类模型预测摄像头实时画面

前期准备

#导入包
import onnxruntime
import torch
import pandas as pd
import numpy as np
from PIL import Image, ImageFont, ImageDraw
import matplotlib.pyplot as plt
%matplotlib inline
import os
import numpy as np
import pandas as pd
import cv2 # opencv-python
from tqdm import tqdm # 进度条
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn.functional as F
from torchvision import models

# 导入中文字体,指定字号
font = ImageFont.truetype('SimHei.ttf', 32)

# 载入 onnx 模型,获取 ONNX Runtime 推理器
ort_session = onnxruntime.InferenceSession('resnet18.onnx')

# 载入ImageNet 1000图像分类标签
df = pd.read_csv('imagenet_class_index.csv')
idx_to_labels = {}
for idx, row in df.iterrows():
    idx_to_labels[row['ID']] = row['Chinese']

# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
from torchvision import transforms
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

预测摄像头的一帧画面

获取摄像头的一帧画面

# 导入opencv-python
import cv2
import time
# 获取摄像头,传入0表示获取系统默认摄像头
cap = cv2.VideoCapture(1)
# 打开cap
cap.open(0)
time.sleep(1)
success, img_bgr = cap.read() 
# 关闭摄像头
cap.release()
# 关闭图像窗口
cv2.destroyAllWindows()

img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) # BGR转RGB
img_pil = Image.fromarray(img_rgb)
img_pil

在这里插入图片描述
使用onnx runtime进行预测

# 预处理
input_img = test_transform(img_pil)
input_tensor = input_img.unsqueeze(0).numpy()
# onnx runtime 预测

# onnx runtime 输入
ort_inputs = {'input': input_tensor}
# onnx runtime 输出
pred_logits = ort_session.run(['output'], ort_inputs)[0]
pred_logits = torch.tensor(pred_logits)
import torch.nn.functional as F
pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算

## 解析图像分类预测结果
n = 5
top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度
draw = ImageDraw.Draw(img_pil) 
# 在图像上写字
for i in range(len(confs)):
    pred_class = idx_to_labels[pred_ids[i]]
    text = '{:<15} {:>.3f}'.format(pred_class, confs[i])
    # 文字坐标,中文字符串,字体,rgba颜色
    draw.text((50, 100 + 50 * i), text, font=font, fill=(255, 0, 0, 1))
img = np.array(img_pil) # PIL 转 array
plt.imshow(img)
plt.show()

在这里插入图片描述

预测摄像头实时画面

处理单帧画面的函数(中文)

# 处理帧函数
def process_frame(img):
    
    # 记录该帧开始处理的时间
    start_time = time.time()
    
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR转RGB
    img_pil = Image.fromarray(img_rgb) # array 转 PIL

    ## 预处理
    input_img = test_transform(img_pil) # 预处理
    input_tensor = input_img.unsqueeze(0).numpy()
    ## onnx runtime 预测
    ort_inputs = {'input': input_tensor} # onnx runtime 输入
    pred_logits = ort_session.run(['output'], ort_inputs)[0] # onnx runtime 输出
    pred_logits = torch.tensor(pred_logits)
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
    
    ## 解析图像分类预测结果
    n = 5
    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
    confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度
    
    draw = ImageDraw.Draw(img_pil) 
    # 在图像上写字
    for i in range(len(confs)):
        pred_class = idx_to_labels[pred_ids[i]]
        text = '{:<15} {:>.3f}'.format(pred_class, confs[i])
        # 文字坐标,中文字符串,字体,rgba颜色
        draw.text((50, 100 + 50 * i), text, font=font, fill=(255, 0, 0, 1))
    img = np.array(img_pil) # PIL 转 array
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # RGB转BGR
    
    # 记录该帧处理完毕的时间
    end_time = time.time()
    # 计算每秒处理图像帧数FPS
    FPS = 1/(end_time - start_time)  
    # 图片,添加的文字,左上角坐标,字体,字体大小,颜色,线宽,线型
    img = cv2.putText(img, 'FPS  '+str(int(FPS)), (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 255), 4, cv2.LINE_AA)
    return img

调用摄像头获取每帧

# 调用摄像头逐帧实时处理模板
# 不需修改任何代码,只需修改process_frame函数即可
# 导入opencv-python
import cv2
import time
# 获取摄像头,传入0表示获取系统默认摄像头
cap = cv2.VideoCapture(1)
# 打开cap
cap.open(0)
# 无限循环,直到break被触发
while cap.isOpened():
    # 获取画面
    success, frame = cap.read()
    if not success:
        print('Error')
        break  
    ## !!!处理帧函数
    frame = process_frame(frame)
    # 展示处理后的三通道图像
    cv2.imshow('my_window',frame)
    if cv2.waitKey(1) in [ord('q'),27]: # 按键盘上的q或esc退出(在英文输入法下)
        break
# 关闭摄像头
cap.release()
# 关闭图像窗口
cv2.destroyAllWindows()

在这里插入图片描述

部署自己训练的图像分类模型

基本流程和上一小节是一样的,我们这里仅展示运用推理引擎ONNX Runtime部署自己训练的水果图像分类模型并预测单张图像。

导出ONNX模型

import torch
from torchvision import models

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

# 导入训练好的模型
model = torch.load('checkpoints/fruit30_pytorch_20230123.pth')
model = model.eval().to(device)

Pytorch模型转ONNX模型

x = torch.randn(1, 3, 256, 256).to(device)

with torch.no_grad():
    torch.onnx.export(
        model,                   # 要转换的模型
        x,                       # 模型的任意一组输入
        'fruit30_resnet18.onnx', # 导出的 ONNX 文件名
        opset_version=11,        # ONNX 算子集版本
        input_names=['input'],   # 输入 Tensor 的名称(自己起名字)
        output_names=['output']  # 输出 Tensor 的名称(自己起名字)
    ) 

验证onnx模型导出成功

import onnx

# 读取 ONNX 模型
onnx_model = onnx.load('fruit30_resnet18.onnx')

# 检查模型格式是否正确
onnx.checker.check_model(onnx_model)
print('无报错,onnx模型载入成功')

#以可读的形式打印计算图
#print(onnx.helper.printable_graph(onnx_model.graph))

无报错,onnx模型载入成功

和上一小节一样也可以使用Netron对onnx模型可视化
在这里插入图片描述

推理引擎ONNX Runtime部署-预测单张图像

import onnxruntime
import numpy as np
import torch

载入 onnx 模型,获取 ONNX Runtime 推理器

ort_session = onnxruntime.InferenceSession('fruit30_resnet18.onnx')

预处理

from torchvision import transforms

# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])
img_path = 'test_img/watermelon1.jpg'
# 用 pillow 载入
from PIL import Image
img_pil = Image.open(img_path)
img_pil

在这里插入图片描述

input_img = test_transform(img_pil)
input_tensor = input_img.unsqueeze(0).numpy()
input_tensor.shape

(1, 3, 256, 256)

ONNX Runtime预测

注意,输入输出张量的名称需要和 torch.onnx.export 中设置的输入输出名对应

# ONNX Runtime 输入
ort_inputs = {'input': input_tensor}
# ONNX Runtime 输出
pred_logits = ort_session.run(['output'], ort_inputs)[0]
pred_logits = torch.tensor(pred_logits)
import torch.nn.functional as F
pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
pred_softmax.shape

torch.Size([1, 30])

解析预测结果

#设置matplotlib中文字体
# windows操作系统
plt.rcParams['font.sans-serif']=['SimHei']  # 用来正常显示中文标签 
plt.rcParams['axes.unicode_minus']=False  # 用来正常显示负号
#载入类别和对应 ID
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()

绘制预测结果柱状图

import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize=(22, 10))

x = idx_to_labels.values()
y = pred_softmax.cpu().detach().numpy()[0] * 100
width = 0.45 # 柱状图宽度

ax = plt.bar(x, y, width)

plt.bar_label(ax, fmt='%.2f', fontsize=15) # 置信度数值
plt.tick_params(labelsize=20) # 设置坐标文字大小

plt.title(img_path, fontsize=30)
plt.xticks(rotation=45) # 横轴文字旋转
plt.xlabel('类别', fontsize=20)
plt.ylabel('置信度', fontsize=20)
plt.show()

在这里插入图片描述

总结

本文主要讲述了ONNX-ONNX Runtime部署流程,首先将训练好的Pytorch模型转ONNX模型,这样我们就可以将ONNX模型在任何安装了ONNX Runtime环境的机器上进行运行,进行单张图片的预测、调用摄像头进行实时画面的预测等。使用ONNX我们可以让模型在不同框架之间进行迁移,方便我们低成本的将模型部署到移动设备中去。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/184302.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaScript的基础知识

目录 一、初识JavaScript 二、JavaScript的基础 1、初步了解 2、代码位置 3、注释 4、变量 ①字符串 ②数组 ③对象 ④条件语句 ⑤函数 三、DOM模块 一、初识JavaScript JavaScript&#xff0c;是一门编程语言。浏览器就是JavaScript语言的解释器。DOM和BOM 相当于编…

Unity功能——宏定义的使用

声明&#xff1a;本文为个人笔记&#xff0c;用于学习研究使用非商用&#xff0c;内容为个人研究及综合整理所得&#xff0c;若有违规&#xff0c;请联系&#xff0c;违规必改。 Unity功能——宏定义的使用 文章目录Unity功能——宏定义的使用一.开发环境二.问题描述三.宏的使用…

拿下大厂Offer的关键——飞滴出行网约车项目全新完结

哈喽各位小伙伴&#xff0c;好久不见吖&#xff01;正月初八&#xff0c;你开工了吗&#xff1f;告别新春的氛围&#xff0c;又开始新一年的奋斗。尤其是年前的离职的小伙伴&#xff0c;马上又是金三银四&#xff0c;你的面试还要准备多久&#xff1f;今天给大家分享一份阿里大…

Leetcode力扣秋招刷题路-0098

从0开始的秋招刷题路&#xff0c;记录下所刷每道题的题解&#xff0c;帮助自己回顾总结 98. 验证二叉搜索树 给你一个二叉树的根节点 root &#xff0c;判断其是否是一个有效的二叉搜索树。 有效 二叉搜索树定义如下&#xff1a; 节点的左子树只包含 小于 当前节点的数。 节点…

[cpp进阶]C++类型转换

文章目录C语言的类型转换为什么C需要四种类型转换C强制类型转换static_castreinterpret_castconst_castdynamic_castexplicitRTTIC语言的类型转换 在C语言中&#xff0c;如果赋值运算符左右两侧类型不同&#xff0c;或者形参与实参类型不匹配&#xff0c;或者返回值类型与接收…

使用lnmp与wordpress做1个外贸询盘网站

目录 lnmp安装 包安装 mysql元数据库 网路策略确认 iptables确认 mysql允许远程访问 wordpress下载安装 包安装 nginx配置 wordpress配置 初始化 astra&#xff0c;elementor和woocommerce插件 插件安装 模板选择 自定义网页 国内不兴建站&#xff0c;通常只有码…

openstack: nova : reset-state

https://github.com/openstack/python-novaclient 牵扯的两个project是&#xff1a;nova和python-novaclient&#xff1b; 这个命令从代码分析和实际使用上来看只是将nova数据库里的实例的状态更改&#xff1b;没有对实例做实质的操作。 https://docs.openstack.org/nova/pik…

01 C语言实现动态气泡碰撞和移动的效果,小球碰撞,Win7气泡壁纸,碰撞算法

C语言实现动态气泡碰撞和移动的效果 作者将狼才鲸创建日期2023-01-29 Git源码仓库地址&#xff1a;C语言实现动态气泡碰撞和移动的效果CSDN文章地址&#xff1a;01 C语言实现动态气泡碰撞和移动的效果 一、前言 想要实现多气泡相互碰撞的效果&#xff1b; 想着这种在Win7壁纸…

【Cloudera Manager】cdh集群ntp时钟同步问题

CM启动后集群界面出现时钟未同步问题在集群主机通过ntpstat命令查看&#xff0c;出现unsynchronised标识通过timedatectl命令&#xff0c;显示NTP synchronized: no以上说明确实没有同步时钟问题排查与解决首先查看ntp配置文件&#xff0c;cat /etc/ntp.confserver 172.X.X.X配…

万年历农历法定节假日数据查询工具

1.数据来源于百度搜索置顶日历&#xff1a; 2.代码&#xff1a; http调用及数据处理均采用了hutool, 也可以用别的工具。 hutool 依赖如下&#xff1a; <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><versio…

DNS 域名解析服务器详解以及相关实验

目录 一&#xff0c; 域名解析服务器的介绍 1&#xff0c; 因特网的域名结构 2&#xff0c; 域名服务器的类型划分 二&#xff0c;DNS 域名解析的过程 三&#xff0c;DNS 解析方式 四&#xff0c;搭建 DNS 服务器 1&#xff0c;使用命令yum install bind -y安装dns软件 2&#…

对于初学python的小白大佬们有什么建议吗?

我认为态度是一块重要的敲门砖。米卢说&#xff1a;“态度决定一切”。你对人生的态度是这个世界真正的试金石。对不同的事情要有不同的态度。而对待自学&#xff0c;认真就妥了。 首先要为自己设定一个目标&#xff0c;对于初学者&#xff0c;看书的话可以看《Head First Pyt…

几款考研必备软件 你还不知道吗?

几款考研必备软件 你还不知道吗? 英语单词软件推荐 背单词软件 墨墨背单词[推荐指数]⭐️⭐️⭐️⭐️⭐️ 科学高效抗遗忘方法,记录详细记忆行为数据,结合记忆反馈帮你记忆更加牢固 界面简洁舒适无广告,没有任何干扰,就算是强迫症也能使用的非常舒适 完美收集权威单词本,全…

哈佛大学庄小威团队破解衰老大脑的关键变化

“了解衰老是生物医学最重要的目标之一&#xff0c;同时这也是一个非常具有挑战性的问题。”哈佛大学终身教授庄小威说&#xff0c;“造成挑战的原因之一在于大脑非常复杂&#xff0c;细胞种类繁多&#xff0c;许多不同类型的神经元和非神经元细胞形成了复杂的相互作用网络。”…

一种二阶Biquad滤波器

一、首先给出biquad的Z变换函数为&#xff1a; 为了计算方便可对上式进行归一化处理&#xff0c;分子分母同时除以a0&#xff0c;则得出如下&#xff1a; 对应的差分方程为&#xff1a; 二、用户定义参数如下 #ifndef LN2 #define LN2 0.69314718055994530942 #endif #ifnde…

密码学技术导论篇

密码学技术前言基础术语不要使用保密的密码算法任何密码总有一天都会被破解对称密码&#xff08;共享秘钥密码&#xff09;AES总结公钥密码 --- 用公钥加密&#xff0c;私钥解密秘钥配送问题公钥密码中间人攻击认证单向散列函数--- 消息的指纹单向散列函数的实际应用单向散列函…

python接口自动化——unittest简介(详解)

简介 前边的随笔主要介绍的requests模块的有关知识个内容&#xff0c;接下来看一下python的单元测试框架unittest。熟悉 或者了解java 的小伙伴应该都清楚常见的单元测试框架 Junit 和 TestNG&#xff0c;这个招聘的需求上也是经常见到的。python 里面也有单元 测试框架-unitt…

Lua 垃圾回收

Lua 垃圾回收 参考至菜鸟教程。 Lua 采用了自动内存管理。 这意味着你不用操心新创建的对象需要的内存如何分配出来&#xff0c; 也不用考虑在对象不再被使用后怎样释放它们所占用的内存。 Lua运行了一个垃圾收集器来收集所有死对象&#xff08;即在Lua中不可能再访问到的对象&…

新能源——充电控制

一、交流充电——慢充 交流充电&#xff1a;电网输入给车辆的交流电&#xff0c;220V AC单向电或380V AC三相电。 车载充电机&#xff1a;交流电转化为直流电 二、直流充电——快充 三、充电模式 模式1——标准插座 模式2——带有交流电动汽车供电设备的标准插座 模式3——…

idea maven打包编译报错 java.lang.AssertionError: input.getType

今天使用idea打包编译maven项目&#xff0c;出现如下报错 构建报错时&#xff0c;最先显示的是这个报错。查了一圈下来&#xff0c;我的配置是没有问题的。 Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.1:compile (default-compile) on project…