onnxruntime推理

news2024/9/30 11:32:23

pytorch模型训练

这里以pytorch平台和mobilenet v2网络为例,给出模型的训练过程。具体代码如下所示:

import os
import torchvision.transforms as transforms
from torchvision import datasets
import torch.utils.data as data
import torch
import numpy as np
import torchvision.models as models
import torchvision.datasets as datasets
from torch.utils.data import random_split
#模型加载
model = models.mobilenet_v2(pretrained=True)
model.classifier = torch.nn.Sequential(torch.nn.Dropout(p=0.5),torch.nn.Linear(1280, 5))
print("model:")
print(model)
#参数
BATCH_SIZE = 32
DEVICE = 'cuda'
epoch_n = 10
#数据集加载
image_path = 'E:/MobileNets-V2-master/flower_photos'
flower_class = ['daisy','dandelion','roses','sunflowers','tulips']

transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}
#
full_data = datasets.ImageFolder(root=image_path,transform=transform['train']) 
train_size = int(len(full_data)*0.8)  
test_size = len(full_data) - train_size
train_dataset, test_dataset =random_split(full_data, [train_size, test_size])
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=0, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=0, shuffle=False)


print("Training data size: {}".format(len(train_dataset)))
print("Testing data size: {}".format(len(test_dataset)))


#损失函数和优化器
loss_f = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
# 模型训练和参数优化
torch.cuda.empty_cache()
model=model.to(DEVICE)
best_acc=0
#
for epoch in range(epoch_n):
    print("Epoch {}/{}".format(epoch + 1, epoch_n))
    print("-" * 10)
    # 设置为True,会进行Dropout并使用batch mean和batch var
    print("Training...")
    model.train(True)
    running_loss = 0.0
    running_corrects = 0
    # enuerate(),返回的是索引和元素
    for batch, data in enumerate(train_loader):
        X, y = data
        X=X.to(DEVICE)
        y=y.to(DEVICE)
        y_pred = model(X)
        # pred,概率较大值对应的索引值,可看做预测结果
        _, pred = torch.max(y_pred.data, 1)
        # 梯度归零
        optimizer.zero_grad()
        # 计算损失
        loss = loss_f(y_pred, y)
        loss.backward()
        optimizer.step()
        # 计算损失和
        running_loss += float(loss)
        # 统计预测正确的图片数
        running_corrects += torch.sum(pred == y.data)
        if batch%10==9:
            print("loss=",running_loss/(BATCH_SIZE*10))
            print("acc is {}%".format(running_corrects.item()/(BATCH_SIZE*10)*100.0))
            running_loss=0
            running_corrects=0
    #
    print("validating...")
    model.eval()
    val_loss=0.0
    correct=0
    total=0
    with torch.no_grad():
        for batch_idx,(inputs,targets) in enumerate(test_loader):
            inputs,targets=inputs.to(DEVICE),targets.to(DEVICE)
            outputs=model(inputs)
            loss=loss_f(outputs,targets.long())
            _,preds=outputs.max(1)
            val_loss+=loss.item()
            total+=targets.size(0)
            correct+=preds.eq(targets).sum().item()
    acc=100.0*correct/total
    print("Epoch={},val loss={}".format(epoch,val_loss/total))
    print("Epoch={},val acc={}%".format(epoch,acc))
    #
    if acc>best_acc:
        #
        print("current accuracy={},saving...".format(acc))
        torch.save(model,"model.pth")
        best_acc=acc

导出为ONNX格式

ONNX是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的人工智能框架(如Pytorch, MXNet)可以采用相同格式存储模型数据并交互。 ONNX的规范及代码主要由微软,亚马逊 ,Facebook 和 IBM 等公司共同开发,以开放源代码的方式托管在Github上。目前官方支持加载ONNX模型并进行推理的深度学习框架有:Caffe2, PyTorch,MXNet,ML.NET,TensorRT和Microsoft CNTK,并且TensorFlow也非官方的支持ONNX。
在Pytorch中,我们可以使用官方自带的torch.onnx.export函数将模型转换成ONNX的函数:

from turtle import mode
import onnx
import torch
#from SpectralCirC3D import SpectralCirC3D
#from mobilenetv2 import model

def export():   
    model = torch.load("model.pth")
    print(model)

    batch_size = 1  
    input_shape = (3, 224, 224)   #input data shape

    # #set the model to inference mode
    model.eval()

    x = torch.randn(batch_size, *input_shape).cuda()	# 生成张量
    y = model(x)
    print(x.size())
    print(y.size())
    export_onnx_file = "mobilenetv2.onnx"			# 目的ONNX文件名
    torch.onnx.export(model,
                  x,
                  export_onnx_file,
                  opset_version=14,
                  example_outputs=y,
                  do_constant_folding=True,	# 是否执行常量折叠优化
                  input_names=["input"],	# 输入名
                  output_names=["output"],	# 输出名
                  dynamic_axes={"input":{0:"batch_size"},  # 批处理变量
                                "output":{0:"batch_size"}})

def check_onnx():
    # Load the ONNX model
    model = onnx.load("mobilenetv2.onnx")
    # Check that the IR is well formed
    onnx.checker.check_model(model)
    # Print a human readable representation of the graph
    print(onnx.helper.printable_graph(model.graph))

if __name__=='__main__':
    export()
    check_onnx()

如代码所示,export函数用于将pytorch模型导出为onnx格式,在导出前,我们需要显式地指定输入数据的尺寸,批大小,在导出过程中,还可以进行常量折叠优化等。
check_onnx函数则用于检查导出后的onnx文件是否符合规范。

onnxruntime推理

ONNXRuntime是微软推出的一款推理框架,用户可以非常便利的用其运行一个onnx模型。ONNXRuntime支持多种运行后端,包括CPU,GPU,TensorRT,DML等。可以说ONNXRuntime是对ONNX模型最原生的支持。

import argparse
import numpy as np
import onnxruntime
import time
import torchvision.datasets as datasets
from torch.utils.data import random_split
import torch.utils.data as data
import torchvision.transforms as transforms
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static

def load_data():
    #数据集加载
    image_path = 'E:/MobileNets-V2-master/flower_photos'

    transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    }
    #
    full_data = datasets.ImageFolder(root=image_path,transform=transform['val']) 
    train_size = int(len(full_data)*0.8)  
    test_size = len(full_data) - train_size
    train_dataset, test_dataset =random_split(full_data, [train_size, test_size])
    train_loader = data.DataLoader(train_dataset, batch_size=1, num_workers=0, shuffle=True)
    test_loader = data.DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False)
    return train_loader,test_loader

def benchmark(model_path,device):
    if device=='cpu':
        print("using CPUExecutionProvider")
        session = onnxruntime.InferenceSession(model_path,providers=['CPUExecutionProvider'])
    else:
        print("using CUDAExecutionProvider")
        session = onnxruntime.InferenceSession(model_path,providers=['CUDAExecutionProvider'])
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    print("input name:{}".format(input_name))
    print("output name:{}".format(output_name))
    total = 0.0
    runs = 10
    input_data = np.zeros((1, 3, 224, 224), np.float32)
    # Warming up
    output = session.run([output_name], {input_name: input_data})
    print(output[0].shape)
    for i in range(runs):
        start = time.perf_counter()
        _ = session.run([], {input_name: input_data})
        end = (time.perf_counter() - start) * 1000
        total += end
        print(f"{end:.2f}ms")
    total /= runs
    print(f"Avg: {total:.2f}ms")

def infer_test(model_path,data_loader,device):
    if device=='cpu':
        print("using CPUExecutionProvider")
        session = onnxruntime.InferenceSession(model_path,providers=['CPUExecutionProvider'])
    else:
        print("using CUDAExecutionProvider")
        session = onnxruntime.InferenceSession(model_path,providers=['CUDAExecutionProvider'])
    #
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    #
    total = 0.0
    correct = 0
    for batch,data in enumerate(data_loader):
        X, y = data
        X = X.numpy()
        y = y.numpy()
        #
        output = session.run([output_name], {input_name: X})[0]
        y_pred = np.argmax(output,axis=1)
        #
        if y[0]==y_pred[0]:
            correct+=1
        total+=1
    #
    print("accuracy is {}%".format(correct/total*100.0))

def main():
    input_model_path = "mobilenetv2.onnx"
    device=input("cpu or gpu?")
    #test latency
    benchmark(input_model_path,device)

    train_loader,test_loader = load_data()
    print(len(train_loader))
    print(len(test_loader))
    #test accuracy
    infer_test(input_model_path,test_loader,device)
    
if __name__ == "__main__":
    main()

  • 实验结果
    在这里插入图片描述
    如上图所示,CPU平台上单张图片的推理时间约为3.29ms,而GPU平台上单张图片的推理时间约为2.91ms

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/670456.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java面试题Class类的理解?创建类的对象的方式?

1.Class类的理解 1.类的加载过程: 程序经过javac.exe命令以后,会生成一个或多个字节码文件(.class结尾)。 接着我们使用java.exe命令对某个字节码文件进行解释运行。相当于将某个字节码文件 加载到内存中。此过程就称为类的加载。加载到内存中的类&…

一起来了解多领域自动采样器的功能特点

多领域自动采样器体积小,便携式设计,功能丰富,操作简便可用于海洋、河流、船舶、沟渠、深井、排污口等多种场景的水样采集,尤其适用于窨井、下水道、沟渠 等空间狭小、现场条件恶劣的工作场合,可以在环保、科研、污水验…

【计算机组成原理】辅助存储器

目录 一、磁盘存储器 二、固态硬盘SSD 三、虚拟存储系统 一、磁盘存储器 大多数计算机外存储器采用磁盘记录,如今正在逐渐被SSD固态硬盘取代 磁表面存储:磁性材料薄层涂在金属或塑料表面做磁载体存储信息 硬磁盘存储器:基底(磁…

【深度学习】近万字解读深度学习领域有哪些瓶颈?

文章目录 一、导读二、深度学习缺乏理论支撑三、领域内越来越工程师化思维四、对抗样本是深度学习的问题,但不是深度学习的瓶颈五、知乎网友的回答5.1 作者:Giant5.2 作者:知乎用户5.3 作者:何之源 一、导读 虽然深度学习在图像、…

Java去掉 txt 文件中的空格空行【代码记录】

文章目录 1、需求2、代码3、结果 1、需求 2、代码 package com.zibo.main;import java.io.BufferedReader; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.regex.Matcher; import java.util.regex.Pattern;public cla…

外卖商城平台微信小程序 后端ssm

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 外卖商城平台微信小程序 前言一、组织结构二、使用步骤1.后端登录代码2.运行截图 源代码 前言 提示:这里可以添加本文要记录的大概内容: 本外卖商城…

windows 快速删除node_modules文件夹

rmdir /Q /S 目录 删除文件夹(非空) /S 除目录本身外,还将删除指定目录下的所有子目录 /Q 安静模式,带 /S 删除目录树时不要求确认

救援模式 单用户模式

救援模式 救援模式是一种在 Linux 操作系统中用于故障排除和修复的特殊启动模式。它可以提供一些基本的系统功能,以便在出现问题时可以对系统进行诊断和修复。 救援模式通常会加载最小的系统资源和驱动程序,以确保在系统出现故障的情况下仍然可以正常启…

异常—javaSE

文章目录 1.概念和结构体系1.1概念1.2结构体系 2.常见异常类型2.1空指针异常2.2数组越界异常2.3算数异常 3.异常的分类3.1编译时异常3.2运行时异常 4.异常的处理4.1防御式编程4.2异常的抛出4.3异常的捕获4.3.1异常申明throws4.3.2try-catch捕获并处理异常4.3.3finally 4.4异常的…

【ubuntu】【vmware tools】设置共享目录

1、现象 ubuntu 22 vmware 16,安装后会发现 “Reinstall VMware Tools…” 灰色不可用。如图: 2、原因分析 ubuntu 22 ISO 内不再提供 VMware Tools 的安装包,未检测到所以灰色不可用 在 Ubuntu 22 上挂载 Windows HGFS 共享目录&#xff…

Linux系统之部署Teleport堡垒机系统

Linux系统之部署Teleport堡垒机系统 一、Teleport介绍1.1 Teleport简介1.2 Teleport特点1.3 支持操作系统 二、本地环境介绍2.1 本地环境规划2.2 本次实践介绍 三、检查本地环境3.1 检查本地操作系统版本3.2 检查系统内核版本 四、部署teleport服务端4.1 创建部署目录4.2 下载t…

Sqoop初认识及安装

Sqoop初认识及安装 文章目录 Sqoop初认识及安装Sqoop简介Sqoop原理安装前置条件镜像地址上传安装包解压修改配置文件重命名配置文件 拷贝JDBC驱动验证Sqoop测试Sqoop是否能够成功连接数据库 Sqoop简介 Sqoop是一款开源的工具,主要用于在Hadoop(Hive)与传统的数据库…

redis高可用集群搭建

redis高可用集群搭建 redis的安装配置允许远程访问重启服务检查服务是否启动架构图开始搭建集群安装ruby创建集群高可用测试redis集群的扩展将7号机添加为新的master节点添加从节点删掉一个slave节点删除master节点 redis的安装 sudo apt-get install redis-server配置允许远程…

引进吸收再消化,可借鉴的产业超车模式探索

近期,C919大型客机顺利开启商业首航,这也标志着坐国产大飞机出行的时代来了!C919是我国首次按照国际适航标准自行研制、具有自主知识产权的喷气式干线客机,它的商用飞行也象征着我国对波音、空中客车等大型客机企业垄断地位的一次…

【unity每日一记】unity中常见的特性大全

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:uni…

【每日一题】LCP 41. 黑白翻转棋

【每日一题】LCP 41. 黑白翻转棋 LCP 41. 黑白翻转棋题目描述解题思路 LCP 41. 黑白翻转棋 题目描述 在 n*m 大小的棋盘中,有黑白两种棋子,黑棋记作字母 “X”, 白棋记作字母 “O”,空余位置记作 “.”。当落下的棋子与其他相同颜色的棋子在…

JMeter根据负载量计算并发用户数实例

目录 前言: 业务需求 分析需求 测试模型构建 & 用例设计 一、场景构建:登录业务操作流程、考勤打卡操作流程; 二、场景用例设计 三、测试脚本用例设计: 模型构建 登录打卡-操作流程: 场景设计 常用测试场景的类型:…

nx安装llvmlite与numba

文参考 Python安装llvmlite、numba报错解决方案_ClearLon的博客-CSDN博客 llvmlite与numba你可以理解为用于数据处理的加速包 我的python版本为3.6.9,llvmlite版本为0.32.1,numba版本为0.49.1 目录 1 安装 llvmlite 2 安装numba 1 安装 llvmlite…

软件测试技能,JMeter压力测试教程,登录参数化CSV 数据文件设置(五)

目录 前言 一、场景案例 二、登录接口 三、测试数据准备 四、CSV数据文件设置 五、查看结果 前言 我们在压测登录接口的时候,如果只用一个账号去设置并发压测,这样的结果很显然是不合理的,一个用户并发无法模拟真实的情况 如果要压测…

Python的特点和优势

Python的优特点 简单易学: Python语言相对于其他编程语言来说,属于比较容易学习的一门编程语言,它注重的是如何解决问题而不是编程语言的语法和结构。正是因为Python语言简单易学,所以,已经有越来越多的初学者选择Pyth…