diff --git a/cougar.py b/cougar.py index a171056..bea7d7d 100644 --- a/cougar.py +++ b/cougar.py @@ -2,7 +2,7 @@ from starlette.responses import JSONResponse, HTMLResponse, RedirectResponse from fastai.vision import ( ImageDataBunch, - ConvLearner, + create_cnn, open_image, get_transforms, models, @@ -47,7 +47,7 @@ async def get_bytes(url): ds_tfms=get_transforms(), size=224, ) -cat_learner = ConvLearner(cat_data, models.resnet34) +cat_learner = create_cnn(cat_data, models.resnet34) cat_learner.model.load_state_dict( torch.load("usa-inaturalist-cats.pth", map_location="cpu") ) @@ -68,7 +68,7 @@ async def classify_url(request): def predict_image_from_bytes(bytes): img = open_image(BytesIO(bytes)) - losses = img.predict(cat_learner) + _, _, losses = cat_learner.predict(img) return JSONResponse({ "predictions": sorted( zip(cat_learner.data.classes, map(float, losses)),