对paddleOCR中的字符识别模型转ONNX

news2024/11/15 21:09:01

对paddle OCR中的模型转换成ONNX。

转换代码:



import os
import sys
import yaml
import numpy as np
import cv2
import argparse
import paddle
from paddle import nn

from argparse import ArgumentParser, RawDescriptionHelpFormatter
import paddle.distributed as dist
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.modeling.architectures import build_model


class AttrDict(dict):
    """Single level attribute dict, NOT recursive"""

    def __init__(self, **kwargs):
        super(AttrDict, self).__init__()
        super(AttrDict, self).update(kwargs)

    def __getattr__(self, key):
        if key in self:
            return self[key]
        raise AttributeError("object has no attribute '{}'".format(key))

global_config = AttrDict()
default_config = {'Global': {'debug': False, }}

class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        # self.add_argument("-c", "--config", default='./configs/ch_PP-OCRv2_rec_idcard.yml',
        #                   help="configuration file to use")

        self.add_argument("-c", "--config", default='./configs/ch_PP-OCRv2_rec.yml',
                          help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config

def merge_config(config):
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
    for key, value in config.items():
        if "." not in key:
            if isinstance(value, dict) and key in global_config:
                global_config[key].update(value)
            else:
                global_config[key] = value
        else:
            sub_keys = key.split('.')
            assert (
                sub_keys[0] in global_config
            ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
                global_config.keys(), sub_keys[0])
            cur = global_config[sub_keys[0]]
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]

def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
    merge_config(default_config)
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
    merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
    return global_config

def check_device(use_gpu, use_xpu=False):
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
    err = "Config {} cannot be set as true while your paddle " \
          "is not compiled with {} ! \nPlease try: \n" \
          "\t1. Install paddlepaddle to run model on {} \n" \
          "\t2. Set {} as false in config file to run " \
          "model on CPU"

    try:
        if use_gpu and use_xpu:
            print("use_xpu and use_gpu can not both be ture.")
        if use_gpu and not paddle.is_compiled_with_cuda():
            print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
            sys.exit(1)
        if use_xpu and not paddle.device.is_compiled_with_xpu():
            print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
            sys.exit(1)
    except Exception as e:
        pass

def getArgs(is_train=False):
    FLAGS = ArgsParser().parse_args()
    config = load_config(FLAGS.config)
    merge_config(FLAGS.opt)

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']

    use_xpu = False

    alg = config['Architecture']['algorithm']
    assert alg in [
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
        'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
        'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
        'Gestalt', 'SLANet', 'RobustScanner'
    ]

    device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
    check_device(use_gpu, use_xpu)

    device = paddle.set_device(device)

    config['Global']['distributed'] = dist.get_world_size() != 1

    return config, device


class CRNN(nn.Layer):
    def __init__(self, config, device):
        super(CRNN, self).__init__()
        # 定义预处理参数
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
        self.mean = paddle.to_tensor(mean).reshape([1, 3, 1, 1])
        self.std = paddle.to_tensor(std).reshape([1, 3, 1, 1])

        self.config = config
        # build post process
        self.post_process_class = build_post_process(config['PostProcess'],
                                                     config['Global'])
        # build model
        if hasattr(self.post_process_class, 'character'):
            char_num = len(getattr(self.post_process_class, 'character'))
            if self.config['Architecture']["algorithm"] in ["Distillation",
                                                            ]:  # distillation model
                for key in self.config['Architecture']["Models"]:
                    if self.config['Architecture']['Models'][key]['Head'][
                        'name'] == 'MultiHead':  # for multi head
                        out_channels_list = {}
                        if self.config['PostProcess'][
                            'name'] == 'DistillationSARLabelDecode':
                            char_num = char_num - 2
                        out_channels_list['CTCLabelDecode'] = char_num
                        out_channels_list['SARLabelDecode'] = char_num + 2
                        self.config['Architecture']['Models'][key]['Head'][
                            'out_channels_list'] = out_channels_list
                    else:
                        self.config['Architecture']["Models"][key]["Head"][
                            'out_channels'] = char_num
            elif self.config['Architecture']['Head'][
                'name'] == 'MultiHead':  # for multi head
                out_channels_list = {}
                if self.config['PostProcess']['name'] == 'SARLabelDecode':
                    char_num = char_num - 2
                out_channels_list['CTCLabelDecode'] = char_num
                out_channels_list['SARLabelDecode'] = char_num + 2
                self.config['Architecture']['Head'][
                    'out_channels_list'] = out_channels_list
            else:  # base rec model
                self.config['Architecture']["Head"]['out_channels'] = char_num

        # 加载模型
        self.model = build_model(config['Architecture'])
        # load_model(config, self.model)
        init_model(self.config, self.model)
        self.model.eval()

    def forward(self, x):
        # x = paddle.transpose(x, [0,3,1,2])
        # x = x / 255.0
        # x = (x - self.mean) / self.std

        model_out = self.model(x)

        # return model_out
        preds_idx = model_out.argmax(axis=2, name='class').astype('float32')
        # preds_idx = model_out.argmax(axis=2, name='class')
        preds_prob = model_out.max(axis=2, name='score').astype('float32')
        return preds_idx, preds_prob

EXPORT_ONNX = True
DYNAMIC = False

if __name__ == '__main__':
    config, device = getArgs()
    model_crnn = CRNN(config, device=device)

    # 构建输入数据images:
    image_path = "1.jpg"
    img = cv2.imread(image_path)
    img = cv2.resize(img, (320, 32))
    print('input data:', img.shape)
    img = img.astype(np.float32)
    img = img.transpose((2, 0, 1)) / 255
    input_data = img[np.newaxis, :]
    print('input data:', input_data.shape)
    x = paddle.to_tensor(input_data)
    print('input data:', x.shape)

    output_idx, output_prob = model_crnn(x)
    print('output_idx: ', output_idx)
    print('output_prob: ', output_prob)

    input_spec = paddle.static.InputSpec.from_tensor(x,  name='input')
    onnx_save_path = "./export_onnx"
    if EXPORT_ONNX:
        onnx_model_name = onnx_save_path + "/char_recognize_20230526_v1"
        if DYNAMIC:
            input_spec = paddle.static.InputSpec(
                shape=[None, 32, 320, 3], dtype='float32',  name='input')

        # ONNX模型导出
        paddle.onnx.export(model_crnn, onnx_model_name, input_spec=[input_spec], opset_version=11,
                           enable_onnx_checker=True, output_spec=[output_idx, output_prob])

转换后的网络结构绘制出来,绘制使用的工具Netron

 绘制出的起始和末尾的网络结构:

测试ONNX的代码:

'''
测试转出的onnx模型
'''
import cv2
import numpy as np

import torch
import onnxruntime as rt
import math
import os

class TestOnnx:
    def __init__(self, onnx_file, character_dict_path, use_space_char=True):
        self.sess = rt.InferenceSession(onnx_file)
        # 获取输入节点名称
        self.input_names = [input.name for input in self.sess.get_inputs()]
        # 获取输出节点名称
        self.output_names = [output.name for output in self.sess.get_outputs()]

        self.character = []
        self.character.append("blank")
        with open(character_dict_path, "rb") as fin:
            lines = fin.readlines()
            for line in lines:
                line = line.decode('utf-8').strip("\n").strip("\r\n")
                self.character.append(line)
        if use_space_char:
            self.character.append(" ")

    def resize_norm_img(self, img, image_shape=[3, 32, 320]):
        imgC, imgH, imgW = image_shape
        h = img.shape[0]
        w = img.shape[1]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        if image_shape[0] == 1:
            resized_image = resized_image / 255
            resized_image = resized_image[np.newaxis, :]
        else:
            resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    # # 准备模型运行的feed_dict
    def process(self, input_names, image):
        feed_dict = dict()
        for input_name in input_names:
            feed_dict[input_name] = image

        return feed_dict

    def get_ignored_tokens(self):
        return [0]

    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
        """ convert text-index into text-label. """
        result_list = []
        ignored_tokens = self.get_ignored_tokens()
        batch_size = len(text_index)
        for batch_idx in range(batch_size):
            selection = np.ones(len(text_index[batch_idx]), dtype=bool)
            if is_remove_duplicate:
                selection[1:] = text_index[batch_idx][1:] != text_index[
                                                                 batch_idx][:-1]
            for ignored_token in ignored_tokens:
                selection &= text_index[batch_idx] != ignored_token

            char_list = [
                self.character[int(text_id)].replace('\n', '')
                for text_id in text_index[batch_idx][selection]
            ]
            if text_prob is not None:
                conf_list = text_prob[batch_idx][selection]
            else:
                conf_list = [1] * len(selection)
            if len(conf_list) == 0:
                conf_list = [0]

            text = ''.join(char_list)
            result_list.append((text, np.mean(conf_list).tolist()))

        return result_list

    def test(self, image_path):
        img_onnx = cv2.imread(image_path)
        # img_onnx = cv2.resize(img_onnx, (320, 32))
        # img_onnx = img_onnx.transpose((2, 0, 1)) / 255
        img_onnx = self.resize_norm_img(img_onnx)
        onnx_indata = img_onnx[np.newaxis, :, :, :]
        onnx_indata = torch.from_numpy(onnx_indata)
        # print('diff:', onnx_indata - input_data)
        print('image shape: ', onnx_indata.shape)
        onnx_indata = np.array(onnx_indata, dtype=np.float32)
        feed_dict = self.process(self.input_names, onnx_indata)

        output_onnx = self.sess.run(self.output_names, feed_dict)
        # print('output1 shape: ', output_onnx[0].shape)
        # print('output1: ', output_onnx[0])
        # print('output2 shape: ', output_onnx[1].shape)
        # print('output2: ', output_onnx[1])

        preds_idx = output_onnx[0]
        preds_prob = output_onnx[1]
        post_result = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)

        if isinstance(post_result, dict):
            rec_info = dict()
            for key in post_result:
                if len(post_result[key][0]) >= 2:
                    rec_info[key] = {
                        "label": post_result[key][0][0],
                        "score": float(post_result[key][0][1]),
                    }
            print(image_path, rec_info)
        else:
            if len(post_result[0]) >= 2:
                # info = post_result[0][0] + "\t" + str(post_result[0][1])
                info = post_result[0][0]
            print(image_path, info)




if __name__=='__main__':
    image_dir = "./sample/img"
    onnx_file = './export_onnx/char_recognize_20230526_v1.onnx'
    character_dict_path = './all_label_num_20230517.txt'

    testobj = TestOnnx(onnx_file, character_dict_path)

    files = os.listdir(image_dir)
    for file in files:
        image_path = os.path.join(image_dir, file)
        result = testobj.test(image_path)




模型转换结束。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1083784.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

没有执行力,一切都是空谈!如何提高执行力

决定人生高度的并非空谈,而是实干,没有执行力一切都是零。 执行力对于达成目标至关重要。即使将目标细分拆解得再细致,若无法切实执行,一切仍然是徒劳。 一旦制定目标,必须进行层层细分的拆解,包括每日的…

信钰证券:汇金增持提振市场情绪 保险、银行等板块集体拉升

12日,两市股指盘中全线走高,沪指一度克复3100点,上证50指数涨超1%。 稳妥、银行、券商板块团体拉升,到发稿,银行板块方面,瑞丰银行涨约6%,盘中一度涨停;紫金银行、渝农银行、西安银…

Python获取本机IP地址的几种方式~转

Python获取本机IP地址的几种方式 目录 1、使用专用网址 2、使用自带socket库 3、使用第三方netifaces库 1、使用专用网站 获取的是公网IP。 网址: http://myip.ipip.net 代码: import requests res requests.get(https://myip.ipip.net, timeout5)…

【C++】C++11 —— 右值引用和移动语义

​ ​📝个人主页:Sherry的成长之路 🏠学习社区:Sherry的成长之路(个人社区) 📖专栏链接:C学习 🎯长路漫漫浩浩,万事皆有期待 上一篇博客:【C】C11…

Java网络编程1

Java网络编程1 网络相关概念 把java网络编程的基础知识学习完之后,我们才会更加了解那些高性能的网络框架像neety它为什么要这样设计?才能把知识掌握的更加清晰。 网络通信 1)概念:两台设备之间,通过网络&#xff0c…

2023年中国车用磁传感器市场发展趋势分析:未来市场规模将保持较高速增长趋势[图]

磁传感器是把磁场、电流、应力应变、温度、光等外界因素引起敏感元件磁性能变化转换成电信号,以这种方式来检测相应物理量的器件。磁传感器广泛用于现代工业和电子产品中以感应磁场强度来测量电流、位置、方向等物理参数。在现有技术中,有许多不同类型的…

最新科技喜报!统一图像和文字生成的MiniGPT-5来了!

原创 | 文 BFT机器人 当前视觉和语言模型的应用非常广泛,包括多模态对话代理、先进的内容创作工具等。这些模型的多模态特征集成不仅是一种发展趋势,更是一项关键的进步,正在塑造着各种应用程序。 那如何在视觉和语言之间建立有效的联系&…

Matlab地理信息绘图—数据诊断

文章目录 数据诊断分析(均值方差)Matlab代码实现结果展示 数据诊断分析(均值方差) 均值方差检测是一种简单但有效的异常检测方法,主要基于样本的均值和方差的统计信息。该方法的核心思想是假设正常的样本点应该聚集在…

用Cmake快速生成vs工程

文章目录 1 安装cmake2 生成vs工程 1 安装cmake 官方网址: https://cmake.org/download/ 打开官网,根据自己需求下载所需文件。(本人是安装在Windows10-x64平台上,所以下文步骤均基于此平台) 下载好后,双…

大数据之Hudi数据湖_基本概念_时间轴_TimeLine---大数据之Hudi数据湖工作笔记0005

然后看一下hudi的,时间轴概念,很简单了,就是之前说的时间旅行,其实就是 比如在某个时间点,记录,这个时间点做了什么,就是这个意思 然后像回去看看的时候,可以找到这个时间点做了什么 一个时间点就是一个Instant (时刻 瞬间的意思) 可以看到时刻的解释 instant 时刻instant包…

【信创】 JED on 鲲鹏(ARM) 调优步骤与成果 | 京东云技术团队

项目背景 基于国家对信创项目的大力推进,为了自主可控的技术发展,基础组件将逐步由国产组件替代,因此从数据库入手,将弹性库JED部署在 国产华为鲲鹏机器上(基于ARM架构)进行调优,与Intel (X86)进行性能对比。 物理机…

基于全息感知的智慧高速IT设施监控运维方案

作为智能交通的重要细分领域,建设智慧高速是实施交通强国战略的重要基础。在信息化时代,交通行业已经依托信息化建设取得了显著的成果,其中以收费网络、办公网络、监控网络和通讯网络为基础的网络架构已经形成,并且正在逐步完善网…

Nginx proxy_set_header参数设置

一、不设置 proxy_set_header Host 不设置 proxy_set_header Host 时,浏览器直接访问 nginx,获取到的 Host 是 proxy_pass 后面的值,即 $proxy_host 的值,参考Module ngx_http_proxy_module 1 2 3 4 5 6 7 8 # cat ngx_header.c…

NIO基础-ByteBuffer,Channel

文章目录 1. 三大组件1.1 Channel1.2 Buffer1.2 Selector 2.ByteBuffer2.1 ByteBuffer 正确使用姿势2.2 ByteBuffer 结构2.3 ByteBuffer 常见方法分配空间向 buffer 写入数据从 buffer 读取数据mark 和 reset字符串与 ByteBuffer 互转分散度集中写byteBuffer黏包半包 3. 文件编…

简历石层大海,为何今年秋招那么难?技术面考官想听啥?

上个月发完关于《2023年的IC求职究竟有多难?》文章,后台就出现很多私信,大家都在频繁的问秋招的事情,今年的秋招提前批让很多人直接破防,感觉书读了那么久,学校也还不错,但是为什么企业招聘的简…

单车模型:横向动力学

文章目录 1 模型推导2 参考资料 较高车速下,不能再假设车轮朝向和车轮速度一致。因此运动学模型在这里的误差就会比较大,必须要考虑动力学模型。 现考虑2自由度单车模型,如下图所示。2自由度表示为: 车辆横线位置 y y y&#xff…

2023-2024-1 高级语言程序设计实验一: 选择结构

7-1 古时年龄称谓知多少? 输入一个人的年龄(岁),判断出他属于哪个年龄段 ? 0-9 :垂髫之年; 10-19: 志学之年; 20-29 :弱冠之年; 30-39 &#…

Docker开启远程访问+idea配置docker+dockerfile发布java项目

一、docker开启远程访问 1.编辑docker服务文件 vim /usr/lib/systemd/system/docker.servicedocker.service原文件如下: [Unit] DescriptionDocker Application Container Engine Documentationhttps://docs.docker.com Afternetwork-online.target docker.socke…

【深蓝学院】手写VIO第7章--VINS初始化和VIO系统--笔记

0. 内容 1. VIO回顾 整个视觉前端pipeline回顾: 两帧图像,可提取特征点,特征匹配(描述子暴力匹配或者光流)已知特征点匹配关系,利用几何约束计算relative pose([R|t]),translation只有方向&…

2023年中国睡眠检测仪产量、销量及市场规模分析[图]

睡眠检测仪行业是指生产和销售用于监测和评估人类睡眠质量和睡眠相关指标的设备和工具的行业。睡眠检测仪可以通过监测人体的脑电图、心率、呼吸、体动等生理信号,来评估睡眠的深度、时长、睡眠阶段的分布等信息,帮助人们了解自己的睡眠状况,…