-
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
811 additions
and
365 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# -*- encoding: utf-8 -*- | ||
# @Author: SWHL | ||
# @Contact: [email protected] | ||
import yaml | ||
|
||
from .download_model import DownloadModel | ||
from .infer_engine import OrtInferSession | ||
from .load_image import LoadImage | ||
from .logger import get_logger | ||
from .post_prepross import PicoDetPostProcess | ||
from .pre_procss import create_operators, transform | ||
from .vis_res import VisLayout | ||
|
||
|
||
def read_yaml(yaml_path): | ||
with open(yaml_path, "rb") as f: | ||
data = yaml.load(f, Loader=yaml.Loader) | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# -*- encoding: utf-8 -*- | ||
# @Author: SWHL | ||
# @Contact: [email protected] | ||
import io | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
|
||
import requests | ||
from tqdm import tqdm | ||
|
||
from .logger import get_logger | ||
|
||
logger = get_logger("DownloadModel") | ||
CUR_DIR = Path(__file__).resolve() | ||
PROJECT_DIR = CUR_DIR.parent.parent | ||
|
||
|
||
class DownloadModel: | ||
cur_dir = PROJECT_DIR | ||
|
||
@classmethod | ||
def download(cls, model_full_url: Union[str, Path]) -> str: | ||
save_dir = cls.cur_dir / "models" | ||
save_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
model_name = Path(model_full_url).name | ||
save_file_path = save_dir / model_name | ||
if save_file_path.exists(): | ||
logger.info("%s already exists", save_file_path) | ||
return str(save_file_path) | ||
|
||
try: | ||
logger.info("Download %s to %s", model_full_url, save_dir) | ||
file = cls.download_as_bytes_with_progress(model_full_url, model_name) | ||
cls.save_file(save_file_path, file) | ||
except Exception as exc: | ||
raise DownloadModelError from exc | ||
return str(save_file_path) | ||
|
||
@staticmethod | ||
def download_as_bytes_with_progress(url: str, name: Optional[str] = None) -> bytes: | ||
resp = requests.get(url, stream=True, allow_redirects=True, timeout=180) | ||
total = int(resp.headers.get("content-length", 0)) | ||
bio = io.BytesIO() | ||
with tqdm( | ||
desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024 | ||
) as pbar: | ||
for chunk in resp.iter_content(chunk_size=65536): | ||
pbar.update(len(chunk)) | ||
bio.write(chunk) | ||
return bio.getvalue() | ||
|
||
@staticmethod | ||
def save_file(save_path: Union[str, Path], file: bytes): | ||
with open(save_path, "wb") as f: | ||
f.write(file) | ||
|
||
|
||
class DownloadModelError(Exception): | ||
pass |
Oops, something went wrong.