From 6e5f7c41ab834a92f618f0e649bfa98622329a19 Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Tue, 26 Nov 2024 11:34:26 +0800 Subject: [PATCH] fix: optim logic points decode --- wired_table_rec/table_recover.py | 88 +++++++++++++++----------------- 1 file changed, 40 insertions(+), 48 deletions(-) diff --git a/wired_table_rec/table_recover.py b/wired_table_rec/table_recover.py index 8e89672..5720d07 100644 --- a/wired_table_rec/table_recover.py +++ b/wired_table_rec/table_recover.py @@ -66,57 +66,49 @@ def get_benchmark_cols( ) -> Tuple[np.ndarray, List[float], int]: longest_col = max(rows.values(), key=lambda x: len(x)) longest_col_points = polygons[longest_col] - longest_x = longest_col_points[:, 0, 0] - + longest_x_start = list(longest_col_points[:, 0, 0]) + longest_x_end = list(longest_col_points[:, 2, 0]) + min_x = longest_x_start[0] + max_x = longest_x_end[-1] theta = 15 - for row_value in rows.values(): - cur_row = polygons[row_value][:, 0, 0] - - range_res = {} - for idx, cur_v in enumerate(cur_row): - start_idx, end_idx = None, None - for i, v in enumerate(longest_x): - if cur_v - theta <= v <= cur_v + theta: - break - if cur_v > v: - start_idx = i - continue + # 根据当前col的起始x坐标,更新col的边界 + def update_longest_col(col_x_list, cur_v, min_x_, max_x_): + for i, v in enumerate(col_x_list): + if cur_v - theta <= v <= cur_v + theta: + break + if cur_v > v: + continue + if cur_v < min_x_: + col_x_list.insert(0, cur_v) + min_x_ = cur_v + break + if cur_v > max_x_: + col_x_list.append(max_x_) + max_x_ = cur_v + if cur_v < v: + col_x_list.insert(i, cur_v) + break + return min_x_, max_x_ - if cur_v < v: - end_idx = i - break + for row_value in rows.values(): + cur_row_start = list(polygons[row_value][:, 0, 0]) + cur_row_end = list(polygons[row_value][:, 2, 0]) + for idx, (cur_v_start, cur_v_end) in enumerate( + zip(cur_row_start, cur_row_end) + ): + min_x, max_x = update_longest_col( + longest_x_start, cur_v_start, min_x, max_x + ) + min_x, max_x = update_longest_col( + longest_x_start, cur_v_end, min_x, max_x + ) - range_res[idx] = [start_idx, end_idx] - - sorted_res = dict( - sorted(range_res.items(), key=lambda x: x[0], reverse=True) - ) - for k, v in sorted_res.items(): - # bugfix: https://github.com/RapidAI/TableStructureRec/discussions/55 - # 最长列不包含第一列和最后一列的场景需要兼容 - if all(v) or v[1] == 0: - longest_x = np.insert(longest_x, v[1], cur_row[k]) - longest_col_points = np.insert( - longest_col_points, v[1], polygons[row_value[k]], axis=0 - ) - elif v[0] and v[0] + 1 == len(longest_x): - longest_x = np.append(longest_x, cur_row[k]) - longest_col_points = np.append( - longest_col_points, - polygons[row_value[k]][np.newaxis, :, :], - axis=0, - ) - # 求出最右侧所有cell的宽,其中最小的作为最后一列宽度 - rightmost_idxs = [v[-1] for v in rows.values()] - rightmost_boxes = polygons[rightmost_idxs] - min_width = min([self.compute_L2(v[3, :], v[0, :]) for v in rightmost_boxes]) - - each_col_widths = (longest_x[1:] - longest_x[:-1]).tolist() - each_col_widths.append(min_width) - - col_nums = longest_x.shape[0] - return longest_col_points, each_col_widths, col_nums + longest_x_start = np.array(longest_x_start) + each_col_widths = (longest_x_start[1:] - longest_x_start[:-1]).tolist() + each_col_widths.append(max_x - longest_x_start[-1]) + col_nums = longest_x_start.shape[0] + return longest_x_start, each_col_widths, col_nums def get_benchmark_rows( self, rows: Dict[int, List], polygons: np.ndarray @@ -160,7 +152,7 @@ def get_merge_cells( box_width = self.compute_L2(box[3, :], box[0, :]) # 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置 - loc_col_idx = np.argmin(np.abs(longest_col[:, 0, 0] - box[0, 0])) + loc_col_idx = np.argmin(np.abs(longest_col - box[0, 0])) col_start = max(sum(one_col_result.values()), loc_col_idx) # 计算合并多少个列方向单元格