yolov8实战之torchserve服务化:使用yolov8x来预打标

news2025/1/20 3:39:07

前言

最近在做一个目标检测的任务,部署在边缘侧,对于模型的速度要求比较严格(yolov8n这种),所以模型的大小不能弄太大,所以原模型的性能受限,更多的重点放在增加数据上。实测yolov8x在数据集上的效果比小模型要好不少,所以想法是用yolov8x来预打标,然后选择一些置信度高的样本加到训练集来训练yolov8n,减少标注的成本。原始数据是在ceph上,比较直观的方式就是一张张读,然后一张张推理。这样效率不高,毕竟GPU适合组batch去推理,所以为了效率就需要自己去组成batch然后推理,然后再把batch的结果再分开对应到单张图上,虽然并不难,还是挺繁琐的,这也是我以前的做法。其实可以不需要这么麻烦,这种batching操作很多的推理服务都会帮我们做掉,我们只需要并发去请求就好了,和做单个没什么区别,要加其他的模型进行组合逻辑也是非常方便。GroundingDINO(一种开集目标检测算法)服务化,根据文本生成检测框_CodingInCV的博客-CSDN博客这个里面我们使用torchserve来实现了算法的服务化,这里我们依旧还是使用torchserve。基础就不做介绍了,可以读上面这篇。与GroundingDINO不同的是,这里我们会启用batch操作,而GroudingDINO里没有支持batch。

导出onnx模型

为了方便起见,我们使用onnx模型,避免去处理yolov8的pytorch环境问题,官网提供了导出的方式:Detect - Ultralytics YOLOv8 Docs
为了支持动态的batch, 我们导出时要以dynamic的方式导出,我这里对导出做了一点修改,只让batch为动态,而输入尺寸固定, 修改engine/exporter.py:
image.png
为了支持我们新增的dynamic_batch参数,我们还需要再default.yaml中增加这个参数,具体可以参考:yolov8训练进阶:新增配置参数_CodingInCV的博客-CSDN博客
image.png
然后自行写脚本转换:

from ultralytics import YOLO

model = YOLO('yolov8x6404/weights/last.pt')  # initialize
model.export(format = "onnx", opset = 11, simplify = True,
             dynamic_batch=True, imgsz=640)  # export to onnx

导出的模型将和输入的模型在同一个路径。

自定义handler

handler的写法

在GroundingDINO(一种开集目标检测算法)服务化,根据文本生成检测框_CodingInCV的博客-CSDN博客我们没有提到怎么写自己的模型handler,所谓模型handler就是告诉torchserve我们的模型如何载入、前处理和后处理。官方教程:Custom Service — PyTorch/Serve master documentation
torchserve自身带了一些handler:
BaseHandler: handler的基类,我们可以继承这个,也可以不继承,如果不继承则至少要实现initializehandle方法。
image.png

我们可以继承他们来实现自己的,也可以不继承,这里以不继承来实现,通用性比较强,不管什么模型都可以搞定,主要就是实现一个类,这个类至少要实现initializehandle方法:
initialize 就是初始化模型,这个方法必须有一个输入参数context(serve/ts/context.py at master · pytorch/serve (github.com)), 从这个参数我们可以拿到比如模型的路径、显卡号等信息。
handle 是接收输入请求和返回处理结果的接口,具有2个参数,第一个参数是输入请求,第二个参数也是context。
对于每个模型我们可以将推理过程拆分为三个过程(方法):preprocess、inference、postprocess,即前处理、推理、后处理,我们的handler只要实现这三个方法,然后依次在handle中调用即可,最后把输出按要求组合起来,handle的返回值必须是list of list,也就是数组的数组,外层list的长度等于输入的batch数(torchserve可以自动组batch),内层的list是单个请求的输出,里面的元素可以是dict,完整代码如下:

import logging
import os,sys
import onnxruntime as ort
import base64
import numpy as np
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
    from common.common import resize_image
except:
    from common import resize_image
import cv2

logger = logging.getLogger(__name__)
console_logger = logging.StreamHandler(sys.stdout)
console_logger.setLevel(logging.DEBUG)
console_logger.setFormatter(logging.Formatter("%(asctime)s %(name)s [%(levelname)s] %(message)s"))
logger.addHandler(console_logger)

class YOLOV8Handler(object):
    def __init__(self):
        self.context = None
        self.initialized = False
        self.model = None
        self.input_name = None
        self.input_shape = None
        self.conf_thres = 0.45
        self.iou_thres = 0.45
        self.class2label = {
            0: "body",
            1: "head",
        }
        self.device = None

    def initialize(self, context):
        #  load the model
        logger.info("initialize grounding dino handler")
        self.context = context
        self.manifest = context.manifest
        properties = context.system_properties
        model_dir = properties.get("model_dir")

        # Read model serialize/pt file
        serialized_file = self.manifest['model']['serializedFile']
        model_pt_path = os.path.join(model_dir, serialized_file)
        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model file")
        
        # get device

        available_providers =  ort.get_available_providers()
        provide_options = {}
        if "CUDAExecutionProvider" in available_providers:
            self.device = str(properties.get("gpu_id"))
            provide_options["device_id"] = self.device
            privider = "CUDAExecutionProvider"
            logger.info("using gpu {}".format(self.device))
        else:
            privider = "CPUExecutionProvider"
        
        self.model = ort.InferenceSession(model_pt_path, providers=[privider], provider_options=[provide_options])
        self.initialized = True
        # get input shape
        self.input_name = self.model.get_inputs()[0].name
        self.input_shape = self.model.get_inputs()[0].shape
        logger.info("model loaded successfully")

    def preprocess(self, data):
        logger.info("preprocess data")
        preprocessed_data = []
        preprocessed_params = []
        network_input_height = self.input_shape[2]
        network_input_width = self.input_shape[3]
        for row in data:
            input = row.get("data") or row.get("body")
            if isinstance(input, dict) and "image" in input:
                image = input["image"]
            else:
                logger.error("No image found in the request")
                assert False, "No  image found in the request"
            if isinstance(image, str):
                # if the image is a string of bytesarray.
                image = base64.b64decode(image)
            # If the image is sent as bytesarray
            if isinstance(image, (bytearray, bytes)):
                image = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_ANYCOLOR)
            else:
                logger.error("No caption or image found in the request")
                assert False, "No caption or image found in the request"
            
            image_h, image_w, _ = image.shape
            image, newh, neww, top, left  = resize_image(image, keep_ratio=True, dst_width=network_input_width, dst_height=network_input_height)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            preprocessed_data.append(image)
            preprocessed_params.append((newh, neww, top, left, image_h, image_w))
        logger.info("preprocess data done")
        preprocessed_data = np.array(preprocessed_data).astype(np.float32)
        preprocessed_data /= 255.0
        preprocessed_data = np.transpose(preprocessed_data, (0, 3, 1, 2))
        return preprocessed_data, preprocessed_params
    
    def inference(self, data, *args, **kwargs):
        logger.info("inference data")
        outputs = self.model.run(None, {self.input_name: data})
        return outputs[0]
    
    def postprocess_one(self, output, request_param):
        newh, neww, top, left, image_h, image_w = request_param
        x_factor = image_w / neww
        y_factor = image_h / newh
        outputs = output
        outputs = np.transpose(np.squeeze(outputs))
        boxes = {}
        for row in outputs:
            classes_scores = row[4:]
            max_score = np.max(classes_scores)
            if max_score < self.conf_thres:
                continue
            class_id = np.argmax(classes_scores)
            x, y, w, h = row[0], row[1], row[2], row[3]

            # Calculate the scaled coordinates of the bounding box
            x1 = x - w / 2
            y1 = y - h / 2
            x1 = x1-left
            y1 = y1-top
            x2 = x1 + w
            y2 = y1 + h

            # Scale the coordinates according to the original image
            x1 = x1 * x_factor
            y1 = y1 * y_factor
            x2 = x2 * x_factor
            y2 = y2 * y_factor
            if class_id not in boxes:
                boxes[class_id] = [[],[]]
            boxes[class_id][0].append([x1, y1, x2, y2])
            boxes[class_id][1].append(float(max_score))
        
        # NMS
        nms_boxes = []
        for class_id in boxes:
            candidate_boxes, scores = boxes[class_id]
            indices = cv2.dnn.NMSBoxes(candidate_boxes, scores, self.conf_thres, self.iou_thres)
            for index in indices:
                nms_boxes.append((candidate_boxes[index], scores[index], self.class2label[class_id]))
        return nms_boxes

    def postprocess(self, data):
        outputs, request_params = data
        boxes = []
        for i in range(len(outputs)):
            output = outputs[i]
            request_param = request_params[i]
            nms_boxes = self.postprocess_one(output, request_param)
            boxes.append(nms_boxes)

        return boxes
    
    def handle(self, data, context):
        self.context = context
        image, request_params = self.preprocess(data)
        outputs = self.inference(image)
        boxes_batch = self.postprocess((outputs, request_params))
        results = []
        for boxes in boxes_batch:
            ret = []
            for box, score, label in boxes:
                ret.append({"box": box, "score": score, "label": label})
            results.append(ret)
        return results

注意:为了实现batch操作,我们实现的接口都应该是对batch来的,而不是只对一张图。

调试handler

我们可以模仿context的内容来初始化handler, 然后调用handle方法来调试结果是否正常。

if __name__=="__main__":
    import addict
    context = addict.Dict()
    context.system_properties = {
        "gpu_id": 0,
        "model_dir": "./weights"

    }
    context.manifest = {
        "model": {
            "serializedFile": "yolov8x.onnx"
        }
        }
    handler = YOLOV8Handler()
    handler.initialize(context)
    image_path = "./body.png"
    with open(image_path, "rb") as f:
        image = f.read()

    data = [
        {
            "data": {
                "image": image
            }
        },
        {
            "data": {
                "image": image
            }
        }
    ]

    outputs = handler.handle(data, context)
    print(outputs)

镜像制作

在GroundingDINO(一种开集目标检测算法)服务化,根据文本生成检测框_CodingInCV的博客-CSDN博客中镜像的基础上安装onnxruntime-gpu, 或者在启动时安装

转换模型

这个操作和上一篇文章一样,只是权重文件和需要handler修改一下,不赘述:

docker run --rm -it -v $(pwd):/data -w /data torchserve:groundingdino bash -c "torch-model-archiver --model-name yolov8x --version 1.0 --serialized-file weights/yolov8x.onnx --handler yolov8/yolov8_handler.py --extra-files common/*.py"

启动服务

与上一篇服务化不同,我们启动时不载入所有模型,而是通过post接口去开启,方便设置模型的batch size, 其中端口号根据需要设置。

docker run -d --name groundingdino -v $(pwd)/model_store:/model_store -p 8080:8080 -p 8081:8081 -p 8082:8082 torchserve:groundingdino bash -c "pip install onnxruntime-gpu && torchserve --start --foreground --model-store /model_store

使用Management API载入模型

Management API — PyTorch/Serve master documentation
可以用curl也可以用postman, 如

curl -X POST "localhost:8081/models?url=yolov8x.mar&batch_size=8&max_batch_delay=50"

如果需要再修改batchsize, 要先调用卸载模型的接口写在,然后再调用上面的接口。
通过上面的操作,torchserve会帮我们组batch, 最大为8.

调用

import json
import base64
import requests
import threadpool

url = "http://localhost:8080/predictions/yolov8x"
headers = {"Content-Type": "application/json"}

def request_worker(arg):
    image_path = "./b03492798d5b44eeb70856b9253386df.jpeg"
    data = {
        "image": base64.b64encode(open(image_path, "rb").read()).decode("utf-8")
    }

    response = requests.post(url, headers=headers, json=data)
    print(response.text)

if __name__ == "__main__":
    pool = threadpool.ThreadPool(24)
    requests_task = threadpool.makeRequests(request_worker, range(100))
    [pool.putRequest(req) for req in requests_task]
    pool.wait()

这里,我们用多线程模仿了高并发的去调用模型,这样torchserve就可以自动的根据负载情况来组成batch了,提高模型的吞吐量。类似的,我们就可以方便的使用多线程去读取数据然后调用模型来得到预打标的结果,而不用去处理模型的依赖、组batch等逻辑,也可以很方便的提供给其他需要的同事来使用。

结语

本文简述了将yolov8服务化的过程,服务化后,我们可以方便的用模型来进行数据的预打标、分享模型给他人使用。
f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/925687.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

旺店通·企业版对接打通金蝶云星空订单查询接口与销售出库新增接口

旺店通企业版对接打通金蝶云星空订单查询接口与销售出库新增接口 数据源平台:旺店通企业版 旺店通是北京掌上先机网络科技有限公司旗下品牌&#xff0c;国内的零售云服务提供商&#xff0c;基于云计算SaaS服务模式&#xff0c;以体系化解决方案&#xff0c;助力零售企业数字化智…

聚水潭与金蝶云星空对接集成库存盘点查询打通其他出库单新增V2

聚水潭与金蝶云星空对接集成库存盘点查询打通其他出库单新增V2 来源系统:聚水潭 聚水潭是SaaS协同平台、电商ERP软件。聚水潭成立于2014年&#xff0c;创始人兼CEO骆海东拥有近三十年传统及电商ERP的研发和实施部署经验。聚水潭创建之初&#xff0c;以电商SaaSERP切入市场&…

机器学习算法示例的收集;MetaAI编码工具Code Llama;“天工AI搜索”首发实测

&#x1f989; AI新闻 &#x1f680; Meta推出新一代AI编码工具Code Llama&#xff0c;助力程序员提高开发效率 摘要&#xff1a;Meta推出Code Llama&#xff0c;这是一个基于Llama 2语言模型打造的AI编码工具&#xff0c;能够生成新的代码并调试人类编写的工作。Code Llama可…

【Go Web 篇】Go 语言进行 Web 开发:构建高性能网络应用

随着互联网的快速发展&#xff0c;Web 开发已经成为了软件开发领域中不可或缺的一部分。随之而来的是对于更高性能、更高效的网络应用的需求。在这个领域&#xff0c;Go 语言因其并发性能、简洁的语法以及丰富的标准库而备受关注。本篇博客将深入探讨如何使用 Go 语言进行 Web …

linux入门详解

文章目录 一、引言1.1 开发环境1.2 生产环境1.3 测试环境1.4 操作系统的选择 二、Linux介绍2.1 Linux介绍2.2 Linux的版本2.3 Linux和Windows区别 三、Linux安装3.1 安装VMware3.2 安装Xterm3.3 在VMware中安装Linux3.3.1 选择安装方式3.3.2 指定镜像方式3.3.3 选择操作系统类型…

springboot设置文件上传大小,默认是1mb

问题排查和解决过程 之前做了个项目&#xff0c;需要用到文件上传&#xff0c;启动项目正常&#xff0c;正常上传图片也正常&#xff0c;但这里图片刚好都小于1M&#xff0c;在代码配置文件里面也写了配置&#xff0c;限制大小为500M&#xff0c;想着就没问题&#xff08;测试…

基于NXP i.MX 6ULL核心板的物联网模块开发案例(1)

目录 前 言 1 SDIO WIFI模块测试 1.1 STA模式测试 1.2 AP模式测试 1.3 SDIO WIFI驱动编译 前言 本文主要介绍基于创龙科技TLIMX6U-EVM评估板的物联网模块开发案例&#xff0c;适用开发环境&#xff1a; Windows开发环境&#xff1a;Windows 7 64bit、Windows 10 64bit …

PDF怎么批量加密?掌握这招事半功倍

PDF文件是一种广泛使用的文档格式&#xff0c;而加密可以有效地保护PDF文件的安全性。当需要批量加密PDF文件时&#xff0c;以下是一些方法及注意事项。 PDF批量加密的方法 相信很多小伙伴平时都是直接在PDF阅读器中对文档进行加密&#xff0c;但是这样只能每次对当前打开的文…

当你在浏览器中输入了网址访问时产生了哪些技术步骤

当你在浏览器中输入了网址访问时产生了哪些技术步骤 前段时间在知乎上了看一些网络方面的知识&#xff0c;刚好小编自己也是从事这一块的相关工作由对网络方面有一定的了解。今天我们来讲讲&#xff0c;当你在浏览器中输入本站域名并回车后&#xff0c;这背后到底发生来什么事…

yolov3加上迁移学习和适度的数据增强形成的网络应用在输电线异物检测

Neural Detection of Foreign Objects for Transmission Lines in Power Systems Abstract. 输电线路为电能从一个地方输送到另一个地方提供了一条路径&#xff0c;确保输电线路的正常运行是向城市和企业供电的先决条件。主要威胁来自外来物&#xff0c;可能导致电力传输中断。…

【高阶数据结构】二叉树搜索树 {概念;实现:核心结构,增删查,默认成员函数;应用:K模型和KV模型;性能分析;相关练习}

二叉搜索树 一、二叉搜索树的概念 二叉搜索树又称二叉排序树&#xff0c;它可以是一棵空树&#xff0c;若果不为空则满足以下性质: 若它的左子树不为空&#xff0c;则左子树上所有节点的值都小于根节点的值若它的右子树不为空&#xff0c;则右子树上所有节点的值都大于根节点…

Cadence软件屏幕显示问题

问题 就是今天打开Cadence软件想导出网表看一下&#xff0c;发现没有显示确定按钮什么的&#xff0c;那个窗口也是无语&#xff0c;不能移动&#xff0c;缩放也只能左右缩放&#xff0c;还不能缩小什么的&#xff0c;真的醉了&#xff0c;后面就是调整窗口的分辨率。 因为我最…

windwos系统如何创建typecho个人博客并通过内网穿透实现无公网IP访问

文章目录 前言1. 环境安装2.安装Typecho3.安装cpolar内网穿透4. 固定公网地址5.配置Typecho 前言 Typecho是一款PHP语言编写的开源博客程序&#xff0c;它是一个轻量级的内容管理系统&#xff0c;专注于博客领域。支持多用户、多站点、多语言等功能&#xff0c;可以满足不同用…

Win10怎么关闭自动更新?简单4招为你解决烦恼!

“买了一部win10的电脑&#xff0c;每次电脑自动更新都会导致我一些文件丢失或者系统错误。怎么才能关闭win10自动更新的功能呢&#xff1f;” Win10自动更新有时候会很影响我们使用电脑。在目前电脑用户中&#xff0c;使用win10系统的用户占大多数。因此很多朋友都会反映win10…

LLM(大语言模型)解码时是怎么生成文本的?

Part1配置及参数 transformers4.28.1 源码地址&#xff1a;transformers/configuration_utils.py at v4.28.1 huggingface/transformers (github.com) 文档地址&#xff1a;Generation (huggingface.co) 对于生成任务而言&#xff1a;text-decoder, text-to-text, speech-…

华为质量管理:从产品质量到用户体验,Kano模型成为新方向

目录 前言 华为质量管理的四个阶段 基于 IPD 如何做质量管理呢&#xff1f; CSDN相关课程 作者简介 前言 今天继续来谈谈华为流程体系中的质量管理过程。 通常来说质量具体是指产品的质量&#xff0c;也就是产品的使用价值及其属性。 产品再细分的话可以分为三个层次&a…

沃尔玛、亚马逊、ozon卖家必看:如何为旺季做准备?

近二十年来&#xff0c;得益于国家外贸政策的大力扶持&#xff0c;再加上近几年国家对跨境电商行业发展的高度重视&#xff0c;国货出海机会明显增多。 在政策利好的情况下&#xff0c;生产制造业的蓬勃发展等各种有利的局面&#xff0c;可谓是天时地利人和&#xff0c;那么在…

JetBrains 2023.2全新发布!IDEA、PyCharm等支持AI辅助

日前JetBrains官方正式宣布旗下IDE系列今年第二个重要版本——v2023.2全新发布&#xff0c;涵盖了 IntelliJ IDEA、PyCharm、WebStorm等一众知名产品&#xff0c;接下来我们一起详细了解一下他们的更新重点吧~ IntelliJ IDEA v2023.2——引入AI辅助开发 IntelliJ IDEA 2023.2…

java-CyclicBarrier、CountDownLatch、Semaphore 的用法以及 volatile 关键字的作用

CyclicBarrier、CountDownLatch、Semaphore 的用法 1. CountDownLatch&#xff08;线程计数器 &#xff09; CountDownLatch 类位于 java.util.concurrent 包下&#xff0c;利用它可以实现类似计数器的功能。比如有一个任务 A&#xff0c;它要等待其他 4 个任务执行完毕之后才…

powerJob报错以及解决办法集锦

1. 本地测试成功新建任务并运行成功&#xff0c;但是部署到服务器时新建任务只要 “参数”有中文就无法报错 前台报错信息&#xff1a; ERROR&#xff1a;JpaSystemException: could not execute statement; nested exception is org.hibernate.exception.GenericJDBCException…