30分钟吃掉 Pytorch 转 onnx

news2024/11/27 16:39:04

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学.

针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。

汇总合集:

《大模型面试宝典》(2024版) 发布!

圈粉无数!《PyTorch 实战宝典》火了!!!


PyTorch 是一个用于机器学习的开源深度学习框架,而ONNX(Open Neural Network Exchange)是一个用于表示深度学习模型的开放式格式。

将 PyTorch 模型转换为ONNX格式有几个原因和优势:

  1. 跨平台部署: ONNX是一个跨平台的格式,支持多种深度学习框架,包括PyTorch、TensorFlow等。将模型转换为ONNX格式可以使模型在不同框架和设备上进行部署和运行。

  2. 性能优化: ONNX格式可以在不同框架之间实现性能优化。例如,可以在PyTorch中训练模型,然后转换为ONNX格式,并在性能更高的框架(如TensorRT)中进行推理。

  3. 模型压缩: ONNX格式可以实现模型的压缩和优化,从而减小模型的体积并提高推理速度。这对于在资源受限的设备上部署模型尤为重要。

pytorch 模型线上部署最常见的方式是转换成onnx,然后再转成tensorRT 在cuda上进行部署推理。

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了Pytorch 技术与面试交流群, 想要获取最新面试题、了解最新面试动态的、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2040。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2040,备注:技术交流

本文介绍将pytorch模型转换成onnx模型并进行推理的方法。

#!pip install onnx 
#!pip install onnxruntime
#!pip install torchvision

一,准备pytorch模型

我们先导入torchvision中的resnet18模型,演示它的推理效果。

以便和onnx的结果进行对比。

import torch
import torchvision.models as models
import numpy as np
import torchvision
import torchvision.transforms as T

from PIL import Image

def create_net():
    net = models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
    return net 

net = create_net()

torch.save(net.state_dict(),'resnet18.pt')
net.eval();
def get_test_transform():
    return T.Compose([
        T.Resize([320, 320]),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

image = Image.open("dog.png") # 289
img = get_test_transform()(image)
img = img.unsqueeze_(0) 
output = net(img)
score, indice = torch.max(torch.softmax(output,axis=-1),1)
info = {'score':score.tolist()[0],'indice':indice.tolist()[0]}

def show_image(image, title):
    import matplotlib.pyplot as plt 
    ax=plt.subplot()
    ax.imshow(image)
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([]) 
    plt.show()

show_image(image, title = info)

图片

二,pytorch模型转换成onnx模型

1, 简化版本

import onnxruntime
import onnx

batch_size = 1  
input_shape = (3, 320, 320)   

x = torch.randn(batch_size, *input_shape)
onnx_file = "resnet18.onnx"
torch.onnx.export(net,x,onnx_file,
                opset_version=10,
                do_constant_folding=True,  # 是否执行常量折叠优化
                input_names=["input"],
                output_names=["output"],
                dynamic_axes={
                    "input":{0:"batch_size"},  
                     "output":{0:"batch_size"}})

!du -s -h resnet18.pt
 45M	resnet18.pt
!du -s -h resnet18.onnx 

 45M	resnet18.onnx

可以在 https://netron.app/ 中拖入 resnet18.onnx 文件查看模型结构

2,全面版本

下面的代码包括了设置输入输出尺寸,以及动态可以变batch等等。

import argparse
from argparse import Namespace
import time
import sys
import os
import torch
import torch.nn as nn
import torchvision.models as models
import onnx
import onnxruntime

from io import BytesIO


ROOT = os.getcwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

params = Namespace(weights='resnet18.pt',
                   img_size=[320,320],
                   batch_size=1,
                   half=False,
                   dynamic_batch=True
                  )

parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='checkpoint.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size')  # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True')
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
parser.add_argument('--dynamic-batch', action='store_true', help='export dynamic batch onnx model')
parser.add_argument('--trt-version', type=int, default=8, help='tensorrt version')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')

args = parser.parse_args(args='',namespace=params)


args.img_size *= 2 if len(args.img_size) == 1 else 1  # expand
print(args)

t = time.time()

# Check device
cuda = args.device != 'cpu' and torch.cuda.is_available()
device = torch.device(f'cuda:{args.device}' if cuda else 'cpu')
assert not (device.type == 'cpu' and args.half), '--half only compatible with GPU export, i.e. use --device 0'

# Load PyTorch model
model = create_net()
model.to(device)
model.load_state_dict(torch.load(args.weights)) # pytorch模型加载

# Input
img = torch.zeros(args.batch_size, 3, *args.img_size).to(device)  # image size(1,3,320,192) iDetection

# Update model
if args.half:
    img, model = img.half(), model.half()  # to FP16
model.eval()

prediction = model(img)  # dry run

# ONNX export
print('\nStarting to export ONNX...')
export_file = args.weights.replace('.pt', '.onnx')  # filename
with BytesIO() as f:
    dynamic_axes = {"input":{0:"batch_size"}, "output":{0:"batch_size"} } if args.dynamic_batch else None
    torch.onnx.export(model, img, f, verbose=False, opset_version=13,
                      training=torch.onnx.TrainingMode.EVAL,
                      do_constant_folding=True,
                      input_names=['input'],
                      output_names=['output'],
                      dynamic_axes=dynamic_axes)
    f.seek(0)
    # Checks
    onnx_model = onnx.load(f)  # load onnx model
    onnx.checker.check_model(onnx_model)  # check onnx model
    
if args.simplify:
    try:
        import onnxsim
        print('\nStarting to simplify ONNX...')
        onnx_model, check = onnxsim.simplify(onnx_model)
        assert check, 'assert check failed'
    except Exception as e:
        print(f'Simplifier failure: {e}')

onnx.save(onnx_model, export_file)

print(f'ONNX export success, saved as {export_file}')

# Finish
print('\nExport complete (%.2fs)' % (time.time() - t))
Namespace(weights='resnet18.pt', img_size=[320, 320], batch_size=1, half=False, dynamic_batch=True, inplace=False, simplify=False, trt_version=8, device='cpu')

Starting to export ONNX...
ONNX export success, saved as resnet18.onnx

Export complete (0.57s)

三,使用onnx模型进行推理

1,函数风格

onnx_sesstion = onnxruntime.InferenceSession(export_file)
def pipe(img_path,
         onnx_sesstion = onnx_sesstion):
    image = Image.open(img_path) 
    img = get_test_transform()(image)
    img = img.unsqueeze_(0) 

    to_numpy = lambda tensor: tensor.data.cpu().numpy()
    
    inputs = {onnx_sesstion.get_inputs()[0].name: to_numpy(img)}
    outs = onnx_sesstion.run(None, inputs)[0]

    score, indice = torch.max(torch.softmax(torch.as_tensor(outs),axis=-1),1)
    info = {'score':score.tolist()[0],'indice':indice.tolist()[0]}
    return info
img_path = 'dog.png'image = Image.open(img_path)info = pipe(img_path)show_image(image,info)

图片

2,对象风格

import os, sys

import onnxruntime
import onnx
    
class ONNXModel():
    def __init__(self, onnx_path):
        self.onnx_session = onnxruntime.InferenceSession(onnx_path)
        self.input_names = [node.name for node in self.onnx_session.get_inputs()]
        self.output_names = [node.name for node in self.onnx_session.get_outputs()]
        print("input_name:{}".format(self.input_names))
        print("output_name:{}".format(self.output_names))
 
    def forward(self, x):
        if isinstance(x,np.ndarray):
            assert len(self.input_names)==1
            input_feed = {self.input_names[0]:x}
        elif isinstance(x,(tuple,list)):
            assert len(self.input_names)==len(x)
            input_feed = {k:v for k,v in zip(self.input_names,x)}
        else:
            assert isinstance(x,dict)
            input_feed = x
        outs = self.onnx_session.run(self.output_names, input_feed=input_feed)
        return outs
    
    def predict(self,img_path):
        image = Image.open(img_path) 
        img = get_test_transform()(image)
        img = img.unsqueeze_(0) 
        to_numpy = lambda tensor: tensor.data.cpu().numpy()
        outs = self.forward(to_numpy(img))[0]
        score, indice = torch.max(torch.softmax(torch.as_tensor(outs),axis=-1),1)
        return {'score':score[0].data.numpy().tolist(),
            'indice':indice[0].data.numpy().tolist()}
onnx_model = ONNXModel(export_file)
info = onnx_model.predict(img_path)
show_image(image, title = info)
input_name:['input']
output_name:['output']

图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1805113.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

论文浅尝 | THINK-ON-GRAPH:基于知识图谱的深层次且可靠的大语言模型推理方法...

笔记整理:刘佳俊,东南大学硕士,研究方向为知识图谱 链接:https://arxiv.org/pdf/2307.07697.pdf 1. 动机 本文是IDEA研究院的工作,这篇工作将知识图谱的和大语言模型推理进行了结合,在每一步图推理中利用大…

Nvidia/算能 +FPGA+AI大算力边缘计算盒子:电力巡检智能机器人

聚焦数字经济与双碳经济赛道,专注于提供集中式新能源场站与分布式综合能源数智化整体解决方案,坚持以场站数字化、综合能源数字化双轮驱动发展。依靠专业化人才队伍与丰富的实证基地研究经验,打造成熟、先进的数智新能源研发平台。 在集中式新…

【Git】远程操作 -- 详解

一、理解分布式版本控制系统 我们目前所说的所有内容(工作区、暂存区、版本库等等)都是在本地,也就是在我们的笔记本或者计算机上。而我们的 Git 其实是分布式版本控制系统。 上面这段话是什么意思呢? 可以简单理解为&#xff1…

模型 SCAMPER创新法则

说明:系列文章 分享 模型,了解更多👉 模型_思维模型目录。激发创新的七步思维法。 1 SCAMPER创新法则的应用 1.1 SCAMPER应用之改进自行车设计 一家自行车制造商希望改进其自行车设计,以吸引更多的消费者并提高市场份额。他们决…

【小白专用24.6.8】C# 异步任务Task和异步方法async/await详解

一、什么是异步 同步和异步主要用于修饰方法。当一个方法被调用时,调用者需要等待该方法执行完毕并返回才能继续执行,我们称这个方法是同步方法;当一个方法被调用时立即返回,并获取一个线程执行该方法内部的业务,调用…

最值,反转数组——跟之前的差不多

文章目录 数组最值感悟改进 反转数组问题 代码改进 数组最值 package com.zhang; /*求数组最大最小值*/ public class test_arr1 {public static void main(String[] args) {int[] arr {10,66,42,8,999,1};max(arr);min(arr);}public static int max(int[] arr){int max arr…

【云原生_K8S系列】Kubernetes 控制器之 Deployment

在 Kubernetes 中,Deployment 是一种高级控制器,负责管理应用的部署和生命周期。它提供了一种声明性的方式来定义应用的期望状态,并确保实际状态与期望状态保持一致。Deployment 可以自动处理应用的滚动更新、扩展和回滚等任务,是…

GPU风扇不旋转:为什么会发生这种情况以及如何修复

GPU在处理数百万像素时往往会发热,因此冷却风扇静音可能会令人担忧,这是可以理解的!如果你注意到你的GPU风扇没有旋转,下面是如何评估是否存在真正的问题,以及如何解决问题。 风扇停止旋转可能是一个功能,而不是一个Bug 如果GPU没有用于密集任务或没有达到高温,则可以…

Java 环境配置 -- Java 语言的安装、配置、编译与运行

大家好,我是栗筝i,这篇文章是我的 “栗筝i 的 Java 技术栈” 专栏的第 002 篇文章,在 “栗筝i 的 Java 技术栈” 这个专栏中我会持续为大家更新 Java 技术相关全套技术栈内容。专栏的主要目标是已经有一定 Java 开发经验,并希望进…

vs - vs2015编译gtest-v1.12.1

文章目录 vs - vs2015编译gtest-v1.12.1概述点评笔记将工程迁出到本地后,如果已经编译过工程,将工程Revert, Clean up 干净。编译用的CMake, 优先用VS2019自带的打开VS2015X64本地命令行编译gtest工程测试安装自己写个测试工程,看看编译出来的…

《精通ChatGPT:从入门到大师的Prompt指南》第10章:案例分析

第10章:案例分析 10.1 优秀Prompt案例解析 在深入探讨如何精通ChatGPT的使用之前,理解并分析一些优秀的Prompt案例是至关重要的。这不仅有助于更好地掌握Prompt的构建技巧,还能提高与AI交互的效果。在这一节中,我们将详细解析一…

STM32F103C8T6 HAL库 printf重定向 USART1 DMA方式发送数据

前言: 在上一篇文章里,我采用printf重定向为usart1,但是这样发送,对于MPU的负载比较大,所以本篇文章采用DMA方式,解放MPU资源,去做其他的事情,这里仅做为自己的记录。 正文开始&…

LDA初步了解

LDA简析 最明显的特征是能够将若干文档自动编码分类为一定数量的主题(注意:主题的数量需要人为指定)。设定好主题数量之后,运行LDA模型就会得到每个主题下边词语的发布概率以及文档对应的主题概率。 LDA原理 LDA的工作原理 可把…

【刷题篇】分治-归并排序

文章目录 1、排序数组2、交易逆序对的总数3、计算右侧小于当前元素的个数4、翻转对 1、排序数组 给你一个整数数组 nums&#xff0c;请你将该数组升序排列。 class Solution { public:vector<int> tmp;void mergeSort(vector<int>& nums,int left,int right){…

Web学习_sqli-labs_1~10关

less1-GET-Error based - Single quotes - String &#xff08;基于错误的GET单引号字符型注入&#xff09; 我每次操作都会在Hackbar中&#xff0c;代码都在Hackbar框中&#xff0c;可放大看 有题目知道了是字符型注入&#xff0c;我们先判断表格有几列&#xff0c;可以发现…

详细分析Mysql临时变量的基本知识(附Demo)

目录 前言1. 用户变量2. 会话变量 前言 临时变量主要分为用户变量和会话变量 1. 用户变量 用户变量是特定于会话的&#xff0c;在单个会话内可以在多个语句中共享 以 符号开头在 SQL 语句中使用 SET 语句或直接在查询中赋值 声明和赋值 SET var_name value; -- 或者 SE…

如何在恢复出厂设置后从 Android 恢复照片

在某些情况下&#xff0c;您可能会考虑将 Android 设备恢复出厂设置。需要注意的是&#xff0c;恢复出厂设置后&#xff0c;所有设置、用户数据甚至应用程序数据都将被清除。因此&#xff0c;如果您将 Android 设备恢复出厂设置&#xff0c;甚至在里面留下了一些珍贵的照片&…

宝塔面板和 LNMP 环境下反代 HFish 蜜罐平台的正确方法

最近明月在热心站长好友的支持下搭建了安全、简单、有效并永久免费的蜜罐平台 HFish,因为 HFish 默认是以 https://IP:端口 的 Web 链接形式提供访问的,这会暴露蜜罐平台的真实服务器 IP 不说,还非常不便于快速的访问(反正明月是记不住 IP 的),所以就需要给部署好的 HFis…

扫地机器人:卷价格,不如卷技术

扫地机器人内卷的终点是技术和价值&#xff0c;价格只是附属品。 一路上涨的价格&#xff0c;一路下跌的销量 从价格飙升&#xff0c;到重新卷回价格&#xff0c;尴尬的背后是扫地机器人在骨感现实下的无奈抉择。 根据数据显示&#xff0c;2020中国扫地机器人线上市场零售均价…

经典文献阅读之--P2O-Calib(利用点对空间遮挡关系的相机-激光雷达标定)

Tip: 如果你在进行深度学习、自动驾驶、模型推理、微调或AI绘画出图等任务&#xff0c;并且需要GPU资源&#xff0c;可以考虑使用UCloud云计算旗下的Compshare的GPU算力云平台。他们提供高性价比的4090 GPU&#xff0c;按时收费每卡2.6元&#xff0c;月卡只需要1.7元每小时&…