-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathmain.py
43 lines (31 loc) · 1.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""Script for Fast API Endpoint."""
import base64
import io
import warnings
import numpy as np
from fastapi import FastAPI
from PIL import Image
from pydantic import BaseModel
from src.engine import DefaultEngine
from src.model import DefaultModel
warnings.filterwarnings("ignore")
app = FastAPI()
detector_cfg = "configs/craft_config.yaml"
detector_model = "models/text_detector/craft_mlt_25k.pth"
recognizer_cfg = "configs/star_config.yaml"
recognizer_model = "models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth"
model = DefaultModel(detector_cfg, detector_model, recognizer_cfg, recognizer_model)
engine = DefaultEngine(model)
class Item(BaseModel):
image: str
@app.get("/")
def read_root():
return {"message": "API is running..."}
@app.post("/ocr/predict")
def predict(item: Item):
item = item.dict()
img_bytes = base64.b64decode(item["image"].encode("utf-8"))
image = Image.open(io.BytesIO(img_bytes))
image = np.array(image)
engine.predict(image)
return engine.result