-
-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from RapidAI/table_optim
feature: add table cls model
- Loading branch information
Showing
11 changed files
with
322 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
name: Push table_cls to pypi | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
paths: | ||
- 'table_cls/**' | ||
# tags: | ||
# - v* | ||
|
||
jobs: | ||
UnitTesting: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Pull latest code | ||
uses: actions/checkout@v3 | ||
|
||
- name: Set up Python 3.10 | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: '3.10' | ||
architecture: 'x64' | ||
|
||
- name: Display Python version | ||
run: python -c "import sys; print(sys.version)" | ||
|
||
- name: Unit testings | ||
run: | | ||
pip install -r requirements.txt | ||
pip install pytest beautifulsoup4 | ||
wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/table_cls_models.zip | ||
unzip table_cls_models.zip | ||
mv table_cls_models/*.onnx table_cls/models/ | ||
pytest tests/test_table_cls.py | ||
GenerateWHL_PushPyPi: | ||
needs: UnitTesting | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
|
||
- name: Set up Python 3.7 | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: '3.7' | ||
architecture: 'x64' | ||
|
||
- name: Run setup.py | ||
run: | | ||
pip install -r requirements.txt | ||
python -m pip install --upgrade pip | ||
pip install wheel get_pypi_latest_version | ||
wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/table_cls_models.zip | ||
unzip table_cls_models.zip | ||
mv table_cls_models/*.onnx table_cls/models/ | ||
python setup_table_cls.py bdist_wheel "${{ github.event.head_commit.message }}" | ||
- name: Publish distribution 📦 to PyPI | ||
uses: pypa/[email protected] | ||
with: | ||
password: ${{ secrets.PYPI_API_TOKEN }} | ||
packages_dir: dist/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# -*- encoding: utf-8 -*- | ||
from table_cls import TableCls | ||
|
||
table_cls = TableCls() | ||
output_dir = "outputs" | ||
img_path = "tests/test_files/table_cls/lineless_table.png" | ||
cls_str, elapse = table_cls(img_path) | ||
print(cls_str) | ||
print(elapse) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# -*- encoding: utf-8 -*- | ||
# @Author: SWHL | ||
# @Contact: [email protected] | ||
import sys | ||
from typing import List, Union | ||
from pathlib import Path | ||
from get_pypi_latest_version import GetPyPiLatestVersion | ||
|
||
import setuptools | ||
|
||
|
||
def read_txt(txt_path: Union[Path, str]) -> List[str]: | ||
with open(txt_path, "r", encoding="utf-8") as f: | ||
data = [v.rstrip("\n") for v in f] | ||
return data | ||
|
||
|
||
MODULE_NAME = "table_cls" | ||
|
||
obtainer = GetPyPiLatestVersion() | ||
try: | ||
latest_version = obtainer(MODULE_NAME) | ||
except Exception: | ||
latest_version = "0.0.0" | ||
|
||
VERSION_NUM = obtainer.version_add_one(latest_version) | ||
|
||
if len(sys.argv) > 2: | ||
match_str = " ".join(sys.argv[2:]) | ||
matched_versions = obtainer.extract_version(match_str) | ||
if matched_versions: | ||
VERSION_NUM = matched_versions | ||
sys.argv = sys.argv[:2] | ||
|
||
setuptools.setup( | ||
name=MODULE_NAME, | ||
version=VERSION_NUM, | ||
platforms="Any", | ||
description="A table classifier for further table rec", | ||
long_description="A table classifier that distinguishes between wired and wireless tables", | ||
long_description_content_type="text/markdown", | ||
author="SWHL", | ||
author_email="[email protected]", | ||
url="https://github.com/RapidAI/TableStructureRec", | ||
license="Apache-2.0", | ||
install_requires=read_txt("requirements.txt"), | ||
include_package_data=True, | ||
packages=setuptools.find_packages(include=[MODULE_NAME, f"{MODULE_NAME}.*"]), | ||
package_data={ | ||
MODULE_NAME: ["*.onnx"], | ||
}, | ||
keywords=["table-classifier", "wired", "wireless", "table-recognition"], | ||
classifiers=[ | ||
"Programming Language :: Python :: 3.6", | ||
"Programming Language :: Python :: 3.7", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
], | ||
python_requires=">=3.6,<3.12", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .main import TableCls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import time | ||
|
||
from pathlib import Path | ||
import numpy as np | ||
import onnxruntime | ||
from PIL import Image | ||
|
||
from .utils import InputType, LoadImage | ||
|
||
cur_dir = Path(__file__).resolve().parent | ||
table_cls_model_path = cur_dir / "models" / "table_cls.onnx" | ||
|
||
|
||
class TableCls: | ||
def __init__(self, device="cpu"): | ||
providers = ( | ||
["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"] | ||
) | ||
self.table_cls = onnxruntime.InferenceSession( | ||
table_cls_model_path, providers=providers | ||
) | ||
self.inp_h = 224 | ||
self.inp_w = 224 | ||
self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) | ||
self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32) | ||
self.cls = {0: "wired", 1: "wireless"} | ||
self.load_img = LoadImage() | ||
|
||
def _preprocess(self, image): | ||
img = Image.fromarray(np.uint8(image)) | ||
img = img.resize((self.inp_h, self.inp_w)) | ||
img = np.array(img, dtype=np.float32) / 255.0 | ||
img -= self.mean | ||
img /= self.std | ||
img = img.transpose(2, 0, 1) # HWC to CHW | ||
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image | ||
return img | ||
|
||
def __call__(self, content: InputType): | ||
ss = time.perf_counter() | ||
img = self.load_img(content) | ||
img = self._preprocess(img) | ||
output = self.table_cls.run(None, {"input": img}) | ||
predict = np.exp(output[0] - np.max(output[0], axis=1, keepdims=True)) | ||
predict /= np.sum(predict, axis=1, keepdims=True) | ||
predict_cla = np.argmax(predict, axis=1)[0] | ||
table_elapse = time.perf_counter() - ss | ||
return self.cls[predict_cla], table_elapse |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from io import BytesIO | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
from PIL import UnidentifiedImageError | ||
from PIL import Image | ||
import numpy as np | ||
import cv2 | ||
|
||
InputType = Union[str, np.ndarray, bytes, Path, Image.Image] | ||
|
||
|
||
class LoadImageError(Exception): | ||
pass | ||
|
||
|
||
class LoadImage: | ||
def __init__( | ||
self, | ||
): | ||
pass | ||
|
||
def __call__(self, img: InputType) -> np.ndarray: | ||
if not isinstance(img, InputType.__args__): | ||
raise LoadImageError( | ||
f"The img type {type(img)} does not in {InputType.__args__}" | ||
) | ||
|
||
origin_img_type = type(img) | ||
img = self.load_img(img) | ||
img = self.convert_img(img, origin_img_type) | ||
return img | ||
|
||
def load_img(self, img: InputType) -> np.ndarray: | ||
if isinstance(img, (str, Path)): | ||
self.verify_exist(img) | ||
try: | ||
img = np.array(Image.open(img)) | ||
except UnidentifiedImageError as e: | ||
raise LoadImageError(f"cannot identify image file {img}") from e | ||
return img | ||
|
||
if isinstance(img, bytes): | ||
img = np.array(Image.open(BytesIO(img))) | ||
return img | ||
|
||
if isinstance(img, np.ndarray): | ||
return img | ||
|
||
if isinstance(img, Image.Image): | ||
return np.array(img) | ||
|
||
raise LoadImageError(f"{type(img)} is not supported!") | ||
|
||
def convert_img(self, img: np.ndarray, origin_img_type): | ||
if img.ndim == 2: | ||
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | ||
|
||
if img.ndim == 3: | ||
channel = img.shape[2] | ||
if channel == 1: | ||
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | ||
|
||
if channel == 2: | ||
return self.cvt_two_to_three(img) | ||
|
||
if channel == 3: | ||
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): | ||
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | ||
return img | ||
|
||
if channel == 4: | ||
return self.cvt_four_to_three(img) | ||
|
||
raise LoadImageError( | ||
f"The channel({channel}) of the img is not in [1, 2, 3, 4]" | ||
) | ||
|
||
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") | ||
|
||
@staticmethod | ||
def cvt_two_to_three(img: np.ndarray) -> np.ndarray: | ||
"""gray + alpha → BGR""" | ||
img_gray = img[..., 0] | ||
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) | ||
|
||
img_alpha = img[..., 1] | ||
not_a = cv2.bitwise_not(img_alpha) | ||
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) | ||
|
||
new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) | ||
new_img = cv2.add(new_img, not_a) | ||
return new_img | ||
|
||
@staticmethod | ||
def cvt_four_to_three(img: np.ndarray) -> np.ndarray: | ||
"""RGBA → BGR""" | ||
r, g, b, a = cv2.split(img) | ||
new_img = cv2.merge((b, g, r)) | ||
|
||
not_a = cv2.bitwise_not(a) | ||
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) | ||
|
||
new_img = cv2.bitwise_and(new_img, new_img, mask=a) | ||
new_img = cv2.add(new_img, not_a) | ||
return new_img | ||
|
||
@staticmethod | ||
def verify_exist(file_path: Union[str, Path]): | ||
if not Path(file_path).exists(): | ||
raise LoadImageError(f"{file_path} does not exist.") |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import sys | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
from table_cls import TableCls | ||
|
||
cur_dir = Path(__file__).resolve().parent | ||
root_dir = cur_dir.parent | ||
|
||
sys.path.append(str(root_dir)) | ||
test_file_dir = cur_dir / "test_files" / "table_cls" | ||
table_cls = TableCls() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"img_path, expected", | ||
[("wired_table.png", "wired"), ("lineless_table.png", "wireless")], | ||
) | ||
def test_input_normal(img_path, expected): | ||
img_path = test_file_dir / img_path | ||
res, elasp = table_cls(img_path) | ||
assert res == expected |