文章目录
- 前言
- 一、基类 OCRExecutorBase
- 二、EasyOCR
- 1.安装
- 2.模型下载
- 3.DEMO
- 三、Tesseract
- 1.安装
- 2.使用问题
- 3.DEMO
- 四、PaddleOCR
- 1.安装
- 2.DEMO
- 五、PaddleOCR(PyTorch移植版)
- 1.代码整理
- 2.DEMO
- 六、TrOCR
- 1.安装
- 2.模型下载
- 3.DEMO
- 七、GOT
- 1.安装
- 2.模型下载
- 3.DEMO
- 总结
前言
OCR
(Optical Character Recognition
,光学字符识别)是指对包含文本内容的图像或视频进行处理和识别,并提取其中所包含的的文字及排版信息的过程(摘自维基百科)。根据其应用场景可分为印刷文本识别、手写文本识别、公式文本识别、场景文本识别以及古籍文本识别。
举一个实用的例子:想阅读一本电子书,但该书是扫描版的 PDF 文档,具有文件体积大、文字不可选、无法编辑和可读性差的缺点;我们可以借助OCR
将文档识别并转换成轻量的 EPUB 格式,并提升阅读体验。有意义的应用场景还有很多,此处不一一列举。
最近由于实际需求,对之前和时下流行的OCR
工具进行了一些货比三家式的接触和使用,尤其是近期(2024年9月)刚出端到端的GOT-OCR2.0,效果惊艳。遂决定在此记录,内容包括EasyOCR、Tesseract、PaddleOCR及其PyTorch移植版、单行手写文本识别TrOCR,以及GOT,对应的repo地址为https://github.com/DaiHaoguang3151/ocr_fusion。
一、基类 OCRExecutorBase
为了统一所有这些OCR
的使用,我在repo的src/ocr_executor/ocr_executor_base.py
文件中定义了基类OCRExecutorBase
,代码如下所示。
1)在初始化self.__init__()
部分,主要是用来完成模型加载等工作;
2)对图片进行批量的OCR
识别则是通过self.execute()
方法,其输入paths_images
则是批量(image_path, opencv_image)
对,这么设计是因为有些OCR
工具的输入既可以是图片的路径,也可以是图片本身,比如OpenCV
或者PIL
读取的图片;
3)self._generate()
是self.execute()
的核心部分,就是直接调用OCR
的地方。
# src/ocr_executor/ocr_executor_base.py
import math
from typing import List
import numpy as np
class OCRExecutorBase:
def __init__(self):
pass
def _generate(self, images: List[np.ndarray]) -> List:
"""
批量生成
"""
raise NotImplementedError
def execute(self, paths_images: List, batch_size=16):
"""
执行ocr
"""
results = []
num = len(paths_images)
paths = [ele[0] for ele in paths_images]
# print("IMAGE PATHS: ", paths)
images = [ele[1] for ele in paths_images]
iterations = math.ceil(num / batch_size)
for iter in range(iterations):
batch_images = images[iter * batch_size: min((iter + 1) * batch_size, num)]
batch_paths = paths[iter * batch_size: min((iter + 1) * batch_size, num)]
batch_results = self._generate(batch_images)
results += batch_results
print("GENERATED: ", results)
return results
二、EasyOCR
1.安装
EasyOCR
是个流行的轻量化OCR
工具,底层依赖于PyTorch
,所以需要一起安装。
# eg: CUDA 11.7
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install easyocr
2.模型下载
EasyOCR
会使用到文本检测模型
和文本识别模型
,在首次使用时会自动下载相应的模型,但是由于网络原因,很可能报错,需要手动下载模型,参考这篇博客,具体步骤如下:
1)首先找到模型下载存放的路径,默认是~/.EasyOCR
;
2)去modelhub中下载权重文件,解压后放置到上述模型存放路径;文本检测模型是CRAFT
,文本识别模型是2rd Generation Models
下方的english_g2
(因为我们选择的语言是英文)。
3.DEMO
EasyOCRExecutor
部分代码如下,1)在初始化阶段使用easyocr.Reader
加载模型,["en"]
表示选择英文,download_enabled=False
则表示使用已经下载到本地的模型,而不再去下载;2)重写self._generate()
方法,通过self.reader.readtext(image, detail=1)
即可获取识别的文本、文本包围盒以及得分。
# src/ocr_executor/easyocr_executor.py
import math
from typing import List
import numpy as np
import easyocr
from ocr_executor.ocr_executor_base import OCRExecutorBase
from util.util import save_detection
class EasyOCRExecutor(OCRExecutorBase):
def __init__(self):
super(EasyOCRExecutor, self).__init__()
self.reader = easyocr.Reader(["en"], download_enabled=False)
def _generate(self, images: List[np.ndarray]) -> List:
"""
批量生成
"""
results = []
for image in images:
# 可以传图片或者文件路径,detail=1返回检测结果
result = self.reader.readtext(image, detail=1)
# for detection in result:
# bbox, text, score = detection
# print(f"Text: {text}, BBox: {bbox}, Score: {score}")
results.append(result)
return results
我在src/main.py
中完成了所有OCR
识别样例图片的脚本,现在就来看一下使用EasyOCRExecutor
的demo:
import os
import numpy as np
import cv2
from ocr_executor.easyocr_executor import EasyOCRExecutor
from util.util import save_detection # 用于绘制检测和识别结果的
# 构造输入
image_path = "/home/ubuntu/Projects_ubuntu/ocr_fusion/src/images/handwriting.png"
image = cv2.imread(image_path)
paths_images = [(image_path, image)]
output_dir = "./images_output"
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# easyocr
easyocr_executor = EasyOCRExecutor()
result = easyocr_executor.execute(paths_images)[0] # 取第一个结果,对应第一张图片
# 简单的输出处理,方便喂给save_detection绘制图片
bboxes = []
texts = []
for idx, detection in enumerate(result):
bbox, text, score = detection
bbox = [[int(pt[0]), int(pt[1])] for pt in bbox]
bbox = np.array(bbox).astype(np.int32).reshape(-1, 2)
bboxes.append(bbox)
texts.append(text)
save_detection(os.path.join(output_dir, "easyocr.png"), image, bboxes, texts=texts, poly=True)
结果如下,左边是原图,蓝色框是EasyOCR
的文本检测框,右边是每个框中识别的文字,效果一般,对于这种简单场景还是会有一些识别错误。
三、Tesseract
1.安装
想要使用Tesseract
,需要先安装tesseract-ocr
,然后安装pytesseract
这个Python
包。我尝试了两种安装方式,针对使用conda
虚拟环境的情况,我推荐第一种。
- 方式1:使用
conda
命令同时安装tesseract
和pytesseract
,比如我使用环境是python3.7;
# 安装参考链接为:
# https://anaconda.org/conda-forge/tesseract
# https://anaconda.org/conda-forge/pytesseract
# 1) 安装:使用conda-forge源,有点慢
conda install conda-forge::tesseract pytesseract
# 2) 测试tesseract
# conda activate your_env
tesseract --version
# tesseract显示的版本有可能是5.3.0,也有可能是4.1.1,问题不大
# tesseract 4.1.1
# leptonica-1.80.0
# libgif 5.2.1 : libjpeg 9e : libpng 1.6.39 : libtiff 4.4.0 : zlib 1.2.13 : libwebp 1.2.4 : libopenjp2 2.4.0
# Found AVX2
# Found AVX
# Found FMA
# Found SSE
# Found libarchive 3.4.0 zlib/1.2.11 liblzma/5.2.4 bz2lib/1.0.8 liblz4/1.9.2 libzstd/1.4.4
# 查找可执行文件路径
which tesseract
# /home/ubuntu/anaconda3/envs/ocr_env/bin/tesseract
# 3) 测试pytesseract
# conda list找到了pytesseract-0.3.10
# 检查是否安装成功
python
# >>> import pytesseract
# >>> print(pytesseract.get_tesseract_version()) # 4.1.1
- 方式2:先在系统层级安装
tesseract-ocr
,然后在虚拟环境中安装pytesseract
;这种安装方式的缺点是没有进行完全的环境隔离,在有些使用场景下,可能会遇到加载动态链接库报错的问题。
# 1) 安装tesseract-ocr
sudo apt-get update
sudo apt-get install tesseract-ocr
# 配置环境变量
vim ~/.bashrc
# 在文件最后添加:
export PATH=$PATH:/usr/local/bin
# 重新加载文件
source ~/.bashrc
# 验证配置
echo $PATH
# /home/ubuntu/anaconda3/bin:/home/ubuntu/anaconda3/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/usr/local/cuda/bin:/usr/local/bin
# 2) 验证 Tesseract 安装
tesseract --version
# tesseract 4.1.1
# leptonica-1.79.0
# libgif 5.1.4 : libjpeg 8d (libjpeg-turbo 2.0.3) : libpng 1.6.37 : libtiff 4.1.0 : zlib 1.2.11 : libwebp 0.6.1 : libopenjp2 2.3.1
# Found AVX2
# Found AVX
# Found FMA
# Found SSE
# Found libarchive 3.4.0 zlib/1.2.11 liblzma/5.2.4 bz2lib/1.0.8 liblz4/1.9.2 libzstd/1.4.4
# 查找可执行文件路径
which tesseract
# /usr/bin/tesseract --> 有点奇怪,PATH=$PATH:/usr/local/bin这个设置好像是无效的,实际路径在/usr/bin/
# 3) 安装pytesseract
pip install pytesseract # 0.3.10
2.使用问题
当我们开始使用pytesseract
时有可能遇到如下报错:
from PIL import Image
image = Image.open("/home/ubuntu/Projects_ubuntu/TrOCR/images/cropped_image/17.jpg")
printed_text = pytesseract.image_to_string(image, config="--psm 7")
# 报错:
# Traceback (most recent call last):
# File "/home/ubuntu/Projects_ubuntu/TrOCR/tesseract_ocr.py", line 24, in <module>
# printed_text = pytesseract.image_to_string(image, config="--psm 7")
# File "/home/ubuntu/anaconda3/envs/fasttext/lib/python3.7/site-packages/pytesseract/pytesseract.py", line 427, in image_to_string
# }[output_type]()
# File "/home/ubuntu/anaconda3/envs/fasttext/lib/python3.7/site-packages/pytesseract/pytesseract.py", line 426, in <lambda>
# Output.STRING: lambda: run_and_get_output(*args),
# File "/home/ubuntu/anaconda3/envs/fasttext/lib/python3.7/site-packages/pytesseract/pytesseract.py", line 288, in run_and_get_output
# run_tesseract(**kwargs)
# File "/home/ubuntu/anaconda3/envs/fasttext/lib/python3.7/site-packages/pytesseract/pytesseract.py", line 264, in run_tesseract
# raise TesseractError(proc.returncode, get_errors(error_string))
# pytesseract.pytesseract.TesseractError: (1, 'Error opening data file /home/ubuntu/anaconda3/envs/ocr_env/share/tessdata Please make sure the TESSDATA_PREFIX environment variable is set to your "tessdata" directory. Failed loading language \'eng\' Tesseract couldn\'t load any languages! Could not initialize tesseract.')
查看最后一行报错,实际上是找不到数据,奇怪的是:
1)在/home/ubuntu/anaconda3/envs/ocr_env/share/tessdata
文件夹下能找到eng.traineddata
;
2)/home/ubuntu/anaconda3/envs/ocr_env/share/tessdata
并不是报错中所说的文件,而是文件夹。
解决方式是按照提示修改环境变量:
# python中
import os
os.environ['TESSDATA_PREFIX'] = '/home/ubuntu/anaconda3/envs/fasttext/share/tessdata/'
# 或者bash中
export TESSDATA_PREFIX=/home/ubuntu/anaconda3/envs/fasttext/share/tessdata/
3.DEMO
Tesseract
是支持多种形式的OCR
,1)比如pytesseract.image_to_string(image, config="--psm 7")
是只识别文字,参数--psm
可以指定识别模式,比如"--psm 7"
表示单行文本识别,可以自己查一下;2)pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
则以字典的形式输出所有信息,包括识别的层次结构级别level
(比如level=5
可以过滤出单词级别的结果)、包围盒以及置信度conf
等等。
# src/ocr_executor/tesseractocr_executor_base.py
from typing import List
import numpy as np
import pytesseract
from ocr_executor.ocr_executor_base import OCRExecutorBase
class TesseractOCRExecutor(OCRExecutorBase):
def __init__(self):
super(TesseractOCRExecutor, self).__init__()
def _generate(self, images: np.ndarray) -> List:
"""
根据输入图片批量生成文字
"""
generated = []
# for image in images:
# result = pytesseract.image_to_string(image, config="--psm 7") # --psm 7 表示单行识别
# generated.append(result)
for image in images:
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
result = []
num = len(data["level"])
for i in range(num):
if not data["level"][i] == 5:
continue
(x, y, w, h) = (data['left'][i], data['top'][i], data['width'][i], data['height'][i])
box = [[x, y], [x + w, y + h]]
text = data["text"][i]
conf = data["conf"][i]
result.append((box, text, conf))
generated.append(result)
return generated
TesseractOCRExecutor
识别demo如下:
from ocr_executor.tesseractocr_executor import TesseractOCRExecutor
# 数据准备和前面是一样,省略
# tesseractocr
tesseractocr_executor = TesseractOCRExecutor()
result = tesseractocr_executor.execute(paths_images)[0]
bboxes = []
texts = []
for idx, detection in enumerate(result):
bbox, text, conf = detection
bboxes.append(bbox)
texts.append(text)
save_detection(os.path.join(output_dir, "tesseractocr.png"), image, bboxes, texts=texts, poly=False)
结果如下,这里只展示了单词级别的包围盒,你可以根据需求选择你需要的level
:
# level 字段可能的值及其含义
# 1: 表示整个页面(Page)。
# 2: 表示块(Block),通常是一个独立的区域,如段落或图像。
# 3: 表示段落(Paragraph)。
# 4: 表示行(Line),即文本行。
# 5: 表示单词(Word),即单个单词。
# 6: 表示字符(Symbol),即单个字符。
四、PaddleOCR
PaddleOCR
底层依赖于PaddlePaddle
,但是它的安装可能会有点麻烦,并且对于CUDA
版本的支持和其他框架比如PyTorch
不是很同步。如果对这方面有些头疼的朋友,可以选择PaddleOCR
的PyTorch
移植版,以便快速体验。
1.安装
代码如下(示例):
- PaddlePaddle安装(我参考了这篇文章):
1)使用conda
安装paddlepaddle-gpu==2.5.2
conda install paddlepaddle-gpu==2.5.2 cudatoolkit=11.7 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge
在安装过程中,你会看到提示信息,会下载和安装cudatoolkit
和cudnn
,我们来查看一下具体安装结果:
(ocr_env) ubuntu@ubuntu:~$ conda list cudatoolkit
# packages in environment at /home/ubuntu/anaconda3/envs/ocr_env:
#
# Name Version Build Channel
cudatoolkit 11.7.1 h4bc3d14_13 conda-forge
(ocr_env) ubuntu@ubuntu:~$ conda list cudnn
# packages in environment at /home/ubuntu/anaconda3/envs/ocr_env:
#
# Name Version Build Channel
cudnn 8.4.1.50 hed8a83a_0 conda-forge
2)按照paddlepaddle
官方文档,检查是否安装成功:
ubuntu@ubuntu:~/anaconda3/envs/ocr_env/lib$ python
>>> import paddle
>>> paddle.utils.run_check()
# 报错日志:
# PreconditionNotMetError: Cannot load cudnn shared library. Cannot invoke method cudnnGetVersion.
# [Hint: cudnn_dso_handle should not be null.] (at ../paddle/phi/backends/dynload/cudnn.cc:64)
# [operator < fill_constant > error]
报错显示,找不到cudnn
相关的shared library
;
3)查看shared library
中有没有libcudnn.so
和libcublas.so
:
- 使用命令
ls /usr/lib |grep lib
查看,发现上述文件不存在;
(ocr_env) ubuntu@ubuntu:/usr/lib$ ls /usr/lib |grep lib
klibc
klibc-abS-oVB3xeRN8SFypUWbQvR33nc.so
libGL.so.1
libpsm1
libreoffice
- 使用命令
find $CONDA_PREFIX/lib -name "libcublas.so"
手动查找cudatoolkit
库文件,这是在当前conda
环境中查找
(ocr_env) ubuntu@ubuntu:~$ find $CONDA_PREFIX/lib -name "libcublas.so"
/home/ubuntu/anaconda3/envs/ocr_env/lib/libcublas.so # 找到了文件路径
4)根据这篇博客,创建软链接:
(ocr_env) ubuntu@ubuntu:~/anaconda3/envs/ocr_env/lib$ cd /usr/lib
(ocr_env) ubuntu@ubuntu:/usr/lib$ sudo ln -s /home/ubuntu/anaconda3/envs/ocr_env/lib/libcudnn.so.8.4.1 libcudnn.so
(ocr_env) ubuntu@ubuntu:/usr/lib$ sudo ln -s /home/ubuntu/anaconda3/envs/ocr_env/lib/libcublas.so.11.10.3.66 libcublas.so
然后检测相关的lib
是否存在:发现已存在
(ocr_env) ubuntu@ubuntu:/usr/lib$ ls /usr/lib |grep lib
klibc
klibc-abS-oVB3xeRN8SFypUWbQvR33nc.so
libcublas.so
libcudnn.so
libGL.so.1
libpsm1
libreoffice
5)重复2),再次确认是否安装成功,这次报错如下:
(ocr_env) ubuntu@ubuntu:~/anaconda3/envs/ocr_env/lib$ python
>>> import paddle
>>> paddle.utils.run_check()
# 报错
# Running verify PaddlePaddle program ...
# I0813 17:02:59.401835 6514 interpretercore.cc:237] New Executor is Running.
# W0813 17:02:59.401952 6514 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 8.6, Driver API Version: 12.2, Runtime API Version: 11.7
# W0813 17:02:59.421336 6514 gpu_resources.cc:149] device: 0, cuDNN Version: 8.4.
# python: symbol lookup error: /usr/local/cuda-11.0/targets/x86_64-linux/lib/libcublas.so: undefined symbol: runGemmShortApi, version libcublasLt.so.11
发现找的是/usr/local/cuda-11.0
,这应该是我之前安装的系统级的cudatoolkit
,说明paddlepaddle-gpu
没有去找虚拟环境的cudatoolkit
;
6)设置环境变量:export LD_LIBRARY_PATH=/home/ubuntu/anaconda3/envs/ocr_env/lib:$LD_LIBRARY_PATH
;
(ocr_env) ubuntu@ubuntu:/usr/local/cuda-11.0/targets/x86_64-linux/lib$ echo $LD_LIBRARY_PATH
(ocr_env) ubuntu@ubuntu:/usr/local/cuda-11.0/targets/x86_64-linux/lib$ conda list paddlepaddle-gpu
# packages in environment at /home/ubuntu/anaconda3/envs/ocr_env:
#
# Name Version Build Channel
paddlepaddle-gpu 2.5.2.post117 pypi_0 pypi
(ocr_env) ubuntu@ubuntu:/usr/local/cuda-11.0/targets/x86_64-linux/lib$ export LD_LIBRARY_PATH=/home/ubuntu/anaconda3/envs/ocr_env/lib:$LD_LIBRARY_PATH
(ocr_env) ubuntu@ubuntu:/usr/local/cuda-11.0/targets/x86_64-linux/lib$ echo $LD_LIBRARY_PATH
/home/ubuntu/anaconda3/envs/ocr_env/lib:
# 方法2:永久设置(对所有终端会话有效,如果有其它虚拟环境,不是很推荐):
# echo 'export LD_LIBRARY_PATH=/path/to/library:$LD_LIBRARY_PATH' >> ~/.bashrc
# source ~/.bashrc
7)重复2),这次检查是安装成功的:
>>> import paddle
>>> paddle.utils.run_check()
# Running verify PaddlePaddle program ...
# I0813 17:19:18.161489 8275 interpretercore.cc:237] New Executor is Running.
# W0813 17:19:18.161605 8275 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 8.6, Driver API Version: 12.2, Runtime API Version: 11.7
# W0813 17:19:18.163659 8275 gpu_resources.cc:149] device: 0, cuDNN Version: 8.4.
# I0813 17:19:18.836459 8275 interpreter_util.cc:518] Standalone Executor is Used.
# PaddlePaddle works well on 1 GPU.
- PaddleOCR安装:
pip install paddleocr # version = 2.7.0.2
2.DEMO
PaddleOCRExecutor
代码如下,1)初始化阶段设置了语言,以及是否识别(可以只检测不识别);2)PaddleOCR
自带的draw_ocr
可以直接绘制识别结果。
# src/ocr_executor/paddleocr_executor_base.py
from typing import Dict, List
import numpy as np
import cv2
from PIL import ImageFont
from paddleocr import PaddleOCR, draw_ocr
from ocr_executor.ocr_executor_base import OCRExecutorBase
class PaddleOCRExecutor(OCRExecutorBase):
def __init__(self, lang: str = "en", rec: bool = True):
super(PaddleOCRExecutor, self).__init__()
# 识别语言
self.lang: str = lang
# 是否识别
self.rec = rec
self._init()
def _init(self):
# 初始化
self.model = PaddleOCR(use_angle_cls=True,
lang=self.lang,
det=True,
rec=self.rec)
def _generate(self, images: np.ndarray) -> List:
"""
根据输入图片生成文字(可选)和返回相应的box
images可以传入路径
"""
# 暂时不能批量处理
generated = []
for image in images:
# result = self.model.ocr(image_path, det=True, rec=self.rec, cls=True)[0]
if not isinstance(image, np.ndarray):
image_ = np.array(image) # TODO: 最好看一下通道顺序
print("image_.shape: ", image_.shape)
else:
image_ = image
result = self.model.ocr(image_, det=True, rec=self.rec, cls=True)[0]
# self._draw_single_result(image_, result)
generated.append(result)
return generated
def _draw_single_result(self, image, result):
"""
绘制一张图片的检测和识别结果
"""
# for line in result:
# print(line)
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
# font = ImageFont.load_default()
im_show = draw_ocr(image, boxes, txts, scores, font_path="/usr/share/fonts/truetype/ttf-khmeros-core/KhmerOS.ttf")
cv2.imwrite("./result.jpg", im_show)
PaddleOCRExecutor
识别demo
如下:
from ocr_executor.paddleocr_executor import PaddleOCRExecutor
# paddleocr
paddleocr_executor = PaddleOCRExecutor()
result = paddleocr_executor.execute(paths_images)[0]
bboxes = []
texts = []
for idx, detection in enumerate(result):
bbox, (text, score) = detection
bbox = np.array(bbox).astype(np.int32).reshape(-1, 2)
bboxes.append(bbox)
texts.append(text)
save_detection(os.path.join(output_dir, "paddleocr.png"), image, bboxes, texts=texts, poly=True)
结果如下,检测和识别效果都是不错的,具体来说就是检测包围盒刚好是我们所期望的样子,颗粒度刚好,同时识别也没有出错;个人经验是其检测模型可靠性挺强,可以和其他识别模型结合使用,识别模型适用于简单场景或者要求不是很高的场景;模型都不大,赞一个。
五、PaddleOCR(PyTorch移植版)
上一小节说了,如果你不想安装PaddlePaddle
,那么也可以使用好心人提供的PyTorch移植版。小缺点是如果你需要在项目中使用,需要copy和整理一下代码;当然也可以转成onnx
格式进行推理。
1.代码整理
由于我当前只是更想用PaddleOCR
的检测模型,所以在repo中整理出了ch_ptocr_v3_det_infer
以及ch_ptocr_v4_det_infer
这两个版本的检测模型及其代码,存放于src/paddle_ocr
。这两个模型都可以检测中英等多种语言文本。如果你另有需求或者想多探索一下其他模型,可以去原作者仓库把玩。
2.DEMO
我将检测模型封装成了类TextDetector
,具体网络模型我们不在此展开,来说一下ch_ptocr_v3_det_infer
以及ch_ptocr_v4_det_infer
对应的参数传递。如下所示,如果使用ch_ptocr_v4_det_infer
,你需要传入模型路径det_model_path
以及配置文件路径det_yaml_path
;而如果使用ch_ptocr_v3_det_infer
,只需传入det_model_path
,因为它的配置并不是以文件形式存在。我这边都写的绝对路径,需要你改成自己的路径;文件都已存在,不需要你自行下载。
# src/paddle_ocr/text_detector.py
from paddle_ocr.model_args import parse_args
_args = parse_args()
# det v4:
# _args.det_model_path = "/home/ubuntu/Projects_ubuntu/torchocr/src/paddle_ocr/pretrained_models/det_v4/ch_ptocr_v4_det_infer.pth"
# _args.det_yaml_path = "/home/ubuntu/Projects_ubuntu/torchocr/src/paddle_ocr/pretrained_models/det_v4/ch_PP-OCRv4_det_student.yml"
# det v3:
_args.det_model_path = "/home/ubuntu/Projects_ubuntu/torchocr/src/paddle_ocr/pretrained_models/det_v3/ch_ptocr_v3_det_infer.pth"
class TextDetector:
def __init__(self, args=_args, **kwargs):
# ...
TextDetector
检测demo如下:
from paddle_ocr.text_detector import TextDetector
# paddleocr (pytorch det model)
text_detector = TextDetector()
bboxes, _ = text_detector(image)
bboxes = [np.array(bbox).astype(np.int32).reshape((-1, 2)) for bbox in bboxes.tolist()]
save_detection(os.path.join(output_dir, "paddleocr_pytorch_det.png"), image, bboxes, texts=None, poly=True)
结果如下,你会发现检测结果和原生的PaddleOCR
是完全一致的。
六、TrOCR
TrOCR
是微软出品,主要用于单行手写文字识别,效果不错;缺点是1)只能支持单行识别,一般需要结合文本检测模型使用;2)该模型使用Transformer
序列生成的方式来生成识别的文字的,速度略慢;3)由于是序列生成,没有很好的方式提供识别结果整个的置信度(github上有人提过这样的问题,但没有解决)。
1.安装
pip install torch # 选择合适gpu版本进行安装
pip install transformers
2.模型下载
Hugging Face
上下载trocr-base-handwritten模型到本地,如下图所示。
3.DEMO
TrOCRExecutor
代码如下,主要组件是处理器TrOCRProcessor
(可以同时处理图像和文本)和模型VisionEncoderDecoderModel
。
# src/ocr_exeutor/trocr_executor.py
from typing import List
import numpy as np
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from ocr_executor.ocr_executor_base import OCRExecutorBase
DAFAULT_MODEL_PATH = "/home/ubuntu/Projects_ubuntu/TrOCR/trocr_base_handwritten"
class TrOCRExecutor(OCRExecutorBase):
def __init__(self, model_path: str = DAFAULT_MODEL_PATH):
super(TrOCRExecutor, self).__init__()
# ocr模型
self.model_path = model_path
self._init()
def _init(self):
# 初始化
if "trocr" in self.model_path:
self.processor = TrOCRProcessor.from_pretrained(self.model_path)
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_path)
else:
raise NotImplementedError
def _generate(self, images: np.ndarray) -> List:
"""
根据输入图片批量生成文字
"""
pixel_values = self.processor(images=images, return_tensors="pt").pixel_values
# generated_ids = self.model.generate(pixel_values, output_scores=True, return_dict_in_generate=True)
generated_ids = self.model.generate(pixel_values) # TODO
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_text
TrOCRExecutor
识别demo如下:
from ocr_executor.trocr_executor import TrOCRExecutor
# 注意,和上面不同,这边换成了单行手写文本的图片
image_path = "/home/ubuntu/Projects_ubuntu/ocr_fusion/src/images/handwriting_single_line.png" # trocr只能识别单行手写文本
image = cv2.imread(image_path)
paths_images = [(image_path, image)]
# trocr
trocr_executor = TrOCRExecutor()
text = trocr_executor.execute(paths_images)[0]
save_detection(os.path.join(output_dir, "trocr.png"), image, [], texts=[text], poly=True)
结果如下,个人体验是效果还可以,但是鲁棒性有待商榷,有时候单行文字截取多一些少一些会影响识别效果。
七、GOT
GOT
是阶跃星辰、旷世、中国科学院以及清华的作品,效果确实不错,同时它支持多种形式的OCR
,比如markdown、音符、分子式等等,通用性强;该模型是端到端的,直接输出文本,但是应该不能输出文本检测框。
1.安装
安装可以参考源仓库(其中提到的Flash-Attention
不是必须的),此处不赘述。
2.模型下载
Hugging Face
上下载GOT-OCR2_0模型到本地,如下图所示。
3.DEMO
我把源代码中必要的部分抽出来,放到了src/got_ocr
中,主体部分封装成了类GOTTextGenerator
,为GOTOCRExecutor
所用,GOTOCRExecutor
代码如下:
# src/ocr_exeutor/gotocr_executor.py
class GOTOCRExecutor(OCRExecutorBase):
def __init__(self, model_name: str = _MODEL_NAME):
super(GOTOCRExecutor, self).__init__()
# 模型
self.text_generator = GOTTextGenerator(model_name)
def _generate(self, images: List[np.ndarray]) -> List:
"""
批量生成
"""
results = []
for image in images:
# 构建模型输入
input_dict = {
"image": image,
"type": "ocr"
}
result = self.text_generator.generate(input_dict)
results.append(result)
return results
GOTOCRExecutor
识别demo如下:
from ocr_executor.gotocr_executor import GOTOCRExecutor
# gotocr
gotocr_executor = GOTOCRExecutor()
result = gotocr_executor.execute(paths_images)[0]
save_detection(os.path.join(output_dir, "gotocr.png"), image, [], texts=[result], poly=True)
结果如下,我们发现直接让它识别多行文本差强人意,单词之间缺少空格,个人猜测有两个原因:1)图片比较大,模型处理时对图片的压缩比较厉害;2)训练数据的分布上可能有问题,因为论文中隐约可以看出他们收集的手写数据是小片段的拼接出来的,所以可能数据上不是很到位。
需要说明的是,其实中文都正确识别出来了,只是OpenCV
绘制文本时默认不支持中文,所以显示的都是问号。
针对上述问题,我们可以换一种思路,就是让GOT
负责单行级别的识别,在复杂场景下只要在前面加一个文本检测模型即可。因此,我是用单行手写文本又测试了一次,结果如下,完美解决。
总结
本篇配合repo讲述了一些流行的OCR
工具的使用方法,比如从好用的paddleOCR
到OCR2.0
的端到端模型GOT
。希望能为需要使用OCR
工具的朋友提供便利。