Skip to content

Commit

Permalink
refactor(rapid_layout): Configure automatic download model
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Jun 19, 2024
1 parent 12ba81c commit 34a878b
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 66 deletions.
6 changes: 2 additions & 4 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import copy
from pathlib import Path

import cv2
import numpy as np

from rapid_layout import RapidLayout, VisLayout
from rapid_orientation import RapidOrientation
Expand Down Expand Up @@ -37,9 +35,9 @@ def demo_layout():
img = cv2.imread(img_path)

boxes, scores, class_names, *elapse = layout_engine(img)

ploted_img = VisLayout.draw_detections(img, boxes, scores, class_names)
cv2.imwrite("layout_res.png", ploted_img)
if ploted_img is not None:
cv2.imwrite("layout_res.png", ploted_img)


def demo_table():
Expand Down
Binary file removed layout_res.png
Binary file not shown.
41 changes: 24 additions & 17 deletions rapid_layout/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import argparse
import time
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Tuple, Union

import cv2
import numpy as np
Expand All @@ -24,6 +24,7 @@
LoadImage,
OrtInferSession,
PicoDetPostProcess,
VisLayout,
create_operators,
get_logger,
read_yaml,
Expand Down Expand Up @@ -64,23 +65,16 @@ def __init__(
self.postprocess_op = PicoDetPostProcess(labels, **config["post_process"])
self.load_img = LoadImage()

def get_model_path(self, model_type: Optional[str] = None) -> str:
model_url = KEY_TO_MODEL_URL.get(model_type, None)
if model_url:
model_path = DownloadModel.download(model_url)
return model_path
logger.info("model url is None, using the default model %s", DEFAULT_MODEL_PATH)
return DEFAULT_MODEL_PATH

def __call__(self, img_content: Union[str, np.ndarray, bytes, Path]):
def __call__(
self, img_content: Union[str, np.ndarray, bytes, Path]
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], float]:
img = self.load_img(img_content)

ori_im = img.copy()
data = transform({"image": img}, self.preprocess_op)
img = data[0]

if img is None:
return None, None, None, 0
return None, None, None, 0.0

img = np.expand_dims(img, axis=0)
img = img.copy()
Expand All @@ -101,6 +95,16 @@ def __call__(self, img_content: Union[str, np.ndarray, bytes, Path]):
elapse = time.time() - starttime
return boxes, scores, class_names, elapse

@staticmethod
def get_model_path(model_type: str) -> str:
model_url = KEY_TO_MODEL_URL.get(model_type, None)
if model_url:
model_path = DownloadModel.download(model_url)
return model_path

logger.info("model url is None, using the default model %s", DEFAULT_MODEL_PATH)
return DEFAULT_MODEL_PATH


def main():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -135,14 +139,17 @@ def main():
)

img = cv2.imread(args.img_path)
layout_res, elapse = layout_engine(img)
print(layout_res)
boxes, scores, class_names, *elapse = layout_engine(img)
print(boxes)
print(scores)
print(class_names)

if args.vis:
img_path = Path(args.img_path)
ploted_img = vis_layout(img, layout_res)
save_path = img_path.resolve().parent / f"vis_{img_path.name}"
cv2.imwrite(str(save_path), ploted_img)
ploted_img = VisLayout.draw_detections(img, boxes, scores, class_names)
if ploted_img is not None:
save_path = img_path.resolve().parent / f"vis_{img_path.name}"
cv2.imwrite(str(save_path), ploted_img)


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions rapid_layout/utils/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def download(cls, model_full_url: Union[str, Path]) -> str:
return str(save_file_path)

@staticmethod
def download_as_bytes_with_progress(url: str, name: Optional[str] = None) -> bytes:
resp = requests.get(url, stream=True, allow_redirects=True, timeout=180)
def download_as_bytes_with_progress(
url: Union[str, Path], name: Optional[str] = None
) -> bytes:
resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180)
total = int(resp.headers.get("content-length", 0))
bio = io.BytesIO()
with tqdm(
Expand Down
51 changes: 10 additions & 41 deletions rapid_layout/utils/vis_res.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,22 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import copy
from typing import Any, Dict, List
from typing import Optional

import cv2
import numpy as np


def vis_layout(img: np.ndarray, layout_res: List[Dict[str, Any]]) -> np.ndarray:
font = cv2.FONT_HERSHEY_COMPLEX
font_scale = 1
font_color = (0, 0, 255)
font_thickness = 1

tmp_img = copy.deepcopy(img)
for v in layout_res:
bbox = np.round(v["bbox"]).astype(np.int32)
label = v["label"]

start_point = (bbox[0], bbox[1])
end_point = (bbox[2], bbox[3])

cv2.rectangle(tmp_img, start_point, end_point, (0, 255, 0), 2)

(w, h), _ = cv2.getTextSize(label, font, font_scale, font_thickness)
put_point = start_point[0], start_point[1] + h
cv2.putText(
tmp_img, label, put_point, font, font_scale, font_color, font_thickness
)
return tmp_img


class VisLayout:
@classmethod
def draw_detections(
cls,
image: np.ndarray,
boxes: np.ndarray,
scores: np.ndarray,
class_names: np.ndarray,
boxes: Optional[np.ndarray],
scores: Optional[np.ndarray],
class_names: Optional[np.ndarray],
mask_alpha=0.3,
):
) -> Optional[np.ndarray]:
"""_summary_
Args:
Expand All @@ -52,23 +27,23 @@ def draw_detections(
mask_alpha (float, optional): _description_. Defaults to 0.3.
Returns:
_type_: _description_
np.ndarray: _description_
"""
if boxes is None or scores is None or class_names is None:
return None

det_img = image.copy()

img_height, img_width = image.shape[:2]
font_size = min([img_height, img_width]) * 0.0006
text_thickness = int(min([img_height, img_width]) * 0.001)

det_img = cls.draw_masks(det_img, boxes, class_names, mask_alpha)
det_img = cls.draw_masks(det_img, boxes, mask_alpha)

# Draw bounding boxes and labels of detections
for label, box, score in zip(class_names, boxes, scores):
color = cls.get_color()

cls.draw_box(det_img, box, color)

caption = f"{label} {int(score * 100)}%"
cls.draw_text(det_img, caption, box, color, font_size, text_thickness)

Expand Down Expand Up @@ -120,18 +95,12 @@ def draw_masks(
cls,
image: np.ndarray,
boxes: np.ndarray,
classes: np.ndarray,
mask_alpha: float = 0.3,
) -> np.ndarray:
mask_img = image.copy()

# Draw bounding boxes and labels of detections
for box, class_name in zip(boxes, classes):
for box in boxes:
color = cls.get_color()

x1, y1, x2, y2 = box.astype(int)

# Draw fill rectangle in mask image
cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1)

return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@
"img_content", [img_path, str(img_path), open(img_path, "rb").read(), img]
)
def test_multi_input(img_content):
layout_res, elapse = layout_engine(img_content)
assert len(layout_res) == 15
boxes, scores, class_names, *elapse = layout_engine(img_content)
assert len(boxes) == 15

0 comments on commit 34a878b

Please sign in to comment.