Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SparkLikeNamespace methods #1779

Merged
merged 7 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ def concat_str(
self,
exprs: Iterable[IntoDaskExpr],
*more_exprs: IntoDaskExpr,
separator: str = "",
ignore_nulls: bool = False,
separator: str,
ignore_nulls: bool,
) -> DaskExpr:
parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def concat_str(
self,
exprs: Iterable[IntoPandasLikeExpr],
*more_exprs: IntoPandasLikeExpr,
separator: str = "",
ignore_nulls: bool = False,
separator: str,
ignore_nulls: bool,
) -> PandasLikeExpr:
parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
Expand Down
14 changes: 14 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,17 @@ def skew(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.skewness, "skew", returns_scalar=True)

def n_unique(self: Self) -> Self:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import IntegerType

def _n_unique(_input: Column) -> Column:
return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType()))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Highly inspired by duckdb implementation πŸ˜‚πŸ˜‰


return self._from_call(_n_unique, "n_unique", returns_scalar=True)

def is_null(self: Self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.isnull, "is_null", returns_scalar=self._returns_scalar)
6 changes: 1 addition & 5 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,7 @@ def agg_pyspark(
if expr._output_names is None: # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)

function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
agg_func = get_spark_function(function_name, **expr._kwargs)
agg_func = get_spark_function(expr._function_name, **expr._kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not manage to make nw.len() work in the group_by context

simple_aggregations.update(
{output_name: agg_func(keys[0]) for output_name in expr._output_names}
)
Expand Down
260 changes: 239 additions & 21 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
import operator
from functools import reduce
from typing import TYPE_CHECKING
from typing import Iterable
from typing import Literal

from narwhals._expression_parsing import combine_root_names
from narwhals._expression_parsing import parse_into_exprs
from narwhals._expression_parsing import reduce_output_names
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals._spark_like.utils import get_column_name
from narwhals.typing import CompliantNamespace

if TYPE_CHECKING:
from pyspark.sql import Column
from pyspark.sql import DataFrame

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.typing import IntoSparkLikeExpr
from narwhals.dtypes import DType
from narwhals.utils import Version
Expand Down Expand Up @@ -43,26 +46,6 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]:
kwargs={},
)

def all_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [reduce(operator.and_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="all_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def col(self, *column_names: str) -> SparkLikeExpr:
return SparkLikeExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
Expand Down Expand Up @@ -90,6 +73,64 @@ def _lit(_: SparkLikeLazyFrame) -> list[Column]:
kwargs={},
)

def len(self) -> SparkLikeExpr:
def func(_: SparkLikeLazyFrame) -> list[Column]:
import pyspark.sql.functions as F # noqa: N812

return [F.count("*").alias("len")]

return SparkLikeExpr( # type: ignore[abstract]
func,
depth=0,
function_name="len",
root_names=None,
output_names=["len"],
returns_scalar=True,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)

def all_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [reduce(operator.and_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="all_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def any_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [reduce(operator.or_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="any_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def sum_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

Expand All @@ -116,3 +157,180 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
version=self._version,
kwargs={"exprs": exprs},
)

def mean_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import IntegerType

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [
(
reduce(operator.add, (F.coalesce(col, F.lit(0)) for col in cols))
/ reduce(
operator.add,
(col.isNotNull().cast(IntegerType()) for col in cols),
)
).alias(col_name)
]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="mean_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def max_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [F.greatest(*cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="max_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def min_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [F.least(*cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="min_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def concat(
self,
items: Iterable[SparkLikeLazyFrame],
*,
how: Literal["horizontal", "vertical", "diagonal"],
) -> SparkLikeLazyFrame:
dfs: list[DataFrame] = [item._native_frame for item in items]
if how == "horizontal":
msg = (
"Horizontal concatenation is not supported for LazyFrame backed by "
"a PySpark DataFrame."
)
raise NotImplementedError(msg)

if how == "vertical":
cols_0 = dfs[0].columns
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.columns
if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)):
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0}\n"
f" - dataframe {i}: {cols_current}\n"
)
raise TypeError(msg)

return SparkLikeLazyFrame(
native_dataframe=reduce(lambda x, y: x.union(y), dfs),
backend_version=self._backend_version,
version=self._version,
)

if how == "diagonal":
return SparkLikeLazyFrame(
native_dataframe=reduce(
lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs
),
backend_version=self._backend_version,
version=self._version,
)
raise NotImplementedError

def concat_str(
self,
exprs: Iterable[IntoSparkLikeExpr],
*more_exprs: IntoSparkLikeExpr,
separator: str,
ignore_nulls: bool,
) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import StringType

parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
*parse_into_exprs(*more_exprs, namespace=self),
]

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = (s.cast(StringType()) for _expr in parsed_exprs for s in _expr(df))
null_mask = [F.isnull(s) for _expr in parsed_exprs for s in _expr(df)]

if not ignore_nulls:
null_mask_result = reduce(lambda x, y: x | y, null_mask)
result = F.when(
~null_mask_result,
reduce(lambda x, y: F.format_string(f"%s{separator}%s", x, y), cols),
).otherwise(F.lit(None))
else:
init_value, *values = [
F.when(~nm, col).otherwise(F.lit(""))
for col, nm in zip(cols, null_mask)
]

separators = (
F.when(nm, F.lit("")).otherwise(F.lit(separator))
for nm in null_mask[:-1]
)
result = reduce(
lambda x, y: F.format_string("%s%s", x, y),
(F.format_string("%s%s", s, v) for s, v in zip(separators, values)),
init_value,
)

return [result]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="concat_str",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={
"exprs": exprs,
"more_exprs": more_exprs,
"separator": separator,
"ignore_nulls": ignore_nulls,
},
)
10 changes: 2 additions & 8 deletions tests/expr_and_series/any_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@

@pytest.mark.parametrize("expr1", ["a", nw.col("a")])
@pytest.mark.parametrize("expr2", ["b", nw.col("b")])
def test_anyh(
request: pytest.FixtureRequest, constructor: Constructor, expr1: Any, expr2: Any
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand All @@ -27,9 +23,7 @@ def test_anyh(
assert_equal_data(result, expected)


def test_anyh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_anyh_all(constructor: Constructor) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/concat_str_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_concat_str(
expected: list[str],
request: pytest.FixtureRequest,
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = (
Expand Down
Loading
Loading