Skip to content

Commit

Permalink
fix: Persist b64 string instead of bytes to ensure items are JSONable
Browse files Browse the repository at this point in the history
  • Loading branch information
thomass-dev committed Jan 28, 2025
1 parent cf21491 commit f1d8b9b
Show file tree
Hide file tree
Showing 15 changed files with 200 additions and 75 deletions.
9 changes: 4 additions & 5 deletions skore/src/skore/persistence/item/altair_chart_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from typing import TYPE_CHECKING, Optional

from .item import Item, ItemTypeError
from .media_item import lazy_is_instance
from skore.persistence.item.item import Item, ItemTypeError
from skore.persistence.item.media_item import lazy_is_instance
from skore.utils import bytes_to_b64_str

if TYPE_CHECKING:
from altair.vegalite.v5.schema.core import TopLevelSpec as AltairChart
Expand Down Expand Up @@ -71,10 +72,8 @@ def chart(self) -> AltairChart:

def as_serializable_dict(self):
"""Convert item to a JSON-serializable dict to used by frontend."""
import base64

chart_bytes = self.chart_str.encode("utf-8")
chart_b64_str = base64.b64encode(chart_bytes).decode()
chart_b64_str = bytes_to_b64_str(chart_bytes)

return super().as_serializable_dict() | {
"media_type": "application/vnd.vega.v5+json;base64",
Expand Down
19 changes: 12 additions & 7 deletions skore/src/skore/persistence/item/cross_validation_reporter_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import plotly.graph_objects
import plotly.io

from skore.persistence.item.item import Item, ItemTypeError
from skore.sklearn.cross_validation import CrossValidationReporter

from .item import Item, ItemTypeError
from skore.utils import b64_str_to_bytes, bytes_to_b64_str

if TYPE_CHECKING:
import sklearn.base
Expand Down Expand Up @@ -147,7 +147,7 @@ class CrossValidationReporterItem(Item):

def __init__(
self,
reporter_bytes: bytes,
reporter_b64_str: str,
created_at: Optional[str] = None,
updated_at: Optional[str] = None,
note: Optional[str] = None,
Expand All @@ -157,7 +157,7 @@ def __init__(
Parameters
----------
reporter_bytes : bytes
reporter_b64_str : str
The raw bytes of the reporter pickled representation.
created_at : str, optional
The creation timestamp in ISO format.
Expand All @@ -168,7 +168,7 @@ def __init__(
"""
super().__init__(created_at, updated_at, note)

self.reporter_bytes = reporter_bytes
self.reporter_b64_str = reporter_b64_str

@classmethod
def factory(
Expand All @@ -195,12 +195,17 @@ def factory(
with io.BytesIO() as stream:
joblib.dump(reporter, stream)

return cls(stream.getvalue(), **kwargs)
reporter_bytes = stream.getvalue()
reporter_b64_str = bytes_to_b64_str(reporter_bytes)

return cls(reporter_b64_str, **kwargs)

@property
def reporter(self) -> CrossValidationReporter:
"""The CrossValidationReporter from the persistence."""
with io.BytesIO(self.reporter_bytes) as stream:
reporter_bytes = b64_str_to_bytes(self.reporter_b64_str)

with io.BytesIO(reporter_bytes) as stream:
return joblib.load(stream)

def as_serializable_dict(self):
Expand Down
23 changes: 14 additions & 9 deletions skore/src/skore/persistence/item/matplotlib_figure_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from __future__ import annotations

from base64 import b64encode
from io import BytesIO
from typing import TYPE_CHECKING, Optional

import joblib

from .item import Item, ItemTypeError
from .media_item import lazy_is_instance
from skore.persistence.item.item import Item, ItemTypeError
from skore.persistence.item.media_item import lazy_is_instance
from skore.utils import b64_str_to_bytes, bytes_to_b64_str

if TYPE_CHECKING:
from matplotlib.figure import Figure
Expand All @@ -23,7 +23,7 @@ class MatplotlibFigureItem(Item):

def __init__(
self,
figure_bytes: str,
figure_b64_str: str,
created_at: Optional[str] = None,
updated_at: Optional[str] = None,
note: Optional[str] = None,
Expand All @@ -33,7 +33,7 @@ def __init__(
Parameters
----------
figure_bytes : bytes
figure_b64_str : str
The raw bytes of the Matplotlib figure pickled representation.
created_at : str, optional
The creation timestamp in ISO format.
Expand All @@ -44,7 +44,7 @@ def __init__(
"""
super().__init__(created_at, updated_at, note)

self.figure_bytes = figure_bytes
self.figure_b64_str = figure_b64_str

@classmethod
def factory(cls, figure: Figure, /, **kwargs) -> MatplotlibFigureItem:
Expand All @@ -67,12 +67,17 @@ def factory(cls, figure: Figure, /, **kwargs) -> MatplotlibFigureItem:
with BytesIO() as stream:
joblib.dump(figure, stream)

return cls(stream.getvalue(), **kwargs)
figure_bytes = stream.getvalue()
figure_b64_str = bytes_to_b64_str(figure_bytes)

return cls(figure_b64_str, **kwargs)

@property
def figure(self) -> Figure:
"""The figure from the persistence."""
with BytesIO(self.figure_bytes) as stream:
figure_bytes = b64_str_to_bytes(self.figure_b64_str)

with BytesIO(figure_bytes) as stream:
return joblib.load(stream)

def as_serializable_dict(self) -> dict:
Expand All @@ -81,7 +86,7 @@ def as_serializable_dict(self) -> dict:
self.figure.savefig(stream, format="svg", bbox_inches="tight")

figure_bytes = stream.getvalue()
figure_b64_str = b64encode(figure_bytes).decode()
figure_b64_str = bytes_to_b64_str(figure_bytes)

return super().as_serializable_dict() | {
"media_type": "image/svg+xml;base64",
Expand Down
18 changes: 12 additions & 6 deletions skore/src/skore/persistence/item/pickle_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import joblib

from .item import Item
from skore.persistence.item.item import Item
from skore.utils import b64_str_to_bytes, bytes_to_b64_str


class PickleItem(Item):
Expand All @@ -25,7 +26,7 @@ class PickleItem(Item):

def __init__(
self,
pickle_bytes: bytes,
pickle_b64_str: str,
created_at: Optional[str] = None,
updated_at: Optional[str] = None,
note: Optional[str] = None,
Expand All @@ -35,7 +36,7 @@ def __init__(
Parameters
----------
pickle_bytes : bytes
pickle_b64_str : str
The raw bytes of the object pickled representation.
created_at : str, optional
The creation timestamp in ISO format.
Expand All @@ -46,7 +47,7 @@ def __init__(
"""
super().__init__(created_at, updated_at, note)

self.pickle_bytes = pickle_bytes
self.pickle_b64_str = pickle_b64_str

@classmethod
def factory(cls, object: Any, /, **kwargs) -> PickleItem:
Expand All @@ -66,12 +67,17 @@ def factory(cls, object: Any, /, **kwargs) -> PickleItem:
with BytesIO() as stream:
joblib.dump(object, stream)

return cls(stream.getvalue(), **kwargs)
pickle_bytes = stream.getvalue()
pickle_b64_str = bytes_to_b64_str(pickle_bytes)

return cls(pickle_b64_str, **kwargs)

@property
def object(self) -> Any:
"""The object from the persistence."""
with BytesIO(self.pickle_bytes) as stream:
pickle_bytes = b64_str_to_bytes(self.pickle_b64_str)

with BytesIO(pickle_bytes) as stream:
return joblib.load(stream)

def as_serializable_dict(self):
Expand Down
28 changes: 16 additions & 12 deletions skore/src/skore/persistence/item/pillow_image_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from __future__ import annotations

from io import BytesIO
from typing import TYPE_CHECKING, Optional

from .item import Item, ItemTypeError
from .media_item import lazy_is_instance
from skore.persistence.item.item import Item, ItemTypeError
from skore.persistence.item.media_item import lazy_is_instance
from skore.utils import b64_str_to_bytes, bytes_to_b64_str

if TYPE_CHECKING:
import PIL.Image
Expand All @@ -19,7 +21,7 @@ class PillowImageItem(Item):

def __init__(
self,
image_bytes: bytes,
image_b64_str: str,
image_mode: str,
image_size: tuple[int],
created_at: Optional[str] = None,
Expand All @@ -31,7 +33,7 @@ def __init__(
Parameters
----------
image_bytes : bytes
image_b64_str : str
The raw bytes of the Pillow image.
image_mode : str
The image mode.
Expand All @@ -46,7 +48,7 @@ def __init__(
"""
super().__init__(created_at, updated_at, note)

self.image_bytes = image_bytes
self.image_b64_str = image_b64_str
self.image_mode = image_mode
self.image_size = image_size

Expand All @@ -68,8 +70,11 @@ def factory(cls, image: PIL.Image.Image, /, **kwargs) -> PillowImageItem:
if not lazy_is_instance(image, "PIL.Image.Image"):
raise ItemTypeError(f"Type '{image.__class__}' is not supported.")

image_bytes = image.tobytes()
image_b64_str = bytes_to_b64_str(image_bytes)

return cls(
image_bytes=image.tobytes(),
image_b64_str=image_b64_str,
image_mode=image.mode,
image_size=image.size,
**kwargs,
Expand All @@ -80,22 +85,21 @@ def image(self) -> PIL.Image.Image:
"""The image from the persistence."""
import PIL.Image

image_bytes = b64_str_to_bytes(self.image_b64_str)

return PIL.Image.frombytes(
mode=self.image_mode,
size=self.image_size,
data=self.image_bytes,
data=image_bytes,
)

def as_serializable_dict(self):
"""Convert item to a JSON-serializable dict to used by frontend."""
import base64
import io

with io.BytesIO() as stream:
with BytesIO() as stream:
self.image.save(stream, format="png")

png_bytes = stream.getvalue()
png_b64_str = base64.b64encode(png_bytes).decode()
png_b64_str = bytes_to_b64_str(png_bytes)

return super().as_serializable_dict() | {
"media_type": "image/png;base64",
Expand Down
11 changes: 5 additions & 6 deletions skore/src/skore/persistence/item/plotly_figure_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from typing import TYPE_CHECKING, Optional

from .item import Item, ItemTypeError
from .media_item import lazy_is_instance
from skore.persistence.item.item import Item, ItemTypeError
from skore.persistence.item.media_item import lazy_is_instance
from skore.utils import bytes_to_b64_str

if TYPE_CHECKING:
import plotly.basedatatypes
Expand All @@ -29,7 +30,7 @@ def __init__(
Parameters
----------
figure_str : bytes
figure_str : str
The JSON str of the Plotly figure.
created_at : str, optional
The creation timestamp in ISO format.
Expand Down Expand Up @@ -78,10 +79,8 @@ def figure(self) -> plotly.basedatatypes.BaseFigure:

def as_serializable_dict(self):
"""Convert item to a JSON-serializable dict to used by frontend."""
import base64

figure_bytes = self.figure_str.encode("utf-8")
figure_b64_str = base64.b64encode(figure_bytes).decode()
figure_b64_str = bytes_to_b64_str(figure_bytes)

return super().as_serializable_dict() | {
"media_type": "application/vnd.plotly.v1+json;base64",
Expand Down
Loading

0 comments on commit f1d8b9b

Please sign in to comment.