Skip to content

Commit

Permalink
Finish SQL functions refactoring (#543)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour authored Nov 27, 2024
1 parent 3671039 commit 6244161
Show file tree
Hide file tree
Showing 42 changed files with 1,653 additions and 408 deletions.
22 changes: 11 additions & 11 deletions docs/references/sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ for operations like [`DataChain.filter`](datachain.md#datachain.lib.dc.DataChain
and [`DataChain.mutate`](datachain.md#datachain.lib.dc.DataChain.mutate). Import
these functions from `datachain.sql.functions`.

::: datachain.sql.functions.avg
::: datachain.sql.functions.count
::: datachain.sql.functions.greatest
::: datachain.sql.functions.least
::: datachain.sql.functions.max
::: datachain.sql.functions.min
::: datachain.sql.functions.rand
::: datachain.sql.functions.sum
::: datachain.sql.functions.array
::: datachain.sql.functions.path
::: datachain.sql.functions.string
::: datachain.func.avg
::: datachain.func.count
::: datachain.func.greatest
::: datachain.func.least
::: datachain.func.max
::: datachain.func.min
::: datachain.func.rand
::: datachain.func.sum
::: datachain.func.array
::: datachain.func.path
::: datachain.func.string
4 changes: 2 additions & 2 deletions examples/computer_vision/openimage-detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from PIL import Image

from datachain import C, DataChain, File, model
from datachain.sql.functions import path
from datachain.func import path


def openimage_detect(args):
Expand Down Expand Up @@ -48,7 +48,7 @@ def openimage_detect(args):
.filter(C("file.path").glob("*.jpg") | C("file.path").glob("*.json"))
.agg(
openimage_detect,
partition_by=path.file_stem(C("file.path")),
partition_by=path.file_stem("file.path"),
params=["file"],
output={"file": File, "bbox": model.BBox},
)
Expand Down
9 changes: 4 additions & 5 deletions examples/get_started/common_sql_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datachain import C, DataChain
from datachain.sql import literal
from datachain.sql.functions import array, greatest, least, path, string
from datachain.func import array, greatest, least, path, string


def num_chars_udf(file):
Expand All @@ -18,7 +17,7 @@ def num_chars_udf(file):
(
dc.mutate(
length=string.length(path.name(C("file.path"))),
parts=string.split(path.name(C("file.path")), literal(".")),
parts=string.split(path.name(C("file.path")), "."),
)
.select("file.path", "length", "parts")
.show(5)
Expand All @@ -35,8 +34,8 @@ def num_chars_udf(file):


chain = dc.mutate(
a=array.length(string.split(C("file.path"), literal("/"))),
b=array.length(string.split(path.name(C("file.path")), literal("0"))),
a=array.length(string.split("file.path", "/")),
b=array.length(string.split(path.name("file.path"), "0")),
)

(
Expand Down
7 changes: 3 additions & 4 deletions examples/multimodal/clip_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from torch.nn.functional import cosine_similarity
from torch.utils.data import DataLoader

from datachain import C, DataChain
from datachain.sql.functions import path
from datachain import C, DataChain, func

source = "gs://datachain-demo/50k-laion-files/000000/00000000*"

Expand All @@ -18,8 +17,8 @@ def create_dataset():
)
return imgs.merge(
captions,
on=path.file_stem(imgs.c("file.path")),
right_on=path.file_stem(captions.c("file.path")),
on=func.path.file_stem(imgs.c("file.path")),
right_on=func.path.file_stem(captions.c("file.path")),
)


Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal/wds.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os

from datachain import DataChain
from datachain.func import path
from datachain.lib.webdataset import process_webdataset
from datachain.lib.webdataset_laion import WDSLaion, process_laion_meta
from datachain.sql.functions import path

IMAGE_TARS = os.getenv(
"IMAGE_TARS", "gs://datachain-demo/datacomp-small/shards/000000[0-5]*.tar"
Expand Down
16 changes: 6 additions & 10 deletions examples/multimodal/wds_filtered.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import datachain.error
from datachain import C, DataChain
from datachain import C, DataChain, func
from datachain.lib.webdataset import process_webdataset
from datachain.lib.webdataset_laion import WDSLaion
from datachain.sql import literal
from datachain.sql.functions import array, greatest, least, string

name = "wds"
try:
Expand All @@ -20,14 +18,12 @@
wds.print_schema()

filtered = (
wds.filter(string.length(C("laion.txt")) > 5)
.filter(array.length(string.split(C("laion.txt"), literal(" "))) > 2)
wds.filter(func.string.length("laion.txt") > 5)
.filter(func.array.length(func.string.split("laion.txt", " ")) > 2)
.filter(func.least("laion.json.original_width", "laion.json.original_height") > 200)
.filter(
least(C("laion.json.original_width"), C("laion.json.original_height")) > 200
)
.filter(
greatest(C("laion.json.original_width"), C("laion.json.original_height"))
/ least(C("laion.json.original_width"), C("laion.json.original_height"))
func.greatest("laion.json.original_width", "laion.json.original_height")
/ func.least("laion.json.original_width", "laion.json.original_height")
< 3.0
)
.save()
Expand Down
2 changes: 0 additions & 2 deletions src/datachain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datachain.lib import func
from datachain.lib.data_model import DataModel, DataType, is_chain_type
from datachain.lib.dc import C, Column, DataChain, Sys
from datachain.lib.file import (
Expand Down Expand Up @@ -35,7 +34,6 @@
"Sys",
"TarVFile",
"TextFile",
"func",
"is_chain_type",
"metrics",
"param",
Expand Down
14 changes: 8 additions & 6 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
QueryScriptCancelError,
QueryScriptRunError,
)
from datachain.listing import Listing
from datachain.node import DirType, Node, NodeWithPath
from datachain.nodes_thread_pool import NodesThreadPool
from datachain.remote.studio import StudioClient
Expand All @@ -76,6 +75,7 @@
from datachain.dataset import DatasetVersion
from datachain.job import Job
from datachain.lib.file import File
from datachain.listing import Listing

logger = logging.getLogger("datachain")

Expand Down Expand Up @@ -236,7 +236,7 @@ def do_task(self, urls):
class NodeGroup:
"""Class for a group of nodes from the same source"""

listing: Listing
listing: "Listing"
sources: list[DataSource]

# The source path within the bucket
Expand Down Expand Up @@ -591,8 +591,9 @@ def enlist_source(
client_config=None,
object_name="file",
skip_indexing=False,
) -> tuple[Listing, str]:
) -> tuple["Listing", str]:
from datachain.lib.dc import DataChain
from datachain.listing import Listing

DataChain.from_storage(
source, session=self.session, update=update, object_name=object_name
Expand Down Expand Up @@ -660,7 +661,8 @@ def enlist_sources_grouped(
no_glob: bool = False,
client_config=None,
) -> list[NodeGroup]:
from datachain.query import DatasetQuery
from datachain.listing import Listing
from datachain.query.dataset import DatasetQuery

def _row_to_node(d: dict[str, Any]) -> Node:
del d["file__source"]
Expand Down Expand Up @@ -876,7 +878,7 @@ def create_new_dataset_version(
def update_dataset_version_with_warehouse_info(
self, dataset: DatasetRecord, version: int, rows_dropped=False, **kwargs
) -> None:
from datachain.query import DatasetQuery
from datachain.query.dataset import DatasetQuery

dataset_version = dataset.get_version(version)

Expand Down Expand Up @@ -1177,7 +1179,7 @@ def listings(self):
def ls_dataset_rows(
self, name: str, version: int, offset=None, limit=None
) -> list[dict]:
from datachain.query import DatasetQuery
from datachain.query.dataset import DatasetQuery

dataset = self.get_dataset(name)

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def show(
schema: bool = False,
) -> None:
from datachain.lib.dc import DataChain
from datachain.query import DatasetQuery
from datachain.query.dataset import DatasetQuery
from datachain.utils import show_records

dataset = catalog.get_dataset(name)
Expand Down
18 changes: 9 additions & 9 deletions src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
from datachain.cache import DataChainCache
from datachain.client.fileslice import FileWrapper
from datachain.error import ClientError as DataChainClientError
from datachain.lib.file import File
from datachain.nodes_fetcher import NodesFetcher
from datachain.nodes_thread_pool import NodeChunk

if TYPE_CHECKING:
from fsspec.spec import AbstractFileSystem

from datachain.dataset import StorageURI
from datachain.lib.file import File


logger = logging.getLogger("datachain")
Expand All @@ -45,7 +45,7 @@

DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$")

ResultQueue = asyncio.Queue[Optional[Sequence[File]]]
ResultQueue = asyncio.Queue[Optional[Sequence["File"]]]


def _is_win_local_path(uri: str) -> bool:
Expand Down Expand Up @@ -212,7 +212,7 @@ async def get_file(self, lpath, rpath, callback):

async def scandir(
self, start_prefix: str, method: str = "default"
) -> AsyncIterator[Sequence[File]]:
) -> AsyncIterator[Sequence["File"]]:
try:
impl = getattr(self, f"_fetch_{method}")
except AttributeError:
Expand Down Expand Up @@ -317,7 +317,7 @@ def get_full_path(self, rel_path: str) -> str:
return f"{self.PREFIX}{self.name}/{rel_path}"

@abstractmethod
def info_to_file(self, v: dict[str, Any], parent: str) -> File: ...
def info_to_file(self, v: dict[str, Any], parent: str) -> "File": ...

def fetch_nodes(
self,
Expand Down Expand Up @@ -354,27 +354,27 @@ def do_instantiate_object(self, file: "File", dst: str) -> None:
copy2(src, dst)

def open_object(
self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
self, file: "File", use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
) -> BinaryIO:
"""Open a file, including files in tar archives."""
if use_cache and (cache_path := self.cache.get_path(file)):
return open(cache_path, mode="rb")
assert not file.location
return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]

def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None:
def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
sync(get_loop(), functools.partial(self._download, file, callback=callback))

async def _download(self, file: File, *, callback: "Callback" = None) -> None:
async def _download(self, file: "File", *, callback: "Callback" = None) -> None:
if self.cache.contains(file):
# Already in cache, so there's nothing to do.
return
await self._put_in_cache(file, callback=callback)

def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
def put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback))

async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
async def _put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
assert not file.location
if file.etag:
etag = await self.get_current_etag(file)
Expand Down
4 changes: 2 additions & 2 deletions src/datachain/data_storage/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy.sql import func as f
from sqlalchemy.sql.expression import false, null, true

from datachain.sql.functions import path
from datachain.sql.functions import path as pathfunc
from datachain.sql.types import Int, SQLType, UInt64

if TYPE_CHECKING:
Expand Down Expand Up @@ -130,7 +130,7 @@ def apply_group_by(self, q):

def query(self, q):
q = self.base_select(q).cte(recursive=True)
parent = path.parent(self.c(q, "path"))
parent = pathfunc.parent(self.c(q, "path"))
q = q.union_all(
sa.select(
sa.literal(-1).label("sys__id"),
Expand Down
49 changes: 49 additions & 0 deletions src/datachain/func/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from sqlalchemy import literal

from . import array, path, random, string
from .aggregate import (
any_value,
avg,
collect,
concat,
count,
dense_rank,
first,
max,
min,
rank,
row_number,
sum,
)
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
from .conditional import greatest, least
from .random import rand
from .window import window

__all__ = [
"any_value",
"array",
"avg",
"collect",
"concat",
"cosine_distance",
"count",
"dense_rank",
"euclidean_distance",
"first",
"greatest",
"least",
"length",
"literal",
"max",
"min",
"path",
"rand",
"random",
"rank",
"row_number",
"sip_hash_64",
"string",
"sum",
"window",
]
Loading

0 comments on commit 6244161

Please sign in to comment.