From ed8cd154136509336fe6b24d97c65d84c8fc1be7 Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Wed, 8 Nov 2023 13:13:57 +0100 Subject: [PATCH] fix DataSetImplements such that we can use functools.reduce again (#221) --- tests/_core/test_dataset.py | 9 ++++ typedspark/_core/dataset.py | 103 ++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/tests/_core/test_dataset.py b/tests/_core/test_dataset.py index 445a32a..1495b44 100644 --- a/tests/_core/test_dataset.py +++ b/tests/_core/test_dataset.py @@ -1,3 +1,5 @@ +import functools + import pandas as pd import pytest from pyspark.sql import SparkSession @@ -94,3 +96,10 @@ def test_schema_property_of_dataset(spark: SparkSession): def test_initialize_dataset_implements(spark: SparkSession): with pytest.raises(NotImplementedError): DataSetImplements() + + +def test_reduce(spark: SparkSession): + functools.reduce( + DataSet.unionByName, + [create_empty_dataset(spark, A), create_empty_dataset(spark, A)], + ) diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index 85ff72e..cc38a04 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -227,3 +227,106 @@ def _add_schema_metadata(self) -> None: """ for field in self._schema_annotations.get_structtype().fields: self.schema[field.name].metadata = field.metadata + + """The following functions are equivalent to their parents in ``DataSetImplements``. However, + to support functions like ``functools.reduce(DataSet.unionByName, datasets)``, we also add them + here. Unfortunately, this leads to some code redundancy, but we'll take that for granted.""" + + def alias(self, alias: str) -> DataSet[_Schema]: + return DataSet[self._schema_annotations](super().alias(alias)) # type: ignore + + def distinct(self) -> DataSet[_Schema]: # pylint: disable=C0116 + return DataSet[self._schema_annotations](super().distinct()) # type: ignore + + def filter(self, condition) -> DataSet[_Schema]: # pylint: disable=C0116 + return DataSet[self._schema_annotations](super().filter(condition)) # type: ignore + + @overload + def join( # type: ignore + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = ..., + how: None = ..., + ) -> DataFrame: + ... # pragma: no cover + + @overload + def join( + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = ..., + how: Literal["semi"] = ..., + ) -> DataSet[_Schema]: + ... # pragma: no cover + + @overload + def join( + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = ..., + how: Optional[str] = ..., + ) -> DataFrame: + ... # pragma: no cover + + def join( # pylint: disable=C0116 + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = None, + how: Optional[str] = None, + ) -> DataFrame: + return super().join(other, on, how) # type: ignore + + def orderBy(self, *args, **kwargs) -> DataSet[_Schema]: # type: ignore # noqa: N802, E501 # pylint: disable=C0116, C0103 + return DataSet[self._schema_annotations](super().orderBy(*args, **kwargs)) # type: ignore + + @overload + def transform( + self, + func: Callable[Concatenate[DataSet[_Schema], P], _ReturnType], + *args: P.args, + **kwargs: P.kwargs, + ) -> _ReturnType: + ... # pragma: no cover + + @overload + def transform(self, func: Callable[..., DataFrame], *args: Any, **kwargs: Any) -> DataFrame: + ... # pragma: no cover + + def transform( # pylint: disable=C0116 + self, func: Callable[..., DataFrame], *args: Any, **kwargs: Any + ) -> DataFrame: + return super().transform(func, *args, **kwargs) + + @overload + def unionByName( # noqa: N802 # pylint: disable=C0116, C0103 + self, + other: DataSet[_Schema], + allowMissingColumns: Literal[False] = ..., # noqa: N803 + ) -> DataSet[_Schema]: + ... # pragma: no cover + + @overload + def unionByName( # noqa: N802 # pylint: disable=C0116, C0103 + self, + other: DataFrame, + allowMissingColumns: bool = ..., # noqa: N803 + ) -> DataFrame: + ... # pragma: no cover + + def unionByName( # noqa: N802 # pylint: disable=C0116, C0103 + self, + other: DataFrame, + allowMissingColumns: bool = False, # noqa: N803 + ) -> DataFrame: + res = super().unionByName(other, allowMissingColumns) + if isinstance(other, DataSet) and other._schema_annotations == self._schema_annotations: + return DataSet[self._schema_annotations](res) # type: ignore + return res # pragma: no cover