Skip to content

Commit

Permalink
Merge pull request #83 from RapidAI/optim_wired_logic_decode
Browse files Browse the repository at this point in the history
fix: optim logic points decode
  • Loading branch information
Joker1212 authored Nov 26, 2024
2 parents 574391d + 6e5f7c4 commit 639f6f7
Showing 1 changed file with 40 additions and 48 deletions.
88 changes: 40 additions & 48 deletions wired_table_rec/table_recover.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,57 +66,49 @@ def get_benchmark_cols(
) -> Tuple[np.ndarray, List[float], int]:
longest_col = max(rows.values(), key=lambda x: len(x))
longest_col_points = polygons[longest_col]
longest_x = longest_col_points[:, 0, 0]

longest_x_start = list(longest_col_points[:, 0, 0])
longest_x_end = list(longest_col_points[:, 2, 0])
min_x = longest_x_start[0]
max_x = longest_x_end[-1]
theta = 15
for row_value in rows.values():
cur_row = polygons[row_value][:, 0, 0]

range_res = {}
for idx, cur_v in enumerate(cur_row):
start_idx, end_idx = None, None
for i, v in enumerate(longest_x):
if cur_v - theta <= v <= cur_v + theta:
break

if cur_v > v:
start_idx = i
continue
# 根据当前col的起始x坐标,更新col的边界
def update_longest_col(col_x_list, cur_v, min_x_, max_x_):
for i, v in enumerate(col_x_list):
if cur_v - theta <= v <= cur_v + theta:
break
if cur_v > v:
continue
if cur_v < min_x_:
col_x_list.insert(0, cur_v)
min_x_ = cur_v
break
if cur_v > max_x_:
col_x_list.append(max_x_)
max_x_ = cur_v
if cur_v < v:
col_x_list.insert(i, cur_v)
break
return min_x_, max_x_

if cur_v < v:
end_idx = i
break
for row_value in rows.values():
cur_row_start = list(polygons[row_value][:, 0, 0])
cur_row_end = list(polygons[row_value][:, 2, 0])
for idx, (cur_v_start, cur_v_end) in enumerate(
zip(cur_row_start, cur_row_end)
):
min_x, max_x = update_longest_col(
longest_x_start, cur_v_start, min_x, max_x
)
min_x, max_x = update_longest_col(
longest_x_start, cur_v_end, min_x, max_x
)

range_res[idx] = [start_idx, end_idx]

sorted_res = dict(
sorted(range_res.items(), key=lambda x: x[0], reverse=True)
)
for k, v in sorted_res.items():
# bugfix: https://github.com/RapidAI/TableStructureRec/discussions/55
# 最长列不包含第一列和最后一列的场景需要兼容
if all(v) or v[1] == 0:
longest_x = np.insert(longest_x, v[1], cur_row[k])
longest_col_points = np.insert(
longest_col_points, v[1], polygons[row_value[k]], axis=0
)
elif v[0] and v[0] + 1 == len(longest_x):
longest_x = np.append(longest_x, cur_row[k])
longest_col_points = np.append(
longest_col_points,
polygons[row_value[k]][np.newaxis, :, :],
axis=0,
)
# 求出最右侧所有cell的宽,其中最小的作为最后一列宽度
rightmost_idxs = [v[-1] for v in rows.values()]
rightmost_boxes = polygons[rightmost_idxs]
min_width = min([self.compute_L2(v[3, :], v[0, :]) for v in rightmost_boxes])

each_col_widths = (longest_x[1:] - longest_x[:-1]).tolist()
each_col_widths.append(min_width)

col_nums = longest_x.shape[0]
return longest_col_points, each_col_widths, col_nums
longest_x_start = np.array(longest_x_start)
each_col_widths = (longest_x_start[1:] - longest_x_start[:-1]).tolist()
each_col_widths.append(max_x - longest_x_start[-1])
col_nums = longest_x_start.shape[0]
return longest_x_start, each_col_widths, col_nums

def get_benchmark_rows(
self, rows: Dict[int, List], polygons: np.ndarray
Expand Down Expand Up @@ -160,7 +152,7 @@ def get_merge_cells(
box_width = self.compute_L2(box[3, :], box[0, :])

# 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置
loc_col_idx = np.argmin(np.abs(longest_col[:, 0, 0] - box[0, 0]))
loc_col_idx = np.argmin(np.abs(longest_col - box[0, 0]))
col_start = max(sum(one_col_result.values()), loc_col_idx)

# 计算合并多少个列方向单元格
Expand Down

0 comments on commit 639f6f7

Please sign in to comment.