1、项目介绍
本文将会使用Microsoft开源的表格检测模型table-transformer-detection来实现表格检测与入门。
以下将分三部分进行介绍:
- 表格检测:检测图片或PDF文件中的表格所在的区域
- 表格结构识别:对于检测后的表格区域,再详细识别表格的区域,即表格-的行、列,表头所在的位置,进一步得到单元格的位置
- 表格数据提取: 在表格结构的基础上,借助OCR可得到每个单元格内的文本,从而获得整个表格数据
2、环境构建
2.1、服务配置
2.2、环境构建
conda create -n Microsoft python==3.8 pip==2.1.1
conda activate Microsoft
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
模型下载
https://huggingface.co/microsoft/table-transformer-detection/tree/main
3、表格检测
检测图片或PDF文件中的表格所在的区域
部分代码如下:
from transformers import AutoImageProcessor, TableTransformerForObjectDetection
import torch
from PIL import Image
file_path = "./images/demo.jpg"
image = Image.open(file_path).convert("RGB")
file_name = file_path.split('/')[-1].split('.')[0]
......
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]
i = 0
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
region = image.crop(box) # 检测
region.save(f'./images/{file_name}_{i}.png')
i += 1
结果如下:
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.
Detected table with confidence 0.998 at location [429.76, 217.22, 730.43, 441.51]
Detected table with confidence 0.997 at location [89.02, 215.52, 372.13, 440.98]
4、表格结构识别
部分代码如下:
from transformers import DetrFeatureExtractor, TableTransformerForObjectDetection
import torch
from PIL import Image
#使用 DetrFeatureExtractor
feature_extractor = DetrFeatureExtractor()
file_path = "./images/demo_1.png"
image = Image.open(file_path).convert("RGB")
# 对图像进行编码处理
encoding = feature_extractor(images=image, return_tensors="pt")
......
# 前向推理
with torch.no_grad():
outputs = model(**encoding)
target_sizes = [image.size[::-1]]
results = feature_extractor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
print(results)
columns_box_list = [results['boxes'][i].tolist() for i in range(len(results['boxes'])) if results['labels'][i].item()==3]
for idx, box in enumerate(columns_box_list):
print(idx)
crop_image = image.crop(box)
crop_image.save(f'header_{idx}.png')
结果如下:
{0: 'table', 1: 'table column', 2: 'table row', 3: 'table column header', 4: 'table projected row header', 5: 'table spanning cell'}
{'scores': tensor([0.9938, 0.9916, 0.9990, 0.9951, 0.9988, 0.9951, 0.9914, 0.9138, 0.9976,
0.9996]), 'labels': tensor([2, 2, 1, 2, 1, 1, 2, 3, 1, 0]), 'boxes': tensor([[ 17.4817, 111.3828, 246.3970, 153.2920],
[ 17.5167, 66.2340, 246.3717, 105.3657],
[ 17.5555, 34.3023, 66.3797, 183.1057],
[ 17.4001, 33.6988, 246.5164, 65.5520],
[146.3985, 34.1393, 225.7526, 182.9028],
[226.7389, 34.1763, 247.0905, 183.1046],
[ 17.4087, 151.3451, 246.1476, 183.2057],
[ 17.1707, 33.7522, 246.4425, 64.2807],
[ 67.4367, 33.9697, 146.2301, 183.5033],
[ 17.5343, 34.2323, 246.6799, 183.0271]])}
5、表格数据提取
部分代码如下:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
......
def paddle_ocr(image_path):
result = ocr.ocr(image_path, cls=True)
ocr_result = []
for idx in range(len(result)):
res = result[idx]
if res:
for line in res:
print(line)
ocr_result.append(line[1][0])
return "".join(ocr_result)
def table_detect(image_box, image_url):
if not image_url:
file_name = str(uuid4())
image = Image.fromarray(image_box).convert('RGB')
else:
image_path = f"./images/{uuid4()}.png"
file_name = image_path.split('/')[-1].split('.')[0]
urlretrieve(image_url, image_path)
image = Image.open(image_path).convert('RGB')
inputs = image_processor(images=image, return_tensors="pt")
outputs = detect_model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]
i = 0
output_images = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {detect_model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
region = image.crop(box) # 检测
output_image_path = f'./images/{file_name}_{i}.jpg'
region.save(output_image_path)
output_images.append(output_image_path)
i += 1
print(f"output_images:{output_images}")
return output_images
def table_ocr(output_images):
# Debugging line to check the contents of output_images
print(f"Type of output_images: {type(output_images)}, Contents: {output_images}")
# Assuming the first element of the list contains the image path.
# Let's check the type of the first element to make sure it's a string.
if len(output_images) > 0:
first_image = output_images[0][0]
print(f"Type of first_image: {type(first_image)}, Contents: {first_image}")
# If it prints out that `first_image` is indeed a string, then you can proceed
# with opening the image as you were doing:
image = Image.open(first_image).convert("RGB")
#image = Image.open(output_image_path).convert("RGB")
encoding = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = structure_model(**encoding)
target_sizes = [image.size[::-1]]
results = feature_extractor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]
print(f"results: {results}\n")
# get column and row
columns = []
rows = []
for i in range(len(results['boxes'])):
_id = results['labels'][i].item()
if _id == 1:##-----列内容
columns.append(results['boxes'][i].tolist())
elif _id == 2:##-----行内容
rows.append(results['boxes'][i].tolist())
sorted_columns = sorted(columns, key=lambda x: x[0])
sorted_rows = sorted(rows, key=lambda x: x[1])
# ocr by cell
ocr_results = []
for row in sorted_rows:
row_result = []
for col in sorted_columns:
rect = [col[0], row[1], col[2], row[3]]
crop_image = image.crop(rect)
image_path = 'cell.png'
crop_image.save(image_path)
row_result.append(paddle_ocr(image_path=image_path))
print(f"row_result: {row_result}\n")
ocr_results.append(row_result)
print(f"ocr_results: {ocr_results}\n")
return ocr_results
if __name__ == '__main__':
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_box = gr.Image()
image_urls = gr.TextArea(lines=1, placeholder="Enter image url", label="Images")
# image_index = gr.TextArea(lines=1, placeholder="Image Number", label="No")
with gr.Column():
gallery = gr.Gallery(label="Tables", show_label=False, elem_id="gallery", columns=[3], rows=[1],
object_fit="contain", height="auto")
detect = gr.Button("Table Detection")
submit = gr.Button("Table OCR")
ocr_outputs = gr.DataFrame(label='Table',
interactive=True,
wrap=True)
detect.click(fn=table_detect,
inputs=[image_box, image_urls],
outputs=gallery)
submit.click(fn=table_ocr,
inputs=[gallery],
outputs=ocr_outputs)
demo.launch(server_name="0.0.0.0", server_port=7676, share=True)
结果如下:
6、总结
表格识别结果,第一列和最后一列有时候识别不出。问题原因可能表格结构识别存在问题,后续会继续优化。