Skip to content

Commit

Permalink
support if dataframe of instance pyspark.sql.connect.dataframe.Datafr…
Browse files Browse the repository at this point in the history
…ame is passed as input (#110)

* Added support if dataframe of instance pyspark.sql.connect.dataframe.DataFrame is passed as input

* Added dependencies for Spark 3.5.0 and above

* reverted version change

---------

Co-authored-by: dgunt2 <[email protected]>
  • Loading branch information
diviteja-g and dgunt2 authored Oct 11, 2024
1 parent d860ae5 commit d7db1bb
Show file tree
Hide file tree
Showing 7 changed files with 845 additions and 506 deletions.
2 changes: 1 addition & 1 deletion docs/delta.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ builder = (
SparkSession.builder.config(
"spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension"
)
.config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0")
.config("spark.jars.packages", "io.delta:delta-spark_2.12:3.0.0")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
Expand Down
1,295 changes: 796 additions & 499 deletions poetry.lock

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ readme = "README.md"
packages = [{ include = "spark_expectations" }]

[tool.poetry.dependencies]
python = "^3.8.9"
python = "^3.9,<3.12"
pluggy = ">=1"
pyspark = "^3.0.0,<3.5"
pyspark = "^3.0.0"
requests = "^2.28.1"

[tool.poetry.group.dev.dependencies]
Expand All @@ -18,6 +18,12 @@ pytest = "7.3.1"
pytest-mock = "3.10.0"
coverage = "7.2.5"
pyspark = "^3.0.0"
pandas = "1.5.3"
numpy = "1.26.4"
pyarrow = "7.0.0"
grpcio = "1.48.1"
google = "3.0.0"
protobuf = "4.21.12"
mypy = "1.3.0"
mkdocs = "1.4.3"
prospector = "1.10.0"
Expand Down
2 changes: 1 addition & 1 deletion spark_expectations/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_spark_session() -> SparkSession:
SparkSession.builder.config(
"spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension"
)
.config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0")
.config("spark.jars.packages", "io.delta:delta-spark_2.12:3.0.0")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
Expand Down
33 changes: 31 additions & 2 deletions spark_expectations/core/expectations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import functools
from dataclasses import dataclass
from typing import Dict, Optional, Any, Union
import packaging.version as package_version
from pyspark import version as spark_version
from pyspark import StorageLevel
from pyspark.sql import DataFrame, SparkSession

try:
from pyspark.sql.connect.dataframe import DataFrame as connectDataFrame
except ImportError:
pass
from spark_expectations import _log
from spark_expectations.config.user_config import Constants as user_config
from spark_expectations.core.context import SparkExpectationsContext
Expand All @@ -22,6 +29,14 @@
from spark_expectations.utils.regulate_flow import SparkExpectationsRegulateFlow


min_spark_version_for_connect = "3.4.0"
installed_spark_version = spark_version.__version__
is_spark_connect_supported = bool(
package_version.parse(installed_spark_version)
>= package_version.parse(min_spark_version_for_connect)
)


@dataclass
class SparkExpectations:
"""
Expand All @@ -45,7 +60,13 @@ class SparkExpectations:
stats_streaming_options: Optional[Dict[str, Union[str, bool]]] = None

def __post_init__(self) -> None:
if isinstance(self.rules_df, DataFrame):
# Databricks runtime 14 and above could pass either instance of a Dataframe depending on how data was read
if (
is_spark_connect_supported is True
and isinstance(self.rules_df, (DataFrame, connectDataFrame))
) or (
is_spark_connect_supported is False and isinstance(self.rules_df, DataFrame)
):
try:
self.spark: Optional[SparkSession] = self.rules_df.sparkSession
except AttributeError:
Expand All @@ -55,10 +76,12 @@ def __post_init__(self) -> None:
raise SparkExpectationsMiscException(
"Spark session is not available, please initialize a spark session before calling SE"
)

else:
raise SparkExpectationsMiscException(
"Input rules_df is not of dataframe type"
)

self.actions: SparkExpectationsActions = SparkExpectationsActions()
self._context: SparkExpectationsContext = SparkExpectationsContext(
product_id=self.product_id, spark=self.spark
Expand Down Expand Up @@ -353,7 +376,13 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame:
self._context.get_run_id,
)

if isinstance(_df, DataFrame):
if (
is_spark_connect_supported is True
and isinstance(_df, (DataFrame, connectDataFrame))
) or (
is_spark_connect_supported is False
and isinstance(_df, DataFrame)
):
_log.info("The function dataframe is created")
self._context.set_table_name(table_name)
if write_to_temp_table:
Expand Down
2 changes: 1 addition & 1 deletion spark_expectations/examples/base_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def set_up_delta() -> SparkSession:
SparkSession.builder.config(
"spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension"
)
.config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0")
.config("spark.jars.packages", "io.delta:delta-spark_2.12:3.0.0")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
Expand Down
7 changes: 7 additions & 0 deletions tests/core/test_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from unittest.mock import patch
import pytest
from pyspark.sql import DataFrame, SparkSession


try:
from pyspark.sql.connect.dataframe import DataFrame as connectDataFrame
except ImportError:
pass

from pyspark.sql.functions import lit, to_timestamp, col
from pyspark.sql.types import StringType, IntegerType, StructField, StructType

Expand Down

0 comments on commit d7db1bb

Please sign in to comment.