diff --git a/skore/src/skore/persistence/item/altair_chart_item.py b/skore/src/skore/persistence/item/altair_chart_item.py index 0b7fd570e..b36d4afd3 100644 --- a/skore/src/skore/persistence/item/altair_chart_item.py +++ b/skore/src/skore/persistence/item/altair_chart_item.py @@ -7,8 +7,9 @@ from typing import TYPE_CHECKING, Optional -from .item import Item, ItemTypeError -from .media_item import lazy_is_instance +from skore.persistence.item.item import Item, ItemTypeError +from skore.persistence.item.media_item import lazy_is_instance +from skore.utils import bytes_to_b64_str if TYPE_CHECKING: from altair.vegalite.v5.schema.core import TopLevelSpec as AltairChart @@ -71,10 +72,8 @@ def chart(self) -> AltairChart: def as_serializable_dict(self): """Convert item to a JSON-serializable dict to used by frontend.""" - import base64 - chart_bytes = self.chart_str.encode("utf-8") - chart_b64_str = base64.b64encode(chart_bytes).decode() + chart_b64_str = bytes_to_b64_str(chart_bytes) return super().as_serializable_dict() | { "media_type": "application/vnd.vega.v5+json;base64", diff --git a/skore/src/skore/persistence/item/cross_validation_reporter_item.py b/skore/src/skore/persistence/item/cross_validation_reporter_item.py index 6ecbffb7b..c0d78cb21 100644 --- a/skore/src/skore/persistence/item/cross_validation_reporter_item.py +++ b/skore/src/skore/persistence/item/cross_validation_reporter_item.py @@ -20,9 +20,9 @@ import plotly.graph_objects import plotly.io +from skore.persistence.item.item import Item, ItemTypeError from skore.sklearn.cross_validation import CrossValidationReporter - -from .item import Item, ItemTypeError +from skore.utils import b64_str_to_bytes, bytes_to_b64_str if TYPE_CHECKING: import sklearn.base @@ -147,7 +147,7 @@ class CrossValidationReporterItem(Item): def __init__( self, - reporter_bytes: bytes, + reporter_b64_str: str, created_at: Optional[str] = None, updated_at: Optional[str] = None, note: Optional[str] = None, @@ -157,7 +157,7 @@ def __init__( Parameters ---------- - reporter_bytes : bytes + reporter_b64_str : str The raw bytes of the reporter pickled representation. created_at : str, optional The creation timestamp in ISO format. @@ -168,7 +168,7 @@ def __init__( """ super().__init__(created_at, updated_at, note) - self.reporter_bytes = reporter_bytes + self.reporter_b64_str = reporter_b64_str @classmethod def factory( @@ -195,12 +195,17 @@ def factory( with io.BytesIO() as stream: joblib.dump(reporter, stream) - return cls(stream.getvalue(), **kwargs) + reporter_bytes = stream.getvalue() + reporter_b64_str = bytes_to_b64_str(reporter_bytes) + + return cls(reporter_b64_str, **kwargs) @property def reporter(self) -> CrossValidationReporter: """The CrossValidationReporter from the persistence.""" - with io.BytesIO(self.reporter_bytes) as stream: + reporter_bytes = b64_str_to_bytes(self.reporter_b64_str) + + with io.BytesIO(reporter_bytes) as stream: return joblib.load(stream) def as_serializable_dict(self): diff --git a/skore/src/skore/persistence/item/matplotlib_figure_item.py b/skore/src/skore/persistence/item/matplotlib_figure_item.py index 2f0853dbf..2cae75fac 100644 --- a/skore/src/skore/persistence/item/matplotlib_figure_item.py +++ b/skore/src/skore/persistence/item/matplotlib_figure_item.py @@ -5,14 +5,14 @@ from __future__ import annotations -from base64 import b64encode from io import BytesIO from typing import TYPE_CHECKING, Optional import joblib -from .item import Item, ItemTypeError -from .media_item import lazy_is_instance +from skore.persistence.item.item import Item, ItemTypeError +from skore.persistence.item.media_item import lazy_is_instance +from skore.utils import b64_str_to_bytes, bytes_to_b64_str if TYPE_CHECKING: from matplotlib.figure import Figure @@ -23,7 +23,7 @@ class MatplotlibFigureItem(Item): def __init__( self, - figure_bytes: str, + figure_b64_str: str, created_at: Optional[str] = None, updated_at: Optional[str] = None, note: Optional[str] = None, @@ -33,7 +33,7 @@ def __init__( Parameters ---------- - figure_bytes : bytes + figure_b64_str : str The raw bytes of the Matplotlib figure pickled representation. created_at : str, optional The creation timestamp in ISO format. @@ -44,7 +44,7 @@ def __init__( """ super().__init__(created_at, updated_at, note) - self.figure_bytes = figure_bytes + self.figure_b64_str = figure_b64_str @classmethod def factory(cls, figure: Figure, /, **kwargs) -> MatplotlibFigureItem: @@ -67,12 +67,17 @@ def factory(cls, figure: Figure, /, **kwargs) -> MatplotlibFigureItem: with BytesIO() as stream: joblib.dump(figure, stream) - return cls(stream.getvalue(), **kwargs) + figure_bytes = stream.getvalue() + figure_b64_str = bytes_to_b64_str(figure_bytes) + + return cls(figure_b64_str, **kwargs) @property def figure(self) -> Figure: """The figure from the persistence.""" - with BytesIO(self.figure_bytes) as stream: + figure_bytes = b64_str_to_bytes(self.figure_b64_str) + + with BytesIO(figure_bytes) as stream: return joblib.load(stream) def as_serializable_dict(self) -> dict: @@ -81,7 +86,7 @@ def as_serializable_dict(self) -> dict: self.figure.savefig(stream, format="svg", bbox_inches="tight") figure_bytes = stream.getvalue() - figure_b64_str = b64encode(figure_bytes).decode() + figure_b64_str = bytes_to_b64_str(figure_bytes) return super().as_serializable_dict() | { "media_type": "image/svg+xml;base64", diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py index 21a69cfad..0881da925 100644 --- a/skore/src/skore/persistence/item/pickle_item.py +++ b/skore/src/skore/persistence/item/pickle_item.py @@ -12,7 +12,8 @@ import joblib -from .item import Item +from skore.persistence.item.item import Item +from skore.utils import b64_str_to_bytes, bytes_to_b64_str class PickleItem(Item): @@ -25,7 +26,7 @@ class PickleItem(Item): def __init__( self, - pickle_bytes: bytes, + pickle_b64_str: str, created_at: Optional[str] = None, updated_at: Optional[str] = None, note: Optional[str] = None, @@ -35,7 +36,7 @@ def __init__( Parameters ---------- - pickle_bytes : bytes + pickle_b64_str : str The raw bytes of the object pickled representation. created_at : str, optional The creation timestamp in ISO format. @@ -46,7 +47,7 @@ def __init__( """ super().__init__(created_at, updated_at, note) - self.pickle_bytes = pickle_bytes + self.pickle_b64_str = pickle_b64_str @classmethod def factory(cls, object: Any, /, **kwargs) -> PickleItem: @@ -66,12 +67,17 @@ def factory(cls, object: Any, /, **kwargs) -> PickleItem: with BytesIO() as stream: joblib.dump(object, stream) - return cls(stream.getvalue(), **kwargs) + pickle_bytes = stream.getvalue() + pickle_b64_str = bytes_to_b64_str(pickle_bytes) + + return cls(pickle_b64_str, **kwargs) @property def object(self) -> Any: """The object from the persistence.""" - with BytesIO(self.pickle_bytes) as stream: + pickle_bytes = b64_str_to_bytes(self.pickle_b64_str) + + with BytesIO(pickle_bytes) as stream: return joblib.load(stream) def as_serializable_dict(self): diff --git a/skore/src/skore/persistence/item/pillow_image_item.py b/skore/src/skore/persistence/item/pillow_image_item.py index 8d3e5266f..e0e5199d9 100644 --- a/skore/src/skore/persistence/item/pillow_image_item.py +++ b/skore/src/skore/persistence/item/pillow_image_item.py @@ -5,10 +5,12 @@ from __future__ import annotations +from io import BytesIO from typing import TYPE_CHECKING, Optional -from .item import Item, ItemTypeError -from .media_item import lazy_is_instance +from skore.persistence.item.item import Item, ItemTypeError +from skore.persistence.item.media_item import lazy_is_instance +from skore.utils import b64_str_to_bytes, bytes_to_b64_str if TYPE_CHECKING: import PIL.Image @@ -19,7 +21,7 @@ class PillowImageItem(Item): def __init__( self, - image_bytes: bytes, + image_b64_str: str, image_mode: str, image_size: tuple[int], created_at: Optional[str] = None, @@ -31,7 +33,7 @@ def __init__( Parameters ---------- - image_bytes : bytes + image_b64_str : str The raw bytes of the Pillow image. image_mode : str The image mode. @@ -46,7 +48,7 @@ def __init__( """ super().__init__(created_at, updated_at, note) - self.image_bytes = image_bytes + self.image_b64_str = image_b64_str self.image_mode = image_mode self.image_size = image_size @@ -68,8 +70,11 @@ def factory(cls, image: PIL.Image.Image, /, **kwargs) -> PillowImageItem: if not lazy_is_instance(image, "PIL.Image.Image"): raise ItemTypeError(f"Type '{image.__class__}' is not supported.") + image_bytes = image.tobytes() + image_b64_str = bytes_to_b64_str(image_bytes) + return cls( - image_bytes=image.tobytes(), + image_b64_str=image_b64_str, image_mode=image.mode, image_size=image.size, **kwargs, @@ -80,22 +85,21 @@ def image(self) -> PIL.Image.Image: """The image from the persistence.""" import PIL.Image + image_bytes = b64_str_to_bytes(self.image_b64_str) + return PIL.Image.frombytes( mode=self.image_mode, size=self.image_size, - data=self.image_bytes, + data=image_bytes, ) def as_serializable_dict(self): """Convert item to a JSON-serializable dict to used by frontend.""" - import base64 - import io - - with io.BytesIO() as stream: + with BytesIO() as stream: self.image.save(stream, format="png") png_bytes = stream.getvalue() - png_b64_str = base64.b64encode(png_bytes).decode() + png_b64_str = bytes_to_b64_str(png_bytes) return super().as_serializable_dict() | { "media_type": "image/png;base64", diff --git a/skore/src/skore/persistence/item/plotly_figure_item.py b/skore/src/skore/persistence/item/plotly_figure_item.py index 2aaae561f..332e7582b 100644 --- a/skore/src/skore/persistence/item/plotly_figure_item.py +++ b/skore/src/skore/persistence/item/plotly_figure_item.py @@ -7,8 +7,9 @@ from typing import TYPE_CHECKING, Optional -from .item import Item, ItemTypeError -from .media_item import lazy_is_instance +from skore.persistence.item.item import Item, ItemTypeError +from skore.persistence.item.media_item import lazy_is_instance +from skore.utils import bytes_to_b64_str if TYPE_CHECKING: import plotly.basedatatypes @@ -29,7 +30,7 @@ def __init__( Parameters ---------- - figure_str : bytes + figure_str : str The JSON str of the Plotly figure. created_at : str, optional The creation timestamp in ISO format. @@ -78,10 +79,8 @@ def figure(self) -> plotly.basedatatypes.BaseFigure: def as_serializable_dict(self): """Convert item to a JSON-serializable dict to used by frontend.""" - import base64 - figure_bytes = self.figure_str.encode("utf-8") - figure_b64_str = base64.b64encode(figure_bytes).decode() + figure_b64_str = bytes_to_b64_str(figure_bytes) return super().as_serializable_dict() | { "media_type": "application/vnd.plotly.v1+json;base64", diff --git a/skore/src/skore/persistence/item/sklearn_base_estimator_item.py b/skore/src/skore/persistence/item/sklearn_base_estimator_item.py index aef7d05ff..d359d8b07 100644 --- a/skore/src/skore/persistence/item/sklearn_base_estimator_item.py +++ b/skore/src/skore/persistence/item/sklearn_base_estimator_item.py @@ -8,7 +8,8 @@ from typing import TYPE_CHECKING, Union -from .item import Item, ItemTypeError +from skore.persistence.item.item import Item, ItemTypeError +from skore.utils import b64_str_to_bytes, bytes_to_b64_str if TYPE_CHECKING: import sklearn.base @@ -25,7 +26,7 @@ class SklearnBaseEstimatorItem(Item): def __init__( self, estimator_html_repr: str, - estimator_skops: bytes, + estimator_skops_b64_str: str, estimator_skops_untrusted_types: list[str], created_at: Union[str, None] = None, updated_at: Union[str, None] = None, @@ -38,7 +39,7 @@ def __init__( ---------- estimator_html_repr : str The HTML representation of the scikit-learn estimator. - estimator_skops : bytes + estimator_skops_b64_str : str The skops representation of the scikit-learn estimator. estimator_skops_untrusted_types : list[str] The list of untrusted types in the skops representation. @@ -52,7 +53,7 @@ def __init__( super().__init__(created_at, updated_at, note) self.estimator_html_repr = estimator_html_repr - self.estimator_skops = estimator_skops + self.estimator_skops_b64_str = estimator_skops_b64_str self.estimator_skops_untrusted_types = estimator_skops_untrusted_types @property @@ -67,8 +68,11 @@ def estimator(self) -> sklearn.base.BaseEstimator: """ import skops.io + estimator_skops_bytes = b64_str_to_bytes(self.estimator_skops_b64_str) + return skops.io.loads( - self.estimator_skops, trusted=self.estimator_skops_untrusted_types + data=estimator_skops_bytes, + trusted=self.estimator_skops_untrusted_types, ) @classmethod @@ -92,24 +96,25 @@ def factory( A new SklearnBaseEstimatorItem instance. """ import sklearn.base - import sklearn.utils if not isinstance(estimator, sklearn.base.BaseEstimator): raise ItemTypeError(f"Type '{estimator.__class__}' is not supported.") # This line is only needed if we know `estimator` has the right type, so we do # it after the type check + import sklearn.utils import skops.io estimator_html_repr = sklearn.utils.estimator_html_repr(estimator) - estimator_skops = skops.io.dumps(estimator) + estimator_skops_bytes = skops.io.dumps(estimator) + estimator_skops_b64_str = bytes_to_b64_str(estimator_skops_bytes) estimator_skops_untrusted_types = skops.io.get_untrusted_types( - data=estimator_skops + data=estimator_skops_bytes ) return cls( estimator_html_repr=estimator_html_repr, - estimator_skops=estimator_skops, + estimator_skops_b64_str=estimator_skops_b64_str, estimator_skops_untrusted_types=estimator_skops_untrusted_types, **kwargs, ) diff --git a/skore/src/skore/utils/__init__.py b/skore/src/skore/utils/__init__.py index 3ee290a4d..17d3e533b 100644 --- a/skore/src/skore/utils/__init__.py +++ b/skore/src/skore/utils/__init__.py @@ -1 +1,19 @@ """Various utilities to help with development.""" + +from base64 import b64decode, b64encode + + +def bytes_to_b64_str(literals: bytes) -> str: + """Encode the bytes-like object `literales` in a Base64 str.""" + return b64encode(literals).decode("utf-8") + + +def b64_str_to_bytes(literals: str) -> bytes: + """Decode the Base64 str object `literales` in a bytes.""" + return b64decode(literals.encode("utf-8")) + + +__all__ = [ + "bytes_to_b64_str", + "b64_str_to_bytes", +] diff --git a/skore/tests/unit/item/test_altair_chart_item.py b/skore/tests/unit/item/test_altair_chart_item.py index 9c5b01984..10b26d7fd 100644 --- a/skore/tests/unit/item/test_altair_chart_item.py +++ b/skore/tests/unit/item/test_altair_chart_item.py @@ -1,4 +1,5 @@ import base64 +import json import altair import pytest @@ -22,6 +23,14 @@ def test_factory_exception(self): with pytest.raises(ItemTypeError): AltairChartItem.factory(None) + def test_ensure_jsonable(self): + chart = altair.Chart().mark_point() + + item = AltairChartItem.factory(chart) + item_parameters = item.__parameters__ + + json.dumps(item_parameters) + def test_chart(self): chart = altair.Chart().mark_point() chart_str = chart.to_json() diff --git a/skore/tests/unit/item/test_cross_validation_reporter_item.py b/skore/tests/unit/item/test_cross_validation_reporter_item.py index ff4b8afdf..77479cb96 100644 --- a/skore/tests/unit/item/test_cross_validation_reporter_item.py +++ b/skore/tests/unit/item/test_cross_validation_reporter_item.py @@ -1,5 +1,6 @@ import dataclasses import io +import json import joblib import numpy @@ -15,6 +16,7 @@ from skore.sklearn.cross_validation.cross_validation_reporter import ( CrossValidationPlots, ) +from skore.utils import bytes_to_b64_str class FakeEstimator: @@ -89,11 +91,20 @@ def test_factory(self, mock_nowstr, reporter): joblib.dump(reporter, stream) reporter_bytes = stream.getvalue() + reporter_b64_str = bytes_to_b64_str(reporter_bytes) - assert item.reporter_bytes == reporter_bytes + assert item.reporter_b64_str == reporter_b64_str assert item.created_at == mock_nowstr assert item.updated_at == mock_nowstr + def test_ensure_jsonable(self): + reporter = FakeCrossValidationReporter() + + item = CrossValidationReporterItem.factory(reporter) + item_parameters = item.__parameters__ + + json.dumps(item_parameters) + def test_reporter(self, mock_nowstr): reporter = FakeCrossValidationReporter() @@ -101,10 +112,11 @@ def test_reporter(self, mock_nowstr): joblib.dump(reporter, stream) reporter_bytes = stream.getvalue() + reporter_b64_str = bytes_to_b64_str(reporter_bytes) item1 = CrossValidationReporterItem.factory(reporter) item2 = CrossValidationReporterItem( - reporter_bytes=reporter_bytes, + reporter_b64_str=reporter_b64_str, created_at=mock_nowstr, updated_at=mock_nowstr, ) diff --git a/skore/tests/unit/item/test_matplotlib_figure_item.py b/skore/tests/unit/item/test_matplotlib_figure_item.py index ccbd3ba02..bffd09b2b 100644 --- a/skore/tests/unit/item/test_matplotlib_figure_item.py +++ b/skore/tests/unit/item/test_matplotlib_figure_item.py @@ -1,5 +1,6 @@ import base64 import io +import json import joblib import pytest @@ -7,6 +8,7 @@ from matplotlib.pyplot import subplots from matplotlib.testing.compare import compare_images from skore.persistence.item import ItemTypeError, MatplotlibFigureItem +from skore.utils import b64_str_to_bytes, bytes_to_b64_str class FakeFigure(Figure): @@ -29,7 +31,7 @@ def test_factory(self, mock_nowstr, tmp_path): # we can't compare figure bytes directly figure.savefig(tmp_path / "figure.png") - with io.BytesIO(item.figure_bytes) as stream: + with io.BytesIO(b64_str_to_bytes(item.figure_b64_str)) as stream: joblib.load(stream).savefig(tmp_path / "item.png") assert compare_images(tmp_path / "figure.png", tmp_path / "item.png", 0) is None @@ -40,6 +42,15 @@ def test_factory_exception(self): with pytest.raises(ItemTypeError): MatplotlibFigureItem.factory(None) + def test_ensure_jsonable(self): + figure, ax = subplots() + ax.plot([1, 2, 3, 4], [1, 4, 2, 3]) + + item = MatplotlibFigureItem.factory(figure) + item_parameters = item.__parameters__ + + json.dumps(item_parameters) + def test_figure(self, tmp_path): figure, ax = subplots() ax.plot([1, 2, 3, 4], [1, 4, 2, 3]) @@ -48,9 +59,10 @@ def test_figure(self, tmp_path): joblib.dump(figure, stream) figure_bytes = stream.getvalue() + figure_b64_str = bytes_to_b64_str(figure_bytes) item1 = MatplotlibFigureItem.factory(figure) - item2 = MatplotlibFigureItem(figure_bytes) + item2 = MatplotlibFigureItem(figure_b64_str) figure.savefig(tmp_path / "figure.png") item1.figure.savefig(tmp_path / "item1.png") diff --git a/skore/tests/unit/item/test_pickle_item.py b/skore/tests/unit/item/test_pickle_item.py index b09dba9e1..c0a10adff 100644 --- a/skore/tests/unit/item/test_pickle_item.py +++ b/skore/tests/unit/item/test_pickle_item.py @@ -1,8 +1,10 @@ import io +import json import joblib import pytest from skore.persistence.item import PickleItem +from skore.utils import bytes_to_b64_str class TestPickleItem: @@ -17,21 +19,29 @@ def test_factory(self, mock_nowstr, object): with io.BytesIO() as stream: joblib.dump(object, stream) - object_bytes = stream.getvalue() + pickle_bytes = stream.getvalue() + pickle_b64_str = bytes_to_b64_str(pickle_bytes) - assert item.pickle_bytes == object_bytes + assert item.pickle_b64_str == pickle_b64_str assert item.created_at == mock_nowstr assert item.updated_at == mock_nowstr + def test_ensure_jsonable(self): + item = PickleItem.factory(object) + item_parameters = item.__parameters__ + + json.dumps(item_parameters) + def test_object(self, mock_nowstr): with io.BytesIO() as stream: joblib.dump(int, stream) - int_bytes = stream.getvalue() + pickle_bytes = stream.getvalue() + pickle_b64_str = bytes_to_b64_str(pickle_bytes) item1 = PickleItem.factory(int) item2 = PickleItem( - pickle_bytes=int_bytes, + pickle_b64_str=pickle_b64_str, created_at=mock_nowstr, updated_at=mock_nowstr, ) diff --git a/skore/tests/unit/item/test_pillow_image_item.py b/skore/tests/unit/item/test_pillow_image_item.py index 9db1a467f..a9e4d21f2 100644 --- a/skore/tests/unit/item/test_pillow_image_item.py +++ b/skore/tests/unit/item/test_pillow_image_item.py @@ -1,9 +1,10 @@ -import base64 import io +import json import PIL.Image import pytest from skore.persistence.item import ItemTypeError, PillowImageItem +from skore.utils import bytes_to_b64_str class TestPillowImageItem: @@ -13,9 +14,12 @@ def monkeypatch_datetime(self, monkeypatch, MockDatetime): def test_factory(self, mock_nowstr): image = PIL.Image.new("RGB", (100, 100), color="red") + image_bytes = image.tobytes() + image_b64_str = bytes_to_b64_str(image_bytes) + item = PillowImageItem.factory(image) - assert item.image_bytes == image.tobytes() + assert item.image_b64_str == image_b64_str assert item.image_mode == image.mode assert item.image_size == image.size assert item.created_at == mock_nowstr @@ -25,11 +29,22 @@ def test_factory_exception(self): with pytest.raises(ItemTypeError): PillowImageItem.factory(None) + def test_ensure_jsonable(self): + image = PIL.Image.new("RGB", (100, 100), color="red") + + item = PillowImageItem.factory(image) + item_parameters = item.__parameters__ + + json.dumps(item_parameters) + def test_image(self): image = PIL.Image.new("RGB", (100, 100), color="red") + image_bytes = image.tobytes() + image_b64_str = bytes_to_b64_str(image_bytes) + item1 = PillowImageItem.factory(image) item2 = PillowImageItem( - image_bytes=image.tobytes(), + image_b64_str=image_b64_str, image_mode=image.mode, image_size=image.size, ) @@ -45,7 +60,7 @@ def test_as_serializable_dict(self, mock_nowstr): image.save(stream, format="png") png_bytes = stream.getvalue() - png_b64_str = base64.b64encode(png_bytes).decode() + png_b64_str = bytes_to_b64_str(png_bytes) assert item.as_serializable_dict() == { "updated_at": mock_nowstr, diff --git a/skore/tests/unit/item/test_plotly_figure_item.py b/skore/tests/unit/item/test_plotly_figure_item.py index 46c9050c4..406c167b3 100644 --- a/skore/tests/unit/item/test_plotly_figure_item.py +++ b/skore/tests/unit/item/test_plotly_figure_item.py @@ -1,4 +1,5 @@ import base64 +import json import plotly.graph_objects import plotly.io @@ -24,6 +25,15 @@ def test_factory_exception(self): with pytest.raises(ItemTypeError): PlotlyFigureItem.factory(None) + def test_ensure_jsonable(self): + bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2]) + figure = plotly.graph_objects.Figure(data=[bar]) + + item = PlotlyFigureItem.factory(figure) + item_parameters = item.__parameters__ + + json.dumps(item_parameters) + def test_figure(self): bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2]) figure = plotly.graph_objects.Figure(data=[bar]) diff --git a/skore/tests/unit/item/test_sklearn_base_estimator_item.py b/skore/tests/unit/item/test_sklearn_base_estimator_item.py index 3ebe635d9..f2a2ebb3e 100644 --- a/skore/tests/unit/item/test_sklearn_base_estimator_item.py +++ b/skore/tests/unit/item/test_sklearn_base_estimator_item.py @@ -1,7 +1,10 @@ +import json + import pytest import sklearn.svm import skops.io from skore.persistence.item import ItemTypeError, SklearnBaseEstimatorItem +from skore.utils import bytes_to_b64_str class Estimator(sklearn.svm.SVC): @@ -21,14 +24,17 @@ def test_factory_exception(self): def test_factory(self, monkeypatch, mock_nowstr): estimator = sklearn.svm.SVC() estimator_html_repr = "" - estimator_skops = "" + estimator_skops_bytes = b"" + estimator_skops_b64_str = bytes_to_b64_str(b"") estimator_skops_untrusted_types = "" monkeypatch.setattr( "sklearn.utils.estimator_html_repr", lambda *args, **kwargs: estimator_html_repr, ) - monkeypatch.setattr("skops.io.dumps", lambda *args, **kwargs: estimator_skops) + monkeypatch.setattr( + "skops.io.dumps", lambda *args, **kwargs: estimator_skops_bytes + ) monkeypatch.setattr( "skops.io.get_untrusted_types", lambda *args, **kwargs: estimator_skops_untrusted_types, @@ -37,23 +43,32 @@ def test_factory(self, monkeypatch, mock_nowstr): item = SklearnBaseEstimatorItem.factory(estimator) assert item.estimator_html_repr == estimator_html_repr - assert item.estimator_skops == estimator_skops + assert item.estimator_skops_b64_str == estimator_skops_b64_str assert item.estimator_skops_untrusted_types == estimator_skops_untrusted_types assert item.created_at == mock_nowstr assert item.updated_at == mock_nowstr + def test_ensure_jsonable(self): + estimator = sklearn.svm.SVC() + + item = SklearnBaseEstimatorItem.factory(estimator) + item_parameters = item.__parameters__ + + json.dumps(item_parameters) + @pytest.mark.order(1) def test_estimator(self, mock_nowstr): estimator = sklearn.svm.SVC() - estimator_skops = skops.io.dumps(estimator) + estimator_skops_bytes = skops.io.dumps(estimator) + estimator_skops_b64_str = bytes_to_b64_str(estimator_skops_bytes) estimator_skops_untrusted_types = skops.io.get_untrusted_types( - data=estimator_skops + data=estimator_skops_bytes ) item1 = SklearnBaseEstimatorItem.factory(estimator) item2 = SklearnBaseEstimatorItem( estimator_html_repr=None, - estimator_skops=estimator_skops, + estimator_skops_b64_str=estimator_skops_b64_str, estimator_skops_untrusted_types=estimator_skops_untrusted_types, created_at=mock_nowstr, updated_at=mock_nowstr, @@ -65,9 +80,10 @@ def test_estimator(self, mock_nowstr): @pytest.mark.order(1) def test_estimator_untrusted(self, mock_nowstr): estimator = Estimator() - estimator_skops = skops.io.dumps(estimator) + estimator_skops_bytes = skops.io.dumps(estimator) + estimator_skops_b64_str = bytes_to_b64_str(estimator_skops_bytes) estimator_skops_untrusted_types = skops.io.get_untrusted_types( - data=estimator_skops + data=estimator_skops_bytes ) if not estimator_skops_untrusted_types: @@ -82,7 +98,7 @@ def test_estimator_untrusted(self, mock_nowstr): item1 = SklearnBaseEstimatorItem.factory(estimator) item2 = SklearnBaseEstimatorItem( estimator_html_repr=None, - estimator_skops=estimator_skops, + estimator_skops_b64_str=estimator_skops_b64_str, estimator_skops_untrusted_types=estimator_skops_untrusted_types, created_at=mock_nowstr, updated_at=mock_nowstr,