1. pdf处理
如果是可编辑的pdf格式,那么可以直接用pdfplumber进行处理:
import pdfplumber
import pandas as pd
with pdfplumber.open("中新科技:2015年年度报告摘要.PDF") as pdf:
page = pdf.pages[1] # 第一页的信息
text = page.extract_text()
print(text)
table = page.extract_tables()
for t in table:
# 得到的table是嵌套list类型,转化成DataFrame更加方便查看和分析
df = pd.DataFrame(t[1:], columns=t[0])
print(df)
如果是图片格式的pdf,可以使用pdf2image库将pdf转为图片后再继续后面的流程:
from pdf2image import convert_from_path
img = np.array(convert_from_path(path, dpi=800, use_cropbox=True)[0])
2. 表格位置检测
2.1 使用ppstructure
使用paddleocr库中的ppstructure可以方便获取表格位置,参考代码:
from paddleocr import PPStructure
structure = table_engine(source_img)
2.2 使用tabledetector
import tabledetector as td
result = td.detect(pdf_path="pdf_path", type="bordered", rotation=False, method='detect')
2.3 使用cv2的图形学方法
调试简单,具体代码如下:
- 二值化去除水印
- 使用getStructuringElement获取纵线和横线
- 两者合并,使用findContours获取表格外边框和内部单元格
3. 位置确认
获取所有单元格后,使用下面的函数获取单元格的相对位置关系:
from typing import Dict, List, Tuple
import numpy as np
class TableRecover:
def __init__(
self,
):
pass
def __call__(self, polygons: np.ndarray) -> Dict[int, Dict]:
rows = self.get_rows(polygons)
longest_col, each_col_widths, col_nums = self.get_benchmark_cols(rows, polygons)
each_row_heights, row_nums = self.get_benchmark_rows(rows, polygons)
table_res = self.get_merge_cells(
polygons,
rows,
row_nums,
col_nums,
longest_col,
each_col_widths,
each_row_heights,
)
return table_res
@staticmethod
def get_rows(polygons: np.array) -> Dict[int, List[int]]:
"""对每个框进行行分类,框定哪个是一行的"""
y_axis = polygons[:, 0, 1]
if y_axis.size == 1:
return {0: [0]}
concat_y = np.array(list(zip(y_axis, y_axis[1:])))
minus_res = concat_y[:, 1] - concat_y[:, 0]
result = {}
thresh = 5.0
split_idxs = np.argwhere(minus_res > thresh).squeeze()
if split_idxs.ndim == 0:
split_idxs = split_idxs[None, ...]
if max(split_idxs) != len(minus_res):
split_idxs = np.append(split_idxs, len(minus_res))
start_idx = 0
for row_num, idx in enumerate(split_idxs):
if row_num != 0:
start_idx = split_idxs[row_num - 1] + 1
result.setdefault(row_num, []).extend(range(start_idx, idx + 1))
# 计算每一行相邻cell的iou,如果大于0.2,则合并为同一个cell
return result
def get_benchmark_cols(
self, rows: Dict[int, List], polygons: np.ndarray
) -> Tuple[np.ndarray, List[float], int]:
longest_col = max(rows.values(), key=lambda x: len(x))
longest_col_points = polygons[longest_col]
longest_x = longest_col_points[:, 0, 0]
theta = 10
for row_value in rows.values():
cur_row = polygons[row_value][:, 0, 0]
range_res = {}
for idx, cur_v in enumerate(cur_row):
start_idx, end_idx = None, None
for i, v in enumerate(longest_x):
if cur_v - theta <= v <= cur_v + theta:
break
if cur_v > v:
start_idx = i
continue
if cur_v < v:
end_idx = i
break
range_res[idx] = [start_idx, end_idx]
sorted_res = dict(
sorted(range_res.items(), key=lambda x: x[0], reverse=True)
)
for k, v in sorted_res.items():
if v[0]==None or v[1]==None:
continue
longest_x = np.insert(longest_x, v[1], cur_row[k])
longest_col_points = np.insert(
longest_col_points, v[1], polygons[row_value[k]], axis=0
)
# 求出最右侧所有cell的宽,其中最小的作为最后一列宽度
rightmost_idxs = [v[-1] for v in rows.values()]
rightmost_boxes = polygons[rightmost_idxs]
min_width = min([self.compute_L2(v[3, :], v[0, :]) for v in rightmost_boxes])
each_col_widths = (longest_x[1:] - longest_x[:-1]).tolist()
each_col_widths.append(min_width)
col_nums = longest_x.shape[0]
return longest_col_points, each_col_widths, col_nums
def get_benchmark_rows(
self, rows: Dict[int, List], polygons: np.ndarray
) -> Tuple[np.ndarray, List[float], int]:
leftmost_cell_idxs = [v[0] for v in rows.values()]
benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1]
theta = 10
# 遍历其他所有的框,按照y轴进行区间划分
range_res = {}
for cur_idx, cur_box in enumerate(polygons):
if cur_idx in benchmark_x:
continue
cur_y = cur_box[0, 1]
start_idx, end_idx = None, None
for i, v in enumerate(benchmark_x):
if cur_y - theta <= v <= cur_y + theta:
break
if cur_y > v:
start_idx = i
continue
if cur_y < v:
end_idx = i
break
range_res[cur_idx] = [start_idx, end_idx]
sorted_res = dict(sorted(range_res.items(), key=lambda x: x[0], reverse=True))
for k, v in sorted_res.items():
if v[0]==None or v[1]==None:
continue
benchmark_x = np.insert(benchmark_x, v[1], polygons[k][0, 1])
each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist()
# 求出最后一行cell中,最大的高度作为最后一行的高度
bottommost_idxs = list(rows.values())[-1]
bottommost_boxes = polygons[bottommost_idxs]
max_height = max([self.compute_L2(v[3, :], v[0, :]) for v in bottommost_boxes])
each_row_widths.append(max_height)
row_nums = benchmark_x.shape[0]
return each_row_widths, row_nums
@staticmethod
def compute_L2(a1: np.ndarray, a2: np.ndarray) -> float:
return np.linalg.norm(a2 - a1)
def get_merge_cells(
self,
polygons: np.ndarray,
rows: Dict,
row_nums: int,
col_nums: int,
longest_col: np.ndarray,
each_col_widths: List[float],
each_row_heights: List[float],
) -> Dict[int, Dict[int, int]]:
col_res_merge, row_res_merge = {}, {}
merge_thresh = 20
for cur_row, col_list in rows.items():
one_col_result, one_row_result = {}, {}
for one_col in col_list:
box = polygons[one_col]
box_width = self.compute_L2(box[3, :], box[0, :])
# 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置
loc_col_idx = np.argmin(np.abs(longest_col[:, 0, 0] - box[0, 0]))
merge_col_cell = max(sum(one_col_result.values()), loc_col_idx)
# 计算合并多少个列方向单元格
for i in range(merge_col_cell, col_nums):
col_cum_sum = sum(each_col_widths[merge_col_cell : i + 1])
if i == merge_col_cell and col_cum_sum > box_width:
one_col_result[one_col] = 1
break
elif abs(col_cum_sum - box_width) <= merge_thresh:
one_col_result[one_col] = i + 1 - merge_col_cell
break
else:
one_col_result[one_col] = i + 1 - merge_col_cell + 1
box_height = self.compute_L2(box[1, :], box[0, :])
merge_row_cell = cur_row
for j in range(merge_row_cell, row_nums):
row_cum_sum = sum(each_row_heights[merge_row_cell : j + 1])
# box_height 不确定是几行的高度,所以要逐个试验,找一个最近的几行的高
# 如果第一次row_cum_sum就比box_height大,那么意味着?丢失了一行
if j == merge_row_cell and row_cum_sum > box_height:
one_row_result[one_col] = 1
break
elif abs(box_height - row_cum_sum) <= merge_thresh:
one_row_result[one_col] = j + 1 - merge_row_cell
break
else:
one_row_result[one_col] = j + 1 - merge_row_cell + 1
col_res_merge[cur_row] = one_col_result
row_res_merge[cur_row] = one_row_result
res = {}
for i, (c, r) in enumerate(zip(col_res_merge.values(), row_res_merge.values())):
res[i] = {k: [cc, r[k]] for k, cc in c.items()}
return res
调用代码如下:
h_min = 10
h_max = 5000
def sortContours(cnts, method='left-to-right'):
reverse = False
i = 0
if method == "right-to-left" or method == "bottom-to-top":
reverse = True
if method == "top-to-bottom" or method == "bottom-to-top":
i = 1
boundingBoxes = [cv2.boundingRect(c) for c in cnts]
(cnts, boundingBoxes) = zip(*sorted(zip(cnts, boundingBoxes),key=lambda b: b[1][i], reverse=reverse))
return (cnts, boundingBoxes)
def sorted_boxes(dt_boxes):
num_boxes = dt_boxes.shape[0]
dt_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(dt_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if (
abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
and _boxes[j + 1][0][0] < _boxes[j][0][0]
):
_boxes[j], _boxes[j + 1] = _boxes[j + 1], _boxes[j]
else:
break
return _boxes
def getBboxDtls(raw):
######### 1. 获得表格的边框,确保merge正确展示了图中的表格边框
gray = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)
binary = 255-cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY)[1]
rows, cols = binary.shape
scale = 30
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (cols // scale, 1))
eroded = cv2.erode(binary, kernel, iterations=1)
dilated_col = cv2.dilate(eroded, kernel, iterations=1)
scale = 20
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, rows // scale))
eroded = cv2.erode(binary, kernel, iterations=1)
dilated_row = cv2.dilate(eroded, kernel, iterations=1)
merge = cv2.add(dilated_col, dilated_row)
kernel = np.ones((3,3),np.uint8)
merge = cv2.erode(cv2.dilate(merge, kernel, iterations=3), kernel, iterations=3)
plt.figure(figsize=(60,30))
io.imshow(merge[1500:2500])
########## 2. 获取表格坐标
tableData = []
contours = cv2.findContours(merge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = contours[0] if len(contours) == 2 else contours[1]
contours, boundingBoxes = sortContours(contours, method='top-to-bottom')
# 获取表格外边框
for c in contours:
x, y, w, h = cv2.boundingRect(c)
if (h>h_min):
tableData.append((x, y, w, h))
# 获取表格内部的单元格
contours, hierarchy = cv2.findContours(merge, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contours, boundingBoxes = sortContours(contours, method="top-to-bottom")
boxes = []
for c in contours:
x, y, w, h = cv2.boundingRect(c)
if (h>h_min) and (h<h_max):
boxes.append([x, y, w, h])
########## 3. 计算表格单元格位置关系
bboxDtls = {}
for tableBox1 in tableData:
key = tableBox1
values = []
for tableBox2 in boxes:
x2, y2, w2, h2 = tableBox2
if tableBox1[0] <= x2 <= tableBox1[0] + tableBox1[2] and tableBox1[1] <= y2 <= tableBox1[1] + tableBox1[3]:
values.append(tableBox2)
bboxDtls[key] = values
for key, values in bboxDtls.items():
x_tab, y_tab, w_tab, h_tab = key
for box in values:
x_box, y_box, w_box, h_box = box
return bboxDtls
4. 表格文字识别
4.1 wired_table_rec
可以尝试使用wired_table_rec进行识别:
from wired_table_rec import WiredTableRecognition
table_rec = WiredTableRecognition()
table_str = table_rec(cv2.imread(img_path))[0]
HTML(table_str)
4.2 rapidocr_onnxruntime或者pytessect
或者可以使用更原子化的ocr服务,逐个单元格进行ocr识别,完整代码如下:
"""
首先安装pdf2image和rapidocr_onnxruntime两个库。
图像处理部分的参数和代码可以自行调整:
1. pad参数用于去除图片的边框
2. 转pdf时,有时候800dpi会失败,因此需要加入try except
3. 图片太小时ocr效果不好,因此做了resize。这里的3000可以自行调整。
4. 大图片做了二值化处理,目的是去除水印的干扰。这里的180也可以尝试自行调整。
5. 只处理第一个单元格总数大于50的表格。如果要识别图片中所有表格,可修改代码。
6. 返回的是html格式的表格,可以用pd.read_html函数转为dataframe
"""
rocr = RapidOCR()
rocr.text_det.preprocess_op = DetPreProcess(736, 'max')
def getResult(path,pad = 20, resize_thresh=3000, binary_thresh=180):
if 'pdf' in path:
try:
source_img = np.array(convert_from_path(path, dpi=800, use_cropbox=True)[0])[pad:-pad,pad:-pad]
except:
source_img = np.array(convert_from_path(path, dpi=300, use_cropbox=True)[0])[pad:-pad,pad:-pad]
else:
source_img = cv2.imread(path)[pad:-pad,pad:-pad]
if source_img.shape[1] < resize_thresh:
source_img =cv2.resize(source_img,(resize_thresh,int(source_img.shape[0]/source_img.shape[1]*resize_thresh)))
img = cv2.threshold(cv2.cvtColor(source_img, cv2.COLOR_BGR2GRAY), binary_thresh, 255, cv2.THRESH_BINARY)[1]
bboxDtls = getBboxDtls(source_img)
boxes = []
table = None
# 寻找到第一个单元格数大于50的表后停止
for k,v in bboxDtls.items():
if len(v)>50:
table = k
for r in tqdm(v[1:]):
res = rocr(img[r[1]: r[1]+r[3],r[0]:r[2]+r[0]])[0]
if res!=None:
res.sort(key = lambda x:(x[0][0][1]//(img.shape[1]//20),x[0][0][0]//(img.shape[0]//20)))
boxes.append([[r[0], r[1]], [r[0], r[1]+r[3]],[r[0]+r[2], r[1]+r[3]], [r[0]+r[2], r[1]], ''.join([t[1].replace('\n','').replace(' ','') for t in res])])
else:
boxes.append([[r[0], r[1]], [r[0], r[1]+r[3]],[r[0]+r[2], r[1]+r[3]], [r[0]+r[2], r[1]], ''])
break
polygons = sorted_boxes(np.array(boxes))
texts = [p[4] for p in polygons]
tr = TableRecover()
table_res = tr(np.array([[np.array(p[0]),np.array(p[1]),np.array(p[2]),np.array(p[3])] for p in polygons]))
table_html = """<table border="1" cellspacing="0">"""
for vs in table_res.values():
table_html+="<tr>"
for i,v in vs.items():
table_html+=f"""<td colspan="{v[0]}" rowspan="{v[1]}">{texts[i]}</td>"""
table_html+="</tr>"
table_html+="""</table>"""
return table_html
原图为:https://www.95598.cn/omg-static/99107281818076039603801539578309.jpg
最终识别出来的结果如下: