diff --git a/demo.py b/demo.py index e063891..277e567 100644 --- a/demo.py +++ b/demo.py @@ -1,11 +1,9 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -import copy from pathlib import Path import cv2 -import numpy as np from rapid_layout import RapidLayout, VisLayout from rapid_orientation import RapidOrientation @@ -37,9 +35,9 @@ def demo_layout(): img = cv2.imread(img_path) boxes, scores, class_names, *elapse = layout_engine(img) - ploted_img = VisLayout.draw_detections(img, boxes, scores, class_names) - cv2.imwrite("layout_res.png", ploted_img) + if ploted_img is not None: + cv2.imwrite("layout_res.png", ploted_img) def demo_table(): diff --git a/layout_res.png b/layout_res.png deleted file mode 100644 index 3631043..0000000 Binary files a/layout_res.png and /dev/null differ diff --git a/rapid_layout/main.py b/rapid_layout/main.py index 6aa6103..d5f292d 100644 --- a/rapid_layout/main.py +++ b/rapid_layout/main.py @@ -14,7 +14,7 @@ import argparse import time from pathlib import Path -from typing import Optional, Union +from typing import Optional, Tuple, Union import cv2 import numpy as np @@ -24,6 +24,7 @@ LoadImage, OrtInferSession, PicoDetPostProcess, + VisLayout, create_operators, get_logger, read_yaml, @@ -64,23 +65,16 @@ def __init__( self.postprocess_op = PicoDetPostProcess(labels, **config["post_process"]) self.load_img = LoadImage() - def get_model_path(self, model_type: Optional[str] = None) -> str: - model_url = KEY_TO_MODEL_URL.get(model_type, None) - if model_url: - model_path = DownloadModel.download(model_url) - return model_path - logger.info("model url is None, using the default model %s", DEFAULT_MODEL_PATH) - return DEFAULT_MODEL_PATH - - def __call__(self, img_content: Union[str, np.ndarray, bytes, Path]): + def __call__( + self, img_content: Union[str, np.ndarray, bytes, Path] + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], float]: img = self.load_img(img_content) ori_im = img.copy() data = transform({"image": img}, self.preprocess_op) img = data[0] - if img is None: - return None, None, None, 0 + return None, None, None, 0.0 img = np.expand_dims(img, axis=0) img = img.copy() @@ -101,6 +95,16 @@ def __call__(self, img_content: Union[str, np.ndarray, bytes, Path]): elapse = time.time() - starttime return boxes, scores, class_names, elapse + @staticmethod + def get_model_path(model_type: str) -> str: + model_url = KEY_TO_MODEL_URL.get(model_type, None) + if model_url: + model_path = DownloadModel.download(model_url) + return model_path + + logger.info("model url is None, using the default model %s", DEFAULT_MODEL_PATH) + return DEFAULT_MODEL_PATH + def main(): parser = argparse.ArgumentParser() @@ -135,14 +139,17 @@ def main(): ) img = cv2.imread(args.img_path) - layout_res, elapse = layout_engine(img) - print(layout_res) + boxes, scores, class_names, *elapse = layout_engine(img) + print(boxes) + print(scores) + print(class_names) if args.vis: img_path = Path(args.img_path) - ploted_img = vis_layout(img, layout_res) - save_path = img_path.resolve().parent / f"vis_{img_path.name}" - cv2.imwrite(str(save_path), ploted_img) + ploted_img = VisLayout.draw_detections(img, boxes, scores, class_names) + if ploted_img is not None: + save_path = img_path.resolve().parent / f"vis_{img_path.name}" + cv2.imwrite(str(save_path), ploted_img) if __name__ == "__main__": diff --git a/rapid_layout/utils/download_model.py b/rapid_layout/utils/download_model.py index c105d4d..a0d9d93 100644 --- a/rapid_layout/utils/download_model.py +++ b/rapid_layout/utils/download_model.py @@ -38,8 +38,10 @@ def download(cls, model_full_url: Union[str, Path]) -> str: return str(save_file_path) @staticmethod - def download_as_bytes_with_progress(url: str, name: Optional[str] = None) -> bytes: - resp = requests.get(url, stream=True, allow_redirects=True, timeout=180) + def download_as_bytes_with_progress( + url: Union[str, Path], name: Optional[str] = None + ) -> bytes: + resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180) total = int(resp.headers.get("content-length", 0)) bio = io.BytesIO() with tqdm( diff --git a/rapid_layout/utils/vis_res.py b/rapid_layout/utils/vis_res.py index e6597f5..f73ee79 100644 --- a/rapid_layout/utils/vis_res.py +++ b/rapid_layout/utils/vis_res.py @@ -1,47 +1,22 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -import copy -from typing import Any, Dict, List +from typing import Optional import cv2 import numpy as np -def vis_layout(img: np.ndarray, layout_res: List[Dict[str, Any]]) -> np.ndarray: - font = cv2.FONT_HERSHEY_COMPLEX - font_scale = 1 - font_color = (0, 0, 255) - font_thickness = 1 - - tmp_img = copy.deepcopy(img) - for v in layout_res: - bbox = np.round(v["bbox"]).astype(np.int32) - label = v["label"] - - start_point = (bbox[0], bbox[1]) - end_point = (bbox[2], bbox[3]) - - cv2.rectangle(tmp_img, start_point, end_point, (0, 255, 0), 2) - - (w, h), _ = cv2.getTextSize(label, font, font_scale, font_thickness) - put_point = start_point[0], start_point[1] + h - cv2.putText( - tmp_img, label, put_point, font, font_scale, font_color, font_thickness - ) - return tmp_img - - class VisLayout: @classmethod def draw_detections( cls, image: np.ndarray, - boxes: np.ndarray, - scores: np.ndarray, - class_names: np.ndarray, + boxes: Optional[np.ndarray], + scores: Optional[np.ndarray], + class_names: Optional[np.ndarray], mask_alpha=0.3, - ): + ) -> Optional[np.ndarray]: """_summary_ Args: @@ -52,8 +27,10 @@ def draw_detections( mask_alpha (float, optional): _description_. Defaults to 0.3. Returns: - _type_: _description_ + np.ndarray: _description_ """ + if boxes is None or scores is None or class_names is None: + return None det_img = image.copy() @@ -61,14 +38,12 @@ def draw_detections( font_size = min([img_height, img_width]) * 0.0006 text_thickness = int(min([img_height, img_width]) * 0.001) - det_img = cls.draw_masks(det_img, boxes, class_names, mask_alpha) + det_img = cls.draw_masks(det_img, boxes, mask_alpha) - # Draw bounding boxes and labels of detections for label, box, score in zip(class_names, boxes, scores): color = cls.get_color() cls.draw_box(det_img, box, color) - caption = f"{label} {int(score * 100)}%" cls.draw_text(det_img, caption, box, color, font_size, text_thickness) @@ -120,18 +95,12 @@ def draw_masks( cls, image: np.ndarray, boxes: np.ndarray, - classes: np.ndarray, mask_alpha: float = 0.3, ) -> np.ndarray: mask_img = image.copy() - - # Draw bounding boxes and labels of detections - for box, class_name in zip(boxes, classes): + for box in boxes: color = cls.get_color() - x1, y1, x2, y2 = box.astype(int) - - # Draw fill rectangle in mask image cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1) return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0) diff --git a/tests/test_layout.py b/tests/test_layout.py index 231f50b..1d3fa5c 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -26,5 +26,5 @@ "img_content", [img_path, str(img_path), open(img_path, "rb").read(), img] ) def test_multi_input(img_content): - layout_res, elapse = layout_engine(img_content) - assert len(layout_res) == 15 + boxes, scores, class_names, *elapse = layout_engine(img_content) + assert len(boxes) == 15