人工智能算法工程师(中级)课程19-模型的量化与部署之模型部署和存储方式与代码详解

news2024/9/22 1:02:52

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程19-模型的量化与部署之模型部署和存储方式与代码详解本文全面介绍了神经网络模型在实际应用中的部署与存储策略,重点覆盖了两大主流框架:LibTorch和TensorRT。LibTorch,作为PyTorch的C++部署工具,提供了将训练好的模型转换为可独立运行的静态库的能力,适用于各种设备上的高性能推理。TensorRT,则是NVIDIA推出的优化深度学习模型推理速度的工具,尤其擅长GPU加速,能够显著提升模型的执行效率。

文章目录

  • 一、引言
  • 二、Libtorch源生部署
    • 1. Libtorch核心原理
    • 2. 代码详解
  • 三、torch.jit.trace()转换
    • 1. 基本原理
    • 2. 代码详解
  • 四、转成ONNX部署
    • 1. 基本原理
    • 2. 代码详解
  • 五、TensorRt部署加速
    • 1. 基本原理
    • 2. 代码详解
  • 六、总结

一、引言

随着深度学习技术的不断发展,神经网络模型在各个领域取得了显著的成果。然而,如何将训练好的模型高效、便捷地部署到不同平台,成为了一个亟待解决的问题。本文将详细介绍神经网络模型的部署和存储方式,包括Libtorch源生部署、torch.jit.trace()转换、转成ONNX部署以及TensorRt部署加速,并附上数学原理及完整可运行的PyTorch代码。

二、Libtorch源生部署

1. Libtorch核心原理

Libtorch是PyTorch的C++接口,它使得我们可以直接在C++环境下运行PyTorch模型。Libtorch部署的核心原理是将Python代码中的模型结构和参数转换为C++可识别的格式。

2. 代码详解

首先,我们使用PyTorch搭建一个简单的神经网络模型:

import torch
import torch.nn as nn
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(20, 50, 5)
        self.fc1 = nn.Linear(50 * 4 * 4, 500)
        self.fc2 = nn.Linear(500, 10)
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 50 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
model = SimpleNet()

接下来,我们将模型保存为.pt文件:

torch.save(model.state_dict(), 'simple_net.pt')

在C++环境下,使用Libtorch加载模型并预测:

#include <torch/script.h>
#include <iostream>
int main() {
    torch::jit::script::Module module;
    try {
        module = torch::jit::load("simple_net.pt");
    }
    catch (const c10::Error& e) {
        std::cerr << "Error loading the model\n";
        return -1;
    }
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({1, 1, 28, 28}));
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/10) << '\n';
    return 0;
}

三、torch.jit.trace()转换

1. 基本原理

torch.jit.trace()方法可以将PyTorch模型转换为TorchScript格式,从而提高模型在C++环境下的运行效率。

2. 代码详解

首先,使用torch.jit.trace()对模型进行转换:

traced_model = torch.jit.trace(model, torch.randn(1, 1, 28, 28))
traced_model.save('traced_simple_net.pt')

在C++环境下,使用Libtorch加载TorchScript模型并预测:

#include <torch/script.h>
// 省略其他代码
int main() {
    // 省略加载模型代码
    at::Tensor output = module.forward(inputs).toTensor();
    // 省略输出代码
    return 0;
}

四、转成ONNX部署

1. 基本原理

ONNX(Open Neural Network Exchange)是一种开放的模型交换格式,可以将不同框架训练的模型转换为统一格式,便于在不同平台部署。
在这里插入图片描述

2. 代码详解

首先,将PyTorch模型转换为ONNX格式:

import torch.onnx
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model, dummy_input, "simple_net.onnx")

在C++环境下,使用ONNX Runtime加载ONNX模型并预测:

#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
// 省略其他代码
int main() {
    Ort::Session session(env, L"simple_net.onnx", session_options);
    // 省略输入输出代码
    return 0;
}

五、TensorRt部署加速

1. 基本原理

TensorRt是NVIDIA推出的一款深度学习推理引擎,通过优化计算图、融合操作等方式,提高模型在GPU上的运行速度。
在这里插入图片描述

2. 代码详解

首先,将ONNX模型转换为TensorRt引擎:

import pycuda.autoinit
import pycuda.driver as cuda
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def build_engine(onnx_file_path, engine_file_path):
    with trt.Builder(TRT_LOGGER) as builder, \
            builder.create_network(common.EXPLICIT_BATCH) as network, \
            trt.OnnxParser(network, TRT_LOGGER) as parser:
        builder.max_workspace_size = 1 << 30
        builder.max_batch_size = 1
        builder.fp16_mode = True
        # Parse model file
        with open(onnx_file_path, 'rb') as model:
            if not parser.parse(model.read()):
                print('ERROR: Failed to parse the ONNX file.')
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                return None
        # Build an engine
        engine = builder.build_cuda_engine(network)
        with open(engine_file_path, 'wb') as f:
            f.write(engine.serialize())
        return engine
engine = build_engine('simple_net.onnx', 'simple_net.trt')

接下来,使用TensorRt引擎进行推理:

import pycuda.autoinit
import pycuda.driver as cuda
import numpy as np
def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
    # Transfer input data to the GPU.
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference.
    context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU.
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return only the host outputs.
    return [out.host for out in outputs]
# Set up the input data
input_data = np.random.random(size=(1, 1, 28, 28)).astype(np.float32)
# Allocate buffers for input and output
inputs, outputs, bindings, stream = common.allocate_buffers(engine, batch_size=1)
# Set the input data
inputs[0].host = input_data
# Run inference
trt_outputs = do_inference(context, bindings, inputs, outputs, stream)
# Print the output
print(trt_outputs[0])

六、总结

本文详细介绍了神经网络模型的部署和存储方式,包括Libtorch源生部署、torch.jit.trace()转换、转成ONNX部署以及TensorRt部署加速。通过数学原理的阐述和完整可运行的PyTorch代码,希望读者能够更好地理解和掌握这些技术。在实际应用中,可以根据需求选择合适的部署方式,以实现高效、便捷的模型部署。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1951032.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaScript模拟滑动手势

双击回到顶部 左滑动 右滑动 代码展示 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>Gesture…

linux命令更新-文本处理awk

awk命令简介 awk是一种强大的文本处理工具&#xff0c;可以对文本文件进行格式化、统计、计算等操作。它逐行读取文本文件&#xff0c;并对每一行进行处理。awk的语法相对简单&#xff0c;但功能非常强大&#xff0c;是Linux系统中常用的文本处理工具之一。 awk命令基本语法 …

某数据泄露防护(DLP)系统NoticeAjax接口SQL注入漏洞复现 [附POC]

文章目录 某数据泄露防护(DLP)系统NoticeAjax接口SQL注入漏洞复现 [附POC]0x01 前言0x02 漏洞描述0x03 影响版本0x04 漏洞环境0x05 漏洞复现1.访问漏洞环境2.构造POC3.复现某数据泄露防护(DLP)系统NoticeAjax接口SQL注入漏洞复现 [附POC] 0x01 前言 免责声明:请勿利用文章内…

60个常见的 Linux 指令

1.ssh 登录到计算机主机 ssh -p port usernamehostnameusername&#xff1a; 远程计算机上的用户账户名。 hostname&#xff1a; 远程计算机的 IP 地址或主机名。 -p 选项指定端口号。 2.ls 列出目录内容 ls ls -l # 显示详细列表 ls -a # 显示包括隐藏文件在内的所有内…

关于深度学习中的cuda编程,cuda相关介绍

深度学习中会涉及大量的、重复的矩阵运算、图形运算&#xff0c;而CPU对这种简单的加减法加速不够显著&#xff0c;可以使用GPU进行加速运算 CUDA是英伟达旗下的专门为深度学习加速运算的显卡&#xff0c;其对于简单的浮点运算、矩形运算相较于CPU加速了数倍不止 本文介绍CUD…

结合GB/T28181规范探讨Android平台设备接入模块心跳实现

技术背景 好多开发者在用我们Android平台GB28181设备接入模块的时候&#xff0c;更希望跟我们探讨一些协议规范方面&#xff0c;以便在现场对接时&#xff0c;可以知其然知其所以然。比如&#xff0c;有开发者提到&#xff0c;GB28181的状态消息报送这块到底要怎么实现&#x…

搭建Vue开发环境

一、下载Vue.js 进入官网教程安装 — Vue.js (vuejs.org) 下载开发版本到本地 二、安装 Vue Devtools 安装完成后

Spring Boot 与 MongoDB 整合指南

MongoDB MongoDB 是一种基于文档的NoSQL数据库&#xff0c;以其高性能、高可用性和易扩展性而著称。它使用 BSON&#xff08;类似 JSON 的二进制格式&#xff09;来存储数据&#xff0c;提供了灵活的数据模型&#xff0c;使得开发者可以更轻松地存储和查询复杂的数据结构。将M…

Navicat premium最新【16/17 版本】安装下载教程,图文步骤详解(超简单,一步到位,免费下载领取)

文章目录 软件介绍软件下载安装步骤激活步骤 软件介绍 Navicat是一款快速、可靠且功能全面的数据库管理工具&#xff0c;专为简化数据库的管理及降低系统管理成本而设计。以下是对Navicat的详细介绍&#xff1a; 一、产品概述 开发目的&#xff1a;Navicat旨在通过其直观和设计…

Linux:core文件无法生成排查步骤

1、进程的RLIMIT_CORE或RLIMIT_SIZE被设置为0。使用getrlimit和ulimit检查修改。 使用ulimit -a 命令检查是否开启core文件生成限制 如果发现-c后面的结果是0&#xff0c;就临时添加环境变量ulimit -c unlimited&#xff0c;之后在启动程序观察是否有core生成&#xff0c;如果…

Qt 学习第二天:创建第一个Qt程序

【最新QT从入门到实战完整版|传智教育】04 创建第一个Qt程序 一、命名规范&#xff08;驼峰命名法&#xff09; 类名&#xff1a; 首字母大写&#xff0c;单词和单词之间首字母大写 函数名和变量名&#xff1a; 首字母小写&#xff0c;单词和单词之间首字母大写 二、快捷…

零食商城管理系统

目录 一、项目背景与目标 1.1 项目背景 1.2 项目意义 1.3 国内外研究现状 1.4 开发工具介绍 二、项目内容与分工 三、 时间表与进度 1. 需求分析阶段&#xff1a; 2. 系统设计阶段&#xff1a; 3. 系统开发阶段&#xff1a; 4. 系统测试阶段&#xff1a; 5. 部署与上…

Selenium 的使用

selenium 是一个自动化测试工具&#xff0c;利用它可以驱动浏览器完成特定的操作&#xff0c;例如点击&#xff0c;下拉等&#xff0c;还可以获取浏览器当前呈现的页面的源代码&#xff0c;做到所见即所爬&#xff0c;对于一些 JavaScript 动态渲染的界面来说&#xff0c;这种爬…

WEB攻防-通用漏洞-SQL 读写注入-MYSQLMSSQLPostgreSQL

什么是高权限注入 高权限注入指的是攻击者通过SQL注入漏洞&#xff0c;利用具有高级权限的数据库账户&#xff08;如MYSQL的root用户、MSSQL的sa用户、PostgreSQL的dba用户&#xff09;执行恶意SQL语句。这些高级权限账户能够访问和修改数据库中的所有数据&#xff0c;甚至执行…

WEB集群-Tomact集群

linux云计算中小企业规模集群架构设计图----总结 在写今天内容前&#xff0c;小编绘制一个图&#xff1a;我设计了linux云计算中小企业规模集群架构设计图&#xff08;也可根据业务需求&#xff0c;增加业务变成大型企业架构设计图&#xff09; 知识补充–故障案例-https no s…

【Golang 面试基础题】每日 5 题(十)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/UWz06 &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 Golang 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏…

Python爬虫(6) --深层爬取

深层爬取 在前面几篇的内容中&#xff0c;我们都是爬取网页表面的信息&#xff0c;这次我们通过表层内容&#xff0c;深度爬取内部数据。 接着按照之前的步骤&#xff0c;我们先访问表层页面&#xff1a; 指定url发送请求获取你想要的数据数据解析 我们试着将以下豆瓣读书页…

WPF代办事项应用

目录 一 设计原型 二 后台源码 一 设计原型 添加代办事项页面&#xff1a; 二 后台源码 Model&#xff1a; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace 待办事项应用.DataModel {pub…

数据结构(Java):Map集合Set集合哈希表

目录 1、介绍 1.1 Map和Set 1.2 模型 2、Map集合 2.1 Map集合说明 2.2 Map.Entry<K&#xff0c;V> 2.3 Map常用方法 2.4 Map注意事项及实现类 3、Set集合 3.1 Set集合说明 3.2 Set常用方法 3.3 Set注意事项及其实现类 4、TreeMap&TreeSet 4.1 集合类TreeM…

头歌最小生成树 ------习题

一、背包问题 1.理解&#xff1a;背包问题相当于最小生成树&#xff0c;也就是线性规划最优解 2.公式&#xff1a; M: 背包的总重量 w&#xff1a;物品 i 的重量 p: 物品 i 的价值 3.基本背包练习 4.完全背包问题&#xff1a;每种物品有无限件 >>> 开头加一个for…