diff --git a/test.ipynb b/test.ipynb index 15e90ee..1ad3508 100644 --- a/test.ipynb +++ b/test.ipynb @@ -11,28 +11,7 @@ "text": [ "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", - "23/07/03 18:53:46 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+----+\n", - "| age|\n", - "+----+\n", - "|null|\n", - "|null|\n", - "|null|\n", - "+----+\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" + "23/08/14 20:09:46 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], @@ -40,76 +19,79 @@ "from typing import Protocol, TypeVar\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.types import LongType, StringType\n", - "from typedspark import Schema, DataSet, Column, PartialDataSet, transform_to_schema, create_empty_dataset\n", + "from typedspark import (\n", + " Schema,\n", + " DataSet,\n", + " Column,\n", + " DataSetExtends,\n", + " DataSetImplements,\n", + " transform_to_schema,\n", + " create_empty_dataset,\n", + ")\n", + "\n", "\n", "class Person(Schema):\n", " name: Column[StringType]\n", " age: Column[LongType]\n", "\n", + "\n", + "class Job(Schema):\n", + " role: Column[StringType]\n", + "\n", + "\n", "class Age(Schema, Protocol):\n", - " type: Column[StringType]\n", + " age: Column[LongType]\n", "\n", - "def get_age(df: PartialDataSet[Age]) -> DataSet[Age]:\n", - " return transform_to_schema(df, Age)\n", "\n", "spark = SparkSession.builder.getOrCreate()\n", "\n", - "df = create_empty_dataset(spark, Person)\n", - "get_age(df).show()" + "person = create_empty_dataset(spark, Person)\n", + "job = create_empty_dataset(spark, Job)" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 10, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+----+----+\n", - "|name| age|\n", - "+----+----+\n", - "|null|null|\n", - "|null|null|\n", - "|null|null|\n", - "+----+----+\n", - "\n" - ] - } - ], + "outputs": [], "source": [ - "T = TypeVar(\"T\", bound=Schema)\n", + "def get_age(df: DataSetExtends[Age]) -> DataSet[Age]:\n", + " return transform_to_schema(df, Age)\n", "\n", - "def birthday(df: PartialDataSet[Age], schema: T) -> DataSet[T]:\n", - " return transform_to_schema(\n", - " df, \n", - " schema, # type: ignore\n", - " {Age.age: Age.age + 1}\n", - " )\n", "\n", - "res: DataSet[Person] = birthday(df, Person)\n", - "res.show()" + "get_age(person)\n", + "try:\n", + " get_age(job) # linting error: Job is not a subtype of Age\n", + "except:\n", + " pass" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 11, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'Base class for protocol classes.\\n\\nProtocol classes are defined as::\\n\\n class Proto(Protocol):\\n def meth(self) -> int:\\n ...\\n\\nSuch classes are primarily used with static type checkers that recognize\\nstructural subtyping (static duck-typing), for example::\\n\\n class C:\\n def meth(self) -> int:\\n return 0\\n\\n def func(x: Proto) -> int:\\n return x.meth()\\n\\n func(C()) # Passes static type check\\n\\nSee PEP 544 for details. Protocol classes decorated with\\n@typing.runtime_checkable act as simple-minded runtime protocols that check\\nonly the presence of given attributes, ignoring their type signatures.\\nProtocol classes can be generic, they are defined as::\\n\\n class GenProto(Protocol[T]):\\n def meth(self) -> T:\\n ...'" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "Age.get_docstring()" + "T = TypeVar(\"T\", bound=Schema)\n", + "\n", + "\n", + "def birthday(\n", + " df: DataSetImplements[Age, T],\n", + ") -> DataSet[T]:\n", + " return transform_to_schema(\n", + " df,\n", + " df.typedspark_schema, # type: ignore\n", + " {\n", + " Age.age: Age.age + 1,\n", + " },\n", + " )\n", + "\n", + "\n", + "res = birthday(person)\n", + "try:\n", + " res = birthday(job) # linting error: Job is not a subtype of Age\n", + "except:\n", + " pass" ] }, { diff --git a/tests/_schema/test_schema.py b/tests/_schema/test_schema.py index 670ec99..79c988c 100644 --- a/tests/_schema/test_schema.py +++ b/tests/_schema/test_schema.py @@ -106,7 +106,7 @@ def test_get_snake_case(): def test_get_docstring(): - assert A.get_docstring() is None + assert A.get_docstring() == "" assert PascalCase.get_docstring() == "Schema docstring." @@ -125,7 +125,7 @@ def test_get_structtype(): def test_get_dlt_kwargs(): assert A.get_dlt_kwargs() == DltKwargs( name="a", - comment=None, + comment="", schema=StructType( [StructField("a", LongType(), True), StructField("b", StringType(), True)] ), diff --git a/typedspark/__init__.py b/typedspark/__init__.py index 681f4c1..1944219 100644 --- a/typedspark/__init__.py +++ b/typedspark/__init__.py @@ -2,7 +2,7 @@ from typedspark._core.column import Column from typedspark._core.column_meta import ColumnMeta -from typedspark._core.dataset import DataSet, PartialDataSet +from typedspark._core.dataset import DataSet, DataSetExtends, DataSetImplements from typedspark._core.datatypes import ( ArrayType, DayTimeIntervalType, @@ -36,7 +36,8 @@ "IntervalType", "MapType", "MetaSchema", - "PartialDataSet", + "DataSetExtends", + "DataSetImplements", "Schema", "StructType", "create_empty_dataset", diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index 19efd66..30b8aeb 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -21,18 +21,23 @@ from typedspark._core.validate_schema import validate_schema from typedspark._schema.schema import Schema -T = TypeVar("T", bound=Schema) +_Schema = TypeVar("_Schema", bound=Schema) _ReturnType = TypeVar("_ReturnType", bound=DataFrame) # pylint: disable=C0103 P = ParamSpec("P") -V = TypeVar("V", bound=Schema, covariant=True) +_Protocol = TypeVar("_Protocol", bound=Schema, covariant=True) +_Implementation = TypeVar("_Implementation", bound=Schema, covariant=True) -class PartialDataSet(DataFrame, Generic[V]): - pass +class DataSetImplements(DataFrame, Generic[_Protocol, _Implementation]): + """TODO.""" -class DataSet(PartialDataSet, Generic[T]): +class DataSetExtends(DataSetImplements[_Protocol, _Protocol], Generic[_Protocol]): + """ "TODO.""" + + +class DataSet(DataSetExtends[_Schema]): """``DataSet`` subclasses pyspark ``DataFrame`` and hence has all the same functionality, with in addition the possibility to define a schema. @@ -47,7 +52,7 @@ def foo(df: DataSet[Person]) -> DataSet[Person]: return df """ - def __new__(cls, dataframe: DataFrame) -> "DataSet[T]": + def __new__(cls, dataframe: DataFrame) -> "DataSet[_Schema]": """``__new__()`` instantiates the object (prior to ``__init__()``). Here, we simply take the provided ``df`` and cast it to a @@ -93,7 +98,7 @@ def _add_schema_metadata(self) -> None: self.schema[field.name].metadata = field.metadata @property - def typedspark_schema(self) -> Type[T]: + def typedspark_schema(self) -> Type[_Schema]: """Returns the ``Schema`` of the ``DataSet``.""" return self._schema_annotations # type: ignore @@ -101,10 +106,10 @@ def typedspark_schema(self) -> Type[T]: don't affect the ``Schema``, we can add type annotations here. We're omitting docstrings, such that the docstring from the parent will appear.""" - def distinct(self) -> "DataSet[T]": # pylint: disable=C0116 + def distinct(self) -> "DataSet[_Schema]": # pylint: disable=C0116 return DataSet[self._schema_annotations](super().distinct()) # type: ignore - def filter(self, condition) -> "DataSet[T]": # pylint: disable=C0116 + def filter(self, condition) -> "DataSet[_Schema]": # pylint: disable=C0116 return DataSet[self._schema_annotations](super().filter(condition)) # type: ignore @overload @@ -126,7 +131,7 @@ def join( Union[str, List[str], SparkColumn, List[SparkColumn]] ] = ..., how: Literal["semi"] = ..., - ) -> "DataSet[T]": + ) -> "DataSet[_Schema]": ... # pragma: no cover @overload @@ -150,13 +155,13 @@ def join( # pylint: disable=C0116 ) -> DataFrame: return super().join(other, on, how) # type: ignore - def orderBy(self, *args, **kwargs) -> "DataSet[T]": # type: ignore # noqa: N802, E501 # pylint: disable=C0116, C0103 + def orderBy(self, *args, **kwargs) -> "DataSet[_Schema]": # type: ignore # noqa: N802, E501 # pylint: disable=C0116, C0103 return DataSet[self._schema_annotations](super().orderBy(*args, **kwargs)) # type: ignore @overload def transform( self, - func: Callable[Concatenate["DataSet[T]", P], _ReturnType], + func: Callable[Concatenate["DataSet[_Schema]", P], _ReturnType], *args: P.args, **kwargs: P.kwargs, ) -> _ReturnType: @@ -174,9 +179,9 @@ def transform( # pylint: disable=C0116 @overload def unionByName( # noqa: N802 # pylint: disable=C0116, C0103 self, - other: "DataSet[T]", + other: "DataSet[_Schema]", allowMissingColumns: Literal[False] = ..., # noqa: N803 - ) -> "DataSet[T]": + ) -> "DataSet[_Schema]": ... # pragma: no cover @overload diff --git a/typedspark/_schema/get_schema_definition.py b/typedspark/_schema/get_schema_definition.py index 237b722..ac17d7e 100644 --- a/typedspark/_schema/get_schema_definition.py +++ b/typedspark/_schema/get_schema_definition.py @@ -1,5 +1,4 @@ -"""Module to output a string with the ``Schema`` definition of a given -``DataFrame``.""" +"""Module to output a string with the ``Schema`` definition of a given ``DataFrame``.""" from __future__ import annotations import re @@ -57,7 +56,7 @@ def _build_schema_definition_string( def _create_docstring(schema: Type[Schema]) -> str: """Create the docstring for a given ``Schema``.""" - if schema.get_docstring() is not None: + if schema.get_docstring() != "": docstring = f' """{schema.get_docstring()}"""\n\n' else: docstring = ' """Add documentation here."""\n\n' @@ -140,8 +139,8 @@ def _replace_literal( def _add_subschemas(schema: Type[Schema], add_subschemas: bool, include_documentation: bool) -> str: - """Identifies whether any ``Column`` are of the ``StructType`` type and - generates their schema recursively.""" + """Identifies whether any ``Column`` are of the ``StructType`` type and generates + their schema recursively.""" lines = "" for val in get_type_hints(schema).values(): args = get_args(val) diff --git a/typedspark/_schema/schema.py b/typedspark/_schema/schema.py index bb158d4..c063c3e 100644 --- a/typedspark/_schema/schema.py +++ b/typedspark/_schema/schema.py @@ -12,7 +12,6 @@ _ProtocolMeta, get_args, get_type_hints, - runtime_checkable, ) from pyspark.sql import DataFrame @@ -116,8 +115,8 @@ def all_column_names(cls) -> List[str]: return list(get_type_hints(cls).keys()) def all_column_names_except_for(cls, except_for: List[str]) -> List[str]: - """Returns all column names for a given schema except for the columns - specified in the ``except_for`` parameter.""" + """Returns all column names for a given schema except for the columns specified + in the ``except_for`` parameter.""" return list(name for name in get_type_hints(cls).keys() if name not in except_for) def get_snake_case(cls) -> str: @@ -177,8 +176,7 @@ def get_structtype(cls) -> StructType: ) def get_dlt_kwargs(cls, name: Optional[str] = None) -> DltKwargs: - """Creates a representation of the ``Schema`` to be used by Delta Live - Tables. + """Creates a representation of the ``Schema`` to be used by Delta Live Tables. .. code-block:: python @@ -198,10 +196,9 @@ def get_schema_name(cls): class Schema(Protocol, metaclass=MetaSchema): - # pylint: disable=missing-class-docstring + # pylint: disable=empty-docstring # Since docstrings are inherrited, and since we use docstrings to # annotate tables (see MetaSchema.get_dlt_kwargs()), we have chosen # to add an empty docstring to the Schema class (otherwise the Schema # docstring would be added to any schema without a docstring). - __doc__ = None - pass + """"""