Skip to content

Commit

Permalink
added nested case ability
Browse files Browse the repository at this point in the history
  • Loading branch information
ilongin committed Jan 9, 2025
1 parent be065d4 commit ba5d315
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 12 deletions.
24 changes: 17 additions & 7 deletions src/datachain/func/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .func import ColT, Func

CaseT = Union[int, float, complex, bool, str]
CaseT = Union[int, float, complex, bool, str, Func]


def greatest(*args: Union[ColT, float]) -> Func:
Expand Down Expand Up @@ -88,7 +88,7 @@ def least(*args: Union[ColT, float]) -> Func:
)


def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
def case(*args: tuple, else_=None) -> Func:
"""
Returns the case function that produces case expression which has a list of
conditions and corresponding results. Results can only be python primitives
Expand All @@ -112,23 +112,33 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
"""
supported_types = [int, float, complex, str, bool]

type_ = type(else_) if else_ else None
def _get_type(val):
if isinstance(val, Func):
# nested functions
return val.result_type
return type(val)

if not args:
raise DataChainParamsError("Missing statements")

type_ = _get_type(else_) if else_ is not None else None

for arg in args:
if type_ and not isinstance(arg[1], type_):
raise DataChainParamsError("Statement values must be of the same type")
type_ = type(arg[1])
arg_type = _get_type(arg[1])
if type_ and arg_type != type_:
raise DataChainParamsError(
f"Statement values must be of the same type, got {type_} amd {arg_type}"
)
type_ = arg_type

if type_ not in supported_types:
raise DataChainParamsError(
f"Only python literals ({supported_types}) are supported for values"
)

kwargs = {"else_": else_}
return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)

return Func("case", inner=sql_case, cols=args, kwargs=kwargs, result_type=type_)


def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
Expand Down
11 changes: 7 additions & 4 deletions src/datachain/func/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .window import Window


ColT = Union[str, ColumnElement, "Func"]
ColT = Union[str, ColumnElement, "Func", tuple]


class Func(Function):
Expand Down Expand Up @@ -78,7 +78,7 @@ def _db_cols(self) -> Sequence[ColT]:
return (
[
col
if isinstance(col, (Func, BindParameter, Case, Comparator))
if isinstance(col, (Func, BindParameter, Case, Comparator, tuple))
else ColumnMeta.to_db_name(
col.name if isinstance(col, ColumnElement) else col
)
Expand Down Expand Up @@ -382,6 +382,8 @@ def get_column(
sql_type = python_to_sql(col_type)

def get_col(col: ColT) -> ColT:
if isinstance(col, tuple):
return tuple(get_col(x) for x in col)
if isinstance(col, Func):
return col.get_column(signals_schema, table=table)
if isinstance(col, str):
Expand All @@ -391,7 +393,8 @@ def get_col(col: ColT) -> ColT:
return col

cols = [get_col(col) for col in self._db_cols]
func_col = self.inner(*cols, *self.args, **self.kwargs)
kwargs = {k: get_col(v) for k, v in self.kwargs.items()}
func_col = self.inner(*cols, *self.args, **kwargs)

if self.is_window:
if not self.window:
Expand Down Expand Up @@ -423,7 +426,7 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
return sql_to_python(col)

return signals_schema.get_column_type(
col.name if isinstance(col, ColumnElement) else col
col.name if isinstance(col, ColumnElement) else col # type: ignore[arg-type]
)


Expand Down
4 changes: 3 additions & 1 deletion tests/unit/sql/test_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def test_case_not_same_result_types(warehouse):
val = 2
with pytest.raises(DataChainParamsError) as exc_info:
select(func.case(*[(val > 1, "A"), (2 < val < 4, 5)], else_="D"))
assert str(exc_info.value) == "Statement values must be of the same type"
assert str(exc_info.value) == (
"Statement values must be of the same type, got <class 'str'> amd <class 'int'>"
)


def test_case_wrong_result_type(warehouse):
Expand Down
66 changes: 66 additions & 0 deletions tests/unit/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,59 @@ def test_case_mutate(dc, val, else_, type_):
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"val,else_,type_",
[
["A", "D", str],
[1, 2, int],
[1.5, 2.5, float],
[True, False, bool],
],
)
def test_nested_case_on_condition_mutate(dc, val, else_, type_):
res = dc.mutate(
test=case((case((C("num") < 2, True), else_=False), val), else_=else_)
)
assert list(res.order_by("test").collect("test")) == sorted(
[val, else_, else_, else_, else_]
)
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"v1,v2,v3,type_",
[
["A", "B", "C", str],
[1, 2, 3, int],
[1.5, 2.5, 3.5, float],
[False, True, True, bool],
],
)
def test_nested_case_on_value_mutate(dc, v1, v2, v3, type_):
res = dc.mutate(
test=case((C("num") < 4, case((C("num") < 2, v1), else_=v2)), else_=v3)
)
assert list(res.order_by("num").collect("test")) == sorted([v1, v2, v2, v3, v3])
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"v1,v2,v3,type_",
[
["A", "B", "C", str],
[1, 2, 3, int],
[1.5, 2.5, 3.5, float],
[False, True, True, bool],
],
)
def test_nested_case_on_else_mutate(dc, v1, v2, v3, type_):
res = dc.mutate(
test=case((C("num") < 3, v1), else_=case((C("num") < 4, v2), else_=v3))
)
assert list(res.order_by("num").collect("test")) == sorted([v1, v1, v2, v3, v3])
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"if_val,else_val,type_",
[
Expand Down Expand Up @@ -695,3 +748,16 @@ def test_isnone_mutate(col):
[False, False, False, True, True]
)
assert res.schema["test"] is bool


@pytest.mark.parametrize("col", [C("val"), "val"])
@skip_if_not_sqlite
def test_isnone_with_ifelse_mutate(col):
dc = DataChain.from_values(
num=list(range(1, 6)),
val=[None if i > 3 else "A" for i in range(1, 6)],
)

res = dc.mutate(test=ifelse(isnone(col), "NONE", "NOT_NONE"))
assert list(res.order_by("num").collect("test")) == ["NOT_NONE"] * 3 + ["NONE"] * 2
assert res.schema["test"] is str

0 comments on commit ba5d315

Please sign in to comment.