对paddle OCR中的模型转换成ONNX。
转换代码:
import os
import sys
import yaml
import numpy as np
import cv2
import argparse
import paddle
from paddle import nn
from argparse import ArgumentParser, RawDescriptionHelpFormatter
import paddle.distributed as dist
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.modeling.architectures import build_model
class AttrDict(dict):
"""Single level attribute dict, NOT recursive"""
def __init__(self, **kwargs):
super(AttrDict, self).__init__()
super(AttrDict, self).update(kwargs)
def __getattr__(self, key):
if key in self:
return self[key]
raise AttributeError("object has no attribute '{}'".format(key))
global_config = AttrDict()
default_config = {'Global': {'debug': False, }}
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
# self.add_argument("-c", "--config", default='./configs/ch_PP-OCRv2_rec_idcard.yml',
# help="configuration file to use")
self.add_argument("-c", "--config", default='./configs/ch_PP-OCRv2_rec.yml',
help="configuration file to use")
self.add_argument(
"-o", "--opt", nargs='+', help="set configuration options")
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
return args
def _parse_opt(self, opts):
config = {}
if not opts:
return config
for s in opts:
s = s.strip()
k, v = s.split('=')
config[k] = yaml.load(v, Loader=yaml.Loader)
return config
def merge_config(config):
"""
Merge config into global config.
Args:
config (dict): Config to be merged.
Returns: global config
"""
for key, value in config.items():
if "." not in key:
if isinstance(value, dict) and key in global_config:
global_config[key].update(value)
else:
global_config[key] = value
else:
sub_keys = key.split('.')
assert (
sub_keys[0] in global_config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
global_config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
cur[sub_key] = value
else:
cur = cur[sub_key]
def load_config(file_path):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
merge_config(default_config)
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
return global_config
def check_device(use_gpu, use_xpu=False):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config {} cannot be set as true while your paddle " \
"is not compiled with {} ! \nPlease try: \n" \
"\t1. Install paddlepaddle to run model on {} \n" \
"\t2. Set {} as false in config file to run " \
"model on CPU"
try:
if use_gpu and use_xpu:
print("use_xpu and use_gpu can not both be ture.")
if use_gpu and not paddle.is_compiled_with_cuda():
print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
sys.exit(1)
if use_xpu and not paddle.device.is_compiled_with_xpu():
print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
sys.exit(1)
except Exception as e:
pass
def getArgs(is_train=False):
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
use_xpu = False
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
'Gestalt', 'SLANet', 'RobustScanner'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
check_device(use_gpu, use_xpu)
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
return config, device
class CRNN(nn.Layer):
def __init__(self, config, device):
super(CRNN, self).__init__()
# 定义预处理参数
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
self.mean = paddle.to_tensor(mean).reshape([1, 3, 1, 1])
self.std = paddle.to_tensor(std).reshape([1, 3, 1, 1])
self.config = config
# build post process
self.post_process_class = build_post_process(config['PostProcess'],
config['Global'])
# build model
if hasattr(self.post_process_class, 'character'):
char_num = len(getattr(self.post_process_class, 'character'))
if self.config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in self.config['Architecture']["Models"]:
if self.config['Architecture']['Models'][key]['Head'][
'name'] == 'MultiHead': # for multi head
out_channels_list = {}
if self.config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
self.config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
self.config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
elif self.config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head
out_channels_list = {}
if self.config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
self.config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model
self.config['Architecture']["Head"]['out_channels'] = char_num
# 加载模型
self.model = build_model(config['Architecture'])
# load_model(config, self.model)
init_model(self.config, self.model)
self.model.eval()
def forward(self, x):
# x = paddle.transpose(x, [0,3,1,2])
# x = x / 255.0
# x = (x - self.mean) / self.std
model_out = self.model(x)
# return model_out
preds_idx = model_out.argmax(axis=2, name='class').astype('float32')
# preds_idx = model_out.argmax(axis=2, name='class')
preds_prob = model_out.max(axis=2, name='score').astype('float32')
return preds_idx, preds_prob
EXPORT_ONNX = True
DYNAMIC = False
if __name__ == '__main__':
config, device = getArgs()
model_crnn = CRNN(config, device=device)
# 构建输入数据images:
image_path = "1.jpg"
img = cv2.imread(image_path)
img = cv2.resize(img, (320, 32))
print('input data:', img.shape)
img = img.astype(np.float32)
img = img.transpose((2, 0, 1)) / 255
input_data = img[np.newaxis, :]
print('input data:', input_data.shape)
x = paddle.to_tensor(input_data)
print('input data:', x.shape)
output_idx, output_prob = model_crnn(x)
print('output_idx: ', output_idx)
print('output_prob: ', output_prob)
input_spec = paddle.static.InputSpec.from_tensor(x, name='input')
onnx_save_path = "./export_onnx"
if EXPORT_ONNX:
onnx_model_name = onnx_save_path + "/char_recognize_20230526_v1"
if DYNAMIC:
input_spec = paddle.static.InputSpec(
shape=[None, 32, 320, 3], dtype='float32', name='input')
# ONNX模型导出
paddle.onnx.export(model_crnn, onnx_model_name, input_spec=[input_spec], opset_version=11,
enable_onnx_checker=True, output_spec=[output_idx, output_prob])
转换后的网络结构绘制出来,绘制使用的工具Netron
绘制出的起始和末尾的网络结构:
测试ONNX的代码:
'''
测试转出的onnx模型
'''
import cv2
import numpy as np
import torch
import onnxruntime as rt
import math
import os
class TestOnnx:
def __init__(self, onnx_file, character_dict_path, use_space_char=True):
self.sess = rt.InferenceSession(onnx_file)
# 获取输入节点名称
self.input_names = [input.name for input in self.sess.get_inputs()]
# 获取输出节点名称
self.output_names = [output.name for output in self.sess.get_outputs()]
self.character = []
self.character.append("blank")
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character.append(line)
if use_space_char:
self.character.append(" ")
def resize_norm_img(self, img, image_shape=[3, 32, 320]):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
# # 准备模型运行的feed_dict
def process(self, input_names, image):
feed_dict = dict()
for input_name in input_names:
feed_dict[input_name] = image
return feed_dict
def get_ignored_tokens(self):
return [0]
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[
batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [
self.character[int(text_id)].replace('\n', '')
for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def test(self, image_path):
img_onnx = cv2.imread(image_path)
# img_onnx = cv2.resize(img_onnx, (320, 32))
# img_onnx = img_onnx.transpose((2, 0, 1)) / 255
img_onnx = self.resize_norm_img(img_onnx)
onnx_indata = img_onnx[np.newaxis, :, :, :]
onnx_indata = torch.from_numpy(onnx_indata)
# print('diff:', onnx_indata - input_data)
print('image shape: ', onnx_indata.shape)
onnx_indata = np.array(onnx_indata, dtype=np.float32)
feed_dict = self.process(self.input_names, onnx_indata)
output_onnx = self.sess.run(self.output_names, feed_dict)
# print('output1 shape: ', output_onnx[0].shape)
# print('output1: ', output_onnx[0])
# print('output2 shape: ', output_onnx[1].shape)
# print('output2: ', output_onnx[1])
preds_idx = output_onnx[0]
preds_prob = output_onnx[1]
post_result = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if isinstance(post_result, dict):
rec_info = dict()
for key in post_result:
if len(post_result[key][0]) >= 2:
rec_info[key] = {
"label": post_result[key][0][0],
"score": float(post_result[key][0][1]),
}
print(image_path, rec_info)
else:
if len(post_result[0]) >= 2:
# info = post_result[0][0] + "\t" + str(post_result[0][1])
info = post_result[0][0]
print(image_path, info)
if __name__=='__main__':
image_dir = "./sample/img"
onnx_file = './export_onnx/char_recognize_20230526_v1.onnx'
character_dict_path = './all_label_num_20230517.txt'
testobj = TestOnnx(onnx_file, character_dict_path)
files = os.listdir(image_dir)
for file in files:
image_path = os.path.join(image_dir, file)
result = testobj.test(image_path)
模型转换结束。