引言
Label Studio ML 后端是一个 SDK,用于包装您的机器学习代码并将其转换为 Web 服务器。Web 服务器可以连接到正在运行的 Label Studio 实例,以自动执行标记任务。我们提供了一个示例模型库,您可以在自己的工作流程中使用这些模型,也可以根据需要进行扩展和自定义。
如果您想改为编写自己的模型,请参阅编写自己的 ML 后端。
1、创建后端服务
地址:GitHub - HumanSignal/label-studio-ml-backend: Configs and boilerplates for Label Studio's Machine Learning backend
终端导航至本地仓库目录 :
#用清华的源会快一点
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
#创建自己的后端服务
label-studio-ml create Stopsign_ml_backend
1.1、环境变量设置
增加环境变量:LABEL_STUDIO_URL,
LABEL_STUDIO_API_KEY
LABEL_STUDIO_URL: LS的IP端口号,如:127.0.0.1:8080
LABEL_STUDIO_API_KEY:LS中个人账户的秘钥
1.2、修改model.py文件
实现predict函数,对于目标检测模型:
from typing import List, Dict, Optional
from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.response import ModelResponse
from label_studio_ml.utils import get_single_tag_keys, get_local_path
import requests, os
from ultralytics import YOLO
from PIL import Image
from io import BytesIO
LS_URL = os.environ['LABEL_STUDIO_URL']
LS_API_TOKEN = os.environ['LABEL_STUDIO_API_KEY']
class YOLOv8Model(LabelStudioMLBase):
"""Custom ML Backend model
"""
def setup(self):
"""Configure any parameters of your model here
"""
self.set("model_version", "0.0.1")
self.from_name, self.to_name, self.value, self.classes = get_single_tag_keys(
self.parsed_label_config, 'RectangleLabels', 'Image')
self.model = YOLO("D:\\Label-stutio-ml-backend\\Stopsign_ml_backend\\best.pt")
self.labels = self.model.names
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
task = tasks[0]
# header = {
# "Authorization": "Token " + LS_API_TOKEN}
# image = Image.open(BytesIO(requests.get(
# LS_URL + task['data']['image'], headers=header).content))
url = tasks[0]['data']['image']
print(f'url is: {url}')
image_path = self.get_local_path(url=url,ls_host=LS_URL,task_id=tasks[0]['id'])
print(f'image_path: {image_path}')
image = Image.open(image_path)
original_width, original_height = image.size
predictions = []
score = 0
i = 0
results = self.model.predict(image,conf=0.5)
for result in results:
for i, prediction in enumerate(result.boxes):
xyxy = prediction.xyxy[0].tolist()
predictions.append({
"id": str(i),
"from_name": self.from_name,
"to_name": self.to_name,
"type": "rectanglelabels",
"score": prediction.conf.item(),
"original_width": original_width,
"original_height": original_height,
"image_rotation": 0,
"value": {
"rotation": 0,
"x": xyxy[0] / original_width * 100,
"y": xyxy[1] / original_height * 100,
"width": (xyxy[2] - xyxy[0]) / original_width * 100,
"height": (xyxy[3] - xyxy[1]) / original_height * 100,
"rectanglelabels": [self.labels[int(prediction.cls.item())]]
}
})
score += prediction.conf.item()
print(f"Prediction Score is {score:.3f}.")
final_prediction = [{
"result": predictions,
"score": score / (i + 1),
"model_version": "v8n"
}]
return ModelResponse(predictions=final_prediction)
def fit(self, event, data, **kwargs):
"""
This method is called each time an annotation is created or updated
You can run your logic here to update the model and persist it to the cache
It is not recommended to perform long-running operations here, as it will block the main thread
Instead, consider running a separate process or a thread (like RQ worker) to perform the training
:param event: event type can be ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING')
:param data: the payload received from the event (check [Webhook event reference](https://labelstud.io/guide/webhook_reference.html))
"""
# use cache to retrieve the data from the previous fit() runs
old_data = self.get('my_data')
old_model_version = self.get('model_version')
print(f'Old data: {old_data}')
print(f'Old model version: {old_model_version}')
# store new data to the cache
self.set('my_data', 'my_new_data_value')
self.set('model_version', 'my_new_model_version')
print(f'New data: {self.get("my_data")}')
print(f'New model version: {self.get("model_version")}')
print('fit() completed successfully.')
1.3、启动服务
label-studio-ml start Stopsign_ml_backend -p 9091
2、LS前端配置
在项目设置页面设置模型,打开交互预标注
在标注页面打开新的图片,出现缓冲条表示在向后台请求预测数据
预测成功如下图所示,会多出一个标注,如果没有则是请求数据错误,请检查后端服务配置
这里用的是一个yoloV8-OBB模型,带方向的矩形框,它的Model.py参考这里
https://download.csdn.net/download/weixin_42253874/89820948