PaddleOCR文字识别模型的FineTune

news2024/12/29 9:52:03

一、paddleOCR

paddle框架为百度开发的深度学习框架,其中对于文字检测、识别具有较为便利的开发条件。同时PaddleOCR文字识别工具较为轻量化,并可按照任务需求进行model的finetune,满足实际的业务需求。
源码来源:githubOCR
在github中对该框架进行较为完整的描述,对于pretrainmodel、parameter进行明确的标注与说明,同时PP飞桨提供每日免费算力进行大模型的finetune,可以用于模型框架的初步认识与训练,满足实际的业务需求。

二、源码解读

1、源码下载

在这里插入图片描述

2、移步开发流程

在该文档中对于模型的开始,安装、预训练、效果展示等都进行明确的描述。
在这里插入图片描述
在这里插入图片描述

3、以CCPD数据集为例进行PaddleOCRmodel的预训练,提高文字识别的精度

(1)首先对下载好的文件夹进行解压,才VScode中进行打开。//任意编辑器
先找到其中的configs 配置文件,其中det表示文字检测model yaml配置文件,rec表示文字识别model yaml配置文件。
在这里插入图片描述
本博客以该配置文件为例,对配置文件中的描述进行解析。
在这里插入图片描述

模型配置文件
Global:
  debug: false      # 是否debug
  use_gpu: true     # 训练过程是否使用GPU
  epoch_num: 200      # 训练的轮次
  log_smooth_window: 20    # 训练日志平滑窗口
  print_batch_step: 10   # 训练时每多少步打印一次日志
  save_model_dir: ./output/rec_ppocr_v4    # 保存模型的地址
  save_epoch_step: 10   # 保存模型的间隔
  eval_batch_step: [0, 2000]   # 验证数据的间隔
  cal_metric_during_train: true   # 是否计算验证集的指标
  pretrained_model:                   # 预训练模型的地址  // 可以在这里pretrain的地址,也可以使用命令行输入
  checkpoints:                  # 
  save_inference_dir:            # 保存预测结果的地址
  use_visualdl: false             # 是否使用visualdl
  infer_img: docs\images\00006737.jpg    # 验证图片的地址        
  character_dict_path: ppocr/utils/ppocr_keys_v1.txt
  max_text_length: &max_text_length 25
  infer_mode: false
  use_space_char: true
  distributed: true
  save_res_path: ./output/rec/predicts_ppocrv3.txt
  d2s_train_image_shape: [3, 48, 320]


Optimizer:    # 优化器
  name: Adam  
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.001
    warmup_epoch: 5
  regularizer:
    name: L2
    factor: 3.0e-05


Architecture:      # 模型的配置
  model_type: rec   # 模型的类型 rec model
  algorithm: SVTR_LCNet     
  Transform:
  Backbone:    # 主干网络
    name: PPLCNetV3    
    scale: 0.95
  Head:             # 输出头网络 
    name: MultiHead   
    head_list:
      - CTCHead:
          Neck:
            name: svtr
            dims: 120    # 维度
            depth: 2     # 深度
            hidden_dims: 120
            kernel_size: [1, 3]
            use_guide: True
          Head:
            fc_decay: 0.00001
      - NRTRHead:
          nrtr_dim: 384
          max_text_length: *max_text_length

Loss:        # 损失函数
  name: MultiLoss
  loss_config_list:
    - CTCLoss:
    - NRTRLoss:

PostProcess:  
  name: CTCLabelDecode

Metric:
  name: RecMetric
  main_indicator: acc

Train:               # 训练数据集的配置
  dataset:                
    name: MultiScaleDataSet  # 多尺度训练
    ds_width: false    
    data_dir: ./train_data/    # 训练数据集地址
    ext_op_transform_idx: 1
    label_file_list:     # 训练数据集索引标签
    - ./train_data/train_list.txt
    transforms:             # 数据增强的处理方式 
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - RecConAug:
        prob: 0.5
        ext_data_num: 2
        image_shape: [48, 320, 3]
        max_text_length: *max_text_length
    - RecAug:
    - MultiLabelEncode:
        gtc_encode: NRTRLabelEncode
    - KeepKeys:
        keep_keys:
        - image
        - label_ctc
        - label_gtc
        - length
        - valid_ratio
  sampler:                # 采样器
    name: MultiScaleSampler
    scales: [[320, 32], [320, 48], [320, 64]]
    first_bs: &bs 192
    fix_bs: false
    divided_factor: [8, 16] # w, h
    is_training: True
  loader:
    shuffle: true
    batch_size_per_card: *bs
    drop_last: true
    num_workers: 8    # 并行线路的最大线程
Eval: 
  dataset: 
    name: SimpleDataSet
    data_dir: ./train_data   # 验证数据集地址
    label_file_list:
    - ./train_data/val_list.txt    # 验证数据标签   
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - MultiLabelEncode:
        gtc_encode: NRTRLabelEncode
    - RecResizeImg:
        image_shape: [3, 48, 320]
    - KeepKeys:
        keep_keys:
        - image
        - label_ctc
        - label_gtc
        - length
        - valid_ratio
  loader:         # 验证集中数据加载器 dataloder 
    shuffle: false
    drop_last: false
    batch_size_per_card: 128
    num_workers: 4

在这里插入图片描述
在快速开始文档中可以查找具体的启动步骤与方式 // 以现有的det为例
在这里插入图片描述

---
comments: true
---

# PP-OCRv3 文本检测模型训练

## 1. 简介

PP-OCRv3在PP-OCRv2的基础上进一步升级。本节介绍PP-OCRv3检测模型的训练步骤。有关PP-OCRv3策略介绍参考[文档](../blog/PP-OCRv3_introduction.md)## 2. 检测训练

PP-OCRv3检测模型是对PP-OCRv2中的[CML](https://arxiv.org/pdf/2109.03144.pdf)(Collaborative Mutual Learning) 协同互学习文本检测蒸馏策略进行了升级。PP-OCRv3分别针对检测教师模型和学生模型进行进一步效果优化。其中,在对教师模型优化时,提出了大感受野的PAN结构LK-PAN和引入了DML(Deep Mutual Learning)蒸馏策略;在对学生模型优化时,提出了残差注意力机制的FPN结构RSE-FPN。

PP-OCRv3检测训练包括两个步骤:

- 步骤1:采用DML蒸馏方法训练检测教师模型
- 步骤2:使用步骤1得到的教师模型采用CML方法训练出轻量学生模型

### 2.1 准备数据和运行环境

训练数据采用icdar2015数据,准备训练集步骤参考[ocr_dataset](./dataset/ocr_datasets.md).

运行环境准备参考[文档](./installation.md)### 2.2 训练教师模型

教师模型训练的配置文件是[ch_PP-OCRv3_det_dml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.5/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml)。教师模型模型结构的Backbone、Neck、Head分别为Resnet50, LKPAN, DBHead,采用DML的蒸馏方法训练。有关配置文件的详细介绍参考[文档](./knowledge_distillation.md)。

下载ImageNet预训练模型:

```bash linenums="1"
# 下载ResNet50_vd的预训练模型
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams

启动训练

# 单卡训练
python3 tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml \
    -o Architecture.Models.Student.pretrained=./pretrain_models/ResNet50_vd_ssld_pretrained \
       Architecture.Models.Student2.pretrained=./pretrain_models/ResNet50_vd_ssld_pretrained \
       Global.save_model_dir=./output/
# 如果要使用多GPU分布式训练,请使用如下命令:
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml \
    -o Architecture.Models.Student.pretrained=./pretrain_models/ResNet50_vd_ssld_pretrained \
       Architecture.Models.Student2.pretrained=./pretrain_models/ResNet50_vd_ssld_pretrained \
       Global.save_model_dir=./output/

训练过程中保存的模型在output目录下,包含以下文件:

best_accuracy.states
best_accuracy.pdparams  # 默认保存最优精度的模型参数
best_accuracy.pdopt     # 默认保存最优精度的优化器相关参数
latest.states
latest.pdparams  # 默认保存的最新模型参数
latest.pdopt     # 默认保存的最新模型的优化器相关参数

其中,best_accuracy是保存的精度最高的模型参数,可以直接使用该模型评估。

模型评估命令如下:

python3 tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml -o Global.checkpoints=./output/best_accuracy

训练的教师模型结构更大,精度更高,用于提升学生模型的精度。

提取教师模型参数
best_accuracy包含两个模型的参数,分别对应配置文件中的Student,Student2。提取Student的参数方法如下:

import paddle
# 加载预训练模型
all_params = paddle.load("output/best_accuracy.pdparams")
# 查看权重参数的keys
print(all_params.keys())
# 模型的权重提取
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# 查看模型权重参数的keys
print(s_params.keys())
# 保存
paddle.save(s_params, "./pretrain_models/dml_teacher.pdparams")

提取出来的模型参数可以用于模型进一步的finetune训练或者蒸馏训练。

2.3 训练学生模型

训练学生模型的配置文件是ch_PP-OCRv3_det_cml.yml
上一节训练得到的教师模型作为监督,采用CML方式训练得到轻量的学生模型。

下载学生模型的ImageNet预训练模型:

# 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams

启动训练

# 单卡训练
python3 tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml \
    -o Architecture.Models.Student.pretrained=./pretrain_models/MobileNetV3_large_x0_5_pretrained \
       Architecture.Models.Student2.pretrained=./pretrain_models/MobileNetV3_large_x0_5_pretrained \
       Architecture.Models.Teacher.pretrained=./pretrain_models/dml_teacher \
       Global.save_model_dir=./output/
# 如果要使用多GPU分布式训练,请使用如下命令:
python3  -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml \
    -o Architecture.Models.Student.pretrained=./pretrain_models/MobileNetV3_large_x0_5_pretrained \
       Architecture.Models.Student2.pretrained=./pretrain_models/MobileNetV3_large_x0_5_pretrained \
       Architecture.Models.Teacher.pretrained=./pretrain_models/dml_teacher \
       Global.save_model_dir=./output/

训练过程中保存的模型在output目录下,
模型评估命令如下:

python3 tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.checkpoints=./output/best_accuracy

best_accuracy包含三个模型的参数,分别对应配置文件中的Student,Student2,Teacher。提取Student参数的方法如下:

import paddle
# 加载预训练模型
all_params = paddle.load("output/best_accuracy.pdparams")
# 查看权重参数的keys
print(all_params.keys())
# 模型的权重提取
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# 查看模型权重参数的keys
print(s_params.keys())
# 保存
paddle.save(s_params, "./pretrain_models/cml_student.pdparams")

提取出来的Student的参数可用于模型部署或者做进一步的finetune训练。

3. 基于PP-OCRv3检测finetune训练

本节介绍如何使用PP-OCRv3检测模型在其他场景上的finetune训练。

finetune训练适用于三种场景:

  • 基于CML蒸馏方法的finetune训练,适用于教师模型在使用场景上精度高于PP-OCRv3检测模型,且希望得到一个轻量检测模型。
  • 基于PP-OCRv3轻量检测模型的finetune训练,无需训练教师模型,希望在PP-OCRv3检测模型基础上提升使用场景上的精度。
  • 基于DML蒸馏方法的finetune训练,适用于采用DML方法进一步提升精度的场景。

基于CML蒸馏方法的finetune训练

下载PP-OCRv3训练模型:

wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
tar xf ch_PP-OCRv3_det_distill_train.tar

ch_PP-OCRv3_det_distill_train/best_accuracy.pdparams包含CML配置文件中Student、Student2、Teacher模型的参数。

启动训练:

# 单卡训练
python3 tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml \
    -o Global.pretrained_model=./ch_PP-OCRv3_det_distill_train/best_accuracy \
       Global.save_model_dir=./output/
# 如果要使用多GPU分布式训练,请使用如下命令:
python3  -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml \
    -o Global.pretrained_model=./ch_PP-OCRv3_det_distill_train/best_accuracy  \
       Global.save_model_dir=./output/

基于PP-OCRv3轻量检测模型的finetune训练

下载PP-OCRv3训练模型,并提取Student结构的模型参数:

wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
tar xf ch_PP-OCRv3_det_distill_train.tar

提取Student参数的方法如下:

import paddle
# 加载预训练模型
all_params = paddle.load("output/best_accuracy.pdparams")
# 查看权重参数的keys
print(all_params.keys())
# 模型的权重提取
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# 查看模型权重参数的keys
print(s_params.keys())
# 保存
paddle.save(s_params, "./student.pdparams")

使用配置文件ch_PP-OCRv3_det_student.yml训练。

启动训练

# 单卡训练
python3 tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml \
    -o Global.pretrained_model=./student \
       Global.save_model_dir=./output/
# 如果要使用多GPU分布式训练,请使用如下命令:
python3  -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml \
    -o Global.pretrained_model=./student \
       Global.save_model_dir=./output/

基于DML蒸馏方法的finetune训练

以ch_PP-OCRv3_det_distill_train中的Teacher模型为例,首先提取Teacher结构的参数,方法如下:

import paddle
# 加载预训练模型
all_params = paddle.load("ch_PP-OCRv3_det_distill_train/best_accuracy.pdparams")
# 查看权重参数的keys
print(all_params.keys())
# 模型的权重提取
s_params = {key[len("Teacher."):]: all_params[key] for key in all_params if "Teacher." in key}
# 查看模型权重参数的keys
print(s_params.keys())
# 保存
paddle.save(s_params, "./teacher.pdparams")

启动训练

# 单卡训练
python3 tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml \
    -o Architecture.Models.Student.pretrained=./teacher \
       Architecture.Models.Student2.pretrained=./teacher \
       Global.save_model_dir=./output/
# 如果要使用多GPU分布式训练,请使用如下命令:
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml \
    -o Architecture.Models.Student.pretrained=./teacher \
       Architecture.Models.Student2.pretrained=./teacher \
       Global.save_model_dir=./output/

三、数据集准备方式

rec文字识别图像数据集准备方式:
在这里插入图片描述
txt数据集准备方式:图像路径 + ” “(空格) + 文字
在这里插入图片描述
dict创建:
创建所有的 出现文字的字典集 == 可以采用默认。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2267388.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据库初阶】Ubuntu 环境安装 MySQL

🎉博主首页: 有趣的中国人 🎉专栏首页: 数据库初阶 🎉其它专栏: C初阶 | C进阶 | 初阶数据结构 小伙伴们大家好,本片文章将会讲解 Ubuntu 系统安装 MySQL 的相关内容。 如果看到最后您觉得这篇…

MoH:将多头注意力(Multi-Head Attention)作为头注意力混合(Mixture-of-Head Attention)

摘要 https://arxiv.org/pdf/2410.11842? 在本文中,我们对Transformer模型的核心——多头注意力机制进行了升级,旨在提高效率的同时保持或超越先前的准确度水平。我们表明,多头注意力可以表示为求和形式。鉴于并非所有注意力头都具有同等重…

AI助力古诗视频制作全流程化教程

AI助力古诗视频制作全流程化教程 目录 1. 制作视频的原材料(全自动) 2.文生图:图像生成(手动) 3.文生音频:TTS技术(全自动) 4.视频编辑(手动) 5.自动发…

基于SSM的“快递管理系统”的设计与实现(源码+数据库+文档+PPT)

基于SSM的“快递管理系统”的设计与实现(源码数据库文档PPT) 开发语言:Java 数据库:MySQL 技术:SSM 工具:IDEA/Ecilpse、Navicat、Maven 系统展示 登陆页面 注册页面 快递员页面 派单员订单管理页面 派单员订单添…

AWTK 在全志 tina linux 上支持 2D 图形加速

全志 tina linux 2D 图形加速插件。 开发环境为 全志 Tina Linux 虚拟机。 1. 准备 下载 awtk git clone https://github.com/zlgopen/awtk.git下载 awtk-linux-fb git clone https://github.com/zlgopen/awtk-linux-fb.git下载 awtk-tina-g2d git clone https://github.co…

Unity游戏环境交互系统

概述 交互功能使用同一个按钮或按钮列表,在不同情况下显示不同的内容,按下执行不同的操作。 按选项个数分类 环境交互系统可分为两种,单选项交互,一般使用射线检测;多选项交互,一般使用范围检测。第一人…

线性直流电流

电阻网络的等效 等效是指被化简的电阻网络与等效电阻具有相同的 u-i 关系 (即端口方程),从而用等效电阻代替电阻网络之后,不 改变其余部分的电压和电流。 串联等效: 并联等效: 星角变换 若这两个三端网络是等效的,从任…

攻防世界web第二题unseping

这是题目 <?php highlight_file(__FILE__);class ease{private $method;private $args;function __construct($method, $args) {$this->method $method;$this->args $args;}function __destruct(){if (in_array($this->method, array("ping"))) {cal…

[文献阅读]ReAct: Synergizing Reasoning and Acting in Language Models

文章目录 摘要Abstract:思考与行为协同化Reason(Chain of thought)ReAct ReAct如何协同推理 响应Action&#xff08;动作空间&#xff09;协同推理 结果总结 摘要 ReAct: Synergizing Reasoning and Acting in Language Models [2210.03629] ReAct: Synergizing Reasoning an…

ISDP010_基于DDD架构实现收银用例主成功场景

信息系统开发实践 &#xff5c; 系列文章传送门 ISDP001_课程概述 ISDP002_Maven上_创建Maven项目 ISDP003_Maven下_Maven项目依赖配置 ISDP004_创建SpringBoot3项目 ISDP005_Spring组件与自动装配 ISDP006_逻辑架构设计 ISDP007_Springboot日志配置与单元测试 ISDP008_SpringB…

Linux -- 从抢票逻辑理解线程互斥

目录 抢票逻辑代码&#xff1a; thread.hpp thread.cc 运行结果&#xff1a; 为什么票会抢为负数&#xff1f; 概念前言 临界资源 临界区 原子性 数据不一致 为什么数据不一致&#xff1f; 互斥 概念 pthread_mutex_init&#xff08;初始化互斥锁&#xff09; p…

1.微服务灰度发布落地实践(方案设计)

前言 微服务架构中的灰度发布&#xff08;也称为金丝雀发布或渐进式发布&#xff09;是一种在不影响现有用户的情况下&#xff0c;逐步将新版本的服务部署到生产环境的策略。通过灰度发布&#xff0c;你可以先将新版本的服务暴露给一小部分用户或特定的流量&#xff0c;观察其…

从 Coding (Jenkinsfile) 到 Docker:全流程自动化部署 Spring Boot 实战指南(简化篇)

前言 本文记录使用 Coding (以 Jenkinsfile 为核心) 和 Docker 部署 Springboot 项目的过程&#xff0c;分享设置细节和一些注意问题。 1. 配置服务器环境 在实施此过程前&#xff0c;确保服务器已配置好 Docker、MySQL 和 Redis&#xff0c;可参考下列链接进行操作&#xff1…

丢失的MD5

丢失的MD5 源代码&#xff1a; import hashlib for i in range(32,127):for j in range(32,127):for k in range(32,127):mhashlib.md5()m.update(TASCchr(i)O3RJMVchr(j)WDJKXchr(k)ZM)desm.hexdigest()if e9032 in des and da in des and 911513 in des:print des 发现给…

基于51单片机的交通灯外部中断proteus仿真

地址&#xff1a; https://pan.baidu.com/s/1WSlta_7pz5HdWsyIGoviHg 提取码&#xff1a;1234 仿真图&#xff1a; 芯片/模块的特点&#xff1a; AT89C52/AT89C51简介&#xff1a; AT89C52/AT89C51是一款经典的8位单片机&#xff0c;是意法半导体&#xff08;STMicroelectro…

JavaWeb(一) | 基本概念(web服务器、Tomcat、HTTP、Maven)、Servlet 简介

1. 基本概念 1.1、前言 web开发&#xff1a; web&#xff0c;网页的意思&#xff0c;www.baidu.com静态 web html,css提供给所有人看的数据始终不会发生变化&#xff01; 动态 web 淘宝&#xff0c;几乎是所有的网站&#xff1b;提供给所有人看的数据始终会发生变化&#xf…

C语言性能优化:从基础到高级的全面指南

引言 C 语言以其高效、灵活和功能强大而著称&#xff0c;被广泛应用于系统编程、嵌入式开发、游戏开发等领域。然而&#xff0c;要写出高性能的 C 语言代码&#xff0c;需要对 C 语言的特性和底层硬件有深入的了解。本文将详细介绍 C 语言性能优化的背后技术&#xff0c;并通过…

C语言-数据结构-查找

目录 一,查找的概念 二,线性查找 1,顺序查找 2,折半查找 3,分块查找 三,树表的查找 1,二叉排序树 (1)查找方式: (2)、二叉排序树的插入和生成 (3)、二叉排序树的删除 2,平衡二叉树 (1)、什么是平衡二叉树 (2)、平衡二叉树的插入调整 &#xff08;1&#xff09;L…

[江科大编程技巧] 第1期 定时器实现非阻塞式程序 按键控制LED闪烁模式——笔记

提前声明——我只是写的详细其实非常简单&#xff0c;不要看着多就放弃学习&#xff01; 阻塞&#xff1a;执行某段程序时&#xff0c;CPU因为需要等待延时或者等待某个信号而被迫处于暂停状态一段时间&#xff0c;程序执行时间较长或者时间不定 非阻塞&#xff1a;执行某段程…

如何理解:产品线经营管理的战略、组织、业务、项目、流程、绩效之间的逻辑关系?-中小企实战运营和营销工作室博客

如何理解&#xff1a;产品线经营管理的战略、组织、业务、项目、流程、绩效之间的逻辑关系&#xff1f;-中小企实战运营和营销工作室博客 产品线经营管理中&#xff0c;战略、组织、业务、项目、流程、绩效之间存在着紧密的逻辑关系&#xff0c;它们相互影响、相互作用&#xf…