diff --git a/src/tasso/handlers/external.py b/src/tasso/handlers/external.py index 5b8bfdd..60a0f17 100644 --- a/src/tasso/handlers/external.py +++ b/src/tasso/handlers/external.py @@ -1,5 +1,6 @@ """Handlers for the app's external root, ``/tasso/``.""" +import random from typing import Annotated from fastapi import APIRouter, Depends @@ -10,8 +11,11 @@ from ..config import config from ..models.classification import Classification +from ..models.classification_run import ClassificationRun from ..models.index import Index +from ..models.subject import Subject from ..storage.classification import ClassificationStore +from ..storage.classification_run import ClassificationRunStore from ..storage.subject import SubjectStore __all__ = ["external_router", "get_index"] @@ -108,3 +112,52 @@ async def put_classification( await db_session_dependency.aclose() return classification + + +@external_router.get( + "/unclassified_subject", + summary="Return a subject that needs classification.", +) +async def get_unclassified_subject( + logger: Annotated[BoundLogger, Depends(logger_dependency)], +) -> Subject | None: + await db_session_dependency.initialize( + config.database_url, config.database_password + ) + + async for db_session in db_session_dependency(): + store = SubjectStore(db_session) + run_store = ClassificationRunStore(db_session) + + runs = await run_store.get_active_runs() + if len(runs) == 0: + return None + elif len(runs) == 1: + run = runs[0] + else: + # choose a run at random. + run = random.choice(runs) # noqa: S311 + + # this needs to not be hardcoded. + user_id = "dfad48bd59404103ba9c668e47f4c700" + + return await store.get_unclassified( + user_id, run.run_id, run.max_classifications + ) + + +@external_router.get( + "/active_runs", + summary="Return runs that are active.", +) +async def get_active_runs( + logger: Annotated[BoundLogger, Depends(logger_dependency)], +) -> list[ClassificationRun] | None: + await db_session_dependency.initialize( + config.database_url, config.database_password + ) + + async for db_session in db_session_dependency(): + run_store = ClassificationRunStore(db_session) + + return await run_store.get_active_runs() diff --git a/src/tasso/storage/classification_run.py b/src/tasso/storage/classification_run.py index d97c140..0c1a709 100644 --- a/src/tasso/storage/classification_run.py +++ b/src/tasso/storage/classification_run.py @@ -3,7 +3,9 @@ from typing import Annotated from fastapi import Depends +from safir.datetime import current_datetime from safir.dependencies.db_session import db_session_dependency +from sqlalchemy import select from sqlalchemy.ext.asyncio import async_scoped_session from ..models.classification_run import ClassificationRun @@ -34,3 +36,28 @@ def __init__( storage=SQLClassificationRun, primary_key="run_id", ) + + async def get_active_runs(self) -> list[ClassificationRun]: + """Return classification runs which are active. + + Returns + ------- + list of ClassificationRun + """ + time_now = current_datetime() + + stmt = select(self.storage).where( + ( + (SQLClassificationRun.time_start <= time_now) + | SQLClassificationRun.time_start.is_(None) + ), + ( + (SQLClassificationRun.time_stop > time_now) + | SQLClassificationRun.time_stop.is_(None) + ), + ) + + print(stmt) + async with self._session.begin(): + result = await self._session.execute(stmt) + return [self.model.model_validate(res[0]) for res in result.all()] diff --git a/src/tasso/storage/subject.py b/src/tasso/storage/subject.py index caf0929..2d9444a 100644 --- a/src/tasso/storage/subject.py +++ b/src/tasso/storage/subject.py @@ -5,9 +5,11 @@ from fastapi import Depends from lsst.resources import ResourcePath from safir.dependencies.db_session import db_session_dependency +from sqlalchemy import select from sqlalchemy.ext.asyncio import async_scoped_session from ..models.subject import Subject +from ..schema import Classification as SQLClassification from ..schema import Subject as SQLSubject from .base import BaseStore @@ -47,3 +49,45 @@ def get_blob(self, subject: Subject) -> bytes: file_uri = ResourcePath(subject.uri) return file_uri.read() + + async def get_unclassified( + self, user_id: str, run_id: str, max_classifications: int + ) -> Subject | None: + """Return a subject for classification. + + Choose one at random from the specified run which has fewer than the + specified numbers of classifications and is not yet classified + by the user. + + Parameters + ---------- + user_id + The user performing the classification. + run_id + The classification run to search for subjects. + max_classifiations + The number of classifications each subject should receive. + """ + stmt = ( + select(SQLSubject) + .where( + SQLSubject.run_id == run_id, + SQLSubject.n_classifications < max_classifications, + ) + .outerjoin(SQLClassification) + .where( + (SQLClassification.user_id != user_id) + | SQLClassification.user_id.is_(None) + ) + ) + + print(stmt) + + async with self._session.begin(): + result = await self._session.execute(stmt) + value = result.one_or_none() + print(value) + if value is None: + return None + else: + return self.model.model_validate(value[0])