diff --git a/docs/references/sql.md b/docs/references/sql.md index 67cf83e91..ddf663993 100644 --- a/docs/references/sql.md +++ b/docs/references/sql.md @@ -5,14 +5,14 @@ for operations like [`DataChain.filter`](datachain.md#datachain.lib.dc.DataChain and [`DataChain.mutate`](datachain.md#datachain.lib.dc.DataChain.mutate). Import these functions from `datachain.sql.functions`. -::: datachain.sql.functions.avg -::: datachain.sql.functions.count -::: datachain.sql.functions.greatest -::: datachain.sql.functions.least -::: datachain.sql.functions.max -::: datachain.sql.functions.min -::: datachain.sql.functions.rand -::: datachain.sql.functions.sum -::: datachain.sql.functions.array -::: datachain.sql.functions.path -::: datachain.sql.functions.string +::: datachain.func.avg +::: datachain.func.count +::: datachain.func.greatest +::: datachain.func.least +::: datachain.func.max +::: datachain.func.min +::: datachain.func.rand +::: datachain.func.sum +::: datachain.func.array +::: datachain.func.path +::: datachain.func.string diff --git a/examples/computer_vision/openimage-detect.py b/examples/computer_vision/openimage-detect.py index fc73f21f5..de1a88d39 100644 --- a/examples/computer_vision/openimage-detect.py +++ b/examples/computer_vision/openimage-detect.py @@ -3,7 +3,7 @@ from PIL import Image from datachain import C, DataChain, File, model -from datachain.sql.functions import path +from datachain.func import path def openimage_detect(args): @@ -48,7 +48,7 @@ def openimage_detect(args): .filter(C("file.path").glob("*.jpg") | C("file.path").glob("*.json")) .agg( openimage_detect, - partition_by=path.file_stem(C("file.path")), + partition_by=path.file_stem("file.path"), params=["file"], output={"file": File, "bbox": model.BBox}, ) diff --git a/examples/get_started/common_sql_functions.py b/examples/get_started/common_sql_functions.py index bb96f1f99..7f6b90ef8 100644 --- a/examples/get_started/common_sql_functions.py +++ b/examples/get_started/common_sql_functions.py @@ -1,6 +1,5 @@ from datachain import C, DataChain -from datachain.sql import literal -from datachain.sql.functions import array, greatest, least, path, string +from datachain.func import array, greatest, least, path, string def num_chars_udf(file): @@ -18,7 +17,7 @@ def num_chars_udf(file): ( dc.mutate( length=string.length(path.name(C("file.path"))), - parts=string.split(path.name(C("file.path")), literal(".")), + parts=string.split(path.name(C("file.path")), "."), ) .select("file.path", "length", "parts") .show(5) @@ -35,8 +34,8 @@ def num_chars_udf(file): chain = dc.mutate( - a=array.length(string.split(C("file.path"), literal("/"))), - b=array.length(string.split(path.name(C("file.path")), literal("0"))), + a=array.length(string.split("file.path", "/")), + b=array.length(string.split(path.name("file.path"), "0")), ) ( diff --git a/examples/multimodal/clip_inference.py b/examples/multimodal/clip_inference.py index b6a37dcf7..183b73045 100644 --- a/examples/multimodal/clip_inference.py +++ b/examples/multimodal/clip_inference.py @@ -3,8 +3,7 @@ from torch.nn.functional import cosine_similarity from torch.utils.data import DataLoader -from datachain import C, DataChain -from datachain.sql.functions import path +from datachain import C, DataChain, func source = "gs://datachain-demo/50k-laion-files/000000/00000000*" @@ -18,8 +17,8 @@ def create_dataset(): ) return imgs.merge( captions, - on=path.file_stem(imgs.c("file.path")), - right_on=path.file_stem(captions.c("file.path")), + on=func.path.file_stem(imgs.c("file.path")), + right_on=func.path.file_stem(captions.c("file.path")), ) diff --git a/examples/multimodal/wds.py b/examples/multimodal/wds.py index 6d016dbc6..0ed5ea5ce 100644 --- a/examples/multimodal/wds.py +++ b/examples/multimodal/wds.py @@ -1,9 +1,9 @@ import os from datachain import DataChain +from datachain.func import path from datachain.lib.webdataset import process_webdataset from datachain.lib.webdataset_laion import WDSLaion, process_laion_meta -from datachain.sql.functions import path IMAGE_TARS = os.getenv( "IMAGE_TARS", "gs://datachain-demo/datacomp-small/shards/000000[0-5]*.tar" diff --git a/examples/multimodal/wds_filtered.py b/examples/multimodal/wds_filtered.py index a06b27657..02822b5f1 100644 --- a/examples/multimodal/wds_filtered.py +++ b/examples/multimodal/wds_filtered.py @@ -1,9 +1,7 @@ import datachain.error -from datachain import C, DataChain +from datachain import C, DataChain, func from datachain.lib.webdataset import process_webdataset from datachain.lib.webdataset_laion import WDSLaion -from datachain.sql import literal -from datachain.sql.functions import array, greatest, least, string name = "wds" try: @@ -20,14 +18,12 @@ wds.print_schema() filtered = ( - wds.filter(string.length(C("laion.txt")) > 5) - .filter(array.length(string.split(C("laion.txt"), literal(" "))) > 2) + wds.filter(func.string.length("laion.txt") > 5) + .filter(func.array.length(func.string.split("laion.txt", " ")) > 2) + .filter(func.least("laion.json.original_width", "laion.json.original_height") > 200) .filter( - least(C("laion.json.original_width"), C("laion.json.original_height")) > 200 - ) - .filter( - greatest(C("laion.json.original_width"), C("laion.json.original_height")) - / least(C("laion.json.original_width"), C("laion.json.original_height")) + func.greatest("laion.json.original_width", "laion.json.original_height") + / func.least("laion.json.original_width", "laion.json.original_height") < 3.0 ) .save() diff --git a/src/datachain/__init__.py b/src/datachain/__init__.py index 98c7b8641..e8bbc00bf 100644 --- a/src/datachain/__init__.py +++ b/src/datachain/__init__.py @@ -1,4 +1,3 @@ -from datachain.lib import func from datachain.lib.data_model import DataModel, DataType, is_chain_type from datachain.lib.dc import C, Column, DataChain, Sys from datachain.lib.file import ( @@ -35,7 +34,6 @@ "Sys", "TarVFile", "TextFile", - "func", "is_chain_type", "metrics", "param", diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 4bc6f75b1..010caaf4f 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -54,7 +54,6 @@ QueryScriptCancelError, QueryScriptRunError, ) -from datachain.listing import Listing from datachain.node import DirType, Node, NodeWithPath from datachain.nodes_thread_pool import NodesThreadPool from datachain.remote.studio import StudioClient @@ -76,6 +75,7 @@ from datachain.dataset import DatasetVersion from datachain.job import Job from datachain.lib.file import File + from datachain.listing import Listing logger = logging.getLogger("datachain") @@ -236,7 +236,7 @@ def do_task(self, urls): class NodeGroup: """Class for a group of nodes from the same source""" - listing: Listing + listing: "Listing" sources: list[DataSource] # The source path within the bucket @@ -591,8 +591,9 @@ def enlist_source( client_config=None, object_name="file", skip_indexing=False, - ) -> tuple[Listing, str]: + ) -> tuple["Listing", str]: from datachain.lib.dc import DataChain + from datachain.listing import Listing DataChain.from_storage( source, session=self.session, update=update, object_name=object_name @@ -660,7 +661,8 @@ def enlist_sources_grouped( no_glob: bool = False, client_config=None, ) -> list[NodeGroup]: - from datachain.query import DatasetQuery + from datachain.listing import Listing + from datachain.query.dataset import DatasetQuery def _row_to_node(d: dict[str, Any]) -> Node: del d["file__source"] @@ -876,7 +878,7 @@ def create_new_dataset_version( def update_dataset_version_with_warehouse_info( self, dataset: DatasetRecord, version: int, rows_dropped=False, **kwargs ) -> None: - from datachain.query import DatasetQuery + from datachain.query.dataset import DatasetQuery dataset_version = dataset.get_version(version) @@ -1177,7 +1179,7 @@ def listings(self): def ls_dataset_rows( self, name: str, version: int, offset=None, limit=None ) -> list[dict]: - from datachain.query import DatasetQuery + from datachain.query.dataset import DatasetQuery dataset = self.get_dataset(name) diff --git a/src/datachain/cli.py b/src/datachain/cli.py index 9859c1aff..e7f178856 100644 --- a/src/datachain/cli.py +++ b/src/datachain/cli.py @@ -957,7 +957,7 @@ def show( schema: bool = False, ) -> None: from datachain.lib.dc import DataChain - from datachain.query import DatasetQuery + from datachain.query.dataset import DatasetQuery from datachain.utils import show_records dataset = catalog.get_dataset(name) diff --git a/src/datachain/client/fsspec.py b/src/datachain/client/fsspec.py index 968458a8d..91774e381 100644 --- a/src/datachain/client/fsspec.py +++ b/src/datachain/client/fsspec.py @@ -28,7 +28,6 @@ from datachain.cache import DataChainCache from datachain.client.fileslice import FileWrapper from datachain.error import ClientError as DataChainClientError -from datachain.lib.file import File from datachain.nodes_fetcher import NodesFetcher from datachain.nodes_thread_pool import NodeChunk @@ -36,6 +35,7 @@ from fsspec.spec import AbstractFileSystem from datachain.dataset import StorageURI + from datachain.lib.file import File logger = logging.getLogger("datachain") @@ -45,7 +45,7 @@ DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$") -ResultQueue = asyncio.Queue[Optional[Sequence[File]]] +ResultQueue = asyncio.Queue[Optional[Sequence["File"]]] def _is_win_local_path(uri: str) -> bool: @@ -212,7 +212,7 @@ async def get_file(self, lpath, rpath, callback): async def scandir( self, start_prefix: str, method: str = "default" - ) -> AsyncIterator[Sequence[File]]: + ) -> AsyncIterator[Sequence["File"]]: try: impl = getattr(self, f"_fetch_{method}") except AttributeError: @@ -317,7 +317,7 @@ def get_full_path(self, rel_path: str) -> str: return f"{self.PREFIX}{self.name}/{rel_path}" @abstractmethod - def info_to_file(self, v: dict[str, Any], parent: str) -> File: ... + def info_to_file(self, v: dict[str, Any], parent: str) -> "File": ... def fetch_nodes( self, @@ -354,7 +354,7 @@ def do_instantiate_object(self, file: "File", dst: str) -> None: copy2(src, dst) def open_object( - self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK + self, file: "File", use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK ) -> BinaryIO: """Open a file, including files in tar archives.""" if use_cache and (cache_path := self.cache.get_path(file)): @@ -362,19 +362,19 @@ def open_object( assert not file.location return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value] - def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None: + def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None: sync(get_loop(), functools.partial(self._download, file, callback=callback)) - async def _download(self, file: File, *, callback: "Callback" = None) -> None: + async def _download(self, file: "File", *, callback: "Callback" = None) -> None: if self.cache.contains(file): # Already in cache, so there's nothing to do. return await self._put_in_cache(file, callback=callback) - def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None: + def put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None: sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback)) - async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None: + async def _put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None: assert not file.location if file.etag: etag = await self.get_current_etag(file) diff --git a/src/datachain/data_storage/schema.py b/src/datachain/data_storage/schema.py index 0c002e5dd..23cc52a2f 100644 --- a/src/datachain/data_storage/schema.py +++ b/src/datachain/data_storage/schema.py @@ -12,7 +12,7 @@ from sqlalchemy.sql import func as f from sqlalchemy.sql.expression import false, null, true -from datachain.sql.functions import path +from datachain.sql.functions import path as pathfunc from datachain.sql.types import Int, SQLType, UInt64 if TYPE_CHECKING: @@ -130,7 +130,7 @@ def apply_group_by(self, q): def query(self, q): q = self.base_select(q).cte(recursive=True) - parent = path.parent(self.c(q, "path")) + parent = pathfunc.parent(self.c(q, "path")) q = q.union_all( sa.select( sa.literal(-1).label("sys__id"), diff --git a/src/datachain/func/__init__.py b/src/datachain/func/__init__.py new file mode 100644 index 000000000..534f852b5 --- /dev/null +++ b/src/datachain/func/__init__.py @@ -0,0 +1,49 @@ +from sqlalchemy import literal + +from . import array, path, random, string +from .aggregate import ( + any_value, + avg, + collect, + concat, + count, + dense_rank, + first, + max, + min, + rank, + row_number, + sum, +) +from .array import cosine_distance, euclidean_distance, length, sip_hash_64 +from .conditional import greatest, least +from .random import rand +from .window import window + +__all__ = [ + "any_value", + "array", + "avg", + "collect", + "concat", + "cosine_distance", + "count", + "dense_rank", + "euclidean_distance", + "first", + "greatest", + "least", + "length", + "literal", + "max", + "min", + "path", + "rand", + "random", + "rank", + "row_number", + "sip_hash_64", + "string", + "sum", + "window", +] diff --git a/src/datachain/lib/func/aggregate.py b/src/datachain/func/aggregate.py similarity index 93% rename from src/datachain/lib/func/aggregate.py rename to src/datachain/func/aggregate.py index 00ae0077a..1492b34d9 100644 --- a/src/datachain/lib/func/aggregate.py +++ b/src/datachain/func/aggregate.py @@ -2,7 +2,7 @@ from sqlalchemy import func as sa_func -from datachain.sql import functions as dc_func +from datachain.sql.functions import aggregate from .func import Func @@ -31,7 +31,9 @@ def count(col: Optional[str] = None) -> Func: Notes: - Result column will always be of type int. """ - return Func("count", inner=sa_func.count, col=col, result_type=int) + return Func( + "count", inner=sa_func.count, cols=[col] if col else None, result_type=int + ) def sum(col: str) -> Func: @@ -59,7 +61,7 @@ def sum(col: str) -> Func: - The `sum` function should be used on numeric columns. - Result column type will be the same as the input column type. """ - return Func("sum", inner=sa_func.sum, col=col) + return Func("sum", inner=sa_func.sum, cols=[col]) def avg(col: str) -> Func: @@ -87,7 +89,7 @@ def avg(col: str) -> Func: - The `avg` function should be used on numeric columns. - Result column will always be of type float. """ - return Func("avg", inner=dc_func.aggregate.avg, col=col, result_type=float) + return Func("avg", inner=aggregate.avg, cols=[col], result_type=float) def min(col: str) -> Func: @@ -115,7 +117,7 @@ def min(col: str) -> Func: - The `min` function can be used with numeric, date, and string columns. - Result column will have the same type as the input column. """ - return Func("min", inner=sa_func.min, col=col) + return Func("min", inner=sa_func.min, cols=[col]) def max(col: str) -> Func: @@ -143,7 +145,7 @@ def max(col: str) -> Func: - The `max` function can be used with numeric, date, and string columns. - Result column will have the same type as the input column. """ - return Func("max", inner=sa_func.max, col=col) + return Func("max", inner=sa_func.max, cols=[col]) def any_value(col: str) -> Func: @@ -174,7 +176,7 @@ def any_value(col: str) -> Func: - The result of `any_value` is non-deterministic, meaning it may return different values for different executions. """ - return Func("any_value", inner=dc_func.aggregate.any_value, col=col) + return Func("any_value", inner=aggregate.any_value, cols=[col]) def collect(col: str) -> Func: @@ -203,7 +205,7 @@ def collect(col: str) -> Func: - The `collect` function can be used with numeric and string columns. - Result column will have an array type. """ - return Func("collect", inner=dc_func.aggregate.collect, col=col, is_array=True) + return Func("collect", inner=aggregate.collect, cols=[col], is_array=True) def concat(col: str, separator="") -> Func: @@ -236,9 +238,9 @@ def concat(col: str, separator="") -> Func: """ def inner(arg): - return dc_func.aggregate.group_concat(arg, separator) + return aggregate.group_concat(arg, separator) - return Func("concat", inner=inner, col=col, result_type=str) + return Func("concat", inner=inner, cols=[col], result_type=str) def row_number() -> Func: @@ -350,4 +352,4 @@ def first(col: str) -> Func: in the specified order. - The result column will have the same type as the input column. """ - return Func("first", inner=sa_func.first_value, col=col, is_window=True) + return Func("first", inner=sa_func.first_value, cols=[col], is_window=True) diff --git a/src/datachain/func/array.py b/src/datachain/func/array.py new file mode 100644 index 000000000..ae3614fb9 --- /dev/null +++ b/src/datachain/func/array.py @@ -0,0 +1,176 @@ +from collections.abc import Sequence +from typing import Union + +from datachain.sql.functions import array + +from .func import Func + + +def cosine_distance(*args: Union[str, Sequence]) -> Func: + """ + Computes the cosine distance between two vectors. + + The cosine distance is derived from the cosine similarity, which measures the angle + between two vectors. This function returns the dissimilarity between the vectors, + where 0 indicates identical vectors and values closer to 1 + indicate higher dissimilarity. + + Args: + args (str | Sequence): Two vectors to compute the cosine distance between. + If a string is provided, it is assumed to be the name of the column vector. + If a sequence is provided, it is assumed to be a vector of values. + + Returns: + Func: A Func object that represents the cosine_distance function. + + Example: + ```py + target_embedding = [0.1, 0.2, 0.3] + dc.mutate( + cos_dist1=func.cosine_distance("embedding", target_embedding), + cos_dist2=func.cosine_distance(target_embedding, [0.4, 0.5, 0.6]), + ) + ``` + + Notes: + - Ensure both vectors have the same number of elements. + - Result column will always be of type float. + """ + cols, func_args = [], [] + for arg in args: + if isinstance(arg, str): + cols.append(arg) + else: + func_args.append(list(arg)) + + if len(cols) + len(func_args) != 2: + raise ValueError("cosine_distance() requires exactly two arguments") + if not cols and len(func_args[0]) != len(func_args[1]): + raise ValueError("cosine_distance() requires vectors of the same length") + + return Func( + "cosine_distance", + inner=array.cosine_distance, + cols=cols, + args=func_args, + result_type=float, + ) + + +def euclidean_distance(*args: Union[str, Sequence]) -> Func: + """ + Computes the Euclidean distance between two vectors. + + The Euclidean distance is the straight-line distance between two points + in Euclidean space. This function returns the distance between the two vectors. + + Args: + args (str | Sequence): Two vectors to compute the Euclidean distance between. + If a string is provided, it is assumed to be the name of the column vector. + If a sequence is provided, it is assumed to be a vector of values. + + Returns: + Func: A Func object that represents the euclidean_distance function. + + Example: + ```py + target_embedding = [0.1, 0.2, 0.3] + dc.mutate( + eu_dist1=func.euclidean_distance("embedding", target_embedding), + eu_dist2=func.euclidean_distance(target_embedding, [0.4, 0.5, 0.6]), + ) + ``` + + Notes: + - Ensure both vectors have the same number of elements. + - Result column will always be of type float. + """ + cols, func_args = [], [] + for arg in args: + if isinstance(arg, str): + cols.append(arg) + else: + func_args.append(list(arg)) + + if len(cols) + len(func_args) != 2: + raise ValueError("euclidean_distance() requires exactly two arguments") + if not cols and len(func_args[0]) != len(func_args[1]): + raise ValueError("euclidean_distance() requires vectors of the same length") + + return Func( + "euclidean_distance", + inner=array.euclidean_distance, + cols=cols, + args=func_args, + result_type=float, + ) + + +def length(arg: Union[str, Sequence, Func]) -> Func: + """ + Returns the length of the array. + + Args: + arg (str | Sequence | Func): Array to compute the length of. + If a string is provided, it is assumed to be the name of the array column. + If a sequence is provided, it is assumed to be an array of values. + If a Func is provided, it is assumed to be a function returning an array. + + Returns: + Func: A Func object that represents the array length function. + + Example: + ```py + dc.mutate( + len1=func.array.length("signal.values"), + len2=func.array.length([1, 2, 3, 4, 5]), + ) + ``` + + Note: + - Result column will always be of type int. + """ + if isinstance(arg, (str, Func)): + cols = [arg] + args = None + else: + cols = None + args = [arg] + + return Func("length", inner=array.length, cols=cols, args=args, result_type=int) + + +def sip_hash_64(arg: Union[str, Sequence]) -> Func: + """ + Computes the SipHash-64 hash of the array. + + Args: + arg (str | Sequence): Array to compute the SipHash-64 hash of. + If a string is provided, it is assumed to be the name of the array column. + If a sequence is provided, it is assumed to be an array of values. + + Returns: + Func: A Func object that represents the sip_hash_64 function. + + Example: + ```py + dc.mutate( + hash1=func.sip_hash_64("signal.values"), + hash2=func.sip_hash_64([1, 2, 3, 4, 5]), + ) + ``` + + Note: + - This function is only available for the ClickHouse warehouse. + - Result column will always be of type int. + """ + if isinstance(arg, str): + cols = [arg] + args = None + else: + cols = None + args = [arg] + + return Func( + "sip_hash_64", inner=array.sip_hash_64, cols=cols, args=args, result_type=int + ) diff --git a/src/datachain/func/base.py b/src/datachain/func/base.py new file mode 100644 index 000000000..ff65dc202 --- /dev/null +++ b/src/datachain/func/base.py @@ -0,0 +1,23 @@ +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from sqlalchemy import TableClause + + from datachain.lib.signal_schema import SignalSchema + from datachain.query.schema import Column + + +class Function: + __metaclass__ = ABCMeta + + name: str + + @abstractmethod + def get_column( + self, + signals_schema: Optional["SignalSchema"] = None, + label: Optional[str] = None, + table: Optional["TableClause"] = None, + ) -> "Column": + pass diff --git a/src/datachain/func/conditional.py b/src/datachain/func/conditional.py new file mode 100644 index 000000000..a95865d63 --- /dev/null +++ b/src/datachain/func/conditional.py @@ -0,0 +1,81 @@ +from typing import Union + +from datachain.sql.functions import conditional + +from .func import ColT, Func + + +def greatest(*args: Union[ColT, float]) -> Func: + """ + Returns the greatest (largest) value from the given input values. + + Args: + args (ColT | str | int | float | Sequence): The values to compare. + If a string is provided, it is assumed to be the name of the column. + If a Func is provided, it is assumed to be a function returning a value. + If an int, float, or Sequence is provided, it is assumed to be a literal. + + Returns: + Func: A Func object that represents the greatest function. + + Example: + ```py + dc.mutate( + greatest=func.greatest("signal.value", 0), + ) + ``` + + Note: + - Result column will always be of the same type as the input columns. + """ + cols, func_args = [], [] + + for arg in args: + if isinstance(arg, (str, Func)): + cols.append(arg) + else: + func_args.append(arg) + + return Func( + "greatest", + inner=conditional.greatest, + cols=cols, + args=func_args, + result_type=int, + ) + + +def least(*args: Union[ColT, float]) -> Func: + """ + Returns the least (smallest) value from the given input values. + + Args: + args (ColT | str | int | float | Sequence): The values to compare. + If a string is provided, it is assumed to be the name of the column. + If a Func is provided, it is assumed to be a function returning a value. + If an int, float, or Sequence is provided, it is assumed to be a literal. + + Returns: + Func: A Func object that represents the least function. + + Example: + ```py + dc.mutate( + least=func.least("signal.value", 0), + ) + ``` + + Note: + - Result column will always be of the same type as the input columns. + """ + cols, func_args = [], [] + + for arg in args: + if isinstance(arg, (str, Func)): + cols.append(arg) + else: + func_args.append(arg) + + return Func( + "least", inner=conditional.least, cols=cols, args=func_args, result_type=int + ) diff --git a/src/datachain/func/func.py b/src/datachain/func/func.py new file mode 100644 index 000000000..3f2352eba --- /dev/null +++ b/src/datachain/func/func.py @@ -0,0 +1,384 @@ +import inspect +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from sqlalchemy import BindParameter, ColumnElement, desc + +from datachain.lib.convert.python_to_sql import python_to_sql +from datachain.lib.utils import DataChainColumnError, DataChainParamsError +from datachain.query.schema import Column, ColumnMeta + +from .base import Function + +if TYPE_CHECKING: + from sqlalchemy import TableClause + + from datachain import DataType + from datachain.lib.signal_schema import SignalSchema + + from .window import Window + + +ColT = Union[str, ColumnElement, "Func"] + + +class Func(Function): + """Represents a function to be applied to a column in a SQL query.""" + + def __init__( + self, + name: str, + inner: Callable, + cols: Optional[Sequence[ColT]] = None, + args: Optional[Sequence[Any]] = None, + result_type: Optional["DataType"] = None, + is_array: bool = False, + is_window: bool = False, + window: Optional["Window"] = None, + label: Optional[str] = None, + ) -> None: + self.name = name + self.inner = inner + self.cols = cols or [] + self.args = args or [] + self.result_type = result_type + self.is_array = is_array + self.is_window = is_window + self.window = window + self.col_label = label + + def __str__(self) -> str: + return self.name + "()" + + def over(self, window: "Window") -> "Func": + if not self.is_window: + raise DataChainParamsError(f"{self} doesn't support window (over())") + + return Func( + "over", + self.inner, + self.cols, + self.args, + self.result_type, + self.is_array, + self.is_window, + window, + self.col_label, + ) + + @property + def _db_cols(self) -> Sequence[ColT]: + return ( + [ + col + if isinstance(col, (Func, BindParameter)) + else ColumnMeta.to_db_name( + col.name if isinstance(col, ColumnElement) else col + ) + for col in self.cols + ] + if self.cols + else [] + ) + + def _db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]: + if not self._db_cols: + return None + + col_type: type = get_db_col_type(signals_schema, self._db_cols[0]) + for col in self._db_cols[1:]: + if get_db_col_type(signals_schema, col) != col_type: + raise DataChainColumnError( + str(self), + "Columns must have the same type to infer result type", + ) + + return list[col_type] if self.is_array else col_type # type: ignore[valid-type] + + def __add__(self, other: Union[ColT, float]) -> "Func": + return math_add(self, other) + + def __radd__(self, other: Union[ColT, float]) -> "Func": + return math_add(other, self) + + def __sub__(self, other: Union[ColT, float]) -> "Func": + return math_sub(self, other) + + def __rsub__(self, other: Union[ColT, float]) -> "Func": + return math_sub(other, self) + + def __mul__(self, other: Union[ColT, float]) -> "Func": + return math_mul(self, other) + + def __rmul__(self, other: Union[ColT, float]) -> "Func": + return math_mul(other, self) + + def __truediv__(self, other: Union[ColT, float]) -> "Func": + return math_truediv(self, other) + + def __rtruediv__(self, other: Union[ColT, float]) -> "Func": + return math_truediv(other, self) + + def __floordiv__(self, other: Union[ColT, float]) -> "Func": + return math_floordiv(self, other) + + def __rfloordiv__(self, other: Union[ColT, float]) -> "Func": + return math_floordiv(other, self) + + def __mod__(self, other: Union[ColT, float]) -> "Func": + return math_mod(self, other) + + def __rmod__(self, other: Union[ColT, float]) -> "Func": + return math_mod(other, self) + + def __pow__(self, other: Union[ColT, float]) -> "Func": + return math_pow(self, other) + + def __rpow__(self, other: Union[ColT, float]) -> "Func": + return math_pow(other, self) + + def __lshift__(self, other: Union[ColT, float]) -> "Func": + return math_lshift(self, other) + + def __rlshift__(self, other: Union[ColT, float]) -> "Func": + return math_lshift(other, self) + + def __rshift__(self, other: Union[ColT, float]) -> "Func": + return math_rshift(self, other) + + def __rrshift__(self, other: Union[ColT, float]) -> "Func": + return math_rshift(other, self) + + def __and__(self, other: Union[ColT, float]) -> "Func": + return math_and(self, other) + + def __rand__(self, other: Union[ColT, float]) -> "Func": + return math_and(other, self) + + def __or__(self, other: Union[ColT, float]) -> "Func": + return math_or(self, other) + + def __ror__(self, other: Union[ColT, float]) -> "Func": + return math_or(other, self) + + def __xor__(self, other: Union[ColT, float]) -> "Func": + return math_xor(self, other) + + def __rxor__(self, other: Union[ColT, float]) -> "Func": + return math_xor(other, self) + + def __lt__(self, other: Union[ColT, float]) -> "Func": + return math_lt(self, other) + + def __le__(self, other: Union[ColT, float]) -> "Func": + return math_le(self, other) + + def __eq__(self, other): + return math_eq(self, other) + + def __ne__(self, other): + return math_ne(self, other) + + def __gt__(self, other: Union[ColT, float]) -> "Func": + return math_gt(self, other) + + def __ge__(self, other: Union[ColT, float]) -> "Func": + return math_ge(self, other) + + def label(self, label: str) -> "Func": + return Func( + self.name, + self.inner, + self.cols, + self.args, + self.result_type, + self.is_array, + self.is_window, + self.window, + label, + ) + + def get_col_name(self, label: Optional[str] = None) -> str: + if label: + return label + if self.col_label: + return self.col_label + if (db_cols := self._db_cols) and len(db_cols) == 1: + if isinstance(db_cols[0], str): + return db_cols[0] + if isinstance(db_cols[0], Column): + return db_cols[0].name + if isinstance(db_cols[0], Func): + return db_cols[0].get_col_name() + return self.name + + def get_result_type( + self, signals_schema: Optional["SignalSchema"] = None + ) -> "DataType": + if self.result_type: + return self.result_type + + if signals_schema and (col_type := self._db_col_type(signals_schema)): + return col_type + + raise DataChainColumnError( + str(self), + "Column name is required to infer result type", + ) + + def get_column( + self, + signals_schema: Optional["SignalSchema"] = None, + label: Optional[str] = None, + table: Optional["TableClause"] = None, + ) -> Column: + col_type = self.get_result_type(signals_schema) + sql_type = python_to_sql(col_type) + + def get_col(col: ColT) -> ColT: + if isinstance(col, Func): + return col.get_column(signals_schema, table=table) + if isinstance(col, str): + column = Column(col, sql_type) + column.table = table + return column + return col + + cols = [get_col(col) for col in self._db_cols] + func_col = self.inner(*cols, *self.args) + + if self.is_window: + if not self.window: + raise DataChainParamsError( + f"Window function {self} requires over() clause with a window spec", + ) + func_col = func_col.over( + partition_by=self.window.partition_by, + order_by=( + desc(self.window.order_by) + if self.window.desc + else self.window.order_by + ), + ) + + func_col.type = sql_type() if inspect.isclass(sql_type) else sql_type + + if col_name := self.get_col_name(label): + func_col = func_col.label(col_name) + + return func_col + + +def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType": + if isinstance(col, Func): + return col.get_result_type(signals_schema) + + return signals_schema.get_column_type( + col.name if isinstance(col, ColumnElement) else col + ) + + +def math_func( + name: str, + inner: Callable, + params: Sequence[Union[ColT, float]], + result_type: Optional["DataType"] = None, +) -> Func: + """Returns math function from the columns.""" + cols, args = [], [] + for arg in params: + if isinstance(arg, (int, float)): + args.append(arg) + else: + cols.append(arg) + return Func(name, inner, cols=cols, args=args, result_type=result_type) + + +def math_add(*args: Union[ColT, float]) -> Func: + """Computes the sum of the column.""" + return math_func("add", lambda a1, a2: a1 + a2, args) + + +def math_sub(*args: Union[ColT, float]) -> Func: + """Computes the diff of the column.""" + return math_func("sub", lambda a1, a2: a1 - a2, args) + + +def math_mul(*args: Union[ColT, float]) -> Func: + """Computes the product of the column.""" + return math_func("mul", lambda a1, a2: a1 * a2, args) + + +def math_truediv(*args: Union[ColT, float]) -> Func: + """Computes the division of the column.""" + return math_func("div", lambda a1, a2: a1 / a2, args, result_type=float) + + +def math_floordiv(*args: Union[ColT, float]) -> Func: + """Computes the floor division of the column.""" + return math_func("floordiv", lambda a1, a2: a1 // a2, args, result_type=float) + + +def math_mod(*args: Union[ColT, float]) -> Func: + """Computes the modulo of the column.""" + return math_func("mod", lambda a1, a2: a1 % a2, args, result_type=float) + + +def math_pow(*args: Union[ColT, float]) -> Func: + """Computes the power of the column.""" + return math_func("pow", lambda a1, a2: a1**a2, args, result_type=float) + + +def math_lshift(*args: Union[ColT, float]) -> Func: + """Computes the left shift of the column.""" + return math_func("lshift", lambda a1, a2: a1 << a2, args, result_type=int) + + +def math_rshift(*args: Union[ColT, float]) -> Func: + """Computes the right shift of the column.""" + return math_func("rshift", lambda a1, a2: a1 >> a2, args, result_type=int) + + +def math_and(*args: Union[ColT, float]) -> Func: + """Computes the logical AND of the column.""" + return math_func("and", lambda a1, a2: a1 & a2, args, result_type=bool) + + +def math_or(*args: Union[ColT, float]) -> Func: + """Computes the logical OR of the column.""" + return math_func("or", lambda a1, a2: a1 | a2, args, result_type=bool) + + +def math_xor(*args: Union[ColT, float]) -> Func: + """Computes the logical XOR of the column.""" + return math_func("xor", lambda a1, a2: a1 ^ a2, args, result_type=bool) + + +def math_lt(*args: Union[ColT, float]) -> Func: + """Computes the less than comparison of the column.""" + return math_func("lt", lambda a1, a2: a1 < a2, args, result_type=bool) + + +def math_le(*args: Union[ColT, float]) -> Func: + """Computes the less than or equal comparison of the column.""" + return math_func("le", lambda a1, a2: a1 <= a2, args, result_type=bool) + + +def math_eq(*args: Union[ColT, float]) -> Func: + """Computes the equality comparison of the column.""" + return math_func("eq", lambda a1, a2: a1 == a2, args, result_type=bool) + + +def math_ne(*args: Union[ColT, float]) -> Func: + """Computes the inequality comparison of the column.""" + return math_func("ne", lambda a1, a2: a1 != a2, args, result_type=bool) + + +def math_gt(*args: Union[ColT, float]) -> Func: + """Computes the greater than comparison of the column.""" + return math_func("gt", lambda a1, a2: a1 > a2, args, result_type=bool) + + +def math_ge(*args: Union[ColT, float]) -> Func: + """Computes the greater than or equal comparison of the column.""" + return math_func("ge", lambda a1, a2: a1 >= a2, args, result_type=bool) diff --git a/src/datachain/func/path.py b/src/datachain/func/path.py new file mode 100644 index 000000000..78b418bfa --- /dev/null +++ b/src/datachain/func/path.py @@ -0,0 +1,110 @@ +from datachain.sql.functions import path + +from .func import ColT, Func + + +def parent(col: ColT) -> Func: + """ + Returns the directory component of a posix-style path. + + Args: + col (str | literal | Func): String to compute the path parent of. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + + Returns: + Func: A Func object that represents the path parent function. + + Example: + ```py + dc.mutate( + parent=func.path.parent("file.path"), + ) + ``` + + Note: + - Result column will always be of type string. + """ + return Func("parent", inner=path.parent, cols=[col], result_type=str) + + +def name(col: ColT) -> Func: + """ + Returns the final component of a posix-style path. + + Args: + col (str | literal): String to compute the path name of. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + + Returns: + Func: A Func object that represents the path name function. + + Example: + ```py + dc.mutate( + file_name=func.path.name("file.path"), + ) + ``` + + Note: + - Result column will always be of type string. + """ + + return Func("name", inner=path.name, cols=[col], result_type=str) + + +def file_stem(col: ColT) -> Func: + """ + Returns the path without the extension. + + Args: + col (str | literal): String to compute the file stem of. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + + Returns: + Func: A Func object that represents the file stem function. + + Example: + ```py + dc.mutate( + file_stem=func.path.file_stem("file.path"), + ) + ``` + + Note: + - Result column will always be of type string. + """ + + return Func("file_stem", inner=path.file_stem, cols=[col], result_type=str) + + +def file_ext(col: ColT) -> Func: + """ + Returns the extension of the given path. + + Args: + col (str | literal): String to compute the file extension of. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + + Returns: + Func: A Func object that represents the file extension function. + + Example: + ```py + dc.mutate( + file_stem=func.path.file_ext("file.path"), + ) + ``` + + Note: + - Result column will always be of type string. + """ + + return Func("file_ext", inner=path.file_ext, cols=[col], result_type=str) diff --git a/src/datachain/func/random.py b/src/datachain/func/random.py new file mode 100644 index 000000000..c9c359737 --- /dev/null +++ b/src/datachain/func/random.py @@ -0,0 +1,23 @@ +from datachain.sql.functions import random + +from .func import Func + + +def rand() -> Func: + """ + Returns the random integer value. + + Returns: + Func: A Func object that represents the rand function. + + Example: + ```py + dc.mutate( + rnd=func.random.rand(), + ) + ``` + + Note: + - Result column will always be of type integer. + """ + return Func("rand", inner=random.rand, result_type=int) diff --git a/src/datachain/func/string.py b/src/datachain/func/string.py new file mode 100644 index 000000000..33aa5a22d --- /dev/null +++ b/src/datachain/func/string.py @@ -0,0 +1,154 @@ +from typing import Optional, Union, get_origin + +from sqlalchemy import literal + +from datachain.sql.functions import string + +from .func import Func + + +def length(col: Union[str, Func]) -> Func: + """ + Returns the length of the string. + + Args: + col (str | literal | Func): String to compute the length of. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + + Returns: + Func: A Func object that represents the string length function. + + Example: + ```py + dc.mutate( + len1=func.string.length("file.path"), + len2=func.string.length("Random string"), + ) + ``` + + Note: + - Result column will always be of type int. + """ + return Func("length", inner=string.length, cols=[col], result_type=int) + + +def split(col: Union[str, Func], sep: str, limit: Optional[int] = None) -> Func: + """ + Takes a column and split character and returns an array of the parts. + + Args: + col (str | literal): Column to split. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + sep (str): Separator to split the string. + limit (int, optional): Maximum number of splits to perform. + + Returns: + Func: A Func object that represents the split function. + + Example: + ```py + dc.mutate( + path_parts=func.string.split("file.path", "/"), + str_words=func.string.length("Random string", " "), + ) + ``` + + Note: + - Result column will always be of type array of strings. + """ + + def inner(arg): + if limit is not None: + return string.split(arg, sep, limit) + return string.split(arg, sep) + + if get_origin(col) is literal: + cols = None + args = [col] + else: + cols = [col] + args = None + + return Func("split", inner=inner, cols=cols, args=args, result_type=list[str]) + + +def replace(col: Union[str, Func], pattern: str, replacement: str) -> Func: + """ + Replaces substring with another string. + + Args: + col (str | literal): Column to split. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + pattern (str): Pattern to replace. + replacement (str): Replacement string. + + Returns: + Func: A Func object that represents the replace function. + + Example: + ```py + dc.mutate( + signal=func.string.replace("signal.name", "pattern", "replacement), + ) + ``` + + Note: + - Result column will always be of type string. + """ + + def inner(arg): + return string.replace(arg, pattern, replacement) + + if get_origin(col) is literal: + cols = None + args = [col] + else: + cols = [col] + args = None + + return Func("replace", inner=inner, cols=cols, args=args, result_type=str) + + +def regexp_replace(col: Union[str, Func], regex: str, replacement: str) -> Func: + r""" + Replaces substring that match a regular expression. + + Args: + col (str | literal): Column to split. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + regex (str): Regular expression pattern to replace. + replacement (str): Replacement string. + + Returns: + Func: A Func object that represents the regexp_replace function. + + Example: + ```py + dc.mutate( + signal=func.string.regexp_replace("signal.name", r"\d+", "X"), + ) + ``` + + Note: + - Result column will always be of type string. + """ + + def inner(arg): + return string.regexp_replace(arg, regex, replacement) + + if get_origin(col) is literal: + cols = None + args = [col] + else: + cols = [col] + args = None + + return Func("regexp_replace", inner=inner, cols=cols, args=args, result_type=str) diff --git a/src/datachain/func/window.py b/src/datachain/func/window.py new file mode 100644 index 000000000..ea653017f --- /dev/null +++ b/src/datachain/func/window.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass + +from datachain.query.schema import ColumnMeta + + +@dataclass +class Window: + """Represents a window specification for SQL window functions.""" + + partition_by: str + order_by: str + desc: bool = False + + +def window(partition_by: str, order_by: str, desc: bool = False) -> Window: + """ + Defines a window specification for SQL window functions. + + The `window` function specifies how to partition and order the result set + for the associated window function. It is used to define the scope of the rows + that the window function will operate on. + + Args: + partition_by (str): The column name by which to partition the result set. + Rows with the same value in the partition column + will be grouped together for the window function. + order_by (str): The column name by which to order the rows + within each partition. This determines the sequence in which + the window function is applied. + desc (bool, optional): If True, the rows will be ordered in descending order. + Defaults to False, which orders the rows + in ascending order. + + Returns: + Window: A Window object representing the window specification. + + Example: + ```py + window = func.window(partition_by="signal.category", order_by="created_at") + dc.mutate( + row_number=func.row_number().over(window), + ) + ``` + """ + return Window( + ColumnMeta.to_db_name(partition_by), + ColumnMeta.to_db_name(order_by), + desc, + ) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index d0c7d19fe..94f3782db 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -28,13 +28,14 @@ from datachain.client import Client from datachain.client.local import FileClient from datachain.dataset import DatasetRecord +from datachain.func.base import Function +from datachain.func.func import Func from datachain.lib.convert.python_to_sql import python_to_sql from datachain.lib.convert.values_to_tuples import values_to_tuples from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model from datachain.lib.dataset_info import DatasetInfo from datachain.lib.file import ArrowRow, File, get_file_type from datachain.lib.file import ExportPlacement as FileExportPlacement -from datachain.lib.func import Func from datachain.lib.listing import ( list_bucket, ls, @@ -112,9 +113,29 @@ def __init__(self, name, msg): # noqa: D107 super().__init__(f"Dataset{name} from values error: {msg}") -def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str: +MergeColType = Union[str, Function, sqlalchemy.ColumnElement] + + +def _validate_merge_on( + on: Union[MergeColType, Sequence[MergeColType]], + ds: "DataChain", +) -> Sequence[MergeColType]: + if isinstance(on, (str, sqlalchemy.ColumnElement)): + return [on] + if isinstance(on, Function): + return [on.get_column(table=ds._query.table)] + if isinstance(on, Sequence): + return [ + c.get_column(table=ds._query.table) if isinstance(c, Function) else c + for c in on + ] + + +def _get_merge_error_str(col: MergeColType) -> str: if isinstance(col, str): return col + if isinstance(col, Function): + return f"{col.name}()" if isinstance(col, sqlalchemy.Column): return col.name.replace(DEFAULT_DELIMITER, ".") if isinstance(col, sqlalchemy.ColumnElement) and hasattr(col, "name"): @@ -125,11 +146,13 @@ def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str: class DatasetMergeError(DataChainParamsError): # noqa: D101 def __init__( # noqa: D107 self, - on: Sequence[Union[str, sqlalchemy.ColumnElement]], - right_on: Optional[Sequence[Union[str, sqlalchemy.ColumnElement]]], + on: Union[MergeColType, Sequence[MergeColType]], + right_on: Optional[Union[MergeColType, Sequence[MergeColType]]], msg: str, ): - def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str: + def _get_str( + on: Union[MergeColType, Sequence[MergeColType]], + ) -> str: if not isinstance(on, Sequence): return str(on) # type: ignore[unreachable] return ", ".join([_get_merge_error_str(col) for col in on]) @@ -1127,7 +1150,7 @@ def select_except(self, *args: str) -> "Self": def group_by( self, *, - partition_by: Union[str, Sequence[str]], + partition_by: Union[str, Func, Sequence[Union[str, Func]]], **kwargs: Func, ) -> "Self": """Group rows by specified set of signals and return new signals @@ -1144,36 +1167,47 @@ def group_by( ) ``` """ - if isinstance(partition_by, str): + if isinstance(partition_by, (str, Func)): partition_by = [partition_by] if not partition_by: raise ValueError("At least one column should be provided for partition_by") - if not kwargs: - raise ValueError("At least one column should be provided for group_by") - for col_name, func in kwargs.items(): - if not isinstance(func, Func): - raise DataChainColumnError( - col_name, - f"Column {col_name} has type {type(func)} but expected Func object", - ) - partition_by_columns: list[Column] = [] signal_columns: list[Column] = [] schema_fields: dict[str, DataType] = {} # validate partition_by columns and add them to the schema - for col_name in partition_by: - col_db_name = ColumnMeta.to_db_name(col_name) - col_type = self.signals_schema.get_column_type(col_db_name) - col = Column(col_db_name, python_to_sql(col_type)) - partition_by_columns.append(col) + for col in partition_by: + if isinstance(col, str): + col_db_name = ColumnMeta.to_db_name(col) + col_type = self.signals_schema.get_column_type(col_db_name) + column = Column(col_db_name, python_to_sql(col_type)) + elif isinstance(col, Function): + column = col.get_column(self.signals_schema) + col_db_name = column.name + col_type = column.type.python_type + else: + raise DataChainColumnError( + col, + ( + f"partition_by column {col} has type {type(col)}" + " but expected str or Function" + ), + ) + partition_by_columns.append(column) schema_fields[col_db_name] = col_type # validate signal columns and add them to the schema + if not kwargs: + raise ValueError("At least one column should be provided for group_by") for col_name, func in kwargs.items(): - col = func.get_column(self.signals_schema, label=col_name) - signal_columns.append(col) + if not isinstance(func, Func): + raise DataChainColumnError( + col_name, + f"Column {col_name} has type {type(func)} but expected Func object", + ) + column = func.get_column(self.signals_schema, label=col_name) + signal_columns.append(column) schema_fields[col_name] = func.get_result_type(self.signals_schema) return self._evolve( @@ -1421,25 +1455,16 @@ def remove_file_signals(self) -> "Self": # noqa: D102 def merge( self, right_ds: "DataChain", - on: Union[ - str, - sqlalchemy.ColumnElement, - Sequence[Union[str, sqlalchemy.ColumnElement]], - ], - right_on: Union[ - str, - sqlalchemy.ColumnElement, - Sequence[Union[str, sqlalchemy.ColumnElement]], - None, - ] = None, + on: Union[MergeColType, Sequence[MergeColType]], + right_on: Optional[Union[MergeColType, Sequence[MergeColType]]] = None, inner=False, rname="right_", ) -> "Self": """Merge two chains based on the specified criteria. Parameters: - right_ds : Chain to join with. - on : Predicate or list of Predicates to join on. If both chains have the + right_ds: Chain to join with. + on: Predicate or list of Predicates to join on. If both chains have the same predicates then this predicate is enough for the join. Otherwise, `right_on` parameter has to specify the predicates for the other chain. right_on: Optional predicate or list of Predicates @@ -1456,23 +1481,24 @@ def merge( if on is None: raise DatasetMergeError(["None"], None, "'on' must be specified") - if isinstance(on, (str, sqlalchemy.ColumnElement)): - on = [on] - elif not isinstance(on, Sequence): + on = _validate_merge_on(on, self) + if not on: raise DatasetMergeError( on, right_on, - f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'", + ( + "'on' must be 'str', 'Func' or 'Sequence' object " + f"but got type '{type(on)}'" + ), ) if right_on is not None: - if isinstance(right_on, (str, sqlalchemy.ColumnElement)): - right_on = [right_on] - elif not isinstance(right_on, Sequence): + right_on = _validate_merge_on(right_on, right_ds) + if not right_on: raise DatasetMergeError( on, right_on, - "'right_on' must be 'str' or 'Sequence' object" + "'right_on' must be 'str', 'Func' or 'Sequence' object" f" but got type '{type(right_on)}'", ) @@ -1488,10 +1514,12 @@ def merge( def _resolve( ds: DataChain, - col: Union[str, sqlalchemy.ColumnElement], + col: Union[str, Function, sqlalchemy.ColumnElement], side: Union[str, None], ): try: + if isinstance(col, Function): + return ds.c(col.get_column()) return ds.c(col) if isinstance(col, (str, C)) else col except ValueError: if side: @@ -2399,9 +2427,9 @@ def filter(self, *args: Any) -> "Self": dc.filter(C("file.name").glob("*.jpg")) ``` - Using `datachain.sql.functions` + Using `datachain.func` ```py - from datachain.sql.functions import string + from datachain.func import string dc.filter(string.length(C("file.name")) > 5) ``` diff --git a/src/datachain/lib/func/__init__.py b/src/datachain/lib/func/__init__.py deleted file mode 100644 index ba6f08027..000000000 --- a/src/datachain/lib/func/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -from .aggregate import ( - any_value, - avg, - collect, - concat, - count, - dense_rank, - first, - max, - min, - rank, - row_number, - sum, -) -from .func import Func, window - -__all__ = [ - "Func", - "any_value", - "avg", - "collect", - "concat", - "count", - "dense_rank", - "first", - "max", - "min", - "rank", - "row_number", - "sum", - "window", -] diff --git a/src/datachain/lib/func/func.py b/src/datachain/lib/func/func.py deleted file mode 100644 index 3e7373d52..000000000 --- a/src/datachain/lib/func/func.py +++ /dev/null @@ -1,152 +0,0 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Optional - -from sqlalchemy import desc - -from datachain.lib.convert.python_to_sql import python_to_sql -from datachain.lib.utils import DataChainColumnError, DataChainParamsError -from datachain.query.schema import Column, ColumnMeta - -if TYPE_CHECKING: - from datachain import DataType - from datachain.lib.signal_schema import SignalSchema - - -@dataclass -class Window: - """Represents a window specification for SQL window functions.""" - - partition_by: str - order_by: str - desc: bool = False - - -def window(partition_by: str, order_by: str, desc: bool = False) -> Window: - """ - Defines a window specification for SQL window functions. - - The `window` function specifies how to partition and order the result set - for the associated window function. It is used to define the scope of the rows - that the window function will operate on. - - Args: - partition_by (str): The column name by which to partition the result set. - Rows with the same value in the partition column - will be grouped together for the window function. - order_by (str): The column name by which to order the rows - within each partition. This determines the sequence in which - the window function is applied. - desc (bool, optional): If True, the rows will be ordered in descending order. - Defaults to False, which orders the rows - in ascending order. - - Returns: - Window: A Window object representing the window specification. - - Example: - ```py - window = func.window(partition_by="signal.category", order_by="created_at") - dc.mutate( - row_number=func.row_number().over(window), - ) - ``` - """ - return Window( - ColumnMeta.to_db_name(partition_by), - ColumnMeta.to_db_name(order_by), - desc, - ) - - -class Func: - """Represents a function to be applied to a column in a SQL query.""" - - def __init__( - self, - name: str, - inner: Callable, - col: Optional[str] = None, - result_type: Optional["DataType"] = None, - is_array: bool = False, - is_window: bool = False, - window: Optional[Window] = None, - ) -> None: - self.name = name - self.inner = inner - self.col = col - self.result_type = result_type - self.is_array = is_array - self.is_window = is_window - self.window = window - - def __str__(self) -> str: - return self.name + "()" - - def over(self, window: Window) -> "Func": - if not self.is_window: - raise DataChainParamsError(f"{self} doesn't support window (over())") - - return Func( - "over", - self.inner, - self.col, - self.result_type, - self.is_array, - self.is_window, - window, - ) - - @property - def db_col(self) -> Optional[str]: - return ColumnMeta.to_db_name(self.col) if self.col else None - - def db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]: - if not self.db_col: - return None - col_type: type = signals_schema.get_column_type(self.db_col) - return list[col_type] if self.is_array else col_type # type: ignore[valid-type] - - def get_result_type(self, signals_schema: "SignalSchema") -> "DataType": - if self.result_type: - return self.result_type - - if col_type := self.db_col_type(signals_schema): - return col_type - - raise DataChainColumnError( - str(self), - "Column name is required to infer result type", - ) - - def get_column( - self, signals_schema: "SignalSchema", label: Optional[str] = None - ) -> Column: - col_type = self.get_result_type(signals_schema) - sql_type = python_to_sql(col_type) - - if self.col: - col = Column(self.db_col, sql_type) - func_col = self.inner(col) - else: - func_col = self.inner() - - if self.is_window: - if not self.window: - raise DataChainParamsError( - f"Window function {self} requires over() clause with a window spec", - ) - func_col = func_col.over( - partition_by=self.window.partition_by, - order_by=( - desc(self.window.order_by) - if self.window.desc - else self.window.order_by - ), - ) - - func_col.type = sql_type - - if label: - func_col = func_col.label(label) - - return func_col diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 29cf202cd..c851d761e 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -23,12 +23,12 @@ from sqlalchemy import ColumnElement from typing_extensions import Literal as LiteralEx +from datachain.func.func import Func from datachain.lib.convert.python_to_sql import python_to_sql from datachain.lib.convert.sql_to_python import sql_to_python from datachain.lib.convert.unflatten import unflatten_to_json_pos from datachain.lib.data_model import DataModel, DataType, DataValue from datachain.lib.file import File -from datachain.lib.func import Func from datachain.lib.model_store import ModelStore from datachain.lib.utils import DataChainParamsError from datachain.query.schema import DEFAULT_DELIMITER, Column diff --git a/src/datachain/nodes_fetcher.py b/src/datachain/nodes_fetcher.py index b57f21542..5aa75dc3c 100644 --- a/src/datachain/nodes_fetcher.py +++ b/src/datachain/nodes_fetcher.py @@ -2,12 +2,12 @@ from collections.abc import Iterable from typing import TYPE_CHECKING -from datachain.node import Node from datachain.nodes_thread_pool import NodesThreadPool if TYPE_CHECKING: from datachain.cache import DataChainCache from datachain.client.fsspec import Client + from datachain.node import Node logger = logging.getLogger("datachain") @@ -22,7 +22,7 @@ def done_task(self, done): for task in done: task.result() - def do_task(self, chunk: Iterable[Node]) -> None: + def do_task(self, chunk: Iterable["Node"]) -> None: from fsspec import Callback class _CB(Callback): diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 3998400ad..3f812bf54 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -43,9 +43,10 @@ ) from datachain.dataset import DatasetStatus, RowDict from datachain.error import DatasetNotFoundError, QueryScriptCancelError +from datachain.func.base import Function from datachain.lib.udf import UDFAdapter from datachain.progress import CombinedDownloadCallback -from datachain.sql.functions import rand +from datachain.sql.functions.random import rand from datachain.utils import ( batched, determine_processes, @@ -65,15 +66,16 @@ from datachain.catalog import Catalog from datachain.data_storage import AbstractWarehouse from datachain.dataset import DatasetRecord - - from .udf import UDFResult + from datachain.lib.udf import UDFResult P = ParamSpec("P") INSERT_BATCH_SIZE = 10000 -PartitionByType = Union[ColumnElement, Sequence[ColumnElement]] +PartitionByType = Union[ + Function, ColumnElement, Sequence[Union[Function, ColumnElement]] +] JoinPredicateType = Union[str, ColumnClause, ColumnElement] DatasetDependencyType = tuple[str, int] @@ -517,13 +519,17 @@ def create_partitions_table(self, query: Select) -> "Table": else: list_partition_by = [self.partition_by] + partition_by = [ + p.get_column() if isinstance(p, Function) else p for p in list_partition_by + ] + # create table with partitions tbl = self.catalog.warehouse.create_udf_table(partition_columns()) # fill table with partitions cols = [ query.selected_columns.sys__id, - f.dense_rank().over(order_by=list_partition_by).label(PARTITION_COLUMN_ID), + f.dense_rank().over(order_by=partition_by).label(PARTITION_COLUMN_ID), ] self.catalog.warehouse.db.execute( tbl.insert().from_select(cols, query.with_only_columns(*cols)) @@ -680,6 +686,12 @@ def q(*columns): return step_result(q, new_query.selected_columns) + def parse_cols( + self, + cols: Sequence[Union[Function, ColumnElement]], + ) -> tuple[ColumnElement, ...]: + return tuple(c.get_column() if isinstance(c, Function) else c for c in cols) + @abstractmethod def apply_sql_clause(self, query): pass @@ -687,12 +699,14 @@ def apply_sql_clause(self, query): @frozen class SQLSelect(SQLClause): - args: tuple[Union[str, ColumnElement], ...] + args: tuple[Union[Function, ColumnElement], ...] def apply_sql_clause(self, query) -> Select: subquery = query.subquery() - - args = [subquery.c[str(c)] if isinstance(c, (str, C)) else c for c in self.args] + args = [ + subquery.c[str(c)] if isinstance(c, (str, C)) else c + for c in self.parse_cols(self.args) + ] if not args: args = subquery.c @@ -701,22 +715,25 @@ def apply_sql_clause(self, query) -> Select: @frozen class SQLSelectExcept(SQLClause): - args: tuple[str, ...] + args: tuple[Union[Function, ColumnElement], ...] def apply_sql_clause(self, query: Select) -> Select: subquery = query.subquery() - names = set(self.args) - args = [c for c in subquery.c if c.name not in names] + args = [c for c in subquery.c if c.name not in set(self.parse_cols(self.args))] return sqlalchemy.select(*args).select_from(subquery) @frozen class SQLMutate(SQLClause): - args: tuple[ColumnElement, ...] + args: tuple[Union[Function, ColumnElement], ...] def apply_sql_clause(self, query: Select) -> Select: original_subquery = query.subquery() - to_mutate = {c.name for c in self.args} + args = [ + original_subquery.c[str(c)] if isinstance(c, (str, C)) else c + for c in self.parse_cols(self.args) + ] + to_mutate = {c.name for c in args} prefix = f"mutate{token_hex(8)}_" cols = [ @@ -726,9 +743,7 @@ def apply_sql_clause(self, query: Select) -> Select: # this is needed for new column to be used in clauses # like ORDER BY, otherwise new column is not recognized subquery = ( - sqlalchemy.select(*cols, *self.args) - .select_from(original_subquery) - .subquery() + sqlalchemy.select(*cols, *args).select_from(original_subquery).subquery() ) return sqlalchemy.select(*subquery.c).select_from(subquery) @@ -736,21 +751,24 @@ def apply_sql_clause(self, query: Select) -> Select: @frozen class SQLFilter(SQLClause): - expressions: tuple[ColumnElement, ...] + expressions: tuple[Union[Function, ColumnElement], ...] def __and__(self, other): - return self.__class__(self.expressions + other) + expressions = self.parse_cols(self.expressions) + return self.__class__(expressions + other) def apply_sql_clause(self, query: Select) -> Select: - return query.filter(*self.expressions) + expressions = self.parse_cols(self.expressions) + return query.filter(*expressions) @frozen class SQLOrderBy(SQLClause): - args: tuple[ColumnElement, ...] + args: tuple[Union[Function, ColumnElement], ...] def apply_sql_clause(self, query: Select) -> Select: - return query.order_by(*self.args) + args = self.parse_cols(self.args) + return query.order_by(*args) @frozen @@ -945,8 +963,8 @@ def q(*columns): @frozen class SQLGroupBy(SQLClause): - cols: Sequence[Union[str, ColumnElement]] - group_by: Sequence[Union[str, ColumnElement]] + cols: Sequence[Union[str, Function, ColumnElement]] + group_by: Sequence[Union[str, Function, ColumnElement]] def apply_sql_clause(self, query) -> Select: if not self.cols: @@ -956,12 +974,20 @@ def apply_sql_clause(self, query) -> Select: subquery = query.subquery() + group_by = [ + c.get_column() if isinstance(c, Function) else c for c in self.group_by + ] + cols = [ - subquery.c[str(c)] if isinstance(c, (str, C)) else c - for c in [*self.group_by, *self.cols] + c.get_column() + if isinstance(c, Function) + else subquery.c[str(c)] + if isinstance(c, (str, C)) + else c + for c in (*group_by, *self.cols) ] - return sqlalchemy.select(*cols).select_from(subquery).group_by(*self.group_by) + return sqlalchemy.select(*cols).select_from(subquery).group_by(*group_by) def _validate_columns( diff --git a/src/datachain/sql/__init__.py b/src/datachain/sql/__init__.py index 4fc757e4c..4d812300e 100644 --- a/src/datachain/sql/__init__.py +++ b/src/datachain/sql/__init__.py @@ -1,13 +1,11 @@ from sqlalchemy.sql.elements import literal from sqlalchemy.sql.expression import column -from . import functions from .default import setup as default_setup from .selectable import select, values __all__ = [ "column", - "functions", "literal", "select", "values", diff --git a/src/datachain/sql/functions/__init__.py b/src/datachain/sql/functions/__init__.py index c8d4ef0de..e69de29bb 100644 --- a/src/datachain/sql/functions/__init__.py +++ b/src/datachain/sql/functions/__init__.py @@ -1,26 +0,0 @@ -from sqlalchemy.sql.expression import func - -from . import array, path, string -from .aggregate import avg -from .conditional import greatest, least -from .random import rand - -count = func.count -sum = func.sum -min = func.min -max = func.max - -__all__ = [ - "array", - "avg", - "count", - "func", - "greatest", - "least", - "max", - "min", - "path", - "rand", - "string", - "sum", -] diff --git a/src/datachain/sql/selectable.py b/src/datachain/sql/selectable.py index 20c0b5f0c..539520b25 100644 --- a/src/datachain/sql/selectable.py +++ b/src/datachain/sql/selectable.py @@ -9,7 +9,9 @@ def __init__(self, data, columns=None, **kwargs): columns = [expression.column(f"c{i}") for i in range(1, num_columns + 1)] else: columns = [ - expression.column(c) if isinstance(c, str) else c for c in columns + process_column_expression(c) + for c in columns + # expression.column(c) if isinstance(c, str) else c for c in columns ] super().__init__(*columns, **kwargs) self._data += tuple(data) @@ -19,13 +21,17 @@ def values(data, columns=None, **kwargs) -> Values: return Values(data, columns=columns, **kwargs) -def process_column_expressions(columns): - return [expression.column(c) if isinstance(c, str) else c for c in columns] +def process_column_expression(col): + if hasattr(col, "get_column"): + return col.get_column() + if isinstance(col, str): + return expression.column(col) + return col def select(*columns, **kwargs) -> "expression.Select": - columns = process_column_expressions(columns) - return expression.select(*columns, **kwargs) + columns_processed = [process_column_expression(c) for c in columns] + return expression.select(*columns_processed, **kwargs) def base_values_compiler(column_name_func, element, compiler, **kwargs): diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index fb11cad6d..41355c774 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -15,10 +15,11 @@ from PIL import Image from sqlalchemy import Column -from datachain import DataModel +from datachain import DataModel, func from datachain.catalog.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE from datachain.data_storage.sqlite import SQLiteWarehouse from datachain.dataset import DatasetDependencyType, DatasetStats +from datachain.func import path as pathfunc from datachain.lib.dc import C, DataChain from datachain.lib.file import File, ImageFile from datachain.lib.listing import LISTING_TTL, is_listing_dataset, parse_listing_uri @@ -26,8 +27,6 @@ from datachain.lib.udf import Mapper from datachain.lib.utils import DataChainError from datachain.query.dataset import QueryStep -from datachain.sql.functions import path as pathfunc -from datachain.sql.functions.array import cosine_distance, euclidean_distance from tests.utils import ( ANY_VALUE, NUM_TREE, @@ -954,7 +953,7 @@ def get_result(chain): expected = [(f"{i:06d}", i) for i in range(100)] dc = ( DataChain.from_storage(ctc.src_uri, session=ctc.session) - .mutate(name=pathfunc.name(C("file.path"))) + .mutate(name=pathfunc.name("file.path")) .save() ) # We test a few different orderings here, because we've had strange @@ -1289,8 +1288,8 @@ def calc_emb(file): DataChain.from_storage(src_uri, session=session) .map(embedding=calc_emb, output={"embedding": list[float]}) .mutate( - cos_dist=cosine_distance(C("embedding"), target_embedding), - eucl_dist=euclidean_distance(C("embedding"), target_embedding), + cos_dist=func.cosine_distance("embedding", target_embedding), + eucl_dist=func.euclidean_distance("embedding", target_embedding), ) .order_by("file.path") ) @@ -1389,6 +1388,38 @@ def file_info(file: File) -> FileInfo: ) +def test_group_by_func(cloud_test_catalog): + from datachain import func + + session = cloud_test_catalog.session + src_uri = cloud_test_catalog.src_uri + + ds = ( + DataChain.from_storage(src_uri, session=session) + .group_by( + cnt=func.count(), + sum=func.sum("file.size"), + partition_by=func.path.parent("file.path").label("file_dir"), + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "file_dir": "str", + "cnt": "int", + "sum": "int", + } + assert sorted_dicts(ds.to_records(), "file_dir") == sorted_dicts( + [ + {"file_dir": "", "cnt": 1, "sum": 13}, + {"file_dir": "cats", "cnt": 2, "sum": 8}, + {"file_dir": "dogs", "cnt": 3, "sum": 11}, + {"file_dir": "dogs/others", "cnt": 1, "sum": 4}, + ], + "file_dir", + ) + + @pytest.mark.parametrize("partition_by", ["file_info.path", "file_info__path"]) @pytest.mark.parametrize("order_by", ["file_info.name", "file_info__name"]) def test_window_signals(cloud_test_catalog, partition_by, order_by): diff --git a/tests/func/test_datasets.py b/tests/func/test_datasets.py index 03b90ecc5..e7b990807 100644 --- a/tests/func/test_datasets.py +++ b/tests/func/test_datasets.py @@ -15,7 +15,7 @@ from datachain.lib.dc import DataChain from datachain.lib.file import File from datachain.lib.listing import parse_listing_uri -from datachain.query import DatasetQuery +from datachain.query.dataset import DatasetQuery from datachain.sql.types import Float32, Int, Int64 from tests.utils import assert_row_names, dataset_dependency_asdict diff --git a/tests/func/test_pull.py b/tests/func/test_pull.py index 24ce20bfe..9717aa772 100644 --- a/tests/func/test_pull.py +++ b/tests/func/test_pull.py @@ -110,7 +110,7 @@ def remote_dataset_version(schema, dataset_rows): "schema": schema, "sources": "", "query_script": ( - 'from datachain.query import DatasetQuery\nDatasetQuery(path="s3://ldb-public")', + 'from datachain.query.dataset import DatasetQuery\nDatasetQuery(path="s3://ldb-public")', ), "created_by_id": 1, } diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index ccde1f4cf..031a87313 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -1868,7 +1868,7 @@ def test_order_by_with_nested_columns(test_session, with_function): file=[File(path=name) for name in names], session=test_session ) if with_function: - from datachain.sql.functions import rand + from datachain.sql.functions.random import rand dc = dc.order_by("file.path", rand()) else: @@ -1917,7 +1917,7 @@ def test_order_by_descending(test_session, with_function): file=[File(path=name) for name in names], session=test_session ) if with_function: - from datachain.sql.functions import rand + from datachain.sql.functions.random import rand dc = dc.order_by("file.path", rand(), descending=True) else: @@ -2272,22 +2272,20 @@ def test_mutate_with_multiplication(test_session): def test_mutate_with_sql_func(test_session): - from datachain.sql import functions as func + from datachain import func ds = DataChain.from_values(id=[1, 2], session=test_session) - assert ( - ds.mutate(new=func.avg(ds.column("id"))).signals_schema.values["new"] is float - ) + assert ds.mutate(new=func.avg("id")).signals_schema.values["new"] is float def test_mutate_with_complex_expression(test_session): - from datachain.sql import functions as func + from datachain import func ds = DataChain.from_values(id=[1, 2], name=["Jim", "Jon"], session=test_session) assert ( - ds.mutate( - new=(func.sum(ds.column("id"))) * (5 - func.min(ds.column("id"))) - ).signals_schema.values["new"] + ds.mutate(new=func.sum("id") * (5 - func.min("id"))).signals_schema.values[ + "new" + ] is int ) diff --git a/tests/unit/lib/test_sql_to_python.py b/tests/unit/lib/test_sql_to_python.py index 85c973ac9..80565ef22 100644 --- a/tests/unit/lib/test_sql_to_python.py +++ b/tests/unit/lib/test_sql_to_python.py @@ -3,7 +3,6 @@ from datachain import Column from datachain.lib.convert.sql_to_python import sql_to_python -from datachain.sql import functions as func from datachain.sql.types import Float, Int64, String @@ -15,8 +14,6 @@ (Column("score", Float), float), # SQL expression (Column("age", Int64) - 2, int), - # SQL function - (func.avg(Column("age", Int64)), float), # Default type (Column("null", NullType), str), ], diff --git a/tests/unit/sql/test_array.py b/tests/unit/sql/test_array.py index 5448c82c6..9238a2c15 100644 --- a/tests/unit/sql/test_array.py +++ b/tests/unit/sql/test_array.py @@ -1,12 +1,65 @@ -from datachain.sql import literal, select -from datachain.sql.functions import array, string +import math + +import pytest + +from datachain import func +from datachain.sql import select + + +def test_cosine_distance(warehouse): + query = select( + func.cosine_distance((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 5, 6)).label("cos1"), + func.cosine_distance([3.0, 5.0, 1.0], (3.0, 5.0, 1.0)).label("cos2"), + func.cosine_distance((1, 0), [0, 10]).label("cos3"), + func.cosine_distance([0.0, 10.0], [1.0, 0.0]).label("cos4"), + ) + result = tuple(warehouse.db.execute(query)) + assert result == ((0.0, 0.0, 1.0, 1.0),) + + +def test_euclidean_distance(warehouse): + query = select( + func.euclidean_distance((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 5, 6)).label("eu1"), + func.euclidean_distance([3.0, 5.0, 1.0], (3.0, 5.0, 1.0)).label("eu2"), + func.euclidean_distance((1, 0), [0, 1]).label("eu3"), + func.euclidean_distance([1.0, 1.0, 1.0], [2.0, 2.0, 2.0]).label("eu4"), + ) + result = tuple(warehouse.db.execute(query)) + assert result == ((0.0, 0.0, math.sqrt(2), math.sqrt(3)),) + + +@pytest.mark.parametrize( + "args", + [ + [], + ["signal"], + [[1, 2]], + [[1, 2], [1, 2], [1, 2]], + ["signal1", "signal2", "signal3"], + ["signal1", "signal2", [1, 2]], + ], +) +def test_cosine_euclidean_distance_error_args(warehouse, args): + with pytest.raises(ValueError, match="requires exactly two arguments"): + func.cosine_distance(*args) + + with pytest.raises(ValueError, match="requires exactly two arguments"): + func.euclidean_distance(*args) + + +def test_cosine_euclidean_distance_error_vectors_length(warehouse): + with pytest.raises(ValueError, match="requires vectors of the same length"): + func.cosine_distance([1], [1, 2]) + + with pytest.raises(ValueError, match="requires vectors of the same length"): + func.euclidean_distance([1], [1, 2]) def test_length(warehouse): query = select( - array.length(["abc", "def", "g", "hi"]), - array.length([3.0, 5.0, 1.0, 6.0, 1.0]), - array.length([[1, 2, 3], [4, 5, 6]]), + func.length(["abc", "def", "g", "hi"]).label("len1"), + func.length([3.0, 5.0, 1.0, 6.0, 1.0]).label("len2"), + func.length([[1, 2, 3], [4, 5, 6]]).label("len3"), ) result = tuple(warehouse.db.execute(query)) assert result == ((4, 5, 2),) @@ -14,7 +67,7 @@ def test_length(warehouse): def test_length_on_split(warehouse): query = select( - array.length(string.split(literal("abc/def/g/hi"), literal("/"))), + func.array.length(func.string.split(func.literal("abc/def/g/hi"), "/")), ) result = tuple(warehouse.db.execute(query)) assert result == ((4,),) diff --git a/tests/unit/sql/test_conditional.py b/tests/unit/sql/test_conditional.py index cae3f2433..db64511fd 100644 --- a/tests/unit/sql/test_conditional.py +++ b/tests/unit/sql/test_conditional.py @@ -1,20 +1,27 @@ import pytest -from datachain.sql import column, select, values -from datachain.sql import literal as lit -from datachain.sql.functions import greatest, least +from datachain import func +from datachain.sql import select, values @pytest.mark.parametrize( "args,expected", [ - ([lit("abc"), lit("bcd"), lit("Abc"), lit("cd")], "cd"), + ( + [ + func.literal("abc"), + func.literal("bcd"), + func.literal("Abc"), + func.literal("cd"), + ], + "cd", + ), ([3, 1, 2.0, 3.1, 2.5, -1], 3.1), ([4], 4), ], ) def test_greatest(warehouse, args, expected): - query = select(greatest(*args)) + query = select(func.greatest(*args)) result = tuple(warehouse.db.execute(query)) assert result == ((expected,),) @@ -22,13 +29,21 @@ def test_greatest(warehouse, args, expected): @pytest.mark.parametrize( "args,expected", [ - ([lit("abc"), lit("bcd"), lit("Abc"), lit("cd")], "Abc"), + ( + [ + func.literal("abc"), + func.literal("bcd"), + func.literal("Abc"), + func.literal("cd"), + ], + "Abc", + ), ([3, 1, 2.0, 3.1, 2.5, -1], -1), ([4], 4), ], ) def test_least(warehouse, args, expected): - query = select(least(*args)) + query = select(func.least(*args)) result = tuple(warehouse.db.execute(query)) assert result == ((expected,),) @@ -36,9 +51,9 @@ def test_least(warehouse, args, expected): @pytest.mark.parametrize( "expr,expected", [ - (greatest(column("a")), [(3,), (8,), (9,)]), - (least(column("a")), [(3,), (8,), (9,)]), - (least(column("a"), column("b")), [(3,), (7,), (1,)]), + (func.greatest("a"), [(3,), (8,), (9,)]), + (func.least("a"), [(3,), (8,), (9,)]), + (func.least("a", "b"), [(3,), (7,), (1,)]), ], ) def test_conditionals_with_multiple_rows(warehouse, expr, expected): diff --git a/tests/unit/sql/test_path.py b/tests/unit/sql/test_path.py index 7f138d333..0e0bea60e 100644 --- a/tests/unit/sql/test_path.py +++ b/tests/unit/sql/test_path.py @@ -2,10 +2,11 @@ import re import pytest -from sqlalchemy import literal, select -from sqlalchemy.sql import func as f +from sqlalchemy import func as sa_func -from datachain.sql.functions import path as sql_path +from datachain import func +from datachain.sql import select +from datachain.sql.functions import path as pathfunc PATHS = ["", "/", "name", "/name", "name/", "some/long/path"] EXT_PATHS = [ @@ -33,7 +34,7 @@ def file_ext(path): return pp.splitext(path)[1].lstrip(".") -@pytest.mark.parametrize("func_base", [f.path, sql_path]) +@pytest.mark.parametrize("func_base", [sa_func.path, pathfunc]) @pytest.mark.parametrize("func_name", ["parent", "name"]) def test_default_not_implement(func_base, func_name): """ @@ -42,34 +43,34 @@ def test_default_not_implement(func_base, func_name): SQLAlchemy dialect. """ fn = getattr(func_base, func_name) - expr = fn(literal("file:///some/file/path")) + expr = fn(func.literal("file:///some/file/path")) with pytest.raises(NotImplementedError, match=re.escape(f"path.{func_name}")): expr.compile() @pytest.mark.parametrize("path", PATHS) def test_parent(warehouse, path): - query = select(f.path.parent(literal(path))) + query = select(func.path.parent(func.literal(path))) result = tuple(warehouse.db.execute(query)) assert result == ((split_parent(path)[0],),) @pytest.mark.parametrize("path", PATHS) def test_name(warehouse, path): - query = select(f.path.name(literal(path))) + query = select(func.path.name(func.literal(path))) result = tuple(warehouse.db.execute(query)) assert result == ((split_parent(path)[1],),) @pytest.mark.parametrize("path", EXT_PATHS) def test_file_stem(warehouse, path): - query = select(sql_path.file_stem(literal(path))) + query = select(func.path.file_stem(func.literal(path))) result = tuple(warehouse.db.execute(query)) assert result == ((file_stem(path),),) @pytest.mark.parametrize("path", EXT_PATHS) def test_file_ext(warehouse, path): - query = select(sql_path.file_ext(literal(path))) + query = select(func.path.file_ext(func.literal(path))) result = tuple(warehouse.db.execute(query)) assert result == ((file_ext(path),),) diff --git a/tests/unit/sql/test_random.py b/tests/unit/sql/test_random.py index 6e486fbc7..46ca8493d 100644 --- a/tests/unit/sql/test_random.py +++ b/tests/unit/sql/test_random.py @@ -1,8 +1,8 @@ +from datachain import func from datachain.sql import select -from datachain.sql.functions import rand def test_rand(warehouse): - query = select(rand()) + query = select(func.random.rand()) result = tuple(warehouse.db.execute(query)) assert isinstance(result[0][0], int) diff --git a/tests/unit/sql/test_string.py b/tests/unit/sql/test_string.py index 2f67d181d..dc43fbb91 100644 --- a/tests/unit/sql/test_string.py +++ b/tests/unit/sql/test_string.py @@ -1,7 +1,7 @@ import pytest -from datachain.sql import literal, select -from datachain.sql.functions import string +from datachain.func import literal, string +from datachain.sql import select def test_length(warehouse): diff --git a/tests/unit/test_func.py b/tests/unit/test_func.py new file mode 100644 index 000000000..3ca009e12 --- /dev/null +++ b/tests/unit/test_func.py @@ -0,0 +1,256 @@ +import pytest +from sqlalchemy import Label + +from datachain import func +from datachain.lib.signal_schema import SignalSchema + + +@pytest.fixture +def rnd(): + return func.random.rand() + + +def test_db_cols(rnd): + assert rnd._db_cols == [] + assert rnd._db_col_type(SignalSchema({})) is None + + +def test_label(rnd): + assert rnd.col_label is None + assert rnd.label("test2") == "test2" + + f = rnd.label("test") + assert f.col_label == "test" + assert f.label("test2") == "test2" + + +def test_col_name(rnd): + assert rnd.get_col_name() == "rand" + assert rnd.label("test").get_col_name() == "test" + assert rnd.get_col_name("test2") == "test2" + + +def test_result_type(rnd): + assert rnd.get_result_type(SignalSchema({})) is int + + +def test_get_column(rnd): + col = rnd.get_column(SignalSchema({})) + assert isinstance(col, Label) + assert col.name == "rand" + + +def test_add(rnd): + f = rnd + 1 + assert str(f) == "add()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1 + rnd + assert str(f) == "add()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_sub(rnd): + f = rnd - 1 + assert str(f) == "sub()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1 - rnd + assert str(f) == "sub()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_mul(rnd): + f = rnd * 1 + assert str(f) == "mul()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1 * rnd + assert str(f) == "mul()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_realdiv(rnd): + f = rnd / 1 + assert str(f) == "div()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1 / rnd + assert str(f) == "div()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_floordiv(rnd): + f = rnd // 1 + assert str(f) == "floordiv()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1 // rnd + assert str(f) == "floordiv()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_mod(rnd): + f = rnd % 1 + assert str(f) == "mod()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1 % rnd + assert str(f) == "mod()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_pow(rnd): + f = rnd**1 + assert str(f) == "pow()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1**rnd + assert str(f) == "pow()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_lshift(rnd): + f = rnd << 1 + assert str(f) == "lshift()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1 << rnd + assert str(f) == "lshift()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_rshift(rnd): + f = rnd >> 1 + assert str(f) == "rshift()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1 >> rnd + assert str(f) == "rshift()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_and(rnd): + f = rnd & 1 + assert str(f) == "and()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1 & rnd + assert str(f) == "and()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_or(rnd): + f = rnd | 1 + assert str(f) == "or()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1 | rnd + assert str(f) == "or()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_xor(rnd): + f = rnd ^ 1 + assert str(f) == "xor()" + assert f.cols == [rnd] + assert f.args == [1] + + f = 1 ^ rnd + assert str(f) == "xor()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_lt(rnd): + f = rnd < 1 + assert str(f) == "lt()" + assert f.cols == [rnd] + assert f.args == [1] + + f = rnd > 1 + assert str(f) == "gt()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_le(rnd): + f = rnd <= 1 + assert str(f) == "le()" + assert f.cols == [rnd] + assert f.args == [1] + + f = rnd >= 1 + assert str(f) == "ge()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_eq(rnd): + f = rnd == 1 + assert str(f) == "eq()" + assert f.cols == [rnd] + assert f.args == [1] + + f = rnd == 1 + assert str(f) == "eq()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_ne(rnd): + f = rnd != 1 + assert str(f) == "ne()" + assert f.cols == [rnd] + assert f.args == [1] + + f = rnd != 1 + assert str(f) == "ne()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_gt(rnd): + f = rnd > 1 + assert str(f) == "gt()" + assert f.cols == [rnd] + assert f.args == [1] + + f = rnd < 1 + assert str(f) == "lt()" + assert f.cols == [rnd] + assert f.args == [1] + + +def test_ge(rnd): + f = rnd >= 1 + assert str(f) == "ge()" + assert f.cols == [rnd] + assert f.args == [1] + + f = rnd <= 1 + assert str(f) == "le()" + assert f.cols == [rnd] + assert f.args == [1] diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 33c680ee7..8d9f9fbdf 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -4,7 +4,8 @@ import sqlalchemy as sa from datachain.error import DatasetNotFoundError -from datachain.query import DatasetQuery, Session +from datachain.query.dataset import DatasetQuery +from datachain.query.session import Session from datachain.sql.types import String