Skip to content

Commit

Permalink
do not pre-load dataset version preview from string (#642)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon authored Nov 29, 2024
1 parent 049f718 commit 08f4625
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
17 changes: 14 additions & 3 deletions src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from dataclasses import dataclass, fields
from datetime import datetime
from functools import cached_property
from typing import (
Any,
NewType,
Expand All @@ -11,6 +12,8 @@
)
from urllib.parse import urlparse

import orjson

from datachain.error import DatasetVersionNotFoundError
from datachain.sql.types import NAME_TYPES_MAPPING, SQLType

Expand Down Expand Up @@ -178,7 +181,7 @@ class DatasetVersion:
schema: dict[str, Union[SQLType, type[SQLType]]]
num_objects: Optional[int]
size: Optional[int]
preview: Optional[list[dict]]
_preview_data: Optional[Union[str, list[dict]]]
sources: str = ""
query_script: str = ""
job_id: Optional[str] = None
Expand All @@ -199,7 +202,7 @@ def parse( # noqa: PLR0913
script_output: str,
num_objects: Optional[int],
size: Optional[int],
preview: Optional[str],
preview: Optional[Union[str, list[dict]]],
schema: dict[str, Union[SQLType, type[SQLType]]],
sources: str = "",
query_script: str = "",
Expand All @@ -220,7 +223,7 @@ def parse( # noqa: PLR0913
schema,
num_objects,
size,
json.loads(preview) if preview else None,
preview,
sources,
query_script,
job_id,
Expand Down Expand Up @@ -260,9 +263,17 @@ def serialized_schema(self) -> dict[str, Any]:
for c_name, c_type in self.schema.items()
}

@cached_property
def preview(self) -> Optional[list[dict]]:
if isinstance(self._preview_data, str):
return orjson.loads(self._preview_data)
return self._preview_data if self._preview_data else None

@classmethod
def from_dict(cls, d: dict[str, Any]) -> "DatasetVersion":
kwargs = {f.name: d[f.name] for f in fields(cls) if f.name in d}
if not hasattr(kwargs, "_preview_data"):
kwargs["_preview_data"] = d.get("preview")
return cls(**kwargs)


Expand Down
34 changes: 33 additions & 1 deletion tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from datetime import datetime, timezone

import pytest
Expand All @@ -6,7 +7,7 @@
from sqlalchemy.schema import CreateTable

from datachain.data_storage.schema import DataTable
from datachain.dataset import DatasetDependency, DatasetDependencyType
from datachain.dataset import DatasetDependency, DatasetDependencyType, DatasetVersion
from datachain.sql.types import (
JSON,
Array,
Expand Down Expand Up @@ -106,3 +107,34 @@ def test_dataset_dependency_dataset_name(dep_name, dep_type, expected):
)

assert dep.dataset_name == expected


@pytest.mark.parametrize(
"use_string",
[True, False],
)
def test_dataset_version_from_dict(use_string):
preview = [{"id": 1, "thing": "a"}, {"id": 2, "thing": "b"}]

preview_data = json.dumps(preview) if use_string else preview

data = {
"id": 1,
"uuid": "98928be4-b6e8-4b7b-a7c5-2ce3b33130d8",
"dataset_id": 40,
"version": 2,
"status": 1,
"feature_schema": {},
"created_at": datetime.fromisoformat("2023-10-01T12:00:00"),
"finished_at": None,
"error_message": "",
"error_stack": "",
"script_output": "",
"schema": {},
"num_objects": 100,
"size": 1000000,
"preview": preview_data,
}

dataset_version = DatasetVersion.from_dict(data)
assert dataset_version.preview == preview

0 comments on commit 08f4625

Please sign in to comment.