Skip to content

Commit

Permalink
refactor: Refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Jun 18, 2024
1 parent d171687 commit 12ba81c
Show file tree
Hide file tree
Showing 15 changed files with 811 additions and 365 deletions.
40 changes: 9 additions & 31 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,11 @@
import cv2
import numpy as np

from rapid_layout import RapidLayout
from rapid_layout import RapidLayout, VisLayout
from rapid_orientation import RapidOrientation
from rapid_table import RapidTable, VisTable


def vis_layout(img: np.ndarray, layout_res: list) -> None:
tmp_img = copy.deepcopy(img)
for v in layout_res:
bbox = np.round(v["bbox"]).astype(np.int32)
label = v["label"]

start_point = (bbox[0], bbox[1])
end_point = (bbox[2], bbox[3])

cv2.rectangle(tmp_img, start_point, end_point, (0, 255, 0), 2)
cv2.putText(
tmp_img, label, start_point, cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 255), 2
)

draw_img_save = Path("./inference_results/")
if not draw_img_save.exists():
draw_img_save.mkdir(parents=True, exist_ok=True)

image_save = str(draw_img_save / "layout_result.jpg")
cv2.imwrite(image_save, tmp_img)
print(f"The infer result has saved in {image_save}")


def vis_table(table_res):
style_res = """<style>td {border-left: 1px solid;border-bottom:1px solid;}
table, th {border-top:1px solid;font-size: 10px;
Expand All @@ -54,14 +31,15 @@ def vis_table(table_res):


def demo_layout():
layout_engine = RapidLayout()
layout_engine = RapidLayout(box_threshold=0.5, model_type="pp_layout_cdla")

img = cv2.imread("test_images/layout.png")
img_path = "tests/test_files/layout.png"
img = cv2.imread(img_path)

layout_res, _ = layout_engine(img)
boxes, scores, class_names, *elapse = layout_engine(img)

vis_layout(img, layout_res)
print(layout_res)
ploted_img = VisLayout.draw_detections(img, boxes, scores, class_names)
cv2.imwrite("layout_res.png", ploted_img)


def demo_table():
Expand Down Expand Up @@ -101,6 +79,6 @@ def demo_orientation():


if __name__ == "__main__":
# demo_layout()
demo_table()
demo_layout()
# demo_table()
# demo_orientation()
39 changes: 22 additions & 17 deletions docs/README_Layout.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,40 @@
- 具体来说,就是分析给定的文档类别图像(论文截图等),定位其中类别和位置,如标题、段落、表格和图片等各个部分。
- 目前支持三种类别的版面分析模型:中文、英文和表格版面分析模型,具体可参见下面表格:

| 模型类型 | 模型名称 | 模型大小 | 支持种类 |
| :------: | :---------------------: | :------: | :---------------------------------------------------------------------------------------------: |
| 表格 | `layout_table.onnx` | 7.06M | `table` |
| 英文 | `layout_publaynet.onnx` | 7.06M | `text title list table figure` |
| 中文 | `layout_cdla.onnx` | 7.07M | `text title figure figure_caption table table_caption` <br> `header footer reference equation` |
|`model_type`| 版面类型 | 模型名称 | 支持类别|
| :------ | :----- | :------ | :----- |
|`pp_layout_table`| 表格 | `layout_table.onnx` |`table` |
| `pp_layout_publaynet`| 英文 | `layout_publaynet.onnx` |`text title list table figure` |
| `pp_layout_table`| 中文 | `layout_cdla.onnx` | `text title figure figure_caption table table_caption` <br> `header footer reference equation` |

- 模型下载地址为:[百度网盘](https://pan.baidu.com/s/1PI9fksW6F6kQfJhwUkewWg?pwd=p29g) | [Google Drive](https://drive.google.com/drive/folders/1DAPWSN2zGQ-ED_Pz7RaJGTjfkN2-Mvsf?usp=sharing)

#### 安装
由于模型较小,预先将中文版面分析模型(`layout_cdla.onnx`)打包进了whl包内,如果做中文版面分析,可直接安装使用

```bash
$ pip install rapid-layout
```

#### 使用方式
1. pip安装
- 由于模型较小,预先将中文版面分析模型(`layout_cdla.onnx`)打包进了whl包内,如果做中文版面分析,可直接安装使用
```bash
$ pip install rapid-layout
```
2. python脚本运行
1. python脚本运行
```python
import cv2
from rapid_layout import RapidLayout
from rapid_layout import RapidLayout,vis_layout

# RapidLayout类提供model_path参数,可以自行指定上述3个模型,默认是layout_cdla.onnx
# layout_engine = RapidLayout(model_path='layout_publaynet.onnx')
layout_engine = RapidLayout()
# model_type类型参见上表。指定不同model_type时,会自动下载相应模型到安装目录下的。
layout_engine = RapidLayout(box_threshold=0.5, model_type="pp_layout_cdla")

img = cv2.imread('test_images/layout.png')

layout_res, elapse = layout_engine(img)

ploted_img = vis_layout(img, layout_res)
cv2.imwrite("layout_res.png", ploted_img)
print(layout_res)
```

3. 终端运行
2. 终端运行
- 用法:
```bash
$ rapid_layout -h
Expand All @@ -58,7 +63,7 @@
$ rapid_layout -v -img test_images/layout.png
```

4. 结果
3. 结果
- 返回结果
```python
# bbox: [左上角x0,左上角y0, 右下角x1, 右下角y1]
Expand Down
Binary file added layout_res.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion rapid_layout/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .rapid_layout import RapidLayout
from .main import RapidLayout
from .utils import VisLayout
92 changes: 64 additions & 28 deletions rapid_layout/rapid_layout.py → rapid_layout/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,99 +14,135 @@
import argparse
import time
from pathlib import Path
from typing import Union
from typing import Optional, Union

import cv2
import numpy as np

from .utils import (
DownloadModel,
LoadImage,
OrtInferSession,
PicoDetPostProcess,
create_operators,
get_logger,
read_yaml,
transform,
vis_layout,
LoadImage,
)

root_dir = Path(__file__).resolve().parent
ROOT_DIR = Path(__file__).resolve().parent
logger = get_logger("rapid_layout")

ROOT_URL = "https://github.com/RapidAI/RapidStructure/releases/download/v0.0.0/"
KEY_TO_MODEL_URL = {
"pp_layout_cdla": f"{ROOT_URL}/layout_cdla.onnx",
"pp_layout_publaynet": f"{ROOT_URL}/layout_publaynet.onnx",
"pp_layout_table": f"{ROOT_URL}/layout_table.onnx",
}
DEFAULT_MODEL_PATH = str(ROOT_DIR / "models" / "layout_cdla.onnx")


class RapidLayout:
def __init__(self, model_path: str = None):
config_path = str(root_dir / "config.yaml")
def __init__(
self,
model_type: str = "pp_layout_cdla",
box_threshold: float = 0.5,
use_cuda: bool = False,
):
config_path = str(ROOT_DIR / "config.yaml")
config = read_yaml(config_path)
if model_path is None:
model_path = str(root_dir / "models" / "layout_cdla.onnx")
config["model_path"] = model_path
config["model_path"] = self.get_model_path(model_type)
config["use_cuda"] = use_cuda

self.session = OrtInferSession(config)
labels = self.session.get_metadata()["character"].splitlines()
labels = self.session.get_character_list()
logger.info("%s contains %s", model_type, labels)

self.preprocess_op = create_operators(config["pre_process"])

config["post_process"]["score_threshold"] = box_threshold
self.postprocess_op = PicoDetPostProcess(labels, **config["post_process"])
self.load_img = LoadImage()

def get_model_path(self, model_type: Optional[str] = None) -> str:
model_url = KEY_TO_MODEL_URL.get(model_type, None)
if model_url:
model_path = DownloadModel.download(model_url)
return model_path
logger.info("model url is None, using the default model %s", DEFAULT_MODEL_PATH)
return DEFAULT_MODEL_PATH

def __call__(self, img_content: Union[str, np.ndarray, bytes, Path]):
img = self.load_img(img_content)

ori_im = img.copy()
data = {"image": img}
data = transform(data, self.preprocess_op)
data = transform({"image": img}, self.preprocess_op)
img = data[0]

if img is None:
return None, 0
return None, None, None, 0

img = np.expand_dims(img, axis=0)
img = img.copy()

preds, elapse = 0, 1
starttime = time.time()

preds = self.session(img)

score_list, boxes_list = [], []
num_outs = int(len(preds) / 2)
for out_idx in range(num_outs):
score_list.append(preds[out_idx])
boxes_list.append(preds[out_idx + num_outs])
preds = dict(boxes=score_list, boxes_num=boxes_list)
post_preds = self.postprocess_op(ori_im, img, preds)

boxes, scores, class_names = self.postprocess_op(
ori_im, img, {"boxes": score_list, "boxes_num": boxes_list}
)
elapse = time.time() - starttime
return post_preds, elapse
return boxes, scores, class_names, elapse


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-v",
"--vis",
action="store_true",
help="Wheter to visualize the layout results.",
)
parser.add_argument(
"-img", "--img_path", type=str, required=True, help="Path to image for layout."
)
parser.add_argument(
"-m",
"--model_path",
"--model_type",
type=str,
default=str(root_dir / "models" / "layout_cdla.onnx"),
help="The model path used for inference.",
default=DEFAULT_MODEL_PATH,
choices=list(KEY_TO_MODEL_URL.keys()),
help="Support model type",
)
parser.add_argument(
"--box_threshold",
type=float,
default=0.5,
choices=list(KEY_TO_MODEL_URL.keys()),
help="Box threshold, the range is [0, 1]",
)
parser.add_argument(
"-v",
"--vis",
action="store_true",
help="Wheter to visualize the layout results.",
)
args = parser.parse_args()

layout_engine = RapidLayout(args.model_path)
layout_engine = RapidLayout(
model_type=args.model_type, box_threshold=args.box_threshold
)

img = cv2.imread(args.img_path)
layout_res, elapse = layout_engine(img)
print(layout_res)

if args.vis:
img_path = Path(args.img_path)
ploted_img = vis_layout(img, layout_res)
save_path = img_path.resolve().parent / f"vis_{img_path.name}"
vis_layout(img, layout_res, str(save_path))
cv2.imwrite(str(save_path), ploted_img)


if __name__ == "__main__":
Expand Down
18 changes: 18 additions & 0 deletions rapid_layout/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import yaml

from .download_model import DownloadModel
from .infer_engine import OrtInferSession
from .load_image import LoadImage
from .logger import get_logger
from .post_prepross import PicoDetPostProcess
from .pre_procss import create_operators, transform
from .vis_res import VisLayout


def read_yaml(yaml_path):
with open(yaml_path, "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
60 changes: 60 additions & 0 deletions rapid_layout/utils/download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import io
from pathlib import Path
from typing import Optional, Union

import requests
from tqdm import tqdm

from .logger import get_logger

logger = get_logger("DownloadModel")
CUR_DIR = Path(__file__).resolve()
PROJECT_DIR = CUR_DIR.parent.parent


class DownloadModel:
cur_dir = PROJECT_DIR

@classmethod
def download(cls, model_full_url: Union[str, Path]) -> str:
save_dir = cls.cur_dir / "models"
save_dir.mkdir(parents=True, exist_ok=True)

model_name = Path(model_full_url).name
save_file_path = save_dir / model_name
if save_file_path.exists():
logger.info("%s already exists", save_file_path)
return str(save_file_path)

try:
logger.info("Download %s to %s", model_full_url, save_dir)
file = cls.download_as_bytes_with_progress(model_full_url, model_name)
cls.save_file(save_file_path, file)
except Exception as exc:
raise DownloadModelError from exc
return str(save_file_path)

@staticmethod
def download_as_bytes_with_progress(url: str, name: Optional[str] = None) -> bytes:
resp = requests.get(url, stream=True, allow_redirects=True, timeout=180)
total = int(resp.headers.get("content-length", 0))
bio = io.BytesIO()
with tqdm(
desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024
) as pbar:
for chunk in resp.iter_content(chunk_size=65536):
pbar.update(len(chunk))
bio.write(chunk)
return bio.getvalue()

@staticmethod
def save_file(save_path: Union[str, Path], file: bytes):
with open(save_path, "wb") as f:
f.write(file)


class DownloadModelError(Exception):
pass
Loading

0 comments on commit 12ba81c

Please sign in to comment.