全流程如下:
数据集
XFUND数据集是微软提出的一个用于KIE任务的多语言数据集,共包含七个数据集,每个数据集包含149张训练集和50张验证集分别为:
ZH(中文)、JA(日语)、ES(西班牙)、FR(法语)、IT(意大利)、DE(德语)、PT(葡萄牙),
选取中文数据集,链接如下
GitHub - doc-analysis/XFUND: XFUND: A Multilingual Form Understanding Benchmark
进行下面这两步,命令行如下:
! wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
! tar -xf XFUND.tar
目录结构
数据集标注格式
转换数据格式
转换脚本
按我这个来,他那个有些bug
import os
import cv2
import json
import shutil
from PIL import Image,ImageDraw
train_path = '/home/aistudio/data/data140302/XFUND_ori/zh.train/'
eval_path = '/home/aistudio/data/data140302/XFUND_ori/zh.val/'
drawImg = False
if drawImg:
os.makedirs('draw_imgs')
rec_save_path = '/home/aistudio/XFUND/rec_imgs/'
if not os.path.exists(rec_save_path):
os.makedirs(rec_save_path)
def transfer_xfun_data(json_path=None, det_output_file=None, rec_output_file=None, di=set()):
with open(json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
json_info = json.loads(lines[0])
documents = json_info["documents"]
if 'train' in json_path:
path = '/zh.train/'
else:
path = '/zh.val/'
det_file = open(det_output_file, "w")
rec_file = open(rec_output_file, "w")
for idx, document in enumerate(documents):
img_info = document["img"]
document = document["document"]
image_path = img_info["fname"]
img = cv2.imread('data/data140302/XFUND_ori'+path+image_path)
# 保存信息到检测文件中
det_info = []
if drawImg:
img_pil = Image.fromarray(img)
draw = ImageDraw.Draw(img_pil)
num=0
for doc in document:
# 检测文件信息
det_info.append({"transcription":doc["text"], "points":[[doc["box"][0],doc["box"][1]],[doc["box"][2],doc["box"][1]],[doc["box"][2],doc["box"][3]],[doc["box"][0],doc["box"][3]]]})
# 保存识别图片
pic = img[doc["box"][1]:doc["box"][3], doc["box"][0]:doc["box"][2]]
rec_save_dir = rec_save_path + os.path.splitext(image_path)[0]+'_'+str(num).zfill(3)+".jpg"
cv2.imwrite(rec_save_dir,pic)
# 识别文件信息
rec_line = '/'.join(rec_save_dir.split('/')[-2:]) + '\t' + doc["text"] + '\n'
rec_file.write(rec_line)
# 字典
di = di | set(doc["text"])
num+=1
if drawImg:
draw.polygon([(doc["box"][0],doc["box"][1]), (doc["box"][2],doc["box"][1]), (doc["box"][2],doc["box"][3]), (doc["box"][0],doc["box"][3])], outline=(255,0,0))
if drawImg:
img_pil.save('./draw_imgs/'+image_path)
det_line = path+ image_path + '\t' + json.dumps(det_info,ensure_ascii=False) +'\n'
det_file.write(det_line)
det_file.close()
rec_file.close()
return di
# =================检测文件=================
det_train = '/home/aistudio/XFUND/det_gt_train.txt'
det_test = '/home/aistudio/XFUND/det_gt_val.txt'
# =================识别文件=================
rec_train = '/home/aistudio/XFUND/rec_gt_train.txt'
rec_test = '/home/aistudio/XFUND/rec_gt_val.txt'
di_xfund = set()
di_xfund = transfer_xfun_data("/home/aistudio/data/data140302/XFUND_ori/zh.train.json", det_train, rec_train, di_xfund)
di_xfund = transfer_xfun_data("/home/aistudio/data/data140302/XFUND_ori/zh.val.json", det_test, rec_test, di_xfund)
文本检测的标注格式
中间用’\t’分隔:
”图像文件名 json.dumps编码的图像标注信息”ch4_test_images/img_61.jpg [{“transcription”: “MASA”, “points”:[[310, 104], [416, 141], [418, 216], [312, 179]]}, {⋯}]
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 points 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。transcription 表示当前文本框的文字,当其内容为“###”时,表示该文本框无效,在训练时会跳过。
文本识别的标注格式
txt文件中默认请将图片路径和图片标签用’\t’分割,如用其他方式分割将造成训练报错
文本检测
方案1:• PP-OCRv2中英文超轻量检测预训练模型
下载模型
%mkdir /home/aistudio/PaddleOCR/pretrain/
%cd /home/aistudio/PaddleOCR/pretrain/
! wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar
! tar -xf ch_PP-OCRv2_det_distill_train.tar && rm -rf ch_PP-OCRv2_det_distill_train.tar
% cd ..
修改配置,主要是改路径
评估如下
%cd /home/aistudio/PaddleOCR
! python tools/eval.py \
-c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
精度达到77%
finetune
100epoch,我直接用训练好的权重文件了
%cd /home/aistudio/PaddleOCR/
! python tools/eval.py \
-c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml \
-o Global.checkpoints="pretrain/ch_db_mv3-student1600-finetune/best_accuracy"
未finetune的如下,效果略微差了一点
导出模型
在模型训练过程中保存的模型文件是包含前向预测和反向传播的过程,在实际的工业部署则不需要反向传播,因此需要将模型进行导成部署需要的模型格式。执行下面命令,即可导出模型。
%cd /home/aistudio/PaddleOCR/
! python tools/export_model.py \
-c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml \
-o Global.pretrained_model="pretrain/ch_db_mv3-student1600-finetune/best_accuracy" \
Global.save_inference_dir="./output/det_db_inference/"
会有这些文件
文本识别
预训练模型
%cd /home/aistudio/PaddleOCR/pretrain/
! wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar
! tar -xf ch_PP-OCRv2_rec_train.tar && rm -rf ch_PP-OCRv2_rec_train.tar
% cd ..
下载预训练模型,评估结果如下
XFUND数据集+finetune
训练代码如下,喜欢自己去跑,太慢了,玩不起
%cd /home/aistudio/PaddleOCR/
! CUDA_VISIBLE_DEVICES=0 python tools/train.py \
-c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
用人训练好的结果如下
XFUND数据集+finetune+真实通用识别数据
这部分数据人家没公开,公开了26W我也跑不动,还行,权重文件给了
下面就跟检测那部分一样导出模型就行了
文档视觉问答
主要分为SER和RE两个任务,先下载权重
%cd pretrain
#下载SER模型
! wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar && tar -xvf ser_LayoutXLM_xfun_zh.tar
%rm -rf ser_LayoutXLM_xfun_zh.tar
#下载RE模型
! wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar && tar -xvf re_LayoutXLM_xfun_zh.tar
%rm -rf re_LayoutXLM_xfun_zh.tar
%cd ../
SER
SER: 语义实体识别 (Semantic Entity Recognition), 可以完成对图像中的文本识别与分类。
对于XFUND数据集,有QUESTION, ANSWER, HEADER 3种类别
用下面的命令下载数据集,不要用他环境提供的,会报错,数据格式有问题
! wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
! tar -xf XFUND.tar
进行训练,命令行如下
%cd /home/aistudio/PaddleOCR/
! CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml
进行评估,代码如下
! CUDA_VISIBLE_DEVICES=0 python tools/eval.py \
-c configs/vqa/ser/layoutxlm.yml \
-o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
完成ocr+SER串联
! CUDA_VISIBLE_DEVICES=0 python tools/infer_vqa_token_ser.py \
-c configs/vqa/ser/layoutxlm.yml \
-o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ \
Global.infer_img=doc/vqa/input/zh_val_42.jpg
结果如下
import cv2
from matplotlib import pyplot as plt
# 在notebook中使用matplotlib.pyplot绘图时,需要添加该命令进行显示
%matplotlib inline
img = cv2.imread('output/ser/zh_val_42_ser.jpg')
plt.figure(figsize=(48,24))
plt.imshow(img)
RE
基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair)
问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
训练命令如下
! CUDA_VISIBLE_DEVICES=0 python tools/train.py \
-c configs/vqa/re/layoutxlm.yml
进行评估
! CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py \
-c configs/vqa/re/layoutxlm.yml \
-o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/
串联全部
%cd /home/aistudio/PaddleOCR
! CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser_re.py \
-c configs/vqa/re/layoutxlm.yml \
-o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ \
Global.infer_img=test_imgs/ \
-c_ser configs/vqa/ser/layoutxlm.yml \
-o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
结果大概这样
红色是问题,蓝色是答案,用绿线连接表示关系
修改文件,导出excel
vim /home/aistudio/PaddleOCR/tools/infer_vqa_token_ser_re.py
为了输出信息匹配对,我们修改tools/infer_vqa_token_ser_re.py
文件中的line 194-197
。
fout.write(img_path + "\t" + json.dumps(
{
"ser_resule": result,
}, ensure_ascii=False) + "\n")
更改为
result_key = {}
for ocr_info_head, ocr_info_tail in result:
result_key[ocr_info_head['text']] = ocr_info_tail['text']
fout.write(img_path + "\t" + json.dumps(
result_key, ensure_ascii=False) + "\n")
import json
import xlsxwriter as xw
workbook = xw.Workbook('output/re/infer_results.xlsx')
format1 = workbook.add_format({
'align': 'center',
'valign': 'vcenter',
'text_wrap': True,
})
worksheet1 = workbook.add_worksheet('sheet1')
worksheet1.activate()
title = ['姓名', '性别', '民族', '文化程度', '身份证号码', '联系电话', '通讯地址']
worksheet1.write_row('A1', title)
i = 2
with open('output/re/infer_results.txt', 'r', encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
img_path, result = line.strip().split('\t')
result_key = json.loads(result)
# 写入Excel
row_data = [result_key['姓名'], result_key['性别'], result_key['民族'], result_key['文化程度'], result_key['身份证号码'],
result_key['联系电话'], result_key['通讯地址']]
row = 'A' + str(i)
worksheet1.write_row(row, row_data, format1)
i+=1
workbook.close()