摘要
在人工智能领域,光学字符识别(OCR)技术已经取得了显著的进展。随着技术的不断进步,我们正迈向OCR 2.0时代。本文将介绍由Vary团队开发的通用端到端模型GOT,这一模型在OCR领域具有革命性的潜力。
论文概览
- 论文标题:GOT: Towards OCR-2.0
- 发布平台:arXiv
- 链接:arXiv.org
模型特点
GOT模型是首个迈向OCR 2.0时代的通用端到端模型,它在多个方面展现了其先进性:
- 多任务支持:GOT模型支持多种OCR任务,包括场景文本OCR、文档OCR、细粒度OCR以及更通用的OCR任务。
- 输出格式多样:除了支持纯文本输出,GOT还能输出格式化文本,如Markdown格式,增强了文本的可读性和可编辑性。
- 结构优化:采用vision encoder + input embedding layer + decoder的架构,其中encoder部分采用带local attention的VITDet架构,有效管理显存使用。
训练方法
GOT模型的训练分为三个阶段:
- 第一阶段:高效预训练encoder,使用小型OPT-125M作为decoder,快速引入大量数据。
- 第二阶段:联合训练encoder-decoder,使用Qwen团队预训练的Qwen0.5B,适当增大decoder以适应OCR-2.0的知识需求。
- 第三阶段:锁定encoder,加强decoder以适配更多OCR应用场景,如支持坐标或颜色引导的细粒度OCR,动态分辨率OCR技术,多页OCR技术。
数据工程
研究团队在数据工程方面投入巨大,学习并使用了多种数据渲染工具,包括Latex、Mathpix-markdown-it、Matplotlib、Tikz、Verovio、Pyecharts等,以构造多样化的数据。
项目地址
对GOT模型感兴趣的研究者和开发者可以通过以下链接访问项目代码:
GitHub - Ucas-HaoranWei/GOT-OCR2.0
安装
基础环境cuda11.8+torch2.0.1
克隆仓库并导航到GOT文件夹
git clone https://github.com/Ucas-HaoranWei/GOT-OCR2.0.git
cd 'the GOT folder'
安装包
conda create -n got python=3.10 -y
conda activate got
pip install -e .
安装 Flash-Attention
pip install ninja
pip install flash-attn --no-build-isolation
GOT 权重
- Huggingface
- Google Drive
- 百度云 密码: OCR2
演示
- 普通文本OCR:
python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type ocr
- 格式文本OCR:
python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format
- 细粒度OCR:
python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format/ocr --box [x1,y1,x2,y2]
python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format/ocr --color red/green/blue
- 多裁剪OCR:
python3 GOT/demo/run_ocr_2.0_crop.py --model-name /GOT_weights/ --image-file /an/image/file.png
- 多页OCR (图像路径包含多个.png文件):
python3 GOT/demo/run_ocr_2.0_crop.py --model-name /GOT_weights/ --image-file /images/path/ --multi-page
- 渲染格式化OCR结果:
python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format --render
注意:
渲染结果可以在/results/demo.html中找到。请打开demo.html查看结果。
训练
- 训练样本可以在此链接找到。注意,在’conversations’-‘human’-‘value’中的’<image>'是必要的!
- 此代码库仅支持在我们GOT权重上的后训练(第二/第三阶段)。
- 如果你想从我们论文中描述的第一阶段训练,你需要这个仓库。
deepspeed /GOT-OCR-2.0-master/GOT/train/train_GOT.py \
--deepspeed /GOT-OCR-2.0-master/zero_config/zero2.json --model_name_or_path /GOT_weights/ \
--use_im_start_end True \
--bf16 True \
--gradient_accumulation_steps 2 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 200 \
--save_total_limit 1 \
--weight_decay 0. \
--warmup_ratio 0.001 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 8192 \
--gradient_checkpointing True \
--dataloader_num_workers 8 \
--report_to none \
--per_device_train_batch_size 2 \
--num_train_epochs 1 \
--learning_rate 2e-5 \
--datasets pdf-ocr+scence \
--output_dir /your/output.path
注意:
- 更改constant.py中相应的数据信息。
- 更改conversation_dataset_qwen.py中第37行为你的数据名称。
评估
- 使用Fox和OneChart基准,其他基准可以在权重下载链接中找到。
- 评估代码可以在GOT/eval中找到。
- 你可以使用evaluate_GOT.py运行评估。如果你有8个GPU,–num-chunks可以设置为8。
python3 GOT/eval/evaluate_GOT.py --model-name /GOT_weights/ --gtfile_path xxxx.json --image_path /image/path/ --out_path /data/eval_results/GOT_mathpix_test/ --num-chunks 8 --datatype OCR