Skip to content

Commit

Permalink
build out unclassified_subject
Browse files Browse the repository at this point in the history
  • Loading branch information
ebellm committed Jan 7, 2025
1 parent 8f6b366 commit a7d2a4f
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/tasso/handlers/external.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Handlers for the app's external root, ``/tasso/``."""

import random
from typing import Annotated

from fastapi import APIRouter, Depends
Expand All @@ -10,8 +11,11 @@

from ..config import config
from ..models.classification import Classification
from ..models.classification_run import ClassificationRun
from ..models.index import Index
from ..models.subject import Subject
from ..storage.classification import ClassificationStore
from ..storage.classification_run import ClassificationRunStore
from ..storage.subject import SubjectStore

__all__ = ["external_router", "get_index"]
Expand Down Expand Up @@ -108,3 +112,52 @@ async def put_classification(
await db_session_dependency.aclose()

return classification


@external_router.get(
"/unclassified_subject",
summary="Return a subject that needs classification.",
)
async def get_unclassified_subject(
logger: Annotated[BoundLogger, Depends(logger_dependency)],
) -> Subject | None:
await db_session_dependency.initialize(
config.database_url, config.database_password
)

async for db_session in db_session_dependency():
store = SubjectStore(db_session)
run_store = ClassificationRunStore(db_session)

runs = await run_store.get_active_runs()
if len(runs) == 0:
return None
elif len(runs) == 1:
run = runs[0]
else:
# choose a run at random.
run = random.choice(runs) # noqa: S311

# this needs to not be hardcoded.
user_id = "dfad48bd59404103ba9c668e47f4c700"

return await store.get_unclassified(
user_id, run.run_id, run.max_classifications
)


@external_router.get(
"/active_runs",
summary="Return runs that are active.",
)
async def get_active_runs(
logger: Annotated[BoundLogger, Depends(logger_dependency)],
) -> list[ClassificationRun] | None:
await db_session_dependency.initialize(
config.database_url, config.database_password
)

async for db_session in db_session_dependency():
run_store = ClassificationRunStore(db_session)

return await run_store.get_active_runs()
27 changes: 27 additions & 0 deletions src/tasso/storage/classification_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from typing import Annotated

from fastapi import Depends
from safir.datetime import current_datetime
from safir.dependencies.db_session import db_session_dependency
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_scoped_session

from ..models.classification_run import ClassificationRun
Expand Down Expand Up @@ -34,3 +36,28 @@ def __init__(
storage=SQLClassificationRun,
primary_key="run_id",
)

async def get_active_runs(self) -> list[ClassificationRun]:
"""Return classification runs which are active.
Returns
-------
list of ClassificationRun
"""
time_now = current_datetime()

stmt = select(self.storage).where(
(
(SQLClassificationRun.time_start <= time_now)
| SQLClassificationRun.time_start.is_(None)
),
(
(SQLClassificationRun.time_stop > time_now)
| SQLClassificationRun.time_stop.is_(None)
),
)

print(stmt)
async with self._session.begin():
result = await self._session.execute(stmt)
return [self.model.model_validate(res[0]) for res in result.all()]
44 changes: 44 additions & 0 deletions src/tasso/storage/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from fastapi import Depends
from lsst.resources import ResourcePath
from safir.dependencies.db_session import db_session_dependency
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_scoped_session

from ..models.subject import Subject
from ..schema import Classification as SQLClassification
from ..schema import Subject as SQLSubject
from .base import BaseStore

Expand Down Expand Up @@ -47,3 +49,45 @@ def get_blob(self, subject: Subject) -> bytes:
file_uri = ResourcePath(subject.uri)

return file_uri.read()

async def get_unclassified(
self, user_id: str, run_id: str, max_classifications: int
) -> Subject | None:
"""Return a subject for classification.
Choose one at random from the specified run which has fewer than the
specified numbers of classifications and is not yet classified
by the user.
Parameters
----------
user_id
The user performing the classification.
run_id
The classification run to search for subjects.
max_classifiations
The number of classifications each subject should receive.
"""
stmt = (
select(SQLSubject)
.where(
SQLSubject.run_id == run_id,
SQLSubject.n_classifications < max_classifications,
)
.outerjoin(SQLClassification)
.where(
(SQLClassification.user_id != user_id)
| SQLClassification.user_id.is_(None)
)
)

print(stmt)

async with self._session.begin():
result = await self._session.execute(stmt)
value = result.one_or_none()
print(value)
if value is None:
return None
else:
return self.model.model_validate(value[0])

0 comments on commit a7d2a4f

Please sign in to comment.