1. 流程说明
ts文件夹下,
从launcher.py进入,执行jar文件。
入口为model_server.py的start()函数。内容包含:
- 读取args,创建pid文件
- 找到java,启动model-server.jar程序,同时读取log-config文件,TEMP文件夹地址,TS_CONFIG_FILE文件
- 根据cpu核数、gpu个数,启动多进程。每个进程有一个socket_name和socket_type,执行model_service_worker.py,创建TorchModelServiceWorker类,并执行run_server方法。run_server不断执行handle_connection方法,handle_connection不断执行predict(cmd为I时)或者load_model(cmd为L时)任务。
- load_model可以返回service对象,而service可以执行predict函数。如果handler中间包含冒号,则用后面的function作为_entry_point,否则默认用handle函数作为_entry_point。
- service的定义如下。其中manifest是一个字典,记录在MAR包里面的MAR_INF/MANIFEST.json中,包含modelName,serializedFile,handler,modelVersion等信息。这些信息也是modelArchiver打包模型时需要的内容。
class Service(object):
"""
Wrapper for custom entry_point
"""
def __init__(
self,
model_name,
model_dir,
manifest,
entry_point,
gpu,
batch_size,
limit_max_image_pixels=True,
metrics_cache=None,
):
- 接下来看一下predict函数。首先是调用retrieve_data_for_inference方法获取input_batch,其格式为
{parameter["name"]: parameter["value"]}
。然后是调用ret = self._entry_point(input_batch, self.context),这里的_entry_point就是我们自己定义的handler.handle方法。默认的handle方法执行三步:
data_preprocess = self.preprocess(data)
output = self.inference(data_preprocess)
output = self.postprocess(output)
2. 运行
- 首先安装java,然后
pip install torchserve torch-model-archiver
- 接着将模型和参数打包:
torch-model-archiver --model-name test --version 1.0 --serialized-file test.torchscript.pt --handler handler_test.py --export-path model_store
- 启动服务
torchserve --start --ncs --model-store model_store --models test.mar --disable-token-auth --ts-config config.properties
- 停止服务
torchserve --stop
- 调用:
res = requests.post("http://127.0.0.1:8080/predictions/test",files = {"data":data})