TensorRT学习笔记--基于FCN-ResNet101推理引擎实现语义分割

news2025/1/8 17:57:18

目录

前言

1--Pytorch模型转换为Onnx模型

2--Onnx模型可视化及测试

2-1--可视化Onnx模型

2-2--测试Onnx模型

3--Onnx模型转换为Tensor RT推理模型

4--基于Tensor RT使用推理引擎实现语义分割


前言

        基于Tensor RT的模型转换流程:Pytorch → Onnx → Tensor RT;本笔记基于 Tensor RT 官方 Github 仓库的语义分割 Demo(Tensor RT 官方Demo链接) 进行实现,首先将训练好的 Pytorch 模型转换为 Onnx 模型,之后基于Tensor RT将 Onnx 模型转换为推理引擎 engine,最后使用Tensor RT的推理引擎 engine 实现语义分割。

1--Pytorch模型转换为Onnx模型

        利用 torch.hub.load() 加载预训练的 FCN-ResNet101 模型,利用 torch.onnx.export()导出Onnx模型;

from PIL import Image
from io import BytesIO
import requests
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np

# 下载并保存网络输入的图片
download_image = "./input.ppm" # 保存的路径,保存为ppm格式
response = requests.get("https://pytorch.org/assets/images/deeplab1.png") # 下载图片
with Image.open(BytesIO(response.content)) as img:
    ppm = Image.new("RGB", img.size, (255, 255, 255))
    ppm.paste(img, mask=img.split()[3])
    ppm.save(download_image) # 保存图片
    
plt.imshow(Image.open(download_image)) 
plt.show()

# 创建并导出模型
output_onnx="./fcn-resnet101.onnx" # 导出的Onnx模型路径
class FCN_ResNet101(nn.Module): # 定义模型
    def __init__(self):
        super(FCN_ResNet101, self).__init__()
        # 下载并导入 fcn_resnet101 模型
        self.model = torch.hub.load('pytorch/vision:v0.6.0', 'fcn_resnet101', pretrained=True)

    def forward(self, inputs):
        x = self.model(inputs)['out']
        x = x.argmax(1, keepdims=True) # 增加 argmax 模块,组成最终的模型
        return x

model = FCN_ResNet101()
model.eval()

# 定义网络的输入
input_tensor = torch.rand(4, 3, 224, 224) 

# 导出Onnx模型
torch.onnx.export(model, input_tensor, output_onnx,
    opset_version=12,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"},
                  "output": {0: "batch", 2: "height", 3: "width"}},
    verbose=False)

        上面利用 torch.hub.load() 下载预训练模型时,因网速的原因可能会比较慢,可根据链接(红框)下载对应的模型放置到本地相应的路径(绿框)当中,以节省时间。

2--Onnx模型可视化及测试

2-1--可视化Onnx模型

        利用 netron 第三方库可视化导出的 Onnx 模型,以查看模型的输入输出维度;

# 终端依次执行
pip install netron

python

import netron

netron.start("./fcn-resnet101.onnx")

2-2--测试Onnx模型

        使用 Onnxruntime 测试导出的 Onnx 推理模型,参考Tensor RT官方 Demo 设计相应的前处理和后处理函数;

import numpy as np
from PIL import Image
import onnx
import matplotlib.pyplot as plt
import onnxruntime

# 前处理
def preprocess(image):
    # Mean normalization
    mean = np.array([0.485, 0.456, 0.406]).astype('float32')
    stddev = np.array([0.229, 0.224, 0.225]).astype('float32')
    data = (np.asarray(image).astype('float32') / float(255.0) - mean) / stddev
    # Switch from HWC to to CHW order
    return np.moveaxis(data, 2, 0)

# 后处理
def postprocess(data):
    num_classes = 21
    # create a color palette, selecting a color for each class
    palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = np.array([palette*i%255 for i in range(num_classes)]).astype("uint8")
    # plot the segmentation predictions for 21 classes in different colors
    img = Image.fromarray(data.astype('uint8'), mode='P')
    img.putpalette(colors)
    return img


if __name__ == "__main__":
    # 加载并可视化网络输入
    input_file = "./input.ppm" 
    with Image.open(input_file) as img:
        input_image = preprocess(img) # 前处理
        image_width = img.width
        image_height = img.height
    plt.imshow(Image.open(input_file)) 
    plt.show()
    
    # 调整输入图片的维度,以适配Onnx模型 
    input_data = input_image[np.newaxis, :]
    
    # 导入Onnx模型
    Onnx_file = "./fcn-resnet101.onnx"
    Model = onnx.load(Onnx_file)
    onnx.checker.check_model(Model) # 验证Onnx模型是否准确
    
    # 使用onnxruntime推理
    model = onnxruntime.InferenceSession(Onnx_file, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
    input_name = model.get_inputs()[0].name # 对应可视化onnx模型时,网络的输入名称: input
    output_name = model.get_outputs()[0].name # 对应可视化onnx模型时,网络的输出名称: output
    print(input_name)
    output = model.run([output_name], {input_name:input_data}) # onnxruntime的输入input_data需要为numpy类型,Tensor类型会报错
    
    # 后处理
    output = postprocess(np.reshape(output, (image_height, image_width)))
    
    # 保存并可视化推理结果
    output_file = "output_onnx.ppm"
    output.convert('RGB').save(output_file, "PPM")
    plt.imshow(Image.open(output_file))
    plt.show()

3--Onnx模型转换为Tensor RT推理模型

        基于 Tensor RT 将导出的 Onnx 模型转换为推理模型 Engine,这里博主基于 Tensor RT 8.2.5.1 提供的 trtexec 可执行文件;

./trtexec --onnx=/path/fcn-resnet101.onnx --fp16 --workspace=4096 --minShapes=input:1x3x256x256 --optShapes=input:1x3x1026x1282 --maxShapes=input:1x3x1440x2560 --buildOnly --saveEngine=/path/fcn-resnet101.engine

        --onnx 和 --saveEngine 需根据实际设置正确的模型路径;之前的Tensor RT 7.x 版本将--workspace设置为64,当使用Tensor RT 8.x 版本时,workspace的空间将不足会出现上图的错误,因此需将 --workspace=64 设置为 --workspace=4096;

4--基于Tensor RT使用推理引擎实现语义分割

        基于Tensor RT加载 Fcn-resnet101.engine 推理引擎,设计相应的前处理和后处理函数,实现语义分割;

import numpy as np
import os
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt

import matplotlib.pyplot as plt
from PIL import Image

# 前处理
def preprocess(image):
    # Mean normalization
    mean = np.array([0.485, 0.456, 0.406]).astype('float32')
    stddev = np.array([0.229, 0.224, 0.225]).astype('float32')
    data = (np.asarray(image).astype('float32') / float(255.0) - mean) / stddev
    # Switch from HWC to to CHW order
    return np.moveaxis(data, 2, 0)

# 后处理
def postprocess(data):
    num_classes = 21
    # create a color palette, selecting a color for each class
    palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = np.array([palette*i%255 for i in range(num_classes)]).astype("uint8")
    # plot the segmentation predictions for 21 classes in different colors
    img = Image.fromarray(data.astype('uint8'), mode='P')
    img.putpalette(colors)
    return img

# 导入推理引擎engine
def load_engine(engine_file_path):
    assert os.path.exists(engine_file_path)
    print("Reading engine from file {}".format(engine_file_path))
    with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())
    
def infer(engine, input_file, output_file):
    print("Reading input image from file {}".format(input_file))
    with Image.open(input_file) as img:
        input_image = preprocess(img)
        image_width = img.width
        image_height = img.height

    with engine.create_execution_context() as context:
        # Set input shape based on image dimensions for inference
        context.set_binding_shape(engine.get_binding_index("input"), (1, 3, image_height, image_width))
        # Allocate host and device buffers
        bindings = []
        for binding in engine:
            binding_idx = engine.get_binding_index(binding)
            size = trt.volume(context.get_binding_shape(binding_idx))
            dtype = trt.nptype(engine.get_binding_dtype(binding))
            if engine.binding_is_input(binding):
                input_buffer = np.ascontiguousarray(input_image)
                input_memory = cuda.mem_alloc(input_image.nbytes)
                bindings.append(int(input_memory))
            else:
                output_buffer = cuda.pagelocked_empty(size, dtype)
                output_memory = cuda.mem_alloc(output_buffer.nbytes)
                bindings.append(int(output_memory))

        stream = cuda.Stream()
        # Transfer input data to the GPU.
        cuda.memcpy_htod_async(input_memory, input_buffer, stream)
        # Run inference
        context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
        # Transfer prediction output from the GPU.
        cuda.memcpy_dtoh_async(output_buffer, output_memory, stream)
        # Synchronize the stream
        stream.synchronize()

    with postprocess(np.reshape(output_buffer, (image_height, image_width))) as img:
        print("Writing output image to file {}".format(output_file))
        img.convert('RGB').save(output_file, "PPM")

if __name__ == "__main__":
    
    TRT_LOGGER = trt.Logger()

    engine_file = "./fcn-resnet101.engine"
    input_file  = "./input.ppm"
    output_file = "./output_trt.ppm"
    
    plt.imshow(Image.open(input_file))
    plt.show()

    print("Running TensorRT inference for FCN-ResNet101")
    with load_engine(engine_file) as engine:
        infer(engine, input_file, output_file)
        
    plt.imshow(Image.open(output_file))
    plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/146553.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

通用vue组件化首页

一、首先先建立文件main.vue,构建主体 1.选择合适的模板element-plus,直接复制 2.编写相应的样式 <template><div class"main"><el-container class"main-content"><el-aside> aside </el-aside><el-container class&q…

2022年中职组网络安全竞赛D模块竞赛漏洞报告单总结

Windows加固 后门用户 漏洞发现过程 打开cmd使用net user 看到”hacker”用户,疑似存在后门用户 使用hacker/123456成功登录目标服务器,证明存在后门用户 漏洞加固过程 删除后门用户

HTML与CSS基础(一)—— HTML基础(web标准、开发工具、标签)

目标能够理解HTML的 基本语法 和标签的关系 能够使用 排版标签 实现网页中标题、段落等效果 能够使用 相对路径 选择不同目录下的文件 能够使用 媒体标签 在网页中显示图片、播放音频和视频 能够使用 链接标签 实现页面跳转功能一、基础认知目标&#xff1a;认识 网页组成 和 五…

【Linux】程序的翻译四个阶段(图示详解)

因为淋过雨&#xff0c;所以懂的为别人撑伞&#xff1b;因为迷茫过&#xff0c;所以懂得为别人指路。 我们都知道写好代码后&#xff0c;编译器会帮助我们把代码生成可执行程序&#xff0c;细加了解又会知道程序的生成又分为四步&#xff1a;预处理、编译、汇编、链接。那么这四…

JAVA语言基础语法——异常中的常见方法及抛出异常等练习

Throwable的成员方法定义在最顶级Throwable类中a.实例如下&#xff1a;e.printStackTrace(); 将异常的所有信息红色的字体打印在控制台&#xff0c;不会结束虚拟机&#xff0c;仅仅只是打印的操作。抛出处理throws注意&#xff1a;写在方法定义处&#xff0c;表示声明一个异常&…

DOM(三):鼠标、键盘事件对象

鼠标、键盘事件对象鼠标事件对象键盘事件对象鼠标事件对象 event对象代表事件的状态&#xff0c;和事件相关的一系列信息的集合。现阶段我们主要是用鼠标事件对象MouseEvent和键盘事件对象KeyboardEvent 例如&#xff1a; // 鼠标事件对象 MouseEventdocument.addEventListene…

Android正确的保活方案,不要掉进保活需求死循环陷进

在开始前&#xff0c;还是给大家简单介绍一下&#xff0c;以前出现过的一些黑科技&#xff1a; 大概在6年前Github中出现过一个叫MarsDaemon&#xff0c;这个库通过双进程守护的方式实现保活&#xff0c;一时间风头无两。好景不长&#xff0c;进入 Android 8.0时代之后&#x…

STM32系列单片机标准库移植FreeRTOS V10.4.6详解

文中所用到的资料下载地址 https://download.csdn.net/download/qq_20222919/87370679 最近看正点原子新录制了手把手教你学FreeRTOS的视频教程&#xff0c;看了一下教程发现视频里面讲的是使用HAL移植 FreeRTOS V10.4.6 版本&#xff0c;以前的标准库移植的是FreeRTOS V9.0 版…

关于PostgreSQL JIT Memory-Leak 问题 从 LLVM源码层面来分析

文章目录前言LLVM Types 在 JIT中的使用LLVM Types 设计导致的 PG JIT 内存问题分析解决&#xff1f;前言 之前介绍 PG 的 JIT 实现 时提到 为了性能开启JIT 之后有一个比较严重的内存泄漏问题。现象就是在一个backend 内持续跑大量的 sqllogic 随机复杂查询&#xff0c;能够看…

java 微服务 Nacos配置 feign 网关路由

Nacos配置管理 配置信息我们写有热更新需求的配置就可以了 1.引入Nacos的配置管理客户端依赖&#xff1a; <!--nacos配置管理依赖--> <dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring-cloud-starter-alibaba-nacos-config…

HBase基础_1

HBase 注&#xff1a;大家觉得博客好的话&#xff0c;别忘了点赞收藏呀&#xff0c;本人每周都会更新关于人工智能和大数据相关的内容&#xff0c;内容多为原创&#xff0c;Python Java Scala SQL 代码&#xff0c;CV NLP 推荐系统等&#xff0c;Spark Flink Kafka Hbase Hive…

学习笔记6:字符串库函数(下)

目录 一. strstr模拟实现 二. strtok模拟实现 三.关于strerror和perror的说明 一. strstr模拟实现 库函数strstr函数首部&#xff1a;char * strstr ( const char *str1, const char * str2); 函数的功能是在str1指向的主字符串中寻找子串str2&#xff0c;并且返回主字符串中…

JS数组对象——英文按照首字母进行排序sort()、localeCompare()

JS数组对象——英文按照首字母进行排序(sort、localeCompare&#xff09;上期回顾场景复现sort()方法与localeCompare实例应用上期回顾 文章内容文章链接JS数组对象——根据日期进行排序Date.parse()&#xff0c;按照时间进行升序或降序排序https://blog.csdn.net/XSL_HR/arti…

【CANN训练营第三季】AI目标属性编辑应用

文章目录1、参考样例进行运行stargan2、dvpp媒体数据处理结业考核题目1、题目2、题目31、参考样例进行运行stargan 下载stargan后&#xff0c;查看readme&#xff0c;进行复现。 # 为了方便下载&#xff0c;在这里直接给出原始模型下载及模型转换命令,可以直接拷贝执行。 cd …

Tic-Tac-Toe:基于Minimax算法的人机对弈程序(python实现)

目录 1. 前言 2. Minimax算法介绍 2.1 博弈树 2.2 估值函数 2.3 基本算法思想 2.4 实例1 ​​​​​​​2.5 实例2—棋类游戏 2.6 小结 3. Tic-Tac-Toe minimax AI实现 3.1 函数说明 3.2 处理流程 3.3 代码 4. 小结 1. 前言 在上一篇中实现一个简单的Tic-Tac-Toe人…

【07】概率图推断之信念传播

概率图推断之信念传播 文章目录将变量消除视为信息传递信息传递算法加总乘积信息传递因子树上的加总乘积信息传递最大乘积信息传递总结在《概率图推断之变量消除算法》中&#xff0c;我们讲了变量消除算法如何对有向图和无向图求P(Y∣Ee)P(Y \mid E e)P(Y∣Ee)的边缘概率。 …

java 微服务之MQ 异步通信

初识MQ 同步调用存在的问题 异步调用常见实现就是事件驱动模式 事件驱动模式优势&#xff1a; 优势1&#xff1a;服务解耦 一旦有新业务只需要订阅或者减少事件就行了 优势2&#xff1a;性能提升&#xff0c;吞吐量提高 优势3&#xff1a;服务没有强依赖&#xff0c;不用担…

【自学C++】C++注释

C注释 C注释教程 用于注解说明解释程序的文字就是注释&#xff0c;注释提高了代码的阅读性。同时&#xff0c;注释也是一个程序员必须要具有的良好编程习惯。我们应该首先将自己的思想通过注释先整理出来&#xff0c;再用代码去体现。 在 C 中&#xff0c;一旦程序中某部分内…

数据结构和算法-计数排序

1.算法描述 技术排序是一个基于比较的排序算法&#xff0c;该算法于1954由Harold H. Seward 提出。它的优势在于对 一定范围内的整数排序时&#xff0c;它的复杂度为O&#xff08;nk&#xff09;&#xff08;其中k是整数的范围&#xff09;&#xff0c;快于任何比较排序算 法…

JavaEE高阶---Spring事务和事务传播机制

一&#xff1a;什么是事务&#xff1f; 事务定义&#xff1a;将⼀组操作封装成⼀个执⾏单元&#xff08;封装到⼀起&#xff09;&#xff0c;要么全部成功&#xff0c;要么全部失败。 二&#xff1a;Spring中事务的实现 编程式事务&#xff08;⼿动写代码操作事务&#xff09…