Skip to content

Commit

Permalink
feat: expose dynamic projection to pylance (#2003)
Browse files Browse the repository at this point in the history
expose dynamic projection to pylance.

TL;DR:
```python
dataset.to_table(columns={
    "name": "func(my_col)"
})
```
is equal to
```SQL
SELECT func(my_col) AS name
FROM dataset
```

TODO:
- [x] update function signatures to `Union[list[str], dict[str, str]]`
- [x] update doc strings
- [ ] update doc to explain how dynamic projection works
  • Loading branch information
chebbyChefNEQ authored Feb 27, 2024
1 parent 934b0fa commit 4a3f8fd
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 49 deletions.
12 changes: 7 additions & 5 deletions python/python/lance/_dataset/sharded_batch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Generator, List, Literal, Union
from typing import TYPE_CHECKING, Dict, Generator, List, Literal, Optional, Union

import lance
from lance.dataset import LanceDataset
Expand All @@ -40,8 +40,10 @@ class ShardedBatchIterator:
The rank (id) of the shard in total `world_rank` shards.
world_rank: int
Total number of shards.
columns: list of strs, optional
Select columns to scan.
columns: list of str, or dict of str to str default None
List of column names to be fetched.
Or a dictionary of column names to SQL expressions.
All columns are fetched if None or unspecified.
batch_size: int, optional
The batch size of each shard.
granularity: str, optional
Expand Down Expand Up @@ -69,7 +71,7 @@ def __init__(
rank: int,
world_size: int,
*,
columns: List[str] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
batch_size: int = 1024 * 10,
granularity: Literal["fragment", "batch"] = "fragment",
batch_readahead: int = 8,
Expand All @@ -94,7 +96,7 @@ def __init__(
def from_torch(
data: Union[str, Path, LanceDataset],
*,
columns: List[str] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
batch_size: int = 1024 * 10,
granularity: Literal["fragment", "batch"] = "fragment",
batch_readahead: int = 8,
Expand Down
60 changes: 39 additions & 21 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def has_index(self):

def scanner(
self,
columns: Optional[list[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
filter: Optional[Union[str, pa.compute.Expression]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
Expand All @@ -232,9 +232,10 @@ def scanner(
Parameters
----------
columns: list of str, default None
columns: list of str, or dict of str to str default None
List of column names to be fetched.
All columns if None or unspecified.
Or a dictionary of column names to SQL expressions.
All columns are fetched if None or unspecified.
filter: pa.compute.Expression or str
Expression or str that is a valid SQL where clause. See
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
Expand Down Expand Up @@ -328,7 +329,7 @@ def schema(self) -> pa.Schema:

def to_table(
self,
columns: Optional[list[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
filter: Optional[Union[str, pa.compute.Expression]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
Expand All @@ -346,9 +347,10 @@ def to_table(
Parameters
----------
columns: list of str, default None
columns: list of str, or dict of str to str default None
List of column names to be fetched.
All columns if None or unspecified.
Or a dictionary of column names to SQL expressions.
All columns are fetched if None or unspecified.
filter : pa.compute.Expression or str
Expression or str that is a valid SQL where clause. See
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
Expand Down Expand Up @@ -443,7 +445,7 @@ def get_fragment(self, fragment_id: int) -> Optional[LanceFragment]:

def to_batches(
self,
columns: Optional[list[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
filter: Optional[Union[str, pa.compute.Expression]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
Expand Down Expand Up @@ -487,7 +489,7 @@ def to_batches(
def sample(
self,
num_rows: int,
columns: Optional[List[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
randomize_order: bool = True,
**kwargs,
) -> pa.Table:
Expand All @@ -497,9 +499,10 @@ def sample(
----------
num_rows: int
number of rows to retrieve
columns: list of strings, optional
list of column names to be fetched. All columns are fetched
if not specified.
columns: list of str, or dict of str to str default None
List of column names to be fetched.
Or a dictionary of column names to SQL expressions.
All columns are fetched if None or unspecified.
**kwargs : dict, optional
see scanner() method for full parameter description.
Expand All @@ -518,7 +521,7 @@ def sample(
def take(
self,
indices: Union[List[int], pa.Array],
columns: Optional[List[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
**kwargs,
) -> pa.Table:
"""Select rows of data by index.
Expand All @@ -527,9 +530,10 @@ def take(
----------
indices : Array or array-like
indices of rows to select in the dataset.
columns: list of strings, optional
List of column names to be fetched. All columns are fetched
if not specified.
columns: list of str, or dict of str to str default None
List of column names to be fetched.
Or a dictionary of column names to SQL expressions.
All columns are fetched if None or unspecified.
**kwargs : dict, optional
See scanner() method for full parameter description.
Expand All @@ -542,7 +546,7 @@ def take(
def _take_rows(
self,
row_ids: Union[List[int], pa.Array],
columns: Optional[List[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
**kargs,
) -> pa.Table:
"""Select rows by row_ids.
Expand All @@ -553,9 +557,10 @@ def _take_rows(
----------
row_ids : List Array or array-like
row IDs to select in the dataset.
columns: list of strings, optional
List of column names to be fetched. All columns are fetched
if not specified.
columns: list of str, or dict of str to str default None
List of column names to be fetched.
Or a dictionary of column names to SQL expressions.
All columns are fetched if None or unspecified.
**kwargs : dict, optional
See scanner() method for full parameter description.
Expand Down Expand Up @@ -1865,6 +1870,7 @@ def __init__(self, ds: LanceDataset):
self._prefilter = None
self._offset = None
self._columns = None
self._columns_with_transform = None
self._nearest = None
self._batch_size: Optional[int] = None
self._batch_readahead: Optional[int] = None
Expand Down Expand Up @@ -1914,8 +1920,19 @@ def offset(self, n: Optional[int] = None) -> ScannerBuilder:
self._offset = n
return self

def columns(self, cols: Optional[list[str]] = None) -> ScannerBuilder:
self._columns = cols
def columns(
self, cols: Optional[Union[List[str], Dict[str, str]]] = None
) -> ScannerBuilder:
if cols is None:
self._columns = None
elif isinstance(cols, dict):
self._columns_with_transform = list(cols.items())
elif isinstance(cols, list):
self._columns = cols
else:
raise TypeError(
f"columns must be a list or dict[name, expression], got {type(cols)}"
)
return self

def filter(self, filter: Union[str, pa.compute.Expression]) -> ScannerBuilder:
Expand Down Expand Up @@ -2040,6 +2057,7 @@ def nearest(
def to_scanner(self) -> LanceScanner:
scanner = self.ds._ds.scanner(
self._columns,
self._columns_with_transform,
self._filter,
self._prefilter,
self._limit,
Expand Down
23 changes: 18 additions & 5 deletions python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@

import json
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, Union
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Union,
)

import pyarrow as pa

Expand Down Expand Up @@ -242,7 +251,7 @@ def head(self, num_rows: int) -> pa.Table:
def scanner(
self,
*,
columns: Optional[list[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
batch_size: Optional[int] = None,
filter: Optional[Union[str, pa.compute.Expression]] = None,
limit: int = 0,
Expand All @@ -265,13 +274,17 @@ def scanner(

return LanceScanner(s, self._ds)

def take(self, indices, columns: Optional[list[str]] = None) -> pa.Table:
def take(
self,
indices,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
) -> pa.Table:
return pa.Table.from_batches([self._fragment.take(indices, columns=columns)])

def to_batches(
self,
*,
columns: Optional[list[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
batch_size: Optional[int] = None,
filter: Optional[Union[str, pa.compute.Expression]] = None,
limit: int = 0,
Expand All @@ -291,7 +304,7 @@ def to_batches(

def to_table(
self,
columns: Optional[list[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
filter: Optional[Union[str, pa.compute.Expression]] = None,
limit: int = 0,
offset: Optional[int] = None,
Expand Down
20 changes: 10 additions & 10 deletions python/python/lance/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from dataclasses import dataclass, field
from heapq import heappush, heappushpop
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, List, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar, Union

import pyarrow as pa

Expand All @@ -47,7 +47,7 @@
def _efficient_sample(
dataset: lance.LanceDataset,
n: int,
columns: list[str],
columns: Optional[Union[List[str], Dict[str, str]]],
batch_size: int,
max_takes: int,
) -> Generator[pa.RecordBatch, None, None]:
Expand Down Expand Up @@ -120,7 +120,7 @@ def _efficient_sample(
def maybe_sample(
dataset: Union[str, Path, lance.LanceDataset],
n: int,
columns: Union[list[str], str],
columns: Union[list[str], dict[str, str], str],
batch_size: int = 10240,
max_takes: int = 2048,
) -> Generator[pa.RecordBatch, None, None]:
Expand All @@ -132,7 +132,7 @@ def maybe_sample(
The dataset to sample from.
n : int
The number of records to sample.
columns : Union[list[str], str]
columns : Union[list[str], dict[str, str], str]
The columns to load.
batch_size : int, optional
The batch size to use when loading the data, by default 10240.
Expand Down Expand Up @@ -209,7 +209,7 @@ def __call__(
ds: lance.LanceDataset,
*args,
batch_size: int = 128,
columns: Optional[List[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
filter: Optional[str] = None,
batch_readahead: int = 16,
with_row_id: bool = False,
Expand All @@ -231,7 +231,7 @@ def __call__(
dataset: lance.LanceDataset,
*args,
batch_size: int = 128,
columns: Optional[List[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
filter: Optional[str] = None,
batch_readahead: int = 16,
with_row_id: bool = False,
Expand Down Expand Up @@ -364,7 +364,7 @@ def _shard_scan(
self,
dataset: lance.LanceDataset,
batch_size: int,
columns: Optional[List[str]],
columns: Optional[Union[List[str], Dict[str, str]]],
batch_readahead: int,
filter: str,
) -> Generator[lance.RecordBatch, None, None]:
Expand Down Expand Up @@ -415,7 +415,7 @@ def _sample_filtered(
self,
dataset: lance.LanceDataset,
batch_size: int,
columns: Optional[List[str]],
columns: Optional[Union[List[str], Dict[str, str]]],
batch_readahead: int,
filter: str,
) -> Generator[lance.RecordBatch, None, None]:
Expand Down Expand Up @@ -455,7 +455,7 @@ def _sample_all(
self,
dataset: lance.LanceDataset,
batch_size: int,
columns: Optional[List[str]],
columns: Optional[Union[List[str], Dict[str, str]]],
batch_readahead: int,
) -> Generator[lance.RecordBatch, None, None]:
total = dataset.count_rows()
Expand Down Expand Up @@ -484,7 +484,7 @@ def __call__(
dataset: lance.LanceDataset,
*args,
batch_size: int = 128,
columns: Optional[List[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
filter: Optional[str] = None,
batch_readahead: int = 16,
with_row_id: Optional[bool] = None,
Expand Down
4 changes: 2 additions & 2 deletions python/python/lance/tf/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def column_to_tensor(array: pa.Array, tensor_spec: tf.TensorSpec) -> tf.Tensor:
def from_lance(
dataset: Union[str, Path, LanceDataset],
*,
columns: Optional[List[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
batch_size: int = 256,
filter: Optional[str] = None,
fragments: Union[Iterable[int], Iterable[LanceFragment], tf.data.Dataset] = None,
Expand Down Expand Up @@ -326,7 +326,7 @@ def lance_take_batches(
dataset: Union[str, Path, LanceDataset],
batch_ranges: Iterable[Tuple[int, int]],
*,
columns: Optional[List[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
output_signature: Optional[Dict[str, tf.TypeSpec]] = None,
batch_readahead: int = 10,
) -> tf.data.Dataset:
Expand Down
4 changes: 2 additions & 2 deletions python/python/lance/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import math
import warnings
from pathlib import Path
from typing import Iterable, Literal, Optional, Union
from typing import Dict, Iterable, List, Literal, Optional, Union

import pyarrow as pa

Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
dataset: Union[torch.utils.data.Dataset, str, Path],
batch_size: int,
*args,
columns: Optional[list[str]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
filter: Optional[str] = None,
samples: Optional[int] = 0,
cache: Optional[Union[str, bool]] = None,
Expand Down
Loading

0 comments on commit 4a3f8fd

Please sign in to comment.