Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(API): Be resilient to item serialization failures #909

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 83 additions & 71 deletions skore/src/skore/ui/project_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from __future__ import annotations

import base64
import operator
from dataclasses import dataclass
import json
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from itertools import starmap
from traceback import format_exc

from fastapi import APIRouter, HTTPException, Request, status
from fastapi import APIRouter, HTTPException, Request, Response, status

from skore.item import (
CrossValidationAggregationItem,
Expand All @@ -23,53 +23,40 @@
PrimitiveItem,
SklearnBaseEstimatorItem,
)
from skore.project import Project
from skore.view.view import Layout, View

if TYPE_CHECKING:
import pandas # type: ignore

router = APIRouter(prefix="/project")


@dataclass
class SerializableItem:
"""Serialized item."""

name: str
media_type: str
value: Any
updated_at: str
created_at: str

class JSONError(Exception):
"""Exception for objects that can't be serialized."""

@dataclass
class SerializableProject:
"""Serialized project, to be sent to the skore-ui."""

items: dict[str, list[SerializableItem]]
views: dict[str, Layout]
def default(_):
"""Raise `JSONError` for objects that can't be serialized."""
raise JSONError


def __pandas_dataframe_as_serializable(df: pandas.DataFrame):
def pandas_dataframe_as_serializable(df):
return df.fillna("NaN").to_dict(orient="tight")


def __item_as_serializable(name: str, item: Item) -> SerializableItem:
def serialize_item_to_json(name: str, item: Item) -> str:
"""Serialize Item to JSON."""
if isinstance(item, PrimitiveItem):
value = item.primitive
media_type = "text/markdown"
elif isinstance(item, NumpyArrayItem):
value = item.array.tolist()
media_type = "text/markdown"
elif isinstance(item, PandasDataFrameItem):
value = __pandas_dataframe_as_serializable(item.dataframe)
value = pandas_dataframe_as_serializable(item.dataframe)
media_type = "application/vnd.dataframe"
elif isinstance(item, PandasSeriesItem):
value = item.series.fillna("NaN").to_list()
media_type = "text/markdown"
elif isinstance(item, PolarsDataFrameItem):
value = __pandas_dataframe_as_serializable(item.dataframe.to_pandas())
value = pandas_dataframe_as_serializable(item.dataframe.to_pandas())
media_type = "application/vnd.dataframe"
elif isinstance(item, PolarsSeriesItem):
value = item.series.to_list()
Expand All @@ -90,65 +77,82 @@ def __item_as_serializable(name: str, item: Item) -> SerializableItem:
else:
raise ValueError(f"Item {item} is not a known item type.")

return SerializableItem(
name=name,
media_type=media_type,
value=value,
updated_at=item.updated_at,
created_at=item.created_at,
try:
value = json.dumps(value, default=default)
except JSONError:
value = "null"
error = "true"
traceback = f'"{format_exc()}"'
else:
error = "false"
traceback = "null"

return (
"{"
f'"name": "{name}",'
f'"media_type": "{media_type}",'
f'"value": {value},'
f'"error": {error},'
f'"traceback": {traceback},'
f'"created_at": "{item.created_at}",'
f'"updated_at": "{item.updated_at}"'
"}"
)


def __project_as_serializable(project: Project) -> SerializableProject:
items = {
key: [
__item_as_serializable(key, item) for item in project.get_item_versions(key)
]
for key in project.list_item_keys()
}
@router.put("/views", status_code=status.HTTP_201_CREATED)
async def put_view(request: Request, key: str, layout: Layout):
project = request.app.state.project
project.put_view(key, View(layout=layout))

views = {key: project.get_view(key).layout for key in project.list_view_keys()}

return SerializableProject(
items=items,
views=views,
)
@router.delete("/views", status_code=status.HTTP_202_ACCEPTED)
async def delete_view(request: Request, key: str):
"""Delete the view corresponding to `key`."""
project = request.app.state.project

try:
project.delete_view(key)
except KeyError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="View not found",
)


@router.get("/items")
async def get_items(request: Request):
"""Serialize a project and send it."""
"""Serialize Project to JSON."""
project = request.app.state.project
return __project_as_serializable(project)


@router.put("/views", status_code=status.HTTP_201_CREATED)
async def put_view(request: Request, key: str, layout: Layout):
"""Set the layout of the view corresponding to `key`.
# serialize items
key_to_items_str = []

If the view corresponding to `key` does not exist, it will be created.
"""
project: Project = request.app.state.project
for key in project.list_item_keys():
items = []
for item in project.get_item_versions(key):
items.append(serialize_item_to_json(key, item))
key_to_items_str.append(f'"{key}": [{", ".join(items)}]')

view = View(layout=layout)
project.put_view(key, view)
key_to_items_str = f'{{{", ".join(key_to_items_str)}}}'

return __project_as_serializable(project)
# serialize layouts
key_to_layout_str = []

for key in project.list_view_keys():
layout = project.get_view(key).layout
layout = json.dumps(layout)
key_to_layout_str.append(f'"{key}": {layout}')

@router.delete("/views", status_code=status.HTTP_202_ACCEPTED)
async def delete_view(request: Request, key: str):
"""Delete the view corresponding to `key`."""
project: Project = request.app.state.project
key_to_layout_str = f'{{{", ".join(key_to_layout_str)}}}'

try:
project.delete_view(key)
except KeyError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="View not found"
) from None
# serialize project
items = str.encode(
f'{{"items": {key_to_items_str}, "views": {key_to_layout_str}}}',
"utf-8",
)

return __project_as_serializable(project)
return Response(content=items, media_type="application/json")


@router.get("/activity")
Expand All @@ -162,13 +166,21 @@ async def get_activity(
datetime `after`, sorted from newest to oldest.
"""
project = request.app.state.project
return sorted(
versions = sorted(
(
__item_as_serializable(key, version)
(key, version)
for key in project.list_item_keys()
for version in project.get_item_versions(key)
if datetime.fromisoformat(version.updated_at) > after
),
key=operator.attrgetter("updated_at"),
key=lambda x: x[1].updated_at,
reverse=True,
)

# serialize activity
activity = str.encode(
f'[{", ".join(starmap(serialize_item_to_json, versions))}]',
"utf-8",
)

return Response(content=activity, media_type="application/json")
36 changes: 33 additions & 3 deletions skore/tests/integration/ui/test_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi.testclient import TestClient
from PIL import Image
from sklearn.linear_model import Lasso
from skore.item.media_item import MediaItem
from skore.item import MediaItem, PrimitiveItem
from skore.ui.app import create_app
from skore.view.view import View

Expand All @@ -31,8 +31,8 @@ def test_get_items(client, in_memory_project):
in_memory_project.put("test", "version_2")

items = in_memory_project.get_item_versions("test")

response = client.get("/api/project/items")

assert response.status_code == 200
assert response.json() == {
"views": {},
Expand All @@ -42,6 +42,8 @@ def test_get_items(client, in_memory_project):
"name": "test",
"media_type": "text/markdown",
"value": item.primitive,
"error": False,
"traceback": None,
"created_at": item.created_at,
"updated_at": item.updated_at,
}
Expand All @@ -51,6 +53,34 @@ def test_get_items(client, in_memory_project):
}


def test_get_items_with_unserializable_object(monkeypatch, client, in_memory_project):
monkeypatch.setattr("skore.ui.project_routes.format_exc", lambda: "<traceback>")

in_memory_project.put_item("test", PrimitiveItem(object))

item = in_memory_project.get_item("test")
response = client.get("/api/project/items")

assert item.primitive is object
assert response.status_code == 200
assert response.json() == {
"views": {},
"items": {
"test": [
{
"name": "test",
"media_type": "text/markdown",
"value": None,
"error": True,
"traceback": "<traceback>",
"updated_at": item.updated_at,
"created_at": item.created_at,
},
]
},
}


def test_put_view_layout(client):
response = client.put("/api/project/views?key=hello", json=["test"])
assert response.status_code == 201
Expand Down Expand Up @@ -146,7 +176,7 @@ def test_serialize_media_item(client, in_memory_project):
assert project["items"]["media html"][0]["value"] == html


def test_activity_feed(monkeypatch, client, in_memory_project):
def test_activity(monkeypatch, client, in_memory_project):
class MockDatetime:
NOW = datetime.datetime.now(tz=datetime.timezone.utc)
TIMEDELTA = datetime.timedelta(days=1)
Expand Down
Loading