yolov8-seg 分割推理流程

news2024/11/25 7:17:25

目录

一、分割+检测

二、图像预处理

二、推理

三、后处理与可视化

3.1、后处理

3.2、mask可视化

四、完整pytorch代码


一、分割+检测

注:本篇只是阐述推理流程,tensorrt实现后续跟进。

yolov8-pose的tensorrt部署代码稍后更新,还是在仓库:GitHub - FeiYull/TensorRT-Alpha: 🔥🔥🔥TensorRT-Alpha supports YOLOv8、YOLOv7、YOLOv6、YOLOv5、YOLOv4、v3、YOLOX、YOLOR...🚀🚀🚀CUDA IS ALL YOU NEED.🍎🍎🍎It also supports end2end CUDA C acceleration and multi-batch inference.

也可以关注:TensorRT系列教程-CSDN博客

以下是官方预测代码:

from ultralytics import YOLO
model = YOLO(model='yolov8n-pose.pt')
model.predict(source="d:/Data/1.jpg", save=True)

推理过程无非是:图像预处理 -> 推理 -> 后处理 + 可视化,这三个关键步骤在文件大概247行:D:\CodePython\ultralytics\ultralytics\engine\predictor.py,代码如下:

# Preprocess
with profilers[0]:
	im = self.preprocess(im0s) # 图像预处理

# Inference
with profilers[1]:
	preds = self.inference(im, *args, **kwargs) # 推理

# Postprocess
with profilers[2]:
	self.results = self.postprocess(preds, im, im0s) # 后处理

二、图像预处理

通过debug,进入上述self.preprocess函数,看到代码实现如下。处理流程大概是:padding(满足矩形推理),图像通道转换,即:BGR装RGB,检查图像数据是否连续,存储顺序有HWC转为CHW,然后归一化。需要注意,原始pytorch框架图像预处理的时候,会将图像缩放+padding为HxW的图像,其中H、W为32倍数,而导出tensorrt的时候,为了高效推理,H、W 固定为640x640。

def preprocess(self, im):
	"""Prepares input image before inference.

	Args:
		im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
	"""
	not_tensor = not isinstance(im, torch.Tensor)
	if not_tensor:
		im = np.stack(self.pre_transform(im))
		im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
		im = np.ascontiguousarray(im)  # contiguous
		im = torch.from_numpy(im)

	img = im.to(self.device)
	img = img.half() if self.model.fp16 else img.float()  # uint8 to fp16/32
	if not_tensor:
		img /= 255  # 0 - 255 to 0.0 - 1.0
	return img

二、推理

图像预处理之后,直接推理就行了,这里是基于pytorch推理。

def inference(self, im, *args, **kwargs):
	visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
							   mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
	return self.model(im, augment=self.args.augment, visualize=visualize)

三、后处理与可视化

3.1、后处理

640x640输入之后,有两个输出,其中

  • output1:尺寸为:116X8400,其中116=4+80+32,32为seg部分特征,经过NMS之后,输出为:N*38,其中38=4 + 2 + 32
  • output2:尺寸为32x160x160,拿上面NMS后的特征图后面,即:N*38矩阵后面部分N*32的特征图和output2作矩阵乘法,得到N*160*160的矩阵,接着执行sigmiod,然后拉平得到N*160*160 的mask。

然后将bbox缩放160*160的坐标系,如下代码,用于截断越界的mask,就是如下函数。最后,将所有mask上采样到640*640,然后用阀值0.5过一下。最后mask中只有0和1了,结束。

有关def crop_mask(masks, boxes):的理解:

def crop_mask(masks, boxes):
    """
    It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box

    Args:
      masks (torch.Tensor): [n, h, w] tensor of masks
      boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form

    Returns:
      (torch.Tensor): The masks are being cropped to the bounding box.
    """
    n, h, w = masks.shape
    x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(n,1,1)
    r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]  # rows shape(1,1,w)
    c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]  # cols shape(1,h,1)

    return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))

上面代码最后一句return,如下图理解,mask中所有点,例如点(r,c)必须在bbox内部。做法就是将bbox缩放到和mask一样的坐标系(160x160)如下图,然后使用绿色的bbox将mask进行截断:

3.2、mask可视化

直接将mask从灰度图转为彩色图,然后将类别对应的颜色乘以0.4,最后加在彩色图上就行了。

四、完整pytorch代码

将以上流程合并起来,并加以修改,完整代码如下:

import torch
import cv2 as cv
import numpy as np
from ultralytics.data.augment import LetterBox
from ultralytics.utils import ops
from ultralytics.engine.results import Results
import copy

# path = 'd:/Data/1.jpg'
path = 'd:/Data/640640.jpg'
device = 'cuda:0'
conf = 0.25
iou = 0.7

# preprocess
im = cv.imread(path)
# letterbox
im = [im]
orig_imgs = copy.deepcopy(im)
im = [LetterBox([640, 640], auto=True, stride=32)(image=x) for x in im]
im = im[0][None] # im = np.stack(im)
im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im = np.ascontiguousarray(im)  # contiguous
im = torch.from_numpy(im)
img = im.to(device)
img = img.float()
img /= 255
# load model pt
ckpt = torch.load('yolov8n-seg.pt', map_location='cpu')
model = ckpt['model'].to(device).float()  # FP32 model
model.eval()

# inference
preds = model(img)

# poseprocess
p = ops.non_max_suppression(preds[0], conf, iou, agnostic=False, max_det=300, nc=80, classes=None)
results = []
# 如果导出onnx,第二个输出维度是1,应该就是mask,需要后续上采样
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported???????
for i, pred in enumerate(p):
    orig_img = orig_imgs[i]
    if not len(pred):  # save empty boxes
        results.append(Results(orig_img=orig_img, path=path, names=model.names, boxes=pred[:, :6]))
        continue
    masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
    if not isinstance(orig_imgs, torch.Tensor):
        pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
    results.append(Results(orig_img=orig_img, path=path, names=model.names, boxes=pred[:, :6], masks=masks))

# show
plot_args = {'line_width': None,'boxes': True,'conf': True, 'labels': True}
plot_args['im_gpu'] = img[0]
result = results[0]
plotted_img = result.plot(**plot_args)
cv.imshow('plotted_img', plotted_img)
cv.waitKey(0)
cv.destroyAllWindows()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1265322.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何根据接口文档,轻松快速的模拟接口服务?

什么是WireMock? WireMock 是一个Http 模拟服务,其核心也是一个web服务,WireMock主要是为特定请求提供固定的返回值。 WireMock可以作为单独进程启动,模拟一个WEB服务器,提供一些API访问,并返回特定的返回值。也可以作为第三方库在项目中使用。 如何使用 standalone方…

csdn博客编写技巧

随便记录一下csdn博客编写时候用的到技巧&#xff0c;以作备忘。 1. 表格 1.1 Markdown-Table-Generator 这个是csdn编辑器中&#xff0c;工具栏自带的表格用法。主要优点是比较直观&#xff0c;缺点是无法设置表格中行列的宽高。 用法&#xff1a; | 表头一 | 表头二 | |-…

贪心算法(新坑)

贪心入门 概述&#xff1a; 贪心算法是一种在每一步选择中都采取当前最优解的策略&#xff0c;希望最终能够得到全局最优解的算法。简单来说&#xff0c;它会不断地做出局部最优的选择&#xff0c;相信通过这种选择最终能够达到全局最优。 举个例子来说明。假设你要从一个迷…

Vue基本使用(一)

&#x1f4d1;前言 本文主要是【Vue】——Vue基本使用的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是听风与他&#x1f947; ☁️博客首页&#xff1a;CSDN主页听风与他 &#x1f304;每日一句&#x…

【CAD二次开发】标注箭头,获取修改标注箭头图块

常见的的标注箭头有以下种类 public static List<string> ArrowBlock = new List<string>(){" ","_CLOSEDBLANK&

淘宝API接口系列:连接商户与消费者的桥梁

一、引言 淘宝&#xff0c;作为中国最大的电商平台之一&#xff0c;拥有数以亿计的注册用户和海量的商品信息。淘宝API接口作为连接商户与消费者的重要桥梁&#xff0c;为开发者提供了丰富的电商资源&#xff0c;帮助他们创新和优化业务。本文将深入探讨淘宝API接口的相关知识…

软件设计开发规程文件

《软件设计开发规程文件》 目的&#xff1a;为需求设计、开发、实现解决方案。

LLM大语言模型

大语言模型的定义 大语言模型&#xff08;英文&#xff1a;Large Language Model&#xff0c;缩写LLM&#xff09;&#xff0c;也称大型语言模型&#xff0c;是一种人工智能模型&#xff0c;旨在理解和生成人类语言。它们在大量的文本数据上进行训练&#xff0c;可以执行广泛的…

怎么更新BI报表数据?问我就对了

BI大数据分析工具上有大量的BI报表模板&#xff0c;这些模板都是一个个完整的BI报表&#xff0c;只需将数据源更换&#xff0c;立即就能用来分析我们自己的数据。那&#xff0c;BI报表的数据怎么更新&#xff1f;接下来就来说说这事。 目的&#xff1a;更新BI报表数据 工具&a…

PPSSPP (PSP游戏模拟器)最新版安装使用教程

PPSSPP优势 1、目前唯一的也是最好的psp模拟器 可运行绝大多数psp游戏且运行高速&#xff0c;即使是低配手机也能游玩经典大作。 2、支持自定义调节虚拟手柄和实体手柄连接 ppsspp模拟器支持使用虚拟手柄或者连接实体手柄游玩&#xff0c;同时还可以自定义调节按键选项。 …

微信小程序+中草药分类+爬虫+keras

目录 1 介绍2 数据爬虫3 模型训练和验证3.1 模型训练3.2 导入一张图片进行验证 4 后台flask部署5 微信小程序 1 介绍 本项目使用深度学习模型&#xff0c;训练5种中药材数据集&#xff0c;然后将其集成到微信小程序&#xff0c;通过微信小程序拍照&#xff0c;将图片传输给后端…

蓝桥杯day02——Fizz Buzz

1、题目 给你一个整数 n &#xff0c;找出从 1 到 n 各个整数的 Fizz Buzz 表示&#xff0c;并用字符串数组 answer&#xff08;下标从 1 开始&#xff09;返回结果&#xff0c;其中&#xff1a; answer[i] "FizzBuzz" 如果 i 同时是 3 和 5 的倍数。answer[i] &…

GitLab 登录中,LDAP和 Standard 验证有什么区别

在 GitLab 中&#xff0c;LDAP&#xff08;Lightweight Directory Access Protocol&#xff09;和 Standard 验证是两种不同的身份验证方法&#xff0c;它们有以下区别&#xff1a; LDAP&#xff08;Lightweight Directory Access Protocol&#xff09;身份验证&#xff1a; L…

uniapp中uni.navigateBack返回后刷新页面数据

文章目录 一、前言1.1、[uni.navigateBack](https://uniapp.dcloud.net.cn/api/router.html#navigateback) 二、方法2.1、父页面设置钩子函数onBackPress2.2、uni.$emit和uni.$on监听通知数据变更2.2.1、子页面2.2.2、父页面 2.3、onShow钩子函数处理数据2.3.1、子页面2.3.2、父…

《深入理解计算机系统》学习笔记 - 第三课 - 位,字节和整型

Lecture 03 Bits,Bytes, and Integer count 位&#xff0c;字节&#xff0c;整型 文章目录 Lecture 03 Bits,Bytes, and Integer count 位&#xff0c;字节&#xff0c;整型运算&#xff1a;加&#xff0c;减&#xff0c;乘&#xff0c;除加法乘法取值范围乘法结果 使用无符号注…

微信小程序Vue+nodejs教室自习室座位预约系统68u2m

本文从管理员、用户的功能要求出发&#xff0c;教室预约系统小程序中的功能模块主要是实现管理端&#xff1b;首页、个人中心、教室信息管理、教室设备管理、用户管理、教室预约管理、管理员管理、系统管理&#xff0c;微信端&#xff1b;首页、教室信息、教室设备、教室预约、…

osgFX扩展库-异性光照、贴图、卡通特效(1)

本章将简单介绍 osgFX扩展库及osgSim 扩展库。osgFX库用得比较多,osgSim库不常用&#xff0c;因此&#xff0c;这里只对这个库作简单的说明。 osgFX扩展库 osgFX是一个OpenSceneGraph 的附加库&#xff0c;是一个用于实现一致、完备、可重用的特殊效果的构架工具&#xff0c;其…

C#:程序发布的大小控制

.net不讨喜有个大原因就是.net平台本身太大了&#xff0c;不同版本没有兼容性&#xff0c;程序依赖哪个版本用户就要安装哪个版本&#xff0c;除非你恰好用的是操作系统默认安装的版本——问题是不同版本操作系统默认安装的不一样。 所以打包程序就很头疼&#xff0c;不打包平台…

Python Selenium 图片资源自动搜索保存 项目实践

实现访问首页 from os.path import dirnamefrom selenium import webdriverclass ImageAutoSearchAndSave:"""图片自动搜索保存"""def __init__(self):"""初始化"""self.driver webdriver.Chrome(executable_pa…

Vue实现可拖拽边界布局

Vue实现可拖拽边界布局 在前端开发中&#xff0c;有时需要实现一种可拖拽边界的布局&#xff0c;通过拖动分隔线来调整不同区域大小。例如&#xff0c;下图是一个典型的可拖拽边界布局&#xff0c;它由左右两个区域组成&#xff0c;左边是一个树形菜单&#xff0c;右边是一个上…