nnunetv2系列:使用onnx模型参数利用onnxruntime推理
首先感谢https://blog.csdn.net/chen_niansan/article/details/142328247
作者分享的示例,这里在此基础上进行修改和增加了将预测结果转换到方便人查看的掩码图。
import os
import numpy as np
from cv2 import (
imread,
imwrite,
cvtColor,
COLOR_BGR2RGB,
resize,
INTER_LINEAR,
)
from time import time
from onnxruntime import (
InferenceSession,
# get_device,
)
def preprocess_image(image, input_size, in_ch = 3):
# image = np.array(image)
image_resized = resize(
image, input_size,
interpolation=INTER_LINEAR
)
# print(image_resized.shape)
image_normalized = image_resized / 255.0
if len(image_normalized.shape) == 2:
image_normalized = np.stack(
[image_normalized] * in_ch,
axis=-1
)
image_normalized = np.transpose(image_normalized, (2, 0, 1))
image_normalized = np.expand_dims(image_normalized, axis=0)
return image_normalized.astype(np.float32)
if __name__ == "__main__":
tic = time()
# 设置环境变量以指定使用0号GPU
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
image_path = "./img_2_0000.png"
image_pred_path = "./img_2_0000.png"
seg_model_path = "/xxx/checkpoint_best_fold_0.onnx"
if not os.path.exists(seg_model_path):
raise FileNotFoundError(
f"Model file {seg_model_path} does not exist."
)
# device = get_device()
device = "CPU"
providers = (
["CUDAExecutionProvider"]
if device != "CPU"
else ["CPUExecutionProvider"]
)
seg_model = InferenceSession(
seg_model_path,
providers=providers
)
image = imread(image_path)
image = cvtColor(image, COLOR_BGR2RGB)
# 尺寸参考plans.json中的patch_size=(h, w)
image_input = preprocess_image(
image,
(w, h),
in_ch=3
)
# 获取输入和输出的名称
# input_name = seg_model.get_inputs()[0].name
# print(f"==>> input_name: {input_name}")
# output_name = seg_model.get_outputs()[0].name
# print(f"==>> output_name: {output_name}")
# 数组组成的列表
mask_pred_info = seg_model.run(
# [output_name],
# {input_name: image_input}
['output'],
{'input': image_input}
)
# print(f"==>> mask_pred_info: {mask_pred_info}")
mask_pred_batch = mask_pred_info[0]
mask_pred = np.squeeze(
mask_pred_batch, axis=0
)
mask_pred = np.argmax(mask_pred, axis=0)
# pred_pixel_value_list = np.unique(mask_pred)
# 恢复自定义的标签值
predict_recover_value_dict = {
0: 0,
1: 128,
2: 196,
3: 255
}
for predict_value, recover_value in predict_recover_value_dict.items():
mask_pred[mask_pred == predict_value] = recover_value
imwrite(image_pred_path, mask_pred)
print("time cost is: ", time() - tic)