From 3f5e9f8365cad74757b1a9db14958eed20d032db Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Tue, 12 Nov 2024 20:38:35 -0800 Subject: [PATCH] fix(dc): collect function return types (#589) --- src/datachain/lib/dc.py | 10 +++++----- src/datachain/lib/meta_formats.py | 1 + src/datachain/lib/signal_schema.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index f3e18a3b1..4491dbe4c 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -30,7 +30,7 @@ from datachain.dataset import DatasetRecord from datachain.lib.convert.python_to_sql import python_to_sql from datachain.lib.convert.values_to_tuples import values_to_tuples -from datachain.lib.data_model import DataModel, DataType, dict_to_data_model +from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model from datachain.lib.dataset_info import DatasetInfo from datachain.lib.file import ArrowRow, File, get_file_type from datachain.lib.file import ExportPlacement as FileExportPlacement @@ -1262,15 +1262,15 @@ def to_dict(cols: list[str], row: tuple[Any, ...]) -> dict[str, Any]: return self.results(row_factory=to_dict) @overload - def collect(self) -> Iterator[tuple[DataType, ...]]: ... + def collect(self) -> Iterator[tuple[DataValue, ...]]: ... @overload - def collect(self, col: str) -> Iterator[DataType]: ... # type: ignore[overload-overlap] + def collect(self, col: str) -> Iterator[DataValue]: ... @overload - def collect(self, *cols: str) -> Iterator[tuple[DataType, ...]]: ... + def collect(self, *cols: str) -> Iterator[tuple[DataValue, ...]]: ... - def collect(self, *cols: str) -> Iterator[Union[DataType, tuple[DataType, ...]]]: # type: ignore[overload-overlap,misc] + def collect(self, *cols: str) -> Iterator[Union[DataValue, tuple[DataValue, ...]]]: # type: ignore[overload-overlap,misc] """Yields rows of values, optionally limited to the specified columns. Args: diff --git a/src/datachain/lib/meta_formats.py b/src/datachain/lib/meta_formats.py index 9747e901a..70473557a 100644 --- a/src/datachain/lib/meta_formats.py +++ b/src/datachain/lib/meta_formats.py @@ -114,6 +114,7 @@ def read_meta( # noqa: C901 ) ) (model_output,) = chain.collect("meta_schema") + assert isinstance(model_output, str) if print_schema: print(f"{model_output}") # Below 'spec' should be a dynamically converted DataModel from Pydantic diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 10bd36996..29cf202cd 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -378,7 +378,7 @@ def slice( def row_to_features( self, row: Sequence, catalog: "Catalog", cache: bool = False - ) -> list[DataType]: + ) -> list[DataValue]: res = [] pos = 0 for fr_cls in self.values.values():