Skip to content

Commit

Permalink
fix(dc): collect function return types (#589)
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored Nov 13, 2024
1 parent 02e8efe commit 3f5e9f8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
10 changes: 5 additions & 5 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from datachain.dataset import DatasetRecord
from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.convert.values_to_tuples import values_to_tuples
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model
from datachain.lib.dataset_info import DatasetInfo
from datachain.lib.file import ArrowRow, File, get_file_type
from datachain.lib.file import ExportPlacement as FileExportPlacement
Expand Down Expand Up @@ -1262,15 +1262,15 @@ def to_dict(cols: list[str], row: tuple[Any, ...]) -> dict[str, Any]:
return self.results(row_factory=to_dict)

@overload
def collect(self) -> Iterator[tuple[DataType, ...]]: ...
def collect(self) -> Iterator[tuple[DataValue, ...]]: ...

@overload
def collect(self, col: str) -> Iterator[DataType]: ... # type: ignore[overload-overlap]
def collect(self, col: str) -> Iterator[DataValue]: ...

@overload
def collect(self, *cols: str) -> Iterator[tuple[DataType, ...]]: ...
def collect(self, *cols: str) -> Iterator[tuple[DataValue, ...]]: ...

def collect(self, *cols: str) -> Iterator[Union[DataType, tuple[DataType, ...]]]: # type: ignore[overload-overlap,misc]
def collect(self, *cols: str) -> Iterator[Union[DataValue, tuple[DataValue, ...]]]: # type: ignore[overload-overlap,misc]
"""Yields rows of values, optionally limited to the specified columns.
Args:
Expand Down
1 change: 1 addition & 0 deletions src/datachain/lib/meta_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def read_meta( # noqa: C901
)
)
(model_output,) = chain.collect("meta_schema")
assert isinstance(model_output, str)
if print_schema:
print(f"{model_output}")
# Below 'spec' should be a dynamically converted DataModel from Pydantic
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def slice(

def row_to_features(
self, row: Sequence, catalog: "Catalog", cache: bool = False
) -> list[DataType]:
) -> list[DataValue]:
res = []
pos = 0
for fr_cls in self.values.values():
Expand Down

0 comments on commit 3f5e9f8

Please sign in to comment.