Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support if dataframe of instance pyspark.sql.connect.dataframe.Dataframe is passed as input #110

2 changes: 1 addition & 1 deletion docs/delta.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ builder = (
SparkSession.builder.config(
"spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension"
)
.config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0")
.config("spark.jars.packages", "io.delta:delta-spark_2.12:3.0.0")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
Expand Down
1,295 changes: 796 additions & 499 deletions poetry.lock

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ readme = "README.md"
packages = [{ include = "spark_expectations" }]

[tool.poetry.dependencies]
python = "^3.8.9"
python = "^3.9,<3.12"
pluggy = ">=1"
pyspark = "^3.0.0,<3.5"
pyspark = "^3.0.0"
requests = "^2.28.1"

[tool.poetry.group.dev.dependencies]
Expand All @@ -18,6 +18,12 @@ pytest = "7.3.1"
pytest-mock = "3.10.0"
coverage = "7.2.5"
pyspark = "^3.0.0"
pandas = "1.5.3"
numpy = "1.26.4"
pyarrow = "7.0.0"
grpcio = "1.48.1"
google = "3.0.0"
protobuf = "4.21.12"
mypy = "1.3.0"
mkdocs = "1.4.3"
prospector = "1.10.0"
Expand Down
2 changes: 1 addition & 1 deletion spark_expectations/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_spark_session() -> SparkSession:
SparkSession.builder.config(
"spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension"
)
.config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0")
.config("spark.jars.packages", "io.delta:delta-spark_2.12:3.0.0")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
Expand Down
33 changes: 31 additions & 2 deletions spark_expectations/core/expectations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import functools
from dataclasses import dataclass
from typing import Dict, Optional, Any, Union
import packaging.version as package_version
from pyspark import version as spark_version
from pyspark import StorageLevel
from pyspark.sql import DataFrame, SparkSession

try:
from pyspark.sql.connect.dataframe import DataFrame as connectDataFrame
except ImportError:
pass
from spark_expectations import _log
from spark_expectations.config.user_config import Constants as user_config
from spark_expectations.core.context import SparkExpectationsContext
Expand All @@ -22,6 +29,14 @@
from spark_expectations.utils.regulate_flow import SparkExpectationsRegulateFlow


min_spark_version_for_connect = "3.4.0"
installed_spark_version = spark_version.__version__
is_spark_connect_supported = bool(
package_version.parse(installed_spark_version)
>= package_version.parse(min_spark_version_for_connect)
)


@dataclass
class SparkExpectations:
"""
Expand All @@ -45,7 +60,13 @@ class SparkExpectations:
stats_streaming_options: Optional[Dict[str, Union[str, bool]]] = None

def __post_init__(self) -> None:
if isinstance(self.rules_df, DataFrame):
# Databricks runtime 14 and above could pass either instance of a Dataframe depending on how data was read
if (
is_spark_connect_supported is True
and isinstance(self.rules_df, (DataFrame, connectDataFrame))
) or (
is_spark_connect_supported is False and isinstance(self.rules_df, DataFrame)
):
try:
self.spark: Optional[SparkSession] = self.rules_df.sparkSession
except AttributeError:
Expand All @@ -55,10 +76,12 @@ def __post_init__(self) -> None:
raise SparkExpectationsMiscException(
"Spark session is not available, please initialize a spark session before calling SE"
)

else:
raise SparkExpectationsMiscException(
"Input rules_df is not of dataframe type"
)

self.actions: SparkExpectationsActions = SparkExpectationsActions()
self._context: SparkExpectationsContext = SparkExpectationsContext(
product_id=self.product_id, spark=self.spark
Expand Down Expand Up @@ -353,7 +376,13 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame:
self._context.get_run_id,
)

if isinstance(_df, DataFrame):
if (
is_spark_connect_supported is True
and isinstance(_df, (DataFrame, connectDataFrame))
) or (
is_spark_connect_supported is False
and isinstance(_df, DataFrame)
):
_log.info("The function dataframe is created")
self._context.set_table_name(table_name)
if write_to_temp_table:
Expand Down
2 changes: 1 addition & 1 deletion spark_expectations/examples/base_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def set_up_delta() -> SparkSession:
SparkSession.builder.config(
"spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension"
)
.config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0")
.config("spark.jars.packages", "io.delta:delta-spark_2.12:3.0.0")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
Expand Down
7 changes: 7 additions & 0 deletions tests/core/test_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from unittest.mock import patch
import pytest
from pyspark.sql import DataFrame, SparkSession


try:
from pyspark.sql.connect.dataframe import DataFrame as connectDataFrame
except ImportError:
pass

from pyspark.sql.functions import lit, to_timestamp, col
from pyspark.sql.types import StringType, IntegerType, StructField, StructType

Expand Down
Loading