Skip to content

Commit

Permalink
fix DataSetImplements such that we can use functools.reduce again (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
nanne-aben authored Nov 8, 2023
1 parent 5e734df commit ed8cd15
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/_core/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools

import pandas as pd
import pytest
from pyspark.sql import SparkSession
Expand Down Expand Up @@ -94,3 +96,10 @@ def test_schema_property_of_dataset(spark: SparkSession):
def test_initialize_dataset_implements(spark: SparkSession):
with pytest.raises(NotImplementedError):
DataSetImplements()


def test_reduce(spark: SparkSession):
functools.reduce(
DataSet.unionByName,
[create_empty_dataset(spark, A), create_empty_dataset(spark, A)],
)
103 changes: 103 additions & 0 deletions typedspark/_core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,106 @@ def _add_schema_metadata(self) -> None:
"""
for field in self._schema_annotations.get_structtype().fields:
self.schema[field.name].metadata = field.metadata

"""The following functions are equivalent to their parents in ``DataSetImplements``. However,
to support functions like ``functools.reduce(DataSet.unionByName, datasets)``, we also add them
here. Unfortunately, this leads to some code redundancy, but we'll take that for granted."""

def alias(self, alias: str) -> DataSet[_Schema]:
return DataSet[self._schema_annotations](super().alias(alias)) # type: ignore

def distinct(self) -> DataSet[_Schema]: # pylint: disable=C0116
return DataSet[self._schema_annotations](super().distinct()) # type: ignore

def filter(self, condition) -> DataSet[_Schema]: # pylint: disable=C0116
return DataSet[self._schema_annotations](super().filter(condition)) # type: ignore

@overload
def join( # type: ignore
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = ...,
how: None = ...,
) -> DataFrame:
... # pragma: no cover

@overload
def join(
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = ...,
how: Literal["semi"] = ...,
) -> DataSet[_Schema]:
... # pragma: no cover

@overload
def join(
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = ...,
how: Optional[str] = ...,
) -> DataFrame:
... # pragma: no cover

def join( # pylint: disable=C0116
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = None,
how: Optional[str] = None,
) -> DataFrame:
return super().join(other, on, how) # type: ignore

def orderBy(self, *args, **kwargs) -> DataSet[_Schema]: # type: ignore # noqa: N802, E501 # pylint: disable=C0116, C0103
return DataSet[self._schema_annotations](super().orderBy(*args, **kwargs)) # type: ignore

@overload
def transform(
self,
func: Callable[Concatenate[DataSet[_Schema], P], _ReturnType],
*args: P.args,
**kwargs: P.kwargs,
) -> _ReturnType:
... # pragma: no cover

@overload
def transform(self, func: Callable[..., DataFrame], *args: Any, **kwargs: Any) -> DataFrame:
... # pragma: no cover

def transform( # pylint: disable=C0116
self, func: Callable[..., DataFrame], *args: Any, **kwargs: Any
) -> DataFrame:
return super().transform(func, *args, **kwargs)

@overload
def unionByName( # noqa: N802 # pylint: disable=C0116, C0103
self,
other: DataSet[_Schema],
allowMissingColumns: Literal[False] = ..., # noqa: N803
) -> DataSet[_Schema]:
... # pragma: no cover

@overload
def unionByName( # noqa: N802 # pylint: disable=C0116, C0103
self,
other: DataFrame,
allowMissingColumns: bool = ..., # noqa: N803
) -> DataFrame:
... # pragma: no cover

def unionByName( # noqa: N802 # pylint: disable=C0116, C0103
self,
other: DataFrame,
allowMissingColumns: bool = False, # noqa: N803
) -> DataFrame:
res = super().unionByName(other, allowMissingColumns)
if isinstance(other, DataSet) and other._schema_annotations == self._schema_annotations:
return DataSet[self._schema_annotations](res) # type: ignore
return res # pragma: no cover

0 comments on commit ed8cd15

Please sign in to comment.