AI模型部署 | onnxruntime部署YOLOv8分割模型详细教程

news2024/11/23 22:05:46

本文首发于公众号【DeepDriving】,欢迎关注。

0. 引言

我之前写的文章《基于YOLOv8分割模型实现垃圾识别》介绍了如何使用YOLOv8分割模型来实现垃圾识别,主要是介绍如何用自定义的数据集来训练YOLOv8分割模型。那么训练好的模型该如何部署呢?YOLOv8分割模型相比检测模型多了一个实例分割的分支,部署的时候还需要做一些后处理操作才能得到分割结果。
本文将详细介绍如何使用onnxruntime框架来部署YOLOv8分割模型,为了方便理解,代码采用Python实现。

1. 准备工作

  • 安装onnxruntime

    onnxruntime分为GPU版本和CPU版本,均可以通过pip直接安装:

    pip install onnxruntime-gpu  #安装GPU版本
    
    pip install onnxruntime  #安装CPU版本
    

    注意: GPU版本和CPU版本建议只选其中一个安装,否则默认会使用CPU版本

  • 下载YOLOv8分割模型权重

    Ultralytics官方提供了用COCO数据集训练的模型权重,我们可以直接从官方网站https://docs.ultralytics.com/tasks/segment/下载使用,本文使用的模型为yolov8m-seg.pt

  • 转换onnx模型

    调用下面的命令可以把YOLOv8m-seg.pt模型转换为onnx格式的模型:

    yolo task=segment mode=export model=yolov8m-seg.pt format=onnx
    

    转换成功后得到的模型为yolov8m-seg.onnx

2. 模型部署

2.1 加载onnx模型

首先导入onnxruntime包,然后调用其API加载模型即可:

import onnxruntime as ort

session = ort.InferenceSession("yolov8m-seg.onnx", providers=["CUDAExecutionProvider"])

因为我使用的是GPU版本的onnxruntime,所以providers参数设置的是"CUDAExecutionProvider";如果是CPU版本,则需设置为"CPUExecutionProvider"

模型加载成功后,我们可以查看一下模型的输入、输出层的属性:

for input in session.get_inputs():
    print("input name: ", input.name)
    print("input shape: ", input.shape)
    print("input type: ", input.type)

for output in session.get_outputs():
    print("output name: ", output.name)
    print("output shape: ", output.shape)
    print("output type: ", output.type)

结果如下:

input name:  images
input shape:  [1, 3, 640, 640]
input type:  tensor(float)
output name:  output0
output shape:  [1, 116, 8400]
output type:  tensor(float)
output name:  output1
output shape:  [1, 32, 160, 160]
output type:  tensor(float)

从上面的打印信息可以知道,模型有一个尺寸为[1, 3, 640, 640]的输入层和两个尺寸分别为[1, 116, 8400][1, 32, 160, 160]的输出层。

2.2 数据预处理

数据预处理采用OpenCVNumpy实现,首先导入这两个包

import cv2
import numpy as np

OpenCV读取图片后,把数据按照YOLOv8的要求做预处理

image = cv2.imread("soccer.jpg")
image_height, image_width, _ = image.shape
input_tensor = prepare_input(image, model_width, model_height)
print("input_tensor shape: ", input_tensor.shape)

其中预处理函数prepare_input的实现如下:

def prepare_input(bgr_image, width, height):
    image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (width, height)).astype(np.float32)
    image = image / 255.0
    image = np.transpose(image, (2, 0, 1))
    input_tensor = np.expand_dims(image, axis=0)
    return input_tensor

处理流程如下:

1. 把OpenCV读取的BGR格式的图片转换为RGB格式;
2. 把图片resize到模型输入尺寸640x640;
3. 对像素值除以255做归一化操作;
4. 把图像数据的通道顺序由HWC调整为CHW;
5. 扩展数据维度,将数据的维度调整为NCHW。

经过预处理后,输入数据input_tensor的维度变为[1, 3, 640, 640],与模型的输入尺寸一致。

2.3 模型推理

输入数据准备好以后,就可以送入模型进行推理:

outputs = session.run(None, {session.get_inputs()[0].name: input_tensor})

前面我们打印了模型的输入输出属性,可以知道模型有两个输出分支,其中一个output0是目标检测分支,另一个output1则是实例分割分支,这里打印一下它们的尺寸看一下

#squeeze函数是用于删除shape中为1的维度,对output0做transpose操作是为了方便后续操作
output0 = np.squeeze(outputs[0]).transpose()
output1 = np.squeeze(outputs[1])
print("output0 shape:", output0.shape)
print("output1 shape:", output1.shape)

结果如下:

output0 shape: (8400, 116)
output1 shape: (32, 160, 160)

处理后目标检测分支的维度为[8400, 116],表示模型总共可以检测出8400个目标(大部分是无效的目标),每个目标包含116个参数。刚接触YOLOv8分割模型的时候可能会对116这个数字感到困惑,这里有必要解释一下:每个目标的参数包含4个坐标属性(x,y,w,h)、80个类别置信度和32个实例分割参数,所以总共是116个参数。实例分割分支的维度为[32, 160, 160],其中第一个维度32与目标检测分支中的32个实例分割参数对应,后面两个维度则由模型输入的宽和高除以4得到,本文所用的模型输入宽和高都是640,所以这两个维度都是160

2.4 后处理

首先把目标检测分支输出的数据分为两个部分,把实例分割相关的参数从中剥离。

boxes = output0[:, 0:84]
masks = output0[:, 84:]
print("boxes shape:", boxes.shape)
print("masks shape:", masks.shape)
boxes shape: (8400, 84)
masks shape: (8400, 32)

然后实例分割这部分数据masks要与模型的另外一个分支输出的数据output1做矩阵乘法操作,在这之前要把output1的维度变换为二维。

output1 = output1.reshape(output1.shape[0], -1)
masks = masks @ output1
print("masks shape:", masks.shape)
masks shape: (8400, 25600)

做完矩阵乘法后,就得到了8400个目标对应的实例分割掩码数据masks,可以把它与目标检测的结果boxes拼接到一起。

detections = np.hstack([boxes, masks])
print("detections shape:", detections.shape)
detections shape: (8400, 25684)

到这里读者应该就能理解清楚了,YOLOv8模型总共可以检测出8400个目标,每个目标的参数包含4个坐标属性(x,y,w,h)、80个类别置信度和一个160x160=25600大小的实例分割掩码。

由于YOLOv8模型检测出的8400个目标中有大量的无效目标,所以先要通过置信度过滤去除置信度低于阈值的目标,对于满足置信度满足要求的目标还需要通过非极大值抑制(NMS)操作去除重复的目标。

objects = []
for row in detections:
    prob = row[4:84].max()
    if prob < 0.5:
        continue
    class_id = row[4:84].argmax()
    label = COCO_CLASSES[class_id]
    xc, yc, w, h = row[:4]
    // 把x1, y1, x2, y2的坐标恢复到原始图像坐标
    x1 = (xc - w / 2) / model_width * image_width
    y1 = (yc - h / 2) / model_height * image_height
    x2 = (xc + w / 2) / model_width * image_width
    y2 = (yc + h / 2) / model_height * image_height
    // 获取实例分割mask
    mask = get_mask(row[84:25684], (x1, y1, x2, y2), image_width, image_height)
    // 从mask中提取轮廓
    polygon = get_polygon(mask, x1, y1)
    objects.append([x1, y1, x2, y2, label, prob, polygon, mask])

// NMS
objects.sort(key=lambda x: x[5], reverse=True)
results = []
while len(objects) > 0:
    results.append(objects[0])
    objects = [object for object in objects if iou(object, objects[0]) < 0.5]

这里重点讲一下获取实例分割掩码的过程。

前面说了每个目标对应的实例分割掩码数据大小为160x160,但是这个尺寸是对应整幅图的掩码。对于单个目标来说,还要从这个160x160的掩码中去截取属于自己的掩码,截取的范围由目标的box决定。上面的代码得到的box是相对于原始图像大小,截取掩码的时候需要把box的坐标转换到相对于160x160的大小,截取完后再把这个掩码的尺寸调整回相对于原始图像大小。截取到box大小的数据后,还需要对数据做sigmoid操作把数值变换到01的范围内,也就是求这个box范围内的每个像素属于这个目标的置信度。最后通过阈值操作,置信度大于0.5的像素被当做目标,否则被认为是背景。

具体实现的代码如下:

def get_mask(row, box, img_width, img_height):
  mask = row.reshape(160, 160)
  x1, y1, x2, y2 = box
  // box坐标是相对于原始图像大小,需转换到相对于160*160的大小
  mask_x1 = round(x1 / img_width * 160)
  mask_y1 = round(y1 / img_height * 160)
  mask_x2 = round(x2 / img_width * 160)
  mask_y2 = round(y2 / img_height * 160)
  mask = mask[mask_y1:mask_y2, mask_x1:mask_x2]
  mask = sigmoid(mask)
  // 把mask的尺寸调整到相对于原始图像大小
  mask = cv2.resize(mask, (round(x2 - x1), round(y2 - y1)))
  mask = (mask > 0.5).astype("uint8") * 255
  return mask

这里需要注意的是,160x160是相对于模型输入尺寸为640x640来的,如果模型输入是其他尺寸,那么上面的代码需要做相应的调整。

如果需要检测的是下面这个图片:

通过上面的代码可以得到最左边那个人的分割掩码为

但是我们需要的并不是这样一张图片,而是需要用于表示这个目标的轮廓,这可以通过OpenCVfindContours函数来实现。findContours函数返回的是一个用于表示该目标的点集,然后我们可以在原始图像中用fillPoly函数画出该目标的分割结果。

全部目标的检测与分割结果如下:

3. 一点其他的想法

从前面的部署过程可以知道,做后处理的时候需要对实例分割的数据做矩阵乘法、sigmoid激活、维度变换等操作,实际上这些操作也可以在导出模型的时候集成到onnx模型中去,这样就可以简化后处理操作。

首先需要修改ultralytics代码仓库中ultralytics/nn/modules/head.py文件的代码,把SegmentForward函数最后的代码修改为:

if self.export:
    output1 = p.reshape(p.shape[0], p.shape[1], -1)
    boxes = x.permute(0, 2, 1)
    masks = torch.sigmoid(mc.permute(0, 2, 1) @ output1)
    out = torch.cat([boxes, masks], dim=2)
    return out
else:
    return (torch.cat([x[0], mc], 1), (x[1], mc, p))

然后修改ultralytics/engine/exporter.py文件中torch.onnx.export的参数,把模型的输出数量改为1个。

在这里插入图片描述

代码修改完成后,执行命令pip install -e '.[dev]'使之生效,然后再重新用yolo命令导出模型。用netron工具可以看到模型只有一个shape[1,8400,25684]的输出。

这样在后处理的时候就可以直接去解析boxmask了,并且mask的数据不需要进行sigmoid激活。

4. 参考资料

  • How to implement instance segmentation using YOLOv8 neural network
  • https://github.com/AndreyGermanov/yolov8_segmentation_python

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1288261.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小航助学题库白名单竞赛考级蓝桥杯等考scratch(10级)(含题库教师学生账号)

需要在线模拟训练的题库账号请点击 小航助学编程在线模拟试卷系统&#xff08;含题库答题软件账号&#xff09; 需要在线模拟训练的题库账号请点击 小航助学编程在线模拟试卷系统&#xff08;含题库答题软件账号&#xff09;

chineseocr项目不使用web推理-docker容器化

整个流程介绍 拉取 ufoym/deepo 镜像 -- 因为包含了主流深度学习框架&#xff0c;镜像4G出头。拉取 chineseocr 项目代码。修改代码&#xff0c;不使用web&#xff0c;增加命令行传入图片路径的功能打包成docker镜像。 开始 拉取 ufoym/deepo 镜像 &#xff1a;cpu版本为例 do…

封装ThreadLocal

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 为什么要封装ThreadLoc…

geemap学习笔记020:如何搜索Earth Engine Python脚本

前言 本节内容比较简单&#xff0c;但是对于自主学习比较重要&#xff0c;JavaScript提供了很多的示例代码&#xff0c;为了便于学习&#xff0c;geemap将其转为了Python代码。 Earth Engine Python脚本 import ee import geemapee.Initialize()geemap.ee_search() #搜索Ear…

Vue3网站用户引导功能【Intro.js】

一、介绍 Intro.js 是一个用于创建网站用户引导、功能介绍和教程的 JavaScript 库。它允许开发者通过步骤和提示突出显示网站上的特定元素&#xff0c;以帮助用户更好地了解和使用网站的功能。以下是 Intro.js 的一些关键特点和用法介绍&#xff1a; 更多Intro.js 功能网址&a…

图扑数字孪生压缩空气储能管控平台

压缩空气储能在解决可再生能源不稳定性和提供可靠能源供应方面具有重要的优势。压缩空气储能&#xff0c;是指在电网负荷低谷期将电能用于压缩空气&#xff0c;在电网负荷高峰期释放压缩空气推动汽轮机发电的储能方式。通过提高能量转换效率、增加储能密度、快速启动和调节能力…

电子编曲软件FL Studio2024汉化中文免费版下载

电子编曲需要什么软件&#xff1f;市面上的宿主软件都可以完成电子编曲的工作&#xff0c;主要适用电子音乐风格编曲的宿主软件有FL Studio、Ableton Live等。电子编曲需要什么基础&#xff1f;需要对于电子音乐足够熟悉、掌握基础乐理知识以及宿主软件的使用方法。 就我个人的…

Linux cgroup技术

cgroup 全称是 control group&#xff0c;顾名思义&#xff0c;它是用来做“控制”的。控制什么东西呢&#xff1f;当然是资源的使用了。 cgroup 定义了下面的一系列子系统&#xff0c;每个子系统用于控制某一类资源。 CPU 子系统&#xff0c;主要限制进程的 CPU 使用率。cpu…

王道数据结构课后代码题p175 06.已知一棵树的层次序列及每个结点的度,编写算法构造此树的孩子-兄弟链表。(c语言代码实现)

/* 此树为 A B C D E F G 孩子-兄弟链表为 A B E C F G D */ 本题代码如下 void createtree(tree* t, char a[], int degree[], int n) {// 为B数组分配内存tree* B (tree*)malloc(sizeof(tree) * n);int i 0;i…

CENTOS 7 添加黑名单禁止IP访问服务器

一、通过 firewall 添加单个黑名单 只需要把ip添加到 /etc/hosts.deny 文件即可&#xff0c;格式 sshd:$IP:deny vim /etc/hosts.deny# 禁止访问sshd:*.*.*.*:deny# 允许的访问sshd:.*.*.*:allowsshd:.*.*.*:allow 二、多次失败登录即封掉IP&#xff0c;防止暴力破解的脚本…

搞程序权益系统v1.1

继1.0出来后我就把antdui换成elem 新增号卡功能现在只支持对接号氪系统 大家问我这个程序到底有什么用&#xff0c;我这边已经在写和WordPress对接文件&#xff0c;到时候在WordPress网站打开该程序就可以把订单同步到你的程序里面去&#xff0c;当然自己有集成能力也可以到小…

FairGuard无缝兼容小米澎湃OS、ColorOS 14 、鸿蒙4!

随着移动互联网时代的发展&#xff0c;各大手机厂商为打造生态系统、构建自身的技术壁垒&#xff0c;纷纷投身自研操作系统。 而对于一款游戏安全产品&#xff0c;在不同操作系统下&#xff0c;是否能够无缝兼容并且提供稳定的、高强度的加密保护&#xff0c;成了行业的一大痛…

python笔记:dtaidistance

1 介绍 用于DTW的库纯Python实现和更快的C语言实现 2 DTW举例 2.1 绘制warping 路径 from dtaidistance import dtw from dtaidistance import dtw_visualisation as dtwvis import numpy as np import matplotlib.pyplot as plts1 np.array([0., 0, 1, 2, 1, 0, 1, 0, 0…

Redis 命令全解析之 String类型

文章目录 ⛄String 介绍⛄命令⛄对应 RedisTemplate API⛄应用场景 ⛄String 介绍 String 类型&#xff0c;也就是字符串类型&#xff0c;是Redis中最简单的存储类型。 其value是字符串&#xff0c;不过根据字符串的格式不同&#xff0c;又可以分为3类&#xff1a; ● string&…

javaScript(四):函数和常用对象

文章目录 1、函数介绍2、函数的作用3、函数语法4、常用对象&#xff1a;数组5、常用对象&#xff1a;String6、常用对象&#xff1a;自定义对象 1、函数介绍 函数是一段可重复使用的代码块&#xff0c;用于执行特定任务或计算并返回结果。 函数由以下几个要素组成&#xff1a; …

2024最新电脑系统清理软件哪个好用?

基本上&#xff0c;不管是win版还是Mac版的电脑&#xff0c;其装机必备就是一款电脑系统清理软件&#xff0c;就比如Mac&#xff0c;目前在市面上&#xff0c;电脑系统清理软件是非常多的。 对于不熟悉系统的用户来说&#xff0c;使用一些小众工具&#xff0c;往往很多用户都不…

Flask项目Day1,Flask常见第三方拓展包

拉项目 git clone https://gitee.com/hahaguai007/python-flask-mysql.git git clone 项目地址运行后即可获取项目 2.创建数据库 在MySQL中创建一个数据库&#xff0c;名字自己定&#xff0c;然后修改RealProject\settings.py里的SQLALCHEMY_DATABASE_URI&#xff0c;格式为 …

一部,即全部,十年超越之作一加12售价4299元起

2023 年 12 月 5 日&#xff0c;一加正式发布十年旗舰一加 12。作为一加十年超越之作&#xff0c;一加 12 秉持「产品力优先」理念&#xff0c;带来多项领先行业的首创技术。一加 12 全球首发拥有医疗级护眼方案和行业第一 4500nit 峰值亮度的 2K 东方屏&#xff0c;完整搭载 F…

【Intel/Altera】 全系列FPGA最新汇总说明,持续更新中

前言 2023年11月14日英特尔 FPGA中国技术日&#xff0c;Intel刚发布了新的FPGA系列&#xff0c;官网信息太多&#xff0c;我这里结合以前的信息&#xff0c;简单汇总更新一下&#xff0c;方便大家快速了解Intel/Altera FPGA家族。 目录 前言 Altera和Intel 型号汇总 1. Agi…

【五分钟】学会利用cv2.resize()函数实现图像缩放

引言 在numpy知识库&#xff1a;深入理解numpy.resize函数和数组的resize方法中&#xff0c;小编较为详细地探讨了numpy的resize函数背后的机理。从结果来看&#xff0c;numpy.resize函数并不适合对图像进行缩放操作。而opencv中的resize函数虽然和numpy的resize函数同名&…